import streamlit as st
|
import plotly.express as px
|
import plotly.graph_objects as go
|
import pandas as pd
|
import numpy as np
|
import joblib
|
import os
|
from datetime import datetime, timedelta
|
from app.services.extruder_service import ExtruderService
|
from app.services.main_process_service import MainProcessService
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
from sklearn.model_selection import train_test_split
|
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
|
|
# 导入稳态识别功能
|
class SteadyStateDetector:
|
def __init__(self):
|
pass
|
|
def detect_steady_state(self, df, weight_col='米重', window_size=20, std_threshold=0.5, duration_threshold=60):
|
"""
|
稳态识别逻辑:标记米重数据中的稳态段
|
:param df: 包含米重数据的数据框
|
:param weight_col: 米重列名
|
:param window_size: 滑动窗口大小(秒)
|
:param std_threshold: 标准差阈值
|
:param duration_threshold: 稳态持续时间阈值(秒)
|
:param trend_threshold: 趋势阈值(绝对值)
|
:return: 包含稳态标记的数据框和稳态信息
|
"""
|
if df is None or df.empty:
|
return df, []
|
|
# 确保时间列是datetime类型
|
df['time'] = pd.to_datetime(df['time'])
|
|
# 计算滚动统计量
|
df['rolling_std'] = df[weight_col].rolling(window=window_size, min_periods=5).std()
|
df['rolling_mean'] = df[weight_col].rolling(window=window_size, min_periods=5).mean()
|
|
# 计算波动范围
|
df['fluctuation_range'] = (df['rolling_std'] / df['rolling_mean']) * 100
|
df['fluctuation_range'] = df['fluctuation_range'].fillna(0)
|
|
# 计算趋势
|
# df['trend'] = df[weight_col].diff().rolling(window=window_size, min_periods=5).mean()
|
# df['trend'] = df['trend'].fillna(0)
|
# df['trend_strength'] = (abs(df['trend']) / df['rolling_mean']) * 100
|
# df['trend_strength'] = df['trend_strength'].fillna(0)
|
|
# 标记稳态点
|
df['is_steady'] = 0
|
steady_condition = (
|
(df['fluctuation_range'] < std_threshold) &
|
(df[weight_col] >= 0.1)
|
)
|
df.loc[steady_condition, 'is_steady'] = 1
|
|
# 识别连续稳态段
|
steady_segments = []
|
current_segment = {}
|
|
for i, row in df.iterrows():
|
if row['is_steady'] == 1:
|
if not current_segment:
|
current_segment = {
|
'start_time': row['time'],
|
'start_idx': i,
|
'weights': [row[weight_col]]
|
}
|
else:
|
current_segment['weights'].append(row[weight_col])
|
else:
|
if current_segment:
|
current_segment['end_time'] = df.loc[i-1, 'time'] if i > 0 else df.loc[i, 'time']
|
current_segment['end_idx'] = i-1
|
duration = (current_segment['end_time'] - current_segment['start_time']).total_seconds()
|
|
if duration >= duration_threshold:
|
weights_array = np.array(current_segment['weights'])
|
current_segment['duration'] = duration
|
current_segment['mean_weight'] = np.mean(weights_array)
|
current_segment['std_weight'] = np.std(weights_array)
|
current_segment['min_weight'] = np.min(weights_array)
|
current_segment['max_weight'] = np.max(weights_array)
|
current_segment['fluctuation_range'] = (current_segment['std_weight'] / current_segment['mean_weight']) * 100
|
|
# 计算置信度
|
confidence = 100 - (current_segment['fluctuation_range'] / std_threshold) * 50
|
confidence = max(50, min(100, confidence))
|
current_segment['confidence'] = confidence
|
|
steady_segments.append(current_segment)
|
|
current_segment = {}
|
|
# 处理最后一个稳态段
|
if current_segment:
|
current_segment['end_time'] = df['time'].iloc[-1]
|
current_segment['end_idx'] = len(df) - 1
|
duration = (current_segment['end_time'] - current_segment['start_time']).total_seconds()
|
|
if duration >= duration_threshold:
|
weights_array = np.array(current_segment['weights'])
|
current_segment['duration'] = duration
|
current_segment['mean_weight'] = np.mean(weights_array)
|
current_segment['std_weight'] = np.std(weights_array)
|
current_segment['min_weight'] = np.min(weights_array)
|
current_segment['max_weight'] = np.max(weights_array)
|
current_segment['fluctuation_range'] = (current_segment['std_weight'] / current_segment['mean_weight']) * 100
|
|
confidence = 100 - (current_segment['fluctuation_range'] / std_threshold) * 50
|
confidence = max(50, min(100, confidence))
|
current_segment['confidence'] = confidence
|
|
steady_segments.append(current_segment)
|
|
# 在数据框中标记完整的稳态段
|
for segment in steady_segments:
|
df.loc[segment['start_idx']:segment['end_idx'], 'is_steady'] = 1
|
|
return df, steady_segments
|
|
# 尝试导入深度学习库
|
use_deep_learning = False
|
try:
|
import torch
|
import torch.nn as nn
|
import torch.optim as optim
|
use_deep_learning = True
|
# 检测GPU是否可用
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
print(f"使用设备: {device}")
|
|
# PyTorch深度学习模型定义
|
class LSTMModel(nn.Module):
|
def __init__(self, input_dim, hidden_dim=64, num_layers=2):
|
super(LSTMModel, self).__init__()
|
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
|
self.fc1 = nn.Linear(hidden_dim, 32)
|
self.dropout = nn.Dropout(0.2)
|
self.fc2 = nn.Linear(32, 1)
|
|
def forward(self, x):
|
out, _ = self.lstm(x)
|
out = out[:, -1, :]
|
out = torch.relu(self.fc1(out))
|
out = self.dropout(out)
|
out = self.fc2(out)
|
return out
|
|
class GRUModel(nn.Module):
|
def __init__(self, input_dim, hidden_dim=64, num_layers=2):
|
super(GRUModel, self).__init__()
|
self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True)
|
self.fc1 = nn.Linear(hidden_dim, 32)
|
self.dropout = nn.Dropout(0.2)
|
self.fc2 = nn.Linear(32, 1)
|
|
def forward(self, x):
|
out, _ = self.gru(x)
|
out = out[:, -1, :]
|
out = torch.relu(self.fc1(out))
|
out = self.dropout(out)
|
out = self.fc2(out)
|
return out
|
|
class BiLSTMModel(nn.Module):
|
def __init__(self, input_dim, hidden_dim=64, num_layers=2):
|
super(BiLSTMModel, self).__init__()
|
self.bilstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)
|
self.fc1 = nn.Linear(hidden_dim * 2, 32)
|
self.dropout = nn.Dropout(0.2)
|
self.fc2 = nn.Linear(32, 1)
|
|
def forward(self, x):
|
out, _ = self.bilstm(x)
|
out = out[:, -1, :]
|
out = torch.relu(self.fc1(out))
|
out = self.dropout(out)
|
out = self.fc2(out)
|
return out
|
|
st.success(f"使用设备: {device}")
|
except ImportError:
|
st.warning("未检测到PyTorch,深度学习模型将不可用。请安装pytorch以使用LSTM/GRU模型。")
|
|
def show_metered_weight_deep_learning():
|
# 初始化服务
|
extruder_service = ExtruderService()
|
main_process_service = MainProcessService()
|
|
# 页面标题
|
st.title("米重深度学习预测")
|
|
# 初始化会话状态
|
if 'mdl_start_date' not in st.session_state:
|
st.session_state['mdl_start_date'] = datetime.now().date() - timedelta(days=7)
|
if 'mdl_end_date' not in st.session_state:
|
st.session_state['mdl_end_date'] = datetime.now().date()
|
if 'mdl_quick_select' not in st.session_state:
|
st.session_state['mdl_quick_select'] = "最近7天"
|
if 'mdl_model_type' not in st.session_state:
|
st.session_state['mdl_model_type'] = 'LSTM'
|
if 'mdl_sequence_length' not in st.session_state:
|
st.session_state['mdl_sequence_length'] = 10
|
if 'mdl_time_offset' not in st.session_state:
|
st.session_state['mdl_time_offset'] = 0
|
if 'mdl_product_variety' not in st.session_state:
|
st.session_state['mdl_product_variety'] = 'all'
|
if 'mdl_filter_transient' not in st.session_state:
|
st.session_state['mdl_filter_transient'] = True
|
|
# 默认特征列表
|
default_features = ['螺杆转速', '机头压力', '流程主速', '螺杆温度',
|
'后机筒温度', '前机筒温度', '机头温度']
|
|
# 定义回调函数
|
def update_dates(qs):
|
st.session_state['mdl_quick_select'] = qs
|
today = datetime.now().date()
|
if qs == "今天":
|
st.session_state['mdl_start_date'] = today
|
st.session_state['mdl_end_date'] = today
|
elif qs == "最近3天":
|
st.session_state['mdl_start_date'] = today - timedelta(days=3)
|
st.session_state['mdl_end_date'] = today
|
elif qs == "最近7天":
|
st.session_state['mdl_start_date'] = today - timedelta(days=7)
|
st.session_state['mdl_end_date'] = today
|
elif qs == "最近30天":
|
st.session_state['mdl_start_date'] = today - timedelta(days=30)
|
st.session_state['mdl_end_date'] = today
|
|
def on_date_change():
|
st.session_state['mdl_quick_select'] = "自定义"
|
|
# 查询条件区域
|
with st.expander("🔍 查询配置", expanded=True):
|
# 添加自定义 CSS 实现响应式换行
|
st.markdown("""
|
<style>
|
/* 强制列容器换行 */
|
[data-testid="stExpander"] [data-testid="column"] {
|
flex: 1 1 120px !important;
|
min-width: 120px !important;
|
}
|
/* 针对日期输入框列稍微加宽一点 */
|
@media (min-width: 768px) {
|
[data-testid="stExpander"] [data-testid="column"]:nth-child(6),
|
[data-testid="stExpander"] [data-testid="column"]:nth-child(7) {
|
flex: 2 1 180px !important;
|
min-width: 180px !important;
|
}
|
}
|
</style>
|
""", unsafe_allow_html=True)
|
|
# 创建布局
|
cols = st.columns([1, 1, 1, 1, 1, 1.5, 1.5, 1])
|
|
options = ["今天", "最近3天", "最近7天", "最近30天", "自定义"]
|
for i, option in enumerate(options):
|
with cols[i]:
|
# 根据当前选择状态决定按钮类型
|
button_type = "primary" if st.session_state['mdl_quick_select'] == option else "secondary"
|
if st.button(option, key=f"btn_mdl_{option}", width='stretch', type=button_type):
|
update_dates(option)
|
st.rerun()
|
|
with cols[5]:
|
start_date = st.date_input(
|
"开始日期",
|
label_visibility="collapsed",
|
key="mdl_start_date",
|
on_change=on_date_change
|
)
|
|
with cols[6]:
|
end_date = st.date_input(
|
"结束日期",
|
label_visibility="collapsed",
|
key="mdl_end_date",
|
on_change=on_date_change
|
)
|
|
with cols[7]:
|
query_button = st.button("🚀 开始分析", key="mdl_query", width='stretch')
|
|
# 高级配置
|
st.markdown("---")
|
advanced_cols = st.columns(2)
|
|
with advanced_cols[0]:
|
st.write("🤖 **模型配置**")
|
# 模型类型选择
|
if use_deep_learning:
|
model_options = ['LSTM', 'GRU', 'BiLSTM']
|
model_type = st.selectbox(
|
"模型类型",
|
options=model_options,
|
key="mdl_model_type",
|
help="选择用于预测的深度学习模型类型"
|
)
|
|
# 序列长度
|
sequence_length = st.slider(
|
"序列长度",
|
min_value=5,
|
max_value=30,
|
value=st.session_state['mdl_sequence_length'],
|
step=1,
|
help="用于深度学习模型的时间序列长度",
|
key="mdl_sequence_length"
|
)
|
else:
|
st.warning("未检测到PyTorch,无法使用深度学习模型")
|
|
with advanced_cols[1]:
|
st.write("⏱️ **时间延迟配置**")
|
# 动态时间偏移(基于流程主速)
|
time_offset = st.slider(
|
"挤出数据向后偏移 (分钟)",
|
min_value=0,
|
max_value=60,
|
value=st.session_state['mdl_time_offset'],
|
step=1,
|
help="由于胎面从挤出到称重需要时间,将挤出机数据向后移动,使其与米重数据在时间轴上对齐。偏移量会影响预测准确性。",
|
key="mdl_time_offset"
|
)
|
|
# 稳态识别配置
|
st.markdown("---")
|
steady_cols = st.columns(3)
|
with steady_cols[0]:
|
st.write("⚖️ **稳态识别配置**")
|
use_steady_data = st.checkbox(
|
"仅使用稳态数据进行训练",
|
value=True,
|
key="mdl_use_steady_data",
|
help="启用后,只使用米重稳态时段的数据进行模型训练和预测"
|
)
|
|
with steady_cols[1]:
|
st.write("📏 **稳态参数**")
|
steady_window = st.slider(
|
"滑动窗口大小 (秒)",
|
min_value=5,
|
max_value=60,
|
value=20,
|
step=5,
|
key="mdl_steady_window",
|
help="用于稳态识别的滑动窗口大小"
|
)
|
|
with steady_cols[2]:
|
st.write("📊 **稳态阈值**")
|
steady_threshold = st.slider(
|
"波动阈值 (%)",
|
min_value=0.1,
|
max_value=2.0,
|
value=0.5,
|
step=0.1,
|
key="mdl_steady_threshold",
|
help="稳态识别的波动范围阈值"
|
)
|
|
|
|
# 转换为datetime对象
|
start_dt = datetime.combine(start_date, datetime.min.time())
|
end_dt = datetime.combine(end_date, datetime.max.time())
|
|
# 查询处理
|
if query_button:
|
with st.spinner("正在获取数据..."):
|
# 1. 获取完整的挤出机数据
|
df_extruder_full = extruder_service.get_extruder_data(start_dt, end_dt)
|
|
# 2. 获取主流程控制数据
|
df_main_speed = main_process_service.get_cutting_setting_data(start_dt, end_dt)
|
|
df_temp = main_process_service.get_temperature_control_data(start_dt, end_dt)
|
|
# 检查是否有数据
|
has_data = any([
|
df_extruder_full is not None and not df_extruder_full.empty,
|
df_main_speed is not None and not df_main_speed.empty,
|
df_temp is not None and not df_temp.empty
|
])
|
|
if not has_data:
|
st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
|
return
|
|
# 缓存数据到会话状态
|
st.session_state['cached_extruder_full'] = df_extruder_full
|
st.session_state['cached_main_speed'] = df_main_speed
|
st.session_state['cached_temp'] = df_temp
|
st.session_state['last_query_start'] = start_dt
|
st.session_state['last_query_end'] = end_dt
|
|
# 数据处理和分析
|
if all(key in st.session_state for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp']):
|
with st.spinner("正在分析数据..."):
|
# 获取缓存数据
|
df_extruder_full = st.session_state['cached_extruder_full']
|
df_main_speed = st.session_state['cached_main_speed']
|
df_temp = st.session_state['cached_temp']
|
|
# 检查是否有数据
|
has_data = any([
|
df_extruder_full is not None and not df_extruder_full.empty,
|
df_main_speed is not None and not df_main_speed.empty,
|
df_temp is not None and not df_temp.empty
|
])
|
|
if not has_data:
|
st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
|
return
|
|
# 数据整合与预处理
|
def integrate_data(df_extruder_full, df_main_speed, df_temp, time_offset):
|
# 确保挤出机数据存在
|
if df_extruder_full is None or df_extruder_full.empty:
|
return None
|
|
# 应用时间偏移
|
offset_delta = timedelta(minutes=time_offset)
|
df_extruder_shifted = df_extruder_full.copy()
|
df_extruder_shifted['time'] = df_extruder_shifted['time'] + offset_delta
|
|
# 创建只包含米重和时间的主数据集
|
df_merged = df_extruder_shifted[['time', 'metered_weight', 'screw_speed_actual', 'head_pressure']].copy()
|
|
# 整合主流程数据
|
if df_main_speed is not None and not df_main_speed.empty:
|
df_main_speed_shifted = df_main_speed.copy()
|
df_main_speed_shifted['time'] = df_main_speed_shifted['time'] + offset_delta
|
|
df_main_speed_shifted = df_main_speed_shifted[['time', 'process_main_speed']]
|
df_merged = pd.merge_asof(
|
df_merged.sort_values('time'),
|
df_main_speed_shifted.sort_values('time'),
|
on='time',
|
direction='nearest',
|
tolerance=pd.Timedelta('1min')
|
)
|
|
# 整合温度数据
|
if df_temp is not None and not df_temp.empty:
|
df_temp_shifted = df_temp.copy()
|
df_temp_shifted['time'] = df_temp_shifted['time'] + offset_delta
|
|
temp_cols = ['time', 'nakata_extruder_screw_display_temp',
|
'nakata_extruder_rear_barrel_display_temp',
|
'nakata_extruder_front_barrel_display_temp',
|
'nakata_extruder_head_display_temp']
|
df_temp_subset = df_temp_shifted[temp_cols].copy()
|
df_merged = pd.merge_asof(
|
df_merged.sort_values('time'),
|
df_temp_subset.sort_values('time'),
|
on='time',
|
direction='nearest',
|
tolerance=pd.Timedelta('1min')
|
)
|
|
# 重命名列以提高可读性
|
df_merged.rename(columns={
|
'screw_speed_actual': '螺杆转速',
|
'head_pressure': '机头压力',
|
'process_main_speed': '流程主速',
|
'nakata_extruder_screw_display_temp': '螺杆温度',
|
'nakata_extruder_rear_barrel_display_temp': '后机筒温度',
|
'nakata_extruder_front_barrel_display_temp': '前机筒温度',
|
'nakata_extruder_head_display_temp': '机头温度'
|
}, inplace=True)
|
|
# 清理数据
|
df_merged.dropna(subset=['metered_weight'], inplace=True)
|
|
return df_merged
|
|
# 执行数据整合
|
df_analysis = integrate_data(df_extruder_full, df_main_speed, df_temp, st.session_state['mdl_time_offset'])
|
|
if df_analysis is None or df_analysis.empty:
|
st.warning("数据整合失败,请检查数据质量或调整时间范围。")
|
return
|
|
# 重命名米重列
|
df_analysis.rename(columns={'metered_weight': '米重'}, inplace=True)
|
|
# 稳态识别
|
steady_detector = SteadyStateDetector()
|
|
# 获取稳态识别参数
|
use_steady_data = st.session_state.get('mdl_use_steady_data', True)
|
steady_window = st.session_state.get('mdl_steady_window', 20)
|
steady_threshold = st.session_state.get('mdl_steady_threshold', 0.5)
|
|
# 执行稳态识别
|
df_analysis_with_steady, steady_segments = steady_detector.detect_steady_state(
|
df_analysis,
|
weight_col='米重',
|
window_size=steady_window,
|
std_threshold=steady_threshold
|
)
|
|
# 更新df_analysis为包含稳态标记的数据
|
df_analysis = df_analysis_with_steady
|
|
|
|
# 高级预测分析
|
st.subheader("📊 深度学习预测分析")
|
|
if use_deep_learning:
|
# 检查所有默认特征是否在数据中
|
missing_features = [f for f in default_features if f not in df_analysis.columns]
|
if missing_features:
|
st.warning(f"数据中缺少以下特征: {', '.join(missing_features)}")
|
else:
|
# 准备数据
|
required_cols = default_features + ['米重', 'is_steady']
|
combined = df_analysis[required_cols].copy()
|
|
# 如果启用了稳态数据,过滤掉非稳态数据
|
use_steady_data = st.session_state.get('mdl_use_steady_data', True)
|
if use_steady_data:
|
combined = combined[combined['is_steady'] == 1]
|
st.info(f"已过滤非稳态数据,使用 {len(combined)} 条稳态数据进行训练")
|
|
# 清理数据中的NaN值
|
combined_clean = combined.dropna()
|
|
# 检查清理后的数据量
|
if len(combined_clean) < 30:
|
st.warning("数据量不足,无法进行有效的预测分析")
|
if use_steady_data:
|
st.info("建议:尝试调整稳态识别参数或禁用'仅使用稳态数据'选项")
|
else:
|
# 显示稳态统计
|
total_data = len(df_analysis)
|
steady_data = len(combined_clean)
|
steady_ratio = (steady_data / total_data * 100) if total_data > 0 else 0
|
|
metrics_cols = st.columns(3)
|
with metrics_cols[0]:
|
st.metric("总数据量", total_data)
|
with metrics_cols[1]:
|
st.metric("稳态数据量", steady_data)
|
with metrics_cols[2]:
|
st.metric("稳态数据比例", f"{steady_ratio:.1f}%")
|
|
# 稳态数据可视化
|
st.markdown("---")
|
st.subheader("📈 稳态数据分布")
|
|
# 创建稳态数据可视化图表
|
fig_steady = go.Figure()
|
|
# 添加原始米重曲线
|
fig_steady.add_trace(go.Scatter(
|
x=df_analysis['time'],
|
y=df_analysis['米重'],
|
name='原始米重',
|
mode='lines',
|
line=dict(color='lightgray', width=1)
|
))
|
|
# 添加稳态数据点
|
steady_data_points = df_analysis[df_analysis['is_steady'] == 1]
|
fig_steady.add_trace(go.Scatter(
|
x=steady_data_points['time'],
|
y=steady_data_points['米重'],
|
name='稳态米重',
|
mode='markers',
|
marker=dict(color='green', size=3, opacity=0.6)
|
))
|
|
# 添加非稳态数据点
|
non_steady_data_points = df_analysis[df_analysis['is_steady'] == 0]
|
fig_steady.add_trace(go.Scatter(
|
x=non_steady_data_points['time'],
|
y=non_steady_data_points['米重'],
|
name='非稳态米重',
|
mode='markers',
|
marker=dict(color='red', size=3, opacity=0.6)
|
))
|
|
# 配置图表布局
|
fig_steady.update_layout(
|
title="米重数据稳态分布",
|
xaxis=dict(title="时间"),
|
yaxis=dict(title="米重 (Kg/m)"),
|
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
|
height=500
|
)
|
|
# 显示图表
|
st.plotly_chart(fig_steady, use_container_width=True)
|
|
# 分离X和y
|
X_clean = combined_clean[default_features]
|
y_clean = combined_clean['米重']
|
|
# 为时间序列模型准备数据
|
def create_sequences(X, y, sequence_length):
|
X_seq = []
|
y_seq = []
|
for i in range(len(X) - sequence_length):
|
X_seq.append(X[i:i+sequence_length])
|
y_seq.append(y[i+sequence_length])
|
return np.array(X_seq), np.array(y_seq)
|
|
# 数据标准化
|
scaler_X = StandardScaler()
|
scaler_y = MinMaxScaler()
|
|
X_scaled = scaler_X.fit_transform(X_clean)
|
y_scaled = scaler_y.fit_transform(y_clean.values.reshape(-1, 1)).ravel()
|
|
# 创建序列数据
|
sequence_length = st.session_state['mdl_sequence_length']
|
X_seq, y_seq = create_sequences(X_scaled, y_scaled, sequence_length)
|
|
# 检查序列数据量
|
if len(X_seq) < 20:
|
st.warning("序列数据量不足,无法进行有效的深度学习训练")
|
else:
|
# 分割训练集和测试集
|
train_size = int(len(X_seq) * 0.8)
|
X_train_seq, X_test_seq = X_seq[:train_size], X_seq[train_size:]
|
y_train_seq, y_test_seq = y_seq[:train_size], y_seq[train_size:]
|
|
# 转换为PyTorch张量
|
X_train_tensor = torch.tensor(X_train_seq, dtype=torch.float32).to(device)
|
y_train_tensor = torch.tensor(y_train_seq, dtype=torch.float32).unsqueeze(1).to(device)
|
X_test_tensor = torch.tensor(X_test_seq, dtype=torch.float32).to(device)
|
y_test_tensor = torch.tensor(y_test_seq, dtype=torch.float32).unsqueeze(1).to(device)
|
|
# 构建模型
|
input_dim = X_scaled.shape[1]
|
|
if st.session_state['mdl_model_type'] == 'LSTM':
|
model = LSTMModel(input_dim).to(device)
|
elif st.session_state['mdl_model_type'] == 'GRU':
|
model = GRUModel(input_dim).to(device)
|
elif st.session_state['mdl_model_type'] == 'BiLSTM':
|
model = BiLSTMModel(input_dim).to(device)
|
|
# 定义损失函数和优化器
|
criterion = nn.MSELoss()
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
# 训练模型
|
num_epochs = 50
|
batch_size = 32
|
|
# 显示训练进度
|
progress_bar = st.progress(0)
|
status_text = st.empty()
|
|
for epoch in range(num_epochs):
|
model.train()
|
optimizer.zero_grad()
|
|
# 前向传播
|
outputs = model(X_train_tensor)
|
loss = criterion(outputs, y_train_tensor)
|
|
# 反向传播和优化
|
loss.backward()
|
optimizer.step()
|
|
# 更新进度
|
progress_bar.progress((epoch + 1) / num_epochs)
|
status_text.text(f"训练中: 第 {epoch + 1}/{num_epochs} 轮, 损失: {loss.item():.6f}")
|
|
# 预测
|
model.eval()
|
with torch.no_grad():
|
y_pred_scaled_tensor = model(X_test_tensor)
|
y_pred_scaled = y_pred_scaled_tensor.cpu().numpy().ravel()
|
|
# 反归一化
|
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
|
y_test_actual = scaler_y.inverse_transform(y_test_seq.reshape(-1, 1)).ravel()
|
|
# 计算评估指标
|
r2 = r2_score(y_test_actual, y_pred)
|
mse = mean_squared_error(y_test_actual, y_pred)
|
mae = mean_absolute_error(y_test_actual, y_pred)
|
rmse = np.sqrt(mse)
|
|
# 显示模型性能
|
metrics_cols = st.columns(2)
|
with metrics_cols[0]:
|
st.metric("R² 得分", f"{r2:.4f}")
|
st.metric("均方误差 (MSE)", f"{mse:.6f}")
|
with metrics_cols[1]:
|
st.metric("平均绝对误差 (MAE)", f"{mae:.6f}")
|
st.metric("均方根误差 (RMSE)", f"{rmse:.6f}")
|
|
# 添加稳态相关的评估说明
|
use_steady_data = st.session_state.get('mdl_use_steady_data', True)
|
if use_steady_data:
|
st.info("⚠️ 模型仅使用稳态数据进行训练,在非稳态工况下预测结果可能不准确")
|
|
# --- 实际值与预测值对比 ---
|
|
# --- 实际值与预测值对比 ---
|
st.subheader("🔄 实际值与预测值对比")
|
|
# 创建对比数据
|
compare_df = pd.DataFrame({
|
'实际值': y_test_actual,
|
'预测值': y_pred
|
})
|
compare_df = compare_df.sort_index()
|
|
# 创建对比图
|
fig_compare = go.Figure()
|
fig_compare.add_trace(go.Scatter(
|
x=compare_df.index,
|
y=compare_df['实际值'],
|
name='实际值',
|
mode='lines+markers',
|
line=dict(color='blue', width=2)
|
))
|
fig_compare.add_trace(go.Scatter(
|
x=compare_df.index,
|
y=compare_df['预测值'],
|
name='预测值',
|
mode='lines+markers',
|
line=dict(color='red', width=2, dash='dash')
|
))
|
fig_compare.update_layout(
|
title=f'测试集: 实际米重 vs 预测米重 ({st.session_state["mdl_model_type"]})',
|
xaxis=dict(title='样本索引'),
|
yaxis=dict(title='米重 (Kg/m)'),
|
legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
|
height=400
|
)
|
st.plotly_chart(fig_compare, width='stretch')
|
|
# --- 残差分析 ---
|
st.subheader("📉 残差分析")
|
|
# 计算残差
|
residuals = y_test_actual - y_pred
|
|
# 创建残差图
|
fig_residual = go.Figure()
|
fig_residual.add_trace(go.Scatter(
|
x=y_pred,
|
y=residuals,
|
mode='markers',
|
marker=dict(color='green', size=8, opacity=0.6)
|
))
|
fig_residual.add_shape(
|
type="line",
|
x0=y_pred.min(),
|
y0=0,
|
x1=y_pred.max(),
|
y1=0,
|
line=dict(color="red", width=2, dash="dash")
|
)
|
fig_residual.update_layout(
|
title='残差图',
|
xaxis=dict(title='预测值'),
|
yaxis=dict(title='残差'),
|
height=400
|
)
|
st.plotly_chart(fig_residual, width='stretch')
|
|
# --- 模型保存 ---
|
st.subheader("💾 模型保存")
|
|
# 创建模型目录(如果不存在)
|
model_dir = "saved_models"
|
os.makedirs(model_dir, exist_ok=True)
|
|
# 准备模型信息
|
model_info = {
|
'model': model,
|
'features': default_features,
|
'scaler_X': scaler_X,
|
'scaler_y': scaler_y,
|
'model_type': st.session_state['mdl_model_type'],
|
'sequence_length': sequence_length,
|
'created_at': datetime.now(),
|
'r2_score': r2,
|
'mse': mse,
|
'mae': mae,
|
'rmse': rmse,
|
'use_steady_data': use_steady_data
|
}
|
|
# 生成模型文件名
|
model_filename = f"deep_{st.session_state['mdl_model_type'].lower()}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.joblib"
|
model_path = os.path.join(model_dir, model_filename)
|
|
# 保存模型
|
joblib.dump(model_info, model_path)
|
|
st.success(f"模型已成功保存: {model_filename}")
|
st.info(f"保存路径: {model_path}")
|
else:
|
st.warning("未检测到PyTorch,无法使用深度学习预测功能。请确保已正确安装PyTorch库。")
|
|
# --- 数据预览 ---
|
st.subheader("🔍 数据预览")
|
st.dataframe(df_analysis.head(20), width='stretch')
|
|
# --- 导出数据 ---
|
st.subheader("💾 导出数据")
|
# 将数据转换为CSV格式
|
csv = df_analysis.to_csv(index=False)
|
# 创建下载按钮
|
st.download_button(
|
label="导出整合后的数据 (CSV)",
|
data=csv,
|
file_name=f"metered_weight_deep_learning_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
|
mime="text/csv",
|
help="点击按钮导出整合后的米重分析数据"
|
)
|
|
else:
|
# 提示用户点击开始分析按钮
|
st.info("请选择时间范围并点击'开始分析'按钮获取数据。")
|