Creación de aplicaciones con Mediapipe

Hoy en día, muchos servicios utilizan modelos de redes neuronales en su trabajo. Sin embargo, debido al bajo rendimiento de los dispositivos cliente, en la mayoría de los casos, los cálculos se realizan en el servidor. Sin embargo, el rendimiento de los teléfonos inteligentes crece cada año y ahora es posible lanzar modelos pequeños en los dispositivos de los clientes. Surge la pregunta: ¿cómo hacer esto? Además de ejecutar el modelo, debe preprocesar y postprocesar los datos. Además, hay al menos dos plataformas donde esto debe implementarse: Android e iOS. Mediapipe es un marco para el lanzamiento de canalizaciones (preprocesamiento de datos, inferencia de un modelo, así como procesamiento posterior de los resultados del modelo) de aprendizaje automático, que permite resolver los problemas descritos anteriormente y simplificar la escritura de código multiplataforma para ejecutar modelos.



Contenido


  1. Descripción del sistema de construcción Bazel
  2. Marco Mediapipe
    1. ¿Qué es un Mediapipe? En qué consiste y por qué es necesario
    2. Aplicación HelloWorld en Mediapipe
    3. Ejecutar un modelo con Mediapipe
  3. Crear una aplicación usando Mediapipe
  4. Conclusión

Descripción del sistema de construcción Bazel


Mediapipe usa Bazel para el montaje. Esta fue la primera vez que me encontré con este sistema de compilación, así que en paralelo con Mediapipe lo descubrí con Bazel, recopilando información de varias fuentes y pisando un rastrillo. Por lo tanto, antes de hablar sobre Mediapipe, me gustaría dar una pequeña introducción sobre el sistema de ensamblaje de Bazel, que se utiliza para construir Mediapipe y proyectos basados ​​en él. Bazel es un sistema de compilación de Google, una versión revisada de su sistema de compilación Blaze interno, compartido. Bazel le permite combinar diferentes idiomas y marcos en un proyecto con un comando simple:


bazel build TARGET

, ? , Bazel-. WORKSPACE, . , . , A, B, WORKSPACE A, B. - WORKSPACE , Mediapipe. Bazel , , , , , . WORKSPACE:


#      
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

#        c++
http_archive(
    name = "rules_cc",
    strip_prefix = "rules_cc-master",
    urls = ["https://github.com/bazelbuild/rules_cc/archive/master.zip"],
)

#   googletest
http_archive(
    name = "com_google_googletest",
    strip_prefix = "googletest-master",
    urls = ["https://github.com/google/googletest/archive/master.zip"],
)

(bazel rule). — Starlark. Bazel , , . , http_archive , http_archive C++ googletest.


WORKSPACE Bazel BUILD. BUILD , , . BUILD C++ :


cc_binary(
    name="HelloWorld",
    srcs=["main.cpp"],
    copts=["-Wall", "-Wpedantic", "-Werror"]
)

, , WORKSPACE http_archive. Bazel- main.cpp.


  • ./WORKSPACE
  • ./Source/BUILD
  • ./Source/main.cpp


bazel build //Source:HelloWorld

Source — , HelloWorld — .


cmake, ( ) ./bazel-bin.


, , . Bazel (libraries).


cc_library(
    name = "lib",
    hdrs = ["utils.h"],
    srcs = ["utils.cpp"],
    visibility = [
        "//visibility:public",
    ],
)

cc_binary(
    name="HelloWorld",
    srcs=["main.cpp"],
    deps=[":lib"],
    copts=["-Wall", "-Wpedantic", "-Werror"]
)

BUILD , . , //


@some_repository//Module:lib

Bazel , . http_archive, git_repository, local_repository .


Mediapipe


Mediapipe


Mediapipe — . , ( ):


  1. (Calculators) — (). . C++ , CalculatorBase:
    • static Status GetContract(CalculatorContract*); — , , .
    • Status Open(CalculatorContext*); — . , , , .
    • Status Process(CalculatorContext*); — .
    • Status Close(CalculatorContext*); — .
  2. (Streams) . . , (input) (output). , , , .
  3. (Packet) — , . — , , , protobuf. timestamp — , . , , , , .

protobuf.


Mediapipe , . Linux, WSL, MacOS, Android, iOS. Mediapipe TensorFlow TFLite . , , .


, Mediapipe. , . Mediapipe, , - .


HelloWorld Mediapipe


, , N , .



.pbtxt , main.cpp.


