# USAGE
# python train_network.py --dataset <path/to/train/images> --epoch <epoch> --model <name>
# e.g. python train_network.py --dataset images --epoch 30 --model my.model
#
# Keep training images in the directories:
# <path/to/train/images>/class-0
# <path/to/train/images>/class-1
# ... etc

# import the necessary packages
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras import backend as K
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import img_to_array
from keras.utils import to_categorical
from nn.lenet import LeNet
from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
import argparse
import random
import cv2
import os

# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", required=True, help="path to input dataset")
ap.add_argument("-e", "--epoch", required=True, help="number of epoch")
ap.add_argument("-m", "--model", required=True, help="path to output model")
args = vars(ap.parse_args())

# Images will be resized to the dimension of RESIZE x RESIZE
DO_RESIZE = True
RESIZE = 28
# Images color depth, RGB => 3, grayscale => 1
IMG_DEPTH = 3

# Initial learning rate
INIT_LR = 1e-3
# Batch size
BS = 32

# Get the number of epochs from args
epochs = int(args["epoch"])

# initialize the data and labels
print("[INFO] loading images...")
data = []
labels = []

# grab the image paths and randomly shuffle them
imagePaths = sorted(list(paths.list_images(args["dataset"])))
random.seed()
random.shuffle(imagePaths)

# loop over the input images
for imagePath in imagePaths:
    # load the image and pre-process it
    image = cv2.imread(imagePath)
    if DO_RESIZE:
        image = cv2.resize(image, (RESIZE, RESIZE))
    image = img_to_array(image)
  
    # extract the class label from the image path
    dirname = imagePath.split(os.path.sep)[-2]
    idx = dirname.find("-")
    if idx == -1:
        # File not in "class-x" directory, skip to another file
        continue
        
    # Label is the number in the direcotry, after "-" 
    label = int(dirname[idx+1:])
       
    # Store image and label in lists
    data.append(image)
    labels.append(label)

# Determine number of classes    
no_classes = len(np.unique(labels))    
    
# scale the raw pixel intensities to the range [0, 1]
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)

# partition the data into training and validating splits using 75% of
# the data for training and the remaining 25% for validating
(train_data, valid_data, train_labels, valid_labels) = train_test_split(data, 
    labels, test_size=0.25)

# convert the labels from integers to vectors
train_labels = to_categorical(train_labels, num_classes=no_classes)
valid_labels = to_categorical(valid_labels, num_classes=no_classes)

# construct the image generator for data augmentation
aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,
    height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,
    horizontal_flip=True, fill_mode="nearest")

# initialize the model
print("[INFO] compiling model...")
model = LeNet.build(width=RESIZE, height=RESIZE, depth=IMG_DEPTH, classes=no_classes)
opt = Adam(lr=INIT_LR, decay=INIT_LR / epochs)
if no_classes == 2:
    model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy"])
else:
    model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])

# train the network
print("[INFO] training network...")
H = model.fit_generator(aug.flow(train_data, train_labels, batch_size=BS),
    validation_data=(valid_data, valid_labels), steps_per_epoch=len(train_data) // BS,
    epochs=epochs, verbose=1)

# save the model to disk
print("[INFO] saving model...")
model.save(args["model"])

# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, epochs), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, epochs), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, epochs), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, epochs), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(args["model"]+".png")

K.clear_session()
