인공지능/딥러닝

CNN(Convolutional Neural Network)-CIFAR 100_ResNet50V2

2^7 2022. 7. 4. 09:52

CIFAR100 - Categorical Classification

 

1. CIFAR100 Data_Set Load

import tensorflow
from tensorflow.keras.datasets import cifar100

(X_train, y_train), (X_test, y_test) = cifar100.load_data(label_mode = 'fine')

2. Data Preprocessing

2-1. Reshape and Normalization

X_train = X_train.reshape((50000,  32, 32, 3))
X_test = X_test.reshape((10000,  32, 32, 3))
X_train = X_train.astype(float) / 255
X_test = X_test.astype(float) / 255

2-2. One Hot Encoding

from tensorflow.keras.utils import to_categorical

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

3. Import ResNet50V2 Model

3-1. conv_base

from tensorflow.keras.applications import ResNet50V2

conv_base = ResNet50V2(weights = 'imagenet',
                      include_top = False,
                      input_shape = (32, 32, 3))

4. Keras Modeling

4-1. Model Define

from tensorflow.keras import models
from tensorflow.keras import layers

model = models.Sequential()
model.add(conv_base)

model.add(layers.Flatten())
model.add(layers.Dense(256, activation = 'relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(100, activation = 'softmax'))
model.summary()


4-2. Model Compile

model.compile(loss = 'categorical_crossentropy',
              optimizer = 'adam',
              metrics = ['accuracy'])

4-3. Model Fit

%%time

Hist_mnist = model.fit(X_train, y_train,
                       epochs = 60,
                       batch_size = 128,
                       validation_split = 0.2)


4-4.학습 결과 시각화

import matplotlib.pyplot as plt

epochs = range(1, len(Hist_mnist.history['loss']) + 1)

plt.figure(figsize = (9, 6))
plt.plot(epochs, Hist_mnist.history['loss'])
plt.plot(epochs, Hist_mnist.history['val_loss'])
# plt.ylim(0, 0.4)
plt.title('Training & Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(['Training Loss', 'Validation Loss'])
plt.grid()
plt.show()


4-5. Model Evaluate

loss, accuracy = model.evaluate(X_test, y_test)

print('Loss = {:.5f}'.format(loss))
print('Accuracy = {:.5f}'.format(accuracy))

728x90