132 lines
4.0 KiB
Python
132 lines
4.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from PIL import Image, ImageTk
|
|
import torchvision.transforms as transforms
|
|
from pathlib import Path
|
|
|
|
from tkinterdnd2 import TkinterDnD
|
|
import tkinter as tk
|
|
from tkinter import Label, Button, filedialog
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
|
|
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
|
|
self.pool = nn.MaxPool2d(2, 2)
|
|
self.fc1 = nn.Linear(128 * 8 * 8, 512)
|
|
self.fc2 = nn.Linear(512, 10)
|
|
self.relu = nn.ReLU()
|
|
self.dropout = nn.Dropout(0.5)
|
|
|
|
def forward(self, x):
|
|
x = self.pool(self.relu(self.conv1(x)))
|
|
x = self.pool(self.relu(self.conv2(x)))
|
|
x = x.view(-1, 128 * 8 * 8)
|
|
x = self.relu(self.fc1(x))
|
|
x = self.dropout(x)
|
|
x = self.fc2(x)
|
|
return x
|
|
|
|
model = Net()
|
|
try:
|
|
model.load_state_dict(torch.load('modelo_cifar10.pth', map_location='cpu'))
|
|
model.eval()
|
|
print("✅ Modelo cargado.")
|
|
except FileNotFoundError:
|
|
print("❌ No se encontró 'modelo_cifar10.pth'. Entrena primero.")
|
|
exit()
|
|
|
|
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize((32, 32)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
|
])
|
|
|
|
class ImageClassifierApp:
|
|
def __init__(self, root):
|
|
self.root = root
|
|
self.root.title("🎯 Clasificador de Imágenes - Arrastra una imagen")
|
|
self.root.geometry("500x600")
|
|
self.root.configure(bg="#f0f0f0")
|
|
|
|
self.label = Label(
|
|
root,
|
|
text="👇 Arrastra una imagen aquí",
|
|
bg="#f0f0f0",
|
|
fg="#333",
|
|
font=("Arial", 14),
|
|
)
|
|
self.label.pack(pady=20)
|
|
|
|
self.button = Button(
|
|
root,
|
|
text="🔍 Seleccionar imagen",
|
|
command=self.open_file,
|
|
bg="#4CAF50",
|
|
fg="white",
|
|
font=("Arial", 12),
|
|
padx=20
|
|
)
|
|
self.button.pack(pady=10)
|
|
|
|
self.image_label = Label(root, bg="#ddd", text="No hay imagen", width=30, height=15)
|
|
self.image_label.pack(pady=10)
|
|
|
|
self.result_label = Label(
|
|
root,
|
|
text="Predicción: ---",
|
|
bg="#f0f0f0",
|
|
font=("Arial", 16, "bold"),
|
|
fg="#005bb5"
|
|
)
|
|
self.result_label.pack(pady=20)
|
|
|
|
self.root.drop_target_register("DND_Files")
|
|
self.root.dnd_bind("<<Drop>>", self.on_drop)
|
|
|
|
def open_file(self):
|
|
file_path = filedialog.askopenfilename(
|
|
filetypes=[("Imágenes", "*.jpg *.jpeg *.png *.bmp *.webp")]
|
|
)
|
|
if file_path:
|
|
self.process_image(file_path)
|
|
|
|
def on_drop(self, event):
|
|
try:
|
|
file_paths = self.root.tk.splitlist(event.data)
|
|
file_path = str(file_paths[0]).strip("{}") # Limpia llaves si vienen
|
|
if Path(file_path).exists():
|
|
self.process_image(file_path)
|
|
else:
|
|
self.result_label.config(text="❌ Archivo no válido")
|
|
except Exception as e:
|
|
self.result_label.config(text=f"❌ Error: {e}")
|
|
|
|
def process_image(self, file_path):
|
|
try:
|
|
image = Image.open(file_path).convert("RGB")
|
|
|
|
img_display = image.resize((250, 250), Image.Resampling.LANCZOS)
|
|
photo = ImageTk.PhotoImage(img_display)
|
|
self.image_label.configure(image=photo, text="")
|
|
self.image_label.image = photo
|
|
|
|
input_tensor = transform(image).unsqueeze(0)
|
|
with torch.no_grad():
|
|
output = model(input_tensor)
|
|
_, predicted = torch.max(output, 1)
|
|
class_name = classes[predicted.item()]
|
|
|
|
self.result_label.config(text=f"🎯 {class_name}")
|
|
|
|
except Exception as e:
|
|
self.result_label.config(text=f"❌ Error: {str(e)}")
|
|
|
|
if __name__ == "__main__":
|
|
root = TkinterDnD.Tk()
|
|
app = ImageClassifierApp(root)
|
|
root.mainloop()
|