1+ import math
2+ import numpy as np
3+ from sklearn .datasets import load_digits
4+ import tensorflow as tf
5+ from tensorflow .keras import layers
6+ from tinymlgen import port
7+
8+
9+ def get_data ():
10+ np .random .seed (1337 )
11+ x_values , y_values = load_digits (return_X_y = True )
12+ x_values /= x_values .max ()
13+ # reshape to (8 x 8 x 1)
14+ x_values = x_values .reshape ((len (x_values ), 8 , 8 , 1 ))
15+
16+ # split into train, validation, test
17+ TRAIN_SPLIT = int (0.6 * len (x_values ))
18+ TEST_SPLIT = int (0.2 * len (x_values ) + TRAIN_SPLIT )
19+ x_train , x_test , x_validate = np .split (x_values , [TRAIN_SPLIT , TEST_SPLIT ])
20+ y_train , y_test , y_validate = np .split (y_values , [TRAIN_SPLIT , TEST_SPLIT ])
21+
22+ return x_train , x_test , x_validate , y_train , y_test , y_validate
23+
24+ def get_model ():
25+ x_train , x_test , x_validate , y_train , y_test , y_validate = get_data ()
26+
27+ # create a CNN
28+ model = tf .keras .Sequential ()
29+ model .add (layers .Conv2D (8 , (3 , 3 ), activation = 'relu' , input_shape = (8 , 8 , 1 )))
30+ # model.add(layers.MaxPooling2D((2, 2)))
31+ # model.add(layers.Conv2D(64, (3, 3), activation='relu'))
32+ # model.add(layers.MaxPooling2D((2, 2)))
33+ # model.add(layers.Conv2D(64, (3, 3), activation='relu'))
34+ model .add (layers .Flatten ())
35+ # model.add(layers.Dense(16, activation='relu'))
36+ model .add (layers .Dense (len (np .unique (y_train ))))
37+
38+ model .compile (optimizer = 'adam' , loss = tf .keras .losses .SparseCategoricalCrossentropy (from_logits = True ), metrics = ['accuracy' ])
39+ model .fit (x_train , y_train , epochs = 50 , batch_size = 16 ,
40+ validation_data = (x_validate , y_validate ))
41+ return Exp
42+
43+
44+ def test_model (model , x_test , y_test ):
45+ x_test = (x_test / x_test .max ()).reshape ((len (x_test ), 8 , 8 , 1 ))
46+ y_pred = model .predict (x_test ).argmax (axis = 1 )
47+ print ('ACCURACY' , (y_pred == y_test ).sum () / len (y_test ))
48+ exit ()
49+
50+
51+ if __name__ == '__main__' :
52+ model , x_test , y_test = get_model ()
53+ test_model (model , x_test , y_test )
54+ c_code = port (model , variable_name = 'digits_model' , pretty_print = True )
55+ print (c_code )
0 commit comments