import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import os
import threading
import queue

# 设置中文字体和负号正常显示
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False
# ================== 1. 手写 MinMaxScaler (替代 sklearn) ==================
class MinMaxScaler:
    def __init__(self):
        self.min_ = None
        self.scale_ = None

    def fit(self, X):
        self.min_ = X.min(axis=0)
        X_max = X.max(axis=0)
        self.scale_ = X_max - self.min_
        self.scale_[self.scale_ == 0] = 1

    def transform(self, X):
        return (X - self.min_) / self.scale_

    def inverse_transform(self, X_scaled):
        return X_scaled * self.scale_ + self.min_


# ================== 2. 手写 RNN 核心算法 (纯 Numpy) ==================
class SimpleRNN:
    def __init__(self, input_size, hidden_size=16, output_size=1, learning_rate=0.01):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.lr = learning_rate
        
        self.W_xh = np.random.randn(hidden_size, input_size) * np.sqrt(1 / input_size)
        self.W_hh = np.random.randn(hidden_size, hidden_size) * np.sqrt(1 / hidden_size)
        self.W_hy = np.random.randn(output_size, hidden_size) * np.sqrt(1 / hidden_size)
        
        self.b_h = np.zeros((hidden_size, 1))
        self.b_y = np.zeros((output_size, 1))

    def tanh(self, x):
        return np.tanh(x)

    def tanh_derivative(self, x):
        return 1 - np.tanh(x) ** 2

    def forward(self, x_sequence):
        T = x_sequence.shape[0]
        h = np.zeros((self.hidden_size, 1))
        
        self.h_history = np.zeros((T, self.hidden_size))
        self.y_pred = np.zeros((T, self.output_size))
        
        for t in range(T):
            x_t = x_sequence[t].reshape(-1, 1)
            h_raw = np.dot(self.W_xh, x_t) + np.dot(self.W_hh, h) + self.b_h
            h = self.tanh(h_raw)
            y_t = np.dot(self.W_hy, h) + self.b_y
            
            self.h_history[t] = h.flatten()
            self.y_pred[t] = y_t.flatten()
            
        return self.y_pred, self.h_history

    def backward(self, x_sequence, y_true):
        T = x_sequence.shape[0]
        dL_dy = (self.y_pred - y_true).reshape(T, -1)
        
        dW_xh = np.zeros_like(self.W_xh)
        dW_hh = np.zeros_like(self.W_hh)
        dW_hy = np.zeros_like(self.W_hy)
        db_h = np.zeros_like(self.b_h)
        db_y = np.zeros_like(self.b_y)
        
        dh_next = np.zeros((self.hidden_size, 1))
        
        for t in reversed(range(T)):
            x_t = x_sequence[t].reshape(-1, 1)
            h_t = self.h_history[t].reshape(-1, 1)
            dy = dL_dy[t].reshape(-1, 1)
            
            dW_hy += np.dot(dy, h_t.T)
            db_y += dy
            dh = np.dot(self.W_hy.T, dy) + dh_next
            dh_raw = dh * self.tanh_derivative(h_t)
            db_h += dh_raw
            dW_xh += np.dot(dh_raw, x_t.T)
            
            if t > 0:
                h_prev = self.h_history[t-1].reshape(-1, 1)
                dW_hh += np.dot(dh_raw, h_prev.T)
                dh_next = np.dot(self.W_hh.T, dh_raw)
            else:
                dh_next = np.dot(self.W_hh.T, dh_raw)
            
        max_grad = 5
        for grad in [dW_xh, dW_hh, dW_hy, db_h, db_y]:
            np.clip(grad, -max_grad, max_grad, out=grad)

        self.W_xh -= self.lr * dW_xh
        self.W_hh -= self.lr * dW_hh
        self.W_hy -= self.lr * dW_hy
        self.b_h -= self.lr * db_h
        self.b_y -= self.lr * db_y


