Building Applications with Mediapipe

Today, many services use neural network models in their work. However, due to the low performance of client devices, in most cases, calculations are performed on the server. However, the performance of smartphones is growing every year and now it is possible to launch small models on client devices. The question arises: how to do this? In addition to running the model, you need to pre-process and post-process the data. In addition, there are at least two platforms where this needs to be implemented: android and iOS. Mediapipe is a framework for launching pipelines (data preprocessing, inference of a model, as well as post-processing of model results) of machine learning, which allows solving the problems described above and simplifying the writing of cross-platform code to run models.



Content


  1. Bazel Build System Overview
  2. Mediapipe Framework
    1. What is a Mediapipe. What it consists of and why it is needed
    2. HelloWorld application on Mediapipe
    3. Running a model using Mediapipe
  3. Creating an Application Using Mediapipe
  4. Conclusion

Bazel Build System Overview


Mediapipe uses Bazel for assembly. This was the first time I came across this build system, so in parallel with Mediapipe I figured it out with Bazel, collecting information from various sources and stepping on a rake. Therefore, before telling about Mediapipe itself, I would like to give a small introduction on the Bazel assembly system, which is used to build Mediapipe and projects based on it. Bazel is a build system from Google, a revised version of their internal Blaze build system, shared. Bazel allows you to combine different languages ​​and frameworks in one project with a simple command:


bazel build TARGET

, ? , Bazel-. WORKSPACE, . , . , A, B, WORKSPACE A, B. - WORKSPACE , Mediapipe. Bazel , , , , , . WORKSPACE:


#      
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

#        c++
http_archive(
    name = "rules_cc",
    strip_prefix = "rules_cc-master",
    urls = ["https://github.com/bazelbuild/rules_cc/archive/master.zip"],
)

#   googletest
http_archive(
    name = "com_google_googletest",
    strip_prefix = "googletest-master",
    urls = ["https://github.com/google/googletest/archive/master.zip"],
)

(bazel rule). β€” Starlark. Bazel , , . , http_archive , http_archive C++ googletest.


WORKSPACE Bazel BUILD. BUILD , , . BUILD C++ :


cc_binary(
    name="HelloWorld",
    srcs=["main.cpp"],
    copts=["-Wall", "-Wpedantic", "-Werror"]
)

, , WORKSPACE http_archive. Bazel- main.cpp.


  • ./WORKSPACE
  • ./Source/BUILD
  • ./Source/main.cpp


bazel build //Source:HelloWorld

Source β€” , HelloWorld β€” .


cmake, ( ) ./bazel-bin.


, , . Bazel (libraries).


cc_library(
    name = "lib",
    hdrs = ["utils.h"],
    srcs = ["utils.cpp"],
    visibility = [
        "//visibility:public",
    ],
)

cc_binary(
    name="HelloWorld",
    srcs=["main.cpp"],
    deps=[":lib"],
    copts=["-Wall", "-Wpedantic", "-Werror"]
)

BUILD , . , //


@some_repository//Module:lib

Bazel , . http_archive, git_repository, local_repository .


Mediapipe


Mediapipe


Mediapipe β€” . , ( ):


  1. (Calculators) β€” (). . C++ , CalculatorBase:
    • static Status GetContract(CalculatorContract*); β€” , , .
    • Status Open(CalculatorContext*); β€” . , , , .
    • Status Process(CalculatorContext*); β€” .
    • Status Close(CalculatorContext*); β€” .
  2. (Streams) . . , (input) (output). , , , .
  3. (Packet) β€” , . β€” , , , protobuf. timestamp β€” , . , , , , .

protobuf.


Mediapipe , . Linux, WSL, MacOS, Android, iOS. Mediapipe TensorFlow TFLite . , , .


, Mediapipe. , . Mediapipe, , - .


HelloWorld Mediapipe


, , N , .



.pbtxt , main.cpp.


β”œβ”€β”€ hello-world
β”‚   β”œβ”€β”€ BUILD
β”‚   β”œβ”€β”€ graph.pbtxt
β”‚   β”œβ”€β”€ main.cpp
β”‚   β”œβ”€β”€ RepeatNTimesCalculator.cpp
β”‚   └── RepeatNTimesCalculator.proto
└── WORKSPACE

:


input_stream: "in"
output_stream: "out"

node {
    calculator: "RepeatNTimesCalculator"
    input_stream: "in"
    output_stream: "OUTPUT_TAG:out"
    node_options: {
        [type.googleapis.com/mediapipe_demonstration.RepeatNTimesCalculatoOptions] {
            n: 3
        }
    }
}

RepeatNTimesCalculator . GetContract, , Open, , Process, .


