TinyMaix是面向單片機(jī)的超輕量級(jí)的神經(jīng)網(wǎng)絡(luò)推理庫(kù),即TinyML推理庫(kù),可以讓你在任意單片機(jī)上運(yùn)行輕量級(jí)深度學(xué)習(xí)模型。
關(guān)鍵特性
- 核心代碼少于400行(tm_layers.c+tm_model.c+arch_cpu.h), 代碼段(.text)少于3KB
- 低內(nèi)存消耗,甚至Arduino ATmega328 (32KB Flash, 2KB Ram) 都能基于TinyMaix跑mnist(手寫(xiě)數(shù)字識(shí)別)
- 支持INT8/FP32/FP16模型,實(shí)驗(yàn)性地支持FP8模型,支持keras h5或tflite模型轉(zhuǎn)換
- 支持多種芯片架構(gòu)的專用指令優(yōu)化: ARM SIMD/NEON/MVEI,RV32P, RV64V
- 友好的用戶接口,只需要load/run模型~
- 支持全靜態(tài)的內(nèi)存配置(無(wú)需malloc)
- MaixHub 在線模型訓(xùn)練支持
項(xiàng)目地址:
https://github.com/sipeed/TinyMaix
實(shí)驗(yàn)了一把手寫(xiě)識(shí)別數(shù)字效果。效果還行。就是發(fā)現(xiàn)有時(shí)候7會(huì)識(shí)別成2了。
其中有個(gè)7識(shí)別成了2.
下面補(bǔ)充下如何移植。官方文檔內(nèi)有說(shuō)明如何移植(readme文檔),下面重新記錄一下。
TinyMaix的核心文件只有這5個(gè):tm_model.c, tm_layers.c, tinymaix.h, tm_port.h, arch_xxx.h
如果你使用沒(méi)有任何指令加速的普通單片機(jī),選擇 arch_cpu.h, 否則選擇對(duì)應(yīng)架構(gòu)的頭文件
然后你需要編輯tm_port.h,填寫(xiě)你需要的配置,所有配置宏后面都有注釋說(shuō)明
注意 TM_MAX_CSIZE,TM_MAX_KSIZE,TM_MAX_KCSIZE 會(huì)占用靜態(tài)緩存。
最后你只需要把他們放進(jìn)你的工程里編譯~
下面是我修改tm_port.h的配置,內(nèi)存分配函數(shù)和printf打印。有內(nèi)存分配所以注意堆空間設(shè)置大點(diǎn)。
/******************************* PORT CONFIG ************************************/
#define TM_ARCH TM_ARCH_CPU
#define TM_OPT_LEVEL TM_OPT0
#define TM_MDL_TYPE TM_MDL_INT8
#define TM_FASTSCALE (0) //enable if your chip don't have FPU, may speed up 1/3, but decrease accuracy
#define TM_LOCAL_MATH (0) //use local math func (like exp()) to avoid libm
#define TM_ENABLE_STAT (1) //enable mdl stat functions
#define TM_MAX_CSIZE (1000) //max channel num //used if INT8 mdl //cost TM_MAX_CSIZE*4 Byte
#define TM_MAX_KSIZE (5*5) //max kernel_size //cost TM_MAX_KSIZE*4 Byte
#define TM_MAX_KCSIZE (3*3*256) //max kernel_size*channels //cost TM_MAX_KSIZE*sizeof(mtype_t) Byte
#define TM_INLINE __attribute__((always_inline)) static inline
#define TM_WEAK __attribute__((weak))
#define tm_malloc(x) malloc(x)
#define tm_free(x) free(x)
#define TM_PRINTF(...) printf(__VA_ARGS__)
#define TM_DBG(...) TM_PRINTF("###L%d: ",__LINE__);TM_PRINTF(__VA_ARGS__);
#define TM_DBGL() TM_PRINTF("###L%drn",__LINE__);
下面是時(shí)間獲取配置。通過(guò)SysTick中斷1ms定時(shí)uwTick++計(jì)時(shí)。
/******************************* DBG TIME CONFIG ************************************/
extern volatile uint32_t uwTick;
#define TM_GET_US() ((uint32_t)uwTick)
#define TM_DBGT_INIT() uint32_t _start,_finish;float _time;_start=TM_GET_US();
#define TM_DBGT_START() _start=TM_GET_US();
#define TM_DBGT(x) {_finish=TM_GET_US();
_time = (float)(_finish-_start);
TM_PRINTF("===%s use %.3f msrn", (x), _time);
_start=TM_GET_US();}
如何使用 (API)
使用步驟-》1、加載模型-》2、輸入數(shù)據(jù)預(yù)處理-》3、運(yùn)行模型-》4、移除模型
加載模型
tm_err_t tm_load (tm_mdl_t* mdl, const uint8_t* bin, uint8_tbuf, tm_cb_t cb, tm_mat_t in);
mdl: 模型句柄;
bin: 模型bin內(nèi)容;
buf: 中間結(jié)果的主緩存;如果NULL,則內(nèi)部自動(dòng)malloc申請(qǐng);否則使用提供的緩存地址
cb: 網(wǎng)絡(luò)層回調(diào)函數(shù);
in: 返回輸入張量,包含輸入緩存地址 //可以忽略之,如果你使用自己的靜態(tài)輸入緩存
輸入數(shù)據(jù)預(yù)處理
tm_err_t tm_preprocess(tm_mdl_t* mdl, tm_pp_t pp_type, tm_mat_t* in, tm_mat_t* out);
TMPP_FP2INT //用戶自己的浮點(diǎn)緩存轉(zhuǎn)換到int8緩存
TMPP_UINT2INT //典型uint8原地轉(zhuǎn)換到int8數(shù)據(jù);int16則需要額外緩存
TMPP_UINT2FP01 //uint8轉(zhuǎn)換到01的浮點(diǎn)數(shù) u8/255.0
TMPP_UINT2FPN11//uint8轉(zhuǎn)換到-11的浮點(diǎn)數(shù)
運(yùn)行模型
tm_err_t tm_run (tm_mdl_t* mdl, tm_mat_t* in, tm_mat_t* out);
移除模型
void tm_unload(tm_mdl_t* mdl);
測(cè)試代碼流程:
int main_demo(void)
{
TM_DBGT_INIT();
TM_PRINTF("mnist demorn");
tm_mdl_t mdl;
for(int i=0; i<28*28; i++){
TM_PRINTF("%3d,", mnist_pic[i]);
if(i%28==27)TM_PRINTF("rn");
}
tm_mat_t in_uint8 = {3,28,28,1, {(mtype_t*)mnist_pic}};
tm_mat_t in = {3,28,28,1, {NULL}};
tm_mat_t outs[1];
tm_err_t res;
// tm_stat((tm_mdlbin_t*)mdl_data);
res = tm_load(&mdl, mdl_data, NULL, layer_cb, &in);
if(res != TM_OK) {
TM_PRINTF("tm model load err %drn", res);
return -1;
}
#if (TM_MDL_TYPE == TM_MDL_INT8) || (TM_MDL_TYPE == TM_MDL_INT16)
res = tm_preprocess(&mdl, TMPP_UINT2INT, &in_uint8, &in);
#else
res = tm_preprocess(&mdl, TMPP_UINT2FP01, &in_uint8, &in);
#endif
TM_DBGT_START();
res = tm_run(&mdl, &in, outs);
TM_DBGT("tm_run");
if(res==TM_OK) parse_output(outs);
else TM_PRINTF("tm run error: %drn", res);
tm_unload(&mdl);
return 0;
}
怎樣添加新平臺(tái)的加速代碼
對(duì)于新增平臺(tái),你只需要在src里添加arch_xxx.h文件并實(shí)現(xiàn)其中的函數(shù)即可,主要為以下幾個(gè)函數(shù)(重要性降序排列,不重要的函數(shù)可以直接拷貝純CPU運(yùn)算的函數(shù)):
a. TM_INLINE void tm_dot_prod(mtype_t* sptr, mtype_t* kptr,uint32_t size, sumtype_t* result)
實(shí)現(xiàn)平臺(tái)相關(guān)的點(diǎn)積函數(shù),可以使用MAC相關(guān)的加速指令加速。
b. TM_INLINE void tm_dot_prod_pack2(mtype_t* sptr, mtype_t* kptr, uint32_t size, sumtype_t* result)
實(shí)現(xiàn)平臺(tái)相關(guān)的雙通道點(diǎn)積函數(shù)。(僅提供到雙通道是因?yàn)橛行┬酒脚_(tái)的寄存器不足以支持更多通道的點(diǎn)積加速)
c. TM_INLINE void tm_postprocess_sum(int n, sumtype_t* sums, btype_t* bs, int act, mtype_t* outp, sctype_t* scales, sctype_t out_s, zptype_t out_zp)
實(shí)現(xiàn)平臺(tái)相關(guān)的批量后處理函數(shù),注意n為2的次冪。
d. TM_INLINE void tm_dot_prod_3x3x1(mtype_t* sptr, mtype_t* kptr, sumtype_t* result)
實(shí)現(xiàn)平臺(tái)相關(guān)的3x3點(diǎn)積加速
e. TM_INLINE void tm_dot_prod_gap_3x3x1(mtype_t* sptr, mtype_t* kptr, uint32_t* k_oft, sumtype_t* result)
實(shí)現(xiàn)平臺(tái)相關(guān)的3x3 gap的點(diǎn)積加速
工程文件:
lpc55s69_tinyml_s.zip (14.5 MB)