Skip to content
Snippets Groups Projects
Commit 2d98304b authored by s87425's avatar s87425
Browse files

Autoencoder

parent cc505742
No related branches found
No related tags found
No related merge requests found
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import ImageFolder
import os
if not os.path.exists('./dc_img'):
os.mkdir('./dc_img')
def to_img(x):
x = 0.5 * (x + 1)
x = x.clamp(0, 1)
x = x.view(x.size(0), 3, 256, 256) # Anpassung der Form für RGB-Bilder
return x
num_epochs = 100
batch_size = 128
learning_rate = 1e-3
img_transform = transforms.Compose([
transforms.Resize((256, 256)), # Anpassung der Größe
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Anpassung der Normalisierung für RGB
])
# Anpassung des Datasets auf ImageFolder
data_dir = 'Plant_leave_diseases_dataset_without_augmentation' # Setze das Verzeichnis deines Bild-Datasets hier ein
dataset = ImageFolder(root=data_dir, transform=img_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
class autoencoder(nn.Module):
def __init__(self):
super(autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 16, 3, stride=3, padding=1), # Anpassung für 3 Kanäle
nn.ReLU(True),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(16, 8, 3, stride=2, padding=1),
nn.ReLU(True),
nn.MaxPool2d(2, stride=1)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(8, 16, 3, stride=2),
nn.ReLU(True),
nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),
nn.ReLU(True),
nn.ConvTranspose2d(8, 3, 2, stride=2, padding=1), # Anpassung für 3 Kanäle
nn.Tanh()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
for epoch in range(num_epochs):
total_loss = 0
for data in dataloader:
img, _ = data
img = Variable(img).to(device)
output = model(img)
loss = criterion(output, img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.data
print('epoch [{}/{}], loss:{:.4f}'.format(epoch+1, num_epochs, total_loss))
if epoch % 10 == 0:
pic = to_img(output.cpu().data)
save_image(pic, './dc_img/image_{}.png'.format(epoch))
torch.save(model.state_dict(), './conv_autoencoder.pth')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment