Création d'applications avec Mediapipe

Aujourd'hui, de nombreux services utilisent des modèles de réseaux neuronaux dans leur travail. Cependant, en raison des faibles performances des périphériques clients, dans la plupart des cas, les calculs sont effectués sur le serveur. Cependant, les performances des smartphones augmentent chaque année et il est désormais possible de lancer de petits modèles sur les appareils clients. La question se pose: comment faire cela? En plus d'exécuter le modèle, vous devez pré-traiter et post-traiter les données. De plus, il existe au moins deux plates-formes sur lesquelles cela doit être implémenté: Android et iOS. Mediapipe est un cadre de lancement de pipelines (prétraitement des données, inférence d'un modèle, ainsi que post-traitement des résultats du modèle) d'apprentissage automatique, qui permet de résoudre les problèmes décrits ci-dessus et de simplifier l'écriture de code multiplateforme pour exécuter des modèles.



Contenu


  1. Présentation de Bazel Build System
  2. Cadre Mediapipe
    1. Qu'est-ce qu'un Mediapipe. En quoi il consiste et pourquoi il est nécessaire
    2. Application HelloWorld sur Mediapipe
    3. Exécution d'un modèle à l'aide de Mediapipe
  3. Création d'une application à l'aide de Mediapipe
  4. Conclusion

Présentation de Bazel Build System


Mediapipe utilise Bazel pour l'assemblage. C'était la première fois que je rencontrais ce système de construction, donc en parallèle avec Mediapipe, je l'ai découvert avec Bazel, collectant des informations de diverses sources et marchant sur un râteau. Par conséquent, avant de parler de Mediapipe lui-même, je voudrais donner une petite introduction sur le système d'assemblage Bazel, qui est utilisé pour construire Mediapipe et les projets basés sur celui-ci. Bazel est un système de build de Google, une version révisée de leur système de build interne Blaze, partagé. Bazel vous permet de combiner différents langages et frameworks dans un même projet avec une simple commande:


bazel build TARGET

, ? , Bazel-. WORKSPACE, . , . , A, B, WORKSPACE A, B. - WORKSPACE , Mediapipe. Bazel , , , , , . WORKSPACE:


#      
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

#        c++
http_archive(
    name = "rules_cc",
    strip_prefix = "rules_cc-master",
    urls = ["https://github.com/bazelbuild/rules_cc/archive/master.zip"],
)

#   googletest
http_archive(
    name = "com_google_googletest",
    strip_prefix = "googletest-master",
    urls = ["https://github.com/google/googletest/archive/master.zip"],
)

(bazel rule). — Starlark. Bazel , , . , http_archive , http_archive C++ googletest.


WORKSPACE Bazel BUILD. BUILD , , . BUILD C++ :


cc_binary(
    name="HelloWorld",
    srcs=["main.cpp"],
    copts=["-Wall", "-Wpedantic", "-Werror"]
)

, , WORKSPACE http_archive. Bazel- main.cpp.


  • ./WORKSPACE
  • ./Source/BUILD
  • ./Source/main.cpp


bazel build //Source:HelloWorld

Source — , HelloWorld — .


cmake, ( ) ./bazel-bin.


, , . Bazel (libraries).


cc_library(
    name = "lib",
    hdrs = ["utils.h"],
    srcs = ["utils.cpp"],
    visibility = [
        "//visibility:public",
    ],
)

cc_binary(
    name="HelloWorld",
    srcs=["main.cpp"],
    deps=[":lib"],
    copts=["-Wall", "-Wpedantic", "-Werror"]
)

BUILD , . , //


@some_repository//Module:lib

Bazel , . http_archive, git_repository, local_repository .


Mediapipe


Mediapipe


Mediapipe — . , ( ):


  1. (Calculators) — (). . C++ , CalculatorBase:
    • static Status GetContract(CalculatorContract*); — , , .
    • Status Open(CalculatorContext*); — . , , , .
    • Status Process(CalculatorContext*); — .
    • Status Close(CalculatorContext*); — .
  2. (Streams) . . , (input) (output). , , , .
  3. (Packet) — , . — , , , protobuf. timestamp — , . , , , , .

protobuf.


Mediapipe , . Linux, WSL, MacOS, Android, iOS. Mediapipe TensorFlow TFLite . , , .


, Mediapipe. , . Mediapipe, , - .


HelloWorld Mediapipe


, , N , .



.pbtxt , main.cpp.


├── hello-world
│   ├── BUILD
│   ├── graph.pbtxt
│   ├── main.cpp
│   ├── RepeatNTimesCalculator.cpp
│   └── RepeatNTimesCalculator.proto
└── WORKSPACE

:


input_stream: "in"
output_stream: "out"

node {
    calculator: "RepeatNTimesCalculator"
    input_stream: "in"
    output_stream: "OUTPUT_TAG:out"
    node_options: {
        [type.googleapis.com/mediapipe_demonstration.RepeatNTimesCalculatoOptions] {
            n: 3
        }
    }
}

RepeatNTimesCalculator . GetContract, , Open, , Process, .


class RepeatNTimesCalculator : public mediapipe::CalculatorBase {
public:
    static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) {
        //        ,  
        cc->Inputs().Get("", 0).Set<std::string>();
        //       OUTPUT_TAG,     
        cc->Outputs().Get("OUTPUT_TAG", 0).Set<std::string>();
        return mediapipe::OkStatus();
    }

    mediapipe::Status Open(mediapipe::CalculatorContext* cc) final {
        //   ,    
        const auto& options = cc->Options<mediapipe_demonstration::RepeatNTimesCalculatoOptions>();
        // n -      
        n_ = options.n();
        return mediapipe::OkStatus();
    }

    mediapipe::Status Process(mediapipe::CalculatorContext* cc) final {
        //     
        //           
        //       std::string
        auto txt = cc->Inputs().Index(0).Value().Get<std::string>();

        for (int i = 0; i < n_; ++i) {
            //      
            auto packet = mediapipe::MakePacket<std::string>(txt).At(cc->InputTimestamp() + i);
            //       OUTPUT_TAG   0
            cc->Outputs().Get("OUTPUT_TAG", 0).AddPacket(packet);
        }

        return mediapipe::OkStatus();
    }
private:
    int n_;
};
//    
REGISTER_CALCULATOR(RepeatNTimesCalculator);

.cpp , , REGISTER_CALCULATOR.


proto- . RepeatNTimesCalculatoOptions, , .


syntax = "proto2";
package mediapipe_demonstration;
import "mediapipe/framework/calculator_options.proto";
message RepeatNTimesCalculatoOptions {
  extend mediapipe.CalculatorOptions {
    optional RepeatNTimesCalculatoOptions ext = 350607623;
  }
  required int32 n = 2;
}

:


mediapipe::Status RunGraph() {
    //    
    std::ifstream file("./hello-world/graph.pbtxt");
    std::string graph_file_content;
    graph_file_content.assign(
        std::istreambuf_iterator<char>(file), 
        std::istreambuf_iterator<char>());
    mediapipe::CalculatorGraphConfig config = 
        mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig>(graph_file_content);
    //  
    mediapipe::CalculatorGraph graph;
    MP_RETURN_IF_ERROR(graph.Initialize(config));
    //    
    ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller, graph.AddOutputStreamPoller("out"));
    //  
    MP_RETURN_IF_ERROR(graph.StartRun({}));
    //        
    auto input_packet = mediapipe::MakePacket<std::string>("Hello!").At(mediapipe::Timestamp(0));
    MP_RETURN_IF_ERROR(graph.AddPacketToInputStream("in", input_packet));
    MP_RETURN_IF_ERROR(graph.CloseInputStream("in"));
    //    
    mediapipe::Packet packet;
    while (poller.Next(&packet)) {
        std::cout << packet.Get<std::string>() << std::endl;
    }
    return graph.WaitUntilDone();
}

, , BUILD , .


