import streamlit as st
|
import plotly.express as px
|
import plotly.graph_objects as go
|
import pandas as pd
|
import numpy as np
|
import joblib
|
import os
|
from datetime import datetime
|
|
# 尝试导入torch,如果失败则禁用深度学习模型支持
|
try:
|
import torch
|
TORCH_AVAILABLE = True
|
except ImportError:
|
TORCH_AVAILABLE = False
|
|
# 页面函数定义
|
def show_metered_weight_prediction():
|
# 页面标题
|
st.title("米重统一预测")
|
|
# 初始化会话状态
|
if 'selected_model' not in st.session_state:
|
st.session_state['selected_model'] = None
|
|
# 创建模型目录(如果不存在)
|
model_dir = "saved_models"
|
os.makedirs(model_dir, exist_ok=True)
|
|
# 获取所有已保存的模型文件
|
model_files = [f for f in os.listdir(model_dir) if f.endswith('.joblib')]
|
model_files.sort(reverse=True) # 最新的模型排在前面
|
|
# 模型选择区域
|
with st.expander("📁 选择模型", expanded=True):
|
if not model_files:
|
st.warning("尚未保存任何模型,请先训练模型并保存。")
|
else:
|
# 模型选择下拉框
|
selected_model_file = st.selectbox(
|
"选择已保存的模型",
|
options=model_files,
|
help="选择要用于预测的模型文件"
|
)
|
|
# 加载并显示模型信息
|
if selected_model_file:
|
model_path = os.path.join(model_dir, selected_model_file)
|
model_info = joblib.load(model_path)
|
|
# 显示模型基本信息
|
st.subheader("📊 模型信息")
|
info_cols = st.columns(2)
|
|
with info_cols[0]:
|
st.metric("模型类型", model_info['model_type'])
|
st.metric("创建时间", model_info['created_at'].strftime('%Y-%m-%d %H:%M:%S'))
|
st.metric("使用稳态数据", "是" if model_info.get('use_steady_data', False) else "否")
|
|
with info_cols[1]:
|
st.metric("R² 得分", f"{model_info['r2_score']:.4f}")
|
st.metric("均方误差 (MSE)", f"{model_info['mse']:.6f}")
|
st.metric("均方根误差 (RMSE)", f"{model_info['rmse']:.6f}")
|
|
# 显示模型特征
|
st.write("🔑 模型使用的特征:")
|
st.code(", ".join(model_info['features']))
|
|
# 如果是深度学习模型,显示序列长度
|
if 'sequence_length' in model_info:
|
st.metric("序列长度", model_info['sequence_length'])
|
|
# 保存模型信息到会话状态
|
st.session_state['selected_model'] = model_info
|
st.session_state['selected_model_file'] = selected_model_file
|
|
# 预测功能区域
|
st.subheader("🔮 米重预测")
|
|
if st.session_state['selected_model']:
|
model_info = st.session_state['selected_model']
|
|
# 获取模型需要的特征
|
required_features = model_info['features']
|
|
# 创建预测表单
|
st.write("输入特征值进行米重预测:")
|
predict_cols = st.columns(2)
|
input_features = {}
|
|
# 显示输入表单
|
for i, feature in enumerate(required_features):
|
with predict_cols[i % 2]:
|
input_features[feature] = st.number_input(
|
f"{feature}",
|
key=f"pred_{feature}",
|
value=0.0,
|
step=0.0001,
|
format="%.4f"
|
)
|
|
# 预测按钮
|
if st.button("🚀 开始预测"):
|
try:
|
# 准备预测数据
|
input_df = pd.DataFrame([input_features])
|
|
# 根据模型类型执行不同的预测逻辑
|
predicted_weight = None
|
|
# 获取模型
|
model = model_info['model']
|
|
# 检查模型类型并执行预测
|
if model_info['model_type'] in ['LSTM', 'GRU', 'BiLSTM']:
|
# 深度学习模型预测
|
if not TORCH_AVAILABLE:
|
st.error("PyTorch 未安装,无法使用深度学习模型进行预测。")
|
return
|
|
# 数据标准化
|
scaler_X = model_info['scaler_X']
|
scaler_y = model_info['scaler_y']
|
input_scaled = scaler_X.transform(input_df)
|
|
# 获取序列长度
|
sequence_length = model_info['sequence_length']
|
|
# 为深度学习模型创建序列
|
input_seq = np.tile(input_scaled, (sequence_length, 1)).reshape(1, sequence_length, -1)
|
|
# 转换为PyTorch张量
|
import torch
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
input_tensor = torch.tensor(input_seq, dtype=torch.float32).to(device)
|
|
# 预测
|
model.eval()
|
with torch.no_grad():
|
y_pred_scaled_tensor = model(input_tensor)
|
y_pred_scaled = y_pred_scaled_tensor.cpu().numpy().ravel()[0]
|
|
# 反归一化
|
predicted_weight = scaler_y.inverse_transform(np.array([[y_pred_scaled]]))[0][0]
|
|
elif model_info['model_type'] in ['SVR', 'MLP']:
|
# 支持向量机或多层感知器预测
|
|
# 数据标准化
|
scaler_X = model_info['scaler_X']
|
scaler_y = model_info['scaler_y']
|
input_scaled = scaler_X.transform(input_df)
|
|
# 预测
|
y_pred_scaled = model.predict(input_scaled)[0]
|
|
# 反归一化
|
predicted_weight = scaler_y.inverse_transform(np.array([[y_pred_scaled]]))[0][0]
|
|
else:
|
# 其他模型(如随机森林、梯度提升、线性回归等)
|
predicted_weight = model.predict(input_df)[0]
|
|
# 显示预测结果
|
st.success(f"预测米重: {predicted_weight:.4f} Kg/m")
|
|
|
except Exception as e:
|
st.error(f"预测失败: {str(e)}")
|
else:
|
st.warning("请先选择一个模型。")
|
|
# 模型管理区域
|
if model_files:
|
with st.expander("🗑️ 模型管理", expanded=False):
|
st.write("管理已保存的模型文件:")
|
|
# 显示所有模型文件
|
for model_file in model_files:
|
cols = st.columns([3, 1, 1])
|
cols[0].write(model_file)
|
|
# 查看模型信息按钮
|
if cols[1].button("查看", key=f"view_{model_file}", help="查看模型信息"):
|
model_path = os.path.join(model_dir, model_file)
|
model_info = joblib.load(model_path)
|
st.write("模型详细信息:")
|
st.json({
|
'model_type': model_info['model_type'],
|
'created_at': model_info['created_at'].strftime('%Y-%m-%d %H:%M:%S'),
|
'r2_score': f"{model_info['r2_score']:.4f}",
|
'mse': f"{model_info['mse']:.6f}",
|
'mae': f"{model_info['mae']:.6f}",
|
'rmse': f"{model_info['rmse']:.6f}",
|
'features': model_info['features'],
|
'use_steady_data': model_info.get('use_steady_data', False)
|
})
|
|
# 删除模型按钮
|
if cols[2].button("删除", key=f"delete_{model_file}", help="删除模型文件", type="primary"):
|
model_path = os.path.join(model_dir, model_file)
|
os.remove(model_path)
|
st.success(f"已删除模型: {model_file}")
|
st.rerun()
|
|
# 页面入口
|
if __name__ == "__main__":
|
show_metered_weight_prediction()
|