文章目录


使用 NCNN 在移动端部署深度学习模型

一、整体流程概览

(1)训练模型,使用各种你熟悉的框架我用的是 pytorch
(2)将 *.pth 转换成 onnx, 优化 onnx 模型
(3)使用转换工具转换成可供 ncnn 使用的模型
(4)编译 ncnn 框架,并编写 c 代码调用上一步转换的模型,得到模型的输出结果,封装成可供调用的类
(5)使用 JNIC
调用上一步 C++ 封装的类,提供出接口
(6)在安卓端编写 java 代码再次封装一次,供应用层调用

二、将 *.pth 转换成 onnx

使用 pytorch 自带的 torch.onnx 即可,需要 1.1 版本以上,这里有一点需要注意,torch 的 API 有些是 onnx 不支持的,如果转换的时候报错就把模型里的函数改成 onnx 支持的吧,有些文章里说这里可以设置 opset_version=12 来解决,但是这样的话在后面转换到 ncnn 或者 mnn 的时候造成转换失败,应该是 ncnn 还没支持到更高版本的 onnx 的原因。在最后输出之前有个 torch.randn () 函数,这里的参数格式是 [b,c, w,h] 这里也不是随便写的,b 固定是 1 了,你模型的输入通道是多少就写多少,后面的就是模型的输入,这里一旦固定了,后面在第 5 步的时候 c++ 里的输入也就固定了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# -*- coding:utf-8 -*-
# name: convert_onnx
# author: bqh
# datetime:2020/6/17 10:31
# =========================
import torch
def load_model(model, pretrained_path):
print('Loading pretrained model from {}'.format(pretrained_path))
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict, strict=False)
return model

output_onnx = '../weights/output.onnx'
raw_weights = '../weights/model.pth'

# load weight
net = you_net()
net = load_model(net, raw_weights)
net.eval()
print('Finished loading model!')
device = torch.device("cuda")
net = net.to(device)

input_names = ["input0"]
output_names = ["output0"]
inputs = torch.randn(1, 3, 300, 300).to(device)
torch_out = torch.onnx._export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)
安装 onnx 简化工具
1
pip3 install onnx-simplifier onnxruntime
简化 onnx 模型

这一步一定要做,否则后面转 onnx 的时候会报错

1
python3 -m onnxsim model.onnx model_sim.onnx
三、编译 NCNN 框架

主要参考 ncnn 官网的教程即可,windows 下编译同上一篇的 MNN 的编译都差不多,只有一点需要说明,官网的教程上有 vulkan-sdk 的安装然后打开 - DNCNN_VULKAN=ON 编译选项。我一切照做后编译出来的 ncnn.lib 在运行 ncnn::Extractor ex = Net->create_extractor (); 这个函数后的所有操作之后,返回的时候就报堆栈溢出错误,包括加载官网给出的例子全部报错;后来不 cmake 的时候这个编译选项不打开编译出来的 ncnn.lib 就一切正常了。可能是自己的问题,也没去深究。反正能用就 OK 了。我把编译出来的 ncnn.lib ncnn.a 和 linux 下的 onnx2ncnn 工具都放在了我的网盘里,不想被编译折磨的就直接去下吧。如果编译遇见问题,也可以给我留言,哈哈~
说明:ncnnd.lib 是 windows 下的 debug 版本,ncnn.lib 是 release 版本,libncnn.a 是 linux 下的库文件,onnx2ncnn 是 linux 下的转换工具。
下载地址:NCNN 提取码:6cuc

四、C++ 调用和封装
说明

对于 vs 中 lib 库和 include 目录的配置就不赘述了,有不懂的之前的文章有提过,假定工程已经配置完成。大体的调用过程 NCNN 和 MNN 都差不多,先加载模型创建一个指向模型的指针,然后创建 session、创建用于处理输入的 tensor,将 input_tensor 送入 session,运行 session,最后得到网络的输出。如果对 C++ 比较熟悉的话,看着官网的教程比葫芦画瓢即可,只有一个地方需要说明就是对输出的获得。先看下我的代码和官网的代码再说为什么
我的输出

1
2
3
4
// run net and get output
ncnn::Mat out, out1;
ret = ex.extract("output0", out);
ex.extract("376", out1);