load("@mediapipe_repository//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
#     
proto_library(
    name = "repeat_n_times_calculator_proto",
    srcs = ["RepeatNTimesCalculator.proto"],
    visibility = ["//visibility:public"],
    deps = [
        "@mediapipe_repository//mediapipe/framework:calculator_proto",
    ],
)
#     
mediapipe_cc_proto_library(
    name = "repeat_n_times_calculator_cc_proto",
    srcs = ["RepeatNTimesCalculator.proto"],
    cc_deps = [
        "@mediapipe_repository//mediapipe/framework:calculator_cc_proto",
    ],
    visibility = ["//visibility:public"],
    deps = [":repeat_n_times_calculator_proto"],
)
#   .  ,    
cc_library(
    name = "repeat_n_times_calculator",
    srcs = ["RepeatNTimesCalculator.cpp"],
    visibility = [
        "//visibility:public",
    ],
    deps = [
        ":repeat_n_times_calculator_cc_proto",
        "@mediapipe_repository//mediapipe/framework:calculator_framework",
        "@mediapipe_repository//mediapipe/framework/port:status",
    ],
    alwayslink = 1,
)
#    ,    .
cc_binary(
    name = "HelloMediapipe",
    srcs = ["main.cpp"],
    deps = [
        "repeat_n_times_calculator",
        "@mediapipe_repository//mediapipe/framework/port:logging",
        "@mediapipe_repository//mediapipe/framework/port:parse_text_proto",
        "@mediapipe_repository//mediapipe/framework/port:status",
    ],
)

:


$ bazel-2.0.0 build --define MEDIAPIPE_DISABLE_GPU=1 //hello-world:HelloMediapipe
...
INFO: Build completed successfully, 4 total actions
$ ./bazel-bin/hello-world/HelloMediapipe
Hello!
Hello!
Hello!

Mediapipe, ML .


Mediapipe


, Mediapipe tflite . . , ImageNet1k. , , . , .


├── inference
│   ├── android/src/main/java/com/com/mediapipe_demonstration/inference
│   │   ├── AndroidManifest.xml
│   │   ├── BUILD
│   │   ├── MainActivity.kt
│   │   └── res
│   │       └── ...
│   ├── BUILD
│   ├── desktop
│   │   ├── BUILD
│   │   └── main.cpp
│   ├── graph.pbtxt
│   ├── img.jpg
│   └── mobilenetv2_imagenet.tflite
└── WORKSPACE

:



input_stream: "in"
output_stream: "out"
node: {
  calculator: "ImageTransformationCalculator"
  input_stream: "IMAGE:in"
  output_stream: "IMAGE:transformed_input"
  node_options: {
    [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
      output_width: 224
      output_height: 224
    }
  }
}
node {
  calculator: "TfLiteConverterCalculator"
  input_stream: "IMAGE:transformed_input"
  output_stream: "TENSORS:image_tensor"
  node_options: {
      [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] {
        zero_center: false
      }
  }
}
node {
  calculator: "TfLiteInferenceCalculator"
  input_stream: "TENSORS:image_tensor"
  output_stream: "TENSORS:prediction_tensor"
  node_options: {
    [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
      model_path: "inference/mobilenetv2_imagenet.tflite"
    }
  }
}
node {
  calculator: "TfLiteTensorsToFloatsCalculator"
  input_stream: "TENSORS:prediction_tensor"
  output_stream: "FLOATS:out"
}

4 :


  1. 224x224.
  2. [-1, 1] ( zero_center: false, [0, 1]), ImageFrame ( Mediapipe, , ) TfLiteTensor ( std::vector<TfLiteTensor> 1). TfLiteTensor ( GetContract), .
  3. .
  4. TfLiteTensor std::vector<float>. .

, .


. . , , .


//  
auto img_mat = cv::imread("./inference/img.jpg");
//    
auto input_frame = std::make_unique<mediapipe::ImageFrame>(
    mediapipe::ImageFormat::SRGB, img_mat.cols, img_mat.rows,
    mediapipe::ImageFrame::kDefaultAlignmentBoundary);
cv::Mat input_frame_mat = mediapipe::formats::MatView(input_frame.get());
img_mat.copyTo(input_frame_mat);
auto frame = mediapipe::Adopt(input_frame.release()).At(mediapipe::Timestamp(0));
//    
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream("in", frame));
MP_RETURN_IF_ERROR(graph.CloseInputStream("in"));
//       
mediapipe::Packet packet;
while (poller.Next(&packet)) {
    auto predictions = packet.Get<std::vector<float>>();
    int idx = std::max_element(predictions.begin(), predictions.end()) - predictions.begin();
    std::cout << idx << std::endl;
}

, :



$ bazel-2.0.0 build --define MEDIAPIPE_DISABLE_GPU=1 //inference/desktop:Inference
...
INFO: Build completed successfully, 3 total actions
$ ./bazel-bin/inference/desktop/Inference
INFO: Initialized TensorFlow Lite runtime.
151

151 ImageNet "Chihuahua".


. Activity, ( onCreate) , Mediapipe. protobuf-. , . onActivityResult, . Mediapipe - ( ).