# ================== 3. GUI 应用逻辑 (多线程优化版) ==================
class RNNEnergyApp:
    def __init__(self, root):
        self.root = root
        self.root.title("【RNN】智能楼宇能耗预测系统")
        self.root.geometry("1000x800")
        
        # 数据与模型变量
        self.train_df = None
        self.test_df = None
        self.infer_df = None
        self.full_df = None
        self.rnn_model = None
        self.scaler_X = MinMaxScaler()
        self.scaler_y = MinMaxScaler()
        self.result_df = None
        
        # 线程安全队列，用于后台线程向主线程传递消息
        self.msg_queue = queue.Queue()
        
        self.setup_ui()
        # 启动定时器，每隔 100ms 检查一次队列中的消息
        self.root.after(100, self.process_queue)

    def setup_ui(self):
        ctrl_frame = ttk.LabelFrame(self.root, text="RNN 控制台")
        ctrl_frame.pack(fill="x", padx=10, pady=5)
        
        # 保存按钮引用，方便在任务进行时禁用/启用它们
        self.btn_load = ttk.Button(ctrl_frame, text="1. 加载数据", command=self.start_load_data_thread)
        self.btn_load.pack(side="left", padx=10, pady=5)
        
        self.btn_train = ttk.Button(ctrl_frame, text="2. 训练 RNN 模型", command=self.start_train_thread, state=tk.DISABLED)
        self.btn_train.pack(side="left", padx=10, pady=5)
        
        self.btn_export = ttk.Button(ctrl_frame, text="3. 导出预测结果", command=self.export_results, state=tk.DISABLED)
        self.btn_export.pack(side="left", padx=10, pady=5)

        log_frame = ttk.LabelFrame(self.root, text="训练日志")
        log_frame.pack(fill="both", expand=True, padx=10, pady=5)
        self.log_text = tk.Text(log_frame, height=15, state="disabled")
        self.log_text.pack(fill="both", expand=True, padx=5, pady=5)

        plot_frame = ttk.LabelFrame(self.root, text="RNN 训练损失曲线")
        plot_frame.pack(fill="both", expand=True, padx=10, pady=5)
        self.fig, self.ax = plt.subplots(figsize=(6, 4))
        self.canvas = FigureCanvasTkAgg(self.fig, master=plot_frame)
        self.canvas.get_tk_widget().pack(fill="both", expand=True)

    def log(self, msg):
        """在主线程中安全地更新日志"""
        self.log_text.config(state="normal")
        self.log_text.insert("end", msg + "\n")
        self.log_text.see("end")
        self.log_text.config(state="disabled")

    def process_queue(self):
        """定期检查队列，处理来自后台线程的消息"""
        try:
            while True:
                task = self.msg_queue.get_nowait()
                action = task['action']
                
                if action == 'log':
                    self.log(task['msg'])
                elif action == 'enable_button':
                    if task['btn'] == 'load': self.btn_load.config(state=tk.NORMAL)
                    if task['btn'] == 'train': self.btn_train.config(state=tk.NORMAL)
                    if task['btn'] == 'export': self.btn_export.config(state=tk.NORMAL)
                elif action == 'disable_button':
                    if task['btn'] == 'load': self.btn_load.config(state=tk.DISABLED)
                    if task['btn'] == 'train': self.btn_train.config(state=tk.DISABLED)
                elif action == 'plot_loss':
                    self.update_plot(task['losses'])
                elif action == 'finish_inference':
                    self.result_df = task['result_df']
                    self.btn_export.config(state=tk.NORMAL)
                    
        except queue.Empty:
            pass
        # 无论队列是否为空，都继续安排下一次检查
        self.root.after(100, self.process_queue)

    def update_plot(self, losses):
        """在主线程中更新绘图"""
        self.ax.clear()
        self.ax.plot(losses, label='RNN 训练 Loss')
        self.ax.set_title("RNN 损失曲线")
        self.ax.set_xlabel("Epoch")
        self.ax.set_ylabel("MSE Loss")
        self.ax.legend()
        self.canvas.draw()

    # --- 触发后台线程 ---
    def start_load_data_thread(self):
        self.msg_queue.put({'action': 'disable_button', 'btn': 'load'})
        thread = threading.Thread(target=self.load_data_worker, daemon=True)
        thread.start()

    def start_train_thread(self):
        self.msg_queue.put({'action': 'disable_button', 'btn': 'train'})
        thread = threading.Thread(target=self.train_rnn_worker, daemon=True)
        thread.start()

    # --- 实际在后台运行的 Worker 函数 ---
    def load_data_worker(self):
        try:
            # 1. 弹出文件对话框，限制文件类型为 Excel
            file_path = filedialog.askopenfilename(
                title="请选择能耗数据文件",
                filetypes=[
                    ("Excel files", "*.xlsx *.xls"),
                    ("All files", "*.*")
                ]
            )
            
            # 2. 如果用户取消选择，直接返回
            if not file_path:
                self.msg_queue.put({'action': 'log', 'msg': "❌ 未选择任何文件，操作已取消。"})
                self.msg_queue.put({'action': 'enable_button', 'btn': 'load'})
                return

            self.msg_queue.put({'action': 'log', 'msg': f"正在后台读取文件: {os.path.basename(file_path)}..."})

            # 3. 使用用户选择的文件路径读取数据
            train_df = pd.read_excel(file_path, sheet_name='train_data')
            test_df = pd.read_excel(file_path, sheet_name='test_data')
            infer_df = pd.read_excel(file_path, sheet_name='inference_cases')
            
            full_df = pd.concat([train_df, test_df], ignore_index=True)
            
            # 将加载好的数据赋值给实例变量
            self.train_df, self.test_df, self.infer_df, self.full_df = train_df, test_df, infer_df, full_df
            
            self.msg_queue.put({'action': 'log', 'msg': f"✅ 数据加载完成！总样本数: {len(self.full_df)}"})
            self.msg_queue.put({'action': 'enable_button', 'btn': 'train'})
            
        except Exception as e:
            self.msg_queue.put({'action': 'log', 'msg': f"❌ 读取失败: {e}"})
        finally:
            # 确保按钮状态正确恢复
            self.msg_queue.put({'action': 'enable_button', 'btn': 'load'})

    def train_rnn_worker(self):
        try:
            if self.full_df is None:
                self.msg_queue.put({'action': 'log', 'msg': "❌ 请先加载数据！"})
                return

            self.msg_queue.put({'action': 'log', 'msg': "正在准备时序数据..."})
            # 【关键修改】在进行任何删除操作前，先把推理集的 sample_id 备份出来！
            infer_ids = self.infer_df['sample_id'].copy()

            cols_to_drop = ['sample_id']
            target_col = 'daily_energy_kwh'
            if 'sample_id' in self.full_df.columns:
                self.full_df = self.full_df.drop(columns=cols_to_drop)
            if 'sample_id' in self.infer_df.columns:
                self.infer_df = self.infer_df.drop(columns=cols_to_drop)
            feature_cols = [col for col in self.full_df.columns if col not in ['sample_id', target_col]]
            
            # 特征工程
            temp_full_df = self.full_df.copy()
            temp_full_df = pd.get_dummies(temp_full_df, columns=['season_code'], prefix='season', dtype=float)
            
            temp_infer_df = self.infer_df.copy()
            for col in temp_full_df.columns:
                if col not in temp_infer_df.columns and col != target_col:
                    temp_infer_df[col] = 0
            temp_infer_df = temp_infer_df[temp_full_df.columns.drop(target_col)]

            X_full_raw = temp_full_df[temp_full_df.columns.difference([target_col])].values
            y_full_raw = self.full_df[target_col].values
            
            self.scaler_X.fit(X_full_raw)
            self.scaler_y.fit(y_full_raw.reshape(-1, 1))
            
            X_full_scaled = self.scaler_X.transform(X_full_raw)
            y_full_scaled = self.scaler_y.transform(y_full_raw.reshape(-1, 1)).flatten()
            X_infer_scaled = self.scaler_X.transform(temp_infer_df.values)
            
            n_steps = 7
            X_seq, y_seq = self.create_sequences(X_full_scaled, y_full_scaled, n_steps)
            
            self.msg_queue.put({'action': 'log', 'msg': f"构建序列完成。序列形状: {X_seq.shape}"})

            # 初始化模型并训练
            input_dim = X_seq.shape[2]
            self.rnn_model = SimpleRNN(input_size=input_dim, hidden_size=20, learning_rate=0.01)
            
            epochs = 100
            losses = []
            
            self.msg_queue.put({'action': 'log', 'msg': "开始 RNN 训练 (后台运行中，界面不会卡死)..."})
            
            for epoch in range(epochs):
                total_loss = 0
                indices = np.random.permutation(len(X_seq))
                
                for i in indices:
                    x_batch = X_seq[i]
                    y_batch = y_seq[i]
                    y_pred, _ = self.rnn_model.forward(x_batch)
                    loss = (y_pred[-1][0] - y_batch) ** 2
                    total_loss += loss
                    self.rnn_model.backward(x_batch, np.array([y_batch]))
                
                avg_loss = total_loss / len(X_seq)
                losses.append(avg_loss)
                
                if (epoch + 1) % 20 == 0:
                    self.msg_queue.put({'action': 'log', 'msg': f"Epoch {epoch+1}/{epochs} - 平均 Loss: {avg_loss:.6f}"})
                    # 实时更新图表
                    self.msg_queue.put({'action': 'plot_loss', 'losses': losses})

            self.msg_queue.put({'action': 'log', 'msg': "✅ 训练完成！正在进行最终推理预测..."})

            # 推理阶段
            result_predictions = []
            last_sequence_X = X_full_scaled[-7:] 
            
            for i in range(len(X_infer_scaled)):
                current_input = X_infer_scaled[i].reshape(1, -1)
                temp_seq = np.vstack([last_sequence_X, current_input])
                pred_scaled, _ = self.rnn_model.forward(temp_seq)
                pred_val = pred_scaled[-1][0]
                pred_original = self.scaler_y.inverse_transform(np.array([[pred_val]]))[0][0]
                result_predictions.append(pred_original)
                last_sequence_X = np.vstack([last_sequence_X[1:], current_input])

            # 【关键修改】这里不再去 infer_df 里找 sample_id，而是直接用我们最开始备份的 infer_ids
            result_df = pd.DataFrame({
                'case_id': infer_ids, 
                'predicted_daily_energy_kwh': np.round(result_predictions, 2)
            })   
            self.msg_queue.put({'action': 'log', 'msg': "🎉 预测完成！请点击导出按钮保存结果。"})
            self.msg_queue.put({'action': 'finish_inference', 'result_df': result_df})

        except Exception as e:
            self.msg_queue.put({'action': 'log', 'msg': f"❌ 发生异常: {str(e)}"})
        finally:
            self.msg_queue.put({'action': 'enable_button', 'btn': 'train'})

    def create_sequences(self, X, y, n_steps):
        Xs, ys = [], []
        for i in range(n_steps, len(X)):
            Xs.append(X[i-n_steps:i])
            ys.append(y[i])
        return np.array(Xs), np.array(ys)

    def export_results(self):
        if self.result_df is not None:
            save_path = filedialog.asksaveasfilename(
                initialfile="RNN_能耗预测结果.csv",
                defaultextension=".csv",
                filetypes=[("CSV files", "*.csv")]
            )
            if save_path:
                self.result_df.to_csv(save_path, index=False)
                self.log(f"结果已导出到: {save_path}")
                messagebox.showinfo("成功", "结果导出成功！")
        else:
            messagebox.showwarning("警告", "请先训练模型并生成预测结果。")

if __name__ == "__main__":
    root = tk.Tk()
    app = RNNEnergyApp(root)
    root.mainloop()