2022-09-08 22:48:00 +08:00

62 lines
1.5 KiB
C

#include "PikaNN.h"
#include "PikaNN_common.h"
extern uint8_t mnist_pic[28 * 28];
extern const uint8_t mdl_data[2408];
tm_err_t layer_cb(tm_mdl_t* mdl, tml_head_t* lh);
void parse_output(tm_mat_t* outs);
void _lm_test(void) {
TM_PRINTF("mnist demo\n");
tm_mdl_t mdl;
for (int i = 0; i < 28 * 28; i++) {
TM_PRINTF("%3d,", mnist_pic[i]);
if (i % 28 == 27)
TM_PRINTF("\n");
}
tm_mat_t in = {3, 28, 28, 1, {NULL}};
tm_err_t res;
tm_stat((tm_mdlbin_t*)mdl_data);
res = tm_load(&mdl, mdl_data, NULL, layer_cb, &in);
if (res != TM_OK) {
TM_PRINTF("tm model load err %d\n", res);
return;
}
uint32_t _start, _finish;
float _time;
_start = TM_GET_US();
tm_mat_t in_uint8 = {3, 28, 28, 1, {(mtype_t*)mnist_pic}};
tm_mat_t outs[1];
#if (TM_MDL_TYPE == TM_MDL_INT8) || (TM_MDL_TYPE == TM_MDL_INT16)
res = tm_preprocess(&mdl, TMPP_UINT2INT, &in_uint8, &in);
#else
res = tm_preprocess(mdl, TMPP_UINT2FP01, &in_uint8, in);
#endif
// TM_DBGT_START();
_start = TM_GET_US();
res = tm_run(&mdl, &in, outs);
// TM_DBGT("tm_run");
_finish = TM_GET_US();
_time = (float)(_finish - _start) / 1000.0;
TM_PRINTF("===%s use %.3f ms\n", "tm_run", _time);
//_start = TM_GET_US();
if (res == TM_OK)
parse_output(outs);
else
TM_PRINTF("tm run error: %d\n", res);
tm_unload(&mdl);
return;
}
void PikaNN_test(PikaObj* self) {
_lm_test();
}