class MainActivity : AppCompatActivity() {
    val PICK_IMAGE = 1
    var mpGraph: Graph? = null
    var timestamp = 0L
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        this.setContentView(R.layout.activity_main)
        val outputTv = findViewById<TextView>(R.id.outputTv)
        val button = findViewById<Button>(R.id.selectButton)
        AndroidAssetUtil.initializeNativeAssetManager(this)
        //    
        //        
        val graph = Graph()
        assets.open("mobile_binary_graph.binarypb").use {
            val graphBytes = it.readBytes()
            graph.loadBinaryGraph(graphBytes)
        }
        //    
        graph.addPacketCallback("out") {
            val res = PacketGetter.getFloat32Vector(it)
            val label = res.indices.maxBy { i -> res[i] } ?: -1
            this@MainActivity.runOnUiThread {
                outputTv.text = label.toString()
            }
        }
        graph.startRunningGraph()
        //      
        button.setOnClickListener {
            val intent = Intent()
            intent.type = "image/*"
            intent.action = Intent.ACTION_GET_CONTENT
            startActivityForResult(Intent.createChooser(intent, "Select Picture"), PICK_IMAGE)
        }
        mpGraph = graph
    }
    override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
        if (requestCode == PICK_IMAGE) {
            //        
            val outputTv = findViewById<TextView>(R.id.outputTv)
            val imageView = findViewById<ImageView>(R.id.imageView)
            val uri = data?.data!!
            //    
            val graph = mpGraph!!
            val creator = AndroidPacketCreator(graph)
            val stream = contentResolver.openInputStream(uri)
            val bitmap = BitmapFactory.decodeStream(stream)
            imageView.setImageBitmap(bitmap)
            val packet = creator.createRgbImageFrame(bitmap)
            graph.addPacketToInputStream("in", packet, timestamp)
        }
    }
    companion object {
        init {
            //   mediapipe 
            System.loadLibrary("mediapipe_jni")
        }
    }
}

:


$ bazel-2.0.0 mobile-install --start_app -c opt --config=android_arm64 //inference/android/src/main/java/com/mediapipe_demonstration/inference:Inference

, :



, .


, github


Mediapipe


Mediapipe


. . , . , . , . :


  • , GPU, ;
  • .

, Mediapipe. , , , . . : , .


Mediapipe, . , ImageNet1k, Stanford Cars Dataset. . :



logit- NLL loss. "" mean pool.


Android , Mediapipe, Mediapipe , .


.
├── Android/src/main/java/com/kshmax/objectrecognition
│   └── objectrecognition
│       ├── AndroidManifest.xml
│       ├── BUILD
│       ├── MainActivity.kt
│       ├── Models.kt
│       ├── ObjectDetectionFrameProcessor.java
│       └── res
│           └── ...
├── Calculators
│   ├── BoundaryBoxCropCalculator.cpp
│   ├── BUILD
│   ├── ClickLocation.proto
│   ├── DetectionFilterCalculator.cpp
│   ├── DetectionFilterCalculator.proto
├── Graphs
│   ├── BUILD
│   └── ObjectRecognitionGraph.pbtxt
├── Models
│   ├── BUILD
│   └── car_vectorizer.tflite
├── Server
│   ├── labels.npy
│   ├── server.py
│   └── vectors.npy
└── WORKSPACE

Mediapipe :



, .


, . , , . : , . FlowLimiter . . AnnotationOverlay , . .


DetectionFilter, , . DetectionFilter, , , .


Process DetectionFilter :


mediapipe::Status DetectionFilterCalculator::Process(mediapipe::CalculatorContext *cc) {
    const auto& input_detections = cc->Inputs().Get("", 0).Get<std::vector<::mediapipe::Detection>>();
    std::vector<::mediapipe::Detection> output_detections;
    for (const auto& input_detection : input_detections) {
        bool next_detection = false;

        for (int pass_id : pass_ids_) {
            for (int label_id : input_detection.label_id()) {
                if (pass_id == label_id) {
                    output_detections.push_back(input_detection);
                    next_detection = true;
                    break;
                }
            }
            if (next_detection) {
                break;
            }
        }
    }
    auto out_packet = mediapipe::MakePacket<std::vector<mediapipe::Detection>>(output_detections).At(cc->InputTimestamp());
    cc->Outputs().Get("", 0).AddPacket(out_packet);

    return mediapipe::OkStatus();
}

DetectionFilterCalculator :


node {
    calculator: "DetectionFilterCalculator"
    input_stream: "filtered_detections"
    output_stream: "car_detections"
    node_options: {
        [type.googleapis.com/objectrecognition.DetectionFilterCalculatorOptions] {
            pass_id: 3
        }
    }
}

. , , .



ScreenTap, BoundaryBoxCrop. , ScreenTap FlowLimiter, , , . , , , , . FrameProcessor, ObjectDetectionFrameProcessor.


ScreenTap , , , " ". . , , , (, , [0, 1]).


