Construindo aplicativos com o Mediapipe

Hoje, muitos serviços usam modelos de redes neurais em seus trabalhos. No entanto, devido ao baixo desempenho dos dispositivos clientes, na maioria dos casos, os cálculos são realizados no servidor. No entanto, o desempenho dos smartphones cresce a cada ano e agora é possível lançar pequenos modelos em dispositivos clientes. Surge a pergunta: como fazer isso? Além de executar o modelo, você precisa pré-processar e pós-processar os dados. Além disso, existem pelo menos duas plataformas em que isso precisa ser implementado: Android e iOS. O Mediapipe é uma estrutura para o lançamento de pipelines (pré-processamento de dados, inferência de um modelo e pós-processamento dos resultados do modelo) de aprendizado de máquina, o que permite resolver os problemas descritos acima e simplificar a gravação de código de plataforma cruzada para executar modelos.



Conteúdo


  1. Visão geral do Bazel Build System
  2. Estrutura do Mediapipe
    1. O que é um Mediapipe. Em que consiste e por que é necessário
    2. Aplicativo HelloWorld no Mediapipe
    3. Executando um Modelo Usando o Mediapipe
  3. Criando um aplicativo usando o Mediapipe
  4. Conclusão

Visão geral do Bazel Build System


O Mediapipe usa Bazel para montagem. Foi a primeira vez que me deparei com esse sistema de compilação, então, paralelamente ao Mediapipe, descobri com o Bazel, coletando informações de várias fontes e entrando em um ancinho. Portanto, antes de falar sobre o próprio Mediapipe, gostaria de fazer uma pequena introdução ao sistema de montagem Bazel, que é usado para criar o Mediapipe e os projetos baseados nele. Bazel é um sistema de compilação do Google, uma versão revisada de seu sistema interno de compilação Blaze, compartilhado. O Bazel permite combinar diferentes idiomas e estruturas em um projeto com um comando simples:


bazel build TARGET

, ? , Bazel-. WORKSPACE, . , . , A, B, WORKSPACE A, B. - WORKSPACE , Mediapipe. Bazel , , , , , . WORKSPACE:


#      
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

#        c++
http_archive(
    name = "rules_cc",
    strip_prefix = "rules_cc-master",
    urls = ["https://github.com/bazelbuild/rules_cc/archive/master.zip"],
)

#   googletest
http_archive(
    name = "com_google_googletest",
    strip_prefix = "googletest-master",
    urls = ["https://github.com/google/googletest/archive/master.zip"],
)

(bazel rule). — Starlark. Bazel , , . , http_archive , http_archive C++ googletest.


WORKSPACE Bazel BUILD. BUILD , , . BUILD C++ :


cc_binary(
    name="HelloWorld",
    srcs=["main.cpp"],
    copts=["-Wall", "-Wpedantic", "-Werror"]
)

, , WORKSPACE http_archive. Bazel- main.cpp.


  • ./WORKSPACE
  • ./Source/BUILD
  • ./Source/main.cpp


bazel build //Source:HelloWorld

Source — , HelloWorld — .


cmake, ( ) ./bazel-bin.


, , . Bazel (libraries).


cc_library(
    name = "lib",
    hdrs = ["utils.h"],
    srcs = ["utils.cpp"],
    visibility = [
        "//visibility:public",
    ],
)

cc_binary(
    name="HelloWorld",
    srcs=["main.cpp"],
    deps=[":lib"],
    copts=["-Wall", "-Wpedantic", "-Werror"]
)

BUILD , . , //


@some_repository//Module:lib

Bazel , . http_archive, git_repository, local_repository .


Mediapipe


Mediapipe


Mediapipe — . , ( ):


  1. (Calculators) — (). . C++ , CalculatorBase:
    • static Status GetContract(CalculatorContract*); — , , .
    • Status Open(CalculatorContext*); — . , , , .
    • Status Process(CalculatorContext*); — .
    • Status Close(CalculatorContext*); — .
  2. (Streams) . . , (input) (output). , , , .
  3. (Packet) — , . — , , , protobuf. timestamp — , . , , , , .

protobuf.


Mediapipe , . Linux, WSL, MacOS, Android, iOS. Mediapipe TensorFlow TFLite . , , .


, Mediapipe. , . Mediapipe, , - .


HelloWorld Mediapipe


, , N , .



.pbtxt , main.cpp.


├── hello-world
│   ├── BUILD
│   ├── graph.pbtxt
│   ├── main.cpp
│   ├── RepeatNTimesCalculator.cpp
│   └── RepeatNTimesCalculator.proto
└── WORKSPACE

