Mediapipe के साथ बिल्डिंग एप्लीकेशन

आज, कई सेवाएं अपने काम में तंत्रिका नेटवर्क मॉडल का उपयोग करती हैं। हालांकि, क्लाइंट डिवाइस के कम प्रदर्शन के कारण, ज्यादातर मामलों में, गणना सर्वर पर की जाती है। हालाँकि, हर साल स्मार्टफोन का प्रदर्शन बढ़ रहा है और अब क्लाइंट डिवाइस पर छोटे मॉडल लॉन्च करना संभव है। सवाल उठता है: यह कैसे करें? मॉडल को चलाने के अलावा, आपको डेटा को प्री-प्रोसेस और पोस्ट-प्रोसेस करने की आवश्यकता है। इसके अलावा, कम से कम दो प्लेटफ़ॉर्म हैं जहां इसे लागू करने की आवश्यकता है: एंड्रॉइड और आईओएस। मेडियापिप मशीन लर्निंग के पाइपलाइन (डेटा प्रीप्रोसेसिंग, एक मॉडल की शुरूआत, साथ ही मॉडल परिणाम के बाद के प्रसंस्करण) के लिए एक रूपरेखा है, जो ऊपर वर्णित समस्याओं को हल करने और मॉडल को चलाने के लिए क्रॉस-प्लेटफ़ॉर्म कोड के लेखन को सरल बनाने की अनुमति देता है।



सामग्री


  1. Bazel बिल्ड सिस्टम अवलोकन
  2. मेडियापाइप फ्रेमवर्क
    1. Mediapipe क्या है। इसमें क्या शामिल है और इसकी आवश्यकता क्यों है
    2. Mediapipe पर HelloWorld आवेदन
    3. Mediapipe का उपयोग करके एक मॉडल चलाना
  3. Mediapipe का उपयोग करके एक एप्लिकेशन बनाना
  4. निष्कर्ष

Bazel बिल्ड सिस्टम अवलोकन


Mediapipe असेंबली के लिए Bazel का उपयोग करता है। यह पहली बार था जब मैं इस बिल्ड सिस्टम में आया था, इसलिए मेडीपाइप के समानांतर मैंने इसे बज़ेल के साथ लगा लिया, विभिन्न स्रोतों से जानकारी एकत्र की और एक रैक पर कदम रखा। इसलिए, स्वयं मेदीपिप के बारे में बताने से पहले, मैं बेज़ेल असेंबली प्रणाली पर एक छोटा सा परिचय देना चाहूंगा, जिसका उपयोग मेडीपाइप और उसके आधार पर परियोजनाओं के निर्माण के लिए किया जाता है। Bazel Google का एक बिल्ड सिस्टम है, जो उनके आंतरिक ब्लेज़ बिल्ड सिस्टम का एक संशोधित संस्करण है। Bazel आपको एक सरल आदेश के साथ एक परियोजना में विभिन्न भाषाओं और रूपरेखाओं को संयोजित करने की अनुमति देता है:


bazel build TARGET

, ? , Bazel-. WORKSPACE, . , . , A, B, WORKSPACE A, B. - WORKSPACE , Mediapipe. Bazel , , , , , . WORKSPACE:


#      
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

#        c++
http_archive(
    name = "rules_cc",
    strip_prefix = "rules_cc-master",
    urls = ["https://github.com/bazelbuild/rules_cc/archive/master.zip"],
)

#   googletest
http_archive(
    name = "com_google_googletest",
    strip_prefix = "googletest-master",
    urls = ["https://github.com/google/googletest/archive/master.zip"],
)

(bazel rule). — Starlark. Bazel , , . , http_archive , http_archive C++ googletest.


WORKSPACE Bazel BUILD. BUILD , , . BUILD C++ :


cc_binary(
    name="HelloWorld",
    srcs=["main.cpp"],
    copts=["-Wall", "-Wpedantic", "-Werror"]
)

