import streamlit as st
|
import plotly.express as px
|
import plotly.graph_objects as go
|
import pandas as pd
|
import numpy as np
|
from datetime import datetime, timedelta
|
from app.services.extruder_service import ExtruderService
|
from app.services.main_process_service import MainProcessService
|
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
from sklearn.model_selection import train_test_split
|
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
|
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
|
from sklearn.svm import SVR
|
from sklearn.neural_network import MLPRegressor
|
|
|
|
|
def show_metered_weight_advanced():
|
# 初始化服务
|
extruder_service = ExtruderService()
|
main_process_service = MainProcessService()
|
|
# 页面标题
|
st.title("米重高级预测分析")
|
|
# 初始化会话状态
|
if 'ma_start_date' not in st.session_state:
|
st.session_state['ma_start_date'] = datetime.now().date() - timedelta(days=7)
|
if 'ma_end_date' not in st.session_state:
|
st.session_state['ma_end_date'] = datetime.now().date()
|
if 'ma_quick_select' not in st.session_state:
|
st.session_state['ma_quick_select'] = "最近7天"
|
if 'ma_model_type' not in st.session_state:
|
st.session_state['ma_model_type'] = 'RandomForest'
|
if 'ma_sequence_length' not in st.session_state:
|
st.session_state['ma_sequence_length'] = 10
|
|
# 默认特征列表(不再允许用户选择)
|
default_features = ['螺杆转速', '机头压力', '流程主速', '螺杆温度',
|
'后机筒温度', '前机筒温度', '机头温度']
|
|
# 定义回调函数
|
def update_dates(qs):
|
st.session_state['ma_quick_select'] = qs
|
today = datetime.now().date()
|
if qs == "今天":
|
st.session_state['ma_start_date'] = today
|
st.session_state['ma_end_date'] = today
|
elif qs == "最近3天":
|
st.session_state['ma_start_date'] = today - timedelta(days=3)
|
st.session_state['ma_end_date'] = today
|
elif qs == "最近7天":
|
st.session_state['ma_start_date'] = today - timedelta(days=7)
|
st.session_state['ma_end_date'] = today
|
elif qs == "最近30天":
|
st.session_state['ma_start_date'] = today - timedelta(days=30)
|
st.session_state['ma_end_date'] = today
|
|
def on_date_change():
|
st.session_state['ma_quick_select'] = "自定义"
|
|
# 查询条件区域
|
with st.expander("🔍 查询配置", expanded=True):
|
# 添加自定义 CSS 实现响应式换行
|
st.markdown("""
|
<style>
|
/* 强制列容器换行 */
|
[data-testid="stExpander"] [data-testid="column"] {
|
flex: 1 1 120px !important;
|
min-width: 120px !important;
|
}
|
/* 针对日期输入框列稍微加宽一点 */
|
@media (min-width: 768px) {
|
[data-testid="stExpander"] [data-testid="column"]:nth-child(6),
|
[data-testid="stExpander"] [data-testid="column"]:nth-child(7) {
|
flex: 2 1 180px !important;
|
min-width: 180px !important;
|
}
|
}
|
</style>
|
""", unsafe_allow_html=True)
|
|
# 创建布局
|
cols = st.columns([1, 1, 1, 1, 1, 1.5, 1.5, 1])
|
|
options = ["今天", "最近3天", "最近7天", "最近30天", "自定义"]
|
for i, option in enumerate(options):
|
with cols[i]:
|
# 根据当前选择状态决定按钮类型
|
button_type = "primary" if st.session_state['ma_quick_select'] == option else "secondary"
|
if st.button(option, key=f"btn_ma_{option}", width='stretch', type=button_type):
|
update_dates(option)
|
st.rerun()
|
|
with cols[5]:
|
start_date = st.date_input(
|
"开始日期",
|
label_visibility="collapsed",
|
key="ma_start_date",
|
on_change=on_date_change
|
)
|
|
with cols[6]:
|
end_date = st.date_input(
|
"结束日期",
|
label_visibility="collapsed",
|
key="ma_end_date",
|
on_change=on_date_change
|
)
|
|
with cols[7]:
|
query_button = st.button("🚀 开始分析", key="ma_query", width='stretch')
|
|
# 模型配置
|
st.markdown("---")
|
st.write("🤖 **模型配置**")
|
model_cols = st.columns(2)
|
|
with model_cols[0]:
|
# 模型类型选择
|
model_options = ['RandomForest', 'GradientBoosting', 'SVR', 'MLP']
|
|
model_type = st.selectbox(
|
"模型类型",
|
options=model_options,
|
key="ma_model_type",
|
help="选择用于预测的模型类型"
|
)
|
|
|
|
# 转换为datetime对象
|
start_dt = datetime.combine(start_date, datetime.min.time())
|
end_dt = datetime.combine(end_date, datetime.max.time())
|
|
# 查询处理
|
if query_button:
|
with st.spinner("正在获取数据..."):
|
# 1. 获取完整的挤出机数据
|
df_extruder_full = extruder_service.get_extruder_data(start_dt, end_dt)
|
|
# 2. 获取主流程控制数据
|
df_main_speed = main_process_service.get_cutting_setting_data(start_dt, end_dt)
|
|
df_temp = main_process_service.get_temperature_control_data(start_dt, end_dt)
|
|
# 检查是否有数据
|
has_data = any([
|
df_extruder_full is not None and not df_extruder_full.empty,
|
df_main_speed is not None and not df_main_speed.empty,
|
df_temp is not None and not df_temp.empty
|
])
|
|
if not has_data:
|
st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
|
# 清除缓存数据
|
for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp', 'last_query_start', 'last_query_end']:
|
if key in st.session_state:
|
del st.session_state[key]
|
return
|
|
# 缓存数据到会话状态
|
st.session_state['cached_extruder_full'] = df_extruder_full
|
st.session_state['cached_main_speed'] = df_main_speed
|
st.session_state['cached_temp'] = df_temp
|
st.session_state['last_query_start'] = start_dt
|
st.session_state['last_query_end'] = end_dt
|
|
# 数据处理和分析
|
if all(key in st.session_state for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp']):
|
with st.spinner("正在分析数据..."):
|
# 获取缓存数据
|
df_extruder_full = st.session_state['cached_extruder_full']
|
df_main_speed = st.session_state['cached_main_speed']
|
df_temp = st.session_state['cached_temp']
|
|
|
|
# 检查是否有数据
|
has_data = any([
|
df_extruder_full is not None and not df_extruder_full.empty,
|
df_main_speed is not None and not df_main_speed.empty,
|
df_temp is not None and not df_temp.empty
|
])
|
|
if not has_data:
|
st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
|
return
|
|
# 数据整合与预处理
|
def integrate_data(df_extruder_full, df_main_speed, df_temp):
|
# 确保挤出机数据存在
|
if df_extruder_full is None or df_extruder_full.empty:
|
return None
|
|
# 创建只包含米重和时间的主数据集
|
df_merged = df_extruder_full[['time', 'metered_weight', 'screw_speed_actual', 'head_pressure']].copy()
|
|
|
# 整合主流程数据
|
if df_main_speed is not None and not df_main_speed.empty:
|
df_main_speed = df_main_speed[['time', 'process_main_speed']]
|
df_merged = pd.merge_asof(
|
df_merged.sort_values('time'),
|
df_main_speed.sort_values('time'),
|
on='time',
|
direction='nearest',
|
tolerance=pd.Timedelta('1min')
|
)
|
|
# 整合温度数据
|
if df_temp is not None and not df_temp.empty:
|
temp_cols = ['time', 'nakata_extruder_screw_display_temp',
|
'nakata_extruder_rear_barrel_display_temp',
|
'nakata_extruder_front_barrel_display_temp',
|
'nakata_extruder_head_display_temp']
|
df_temp_subset = df_temp[temp_cols].copy()
|
df_merged = pd.merge_asof(
|
df_merged.sort_values('time'),
|
df_temp_subset.sort_values('time'),
|
on='time',
|
direction='nearest',
|
tolerance=pd.Timedelta('1min')
|
)
|
|
# 重命名列以提高可读性
|
df_merged.rename(columns={
|
'screw_speed_actual': '螺杆转速',
|
'head_pressure': '机头压力',
|
'process_main_speed': '流程主速',
|
'nakata_extruder_screw_display_temp': '螺杆温度',
|
'nakata_extruder_rear_barrel_display_temp': '后机筒温度',
|
'nakata_extruder_front_barrel_display_temp': '前机筒温度',
|
'nakata_extruder_head_display_temp': '机头温度'
|
}, inplace=True)
|
|
# 清理数据
|
df_merged.dropna(subset=['metered_weight'], inplace=True)
|
|
return df_merged
|
|
# 执行数据整合
|
df_analysis = integrate_data(df_extruder_full, df_main_speed, df_temp)
|
|
if df_analysis is None or df_analysis.empty:
|
st.warning("数据整合失败,请检查数据质量或调整时间范围。")
|
return
|
|
# 重命名米重列
|
df_analysis.rename(columns={'metered_weight': '米重'}, inplace=True)
|
|
# --- 原始数据趋势图 ---
|
st.subheader("📈 原始数据趋势图")
|
|
# 创建趋势图
|
fig_trend = go.Figure()
|
|
# 添加米重数据
|
if df_extruder_full is not None and not df_extruder_full.empty:
|
fig_trend.add_trace(go.Scatter(
|
x=df_extruder_full['time'],
|
y=df_extruder_full['metered_weight'],
|
name='米重 (Kg/m)',
|
mode='lines',
|
line=dict(color='blue', width=2)
|
))
|
|
# 添加螺杆转速
|
fig_trend.add_trace(go.Scatter(
|
x=df_extruder_full['time'],
|
y=df_extruder_full['screw_speed_actual'],
|
name='螺杆转速 (RPM)',
|
mode='lines',
|
line=dict(color='green', width=1.5),
|
yaxis='y2'
|
))
|
|
# 添加机头压力
|
fig_trend.add_trace(go.Scatter(
|
x=df_extruder_full['time'],
|
y=df_extruder_full['head_pressure'],
|
name='机头压力',
|
mode='lines',
|
line=dict(color='orange', width=1.5),
|
yaxis='y3'
|
))
|
|
# 添加流程主速
|
if df_main_speed is not None and not df_main_speed.empty:
|
fig_trend.add_trace(go.Scatter(
|
x=df_main_speed['time'],
|
y=df_main_speed['process_main_speed'],
|
name='流程主速 (M/Min)',
|
mode='lines',
|
line=dict(color='red', width=1.5),
|
yaxis='y4'
|
))
|
|
# 添加温度数据
|
if df_temp is not None and not df_temp.empty:
|
# 螺杆温度
|
fig_trend.add_trace(go.Scatter(
|
x=df_temp['time'],
|
y=df_temp['nakata_extruder_screw_display_temp'],
|
name='螺杆温度 (°C)',
|
mode='lines',
|
line=dict(color='purple', width=1),
|
yaxis='y5'
|
))
|
|
# 配置趋势图布局
|
fig_trend.update_layout(
|
title='原始数据趋势',
|
xaxis=dict(
|
title='时间',
|
rangeslider=dict(visible=True),
|
type='date'
|
),
|
yaxis=dict(
|
title='米重 (Kg/m)',
|
title_font=dict(color='blue'),
|
tickfont=dict(color='blue')
|
),
|
yaxis2=dict(
|
title='螺杆转速 (RPM)',
|
title_font=dict(color='green'),
|
tickfont=dict(color='green'),
|
overlaying='y',
|
side='right'
|
),
|
yaxis3=dict(
|
title='机头压力',
|
title_font=dict(color='orange'),
|
tickfont=dict(color='orange'),
|
overlaying='y',
|
side='right',
|
anchor='free',
|
position=0.85
|
),
|
yaxis4=dict(
|
title='流程主速 (M/Min)',
|
title_font=dict(color='red'),
|
tickfont=dict(color='red'),
|
overlaying='y',
|
side='right',
|
anchor='free',
|
position=0.75
|
),
|
yaxis5=dict(
|
title='温度 (°C)',
|
title_font=dict(color='purple'),
|
tickfont=dict(color='purple'),
|
overlaying='y',
|
side='left',
|
anchor='free',
|
position=0.15
|
),
|
legend=dict(
|
orientation="h",
|
yanchor="bottom",
|
y=1.02,
|
xanchor="right",
|
x=1
|
),
|
height=600,
|
margin=dict(l=100, r=200, t=100, b=100),
|
hovermode='x unified'
|
)
|
|
# 显示趋势图
|
st.plotly_chart(fig_trend, width='stretch', config={'scrollZoom': True})
|
|
# --- 高级预测分析 ---
|
st.subheader("📊 高级预测分析")
|
|
# 检查所有默认特征是否在数据中
|
missing_features = [f for f in default_features if f not in df_analysis.columns]
|
if missing_features:
|
st.warning(f"数据中缺少以下特征: {', '.join(missing_features)}")
|
else:
|
try:
|
# 准备数据
|
# 首先确保df_analysis中没有NaN值
|
df_analysis_clean = df_analysis.dropna(subset=default_features + ['米重'])
|
|
# 检查清理后的数据量
|
if len(df_analysis_clean) < 30:
|
st.warning("数据量不足,无法进行有效的预测分析")
|
else:
|
# 创建一个新的DataFrame来存储所有特征和目标变量
|
all_features = df_analysis_clean[default_features + ['米重']].copy()
|
|
|
|
|
# 清理所有NaN值
|
all_features_clean = all_features.dropna()
|
|
# 检查清理后的数据量
|
if len(all_features_clean) < 20:
|
st.warning("特征工程后数据量不足,无法进行有效的预测分析")
|
else:
|
# 分离特征和目标变量
|
feature_columns = [col for col in all_features_clean.columns if col != '米重']
|
X_final = all_features_clean[feature_columns]
|
y_final = all_features_clean['米重']
|
|
# 检查最终数据量
|
if len(X_final) >= 20:
|
# 分割训练集和测试集
|
X_train, X_test, y_train, y_test = train_test_split(X_final, y_final, test_size=0.2, random_state=42)
|
|
# 数据标准化
|
scaler_X = StandardScaler()
|
scaler_y = MinMaxScaler()
|
|
X_train_scaled = scaler_X.fit_transform(X_train)
|
X_test_scaled = scaler_X.transform(X_test)
|
y_train_scaled = scaler_y.fit_transform(y_train.values.reshape(-1, 1)).ravel()
|
y_test_scaled = scaler_y.transform(y_test.values.reshape(-1, 1)).ravel()
|
|
# 模型训练
|
model = None
|
y_pred = None
|
|
if model_type == 'RandomForest':
|
# 随机森林回归
|
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
model.fit(X_train, y_train)
|
y_pred = model.predict(X_test)
|
|
elif model_type == 'GradientBoosting':
|
# 梯度提升回归
|
model = GradientBoostingRegressor(n_estimators=100, random_state=42)
|
model.fit(X_train, y_train)
|
y_pred = model.predict(X_test)
|
|
elif model_type == 'SVR':
|
# 支持向量回归
|
model = SVR(kernel='rbf', C=1.0, gamma='scale')
|
model.fit(X_train_scaled, y_train_scaled)
|
y_pred_scaled = model.predict(X_test_scaled)
|
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
|
|
elif model_type == 'MLP':
|
# 多层感知器回归
|
model = MLPRegressor(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
|
model.fit(X_train_scaled, y_train_scaled)
|
y_pred_scaled = model.predict(X_test_scaled)
|
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
|
|
|
|
# 计算评估指标
|
# 确保y_test和y_pred长度一致
|
min_len = min(len(y_test), len(y_pred))
|
if min_len > 0:
|
y_test_trimmed = y_test[:min_len]
|
y_pred_trimmed = y_pred[:min_len]
|
r2 = r2_score(y_test_trimmed, y_pred_trimmed)
|
mse = mean_squared_error(y_test_trimmed, y_pred_trimmed)
|
mae = mean_absolute_error(y_test_trimmed, y_pred_trimmed)
|
rmse = np.sqrt(mse)
|
else:
|
r2 = 0
|
mse = 0
|
mae = 0
|
rmse = 0
|
|
# 显示模型性能
|
metrics_cols = st.columns(2)
|
with metrics_cols[0]:
|
st.metric("R² 得分", f"{r2:.4f}")
|
st.metric("均方误差 (MSE)", f"{mse:.6f}")
|
with metrics_cols[1]:
|
st.metric("平均绝对误差 (MAE)", f"{mae:.6f}")
|
st.metric("均方根误差 (RMSE)", f"{rmse:.6f}")
|
|
# --- 实际值与预测值对比 ---
|
st.subheader("🔄 实际值与预测值对比")
|
|
# 创建对比数据
|
compare_df = pd.DataFrame({
|
'实际值': y_test_trimmed,
|
'预测值': y_pred_trimmed
|
})
|
compare_df = compare_df.sort_index()
|
|
# 创建对比图
|
fig_compare = go.Figure()
|
fig_compare.add_trace(go.Scatter(
|
x=compare_df.index,
|
y=compare_df['实际值'],
|
name='实际值',
|
mode='lines+markers',
|
line=dict(color='blue', width=2)
|
))
|
fig_compare.add_trace(go.Scatter(
|
x=compare_df.index,
|
y=compare_df['预测值'],
|
name='预测值',
|
mode='lines+markers',
|
line=dict(color='red', width=2, dash='dash')
|
))
|
fig_compare.update_layout(
|
title=f'测试集: 实际米重 vs 预测米重 ({model_type})',
|
xaxis=dict(title='时间'),
|
yaxis=dict(title='米重 (Kg/m)'),
|
legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
|
height=400
|
)
|
st.plotly_chart(fig_compare, width='stretch')
|
|
# --- 残差分析 ---
|
st.subheader("📉 残差分析")
|
|
# 计算残差
|
residuals = y_test_trimmed - y_pred_trimmed
|
|
# 创建残差图
|
fig_residual = go.Figure()
|
fig_residual.add_trace(go.Scatter(
|
x=y_pred,
|
y=residuals,
|
mode='markers',
|
marker=dict(color='green', size=8, opacity=0.6)
|
))
|
fig_residual.add_shape(
|
type="line",
|
x0=y_pred.min(),
|
y0=0,
|
x1=y_pred.max(),
|
y1=0,
|
line=dict(color="red", width=2, dash="dash")
|
)
|
fig_residual.update_layout(
|
title='残差图',
|
xaxis=dict(title='预测值'),
|
yaxis=dict(title='残差'),
|
height=400
|
)
|
st.plotly_chart(fig_residual, width='stretch')
|
|
# --- 特征重要性(如果模型支持) ---
|
if model_type in ['RandomForest', 'GradientBoosting']:
|
st.subheader("⚖️ 特征重要性分析")
|
|
# 计算特征重要性
|
feature_importance = pd.DataFrame({
|
'特征': X_train.columns,
|
'重要性': model.feature_importances_
|
})
|
feature_importance = feature_importance.sort_values('重要性', ascending=False)
|
|
# 创建特征重要性图
|
fig_importance = px.bar(
|
feature_importance,
|
x='特征',
|
y='重要性',
|
title='特征重要性',
|
color='重要性',
|
color_continuous_scale='viridis'
|
)
|
fig_importance.update_layout(
|
xaxis=dict(tickangle=-45),
|
height=400
|
)
|
st.plotly_chart(fig_importance, width='stretch')
|
|
# --- 预测功能 ---
|
st.subheader("🔮 米重预测")
|
|
# 创建预测表单
|
st.write("输入特征值进行米重预测:")
|
predict_cols = st.columns(2)
|
input_features = {}
|
|
for i, feature in enumerate(default_features):
|
with predict_cols[i % 2]:
|
# 获取特征的统计信息
|
min_val = df_analysis_clean[feature].min()
|
max_val = df_analysis_clean[feature].max()
|
mean_val = df_analysis_clean[feature].mean()
|
|
input_features[feature] = st.number_input(
|
f"{feature}",
|
key=f"ma_pred_{feature}",
|
value=float(mean_val),
|
min_value=float(min_val),
|
max_value=float(max_val),
|
step=0.1
|
)
|
|
if st.button("预测米重"):
|
# 准备预测数据
|
input_df = pd.DataFrame([input_features])
|
|
# 合并特征
|
input_combined = pd.concat([input_df], axis=1)
|
|
# 预测
|
if model_type in ['SVR', 'MLP']:
|
input_scaled = scaler_X.transform(input_combined)
|
prediction_scaled = model.predict(input_scaled)
|
predicted_weight = scaler_y.inverse_transform(prediction_scaled.reshape(-1, 1)).ravel()[0]
|
|
else:
|
predicted_weight = model.predict(input_combined)[0]
|
|
# 显示预测结果
|
st.success(f"预测米重: {predicted_weight:.4f} Kg/m")
|
|
# --- 数据预览 ---
|
st.subheader("🔍 数据预览")
|
st.dataframe(df_analysis.head(20), width='stretch')
|
|
# --- 导出数据 ---
|
st.subheader("💾 导出数据")
|
# 将数据转换为CSV格式
|
csv = df_analysis.to_csv(index=False)
|
# 创建下载按钮
|
st.download_button(
|
label="导出整合后的数据 (CSV)",
|
data=csv,
|
file_name=f"metered_weight_advanced_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
|
mime="text/csv",
|
help="点击按钮导出整合后的米重分析数据"
|
)
|
except Exception as e:
|
st.error(f"模型训练或预测失败: {str(e)}")
|
|
else:
|
# 提示用户点击开始分析按钮
|
st.info("请选择时间范围并点击'开始分析'按钮获取数据。")
|