├── hello-world
│   ├── BUILD
│   ├── graph.pbtxt
│   ├── main.cpp
│   ├── RepeatNTimesCalculator.cpp
│   └── RepeatNTimesCalculator.proto
└── WORKSPACE

:


input_stream: "in"
output_stream: "out"

node {
    calculator: "RepeatNTimesCalculator"
    input_stream: "in"
    output_stream: "OUTPUT_TAG:out"
    node_options: {
        [type.googleapis.com/mediapipe_demonstration.RepeatNTimesCalculatoOptions] {
            n: 3
        }
    }
}

RepeatNTimesCalculator . GetContract, , Open, , Process, .


class RepeatNTimesCalculator : public mediapipe::CalculatorBase {
public:
    static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) {
        //        ,  
        cc->Inputs().Get("", 0).Set<std::string>();
        //       OUTPUT_TAG,     
        cc->Outputs().Get("OUTPUT_TAG", 0).Set<std::string>();
        return mediapipe::OkStatus();
    }

    mediapipe::Status Open(mediapipe::CalculatorContext* cc) final {
        //   ,    
        const auto& options = cc->Options<mediapipe_demonstration::RepeatNTimesCalculatoOptions>();
        // n -      
        n_ = options.n();
        return mediapipe::OkStatus();
    }

    mediapipe::Status Process(mediapipe::CalculatorContext* cc) final {
        //     
        //           
        //       std::string
        auto txt = cc->Inputs().Index(0).Value().Get<std::string>();

        for (int i = 0; i < n_; ++i) {
            //      
            auto packet = mediapipe::MakePacket<std::string>(txt).At(cc->InputTimestamp() + i);
            //       OUTPUT_TAG   0
            cc->Outputs().Get("OUTPUT_TAG", 0).AddPacket(packet);
        }

        return mediapipe::OkStatus();
    }
private:
    int n_;
};
//    
REGISTER_CALCULATOR(RepeatNTimesCalculator);

.cpp , , REGISTER_CALCULATOR.


proto- . RepeatNTimesCalculatoOptions, , .


syntax = "proto2";
package mediapipe_demonstration;
import "mediapipe/framework/calculator_options.proto";
message RepeatNTimesCalculatoOptions {
  extend mediapipe.CalculatorOptions {
    optional RepeatNTimesCalculatoOptions ext = 350607623;
  }
  required int32 n = 2;
}

:


mediapipe::Status RunGraph() {
    //    
    std::ifstream file("./hello-world/graph.pbtxt");
    std::string graph_file_content;
    graph_file_content.assign(
        std::istreambuf_iterator<char>(file), 
        std::istreambuf_iterator<char>());
    mediapipe::CalculatorGraphConfig config = 
        mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig>(graph_file_content);
    //  
    mediapipe::CalculatorGraph graph;
    MP_RETURN_IF_ERROR(graph.Initialize(config));
    //    
    ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller, graph.AddOutputStreamPoller("out"));
    //  
    MP_RETURN_IF_ERROR(graph.StartRun({}));
    //        
    auto input_packet = mediapipe::MakePacket<std::string>("Hello!").At(mediapipe::Timestamp(0));
    MP_RETURN_IF_ERROR(graph.AddPacketToInputStream("in", input_packet));
    MP_RETURN_IF_ERROR(graph.CloseInputStream("in"));
    //    
    mediapipe::Packet packet;
    while (poller.Next(&packet)) {
        std::cout << packet.Get<std::string>() << std::endl;
    }
    return graph.WaitUntilDone();
}

, , BUILD , .


