mirror of
https://gitee.com/Lyon1998/pikapython.git
synced 2025-01-22 17:12:55 +08:00
62 lines
1.5 KiB
C
62 lines
1.5 KiB
C
|
#include "PikaNN.h"
|
||
|
#include "PikaNN_common.h"
|
||
|
|
||
|
extern uint8_t mnist_pic[28 * 28];
|
||
|
extern const uint8_t mdl_data[2408];
|
||
|
tm_err_t layer_cb(tm_mdl_t* mdl, tml_head_t* lh);
|
||
|
void parse_output(tm_mat_t* outs);
|
||
|
|
||
|
void _lm_test(void) {
|
||
|
TM_PRINTF("mnist demo\n");
|
||
|
tm_mdl_t mdl;
|
||
|
|
||
|
for (int i = 0; i < 28 * 28; i++) {
|
||
|
TM_PRINTF("%3d,", mnist_pic[i]);
|
||
|
if (i % 28 == 27)
|
||
|
TM_PRINTF("\n");
|
||
|
}
|
||
|
|
||
|
tm_mat_t in = {3, 28, 28, 1, {NULL}};
|
||
|
tm_err_t res;
|
||
|
tm_stat((tm_mdlbin_t*)mdl_data);
|
||
|
|
||
|
res = tm_load(&mdl, mdl_data, NULL, layer_cb, &in);
|
||
|
if (res != TM_OK) {
|
||
|
TM_PRINTF("tm model load err %d\n", res);
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
uint32_t _start, _finish;
|
||
|
float _time;
|
||
|
_start = TM_GET_US();
|
||
|
|
||
|
tm_mat_t in_uint8 = {3, 28, 28, 1, {(mtype_t*)mnist_pic}};
|
||
|
tm_mat_t outs[1];
|
||
|
#if (TM_MDL_TYPE == TM_MDL_INT8) || (TM_MDL_TYPE == TM_MDL_INT16)
|
||
|
res = tm_preprocess(&mdl, TMPP_UINT2INT, &in_uint8, &in);
|
||
|
#else
|
||
|
res = tm_preprocess(mdl, TMPP_UINT2FP01, &in_uint8, in);
|
||
|
#endif
|
||
|
// TM_DBGT_START();
|
||
|
_start = TM_GET_US();
|
||
|
res = tm_run(&mdl, &in, outs);
|
||
|
// TM_DBGT("tm_run");
|
||
|
|
||
|
_finish = TM_GET_US();
|
||
|
_time = (float)(_finish - _start) / 1000.0;
|
||
|
TM_PRINTF("===%s use %.3f ms\n", "tm_run", _time);
|
||
|
//_start = TM_GET_US();
|
||
|
if (res == TM_OK)
|
||
|
parse_output(outs);
|
||
|
else
|
||
|
TM_PRINTF("tm run error: %d\n", res);
|
||
|
|
||
|
tm_unload(&mdl);
|
||
|
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
void PikaNN_test(PikaObj* self) {
|
||
|
_lm_test();
|
||
|
}
|