import streamlit as st import plotly.express as px import plotly.graph_objects as go import pandas as pd import numpy as np import joblib import os from datetime import datetime # 尝试导入torch,如果失败则禁用深度学习模型支持 try: import torch TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False # 页面函数定义 def show_metered_weight_prediction(): # 页面标题 st.title("米重统一预测") # 初始化会话状态 if 'selected_model' not in st.session_state: st.session_state['selected_model'] = None # 创建模型目录(如果不存在) model_dir = "saved_models" os.makedirs(model_dir, exist_ok=True) # 获取所有已保存的模型文件 model_files = [f for f in os.listdir(model_dir) if f.endswith('.joblib')] model_files.sort(reverse=True) # 最新的模型排在前面 # 模型选择区域 with st.expander("📁 选择模型", expanded=True): if not model_files: st.warning("尚未保存任何模型,请先训练模型并保存。") else: # 模型选择下拉框 selected_model_file = st.selectbox( "选择已保存的模型", options=model_files, help="选择要用于预测的模型文件" ) # 加载并显示模型信息 if selected_model_file: model_path = os.path.join(model_dir, selected_model_file) model_info = joblib.load(model_path) # 显示模型基本信息 st.subheader("📊 模型信息") info_cols = st.columns(2) with info_cols[0]: st.metric("模型类型", model_info['model_type']) st.metric("创建时间", model_info['created_at'].strftime('%Y-%m-%d %H:%M:%S')) st.metric("使用稳态数据", "是" if model_info.get('use_steady_data', False) else "否") with info_cols[1]: st.metric("R² 得分", f"{model_info['r2_score']:.4f}") st.metric("均方误差 (MSE)", f"{model_info['mse']:.6f}") st.metric("均方根误差 (RMSE)", f"{model_info['rmse']:.6f}") # 显示模型特征 st.write("🔑 模型使用的特征:") st.code(", ".join(model_info['features'])) # 如果是深度学习模型,显示序列长度 if 'sequence_length' in model_info: st.metric("序列长度", model_info['sequence_length']) # 保存模型信息到会话状态 st.session_state['selected_model'] = model_info st.session_state['selected_model_file'] = selected_model_file # 预测功能区域 st.subheader("🔮 米重预测") if st.session_state['selected_model']: model_info = st.session_state['selected_model'] # 获取模型需要的特征 required_features = model_info['features'] # 创建预测表单 st.write("输入特征值进行米重预测:") predict_cols = st.columns(2) input_features = {} # 显示输入表单 for i, feature in enumerate(required_features): with predict_cols[i % 2]: input_features[feature] = st.number_input( f"{feature}", key=f"pred_{feature}", value=0.0, step=0.0001, format="%.4f" ) # 预测按钮 if st.button("🚀 开始预测"): try: # 准备预测数据 input_df = pd.DataFrame([input_features]) # 根据模型类型执行不同的预测逻辑 predicted_weight = None # 获取模型 model = model_info['model'] # 检查模型类型并执行预测 if model_info['model_type'] in ['LSTM', 'GRU', 'BiLSTM']: # 深度学习模型预测 if not TORCH_AVAILABLE: st.error("PyTorch 未安装,无法使用深度学习模型进行预测。") return # 数据标准化 scaler_X = model_info['scaler_X'] scaler_y = model_info['scaler_y'] input_scaled = scaler_X.transform(input_df) # 获取序列长度 sequence_length = model_info['sequence_length'] # 为深度学习模型创建序列 input_seq = np.tile(input_scaled, (sequence_length, 1)).reshape(1, sequence_length, -1) # 转换为PyTorch张量 import torch device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') input_tensor = torch.tensor(input_seq, dtype=torch.float32).to(device) # 预测 model.eval() with torch.no_grad(): y_pred_scaled_tensor = model(input_tensor) y_pred_scaled = y_pred_scaled_tensor.cpu().numpy().ravel()[0] # 反归一化 predicted_weight = scaler_y.inverse_transform(np.array([[y_pred_scaled]]))[0][0] elif model_info['model_type'] in ['SVR', 'MLP']: # 支持向量机或多层感知器预测 # 数据标准化 scaler_X = model_info['scaler_X'] scaler_y = model_info['scaler_y'] input_scaled = scaler_X.transform(input_df) # 预测 y_pred_scaled = model.predict(input_scaled)[0] # 反归一化 predicted_weight = scaler_y.inverse_transform(np.array([[y_pred_scaled]]))[0][0] else: # 其他模型(如随机森林、梯度提升、线性回归等) predicted_weight = model.predict(input_df)[0] # 显示预测结果 st.success(f"预测米重: {predicted_weight:.4f} Kg/m") except Exception as e: st.error(f"预测失败: {str(e)}") else: st.warning("请先选择一个模型。") # 模型管理区域 if model_files: with st.expander("🗑️ 模型管理", expanded=False): st.write("管理已保存的模型文件:") # 显示所有模型文件 for model_file in model_files: cols = st.columns([3, 1, 1]) cols[0].write(model_file) # 查看模型信息按钮 if cols[1].button("查看", key=f"view_{model_file}", help="查看模型信息"): model_path = os.path.join(model_dir, model_file) model_info = joblib.load(model_path) st.write("模型详细信息:") st.json({ 'model_type': model_info['model_type'], 'created_at': model_info['created_at'].strftime('%Y-%m-%d %H:%M:%S'), 'r2_score': f"{model_info['r2_score']:.4f}", 'mse': f"{model_info['mse']:.6f}", 'mae': f"{model_info['mae']:.6f}", 'rmse': f"{model_info['rmse']:.6f}", 'features': model_info['features'], 'use_steady_data': model_info.get('use_steady_data', False) }) # 删除模型按钮 if cols[2].button("删除", key=f"delete_{model_file}", help="删除模型文件", type="primary"): model_path = os.path.join(model_dir, model_file) os.remove(model_path) st.success(f"已删除模型: {model_file}") st.rerun() # 页面入口 if __name__ == "__main__": show_metered_weight_prediction()