load("@mediapipe_repository//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
#     
proto_library(
    name = "repeat_n_times_calculator_proto",
    srcs = ["RepeatNTimesCalculator.proto"],
    visibility = ["//visibility:public"],
    deps = [
        "@mediapipe_repository//mediapipe/framework:calculator_proto",
    ],
)
#     
mediapipe_cc_proto_library(
    name = "repeat_n_times_calculator_cc_proto",
    srcs = ["RepeatNTimesCalculator.proto"],
    cc_deps = [
        "@mediapipe_repository//mediapipe/framework:calculator_cc_proto",
    ],
    visibility = ["//visibility:public"],
    deps = [":repeat_n_times_calculator_proto"],
)
#   .  ,    
cc_library(
    name = "repeat_n_times_calculator",
    srcs = ["RepeatNTimesCalculator.cpp"],
    visibility = [
        "//visibility:public",
    ],
    deps = [
        ":repeat_n_times_calculator_cc_proto",
        "@mediapipe_repository//mediapipe/framework:calculator_framework",
        "@mediapipe_repository//mediapipe/framework/port:status",
    ],
    alwayslink = 1,
)
#    ,    .
cc_binary(
    name = "HelloMediapipe",
    srcs = ["main.cpp"],
    deps = [
        "repeat_n_times_calculator",
        "@mediapipe_repository//mediapipe/framework/port:logging",
        "@mediapipe_repository//mediapipe/framework/port:parse_text_proto",
        "@mediapipe_repository//mediapipe/framework/port:status",
    ],
)

:


$ bazel-2.0.0 build --define MEDIAPIPE_DISABLE_GPU=1 //hello-world:HelloMediapipe
...
INFO: Build completed successfully, 4 total actions
$ ./bazel-bin/hello-world/HelloMediapipe
Hello!
Hello!
Hello!

Mediapipe, ML .


Mediapipe


, Mediapipe tflite . . , ImageNet1k. , , . , .


├── inference
│   ├── android/src/main/java/com/com/mediapipe_demonstration/inference
│   │   ├── AndroidManifest.xml
│   │   ├── BUILD
│   │   ├── MainActivity.kt
│   │   └── res
│   │       └── ...
│   ├── BUILD
│   ├── desktop
│   │   ├── BUILD
│   │   └── main.cpp
│   ├── graph.pbtxt
│   ├── img.jpg
│   └── mobilenetv2_imagenet.tflite
└── WORKSPACE

:



input_stream: "in"
output_stream: "out"
node: {
  calculator: "ImageTransformationCalculator"
  input_stream: "IMAGE:in"
  output_stream: "IMAGE:transformed_input"
  node_options: {
    [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
      output_width: 224
      output_height: 224
    }
  }
}
node {
  calculator: "TfLiteConverterCalculator"
  input_stream: "IMAGE:transformed_input"
  output_stream: "TENSORS:image_tensor"
  node_options: {
      [type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] {
        zero_center: false
      }
  }
}
node {
  calculator: "TfLiteInferenceCalculator"
  input_stream: "TENSORS:image_tensor"
  output_stream: "TENSORS:prediction_tensor"
  node_options: {
    [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
      model_path: "inference/mobilenetv2_imagenet.tflite"
    }
  }
}
node {
  calculator: "TfLiteTensorsToFloatsCalculator"
  input_stream: "TENSORS:prediction_tensor"
  output_stream: "FLOATS:out"
}

4 :


  1. 224x224.
  2. [-1, 1] ( zero_center: false, [0, 1]), ImageFrame ( Mediapipe, , ) TfLiteTensor ( std::vector<TfLiteTensor> 1). TfLiteTensor ( GetContract), .
  3. .
  4. TfLiteTensor std::vector<float>. .

, .


. . , , .


//  
auto img_mat = cv::imread("./inference/img.jpg");
//    
auto input_frame = std::make_unique<mediapipe::ImageFrame>(
    mediapipe::ImageFormat::SRGB, img_mat.cols, img_mat.rows,
    mediapipe::ImageFrame::kDefaultAlignmentBoundary);
cv::Mat input_frame_mat = mediapipe::formats::MatView(input_frame.get());
img_mat.copyTo(input_frame_mat);
auto frame = mediapipe::Adopt(input_frame.release()).At(mediapipe::Timestamp(0));
//    
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream("in", frame));
MP_RETURN_IF_ERROR(graph.CloseInputStream("in"));
//       
mediapipe::Packet packet;
while (poller.Next(&packet)) {
    auto predictions = packet.Get<std::vector<float>>();
    int idx = std::max_element(predictions.begin(), predictions.end()) - predictions.begin();
    std::cout << idx << std::endl;
}

, :



$ bazel-2.0.0 build --define MEDIAPIPE_DISABLE_GPU=1 //inference/desktop:Inference
...
INFO: Build completed successfully, 3 total actions
$ ./bazel-bin/inference/desktop/Inference
INFO: Initialized TensorFlow Lite runtime.
151

151 ImageNet "Chihuahua".


. Activity, ( onCreate) , Mediapipe. protobuf-. , . onActivityResult, . Mediapipe - ( ).


class MainActivity : AppCompatActivity() {
    val PICK_IMAGE = 1
    var mpGraph: Graph? = null
    var timestamp = 0L
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        this.setContentView(R.layout.activity_main)
        val outputTv = findViewById<TextView>(R.id.outputTv)
        val button = findViewById<Button>(R.id.selectButton)
        AndroidAssetUtil.initializeNativeAssetManager(this)
        //    
        //        
        val graph = Graph()
        assets.open("mobile_binary_graph.binarypb").use {
            val graphBytes = it.readBytes()
            graph.loadBinaryGraph(graphBytes)
        }
        //    
        graph.addPacketCallback("out") {
            val res = PacketGetter.getFloat32Vector(it)
            val label = res.indices.maxBy { i -> res[i] } ?: -1
            this@MainActivity.runOnUiThread {
                outputTv.text = label.toString()
            }
        }
        graph.startRunningGraph()
        //      
        button.setOnClickListener {
            val intent = Intent()
            intent.type = "image/*"
            intent.action = Intent.ACTION_GET_CONTENT
            startActivityForResult(Intent.createChooser(intent, "Select Picture"), PICK_IMAGE)
        }
        mpGraph = graph
    }
    override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
        if (requestCode == PICK_IMAGE) {
            //        
            val outputTv = findViewById<TextView>(R.id.outputTv)
            val imageView = findViewById<ImageView>(R.id.imageView)
            val uri = data?.data!!
            //    
            val graph = mpGraph!!
            val creator = AndroidPacketCreator(graph)
            val stream = contentResolver.openInputStream(uri)
            val bitmap = BitmapFactory.decodeStream(stream)
            imageView.setImageBitmap(bitmap)
            val packet = creator.createRgbImageFrame(bitmap)
            graph.addPacketToInputStream("in", packet, timestamp)
        }
    }
    companion object {
        init {
            //   mediapipe 
            System.loadLibrary("mediapipe_jni")
        }
    }
}

:


$ bazel-2.0.0 mobile-install --start_app -c opt --config=android_arm64 //inference/android/src/main/java/com/mediapipe_demonstration/inference:Inference

, :



, .


, github


Mediapipe


Mediapipe


. . , . , . , . :


  • , GPU, ;
  • .

, Mediapipe. , , , . . : , .


Mediapipe, . , ImageNet1k, Stanford Cars Dataset. . :



logit- NLL loss. "" mean pool.


Android , Mediapipe, Mediapipe , .


.
├── Android/src/main/java/com/kshmax/objectrecognition
│   └── objectrecognition
│       ├── AndroidManifest.xml
│       ├── BUILD
│       ├── MainActivity.kt
│       ├── Models.kt
│       ├── ObjectDetectionFrameProcessor.java
│       └── res
│           └── ...
├── Calculators
│   ├── BoundaryBoxCropCalculator.cpp
│   ├── BUILD
│   ├── ClickLocation.proto
│   ├── DetectionFilterCalculator.cpp
│   ├── DetectionFilterCalculator.proto
├── Graphs
│   ├── BUILD
│   └── ObjectRecognitionGraph.pbtxt
├── Models
│   ├── BUILD
│   └── car_vectorizer.tflite
├── Server
│   ├── labels.npy
│   ├── server.py
│   └── vectors.npy
└── WORKSPACE

Mediapipe :



, .


, . , , . : , . FlowLimiter . . AnnotationOverlay , . .


DetectionFilter, , . DetectionFilter, , , .


Process DetectionFilter :


mediapipe::Status DetectionFilterCalculator::Process(mediapipe::CalculatorContext *cc) {
    const auto& input_detections = cc->Inputs().Get("", 0).Get<std::vector<::mediapipe::Detection>>();
    std::vector<::mediapipe::Detection> output_detections;
    for (const auto& input_detection : input_detections) {
        bool next_detection = false;

        for (int pass_id : pass_ids_) {
            for (int label_id : input_detection.label_id()) {
                if (pass_id == label_id) {
                    output_detections.push_back(input_detection);
                    next_detection = true;
                    break;
                }
            }
            if (next_detection) {
                break;
            }
        }
    }
    auto out_packet = mediapipe::MakePacket<std::vector<mediapipe::Detection>>(output_detections).At(cc->InputTimestamp());
    cc->Outputs().Get("", 0).AddPacket(out_packet);

    return mediapipe::OkStatus();
}

DetectionFilterCalculator :


node {
    calculator: "DetectionFilterCalculator"
    input_stream: "filtered_detections"
    output_stream: "car_detections"
    node_options: {
        [type.googleapis.com/objectrecognition.DetectionFilterCalculatorOptions] {
            pass_id: 3
        }
    }
}

. , , .



ScreenTap, BoundaryBoxCrop. , ScreenTap FlowLimiter, , , . , , , , . FrameProcessor, ObjectDetectionFrameProcessor.


ScreenTap , , , " ". . , , , (, , [0, 1]).


