pikapython/package/PikaNN/PikaNN_net.c
2023-07-20 16:12:33 +08:00

245 lines
11 KiB
C

#include "PikaNN_net.h"
#include <sys/time.h>
#include <time.h>
#include "PikaNN_common.h"
#if TM_MDL_TYPE == TM_MDL_INT8
#include "./TinyMaix/tools/tmdl/mnist_valid_q.h"
#elif TM_MDL_TYPE == TM_MDL_FP32
#include "./TinyMaix/tools/tmdl/mnist_valid_f.h"
#elif TM_MDL_TYPE == TM_MDL_FP16
#include "./TinyMaix/tools/tmdl/mnist_valid_fp16.h"
#elif TM_MDL_TYPE == TM_MDL_FP8_143
#include "./TinyMaix/tools/tmdl/mnist_fp8_143.h"
#elif TM_MDL_TYPE == TM_MDL_FP8_152
#include "./TinyMaix/tools/tmdl/mnist_fp8_152.h"
#endif
#if 1
uint8_t mnist_pic[28 * 28] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 116, 125, 171, 255, 255, 150, 93, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 169, 253, 253, 253, 253, 253, 253, 218, 30, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 169, 253,
253, 253, 213, 142, 176, 253, 253, 122, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 52, 250, 253, 210, 32,
12, 0, 6, 206, 253, 140, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 77, 251, 210, 25, 0, 0, 0,
122, 248, 253, 65, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 31, 18, 0, 0, 0, 0, 209, 253,
253, 65, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 117, 247, 253, 198, 10,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 76, 247, 253, 231, 63, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 128, 253, 253, 144, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 176, 246, 253, 159, 12, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 25,
234, 253, 233, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 198, 253, 253,
141, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 78, 248, 253, 189, 12, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 19, 200, 253, 253, 141, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 134, 253, 253, 173, 12, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 248, 253, 253, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
248, 253, 253, 43, 20, 20, 20, 20, 5, 0, 5, 20, 20, 37, 150,
150, 150, 147, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 248, 253,
253, 253, 253, 253, 253, 253, 168, 143, 166, 253, 253, 253, 253, 253, 253,
253, 123, 0, 0, 0, 0, 0, 0, 0, 0, 0, 174, 253, 253, 253,
253, 253, 253, 253, 253, 253, 253, 253, 249, 247, 247, 169, 117, 117, 57,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 118, 123, 123, 123, 166,
253, 253, 253, 155, 123, 123, 41, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0,
};
#else
uint8_t mnist_pic[28 * 28] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 36,
56, 137, 201, 199, 95, 37, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 152, 234, 254, 254,
254, 254, 254, 250, 211, 151, 6, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 46, 153, 240, 254, 254, 227, 166, 133, 251,
200, 254, 229, 225, 104, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 153, 234, 254, 254, 187, 142, 8, 0, 0, 191, 40, 198,
246, 223, 253, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8,
126, 253, 254, 233, 128, 11, 0, 0, 0, 0, 210, 43, 70, 254, 254,
254, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 72, 243, 254,
228, 54, 0, 0, 0, 0, 3, 32, 116, 225, 242, 254, 255, 162, 5,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 75, 240, 254, 223, 109,
138, 178, 178, 169, 210, 251, 231, 254, 254, 254, 232, 38, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 9, 175, 244, 253, 255, 254, 254,
251, 254, 254, 254, 254, 254, 252, 171, 25, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 16, 136, 195, 176, 146, 153, 200,
254, 254, 254, 254, 150, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 162, 254, 254,
241, 99, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 118, 250, 254, 254, 90, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 100, 242, 254, 254, 211, 7, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 54, 241, 254, 254, 242, 59, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 131, 254, 254, 244, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 249,
254, 254, 152, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 228, 254, 254, 208,
8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 78, 255, 254, 254, 66, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 209, 254, 254, 137, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 227, 255, 233, 25, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 113, 255, 108, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0,
};
#endif
tm_err_t layer_cb(tm_mdl_t* mdl, tml_head_t* lh) { // dump middle result
int h = lh->out_dims[1];
int w = lh->out_dims[2];
int ch = lh->out_dims[3];
mtype_t* output = TML_GET_OUTPUT(mdl, lh);
return TM_OK;
TM_PRINTF("Layer %d callback ========\n", mdl->layer_i);
#if 1
for (int y = 0; y < h; y++) {
TM_PRINTF("[");
for (int x = 0; x < w; x++) {
TM_PRINTF("[");
for (int c = 0; c < ch; c++) {
#if TM_MDL_TYPE == TM_MDL_FP32
TM_PRINTF("%.3f,", output[(y * w + x) * ch + c]);
#else
TM_PRINTF("%.3f,",
TML_DEQUANT(lh, output[(y * w + x) * ch + c]));
#endif
}
TM_PRINTF("],");
}
TM_PRINTF("],\n");
}
TM_PRINTF("\n");
#endif
return TM_OK;
}
void parse_output(tm_mat_t* outs) {
tm_mat_t out = outs[0];
float* data = out.dataf;
float maxp = 0;
int maxi = -1;
for (int i = 0; i < 10; i++) {
printf("%d: %.3f\n", i, data[i]);
if (data[i] > maxp) {
maxi = i;
maxp = data[i];
}
}
TM_PRINTF("### Predict output is: Number %d, prob %.3f\n", maxi, maxp);
return;
}
void PikaNN_net___init__(PikaObj* self) {
// TM_DBGT_INIT();
TM_PRINTF("mnist demo\n");
tm_mdl_t mdl;
for (int i = 0; i < 28 * 28; i++) {
TM_PRINTF("%3d,", mnist_pic[i]);
if (i % 28 == 27)
TM_PRINTF("\n");
}
obj_setStruct(self, "mdl", mdl);
}
void PikaNN_net_load(PikaObj* self) {
tm_mdl_t* mdl = obj_getStruct(self, "mdl");
tm_mat_t in = {3, 28, 28, 1, {NULL}};
tm_err_t res;
tm_stat((tm_mdlbin_t*)mdl_data);
res = tm_load(mdl, mdl_data, NULL, layer_cb, &in);
if (res != TM_OK) {
TM_PRINTF("tm model load err %d\n", res);
return;
}
obj_setStruct(self, "in", in);
return;
}
void PikaNN_net_run(PikaObj* self) {
// TM_DBGT_INIT();
uint32_t _start, _finish;
float _time;
_start = TM_GET_US();
tm_mat_t in_uint8 = {3, 28, 28, 1, {(mtype_t*)mnist_pic}};
tm_mat_t outs[1];
tm_err_t res;
tm_mdl_t* mdl = obj_getStruct(self, "mdl");
tm_mat_t* in = obj_getStruct(self, "in");
#if (TM_MDL_TYPE == TM_MDL_INT8) || (TM_MDL_TYPE == TM_MDL_INT16)
res = tm_preprocess(mdl, TMPP_UINT2INT, &in_uint8, in);
#else
res = tm_preprocess(mdl, TMPP_UINT2FP01, &in_uint8, in);
#endif
// TM_DBGT_START();
_start = TM_GET_US();
res = tm_run(mdl, in, outs);
// TM_DBGT("tm_run");
_finish = TM_GET_US();
_time = (float)(_finish - _start) / 1000.0;
TM_PRINTF("===%s use %.3f ms\n", "tm_run", _time);
//_start = TM_GET_US();
if (res == TM_OK)
parse_output(outs);
else
TM_PRINTF("tm run error: %d\n", res);
return;
}
void PikaNN_net_unload(PikaObj* self) {
tm_mdl_t* mdl = obj_getStruct(self, "mdl");
tm_unload(mdl);
return;
}