, , WORKSPACE http_archive. Bazel- main.cpp.


  • ./WORKSPACE
  • ./Source/BUILD
  • ./Source/main.cpp


bazel build //Source:HelloWorld

Source — , HelloWorld — .


cmake, ( ) ./bazel-bin.


, , . Bazel (libraries).


cc_library(
    name = "lib",
    hdrs = ["utils.h"],
    srcs = ["utils.cpp"],
    visibility = [
        "//visibility:public",
    ],
)

cc_binary(
    name="HelloWorld",
    srcs=["main.cpp"],
    deps=[":lib"],
    copts=["-Wall", "-Wpedantic", "-Werror"]
)

BUILD , . , //


@some_repository//Module:lib

Bazel , . http_archive, git_repository, local_repository .


Mediapipe


Mediapipe


Mediapipe — . , ( ):


  1. (Calculators) — (). . C++ , CalculatorBase:
    • static Status GetContract(CalculatorContract*); — , , .
    • Status Open(CalculatorContext*); — . , , , .
    • Status Process(CalculatorContext*); — .
    • Status Close(CalculatorContext*); — .
  2. (Streams) . . , (input) (output). , , , .
  3. (Packet) — , . — , , , protobuf. timestamp — , . , , , , .

protobuf.


Mediapipe , . Linux, WSL, MacOS, Android, iOS. Mediapipe TensorFlow TFLite . , , .


, Mediapipe. , . Mediapipe, , - .


HelloWorld Mediapipe


, , N , .



.pbtxt , main.cpp.


├── hello-world
│   ├── BUILD
│   ├── graph.pbtxt
│   ├── main.cpp
│   ├── RepeatNTimesCalculator.cpp
│   └── RepeatNTimesCalculator.proto
└── WORKSPACE

:


input_stream: "in"
output_stream: "out"

node {
    calculator: "RepeatNTimesCalculator"
    input_stream: "in"
    output_stream: "OUTPUT_TAG:out"
    node_options: {
        [type.googleapis.com/mediapipe_demonstration.RepeatNTimesCalculatoOptions] {
            n: 3
        }
    }
}

RepeatNTimesCalculator . GetContract, , Open, , Process, .


class RepeatNTimesCalculator : public mediapipe::CalculatorBase {
public:
    static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) {
        //        ,  
        cc->Inputs().Get("", 0).Set<std::string>();
        //       OUTPUT_TAG,     
        cc->Outputs().Get("OUTPUT_TAG", 0).Set<std::string>();
        return mediapipe::OkStatus();
    }

    mediapipe::Status Open(mediapipe::CalculatorContext* cc) final {
        //   ,    
        const auto& options = cc->Options<mediapipe_demonstration::RepeatNTimesCalculatoOptions>();
        // n -      
        n_ = options.n();
        return mediapipe::OkStatus();
    }

    mediapipe::Status Process(mediapipe::CalculatorContext* cc) final {
        //     
        //           
        //       std::string
        auto txt = cc->Inputs().Index(0).Value().Get<std::string>();

        for (int i = 0; i < n_; ++i) {
            //      
            auto packet = mediapipe::MakePacket<std::string>(txt).At(cc->InputTimestamp() + i);
            //       OUTPUT_TAG   0
            cc->Outputs().Get("OUTPUT_TAG", 0).AddPacket(packet);
        }

        return mediapipe::OkStatus();
    }
private:
    int n_;
};
//    
REGISTER_CALCULATOR(RepeatNTimesCalculator);

.cpp , , REGISTER_CALCULATOR.


proto- . RepeatNTimesCalculatoOptions, , .


syntax = "proto2";
package mediapipe_demonstration;
import "mediapipe/framework/calculator_options.proto";
message RepeatNTimesCalculatoOptions {
  extend mediapipe.CalculatorOptions {
    optional RepeatNTimesCalculatoOptions ext = 350607623;
  }
  required int32 n = 2;
}

:


