# Importações necessárias
import os
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tkinter as tk
from tkinter import ttk, messagebox
from PIL import Image, ImageTk
import threading
import time

# Classes do dataset Garbage Classification v2
CLASSES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
IMG_SIZE = 224  # Tamanho padrão para MobileNetV2


class ReciclagemApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Ctrl+Eco - Classificador de Lixo")
        self.root.geometry("800x600")

        # Variáveis
        self.modelo = None
        self.webcam_ativa = False
        self.cap = None
        self.resultado = tk.StringVar(value="Aguardando classificação...")

        # Interface principal
        self.criar_interface()

        # Tentar carregar modelo pré-treinado
        self.carregar_modelo()

    def criar_interface(self):
        # Frame principal
        main_frame = ttk.Frame(self.root, padding="10")
        main_frame.pack(fill=tk.BOTH, expand=True)

        # Área de título
        titulo = ttk.Label(main_frame, text="Ctrl+Eco - Classificador de Materiais Recicláveis",
                           font=("Arial", 16, "bold"))
        titulo.pack(pady=10)

        # Frame para webcam
        webcam_frame = ttk.LabelFrame(main_frame, text="Visualização da Webcam")
        webcam_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)

        self.area_video = ttk.Label(webcam_frame)
        self.area_video.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)

        # Frame para resultado
        resultado_frame = ttk.LabelFrame(main_frame, text="Resultado da Classificação")
        resultado_frame.pack(fill=tk.X, padx=10, pady=10)

        resultado_label = ttk.Label(resultado_frame, textvariable=self.resultado,
                                    font=("Arial", 20))
        resultado_label.pack(pady=20)

        # Frame para botões
        botoes_frame = ttk.Frame(main_frame)
        botoes_frame.pack(fill=tk.X, padx=10, pady=10)

        # Botões
        self.btn_iniciar = ttk.Button(botoes_frame, text="Iniciar Webcam",
                                      command=self.iniciar_webcam)
        self.btn_iniciar.pack(side=tk.LEFT, padx=5)

        self.btn_parar = ttk.Button(botoes_frame, text="Parar Webcam",
                                    command=self.parar_webcam)
        self.btn_parar.pack(side=tk.LEFT, padx=5)

        self.btn_treinar = ttk.Button(botoes_frame, text="Treinar Modelo",
                                      command=self.treinar_modelo)
        self.btn_treinar.pack(side=tk.RIGHT, padx=5)

    def carregar_modelo(self):
        """Tenta carregar um modelo pré-treinado"""
        try:
            self.modelo = load_model('modelo_reciclagem.h5')
            messagebox.showinfo("Modelo Carregado", "Modelo pré-treinado carregado com sucesso!")
        except:
            messagebox.showinfo("Modelo não encontrado",
                                "Nenhum modelo pré-treinado encontrado. Use o botão 'Treinar Modelo' para criar um.")

    def treinar_modelo(self):
        """Treina o modelo usando o dataset Garbage Classification v2"""
        # Verificar se o dataset está disponível
        if not os.path.exists('dataset'):
            messagebox.showerror("Erro",
                                 "Dataset não encontrado! Por favor, baixe o dataset 'Garbage Classification v2' do Kaggle e extraia na pasta 'dataset'.")
            return

        # Iniciar treinamento em thread separada
        threading.Thread(target=self._treinar).start()

    def _treinar(self):
        try:
            # Desabilitar botões durante o treinamento
            self.btn_treinar.config(state='disabled')
            self.resultado.set("Treinando modelo... Por favor, aguarde.")

            # Preparar geradores de dados
            datagen = ImageDataGenerator(
                rescale=1. / 255,
                validation_split=0.2,  # 20% para validação
                rotation_range=20,
                width_shift_range=0.2,
                height_shift_range=0.2,
                shear_range=0.2,
                zoom_range=0.2,
                horizontal_flip=True
            )

            # Carregar dados de treinamento
            train_generator = datagen.flow_from_directory(
                'dataset',
                target_size=(IMG_SIZE, IMG_SIZE),
                batch_size=32,
                class_mode='categorical',
                subset='training'
            )

            # Carregar dados de validação
            validation_generator = datagen.flow_from_directory(
                'dataset',
                target_size=(IMG_SIZE, IMG_SIZE),
                batch_size=32,
                class_mode='categorical',
                subset='validation'
            )

            # Criar modelo usando MobileNetV2 (transfer learning)
            base_model = MobileNetV2(weights='imagenet', include_top=False,
                                     input_shape=(IMG_SIZE, IMG_SIZE, 3))
            base_model.trainable = False  # Congelar camadas pré-treinadas

            modelo = Sequential([
                base_model,
                GlobalAveragePooling2D(),
                Dense(128, activation='relu'),
                Dropout(0.2),
                Dense(len(CLASSES), activation='softmax')
            ])

            # Compilar modelo
            modelo.compile(
                optimizer='adam',
                loss='categorical_crossentropy',
                metrics=['accuracy']
            )

            # Treinar modelo (poucas épocas para ser mais rápido)
            modelo.fit(
                train_generator,
                steps_per_epoch=train_generator.samples // 32,
                validation_data=validation_generator,
                validation_steps=validation_generator.samples // 32,
                epochs=5  # Reduzido para ser mais rápido
            )

            # Salvar modelo
            modelo.save('modelo_reciclagem.h5')
            self.modelo = modelo

            # Atualizar interface
            self.resultado.set("Treinamento concluído! Pronto para classificar.")
            messagebox.showinfo("Sucesso", "Modelo treinado com sucesso!")

        except Exception as e:
            messagebox.showerror("Erro", f"Erro durante o treinamento: {str(e)}")
            self.resultado.set("Erro no treinamento. Verifique o console.")
            print(f"Erro detalhado: {str(e)}")

        finally:
            # Reabilitar botões
            self.btn_treinar.config(state='normal')

    def iniciar_webcam(self):
        """Inicia a captura da webcam e classificação"""
        # Verificar se o modelo está carregado
        if self.modelo is None:
            messagebox.showerror("Erro", "Nenhum modelo carregado! Por favor, treine o modelo primeiro.")
            return

        # Iniciar webcam
        self.webcam_ativa = True
        self.cap = cv2.VideoCapture(0)

        if not self.cap.isOpened():
            messagebox.showerror("Erro", "Não foi possível acessar a webcam!")
            self.webcam_ativa = False
            return

        # Iniciar thread para captura de frames
        threading.Thread(target=self.processar_webcam).start()

        # Atualizar interface
        self.btn_iniciar.config(state='disabled')
        self.btn_parar.config(state='normal')

    def processar_webcam(self):
        """Processa frames da webcam e faz classificação"""
        while self.webcam_ativa:
            ret, frame = self.cap.read()
            if not ret:
                break

            # Processar frame para classificação
            img_processada = cv2.resize(frame, (IMG_SIZE, IMG_SIZE))
            img_processada = img_processada / 255.0
            img_processada = np.expand_dims(img_processada, axis=0)

            # Fazer previsão
            predicao = self.modelo.predict(img_processada)
            classe_idx = np.argmax(predicao[0])
            confianca = predicao[0][classe_idx] * 100

            # Obter nome da classe
            classe = CLASSES[classe_idx]

            # Adicionar texto ao frame
            texto = f"{classe}: {confianca:.1f}%"
            cv2.putText(frame, texto, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
                        1, (0, 255, 0), 2)

            # Converter para formato Tkinter
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(frame_rgb)
            img = ImageTk.PhotoImage(image=img)

            # Atualizar interface (thread-safe)
            self.root.after(0, self._atualizar_interface, img, texto)

            # Pequena pausa
            time.sleep(0.03)

    def _atualizar_interface(self, img, texto):
        """Atualiza a interface com a nova imagem e resultado"""
        self.area_video.configure(image=img)
        self.area_video.image = img  # Manter referência
        self.resultado.set(texto)

    def parar_webcam(self):
        """Para a captura da webcam"""
        self.webcam_ativa = False
        if self.cap is not None:
            self.cap.release()
            self.cap = None

        # Atualizar interface
        self.btn_iniciar.config(state='normal')
        self.btn_parar.config(state='disabled')
        self.resultado.set("Webcam desativada")
        self.area_video.configure(image='')


# Função para preparar o dataset do Kaggle
def preparar_dataset_kaggle(caminho_zip, destino='dataset'):
    """
    Extrai e organiza o dataset Garbage Classification v2 do Kaggle

    Args:
        caminho_zip: Caminho para o arquivo ZIP baixado do Kaggle
        destino: Pasta onde o dataset será extraído
    """
    import zipfile
    import shutil

    # Criar pasta de destino
    os.makedirs(destino, exist_ok=True)

    # Extrair ZIP
    with zipfile.ZipFile(caminho_zip, 'r') as zip_ref:
        zip_ref.extractall('temp_dataset')

    # O dataset já está organizado por pastas, então só precisamos mover
    # da pasta temporária para o destino final
    for classe in CLASSES:
        os.makedirs(os.path.join(destino, classe), exist_ok=True)

        # Origem pode variar dependendo da estrutura do ZIP
        origem = os.path.join('temp_dataset', 'Garbage classification', classe)
        if not os.path.exists(origem):
            origem = os.path.join('temp_dataset', classe)

        # Copiar arquivos
        for arquivo in os.listdir(origem):
            if arquivo.endswith(('.jpg', '.jpeg', '.png')):
                shutil.copy(
                    os.path.join(origem, arquivo),
                    os.path.join(destino, classe, arquivo)
                )

    # Limpar pasta temporária
    shutil.rmtree('temp_dataset')
    print(f"Dataset preparado em {destino}")


# Função principal
def main():
    # Criar janela principal
    root = tk.Tk()
    app = ReciclagemApp(root)
    root.mainloop()


if __name__ == "__main__":
    main()