:


input_stream: "in"
output_stream: "out"

node {
    calculator: "RepeatNTimesCalculator"
    input_stream: "in"
    output_stream: "OUTPUT_TAG:out"
    node_options: {
        [type.googleapis.com/mediapipe_demonstration.RepeatNTimesCalculatoOptions] {
            n: 3
        }
    }
}

RepeatNTimesCalculator . GetContract, , Open, , Process, .


class RepeatNTimesCalculator : public mediapipe::CalculatorBase {
public:
    static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) {
        //        ,  
        cc->Inputs().Get("", 0).Set<std::string>();
        //       OUTPUT_TAG,     
        cc->Outputs().Get("OUTPUT_TAG", 0).Set<std::string>();
        return mediapipe::OkStatus();
    }

    mediapipe::Status Open(mediapipe::CalculatorContext* cc) final {
        //   ,    
        const auto& options = cc->Options<mediapipe_demonstration::RepeatNTimesCalculatoOptions>();
        // n -      
        n_ = options.n();
        return mediapipe::OkStatus();
    }

    mediapipe::Status Process(mediapipe::CalculatorContext* cc) final {
        //     
        //           
        //       std::string
        auto txt = cc->Inputs().Index(0).Value().Get<std::string>();

        for (int i = 0; i < n_; ++i) {
            //      
            auto packet = mediapipe::MakePacket<std::string>(txt).At(cc->InputTimestamp() + i);
            //       OUTPUT_TAG   0
            cc->Outputs().Get("OUTPUT_TAG", 0).AddPacket(packet);
        }

        return mediapipe::OkStatus();
    }
private:
    int n_;
};
//    
REGISTER_CALCULATOR(RepeatNTimesCalculator);

.cpp , , REGISTER_CALCULATOR.


proto- . RepeatNTimesCalculatoOptions, , .


syntax = "proto2";
package mediapipe_demonstration;
import "mediapipe/framework/calculator_options.proto";
message RepeatNTimesCalculatoOptions {
  extend mediapipe.CalculatorOptions {
    optional RepeatNTimesCalculatoOptions ext = 350607623;
  }
  required int32 n = 2;
}

:


mediapipe::Status RunGraph() {
    //    
    std::ifstream file("./hello-world/graph.pbtxt");
    std::string graph_file_content;
    graph_file_content.assign(
        std::istreambuf_iterator<char>(file), 
        std::istreambuf_iterator<char>());
    mediapipe::CalculatorGraphConfig config = 
        mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig>(graph_file_content);
    //  
    mediapipe::CalculatorGraph graph;
    MP_RETURN_IF_ERROR(graph.Initialize(config));
    //    
    ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller, graph.AddOutputStreamPoller("out"));
    //  
    MP_RETURN_IF_ERROR(graph.StartRun({}));
    //        
    auto input_packet = mediapipe::MakePacket<std::string>("Hello!").At(mediapipe::Timestamp(0));
    MP_RETURN_IF_ERROR(graph.AddPacketToInputStream("in", input_packet));
    MP_RETURN_IF_ERROR(graph.CloseInputStream("in"));
    //    
    mediapipe::Packet packet;
    while (poller.Next(&packet)) {
        std::cout << packet.Get<std::string>() << std::endl;
    }
    return graph.WaitUntilDone();
}

, , BUILD , .