官网的例子输出

1
2
ncnn::Mat out;
ex.extract("detection_out", out);

辣么问题来了,我的 "output0" 和 "376"、官网的 “detection_out” 都哪里来的?有两个地方可以得到,最简单的方法,使用 MNN 框架下的转换工具,在转换完成的时候会给出模型的输入和输出名称,直接拷贝即可

1
2
3
4
5
6
7
8
9
10
11
>MNNConvert.exe -f ONNX --modelFile model.onnx --MNNModel slime.mnn  --bizCode biz

MNNConverter Version: 0.2.1.5git - MNN @ 2018

Start to Convert Other Model Format To MNN Model...
[17:49:58] :29: ONNX Model ir version: 6
Start to Optimize the MNN Net...
[17:49:58] :20: Inputs: input0
[17:49:58] :37: Outputs: output0, Type = Concat
[17:49:58] :37: Outputs: 376, Type = Softmax
Converted Done!

如果没有 MNN 的转换工具,在后面加载模型后单步跟一下,在 Net = new ncnn::Net () 变量中有个 blob 变量,在内存中查看一下,里面存的有模型的各个层的名称。代码中的 img_w,img_h 就是在第二步转换的时候你指定的 w,h。这里只写了核心调用函数,具体使用时还请自行添加一些辅助函数!

C++ 代码

detection.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#pragma once

#include <opencv2/opencv.hpp>
#include <string>
#include <stack>
#include "net.h"
#include <stdio.h>
#include <algorithm>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/opencv.hpp>
#include <fstream>
#include "omp.h"


struct bbox {
float x1;
float y1;
float x2;
float y2;
float s;
};

struct box {
float cx;
float cy;
float sx;
float sy;
};

struct ObjectInfo {
float x1; //bbox的left
float y1; //bbox的top
float x2; //bbox的right
float y2; //bbox的bottom
float prob; //置信度
};


class ObjectDetection
{
private:
float _nms = 0.4;
float _threshold = 0.6;
const float mean_vals[3] = { 104.f, 117.f, 123.f };
const float norm_vals[3] = { 1.0 / 104.0, 1.0 / 117.0, 1.0 / 123.0 };
cv::Mat img;
ncnn::Net *Net;
int img_w = 300;
int img_h = 300;
int numThread;
int detect_count = 0;
static inline bool cmp(bbox a, bbox b);
public:
ObjectDetection(std::string modelFolder, int num_thread);
~ObjectDetection();
int Detect(unsigned char *inputImage, int inputw, int inputh, std::vector<ObjectInfo > &obj);
};

detection.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include "Detection.h"
#include <cmath>

ObjectDetection::ObjectDetection(std::string modelFolder, int num_thread)
{
Net = new ncnn::Net();
std::string model_param = modelFolder + "Detect.param";
std::string model_bin = modelFolder + "Detect.bin";
int ret = Net->load_param(model_param.c_str());
ret = Net->load_model(model_bin.c_str());
numThread = num_thread;
}

ObjectDetection::~ObjectDetection()
{
if (Net != nullptr)
{
delete Net;
Net = nullptr;
}
}

int ObjectDetection::Detect(unsigned char *inputImage, int inputw, int inputh, std::vector<ObjectInfo > &obj)
{
int ret = -1;
ncnn::Mat in = ncnn::Mat::from_pixels_resize(inputImage, ncnn::Mat::PIXEL_BGR, inputw, inputh, img_w, img_h);
in.substract_mean_normalize(mean_vals, norm_vals);
ncnn::Extractor ex = Net->create_extractor();
ex.set_light_mode(true);
ret = ex.input("input0", in);
// run net and get output
ncnn::Mat out, out1;
// bbox的输出
ret = ex.extract("output0", out);
ex.extract("376", out1);

// get result
for (int i = 0; i < out.h; ++i)
{
// 得到网络的具体输出
const float *boxes = out.row(i);
const float *scores = out1.row(i);
// 执行你自己的操作
}
std::sort(total_box.begin(), total_box.end(), cmp);
NMS(total_box, _nms);
return 0;
}
五、 编写 JNI C++

