Skip to content
Snippets Groups Projects
Commit 733fbc9b authored by s47700's avatar s47700
Browse files

Upload New File

parent f614dad5
No related branches found
No related tags found
No related merge requests found
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import os
import sys
import numpy as np
import time
import yaml
import cv2
import matplotlib.pyplot as plt
"""
Author: Jenö Faist, Paul Judis
Refernces: LFI-3 cnn.py
"""
"""
This File for Testing a Trained Model with the Kegal "Early detection of 3D printing issues" Set
Link to Original Dataset: https://www.kaggle.com/datasets/gauravduttakiit/early-detection-of-3d-printing-issues?select=train
Wearning: The Test Dataset for what ever reason is not labeled
so pleas use this Modified Kegal Dataset with the Labeld Test Dataset Labeled by Hand.
Link to the Test Dataset Labeled: ??? TODO ???
Because the Test Dataset was hand labeled there is a MANUAL_MODE so that you can MANUAL specifice
by pressing "f" if the current image is a file print or not !
MANUAL_MODE is Perfered because probably the Test Dataset was labeled porley be hand.
"""
if __name__ == '__main__':
"""
Setting up from where to load the Dataset and Model !
"""
absolutepath = os.path.dirname(__file__)
test_set_PATH = absolutepath+'/DATASETS/early_3D_Kegel_SET/test'
model_save_PATH = absolutepath+'/COMPLETE_MODELS/3D_DEC_MODEL_MIXR18_18E_64B.pt'
MANUAL_MODE = True
#Consol
print("STARTING CNN TESTING")
print("--------------------")
print("Cuda Version: " + torch.version.cuda)
print("Cuda: "+str(torch.cuda.is_available()))
print("GPU: "+str(torch.cuda.get_device_name()))
print("MANUAL_MODE: "+str(MANUAL_MODE))
print("--------------------")
print("Loading Model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
"""
Specifing which Model we using for testing the Data this!
THIS MUST BE CHANGED DEPENDING OF WHICH MODEL YOU ARE USING !!!
"""
model = torchvision.models.resnet18(weights='DEFAULT')
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
"""
Loading Model and Dataset if possible else the programm stops
"""
try:
model.load_state_dict(torch.load(model_save_PATH),strict=False)
model.eval()
print("[✓] Model Loaded [✓]")
except:
print("[!?] No Model Found [?!]")
exit(1)
model.eval()
model = model.to(device)
print("Loading Test Dataset") # Console
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256,256)),
])
test_dataset = None
try:
test_dataset = torchvision.datasets.ImageFolder(
root=test_set_PATH,
transform=transform
)
print("[✓] Test Dataset Loaded [✓]") # Console
except:
print("[!?] Test Dataset couldn't be loaded [?!]") # Console
exit(1)
print("--------------------") # Console
running_accrucay = 0
count = 0
data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
len_test_set = len(data_loader)
with torch.no_grad():
for batch_idx, (inputs, labels) in enumerate(data_loader):
count +=1
"""
Compute current Network Output for the current Image
"""
img = (inputs[0].squeeze()).numpy()
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
"""
MANUAL_MODE shows the current Image and you can press "f" to say if the current image is a failed print or
not (not when f not pressed), this is compared to the network output to specifice accurarcy!
ELSE normal labeling of the Test Dataset is used to calculate accuracy
"""
if(MANUAL_MODE):
norm_image = cv2.normalize(img, None, alpha = 0, beta = 255, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_64F)
norm_image = norm_image.astype(np.uint8)
norm_image = np.transpose(norm_image, (1, 2, 0))
cv2.imshow('TEST MODEL PRESS "F" TO MARK AS FAIL',norm_image)
k =cv2.waitKey(20)
FAILED_PRINT = False
if(k==ord("f")):
FAILED_PRINT = True
if(preds[0].item() == 1 and FAILED_PRINT):
running_accrucay +=1
elif(preds[0].item() == 0 and not FAILED_PRINT):
running_accrucay +=1
"""
JUST FOR VISUALS
"""
sym = "NO"
sym_2 = "NO"
if(preds[0].item() == 1):
sym = "YES"
if FAILED_PRINT:
sym_2 = "YES"
sys.stdout.write("\033[K")
#DEBUG
# " NETWORK VALUES: " + str(outputs[0][0].item())+","+str(outputs[0][1].item())
print("CNN FAIL DETECTED: " + sym + " USER FAIL DETECTED: " + sym_2 + " CURRENT ACCURACY: "+str(int(100*(running_accrucay/count)))+"%", end='\r')
else:
if(labels == preds[0].item()):
running_accrucay += 1
sys.stdout.write("\033[K")
print("Testing Image "+str(batch_idx)+"/"+str(len_test_set)+" CURRENT ACCURACY: "+str(int(100*(running_accrucay/count)))+"%", end='\r')
sys.stdout.write("\033[K")
print("TESTING FINISHED ACCURACY: "+str(int(100*(running_accrucay/count)))+"%", end='\r')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment