|

楼主 |
发表于 2025-4-11 16:50:21
|
显示全部楼层
1. PartClassifier.hpp
#pragma once
#include <string>
#include <opencv2/opencv.hpp>
class PartClassifier {
public:
explicit PartClassifier(const std::string& model_path);
std::string classify(const cv::Mat& roi);
void detectAndClassify(const cv::Mat& input, cv::Mat& output, bool draw_result = true);
private:
struct Impl;
std::shared_ptr<Impl> impl_;
};
✅ 2. PartClassifier.cpp (支持 ONNX 推理 + OpenCV ROI)
#include "PartClassifier.hpp"
#include <onnxruntime_cxx_api.h>
#include <algorithm>
#include <iostream>
struct PartClassifier::Impl {
Ort::Env env;
Ort::Session session;
Ort::MemoryInfo memory_info;
std::vector<int64_t> input_dims;
const char* input_names[1] = {"input"};
const char* output_names[1] = {"output"};
Impl(const std::string& path)
: env(ORT_LOGGING_LEVEL_WARNING, "part"),
session(env, path.c_str(), Ort::SessionOptions{nullptr}),
memory_info(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault)),
input_dims{1, 3, 224, 224} {}
std::string infer(const cv::Mat& roi) {
cv::Mat img;
cv::resize(roi, img, cv::Size(224, 224));
img.convertTo(img, CV_32F, 1.0 / 255);
std::vector<float> input_tensor(3 * 224 * 224);
for (int c = 0; c < 3; ++c)
for (int y = 0; y < 224; ++y)
for (int x = 0; x < 224; ++x)
input_tensor[c * 224 * 224 + y * 224 + x] = img.at<cv::Vec3f>(y, x)[c];
Ort::Value input_tensor_ort = Ort::Value::CreateTensor<float>(
memory_info, input_tensor.data(), input_tensor.size(), input_dims.data(), 4);
auto output = session.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor_ort, 1, output_names, 1);
float* scores = output[0].GetTensorMutableData<float>();
int label = std::max_element(scores, scores + 3) - scores;
static std::vector<std::string> labels = {"Front", "Back", "Stacked"};
return labels[label];
}
};
PartClassifier::PartClassifier(const std::string& model_path)
: impl_(std::make_shared<Impl>(model_path)) {}
std::string PartClassifier::classify(const cv::Mat& roi) {
return impl_->infer(roi);
}
void PartClassifier::detectAndClassify(const cv::Mat& input, cv::Mat& output, bool draw_result) {
cv::Mat gray, bin;
cv::cvtColor(input, gray, cv::COLOR_BGR2GRAY);
cv::adaptiveThreshold(gray, bin, 255, cv::ADAPTIVE_THRESH_MEAN_C,
cv::THRESH_BINARY_INV, 21, 5);
std::vector<std::vector<cv::Point>> contours;
cv::findContours(bin, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
output = input.clone();
for (const auto& contour : contours) {
if (cv::contourArea(contour) < 500) continue;
cv::Rect bbox = cv::boundingRect(contour);
if (bbox.width < 20 || bbox.height < 20) continue;
cv::Mat roi = input(bbox).clone();
std::string label = impl_->infer(roi);
if (draw_result) {
cv::rectangle(output, bbox, cv::Scalar(0, 255, 0), 2);
cv::putText(output, label, bbox.tl(), cv::FONT_HERSHEY_SIMPLEX, 0.6,
label == "Stacked" ? cv::Scalar(0, 0, 255) : cv::Scalar(255, 255, 0), 2);
}
}
}
✅ 3. 示例程序 main.cpp
#include "PartClassifier.hpp"
int main() {
PartClassifier classifier("part_classifier.onnx");
cv::Mat image = cv::imread("part_test.jpg");
if (image.empty()) {
std::cerr << "❌ 图像读取失败!" << std::endl;
return -1;
}
cv::Mat result;
classifier.detectAndClassify(image, result);
cv::imshow("Result", result);
cv::waitKey(0);
return 0;
}
📦 4. CMakeLists.txt
cmake_minimum_required(VERSION 3.10)
project(PartClassifierDemo)
set(CMAKE_CXX_STANDARD 17)
find_package(OpenCV REQUIRED)
find_package(onnxruntime REQUIRED)
include_directories(${OpenCV_INCLUDE_DIRS})
include_directories(${ONNXRUNTIME_INCLUDE_DIRS})
add_executable(demo main.cpp PartClassifier.cpp)
target_link_libraries(demo ${OpenCV_LIBS} onnxruntime)
✅ 构建 & 运行
mkdir build && cd build
cmake ..
make
./demo
🧠 可复用 API 一览
PartClassifier classifier("model.onnx");
// 直接裁剪一个 ROI 推理
std::string label = classifier.classify(cv::Mat roi);
// 自动识别图像中多个零件并标注
classifier.detectAndClassify(cv::Mat input, cv::Mat& output); |
|