在 Android Studio 中配置 NDK,具体配置网上有很多教程我就不啰嗦了,假定 android strdio 的 jni c 环境已经配置完成。源码中的函数名的格式是 jni c 要求的,必须这种格式,根据实际情况修改,函数名中的 "com_example_demokit_Detection" 对应到 java 的应用中就是 "com.example.demokit.Detection" 这样就很好理解了。
native-lib.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <jni.h>
#include <string>

#include "Detection.h"
#include <android/log.h>

extern "C" JNIEXPORT jlong JNICALL
Java_com_example_demokit_Detection_Create(JNIEnv *env, jobject instance, jstring path) {
char* _path;
_path = (char*)env->GetStringUTFChars(path,0);
Detection *phandle = new Detection(_path, 2);
return (jlong)phandle;
}

extern "C" JNIEXPORT jintArray JNICALL
Java_com_example_demokit_Detection_Detect(JNIEnv *env, jobject instance, jlong handle, jint campos, jint w, jint h, jbyteArray data_) {
Detection *gp = NULL;
if (handle)
gp = (Detection *)handle;
else
return nullptr;
jbyte *data = env->GetByteArrayElements(data_, NULL);

std::vector<ObjectInfo> objects;
gp->Detect((unsigned char*)data, w, h, objects);
env->ReleaseByteArrayElements(data_, data, 0);
jintArray jarr = env->NewIntArray(objects.size()*15+1);
jint *arr = env->GetIntArrayElements(jarr, NULL);
arr[0] = objects.size();
for (int i = 0; i < objects.size(); i++)
{
arr[i*5 + 1] = objects[i].x1;
arr[i*5 + 2] = objects[i].y1;
arr[i*5 + 3] = objects[i].x2;
arr[i*5 + 4] = objects[i].y2;
arr[i*5 + 5] = objects[i].prob;
}
env->ReleaseIntArrayElements(jarr, arr, 0);
return jarr;
}
六、java 调用
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
package com.example.demokit;

public class Detection {
static {
System.loadLibrary("native-lib");
}
private long handle;
public Detection(String path){
handle = Create(path);
}
public int[] Detect(int w, int h, byte[] data){
return Detect(handle, w, h, data);
}
private native long Create(String path);
private native int[] Detect(long handle, int w, int h, byte[] data);
}
七、 应用层使用

在应用层就可以直接调用上面的 java 类啦,搞定~

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
package com.example.demokit;
import android.graphics.Point;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class DetectTool {
private Detection mDetection;
private static final int DATA_LENGTH = 5; // 矩形框坐标2个,每个具有x,y两个值;置信度1个;

public DetectTool(String dect_model_dir){
/**
* @dect_model_dir: 检测模型所在的目录路径
*/
mDetection = new Detection(dect_model_dir);
}

private ObjectInfo ArrayAnalysis(int[] src_array){
/**
* 对输入的数组进行解析,返回ObjectInfo对象
* @src_array: 具有DATA_LENGTH所示结构的数组
*/
ObjectInfo obj_info = new ObjectInfo();

Point[] pointFaceBox = new Point[2];
// face_bbox 坐标
for(int i = 0; i < 2; i++) {
Point point = new Point();
point.x = src_array[2*i];
point.y = src_array[2*i+1];
pointFaceBox[i] = point;
}
// 置信度
obj_info.setProb(src_array[4]);
return obj_info ;
}

public List<ObjectInfo> GetObjectInfo(int width, int height, byte[] data){
/**
* @width:图片宽度
* @height:图片高度
* @data:图片的字节流
*/
int[] obj= mDetection.Detect(width, height, data);
List<ObjectInfo> obj_list = new ArrayList<>();
int obj_count = obj[0];
for(int i = 0; i < obj_count ; i++){
int[] obj_array = Arrays.copyOfRange(obj, i*DATA_LENGTH + 1, (i + 1) * DATA_LENGTH+1);
ObjectInfo obj_info = this.ArrayAnalysis(obj_array);
obj_list.add(obj_info);
}
return obj_list;
}
}

————————————————
版权声明:本文为 CSDN 博主「zzubqh103」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_36810544/article/details/106911025

× 请我吃糖~
打赏二维码