import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn import os import sys import numpy as np import time import yaml import cv2 import matplotlib.pyplot as plt """ Author: Jenö Faist, Paul Judis Refernces: LFI-3 cnn.py """ """ This File for Testing a Trained Model with the Kegal "Early detection of 3D printing issues" Set Link to Original Dataset: https://www.kaggle.com/datasets/gauravduttakiit/early-detection-of-3d-printing-issues?select=train Wearning: The Test Dataset for what ever reason is not labeled so pleas use this Modified Kegal Dataset with the Labeld Test Dataset Labeled by Hand. Link to the Test Dataset Labeled: ??? TODO ??? Because the Test Dataset was hand labeled there is a MANUAL_MODE so that you can MANUAL specifice by pressing "f" if the current image is a file print or not ! MANUAL_MODE is Perfered because probably the Test Dataset was labeled porley be hand. """ if __name__ == '__main__': """ Setting up from where to load the Dataset and Model ! """ absolutepath = os.path.dirname(__file__) test_set_PATH = absolutepath+'/DATASETS/early_3D_Kegel_SET/test' model_save_PATH = absolutepath+'/COMPLETE_MODELS/3D_DEC_MODEL_MIXR18_18E_64B.pt' MANUAL_MODE = True #Consol print("STARTING CNN TESTING") print("--------------------") print("Cuda Version: " + torch.version.cuda) print("Cuda: "+str(torch.cuda.is_available())) print("GPU: "+str(torch.cuda.get_device_name())) print("MANUAL_MODE: "+str(MANUAL_MODE)) print("--------------------") print("Loading Model...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") """ Specifing which Model we using for testing the Data this! THIS MUST BE CHANGED DEPENDING OF WHICH MODEL YOU ARE USING !!! """ model = torchvision.models.resnet18(weights='DEFAULT') num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 2) """ Loading Model and Dataset if possible else the programm stops """ try: model.load_state_dict(torch.load(model_save_PATH),strict=False) model.eval() print("[✓] Model Loaded [✓]") except: print("[!?] No Model Found [?!]") exit(1) model.eval() model = model.to(device) print("Loading Test Dataset") # Console transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((256,256)), ]) test_dataset = None try: test_dataset = torchvision.datasets.ImageFolder( root=test_set_PATH, transform=transform ) print("[✓] Test Dataset Loaded [✓]") # Console except: print("[!?] Test Dataset couldn't be loaded [?!]") # Console exit(1) print("--------------------") # Console running_accrucay = 0 count = 0 data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) len_test_set = len(data_loader) with torch.no_grad(): for batch_idx, (inputs, labels) in enumerate(data_loader): count +=1 """ Compute current Network Output for the current Image """ img = (inputs[0].squeeze()).numpy() inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) """ MANUAL_MODE shows the current Image and you can press "f" to say if the current image is a failed print or not (not when f not pressed), this is compared to the network output to specifice accurarcy! ELSE normal labeling of the Test Dataset is used to calculate accuracy """ if(MANUAL_MODE): norm_image = cv2.normalize(img, None, alpha = 0, beta = 255, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_64F) norm_image = norm_image.astype(np.uint8) norm_image = np.transpose(norm_image, (1, 2, 0)) cv2.imshow('TEST MODEL PRESS "F" TO MARK AS FAIL',norm_image) k =cv2.waitKey(20) FAILED_PRINT = False if(k==ord("f")): FAILED_PRINT = True if(preds[0].item() == 1 and FAILED_PRINT): running_accrucay +=1 elif(preds[0].item() == 0 and not FAILED_PRINT): running_accrucay +=1 """ JUST FOR VISUALS """ sym = "NO" sym_2 = "NO" if(preds[0].item() == 1): sym = "YES" if FAILED_PRINT: sym_2 = "YES" sys.stdout.write("\033[K") #DEBUG # " NETWORK VALUES: " + str(outputs[0][0].item())+","+str(outputs[0][1].item()) print("CNN FAIL DETECTED: " + sym + " USER FAIL DETECTED: " + sym_2 + " CURRENT ACCURACY: "+str(int(100*(running_accrucay/count)))+"%", end='\r') else: if(labels == preds[0].item()): running_accrucay += 1 sys.stdout.write("\033[K") print("Testing Image "+str(batch_idx)+"/"+str(len_test_set)+" CURRENT ACCURACY: "+str(int(100*(running_accrucay/count)))+"%", end='\r') sys.stdout.write("\033[K") print("TESTING FINISHED ACCURACY: "+str(int(100*(running_accrucay/count)))+"%", end='\r')