文章目录
使用 NCNN 在移动端部署深度学习模型
一、整体流程概览
(1)训练模型,使用各种你熟悉的框架我用的是 pytorch
(2)将 *.pth
转换成 onnx, 优化 onnx 模型
(3)使用转换工具转换成可供 ncnn 使用的模型
(4)编译 ncnn 框架,并编写 c 代码调用上一步转换的模型,得到模型的输出结果,封装成可供调用的类
(5)使用 JNIC 调用上一步 C++ 封装的类,提供出接口
(6)在安卓端编写 java 代码再次封装一次,供应用层调用
二、将 *.pth
转换成 onnx
使用 pytorch 自带的 torch.onnx 即可,需要 1.1 版本以上,这里有一点需要注意,torch 的 API 有些是 onnx 不支持的,如果转换的时候报错就把模型里的函数改成 onnx 支持的吧,有些文章里说这里可以设置 opset_version=12 来解决,但是这样的话在后面转换到 ncnn 或者 mnn 的时候造成转换失败,应该是 ncnn 还没支持到更高版本的 onnx 的原因。在最后输出之前有个 torch.randn () 函数,这里的参数格式是 [b,c, w,h] 这里也不是随便写的,b 固定是 1 了,你模型的输入通道是多少就写多少,后面的就是模型的输入,这里一旦固定了,后面在第 5 步的时候 c++ 里的输入也就固定了
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 import torchdef load_model (model, pretrained_path ): print ('Loading pretrained model from {}' .format (pretrained_path)) pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage) if "state_dict" in pretrained_dict.keys(): pretrained_dict = remove_prefix(pretrained_dict['state_dict' ], 'module.' ) else : pretrained_dict = remove_prefix(pretrained_dict, 'module.' ) check_keys(model, pretrained_dict) model.load_state_dict(pretrained_dict, strict=False ) return model output_onnx = '../weights/output.onnx' raw_weights = '../weights/model.pth' net = you_net() net = load_model(net, raw_weights) net.eval () print ('Finished loading model!' )device = torch.device("cuda" ) net = net.to(device) input_names = ["input0" ] output_names = ["output0" ] inputs = torch.randn(1 , 3 , 300 , 300 ).to(device) torch_out = torch.onnx._export(net, inputs, output_onnx, export_params=True , verbose=False ,keep_initializers_as_inputs=True , input_names=input_names, output_names=output_names)
安装 onnx 简化工具
1 pip3 install onnx-simplifier onnxruntime
简化 onnx 模型
这一步一定要做,否则后面转 onnx 的时候会报错
1 python3 -m onnxsim model.onnx model_sim.onnx
三、编译 NCNN 框架
主要参考 ncnn 官网的教程即可,windows 下编译同上一篇的 MNN 的编译都差不多,只有一点需要说明,官网的教程上有 vulkan-sdk 的安装然后打开 - DNCNN_VULKAN=ON 编译选项。我一切照做后编译出来的 ncnn.lib 在运行 ncnn::Extractor ex = Net->create_extractor (); 这个函数后的所有操作之后,返回的时候就报堆栈溢出错误,包括加载官网给出的例子全部报错;后来不 cmake 的时候这个编译选项不打开编译出来的 ncnn.lib 就一切正常了。可能是自己的问题,也没去深究。反正能用就 OK 了。我把编译出来的 ncnn.lib ncnn.a 和 linux 下的 onnx2ncnn 工具都放在了我的网盘里,不想被编译折磨的就直接去下吧。如果编译遇见问题,也可以给我留言,哈哈~
说明:ncnnd.lib 是 windows 下的 debug 版本,ncnn.lib 是 release 版本,libncnn.a 是 linux 下的库文件,onnx2ncnn 是 linux 下的转换工具。
下载地址:NCNN 提取码:6cuc
四、C++ 调用和封装
说明
对于 vs 中 lib 库和 include 目录的配置就不赘述了,有不懂的之前的文章有提过,假定工程已经配置完成。大体的调用过程 NCNN 和 MNN 都差不多,先加载模型创建一个指向模型的指针,然后创建 session、创建用于处理输入的 tensor,将 input_tensor 送入 session,运行 session,最后得到网络的输出。如果对 C++ 比较熟悉的话,看着官网的教程比葫芦画瓢即可,只有一个地方需要说明就是对输出的获得。先看下我的代码和官网的代码再说为什么
我的输出 :
1 2 3 4 // run net and get output ncnn::Mat out, out1; ret = ex.extract("output0" , out); ex.extract("376" , out1);
官网的例子输出 :
1 2 ncnn::Mat out; ex.extract("detection_out" , out);
辣么问题来了,我的 "output0" 和 "376"、官网的 “detection_out” 都哪里来的?有两个地方可以得到,最简单的方法,使用 MNN 框架下的转换工具,在转换完成的时候会给出模型的输入和输出名称,直接拷贝即可
1 2 3 4 5 6 7 8 9 10 11 >MNNConvert.exe -f ONNX --modelFile model.onnx --MNNModel slime.mnn --bizCode biz MNNConverter Version: 0.2.1.5git - MNN @ 2018 Start to Convert Other Model Format To MNN Model... [17:49:58] :29: ONNX Model ir version: 6 Start to Optimize the MNN Net... [17:49:58] :20: Inputs: input0 [17:49:58] :37: Outputs: output0, Type = Concat [17:49:58] :37: Outputs: 376, Type = Softmax Converted Done!
如果没有 MNN 的转换工具,在后面加载模型后单步跟一下,在 Net = new ncnn::Net () 变量中有个 blob 变量,在内存中查看一下,里面存的有模型的各个层的名称。代码中的 img_w,img_h 就是在第二步转换的时候你指定的 w,h。这里只写了核心调用函数,具体使用时还请自行添加一些辅助函数!
C++ 代码
detection.h
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 struct bbox { float x1; float y1; float x2; float y2; float s; }; struct box { float cx; float cy; float sx; float sy; }; struct ObjectInfo { float x1; //bbox的left float y1; //bbox的top float x2; //bbox的right float y2; //bbox的bottom float prob; //置信度 }; class ObjectDetection { private: float _nms = 0.4 ; float _threshold = 0.6 ; const float mean_vals[3 ] = { 104. f, 117. f, 123. f }; const float norm_vals[3 ] = { 1.0 / 104.0 , 1.0 / 117.0 , 1.0 / 123.0 }; cv::Mat img; ncnn::Net *Net; int img_w = 300 ; int img_h = 300 ; int numThread; int detect_count = 0 ; static inline bool cmp(bbox a, bbox b); public: ObjectDetection(std::string modelFolder, int num_thread); ~ObjectDetection(); int Detect(unsigned char *inputImage, int inputw, int inputh, std::vector<ObjectInfo > &obj); };
detection.cpp
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 ObjectDetection::ObjectDetection(std::string modelFolder, int num_thread) { Net = new ncnn::Net(); std::string model_param = modelFolder + "Detect.param" ; std::string model_bin = modelFolder + "Detect.bin" ; int ret = Net->load_param(model_param.c_str()); ret = Net->load_model(model_bin.c_str()); numThread = num_thread; } ObjectDetection::~ObjectDetection() { if (Net != nullptr) { delete Net; Net = nullptr; } } int ObjectDetection::Detect(unsigned char *inputImage, int inputw, int inputh, std::vector<ObjectInfo > &obj){ int ret = -1 ; ncnn::Mat in = ncnn::Mat::from_pixels_resize(inputImage, ncnn::Mat::PIXEL_BGR, inputw, inputh, img_w, img_h); in .substract_mean_normalize(mean_vals, norm_vals); ncnn::Extractor ex = Net->create_extractor(); ex.set_light_mode(true); ret = ex.input ("input0" , in ); // run net and get output ncnn::Mat out, out1; // bbox的输出 ret = ex.extract("output0" , out); ex.extract("376" , out1); // get result for (int i = 0 ; i < out.h; ++i) { // 得到网络的具体输出 const float *boxes = out.row(i); const float *scores = out1.row(i); // 执行你自己的操作 } std::sort(total_box.begin(), total_box.end(), cmp); NMS(total_box, _nms); return 0 ; }
五、 编写 JNI C++
在 Android Studio 中配置 NDK,具体配置网上有很多教程我就不啰嗦了,假定 android strdio 的 jni c 环境已经配置完成。源码中的函数名的格式是 jni c 要求的,必须这种格式,根据实际情况修改,函数名中的 "com_example_demokit_Detection" 对应到 java 的应用中就是 "com.example.demokit.Detection" 这样就很好理解了。
native-lib.cpp
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 extern "C" JNIEXPORT jlong JNICALL Java_com_example_demokit_Detection_Create(JNIEnv *env, jobject instance, jstring path) { char* _path; _path = (char*)env->GetStringUTFChars(path,0); Detection *phandle = new Detection(_path, 2 ); return (jlong)phandle; } extern "C" JNIEXPORT jintArray JNICALL Java_com_example_demokit_Detection_Detect(JNIEnv *env, jobject instance, jlong handle, jint campos, jint w, jint h, jbyteArray data_) { Detection *gp = NULL; if (handle) gp = (Detection *)handle; else return nullptr; jbyte *data = env->GetByteArrayElements(data_, NULL); std::vector<ObjectInfo> objects; gp->Detect((unsigned char*)data, w, h, objects); env->ReleaseByteArrayElements(data_, data, 0); jintArray jarr = env->NewIntArray(objects.size()*15+1); jint *arr = env->GetIntArrayElements(jarr, NULL); arr[0 ] = objects.size(); for (int i = 0 ; i < objects.size(); i++) { arr[i*5 + 1 ] = objects[i].x1; arr[i*5 + 2 ] = objects[i].y1; arr[i*5 + 3 ] = objects[i].x2; arr[i*5 + 4 ] = objects[i].y2; arr[i*5 + 5 ] = objects[i].prob; } env->ReleaseIntArrayElements(jarr, arr, 0); return jarr; }
六、java 调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 package com.example.demokit; public class Detection { static { System.loadLibrary("native-lib" ); } private long handle; public Detection(String path){ handle = Create(path); } public int [] Detect(int w, int h, byte[] data){ return Detect(handle, w, h, data); } private native long Create(String path); private native int [] Detect(long handle, int w, int h, byte[] data); }
七、 应用层使用
在应用层就可以直接调用上面的 java 类啦,搞定~
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 package com.example.demokit; import android.graphics.Point;import java.util.ArrayList;import java.util.Arrays;import java.util.List ;public class DetectTool { private Detection mDetection; private static final int DATA_LENGTH = 5 ; // 矩形框坐标2 个,每个具有x,y两个值;置信度1 个; public DetectTool(String dect_model_dir){ /** * @dect_model_dir: 检测模型所在的目录路径 */ mDetection = new Detection(dect_model_dir); } private ObjectInfo ArrayAnalysis(int [] src_array){ /** * 对输入的数组进行解析,返回ObjectInfo对象 * @src_array: 具有DATA_LENGTH所示结构的数组 */ ObjectInfo obj_info = new ObjectInfo(); Point[] pointFaceBox = new Point[2 ]; // face_bbox 坐标 for (int i = 0 ; i < 2 ; i++) { Point point = new Point(); point.x = src_array[2 *i]; point.y = src_array[2 *i+1 ]; pointFaceBox[i] = point; } // 置信度 obj_info.setProb(src_array[4 ]); return obj_info ; } public List <ObjectInfo> GetObjectInfo(int width, int height, byte[] data){ /** * @width:图片宽度 * @height:图片高度 * @data:图片的字节流 */ int [] obj= mDetection.Detect(width, height, data); List <ObjectInfo> obj_list = new ArrayList<>(); int obj_count = obj[0 ]; for (int i = 0 ; i < obj_count ; i++){ int [] obj_array = Arrays.copyOfRange(obj, i*DATA_LENGTH + 1 , (i + 1 ) * DATA_LENGTH+1 ); ObjectInfo obj_info = this.ArrayAnalysis(obj_array); obj_list.add(obj_info); } return obj_list; } }
————————————————
版权声明:本文为 CSDN 博主「zzubqh103」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_36810544/article/details/106911025