load("@mediapipe_repository//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
#     
proto_library(
    name = "repeat_n_times_calculator_proto",
    srcs = ["RepeatNTimesCalculator.proto"],
    visibility = ["//visibility:public"],
    deps = [
        "@mediapipe_repository//mediapipe/framework:calculator_proto",
    ],
)
#     
mediapipe_cc_proto_library(
    name = "repeat_n_times_calculator_cc_proto",
    srcs = ["RepeatNTimesCalculator.proto"],
    cc_deps = [
        "@mediapipe_repository//mediapipe/framework:calculator_cc_proto",
    ],
    visibility = ["//visibility:public"],
    deps = [":repeat_n_times_calculator_proto"],
)
#   .  ,    
cc_library(
    name = "repeat_n_times_calculator",
    srcs = ["RepeatNTimesCalculator.cpp"],
    visibility = [
        "//visibility:public",
    ],
    deps = [
        ":repeat_n_times_calculator_cc_proto",
        "@mediapipe_repository//mediapipe/framework:calculator_framework",
        "@mediapipe_repository//mediapipe/framework/port:status",
    ],
    alwayslink = 1,
)
#    ,    .
cc_binary(
    name = "HelloMediapipe",
    srcs = ["main.cpp"],
    deps = [
        "repeat_n_times_calculator",
        "@mediapipe_repository//mediapipe/framework/port:logging",
        "@mediapipe_repository//mediapipe/framework/port:parse_text_proto",
        "@mediapipe_repository//mediapipe/framework/port:status",
    ],
)

:


$ bazel-2.0.0 build --define MEDIAPIPE_DISABLE_GPU=1 //hello-world:HelloMediapipe
...
INFO: Build completed successfully, 4 total actions
$ ./bazel-bin/hello-world/HelloMediapipe
Hello!
Hello!
Hello!

Mediapipe, ML .


Mediapipe


, Mediapipe tflite . . , ImageNet1k. , , . , .


├── inference
│   ├── android/src/main/java/com/com/mediapipe_demonstration/inference
│   │   ├── AndroidManifest.xml
│   │   ├── BUILD
│   │   ├── MainActivity.kt
│   │   └── res
│   │       └── ...
│   ├── BUILD
│   ├── desktop
│   │   ├── BUILD
│   │   └── main.cpp
│   ├── graph.pbtxt
│   ├── img.jpg
│   └── mobilenetv2_imagenet.tflite
└── WORKSPACE

:



input_stream: "in"
output_stream: "out"
node: {
  calculator: "ImageTransformationCalculator"
  input_stream: "IMAGE:in"
  output_stream: "IMAGE:transformed_input"
  node_options: {
    [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
      output_width: 224
      output_height: 224
    }
  }
}
node {
  calculator: "TfLiteConverterCalculator"
  input_stream: "IMAGE:transformed_input"
  output_stream: "TENSORS:image_tensor"
  node_options: {
      [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] {
        zero_center: false
      }
  }
}
node {
  calculator: "TfLiteInferenceCalculator"
  input_stream: "TENSORS:image_tensor"
  output_stream: "TENSORS:prediction_tensor"
  node_options: {
    [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
      model_path: "inference/mobilenetv2_imagenet.tflite"
    }
  }
}
node {
  calculator: "TfLiteTensorsToFloatsCalculator"
  input_stream: "TENSORS:prediction_tensor"
  output_stream: "FLOATS:out"
}

4 :


  1. 224x224.
  2. [-1, 1] ( zero_center: false, [0, 1]), ImageFrame ( Mediapipe, , ) TfLiteTensor ( std::vector<TfLiteTensor> 1). TfLiteTensor ( GetContract), .
  3. .
  4. TfLiteTensor std::vector<float>. .

, .


. . , , .


//  
auto img_mat = cv::imread("./inference/img.jpg");
//    
auto input_frame = std::make_unique<mediapipe::ImageFrame>(
    mediapipe::ImageFormat::SRGB, img_mat.cols, img_mat.rows,
    mediapipe::ImageFrame::kDefaultAlignmentBoundary);
cv::Mat input_frame_mat = mediapipe::formats::MatView(input_frame.get());
img_mat.copyTo(input_frame_mat);
auto frame = mediapipe::Adopt(input_frame.release()).At(mediapipe::Timestamp(0));
//    
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream("in", frame));
MP_RETURN_IF_ERROR(graph.CloseInputStream("in"));
//       
mediapipe::Packet packet;
while (poller.Next(&packet)) {
    auto predictions = packet.Get<std::vector<float>>();
    int idx = std::max_element(predictions.begin(), predictions.end()) - predictions.begin();
    std::cout << idx << std::endl;
}

, :



$ bazel-2.0.0 build --define MEDIAPIPE_DISABLE_GPU=1 //inference/desktop:Inference
...
INFO: Build completed successfully, 3 total actions
$ ./bazel-bin/inference/desktop/Inference
INFO: Initialized TensorFlow Lite runtime.
151

151 ImageNet "Chihuahua".


. Activity, ( onCreate) , Mediapipe. protobuf-. , . onActivityResult, . Mediapipe - ( ).


