happyme531 commited on
Commit
69499b6
·
verified ·
1 Parent(s): 3981667

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ decoder_model.rknn filter=lfs diff=lfs merge=lfs -text
37
+ encoder_model.rknn filter=lfs diff=lfs merge=lfs -text
38
+ vision_encoder.rknn filter=lfs diff=lfs merge=lfs -text
convert.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import numpy as np
5
+ from rknn.api import RKNN
6
+ from math import exp
7
+ from sys import exit
8
+
9
+ import onnx
10
+ import onnxscript
11
+
12
+ batch_size = 1
13
+ encoder_seq_len_list = [13]
14
+
15
+ decoder_seq_len = 1
16
+
17
+ # set current directory to the directory of this file
18
+ import os
19
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
20
+
21
+ import subprocess
22
+ import select
23
+
24
+ def run_python_code(code):
25
+ # 启动子进程并执行代码
26
+ process = subprocess.Popen(
27
+ ['python', '-c', code],
28
+ stdout=subprocess.PIPE,
29
+ stderr=subprocess.PIPE,
30
+ text=True
31
+ )
32
+
33
+ # 实时读取子进程的输出和错误输出
34
+ while True:
35
+ reads = [process.stdout.fileno(), process.stderr.fileno()]
36
+ ret = select.select(reads, [], [])
37
+
38
+ for fd in ret[0]:
39
+ if fd == process.stdout.fileno():
40
+ output = process.stdout.readline()
41
+ if output:
42
+ print(output.strip())
43
+ if fd == process.stderr.fileno():
44
+ err = process.stderr.readline()
45
+ if err:
46
+ print(f"Error: {err.strip()}")
47
+
48
+ if process.poll() is not None:
49
+ break
50
+
51
+ def convert_decoder():
52
+ rknn = RKNN(verbose=True)
53
+
54
+ ONNX_MODEL="decoder_model.onnx"
55
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
56
+ DATASET="dataset.txt"
57
+ QUANTIZE=False
58
+
59
+ # [batch_size, encoder_seq_len, 768],
60
+ # [batch_size, decoder_seq_len, 768]]
61
+ input_shapes =[[
62
+ [batch_size, encoder_seq_len, 768],
63
+ [batch_size, decoder_seq_len, 768]] for encoder_seq_len in encoder_seq_len_list]
64
+ # pre-process config
65
+ print('--> Config model')
66
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3,
67
+ dynamic_input=input_shapes)
68
+ print('done')
69
+
70
+ # Load ONNX model
71
+ print('--> Loading model')
72
+ ret = rknn.load_onnx(model=ONNX_MODEL,
73
+ )
74
+ if ret != 0:
75
+ print('Load model failed!')
76
+ exit(ret)
77
+ print('done')
78
+
79
+ # Build model
80
+ print('--> Building model')
81
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
82
+ if ret != 0:
83
+ print('Build model failed!')
84
+ exit(ret)
85
+ print('done')
86
+
87
+ #export
88
+ print('--> Export RKNN model')
89
+ ret = rknn.export_rknn(RKNN_MODEL)
90
+ if ret != 0:
91
+ print('Export RKNN model failed!')
92
+ exit(ret)
93
+ print('done')
94
+
95
+ def convert_decoder_2():
96
+ import onnx_graphsurgeon as gs
97
+ ONNX_MODEL="decoder_model_merged.onnx"
98
+
99
+ graph = gs.import_onnx(onnx.load(ONNX_MODEL))
100
+ inp = graph.inputs[27] # use_cache_branch
101
+ inp.to_constant(np.array([True], dtype=np.bool_))
102
+ ONNX_MODEL
103
+ onnx.save(gs.export_onnx(graph), "new_model.onnx")
104
+
105
+ np_true = np.array([True], dtype=np.bool_)
106
+ np.save("np_true.npy", np_true)
107
+
108
+
109
+ rknn = RKNN(verbose=True)
110
+
111
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
112
+ DATASET="dataset.txt"
113
+ QUANTIZE=False
114
+
115
+ # [batch_size, encoder_seq_len, 768],
116
+ # [batch_size, decoder_seq_len, 768]]
117
+ input_shapes =[[
118
+ [batch_size, encoder_seq_len, 768],
119
+ [batch_size, decoder_seq_len, 768]] for encoder_seq_len in encoder_seq_len_list]
120
+ # pre-process config
121
+ print('--> Config model')
122
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3,
123
+ dynamic_input=input_shapes)
124
+ print('done')
125
+
126
+ # Load ONNX model
127
+ print('--> Loading model')
128
+ ret = rknn.load_onnx(model=ONNX_MODEL,
129
+ )
130
+ if ret != 0:
131
+ print('Load model failed!')
132
+ exit(ret)
133
+ print('done')
134
+
135
+ # Build model
136
+ print('--> Building model')
137
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
138
+ if ret != 0:
139
+ print('Build model failed!')
140
+ exit(ret)
141
+ print('done')
142
+
143
+ #export
144
+ print('--> Export RKNN model')
145
+ ret = rknn.export_rknn(RKNN_MODEL)
146
+ if ret != 0:
147
+ print('Export RKNN model failed!')
148
+ exit(ret)
149
+ print('done')
150
+
151
+ def convert_encoder():
152
+ rknn = RKNN(verbose=True)
153
+
154
+ ONNX_MODEL="encoder_model.onnx"
155
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
156
+ DATASET="dataset.txt"
157
+ QUANTIZE=False
158
+
159
+ input_shapes = [[[batch_size, encoder_seq_len, 768], [batch_size, encoder_seq_len]] for encoder_seq_len in encoder_seq_len_list]
160
+ # pre-process config
161
+ print('--> Config model')
162
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, dynamic_input=input_shapes)
163
+ print('done')
164
+
165
+ # Load ONNX model
166
+ print('--> Loading model')
167
+ ret = rknn.load_onnx(model=ONNX_MODEL
168
+ )
169
+ if ret != 0:
170
+ print('Load model failed!')
171
+ exit(ret)
172
+ print('done')
173
+
174
+ # Build model
175
+ print('--> Building model')
176
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
177
+ if ret != 0:
178
+ print('Build model failed!')
179
+ exit(ret)
180
+ print('done')
181
+
182
+ # Export RKNN model
183
+ print('--> Export RKNN model')
184
+ ret = rknn.export_rknn(RKNN_MODEL)
185
+ if ret != 0:
186
+ print('Export RKNN model failed!')
187
+ exit(ret)
188
+ print('done')
189
+
190
+ def convert_vision():
191
+ ONNX_MODEL="vision_encoder.onnx"
192
+ DATASET="dataset.txt"
193
+ QUANTIZE=False
194
+ global batch_size
195
+
196
+ ##### Build stage 1
197
+ from rknn.api import RKNN
198
+ rknn = RKNN(verbose=True)
199
+ ONNX_MODEL="vision_encoder.onnx"
200
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
201
+ DATASET="dataset.txt"
202
+ QUANTIZE=False
203
+ # pre-process config
204
+ print('--> Config model')
205
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3)
206
+ print('done')
207
+
208
+ # Load ONNX model
209
+ print('--> Loading model')
210
+ ret = rknn.load_onnx(model=ONNX_MODEL,
211
+ inputs=["pixel_values"],
212
+ input_size_list=[[batch_size, 3, 64, 64]],
213
+ )
214
+ if ret != 0:
215
+ print('Load model failed!')
216
+ exit(ret)
217
+ print('done')
218
+
219
+ print('--> Building model stage 1')
220
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
221
+ if ret != 0:
222
+ print('Build model failed!')
223
+ exit(ret)
224
+ print('done')
225
+
226
+ print("Build stage 1 done")
227
+ del rknn
228
+
229
+ intermidiate_model = onnx.load("check3_fuse_ops.onnx")
230
+
231
+ # fuse ops
232
+ from onnxscript.rewriter import pattern
233
+ import onnx.numpy_helper as onh
234
+ import numpy as np
235
+ def tp_rs_tp_rs_tp_pattern(op, input1, perm1, shape2, perm3, shape4, perm5):
236
+ i1 = op.Transpose(input1, perm=perm1)
237
+ i2 = op.Reshape(i1, shape2)
238
+ i3 = op.Transpose(i2, perm=perm3)
239
+ i4 = op.Reshape(i3, shape4)
240
+ i5 = op.Transpose(i4, perm=perm5)
241
+ return i5
242
+
243
+ def fused_pattern(op, input1, perm1, shape2, perm3, shape4, perm5):
244
+ rs1_shape = op.Constant(value=onh.from_array(np.array([input1.shape[0]* 3, input1.shape[1]//3, input1.shape[2], input1.shape[3]], dtype=np.int64)))
245
+ fi1 = op.Reshape(input1, rs1_shape)
246
+ fi2 = op.Transpose(fi1, perm=[0, 2, 1, 3])
247
+ elems = input1.shape[0] * input1.shape[1] * input1.shape[2] * input1.shape[3]
248
+ rs4_shape = op.Constant(value=onh.from_array(np.array([elems / 32 / 144, 32, 1, 144], dtype=np.int64)))
249
+ fi3 = op.Reshape(fi2, rs4_shape)
250
+ return fi3
251
+
252
+ rewrite_rule = pattern.RewriteRule(tp_rs_tp_rs_tp_pattern, fused_pattern)
253
+ rewrite_rule_set = pattern.RewriteRuleSet([rewrite_rule],commute=True)
254
+ fused_model = onnxscript.rewriter.rewrite(
255
+ intermidiate_model,
256
+ pattern_rewrite_rules=rewrite_rule_set
257
+ )
258
+ onnx.save(fused_model, "vision_encoder_optimized.onnx")
259
+ ONNX_MODEL = "vision_encoder_optimized.onnx"
260
+ # RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
261
+ del intermidiate_model
262
+ del fused_model
263
+
264
+
265
+ rknn = RKNN(verbose=True)
266
+
267
+ # pre-process config
268
+ print('--> Config model')
269
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3)
270
+ print('done')
271
+
272
+ # Load ONNX model
273
+ print('--> Loading model')
274
+ ret = rknn.load_onnx(model=ONNX_MODEL)
275
+ if ret != 0:
276
+ print('Load model failed!')
277
+ exit(ret)
278
+ print('done')
279
+
280
+ # Build model
281
+ print('--> Building model stage 2')
282
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
283
+ if ret != 0:
284
+ print('Build model failed!')
285
+ exit(ret)
286
+ print('done')
287
+
288
+ # Export RKNN model
289
+ print('--> Export RKNN model')
290
+ ret = rknn.export_rknn(RKNN_MODEL)
291
+ if ret != 0:
292
+ print('Export RKNN model failed!')
293
+ exit(ret)
294
+ print('done')
295
+ os.remove("vision_encoder_optimized.onnx")
296
+
297
+
298
+
299
+ def check_vision_model():
300
+ rknn = RKNN(verbose=True)
301
+
302
+ ONNX_MODEL="vision_encoder.onnx"
303
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
304
+ DATASET="dataset.txt"
305
+ QUANTIZE=False
306
+
307
+ # pre-process config
308
+ print('--> Config model')
309
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3)
310
+ print('done')
311
+
312
+ # Load ONNX model
313
+ print('--> Loading model')
314
+ ret = rknn.load_onnx(model=ONNX_MODEL,
315
+ inputs=["pixel_values"],
316
+ input_size_list=[[batch_size, 3, vision_size[0], vision_size[1]]],
317
+ )
318
+ if ret != 0:
319
+ print('Load model failed!')
320
+ exit(ret)
321
+ print('done')
322
+
323
+ # Build model
324
+ print('--> Building model')
325
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
326
+ if ret != 0:
327
+ print('Build model failed!')
328
+ exit(ret)
329
+ print('done')
330
+
331
+ # Export RKNN model
332
+ print('--> Export RKNN model')
333
+ ret = rknn.export_rknn(RKNN_MODEL)
334
+ if ret != 0:
335
+ print('Export RKNN model failed!')
336
+ exit(ret)
337
+ print('done')
338
+
339
+ #init runtime
340
+ print('--> Init runtime environment')
341
+ ret = rknn.init_runtime(target='rk3588')
342
+ if ret != 0:
343
+ print('Init runtime environment failed!')
344
+ exit(ret)
345
+ print('done')
346
+
347
+ #precision check
348
+ print('--> Precision check')
349
+ ret = rknn.accuracy_analysis(inputs=["lena.png"], target='rk3588')
350
+ if ret != 0:
351
+ print('Precision check failed!')
352
+ exit(ret)
353
+ print('done')
354
+
355
+
356
+ import argparse
357
+ # python convert.py <decoder|encoder|vision|all>
358
+ if __name__ == "__main__":
359
+ parser = argparse.ArgumentParser()
360
+ parser.add_argument("model", type=str, help="Model to convert")
361
+ parser.add_argument("--check", action="store_true", help="Check model")
362
+ args = parser.parse_args()
363
+ if args.model == "decoder":
364
+ convert_decoder()
365
+ elif args.model == "encoder":
366
+ convert_encoder()
367
+ # elif args.model == "embed": # embed is faster with cpu
368
+ # convert_embed()
369
+ elif args.model == "vision":
370
+ if args.check:
371
+ check_vision_model()
372
+ else:
373
+ convert_vision()
374
+ elif args.model == "all":
375
+ convert_decoder()
376
+ convert_encoder()
377
+ # convert_embed()
378
+ convert_vision()
379
+ else:
380
+ print("Invalid model")
381
+ exit(1)
decoder_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b197ed07d9fe1da03dcce93b1f5ebf3cee4b66e531c9703b2087fc53ca50acb
3
+ size 387818953
decoder_model.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:414974a11e9ef72012f77c16c3da8c633ebeb351e0fa78cc2dea5737d431ff1f
3
+ size 194928054
decoder_model_merged.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b4ff6b536f773955b3355c5f19fb8436100b10705cc2e520724d6feeb733ae5
3
+ size 388046167
embed_tokens.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a51f42510df4e723a5d50f9da43fccd9e59d4c507bb9b28960d6500e778ee3b0
3
+ size 157560107
encoder_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:934f366c4d834a58f017e5373f01df0b2b9a333533c1e41032e7de8849c61a12
3
+ size 173409090
encoder_model.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d1af501f57d5e068d47554e5c08cb0a489d24c7102d433a85fde0413ac12d2f
3
+ size 86759918
image.png ADDED
run.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor
2
+ from PIL import Image
3
+ import numpy as np
4
+ import onnxruntime as ort
5
+ import time
6
+ import argparse
7
+ import random
8
+
9
+ # Use RKNN for some models
10
+ import ztu_somemodelruntime_rknnlite2 as rknnort
11
+ # Uncomment this to use ONNXRuntime for some models
12
+ # import onnxruntime as rknnort
13
+
14
+ # set current working directory to the directory of this file
15
+ import os
16
+
17
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
18
+
19
+
20
+ def run(image_path, prompt, max_new_tokens, output_image_path, temperature, seed):
21
+ # set seed for reproducibility
22
+ if seed is not None:
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+
26
+ # 初始化总时间计数器
27
+ total_time = 0
28
+
29
+ # Initialize RKNNLite instances
30
+ vision_encoder = rknnort.InferenceSession(
31
+ "vision_encoder.onnx", providers=["CPUExecutionProvider"]
32
+ )
33
+ encoder = rknnort.InferenceSession(
34
+ "encoder_model.onnx", providers=["CPUExecutionProvider"]
35
+ )
36
+ decoder_prefill = rknnort.InferenceSession(
37
+ "decoder_model.onnx", providers=["CPUExecutionProvider"]
38
+ )
39
+
40
+ text_embed = ort.InferenceSession(
41
+ "embed_tokens.onnx", providers=["CPUExecutionProvider"]
42
+ )
43
+ decoder_decode = ort.InferenceSession(
44
+ "decoder_model_merged.onnx", providers=["CPUExecutionProvider"]
45
+ )
46
+
47
+ # 1. prepare inputs
48
+ processor = AutoProcessor.from_pretrained(
49
+ "microsoft/Florence-2-base", trust_remote_code=True
50
+ )
51
+
52
+ # 2. prepare image
53
+ image = Image.open(image_path).convert("RGB")
54
+ original_image = image.copy()
55
+ original_size = image.size
56
+ # resize image to 64x64
57
+ image = image.resize((64, 64))
58
+ # 3. prepare text
59
+
60
+ inputs = processor(
61
+ text=prompt, images=image, return_tensors="np", do_resize=False
62
+ ) # , padding="max_length", max_length=pad_to + 577, truncation=True)
63
+ for k, v in inputs.items():
64
+ print(k, v.shape)
65
+ # print(inputs)
66
+ # 4. run vision encoder using RKNN
67
+ start_time = time.time()
68
+ image_features = vision_encoder.run(None, {"pixel_values": inputs["pixel_values"]})[
69
+ 0
70
+ ]
71
+
72
+ end_time = time.time()
73
+ vision_encoder_time = (end_time - start_time) * 1000
74
+ total_time += vision_encoder_time
75
+ print(f"Vision encoder time: {vision_encoder_time:.2f} ms")
76
+ print(image_features.shape)
77
+ # np.save("image_features.npy", image_features)
78
+
79
+ # 5. run text embed using RKNN
80
+ start_time = time.time()
81
+ inputs_embeds = text_embed.run(None, {"input_ids": inputs["input_ids"]})[0]
82
+ end_time = time.time()
83
+ text_embed_time = (end_time - start_time) * 1000
84
+ total_time += text_embed_time
85
+ print(f"Text embed time: {text_embed_time:.2f} ms")
86
+ print(inputs_embeds.shape)
87
+ # print(inputs_embeds)
88
+
89
+ # 6. concat image features and text embed
90
+ batch_size, image_token_length = image_features.shape[:-1]
91
+ image_attention_mask = np.ones((batch_size, image_token_length))
92
+ task_prefix_embeds = inputs_embeds
93
+ task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1]))
94
+ # task_prefix_attention_mask = inputs["attention_mask"]
95
+ if len(task_prefix_attention_mask.shape) == 3:
96
+ task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
97
+ inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1)
98
+ attention_mask = np.concatenate(
99
+ [image_attention_mask, task_prefix_attention_mask], axis=1
100
+ )
101
+
102
+ # 6. run encoder using RKNN
103
+ start_time = time.time()
104
+ encoder_out = encoder.run(
105
+ None,
106
+ {
107
+ "inputs_embeds": inputs_embeds,
108
+ "attention_mask": attention_mask.astype(np.int64),
109
+ },
110
+ )
111
+ end_time = time.time()
112
+ encoder_time = (end_time - start_time) * 1000
113
+ total_time += encoder_time
114
+ print(f"Encoder time: {encoder_time:.2f} ms")
115
+ encoder_hidden_states = encoder_out[0]
116
+ print(encoder_hidden_states.shape)
117
+
118
+ # 7. run decoder prefill stage using RKNN
119
+ start_time = time.time()
120
+ next_token = processor.tokenizer.bos_token_id
121
+ next_input_embeds = text_embed.run(None, {
122
+ "input_ids": np.array([[next_token]], dtype=np.int64)
123
+ })[0]
124
+ decoder_outs = decoder_prefill.run(
125
+ None,
126
+ {
127
+ "inputs_embeds": next_input_embeds,
128
+ "encoder_hidden_states": encoder_hidden_states,
129
+ # "encoder_attention_mask": attention_mask.astype(np.int64)
130
+ },
131
+ )
132
+ end_time = time.time()
133
+ decoder_prefill_time = (end_time - start_time) * 1000
134
+ total_time += decoder_prefill_time
135
+ print(f"Decoder prefill time: {decoder_prefill_time:.2f} ms")
136
+ # for output in decoder_outs:
137
+ # print(output.shape)
138
+
139
+ encoder_kv = decoder_outs[1:]
140
+
141
+ # 8. run decoder decode stage(autoregressive) (using onnxruntime)
142
+ generated_tokens = []
143
+ decoder_decode_total_time = 0
144
+ while generated_tokens.__len__() < max_new_tokens:
145
+ # 获取上一步的输出
146
+ logits = decoder_outs[0]
147
+ decoder_kv = decoder_outs[1:]
148
+
149
+ # 选择最后一个token的logits
150
+ next_token_logits = logits[:, -1, :]
151
+
152
+ if temperature == 0:
153
+ # Greedy decoding
154
+ next_token = np.argmax(next_token_logits, axis=-1)[0]
155
+ else:
156
+ # Temperature sampling
157
+ # 应用温度
158
+ next_token_logits /= temperature
159
+
160
+ # 从logits中减去最大值以提高数值稳定性
161
+ next_token_logits -= np.max(next_token_logits)
162
+
163
+ # 计算softmax
164
+ probs = np.exp(next_token_logits) / np.sum(np.exp(next_token_logits))
165
+
166
+ # 从概率分布中采样
167
+ next_token = np.random.choice(len(probs[0]), p=probs[0])
168
+
169
+ print("next_token: ", processor.decode([next_token]))
170
+ # 将新生成的token添加到结果中
171
+ generated_tokens.append(next_token)
172
+
173
+ # 如果生成了结束符,则停止生成
174
+ if next_token == 2: # </s>
175
+ break
176
+
177
+ # 准备下一步的输入
178
+ start_time = time.time()
179
+ next_input_embeds = text_embed.run(
180
+ None, {"input_ids": np.array([[next_token]], dtype=np.int64)}
181
+ )[0]
182
+ end_time = time.time()
183
+ text_embed_time = (end_time - start_time) * 1000
184
+ decoder_decode_total_time += text_embed_time
185
+
186
+ # 运行decoder的decode阶段
187
+ start_time = time.time()
188
+ decoder_outs = decoder_decode.run(
189
+ None,
190
+ {
191
+ "use_cache_branch": np.array([True], dtype=np.bool_),
192
+ "inputs_embeds": next_input_embeds,
193
+ "encoder_hidden_states": encoder_hidden_states,
194
+ # "encoder_attention_mask": attention_mask.astype(np.int64),
195
+ "past_key_values.0.decoder.key": decoder_kv[0],
196
+ "past_key_values.0.decoder.value": decoder_kv[1],
197
+ "past_key_values.0.encoder.key": encoder_kv[2],
198
+ "past_key_values.0.encoder.value": encoder_kv[3],
199
+ "past_key_values.1.decoder.key": decoder_kv[4],
200
+ "past_key_values.1.decoder.value": decoder_kv[5],
201
+ "past_key_values.1.encoder.key": encoder_kv[6],
202
+ "past_key_values.1.encoder.value": encoder_kv[7],
203
+ "past_key_values.2.decoder.key": decoder_kv[8],
204
+ "past_key_values.2.decoder.value": decoder_kv[9],
205
+ "past_key_values.2.encoder.key": encoder_kv[10],
206
+ "past_key_values.2.encoder.value": encoder_kv[11],
207
+ "past_key_values.3.decoder.key": decoder_kv[12],
208
+ "past_key_values.3.decoder.value": decoder_kv[13],
209
+ "past_key_values.3.encoder.key": encoder_kv[14],
210
+ "past_key_values.3.encoder.value": encoder_kv[15],
211
+ "past_key_values.4.decoder.key": decoder_kv[16],
212
+ "past_key_values.4.decoder.value": decoder_kv[17],
213
+ "past_key_values.4.encoder.key": encoder_kv[18],
214
+ "past_key_values.4.encoder.value": encoder_kv[19],
215
+ "past_key_values.5.decoder.key": decoder_kv[20],
216
+ "past_key_values.5.decoder.value": decoder_kv[21],
217
+ "past_key_values.5.encoder.key": encoder_kv[22],
218
+ "past_key_values.5.encoder.value": encoder_kv[23],
219
+ },
220
+ )
221
+ end_time = time.time()
222
+ decoder_decode_time = (end_time - start_time) * 1000
223
+ decoder_decode_total_time += decoder_decode_time
224
+
225
+ total_time += decoder_decode_total_time
226
+ print(f"Decoder decode total time: {decoder_decode_total_time:.2f} ms")
227
+
228
+ # 将生成的tokens转换为文本
229
+ print("generated_tokens: ", generated_tokens)
230
+ generated_text = processor.batch_decode(
231
+ [generated_tokens], skip_special_tokens=False
232
+ )[0]
233
+ print("Generated Text:", generated_text)
234
+ parsed_answer = processor.post_process_generation(
235
+ generated_text,
236
+ task=prompt.split(">")[0].strip() + ">",
237
+ image_size=original_size,
238
+ )
239
+ print("Parsed Answer:", parsed_answer)
240
+
241
+ print(f"Total inference time: {total_time:.2f} ms")
242
+
243
+
244
+ if __name__ == "__main__":
245
+ parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
246
+ parser.add_argument("image_path", type=str, help="Path to the input image.")
247
+ parser.add_argument(
248
+ "--max_new_tokens",
249
+ type=int,
250
+ default=512,
251
+ help="Maximum number of new tokens to generate.",
252
+ )
253
+ parser.add_argument(
254
+ "--output_image_path",
255
+ type=str,
256
+ default="result_image.jpg",
257
+ help="Path to save the output image with visualizations.",
258
+ )
259
+ parser.add_argument(
260
+ "--temperature",
261
+ type=float,
262
+ default=0,
263
+ help="Temperature for sampling. Set to 0 for greedy decoding.",
264
+ )
265
+ parser.add_argument(
266
+ "--seed", type=int, default=None, help="Random seed for reproducibility."
267
+ )
268
+ args = parser.parse_args()
269
+ run(
270
+ args.image_path,
271
+ "<CAPTION>",
272
+ args.max_new_tokens,
273
+ args.output_image_path,
274
+ args.temperature,
275
+ args.seed,
276
+ )
vision_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:731e2a61276e681979ea5a6fca66da84e59877b045bf4b11299cf43f92817a2a
3
+ size 365965528
vision_encoder.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:184089ff44d372ab6e83f76ee077da9f213de9d5df9433365fe25f6b225b9183
3
+ size 191014658
ztu_somemodelruntime_rknnlite2.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 模块级常量和函数
2
+ from rknnlite.api import RKNNLite
3
+ import numpy as np
4
+ import os
5
+ import warnings
6
+ import logging
7
+ from typing import List, Dict, Union, Optional
8
+
9
+ try:
10
+ import onnxruntime as ort
11
+ HAS_ORT = True
12
+ except ImportError:
13
+ HAS_ORT = False
14
+ warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
15
+
16
+ # 配置日志
17
+ logger = logging.getLogger("somemodelruntime_rknnlite2")
18
+ logger.setLevel(logging.ERROR) # 默认只输出错误信息
19
+ if not logger.handlers:
20
+ handler = logging.StreamHandler()
21
+ handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
22
+ logger.addHandler(handler)
23
+
24
+ # ONNX Runtime日志级别到Python logging级别的映射
25
+ _LOGGING_LEVEL_MAP = {
26
+ 0: logging.DEBUG, # Verbose
27
+ 1: logging.INFO, # Info
28
+ 2: logging.WARNING, # Warning
29
+ 3: logging.ERROR, # Error
30
+ 4: logging.CRITICAL # Fatal
31
+ }
32
+
33
+ # 检查环境变量中的日志级别设置
34
+ try:
35
+ env_log_level = os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL')
36
+ if env_log_level is not None:
37
+ log_level = int(env_log_level)
38
+ if log_level in _LOGGING_LEVEL_MAP:
39
+ logger.setLevel(_LOGGING_LEVEL_MAP[log_level])
40
+ logger.info(f"从环境变量设置日志级别: {log_level}")
41
+ else:
42
+ logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {log_level}, 应该是0-4之间的整数")
43
+ except ValueError:
44
+ logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {env_log_level}, 应该是0-4之间的整数")
45
+
46
+
47
+ def set_default_logger_severity(level: int) -> None:
48
+ """
49
+ Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
50
+
51
+ Args:
52
+ level: 日志级别(0-4)
53
+ """
54
+ if level not in _LOGGING_LEVEL_MAP:
55
+ raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
56
+ logger.setLevel(_LOGGING_LEVEL_MAP[level])
57
+
58
+ def set_default_logger_verbosity(level: int) -> None:
59
+ """
60
+ Sets the default logging verbosity level. To activate the verbose log,
61
+ you need to set the default logging severity to 0:Verbose level.
62
+
63
+ Args:
64
+ level: 日志级别(0-4)
65
+ """
66
+ set_default_logger_severity(level)
67
+
68
+ # RKNN tensor type到numpy dtype的映射
69
+ RKNN_DTYPE_MAP = {
70
+ 0: np.float32, # RKNN_TENSOR_FLOAT32
71
+ 1: np.float16, # RKNN_TENSOR_FLOAT16
72
+ 2: np.int8, # RKNN_TENSOR_INT8
73
+ 3: np.uint8, # RKNN_TENSOR_UINT8
74
+ 4: np.int16, # RKNN_TENSOR_INT16
75
+ 5: np.uint16, # RKNN_TENSOR_UINT16
76
+ 6: np.int32, # RKNN_TENSOR_INT32
77
+ 7: np.uint32, # RKNN_TENSOR_UINT32
78
+ 8: np.int64, # RKNN_TENSOR_INT64
79
+ 9: bool, # RKNN_TENSOR_BOOL
80
+ 10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
81
+ }
82
+
83
+ def get_available_providers() -> List[str]:
84
+ """
85
+ 获取可用的设备提供者列表(为保持接口兼容性的占位函数)
86
+
87
+ Returns:
88
+ list: 可用的设备提供者列表,总是返回["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
89
+ """
90
+ return ["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
91
+
92
+
93
+ def get_device() -> str:
94
+ """
95
+ 获取当前设备
96
+
97
+ Returns:
98
+ str: 当前设备
99
+ """
100
+ return "RKNN2"
101
+
102
+ def get_version_info() -> Dict[str, str]:
103
+ """
104
+ 获取版本信息
105
+
106
+ Returns:
107
+ dict: 包含API和驱动版本信息的字典
108
+ """
109
+ runtime = RKNNLite()
110
+ version = runtime.get_sdk_version()
111
+ return {
112
+ "api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
113
+ "driver_version": version.split('\n')[3].split(': ')[1]
114
+ }
115
+
116
+ class IOTensor:
117
+ """输入/输出张量的信息封装类"""
118
+ def __init__(self, name, shape, type=None):
119
+ self.name = name.decode() if isinstance(name, bytes) else name
120
+ self.shape = shape
121
+ self.type = type
122
+
123
+ def __str__(self):
124
+ return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
125
+
126
+ class SessionOptions:
127
+ """会话选项类"""
128
+ def __init__(self):
129
+ self.enable_profiling = False # 是否使用性能分析
130
+ self.intra_op_num_threads = 1 # 设置RKNN的线程数, 对应rknn的core_mask
131
+ self.log_severity_level = -1 # 另一个设置日志级别的参数
132
+ self.log_verbosity_level = -1 # 另一个设置日志级别的参数
133
+
134
+
135
+ class InferenceSession:
136
+ """
137
+ RKNNLite运行时封装类,API风格类似ONNX Runtime
138
+ """
139
+
140
+ def __new__(cls, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
141
+ processed_path = InferenceSession._process_model_path(model_path, sess_options)
142
+ if isinstance(processed_path, str) and processed_path.lower().endswith('.onnx'):
143
+ logger.info("使用ONNX Runtime加载模型")
144
+ if not HAS_ORT:
145
+ raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
146
+ return ort.InferenceSession(processed_path, sess_options=sess_options, **kwargs)
147
+ else:
148
+ # 如果不是 ONNX 模型,则调用父类的 __new__ 创建 InferenceSession 实例
149
+ instance = super().__new__(cls)
150
+ # 保存处理后的路径
151
+ instance._processed_path = processed_path
152
+ return instance
153
+
154
+ def __init__(self, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
155
+ """
156
+ 初始化运行时并加载模型
157
+
158
+ Args:
159
+ model_path: 模型文件路径(.rknn或.onnx)
160
+ sess_options: 会话选项
161
+ **kwargs: 其他初始化参数
162
+ """
163
+ options = sess_options or SessionOptions()
164
+
165
+ # 只在未设置环境变量时使用SessionOptions中的日志级别
166
+ if os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') is None:
167
+ if options.log_severity_level != -1:
168
+ set_default_logger_severity(options.log_severity_level)
169
+ if options.log_verbosity_level != -1:
170
+ set_default_logger_verbosity(options.log_verbosity_level)
171
+
172
+ # 使用__new__中处理好的路径
173
+ model_path = getattr(self, '_processed_path', model_path)
174
+ if isinstance(model_path, str) and model_path.lower().endswith('.onnx'):
175
+ # 避免重复加载 ONNX 模型
176
+ return
177
+
178
+ # ... 现有的 RKNN 模型加载和初始化代码 ...
179
+ self.model_path = model_path
180
+ if not os.path.exists(self.model_path):
181
+ logger.error(f"模型文件不存在: {self.model_path}")
182
+ raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
183
+
184
+ self.runtime = RKNNLite(verbose=options.enable_profiling)
185
+
186
+ logger.debug(f"正在加载模型: {self.model_path}")
187
+ ret = self.runtime.load_rknn(self.model_path)
188
+ if ret != 0:
189
+ logger.error(f"加载RKNN模型失败: {self.model_path}")
190
+ raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
191
+ logger.debug("模型加载成功")
192
+
193
+
194
+ if options.intra_op_num_threads == 1:
195
+ core_mask = RKNNLite.NPU_CORE_AUTO
196
+ elif options.intra_op_num_threads == 2:
197
+ core_mask = RKNNLite.NPU_CORE_0_1
198
+ elif options.intra_op_num_threads == 3:
199
+ core_mask = RKNNLite.NPU_CORE_0_1_2
200
+ else:
201
+ raise ValueError(f"intra_op_num_threads的值无效: {options.intra_op_num_threads}, 只能是1,2或3")
202
+
203
+ logger.debug("正在初始化运行时环境")
204
+ ret = self.runtime.init_runtime(core_mask=core_mask)
205
+ if ret != 0:
206
+ logger.error("初始化运行时环境失败")
207
+ raise RuntimeError('初始化运行时环境失败')
208
+ logger.debug("运行时环境初始化成功")
209
+
210
+ self._init_io_info()
211
+ self.options = options
212
+
213
+ def get_performance_info(self) -> Dict[str, float]:
214
+ """
215
+ 获取性能信息
216
+
217
+ Returns:
218
+ dict: 包含性能信息的字典
219
+ """
220
+ if not self.options.perf_debug:
221
+ raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
222
+
223
+ perf = self.runtime.rknn_runtime.get_run_perf()
224
+ return {
225
+ "run_duration": perf.run_duration / 1000.0 # 转换为毫秒
226
+ }
227
+
228
+ def set_core_mask(self, core_mask: int) -> None:
229
+ """
230
+ 设置NPU核心使用模式
231
+
232
+ Args:
233
+ core_mask: NPU核心掩码,使用NPU_CORE_*常量
234
+ """
235
+ ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
236
+ if ret != 0:
237
+ raise RuntimeError("设置NPU核心模式失败")
238
+
239
+ @staticmethod
240
+ def _process_model_path(model_path, sess_options):
241
+ """
242
+ 处理模型路径,支持.onnx和.rknn文件
243
+
244
+ Args:
245
+ model_path: 模型文件路径
246
+ """
247
+ # 如果是ONNX文件,检查是否需要自动加载RKNN
248
+ if model_path.lower().endswith('.onnx'):
249
+ logger.info("检测到ONNX模型文件")
250
+
251
+ # 获取需要跳过自动加载的模型列表
252
+ skip_models = os.getenv('ZTU_MODELRT_RKNNL2_SKIP', '').strip()
253
+ if skip_models:
254
+ skip_list = [m.strip() for m in skip_models.split(',')]
255
+ # 获取模型文件名(不含路径)用于匹配
256
+ model_name = os.path.basename(model_path)
257
+ if model_name.lower() in [m.lower() for m in skip_list]:
258
+ logger.info(f"模型{model_name}在跳过列表中,将使用ONNX Runtime")
259
+ return model_path
260
+
261
+ # 构造RKNN文件路径
262
+ rknn_path = os.path.splitext(model_path)[0] + '.rknn'
263
+ if os.path.exists(rknn_path):
264
+ logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
265
+ return rknn_path
266
+ else:
267
+ logger.info("未找到对应的RKNN模型,将使用ONNX Runtime")
268
+ return model_path
269
+
270
+ return model_path
271
+
272
+ def _convert_nhwc_to_nchw(self, shape):
273
+ """将NHWC格式的shape转换为NCHW格式"""
274
+ if len(shape) == 4:
275
+ # NHWC -> NCHW
276
+ n, h, w, c = shape
277
+ return [n, c, h, w]
278
+ return shape
279
+
280
+ def _init_io_info(self):
281
+ """初始化模型的输入输出信息"""
282
+ runtime = self.runtime.rknn_runtime
283
+
284
+ # 获取输入输出数量
285
+ n_input, n_output = runtime.get_in_out_num()
286
+
287
+ # 获取输入信息
288
+ self.input_tensors = []
289
+ for i in range(n_input):
290
+ attr = runtime.get_tensor_attr(i)
291
+ shape = [attr.dims[j] for j in range(attr.n_dims)]
292
+ # 对四维输入进行NHWC到NCHW的转换
293
+ shape = self._convert_nhwc_to_nchw(shape)
294
+ # 获取dtype
295
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
296
+ tensor = IOTensor(attr.name, shape, dtype)
297
+ self.input_tensors.append(tensor)
298
+
299
+ # 获取输出信息
300
+ self.output_tensors = []
301
+ for i in range(n_output):
302
+ attr = runtime.get_tensor_attr(i, is_output=True)
303
+ shape = runtime.get_output_shape(i)
304
+ # 获取dtype
305
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
306
+ tensor = IOTensor(attr.name, shape, dtype)
307
+ self.output_tensors.append(tensor)
308
+
309
+ def get_inputs(self):
310
+ """
311
+ 获取模型输入信息
312
+
313
+ Returns:
314
+ list: 包含输入信息的列表
315
+ """
316
+ return self.input_tensors
317
+
318
+ def get_outputs(self):
319
+ """
320
+ 获取模型输出信息
321
+
322
+ Returns:
323
+ list: 包含输出信息的列表
324
+ """
325
+ return self.output_tensors
326
+
327
+ def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
328
+ """
329
+ 执行模型推理
330
+
331
+ Args:
332
+ output_names: 输出节点名称列表,指定需要返回哪些输出
333
+ input_feed: 输入数据字典或列表
334
+ data_format: 输入数据格式,"nchw"或"nhwc"
335
+ **kwargs: 其他运行时参数
336
+
337
+ Returns:
338
+ list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
339
+ """
340
+ if input_feed is None:
341
+ logger.error("input_feed不能为None")
342
+ raise ValueError("input_feed不能为None")
343
+
344
+ # 准备输入数据
345
+ if isinstance(input_feed, dict):
346
+ # 如果是字典,按照模型输入顺序排列
347
+ inputs = []
348
+ input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
349
+ for tensor in self.input_tensors:
350
+ if tensor.name not in input_feed:
351
+ raise ValueError(f"缺少输入: {tensor.name}")
352
+ inputs.append(input_feed[tensor.name])
353
+ elif isinstance(input_feed, (list, tuple)):
354
+ # 如果是列表,确保长度匹配
355
+ if len(input_feed) != len(self.input_tensors):
356
+ raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
357
+ inputs = list(input_feed)
358
+ else:
359
+ logger.error("input_feed必须是字典或列表类型")
360
+ raise ValueError("input_feed必须是字典或列表类型")
361
+
362
+ # 执行推理
363
+ try:
364
+ logger.debug("开始执行推理")
365
+ all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
366
+
367
+ # 如果没有指定output_names,返回所有输出
368
+ if output_names is None:
369
+ return all_outputs
370
+
371
+ # 获取指定的输出
372
+ output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
373
+ selected_outputs = []
374
+ for name in output_names:
375
+ if name not in output_map:
376
+ raise ValueError(f"未找到输出节点: {name}")
377
+ selected_outputs.append(all_outputs[output_map[name]])
378
+
379
+ return selected_outputs
380
+
381
+ except Exception as e:
382
+ logger.error(f"推理执行失败: {str(e)}")
383
+ raise RuntimeError(f"推理执行失败: {str(e)}")
384
+
385
+ def close(self):
386
+ """
387
+ 关闭会话,释放资源
388
+ """
389
+ if self.runtime is not None:
390
+ logger.info("正在释放运行时资源")
391
+ self.runtime.release()
392
+ self.runtime = None
393
+
394
+ def __enter__(self):
395
+ return self
396
+
397
+ def __exit__(self, exc_type, exc_val, exc_tb):
398
+ self.close()
399
+
400
+ def end_profiling(self) -> Optional[str]:
401
+ """
402
+ 结束性能分析的存根方法
403
+
404
+ Returns:
405
+ Optional[str]: None
406
+ """
407
+ warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
408
+ return None
409
+
410
+ def get_profiling_start_time_ns(self) -> int:
411
+ """
412
+ 获取性能分析开始时间的存根方法
413
+
414
+ Returns:
415
+ int: 0
416
+ """
417
+ warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
418
+ return 0
419
+
420
+ def get_modelmeta(self) -> Dict[str, str]:
421
+ """
422
+ 获取模型元数据的存根方法
423
+
424
+ Returns:
425
+ Dict[str, str]: 空字典
426
+ """
427
+ warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
428
+ return {}
429
+
430
+ def get_session_options(self) -> SessionOptions:
431
+ """
432
+ 获取会话选项
433
+
434
+ Returns:
435
+ SessionOptions: 当前会话选项
436
+ """
437
+ return self.options
438
+
439
+ def get_providers(self) -> List[str]:
440
+ """
441
+ 获取当前使用的providers的存根方法
442
+
443
+ Returns:
444
+ List[str]: ["CPUExecutionProvider"]
445
+ """
446
+ warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
447
+ return ["CPUExecutionProvider"]
448
+
449
+ def get_provider_options(self) -> Dict[str, Dict[str, str]]:
450
+ """
451
+ 获取provider选项的存根方法
452
+
453
+ Returns:
454
+ Dict[str, Dict[str, str]]: 空字典
455
+ """
456
+ warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
457
+ return {}
458
+
459
+ def get_session_config(self) -> Dict[str, str]:
460
+ """
461
+ 获取会话配置的存根方法
462
+
463
+ Returns:
464
+ Dict[str, str]: 空字典
465
+ """
466
+ warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
467
+ return {}
468
+
469
+ def get_session_state(self) -> Dict[str, str]:
470
+ """
471
+ 获取会话状态的存根方法
472
+
473
+ Returns:
474
+ Dict[str, str]: 空字典
475
+ """
476
+ warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
477
+ return {}
478
+
479
+ def set_session_config(self, config: Dict[str, str]) -> None:
480
+ """
481
+ 设置会话配置的存根方法
482
+
483
+ Args:
484
+ config: 会话配置字典
485
+ """
486
+ warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
487
+
488
+ def get_memory_info(self) -> Dict[str, int]:
489
+ """
490
+ 获取内存使用信息的存根方法
491
+
492
+ Returns:
493
+ Dict[str, int]: 空字典
494
+ """
495
+ warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
496
+ return {}
497
+
498
+ def set_memory_pattern(self, enable: bool) -> None:
499
+ """
500
+ 设置内存模式的存根方法
501
+
502
+ Args:
503
+ enable: 是否启用内存模式
504
+ """
505
+ warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
506
+
507
+ def disable_memory_pattern(self) -> None:
508
+ """
509
+ 禁用内存模式的存根方法
510
+ """
511
+ warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
512
+
513
+ def get_optimization_level(self) -> int:
514
+ """
515
+ 获取优化级别的存根方法
516
+
517
+ Returns:
518
+ int: 0
519
+ """
520
+ warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
521
+ return 0
522
+
523
+ def set_optimization_level(self, level: int) -> None:
524
+ """
525
+ 设置优化级别的存根方法
526
+
527
+ Args:
528
+ level: 优化级别
529
+ """
530
+ warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
531
+
532
+ def get_model_metadata(self) -> Dict[str, str]:
533
+ """
534
+ 获取模型元数据的存根方法(与get_modelmeta不同的接口)
535
+
536
+ Returns:
537
+ Dict[str, str]: 空字典
538
+ """
539
+ warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
540
+ return {}
541
+
542
+ def get_model_path(self) -> str:
543
+ """
544
+ 获取模型路径
545
+
546
+ Returns:
547
+ str: 模型文件路径
548
+ """
549
+ return self.model_path
550
+
551
+ def get_input_type_info(self) -> List[Dict[str, str]]:
552
+ """
553
+ 获取输入类型信息的存根方法
554
+
555
+ Returns:
556
+ List[Dict[str, str]]: 空列表
557
+ """
558
+ warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
559
+ return []
560
+
561
+ def get_output_type_info(self) -> List[Dict[str, str]]:
562
+ """
563
+ 获取输出类型信息的存根方法
564
+
565
+ Returns:
566
+ List[Dict[str, str]]: 空列表
567
+ """
568
+ warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
569
+ return []