baoshiwei
2026-04-01 81b0ad0124847f083990d574dc8d20961ec6e713
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import joblib
import os
from datetime import datetime
 
# 尝试导入torch,如果失败则禁用深度学习模型支持
try:
    import torch
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
 
# 页面函数定义
def show_metered_weight_prediction():
    # 页面标题
    st.title("米重统一预测")
    
    # 初始化会话状态
    if 'selected_model' not in st.session_state:
        st.session_state['selected_model'] = None
    
    # 创建模型目录(如果不存在)
    model_dir = "saved_models"
    os.makedirs(model_dir, exist_ok=True)
    
    # 获取所有已保存的模型文件
    model_files = [f for f in os.listdir(model_dir) if f.endswith('.joblib')]
    model_files.sort(reverse=True)  # 最新的模型排在前面
    
    # 模型选择区域
    with st.expander("📁 选择模型", expanded=True):
        if not model_files:
            st.warning("尚未保存任何模型,请先训练模型并保存。")
        else:
            # 模型选择下拉框
            selected_model_file = st.selectbox(
                "选择已保存的模型",
                options=model_files,
                help="选择要用于预测的模型文件"
            )
            
            # 加载并显示模型信息
            if selected_model_file:
                model_path = os.path.join(model_dir, selected_model_file)
                model_info = joblib.load(model_path)
                
                # 显示模型基本信息
                st.subheader("📊 模型信息")
                info_cols = st.columns(2)
                
                with info_cols[0]:
                    st.metric("模型类型", model_info['model_type'])
                    st.metric("创建时间", model_info['created_at'].strftime('%Y-%m-%d %H:%M:%S'))
                    st.metric("使用稳态数据", "是" if model_info.get('use_steady_data', False) else "否")
                
                with info_cols[1]:
                    st.metric("R² 得分", f"{model_info['r2_score']:.4f}")
                    st.metric("均方误差 (MSE)", f"{model_info['mse']:.6f}")
                    st.metric("均方根误差 (RMSE)", f"{model_info['rmse']:.6f}")
                
                # 显示模型特征
                st.write("🔑 模型使用的特征:")
                st.code(", ".join(model_info['features']))
                
                # 如果是深度学习模型,显示序列长度
                if 'sequence_length' in model_info:
                    st.metric("序列长度", model_info['sequence_length'])
                
                # 保存模型信息到会话状态
                st.session_state['selected_model'] = model_info
                st.session_state['selected_model_file'] = selected_model_file
    
    # 预测功能区域
    st.subheader("🔮 米重预测")
    
    if st.session_state['selected_model']:
        model_info = st.session_state['selected_model']
        
        # 获取模型需要的特征
        required_features = model_info['features']
        
        # 创建预测表单
        st.write("输入特征值进行米重预测:")
        predict_cols = st.columns(2)
        input_features = {}
        
        # 显示输入表单
        for i, feature in enumerate(required_features):
            with predict_cols[i % 2]:
                input_features[feature] = st.number_input(
                    f"{feature}",
                    key=f"pred_{feature}",
                    value=0.0,
                    step=0.0001,
                    format="%.4f"
                )
        
        # 预测按钮
        if st.button("🚀 开始预测"):
            try:
                # 准备预测数据
                input_df = pd.DataFrame([input_features])
                
                # 根据模型类型执行不同的预测逻辑
                predicted_weight = None
                
                # 获取模型
                model = model_info['model']
                
                # 检查模型类型并执行预测
                if model_info['model_type'] in ['LSTM', 'GRU', 'BiLSTM']:
                    # 深度学习模型预测
                    if not TORCH_AVAILABLE:
                        st.error("PyTorch 未安装,无法使用深度学习模型进行预测。")
                        return
                    
                    # 数据标准化
                    scaler_X = model_info['scaler_X']
                    scaler_y = model_info['scaler_y']
                    input_scaled = scaler_X.transform(input_df)
                    
                    # 获取序列长度
                    sequence_length = model_info['sequence_length']
                    
                    # 为深度学习模型创建序列
                    input_seq = np.tile(input_scaled, (sequence_length, 1)).reshape(1, sequence_length, -1)
                    
                    # 转换为PyTorch张量
                    import torch
                    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
                    input_tensor = torch.tensor(input_seq, dtype=torch.float32).to(device)
                    
                    # 预测
                    model.eval()
                    with torch.no_grad():
                        y_pred_scaled_tensor = model(input_tensor)
                        y_pred_scaled = y_pred_scaled_tensor.cpu().numpy().ravel()[0]
                        
                        # 反归一化
                        predicted_weight = scaler_y.inverse_transform(np.array([[y_pred_scaled]]))[0][0]
                    
                elif model_info['model_type'] in ['SVR', 'MLP']:
                    # 支持向量机或多层感知器预测
                    
                    # 数据标准化
                    scaler_X = model_info['scaler_X']
                    scaler_y = model_info['scaler_y']
                    input_scaled = scaler_X.transform(input_df)
                    
                    # 预测
                    y_pred_scaled = model.predict(input_scaled)[0]
                    
                    # 反归一化
                    predicted_weight = scaler_y.inverse_transform(np.array([[y_pred_scaled]]))[0][0]
                    
                else:
                    # 其他模型(如随机森林、梯度提升、线性回归等)
                    predicted_weight = model.predict(input_df)[0]
                
                # 显示预测结果
                st.success(f"预测米重: {predicted_weight:.4f} Kg/m")
                
                
            except Exception as e:
                st.error(f"预测失败: {str(e)}")
    else:
        st.warning("请先选择一个模型。")
    
    # 模型管理区域
    if model_files:
        with st.expander("🗑️ 模型管理", expanded=False):
            st.write("管理已保存的模型文件:")
            
            # 显示所有模型文件
            for model_file in model_files:
                cols = st.columns([3, 1, 1])
                cols[0].write(model_file)
                
                # 查看模型信息按钮
                if cols[1].button("查看", key=f"view_{model_file}", help="查看模型信息"):
                    model_path = os.path.join(model_dir, model_file)
                    model_info = joblib.load(model_path)
                    st.write("模型详细信息:")
                    st.json({
                        'model_type': model_info['model_type'],
                        'created_at': model_info['created_at'].strftime('%Y-%m-%d %H:%M:%S'),
                        'r2_score': f"{model_info['r2_score']:.4f}",
                        'mse': f"{model_info['mse']:.6f}",
                        'mae': f"{model_info['mae']:.6f}",
                        'rmse': f"{model_info['rmse']:.6f}",
                        'features': model_info['features'],
                        'use_steady_data': model_info.get('use_steady_data', False)
                    })
                
                # 删除模型按钮
                if cols[2].button("删除", key=f"delete_{model_file}", help="删除模型文件", type="primary"):
                    model_path = os.path.join(model_dir, model_file)
                    os.remove(model_path)
                    st.success(f"已删除模型: {model_file}")
                    st.rerun()
 
# 页面入口
if __name__ == "__main__":
    show_metered_weight_prediction()