Tensorflow - CNN modeling cheet sheet

반응형

Library

import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Flatten
from keras.layers import Conv2D, MaxPooling2D

 

Modeling

def CNN_model() :
  class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}) :
      if logs.get('val_accuracy') >= 0.96 :
        print('\nReached 96% accuracy so cancelling training!')
        self.model.stop_training=True

  my_callback = myCallback()

  mnist = tf.keras.datasets.mnist
  (X_train, y_train), (X_test, y_test) = mnist.load_data()

  X_train = X_train / 255.0
  X_test = X_test / 255.0

  X_train = X_train.reshape(60000, 28, 28, 1)
  X_test = X_train.reshape(10000, 28, 28, 1)

  model = Sequential ([
    Conv2D(filters=64, kernel_size=(3,3), activation='relu', input_shape=(28,28,1) ),
    MaxPooling2D(pool_size=(2,2), strides=2),

    Flatten(),
    Dense(units=120, activation='relu'),
    Dense(units=10, activation='softmax')
  ])
  
  model.compile('adam', 'sparse_categorical_crossentropy', ['accuracy'])
  
  history = model.fit(X_train, y_train, epochs=20, validation_split=0.2, callbacks=[my_callback] )
  return history.epoch, history.history['accuracy'][-1]
반응형