This repository has been archived on 2023-07-17. You can view files and clone it, but cannot push or open issues or pull requests.
bl_mcu_sdk/components/ai/TinyMaix/src/tm_model.c

159 lines
6.4 KiB
C

/* Copyright 2022 Sipeed Technology Co., Ltd. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tinymaix.h"
//load model
//mdl: model handle; bin: model bin buf; buf: main buf for middle output; cb: layer callback;
//in: return input mat, include buf addr; //you can ignore it if use static buf
tm_err_t TM_WEAK tm_load (tm_mdl_t* mdl, const uint8_t* bin, uint8_t*buf, tm_cb_t cb, tm_mat_t* in)
{
tm_mdlbin_t* mdl_bin = (tm_mdlbin_t*)bin;
if(mdl_bin->magic != TM_MDL_MAGIC) return TM_ERR_MAGIC; //FIXME: big-endian not compatible
if(mdl_bin->mdl_type != TM_MDL_TYPE) return TM_ERR_MDLTYPE;
mdl->b = mdl_bin;
mdl->cb = (void*)cb;
if(buf == NULL) {
mdl->buf = (uint8_t*)tm_malloc(mdl->b->buf_size);
if(mdl->buf == NULL) return TM_ERR_OOM;
mdl->main_alloc = 1;
} else {
mdl->buf = buf;
mdl->main_alloc = 0;
}
if(mdl->b->sub_size > 0) {
mdl->subbuf = (uint8_t*)tm_malloc(mdl->b->sub_size);
if(mdl->subbuf == NULL) return TM_ERR_OOM;
} else mdl->subbuf = NULL;
mdl->layer_i = 0;
mdl->layer_body = mdl->b->layers_body;
memcpy((void*)in, (void*)mdl->b->in_dims, sizeof(tm_mat_t));
in->data = (mtype_t*)mdl->buf; //input at 0 oft
return TM_OK;
}
//remove model
void TM_WEAK tm_unload(tm_mdl_t* mdl)
{
if(mdl->main_alloc) tm_free(mdl->buf);
return;
}
//preprocess data input
tm_err_t TM_WEAK tm_preprocess(tm_mdl_t* mdl, tm_pp_t pp_type, tm_mat_t* in, tm_mat_t* out)
{
tml_head_t* l0h = (tml_head_t*)mdl->b->layers_body;
sctype_t in_s = l0h->in_s;
zptype_t in_zp= l0h->in_zp;
int in_size = in->h*in->w*in->c;
switch(pp_type){
#if (TM_MDL_TYPE == TM_MDL_INT8)||(TM_MDL_TYPE == TM_MDL_INT16)
case TMPP_FP2INT:
for(int i=0; i<in_size; i++)
out->data[i] = (mtype_t)(in->dataf[i]/in_s + in_zp);
break;
case TMPP_UINT2INT:
for(int i=0; i<in_size; i++)
out->data[i] = ((mtype_t)(((uint8_t*)(in->data))[i]-128))<<UINT2INT_SHIFT;
break;
#else
case TMPP_UINT2FP01:
for(int i=0; i<in_size; i++)
out->data[i] = (((uint8_t*)(in->data))[i])/255.0;
break;
case TMPP_UINT2FPN11:
for(int i=0; i<in_size; i++)
out->data[i] = ((((uint8_t*)(in->data))[i])-128)/128.0;
break;
#endif
default: //don't do anything
out->data = in->data;
break;
}
return TM_OK;
}
//run model
//mdl: model handle; in: input mat; out: output mat
tm_err_t TM_WEAK tm_run(tm_mdl_t* mdl, tm_mat_t* in, tm_mat_t* out)
{
tm_mat_t _in, _in1, _out;
tm_err_t res = TM_OK;
int out_idx = 0;
memcpy((void*)&_in, (void*)in, sizeof(tm_mat_t));
mdl->layer_body = mdl->b->layers_body;
for(mdl->layer_i = 0; mdl->layer_i < mdl->b->layer_cnt; mdl->layer_i++){
tml_head_t* h = (tml_head_t*)(mdl->layer_body);
if(mdl->layer_i>0) {
_in.data = (mtype_t *)(mdl->buf + h->in_oft);
memcpy((void*)&_in, (void*)(h->in_dims), sizeof(uint16_t)*4);
}
_out.data = (mtype_t *)(mdl->buf + h->out_oft);
memcpy((void*)&_out, (void*)(h->out_dims), sizeof(uint16_t)*4);
switch(h->type){
case TML_CONV2D:
case TML_DWCONV2D:{
tml_conv2d_dw_t* l = (tml_conv2d_dw_t*)(mdl->layer_body);
res = tml_conv2d_dwconv2d(&_in, &_out, (wtype_t*)(mdl->layer_body + l->w_oft), (btype_t*)(mdl->layer_body + l->b_oft), \
l->kernel_w, l->kernel_h, l->stride_w, l->stride_h, l->dilation_w, l->dilation_h, \
l->act, l->pad[0], l->pad[1], l->pad[2], l->pad[3], l->depth_mul, \
(sctype_t*)(mdl->layer_body + l->ws_oft), h->in_s, h->in_zp, h->out_s, h->out_zp);
break;}
case TML_GAP: {
tml_gap_t* l = (tml_gap_t*)(mdl->layer_body);
res = tml_gap(&_in, &_out, h->in_s, h->in_zp, h->out_s, h->out_zp);
break;}
case TML_FC: {
tml_fc_t* l = (tml_fc_t*)(mdl->layer_body);
res = tml_fc(&_in, &_out, (wtype_t*)(mdl->layer_body + l->w_oft), (btype_t*)(mdl->layer_body + l->b_oft), \
(sctype_t*)(mdl->layer_body + l->ws_oft), h->in_s, h->in_zp, h->out_s, h->out_zp);
break;}
case TML_SOFTMAX: {
tml_softmax_t* l = (tml_softmax_t*)(mdl->layer_body);
res = tml_softmax(&_in, &_out, h->in_s, h->in_zp, h->out_s, h->out_zp);
break; }
case TML_RESHAPE: {
tml_reshape_t* l = (tml_reshape_t*)(mdl->layer_body);
res = tml_reshape(&_in, &_out, h->in_s, h->in_zp, h->out_s, h->out_zp);
break; }
case TML_ADD: {
tml_add_t* l = (tml_add_t*)(mdl->layer_body);
memcpy((void*)&_in1, (void*)(h->in_dims), sizeof(uint16_t)*4);
_in1.data = (mtype_t *)(mdl->buf + l->in_oft1);
res = tml_add(&_in, &_in1, &_out, h->in_s, h->in_zp, l->in_s1, l->in_zp1, h->out_s, h->out_zp);
break; }
default:
res = TM_ERR_LAYERTYPE;
break;
}
if(res != TM_OK) return res;
if(mdl->cb) ((tm_cb_t)mdl->cb)(mdl, h); //layer callback
if(h->is_out) {
memcpy((void*)(&out[out_idx]), (void*)(&(h->out_dims)), sizeof(uint16_t)*4);
if(mdl->b->out_deq == 0 || TM_MDL_TYPE == TM_MDL_FP32) //fp32 do not need deq
out[out_idx].data = (mtype_t*)(TML_GET_OUTPUT(mdl, h));
else {
int out_size = h->out_dims[1]*h->out_dims[2]*h->out_dims[3];
float* outf = (float*)(TM_ALIGN(TML_GET_OUTPUT(mdl, h) + out_size));
for(int i=0; i<out_size; i++) //do dequant
outf[i] = TML_DEQUANT(h, (TML_GET_OUTPUT(mdl, h))[i]);
out[out_idx].dataf = outf;
}
out_idx += 1;
}
mdl->layer_body += (h->size);
}
return TM_OK;
}