Coverage for src/ensae_teaching_dl/examples/keras_mnist.py: 0%
76 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-25 02:07 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-25 02:07 +0200
1"""
2@file
3@brief Taken from `mnist_cnn.py <https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py>`_.
5Trains a simple convolution network on the :epkg:`MNIST` dataset.
7Gets to 99.25% test accuracy after 12 epochs
8(there is still a lot of margin for parameter tuning).
916 seconds per epoch on a GRID K520 GPU.
10"""
13def keras_mnist_data():
14 """
15 Retrieves the :epkg:`MNIST` database for :epkg:`keras`.
16 """
17 from keras.datasets import mnist
18 from keras.utils import np_utils
19 from keras import backend as K
21 # the data, shuffled and split between train and test sets
22 (X_train, y_train), (X_test, y_test) = mnist.load_data()
23 img_rows, img_cols = 28, 28 # should be cmputed from the data
25 try:
26 imgord = K.image_data_format()
27 except Exception: # pylint: disable=W0703
28 # older version
29 try:
30 imgord = K.common.image_dim_ordering() # pylint: disable=E1101
31 except Exception: # pylint: disable=W0703
32 # older version
33 imgord = K.image_dim_ordering() # pylint: disable=E1101
35 if imgord == 'th':
36 X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
37 X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
38 else:
39 X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
40 X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
42 X_train = X_train.astype('float32')
43 X_test = X_test.astype('float32')
44 X_train /= 255
45 X_test /= 255
47 # convert class vectors to binary class matrices
48 nb_classes = len(set(y_train))
49 Y_train = np_utils.to_categorical(y_train, nb_classes)
50 Y_test = np_utils.to_categorical(y_test, nb_classes)
51 return (X_train, Y_train), (X_test, Y_test)
54def keras_build_mnist_model(nb_classes, fLOG=None):
55 """
56 Builds a :epkg:`CNN` for :epkg:`MNIST` with :epkg:`keras`.
58 @param nb_classes number of classes
59 @param fLOG logging function
60 @return the model
61 """
62 from keras.models import Sequential
63 from keras.layers import (
64 Dense, Dropout, Activation, Flatten,
65 Convolution2D, MaxPooling2D)
66 from keras import backend as K
68 try:
69 imgord = K.image_data_format()
70 except Exception: # pylint: disable=W0703
71 # older version
72 try:
73 imgord = K.common.image_dim_ordering() # pylint: disable=E1101
74 except Exception: # pylint: disable=W0703
75 # older version
76 imgord = K.image_dim_ordering() # pylint: disable=E1101
78 model = Sequential()
80 nb_filters = 32
81 pool_size = (2, 2)
82 kernel_size = (3, 3)
83 img_rows, img_cols = 28, 28 # should be cmputed from the data
85 fLOG("[keras_build_mnist_model] K.image_dim_ordering()={0}".format(imgord))
86 if imgord == 'th':
87 input_shape = (1, img_rows, img_cols)
88 else:
89 input_shape = (img_rows, img_cols, 1)
91 try:
92 model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1],
93 padding='valid', input_shape=input_shape))
94 except Exception: # pylint: disable=W0703
95 # older version
96 model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1],
97 border_mode='valid', input_shape=input_shape))
98 model.add(Activation('relu'))
99 model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1]))
100 model.add(Activation('relu'))
101 model.add(MaxPooling2D(pool_size=pool_size))
102 model.add(Dropout(0.25))
104 model.add(Flatten())
105 model.add(Dense(128))
106 model.add(Activation('relu'))
107 model.add(Dropout(0.5))
108 model.add(Dense(nb_classes))
109 model.add(Activation('softmax'))
111 model.compile(loss='categorical_crossentropy',
112 optimizer='adadelta',
113 metrics=['accuracy'])
114 return model
117def keras_fit(model, X_train, Y_train, X_test, Y_test, batch_size=128,
118 nb_classes=None, epochs=12, fLOG=None):
119 """
120 Fits a :epkg:`keras` model.
122 @param model :epkg:`keras` model
123 @param X_train training features
124 @param Y_train training target
125 @param X_test test features
126 @param Y_test test target
127 @param batch_size batch size
128 @param nb_classes nb_classes
129 @param epochs number of iterations
130 @param fLOG logging function
131 @return model
132 """
133 # numpy.random.seed(1337) # for reproducibility
135 if nb_classes is None:
136 nb_classes = Y_train.shape[1]
137 if fLOG:
138 fLOG("[keras_fit] nb_classes=%d" % nb_classes)
139 try:
140 model.fit(X_train, Y_train, batch_size=batch_size, epochs=epochs,
141 verbose=1, validation_data=(X_test, Y_test))
142 except Exception: # pylint: disable=W0703
143 model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=epochs,
144 verbose=1, validation_data=(X_test, Y_test))
145 return model
148def keras_predict(model, X_test, Y_test):
149 """
150 Computes the predictions with a :epkg:`keras` model.
152 @param model :epkg:`keras` model
153 @param X_test test features
154 @param Y_test test target
155 @return score
156 """
157 score = model.evaluate(X_test, Y_test, verbose=0)
158 return score