baoshiwei
2026-02-02 4048393750de17cfa2ae59fec1380a81ea2b2a6b
feat: 添加米重分析模块并优化综合看板功能

- 新增米重综合分析、相关性分析、回归分析和高级预测分析页面
- 在综合看板中添加时间偏移功能以对齐上下游数据
- 优化图表交互功能,添加统一悬停模式和缩放控制
- 更新数据库配置和依赖项,添加scikit-learn和pytorch
- 改进数据查询逻辑,添加裁切计数字段
- 修复数据展示问题,调整单位显示和过滤异常值
已修改9个文件
已添加6个文件
4407 ■■■■■ 文件已修改
.env 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md 22 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/pages/comprehensive_dashboard.py 119 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/pages/main_process_dashboard.py 39 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/pages/metered_weight_advanced copy.py 832 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/pages/metered_weight_advanced pytorch.py 873 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/pages/metered_weight_advanced.py 635 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/pages/metered_weight_correlation.py 777 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/pages/metered_weight_dashboard.py 430 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/pages/metered_weight_regression.py 620 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/pages/sorting_dashboard.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/services/data_processing_service.py 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
app/services/main_process_service.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
dashboard.py 36 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
requirements.txt 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
.env
@@ -1,6 +1,6 @@
# Database Configuration
DB_HOST=localhost
DB_PORT=5433
DB_HOST=192.168.21.6
DB_PORT=5432
DB_NAME=aics
DB_USER=aics
DB_PASSWORD=123456lb
DB_PASSWORD=123456
README.md
@@ -1 +1,21 @@
本项目采用MIT许可证ã€
本项目采用MIT许可证�
本地测试----postgresql数据库------------
用户:postgres
密码:123456
数据库:aics
用户名:aics
密码:123456lb
-----------------------------------------
正新项目现场---postgresql-------------
用户:postgres
密码:123456
数据库:aics
用户名:aics
密码:lanbaoit-123
------------------------------------------
app/pages/comprehensive_dashboard.py
@@ -14,7 +14,7 @@
    main_process_service = MainProcessService()
    # é¡µé¢æ ‡é¢˜
    st.title("多维综合分析")
    st.title("条重综合分析")
    # åˆå§‹åŒ–会话状态用于日期同步
    if 'comp_start_date' not in st.session_state:
@@ -23,6 +23,8 @@
        st.session_state['comp_end_date'] = datetime.now().date()
    if 'comp_quick_select' not in st.session_state:
        st.session_state['comp_quick_select'] = "最近7天"
    if 'time_offset' not in st.session_state:
        st.session_state['time_offset'] = 0
    # å®šä¹‰å›žè°ƒå‡½æ•°
    def update_dates(qs):
@@ -93,8 +95,22 @@
                on_change=on_date_change
            )
        with cols[7]:
            query_button = st.button("🚀 æŸ¥è¯¢", key="comp_query", width='stretch')
        # åœ¨ç¬¬äºŒè¡Œæ·»åŠ æ—¶é—´åç§»é…ç½®
        st.markdown("---")
        offset_cols = st.columns([2, 4, 2])
        with offset_cols[0]:
            st.write("⏱️ **生产对齐配置**")
        with offset_cols[1]:
            time_offset = st.slider(
                "挤出/主流程数据向后偏移 (分钟)",
                min_value=0,
                max_value=60,
                value=st.session_state['time_offset'],
                help="由于胎面从挤出到分拣需要时间,将上游数据向后移动,使其与分拣磅秤上的重量数据在时间轴上对齐。"
            )
            st.session_state['time_offset'] = time_offset
        with offset_cols[2]:
            query_button = st.button("🚀 å¼€å§‹åˆ†æž", key="comp_query", width='stretch')
    # è½¬æ¢ä¸ºdatetime对象
    start_dt = datetime.combine(start_date, datetime.min.time())
@@ -103,20 +119,32 @@
    # æŸ¥è¯¢å¤„理
    if query_button:
        with st.spinner("正在聚合多源数据..."):
            # 1. èŽ·å–åˆ†æ‹£ç£…ç§¤æ•°æ®
            # èŽ·å–åç§»é‡
            offset_delta = timedelta(minutes=st.session_state['time_offset'])
            # 1. èŽ·å–åˆ†æ‹£ç£…ç§¤æ•°æ® (作为基准,不偏移)
            df_sorting = sorting_service.get_sorting_scale_data(start_dt, end_dt)
            # 2. èŽ·å–æŒ¤å‡ºæœºæ•°æ®
            # 2. èŽ·å–æŒ¤å‡ºæœºæ•°æ® (应用偏移)
            df_extruder = extruder_service.get_extruder_data(start_dt, end_dt)
            # 3. èŽ·å–ä¸»æµç¨‹æŽ§åˆ¶æ•°æ®
            if df_extruder is not None and not df_extruder.empty:
                df_extruder['time'] = df_extruder['time'] + offset_delta
            # 3. èŽ·å–ä¸»æµç¨‹æŽ§åˆ¶æ•°æ® (应用偏移)
            df_main_speed = main_process_service.get_cutting_setting_data(start_dt, end_dt)
            if df_main_speed is not None and not df_main_speed.empty:
                df_main_speed['time'] = df_main_speed['time'] + offset_delta
            df_temp = main_process_service.get_temperature_control_data(start_dt, end_dt)
            if df_temp is not None and not df_temp.empty:
                df_temp['time'] = df_temp['time'] + offset_delta
            # æ£€æŸ¥æ˜¯å¦æœ‰æ•°æ®
            has_data = any([
                df_sorting is not None and not df_sorting.empty,
                df_extruder is not None and not df_extruder.empty,
                df_main_speed is not None and not df_main_speed.empty,
                df_temp is not None and not df_temp.empty
                df_temp is not None and not df_temp.empty
            ])
            if not has_data:
@@ -207,14 +235,14 @@
            # æ·»åŠ æŒ¤å‡ºæœºç±³é‡
            if df_extruder is not None and not df_extruder.empty:
                fig.add_trace(go.Scatter(
                    x=df_extruder['time'],
                    y=df_extruder['metered_weight'],
                    name='挤出机米重 (g/m)',
                    mode='lines',
                    line=dict(color='green', width=1.5),
                    yaxis='y2'
                ))
                # fig.add_trace(go.Scatter(
                #     x=df_extruder['time'],
                #     y=df_extruder['metered_weight'],
                #     name='挤出机米重 (Kg/m)',
                #     mode='lines',
                #     line=dict(color='green', width=1.5),
                #     yaxis='y2'
                # ))
                # æ·»åŠ æŒ¤å‡ºæœºå®žé™…è½¬é€Ÿ
                fig.add_trace(go.Scatter(
                    x=df_extruder['time'], 
@@ -235,14 +263,26 @@
                    line=dict(color='red', width=1.5),
                    yaxis='y3' # å…±ç”¨é€Ÿåº¦è½´
                ))
                # æ·»åŠ è£åˆ‡è®¡æ•°
                if 'cutting_count' in df_main_speed.columns:
                    fig.add_trace(go.Scatter(
                        x=df_main_speed['time'],
                        y=df_main_speed['cutting_count'],
                        name='裁切计数',
                        mode='lines',
                        line=dict(color='purple', width=1.5),
                        yaxis='y5'
                    ))
            # æ·»åŠ æ¸©åº¦è®¾å®šå€¼
            if df_temp is not None and not df_temp.empty:
                temp_fields = {
                    'nakata_extruder_screw_set_temp': '螺杆设定 (°C)',
                    'nakata_extruder_rear_barrel_set_temp': '后机筒设定 (°C)',
                    'nakata_extruder_front_barrel_set_temp': '前机筒设定 (°C)',
                    'nakata_extruder_head_set_temp': '机头设定 (°C)'
                    'nakata_extruder_screw_display_temp': '螺杆显示 (°C)',
                    'nakata_extruder_rear_barrel_display_temp': '后机筒显示 (°C)',
                    'nakata_extruder_front_barrel_display_temp': '前机筒显示 (°C)',
                    'nakata_extruder_head_display_temp': '机头显示 (°C)'
                }
                colors = ['#FF4B4B', '#FF8C00', '#FFD700', '#DA70D6']
                for i, (field, label) in enumerate(temp_fields.items()):
