Skip to content
Snippets Groups Projects
Commit 4aa4eb49 authored by s87425's avatar s87425
Browse files

Autoencoder fixed

parent 2d98304b
No related branches found
No related tags found
No related merge requests found
import torch import torch
import torchvision
from torch import nn from torch import nn
from torch.autograd import Variable from torch.utils.data import DataLoader, random_split, Dataset
from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import ImageFolder from torchvision.datasets import ImageFolder
import os import os
from sklearn import svm
from sklearn.metrics import accuracy_score
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
if not os.path.exists('./dc_img'): if not os.path.exists('./dc_img'):
os.mkdir('./dc_img') os.mkdir('./dc_img')
...@@ -17,6 +18,35 @@ def to_img(x): ...@@ -17,6 +18,35 @@ def to_img(x):
x = x.view(x.size(0), 3, 256, 256) # Anpassung der Form für RGB-Bilder x = x.view(x.size(0), 3, 256, 256) # Anpassung der Form für RGB-Bilder
return x return x
class LimitedImageFolder(Dataset):
def __init__(self, root, transform=None, limit_per_class=10000):
self.root = root
self.transform = transform
self.limit_per_class = limit_per_class
self.image_folder = ImageFolder(root=self.root, transform=self.transform)
self.class_indices = self._limit_per_class()
def _limit_per_class(self):
class_indices = {}
for i, (image_path, class_label) in enumerate(self.image_folder.imgs):
if class_label not in class_indices:
class_indices[class_label] = []
if len(class_indices[class_label]) < self.limit_per_class:
class_indices[class_label].append(i)
return class_indices
def __getitem__(self, index):
original_index = self.class_indices[index // self.limit_per_class][index % self.limit_per_class]
return self.image_folder[original_index]
def __len__(self):
return len(self.class_indices) * self.limit_per_class
num_epochs = 100 num_epochs = 100
batch_size = 128 batch_size = 128
learning_rate = 1e-3 learning_rate = 1e-3
...@@ -29,8 +59,16 @@ img_transform = transforms.Compose([ ...@@ -29,8 +59,16 @@ img_transform = transforms.Compose([
# Anpassung des Datasets auf ImageFolder # Anpassung des Datasets auf ImageFolder
data_dir = 'Plant_leave_diseases_dataset_without_augmentation' # Setze das Verzeichnis deines Bild-Datasets hier ein data_dir = 'Plant_leave_diseases_dataset_without_augmentation' # Setze das Verzeichnis deines Bild-Datasets hier ein
dataset = ImageFolder(root=data_dir, transform=img_transform) dataset = LimitedImageFolder(root=data_dir, transform=img_transform, limit_per_class=10)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Aufteilung in Trainings- und Testset
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
# DataLoader für Trainings- und Testset
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
class autoencoder(nn.Module): class autoencoder(nn.Module):
def __init__(self): def __init__(self):
...@@ -53,6 +91,7 @@ class autoencoder(nn.Module): ...@@ -53,6 +91,7 @@ class autoencoder(nn.Module):
) )
def forward(self, x): def forward(self, x):
#print("Hallo")
x = self.encoder(x) x = self.encoder(x)
x = self.decoder(x) x = self.decoder(x)
return x return x
...@@ -65,9 +104,9 @@ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay= ...@@ -65,9 +104,9 @@ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=
for epoch in range(num_epochs): for epoch in range(num_epochs):
total_loss = 0 total_loss = 0
for data in dataloader: for data in train_dataloader:
img, _ = data img, labels = data # Achtung: Hier gehe ich davon aus, dass deine DataLoader die Labels zurückgeben
img = Variable(img).to(device) img = img.to(device)
output = model(img) output = model(img)
loss = criterion(output, img) loss = criterion(output, img)
optimizer.zero_grad() optimizer.zero_grad()
...@@ -75,9 +114,50 @@ for epoch in range(num_epochs): ...@@ -75,9 +114,50 @@ for epoch in range(num_epochs):
optimizer.step() optimizer.step()
total_loss += loss.data total_loss += loss.data
print('epoch [{}/{}], loss:{:.4f}'.format(epoch+1, num_epochs, total_loss)) print('Autoencoder: epoch [{}/{}], loss:{:.4f}'.format(epoch+1, num_epochs, total_loss))
if epoch % 10 == 0:
pic = to_img(output.cpu().data) # Wende den Autoencoder auf den Trainings- und Testdatensatz an und extrahiere den Latent-Space
save_image(pic, './dc_img/image_{}.png'.format(epoch)) model.eval()
with torch.no_grad():
train_latent = []
train_labels = []
for data in train_dataloader:
img, labels = data
img = img.to(device)
latent = model.encoder(img)
train_latent.append(latent.cpu().numpy())
train_labels.extend(labels.numpy())
test_latent = []
test_labels = []
for data in test_dataloader:
img, labels = data
img = img.to(device)
latent = model.encoder(img)
test_latent.append(latent.cpu().numpy())
test_labels.extend(labels.numpy())
unique_classes = torch.unique(torch.tensor(train_labels))
# Konvertiere Latent-Space-Daten in Tensoren
train_latent = torch.cat([torch.from_numpy(latent) for latent in train_latent], dim=0)
test_latent = torch.cat([torch.from_numpy(latent) for latent in test_latent], dim=0)
# Flatten Sie die Latent-Space-Daten (optional)
train_latent = train_latent.view(train_latent.size(0), -1)
test_latent = test_latent.view(test_latent.size(0), -1)
# Berechne die Klassen-Gewichte basierend auf der Anzahl der Bilder pro Klasse
class_weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
# Konvertiere die Gewichte in ein Dictionary
class_weight_dict = {class_idx: weight for class_idx, weight in enumerate(class_weights)}
# Trainiere eine SVM auf den Latent-Space-Daten
svm_classifier = svm.SVC(class_weight=class_weight_dict)
svm_classifier.fit(train_latent, train_labels)
# Klassifiziere den Testdatensatz mit der trainierten SVM
predicted_labels = svm_classifier.predict(test_latent)
torch.save(model.state_dict(), './conv_autoencoder.pth') # Berechne die Genauigkeit
accuracy = accuracy_score(test_labels, predicted_labels)
print(f'SVM Accuracy: {accuracy}')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment