añadido codigo para poder usar el modelo
This commit is contained in:
131
drag_and_drop.py
Normal file
131
drag_and_drop.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image, ImageTk
|
||||
import torchvision.transforms as transforms
|
||||
from pathlib import Path
|
||||
|
||||
from tkinterdnd2 import TkinterDnD
|
||||
import tkinter as tk
|
||||
from tkinter import Label, Button, filedialog
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
self.fc1 = nn.Linear(128 * 8 * 8, 512)
|
||||
self.fc2 = nn.Linear(512, 10)
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool(self.relu(self.conv1(x)))
|
||||
x = self.pool(self.relu(self.conv2(x)))
|
||||
x = x.view(-1, 128 * 8 * 8)
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
model = Net()
|
||||
try:
|
||||
model.load_state_dict(torch.load('modelo_cifar10.pth', map_location='cpu'))
|
||||
model.eval()
|
||||
print("✅ Modelo cargado.")
|
||||
except FileNotFoundError:
|
||||
print("❌ No se encontró 'modelo_cifar10.pth'. Entrena primero.")
|
||||
exit()
|
||||
|
||||
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((32, 32)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
class ImageClassifierApp:
|
||||
def __init__(self, root):
|
||||
self.root = root
|
||||
self.root.title("🎯 Clasificador de Imágenes - Arrastra una imagen")
|
||||
self.root.geometry("500x600")
|
||||
self.root.configure(bg="#f0f0f0")
|
||||
|
||||
self.label = Label(
|
||||
root,
|
||||
text="👇 Arrastra una imagen aquí",
|
||||
bg="#f0f0f0",
|
||||
fg="#333",
|
||||
font=("Arial", 14),
|
||||
)
|
||||
self.label.pack(pady=20)
|
||||
|
||||
self.button = Button(
|
||||
root,
|
||||
text="🔍 Seleccionar imagen",
|
||||
command=self.open_file,
|
||||
bg="#4CAF50",
|
||||
fg="white",
|
||||
font=("Arial", 12),
|
||||
padx=20
|
||||
)
|
||||
self.button.pack(pady=10)
|
||||
|
||||
self.image_label = Label(root, bg="#ddd", text="No hay imagen", width=30, height=15)
|
||||
self.image_label.pack(pady=10)
|
||||
|
||||
self.result_label = Label(
|
||||
root,
|
||||
text="Predicción: ---",
|
||||
bg="#f0f0f0",
|
||||
font=("Arial", 16, "bold"),
|
||||
fg="#005bb5"
|
||||
)
|
||||
self.result_label.pack(pady=20)
|
||||
|
||||
self.root.drop_target_register("DND_Files")
|
||||
self.root.dnd_bind("<<Drop>>", self.on_drop)
|
||||
|
||||
def open_file(self):
|
||||
file_path = filedialog.askopenfilename(
|
||||
filetypes=[("Imágenes", "*.jpg *.jpeg *.png *.bmp *.webp")]
|
||||
)
|
||||
if file_path:
|
||||
self.process_image(file_path)
|
||||
|
||||
def on_drop(self, event):
|
||||
try:
|
||||
file_paths = self.root.tk.splitlist(event.data)
|
||||
file_path = str(file_paths[0]).strip("{}") # Limpia llaves si vienen
|
||||
if Path(file_path).exists():
|
||||
self.process_image(file_path)
|
||||
else:
|
||||
self.result_label.config(text="❌ Archivo no válido")
|
||||
except Exception as e:
|
||||
self.result_label.config(text=f"❌ Error: {e}")
|
||||
|
||||
def process_image(self, file_path):
|
||||
try:
|
||||
image = Image.open(file_path).convert("RGB")
|
||||
|
||||
img_display = image.resize((250, 250), Image.Resampling.LANCZOS)
|
||||
photo = ImageTk.PhotoImage(img_display)
|
||||
self.image_label.configure(image=photo, text="")
|
||||
self.image_label.image = photo
|
||||
|
||||
input_tensor = transform(image).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
output = model(input_tensor)
|
||||
_, predicted = torch.max(output, 1)
|
||||
class_name = classes[predicted.item()]
|
||||
|
||||
self.result_label.config(text=f"🎯 {class_name}")
|
||||
|
||||
except Exception as e:
|
||||
self.result_label.config(text=f"❌ Error: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
root = TkinterDnD.Tk()
|
||||
app = ImageClassifierApp(root)
|
||||
root.mainloop()
|
||||
Reference in New Issue
Block a user