BoundaryBoxCrop , . , DetectionFilter , FlowLimiter , GPU CPU. , , , .


mediapipe::Status BoundaryBoxCropCalculator::Process(mediapipe::CalculatorContext *cc) {
    auto& detections_packet = cc->Inputs().Get("DETECTION", 0);
    auto& frames_packet = cc->Inputs().Get("IMAGE", 0);
    auto& click_packet = cc->Inputs().Get("CLICK", 0);

    if (detections_packet.IsEmpty()
        || frames_packet.IsEmpty()
        || click_packet.IsEmpty()) {
        return mediapipe::OkStatus();
    }

    const std::vector<mediapipe::Detection>& detections =
            detections_packet.Get<std::vector<mediapipe::Detection>>();
    const mediapipe::ImageFrame& image_frame = frames_packet.Get<mediapipe::ImageFrame>();

    // Java    protobuf  ,     .
    auto click_location_str = click_packet.Get<std::string>();
    objectrecognition::ClickLocation click_location;
    click_location.ParseFromString(click_location_str);
    //   ,    
    if (click_location.x() == -1 || click_location.y() == -1) {
        return mediapipe::OkStatus();
    }

    //      
    absl::optional<mediapipe::Detection> detection = FindOverlappedDetection(click_location, detections);

    if (detection.has_value()) {
        //     ,     
        std::unique_ptr<mediapipe::ImageFrame> cropped_image = CropImage(image_frame, detection.value());
        cc->Outputs().Get("", 0).Add(cropped_image.release(), cc->InputTimestamp());
    }

    return mediapipe::OkStatus();
}

absl::optional<mediapipe::Detection> BoundaryBoxCropCalculator::FindOverlappedDetection(
        const objectrecognition::ClickLocation& click_location,
        const std::vector<mediapipe::Detection>& detections) {
    for (const auto& input_detection : detections) {
        const auto& b_box = input_detection.location_data().relative_bounding_box();

        if (b_box.xmin() < click_location.x() && click_location.x() < (b_box.xmin() + b_box.width())
            && b_box.ymin() < click_location.y() && click_location.y() < (b_box.ymin() + b_box.height())) {
            return input_detection;
        }
    }

    return absl::nullopt;
}

std::unique_ptr<mediapipe::ImageFrame> BoundaryBoxCropCalculator::CropImage(
        const mediapipe::ImageFrame& image_frame,
        const mediapipe::Detection& detection) {
    const uint8* pixel_data = image_frame.PixelData();
    const auto& b_box = detection.location_data().relative_bounding_box();

    int height = static_cast<int>(b_box.height() * static_cast<float>(image_frame.Height()));
    int width = static_cast<int>(b_box.width() * static_cast<float>(image_frame.Width()));
    int xmin = static_cast<int>(b_box.xmin() * static_cast<float>(image_frame.Width()));
    int ymin = static_cast<int>(b_box.ymin() * static_cast<float>(image_frame.Height()));

    if (xmin < 0) {
        width += xmin;
        xmin = 0;
    }
    if (ymin < 0) {
        height += ymin;
        ymin = 0;
    }
    if (width > image_frame.Width()) {
        width = image_frame.Width();
    }
    if (height > image_frame.Height()) {
        height = image_frame.Height();
    }

    std::vector<uint8_t> pixels;
    pixels.reserve(height * width * image_frame.NumberOfChannels());
    for (int y = ymin; y < ymin + height; ++y) {
        int row_offset = y * image_frame.WidthStep();
        for (int x = xmin; x < xmin + width; ++x) {
            for (int ch = 0; ch < image_frame.NumberOfChannels(); ++ch) {
                pixels.push_back(pixel_data[row_offset + x * image_frame.NumberOfChannels() + ch]);
            }
        }
    }

    std::unique_ptr<mediapipe::ImageFrame> cropped_image = std::make_unique<mediapipe::ImageFrame>();
    cropped_image->CopyPixelData(image_frame.Format(), width, height, pixels.data(),
            mediapipe::ImageFrame::kDefaultAlignmentBoundary);

    return cropped_image;
}

. TfLiteTensor .


(Java ) , HTTP . , "" . , ( ).


, — . :



. ( camry hummer), ford focus , .. -, (Hyundai SantaFe).


4 , , .



, , , , . , , .



, Mediapipe, , , , . real-time . UX, , . proof-of-concept, :


  • CPU, , , GPU. GPU CPU GPU;
  • , . ;
  • ;
  • ;
  • . ANN (Approximate Nearest Neighbors).

All Articles