mediapipe::Status RunGraph() {
    //    
    std::ifstream file("./hello-world/graph.pbtxt");
    std::string graph_file_content;
    graph_file_content.assign(
        std::istreambuf_iterator<char>(file), 
        std::istreambuf_iterator<char>());
    mediapipe::CalculatorGraphConfig config = 
        mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig>(graph_file_content);
    //  
    mediapipe::CalculatorGraph graph;
    MP_RETURN_IF_ERROR(graph.Initialize(config));
    //    
    ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller, graph.AddOutputStreamPoller("out"));
    //  
    MP_RETURN_IF_ERROR(graph.StartRun({}));
    //        
    auto input_packet = mediapipe::MakePacket<std::string>("Hello!").At(mediapipe::Timestamp(0));
    MP_RETURN_IF_ERROR(graph.AddPacketToInputStream("in", input_packet));
    MP_RETURN_IF_ERROR(graph.CloseInputStream("in"));
    //    
    mediapipe::Packet packet;
    while (poller.Next(&packet)) {
        std::cout << packet.Get<std::string>() << std::endl;
    }
    return graph.WaitUntilDone();
}

, , BUILD , .


load("@mediapipe_repository//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
#     
proto_library(
    name = "repeat_n_times_calculator_proto",
    srcs = ["RepeatNTimesCalculator.proto"],
    visibility = ["//visibility:public"],
    deps = [
        "@mediapipe_repository//mediapipe/framework:calculator_proto",
    ],
)
#     
mediapipe_cc_proto_library(
    name = "repeat_n_times_calculator_cc_proto",
    srcs = ["RepeatNTimesCalculator.proto"],
    cc_deps = [
        "@mediapipe_repository//mediapipe/framework:calculator_cc_proto",
    ],
    visibility = ["//visibility:public"],
    deps = [":repeat_n_times_calculator_proto"],
)
#   .  ,    
cc_library(
    name = "repeat_n_times_calculator",
    srcs = ["RepeatNTimesCalculator.cpp"],
    visibility = [
        "//visibility:public",
    ],
    deps = [
        ":repeat_n_times_calculator_cc_proto",
        "@mediapipe_repository//mediapipe/framework:calculator_framework",
        "@mediapipe_repository//mediapipe/framework/port:status",
    ],
    alwayslink = 1,
)
#    ,    .
cc_binary(
    name = "HelloMediapipe",
    srcs = ["main.cpp"],
    deps = [
        "repeat_n_times_calculator",
        "@mediapipe_repository//mediapipe/framework/port:logging",
        "@mediapipe_repository//mediapipe/framework/port:parse_text_proto",
        "@mediapipe_repository//mediapipe/framework/port:status",
    ],
)

:


$ bazel-2.0.0 build --define MEDIAPIPE_DISABLE_GPU=1 //hello-world:HelloMediapipe
...
INFO: Build completed successfully, 4 total actions
$ ./bazel-bin/hello-world/HelloMediapipe
Hello!
Hello!
Hello!

Mediapipe, ML .


Mediapipe


, Mediapipe tflite . . , ImageNet1k. , , . , .


├── inference
│   ├── android/src/main/java/com/com/mediapipe_demonstration/inference
│   │   ├── AndroidManifest.xml
│   │   ├── BUILD
│   │   ├── MainActivity.kt
│   │   └── res
│   │       └── ...
│   ├── BUILD
│   ├── desktop
│   │   ├── BUILD
│   │   └── main.cpp
│   ├── graph.pbtxt
│   ├── img.jpg
│   └── mobilenetv2_imagenet.tflite
└── WORKSPACE

:



input_stream: "in"
output_stream: "out"
node: {
  calculator: "ImageTransformationCalculator"
  input_stream: "IMAGE:in"
  output_stream: "IMAGE:transformed_input"
  node_options: {
    [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
      output_width: 224
      output_height: 224
    }
  }
}
node {
  calculator: "TfLiteConverterCalculator"
  input_stream: "IMAGE:transformed_input"
  output_stream: "TENSORS:image_tensor"
  node_options: {
      [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] {
        zero_center: false
      }
  }
}
node {
  calculator: "TfLiteInferenceCalculator"
  input_stream: "TENSORS:image_tensor"
  output_stream: "TENSORS:prediction_tensor"
  node_options: {
    [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
      model_path: "inference/mobilenetv2_imagenet.tflite"
    }
  }
}
node {
  calculator: "TfLiteTensorsToFloatsCalculator"
  input_stream: "TENSORS:prediction_tensor"
  output_stream: "FLOATS:out"
}

4 :


  1. 224x224.
  2. [-1, 1] ( zero_center: false, [0, 1]), ImageFrame ( Mediapipe, , ) TfLiteTensor ( std::vector<TfLiteTensor> 1). TfLiteTensor ( GetContract), .
  3. .
  4. TfLiteTensor std::vector<float>. .

, .


. . , , .


//  
auto img_mat = cv::imread("./inference/img.jpg");
//    
auto input_frame = std::make_unique<mediapipe::ImageFrame>(
    mediapipe::ImageFormat::SRGB, img_mat.cols, img_mat.rows,
    mediapipe::ImageFrame::kDefaultAlignmentBoundary);
cv::Mat input_frame_mat = mediapipe::formats::MatView(input_frame.get());
img_mat.copyTo(input_frame_mat);
auto frame = mediapipe::Adopt(input_frame.release()).At(mediapipe::Timestamp(0));
//    
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream("in", frame));
MP_RETURN_IF_ERROR(graph.CloseInputStream("in"));
//       
mediapipe::Packet packet;
while (poller.Next(&packet)) {
    auto predictions = packet.Get<std::vector<float>>();
    int idx = std::max_element(predictions.begin(), predictions.end()) - predictions.begin();
    std::cout << idx << std::endl;
}

