Tensorflow - mnist cheet sheet

반응형

 

def train_mnist():

  class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}) :
      if logs.get('val_accuracy') > 0.98 :
        print('\nReached 98% accuracy so cancelling training!')
        self.model.stop_training=True

  my_callback = myCallback()

  mnist = tf.keras.datasets.mnist
  (X_train, y_train),(X_test, y_test) = mnist.load_data()

  X_train = X_train / 255.0
  X_test = X_test / 255.0

  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, 'relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(512, 'relu'),
    tf.keras.layers.Dense(10, 'softmax')
  ])

  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
  
  history = model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test), callbacks=[my_callback] )
  return history.epoch, history.history['accuracy'][-1]

 

+ Confusion Matrix

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test.argmax(axis=1), y_pred.argmax(axis=1))
cm

import seaborn as sns
sns.heatmap(cm, annot=True, fmt='.2f', cmap='RdPu')

 

+ Graph

import matplotlib.pyplot as plt
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend( ['train accuracy', 'val accuracy'] )
plt.show()
반응형