BoundaryBoxCrop , . , DetectionFilter , FlowLimiter , GPU CPU. , , , .


mediapipe::Status BoundaryBoxCropCalculator::Process(mediapipe::CalculatorContext *cc) {
    auto& detections_packet = cc->Inputs().Get("DETECTION", 0);
    auto& frames_packet = cc->Inputs().Get("IMAGE", 0);
    auto& click_packet = cc->Inputs().Get("CLICK", 0);

    if (detections_packet.IsEmpty()
        || frames_packet.IsEmpty()
        || click_packet.IsEmpty()) {
        return mediapipe::OkStatus();
    }

    const std::vector<mediapipe::Detection>& detections =
            detections_packet.Get<std::vector<mediapipe::Detection>>();
    const mediapipe::ImageFrame& image_frame = frames_packet.Get<mediapipe::ImageFrame>();

    // Java    protobuf  ,     .
    auto click_location_str = click_packet.Get<std::string>();
    objectrecognition::ClickLocation click_location;
    click_location.ParseFromString(click_location_str);
    //   ,    
    if (click_location.x() == -1 || click_location.y() == -1) {
        return mediapipe::OkStatus();
    }

    //      
    absl::optional<mediapipe::Detection> detection = FindOverlappedDetection(click_location, detections);

    if (detection.has_value()) {
        //     ,     
        std::unique_ptr<mediapipe::ImageFrame> cropped_image = CropImage(image_frame, detection.value());
        cc->Outputs().Get("", 0).Add(cropped_image.release(), cc->InputTimestamp());
    }

    return mediapipe::OkStatus();
}

absl::optional<mediapipe::Detection> BoundaryBoxCropCalculator::FindOverlappedDetection(
        const objectrecognition::ClickLocation& click_location,
        const std::vector<mediapipe::Detection>& detections) {
    for (const auto& input_detection : detections) {
        const auto& b_box = input_detection.location_data().relative_bounding_box();

        if (b_box.xmin() < click_location.x() && click_location.x() < (b_box.xmin() + b_box.width())
            && b_box.ymin() < click_location.y() && click_location.y() < (b_box.ymin() + b_box.height())) {
            return input_detection;
        }
    }

    return absl::nullopt;
}

std::unique_ptr<mediapipe::ImageFrame> BoundaryBoxCropCalculator::CropImage(
        const mediapipe::ImageFrame& image_frame,
        const mediapipe::Detection& detection) {
    const uint8* pixel_data = image_frame.PixelData();
    const auto& b_box = detection.location_data().relative_bounding_box();

    int height = static_cast<int>(b_box.height() * static_cast<float>(image_frame.Height()));
    int width = static_cast<int>(b_box.width() * static_cast<float>(image_frame.Width()));
    int xmin = static_cast<int>(b_box.xmin() * static_cast<float>(image_frame.Width()));
    int ymin = static_cast<int>(b_box.ymin() * static_cast<float>(image_frame.Height()));

    if (xmin < 0) {
        width += xmin;
        xmin = 0;
    }
    if (ymin < 0) {
        height += ymin;
        ymin = 0;
    }
    if (width > image_frame.Width()) {
        width = image_frame.Width();
    }
    if (height > image_frame.Height()) {
        height = image_frame.Height();
    }

    std::vector<uint8_t> pixels;
    pixels.reserve(height * width * image_frame.NumberOfChannels());
    for (int y = ymin; y < ymin + height; ++y) {
        int row_offset = y * image_frame.WidthStep();
        for (int x = xmin; x < xmin + width; ++x) {
            for (int ch = 0; ch < image_frame.NumberOfChannels(); ++ch) {
                pixels.push_back(pixel_data[row_offset + x * image_frame.NumberOfChannels() + ch]);
            }
        }
    }

    std::unique_ptr<mediapipe::ImageFrame> cropped_image = std::make_unique<mediapipe::ImageFrame>();
    cropped_image->CopyPixelData(image_frame.Format(), width, height, pixels.data(),
            mediapipe::ImageFrame::kDefaultAlignmentBoundary);

    return cropped_image;
}

. TfLiteTensor .


(Java ) , HTTP . , "" . , ( ).


, — . :



. ( camry hummer), ford focus , .. -, (Hyundai SantaFe).


4 , , .



, , , , . , , .



, Mediapipe, , , , . real-time . UX, , . proof-of-concept, :


  • CPU, , , GPU. GPU CPU GPU;
  • , . ;
  • ;
  • ;
  • . ANN (Approximate Nearest Neighbors).

All Articles