class RepeatNTimesCalculator : public mediapipe::CalculatorBase {
public:
    static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) {
        //        ,  
        cc->Inputs().Get("", 0).Set<std::string>();
        //       OUTPUT_TAG,     
        cc->Outputs().Get("OUTPUT_TAG", 0).Set<std::string>();
        return mediapipe::OkStatus();
    }

    mediapipe::Status Open(mediapipe::CalculatorContext* cc) final {
        //   ,    
        const auto& options = cc->Options<mediapipe_demonstration::RepeatNTimesCalculatoOptions>();
        // n -      
        n_ = options.n();
        return mediapipe::OkStatus();
    }

    mediapipe::Status Process(mediapipe::CalculatorContext* cc) final {
        //     
        //           
        //       std::string
        auto txt = cc->Inputs().Index(0).Value().Get<std::string>();

        for (int i = 0; i < n_; ++i) {
            //      
            auto packet = mediapipe::MakePacket<std::string>(txt).At(cc->InputTimestamp() + i);
            //       OUTPUT_TAG   0
            cc->Outputs().Get("OUTPUT_TAG", 0).AddPacket(packet);
        }

        return mediapipe::OkStatus();
    }
private:
    int n_;
};
//    
REGISTER_CALCULATOR(RepeatNTimesCalculator);

.cpp , , REGISTER_CALCULATOR.


proto- . RepeatNTimesCalculatoOptions, , .


syntax = "proto2";
package mediapipe_demonstration;
import "mediapipe/framework/calculator_options.proto";
message RepeatNTimesCalculatoOptions {
  extend mediapipe.CalculatorOptions {
    optional RepeatNTimesCalculatoOptions ext = 350607623;
  }
  required int32 n = 2;
}

:


mediapipe::Status RunGraph() {
    //    
    std::ifstream file("./hello-world/graph.pbtxt");
    std::string graph_file_content;
    graph_file_content.assign(
        std::istreambuf_iterator<char>(file), 
        std::istreambuf_iterator<char>());
    mediapipe::CalculatorGraphConfig config = 
        mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig>(graph_file_content);
    //  
    mediapipe::CalculatorGraph graph;
    MP_RETURN_IF_ERROR(graph.Initialize(config));
    //    
    ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller, graph.AddOutputStreamPoller("out"));
    //  
    MP_RETURN_IF_ERROR(graph.StartRun({}));
    //        
    auto input_packet = mediapipe::MakePacket<std::string>("Hello!").At(mediapipe::Timestamp(0));
    MP_RETURN_IF_ERROR(graph.AddPacketToInputStream("in", input_packet));
    MP_RETURN_IF_ERROR(graph.CloseInputStream("in"));
    //    
    mediapipe::Packet packet;
    while (poller.Next(&packet)) {
        std::cout << packet.Get<std::string>() << std::endl;
    }
    return graph.WaitUntilDone();
}

, , BUILD , .


