|
@@ -17,6 +17,7 @@ from core.utils.voiceprint_provider import VoiceprintProvider
|
|
|
from collections import deque
|
|
|
from core.utils.dialogue import Dialogue
|
|
|
from core.utils.prompt_manager import PromptManager
|
|
|
+from config.config_loader import get_private_config_from_api
|
|
|
|
|
|
|
|
|
|
|
@@ -138,5 +139,83 @@ class ConnnectionHandler:
|
|
|
# 初始化提示词管理器
|
|
|
self.prompt_manager = PromptManager(config, self.logger)
|
|
|
|
|
|
+ async def handle_connection(self, ws):
|
|
|
+ try:
|
|
|
+ self.headers = dict(ws.request.headers)
|
|
|
+ if self.headers.get("device-id", None) is None:
|
|
|
+ # 尝试从 URL 的查询参数中获取 device-id
|
|
|
+ from urllib.parse import parse_qs, urlparse
|
|
|
+
|
|
|
+ # 从 WebSocket 请求中获取路径
|
|
|
+ request_path = ws.request.path
|
|
|
+ if not request_path:
|
|
|
+ self.logger.bind(tag=TAG).error("无法获取请求路径")
|
|
|
+ return
|
|
|
+ parsed_url = urlparse(request_path)
|
|
|
+ query_params = parse_qs(parsed_url.query)
|
|
|
+ if "device-id" in query_params:
|
|
|
+ self.headers["device-id"] = query_params["device-id"][0]
|
|
|
+ self.headers["client-id"] = query_params["client-id"][0]
|
|
|
+ else:
|
|
|
+ await ws.send("端口正常,如需测试连接,请使用test_page.html")
|
|
|
+ await self.close(ws)
|
|
|
+ return
|
|
|
+ real_ip = self.headers.get("x-real-ip") or self.headers.get("x-forwarded-for")
|
|
|
+ if real_ip:
|
|
|
+ self.client_ip = real_ip.split(",")[0].strip()
|
|
|
+ else:
|
|
|
+ self.client_ip = ws.remote_address[0]
|
|
|
+ self.logger.bind(tag=TAG).info(f"{self.client_ip} conn - Headers: {self.headers}")
|
|
|
+ await self.auth.authenticate(self.headers)
|
|
|
+
|
|
|
+ # 验证通过,继续处理
|
|
|
+ self.websocket = ws
|
|
|
+ self.device_id = self.headers.get("device-id", None)
|
|
|
+ self.last_activity_time = time.time() * 1000
|
|
|
+
|
|
|
+ self.timeout_task = asyncio.create_task(self._check_timeout())
|
|
|
+ self.welcome_msg = self.config["xiaozhi"]
|
|
|
+ self.welcome_msg["session_id"] = self.session_id
|
|
|
+
|
|
|
+ self._init
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ async def _check_timeout(self):
|
|
|
+ """"检查连接超时"""
|
|
|
+ try:
|
|
|
+ while not self.stop_event.is_set():
|
|
|
+ if self.last_activity_time > 0.0:
|
|
|
+ current_time = time.time() * 1000
|
|
|
+ if (current_time - self.last_activity_time) > self.timeout_seconds * 1000:
|
|
|
+ if not self.stop_event.is_set():
|
|
|
+ self.logger.bind(tag=TAG).info("连接超时,关闭连接")
|
|
|
+ self.stop_event.set()
|
|
|
+ # 确保不会因为异常堵塞
|
|
|
+ try:
|
|
|
+ await self.close(self.websocket)
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.bind(tag=TAG).error(f"关闭连接时发生错误: {str(e)}")
|
|
|
+ break
|
|
|
+ await asyncio.sleep(10)
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.bind(tag=TAG).error(f"检查连接超时时发生错误: {str(e)}")
|
|
|
+ finally:
|
|
|
+ self.logger.bind(tag=TAG).info("检查连接超时任务结束")
|
|
|
+
|
|
|
+ def _initialize_private_config(self):
|
|
|
+ """
|
|
|
+
|
|
|
+
|
|
|
+ """
|
|
|
+ if not self.read_config_from_api:
|
|
|
+ return
|
|
|
+ try:
|
|
|
+ begin_time = time.time()
|
|
|
+ private_config = get_private_config_from_api(self.device_id)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
|
|
|
|