# USAGE
# python test_network.py --model <name> --dataset <path/to/test/images>
# e.g python test_network.py --model my.model --testset images/test
#
# import the necessary packages
from keras.preprocessing.image import img_to_array
from keras.models import load_model
from keras import backend as K
import numpy as np
import argparse
import imutils
import cv2
import os

# Resize parmeters (RESIZE should be the same as used in training)
DO_RESIZE = True
RESIZE = 28

# Put here labels for classes to recognize
# For pg-santa-other
#CLASS_LABELS = ["Other", "PG", "Santa"]
# For mnist:
CLASS_LABELS = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]

# construct the argument parse and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True, help="path to trained model")
ap.add_argument("-d", "--dataset", required=True, help="path to test images")
args = vars(ap.parse_args())

# load the trained convolutional neural network
print("[INFO] loading network...")
model = load_model(args["model"])

print("[INFO] classifying...")
for image in os.listdir(args["dataset"]):

    # load the image
    image = cv2.imread(args["dataset"]+"/"+image)
    orig = image.copy()

    # pre-process the image for classification
    if DO_RESIZE:
        image = cv2.resize(image, (RESIZE, RESIZE))
    image = image.astype("float") / 255.0
    image = img_to_array(image)
    image = np.expand_dims(image, axis=0)

    # classify the input image
    prediction = list(model.predict(image)[0])

    # Find the winner class and the probability
    winnerclass = prediction.index(max(prediction))
    winnerprobability = round(max(prediction)*100, 2)
    
    # build the label
    label = "{}: {:.2f}%".format(CLASS_LABELS[winnerclass], winnerprobability)

    # draw the label on the image
    output = imutils.resize(orig, width=600)
    cv2.putText(output, label, (10, 25),  cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

    # show the output image
    cv2.imshow("Output", output)
    cv2.waitKey(0)
        
K.clear_session()
