añadido el script para entrenar con pytorch
This commit is contained in:
107
entrenar.py
Normal file
107
entrenar.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||||
|
])
|
||||||
|
|
||||||
|
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
|
||||||
|
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
|
||||||
|
|
||||||
|
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
|
||||||
|
testloader = DataLoader(testset, batch_size=32, shuffle=False)
|
||||||
|
|
||||||
|
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
|
||||||
|
self.pool = nn.MaxPool2d(2, 2)
|
||||||
|
self.fc1 = nn.Linear(128 * 8 * 8, 512)
|
||||||
|
self.fc2 = nn.Linear(512, 10)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.dropout = nn.Dropout(0.5)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.pool(self.relu(self.conv1(x))) # 32x32 -> 16x16
|
||||||
|
x = self.pool(self.relu(self.conv2(x))) # 16x16 -> 8x8
|
||||||
|
x = x.view(-1, 128 * 8 * 8) # Aplanar
|
||||||
|
x = self.relu(self.fc1(x))
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
|
||||||
|
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.Adam(net.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
net.to(device)
|
||||||
|
|
||||||
|
print("Entrenando...")
|
||||||
|
|
||||||
|
for epoch in range(10):
|
||||||
|
running_loss = 0.0
|
||||||
|
for i, data in enumerate(trainloader, 0):
|
||||||
|
inputs, labels = data
|
||||||
|
inputs, labels = inputs.to(device), labels.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
outputs = net(inputs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
if i % 200 == 199:
|
||||||
|
print(f'Época [{epoch + 1}], Paso [{i + 1}], Pérdida: {running_loss / 200:.3f}')
|
||||||
|
running_loss = 0.0
|
||||||
|
|
||||||
|
print("Entrenamiento terminado.")
|
||||||
|
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
net.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in testloader:
|
||||||
|
images, labels = data
|
||||||
|
images, labels = images.to(device), labels.to(device)
|
||||||
|
outputs = net(images)
|
||||||
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += (predicted == labels).sum().item()
|
||||||
|
|
||||||
|
print(f'Precisión del modelo en el conjunto de prueba: {100 * correct / total:.2f}%')
|
||||||
|
|
||||||
|
def imshow(img):
|
||||||
|
img = img / 2 + 0.5 # Desnormalizar
|
||||||
|
npimg = img.numpy()
|
||||||
|
plt.imshow(np.transpose(npimg, (1, 2, 0)))
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
dataiter = iter(testloader)
|
||||||
|
images, labels = next(dataiter)
|
||||||
|
images, labels = images.to(device), labels.to(device)
|
||||||
|
|
||||||
|
outputs = net(images)
|
||||||
|
_, predicted = torch.max(outputs, 1)
|
||||||
|
|
||||||
|
imshow(torchvision.utils.make_grid(images.cpu()))
|
||||||
|
print('Verdaderos: ', ' '.join(f'{classes[labels[j]]}' for j in range(4)))
|
||||||
|
print('Predichos: ', ' '.join(f'{classes[predicted[j]]}' for j in range(4)))
|
||||||
|
|
||||||
|
torch.save(net.state_dict(), 'modelo_cifar10.pth')
|
||||||
|
print("Modelo guardado como 'modelo_cifar10.pth'")
|
||||||
Reference in New Issue
Block a user