add onnxruntime inference(no post-processing)

This commit is contained in:
liangyuxuan
2025-12-18 18:52:49 +08:00
parent 4f8ae43239
commit 44b2cc43e0
3 changed files with 129 additions and 4 deletions

View File

@@ -0,0 +1,33 @@
#include <vector>
#include <opencv2/opencv.hpp>
#include <onnxruntime_cxx_api.h>
namespace inference {
class OnnxruntimeInterface {
public:
OnnxruntimeInterface(const std::string &model_path, const int input_tensor_size = 1 * 3 * 640 * 640);
void forward(cv::Mat &img);
private:
Ort::Env env;
Ort::Session session;
Ort::SessionOptions session_options;
Ort::AllocatorWithDefaultOptions allocator;
Ort::MemoryInfo mem_info;
Ort::Value input_tensor;
// std::vector<float> input_tensor_values(1 * 3 * 640 * 640);
std::vector<float> input_tensor_values;
float *detections_data, *prototypes_data;
std::vector<int64_t> detections_shape, prototypes_shape;
void segementation_interface(const cv::Mat &input_image);
};
}

View File

@@ -54,8 +54,9 @@ int main() {
);
// 5、获取输出名称
Ort::AllocatedStringPtr output_name_ptr = session.GetOutputNameAllocated(0, allocator);
const char* output_name = output_name_ptr.get();
Ort::AllocatedStringPtr detection_name = session.GetOutputNameAllocated(0, allocator);
Ort::AllocatedStringPtr prototype_name = session.GetOutputNameAllocated(1, allocator);
const char* output_name[] = {detection_name.get(), prototype_name.get()};
// 6、推理
auto output_tensors = session.Run(
@@ -63,14 +64,17 @@ int main() {
&input_name,
&input_tensor,
1,
&output_name,
1
output_name,
2
);
// 7、输出信息
float* output_data = output_tensors[0].GetTensorMutableData<float>();
auto out_shape = output_tensors[0].GetTensorTypeAndShapeInfo().GetShape();
float* output_data_1 = output_tensors[1].GetTensorMutableData<float>();
auto out_shape_1 = output_tensors[1].GetTensorTypeAndShapeInfo().GetShape();
std::cout << "Output shape: [";
for (size_t i = 0; i < out_shape.size(); ++i) {
std::cout << out_shape[i];

View File

@@ -0,0 +1,88 @@
#include "vision_test/ultralytics/segementation_inference_onnxruntime.hpp"
inference::OnnxruntimeInterface::OnnxruntimeInterface(
const std::string &model_path,
const int input_tensor_size
) : env(ORT_LOGGING_LEVEL_WARNING, "segmentation_model"),
session_options(),
session(nullptr),
mem_info(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault))
{
input_tensor_values.resize(input_tensor_size);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
session = Ort::Session(env, model_path.c_str(), session_options);
std::vector<int64_t> input_shape = {1, 3, 640, 640};
input_tensor = Ort::Value::CreateTensor<float>(
mem_info,
input_tensor_values.data(),
input_tensor_values.size(),
input_shape.data(),
input_shape.size()
);
}
void inference::OnnxruntimeInterface::forward(cv::Mat &img) {
// 2、预处理resize + float + normalize
cv::resize(img, img, cv::Size(640, 640));
img.convertTo(img, CV_32F, 1.0 / 255.0);
int idx = 0;
for (int c = 0; c < 3; ++c)
for (int h = 0; h < 640; ++h)
for (int w = 0; w < 640; ++w)
input_tensor_values[idx++] = img.at<cv::Vec3f>(h, w)[c];
// 调用推理接口
segementation_interface(img);
}
void inference::OnnxruntimeInterface::segementation_interface(
const cv::Mat &input_image
) {
// 4、获取输入名称
Ort::AllocatorWithDefaultOptions allocator;
Ort::AllocatedStringPtr input_name_ptr = session.GetInputNameAllocated(0, allocator);
const char* input_name = input_name_ptr.get();
// 5、获取输出名称
Ort::AllocatedStringPtr detection_name = session.GetOutputNameAllocated(0, allocator);
Ort::AllocatedStringPtr prototype_name = session.GetOutputNameAllocated(1, allocator);
const char* output_name[] = {detection_name.get(), prototype_name.get()};
// 6、推理
auto output_tensors = session.Run(
Ort::RunOptions{nullptr},
&input_name,
&input_tensor,
1,
output_name,
2
);
// 7、输出信息
detections_data = output_tensors[0].GetTensorMutableData<float>();
detections_shape = output_tensors[0].GetTensorTypeAndShapeInfo().GetShape();
prototypes_data = output_tensors[1].GetTensorMutableData<float>();
prototypes_shape = output_tensors[1].GetTensorTypeAndShapeInfo().GetShape();
}
// int main() {
// cv::Mat img = cv::imread("/home/lyx/ROS2/hivecore_part_test/src/hivecore_robot_vision/test/color_image.png");
// if (img.empty()) {
// std::cerr << "Failed to read image!" << std::endl;
// return -1;
// }
// inference::SegementationModel model("/home/lyx/ROS2/hivecore_part_test/src/hivecore_robot_vision/vision_test/checkpoints/yolo11n-seg.onnx");
// model.forward(img);
// return 0;
// }