load("@mediapipe_repository//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
#     
proto_library(
    name = "repeat_n_times_calculator_proto",
    srcs = ["RepeatNTimesCalculator.proto"],
    visibility = ["//visibility:public"],
    deps = [
        "@mediapipe_repository//mediapipe/framework:calculator_proto",
    ],
)
#     
mediapipe_cc_proto_library(
    name = "repeat_n_times_calculator_cc_proto",
    srcs = ["RepeatNTimesCalculator.proto"],
    cc_deps = [
        "@mediapipe_repository//mediapipe/framework:calculator_cc_proto",
    ],
    visibility = ["//visibility:public"],
    deps = [":repeat_n_times_calculator_proto"],
)
#   .  ,    
cc_library(
    name = "repeat_n_times_calculator",
    srcs = ["RepeatNTimesCalculator.cpp"],
    visibility = [
        "//visibility:public",
    ],
    deps = [
        ":repeat_n_times_calculator_cc_proto",
        "@mediapipe_repository//mediapipe/framework:calculator_framework",
        "@mediapipe_repository//mediapipe/framework/port:status",
    ],
    alwayslink = 1,
)
#    ,    .
cc_binary(
    name = "HelloMediapipe",
    srcs = ["main.cpp"],
    deps = [
        "repeat_n_times_calculator",
        "@mediapipe_repository//mediapipe/framework/port:logging",
        "@mediapipe_repository//mediapipe/framework/port:parse_text_proto",
        "@mediapipe_repository//mediapipe/framework/port:status",
    ],
)

:


$ bazel-2.0.0 build --define MEDIAPIPE_DISABLE_GPU=1 //hello-world:HelloMediapipe
...
INFO: Build completed successfully, 4 total actions
$ ./bazel-bin/hello-world/HelloMediapipe
Hello!
Hello!
Hello!

Mediapipe, ML .


Mediapipe


, Mediapipe tflite . . , ImageNet1k. , , . , .


β”œβ”€β”€ inference
β”‚   β”œβ”€β”€ android/src/main/java/com/com/mediapipe_demonstration/inference
β”‚   β”‚   β”œβ”€β”€ AndroidManifest.xml
β”‚   β”‚   β”œβ”€β”€ BUILD
β”‚   β”‚   β”œβ”€β”€ MainActivity.kt
β”‚   β”‚   └── res
β”‚   β”‚       └── ...
β”‚   β”œβ”€β”€ BUILD
β”‚   β”œβ”€β”€ desktop
β”‚   β”‚   β”œβ”€β”€ BUILD
β”‚   β”‚   └── main.cpp
β”‚   β”œβ”€β”€ graph.pbtxt
β”‚   β”œβ”€β”€ img.jpg
β”‚   └── mobilenetv2_imagenet.tflite
└── WORKSPACE

:



input_stream: "in"
output_stream: "out"
node: {
  calculator: "ImageTransformationCalculator"
  input_stream: "IMAGE:in"
  output_stream: "IMAGE:transformed_input"
  node_options: {
    [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
      output_width: 224
      output_height: 224
    }
  }
}
node {
  calculator: "TfLiteConverterCalculator"
  input_stream: "IMAGE:transformed_input"
  output_stream: "TENSORS:image_tensor"
  node_options: {
      [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] {
        zero_center: false
      }
  }
}
node {
  calculator: "TfLiteInferenceCalculator"
  input_stream: "TENSORS:image_tensor"
  output_stream: "TENSORS:prediction_tensor"
  node_options: {
    [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
      model_path: "inference/mobilenetv2_imagenet.tflite"
    }
  }
}
node {
  calculator: "TfLiteTensorsToFloatsCalculator"
  input_stream: "TENSORS:prediction_tensor"
  output_stream: "FLOATS:out"
}

4 :


  1. 224x224.
  2. [-1, 1] ( zero_center: false, [0, 1]), ImageFrame ( Mediapipe, , ) TfLiteTensor ( std::vector<TfLiteTensor> 1). TfLiteTensor ( GetContract), .
  3. .
  4. TfLiteTensor std::vector<float>. .

, .


. . , , .


//  
auto img_mat = cv::imread("./inference/img.jpg");
//    
auto input_frame = std::make_unique<mediapipe::ImageFrame>(
    mediapipe::ImageFormat::SRGB, img_mat.cols, img_mat.rows,
    mediapipe::ImageFrame::kDefaultAlignmentBoundary);
cv::Mat input_frame_mat = mediapipe::formats::MatView(input_frame.get());
img_mat.copyTo(input_frame_mat);
auto frame = mediapipe::Adopt(input_frame.release()).At(mediapipe::Timestamp(0));
//    
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream("in", frame));
MP_RETURN_IF_ERROR(graph.CloseInputStream("in"));
//       
mediapipe::Packet packet;
while (poller.Next(&packet)) {
    auto predictions = packet.Get<std::vector<float>>();
    int idx = std::max_element(predictions.begin(), predictions.end()) - predictions.begin();
    std::cout << idx << std::endl;
}

, :



$ bazel-2.0.0 build --define MEDIAPIPE_DISABLE_GPU=1 //inference/desktop:Inference
...
INFO: Build completed successfully, 3 total actions
$ ./bazel-bin/inference/desktop/Inference
INFO: Initialized TensorFlow Lite runtime.
151

151 ImageNet "Chihuahua".


. Activity, ( onCreate) , Mediapipe. protobuf-. , . onActivityResult, . Mediapipe - ( ).