, :



$ bazel-2.0.0 build --define MEDIAPIPE_DISABLE_GPU=1 //inference/desktop:Inference
...
INFO: Build completed successfully, 3 total actions
$ ./bazel-bin/inference/desktop/Inference
INFO: Initialized TensorFlow Lite runtime.
151

151 ImageNet "Chihuahua".


. Activity, ( onCreate) , Mediapipe. protobuf-. , . onActivityResult, . Mediapipe - ( ).


class MainActivity : AppCompatActivity() {
    val PICK_IMAGE = 1
    var mpGraph: Graph? = null
    var timestamp = 0L
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        this.setContentView(R.layout.activity_main)
        val outputTv = findViewById<TextView>(R.id.outputTv)
        val button = findViewById<Button>(R.id.selectButton)
        AndroidAssetUtil.initializeNativeAssetManager(this)
        //    
        //        
        val graph = Graph()
        assets.open("mobile_binary_graph.binarypb").use {
            val graphBytes = it.readBytes()
            graph.loadBinaryGraph(graphBytes)
        }
        //    
        graph.addPacketCallback("out") {
            val res = PacketGetter.getFloat32Vector(it)
            val label = res.indices.maxBy { i -> res[i] } ?: -1
            this@MainActivity.runOnUiThread {
                outputTv.text = label.toString()
            }
        }
        graph.startRunningGraph()
        //      
        button.setOnClickListener {
            val intent = Intent()
            intent.type = "image/*"
            intent.action = Intent.ACTION_GET_CONTENT
            startActivityForResult(Intent.createChooser(intent, "Select Picture"), PICK_IMAGE)
        }
        mpGraph = graph
    }
    override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
        if (requestCode == PICK_IMAGE) {
            //        
            val outputTv = findViewById<TextView>(R.id.outputTv)
            val imageView = findViewById<ImageView>(R.id.imageView)
            val uri = data?.data!!
            //    
            val graph = mpGraph!!
            val creator = AndroidPacketCreator(graph)
            val stream = contentResolver.openInputStream(uri)
            val bitmap = BitmapFactory.decodeStream(stream)
            imageView.setImageBitmap(bitmap)
            val packet = creator.createRgbImageFrame(bitmap)
            graph.addPacketToInputStream("in", packet, timestamp)
        }
    }
    companion object {
        init {
            //   mediapipe 
            System.loadLibrary("mediapipe_jni")
        }
    }
}

