This repository has been archived on 2023-07-17. You can view files and clone it, but cannot push or open issues or pull requests.
bl_mcu_sdk/components/TinyMaix/examples/mnist/mnist_train.ipynb

348 lines
12 KiB
Plaintext
Raw Normal View History

2022-11-18 16:05:54 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from keras.datasets import mnist\n",
"from tensorflow.python.keras.backend import set_session\n",
"from tensorflow.python.keras.models import load_model\n",
"from tensorflow.keras.models import Model, load_model, Sequential\n",
"from tensorflow.keras.layers import Input,Conv2D, Dense, MaxPooling2D, Softmax, Activation, BatchNormalization, Flatten, Dropout, DepthwiseConv2D\n",
"from tensorflow.keras.layers import MaxPool2D, AvgPool2D, AveragePooling2D, GlobalAveragePooling2D,ZeroPadding2D,Input,Embedding,PReLU,Reshape\n",
"from keras.callbacks import ModelCheckpoint\n",
"from keras.callbacks import TensorBoard\n",
"from keras.utils import np_utils\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"import keras.backend as K\n",
"import tensorflow as tf\n",
"import time\n",
"\n",
"from os import environ\n",
"environ['CUDA_VISIBLE_DEVICES'] = '0'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"(x_train,y_train), (x_test,y_test) = mnist.load_data() "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"x_train = x_train.reshape(x_train.shape[0],x_train.shape[1],x_train.shape[2],1)/255\n",
"x_test = x_test.reshape(x_test.shape[0],x_test.shape[1],x_test.shape[2],1)/255\n",
"\n",
"y_train = np_utils.to_categorical(y_train,num_classes=10) \n",
"y_test = np_utils.to_categorical(y_test,num_classes=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#mnist_valid\n",
"def init_model(dim0):\n",
" model = Sequential()\n",
" model.add(Conv2D(dim0, (3,3), padding = 'valid',strides = (2, 2), input_shape = (28, 28, 1), name='ftr0'));model.add(BatchNormalization(name=\"bn0\"));model.add(Activation('relu', name=\"relu0\")); \n",
" model.add(Conv2D(dim0*2, (3,3), padding = 'valid',strides = (2, 2), name='ftr1'));model.add(BatchNormalization(name=\"bn1\"));model.add(Activation('relu',name=\"relu1\")); \n",
" model.add(Conv2D(dim0*4, (3,3), padding = 'valid',strides = (2, 2), name='ftr2'));model.add(BatchNormalization());model.add(Activation('relu')); \n",
" \n",
" model.add(GlobalAveragePooling2D(name='GAP'))\n",
" model.add(Dense(10, name=\"fc1\"))\n",
" model.add(Activation('softmax', name=\"sm\"))\n",
" return model\n",
"\n",
"DIM0 = 4\n",
"model=init_model(DIM0) \n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#mnist_same\n",
"def init_model(dim0):\n",
" model = Sequential()\n",
" model.add(Conv2D(dim0, (3,3), padding = 'same',strides = (2, 2), input_shape = (28, 28, 1), name='ftr0'));model.add(BatchNormalization(name=\"bn0\"));model.add(Activation('relu', name=\"relu0\")); \n",
" model.add(Conv2D(dim0*2, (3,3), padding = 'same',strides = (2, 2), name='ftr1'));model.add(BatchNormalization(name=\"bn1\"));model.add(Activation('relu',name=\"relu1\")); \n",
" model.add(Conv2D(dim0*4, (3,3), padding = 'same',strides = (2, 2), name='ftr2'));model.add(BatchNormalization());model.add(Activation('relu')); \n",
" \n",
" model.add(GlobalAveragePooling2D(name='GAP'))\n",
" model.add(Dense(10, name=\"fc1\"))\n",
" model.add(Activation('softmax', name=\"sm\"))\n",
" return model\n",
"\n",
"DIM0 = 4\n",
"model=init_model(DIM0) \n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#mnist_fc\n",
"def init_model(dim0):\n",
" model = Sequential()\n",
" model.add(Flatten(input_shape = (28, 28, 1), name='flatten'))\n",
" model.add(Dense(10, name=\"fc3\"))\n",
" model.add(Activation('softmax', name=\"sm\"))\n",
" return model\n",
"\n",
"DIM0 = 4\n",
"model=init_model(DIM0) \n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#mnist_dw\n",
"def init_model(dim0):\n",
" model = Sequential()\n",
" model.add(Conv2D(dim0, (3,3), padding = 'same',strides = (2, 2), input_shape = (28, 28, 1), name='ftr0a'));model.add(BatchNormalization(name=\"bn0\"));model.add(Activation('relu', name=\"relu0\")); \n",
" model.add(DepthwiseConv2D((3,3), padding='same', name='ftr0b'));model.add(BatchNormalization());model.add(Activation('relu', name=\"relu00\")); #32x32\n",
" model.add(Conv2D(dim0*4, (3,3), padding = 'same',strides = (2, 2), name='ftr1a'));model.add(BatchNormalization(name=\"bn1\"));model.add(Activation('relu',name=\"relu1\")); \n",
" model.add(DepthwiseConv2D((3,3), padding = 'same', depth_multiplier=2, name='ftr1b'));model.add(BatchNormalization());model.add(Activation('relu', name=\"relu11\")); \n",
" \n",
" model.add(GlobalAveragePooling2D(name='GAP'))\n",
" model.add(Dense(10, name=\"fc1\"))\n",
" model.add(Activation('softmax', name=\"sm\"))\n",
" return model\n",
"\n",
"DIM0 = 4\n",
"model=init_model(DIM0) \n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#mnist_rect\n",
"def init_model(dim0):\n",
" model = Sequential()\n",
" model.add(Conv2D(dim0, (15,3), padding = 'same',strides = (2, 2), input_shape = (28, 28, 1), name='ftr0'));model.add(BatchNormalization(name=\"bn0\"));model.add(Activation('relu', name=\"relu0\")); \n",
" model.add(Conv2D(dim0*2, (7,3), padding = 'same',strides = (2, 2), name='ftr1'));model.add(BatchNormalization(name=\"bn1\"));model.add(Activation('relu',name=\"relu1\")); \n",
" model.add(Conv2D(dim0*4, (3,3), padding = 'same',strides = (2, 2), name='ftr2'));model.add(BatchNormalization());model.add(Activation('relu',name=\"relu2\")); \n",
" \n",
" model.add(GlobalAveragePooling2D(name='GAP'))\n",
" model.add(Dense(10, name=\"fc1\"))\n",
" model.add(Activation('softmax', name=\"sm\"))\n",
" return model\n",
"\n",
"DIM0 = 4\n",
"model=init_model(DIM0) \n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#mnist_nnom\n",
"def init_model(dim0):\n",
" model = Sequential()\n",
" model.add(Conv2D(dim0, (3,3), padding = 'same',strides = (2, 2), input_shape = (28, 28, 1), name='ftr0'));model.add(BatchNormalization(name=\"bn0\"));model.add(Activation('relu', name=\"relu0\")); \n",
" model.add(Conv2D(dim0*2, (3,3), padding = 'same',strides = (2, 2), name='ftr1'));model.add(BatchNormalization(name=\"bn1\"));model.add(Activation('relu',name=\"relu1\")); \n",
" model.add(Conv2D(dim0*4, (3,3), padding = 'same',strides = (2, 2), name='ftr2'));model.add(BatchNormalization());model.add(Activation('relu',name=\"relu2\")); \n",
" \n",
" model.add(Conv2D(dim0*8, (4,4), padding = 'valid', name='ftr3'));\n",
" model.add(Reshape((96,),name='reshape'))\n",
" model.add(Dense(10, name=\"fc1\"))\n",
" \n",
" model.add(Activation('softmax', name=\"sm\"))\n",
" return model\n",
"\n",
"DIM0 = 12\n",
"model=init_model(DIM0) \n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#mnist_arduino\n",
"def init_model(dim0):\n",
" model = Sequential()\n",
" model.add(Conv2D(dim0, (3,3), padding = 'valid',strides = (2, 2), input_shape = (28, 28, 1), name='ftr0'));model.add(BatchNormalization(name=\"bn0\"));model.add(Activation('relu', name=\"relu0\")); \n",
" model.add(Conv2D(dim0*3, (3,3), padding = 'valid',strides = (2, 2), name='ftr1'));model.add(BatchNormalization(name=\"bn1\"));model.add(Activation('relu',name=\"relu1\")); \n",
" model.add(Conv2D(dim0*6, (3,3), padding = 'valid',strides = (2, 2), name='ftr2'));model.add(BatchNormalization());model.add(Activation('relu')); \n",
" \n",
" model.add(GlobalAveragePooling2D(name='GAP'))\n",
" model.add(Dense(10, name=\"fc1\"))\n",
" model.add(Activation('softmax', name=\"sm\"))\n",
" return model\n",
"\n",
"DIM0 = 1\n",
"model=init_model(DIM0) \n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#mnist resnet\n",
"def init_model(dim0):\n",
" inputs = Input(shape=(28,28,1))\n",
" x = Conv2D(dim0, (3,3), padding = 'same',strides = (2, 2), name='ftr0')(inputs)\n",
" x = BatchNormalization(name=\"bn0\")(x)\n",
" x = Activation('relu', name=\"relu0\")(x)\n",
" \n",
" x = Conv2D(dim0*2, (3,3), padding = 'same',strides = (2, 2), name='ftr1')(x)\n",
" x = BatchNormalization(name=\"bn1\")(x)\n",
" x = Activation('relu', name=\"relu1\")(x)\n",
" res = x\n",
" \n",
" x = Conv2D(dim0*2, (3,3), padding = 'same', name='ftr2')(x)\n",
" x = BatchNormalization(name=\"bn2\")(x)\n",
" \n",
" x = res + x\n",
" \n",
" x = Conv2D(dim0*4, (3,3), padding = 'valid',strides = (2, 2), name='ftr3')(x)\n",
" x = Flatten(name='reshape')(x)\n",
" x = Dense(10, name=\"fc1\")(x)\n",
" \n",
" x = Activation('softmax', name=\"sm\")(x)\n",
" model = Model(inputs = inputs, outputs=x)\n",
" \n",
" return model\n",
"\n",
"DIM0 = 12\n",
"model=init_model(DIM0) \n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"EPOCHS = 20\n",
"model.compile(optimizer='adam', loss = \"categorical_crossentropy\", metrics = [\"categorical_accuracy\"]) \n",
"H = model.fit(x_train, y_train, batch_size=64, epochs= EPOCHS, verbose= 1, validation_data = (x_test, y_test), shuffle=True) "
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"h5_file = \"mnist_resnet.h5\"\n",
"model.save(h5_file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = x_test[1]\n",
"for y in range(28):\n",
" for x in range(28):\n",
" print(\"%3d,\"%(int(data[y,x,0]*255)), end=\"\")\n",
" print(\"\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1/1 [==============================] - 0s 190ms/step\n"
]
},
{
"data": {
"text/plain": [
"array([9.6214655e-14, 3.2803200e-18, 1.0000000e+00, 3.8382069e-22,\n",
" 5.4997339e-24, 5.6444559e-32, 5.4293467e-14, 3.6996416e-21,\n",
" 3.2022371e-21, 1.7273456e-21], dtype=float32)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result = model.predict(x_test[1:2])[0]\n",
"result"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tf29",
"language": "python",
"name": "py374_tf21"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}