añadido codigo para poder usar el modelo

This commit is contained in:
2025-07-21 16:57:43 -03:00
parent ef30f09426
commit 7311cdf6c6

131
drag_and_drop.py Normal file
View File

@@ -0,0 +1,131 @@
import torch
import torch.nn as nn
from PIL import Image, ImageTk
import torchvision.transforms as transforms
from pathlib import Path
from tkinterdnd2 import TkinterDnD
import tkinter as tk
from tkinter import Label, Button, filedialog
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(128 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 128 * 8 * 8)
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
model = Net()
try:
model.load_state_dict(torch.load('modelo_cifar10.pth', map_location='cpu'))
model.eval()
print("✅ Modelo cargado.")
except FileNotFoundError:
print("❌ No se encontró 'modelo_cifar10.pth'. Entrena primero.")
exit()
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
class ImageClassifierApp:
def __init__(self, root):
self.root = root
self.root.title("🎯 Clasificador de Imágenes - Arrastra una imagen")
self.root.geometry("500x600")
self.root.configure(bg="#f0f0f0")
self.label = Label(
root,
text="👇 Arrastra una imagen aquí",
bg="#f0f0f0",
fg="#333",
font=("Arial", 14),
)
self.label.pack(pady=20)
self.button = Button(
root,
text="🔍 Seleccionar imagen",
command=self.open_file,
bg="#4CAF50",
fg="white",
font=("Arial", 12),
padx=20
)
self.button.pack(pady=10)
self.image_label = Label(root, bg="#ddd", text="No hay imagen", width=30, height=15)
self.image_label.pack(pady=10)
self.result_label = Label(
root,
text="Predicción: ---",
bg="#f0f0f0",
font=("Arial", 16, "bold"),
fg="#005bb5"
)
self.result_label.pack(pady=20)
self.root.drop_target_register("DND_Files")
self.root.dnd_bind("<<Drop>>", self.on_drop)
def open_file(self):
file_path = filedialog.askopenfilename(
filetypes=[("Imágenes", "*.jpg *.jpeg *.png *.bmp *.webp")]
)
if file_path:
self.process_image(file_path)
def on_drop(self, event):
try:
file_paths = self.root.tk.splitlist(event.data)
file_path = str(file_paths[0]).strip("{}") # Limpia llaves si vienen
if Path(file_path).exists():
self.process_image(file_path)
else:
self.result_label.config(text="❌ Archivo no válido")
except Exception as e:
self.result_label.config(text=f"❌ Error: {e}")
def process_image(self, file_path):
try:
image = Image.open(file_path).convert("RGB")
img_display = image.resize((250, 250), Image.Resampling.LANCZOS)
photo = ImageTk.PhotoImage(img_display)
self.image_label.configure(image=photo, text="")
self.image_label.image = photo
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
_, predicted = torch.max(output, 1)
class_name = classes[predicted.item()]
self.result_label.config(text=f"🎯 {class_name}")
except Exception as e:
self.result_label.config(text=f"❌ Error: {str(e)}")
if __name__ == "__main__":
root = TkinterDnD.Tk()
app = ImageClassifierApp(root)
root.mainloop()