Coverage for src/ensae_teaching_dl/examples/keras_mnist.py: 0%

76 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-25 02:07 +0200

1""" 

2@file 

3@brief Taken from `mnist_cnn.py <https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py>`_. 

4 

5Trains a simple convolution network on the :epkg:`MNIST` dataset. 

6 

7Gets to 99.25% test accuracy after 12 epochs 

8(there is still a lot of margin for parameter tuning). 

916 seconds per epoch on a GRID K520 GPU. 

10""" 

11 

12 

13def keras_mnist_data(): 

14 """ 

15 Retrieves the :epkg:`MNIST` database for :epkg:`keras`. 

16 """ 

17 from keras.datasets import mnist 

18 from keras.utils import np_utils 

19 from keras import backend as K 

20 

21 # the data, shuffled and split between train and test sets 

22 (X_train, y_train), (X_test, y_test) = mnist.load_data() 

23 img_rows, img_cols = 28, 28 # should be cmputed from the data 

24 

25 try: 

26 imgord = K.image_data_format() 

27 except Exception: # pylint: disable=W0703 

28 # older version 

29 try: 

30 imgord = K.common.image_dim_ordering() # pylint: disable=E1101 

31 except Exception: # pylint: disable=W0703 

32 # older version 

33 imgord = K.image_dim_ordering() # pylint: disable=E1101 

34 

35 if imgord == 'th': 

36 X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols) 

37 X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols) 

38 else: 

39 X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1) 

40 X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1) 

41 

42 X_train = X_train.astype('float32') 

43 X_test = X_test.astype('float32') 

44 X_train /= 255 

45 X_test /= 255 

46 

47 # convert class vectors to binary class matrices 

48 nb_classes = len(set(y_train)) 

49 Y_train = np_utils.to_categorical(y_train, nb_classes) 

50 Y_test = np_utils.to_categorical(y_test, nb_classes) 

51 return (X_train, Y_train), (X_test, Y_test) 

52 

53 

54def keras_build_mnist_model(nb_classes, fLOG=None): 

55 """ 

56 Builds a :epkg:`CNN` for :epkg:`MNIST` with :epkg:`keras`. 

57 

58 @param nb_classes number of classes 

59 @param fLOG logging function 

60 @return the model 

61 """ 

62 from keras.models import Sequential 

63 from keras.layers import ( 

64 Dense, Dropout, Activation, Flatten, 

65 Convolution2D, MaxPooling2D) 

66 from keras import backend as K 

67 

68 try: 

69 imgord = K.image_data_format() 

70 except Exception: # pylint: disable=W0703 

71 # older version 

72 try: 

73 imgord = K.common.image_dim_ordering() # pylint: disable=E1101 

74 except Exception: # pylint: disable=W0703 

75 # older version 

76 imgord = K.image_dim_ordering() # pylint: disable=E1101 

77 

78 model = Sequential() 

79 

80 nb_filters = 32 

81 pool_size = (2, 2) 

82 kernel_size = (3, 3) 

83 img_rows, img_cols = 28, 28 # should be cmputed from the data 

84 

85 fLOG("[keras_build_mnist_model] K.image_dim_ordering()={0}".format(imgord)) 

86 if imgord == 'th': 

87 input_shape = (1, img_rows, img_cols) 

88 else: 

89 input_shape = (img_rows, img_cols, 1) 

90 

91 try: 

92 model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1], 

93 padding='valid', input_shape=input_shape)) 

94 except Exception: # pylint: disable=W0703 

95 # older version 

96 model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1], 

97 border_mode='valid', input_shape=input_shape)) 

98 model.add(Activation('relu')) 

99 model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1])) 

100 model.add(Activation('relu')) 

101 model.add(MaxPooling2D(pool_size=pool_size)) 

102 model.add(Dropout(0.25)) 

103 

104 model.add(Flatten()) 

105 model.add(Dense(128)) 

106 model.add(Activation('relu')) 

107 model.add(Dropout(0.5)) 

108 model.add(Dense(nb_classes)) 

109 model.add(Activation('softmax')) 

110 

111 model.compile(loss='categorical_crossentropy', 

112 optimizer='adadelta', 

113 metrics=['accuracy']) 

114 return model 

115 

116 

117def keras_fit(model, X_train, Y_train, X_test, Y_test, batch_size=128, 

118 nb_classes=None, epochs=12, fLOG=None): 

119 """ 

120 Fits a :epkg:`keras` model. 

121 

122 @param model :epkg:`keras` model 

123 @param X_train training features 

124 @param Y_train training target 

125 @param X_test test features 

126 @param Y_test test target 

127 @param batch_size batch size 

128 @param nb_classes nb_classes 

129 @param epochs number of iterations 

130 @param fLOG logging function 

131 @return model 

132 """ 

133 # numpy.random.seed(1337) # for reproducibility 

134 

135 if nb_classes is None: 

136 nb_classes = Y_train.shape[1] 

137 if fLOG: 

138 fLOG("[keras_fit] nb_classes=%d" % nb_classes) 

139 try: 

140 model.fit(X_train, Y_train, batch_size=batch_size, epochs=epochs, 

141 verbose=1, validation_data=(X_test, Y_test)) 

142 except Exception: # pylint: disable=W0703 

143 model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=epochs, 

144 verbose=1, validation_data=(X_test, Y_test)) 

145 return model 

146 

147 

148def keras_predict(model, X_test, Y_test): 

149 """ 

150 Computes the predictions with a :epkg:`keras` model. 

151 

152 @param model :epkg:`keras` model 

153 @param X_test test features 

154 @param Y_test test target 

155 @return score 

156 """ 

157 score = model.evaluate(X_test, Y_test, verbose=0) 

158 return score