class MainActivity : AppCompatActivity() {
    val PICK_IMAGE = 1
    var mpGraph: Graph? = null
    var timestamp = 0L
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        this.setContentView(R.layout.activity_main)
        val outputTv = findViewById<TextView>(R.id.outputTv)
        val button = findViewById<Button>(R.id.selectButton)
        AndroidAssetUtil.initializeNativeAssetManager(this)
        //    
        //        
        val graph = Graph()
        assets.open("mobile_binary_graph.binarypb").use {
            val graphBytes = it.readBytes()
            graph.loadBinaryGraph(graphBytes)
        }
        //    
        graph.addPacketCallback("out") {
            val res = PacketGetter.getFloat32Vector(it)
            val label = res.indices.maxBy { i -> res[i] } ?: -1
            this@MainActivity.runOnUiThread {
                outputTv.text = label.toString()
            }
        }
        graph.startRunningGraph()
        //      
        button.setOnClickListener {
            val intent = Intent()
            intent.type = "image/*"
            intent.action = Intent.ACTION_GET_CONTENT
            startActivityForResult(Intent.createChooser(intent, "Select Picture"), PICK_IMAGE)
        }
        mpGraph = graph
    }
    override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
        if (requestCode == PICK_IMAGE) {
            //        
            val outputTv = findViewById<TextView>(R.id.outputTv)
            val imageView = findViewById<ImageView>(R.id.imageView)
            val uri = data?.data!!
            //    
            val graph = mpGraph!!
            val creator = AndroidPacketCreator(graph)
            val stream = contentResolver.openInputStream(uri)
            val bitmap = BitmapFactory.decodeStream(stream)
            imageView.setImageBitmap(bitmap)
            val packet = creator.createRgbImageFrame(bitmap)
            graph.addPacketToInputStream("in", packet, timestamp)
        }
    }
    companion object {
        init {
            //   mediapipe 
            System.loadLibrary("mediapipe_jni")
        }
    }
}

:


$ bazel-2.0.0 mobile-install --start_app -c opt --config=android_arm64 //inference/android/src/main/java/com/mediapipe_demonstration/inference:Inference

, :



, .


, github


Mediapipe


Mediapipe


. . , . , . , . :


  • , GPU, ;
  • .

, Mediapipe. , , , . . : , .


Mediapipe, . , ImageNet1k, Stanford Cars Dataset. . :



logit- NLL loss. "" mean pool.


Android , Mediapipe, Mediapipe , .


.
├── Android/src/main/java/com/kshmax/objectrecognition
│   └── objectrecognition
│       ├── AndroidManifest.xml
│       ├── BUILD
│       ├── MainActivity.kt
│       ├── Models.kt
│       ├── ObjectDetectionFrameProcessor.java
│       └── res
│           └── ...
├── Calculators
│   ├── BoundaryBoxCropCalculator.cpp
│   ├── BUILD
│   ├── ClickLocation.proto
│   ├── DetectionFilterCalculator.cpp
│   ├── DetectionFilterCalculator.proto
├── Graphs
│   ├── BUILD
│   └── ObjectRecognitionGraph.pbtxt
├── Models
│   ├── BUILD
│   └── car_vectorizer.tflite
├── Server
│   ├── labels.npy
│   ├── server.py
│   └── vectors.npy
└── WORKSPACE

Mediapipe :



, .


, . , , . : , . FlowLimiter . . AnnotationOverlay , . .


DetectionFilter, , . DetectionFilter, , , .


Process DetectionFilter :


mediapipe::Status DetectionFilterCalculator::Process(mediapipe::CalculatorContext *cc) {
    const auto& input_detections = cc->Inputs().Get("", 0).Get<std::vector<::mediapipe::Detection>>();
    std::vector<::mediapipe::Detection> output_detections;
    for (const auto& input_detection : input_detections) {
        bool next_detection = false;

        for (int pass_id : pass_ids_) {
            for (int label_id : input_detection.label_id()) {
                if (pass_id == label_id) {
                    output_detections.push_back(input_detection);
                    next_detection = true;
                    break;
                }
            }
            if (next_detection) {
                break;
            }
        }
    }
    auto out_packet = mediapipe::MakePacket<std::vector<mediapipe::Detection>>(output_detections).At(cc->InputTimestamp());
    cc->Outputs().Get("", 0).AddPacket(out_packet);

    return mediapipe::OkStatus();
}

DetectionFilterCalculator :


node {
    calculator: "DetectionFilterCalculator"
    input_stream: "filtered_detections"
    output_stream: "car_detections"
    node_options: {
        [type.googleapis.com/objectrecognition.DetectionFilterCalculatorOptions] {
            pass_id: 3
        }
    }
}

. , , .



ScreenTap, BoundaryBoxCrop. , ScreenTap FlowLimiter, , , . , , , , . FrameProcessor, ObjectDetectionFrameProcessor.


ScreenTap , , , " ". . , , , (, , [0, 1]).


BoundaryBoxCrop , . , DetectionFilter , FlowLimiter , GPU CPU. , , , .


mediapipe::Status BoundaryBoxCropCalculator::Process(mediapipe::CalculatorContext *cc) {
    auto& detections_packet = cc->Inputs().Get("DETECTION", 0);
    auto& frames_packet = cc->Inputs().Get("IMAGE", 0);
    auto& click_packet = cc->Inputs().Get("CLICK", 0);

    if (detections_packet.IsEmpty()
        || frames_packet.IsEmpty()
        || click_packet.IsEmpty()) {
        return mediapipe::OkStatus();
    }

    const std::vector<mediapipe::Detection>& detections =
            detections_packet.Get<std::vector<mediapipe::Detection>>();
    const mediapipe::ImageFrame& image_frame = frames_packet.Get<mediapipe::ImageFrame>();

    // Java    protobuf  ,     .
    auto click_location_str = click_packet.Get<std::string>();
    objectrecognition::ClickLocation click_location;
    click_location.ParseFromString(click_location_str);
    //   ,    
    if (click_location.x() == -1 || click_location.y() == -1) {
        return mediapipe::OkStatus();
    }

    //      
    absl::optional<mediapipe::Detection> detection = FindOverlappedDetection(click_location, detections);

    if (detection.has_value()) {
        //     ,     
        std::unique_ptr<mediapipe::ImageFrame> cropped_image = CropImage(image_frame, detection.value());
        cc->Outputs().Get("", 0).Add(cropped_image.release(), cc->InputTimestamp());
    }

    return mediapipe::OkStatus();
}

absl::optional<mediapipe::Detection> BoundaryBoxCropCalculator::FindOverlappedDetection(
        const objectrecognition::ClickLocation& click_location,
        const std::vector<mediapipe::Detection>& detections) {
    for (const auto& input_detection : detections) {
        const auto& b_box = input_detection.location_data().relative_bounding_box();

        if (b_box.xmin() < click_location.x() && click_location.x() < (b_box.xmin() + b_box.width())
            && b_box.ymin() < click_location.y() && click_location.y() < (b_box.ymin() + b_box.height())) {
            return input_detection;
        }
    }

    return absl::nullopt;
}

std::unique_ptr<mediapipe::ImageFrame> BoundaryBoxCropCalculator::CropImage(
        const mediapipe::ImageFrame& image_frame,
        const mediapipe::Detection& detection) {
    const uint8* pixel_data = image_frame.PixelData();
    const auto& b_box = detection.location_data().relative_bounding_box();

    int height = static_cast<int>(b_box.height() * static_cast<float>(image_frame.Height()));
    int width = static_cast<int>(b_box.width() * static_cast<float>(image_frame.Width()));
    int xmin = static_cast<int>(b_box.xmin() * static_cast<float>(image_frame.Width()));
    int ymin = static_cast<int>(b_box.ymin() * static_cast<float>(image_frame.Height()));

    if (xmin < 0) {
        width += xmin;
        xmin = 0;
    }
    if (ymin < 0) {
        height += ymin;
        ymin = 0;
    }
    if (width > image_frame.Width()) {
        width = image_frame.Width();
    }
    if (height > image_frame.Height()) {
        height = image_frame.Height();
    }

    std::vector<uint8_t> pixels;
    pixels.reserve(height * width * image_frame.NumberOfChannels());
    for (int y = ymin; y < ymin + height; ++y) {
        int row_offset = y * image_frame.WidthStep();
        for (int x = xmin; x < xmin + width; ++x) {
            for (int ch = 0; ch < image_frame.NumberOfChannels(); ++ch) {
                pixels.push_back(pixel_data[row_offset + x * image_frame.NumberOfChannels() + ch]);
            }
        }
    }

    std::unique_ptr<mediapipe::ImageFrame> cropped_image = std::make_unique<mediapipe::ImageFrame>();
    cropped_image->CopyPixelData(image_frame.Format(), width, height, pixels.data(),
            mediapipe::ImageFrame::kDefaultAlignmentBoundary);

    return cropped_image;
}

. TfLiteTensor .


(Java ) , HTTP . , "" . , ( ).


, — . :



. ( camry hummer), ford focus , .. -, (Hyundai SantaFe).


4 , , .



, , , , . , , .



, Mediapipe, , , , . real-time . UX, , . proof-of-concept, :


  • CPU, , , GPU. GPU CPU GPU;
  • , . ;
  • ;
  • ;
  • . ANN (Approximate Nearest Neighbors).

All Articles