:


$ bazel-2.0.0 mobile-install --start_app -c opt --config=android_arm64 //inference/android/src/main/java/com/mediapipe_demonstration/inference:Inference

, :



, .


, github


Mediapipe


Mediapipe


. . , . , . , . :


  • , GPU, ;
  • .

, Mediapipe. , , , . . : , .


Mediapipe, . , ImageNet1k, Stanford Cars Dataset. . :



logit- NLL loss. "" mean pool.


Android , Mediapipe, Mediapipe , .


.
├── Android/src/main/java/com/kshmax/objectrecognition
│   └── objectrecognition
│       ├── AndroidManifest.xml
│       ├── BUILD
│       ├── MainActivity.kt
│       ├── Models.kt
│       ├── ObjectDetectionFrameProcessor.java
│       └── res
│           └── ...
├── Calculators
│   ├── BoundaryBoxCropCalculator.cpp
│   ├── BUILD
│   ├── ClickLocation.proto
│   ├── DetectionFilterCalculator.cpp
│   ├── DetectionFilterCalculator.proto
├── Graphs
│   ├── BUILD
│   └── ObjectRecognitionGraph.pbtxt
├── Models
│   ├── BUILD
│   └── car_vectorizer.tflite
├── Server
│   ├── labels.npy
│   ├── server.py
│   └── vectors.npy
└── WORKSPACE

Mediapipe :



, .


, . , , . : , . FlowLimiter . . AnnotationOverlay , . .


DetectionFilter, , . DetectionFilter, , , .


Process DetectionFilter :


mediapipe::Status DetectionFilterCalculator::Process(mediapipe::CalculatorContext *cc) {
    const auto& input_detections = cc->Inputs().Get("", 0).Get<std::vector<::mediapipe::Detection>>();
    std::vector<::mediapipe::Detection> output_detections;
    for (const auto& input_detection : input_detections) {
        bool next_detection = false;

        for (int pass_id : pass_ids_) {
            for (int label_id : input_detection.label_id()) {
                if (pass_id == label_id) {
                    output_detections.push_back(input_detection);
                    next_detection = true;
                    break;
                }
            }
            if (next_detection) {
                break;
            }
        }
    }
    auto out_packet = mediapipe::MakePacket<std::vector<mediapipe::Detection>>(output_detections).At(cc->InputTimestamp());
    cc->Outputs().Get("", 0).AddPacket(out_packet);

    return mediapipe::OkStatus();
}

DetectionFilterCalculator :


node {
    calculator: "DetectionFilterCalculator"
    input_stream: "filtered_detections"
    output_stream: "car_detections"
    node_options: {
        [type.googleapis.com/objectrecognition.DetectionFilterCalculatorOptions] {
            pass_id: 3
        }
    }
}

. , , .



ScreenTap, BoundaryBoxCrop. , ScreenTap FlowLimiter, , , . , , , , . FrameProcessor, ObjectDetectionFrameProcessor.


ScreenTap , , , " ". . , , , (, , [0, 1]).


BoundaryBoxCrop , . , DetectionFilter , FlowLimiter , GPU CPU. , , , .


mediapipe::Status BoundaryBoxCropCalculator::Process(mediapipe::CalculatorContext *cc) {
    auto& detections_packet = cc->Inputs().Get("DETECTION", 0);
    auto& frames_packet = cc->Inputs().Get("IMAGE", 0);
    auto& click_packet = cc->Inputs().Get("CLICK", 0);

    if (detections_packet.IsEmpty()
        || frames_packet.IsEmpty()
        || click_packet.IsEmpty()) {
        return mediapipe::OkStatus();
    }

    const std::vector<mediapipe::Detection>& detections =
            detections_packet.Get<std::vector<mediapipe::Detection>>();
    const mediapipe::ImageFrame& image_frame = frames_packet.Get<mediapipe::ImageFrame>();

    // Java    protobuf  ,     .
    auto click_location_str = click_packet.Get<std::string>();
    objectrecognition::ClickLocation click_location;
    click_location.ParseFromString(click_location_str);
    //   ,    
    if (click_location.x() == -1 || click_location.y() == -1) {
        return mediapipe::OkStatus();
    }

    //      
    absl::optional<mediapipe::Detection> detection = FindOverlappedDetection(click_location, detections);

    if (detection.has_value()) {
        //     ,     
        std::unique_ptr<mediapipe::ImageFrame> cropped_image = CropImage(image_frame, detection.value());
        cc->Outputs().Get("", 0).Add(cropped_image.release(), cc->InputTimestamp());
    }

    return mediapipe::OkStatus();
}

absl::optional<mediapipe::Detection> BoundaryBoxCropCalculator::FindOverlappedDetection(
        const objectrecognition::ClickLocation& click_location,
        const std::vector<mediapipe::Detection>& detections) {
    for (const auto& input_detection : detections) {
        const auto& b_box = input_detection.location_data().relative_bounding_box();

        if (b_box.xmin() < click_location.x() && click_location.x() < (b_box.xmin() + b_box.width())
            && b_box.ymin() < click_location.y() && click_location.y() < (b_box.ymin() + b_box.height())) {
            return input_detection;
        }
    }

    return absl::nullopt;
}

std::unique_ptr<mediapipe::ImageFrame> BoundaryBoxCropCalculator::CropImage(
        const mediapipe::ImageFrame& image_frame,
        const mediapipe::Detection& detection) {
    const uint8* pixel_data = image_frame.PixelData();
    const auto& b_box = detection.location_data().relative_bounding_box();

    int height = static_cast<int>(b_box.height() * static_cast<float>(image_frame.Height()));
    int width = static_cast<int>(b_box.width() * static_cast<float>(image_frame.Width()));
    int xmin = static_cast<int>(b_box.xmin() * static_cast<float>(image_frame.Width()));
    int ymin = static_cast<int>(b_box.ymin() * static_cast<float>(image_frame.Height()));

    if (xmin < 0) {
        width += xmin;
        xmin = 0;
    }
    if (ymin < 0) {
        height += ymin;
        ymin = 0;
    }
    if (width > image_frame.Width()) {
        width = image_frame.Width();
    }
    if (height > image_frame.Height()) {
        height = image_frame.Height();
    }

    std::vector<uint8_t> pixels;
    pixels.reserve(height * width * image_frame.NumberOfChannels());
    for (int y = ymin; y < ymin + height; ++y) {
        int row_offset = y * image_frame.WidthStep();
        for (int x = xmin; x < xmin + width; ++x) {
            for (int ch = 0; ch < image_frame.NumberOfChannels(); ++ch) {
                pixels.push_back(pixel_data[row_offset + x * image_frame.NumberOfChannels() + ch]);
            }
        }
    }

    std::unique_ptr<mediapipe::ImageFrame> cropped_image = std::make_unique<mediapipe::ImageFrame>();
    cropped_image->CopyPixelData(image_frame.Format(), width, height, pixels.data(),
            mediapipe::ImageFrame::kDefaultAlignmentBoundary);

    return cropped_image;
}

. TfLiteTensor .


(Java ) , HTTP . , "" . , ( ).


, — . :



. ( camry hummer), ford focus , .. -, (Hyundai SantaFe).


4 , , .



, , , , . , , .



, Mediapipe, , , , . real-time . UX, , . proof-of-concept, :


  • CPU, , , GPU. GPU CPU GPU;
  • , . ;
  • ;
  • ;
  • . ANN (Approximate Nearest Neighbors).

All Articles