Upload 12 files
Browse files- .gitattributes +3 -0
- convert.py +381 -0
- decoder_model.onnx +3 -0
- decoder_model.rknn +3 -0
- decoder_model_merged.onnx +3 -0
- embed_tokens.onnx +3 -0
- encoder_model.onnx +3 -0
- encoder_model.rknn +3 -0
- image.png +0 -0
- run.py +276 -0
- vision_encoder.onnx +3 -0
- vision_encoder.rknn +3 -0
- ztu_somemodelruntime_rknnlite2.py +569 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
decoder_model.rknn filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
encoder_model.rknn filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
vision_encoder.rknn filter=lfs diff=lfs merge=lfs -text
|
convert.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding: utf-8
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from rknn.api import RKNN
|
| 6 |
+
from math import exp
|
| 7 |
+
from sys import exit
|
| 8 |
+
|
| 9 |
+
import onnx
|
| 10 |
+
import onnxscript
|
| 11 |
+
|
| 12 |
+
batch_size = 1
|
| 13 |
+
encoder_seq_len_list = [13]
|
| 14 |
+
|
| 15 |
+
decoder_seq_len = 1
|
| 16 |
+
|
| 17 |
+
# set current directory to the directory of this file
|
| 18 |
+
import os
|
| 19 |
+
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
| 20 |
+
|
| 21 |
+
import subprocess
|
| 22 |
+
import select
|
| 23 |
+
|
| 24 |
+
def run_python_code(code):
|
| 25 |
+
# 启动子进程并执行代码
|
| 26 |
+
process = subprocess.Popen(
|
| 27 |
+
['python', '-c', code],
|
| 28 |
+
stdout=subprocess.PIPE,
|
| 29 |
+
stderr=subprocess.PIPE,
|
| 30 |
+
text=True
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# 实时读取子进程的输出和错误输出
|
| 34 |
+
while True:
|
| 35 |
+
reads = [process.stdout.fileno(), process.stderr.fileno()]
|
| 36 |
+
ret = select.select(reads, [], [])
|
| 37 |
+
|
| 38 |
+
for fd in ret[0]:
|
| 39 |
+
if fd == process.stdout.fileno():
|
| 40 |
+
output = process.stdout.readline()
|
| 41 |
+
if output:
|
| 42 |
+
print(output.strip())
|
| 43 |
+
if fd == process.stderr.fileno():
|
| 44 |
+
err = process.stderr.readline()
|
| 45 |
+
if err:
|
| 46 |
+
print(f"Error: {err.strip()}")
|
| 47 |
+
|
| 48 |
+
if process.poll() is not None:
|
| 49 |
+
break
|
| 50 |
+
|
| 51 |
+
def convert_decoder():
|
| 52 |
+
rknn = RKNN(verbose=True)
|
| 53 |
+
|
| 54 |
+
ONNX_MODEL="decoder_model.onnx"
|
| 55 |
+
RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
|
| 56 |
+
DATASET="dataset.txt"
|
| 57 |
+
QUANTIZE=False
|
| 58 |
+
|
| 59 |
+
# [batch_size, encoder_seq_len, 768],
|
| 60 |
+
# [batch_size, decoder_seq_len, 768]]
|
| 61 |
+
input_shapes =[[
|
| 62 |
+
[batch_size, encoder_seq_len, 768],
|
| 63 |
+
[batch_size, decoder_seq_len, 768]] for encoder_seq_len in encoder_seq_len_list]
|
| 64 |
+
# pre-process config
|
| 65 |
+
print('--> Config model')
|
| 66 |
+
rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3,
|
| 67 |
+
dynamic_input=input_shapes)
|
| 68 |
+
print('done')
|
| 69 |
+
|
| 70 |
+
# Load ONNX model
|
| 71 |
+
print('--> Loading model')
|
| 72 |
+
ret = rknn.load_onnx(model=ONNX_MODEL,
|
| 73 |
+
)
|
| 74 |
+
if ret != 0:
|
| 75 |
+
print('Load model failed!')
|
| 76 |
+
exit(ret)
|
| 77 |
+
print('done')
|
| 78 |
+
|
| 79 |
+
# Build model
|
| 80 |
+
print('--> Building model')
|
| 81 |
+
ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
|
| 82 |
+
if ret != 0:
|
| 83 |
+
print('Build model failed!')
|
| 84 |
+
exit(ret)
|
| 85 |
+
print('done')
|
| 86 |
+
|
| 87 |
+
#export
|
| 88 |
+
print('--> Export RKNN model')
|
| 89 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
| 90 |
+
if ret != 0:
|
| 91 |
+
print('Export RKNN model failed!')
|
| 92 |
+
exit(ret)
|
| 93 |
+
print('done')
|
| 94 |
+
|
| 95 |
+
def convert_decoder_2():
|
| 96 |
+
import onnx_graphsurgeon as gs
|
| 97 |
+
ONNX_MODEL="decoder_model_merged.onnx"
|
| 98 |
+
|
| 99 |
+
graph = gs.import_onnx(onnx.load(ONNX_MODEL))
|
| 100 |
+
inp = graph.inputs[27] # use_cache_branch
|
| 101 |
+
inp.to_constant(np.array([True], dtype=np.bool_))
|
| 102 |
+
ONNX_MODEL
|
| 103 |
+
onnx.save(gs.export_onnx(graph), "new_model.onnx")
|
| 104 |
+
|
| 105 |
+
np_true = np.array([True], dtype=np.bool_)
|
| 106 |
+
np.save("np_true.npy", np_true)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
rknn = RKNN(verbose=True)
|
| 110 |
+
|
| 111 |
+
RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
|
| 112 |
+
DATASET="dataset.txt"
|
| 113 |
+
QUANTIZE=False
|
| 114 |
+
|
| 115 |
+
# [batch_size, encoder_seq_len, 768],
|
| 116 |
+
# [batch_size, decoder_seq_len, 768]]
|
| 117 |
+
input_shapes =[[
|
| 118 |
+
[batch_size, encoder_seq_len, 768],
|
| 119 |
+
[batch_size, decoder_seq_len, 768]] for encoder_seq_len in encoder_seq_len_list]
|
| 120 |
+
# pre-process config
|
| 121 |
+
print('--> Config model')
|
| 122 |
+
rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3,
|
| 123 |
+
dynamic_input=input_shapes)
|
| 124 |
+
print('done')
|
| 125 |
+
|
| 126 |
+
# Load ONNX model
|
| 127 |
+
print('--> Loading model')
|
| 128 |
+
ret = rknn.load_onnx(model=ONNX_MODEL,
|
| 129 |
+
)
|
| 130 |
+
if ret != 0:
|
| 131 |
+
print('Load model failed!')
|
| 132 |
+
exit(ret)
|
| 133 |
+
print('done')
|
| 134 |
+
|
| 135 |
+
# Build model
|
| 136 |
+
print('--> Building model')
|
| 137 |
+
ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
|
| 138 |
+
if ret != 0:
|
| 139 |
+
print('Build model failed!')
|
| 140 |
+
exit(ret)
|
| 141 |
+
print('done')
|
| 142 |
+
|
| 143 |
+
#export
|
| 144 |
+
print('--> Export RKNN model')
|
| 145 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
| 146 |
+
if ret != 0:
|
| 147 |
+
print('Export RKNN model failed!')
|
| 148 |
+
exit(ret)
|
| 149 |
+
print('done')
|
| 150 |
+
|
| 151 |
+
def convert_encoder():
|
| 152 |
+
rknn = RKNN(verbose=True)
|
| 153 |
+
|
| 154 |
+
ONNX_MODEL="encoder_model.onnx"
|
| 155 |
+
RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
|
| 156 |
+
DATASET="dataset.txt"
|
| 157 |
+
QUANTIZE=False
|
| 158 |
+
|
| 159 |
+
input_shapes = [[[batch_size, encoder_seq_len, 768], [batch_size, encoder_seq_len]] for encoder_seq_len in encoder_seq_len_list]
|
| 160 |
+
# pre-process config
|
| 161 |
+
print('--> Config model')
|
| 162 |
+
rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, dynamic_input=input_shapes)
|
| 163 |
+
print('done')
|
| 164 |
+
|
| 165 |
+
# Load ONNX model
|
| 166 |
+
print('--> Loading model')
|
| 167 |
+
ret = rknn.load_onnx(model=ONNX_MODEL
|
| 168 |
+
)
|
| 169 |
+
if ret != 0:
|
| 170 |
+
print('Load model failed!')
|
| 171 |
+
exit(ret)
|
| 172 |
+
print('done')
|
| 173 |
+
|
| 174 |
+
# Build model
|
| 175 |
+
print('--> Building model')
|
| 176 |
+
ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
|
| 177 |
+
if ret != 0:
|
| 178 |
+
print('Build model failed!')
|
| 179 |
+
exit(ret)
|
| 180 |
+
print('done')
|
| 181 |
+
|
| 182 |
+
# Export RKNN model
|
| 183 |
+
print('--> Export RKNN model')
|
| 184 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
| 185 |
+
if ret != 0:
|
| 186 |
+
print('Export RKNN model failed!')
|
| 187 |
+
exit(ret)
|
| 188 |
+
print('done')
|
| 189 |
+
|
| 190 |
+
def convert_vision():
|
| 191 |
+
ONNX_MODEL="vision_encoder.onnx"
|
| 192 |
+
DATASET="dataset.txt"
|
| 193 |
+
QUANTIZE=False
|
| 194 |
+
global batch_size
|
| 195 |
+
|
| 196 |
+
##### Build stage 1
|
| 197 |
+
from rknn.api import RKNN
|
| 198 |
+
rknn = RKNN(verbose=True)
|
| 199 |
+
ONNX_MODEL="vision_encoder.onnx"
|
| 200 |
+
RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
|
| 201 |
+
DATASET="dataset.txt"
|
| 202 |
+
QUANTIZE=False
|
| 203 |
+
# pre-process config
|
| 204 |
+
print('--> Config model')
|
| 205 |
+
rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3)
|
| 206 |
+
print('done')
|
| 207 |
+
|
| 208 |
+
# Load ONNX model
|
| 209 |
+
print('--> Loading model')
|
| 210 |
+
ret = rknn.load_onnx(model=ONNX_MODEL,
|
| 211 |
+
inputs=["pixel_values"],
|
| 212 |
+
input_size_list=[[batch_size, 3, 64, 64]],
|
| 213 |
+
)
|
| 214 |
+
if ret != 0:
|
| 215 |
+
print('Load model failed!')
|
| 216 |
+
exit(ret)
|
| 217 |
+
print('done')
|
| 218 |
+
|
| 219 |
+
print('--> Building model stage 1')
|
| 220 |
+
ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
|
| 221 |
+
if ret != 0:
|
| 222 |
+
print('Build model failed!')
|
| 223 |
+
exit(ret)
|
| 224 |
+
print('done')
|
| 225 |
+
|
| 226 |
+
print("Build stage 1 done")
|
| 227 |
+
del rknn
|
| 228 |
+
|
| 229 |
+
intermidiate_model = onnx.load("check3_fuse_ops.onnx")
|
| 230 |
+
|
| 231 |
+
# fuse ops
|
| 232 |
+
from onnxscript.rewriter import pattern
|
| 233 |
+
import onnx.numpy_helper as onh
|
| 234 |
+
import numpy as np
|
| 235 |
+
def tp_rs_tp_rs_tp_pattern(op, input1, perm1, shape2, perm3, shape4, perm5):
|
| 236 |
+
i1 = op.Transpose(input1, perm=perm1)
|
| 237 |
+
i2 = op.Reshape(i1, shape2)
|
| 238 |
+
i3 = op.Transpose(i2, perm=perm3)
|
| 239 |
+
i4 = op.Reshape(i3, shape4)
|
| 240 |
+
i5 = op.Transpose(i4, perm=perm5)
|
| 241 |
+
return i5
|
| 242 |
+
|
| 243 |
+
def fused_pattern(op, input1, perm1, shape2, perm3, shape4, perm5):
|
| 244 |
+
rs1_shape = op.Constant(value=onh.from_array(np.array([input1.shape[0]* 3, input1.shape[1]//3, input1.shape[2], input1.shape[3]], dtype=np.int64)))
|
| 245 |
+
fi1 = op.Reshape(input1, rs1_shape)
|
| 246 |
+
fi2 = op.Transpose(fi1, perm=[0, 2, 1, 3])
|
| 247 |
+
elems = input1.shape[0] * input1.shape[1] * input1.shape[2] * input1.shape[3]
|
| 248 |
+
rs4_shape = op.Constant(value=onh.from_array(np.array([elems / 32 / 144, 32, 1, 144], dtype=np.int64)))
|
| 249 |
+
fi3 = op.Reshape(fi2, rs4_shape)
|
| 250 |
+
return fi3
|
| 251 |
+
|
| 252 |
+
rewrite_rule = pattern.RewriteRule(tp_rs_tp_rs_tp_pattern, fused_pattern)
|
| 253 |
+
rewrite_rule_set = pattern.RewriteRuleSet([rewrite_rule],commute=True)
|
| 254 |
+
fused_model = onnxscript.rewriter.rewrite(
|
| 255 |
+
intermidiate_model,
|
| 256 |
+
pattern_rewrite_rules=rewrite_rule_set
|
| 257 |
+
)
|
| 258 |
+
onnx.save(fused_model, "vision_encoder_optimized.onnx")
|
| 259 |
+
ONNX_MODEL = "vision_encoder_optimized.onnx"
|
| 260 |
+
# RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
|
| 261 |
+
del intermidiate_model
|
| 262 |
+
del fused_model
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
rknn = RKNN(verbose=True)
|
| 266 |
+
|
| 267 |
+
# pre-process config
|
| 268 |
+
print('--> Config model')
|
| 269 |
+
rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3)
|
| 270 |
+
print('done')
|
| 271 |
+
|
| 272 |
+
# Load ONNX model
|
| 273 |
+
print('--> Loading model')
|
| 274 |
+
ret = rknn.load_onnx(model=ONNX_MODEL)
|
| 275 |
+
if ret != 0:
|
| 276 |
+
print('Load model failed!')
|
| 277 |
+
exit(ret)
|
| 278 |
+
print('done')
|
| 279 |
+
|
| 280 |
+
# Build model
|
| 281 |
+
print('--> Building model stage 2')
|
| 282 |
+
ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
|
| 283 |
+
if ret != 0:
|
| 284 |
+
print('Build model failed!')
|
| 285 |
+
exit(ret)
|
| 286 |
+
print('done')
|
| 287 |
+
|
| 288 |
+
# Export RKNN model
|
| 289 |
+
print('--> Export RKNN model')
|
| 290 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
| 291 |
+
if ret != 0:
|
| 292 |
+
print('Export RKNN model failed!')
|
| 293 |
+
exit(ret)
|
| 294 |
+
print('done')
|
| 295 |
+
os.remove("vision_encoder_optimized.onnx")
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def check_vision_model():
|
| 300 |
+
rknn = RKNN(verbose=True)
|
| 301 |
+
|
| 302 |
+
ONNX_MODEL="vision_encoder.onnx"
|
| 303 |
+
RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
|
| 304 |
+
DATASET="dataset.txt"
|
| 305 |
+
QUANTIZE=False
|
| 306 |
+
|
| 307 |
+
# pre-process config
|
| 308 |
+
print('--> Config model')
|
| 309 |
+
rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3)
|
| 310 |
+
print('done')
|
| 311 |
+
|
| 312 |
+
# Load ONNX model
|
| 313 |
+
print('--> Loading model')
|
| 314 |
+
ret = rknn.load_onnx(model=ONNX_MODEL,
|
| 315 |
+
inputs=["pixel_values"],
|
| 316 |
+
input_size_list=[[batch_size, 3, vision_size[0], vision_size[1]]],
|
| 317 |
+
)
|
| 318 |
+
if ret != 0:
|
| 319 |
+
print('Load model failed!')
|
| 320 |
+
exit(ret)
|
| 321 |
+
print('done')
|
| 322 |
+
|
| 323 |
+
# Build model
|
| 324 |
+
print('--> Building model')
|
| 325 |
+
ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
|
| 326 |
+
if ret != 0:
|
| 327 |
+
print('Build model failed!')
|
| 328 |
+
exit(ret)
|
| 329 |
+
print('done')
|
| 330 |
+
|
| 331 |
+
# Export RKNN model
|
| 332 |
+
print('--> Export RKNN model')
|
| 333 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
| 334 |
+
if ret != 0:
|
| 335 |
+
print('Export RKNN model failed!')
|
| 336 |
+
exit(ret)
|
| 337 |
+
print('done')
|
| 338 |
+
|
| 339 |
+
#init runtime
|
| 340 |
+
print('--> Init runtime environment')
|
| 341 |
+
ret = rknn.init_runtime(target='rk3588')
|
| 342 |
+
if ret != 0:
|
| 343 |
+
print('Init runtime environment failed!')
|
| 344 |
+
exit(ret)
|
| 345 |
+
print('done')
|
| 346 |
+
|
| 347 |
+
#precision check
|
| 348 |
+
print('--> Precision check')
|
| 349 |
+
ret = rknn.accuracy_analysis(inputs=["lena.png"], target='rk3588')
|
| 350 |
+
if ret != 0:
|
| 351 |
+
print('Precision check failed!')
|
| 352 |
+
exit(ret)
|
| 353 |
+
print('done')
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
import argparse
|
| 357 |
+
# python convert.py <decoder|encoder|vision|all>
|
| 358 |
+
if __name__ == "__main__":
|
| 359 |
+
parser = argparse.ArgumentParser()
|
| 360 |
+
parser.add_argument("model", type=str, help="Model to convert")
|
| 361 |
+
parser.add_argument("--check", action="store_true", help="Check model")
|
| 362 |
+
args = parser.parse_args()
|
| 363 |
+
if args.model == "decoder":
|
| 364 |
+
convert_decoder()
|
| 365 |
+
elif args.model == "encoder":
|
| 366 |
+
convert_encoder()
|
| 367 |
+
# elif args.model == "embed": # embed is faster with cpu
|
| 368 |
+
# convert_embed()
|
| 369 |
+
elif args.model == "vision":
|
| 370 |
+
if args.check:
|
| 371 |
+
check_vision_model()
|
| 372 |
+
else:
|
| 373 |
+
convert_vision()
|
| 374 |
+
elif args.model == "all":
|
| 375 |
+
convert_decoder()
|
| 376 |
+
convert_encoder()
|
| 377 |
+
# convert_embed()
|
| 378 |
+
convert_vision()
|
| 379 |
+
else:
|
| 380 |
+
print("Invalid model")
|
| 381 |
+
exit(1)
|
decoder_model.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9b197ed07d9fe1da03dcce93b1f5ebf3cee4b66e531c9703b2087fc53ca50acb
|
| 3 |
+
size 387818953
|
decoder_model.rknn
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:414974a11e9ef72012f77c16c3da8c633ebeb351e0fa78cc2dea5737d431ff1f
|
| 3 |
+
size 194928054
|
decoder_model_merged.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b4ff6b536f773955b3355c5f19fb8436100b10705cc2e520724d6feeb733ae5
|
| 3 |
+
size 388046167
|
embed_tokens.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a51f42510df4e723a5d50f9da43fccd9e59d4c507bb9b28960d6500e778ee3b0
|
| 3 |
+
size 157560107
|
encoder_model.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:934f366c4d834a58f017e5373f01df0b2b9a333533c1e41032e7de8849c61a12
|
| 3 |
+
size 173409090
|
encoder_model.rknn
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7d1af501f57d5e068d47554e5c08cb0a489d24c7102d433a85fde0413ac12d2f
|
| 3 |
+
size 86759918
|
image.png
ADDED
|
|
run.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoProcessor
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
import onnxruntime as ort
|
| 5 |
+
import time
|
| 6 |
+
import argparse
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
# Use RKNN for some models
|
| 10 |
+
import ztu_somemodelruntime_rknnlite2 as rknnort
|
| 11 |
+
# Uncomment this to use ONNXRuntime for some models
|
| 12 |
+
# import onnxruntime as rknnort
|
| 13 |
+
|
| 14 |
+
# set current working directory to the directory of this file
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def run(image_path, prompt, max_new_tokens, output_image_path, temperature, seed):
|
| 21 |
+
# set seed for reproducibility
|
| 22 |
+
if seed is not None:
|
| 23 |
+
random.seed(seed)
|
| 24 |
+
np.random.seed(seed)
|
| 25 |
+
|
| 26 |
+
# 初始化总时间计数器
|
| 27 |
+
total_time = 0
|
| 28 |
+
|
| 29 |
+
# Initialize RKNNLite instances
|
| 30 |
+
vision_encoder = rknnort.InferenceSession(
|
| 31 |
+
"vision_encoder.onnx", providers=["CPUExecutionProvider"]
|
| 32 |
+
)
|
| 33 |
+
encoder = rknnort.InferenceSession(
|
| 34 |
+
"encoder_model.onnx", providers=["CPUExecutionProvider"]
|
| 35 |
+
)
|
| 36 |
+
decoder_prefill = rknnort.InferenceSession(
|
| 37 |
+
"decoder_model.onnx", providers=["CPUExecutionProvider"]
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
text_embed = ort.InferenceSession(
|
| 41 |
+
"embed_tokens.onnx", providers=["CPUExecutionProvider"]
|
| 42 |
+
)
|
| 43 |
+
decoder_decode = ort.InferenceSession(
|
| 44 |
+
"decoder_model_merged.onnx", providers=["CPUExecutionProvider"]
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# 1. prepare inputs
|
| 48 |
+
processor = AutoProcessor.from_pretrained(
|
| 49 |
+
"microsoft/Florence-2-base", trust_remote_code=True
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# 2. prepare image
|
| 53 |
+
image = Image.open(image_path).convert("RGB")
|
| 54 |
+
original_image = image.copy()
|
| 55 |
+
original_size = image.size
|
| 56 |
+
# resize image to 64x64
|
| 57 |
+
image = image.resize((64, 64))
|
| 58 |
+
# 3. prepare text
|
| 59 |
+
|
| 60 |
+
inputs = processor(
|
| 61 |
+
text=prompt, images=image, return_tensors="np", do_resize=False
|
| 62 |
+
) # , padding="max_length", max_length=pad_to + 577, truncation=True)
|
| 63 |
+
for k, v in inputs.items():
|
| 64 |
+
print(k, v.shape)
|
| 65 |
+
# print(inputs)
|
| 66 |
+
# 4. run vision encoder using RKNN
|
| 67 |
+
start_time = time.time()
|
| 68 |
+
image_features = vision_encoder.run(None, {"pixel_values": inputs["pixel_values"]})[
|
| 69 |
+
0
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
end_time = time.time()
|
| 73 |
+
vision_encoder_time = (end_time - start_time) * 1000
|
| 74 |
+
total_time += vision_encoder_time
|
| 75 |
+
print(f"Vision encoder time: {vision_encoder_time:.2f} ms")
|
| 76 |
+
print(image_features.shape)
|
| 77 |
+
# np.save("image_features.npy", image_features)
|
| 78 |
+
|
| 79 |
+
# 5. run text embed using RKNN
|
| 80 |
+
start_time = time.time()
|
| 81 |
+
inputs_embeds = text_embed.run(None, {"input_ids": inputs["input_ids"]})[0]
|
| 82 |
+
end_time = time.time()
|
| 83 |
+
text_embed_time = (end_time - start_time) * 1000
|
| 84 |
+
total_time += text_embed_time
|
| 85 |
+
print(f"Text embed time: {text_embed_time:.2f} ms")
|
| 86 |
+
print(inputs_embeds.shape)
|
| 87 |
+
# print(inputs_embeds)
|
| 88 |
+
|
| 89 |
+
# 6. concat image features and text embed
|
| 90 |
+
batch_size, image_token_length = image_features.shape[:-1]
|
| 91 |
+
image_attention_mask = np.ones((batch_size, image_token_length))
|
| 92 |
+
task_prefix_embeds = inputs_embeds
|
| 93 |
+
task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1]))
|
| 94 |
+
# task_prefix_attention_mask = inputs["attention_mask"]
|
| 95 |
+
if len(task_prefix_attention_mask.shape) == 3:
|
| 96 |
+
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
|
| 97 |
+
inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1)
|
| 98 |
+
attention_mask = np.concatenate(
|
| 99 |
+
[image_attention_mask, task_prefix_attention_mask], axis=1
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# 6. run encoder using RKNN
|
| 103 |
+
start_time = time.time()
|
| 104 |
+
encoder_out = encoder.run(
|
| 105 |
+
None,
|
| 106 |
+
{
|
| 107 |
+
"inputs_embeds": inputs_embeds,
|
| 108 |
+
"attention_mask": attention_mask.astype(np.int64),
|
| 109 |
+
},
|
| 110 |
+
)
|
| 111 |
+
end_time = time.time()
|
| 112 |
+
encoder_time = (end_time - start_time) * 1000
|
| 113 |
+
total_time += encoder_time
|
| 114 |
+
print(f"Encoder time: {encoder_time:.2f} ms")
|
| 115 |
+
encoder_hidden_states = encoder_out[0]
|
| 116 |
+
print(encoder_hidden_states.shape)
|
| 117 |
+
|
| 118 |
+
# 7. run decoder prefill stage using RKNN
|
| 119 |
+
start_time = time.time()
|
| 120 |
+
next_token = processor.tokenizer.bos_token_id
|
| 121 |
+
next_input_embeds = text_embed.run(None, {
|
| 122 |
+
"input_ids": np.array([[next_token]], dtype=np.int64)
|
| 123 |
+
})[0]
|
| 124 |
+
decoder_outs = decoder_prefill.run(
|
| 125 |
+
None,
|
| 126 |
+
{
|
| 127 |
+
"inputs_embeds": next_input_embeds,
|
| 128 |
+
"encoder_hidden_states": encoder_hidden_states,
|
| 129 |
+
# "encoder_attention_mask": attention_mask.astype(np.int64)
|
| 130 |
+
},
|
| 131 |
+
)
|
| 132 |
+
end_time = time.time()
|
| 133 |
+
decoder_prefill_time = (end_time - start_time) * 1000
|
| 134 |
+
total_time += decoder_prefill_time
|
| 135 |
+
print(f"Decoder prefill time: {decoder_prefill_time:.2f} ms")
|
| 136 |
+
# for output in decoder_outs:
|
| 137 |
+
# print(output.shape)
|
| 138 |
+
|
| 139 |
+
encoder_kv = decoder_outs[1:]
|
| 140 |
+
|
| 141 |
+
# 8. run decoder decode stage(autoregressive) (using onnxruntime)
|
| 142 |
+
generated_tokens = []
|
| 143 |
+
decoder_decode_total_time = 0
|
| 144 |
+
while generated_tokens.__len__() < max_new_tokens:
|
| 145 |
+
# 获取上一步的输出
|
| 146 |
+
logits = decoder_outs[0]
|
| 147 |
+
decoder_kv = decoder_outs[1:]
|
| 148 |
+
|
| 149 |
+
# 选择最后一个token的logits
|
| 150 |
+
next_token_logits = logits[:, -1, :]
|
| 151 |
+
|
| 152 |
+
if temperature == 0:
|
| 153 |
+
# Greedy decoding
|
| 154 |
+
next_token = np.argmax(next_token_logits, axis=-1)[0]
|
| 155 |
+
else:
|
| 156 |
+
# Temperature sampling
|
| 157 |
+
# 应用温度
|
| 158 |
+
next_token_logits /= temperature
|
| 159 |
+
|
| 160 |
+
# 从logits中减去最大值以提高数值稳定性
|
| 161 |
+
next_token_logits -= np.max(next_token_logits)
|
| 162 |
+
|
| 163 |
+
# 计算softmax
|
| 164 |
+
probs = np.exp(next_token_logits) / np.sum(np.exp(next_token_logits))
|
| 165 |
+
|
| 166 |
+
# 从概率分布中采样
|
| 167 |
+
next_token = np.random.choice(len(probs[0]), p=probs[0])
|
| 168 |
+
|
| 169 |
+
print("next_token: ", processor.decode([next_token]))
|
| 170 |
+
# 将新生成的token添加到结果中
|
| 171 |
+
generated_tokens.append(next_token)
|
| 172 |
+
|
| 173 |
+
# 如果生成了结束符,则停止生成
|
| 174 |
+
if next_token == 2: # </s>
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
# 准备下一步的输入
|
| 178 |
+
start_time = time.time()
|
| 179 |
+
next_input_embeds = text_embed.run(
|
| 180 |
+
None, {"input_ids": np.array([[next_token]], dtype=np.int64)}
|
| 181 |
+
)[0]
|
| 182 |
+
end_time = time.time()
|
| 183 |
+
text_embed_time = (end_time - start_time) * 1000
|
| 184 |
+
decoder_decode_total_time += text_embed_time
|
| 185 |
+
|
| 186 |
+
# 运行decoder的decode阶段
|
| 187 |
+
start_time = time.time()
|
| 188 |
+
decoder_outs = decoder_decode.run(
|
| 189 |
+
None,
|
| 190 |
+
{
|
| 191 |
+
"use_cache_branch": np.array([True], dtype=np.bool_),
|
| 192 |
+
"inputs_embeds": next_input_embeds,
|
| 193 |
+
"encoder_hidden_states": encoder_hidden_states,
|
| 194 |
+
# "encoder_attention_mask": attention_mask.astype(np.int64),
|
| 195 |
+
"past_key_values.0.decoder.key": decoder_kv[0],
|
| 196 |
+
"past_key_values.0.decoder.value": decoder_kv[1],
|
| 197 |
+
"past_key_values.0.encoder.key": encoder_kv[2],
|
| 198 |
+
"past_key_values.0.encoder.value": encoder_kv[3],
|
| 199 |
+
"past_key_values.1.decoder.key": decoder_kv[4],
|
| 200 |
+
"past_key_values.1.decoder.value": decoder_kv[5],
|
| 201 |
+
"past_key_values.1.encoder.key": encoder_kv[6],
|
| 202 |
+
"past_key_values.1.encoder.value": encoder_kv[7],
|
| 203 |
+
"past_key_values.2.decoder.key": decoder_kv[8],
|
| 204 |
+
"past_key_values.2.decoder.value": decoder_kv[9],
|
| 205 |
+
"past_key_values.2.encoder.key": encoder_kv[10],
|
| 206 |
+
"past_key_values.2.encoder.value": encoder_kv[11],
|
| 207 |
+
"past_key_values.3.decoder.key": decoder_kv[12],
|
| 208 |
+
"past_key_values.3.decoder.value": decoder_kv[13],
|
| 209 |
+
"past_key_values.3.encoder.key": encoder_kv[14],
|
| 210 |
+
"past_key_values.3.encoder.value": encoder_kv[15],
|
| 211 |
+
"past_key_values.4.decoder.key": decoder_kv[16],
|
| 212 |
+
"past_key_values.4.decoder.value": decoder_kv[17],
|
| 213 |
+
"past_key_values.4.encoder.key": encoder_kv[18],
|
| 214 |
+
"past_key_values.4.encoder.value": encoder_kv[19],
|
| 215 |
+
"past_key_values.5.decoder.key": decoder_kv[20],
|
| 216 |
+
"past_key_values.5.decoder.value": decoder_kv[21],
|
| 217 |
+
"past_key_values.5.encoder.key": encoder_kv[22],
|
| 218 |
+
"past_key_values.5.encoder.value": encoder_kv[23],
|
| 219 |
+
},
|
| 220 |
+
)
|
| 221 |
+
end_time = time.time()
|
| 222 |
+
decoder_decode_time = (end_time - start_time) * 1000
|
| 223 |
+
decoder_decode_total_time += decoder_decode_time
|
| 224 |
+
|
| 225 |
+
total_time += decoder_decode_total_time
|
| 226 |
+
print(f"Decoder decode total time: {decoder_decode_total_time:.2f} ms")
|
| 227 |
+
|
| 228 |
+
# 将生成的tokens转换为文本
|
| 229 |
+
print("generated_tokens: ", generated_tokens)
|
| 230 |
+
generated_text = processor.batch_decode(
|
| 231 |
+
[generated_tokens], skip_special_tokens=False
|
| 232 |
+
)[0]
|
| 233 |
+
print("Generated Text:", generated_text)
|
| 234 |
+
parsed_answer = processor.post_process_generation(
|
| 235 |
+
generated_text,
|
| 236 |
+
task=prompt.split(">")[0].strip() + ">",
|
| 237 |
+
image_size=original_size,
|
| 238 |
+
)
|
| 239 |
+
print("Parsed Answer:", parsed_answer)
|
| 240 |
+
|
| 241 |
+
print(f"Total inference time: {total_time:.2f} ms")
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == "__main__":
|
| 245 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
|
| 246 |
+
parser.add_argument("image_path", type=str, help="Path to the input image.")
|
| 247 |
+
parser.add_argument(
|
| 248 |
+
"--max_new_tokens",
|
| 249 |
+
type=int,
|
| 250 |
+
default=512,
|
| 251 |
+
help="Maximum number of new tokens to generate.",
|
| 252 |
+
)
|
| 253 |
+
parser.add_argument(
|
| 254 |
+
"--output_image_path",
|
| 255 |
+
type=str,
|
| 256 |
+
default="result_image.jpg",
|
| 257 |
+
help="Path to save the output image with visualizations.",
|
| 258 |
+
)
|
| 259 |
+
parser.add_argument(
|
| 260 |
+
"--temperature",
|
| 261 |
+
type=float,
|
| 262 |
+
default=0,
|
| 263 |
+
help="Temperature for sampling. Set to 0 for greedy decoding.",
|
| 264 |
+
)
|
| 265 |
+
parser.add_argument(
|
| 266 |
+
"--seed", type=int, default=None, help="Random seed for reproducibility."
|
| 267 |
+
)
|
| 268 |
+
args = parser.parse_args()
|
| 269 |
+
run(
|
| 270 |
+
args.image_path,
|
| 271 |
+
"<CAPTION>",
|
| 272 |
+
args.max_new_tokens,
|
| 273 |
+
args.output_image_path,
|
| 274 |
+
args.temperature,
|
| 275 |
+
args.seed,
|
| 276 |
+
)
|
vision_encoder.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:731e2a61276e681979ea5a6fca66da84e59877b045bf4b11299cf43f92817a2a
|
| 3 |
+
size 365965528
|
vision_encoder.rknn
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:184089ff44d372ab6e83f76ee077da9f213de9d5df9433365fe25f6b225b9183
|
| 3 |
+
size 191014658
|
ztu_somemodelruntime_rknnlite2.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 模块级常量和函数
|
| 2 |
+
from rknnlite.api import RKNNLite
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import warnings
|
| 6 |
+
import logging
|
| 7 |
+
from typing import List, Dict, Union, Optional
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import onnxruntime as ort
|
| 11 |
+
HAS_ORT = True
|
| 12 |
+
except ImportError:
|
| 13 |
+
HAS_ORT = False
|
| 14 |
+
warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
|
| 15 |
+
|
| 16 |
+
# 配置日志
|
| 17 |
+
logger = logging.getLogger("somemodelruntime_rknnlite2")
|
| 18 |
+
logger.setLevel(logging.ERROR) # 默认只输出错误信息
|
| 19 |
+
if not logger.handlers:
|
| 20 |
+
handler = logging.StreamHandler()
|
| 21 |
+
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
| 22 |
+
logger.addHandler(handler)
|
| 23 |
+
|
| 24 |
+
# ONNX Runtime日志级别到Python logging级别的映射
|
| 25 |
+
_LOGGING_LEVEL_MAP = {
|
| 26 |
+
0: logging.DEBUG, # Verbose
|
| 27 |
+
1: logging.INFO, # Info
|
| 28 |
+
2: logging.WARNING, # Warning
|
| 29 |
+
3: logging.ERROR, # Error
|
| 30 |
+
4: logging.CRITICAL # Fatal
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# 检查环境变量中的日志级别设置
|
| 34 |
+
try:
|
| 35 |
+
env_log_level = os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL')
|
| 36 |
+
if env_log_level is not None:
|
| 37 |
+
log_level = int(env_log_level)
|
| 38 |
+
if log_level in _LOGGING_LEVEL_MAP:
|
| 39 |
+
logger.setLevel(_LOGGING_LEVEL_MAP[log_level])
|
| 40 |
+
logger.info(f"从环境变量设置日志级别: {log_level}")
|
| 41 |
+
else:
|
| 42 |
+
logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {log_level}, 应该是0-4之间的整数")
|
| 43 |
+
except ValueError:
|
| 44 |
+
logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {env_log_level}, 应该是0-4之间的整数")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def set_default_logger_severity(level: int) -> None:
|
| 48 |
+
"""
|
| 49 |
+
Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
level: 日志级别(0-4)
|
| 53 |
+
"""
|
| 54 |
+
if level not in _LOGGING_LEVEL_MAP:
|
| 55 |
+
raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
|
| 56 |
+
logger.setLevel(_LOGGING_LEVEL_MAP[level])
|
| 57 |
+
|
| 58 |
+
def set_default_logger_verbosity(level: int) -> None:
|
| 59 |
+
"""
|
| 60 |
+
Sets the default logging verbosity level. To activate the verbose log,
|
| 61 |
+
you need to set the default logging severity to 0:Verbose level.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
level: 日志级别(0-4)
|
| 65 |
+
"""
|
| 66 |
+
set_default_logger_severity(level)
|
| 67 |
+
|
| 68 |
+
# RKNN tensor type到numpy dtype的映射
|
| 69 |
+
RKNN_DTYPE_MAP = {
|
| 70 |
+
0: np.float32, # RKNN_TENSOR_FLOAT32
|
| 71 |
+
1: np.float16, # RKNN_TENSOR_FLOAT16
|
| 72 |
+
2: np.int8, # RKNN_TENSOR_INT8
|
| 73 |
+
3: np.uint8, # RKNN_TENSOR_UINT8
|
| 74 |
+
4: np.int16, # RKNN_TENSOR_INT16
|
| 75 |
+
5: np.uint16, # RKNN_TENSOR_UINT16
|
| 76 |
+
6: np.int32, # RKNN_TENSOR_INT32
|
| 77 |
+
7: np.uint32, # RKNN_TENSOR_UINT32
|
| 78 |
+
8: np.int64, # RKNN_TENSOR_INT64
|
| 79 |
+
9: bool, # RKNN_TENSOR_BOOL
|
| 80 |
+
10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
def get_available_providers() -> List[str]:
|
| 84 |
+
"""
|
| 85 |
+
获取可用的设备提供者列表(为保持接口兼容性的占位函数)
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
list: 可用的设备提供者列表,总是返回["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
|
| 89 |
+
"""
|
| 90 |
+
return ["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_device() -> str:
|
| 94 |
+
"""
|
| 95 |
+
获取当前设备
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
str: 当前设备
|
| 99 |
+
"""
|
| 100 |
+
return "RKNN2"
|
| 101 |
+
|
| 102 |
+
def get_version_info() -> Dict[str, str]:
|
| 103 |
+
"""
|
| 104 |
+
获取版本信息
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
dict: 包含API和驱动版本信息的字典
|
| 108 |
+
"""
|
| 109 |
+
runtime = RKNNLite()
|
| 110 |
+
version = runtime.get_sdk_version()
|
| 111 |
+
return {
|
| 112 |
+
"api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
|
| 113 |
+
"driver_version": version.split('\n')[3].split(': ')[1]
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
class IOTensor:
|
| 117 |
+
"""输入/输出张量的信息封装类"""
|
| 118 |
+
def __init__(self, name, shape, type=None):
|
| 119 |
+
self.name = name.decode() if isinstance(name, bytes) else name
|
| 120 |
+
self.shape = shape
|
| 121 |
+
self.type = type
|
| 122 |
+
|
| 123 |
+
def __str__(self):
|
| 124 |
+
return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
|
| 125 |
+
|
| 126 |
+
class SessionOptions:
|
| 127 |
+
"""会话选项类"""
|
| 128 |
+
def __init__(self):
|
| 129 |
+
self.enable_profiling = False # 是否使用性能分析
|
| 130 |
+
self.intra_op_num_threads = 1 # 设置RKNN的线程数, 对应rknn的core_mask
|
| 131 |
+
self.log_severity_level = -1 # 另一个设置日志级别的参数
|
| 132 |
+
self.log_verbosity_level = -1 # 另一个设置日志级别的参数
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class InferenceSession:
|
| 136 |
+
"""
|
| 137 |
+
RKNNLite运行时封装类,API风格类似ONNX Runtime
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __new__(cls, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
|
| 141 |
+
processed_path = InferenceSession._process_model_path(model_path, sess_options)
|
| 142 |
+
if isinstance(processed_path, str) and processed_path.lower().endswith('.onnx'):
|
| 143 |
+
logger.info("使用ONNX Runtime加载模型")
|
| 144 |
+
if not HAS_ORT:
|
| 145 |
+
raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
|
| 146 |
+
return ort.InferenceSession(processed_path, sess_options=sess_options, **kwargs)
|
| 147 |
+
else:
|
| 148 |
+
# 如果不是 ONNX 模型,则调用父类的 __new__ 创建 InferenceSession 实例
|
| 149 |
+
instance = super().__new__(cls)
|
| 150 |
+
# 保存处理后的路径
|
| 151 |
+
instance._processed_path = processed_path
|
| 152 |
+
return instance
|
| 153 |
+
|
| 154 |
+
def __init__(self, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
|
| 155 |
+
"""
|
| 156 |
+
初始化运行时并加载模型
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
model_path: 模型文件路径(.rknn或.onnx)
|
| 160 |
+
sess_options: 会话选项
|
| 161 |
+
**kwargs: 其他初始化参数
|
| 162 |
+
"""
|
| 163 |
+
options = sess_options or SessionOptions()
|
| 164 |
+
|
| 165 |
+
# 只在未设置环境变量时使用SessionOptions中的日志级别
|
| 166 |
+
if os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') is None:
|
| 167 |
+
if options.log_severity_level != -1:
|
| 168 |
+
set_default_logger_severity(options.log_severity_level)
|
| 169 |
+
if options.log_verbosity_level != -1:
|
| 170 |
+
set_default_logger_verbosity(options.log_verbosity_level)
|
| 171 |
+
|
| 172 |
+
# 使用__new__中处理好的路径
|
| 173 |
+
model_path = getattr(self, '_processed_path', model_path)
|
| 174 |
+
if isinstance(model_path, str) and model_path.lower().endswith('.onnx'):
|
| 175 |
+
# 避免重复加载 ONNX 模型
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
# ... 现有的 RKNN 模型加载和初始化代码 ...
|
| 179 |
+
self.model_path = model_path
|
| 180 |
+
if not os.path.exists(self.model_path):
|
| 181 |
+
logger.error(f"模型文件不存在: {self.model_path}")
|
| 182 |
+
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
|
| 183 |
+
|
| 184 |
+
self.runtime = RKNNLite(verbose=options.enable_profiling)
|
| 185 |
+
|
| 186 |
+
logger.debug(f"正在加载模型: {self.model_path}")
|
| 187 |
+
ret = self.runtime.load_rknn(self.model_path)
|
| 188 |
+
if ret != 0:
|
| 189 |
+
logger.error(f"加载RKNN模型失败: {self.model_path}")
|
| 190 |
+
raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
|
| 191 |
+
logger.debug("模型加载成功")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if options.intra_op_num_threads == 1:
|
| 195 |
+
core_mask = RKNNLite.NPU_CORE_AUTO
|
| 196 |
+
elif options.intra_op_num_threads == 2:
|
| 197 |
+
core_mask = RKNNLite.NPU_CORE_0_1
|
| 198 |
+
elif options.intra_op_num_threads == 3:
|
| 199 |
+
core_mask = RKNNLite.NPU_CORE_0_1_2
|
| 200 |
+
else:
|
| 201 |
+
raise ValueError(f"intra_op_num_threads的值无效: {options.intra_op_num_threads}, 只能是1,2或3")
|
| 202 |
+
|
| 203 |
+
logger.debug("正在初始化运行时环境")
|
| 204 |
+
ret = self.runtime.init_runtime(core_mask=core_mask)
|
| 205 |
+
if ret != 0:
|
| 206 |
+
logger.error("初始化运行时环境失败")
|
| 207 |
+
raise RuntimeError('初始化运行时环境失败')
|
| 208 |
+
logger.debug("运行时环境初始化成功")
|
| 209 |
+
|
| 210 |
+
self._init_io_info()
|
| 211 |
+
self.options = options
|
| 212 |
+
|
| 213 |
+
def get_performance_info(self) -> Dict[str, float]:
|
| 214 |
+
"""
|
| 215 |
+
获取性能信息
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
dict: 包含性能信息的字典
|
| 219 |
+
"""
|
| 220 |
+
if not self.options.perf_debug:
|
| 221 |
+
raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
|
| 222 |
+
|
| 223 |
+
perf = self.runtime.rknn_runtime.get_run_perf()
|
| 224 |
+
return {
|
| 225 |
+
"run_duration": perf.run_duration / 1000.0 # 转换为毫秒
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
def set_core_mask(self, core_mask: int) -> None:
|
| 229 |
+
"""
|
| 230 |
+
设置NPU核心使用模式
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
core_mask: NPU核心掩码,使用NPU_CORE_*常量
|
| 234 |
+
"""
|
| 235 |
+
ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
|
| 236 |
+
if ret != 0:
|
| 237 |
+
raise RuntimeError("设置NPU核心模式失败")
|
| 238 |
+
|
| 239 |
+
@staticmethod
|
| 240 |
+
def _process_model_path(model_path, sess_options):
|
| 241 |
+
"""
|
| 242 |
+
处理模型路径,支持.onnx和.rknn文件
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
model_path: 模型文件路径
|
| 246 |
+
"""
|
| 247 |
+
# 如果是ONNX文件,检查是否需要自动加载RKNN
|
| 248 |
+
if model_path.lower().endswith('.onnx'):
|
| 249 |
+
logger.info("检测到ONNX模型文件")
|
| 250 |
+
|
| 251 |
+
# 获取需要跳过自动加载的模型列表
|
| 252 |
+
skip_models = os.getenv('ZTU_MODELRT_RKNNL2_SKIP', '').strip()
|
| 253 |
+
if skip_models:
|
| 254 |
+
skip_list = [m.strip() for m in skip_models.split(',')]
|
| 255 |
+
# 获取模型文件名(不含路径)用于匹配
|
| 256 |
+
model_name = os.path.basename(model_path)
|
| 257 |
+
if model_name.lower() in [m.lower() for m in skip_list]:
|
| 258 |
+
logger.info(f"模型{model_name}在跳过列表中,将使用ONNX Runtime")
|
| 259 |
+
return model_path
|
| 260 |
+
|
| 261 |
+
# 构造RKNN文件路径
|
| 262 |
+
rknn_path = os.path.splitext(model_path)[0] + '.rknn'
|
| 263 |
+
if os.path.exists(rknn_path):
|
| 264 |
+
logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
|
| 265 |
+
return rknn_path
|
| 266 |
+
else:
|
| 267 |
+
logger.info("未找到对应的RKNN模型,将使用ONNX Runtime")
|
| 268 |
+
return model_path
|
| 269 |
+
|
| 270 |
+
return model_path
|
| 271 |
+
|
| 272 |
+
def _convert_nhwc_to_nchw(self, shape):
|
| 273 |
+
"""将NHWC格式的shape转换为NCHW格式"""
|
| 274 |
+
if len(shape) == 4:
|
| 275 |
+
# NHWC -> NCHW
|
| 276 |
+
n, h, w, c = shape
|
| 277 |
+
return [n, c, h, w]
|
| 278 |
+
return shape
|
| 279 |
+
|
| 280 |
+
def _init_io_info(self):
|
| 281 |
+
"""初始化模型的输入输出信息"""
|
| 282 |
+
runtime = self.runtime.rknn_runtime
|
| 283 |
+
|
| 284 |
+
# 获取输入输出数量
|
| 285 |
+
n_input, n_output = runtime.get_in_out_num()
|
| 286 |
+
|
| 287 |
+
# 获取输入信息
|
| 288 |
+
self.input_tensors = []
|
| 289 |
+
for i in range(n_input):
|
| 290 |
+
attr = runtime.get_tensor_attr(i)
|
| 291 |
+
shape = [attr.dims[j] for j in range(attr.n_dims)]
|
| 292 |
+
# 对四维输入进行NHWC到NCHW的转换
|
| 293 |
+
shape = self._convert_nhwc_to_nchw(shape)
|
| 294 |
+
# 获取dtype
|
| 295 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
| 296 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
| 297 |
+
self.input_tensors.append(tensor)
|
| 298 |
+
|
| 299 |
+
# 获取输出信息
|
| 300 |
+
self.output_tensors = []
|
| 301 |
+
for i in range(n_output):
|
| 302 |
+
attr = runtime.get_tensor_attr(i, is_output=True)
|
| 303 |
+
shape = runtime.get_output_shape(i)
|
| 304 |
+
# 获取dtype
|
| 305 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
| 306 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
| 307 |
+
self.output_tensors.append(tensor)
|
| 308 |
+
|
| 309 |
+
def get_inputs(self):
|
| 310 |
+
"""
|
| 311 |
+
获取模型输入信息
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
list: 包含输入信息的列表
|
| 315 |
+
"""
|
| 316 |
+
return self.input_tensors
|
| 317 |
+
|
| 318 |
+
def get_outputs(self):
|
| 319 |
+
"""
|
| 320 |
+
获取模型输出信息
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
list: 包含输出信息的列表
|
| 324 |
+
"""
|
| 325 |
+
return self.output_tensors
|
| 326 |
+
|
| 327 |
+
def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
|
| 328 |
+
"""
|
| 329 |
+
执行模型推理
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
output_names: 输出节点名称列表,指定需要返回哪些输出
|
| 333 |
+
input_feed: 输入数据字典或列表
|
| 334 |
+
data_format: 输入数据格式,"nchw"或"nhwc"
|
| 335 |
+
**kwargs: 其他运行时参数
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
|
| 339 |
+
"""
|
| 340 |
+
if input_feed is None:
|
| 341 |
+
logger.error("input_feed不能为None")
|
| 342 |
+
raise ValueError("input_feed不能为None")
|
| 343 |
+
|
| 344 |
+
# 准备输入数据
|
| 345 |
+
if isinstance(input_feed, dict):
|
| 346 |
+
# 如果是字典,按照模型输入顺序排列
|
| 347 |
+
inputs = []
|
| 348 |
+
input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
|
| 349 |
+
for tensor in self.input_tensors:
|
| 350 |
+
if tensor.name not in input_feed:
|
| 351 |
+
raise ValueError(f"缺少输入: {tensor.name}")
|
| 352 |
+
inputs.append(input_feed[tensor.name])
|
| 353 |
+
elif isinstance(input_feed, (list, tuple)):
|
| 354 |
+
# 如果是列表,确保长度匹配
|
| 355 |
+
if len(input_feed) != len(self.input_tensors):
|
| 356 |
+
raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
|
| 357 |
+
inputs = list(input_feed)
|
| 358 |
+
else:
|
| 359 |
+
logger.error("input_feed必须是字典或列表类型")
|
| 360 |
+
raise ValueError("input_feed必须是字典或列表类型")
|
| 361 |
+
|
| 362 |
+
# 执行推理
|
| 363 |
+
try:
|
| 364 |
+
logger.debug("开始执行推理")
|
| 365 |
+
all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
|
| 366 |
+
|
| 367 |
+
# 如果没有指定output_names,返回所有输出
|
| 368 |
+
if output_names is None:
|
| 369 |
+
return all_outputs
|
| 370 |
+
|
| 371 |
+
# 获取指定的输出
|
| 372 |
+
output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
|
| 373 |
+
selected_outputs = []
|
| 374 |
+
for name in output_names:
|
| 375 |
+
if name not in output_map:
|
| 376 |
+
raise ValueError(f"未找到输出节点: {name}")
|
| 377 |
+
selected_outputs.append(all_outputs[output_map[name]])
|
| 378 |
+
|
| 379 |
+
return selected_outputs
|
| 380 |
+
|
| 381 |
+
except Exception as e:
|
| 382 |
+
logger.error(f"推理执行失败: {str(e)}")
|
| 383 |
+
raise RuntimeError(f"推理执行失败: {str(e)}")
|
| 384 |
+
|
| 385 |
+
def close(self):
|
| 386 |
+
"""
|
| 387 |
+
关闭会话,释放资源
|
| 388 |
+
"""
|
| 389 |
+
if self.runtime is not None:
|
| 390 |
+
logger.info("正在释放运行时资源")
|
| 391 |
+
self.runtime.release()
|
| 392 |
+
self.runtime = None
|
| 393 |
+
|
| 394 |
+
def __enter__(self):
|
| 395 |
+
return self
|
| 396 |
+
|
| 397 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 398 |
+
self.close()
|
| 399 |
+
|
| 400 |
+
def end_profiling(self) -> Optional[str]:
|
| 401 |
+
"""
|
| 402 |
+
结束性能分析的存根方法
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
Optional[str]: None
|
| 406 |
+
"""
|
| 407 |
+
warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 408 |
+
return None
|
| 409 |
+
|
| 410 |
+
def get_profiling_start_time_ns(self) -> int:
|
| 411 |
+
"""
|
| 412 |
+
获取性能分析开始时间的存根方法
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
int: 0
|
| 416 |
+
"""
|
| 417 |
+
warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 418 |
+
return 0
|
| 419 |
+
|
| 420 |
+
def get_modelmeta(self) -> Dict[str, str]:
|
| 421 |
+
"""
|
| 422 |
+
获取模型元数据的存根方法
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
Dict[str, str]: 空字典
|
| 426 |
+
"""
|
| 427 |
+
warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 428 |
+
return {}
|
| 429 |
+
|
| 430 |
+
def get_session_options(self) -> SessionOptions:
|
| 431 |
+
"""
|
| 432 |
+
获取会话选项
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
SessionOptions: 当前会话选项
|
| 436 |
+
"""
|
| 437 |
+
return self.options
|
| 438 |
+
|
| 439 |
+
def get_providers(self) -> List[str]:
|
| 440 |
+
"""
|
| 441 |
+
获取当前使用的providers的存根方法
|
| 442 |
+
|
| 443 |
+
Returns:
|
| 444 |
+
List[str]: ["CPUExecutionProvider"]
|
| 445 |
+
"""
|
| 446 |
+
warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
|
| 447 |
+
return ["CPUExecutionProvider"]
|
| 448 |
+
|
| 449 |
+
def get_provider_options(self) -> Dict[str, Dict[str, str]]:
|
| 450 |
+
"""
|
| 451 |
+
获取provider选项的存根方法
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
Dict[str, Dict[str, str]]: 空字典
|
| 455 |
+
"""
|
| 456 |
+
warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 457 |
+
return {}
|
| 458 |
+
|
| 459 |
+
def get_session_config(self) -> Dict[str, str]:
|
| 460 |
+
"""
|
| 461 |
+
获取会话配置的存根方法
|
| 462 |
+
|
| 463 |
+
Returns:
|
| 464 |
+
Dict[str, str]: 空字典
|
| 465 |
+
"""
|
| 466 |
+
warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 467 |
+
return {}
|
| 468 |
+
|
| 469 |
+
def get_session_state(self) -> Dict[str, str]:
|
| 470 |
+
"""
|
| 471 |
+
获取会话状态的存根方法
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
Dict[str, str]: 空字典
|
| 475 |
+
"""
|
| 476 |
+
warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 477 |
+
return {}
|
| 478 |
+
|
| 479 |
+
def set_session_config(self, config: Dict[str, str]) -> None:
|
| 480 |
+
"""
|
| 481 |
+
设置会话配置的存根方法
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
config: 会话配置字典
|
| 485 |
+
"""
|
| 486 |
+
warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 487 |
+
|
| 488 |
+
def get_memory_info(self) -> Dict[str, int]:
|
| 489 |
+
"""
|
| 490 |
+
获取内存使用信息的存根方法
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
Dict[str, int]: 空字典
|
| 494 |
+
"""
|
| 495 |
+
warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 496 |
+
return {}
|
| 497 |
+
|
| 498 |
+
def set_memory_pattern(self, enable: bool) -> None:
|
| 499 |
+
"""
|
| 500 |
+
设置内存模式的存根方法
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
enable: 是否启用内存模式
|
| 504 |
+
"""
|
| 505 |
+
warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 506 |
+
|
| 507 |
+
def disable_memory_pattern(self) -> None:
|
| 508 |
+
"""
|
| 509 |
+
禁用内存模式的存根方法
|
| 510 |
+
"""
|
| 511 |
+
warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 512 |
+
|
| 513 |
+
def get_optimization_level(self) -> int:
|
| 514 |
+
"""
|
| 515 |
+
获取优化级别的存根方法
|
| 516 |
+
|
| 517 |
+
Returns:
|
| 518 |
+
int: 0
|
| 519 |
+
"""
|
| 520 |
+
warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 521 |
+
return 0
|
| 522 |
+
|
| 523 |
+
def set_optimization_level(self, level: int) -> None:
|
| 524 |
+
"""
|
| 525 |
+
设置优化级别的存根方法
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
level: 优化级别
|
| 529 |
+
"""
|
| 530 |
+
warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 531 |
+
|
| 532 |
+
def get_model_metadata(self) -> Dict[str, str]:
|
| 533 |
+
"""
|
| 534 |
+
获取模型元数据的存根方法(与get_modelmeta不同的接口)
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
Dict[str, str]: 空字典
|
| 538 |
+
"""
|
| 539 |
+
warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 540 |
+
return {}
|
| 541 |
+
|
| 542 |
+
def get_model_path(self) -> str:
|
| 543 |
+
"""
|
| 544 |
+
获取模型路径
|
| 545 |
+
|
| 546 |
+
Returns:
|
| 547 |
+
str: 模型文件路径
|
| 548 |
+
"""
|
| 549 |
+
return self.model_path
|
| 550 |
+
|
| 551 |
+
def get_input_type_info(self) -> List[Dict[str, str]]:
|
| 552 |
+
"""
|
| 553 |
+
获取输入类型信息的存根方法
|
| 554 |
+
|
| 555 |
+
Returns:
|
| 556 |
+
List[Dict[str, str]]: 空列表
|
| 557 |
+
"""
|
| 558 |
+
warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 559 |
+
return []
|
| 560 |
+
|
| 561 |
+
def get_output_type_info(self) -> List[Dict[str, str]]:
|
| 562 |
+
"""
|
| 563 |
+
获取输出类型信息的存根方法
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
List[Dict[str, str]]: 空列表
|
| 567 |
+
"""
|
| 568 |
+
warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
| 569 |
+
return []
|