@@ -257,7 +297,7 @@
            # è®¾ç½®å¤šåæ ‡è½´å¸ƒå±€
            fig.update_layout(
                title='多维综合趋势分析',
                title='条重综合趋势分析',
                xaxis=dict(
                    title='时间',
                    rangeslider=dict(visible=True),
@@ -269,7 +309,7 @@
                    tickfont=dict(color='blue')
                ),
                yaxis2=dict(
                    title='米重 (g/m)',
                    title='米重 (Kg/m)',
                    title_font=dict(color='green'),
                    tickfont=dict(color='green'),
                    overlaying='y',
@@ -293,6 +333,15 @@
                    anchor='free',
                    position=0.15
                ),
                yaxis5=dict(
                    title='裁切计数',
                    title_font=dict(color='purple'),
                    tickfont=dict(color='purple'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.7
                ),
                legend=dict(
                    orientation="h",
                    yanchor="bottom",
@@ -309,21 +358,21 @@
            st.plotly_chart(fig, width='stretch', config={'scrollZoom': True})
            
            # æ•°æ®æ‘˜è¦
            st.subheader("📊 æ•°æ®æ‘˜è¦")
            summary_cols = st.columns(4)
            # st.subheader("📊 æ•°æ®æ‘˜è¦")
            # summary_cols = st.columns(4)
            
            with summary_cols[0]:
                if df_sorting is not None and not df_sorting.empty:
                    st.metric("平均重量", f"{df_sorting['weight'].mean():.2f} kg")
            # with summary_cols[0]:
            #     if df_sorting is not None and not df_sorting.empty:
            #         st.metric("平均重量", f"{df_sorting['weight'].mean():.2f} kg")
            
            with summary_cols[1]:
                if df_extruder is not None and not df_extruder.empty:
                    st.metric("平均米重", f"{df_extruder['metered_weight'].mean():.2f} g/m")
            # with summary_cols[1]:
            #     if df_extruder is not None and not df_extruder.empty:
            #         st.metric("平均米重", f"{df_extruder['metered_weight'].mean():.2f} Kg/m")
            
            with summary_cols[2]:
                if df_main_speed is not None and not df_main_speed.empty:
                    st.metric("平均主速", f"{df_main_speed['process_main_speed'].mean():.2f} M/Min")
            # with summary_cols[2]:
            #     if df_main_speed is not None and not df_main_speed.empty:
            #         st.metric("平均主速", f"{df_main_speed['process_main_speed'].mean():.2f} M/Min")
            
            with summary_cols[3]:
                if df_temp is not None and not df_temp.empty:
                    st.metric("平均螺杆温控", f"{df_temp['nakata_extruder_screw_set_temp'].mean():.1f} Â°C")
            # with summary_cols[3]:
            #     if df_temp is not None and not df_temp.empty:
            #         st.metric("平均螺杆温控", f"{df_temp['nakata_extruder_screw_set_temp'].mean():.1f} Â°C")
app/pages/main_process_dashboard.py
@@ -111,8 +111,19 @@
                fig_speed = px.line(df_speed, x='time', y='process_main_speed', 
                                   title="流程主速度 (M/Min)",
                                   labels={'time': '时间', 'process_main_speed': '主速度 (M/Min)'})
                fig_speed.update_layout(xaxis=dict(rangeslider=dict(visible=True), type='date'))
                st.plotly_chart(fig_speed, width='stretch', config={'scrollZoom': True})
                fig_speed.update_layout(
                    xaxis=dict(rangeslider=dict(visible=True), type='date'),
                                            yaxis=dict(fixedrange=False),
                                            hovermode='x unified',
                    dragmode='zoom',
                )
                st.plotly_chart(fig_speed, width='stretch', config={
                    'scrollZoom': True,
                    'modeBarButtonsToAdd': ['zoom2d', 'zoomIn2d', 'zoomOut2d'],
                    'doubleClick': 'reset',
                    'displayModeBar': True,
                    'toImageButtonOptions': {'format': 'png'}
                })
            else:
                st.info("该时间段内无主速度数据")
@@ -126,9 +137,17 @@
                    title="电机线速 (M/Min)", 
                    xaxis_title="时间", 
                    yaxis_title="线速 (M/Min)",
                    xaxis=dict(rangeslider=dict(visible=True), type='date')
                    xaxis=dict(rangeslider=dict(visible=True), type='date'),
                                            yaxis=dict(fixedrange=False),
                                            hovermode='x unified',
                    dragmode='zoom'
                )
                st.plotly_chart(fig_motor, width='stretch', config={'scrollZoom': True})
                st.plotly_chart(fig_motor, width='stretch', config={
                    'scrollZoom': True,
                    'modeBarButtonsToAdd': ['zoom2d', 'zoomIn2d', 'zoomOut2d'],
                    'doubleClick': 'reset',
                    'displayModeBar': True
                })
            else:
                st.info("该时间段内无电机监控数据")
@@ -165,8 +184,16 @@
                    title="中田挤出机温度 (°C)", 
                    xaxis_title="时间", 
                    yaxis_title="温度 (°C)",
                    xaxis=dict(rangeslider=dict(visible=True), type='date')
                    xaxis=dict(rangeslider=dict(visible=True), type='date'),
                                            yaxis=dict(fixedrange=False),
                                            hovermode='x unified',
                    dragmode='zoom'
                )
                st.plotly_chart(fig_temp, width='stretch', config={'scrollZoom': True})
                st.plotly_chart(fig_temp, width='stretch', config={
                    'scrollZoom': True,
                    'modeBarButtonsToAdd': ['zoom2d', 'zoomIn2d', 'zoomOut2d'],
                    'doubleClick': 'reset',
                    'displayModeBar': True
                })
            else:
                st.info("该时间段内无温度控制数据")
app/pages/metered_weight_advanced copy.py
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,832 @@
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from app.services.extruder_service import ExtruderService
from app.services.main_process_service import MainProcessService
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.svm import SVR
from sklearn.neural_network import MLPRegressor
# å°è¯•导入深度学习库
use_deep_learning = False
try:
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import LSTM, GRU, Dense, Dropout, Bidirectional
    from tensorflow.keras.optimizers import Adam
    use_deep_learning = True
except ImportError:
    st.warning("未检测到TensorFlow/Keras,深度学习模型将不可用。请安装tensorflow以使用LSTM/GRU模型。")
def show_metered_weight_advanced():
    # åˆå§‹åŒ–服务
    extruder_service = ExtruderService()
    main_process_service = MainProcessService()
    # é¡µé¢æ ‡é¢˜
    st.title("米重高级预测分析")
    # åˆå§‹åŒ–会话状态
    if 'ma_start_date' not in st.session_state:
        st.session_state['ma_start_date'] = datetime.now().date() - timedelta(days=7)
    if 'ma_end_date' not in st.session_state:
        st.session_state['ma_end_date'] = datetime.now().date()
    if 'ma_quick_select' not in st.session_state:
        st.session_state['ma_quick_select'] = "最近7天"
    if 'ma_model_type' not in st.session_state:
        st.session_state['ma_model_type'] = 'RandomForest'
    if 'ma_sequence_length' not in st.session_state:
        st.session_state['ma_sequence_length'] = 10
    # é»˜è®¤ç‰¹å¾åˆ—表(不再允许用户选择)
    default_features = ['螺杆转速', '机头压力', '流程主速', '螺杆温度',
                       '后机筒温度', '前机筒温度', '机头温度']
    # å®šä¹‰å›žè°ƒå‡½æ•°
    def update_dates(qs):
        st.session_state['ma_quick_select'] = qs
        today = datetime.now().date()
        if qs == "今天":
            st.session_state['ma_start_date'] = today
            st.session_state['ma_end_date'] = today
        elif qs == "最近3天":
            st.session_state['ma_start_date'] = today - timedelta(days=3)
            st.session_state['ma_end_date'] = today
        elif qs == "最近7天":
            st.session_state['ma_start_date'] = today - timedelta(days=7)
            st.session_state['ma_end_date'] = today
        elif qs == "最近30天":
            st.session_state['ma_start_date'] = today - timedelta(days=30)
            st.session_state['ma_end_date'] = today
        # æ¸…除之前的缓存数据和分析标志
        for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp', 'last_query_start', 'last_query_end', 'analysis_completed']:
            if key in st.session_state:
                del st.session_state[key]
    def on_date_change():
        st.session_state['ma_quick_select'] = "自定义"
        # æ¸…除之前的缓存数据和分析标志
        for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp', 'last_query_start', 'last_query_end', 'analysis_completed']:
            if key in st.session_state:
                del st.session_state[key]
    # æŸ¥è¯¢æ¡ä»¶åŒºåŸŸ
    with st.expander("🔍 æŸ¥è¯¢é…ç½®", expanded=True):
        # æ·»åŠ è‡ªå®šä¹‰ CSS å®žçŽ°å“åº”å¼æ¢è¡Œ
        st.markdown("""
            <style>
            /* å¼ºåˆ¶åˆ—容器换行 */
            [data-testid="stExpander"] [data-testid="column"] {
                flex: 1 1 120px !important;
                min-width: 120px !important;
            }
            /* é’ˆå¯¹æ—¥æœŸè¾“入框列稍微加宽一点 */
            @media (min-width: 768px) {
                [data-testid="stExpander"] [data-testid="column"]:nth-child(6),
                [data-testid="stExpander"] [data-testid="column"]:nth-child(7) {
                    flex: 2 1 180px !important;
                    min-width: 180px !important;
                }
            }
            </style>
            """, unsafe_allow_html=True)
        # åˆ›å»ºå¸ƒå±€
        cols = st.columns([1, 1, 1, 1, 1, 1.5, 1.5, 1])
        options = ["今天", "最近3天", "最近7天", "最近30天", "自定义"]
        for i, option in enumerate(options):
            with cols[i]:
                # æ ¹æ®å½“前选择状态决定按钮类型
                button_type = "primary" if st.session_state['ma_quick_select'] == option else "secondary"
                if st.button(option, key=f"btn_ma_{option}", width='stretch', type=button_type):
                    update_dates(option)
                    st.rerun()
        with cols[5]:
            start_date = st.date_input(
                "开始日期",
                label_visibility="collapsed",
                key="ma_start_date",
                on_change=on_date_change
            )
        with cols[6]:
            end_date = st.date_input(
                "结束日期",
                label_visibility="collapsed",
                key="ma_end_date",
                on_change=on_date_change
            )
        with cols[7]:
            query_button = st.button("🚀 å¼€å§‹åˆ†æž", key="ma_query", width='stretch')
        # æ¨¡åž‹é…ç½®
        st.markdown("---")
        st.write("🤖 **模型配置**")
        model_cols = st.columns(2)
        with model_cols[0]:
            # æ¨¡åž‹ç±»åž‹é€‰æ‹©
            model_options = ['RandomForest', 'GradientBoosting', 'SVR', 'MLP']
            if use_deep_learning:
                model_options.extend(['LSTM', 'GRU', 'BiLSTM'])
            model_type = st.selectbox(
                "模型类型",
                options=model_options,
                key="ma_model_type",
                help="选择用于预测的模型类型"
            )
        with model_cols[1]:
            # åºåˆ—长度(仅适用于深度学习模型)
            if model_type in ['LSTM', 'GRU', 'BiLSTM']:
                sequence_length = st.slider(
                    "序列长度",
                    min_value=5,
                    max_value=30,
                    value=st.session_state['ma_sequence_length'],
                    step=1,
                    help="用于深度学习模型的时间序列长度",
                    key="ma_sequence_length"
                )
            else:
                st.session_state['ma_sequence_length'] = 10
                st.write("序列长度: 10 (默认,仅适用于深度学习模型)")
    # è½¬æ¢ä¸ºdatetime对象
    start_dt = datetime.combine(start_date, datetime.min.time())
    end_dt = datetime.combine(end_date, datetime.max.time())
    # æŸ¥è¯¢å¤„理
    if query_button:
        with st.spinner("正在获取数据..."):
            # 1. èŽ·å–å®Œæ•´çš„æŒ¤å‡ºæœºæ•°æ®
            df_extruder_full = extruder_service.get_extruder_data(start_dt, end_dt)
            # 2. èŽ·å–ä¸»æµç¨‹æŽ§åˆ¶æ•°æ®
            df_main_speed = main_process_service.get_cutting_setting_data(start_dt, end_dt)
            df_temp = main_process_service.get_temperature_control_data(start_dt, end_dt)
            # æ£€æŸ¥æ˜¯å¦æœ‰æ•°æ®
            has_data = any([
                df_extruder_full is not None and not df_extruder_full.empty,
                df_main_speed is not None and not df_main_speed.empty,
                df_temp is not None and not df_temp.empty
            ])
            if not has_data:
                st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
                # æ¸…除缓存数据
                for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp', 'last_query_start', 'last_query_end']:
                    if key in st.session_state:
                        del st.session_state[key]
                return
            # ç¼“存数据到会话状态
            st.session_state['cached_extruder_full'] = df_extruder_full
            st.session_state['cached_main_speed'] = df_main_speed
            st.session_state['cached_temp'] = df_temp
            st.session_state['last_query_start'] = start_dt
            st.session_state['last_query_end'] = end_dt
            # è®¾ç½®åˆ†æžå®Œæˆæ ‡å¿—
            st.session_state['analysis_completed'] = True
    # æ•°æ®å¤„理和分析
    if all(key in st.session_state for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp']) and st.session_state.get('analysis_completed', False):
        with st.spinner("正在分析数据..."):
            # èŽ·å–ç¼“å­˜æ•°æ®
            df_extruder_full = st.session_state['cached_extruder_full']
            df_main_speed = st.session_state['cached_main_speed']
            df_temp = st.session_state['cached_temp']
            # æ£€æŸ¥æ˜¯å¦æœ‰æ•°æ®
            has_data = any([
                df_extruder_full is not None and not df_extruder_full.empty,
                df_main_speed is not None and not df_main_speed.empty,
                df_temp is not None and not df_temp.empty
            ])
            if not has_data:
                st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
                return
            # æ•°æ®æ•´åˆä¸Žé¢„处理
            def integrate_data(df_extruder_full, df_main_speed, df_temp):
                # ç¡®ä¿æŒ¤å‡ºæœºæ•°æ®å­˜åœ¨
                if df_extruder_full is None or df_extruder_full.empty:
                    return None
                # åˆ›å»ºåªåŒ…含米重和时间的主数据集
                df_merged = df_extruder_full[['time', 'metered_weight', 'screw_speed_actual', 'head_pressure']].copy()
                # æ•´åˆä¸»æµç¨‹æ•°æ®
                if df_main_speed is not None and not df_main_speed.empty:
                    df_main_speed = df_main_speed[['time', 'process_main_speed']]
                    df_merged = pd.merge_asof(
                        df_merged.sort_values('time'),
                        df_main_speed.sort_values('time'),
                        on='time',
                        direction='nearest',
                        tolerance=pd.Timedelta('1min')
                    )
                # æ•´åˆæ¸©åº¦æ•°æ®
                if df_temp is not None and not df_temp.empty:
                    temp_cols = ['time', 'nakata_extruder_screw_display_temp',
                               'nakata_extruder_rear_barrel_display_temp',
                               'nakata_extruder_front_barrel_display_temp',
                               'nakata_extruder_head_display_temp']
                    df_temp_subset = df_temp[temp_cols].copy()
                    df_merged = pd.merge_asof(
                        df_merged.sort_values('time'),
                        df_temp_subset.sort_values('time'),
                        on='time',
                        direction='nearest',
                        tolerance=pd.Timedelta('1min')
                    )
                # é‡å‘½ååˆ—以提高可读性
                df_merged.rename(columns={
                    'screw_speed_actual': '螺杆转速',
                    'head_pressure': '机头压力',
                    'process_main_speed': '流程主速',
                    'nakata_extruder_screw_display_temp': '螺杆温度',
                    'nakata_extruder_rear_barrel_display_temp': '后机筒温度',
                    'nakata_extruder_front_barrel_display_temp': '前机筒温度',
                    'nakata_extruder_head_display_temp': '机头温度'
                }, inplace=True)
                # æ¸…理数据
                df_merged.dropna(subset=['metered_weight'], inplace=True)
                return df_merged
            # æ‰§è¡Œæ•°æ®æ•´åˆ
            df_analysis = integrate_data(df_extruder_full, df_main_speed, df_temp)
            if df_analysis is None or df_analysis.empty:
                st.warning("数据整合失败,请检查数据质量或调整时间范围。")
                return
            # é‡å‘½åç±³é‡åˆ—
            df_analysis.rename(columns={'metered_weight': '米重'}, inplace=True)
            # --- åŽŸå§‹æ•°æ®è¶‹åŠ¿å›¾ ---
            st.subheader("📈 åŽŸå§‹æ•°æ®è¶‹åŠ¿å›¾")
            # åˆ›å»ºè¶‹åŠ¿å›¾
            fig_trend = go.Figure()
            # æ·»åŠ ç±³é‡æ•°æ®
            if df_extruder_full is not None and not df_extruder_full.empty:
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_full['time'],
                    y=df_extruder_full['metered_weight'],
                    name='米重 (Kg/m)',
                    mode='lines',
                    line=dict(color='blue', width=2)
                ))
                # æ·»åŠ èžºæ†è½¬é€Ÿ
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_full['time'],
                    y=df_extruder_full['screw_speed_actual'],
                    name='螺杆转速 (RPM)',
                    mode='lines',
                    line=dict(color='green', width=1.5),
                    yaxis='y2'
                ))
                # æ·»åŠ æœºå¤´åŽ‹åŠ›
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_full['time'],
                    y=df_extruder_full['head_pressure'],
                    name='机头压力',
                    mode='lines',
                    line=dict(color='orange', width=1.5),
                    yaxis='y3'
                ))
            # æ·»åŠ æµç¨‹ä¸»é€Ÿ
            if df_main_speed is not None and not df_main_speed.empty:
                fig_trend.add_trace(go.Scatter(
                    x=df_main_speed['time'],
                    y=df_main_speed['process_main_speed'],
                    name='流程主速 (M/Min)',
                    mode='lines',
                    line=dict(color='red', width=1.5),
                    yaxis='y4'
                ))
            # æ·»åŠ æ¸©åº¦æ•°æ®
            if df_temp is not None and not df_temp.empty:
                # èžºæ†æ¸©åº¦
                fig_trend.add_trace(go.Scatter(
                    x=df_temp['time'],
                    y=df_temp['nakata_extruder_screw_display_temp'],
                    name='螺杆温度 (°C)',
                    mode='lines',
                    line=dict(color='purple', width=1),
                    yaxis='y5'
                ))
            # é…ç½®è¶‹åŠ¿å›¾å¸ƒå±€
            fig_trend.update_layout(
                title='原始数据趋势',
                xaxis=dict(
                    title='时间',
                    rangeslider=dict(visible=True),
                    type='date'
                ),
                yaxis=dict(
                    title='米重 (Kg/m)',
                    title_font=dict(color='blue'),
                    tickfont=dict(color='blue')
                ),
                yaxis2=dict(
                    title='螺杆转速 (RPM)',
                    title_font=dict(color='green'),
                    tickfont=dict(color='green'),
                    overlaying='y',
                    side='right'
                ),
                yaxis3=dict(
                    title='机头压力',
                    title_font=dict(color='orange'),
                    tickfont=dict(color='orange'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.85
                ),
                yaxis4=dict(
                    title='流程主速 (M/Min)',
                    title_font=dict(color='red'),
                    tickfont=dict(color='red'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.75
                ),
                yaxis5=dict(
                    title='温度 (°C)',
                    title_font=dict(color='purple'),
                    tickfont=dict(color='purple'),
                    overlaying='y',
                    side='left',
                    anchor='free',
                    position=0.15
                ),
                legend=dict(
                    orientation="h",
                    yanchor="bottom",
                    y=1.02,
                    xanchor="right",
                    x=1
                ),
                height=600,
                margin=dict(l=100, r=200, t=100, b=100),
                hovermode='x unified'
            )
            # æ˜¾ç¤ºè¶‹åŠ¿å›¾
            st.plotly_chart(fig_trend, width='stretch', config={'scrollZoom': True})
            # --- é«˜çº§é¢„测分析 ---
            st.subheader("📊 é«˜çº§é¢„测分析")
            # æ£€æŸ¥æ‰€æœ‰é»˜è®¤ç‰¹å¾æ˜¯å¦åœ¨æ•°æ®ä¸­
            missing_features = [f for f in default_features if f not in df_analysis.columns]
            if missing_features:
                st.warning(f"数据中缺少以下特征: {', '.join(missing_features)}")
            else:
                try:
                    # å‡†å¤‡æ•°æ®
                    # é¦–先确保df_analysis中没有NaN值
                    df_analysis_clean = df_analysis.dropna(subset=default_features + ['米重'])
                    # æ£€æŸ¥æ¸…理后的数据量
                    if len(df_analysis_clean) < 30:
                        st.warning("数据量不足,无法进行有效的预测分析")
                    else:
                        # åˆ›å»ºä¸€ä¸ªæ–°çš„DataFrame来存储所有特征和目标变量
                        all_features = df_analysis_clean[default_features + ['米重']].copy()
                        # æ·»åŠ æ—¶é—´ç›¸å…³ç‰¹å¾
                        if 'time' in df_analysis_clean.columns:
                            all_features['hour'] = df_analysis_clean['time'].dt.hour
                            all_features['minute'] = df_analysis_clean['time'].dt.minute
                            all_features['second'] = df_analysis_clean['time'].dt.second
                            all_features['time_of_day'] = all_features['hour'] * 3600 + all_features['minute'] * 60 + all_features['second']
                        else:
                            all_features['hour'] = 0
                            all_features['minute'] = 0
                            all_features['second'] = 0
                            all_features['time_of_day'] = 0
                        # æ·»åŠ æ»žåŽç‰¹å¾
                        for feature in default_features:
                            for lag in [1, 2, 3]:
                                all_features[f'{feature}_lag{lag}'] = all_features[feature].shift(lag)
                                all_features[f'{feature}_diff{lag}'] = all_features[feature].diff(lag)
                        # æ·»åŠ æ»šåŠ¨ç»Ÿè®¡ç‰¹å¾
                        for feature in default_features:
                            all_features[f'{feature}_rolling_mean'] = all_features[feature].rolling(window=5).mean()
                            all_features[f'{feature}_rolling_std'] = all_features[feature].rolling(window=5).std()
                            all_features[f'{feature}_rolling_min'] = all_features[feature].rolling(window=5).min()
                            all_features[f'{feature}_rolling_max'] = all_features[feature].rolling(window=5).max()
                        # æ¸…理所有NaN值
                        all_features_clean = all_features.dropna()
                        # æ£€æŸ¥æ¸…理后的数据量
                        if len(all_features_clean) < 20:
                            st.warning("特征工程后数据量不足,无法进行有效的预测分析")
                        else:
                            # åˆ†ç¦»ç‰¹å¾å’Œç›®æ ‡å˜é‡
                            feature_columns = [col for col in all_features_clean.columns if col != '米重']
                            X_final = all_features_clean[feature_columns]
                            y_final = all_features_clean['米重']
                            # æ£€æŸ¥æœ€ç»ˆæ•°æ®é‡
                            if len(X_final) >= 20:
                                # åˆ†å‰²è®­ç»ƒé›†å’Œæµ‹è¯•集
                                X_train, X_test, y_train, y_test = train_test_split(X_final, y_final, test_size=0.2, random_state=42)
                                # æ•°æ®æ ‡å‡†åŒ–
                                scaler_X = StandardScaler()
                                scaler_y = MinMaxScaler()
                                X_train_scaled = scaler_X.fit_transform(X_train)
                                X_test_scaled = scaler_X.transform(X_test)
                                y_train_scaled = scaler_y.fit_transform(y_train.values.reshape(-1, 1)).ravel()
                                y_test_scaled = scaler_y.transform(y_test.values.reshape(-1, 1)).ravel()
                                # æ¨¡åž‹è®­ç»ƒ
                                model = None
                                y_pred = None
                                if model_type == 'RandomForest':
                                    # éšæœºæ£®æž—回归
                                    model = RandomForestRegressor(n_estimators=100, random_state=42)
                                    model.fit(X_train, y_train)
                                    y_pred = model.predict(X_test)
                                elif model_type == 'GradientBoosting':
                                    # æ¢¯åº¦æå‡å›žå½’
                                    model = GradientBoostingRegressor(n_estimators=100, random_state=42)
                                    model.fit(X_train, y_train)
                                    y_pred = model.predict(X_test)
                                elif model_type == 'SVR':
                                    # æ”¯æŒå‘量回归
                                    model = SVR(kernel='rbf', C=1.0, gamma='scale')
                                    model.fit(X_train_scaled, y_train_scaled)
                                    y_pred_scaled = model.predict(X_test_scaled)
                                    y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
                                elif model_type == 'MLP':
                                    # å¤šå±‚感知器回归
                                    model = MLPRegressor(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
                                    model.fit(X_train_scaled, y_train_scaled)
                                    y_pred_scaled = model.predict(X_test_scaled)
                                    y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
                                elif use_deep_learning and model_type in ['LSTM', 'GRU', 'BiLSTM']:
                                    # å‡†å¤‡æ—¶é—´åºåˆ—数据
                                    sequence_length = st.session_state['ma_sequence_length']
                                    def create_sequences(X, y, seq_length):
                                        X_seq = []
                                        y_seq = []
                                        # ç¡®ä¿X和y的长度一致
                                        min_len = min(len(X), len(y))
                                        # ç¡®ä¿X和y的长度至少为seq_length + 1
                                        if min_len <= seq_length:
                                            return np.array([]), np.array([])
                                        # æˆªæ–­X和y到相同长度
                                        X_trimmed = X[:min_len]
                                        y_trimmed = y[:min_len]
                                        # åˆ›å»ºåºåˆ—
                                        for i in range(len(X_trimmed) - seq_length):
                                            X_seq.append(X_trimmed[i:i+seq_length])
                                            y_seq.append(y_trimmed[i+seq_length])
                                        return np.array(X_seq), np.array(y_seq)
                                    # ä¸ºæ·±åº¦å­¦ä¹ æ¨¡åž‹åˆ›å»ºåºåˆ—
                                    X_train_seq, y_train_seq = create_sequences(X_train_scaled, y_train_scaled, sequence_length)
                                    X_test_seq, y_test_seq = create_sequences(X_test_scaled, y_test_scaled, sequence_length)
                                    # ç¡®ä¿åºåˆ—数据长度一致
                                    if len(X_train_seq) != len(y_train_seq):
                                        min_len_train = min(len(X_train_seq), len(y_train_seq))
                                        X_train_seq = X_train_seq[:min_len_train]
                                        y_train_seq = y_train_seq[:min_len_train]
                                    if len(X_test_seq) != len(y_test_seq):
                                        min_len_test = min(len(X_test_seq), len(y_test_seq))
                                        X_test_seq = X_test_seq[:min_len_test]
                                        y_test_seq = y_test_seq[:min_len_test]
                                    # æ£€æŸ¥åˆ›å»ºçš„序列是否为空
                                    if len(X_train_seq) == 0 or len(y_train_seq) == 0:
                                        st.warning(f"数据量不足,无法创建有效的LSTM序列。需要至少 {sequence_length + 1} ä¸ªæ ·æœ¬ï¼Œå½“前只有 {min(len(X_train_scaled), len(y_train_scaled))} ä¸ªæ ·æœ¬ã€‚")
                                        # ä½¿ç”¨éšæœºæ£®æž—作为备选模型
                                        model = RandomForestRegressor(n_estimators=100, random_state=42)
                                        model.fit(X_train, y_train)
                                        y_pred = model.predict(X_test)
                                    else:
                                        # æž„建深度学习模型
                                        input_shape = (sequence_length, X_train_scaled.shape[1])
                                        deep_model = Sequential()
                                        if model_type == 'LSTM':
                                            deep_model.add(LSTM(64, return_sequences=True, input_shape=input_shape))
                                            deep_model.add(LSTM(32, return_sequences=False))
                                        elif model_type == 'GRU':
                                            deep_model.add(GRU(64, return_sequences=True, input_shape=input_shape))
                                            deep_model.add(GRU(32, return_sequences=False))
                                        elif model_type == 'BiLSTM':
                                            deep_model.add(Bidirectional(LSTM(64, return_sequences=True), input_shape=input_shape))
                                            deep_model.add(Bidirectional(LSTM(32, return_sequences=False)))
                                        deep_model.add(Dense(32, activation='relu'))
                                        deep_model.add(Dropout(0.2))
                                        deep_model.add(Dense(1))
                                        # ç¼–译模型
                                        deep_model.compile(optimizer=Adam(learning_rate=0.001), loss='mean_squared_error')
                                        # è®­ç»ƒæ¨¡åž‹
                                        # ç¡®ä¿X_train_seq和y_train_seq长度一致
                                        min_len_train = min(len(X_train_seq), len(y_train_seq))
                                        min_len_test = min(len(X_test_seq), len(y_test_seq))
                                        if min_len_train > 0 and min_len_test > 0:
                                            X_train_seq_trimmed = X_train_seq[:min_len_train]
                                            y_train_seq_trimmed = y_train_seq[:min_len_train]
                                            X_test_seq_trimmed = X_test_seq[:min_len_test]
                                            y_test_seq_trimmed = y_test_seq[:min_len_test]
                                            history = deep_model.fit(
                                                X_train_seq_trimmed, y_train_seq_trimmed,
                                                validation_data=(X_test_seq_trimmed, y_test_seq_trimmed),
                                                epochs=50,
                                                batch_size=32,
                                                verbose=0
                                            )
                                        else:
                                            st.warning("数据量不足,无法训练深度学习模型")
                                            # ä½¿ç”¨éšæœºæ£®æž—作为备选模型
                                            model = RandomForestRegressor(n_estimators=100, random_state=42)
                                            model.fit(X_train, y_train)
                                            y_pred = model.predict(X_test)
                                            # ç¡®ä¿y_test和y_pred长度一致
                                            min_len = min(len(y_test), len(y_pred))
                                            if min_len > 0:
                                                y_test_trimmed = y_test[:min_len]
                                                y_pred_trimmed = y_pred[:min_len]
                                            else:
                                                y_test_trimmed = y_test
                                                y_pred_trimmed = y_pred
                                        # é¢„测
                                        if 'X_test_seq_trimmed' in locals():
                                            y_pred_scaled = deep_model.predict(X_test_seq_trimmed).ravel()
                                        else:
                                            y_pred_scaled = deep_model.predict(X_test_seq).ravel()
                                        y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
                                        # ä¿å­˜æ¨¡åž‹
                                        model = deep_model
                                # è®¡ç®—评估指标
                                # ç¡®ä¿y_test和y_pred长度一致
                                min_len = min(len(y_test), len(y_pred))
                                if min_len > 0:
                                    y_test_trimmed = y_test[:min_len]
                                    y_pred_trimmed = y_pred[:min_len]
                                    r2 = r2_score(y_test_trimmed, y_pred_trimmed)
                                    mse = mean_squared_error(y_test_trimmed, y_pred_trimmed)
                                    mae = mean_absolute_error(y_test_trimmed, y_pred_trimmed)
                                    rmse = np.sqrt(mse)
                                else:
                                    r2 = 0
                                    mse = 0
                                    mae = 0
                                    rmse = 0
                                # æ˜¾ç¤ºæ¨¡åž‹æ€§èƒ½
                                metrics_cols = st.columns(2)
                                with metrics_cols[0]:
                                    st.metric("R² å¾—分", f"{r2:.4f}")
                                    st.metric("均方误差 (MSE)", f"{mse:.6f}")
                                with metrics_cols[1]:
                                    st.metric("平均绝对误差 (MAE)", f"{mae:.6f}")
                                    st.metric("均方根误差 (RMSE)", f"{rmse:.6f}")
                                # --- å®žé™…值与预测值对比 ---
                                st.subheader("🔄 å®žé™…值与预测值对比")
                                # åˆ›å»ºå¯¹æ¯”数据
                                compare_df = pd.DataFrame({
                                    '实际值': y_test_trimmed,
                                    '预测值': y_pred_trimmed
                                })
                                compare_df = compare_df.sort_index()
                                # åˆ›å»ºå¯¹æ¯”图
                                fig_compare = go.Figure()
                                fig_compare.add_trace(go.Scatter(
                                    x=compare_df.index,
                                    y=compare_df['实际值'],
                                    name='实际值',
                                    mode='lines+markers',
                                    line=dict(color='blue', width=2)
                                ))
                                fig_compare.add_trace(go.Scatter(
                                    x=compare_df.index,
                                    y=compare_df['预测值'],
                                    name='预测值',
                                    mode='lines+markers',
                                    line=dict(color='red', width=2, dash='dash')
                                ))
                                fig_compare.update_layout(
                                    title=f'测试集: å®žé™…米重 vs é¢„测米重 ({model_type})',
                                    xaxis=dict(title='时间'),
                                    yaxis=dict(title='米重 (Kg/m)'),
                                    legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
                                    height=400
                                )
                                st.plotly_chart(fig_compare, width='stretch')
                                # --- æ®‹å·®åˆ†æž ---
                                st.subheader("📉 æ®‹å·®åˆ†æž")
                                # è®¡ç®—残差
                                residuals = y_test_trimmed - y_pred_trimmed
                                # åˆ›å»ºæ®‹å·®å›¾
                                fig_residual = go.Figure()
                                fig_residual.add_trace(go.Scatter(
                                    x=y_pred,
                                    y=residuals,
                                    mode='markers',
                                    marker=dict(color='green', size=8, opacity=0.6)
                                ))
                                fig_residual.add_shape(
                                    type="line",
                                    x0=y_pred.min(),
                                    y0=0,
                                    x1=y_pred.max(),
                                    y1=0,
                                    line=dict(color="red", width=2, dash="dash")
                                )
                                fig_residual.update_layout(
                                    title='残差图',
                                    xaxis=dict(title='预测值'),
                                    yaxis=dict(title='残差'),
                                    height=400
                                )
                                st.plotly_chart(fig_residual, width='stretch')
                                # --- ç‰¹å¾é‡è¦æ€§ï¼ˆå¦‚果模型支持) ---
                                if model_type in ['RandomForest', 'GradientBoosting']:
                                    st.subheader("⚖️ ç‰¹å¾é‡è¦æ€§åˆ†æž")
                                    # è®¡ç®—特征重要性
                                    feature_importance = pd.DataFrame({
                                        '特征': X_train.columns,
                                        '重要性': model.feature_importances_
                                    })
                                    feature_importance = feature_importance.sort_values('重要性', ascending=False)
                                    # åˆ›å»ºç‰¹å¾é‡è¦æ€§å›¾
                                    fig_importance = px.bar(
                                        feature_importance,
                                        x='特征',
                                        y='重要性',
                                        title='特征重要性',
                                        color='重要性',
                                        color_continuous_scale='viridis'
                                    )
                                    fig_importance.update_layout(
                                        xaxis=dict(tickangle=-45),
                                        height=400
                                    )
                                    st.plotly_chart(fig_importance, width='stretch')
                                # --- é¢„测功能 ---
                                st.subheader("🔮 ç±³é‡é¢„测")
                                # åˆ›å»ºé¢„测表单,使用form包装以防止输入时触发重新分析
                                with st.form(key="prediction_form"):
                                    st.write("输入特征值进行米重预测:")
                                    predict_cols = st.columns(2)
                                    input_features = {}
                                    for i, feature in enumerate(default_features):
                                        with predict_cols[i % 2]:
                                            # èŽ·å–ç‰¹å¾çš„ç»Ÿè®¡ä¿¡æ¯
                                            min_val = df_analysis_clean[feature].min()
                                            max_val = df_analysis_clean[feature].max()
                                            mean_val = df_analysis_clean[feature].mean()
                                            input_features[feature] = st.number_input(
                                                f"{feature}",
                                                key=f"ma_pred_{feature}",
                                                value=float(mean_val),
                                                min_value=float(min_val),
                                                max_value=float(max_val),
                                                step=0.1
                                            )
                                    # é¢„测按钮
                                    predict_button = st.form_submit_button("预测米重")
                                if predict_button:
                                    # å‡†å¤‡é¢„测数据
                                    input_df = pd.DataFrame([input_features])
                                    # æ·»åŠ æ—¶é—´ç‰¹å¾ï¼ˆä½¿ç”¨å½“å‰æ—¶é—´ï¼‰
                                    current_time = datetime.now()
                                    time_features_input = pd.DataFrame({
                                        'hour': [current_time.hour],
                                        'minute': [current_time.minute],
                                        'second': [current_time.second],
                                        'time_of_day': [current_time.hour * 3600 + current_time.minute * 60 + current_time.second]
                                    })
                                    # æ·»åŠ æ»žåŽç‰¹å¾ï¼ˆä½¿ç”¨è¾“å…¥å€¼ä½œä¸ºæ›¿ä»£ï¼‰
                                    for feature in default_features:
                                        for lag in [1, 2, 3]:
                                            time_features_input[f'{feature}_lag{lag}'] = input_features[feature]
                                            time_features_input[f'{feature}_diff{lag}'] = 0.0
                                    # æ·»åŠ æ»šåŠ¨ç»Ÿè®¡ç‰¹å¾ï¼ˆä½¿ç”¨è¾“å…¥å€¼ä½œä¸ºæ›¿ä»£ï¼‰
                                    for feature in default_features:
                                        time_features_input[f'{feature}_rolling_mean'] = input_features[feature]
                                        time_features_input[f'{feature}_rolling_std'] = 0.0
                                        time_features_input[f'{feature}_rolling_min'] = input_features[feature]
                                        time_features_input[f'{feature}_rolling_max'] = input_features[feature]
                                    # åˆå¹¶ç‰¹å¾
                                    input_combined = pd.concat([input_df, time_features_input], axis=1)
                                    # é¢„测
                                    if model_type in ['SVR', 'MLP']:
                                        input_scaled = scaler_X.transform(input_combined)
                                        prediction_scaled = model.predict(input_scaled)
                                        predicted_weight = scaler_y.inverse_transform(prediction_scaled.reshape(-1, 1)).ravel()[0]
                                    elif use_deep_learning and model_type in ['LSTM', 'GRU', 'BiLSTM']:
                                        # ä¸ºæ·±åº¦å­¦ä¹ æ¨¡åž‹åˆ›å»ºåºåˆ—
                                        input_scaled = scaler_X.transform(input_combined)
                                        # é‡å¤è¾“入以创建序列
                                        sequence_length = st.session_state['ma_sequence_length']
                                        input_seq = np.tile(input_scaled, (sequence_length, 1)).reshape(1, sequence_length, -1)
                                        prediction_scaled = model.predict(input_seq).ravel()[0]
                                        predicted_weight = scaler_y.inverse_transform(prediction_scaled.reshape(-1, 1)).ravel()[0]
                                    else:
                                        predicted_weight = model.predict(input_combined)[0]
                                    # æ˜¾ç¤ºé¢„测结果
                                    st.success(f"预测米重: {predicted_weight:.4f} Kg/m")
                                # --- æ•°æ®é¢„览 ---
                                st.subheader("🔍 æ•°æ®é¢„览")
                                st.dataframe(df_analysis.head(20), width='stretch')
                                # --- å¯¼å‡ºæ•°æ® ---
                                st.subheader("💾 å¯¼å‡ºæ•°æ®")
                                # å°†æ•°æ®è½¬æ¢ä¸ºCSV格式
                                csv = df_analysis.to_csv(index=False)
                                # åˆ›å»ºä¸‹è½½æŒ‰é’®
                                st.download_button(
                                    label="导出整合后的数据 (CSV)",
                                    data=csv,
                                    file_name=f"metered_weight_advanced_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
                                    mime="text/csv",
                                    help="点击按钮导出整合后的米重分析数据"
                                )
                except Exception as e:
                    st.error(f"模型训练或预测失败: {str(e)}")
    else:
        # æç¤ºç”¨æˆ·ç‚¹å‡»å¼€å§‹åˆ†æžæŒ‰é’®
        st.info("请选择时间范围并点击'开始分析'按钮获取数据。")
app/pages/metered_weight_advanced pytorch.py
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,873 @@
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from app.services.extruder_service import ExtruderService
from app.services.main_process_service import MainProcessService
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.svm import SVR
from sklearn.neural_network import MLPRegressor
# å°è¯•导入深度学习库
use_deep_learning = False
try:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    use_deep_learning = True
    # æ£€æµ‹GPU是否可用
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    st.success(f"使用设备: {device}")
except ImportError:
    st.warning("未检测到PyTorch,深度学习模型将不可用。请安装pytorch以使用LSTM/GRU模型。")
# PyTorch深度学习模型定义
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, num_layers=2):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc1 = nn.Linear(hidden_dim, 32)
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(32, 1)
    def forward(self, x):
        out, _ = self.lstm(x)
        out = out[:, -1, :]
        out = torch.relu(self.fc1(out))
        out = self.dropout(out)
        out = self.fc2(out)
        return out
class GRUModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, num_layers=2):
        super(GRUModel, self).__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc1 = nn.Linear(hidden_dim, 32)
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(32, 1)
    def forward(self, x):
        out, _ = self.gru(x)
        out = out[:, -1, :]
        out = torch.relu(self.fc1(out))
        out = self.dropout(out)
        out = self.fc2(out)
        return out
class BiLSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, num_layers=2):
        super(BiLSTMModel, self).__init__()
        self.bilstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)
        self.fc1 = nn.Linear(hidden_dim * 2, 32)
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(32, 1)
    def forward(self, x):
        out, _ = self.bilstm(x)
        out = out[:, -1, :]
        out = torch.relu(self.fc1(out))
        out = self.dropout(out)
        out = self.fc2(out)
        return out
