فهرست منبع

8.2-2连接超时处理

Zzcoded 1 روز پیش
والد
کامیت
a968b223ab
1فایلهای تغییر یافته به همراه79 افزوده شده و 0 حذف شده
  1. 79 0
      main/server/core/connnection.py

+ 79 - 0
main/server/core/connnection.py

@@ -17,6 +17,7 @@ from core.utils.voiceprint_provider import VoiceprintProvider
 from collections import deque
 from core.utils.dialogue import Dialogue
 from core.utils.prompt_manager import PromptManager
+from config.config_loader import get_private_config_from_api
 
 
 
@@ -138,5 +139,83 @@ class ConnnectionHandler:
         # 初始化提示词管理器
         self.prompt_manager = PromptManager(config, self.logger)
 
+    async def handle_connection(self, ws):
+        try:
+            self.headers = dict(ws.request.headers)
+            if self.headers.get("device-id", None) is None:
+                # 尝试从 URL 的查询参数中获取 device-id
+                from urllib.parse import parse_qs, urlparse
+
+                # 从 WebSocket 请求中获取路径
+                request_path = ws.request.path
+                if not request_path:
+                    self.logger.bind(tag=TAG).error("无法获取请求路径")
+                    return
+                parsed_url = urlparse(request_path)
+                query_params = parse_qs(parsed_url.query)
+                if "device-id" in query_params:
+                    self.headers["device-id"] = query_params["device-id"][0]
+                    self.headers["client-id"] = query_params["client-id"][0]
+                else:
+                    await ws.send("端口正常,如需测试连接,请使用test_page.html")
+                    await self.close(ws)
+                    return
+            real_ip = self.headers.get("x-real-ip") or self.headers.get("x-forwarded-for")
+            if real_ip:
+                self.client_ip = real_ip.split(",")[0].strip()
+            else:
+                self.client_ip = ws.remote_address[0]
+            self.logger.bind(tag=TAG).info(f"{self.client_ip} conn - Headers: {self.headers}")
+            await self.auth.authenticate(self.headers)
+
+            # 验证通过,继续处理
+            self.websocket = ws
+            self.device_id = self.headers.get("device-id", None)
+            self.last_activity_time = time.time() * 1000
+
+            self.timeout_task = asyncio.create_task(self._check_timeout())
+            self.welcome_msg = self.config["xiaozhi"]
+            self.welcome_msg["session_id"] = self.session_id
+
+            self._init
+
+
+
+    async def _check_timeout(self):
+        """"检查连接超时"""
+        try:
+            while not self.stop_event.is_set():
+                if self.last_activity_time > 0.0:
+                    current_time = time.time() * 1000
+                    if (current_time - self.last_activity_time) > self.timeout_seconds * 1000:
+                        if not self.stop_event.is_set():
+                            self.logger.bind(tag=TAG).info("连接超时,关闭连接")
+                            self.stop_event.set()
+                            # 确保不会因为异常堵塞
+                            try:
+                                await self.close(self.websocket)
+                            except Exception as e:
+                                self.logger.bind(tag=TAG).error(f"关闭连接时发生错误: {str(e)}")
+                        break
+                await asyncio.sleep(10)
+        except Exception as e:
+            self.logger.bind(tag=TAG).error(f"检查连接超时时发生错误: {str(e)}")
+        finally:
+            self.logger.bind(tag=TAG).info("检查连接超时任务结束")
+
+    def _initialize_private_config(self):
+        """
+        
+        
+        """
+        if not self.read_config_from_api:
+            return
+        try:
+            begin_time = time.time()
+            private_config = get_private_config_from_api(self.device_id)
+
+
+
+