Skip to content
Snippets Groups Projects
plant_diseases_classifier.py 2.78 KiB
Newer Older
s87425's avatar
s87425 committed
import os
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay, precision_score, f1_score, recall_score
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import sklearn
import matplotlib.pyplot as plt

s87425's avatar
s87425 committed
nfeatures = 500
s87425's avatar
s87425 committed
main_directory = 'result_descriptors_and_labels/train'
data_matrix = np.empty((0, 128*nfeatures), dtype = np.float16)   
label_vector = np.empty((0, 1), dtype = np.float16)
label_dict = {}

for npz_name in sorted(os.listdir(main_directory)):

    npz_path = os.path.join(main_directory, npz_name)
    npz_data = np.load(npz_path)
    npz_arr = npz_data[list(npz_data.keys())[0]]
    
    if npz_arr.shape[1] == 1:
        label_vector = np.vstack((label_vector, npz_arr), dtype = np.float16)
    else:
        data_matrix = np.vstack((data_matrix, npz_arr), dtype = np.float16)
        label_dict[npz_arr[0,0]] = npz_name[:npz_name.find('_label_')]
    
print("Hier1")
num_classes = len(np.unique(label_vector))
class_weights = compute_class_weight('balanced', classes=np.unique(label_vector), y=label_vector.flatten())
class_weights_dict = dict(zip(np.unique(label_vector), class_weights))
print("Hier2")
#class_weights = dict(zip(list(range(len(os.listdir(main_directory))/2)), compute_class_weight('balanced', list(range(len(os.listdir(main_directory))/2)), label_vector)))
plant_diseases_svm = SVC(kernel='linear', class_weight=class_weights_dict)
plant_diseases_svm.fit(data_matrix, np.ravel(label_vector))
print("Hier3")
test_directory = 'result_descriptors_and_labels/test'
test_data_matrix = np.empty((0, 128*nfeatures), dtype = np.float16)
test_label_vector = np.empty((0,1), dtype = np.float16)

for npz_name in sorted(os.listdir(test_directory)):

    npz_path = os.path.join(test_directory, npz_name)
    npz_data = np.load(npz_path)
    npz_arr = npz_data[list(npz_data.keys())[0]]   

    if npz_arr.shape[1] == 1:
        test_label_vector = np.vstack((label_vector, npz_arr),dtype = np.float16)
    else:
        test_data_matrix = np.vstack((data_matrix, npz_arr),dtype = np.float16)

test_classes = plant_diseases_svm.predict(test_data_matrix)     
s87425's avatar
s87425 committed
accuracy = accuracy_score(test_label_vector, test_classes)
s87425's avatar
s87425 committed
print(f'SVM Accuracy: {accuracy}')

s87425's avatar
s87425 committed
cm = confusion_matrix(test_label_vector, test_classes)
s87425's avatar
s87425 committed

s87425's avatar
s87425 committed
ps = precision_score(test_label_vector, test_classes, average=None)
s87425's avatar
s87425 committed

s87425's avatar
s87425 committed
f1 = f1_score(test_label_vector, test_classes, average=None)
s87425's avatar
s87425 committed

s87425's avatar
s87425 committed
rc = recall_score(test_label_vector, test_classes, average=None)
s87425's avatar
s87425 committed

print('confusion matrix:', cm , "\n \n", 
      'precision:', ps ,"\n \n",
      'f1:', f1 , "\n \n",
      'recall:', rc , "\n \n")

cmd = ConfusionMatrixDisplay(cm, display_labels=["Apple_scab",
"Apple_Black_rot"])

cmd.plot()
plt.show()