import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn

import os
import sys
import numpy as np
import time

import yaml
import json

import cv2 

"""
Author: Jenö Faist, Paul Judis 
Refernces: LFI-3 cnn.py 
"""

"""
This traines a Normal Resnet18 with the a Kegal dataset "Early detection of 3D printing issues" MIXED with a Dataset 
of a our on to Domain Shift the Dataset to fit our camera and printer.
Link to Original Dataset: https://www.kaggle.com/datasets/gauravduttakiit/early-detection-of-3d-printing-issues?select=train
Link to the Mixed Dataset: ??? TODO ???
The Mixed Dataset contains approximately 10% more images of labeled fail and correct prints from the printer of Paul Judis
"""

"""
This file trains the convolutional neural network with the selected dataset and hyperparameters. 
The model is saved in the TRAIN_MODELS folder and can be further trained from there or moved to 
COMPLETE_MODELS to save it.
"""

if __name__ == '__main__':

    """
    You Define all Data Paths here commonly you save the Model and Hyperparmeters in TRAIN_MODELS 
    """

    absolutepath = os.path.dirname(__file__)


    traing_set_PATH = absolutepath+'/DATASETS/early_3D_Kegel_MODIFIED_SET/train'
    hyperparameters_PATH = absolutepath+'/TRAIN_MODELS/hyperparameters_MIX_R18.yaml'

    """
    All Save PATHS
    """
    model_save_PATH = absolutepath+'/TRAIN_MODELS/3D_DEC_MODEL_MIX_R18.pt'
    training_history_PATH = absolutepath+'/TRAIN_MODELS/3D_DEC_MODEL_TRAIN_HISTORY_MIX_R18.json'
    training_save_PATH = absolutepath+'/TRAIN_MODELS/3D_DEC_MODEL_TRAIN_SAVE_MIX_R18.json'


    """
    This are the References Printed in the Consol to insure that CUDA uses your GPU 
    and the DATA Paths are set Correctly
    """
   
    print("STARTING CNN TRAINING")
    print("---------------------------")
    print("Cuda Version:             " + torch.version.cuda)
    print("Cuda:                     "+str(torch.cuda.is_available()))
    print("GPU:                      "+str(torch.cuda.get_device_name()))
    print("Current Folderpath:       "+absolutepath)
    print("Model Saving in:          "+model_save_PATH)
    print("Training Save in:         "+training_save_PATH)
    print("Hyperparamtes in:         "+hyperparameters_PATH)
    print("Training History Save in: "+training_history_PATH)
    print("---------------------------")
  






    """
    Setting up Pytorch and setting the seed so that results can be recreated
    """
    torch.manual_seed(0)
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



    """
    Loading the Hyperparamters if Possible if not use Default Hyperparamters
    """
   
    print("Loading Hyperparamters") # CONSOLE
    

    data = { 
        'default': 
        {'batch_size': 64,
         'num_epochs': 10,
         'learning_rate': 0.001}
        }
    try:
        with open(hyperparameters_PATH, "r") as stream:
            data = yaml.safe_load(stream)
            print("[!!!] Hyperparamters Loaded Successfully ! [!!!]")
    except:
        print("[?!?] Hyperparamters couldn't be loaded ! [?!?]")
        print("[!!!] Using Default Hyperparamters ! [!!!]")

    

    # Set hyperparameters
    num_epochs = data['default']['num_epochs']
    batch_size = data['default']['batch_size']
    learning_rate = data['default']['learning_rate']

    print("---------------------------") # CONSOLE



    """ 
    SETTING UP TRAINING DATA SET 
    Initialize transformations for data augmentation.
    For this Net we just use the RAW Images resized
    """
    transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((256,256)),  
    ]
    )

    # Load the Dataset with the transformations
    train_dataset = torchvision.datasets.ImageFolder(
        root= traing_set_PATH, 
        transform=transform
    )


    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

    """
    In this Path you Define the Neural Network Model that you want be using.
    For this File: its default resnet18
    """

    model = torchvision.models.resnet18(weights='DEFAULT') # base model
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 2)


    """
    This Loades the not fulled trained Model if training was canceled,
    so that the Training of Model can be continued.
    """

    # Last training Epoch and Batch
    last_eb = [0,0]
    # A List to safe the loss over time for later plots
    training_history = [[],[]]

    try:
       model.load_state_dict(torch.load(model_save_PATH),strict=False)
       model.eval()


       with open(training_save_PATH, 'r') as f:
           last_eb = json.load(f)

       with open(training_history_PATH, 'r') as f:
           training_history = json.load(f)



       print("[!!!] Using previous Trained Model Starting from Epoch:"+str(last_eb[0])+" Batch:"+str(last_eb[1])+" [!!!]")
    except:
       print("[!!!] No last training model found, starting new training with base model [!!!]") 


    ### MODEL TORCH MODIFIERS ###
       
    # Problem give errors (TODO)
    # Parallelize training across multiple GPUs
    #model = torch.nn.DataParallel(model)

    # Set the model to run on the device
    model = model.to(device)

   
    """
    Here you are defining wich error function and optimizer you want be using
    """
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)




    ### CONSOLE PRINTS ####
    print("---------------------")
    print("Beginn Training")
    print("Starting with " + "Epoch:"+str(last_eb[0])+" Batch:"+str(last_eb[1]))
    print("---------------------")
    #######################

    # Time Variables for tracking
    start_time = time.time()
    current_time = time.time()
    last_time = 0

    

    """
    In here the Model gets Trainied using the specifications above.
    The trainings loop is in a try so if a KeyboardInterrupt happens argo the Programm gets closed
    the Model gets automaticly saved.
    """
    try:

        
        for epoch in range(last_eb[0],num_epochs):
            """
            If we train a Model that gets continued in Training then this part insures that 
            the last trained batches get skipped.
            """

            list_dt = []

            count = last_eb[1]
            train_loader_iter = iter(train_loader)
            data_load_time_start = time.time()
            data_load_last_time = 0
            list_data_time = []
            for n in range(last_eb[1]):
                next(train_loader_iter)
                data_load_time = time.time() - data_load_time_start - data_load_last_time 
                list_data_time.append(data_load_time)

                avg_dt = sum(list_data_time)/len(list_data_time)

                sys.stdout.write("\033[K")
                est_time =((last_eb[1] -n)*avg_dt)/60**2
                print('Skiping last Trained Batches Estimated Left Time:' + '%.2f h ' % est_time + 'Batches to skip:'+str((last_eb[1] -n)), end='\r')
                data_load_last_time = time.time() - data_load_time_start

            sys.stdout.write("\033[K")   


            print("Training Epoch: "+str(epoch)) # Console




            for inputs, labels in train_loader:

                """
                Train the Model with this Batch
                """
                # Move input and label tensors to the device
                
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero out the optimizer
                optimizer.zero_grad()

                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                # Backward pass
                loss.backward()
                optimizer.step()


                """
                The Real Training Ends here.
                This Part is just for visuals so that you now how far the Model is trained 
                and how much time is left for this Epoch
                """
                current_time = time.time() - start_time
                dt = current_time - last_time
                list_dt.append(dt)

                training_history[0].append(loss.item())

                _, preds = torch.max(outputs, 1)
                training_history[1].append(torch.sum(preds == labels.data).item())

                avg_dt = sum(list_dt)/len(list_dt)
                avg_loss = sum(training_history[0])/len(training_history[0])



                Time_Estimate = (round((len(train_dataset)-count*batch_size)/batch_size)*avg_dt)/60**2 
                sys.stdout.write("\033[K")
                print('Training Model | Estimated Left Time for this Epoch: ' + "%.2f h" % Time_Estimate + "| Current Average Loss: " + "%.5f" % avg_loss + '| Batches Left:'+str(round((len(train_dataset)-count*batch_size)/batch_size)), end='\r')
                last_time = current_time



                count += 1 

            """
            Every Epoch the Model gets Saved Plus the Training History containing all Loses and Accurarcys for every Batch
            """
            sys.stdout.write("\033[K") 


            print(f'Epoch {epoch+1}/{num_epochs} Done, Loss: {sum(training_history[0])/len(training_history[0]):.4f}')
            last_eb = [epoch+1,0]
            with open(training_save_PATH, 'w') as f:
                json.dump(last_eb, f) 
                torch.save(model.state_dict(), model_save_PATH)
            with open(training_history_PATH, 'w') as f:
                json.dump(training_history, f) 
        
            
    except KeyboardInterrupt:

        """
        This Part Saves the Model if at any Time the Training gets cancelled argo the programm gets Closed
        """

        with open(training_save_PATH, 'w') as f:
            sys.stdout.write("\033[K") 
            print("---------------------")
            print("[!!!] Trainings Interruption SAVING TRAINING [!!!]")
            json.dump(last_eb, f) 
            torch.save(model.state_dict(), model_save_PATH)
            print("[!!!] Training Saved [!!!]")
            print("[!!!] Epoch: " + str(epoch) + " Batch: " + str(count)+" [!!!]")
        with open(training_history_PATH, 'w') as f:
            json.dump(training_history, f) 

    """
    After Training save the model and Training Histroy
    """
    sys.stdout.write("\033[K")     
    print("---------------------")
    print(f'Finished Training, Loss: {sum(training_history[0])/len(training_history[0]):.4f}')
    torch.save(model.state_dict(), model_save_PATH)
    with open(training_history_PATH, 'w') as f:
            json.dump(training_history, f)