文章目录


ONNX 模型构造与代码检查

参考博客:https://zhuanlan.zhihu.com/p/516920606

1 构造描述张量信息的对象 ValueInfoProto
1
2
3
4
5
6
7
8
import onnx 
from onnx import helper
from onnx import TensorProto

a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])
2 构造算子节点信息 NodeProto
1
2
mul = helper.make_node('Mul', ['a', 'x'], ['c']) 
add = helper.make_node('Add', ['c', 'b'], ['output'])
3 构造计算图 GraphProto
1
graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output]) 
4 封装计算图

helper.make_model 把计算图 GraphProto 封装进模型 ModelProto

1
model = helper.make_model(graph) 
5 检查代码
1
2
3
onnx.checker.check_model(model) 
print(model)
onnx.save(model, 'linear_func.onnx')

ONNX Python API 构造模型完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import onnx 
from onnx import helper
from onnx import TensorProto

# input and output
a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10])
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10])

# Mul
mul = helper.make_node('Mul', ['a', 'x'], ['c'])

# Add
add = helper.make_node('Add', ['c', 'b'], ['output'])

# graph and model
graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output])
model = helper.make_model(graph)

# save model
onnx.checker.check_model(model)
print(model)
onnx.save(model, 'linear_func.onnx')
× 请我吃糖~
打赏二维码