def show_metered_weight_advanced():
    # åˆå§‹åŒ–服务
    extruder_service = ExtruderService()
    main_process_service = MainProcessService()
    # é¡µé¢æ ‡é¢˜
    st.title("米重高级预测分析")
    # åˆå§‹åŒ–会话状态
    if 'ma_start_date' not in st.session_state:
        st.session_state['ma_start_date'] = datetime.now().date() - timedelta(days=7)
    if 'ma_end_date' not in st.session_state:
        st.session_state['ma_end_date'] = datetime.now().date()
    if 'ma_quick_select' not in st.session_state:
        st.session_state['ma_quick_select'] = "最近7天"
    if 'ma_model_type' not in st.session_state:
        st.session_state['ma_model_type'] = 'RandomForest'
    if 'ma_sequence_length' not in st.session_state:
        st.session_state['ma_sequence_length'] = 10
    # é»˜è®¤ç‰¹å¾åˆ—表(不再允许用户选择)
    default_features = ['螺杆转速', '机头压力', '流程主速', '螺杆温度',
                       '后机筒温度', '前机筒温度', '机头温度']
    # å®šä¹‰å›žè°ƒå‡½æ•°
    def update_dates(qs):
        st.session_state['ma_quick_select'] = qs
        today = datetime.now().date()
        if qs == "今天":
            st.session_state['ma_start_date'] = today
            st.session_state['ma_end_date'] = today
        elif qs == "最近3天":
            st.session_state['ma_start_date'] = today - timedelta(days=3)
            st.session_state['ma_end_date'] = today
        elif qs == "最近7天":
            st.session_state['ma_start_date'] = today - timedelta(days=7)
            st.session_state['ma_end_date'] = today
        elif qs == "最近30天":
            st.session_state['ma_start_date'] = today - timedelta(days=30)
            st.session_state['ma_end_date'] = today
    def on_date_change():
        st.session_state['ma_quick_select'] = "自定义"
    # æŸ¥è¯¢æ¡ä»¶åŒºåŸŸ
    with st.expander("🔍 æŸ¥è¯¢é…ç½®", expanded=True):
        # æ·»åŠ è‡ªå®šä¹‰ CSS å®žçŽ°å“åº”å¼æ¢è¡Œ
        st.markdown("""
            <style>
            /* å¼ºåˆ¶åˆ—容器换行 */
            [data-testid="stExpander"] [data-testid="column"] {
                flex: 1 1 120px !important;
                min-width: 120px !important;
            }
            /* é’ˆå¯¹æ—¥æœŸè¾“入框列稍微加宽一点 */
            @media (min-width: 768px) {
                [data-testid="stExpander"] [data-testid="column"]:nth-child(6),
                [data-testid="stExpander"] [data-testid="column"]:nth-child(7) {
                    flex: 2 1 180px !important;
                    min-width: 180px !important;
                }
            }
            </style>
            """, unsafe_allow_html=True)
        # åˆ›å»ºå¸ƒå±€
        cols = st.columns([1, 1, 1, 1, 1, 1.5, 1.5, 1])
        options = ["今天", "最近3天", "最近7天", "最近30天", "自定义"]
        for i, option in enumerate(options):
            with cols[i]:
                # æ ¹æ®å½“前选择状态决定按钮类型
                button_type = "primary" if st.session_state['ma_quick_select'] == option else "secondary"
                if st.button(option, key=f"btn_ma_{option}", width='stretch', type=button_type):
                    update_dates(option)
                    st.rerun()
        with cols[5]:
            start_date = st.date_input(
                "开始日期",
                label_visibility="collapsed",
                key="ma_start_date",
                on_change=on_date_change
            )
        with cols[6]:
            end_date = st.date_input(
                "结束日期",
                label_visibility="collapsed",
                key="ma_end_date",
                on_change=on_date_change
            )
        with cols[7]:
            query_button = st.button("🚀 å¼€å§‹åˆ†æž", key="ma_query", width='stretch')
        # æ¨¡åž‹é…ç½®
        st.markdown("---")
        st.write("🤖 **模型配置**")
        model_cols = st.columns(2)
        with model_cols[0]:
            # æ¨¡åž‹ç±»åž‹é€‰æ‹©
            model_options = ['RandomForest', 'GradientBoosting', 'SVR', 'MLP']
            if use_deep_learning:
                model_options.extend(['LSTM', 'GRU', 'BiLSTM'])
            model_type = st.selectbox(
                "模型类型",
                options=model_options,
                key="ma_model_type",
                help="选择用于预测的模型类型"
            )
        with model_cols[1]:
            # åºåˆ—长度(仅适用于深度学习模型)
            if model_type in ['LSTM', 'GRU', 'BiLSTM']:
                sequence_length = st.slider(
                    "序列长度",
                    min_value=5,
                    max_value=30,
                    value=st.session_state['ma_sequence_length'],
                    step=1,
                    help="用于深度学习模型的时间序列长度",
                    key="ma_sequence_length"
                )
            else:
                st.session_state['ma_sequence_length'] = 10
                st.write("序列长度: 10 (默认,仅适用于深度学习模型)")
    # è½¬æ¢ä¸ºdatetime对象
    start_dt = datetime.combine(start_date, datetime.min.time())
    end_dt = datetime.combine(end_date, datetime.max.time())
    # æŸ¥è¯¢å¤„理
    if query_button:
        with st.spinner("正在获取数据..."):
            # 1. èŽ·å–å®Œæ•´çš„æŒ¤å‡ºæœºæ•°æ®
            df_extruder_full = extruder_service.get_extruder_data(start_dt, end_dt)
            # 2. èŽ·å–ä¸»æµç¨‹æŽ§åˆ¶æ•°æ®
            df_main_speed = main_process_service.get_cutting_setting_data(start_dt, end_dt)
            df_temp = main_process_service.get_temperature_control_data(start_dt, end_dt)
            # æ£€æŸ¥æ˜¯å¦æœ‰æ•°æ®
            has_data = any([
                df_extruder_full is not None and not df_extruder_full.empty,
                df_main_speed is not None and not df_main_speed.empty,
                df_temp is not None and not df_temp.empty
            ])
            if not has_data:
                st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
                # æ¸…除缓存数据
                for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp', 'last_query_start', 'last_query_end']:
                    if key in st.session_state:
                        del st.session_state[key]
                return
            # ç¼“存数据到会话状态
            st.session_state['cached_extruder_full'] = df_extruder_full
            st.session_state['cached_main_speed'] = df_main_speed
            st.session_state['cached_temp'] = df_temp
            st.session_state['last_query_start'] = start_dt
            st.session_state['last_query_end'] = end_dt
    # æ•°æ®å¤„理和分析
    if all(key in st.session_state for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp']):
        with st.spinner("正在分析数据..."):
            # èŽ·å–ç¼“å­˜æ•°æ®
            df_extruder_full = st.session_state['cached_extruder_full']
            df_main_speed = st.session_state['cached_main_speed']
            df_temp = st.session_state['cached_temp']
            # æ£€æŸ¥æ˜¯å¦æœ‰æ•°æ®
            has_data = any([
                df_extruder_full is not None and not df_extruder_full.empty,
                df_main_speed is not None and not df_main_speed.empty,
                df_temp is not None and not df_temp.empty
            ])
            if not has_data:
                st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
                return
            # æ•°æ®æ•´åˆä¸Žé¢„处理
            def integrate_data(df_extruder_full, df_main_speed, df_temp):
                # ç¡®ä¿æŒ¤å‡ºæœºæ•°æ®å­˜åœ¨
                if df_extruder_full is None or df_extruder_full.empty:
                    return None
                # åˆ›å»ºåªåŒ…含米重和时间的主数据集
                df_merged = df_extruder_full[['time', 'metered_weight', 'screw_speed_actual', 'head_pressure']].copy()
                # æ•´åˆä¸»æµç¨‹æ•°æ®
                if df_main_speed is not None and not df_main_speed.empty:
                    df_main_speed = df_main_speed[['time', 'process_main_speed']]
                    df_merged = pd.merge_asof(
                        df_merged.sort_values('time'),
                        df_main_speed.sort_values('time'),
                        on='time',
                        direction='nearest',
                        tolerance=pd.Timedelta('1min')
                    )
                # æ•´åˆæ¸©åº¦æ•°æ®
                if df_temp is not None and not df_temp.empty:
                    temp_cols = ['time', 'nakata_extruder_screw_display_temp',
                               'nakata_extruder_rear_barrel_display_temp',
                               'nakata_extruder_front_barrel_display_temp',
                               'nakata_extruder_head_display_temp']
                    df_temp_subset = df_temp[temp_cols].copy()
                    df_merged = pd.merge_asof(
                        df_merged.sort_values('time'),
                        df_temp_subset.sort_values('time'),
                        on='time',
                        direction='nearest',
                        tolerance=pd.Timedelta('1min')
                    )
                # é‡å‘½ååˆ—以提高可读性
                df_merged.rename(columns={
                    'screw_speed_actual': '螺杆转速',
                    'head_pressure': '机头压力',
                    'process_main_speed': '流程主速',
                    'nakata_extruder_screw_display_temp': '螺杆温度',
                    'nakata_extruder_rear_barrel_display_temp': '后机筒温度',
                    'nakata_extruder_front_barrel_display_temp': '前机筒温度',
                    'nakata_extruder_head_display_temp': '机头温度'
                }, inplace=True)
                # æ¸…理数据
                df_merged.dropna(subset=['metered_weight'], inplace=True)
                return df_merged
            # æ‰§è¡Œæ•°æ®æ•´åˆ
            df_analysis = integrate_data(df_extruder_full, df_main_speed, df_temp)
            if df_analysis is None or df_analysis.empty:
                st.warning("数据整合失败,请检查数据质量或调整时间范围。")
                return
            # é‡å‘½åç±³é‡åˆ—
            df_analysis.rename(columns={'metered_weight': '米重'}, inplace=True)
            # --- åŽŸå§‹æ•°æ®è¶‹åŠ¿å›¾ ---
            st.subheader("📈 åŽŸå§‹æ•°æ®è¶‹åŠ¿å›¾")
            # åˆ›å»ºè¶‹åŠ¿å›¾
            fig_trend = go.Figure()
            # æ·»åŠ ç±³é‡æ•°æ®
            if df_extruder_full is not None and not df_extruder_full.empty:
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_full['time'],
                    y=df_extruder_full['metered_weight'],
                    name='米重 (Kg/m)',
                    mode='lines',
                    line=dict(color='blue', width=2)
                ))
                # æ·»åŠ èžºæ†è½¬é€Ÿ
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_full['time'],
                    y=df_extruder_full['screw_speed_actual'],
                    name='螺杆转速 (RPM)',
                    mode='lines',
                    line=dict(color='green', width=1.5),
                    yaxis='y2'
                ))
                # æ·»åŠ æœºå¤´åŽ‹åŠ›
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_full['time'],
                    y=df_extruder_full['head_pressure'],
                    name='机头压力',
                    mode='lines',
                    line=dict(color='orange', width=1.5),
                    yaxis='y3'
                ))
            # æ·»åŠ æµç¨‹ä¸»é€Ÿ
            if df_main_speed is not None and not df_main_speed.empty:
                fig_trend.add_trace(go.Scatter(
                    x=df_main_speed['time'],
                    y=df_main_speed['process_main_speed'],
                    name='流程主速 (M/Min)',
                    mode='lines',
                    line=dict(color='red', width=1.5),
                    yaxis='y4'
                ))
            # æ·»åŠ æ¸©åº¦æ•°æ®
            if df_temp is not None and not df_temp.empty:
                # èžºæ†æ¸©åº¦
                fig_trend.add_trace(go.Scatter(
                    x=df_temp['time'],
                    y=df_temp['nakata_extruder_screw_display_temp'],
                    name='螺杆温度 (°C)',
                    mode='lines',
                    line=dict(color='purple', width=1),
                    yaxis='y5'
                ))
            # é…ç½®è¶‹åŠ¿å›¾å¸ƒå±€
            fig_trend.update_layout(
                title='原始数据趋势',
                xaxis=dict(
                    title='时间',
                    rangeslider=dict(visible=True),
                    type='date'
                ),
                yaxis=dict(
                    title='米重 (Kg/m)',
                    title_font=dict(color='blue'),
                    tickfont=dict(color='blue')
                ),
                yaxis2=dict(
                    title='螺杆转速 (RPM)',
                    title_font=dict(color='green'),
                    tickfont=dict(color='green'),
                    overlaying='y',
                    side='right'
                ),
                yaxis3=dict(
                    title='机头压力',
                    title_font=dict(color='orange'),
                    tickfont=dict(color='orange'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.85
                ),
                yaxis4=dict(
                    title='流程主速 (M/Min)',
                    title_font=dict(color='red'),
                    tickfont=dict(color='red'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.75
                ),
                yaxis5=dict(
                    title='温度 (°C)',
                    title_font=dict(color='purple'),
                    tickfont=dict(color='purple'),
                    overlaying='y',
                    side='left',
                    anchor='free',
                    position=0.15
                ),
                legend=dict(
                    orientation="h",
                    yanchor="bottom",
                    y=1.02,
                    xanchor="right",
                    x=1
                ),
                height=600,
                margin=dict(l=100, r=200, t=100, b=100),
                hovermode='x unified'
            )
            # æ˜¾ç¤ºè¶‹åŠ¿å›¾
            st.plotly_chart(fig_trend, width='stretch', config={'scrollZoom': True})
            # --- é«˜çº§é¢„测分析 ---
            st.subheader("📊 é«˜çº§é¢„测分析")
            # æ£€æŸ¥æ‰€æœ‰é»˜è®¤ç‰¹å¾æ˜¯å¦åœ¨æ•°æ®ä¸­
            missing_features = [f for f in default_features if f not in df_analysis.columns]
            if missing_features:
                st.warning(f"数据中缺少以下特征: {', '.join(missing_features)}")
            else:
                # å‡†å¤‡æ•°æ®
                X = df_analysis[default_features]
                y = df_analysis['米重']
                # æ¸…理数据中的NaN值
                combined = pd.concat([X, y], axis=1)
                combined_clean = combined.dropna()
                # æ£€æŸ¥æ¸…理后的数据量
                if len(combined_clean) < 30:
                    st.warning("数据量不足,无法进行有效的预测分析")
                else:
                    # é‡æ–°åˆ†ç¦»X和y
                    X_clean = combined_clean[default_features]
                    y_clean = combined_clean['米重']
                    # ç‰¹å¾å·¥ç¨‹ï¼šæ·»åŠ æ—¶é—´ç›¸å…³ç‰¹å¾
                    # ç¡®ä¿ä½¿ç”¨æ—¶é—´åˆ—作为索引
                    if 'time' in combined_clean.columns:
                        # å°†time列设置为索引
                        combined_clean = combined_clean.set_index('time')
                        # åˆ›å»ºæ—¶é—´ç‰¹å¾
                        time_features = pd.DataFrame(index=combined_clean.index)
                        time_features['hour'] = combined_clean.index.hour
                        time_features['minute'] = combined_clean.index.minute
                        time_features['second'] = combined_clean.index.second
                        time_features['time_of_day'] = time_features['hour'] * 3600 + time_features['minute'] * 60 + time_features['second']
                    else:
                        # å¦‚果没有time列,创建空的时间特征
                        time_features = pd.DataFrame(index=combined_clean.index)
                        time_features['hour'] = 0
                        time_features['minute'] = 0
                        time_features['second'] = 0
                        time_features['time_of_day'] = 0
                    # æ·»åŠ æ»žåŽç‰¹å¾
                    for feature in default_features:
                        for lag in [1, 2, 3]:
                            time_features[f'{feature}_lag{lag}'] = X_clean[feature].shift(lag)
                            time_features[f'{feature}_diff{lag}'] = X_clean[feature].diff(lag)
                    # æ·»åŠ æ»šåŠ¨ç»Ÿè®¡ç‰¹å¾
                    for feature in default_features:
                        time_features[f'{feature}_rolling_mean'] = X_clean[feature].rolling(window=5).mean()
                        time_features[f'{feature}_rolling_std'] = X_clean[feature].rolling(window=5).std()
                        time_features[f'{feature}_rolling_min'] = X_clean[feature].rolling(window=5).min()
                        time_features[f'{feature}_rolling_max'] = X_clean[feature].rolling(window=5).max()
                    # æ¸…理滞后特征和滚动统计特征产生的NaN值
                    time_features.dropna(inplace=True)
                    # å¯¹é½X和y
                    common_index = time_features.index.intersection(y_clean.index)
                    X_final = pd.concat([X_clean.loc[common_index], time_features.loc[common_index]], axis=1)
                    y_final = y_clean.loc[common_index]
                    # æ£€æŸ¥æœ€ç»ˆæ•°æ®é‡
                    if len(X_final) < 20:
                        st.warning("特征工程后数据量不足,无法进行有效的预测分析")
                    else:
                        # åˆ†å‰²è®­ç»ƒé›†å’Œæµ‹è¯•集
                        X_train, X_test, y_train, y_test = train_test_split(X_final, y_final, test_size=0.2, random_state=42)
                        # æ•°æ®æ ‡å‡†åŒ–
                        scaler_X = StandardScaler()
                        scaler_y = MinMaxScaler()
                        X_train_scaled = scaler_X.fit_transform(X_train)
                        X_test_scaled = scaler_X.transform(X_test)
                        y_train_scaled = scaler_y.fit_transform(y_train.values.reshape(-1, 1)).ravel()
                        y_test_scaled = scaler_y.transform(y_test.values.reshape(-1, 1)).ravel()
                        # æ¨¡åž‹è®­ç»ƒ
                        model = None
                        y_pred = None
                        try:
                            if model_type == 'RandomForest':
                                # éšæœºæ£®æž—回归
                                model = RandomForestRegressor(n_estimators=100, random_state=42)
                                model.fit(X_train, y_train)
                                y_pred = model.predict(X_test)
                            elif model_type == 'GradientBoosting':
                                # æ¢¯åº¦æå‡å›žå½’
                                model = GradientBoostingRegressor(n_estimators=100, random_state=42)
                                model.fit(X_train, y_train)
                                y_pred = model.predict(X_test)
                            elif model_type == 'SVR':
                                # æ”¯æŒå‘量回归
                                model = SVR(kernel='rbf', C=1.0, gamma='scale')
                                model.fit(X_train_scaled, y_train_scaled)
                                y_pred_scaled = model.predict(X_test_scaled)
                                y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
                            elif model_type == 'MLP':
                                # å¤šå±‚感知器回归
                                model = MLPRegressor(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
                                model.fit(X_train_scaled, y_train_scaled)
                                y_pred_scaled = model.predict(X_test_scaled)
                                y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
                            elif use_deep_learning and model_type in ['LSTM', 'GRU', 'BiLSTM']:
                                # å‡†å¤‡æ—¶é—´åºåˆ—数据
                                sequence_length = st.session_state['ma_sequence_length']
                                def create_sequences(X, y, seq_length):
                                    X_seq = []
                                    y_seq = []
                                    # ç¡®ä¿X和y的长度一致
                                    min_len = min(len(X), len(y))
                                    # ç¡®ä¿X和y的长度至少为seq_length + 1
                                    if min_len <= seq_length:
                                        return np.array([]), np.array([])
                                    # æˆªæ–­X和y到相同长度
                                    X_trimmed = X[:min_len]
                                    y_trimmed = y[:min_len]
                                    # åˆ›å»ºåºåˆ—
                                    for i in range(len(X_trimmed) - seq_length):
                                        X_seq.append(X_trimmed[i:i+seq_length])
                                        y_seq.append(y_trimmed[i+seq_length])
                                    return np.array(X_seq), np.array(y_seq)
                                # ä¸ºæ·±åº¦å­¦ä¹ æ¨¡åž‹åˆ›å»ºåºåˆ—
                                X_train_seq, y_train_seq = create_sequences(X_train_scaled, y_train_scaled, sequence_length)
                                X_test_seq, y_test_seq = create_sequences(X_test_scaled, y_test_scaled, sequence_length)
                                # æ£€æŸ¥åˆ›å»ºçš„序列是否为空
                                if len(X_train_seq) == 0 or len(y_train_seq) == 0:
                                    st.warning(f"数据量不足,无法创建有效的LSTM序列。需要至少 {sequence_length + 1} ä¸ªæ ·æœ¬ï¼Œå½“前只有 {min(len(X_train_scaled), len(y_train_scaled))} ä¸ªæ ·æœ¬ã€‚")
                                    # ä½¿ç”¨éšæœºæ£®æž—作为备选模型
                                    model = RandomForestRegressor(n_estimators=100, random_state=42)
                                    model.fit(X_train, y_train)
                                    y_pred = model.predict(X_test)
                                else:
                                    # è½¬æ¢ä¸ºPyTorch张量并移动到设备
                                    X_train_tensor = torch.tensor(X_train_seq, dtype=torch.float32).to(device)
                                    y_train_tensor = torch.tensor(y_train_seq, dtype=torch.float32).unsqueeze(1).to(device)
                                    X_test_tensor = torch.tensor(X_test_seq, dtype=torch.float32).to(device)
                                    y_test_tensor = torch.tensor(y_test_seq, dtype=torch.float32).unsqueeze(1).to(device)
                                    # æž„建PyTorch模型并移动到设备
                                    input_dim = X_train_scaled.shape[1]
                                    if model_type == 'LSTM':
                                        deep_model = LSTMModel(input_dim).to(device)
                                    elif model_type == 'GRU':
                                        deep_model = GRUModel(input_dim).to(device)
                                    elif model_type == 'BiLSTM':
                                        deep_model = BiLSTMModel(input_dim).to(device)
                                    # å®šä¹‰æŸå¤±å‡½æ•°å’Œä¼˜åŒ–器
                                    criterion = nn.MSELoss()
                                    optimizer = optim.Adam(deep_model.parameters(), lr=0.001)
                                    # æ˜¾ç¤ºä½¿ç”¨çš„设备
                                    st.info(f"使用设备: {device}")
                                    # è®­ç»ƒæ¨¡åž‹
                                    num_epochs = 50
                                    batch_size = 32
                                    for epoch in range(num_epochs):
                                        deep_model.train()
                                        optimizer.zero_grad()
                                        # å‰å‘ä¼ æ’­
                                        outputs = deep_model(X_train_tensor)
                                        loss = criterion(outputs, y_train_tensor)
                                        # åå‘传播和优化
                                        loss.backward()
                                        optimizer.step()
                                    # é¢„测
                                    deep_model.eval()
                                    with torch.no_grad():
                                        y_pred_scaled_tensor = deep_model(X_test_tensor)
                                        y_pred_scaled = y_pred_scaled_tensor.numpy().ravel()
                                        y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
                                    # å°†y_test_seq转换回原始尺度
                                    y_test_actual = scaler_y.inverse_transform(y_test_seq.reshape(-1, 1)).ravel()
                                # ä¿å­˜æ¨¡åž‹
                                model = deep_model
                            # è®¡ç®—评估指标
                            if model_type in ['LSTM', 'GRU', 'BiLSTM']:
                                # ä½¿ç”¨è½¬æ¢åŽçš„y_test_seq作为真实值
                                r2 = r2_score(y_test_actual, y_pred)
                                mse = mean_squared_error(y_test_actual, y_pred)
                                mae = mean_absolute_error(y_test_actual, y_pred)
                                rmse = np.sqrt(mse)
                            else:
                                # ä½¿ç”¨åŽŸå§‹çš„y_test作为真实值
                                r2 = r2_score(y_test, y_pred)
                                mse = mean_squared_error(y_test, y_pred)
                                mae = mean_absolute_error(y_test, y_pred)
                                rmse = np.sqrt(mse)
                            # æ˜¾ç¤ºæ¨¡åž‹æ€§èƒ½
                            metrics_cols = st.columns(2)
                            with metrics_cols[0]:
                                st.metric("R² å¾—分", f"{r2:.4f}")
                                st.metric("均方误差 (MSE)", f"{mse:.6f}")
                            with metrics_cols[1]:
                                st.metric("平均绝对误差 (MAE)", f"{mae:.6f}")
                                st.metric("均方根误差 (RMSE)", f"{rmse:.6f}")
                            # --- å®žé™…值与预测值对比 ---
                            st.subheader("🔄 å®žé™…值与预测值对比")
                            # åˆ›å»ºå¯¹æ¯”数据
                            if model_type in ['LSTM', 'GRU', 'BiLSTM']:
                                # ä½¿ç”¨è½¬æ¢åŽçš„y_test_actual
                                compare_df = pd.DataFrame({
                                    '实际值': y_test_actual,
                                    '预测值': y_pred
                                })
                            else:
                                # ä½¿ç”¨åŽŸå§‹çš„y_test
                                compare_df = pd.DataFrame({
                                    '实际值': y_test,
                                    '预测值': y_pred
                                })
                            compare_df = compare_df.sort_index()
                            # åˆ›å»ºå¯¹æ¯”图
                            fig_compare = go.Figure()
                            fig_compare.add_trace(go.Scatter(
                                x=compare_df.index,
                                y=compare_df['实际值'],
                                name='实际值',
                                mode='lines+markers',
                                line=dict(color='blue', width=2)
                            ))
                            fig_compare.add_trace(go.Scatter(
                                x=compare_df.index,
                                y=compare_df['预测值'],
                                name='预测值',
                                mode='lines+markers',
                                line=dict(color='red', width=2, dash='dash')
                            ))
                            fig_compare.update_layout(
                                title=f'测试集: å®žé™…米重 vs é¢„测米重 ({model_type})',
                                xaxis=dict(title='时间'),
                                yaxis=dict(title='米重 (Kg/m)'),
                                legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
                                height=400
                            )
                            st.plotly_chart(fig_compare, width='stretch')
                            # --- æ®‹å·®åˆ†æž ---
                            st.subheader("📉 æ®‹å·®åˆ†æž")
                            # è®¡ç®—残差
                            if model_type in ['LSTM', 'GRU', 'BiLSTM']:
                                # ä½¿ç”¨è½¬æ¢åŽçš„y_test_actual
                                residuals = y_test_actual - y_pred
                            else:
                                # ä½¿ç”¨åŽŸå§‹çš„y_test
                                residuals = y_test - y_pred
                            # åˆ›å»ºæ®‹å·®å›¾
                            fig_residual = go.Figure()
                            fig_residual.add_trace(go.Scatter(
                                x=y_pred,
                                y=residuals,
                                mode='markers',
                                marker=dict(color='green', size=8, opacity=0.6)
                            ))
                            fig_residual.add_shape(
                                type="line",
                                x0=y_pred.min(),
                                y0=0,
                                x1=y_pred.max(),
                                y1=0,
                                line=dict(color="red", width=2, dash="dash")
                            )
                            fig_residual.update_layout(
                                title='残差图',
                                xaxis=dict(title='预测值'),
                                yaxis=dict(title='残差'),
                                height=400
                            )
                            st.plotly_chart(fig_residual, width='stretch')
                            # --- ç‰¹å¾é‡è¦æ€§ï¼ˆå¦‚果模型支持) ---
                            if model_type in ['RandomForest', 'GradientBoosting']:
                                st.subheader("⚖️ ç‰¹å¾é‡è¦æ€§åˆ†æž")
                                # è®¡ç®—特征重要性
                                feature_importance = pd.DataFrame({
                                    '特征': X_train.columns,
                                    '重要性': model.feature_importances_
                                })
                                feature_importance = feature_importance.sort_values('重要性', ascending=False)
                                # åˆ›å»ºç‰¹å¾é‡è¦æ€§å›¾
                                fig_importance = px.bar(
                                    feature_importance,
                                    x='特征',
                                    y='重要性',
                                    title='特征重要性',
                                    color='重要性',
                                    color_continuous_scale='viridis'
                                )
                                fig_importance.update_layout(
                                    xaxis=dict(tickangle=-45),
                                    height=400
                                )
                                st.plotly_chart(fig_importance, width='stretch')
                            # --- é¢„测功能 ---
                            st.subheader("🔮 ç±³é‡é¢„测")
                            # åˆ›å»ºé¢„测表单
                            st.write("输入特征值进行米重预测:")
                            predict_cols = st.columns(2)
                            input_features = {}
                            for i, feature in enumerate(default_features):
                                with predict_cols[i % 2]:
                                    # èŽ·å–ç‰¹å¾çš„ç»Ÿè®¡ä¿¡æ¯
                                    min_val = X_clean[feature].min()
                                    max_val = X_clean[feature].max()
                                    mean_val = X_clean[feature].mean()
                                    input_features[feature] = st.number_input(
                                        f"{feature}",
                                        key=f"ma_pred_{feature}",
                                        value=float(mean_val),
                                        min_value=float(min_val),
                                        max_value=float(max_val),
                                        step=0.1
                                    )
                            if st.button("预测米重"):
                                # å‡†å¤‡é¢„测数据
                                input_df = pd.DataFrame([input_features])
                                # æ·»åŠ æ—¶é—´ç‰¹å¾ï¼ˆä½¿ç”¨å½“å‰æ—¶é—´ï¼‰
                                current_time = datetime.now()
                                time_features_input = pd.DataFrame({
                                    'hour': [current_time.hour],
                                    'minute': [current_time.minute],
                                    'second': [current_time.second],
                                    'time_of_day': [current_time.hour * 3600 + current_time.minute * 60 + current_time.second]
                                })
                                # æ·»åŠ æ»žåŽç‰¹å¾ï¼ˆä½¿ç”¨è¾“å…¥å€¼ä½œä¸ºæ›¿ä»£ï¼‰
                                for feature in default_features:
                                    for lag in [1, 2, 3]:
                                        time_features_input[f'{feature}_lag{lag}'] = input_features[feature]
                                        time_features_input[f'{feature}_diff{lag}'] = 0.0
                                # æ·»åŠ æ»šåŠ¨ç»Ÿè®¡ç‰¹å¾ï¼ˆä½¿ç”¨è¾“å…¥å€¼ä½œä¸ºæ›¿ä»£ï¼‰
                                for feature in default_features:
                                    time_features_input[f'{feature}_rolling_mean'] = input_features[feature]
                                    time_features_input[f'{feature}_rolling_std'] = 0.0
                                    time_features_input[f'{feature}_rolling_min'] = input_features[feature]
                                    time_features_input[f'{feature}_rolling_max'] = input_features[feature]
                                # åˆå¹¶ç‰¹å¾
                                input_combined = pd.concat([input_df, time_features_input], axis=1)
                                # é¢„测
                                if model_type in ['SVR', 'MLP']:
                                    input_scaled = scaler_X.transform(input_combined)
                                    prediction_scaled = model.predict(input_scaled)
                                    predicted_weight = scaler_y.inverse_transform(prediction_scaled.reshape(-1, 1)).ravel()[0]
                                elif use_deep_learning and model_type in ['LSTM', 'GRU', 'BiLSTM']:
                                    # ä¸ºæ·±åº¦å­¦ä¹ æ¨¡åž‹åˆ›å»ºåºåˆ—
                                    input_scaled = scaler_X.transform(input_combined)
                                    # é‡å¤è¾“入以创建序列
                                    input_seq = np.tile(input_scaled, (sequence_length, 1)).reshape(1, sequence_length, -1)
                                    # è½¬æ¢ä¸ºPyTorch张量并移动到设备
                                    input_tensor = torch.tensor(input_seq, dtype=torch.float32).to(device)
                                    # é¢„测
                                    model.eval()
                                    with torch.no_grad():
                                        prediction_scaled_tensor = model(input_tensor)
                                        prediction_scaled = prediction_scaled_tensor.cpu().numpy().ravel()[0]
                                    predicted_weight = scaler_y.inverse_transform(prediction_scaled.reshape(-1, 1)).ravel()[0]
                                else:
                                    predicted_weight = model.predict(input_combined)[0]
                                # æ˜¾ç¤ºé¢„测结果
                                st.success(f"预测米重: {predicted_weight:.4f} Kg/m")
                            # --- æ•°æ®é¢„览 ---
                            st.subheader("🔍 æ•°æ®é¢„览")
                            st.dataframe(df_analysis.head(20), width='stretch')
                            # --- å¯¼å‡ºæ•°æ® ---
                            st.subheader("💾 å¯¼å‡ºæ•°æ®")
                            # å°†æ•°æ®è½¬æ¢ä¸ºCSV格式
                            csv = df_analysis.to_csv(index=False)
                            # åˆ›å»ºä¸‹è½½æŒ‰é’®
                            st.download_button(
                                label="导出整合后的数据 (CSV)",
                                data=csv,
                                file_name=f"metered_weight_advanced_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
                                mime="text/csv",
                                help="点击按钮导出整合后的米重分析数据"
                            )
                        except Exception as e:
                            st.error(f"模型训练或预测失败: {str(e)}")
    else:
        # æç¤ºç”¨æˆ·ç‚¹å‡»å¼€å§‹åˆ†æžæŒ‰é’®
        st.info("请选择时间范围并点击'开始分析'按钮获取数据。")
app/pages/metered_weight_advanced.py
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,635 @@
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from app.services.extruder_service import ExtruderService
from app.services.main_process_service import MainProcessService
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.svm import SVR
from sklearn.neural_network import MLPRegressor
def show_metered_weight_advanced():
    # åˆå§‹åŒ–服务
    extruder_service = ExtruderService()
    main_process_service = MainProcessService()
    # é¡µé¢æ ‡é¢˜
    st.title("米重高级预测分析")
    # åˆå§‹åŒ–会话状态
    if 'ma_start_date' not in st.session_state:
        st.session_state['ma_start_date'] = datetime.now().date() - timedelta(days=7)
    if 'ma_end_date' not in st.session_state:
        st.session_state['ma_end_date'] = datetime.now().date()
    if 'ma_quick_select' not in st.session_state:
        st.session_state['ma_quick_select'] = "最近7天"
    if 'ma_model_type' not in st.session_state:
        st.session_state['ma_model_type'] = 'RandomForest'
    if 'ma_sequence_length' not in st.session_state:
        st.session_state['ma_sequence_length'] = 10
    # é»˜è®¤ç‰¹å¾åˆ—表(不再允许用户选择)
    default_features = ['螺杆转速', '机头压力', '流程主速', '螺杆温度',
                       '后机筒温度', '前机筒温度', '机头温度']
    # å®šä¹‰å›žè°ƒå‡½æ•°
    def update_dates(qs):
        st.session_state['ma_quick_select'] = qs
        today = datetime.now().date()
        if qs == "今天":
            st.session_state['ma_start_date'] = today
            st.session_state['ma_end_date'] = today
        elif qs == "最近3天":
            st.session_state['ma_start_date'] = today - timedelta(days=3)
            st.session_state['ma_end_date'] = today
        elif qs == "最近7天":
            st.session_state['ma_start_date'] = today - timedelta(days=7)
            st.session_state['ma_end_date'] = today
        elif qs == "最近30天":
            st.session_state['ma_start_date'] = today - timedelta(days=30)
            st.session_state['ma_end_date'] = today
    def on_date_change():
        st.session_state['ma_quick_select'] = "自定义"
    # æŸ¥è¯¢æ¡ä»¶åŒºåŸŸ
    with st.expander("🔍 æŸ¥è¯¢é…ç½®", expanded=True):
        # æ·»åŠ è‡ªå®šä¹‰ CSS å®žçŽ°å“åº”å¼æ¢è¡Œ
        st.markdown("""
            <style>
            /* å¼ºåˆ¶åˆ—容器换行 */
            [data-testid="stExpander"] [data-testid="column"] {
                flex: 1 1 120px !important;
                min-width: 120px !important;
            }
            /* é’ˆå¯¹æ—¥æœŸè¾“入框列稍微加宽一点 */
            @media (min-width: 768px) {
                [data-testid="stExpander"] [data-testid="column"]:nth-child(6),
                [data-testid="stExpander"] [data-testid="column"]:nth-child(7) {
                    flex: 2 1 180px !important;
                    min-width: 180px !important;
                }
            }
            </style>
            """, unsafe_allow_html=True)
        # åˆ›å»ºå¸ƒå±€
        cols = st.columns([1, 1, 1, 1, 1, 1.5, 1.5, 1])
        options = ["今天", "最近3天", "最近7天", "最近30天", "自定义"]
        for i, option in enumerate(options):
            with cols[i]:
                # æ ¹æ®å½“前选择状态决定按钮类型
                button_type = "primary" if st.session_state['ma_quick_select'] == option else "secondary"
                if st.button(option, key=f"btn_ma_{option}", width='stretch', type=button_type):
                    update_dates(option)
                    st.rerun()
        with cols[5]:
            start_date = st.date_input(
                "开始日期",
                label_visibility="collapsed",
                key="ma_start_date",
                on_change=on_date_change
            )
        with cols[6]:
            end_date = st.date_input(
                "结束日期",
                label_visibility="collapsed",
                key="ma_end_date",
                on_change=on_date_change
            )
        with cols[7]:
            query_button = st.button("🚀 å¼€å§‹åˆ†æž", key="ma_query", width='stretch')
        # æ¨¡åž‹é…ç½®
        st.markdown("---")
        st.write("🤖 **模型配置**")
        model_cols = st.columns(2)
        with model_cols[0]:
            # æ¨¡åž‹ç±»åž‹é€‰æ‹©
            model_options = ['RandomForest', 'GradientBoosting', 'SVR', 'MLP']
            model_type = st.selectbox(
                "模型类型",
                options=model_options,
                key="ma_model_type",
                help="选择用于预测的模型类型"
            )
    # è½¬æ¢ä¸ºdatetime对象
    start_dt = datetime.combine(start_date, datetime.min.time())
    end_dt = datetime.combine(end_date, datetime.max.time())
    # æŸ¥è¯¢å¤„理
    if query_button:
        with st.spinner("正在获取数据..."):
            # 1. èŽ·å–å®Œæ•´çš„æŒ¤å‡ºæœºæ•°æ®
            df_extruder_full = extruder_service.get_extruder_data(start_dt, end_dt)
            # 2. èŽ·å–ä¸»æµç¨‹æŽ§åˆ¶æ•°æ®
            df_main_speed = main_process_service.get_cutting_setting_data(start_dt, end_dt)
            df_temp = main_process_service.get_temperature_control_data(start_dt, end_dt)
            # æ£€æŸ¥æ˜¯å¦æœ‰æ•°æ®
            has_data = any([
                df_extruder_full is not None and not df_extruder_full.empty,
                df_main_speed is not None and not df_main_speed.empty,
                df_temp is not None and not df_temp.empty
            ])
            if not has_data:
                st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
                # æ¸…除缓存数据
                for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp', 'last_query_start', 'last_query_end']:
                    if key in st.session_state:
                        del st.session_state[key]
                return
            # ç¼“存数据到会话状态
            st.session_state['cached_extruder_full'] = df_extruder_full
            st.session_state['cached_main_speed'] = df_main_speed
            st.session_state['cached_temp'] = df_temp
            st.session_state['last_query_start'] = start_dt
            st.session_state['last_query_end'] = end_dt
    # æ•°æ®å¤„理和分析
    if all(key in st.session_state for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp']):
        with st.spinner("正在分析数据..."):
            # èŽ·å–ç¼“å­˜æ•°æ®
            df_extruder_full = st.session_state['cached_extruder_full']
            df_main_speed = st.session_state['cached_main_speed']
            df_temp = st.session_state['cached_temp']
            # æ£€æŸ¥æ˜¯å¦æœ‰æ•°æ®
            has_data = any([
                df_extruder_full is not None and not df_extruder_full.empty,
                df_main_speed is not None and not df_main_speed.empty,
                df_temp is not None and not df_temp.empty
            ])
            if not has_data:
                st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
                return
            # æ•°æ®æ•´åˆä¸Žé¢„处理
            def integrate_data(df_extruder_full, df_main_speed, df_temp):
                # ç¡®ä¿æŒ¤å‡ºæœºæ•°æ®å­˜åœ¨
                if df_extruder_full is None or df_extruder_full.empty:
                    return None
                # åˆ›å»ºåªåŒ…含米重和时间的主数据集
                df_merged = df_extruder_full[['time', 'metered_weight', 'screw_speed_actual', 'head_pressure']].copy()
                # æ•´åˆä¸»æµç¨‹æ•°æ®
                if df_main_speed is not None and not df_main_speed.empty:
                    df_main_speed = df_main_speed[['time', 'process_main_speed']]
                    df_merged = pd.merge_asof(
                        df_merged.sort_values('time'),
                        df_main_speed.sort_values('time'),
                        on='time',
                        direction='nearest',
                        tolerance=pd.Timedelta('1min')
                    )
                # æ•´åˆæ¸©åº¦æ•°æ®
                if df_temp is not None and not df_temp.empty:
                    temp_cols = ['time', 'nakata_extruder_screw_display_temp',
                               'nakata_extruder_rear_barrel_display_temp',
                               'nakata_extruder_front_barrel_display_temp',
                               'nakata_extruder_head_display_temp']
                    df_temp_subset = df_temp[temp_cols].copy()
                    df_merged = pd.merge_asof(
                        df_merged.sort_values('time'),
                        df_temp_subset.sort_values('time'),
                        on='time',
                        direction='nearest',
                        tolerance=pd.Timedelta('1min')
                    )
                # é‡å‘½ååˆ—以提高可读性
                df_merged.rename(columns={
                    'screw_speed_actual': '螺杆转速',
                    'head_pressure': '机头压力',
                    'process_main_speed': '流程主速',
                    'nakata_extruder_screw_display_temp': '螺杆温度',
                    'nakata_extruder_rear_barrel_display_temp': '后机筒温度',
                    'nakata_extruder_front_barrel_display_temp': '前机筒温度',
                    'nakata_extruder_head_display_temp': '机头温度'
                }, inplace=True)
                # æ¸…理数据
                df_merged.dropna(subset=['metered_weight'], inplace=True)
                return df_merged
            # æ‰§è¡Œæ•°æ®æ•´åˆ
            df_analysis = integrate_data(df_extruder_full, df_main_speed, df_temp)
            if df_analysis is None or df_analysis.empty:
                st.warning("数据整合失败,请检查数据质量或调整时间范围。")
                return
            # é‡å‘½åç±³é‡åˆ—
            df_analysis.rename(columns={'metered_weight': '米重'}, inplace=True)
            # --- åŽŸå§‹æ•°æ®è¶‹åŠ¿å›¾ ---
            st.subheader("📈 åŽŸå§‹æ•°æ®è¶‹åŠ¿å›¾")
            # åˆ›å»ºè¶‹åŠ¿å›¾
            fig_trend = go.Figure()
            # æ·»åŠ ç±³é‡æ•°æ®
            if df_extruder_full is not None and not df_extruder_full.empty:
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_full['time'],
                    y=df_extruder_full['metered_weight'],
                    name='米重 (Kg/m)',
                    mode='lines',
                    line=dict(color='blue', width=2)
                ))
                # æ·»åŠ èžºæ†è½¬é€Ÿ
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_full['time'],
                    y=df_extruder_full['screw_speed_actual'],
                    name='螺杆转速 (RPM)',
                    mode='lines',
                    line=dict(color='green', width=1.5),
                    yaxis='y2'
                ))
                # æ·»åŠ æœºå¤´åŽ‹åŠ›
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_full['time'],
                    y=df_extruder_full['head_pressure'],
                    name='机头压力',
                    mode='lines',
                    line=dict(color='orange', width=1.5),
                    yaxis='y3'
                ))
            # æ·»åŠ æµç¨‹ä¸»é€Ÿ
            if df_main_speed is not None and not df_main_speed.empty:
                fig_trend.add_trace(go.Scatter(
                    x=df_main_speed['time'],
                    y=df_main_speed['process_main_speed'],
                    name='流程主速 (M/Min)',
                    mode='lines',
                    line=dict(color='red', width=1.5),
                    yaxis='y4'
                ))
            # æ·»åŠ æ¸©åº¦æ•°æ®
            if df_temp is not None and not df_temp.empty:
                # èžºæ†æ¸©åº¦
                fig_trend.add_trace(go.Scatter(
                    x=df_temp['time'],
                    y=df_temp['nakata_extruder_screw_display_temp'],
                    name='螺杆温度 (°C)',
                    mode='lines',
                    line=dict(color='purple', width=1),
                    yaxis='y5'
                ))
            # é…ç½®è¶‹åŠ¿å›¾å¸ƒå±€
            fig_trend.update_layout(
                title='原始数据趋势',
                xaxis=dict(
                    title='时间',
                    rangeslider=dict(visible=True),
                    type='date'
                ),
                yaxis=dict(
                    title='米重 (Kg/m)',
                    title_font=dict(color='blue'),
                    tickfont=dict(color='blue')
                ),
                yaxis2=dict(
                    title='螺杆转速 (RPM)',
                    title_font=dict(color='green'),
                    tickfont=dict(color='green'),
                    overlaying='y',
                    side='right'
                ),
                yaxis3=dict(
                    title='机头压力',
                    title_font=dict(color='orange'),
                    tickfont=dict(color='orange'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.85
                ),
                yaxis4=dict(
                    title='流程主速 (M/Min)',
                    title_font=dict(color='red'),
                    tickfont=dict(color='red'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.75
                ),
                yaxis5=dict(
                    title='温度 (°C)',
                    title_font=dict(color='purple'),
                    tickfont=dict(color='purple'),
                    overlaying='y',
                    side='left',
                    anchor='free',
                    position=0.15
                ),
                legend=dict(
                    orientation="h",
                    yanchor="bottom",
                    y=1.02,
                    xanchor="right",
                    x=1
                ),
                height=600,
                margin=dict(l=100, r=200, t=100, b=100),
                hovermode='x unified'
            )
            # æ˜¾ç¤ºè¶‹åŠ¿å›¾
            st.plotly_chart(fig_trend, width='stretch', config={'scrollZoom': True})
            # --- é«˜çº§é¢„测分析 ---
            st.subheader("📊 é«˜çº§é¢„测分析")
            # æ£€æŸ¥æ‰€æœ‰é»˜è®¤ç‰¹å¾æ˜¯å¦åœ¨æ•°æ®ä¸­
            missing_features = [f for f in default_features if f not in df_analysis.columns]
            if missing_features:
                st.warning(f"数据中缺少以下特征: {', '.join(missing_features)}")
            else:
                try:
                    # å‡†å¤‡æ•°æ®
                    # é¦–先确保df_analysis中没有NaN值
                    df_analysis_clean = df_analysis.dropna(subset=default_features + ['米重'])
                    # æ£€æŸ¥æ¸…理后的数据量
                    if len(df_analysis_clean) < 30:
                        st.warning("数据量不足,无法进行有效的预测分析")
                    else:
                        # åˆ›å»ºä¸€ä¸ªæ–°çš„DataFrame来存储所有特征和目标变量
                        all_features = df_analysis_clean[default_features + ['米重']].copy()
                        # æ¸…理所有NaN值
                        all_features_clean = all_features.dropna()
                        # æ£€æŸ¥æ¸…理后的数据量
                        if len(all_features_clean) < 20:
                            st.warning("特征工程后数据量不足,无法进行有效的预测分析")
                        else:
                            # åˆ†ç¦»ç‰¹å¾å’Œç›®æ ‡å˜é‡
                            feature_columns = [col for col in all_features_clean.columns if col != '米重']
                            X_final = all_features_clean[feature_columns]
                            y_final = all_features_clean['米重']
                            # æ£€æŸ¥æœ€ç»ˆæ•°æ®é‡
                            if len(X_final) >= 20:
                                # åˆ†å‰²è®­ç»ƒé›†å’Œæµ‹è¯•集
                                X_train, X_test, y_train, y_test = train_test_split(X_final, y_final, test_size=0.2, random_state=42)
                                # æ•°æ®æ ‡å‡†åŒ–
                                scaler_X = StandardScaler()
                                scaler_y = MinMaxScaler()
                                X_train_scaled = scaler_X.fit_transform(X_train)
                                X_test_scaled = scaler_X.transform(X_test)
                                y_train_scaled = scaler_y.fit_transform(y_train.values.reshape(-1, 1)).ravel()
                                y_test_scaled = scaler_y.transform(y_test.values.reshape(-1, 1)).ravel()
                                # æ¨¡åž‹è®­ç»ƒ
                                model = None
                                y_pred = None
                                if model_type == 'RandomForest':
                                    # éšæœºæ£®æž—回归
                                    model = RandomForestRegressor(n_estimators=100, random_state=42)
                                    model.fit(X_train, y_train)
                                    y_pred = model.predict(X_test)
                                elif model_type == 'GradientBoosting':
                                    # æ¢¯åº¦æå‡å›žå½’
                                    model = GradientBoostingRegressor(n_estimators=100, random_state=42)
                                    model.fit(X_train, y_train)
                                    y_pred = model.predict(X_test)
                                elif model_type == 'SVR':
                                    # æ”¯æŒå‘量回归
                                    model = SVR(kernel='rbf', C=1.0, gamma='scale')
                                    model.fit(X_train_scaled, y_train_scaled)
                                    y_pred_scaled = model.predict(X_test_scaled)
                                    y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
                                elif model_type == 'MLP':
                                    # å¤šå±‚感知器回归
                                    model = MLPRegressor(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
                                    model.fit(X_train_scaled, y_train_scaled)
                                    y_pred_scaled = model.predict(X_test_scaled)
                                    y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).ravel()
                                # è®¡ç®—评估指标
                                # ç¡®ä¿y_test和y_pred长度一致
                                min_len = min(len(y_test), len(y_pred))
                                if min_len > 0:
                                    y_test_trimmed = y_test[:min_len]
                                    y_pred_trimmed = y_pred[:min_len]
                                    r2 = r2_score(y_test_trimmed, y_pred_trimmed)
                                    mse = mean_squared_error(y_test_trimmed, y_pred_trimmed)
                                    mae = mean_absolute_error(y_test_trimmed, y_pred_trimmed)
                                    rmse = np.sqrt(mse)
                                else:
                                    r2 = 0
                                    mse = 0
                                    mae = 0
                                    rmse = 0
                                # æ˜¾ç¤ºæ¨¡åž‹æ€§èƒ½
                                metrics_cols = st.columns(2)
                                with metrics_cols[0]:
                                    st.metric("R² å¾—分", f"{r2:.4f}")
                                    st.metric("均方误差 (MSE)", f"{mse:.6f}")
                                with metrics_cols[1]:
                                    st.metric("平均绝对误差 (MAE)", f"{mae:.6f}")
                                    st.metric("均方根误差 (RMSE)", f"{rmse:.6f}")
                                # --- å®žé™…值与预测值对比 ---
                                st.subheader("🔄 å®žé™…值与预测值对比")
                                # åˆ›å»ºå¯¹æ¯”数据
                                compare_df = pd.DataFrame({
                                    '实际值': y_test_trimmed,
                                    '预测值': y_pred_trimmed
                                })
                                compare_df = compare_df.sort_index()
                                # åˆ›å»ºå¯¹æ¯”图
                                fig_compare = go.Figure()
                                fig_compare.add_trace(go.Scatter(
                                    x=compare_df.index,
                                    y=compare_df['实际值'],
                                    name='实际值',
                                    mode='lines+markers',
                                    line=dict(color='blue', width=2)
                                ))
                                fig_compare.add_trace(go.Scatter(
                                    x=compare_df.index,
                                    y=compare_df['预测值'],
                                    name='预测值',
                                    mode='lines+markers',
                                    line=dict(color='red', width=2, dash='dash')
                                ))
                                fig_compare.update_layout(
                                    title=f'测试集: å®žé™…米重 vs é¢„测米重 ({model_type})',
                                    xaxis=dict(title='时间'),
                                    yaxis=dict(title='米重 (Kg/m)'),
                                    legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
                                    height=400
                                )
                                st.plotly_chart(fig_compare, width='stretch')
                                # --- æ®‹å·®åˆ†æž ---
                                st.subheader("📉 æ®‹å·®åˆ†æž")
                                # è®¡ç®—残差
                                residuals = y_test_trimmed - y_pred_trimmed
                                # åˆ›å»ºæ®‹å·®å›¾
                                fig_residual = go.Figure()
                                fig_residual.add_trace(go.Scatter(
                                    x=y_pred,
                                    y=residuals,
                                    mode='markers',
                                    marker=dict(color='green', size=8, opacity=0.6)
                                ))
                                fig_residual.add_shape(
                                    type="line",
                                    x0=y_pred.min(),
                                    y0=0,
                                    x1=y_pred.max(),
                                    y1=0,
                                    line=dict(color="red", width=2, dash="dash")
                                )
                                fig_residual.update_layout(
                                    title='残差图',
                                    xaxis=dict(title='预测值'),
                                    yaxis=dict(title='残差'),
                                    height=400
                                )
                                st.plotly_chart(fig_residual, width='stretch')
                                # --- ç‰¹å¾é‡è¦æ€§ï¼ˆå¦‚果模型支持) ---
                                if model_type in ['RandomForest', 'GradientBoosting']:
                                    st.subheader("⚖️ ç‰¹å¾é‡è¦æ€§åˆ†æž")
                                    # è®¡ç®—特征重要性
                                    feature_importance = pd.DataFrame({
                                        '特征': X_train.columns,
                                        '重要性': model.feature_importances_
                                    })
                                    feature_importance = feature_importance.sort_values('重要性', ascending=False)
                                    # åˆ›å»ºç‰¹å¾é‡è¦æ€§å›¾
                                    fig_importance = px.bar(
                                        feature_importance,
                                        x='特征',
                                        y='重要性',
                                        title='特征重要性',
                                        color='重要性',
                                        color_continuous_scale='viridis'
                                    )
                                    fig_importance.update_layout(
                                        xaxis=dict(tickangle=-45),
                                        height=400
                                    )
                                    st.plotly_chart(fig_importance, width='stretch')
                                # --- é¢„测功能 ---
                                st.subheader("🔮 ç±³é‡é¢„测")
                                # åˆ›å»ºé¢„测表单
                                st.write("输入特征值进行米重预测:")
                                predict_cols = st.columns(2)
                                input_features = {}
                                for i, feature in enumerate(default_features):
                                    with predict_cols[i % 2]:
                                        # èŽ·å–ç‰¹å¾çš„ç»Ÿè®¡ä¿¡æ¯
                                        min_val = df_analysis_clean[feature].min()
                                        max_val = df_analysis_clean[feature].max()
                                        mean_val = df_analysis_clean[feature].mean()
                                        input_features[feature] = st.number_input(
                                            f"{feature}",
                                            key=f"ma_pred_{feature}",
                                            value=float(mean_val),
                                            min_value=float(min_val),
                                            max_value=float(max_val),
                                            step=0.1
                                        )
                                if st.button("预测米重"):
                                    # å‡†å¤‡é¢„测数据
                                    input_df = pd.DataFrame([input_features])
                                    # åˆå¹¶ç‰¹å¾
                                    input_combined = pd.concat([input_df], axis=1)
                                    # é¢„测
                                    if model_type in ['SVR', 'MLP']:
                                        input_scaled = scaler_X.transform(input_combined)
                                        prediction_scaled = model.predict(input_scaled)
                                        predicted_weight = scaler_y.inverse_transform(prediction_scaled.reshape(-1, 1)).ravel()[0]
                                    else:
                                        predicted_weight = model.predict(input_combined)[0]
                                    # æ˜¾ç¤ºé¢„测结果
                                    st.success(f"预测米重: {predicted_weight:.4f} Kg/m")
                                # --- æ•°æ®é¢„览 ---
                                st.subheader("🔍 æ•°æ®é¢„览")
                                st.dataframe(df_analysis.head(20), width='stretch')
                                # --- å¯¼å‡ºæ•°æ® ---
                                st.subheader("💾 å¯¼å‡ºæ•°æ®")
                                # å°†æ•°æ®è½¬æ¢ä¸ºCSV格式
                                csv = df_analysis.to_csv(index=False)
                                # åˆ›å»ºä¸‹è½½æŒ‰é’®
                                st.download_button(
                                    label="导出整合后的数据 (CSV)",
                                    data=csv,
                                    file_name=f"metered_weight_advanced_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
                                    mime="text/csv",
                                    help="点击按钮导出整合后的米重分析数据"
                                )
                except Exception as e:
                    st.error(f"模型训练或预测失败: {str(e)}")
    else:
        # æç¤ºç”¨æˆ·ç‚¹å‡»å¼€å§‹åˆ†æžæŒ‰é’®
        st.info("请选择时间范围并点击'开始分析'按钮获取数据。")
app/pages/metered_weight_correlation.py
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,777 @@
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from app.services.extruder_service import ExtruderService
from app.services.main_process_service import MainProcessService
def show_metered_weight_correlation():
    # åˆå§‹åŒ–服务
    extruder_service = ExtruderService()
    main_process_service = MainProcessService()
    # é¡µé¢æ ‡é¢˜
    st.title("米重相关性分析")
    # åˆå§‹åŒ–会话状态用于日期同步
    if 'mc_start_date' not in st.session_state:
        st.session_state['mc_start_date'] = datetime.now().date() - timedelta(days=7)
    if 'mc_end_date' not in st.session_state:
        st.session_state['mc_end_date'] = datetime.now().date()
    if 'mc_quick_select' not in st.session_state:
        st.session_state['mc_quick_select'] = "最近7天"
    if 'mc_time_offset' not in st.session_state:
        st.session_state['mc_time_offset'] = 0.0
    # å®šä¹‰å›žè°ƒå‡½æ•°
    def update_dates(qs):
        st.session_state['mc_quick_select'] = qs
        today = datetime.now().date()
        if qs == "今天":
            st.session_state['mc_start_date'] = today
            st.session_state['mc_end_date'] = today
        elif qs == "最近3天":
            st.session_state['mc_start_date'] = today - timedelta(days=3)
            st.session_state['mc_end_date'] = today
        elif qs == "最近7天":
            st.session_state['mc_start_date'] = today - timedelta(days=7)
            st.session_state['mc_end_date'] = today
        elif qs == "最近30天":
            st.session_state['mc_start_date'] = today - timedelta(days=30)
            st.session_state['mc_end_date'] = today
    def on_date_change():
        st.session_state['mc_quick_select'] = "自定义"
    # æŸ¥è¯¢æ¡ä»¶åŒºåŸŸ
    with st.expander("🔍 æŸ¥è¯¢é…ç½®", expanded=True):
        # æ·»åŠ è‡ªå®šä¹‰ CSS å®žçŽ°å“åº”å¼æ¢è¡Œ
        st.markdown("""
            <style>
            /* å¼ºåˆ¶åˆ—容器换行 */
            [data-testid="stExpander"] [data-testid="column"] {
                flex: 1 1 120px !important;
                min-width: 120px !important;
            }
            /* é’ˆå¯¹æ—¥æœŸè¾“入框列稍微加宽一点 */
            @media (min-width: 768px) {
                [data-testid="stExpander"] [data-testid="column"]:nth-child(6),
                [data-testid="stExpander"] [data-testid="column"]:nth-child(7) {
                    flex: 2 1 180px !important;
                    min-width: 180px !important;
                }
            }
            </style>
            """, unsafe_allow_html=True)
        # åˆ›å»ºå¸ƒå±€
        cols = st.columns([1, 1, 1, 1, 1, 1.5, 1.5, 1])
        options = ["今天", "最近3天", "最近7天", "最近30天", "自定义"]
        for i, option in enumerate(options):
            with cols[i]:
                # æ ¹æ®å½“前选择状态决定按钮类型
                button_type = "primary" if st.session_state['mc_quick_select'] == option else "secondary"
                if st.button(option, key=f"btn_mc_{option}", width='stretch', type=button_type):
                    update_dates(option)
                    st.rerun()
        with cols[5]:
            start_date = st.date_input(
                "开始日期",
                label_visibility="collapsed",
                key="mc_start_date",
                on_change=on_date_change
            )
        with cols[6]:
            end_date = st.date_input(
                "结束日期",
                label_visibility="collapsed",
                key="mc_end_date",
                on_change=on_date_change
            )
        with cols[7]:
            query_button = st.button("🚀 å¼€å§‹åˆ†æž", key="mc_query", width='stretch')
        # æ•°æ®å¯¹é½è°ƒæ•´
        st.markdown("---")
        offset_cols = st.columns([2, 4, 2])
        with offset_cols[0]:
            st.write("⏱️ **数据对齐调整**")
        with offset_cols[1]:
            time_offset = st.slider(
                "时间偏移 (分钟)",
                min_value=0.0,
                max_value=5.0,
                value=st.session_state['mc_time_offset'],
                step=0.1,
                help="调整主流程和温度数据的时间偏移,使其与挤出机米重数据对齐。"
            )
            st.session_state['mc_time_offset'] = time_offset
        with offset_cols[2]:
            st.write(f"当前偏移: {time_offset} åˆ†é’Ÿ")
    # è½¬æ¢ä¸ºdatetime对象
    start_dt = datetime.combine(start_date, datetime.min.time())
    end_dt = datetime.combine(end_date, datetime.max.time())
    # æŸ¥è¯¢å¤„理 - ä»…获取数据并缓存到会话状态
    if query_button:
        with st.spinner("正在获取数据..."):
            # 1. èŽ·å–å®Œæ•´çš„æŒ¤å‡ºæœºæ•°æ®ï¼ˆåŒ…å«æ‰€æœ‰æ—¶é—´ç‚¹çš„èžºæ†è½¬é€Ÿå’Œæœºå¤´åŽ‹åŠ›ï¼‰
            df_extruder_full = extruder_service.get_extruder_data(start_dt, end_dt)
            # 2. èŽ·å–ä¸»æµç¨‹æŽ§åˆ¶æ•°æ®
            df_main_speed = main_process_service.get_cutting_setting_data(start_dt, end_dt)
            df_temp = main_process_service.get_temperature_control_data(start_dt, end_dt)
            # æ£€æŸ¥æ˜¯å¦æœ‰æ•°æ®
            has_data = any([
                df_extruder_full is not None and not df_extruder_full.empty,
                df_main_speed is not None and not df_main_speed.empty,
                df_temp is not None and not df_temp.empty
            ])
            if not has_data:
                st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
                # æ¸…除缓存数据
                for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp', 'last_query_start', 'last_query_end']:
                    if key in st.session_state:
                        del st.session_state[key]
                return
            # ç¼“存数据到会话状态
            st.session_state['cached_extruder_full'] = df_extruder_full
            st.session_state['cached_main_speed'] = df_main_speed
            st.session_state['cached_temp'] = df_temp
            st.session_state['last_query_start'] = start_dt
            st.session_state['last_query_end'] = end_dt
    # æ•°æ®å¤„理和图表渲染 - æ¯æ¬¡åº”用重新运行时执行(包括调整时间偏移时)
    if all(key in st.session_state for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp']):
        with st.spinner("正在分析数据相关性..."):
            # èŽ·å–ç¼“å­˜æ•°æ®
            df_extruder_full = st.session_state['cached_extruder_full']
            df_main_speed = st.session_state['cached_main_speed']
            df_temp = st.session_state['cached_temp']
            # èŽ·å–å½“å‰æ—¶é—´åç§»é‡
            offset_delta = timedelta(minutes=st.session_state['mc_time_offset'])
            # å¤„理数据
            if df_extruder_full is not None and not df_extruder_full.empty:
                # è¿‡æ»¤æœºå¤´åŽ‹åŠ›å¤§äºŽ2的值
                df_extruder_filtered = df_extruder_full[df_extruder_full['head_pressure'] <= 2]
                # ä¸ºç±³é‡æ•°æ®åˆ›å»ºåç§»åŽçš„æ—¶é—´åˆ—(只对米重数据进行时间偏移)
                df_extruder_filtered['weight_time'] = df_extruder_filtered['time'] - offset_delta
            else:
                df_extruder_filtered = None
            # æ£€æŸ¥æ˜¯å¦æœ‰æ•°æ®
            has_data = any([
                df_extruder_filtered is not None and not df_extruder_filtered.empty,
                df_main_speed is not None and not df_main_speed.empty,
                df_temp is not None and not df_temp.empty
            ])
            if not has_data:
                st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
                return
            # æ•°æ®æ•´åˆä¸Žé¢„处理
            def integrate_data(df_extruder_filtered, df_main_speed, df_temp):
                # ç¡®ä¿æŒ¤å‡ºæœºæ•°æ®å­˜åœ¨
                if df_extruder_filtered is None or df_extruder_filtered.empty:
                    return None
                # åˆ›å»ºåªåŒ…含米重和偏移时间的主数据集
                df_weight = df_extruder_filtered[['weight_time', 'metered_weight']].copy()
                df_weight.rename(columns={'weight_time': 'time'}, inplace=True)  # å°†weight_time重命名为time作为基准时间
                # åˆ›å»ºåŒ…含螺杆转速和原始时间的完整数据集
                # æ³¨æ„ï¼šè¿™é‡Œä½¿ç”¨å®Œæ•´çš„螺杆转速数据,而不仅仅是与米重对应的数据点
                df_screw = df_extruder_filtered[['time', 'screw_speed_actual']].copy()
                # åˆ›å»ºåŒ…含机头压力和原始时间的完整数据集
                # æ³¨æ„ï¼šè¿™é‡Œä½¿ç”¨å®Œæ•´çš„æœºå¤´åŽ‹åŠ›æ•°æ®ï¼Œè€Œä¸ä»…ä»…æ˜¯ä¸Žç±³é‡å¯¹åº”çš„æ•°æ®ç‚¹
                df_pressure = df_extruder_filtered[['time', 'head_pressure']].copy()
                # ä½¿ç”¨åç§»åŽçš„米重时间整合螺杆转速数据
                # å…³é”®ï¼šä½¿ç”¨merge_asof根据偏移后的米重时间查找最接近的螺杆转速数据
                df_merged = pd.merge_asof(
                    df_weight.sort_values('time'),
                    df_screw.sort_values('time'),
                    on='time',
                    direction='nearest',
                    tolerance=pd.Timedelta('1min')
                )
                # ä½¿ç”¨åç§»åŽçš„米重时间整合机头压力数据
                # å…³é”®ï¼šä½¿ç”¨merge_asof根据偏移后的米重时间查找最接近的机头压力数据
                df_merged = pd.merge_asof(
                    df_merged.sort_values('time'),
                    df_pressure.sort_values('time'),
                    on='time',
                    direction='nearest',
                    tolerance=pd.Timedelta('1min')
                )
                # æ•´åˆä¸»æµç¨‹æ•°æ®
                if df_main_speed is not None and not df_main_speed.empty:
                    df_main_speed = df_main_speed[['time', 'process_main_speed']]
                    df_merged = pd.merge_asof(
                        df_merged.sort_values('time'),
                        df_main_speed.sort_values('time'),
                        on='time',
                        direction='nearest',
                        tolerance=pd.Timedelta('1min')
                    )
                # æ•´åˆæ¸©åº¦æ•°æ®
                if df_temp is not None and not df_temp.empty:
                    temp_cols = ['time', 'nakata_extruder_screw_display_temp',
                               'nakata_extruder_rear_barrel_display_temp',
                               'nakata_extruder_front_barrel_display_temp',
                               'nakata_extruder_head_display_temp']
                    df_temp_subset = df_temp[temp_cols].copy()
                    df_merged = pd.merge_asof(
                        df_merged.sort_values('time'),
                        df_temp_subset.sort_values('time'),
                        on='time',
                        direction='nearest',
                        tolerance=pd.Timedelta('1min')
                    )
                # é‡å‘½ååˆ—以提高可读性
                df_merged.rename(columns={
                    'screw_speed_actual': '螺杆转速',
                    'head_pressure': '机头压力',
                    'process_main_speed': '流程主速',
                    'nakata_extruder_screw_display_temp': '螺杆温度',
                    'nakata_extruder_rear_barrel_display_temp': '后机筒温度',
                    'nakata_extruder_front_barrel_display_temp': '前机筒温度',
                    'nakata_extruder_head_display_temp': '机头温度'
                }, inplace=True)
                # æ¸…理数据
                df_merged.dropna(subset=['metered_weight'], inplace=True)
                return df_merged
            # æ‰§è¡Œæ•°æ®æ•´åˆ
            df_analysis = integrate_data(df_extruder_filtered, df_main_speed, df_temp)
            if df_analysis is None or df_analysis.empty:
                st.warning("数据整合失败,请检查数据质量或调整时间范围。")
                return
            # é‡å‘½åç±³é‡åˆ—
            df_analysis.rename(columns={'metered_weight': '米重'}, inplace=True)
            # --- åŽŸå§‹æ•°æ®è¶‹åŠ¿å›¾ ---
            st.subheader("📈 åŽŸå§‹æ•°æ®è¶‹åŠ¿å›¾")
            # åˆ›å»ºè¶‹åŠ¿å›¾
            fig_trend = go.Figure()
            # æ·»åŠ ç±³é‡æ•°æ®ï¼ˆä½¿ç”¨åç§»åŽçš„æ—¶é—´ï¼‰
            if df_extruder_filtered is not None and not df_extruder_filtered.empty:
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_filtered['weight_time'],  # ä½¿ç”¨åç§»åŽçš„æ—¶é—´
                    y=df_extruder_filtered['metered_weight'],
                    name='米重 (Kg/m) [已偏移]',
                    mode='lines',
                    line=dict(color='blue', width=2)
                ))
                # æ·»åŠ èžºæ†è½¬é€Ÿï¼ˆä½¿ç”¨åŽŸå§‹æ—¶é—´ï¼‰
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_filtered['time'],  # ä½¿ç”¨åŽŸå§‹æ—¶é—´
                    y=df_extruder_filtered['screw_speed_actual'],
                    name='螺杆转速 (RPM)',
                    mode='lines',
                    line=dict(color='green', width=1.5),
                    yaxis='y2'
                ))
                # æ·»åŠ æœºå¤´åŽ‹åŠ›ï¼ˆä½¿ç”¨åŽŸå§‹æ—¶é—´ï¼Œå·²è¿‡æ»¤å¤§äºŽ2的值)
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_filtered['time'],  # ä½¿ç”¨åŽŸå§‹æ—¶é—´
                    y=df_extruder_filtered['head_pressure'],
                    name='机头压力 (≤2)',
                    mode='lines',
                    line=dict(color='orange', width=1.5),
                    yaxis='y3'
                ))
            # æ·»åŠ æµç¨‹ä¸»é€Ÿ
            if df_main_speed is not None and not df_main_speed.empty:
                fig_trend.add_trace(go.Scatter(
                    x=df_main_speed['time'],
                    y=df_main_speed['process_main_speed'],
                    name='流程主速 (M/Min)',
                    mode='lines',
                    line=dict(color='red', width=1.5),
                    yaxis='y4'
                ))
            # æ·»åŠ æ¸©åº¦æ•°æ®
            if df_temp is not None and not df_temp.empty:
                # èžºæ†æ¸©åº¦
                fig_trend.add_trace(go.Scatter(
                    x=df_temp['time'],
                    y=df_temp['nakata_extruder_screw_display_temp'],
                    name='螺杆温度 (°C)',
                    mode='lines',
                    line=dict(color='purple', width=1),
                    yaxis='y5'
                ))
                # åŽæœºç­’温度
                fig_trend.add_trace(go.Scatter(
                    x=df_temp['time'],
                    y=df_temp['nakata_extruder_rear_barrel_display_temp'],
                    name='后机筒温度 (°C)',
                    mode='lines',
                    line=dict(color='pink', width=1),
                    yaxis='y5'
                ))
                # å‰æœºç­’温度
                fig_trend.add_trace(go.Scatter(
                    x=df_temp['time'],
                    y=df_temp['nakata_extruder_front_barrel_display_temp'],
                    name='前机筒温度 (°C)',
                    mode='lines',
                    line=dict(color='brown', width=1),
                    yaxis='y5'
                ))
                # æœºå¤´æ¸©åº¦
                fig_trend.add_trace(go.Scatter(
                    x=df_temp['time'],
                    y=df_temp['nakata_extruder_head_display_temp'],
                    name='机头温度 (°C)',
                    mode='lines',
                    line=dict(color='gray', width=1),
                    yaxis='y5'
                ))
            # é…ç½®è¶‹åŠ¿å›¾å¸ƒå±€
            fig_trend.update_layout(
                title=f'原始数据趋势 (米重向前偏移 {st.session_state["mc_time_offset"]} åˆ†é’Ÿ)',
                xaxis=dict(
                    title='时间',
                    rangeslider=dict(visible=True),
                    type='date'
                ),
                yaxis=dict(
                    title='米重 (Kg/m)',
                    title_font=dict(color='blue'),
                    tickfont=dict(color='blue')
                ),
                yaxis2=dict(
                    title='螺杆转速 (RPM)',
                    title_font=dict(color='green'),
                    tickfont=dict(color='green'),
                    overlaying='y',
                    side='right'
                ),
                yaxis3=dict(
                    title='机头压力',
                    title_font=dict(color='orange'),
                    tickfont=dict(color='orange'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.85
                ),
                yaxis4=dict(
                    title='流程主速 (M/Min)',
                    title_font=dict(color='red'),
                    tickfont=dict(color='red'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.75
                ),
                yaxis5=dict(
                    title='温度 (°C)',
                    title_font=dict(color='purple'),
                    tickfont=dict(color='purple'),
                    overlaying='y',
                    side='left',
                    anchor='free',
                    position=0.15
                ),
                legend=dict(
                    orientation="h",
                    yanchor="bottom",
                    y=1.02,
                    xanchor="right",
                    x=1
                ),
                height=600,
                margin=dict(l=100, r=200, t=100, b=100),
                hovermode='x unified',
                dragmode='select',
            )
            # æ˜¾ç¤ºè¶‹åŠ¿å›¾
            selection = st.plotly_chart(fig_trend, width='stretch', config={'scrollZoom': True}, on_select='rerun' )
            # è°ƒè¯•输出
          #  st.write("原始 selection å¯¹è±¡:", selection)
            # å®šä¹‰åˆ†æžåˆ—
            analysis_cols = ['米重', '螺杆转速', '机头压力', '流程主速', '螺杆温度', '后机筒温度', '前机筒温度', '机头温度']
            # å®šä¹‰è¦åˆ†æžçš„参数
            params = [
                ('螺杆转速', 'RPM'),
                ('机头压力', ''),
                ('流程主速', 'M/Min'),
                ('螺杆温度', '°C'),
                ('后机筒温度', '°C'),
                ('前机筒温度', '°C'),
                ('机头温度', '°C')
            ]
            # æ­£ç¡®æå–
            selected_data = None
            if selection.selection and selection.selection.box:
                boxs = selection.selection.box
                # èŽ·å–é€‰ä¸­æ¡†çš„x轴范围
                x_range = boxs[0]['x'][0], boxs[0]['x'][1]
                st.write("x轴范围:", x_range)
                # è¿‡æ»¤å‡ºåœ¨x轴范围内的数据
                # æ³¨æ„ï¼šè¿™é‡Œéœ€è¦ä½¿ç”¨df_analysis的time列进行过滤
                # é¦–先需要确保df_analysis有time列
                if 'time' in df_analysis.columns:
                    selected_data = df_analysis[
                        (df_analysis['time'] >= x_range[0]) &
                        (df_analysis['time'] <= x_range[1])
                    ].copy()  # ä½¿ç”¨copy()避免切片警告
                    st.write(f"选中范围内的数据点数量: {len(selected_data)}")
                    # æ˜¾ç¤ºå¯ç”¨çš„列名,帮助调试
                    st.write("可用列名:", list(selected_data.columns))
                else:
                    st.warning("数据中缺少time列,无法进行范围过滤")
            else:
                st.info("请使用矩形框选工具选择时间范围(已自动启用选择模式)")
            # æ·»åŠ ç»†èŠ‚åˆ†æžæŒ‰é’®
            if selected_data is not None and not selected_data.empty:
                if st.button("🔍 ç»†èŠ‚åˆ†æž"):
                    st.subheader("📊 æ¡†é€‰èŒƒå›´ç»†èŠ‚åˆ†æž")
                    # è®¡ç®—选中范围内的相关系数矩阵
                    selected_corr_matrix = selected_data[analysis_cols].corr()
                    # åˆ›å»ºé€‰ä¸­èŒƒå›´çš„热力图
                    selected_fig_heatmap = px.imshow(
                        selected_corr_matrix,
                        text_auto=True,
                        aspect="auto",
                        title="框选范围参数相关性矩阵",
                        color_continuous_scale=["#0000FF", "#FFFFFF", "#FF0000"],
                        color_continuous_midpoint=0,
                        labels=dict(color="相关系数")
                    )
                    # è‡ªå®šä¹‰å¸ƒå±€
                    selected_fig_heatmap.update_layout(
                        height=400,
                        margin=dict(l=80, r=80, t=80, b=80),
                        xaxis=dict(tickangle=-45),
                        yaxis=dict(tickangle=0)
                    )
                    # æ˜¾ç¤ºé€‰ä¸­èŒƒå›´çš„热力图
                    st.plotly_chart(selected_fig_heatmap, width='stretch')
                    # æ˜¾ç¤ºé€‰ä¸­èŒƒå›´çš„参数与米重散点图
                    st.subheader("📈 æ¡†é€‰èŒƒå›´å‚数与米重散点图")
                    # åˆ›å»ºé€‰ä¸­èŒƒå›´çš„æ•£ç‚¹å›¾
                    for i in range(0, len(params), 2):
                        row_cols = st.columns(2)
                        for j in range(2):
                            if i + j < len(params):
                                param_name, unit = params[i + j]
                                with row_cols[j]:
                                    if param_name in selected_data.columns:
                                        # è®¡ç®—相关系数(添加错误处理)
                                        try:
                                            # è¿‡æ»¤æŽ‰NaN值
                                            valid_data = selected_data[[param_name, '米重']].dropna()
                                            if len(valid_data) >= 2:  # è‡³å°‘需要2个数据点
                                                corr_coef = np.corrcoef(valid_data['米重'], valid_data[param_name])[0, 1]
                                            else:
                                                corr_coef = None
                                        except Exception as e:
                                            corr_coef = None
                                        # åˆ›å»ºæ•£ç‚¹å›¾
                                        fig_scatter = px.scatter(
                                            selected_data,
                                            x=param_name,
                                            y='米重',
                                            title=f"{param_name} vs ç±³é‡ï¼ˆæ¡†é€‰èŒƒå›´ï¼‰",
                                            labels={param_name: f"{param_name} ({unit})" if unit else param_name, '米重': '米重 (Kg/m)'}
                                        )
                                        # æ·»åŠ è¶‹åŠ¿çº¿ï¼ˆæ·»åŠ é”™è¯¯å¤„ç†ï¼‰
                                        try:
                                            # è¿‡æ»¤æŽ‰NaN值
                                            valid_data = selected_data[[param_name, '米重']].dropna()
                                            if len(valid_data) >= 2:  # è‡³å°‘需要2个数据点
                                                trend_line = np.poly1d(np.polyfit(valid_data[param_name], valid_data['米重'], 1))(valid_data[param_name])
                                                fig_scatter.add_trace(go.Scatter(
                                                    x=valid_data[param_name],
                                                    y=trend_line,
                                                    mode='lines',
                                                    name='趋势线',
                                                    line=dict(color='red', width=2)
                                                ))
                                        except Exception as e:
                                            # å¦‚果趋势线计算失败,跳过添加趋势线
                                            pass
                                        # æ·»åŠ ç›¸å…³ç³»æ•°æ³¨é‡Šï¼ˆæ·»åŠ é”™è¯¯å¤„ç†ï¼‰
                                        if corr_coef is not None:
                                            fig_scatter.add_annotation(
                                                x=0.05, y=0.95,
                                                xref='paper', yref='paper',
                                                text=f"相关系数: {corr_coef:.4f}",
                                                showarrow=False,
                                                font=dict(size=12, color="black"),
                                                bgcolor="white",
                                                bordercolor="black",
                                                borderwidth=1
                                            )
                                        else:
                                            fig_scatter.add_annotation(
                                                x=0.05, y=0.95,
                                                xref='paper', yref='paper',
                                                text="相关系数: æ— æ³•计算",
                                                showarrow=False,
                                                font=dict(size=12, color="black"),
                                                bgcolor="white",
                                                bordercolor="black",
                                                borderwidth=1
                                            )
                                        # æ˜¾ç¤ºæ•£ç‚¹å›¾
                                        st.plotly_chart(fig_scatter, use_container_width=True)
                                    else:
                                        st.warning(f"数据中缺少 {param_name} åˆ—")
                    # æ˜¾ç¤ºé€‰ä¸­èŒƒå›´çš„æ•°æ®æ‘˜è¦
                    st.subheader("📊 æ¡†é€‰èŒƒå›´æ•°æ®æ‘˜è¦")
                    selected_summary_cols = st.columns(4)
                    with selected_summary_cols[0]:
                        if '米重' in selected_data.columns:
                            st.metric("平均米重", f"{selected_data['米重'].mean():.2f} Kg/m")
                    with selected_summary_cols[1]:
                        if '螺杆转速' in selected_data.columns:
                            st.metric("平均螺杆转速", f"{selected_data['螺杆转速'].mean():.2f} RPM")
                    with selected_summary_cols[2]:
                        if '流程主速' in selected_data.columns:
                            st.metric("平均流程主速", f"{selected_data['流程主速'].mean():.2f} M/Min")
                    with selected_summary_cols[3]:
                        if '机头压力' in selected_data.columns:
                            st.metric("平均机头压力", f"{selected_data['机头压力'].mean():.2f}")
                    # æ˜¾ç¤ºé€‰ä¸­èŒƒå›´çš„æ•°æ®é¢„览
                    st.subheader("🔍 æ¡†é€‰èŒƒå›´æ•°æ®é¢„览")
                    st.dataframe(selected_data[analysis_cols].head(10), use_container_width=True)
            # --- ç›¸å…³æ€§çŸ©é˜µçƒ­åЛ图 ---
            st.subheader("📊 ç›¸å…³æ€§çŸ©é˜µçƒ­åЛ图")
            # é‡å‘½åç±³é‡åˆ—
            df_analysis.rename(columns={'metered_weight': '米重'}, inplace=True)
            # è®¡ç®—相关系数矩阵
            corr_matrix = df_analysis[analysis_cols].corr()
            # åˆ›å»ºçƒ­åЛ图
            fig_heatmap = px.imshow(
                corr_matrix,
                text_auto=True,
                aspect="auto",
                title="参数相关性矩阵",
                color_continuous_scale=["#0000FF", "#FFFFFF", "#FF0000"],
                color_continuous_midpoint=0,
                labels=dict(color="相关系数")
            )
            # è‡ªå®šä¹‰å¸ƒå±€
            fig_heatmap.update_layout(
                height=500,
                margin=dict(l=100, r=100, t=100, b=100),
                xaxis=dict(tickangle=-45),
                yaxis=dict(tickangle=0)
            )
            # æ˜¾ç¤ºçƒ­åЛ图
            st.plotly_chart(fig_heatmap, width='stretch')
            # --- å‚数与米重散点图 ---
            st.subheader("📈 å‚数与米重散点图")
            # åˆ›å»ºæ•£ç‚¹å›¾
            for i in range(0, len(params), 2):
                row_cols = st.columns(2)
                for j in range(2):
                    if i + j < len(params):
                        param_name, unit = params[i + j]
                        with row_cols[j]:
                            if param_name in df_analysis.columns:
                                # è®¡ç®—相关系数(添加错误处理)
                                try:
                                    # è¿‡æ»¤æŽ‰NaN值
                                    valid_data = df_analysis[[param_name, '米重']].dropna()
                                    if len(valid_data) >= 2:  # è‡³å°‘需要2个数据点
                                        corr_coef = np.corrcoef(valid_data['米重'], valid_data[param_name])[0, 1]
                                    else:
                                        corr_coef = None
                                except Exception as e:
                                    corr_coef = None
                                # åˆ›å»ºæ•£ç‚¹å›¾
                                fig_scatter = px.scatter(
                                    df_analysis,
                                    x=param_name,
                                    y='米重',
                                    title=f"{param_name} vs ç±³é‡",
                                    labels={param_name: f"{param_name} ({unit})" if unit else param_name, '米重': '米重 (Kg/m)'}
                                )
                                # æ·»åŠ è¶‹åŠ¿çº¿ï¼ˆæ·»åŠ é”™è¯¯å¤„ç†ï¼‰
                                try:
                                    # è¿‡æ»¤æŽ‰NaN值
                                    valid_data = df_analysis[[param_name, '米重']].dropna()
                                    if len(valid_data) >= 2:  # è‡³å°‘需要2个数据点
                                        trend_line = np.poly1d(np.polyfit(valid_data[param_name], valid_data['米重'], 1))(valid_data[param_name])
                                        fig_scatter.add_trace(go.Scatter(
                                            x=valid_data[param_name],
                                            y=trend_line,
                                            mode='lines',
                                            name='趋势线',
                                            line=dict(color='red', width=2)
                                        ))
                                except Exception as e:
                                    # å¦‚果趋势线计算失败,跳过添加趋势线
                                    pass
                                # æ·»åŠ ç›¸å…³ç³»æ•°æ³¨é‡Šï¼ˆæ·»åŠ é”™è¯¯å¤„ç†ï¼‰
                                if corr_coef is not None:
                                    fig_scatter.add_annotation(
                                        x=0.05, y=0.95,
                                        xref='paper', yref='paper',
                                        text=f"相关系数: {corr_coef:.4f}",
                                        showarrow=False,
                                        font=dict(size=12, color="black"),
                                        bgcolor="white",
                                        bordercolor="black",
                                        borderwidth=1
                                    )
                                else:
                                    fig_scatter.add_annotation(
                                        x=0.05, y=0.95,
                                        xref='paper', yref='paper',
                                        text="相关系数: æ— æ³•计算",
                                        showarrow=False,
                                        font=dict(size=12, color="black"),
                                        bgcolor="white",
                                        bordercolor="black",
                                        borderwidth=1
                                    )
                                # æ˜¾ç¤ºæ•£ç‚¹å›¾
                                st.plotly_chart(fig_scatter, use_container_width=True)
                            else:
                                st.warning(f"数据中缺少 {param_name} åˆ—")
            # --- ç›¸å…³æ€§ç»Ÿè®¡è¡¨æ ¼ ---
            st.subheader("📋 ç›¸å…³æ€§ç»Ÿè®¡")
            # è®¡ç®—每个参数与米重的相关系数(添加错误处理)
            corr_stats = []
            for param_name, _ in params:
                if param_name in df_analysis.columns:
                    try:
                        # è¿‡æ»¤æŽ‰NaN值
                        valid_data = df_analysis[[param_name, '米重']].dropna()
                        if len(valid_data) >= 2:  # è‡³å°‘需要2个数据点
                            corr_coef = np.corrcoef(valid_data['米重'], valid_data[param_name])[0, 1]
                            corr_stats.append({
                                '参数': param_name,
                                '相关系数': corr_coef,
                                '相关程度': '强' if abs(corr_coef) > 0.7 else '中等' if abs(corr_coef) > 0.3 else 'å¼±'
                            })
                        else:
                            corr_stats.append({
                                '参数': param_name,
                                '相关系数': None,
                                '相关程度': '无法计算'
                            })
                    except Exception as e:
                        corr_stats.append({
                            '参数': param_name,
                            '相关系数': None,
                            '相关程度': '无法计算'
                        })
            # åˆ›å»ºç»Ÿè®¡è¡¨æ ¼
            corr_df = pd.DataFrame(corr_stats)
            # æŒ‰ç›¸å…³ç³»æ•°ç»å¯¹å€¼æŽ’序(处理None值)
            try:
                # è®¡ç®—相关系数绝对值,对于None值使用-1(这样会排在最后)
                corr_df['相关系数绝对值'] = corr_df['相关系数'].apply(lambda x: abs(x) if x is not None else -1)
                corr_df.sort_values('相关系数绝对值', ascending=False, inplace=True)
                corr_df.drop('相关系数绝对值', axis=1, inplace=True)
            except Exception as e:
                # å¦‚果排序失败,保持原始顺序
                pass
            # æ˜¾ç¤ºè¡¨æ ¼
            st.dataframe(corr_df, use_container_width=True)
            # --- æ•°æ®æ‘˜è¦ ---
            # st.subheader("📊 æ•°æ®æ‘˜è¦")
            # summary_cols = st.columns(4)
            # with summary_cols[0]:
            #     if '米重' in df_analysis.columns:
            #         st.metric("平均米重", f"{df_analysis['米重'].mean():.2f} Kg/m")
            # with summary_cols[1]:
            #     if '螺杆转速' in df_analysis.columns:
            #         st.metric("平均螺杆转速", f"{df_analysis['螺杆转速'].mean():.2f} RPM")
            # with summary_cols[2]:
            #     if '流程主速' in df_analysis.columns:
            #         st.metric("平均流程主速", f"{df_analysis['流程主速'].mean():.2f} M/Min")
            # with summary_cols[3]:
            #     if '机头压力' in df_analysis.columns:
            #         st.metric("平均机头压力", f"{df_analysis['机头压力'].mean():.2f}")
            # --- æ•°æ®é¢„览 ---
            st.subheader("🔍 æ•°æ®é¢„览")
            st.dataframe(df_analysis[analysis_cols].head(20), use_container_width=True)
    else:
        # æç¤ºç”¨æˆ·ç‚¹å‡»å¼€å§‹åˆ†æžæŒ‰é’®
        st.info("请选择时间范围并点击'开始分析'按钮获取数据。")
app/pages/metered_weight_dashboard.py
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,430 @@
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from datetime import datetime, timedelta
from app.services.extruder_service import ExtruderService
from app.services.main_process_service import MainProcessService
def show_metered_weight_dashboard():
    # åˆå§‹åŒ–服务
    extruder_service = ExtruderService()
    main_process_service = MainProcessService()
    # é¡µé¢æ ‡é¢˜
    st.title("米重综合分析")
    # åˆå§‹åŒ–会话状态用于日期同步
    if 'mw_start_date' not in st.session_state:
        st.session_state['mw_start_date'] = datetime.now().date() - timedelta(days=7)
    if 'mw_end_date' not in st.session_state:
        st.session_state['mw_end_date'] = datetime.now().date()
    if 'mw_quick_select' not in st.session_state:
        st.session_state['mw_quick_select'] = "最近7天"
    # å®šä¹‰å›žè°ƒå‡½æ•°
    def update_dates(qs):
        st.session_state['mw_quick_select'] = qs
        today = datetime.now().date()
        if qs == "今天":
            st.session_state['mw_start_date'] = today
            st.session_state['mw_end_date'] = today
        elif qs == "最近3天":
            st.session_state['mw_start_date'] = today - timedelta(days=3)
            st.session_state['mw_end_date'] = today
        elif qs == "最近7天":
            st.session_state['mw_start_date'] = today - timedelta(days=7)
            st.session_state['mw_end_date'] = today
        elif qs == "最近30天":
            st.session_state['mw_start_date'] = today - timedelta(days=30)
            st.session_state['mw_end_date'] = today
    def on_date_change():
        st.session_state['mw_quick_select'] = "自定义"
    # æŸ¥è¯¢æ¡ä»¶åŒºåŸŸ
    with st.expander("🔍 æŸ¥è¯¢é…ç½®", expanded=True):
        # æ·»åŠ è‡ªå®šä¹‰ CSS å®žçŽ°å“åº”å¼æ¢è¡Œ
        st.markdown("""
            <style>
            /* å¼ºåˆ¶åˆ—容器换行 */
            [data-testid="stExpander"] [data-testid="column"] {
                flex: 1 1 120px !important;
                min-width: 120px !important;
            }
            /* é’ˆå¯¹æ—¥æœŸè¾“入框列稍微加宽一点 */
            @media (min-width: 768px) {
                [data-testid="stExpander"] [data-testid="column"]:nth-child(6),
                [data-testid="stExpander"] [data-testid="column"]:nth-child(7) {
                    flex: 2 1 180px !important;
                    min-width: 180px !important;
                }
            }
            </style>
            """, unsafe_allow_html=True)
        # åˆ›å»ºå¸ƒå±€
        cols = st.columns([1, 1, 1, 1, 1, 1.5, 1.5, 1])
        options = ["今天", "最近3天", "最近7天", "最近30天", "自定义"]
        for i, option in enumerate(options):
            with cols[i]:
                # æ ¹æ®å½“前选择状态决定按钮类型
                button_type = "primary" if st.session_state['mw_quick_select'] == option else "secondary"
                if st.button(option, key=f"btn_mw_{option}", width='stretch', type=button_type):
                    update_dates(option)
                    st.rerun()
        with cols[5]:
            start_date = st.date_input(
                "开始日期",
                label_visibility="collapsed",
                key="mw_start_date",
                on_change=on_date_change
            )
        with cols[6]:
            end_date = st.date_input(
                "结束日期",
                label_visibility="collapsed",
                key="mw_end_date",
                on_change=on_date_change
            )
        with cols[7]:
            query_button = st.button("🚀 å¼€å§‹åˆ†æž", key="mw_query", width='stretch')
    # è½¬æ¢ä¸ºdatetime对象
    start_dt = datetime.combine(start_date, datetime.min.time())
    end_dt = datetime.combine(end_date, datetime.max.time())
    # æŸ¥è¯¢å¤„理
    if query_button:
        with st.spinner("正在聚合数据..."):
            # 1. èŽ·å–æŒ¤å‡ºæœºæ•°æ®
            df_extruder = extruder_service.get_extruder_data(start_dt, end_dt)
            # å¤„理机头压力,去除超过2的值
            if df_extruder is not None and not df_extruder.empty:
                df_extruder = df_extruder[df_extruder['head_pressure'] <= 2]
            # 2. èŽ·å–ä¸»æµç¨‹æŽ§åˆ¶æ•°æ®
            df_main_speed = main_process_service.get_cutting_setting_data(start_dt, end_dt)
            df_temp = main_process_service.get_temperature_control_data(start_dt, end_dt)
            # èŽ·å–ç”µæœºè¿è¡Œç›‘è§†æ•°æ®
            df_motor = main_process_service.get_motor_monitoring_data(start_dt, end_dt)
            # å¤„理电机线速数据,除以10
            if df_motor is not None and not df_motor.empty:
                df_motor['m1_line_speed'] = df_motor['m1_line_speed'] / 10
                df_motor['m2_line_speed'] = df_motor['m2_line_speed'] / 10
            # æ£€æŸ¥æ˜¯å¦æœ‰æ•°æ®
            has_data = any([
                df_extruder is not None and not df_extruder.empty,
                df_main_speed is not None and not df_main_speed.empty,
                df_temp is not None and not df_temp.empty,
                df_motor is not None and not df_motor.empty
            ])
            if not has_data:
                st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
                return
            # --- å›¾è¡¨1: ç±³é‡ä¸Žå®žé™…参数分析 ---
            st.subheader("📈 ç±³é‡ä¸Žå®žé™…参数分析")
            fig1 = go.Figure()
            # æ·»åŠ ç±³é‡
            if df_extruder is not None and not df_extruder.empty:
                fig1.add_trace(go.Scatter(
                    x=df_extruder['time'],
                    y=df_extruder['metered_weight'],
                    name='米重 (Kg/m)',
                    mode='lines',
                    line=dict(color='blue', width=2),
                    yaxis='y1'
                ))
                # æ·»åŠ æŒ¤å‡ºæœºå®žé™…è½¬é€Ÿ
                fig1.add_trace(go.Scatter(
                    x=df_extruder['time'],
                    y=df_extruder['screw_speed_actual'],
                    name='挤出机实际转速 (RPM)',
                    mode='lines',
                    line=dict(color='green', width=1.5),
                    yaxis='y2'
                ))
                # æ·»åŠ æŒ¤å‡ºæœºæœºå¤´åŽ‹åŠ›
                fig1.add_trace(go.Scatter(
                    x=df_extruder['time'],
                    y=df_extruder['head_pressure'],
                    name='挤出机机头压力',
                    mode='lines',
                    line=dict(color='orange', width=1.5),
                    yaxis='y3'
                ))
            # æ·»åŠ æµç¨‹ä¸»é€Ÿ
            if df_main_speed is not None and not df_main_speed.empty:
                fig1.add_trace(go.Scatter(
                    x=df_main_speed['time'],
                    y=df_main_speed['process_main_speed'],
                    name='流程主速 (M/Min)',
                    mode='lines',
                    line=dict(color='red', width=1.5),
                    yaxis='y4' # å•独的速度轴
                ))
            # æ·»åŠ æ¸©åº¦æ˜¾ç¤ºå€¼
            if df_temp is not None and not df_temp.empty:
                temp_display_fields = {
                    'nakata_extruder_screw_display_temp': '螺杆显示 (°C)',
                    'nakata_extruder_rear_barrel_display_temp': '后机筒显示 (°C)',
                    'nakata_extruder_front_barrel_display_temp': '前机筒显示 (°C)',
                    'nakata_extruder_head_display_temp': '机头显示 (°C)',
                }
                for field, label in temp_display_fields.items():
                    fig1.add_trace(go.Scatter(
                        x=df_temp['time'],
                        y=df_temp[field],
                        name=label,
                        mode='lines',
                        line=dict(width=1),
                        yaxis='y5'
                    ))
            # æ·»åŠ ç”µæœºçº¿é€Ÿæ•°æ®
            if df_motor is not None and not df_motor.empty:
                fig1.add_trace(go.Scatter(
                    x=df_motor['time'],
                    y=df_motor['m1_line_speed'],
                    name='拉出一段线速 (M/Min)',
                    mode='lines',
                    line=dict(color='cyan', width=1.5),
                    yaxis='y4'
                ))
                fig1.add_trace(go.Scatter(
                    x=df_motor['time'],
                    y=df_motor['m2_line_speed'],
                    name='拉出二段线速 (M/Min)',
                    mode='lines',
                    line=dict(color='teal', width=1.5),
                    yaxis='y4'
                ))
            # è®¾ç½®å›¾è¡¨1布局
            fig1.update_layout(
                title='米重与实际参数趋势分析',
                xaxis=dict(
                    title='时间',
                    rangeslider=dict(visible=True),
                    type='date'
                ),
                yaxis=dict(
                    title='米重 (Kg/m)',
                    title_font=dict(color='blue'),
                    tickfont=dict(color='blue')
                ),
                yaxis2=dict(
                    title='挤出机转速 (RPM)',
                    title_font=dict(color='green'),
                    tickfont=dict(color='green'),
                    overlaying='y',
                    side='right'
                ),
                yaxis3=dict(
                    title='机头压力',
                    title_font=dict(color='orange'),
                    tickfont=dict(color='orange'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.85
                ),
                yaxis4=dict(
                    title='流程主速 (M/Min)',
                    title_font=dict(color='red'),
                    tickfont=dict(color='red'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.75
                ),
                yaxis5=dict(
                    title='温度显示 (°C)',
                    title_font=dict(color='purple'),
                    tickfont=dict(color='purple'),
                    overlaying='y',
                    side='left',
                    anchor='free',
                    position=0.15
                ),
                yaxis6=dict(
                    title='拉出线速 (M/Min)',
                    title_font=dict(color='cyan'),
                    tickfont=dict(color='cyan'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.65
                ),
                legend=dict(
                    orientation="h",
                    yanchor="bottom",
                    y=1.02,
                    xanchor="right",
                    x=1
                ),
                height=600,
                margin=dict(l=100, r=200, t=100, b=100),
                hovermode='x unified'
            )
            # æ˜¾ç¤ºå›¾è¡¨1
            st.plotly_chart(fig1, width='stretch', config={'scrollZoom': True})
            # --- å›¾è¡¨2: ç±³é‡ä¸Žè®¾å®šå‚数分析 ---
            st.subheader("📈 ç±³é‡ä¸Žè®¾å®šå‚数分析")
            fig2 = go.Figure()
            # æ·»åŠ ç±³é‡
            if df_extruder is not None and not df_extruder.empty:
                fig2.add_trace(go.Scatter(
                    x=df_extruder['time'],
                    y=df_extruder['metered_weight'],
                    name='米重 (Kg/m)',
                    mode='lines',
                    line=dict(color='blue', width=2),
                    yaxis='y1'
                ))
                # æ·»åŠ æŒ¤å‡ºæœºè®¾å®šè½¬é€Ÿ
                fig2.add_trace(go.Scatter(
                    x=df_extruder['time'],
                    y=df_extruder['screw_speed_set'],
                    name='挤出机设定转速 (RPM)',
                    mode='lines',
                    line=dict(color='green', width=1.5, dash='dash'),
                    yaxis='y2'
                ))
                # æ·»åŠ æŒ¤å‡ºæœºæœºå¤´åŽ‹åŠ›
                fig2.add_trace(go.Scatter(
                    x=df_extruder['time'],
                    y=df_extruder['head_pressure'],
                    name='挤出机机头压力',
                    mode='lines',
                    line=dict(color='orange', width=1.5),
                    yaxis='y3'
                ))
            # æ·»åŠ æµç¨‹ä¸»é€Ÿ
            if df_main_speed is not None and not df_main_speed.empty:
                fig2.add_trace(go.Scatter(
                    x=df_main_speed['time'],
                    y=df_main_speed['process_main_speed'],
                    name='流程主速 (M/Min)',
                    mode='lines',
                    line=dict(color='red', width=1.5),
                    yaxis='y4' # å•独的速度轴
                ))
            # æ·»åŠ æ¸©åº¦è®¾å®šå€¼
            if df_temp is not None and not df_temp.empty:
                temp_set_fields = {
                    'nakata_extruder_screw_set_temp': '螺杆设定 (°C)',
                    'nakata_extruder_rear_barrel_set_temp': '后机筒设定 (°C)',
                    'nakata_extruder_front_barrel_set_temp': '前机筒设定 (°C)',
                    'nakata_extruder_head_set_temp': '机头设定 (°C)',
                }
                for field, label in temp_set_fields.items():
                    fig2.add_trace(go.Scatter(
                        x=df_temp['time'],
                        y=df_temp[field],
                        name=label,
                        mode='lines',
                        line=dict(width=1, dash='dash'),
                        yaxis='y5'
                    ))
            # è®¾ç½®å›¾è¡¨2布局
            fig2.update_layout(
                title='米重与设定参数趋势分析',
                xaxis=dict(
                    title='时间',
                    rangeslider=dict(visible=True),
                    type='date'
                ),
                yaxis=dict(
                    title='米重 (Kg/m)',
                    title_font=dict(color='blue'),
                    tickfont=dict(color='blue')
                ),
                yaxis2=dict(
                    title='挤出机转速 (RPM)',
                    title_font=dict(color='green'),
                    tickfont=dict(color='green'),
                    overlaying='y',
                    side='right'
                ),
                yaxis3=dict(
                    title='机头压力',
                    title_font=dict(color='orange'),
                    tickfont=dict(color='orange'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.85
                ),
                yaxis4=dict(
                    title='流程主速 (M/Min)',
                    title_font=dict(color='red'),
                    tickfont=dict(color='red'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.75
                ),
                yaxis5=dict(
                    title='温度设定 (°C)',
                    title_font=dict(color='purple'),
                    tickfont=dict(color='purple'),
                    overlaying='y',
                    side='left',
                    anchor='free',
                    position=0.15
                ),
                legend=dict(
                    orientation="h",
                    yanchor="bottom",
                    y=1.02,
                    xanchor="right",
                    x=1
                ),
                height=600,
                margin=dict(l=100, r=150, t=100, b=100),
                hovermode='x unified'
            )
            # æ˜¾ç¤ºå›¾è¡¨2
            st.plotly_chart(fig2, width='stretch', config={'scrollZoom': True})
            # æ•°æ®æ‘˜è¦
            # st.subheader("📊 æ•°æ®æ‘˜è¦")
            # summary_cols = st.columns(4)
            # with summary_cols[0]:
            #     if df_extruder is not None and not df_extruder.empty:
            #         st.metric("平均米重", f"{df_extruder['metered_weight'].mean():.2f} Kg/m")
            # with summary_cols[1]:
            #     if df_extruder is not None and not df_extruder.empty:
            #         st.metric("平均设定转速", f"{df_extruder['screw_speed_set'].mean():.2f} RPM")
            # with summary_cols[2]:
            #     if df_extruder is not None and not df_extruder.empty:
            #         st.metric("平均实际转速", f"{df_extruder['screw_speed_actual'].mean():.2f} RPM")
            # with summary_cols[3]:
            #     if df_extruder is not None and not df_extruder.empty:
            #         st.metric("平均机头压力", f"{df_extruder['head_pressure'].mean():.2f}")
app/pages/metered_weight_regression.py
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,620 @@
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from app.services.extruder_service import ExtruderService
from app.services.main_process_service import MainProcessService
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
def show_metered_weight_regression():
    # åˆå§‹åŒ–服务
    extruder_service = ExtruderService()
    main_process_service = MainProcessService()
    # é¡µé¢æ ‡é¢˜
    st.title("米重多元线性回归分析")
    # åˆå§‹åŒ–会话状态用于日期同步
    if 'mr_start_date' not in st.session_state:
        st.session_state['mr_start_date'] = datetime.now().date() - timedelta(days=7)
    if 'mr_end_date' not in st.session_state:
        st.session_state['mr_end_date'] = datetime.now().date()
    if 'mr_quick_select' not in st.session_state:
        st.session_state['mr_quick_select'] = "最近7天"
    if 'mr_time_offset' not in st.session_state:
        st.session_state['mr_time_offset'] = 0.0
    if 'mr_selected_features' not in st.session_state:
        st.session_state['mr_selected_features'] = [
            '螺杆转速', '机头压力', '流程主速', '螺杆温度',
            '后机筒温度', '前机筒温度', '机头温度'
        ]
    # å®šä¹‰å›žè°ƒå‡½æ•°
    def update_dates(qs):
        st.session_state['mr_quick_select'] = qs
        today = datetime.now().date()
        if qs == "今天":
            st.session_state['mr_start_date'] = today
            st.session_state['mr_end_date'] = today
        elif qs == "最近3天":
            st.session_state['mr_start_date'] = today - timedelta(days=3)
            st.session_state['mr_end_date'] = today
        elif qs == "最近7天":
            st.session_state['mr_start_date'] = today - timedelta(days=7)
            st.session_state['mr_end_date'] = today
        elif qs == "最近30天":
            st.session_state['mr_start_date'] = today - timedelta(days=30)
            st.session_state['mr_end_date'] = today
    def on_date_change():
        st.session_state['mr_quick_select'] = "自定义"
    # æŸ¥è¯¢æ¡ä»¶åŒºåŸŸ
    with st.expander("🔍 æŸ¥è¯¢é…ç½®", expanded=True):
        # æ·»åŠ è‡ªå®šä¹‰ CSS å®žçŽ°å“åº”å¼æ¢è¡Œ
        st.markdown("""
            <style>
            /* å¼ºåˆ¶åˆ—容器换行 */
            [data-testid="stExpander"] [data-testid="column"] {
                flex: 1 1 120px !important;
                min-width: 120px !important;
            }
            /* é’ˆå¯¹æ—¥æœŸè¾“入框列稍微加宽一点 */
            @media (min-width: 768px) {
                [data-testid="stExpander"] [data-testid="column"]:nth-child(6),
                [data-testid="stExpander"] [data-testid="column"]:nth-child(7) {
                    flex: 2 1 180px !important;
                    min-width: 180px !important;
                }
            }
            </style>
            """, unsafe_allow_html=True)
        # åˆ›å»ºå¸ƒå±€
        cols = st.columns([1, 1, 1, 1, 1, 1.5, 1.5, 1])
        options = ["今天", "最近3天", "最近7天", "最近30天", "自定义"]
        for i, option in enumerate(options):
            with cols[i]:
                # æ ¹æ®å½“前选择状态决定按钮类型
                button_type = "primary" if st.session_state['mr_quick_select'] == option else "secondary"
                if st.button(option, key=f"btn_mr_{option}", width='stretch', type=button_type):
                    update_dates(option)
                    st.rerun()
        with cols[5]:
            start_date = st.date_input(
                "开始日期",
                label_visibility="collapsed",
                key="mr_start_date",
                on_change=on_date_change
            )
        with cols[6]:
            end_date = st.date_input(
                "结束日期",
                label_visibility="collapsed",
                key="mr_end_date",
                on_change=on_date_change
            )
        with cols[7]:
            query_button = st.button("🚀 å¼€å§‹åˆ†æž", key="mr_query", width='stretch')
        # æ•°æ®å¯¹é½è°ƒæ•´
        st.markdown("---")
        offset_cols = st.columns([2, 4, 2])
        with offset_cols[0]:
            st.write("⏱️ **数据对齐调整**")
        with offset_cols[1]:
            time_offset = st.slider(
                "时间偏移 (分钟)",
                min_value=0.0,
                max_value=5.0,
                value=st.session_state['mr_time_offset'],
                step=0.1,
                help="调整主流程和温度数据的时间偏移,使其与挤出机米重数据对齐。"
            )
            st.session_state['mr_time_offset'] = time_offset
        with offset_cols[2]:
            st.write(f"当前偏移: {time_offset} åˆ†é’Ÿ")
        # ç‰¹å¾é€‰æ‹©
        st.markdown("---")
        st.write("📋 **特征选择**")
        feature_cols = st.columns(2)
        all_features = [
            '螺杆转速', '机头压力', '流程主速', '螺杆温度',
            '后机筒温度', '前机筒温度', '机头温度'
        ]
        for i, feature in enumerate(all_features):
            with feature_cols[i % 2]:
                st.session_state['mr_selected_features'] = [
                    f for f in st.session_state['mr_selected_features'] if f in all_features
                ]
                if st.checkbox(
                    feature,
                    key=f"feat_{feature}",
                    value=feature in st.session_state['mr_selected_features']
                ):
                    if feature not in st.session_state['mr_selected_features']:
                        st.session_state['mr_selected_features'].append(feature)
                else:
                    if feature in st.session_state['mr_selected_features']:
                        st.session_state['mr_selected_features'].remove(feature)
        if not st.session_state['mr_selected_features']:
            st.warning("至少需要选择一个特征变量")
    # è½¬æ¢ä¸ºdatetime对象
    start_dt = datetime.combine(start_date, datetime.min.time())
    end_dt = datetime.combine(end_date, datetime.max.time())
    # æŸ¥è¯¢å¤„理 - ä»…获取数据并缓存到会话状态
    if query_button:
        with st.spinner("正在获取数据..."):
            # 1. èŽ·å–å®Œæ•´çš„æŒ¤å‡ºæœºæ•°æ®ï¼ˆåŒ…å«æ‰€æœ‰æ—¶é—´ç‚¹çš„èžºæ†è½¬é€Ÿå’Œæœºå¤´åŽ‹åŠ›ï¼‰
            df_extruder_full = extruder_service.get_extruder_data(start_dt, end_dt)
            # 2. èŽ·å–ä¸»æµç¨‹æŽ§åˆ¶æ•°æ®
            df_main_speed = main_process_service.get_cutting_setting_data(start_dt, end_dt)
            df_temp = main_process_service.get_temperature_control_data(start_dt, end_dt)
            # æ£€æŸ¥æ˜¯å¦æœ‰æ•°æ®
            has_data = any([
                df_extruder_full is not None and not df_extruder_full.empty,
                df_main_speed is not None and not df_main_speed.empty,
                df_temp is not None and not df_temp.empty
            ])
            if not has_data:
                st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
                # æ¸…除缓存数据
                for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp', 'last_query_start', 'last_query_end']:
                    if key in st.session_state:
                        del st.session_state[key]
                return
            # ç¼“存数据到会话状态
            st.session_state['cached_extruder_full'] = df_extruder_full
            st.session_state['cached_main_speed'] = df_main_speed
            st.session_state['cached_temp'] = df_temp
            st.session_state['last_query_start'] = start_dt
            st.session_state['last_query_end'] = end_dt
    # æ•°æ®å¤„理和图表渲染 - æ¯æ¬¡åº”用重新运行时执行(包括调整时间偏移时)
    if all(key in st.session_state for key in ['cached_extruder_full', 'cached_main_speed', 'cached_temp']):
        with st.spinner("正在分析数据..."):
            # èŽ·å–ç¼“å­˜æ•°æ®
            df_extruder_full = st.session_state['cached_extruder_full']
            df_main_speed = st.session_state['cached_main_speed']
            df_temp = st.session_state['cached_temp']
            # èŽ·å–å½“å‰æ—¶é—´åç§»é‡
            offset_delta = timedelta(minutes=st.session_state['mr_time_offset'])
            # å¤„理数据
            if df_extruder_full is not None and not df_extruder_full.empty:
                # è¿‡æ»¤æœºå¤´åŽ‹åŠ›å¤§äºŽ2的值
                df_extruder_filtered = df_extruder_full[df_extruder_full['head_pressure'] <= 2]
                # ä¸ºç±³é‡æ•°æ®åˆ›å»ºåç§»åŽçš„æ—¶é—´åˆ—(只对米重数据进行时间偏移)
                df_extruder_filtered['weight_time'] = df_extruder_filtered['time'] - offset_delta
            else:
                df_extruder_filtered = None
            # æ£€æŸ¥æ˜¯å¦æœ‰æ•°æ®
            has_data = any([
                df_extruder_filtered is not None and not df_extruder_filtered.empty,
                df_main_speed is not None and not df_main_speed.empty,
                df_temp is not None and not df_temp.empty
            ])
            if not has_data:
                st.warning("所选时间段内未找到任何数据,请尝试调整查询条件。")
                return
            # æ•°æ®æ•´åˆä¸Žé¢„处理
            def integrate_data(df_extruder_filtered, df_main_speed, df_temp):
                # ç¡®ä¿æŒ¤å‡ºæœºæ•°æ®å­˜åœ¨
                if df_extruder_filtered is None or df_extruder_filtered.empty:
                    return None
                # åˆ›å»ºåªåŒ…含米重和偏移时间的主数据集
                df_weight = df_extruder_filtered[['weight_time', 'metered_weight']].copy()
                df_weight.rename(columns={'weight_time': 'time'}, inplace=True)  # å°†weight_time重命名为time作为基准时间
                # åˆ›å»ºåŒ…含螺杆转速和原始时间的完整数据集
                df_screw = df_extruder_filtered[['time', 'screw_speed_actual']].copy()
                # åˆ›å»ºåŒ…含机头压力和原始时间的完整数据集
                df_pressure = df_extruder_filtered[['time', 'head_pressure']].copy()
                # ä½¿ç”¨åç§»åŽçš„米重时间整合螺杆转速数据
                df_merged = pd.merge_asof(
                    df_weight.sort_values('time'),
                    df_screw.sort_values('time'),
                    on='time',
                    direction='nearest',
                    tolerance=pd.Timedelta('1min')
                )
                # ä½¿ç”¨åç§»åŽçš„米重时间整合机头压力数据
                df_merged = pd.merge_asof(
                    df_merged.sort_values('time'),
                    df_pressure.sort_values('time'),
                    on='time',
                    direction='nearest',
                    tolerance=pd.Timedelta('1min')
                )
                # æ•´åˆä¸»æµç¨‹æ•°æ®
                if df_main_speed is not None and not df_main_speed.empty:
                    df_main_speed = df_main_speed[['time', 'process_main_speed']]
                    df_merged = pd.merge_asof(
                        df_merged.sort_values('time'),
                        df_main_speed.sort_values('time'),
                        on='time',
                        direction='nearest',
                        tolerance=pd.Timedelta('1min')
                    )
                # æ•´åˆæ¸©åº¦æ•°æ®
                if df_temp is not None and not df_temp.empty:
                    temp_cols = ['time', 'nakata_extruder_screw_display_temp',
                               'nakata_extruder_rear_barrel_display_temp',
                               'nakata_extruder_front_barrel_display_temp',
                               'nakata_extruder_head_display_temp']
                    df_temp_subset = df_temp[temp_cols].copy()
                    df_merged = pd.merge_asof(
                        df_merged.sort_values('time'),
                        df_temp_subset.sort_values('time'),
                        on='time',
                        direction='nearest',
                        tolerance=pd.Timedelta('1min')
                    )
                # é‡å‘½ååˆ—以提高可读性
                df_merged.rename(columns={
                    'screw_speed_actual': '螺杆转速',
                    'head_pressure': '机头压力',
                    'process_main_speed': '流程主速',
                    'nakata_extruder_screw_display_temp': '螺杆温度',
                    'nakata_extruder_rear_barrel_display_temp': '后机筒温度',
                    'nakata_extruder_front_barrel_display_temp': '前机筒温度',
                    'nakata_extruder_head_display_temp': '机头温度'
                }, inplace=True)
                # æ¸…理数据
                df_merged.dropna(subset=['metered_weight'], inplace=True)
                return df_merged
            # æ‰§è¡Œæ•°æ®æ•´åˆ
            df_analysis = integrate_data(df_extruder_filtered, df_main_speed, df_temp)
            if df_analysis is None or df_analysis.empty:
                st.warning("数据整合失败,请检查数据质量或调整时间范围。")
                return
            # é‡å‘½åç±³é‡åˆ—
            df_analysis.rename(columns={'metered_weight': '米重'}, inplace=True)
            # --- åŽŸå§‹æ•°æ®è¶‹åŠ¿å›¾ ---
            st.subheader("📈 åŽŸå§‹æ•°æ®è¶‹åŠ¿å›¾")
            # åˆ›å»ºè¶‹åŠ¿å›¾
            fig_trend = go.Figure()
            # æ·»åŠ ç±³é‡æ•°æ®ï¼ˆä½¿ç”¨åç§»åŽçš„æ—¶é—´ï¼‰
            if df_extruder_filtered is not None and not df_extruder_filtered.empty:
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_filtered['weight_time'],  # ä½¿ç”¨åç§»åŽçš„æ—¶é—´
                    y=df_extruder_filtered['metered_weight'],
                    name='米重 (Kg/m) [已偏移]',
                    mode='lines',
                    line=dict(color='blue', width=2)
                ))
                # æ·»åŠ èžºæ†è½¬é€Ÿï¼ˆä½¿ç”¨åŽŸå§‹æ—¶é—´ï¼‰
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_filtered['time'],  # ä½¿ç”¨åŽŸå§‹æ—¶é—´
                    y=df_extruder_filtered['screw_speed_actual'],
                    name='螺杆转速 (RPM)',
                    mode='lines',
                    line=dict(color='green', width=1.5),
                    yaxis='y2'
                ))
                # æ·»åŠ æœºå¤´åŽ‹åŠ›ï¼ˆä½¿ç”¨åŽŸå§‹æ—¶é—´ï¼Œå·²è¿‡æ»¤å¤§äºŽ2的值)
                fig_trend.add_trace(go.Scatter(
                    x=df_extruder_filtered['time'],  # ä½¿ç”¨åŽŸå§‹æ—¶é—´
                    y=df_extruder_filtered['head_pressure'],
                    name='机头压力 (≤2)',
                    mode='lines',
                    line=dict(color='orange', width=1.5),
                    yaxis='y3'
                ))
            # æ·»åŠ æµç¨‹ä¸»é€Ÿ
            if df_main_speed is not None and not df_main_speed.empty:
                fig_trend.add_trace(go.Scatter(
                    x=df_main_speed['time'],
                    y=df_main_speed['process_main_speed'],
                    name='流程主速 (M/Min)',
                    mode='lines',
                    line=dict(color='red', width=1.5),
                    yaxis='y4'
                ))
            # æ·»åŠ æ¸©åº¦æ•°æ®
            if df_temp is not None and not df_temp.empty:
                # èžºæ†æ¸©åº¦
                fig_trend.add_trace(go.Scatter(
                    x=df_temp['time'],
                    y=df_temp['nakata_extruder_screw_display_temp'],
                    name='螺杆温度 (°C)',
                    mode='lines',
                    line=dict(color='purple', width=1),
                    yaxis='y5'
                ))
            # é…ç½®è¶‹åŠ¿å›¾å¸ƒå±€
            fig_trend.update_layout(
                title=f'原始数据趋势 (米重向前偏移 {st.session_state["mr_time_offset"]} åˆ†é’Ÿ)',
                xaxis=dict(
                    title='时间',
                    rangeslider=dict(visible=True),
                    type='date'
                ),
                yaxis=dict(
                    title='米重 (Kg/m)',
                    title_font=dict(color='blue'),
                    tickfont=dict(color='blue')
                ),
                yaxis2=dict(
                    title='螺杆转速 (RPM)',
                    title_font=dict(color='green'),
                    tickfont=dict(color='green'),
                    overlaying='y',
                    side='right'
                ),
                yaxis3=dict(
                    title='机头压力',
                    title_font=dict(color='orange'),
                    tickfont=dict(color='orange'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.85
                ),
                yaxis4=dict(
                    title='流程主速 (M/Min)',
                    title_font=dict(color='red'),
                    tickfont=dict(color='red'),
                    overlaying='y',
                    side='right',
                    anchor='free',
                    position=0.75
                ),
                yaxis5=dict(
                    title='温度 (°C)',
                    title_font=dict(color='purple'),
                    tickfont=dict(color='purple'),
                    overlaying='y',
                    side='left',
                    anchor='free',
                    position=0.15
                ),
                legend=dict(
                    orientation="h",
                    yanchor="bottom",
                    y=1.02,
                    xanchor="right",
                    x=1
                ),
                height=600,
                margin=dict(l=100, r=200, t=100, b=100),
                hovermode='x unified'
            )
            # æ˜¾ç¤ºè¶‹åŠ¿å›¾
            st.plotly_chart(fig_trend, width='stretch', config={'scrollZoom': True})
            # --- å¤šå…ƒçº¿æ€§å›žå½’分析 ---
            st.subheader("📊 å¤šå…ƒçº¿æ€§å›žå½’分析")
            # æ£€æŸ¥æ˜¯å¦é€‰æ‹©äº†ç‰¹å¾
            if not st.session_state['mr_selected_features']:
                st.warning("请至少选择一个特征变量进行回归分析")
            else:
                # æ£€æŸ¥æ‰€æœ‰é€‰æ‹©çš„特征是否在数据中
                missing_features = [f for f in st.session_state['mr_selected_features'] if f not in df_analysis.columns]
                if missing_features:
                    st.warning(f"数据中缺少以下特征: {', '.join(missing_features)}")
                else:
                    # å‡†å¤‡æ•°æ®
                    X = df_analysis[st.session_state['mr_selected_features']]
                    y = df_analysis['米重']
                    # æ¸…理数据中的NaN值
                    combined = pd.concat([X, y], axis=1)
                    combined_clean = combined.dropna()
                    # æ£€æŸ¥æ¸…理后的数据量
                    if len(combined_clean) < 10:
                        st.warning("数据量不足或包含过多NaN值,无法进行有效的回归分析")
                    else:
                        # é‡æ–°åˆ†ç¦»X和y
                        X_clean = combined_clean[st.session_state['mr_selected_features']]
                        y_clean = combined_clean['米重']
                        # åˆ†å‰²è®­ç»ƒé›†å’Œæµ‹è¯•集
                        X_train, X_test, y_train, y_test = train_test_split(X_clean, y_clean, test_size=0.2, random_state=42)
                        # è®­ç»ƒæ¨¡åž‹
                        model = LinearRegression()
                        model.fit(X_train, y_train)
                        # é¢„测
                        y_pred = model.predict(X_test)
                        y_train_pred = model.predict(X_train)
                        # è®¡ç®—评估指标
                        r2 = r2_score(y_test, y_pred)
                        mse = mean_squared_error(y_test, y_pred)
                        mae = mean_absolute_error(y_test, y_pred)
                        rmse = np.sqrt(mse)
                        # æ˜¾ç¤ºæ¨¡åž‹æ€§èƒ½
                        metrics_cols = st.columns(2)
                        with metrics_cols[0]:
                            st.metric("R² å¾—分", f"{r2:.4f}")
                            st.metric("均方误差 (MSE)", f"{mse:.6f}")
                        with metrics_cols[1]:
                            st.metric("平均绝对误差 (MAE)", f"{mae:.6f}")
                            st.metric("均方根误差 (RMSE)", f"{rmse:.6f}")
                        # --- å®žé™…值与预测值对比 ---
                        st.subheader("🔄 å®žé™…值与预测值对比")
                        # åˆ›å»ºå¯¹æ¯”数据
                        compare_df = pd.DataFrame({
                            '实际值': y_test,
                            '预测值': y_pred
                        })
                        compare_df = compare_df.sort_index()
                        # åˆ›å»ºå¯¹æ¯”图
                        fig_compare = go.Figure()
                        fig_compare.add_trace(go.Scatter(
                            x=compare_df.index,
                            y=compare_df['实际值'],
                            name='实际值',
                            mode='lines+markers',
                            line=dict(color='blue', width=2)
                        ))
                        fig_compare.add_trace(go.Scatter(
                            x=compare_df.index,
                            y=compare_df['预测值'],
                            name='预测值',
                            mode='lines+markers',
                            line=dict(color='red', width=2, dash='dash')
                        ))
                        fig_compare.update_layout(
                            title='测试集: å®žé™…米重 vs é¢„测米重',
                            xaxis=dict(title='样本索引'),
                            yaxis=dict(title='米重 (Kg/m)'),
                            legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
                            height=400
                        )
                        st.plotly_chart(fig_compare, width='stretch')
                        # --- æ®‹å·®åˆ†æž ---
                        st.subheader("📉 æ®‹å·®åˆ†æž")
                        # è®¡ç®—残差
                        residuals = y_test - y_pred
                        # åˆ›å»ºæ®‹å·®å›¾
                        fig_residual = go.Figure()
                        fig_residual.add_trace(go.Scatter(
                            x=y_pred,
                            y=residuals,
                            mode='markers',
                            marker=dict(color='green', size=8, opacity=0.6)
                        ))
                        fig_residual.add_shape(
                            type="line",
                            x0=y_pred.min(),
                            y0=0,
                            x1=y_pred.max(),
                            y1=0,
                            line=dict(color="red", width=2, dash="dash")
                        )
                        fig_residual.update_layout(
                            title='残差图',
                            xaxis=dict(title='预测值'),
                            yaxis=dict(title='残差'),
                            height=400
                        )
                        st.plotly_chart(fig_residual, width='stretch')
                        # --- ç‰¹å¾é‡è¦æ€§ ---
                        st.subheader("⚖️ ç‰¹å¾é‡è¦æ€§åˆ†æž")
                        # è®¡ç®—特征重要性(基于系数绝对值)
                        feature_importance = pd.DataFrame({
                            '特征': st.session_state['mr_selected_features'],
                            '系数': model.coef_,
                            '重要性': np.abs(model.coef_)
                        })
                        feature_importance = feature_importance.sort_values('重要性', ascending=False)
                        # åˆ›å»ºç‰¹å¾é‡è¦æ€§å›¾
                        fig_importance = px.bar(
                            feature_importance,
                            x='特征',
                            y='重要性',
                            title='特征重要性(基于系数绝对值)',
                            color='重要性',
                            color_continuous_scale='viridis'
                        )
                        fig_importance.update_layout(
                            xaxis=dict(tickangle=-45),
                            height=400
                        )
                        st.plotly_chart(fig_importance, width='stretch')
                        # æ˜¾ç¤ºç³»æ•°è¡¨
                        st.write("### æ¨¡åž‹ç³»æ•°")
                        coef_df = pd.DataFrame({
                            '特征': ['截距'] + st.session_state['mr_selected_features'],
                            '系数': [model.intercept_] + list(model.coef_)
                        })
                        st.dataframe(coef_df, use_container_width=True)
                        # --- é¢„测功能 ---
                        st.subheader("🔮 ç±³é‡é¢„测")
                        # åˆ›å»ºé¢„测表单
                        st.write("输入特征值进行米重预测:")
                        predict_cols = st.columns(2)
                        input_features = {}
                        for i, feature in enumerate(st.session_state['mr_selected_features']):
                            with predict_cols[i % 2]:
                                # èŽ·å–ç‰¹å¾çš„ç»Ÿè®¡ä¿¡æ¯
                                min_val = df_analysis[feature].min()
                                max_val = df_analysis[feature].max()
                                mean_val = df_analysis[feature].mean()
                                input_features[feature] = st.number_input(
                                    f"{feature}",
                                    key=f"pred_{feature}",
                                    value=float(mean_val),
                                    min_value=float(min_val),
                                    max_value=float(max_val),
                                    step=0.1
                                )
                        if st.button("预测米重"):
                            # å‡†å¤‡é¢„测数据
                            input_data = [[input_features[feature] for feature in st.session_state['mr_selected_features']]]
                            # é¢„测
                            predicted_weight = model.predict(input_data)[0]
                            # æ˜¾ç¤ºé¢„测结果
                            st.success(f"预测米重: {predicted_weight:.4f} Kg/m")
                        # --- æ•°æ®é¢„览 ---
                        st.subheader("🔍 æ•°æ®é¢„览")
                        st.dataframe(df_analysis.head(20), use_container_width=True)
    else:
        # æç¤ºç”¨æˆ·ç‚¹å‡»å¼€å§‹åˆ†æžæŒ‰é’®
        st.info("请选择时间范围并点击'开始分析'按钮获取数据。")
app/pages/sorting_dashboard.py
@@ -183,7 +183,9 @@
                            rangeslider=dict(visible=True)
                        ),
                        yaxis=dict(fixedrange=False),
                        hovermode='x unified',
                        dragmode='zoom'
                    )
                    
                    # é…ç½®å›¾è¡¨å‚æ•°
app/services/data_processing_service.py
@@ -211,20 +211,21 @@
        try:
            # è¯†åˆ«æžå€¼ç‚¹
            extreme_points = self.identify_local_maxima(df)
            # print("识别极值点:", extreme_points)
            # è¯†åˆ«é˜¶æ®µæœ€å¤§å€¼
            phase_maxima = self.identify_phase_maxima(df)
            # phase_maxima = self.identify_phase_maxima(df)
            # print("识别阶段最大值:", phase_maxima)
            
            # è®¡ç®—每个极值点的合格率
            if not extreme_points.empty:
                extreme_points['pass_rate'] = extreme_points.apply(self.calculate_pass_rate, axis=1)
            
            # è®¡ç®—整体合格率
            overall_pass_rate = self.calculate_overall_pass_rate(df)
            overall_pass_rate = self.calculate_overall_pass_rate(extreme_points)
            
            return {
                'extreme_points': extreme_points,
                'phase_maxima': phase_maxima,
                'phase_maxima': pd.DataFrame(),
                'overall_pass_rate': overall_pass_rate
            }
        except Exception as e:
app/services/main_process_service.py
@@ -25,7 +25,7 @@
                self.db.connect()
            
            query = """
            SELECT time, process_main_speed
            SELECT time, process_main_speed, cutting_count
            FROM public.aics_main_process_cutting_setting 
            WHERE time BETWEEN %s AND %s 
            ORDER BY time ASC
dashboard.py
@@ -3,6 +3,10 @@
from app.pages.extruder_dashboard import show_extruder_dashboard
from app.pages.main_process_dashboard import show_main_process_dashboard
from app.pages.comprehensive_dashboard import show_comprehensive_dashboard
from app.pages.metered_weight_dashboard import show_metered_weight_dashboard
from app.pages.metered_weight_correlation import show_metered_weight_correlation
from app.pages.metered_weight_regression import show_metered_weight_regression
from app.pages.metered_weight_advanced import show_metered_weight_advanced
# è®¾ç½®é¡µé¢é…ç½®
st.set_page_config(
@@ -35,9 +39,37 @@
comprehensive_page = st.Page(
    show_comprehensive_dashboard,
    title="综合分析",
    title="条重综合分析",
    icon="🌐",
    url_path="comprehensive"
)
metered_weight_page = st.Page(
    show_metered_weight_dashboard,
    title="米重综合分析",
    icon="📏",
    url_path="metered_weight"
)
metered_weight_correlation_page = st.Page(
    show_metered_weight_correlation,
    title="米重相关性分析",
    icon="📊",
    url_path="metered_weight_correlation"
)
metered_weight_regression_page = st.Page(
    show_metered_weight_regression,
    title="米重多元线性回归分析",
    icon="📈",
    url_path="metered_weight_regression"
)
metered_weight_advanced_page = st.Page(
    show_metered_weight_advanced,
    title="米重高级预测分析",
    icon="🤖",
    url_path="metered_weight_advanced"
)
# ä¾§è¾¹æ é¡µè„šä¿¡æ¯
@@ -47,7 +79,7 @@
# å¯¼èˆªé…ç½®
pg = st.navigation({
    "综合分析": [comprehensive_page],
    "综合分析": [comprehensive_page, metered_weight_page, metered_weight_correlation_page, metered_weight_regression_page, metered_weight_advanced_page],
    "分项分析": [sorting_page, extruder_page, main_process_page]
})
requirements.txt
@@ -2,4 +2,7 @@
psycopg2-binary
pandas
plotly
python-dotenv
python-dotenv
scikit-learn
pytorch
torchvision