반응형
def train_mnist():
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}) :
if logs.get('val_accuracy') > 0.98 :
print('\nReached 98% accuracy so cancelling training!')
self.model.stop_training=True
my_callback = myCallback()
mnist = tf.keras.datasets.mnist
(X_train, y_train),(X_test, y_test) = mnist.load_data()
X_train = X_train / 255.0
X_test = X_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, 'relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(512, 'relu'),
tf.keras.layers.Dense(10, 'softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
history = model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test), callbacks=[my_callback] )
return history.epoch, history.history['accuracy'][-1]
+ Confusion Matrix
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test.argmax(axis=1), y_pred.argmax(axis=1))
cm
import seaborn as sns
sns.heatmap(cm, annot=True, fmt='.2f', cmap='RdPu')
+ Graph
import matplotlib.pyplot as plt
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend( ['train accuracy', 'val accuracy'] )
plt.show()
반응형