Skip to content
Snippets Groups Projects
EfficientNet.py 5.23 KiB
Newer Older
s87425's avatar
s87425 committed
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import datasets, transforms
from torchvision.models.resnet import resnet18, ResNet18_Weights
from torchvision.datasets import ImageFolder
from sklearn.metrics import confusion_matrix, classification_report

class LimitedImageFolder(Dataset):
    def __init__(self, root, transform=None, limit_per_class=10000):
        self.root = root
        self.transform = transform
        self.limit_per_class = limit_per_class

        self.image_folder = ImageFolder(root=self.root, transform=self.transform)
        self.class_indices = self._limit_per_class()

    def _limit_per_class(self):
        class_indices = {}
        for i, (image_path, class_label) in enumerate(self.image_folder.imgs):
            if class_label not in class_indices:
                class_indices[class_label] = []

            if len(class_indices[class_label]) < self.limit_per_class:
                class_indices[class_label].append(i)

        return class_indices

    def __getitem__(self, index):
        class_indices_length = len(self.class_indices)
        if class_indices_length == 0:
            raise ValueError("class_indices is empty!")

        class_index = index // self.limit_per_class
        if class_index >= class_indices_length:
            raise IndexError(f"class_index {class_index} out of range!")

        indices_list = self.class_indices[class_index]
        indices_list_length = len(indices_list)
        if indices_list_length == 0:
            raise ValueError(f"indices_list for class_index {class_index} is empty!")

        inner_index = index % self.limit_per_class
        if inner_index >= indices_list_length:
            # Adjust inner_index to be within the valid range
            inner_index = inner_index % indices_list_length

        original_index = indices_list[inner_index]
        return self.image_folder[original_index]



    def __len__(self):
        return len(self.class_indices) * self.limit_per_class



# Daten vorbereiten
img_transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Anpassung der Größe
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Anpassung der Normalisierung für RGB
])

data_dir = 'Plant_leave_diseases_dataset_without_augmentation'  # Setze das Verzeichnis deines Bild-Datasets hier ein
dataset = LimitedImageFolder(root=data_dir, transform=img_transform, limit_per_class=5)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# EfficientNet definieren

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, 39)
# # Nur den Klassifikations-Head (fully connected layer) feintunen
# for name, param in model.named_parameters():
#     if name.startswith('fc'):
#         param.requires_grad = True
#     else:
#         param.requires_grad = False
# Nur den Klassifikations-Head (fully connected layer) feintunen
fc_found = False
for name, param in model.named_parameters():
    if name.startswith('fc'):
        param.requires_grad = True
        fc_found = True
    elif fc_found:
        # Setze requires_grad auf True für den Layer vor dem fc-Layer
        param.requires_grad = True
    else:
        param.requires_grad = False


# Optimizer und Loss-Funktion
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training
num_epochs = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Evaluation
    model.eval()
    # correct = 0
    # total = 0

    # with torch.no_grad():
    #     for images, labels in test_loader:
    #         images, labels = images.to(device), labels.to(device)
    #         outputs = model(images)
    #         _, predicted = torch.max(outputs.data, 1)
    #         total += labels.size(0)
    #         correct += (predicted == labels).sum().item()

    # accuracy = correct / total
    # print(f'Epoch {epoch + 1}/{num_epochs}, Accuracy: {accuracy:.4f}')
    predictions = []
    ground_truth = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            predictions.extend(predicted.cpu().numpy())
            ground_truth.extend(labels.cpu().numpy())

    # Classification Report und Confusion Matrix berechnen
    print(f'Epoch {epoch + 1}/{num_epochs}')
    print(classification_report(ground_truth, predictions))
    print(confusion_matrix(ground_truth, predictions))

# Speichern des trainierten Modells
torch.save(model.state_dict(), 'efficientnet_finetuned.pth')