class MainActivity : AppCompatActivity() {
    val PICK_IMAGE = 1
    var mpGraph: Graph? = null
    var timestamp = 0L
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        this.setContentView(R.layout.activity_main)
        val outputTv = findViewById<TextView>(R.id.outputTv)
        val button = findViewById<Button>(R.id.selectButton)
        AndroidAssetUtil.initializeNativeAssetManager(this)
        //    
        //        
        val graph = Graph()
        assets.open("mobile_binary_graph.binarypb").use {
            val graphBytes = it.readBytes()
            graph.loadBinaryGraph(graphBytes)
        }
        //    
        graph.addPacketCallback("out") {
            val res = PacketGetter.getFloat32Vector(it)
            val label = res.indices.maxBy { i -> res[i] } ?: -1
            this@MainActivity.runOnUiThread {
                outputTv.text = label.toString()
            }
        }
        graph.startRunningGraph()
        //      
        button.setOnClickListener {
            val intent = Intent()
            intent.type = "image/*"
            intent.action = Intent.ACTION_GET_CONTENT
            startActivityForResult(Intent.createChooser(intent, "Select Picture"), PICK_IMAGE)
        }
        mpGraph = graph
    }
    override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
        if (requestCode == PICK_IMAGE) {
            //        
            val outputTv = findViewById<TextView>(R.id.outputTv)
            val imageView = findViewById<ImageView>(R.id.imageView)
            val uri = data?.data!!
            //    
            val graph = mpGraph!!
            val creator = AndroidPacketCreator(graph)
            val stream = contentResolver.openInputStream(uri)
            val bitmap = BitmapFactory.decodeStream(stream)
            imageView.setImageBitmap(bitmap)
            val packet = creator.createRgbImageFrame(bitmap)
            graph.addPacketToInputStream("in", packet, timestamp)
        }
    }
    companion object {
        init {
            //   mediapipe 
            System.loadLibrary("mediapipe_jni")
        }
    }
}

:


$ bazel-2.0.0 mobile-install --start_app -c opt --config=android_arm64 //inference/android/src/main/java/com/mediapipe_demonstration/inference:Inference

, :



, .


, github


Mediapipe


Mediapipe


. . , . , . , . :


  • , GPU, ;
  • .

, Mediapipe. , , , . . : , .


Mediapipe, . , ImageNet1k, Stanford Cars Dataset. . :



logit- NLL loss. "" mean pool.


Android , Mediapipe, Mediapipe , .


.
β”œβ”€β”€ Android/src/main/java/com/kshmax/objectrecognition
β”‚   └── objectrecognition
β”‚       β”œβ”€β”€ AndroidManifest.xml
β”‚       β”œβ”€β”€ BUILD
β”‚       β”œβ”€β”€ MainActivity.kt
β”‚       β”œβ”€β”€ Models.kt
β”‚       β”œβ”€β”€ ObjectDetectionFrameProcessor.java
β”‚       └── res
β”‚           └── ...
β”œβ”€β”€ Calculators
β”‚   β”œβ”€β”€ BoundaryBoxCropCalculator.cpp
β”‚   β”œβ”€β”€ BUILD
β”‚   β”œβ”€β”€ ClickLocation.proto
β”‚   β”œβ”€β”€ DetectionFilterCalculator.cpp
β”‚   β”œβ”€β”€ DetectionFilterCalculator.proto
β”œβ”€β”€ Graphs
β”‚   β”œβ”€β”€ BUILD
β”‚   └── ObjectRecognitionGraph.pbtxt
β”œβ”€β”€ Models
β”‚   β”œβ”€β”€ BUILD
β”‚   └── car_vectorizer.tflite
β”œβ”€β”€ Server
β”‚   β”œβ”€β”€ labels.npy
β”‚   β”œβ”€β”€ server.py
β”‚   └── vectors.npy
└── WORKSPACE

Mediapipe :



, .


, . , , . : , . FlowLimiter . . AnnotationOverlay , . .


DetectionFilter, , . DetectionFilter, , , .


Process DetectionFilter :


mediapipe::Status DetectionFilterCalculator::Process(mediapipe::CalculatorContext *cc) {
    const auto& input_detections = cc->Inputs().Get("", 0).Get<std::vector<::mediapipe::Detection>>();
    std::vector<::mediapipe::Detection> output_detections;
    for (const auto& input_detection : input_detections) {
        bool next_detection = false;

        for (int pass_id : pass_ids_) {
            for (int label_id : input_detection.label_id()) {
                if (pass_id == label_id) {
                    output_detections.push_back(input_detection);
                    next_detection = true;
                    break;
                }
            }
            if (next_detection) {
                break;
            }
        }
    }
    auto out_packet = mediapipe::MakePacket<std::vector<mediapipe::Detection>>(output_detections).At(cc->InputTimestamp());
    cc->Outputs().Get("", 0).AddPacket(out_packet);

    return mediapipe::OkStatus();
}

DetectionFilterCalculator :


node {
    calculator: "DetectionFilterCalculator"
    input_stream: "filtered_detections"
    output_stream: "car_detections"
    node_options: {
        [type.googleapis.com/objectrecognition.DetectionFilterCalculatorOptions] {
            pass_id: 3
        }
    }
}

. , , .



ScreenTap, BoundaryBoxCrop. , ScreenTap FlowLimiter, , , . , , , , . FrameProcessor, ObjectDetectionFrameProcessor.


ScreenTap , , , " ". . , , , (, , [0, 1]).


BoundaryBoxCrop , . , DetectionFilter , FlowLimiter , GPU CPU. , , , .


mediapipe::Status BoundaryBoxCropCalculator::Process(mediapipe::CalculatorContext *cc) {
    auto& detections_packet = cc->Inputs().Get("DETECTION", 0);
    auto& frames_packet = cc->Inputs().Get("IMAGE", 0);
    auto& click_packet = cc->Inputs().Get("CLICK", 0);

    if (detections_packet.IsEmpty()
        || frames_packet.IsEmpty()
        || click_packet.IsEmpty()) {
        return mediapipe::OkStatus();
    }

    const std::vector<mediapipe::Detection>& detections =
            detections_packet.Get<std::vector<mediapipe::Detection>>();
    const mediapipe::ImageFrame& image_frame = frames_packet.Get<mediapipe::ImageFrame>();

    // Java    protobuf  ,     .
    auto click_location_str = click_packet.Get<std::string>();
    objectrecognition::ClickLocation click_location;
    click_location.ParseFromString(click_location_str);
    //   ,    
    if (click_location.x() == -1 || click_location.y() == -1) {
        return mediapipe::OkStatus();
    }

    //      
    absl::optional<mediapipe::Detection> detection = FindOverlappedDetection(click_location, detections);

    if (detection.has_value()) {
        //     ,     
        std::unique_ptr<mediapipe::ImageFrame> cropped_image = CropImage(image_frame, detection.value());
        cc->Outputs().Get("", 0).Add(cropped_image.release(), cc->InputTimestamp());
    }

    return mediapipe::OkStatus();
}

absl::optional<mediapipe::Detection> BoundaryBoxCropCalculator::FindOverlappedDetection(
        const objectrecognition::ClickLocation& click_location,
        const std::vector<mediapipe::Detection>& detections) {
    for (const auto& input_detection : detections) {
        const auto& b_box = input_detection.location_data().relative_bounding_box();

        if (b_box.xmin() < click_location.x() && click_location.x() < (b_box.xmin() + b_box.width())
            && b_box.ymin() < click_location.y() && click_location.y() < (b_box.ymin() + b_box.height())) {
            return input_detection;
        }
    }

    return absl::nullopt;
}

std::unique_ptr<mediapipe::ImageFrame> BoundaryBoxCropCalculator::CropImage(
        const mediapipe::ImageFrame& image_frame,
        const mediapipe::Detection& detection) {
    const uint8* pixel_data = image_frame.PixelData();
    const auto& b_box = detection.location_data().relative_bounding_box();

    int height = static_cast<int>(b_box.height() * static_cast<float>(image_frame.Height()));
    int width = static_cast<int>(b_box.width() * static_cast<float>(image_frame.Width()));
    int xmin = static_cast<int>(b_box.xmin() * static_cast<float>(image_frame.Width()));
    int ymin = static_cast<int>(b_box.ymin() * static_cast<float>(image_frame.Height()));

    if (xmin < 0) {
        width += xmin;
        xmin = 0;
    }
    if (ymin < 0) {
        height += ymin;
        ymin = 0;
    }
    if (width > image_frame.Width()) {
        width = image_frame.Width();
    }
    if (height > image_frame.Height()) {
        height = image_frame.Height();
    }

    std::vector<uint8_t> pixels;
    pixels.reserve(height * width * image_frame.NumberOfChannels());
    for (int y = ymin; y < ymin + height; ++y) {
        int row_offset = y * image_frame.WidthStep();
        for (int x = xmin; x < xmin + width; ++x) {
            for (int ch = 0; ch < image_frame.NumberOfChannels(); ++ch) {
                pixels.push_back(pixel_data[row_offset + x * image_frame.NumberOfChannels() + ch]);
            }
        }
    }

    std::unique_ptr<mediapipe::ImageFrame> cropped_image = std::make_unique<mediapipe::ImageFrame>();
    cropped_image->CopyPixelData(image_frame.Format(), width, height, pixels.data(),
            mediapipe::ImageFrame::kDefaultAlignmentBoundary);

    return cropped_image;
}

. TfLiteTensor .


(Java ) , HTTP . , "" . , ( ).


, β€” . :



. ( camry hummer), ford focus , .. -, (Hyundai SantaFe).


4 , , .



, , , , . , , .



, Mediapipe, , , , . real-time . UX, , . proof-of-concept, :


  • CPU, , , GPU. GPU CPU GPU;
  • , . ;
  • ;
  • ;
  • . ANN (Approximate Nearest Neighbors).

All Articles