Forráskód Böngészése

8-1服务器框架

Zzcoded 2 napja
szülő
commit
5b71d1f66a

+ 2 - 2
main/server/app.py

@@ -6,8 +6,8 @@ from aioconsole import ainput
 from config.settings import load_config
 from config.logger import setup_logging
 from core.utils.util import check_ffmpeg_installed, get_local_ip, validate_mcp_endpoint
-from core.websocket.server import WebSocketServer
-from core.simple_http.server import SimpleHttpServer
+from core.websocket_server import WebSocketServer
+from core.http_server import SimpleHttpServer
 
 TAG = __name__
 logger = setup_logging()

+ 67 - 0
main/server/core/auth.py

@@ -0,0 +1,67 @@
+from config.logger import setup_logging
+
+TAG = __name__
+logger = setup_logging()
+
+class AuthenticationError(Exception):
+    pass
+
+
+class AuthMiddleware:
+    def __init__(self,config):
+        self.config = config
+        self.auth_config = config["server"].get("auth",{})
+
+        # 构建token查询表
+        self.tokens = {
+            item["token"]: item["name"]
+            for item in self.auth_config.get("tokens",[])
+        }
+        # 设备白名单
+        self.allowed_devices = set(
+            self.auth_config.get("allowed_devices",[])
+        )
+
+    async def authenticate(self, headers: dict):
+        """
+        验证请求头中的token
+
+        Args:
+            headers (dict): 请求头
+
+        Returns:
+            str: 用户名
+        """
+        if not self.auth_config.get("enabled",False):
+            return True
+        
+        # 检查设备是否在白名单中
+        device_id = headers.get("device_id","")
+        if self.allowed_devices and device_id not in self.allowed_devices:
+            return True
+        
+        # 验证Authorization header
+        auth_header = headers.get("Authorization","")
+        if not auth_header.startswith("Bearer "):
+            logger.bind(tag=TAG).error("Missing or invalid Authorization header")
+            raise AuthenticationError("Missing or invalid Authorization header")
+        
+        token = auth_header.split(" ")[1]
+        if token not in self.tokens:
+            logger.bind(tag=TAG).error(f"Invalid token: {token}")
+            raise AuthenticationError(f"Invalid token: {token}")
+        
+        logger.bind(tag=TAG).info(f"Authentication successful - Device: {device_id}, Token: {self.tokens[token]}")
+        return True
+    
+    def get_token_name(self, token: str) -> str:
+        """
+        获取token对应的名称
+
+        Args:
+            token (str): 令牌
+
+        Returns:
+            str: 令牌名称
+        """
+        return self.tokens.get(token, "Unknown")

+ 121 - 0
main/server/core/utils/dialogue.py

@@ -0,0 +1,121 @@
+import uuid
+import re
+from typing import List, Dict
+from datetime import datetime
+
+class Message:
+    def __init__(
+        self,
+        role: str,
+        content: str = None,
+        uniq_id: str = None,
+        tool_calls = None,
+        tool_call_id = None,
+    ):
+        self.unid_id = uniq_id if uniq_id is not None else str(uuid.uuid4())
+        self.role = role
+        self.content = content
+        self.tool_calls = tool_calls
+        self.tool_call_id = tool_call_id
+
+
+
+class Dialogue:
+    def __init__(self):
+        self.dialogue: List[Message] = []
+        # 获取当前时间
+        self.current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+
+    def put(self, message: Message):
+        self.dialogue.append(message)
+
+    def getMessages(self, m, dialogue):
+        if m.tool_Calls is not None:
+            dialogue.append({"role": m.role, "tool_calls": m.tool_calls})
+        elif m.role == "tool":
+            dialogue.append(
+                {
+                    "role": m.role,
+                    "tool_call_id": str(uuid.uuid64() if m.tool_call_id is None else m.tool_call_id),
+                    "content": m.content,
+                }
+            )
+        else:
+            dialogue.append(
+                {
+                    "role": m.role,
+                    "content": m.content,
+                }
+            )
+
+    def get_llm_dialogue(self) -> List[Dict[str, str]]:
+        # 直接调用get_llm_dialogue_with_memory,传入None作为memory_str
+        # 这样确保说话人功能在所有调用路径下都生效
+        return self.get_llm_dialogue_with_memory(None, None)
+    
+    def update_system_message(self, new_content: str):
+        """
+        更新系统消息
+
+        Args:
+            new_content (str): 新的系统消息内容
+        """
+        system_message = next((msg for msg in self.dialogue if msg.role == "system"), None)
+        if system_message:
+            system_message.content = new_content
+        else:
+            self.put(Message(role="system", content=new_content))
+
+    def get_llm_dialogue_with_memory(
+        self, memory_str: str = None, voiceprint_config: dict = None
+    ) -> List[Dict[str, str]]:
+        # 构建对话
+        dialogue = []
+
+        # 添加系统提示和记忆
+        system_message = next(
+            (msg for msg in self.dialogue if msg.role == "system"), None
+        )
+
+        if system_message:
+            # 基础系统提示
+            enhanced_system_prompt = system_message.content
+
+            # 添加说话人个性化描述
+            try:
+                speakers = voiceprint_config.get("speakers", [])
+                if speakers:
+                    enhanced_system_prompt += "\n\n<speakers_info>"
+                    for speaker_str in speakers:
+                        try:
+                            parts = speaker_str.split(",", 2)
+                            if len(parts) >= 2:
+                                name = parts[1].strip()
+                                # 如果描述为空,则为""
+                                description = (
+                                    parts[2].strip() if len(parts) >= 3 else ""
+                                )
+                                enhanced_system_prompt += f"\n- {name}:{description}"
+                        except:
+                            pass
+                    enhanced_system_prompt += "\n\n</speakers_info>"
+            except:
+                # 配置读取失败时忽略错误,不影响其他功能
+                pass
+
+            # 使用正则表达式匹配 <memory> 标签,不管中间有什么内容
+            if memory_str is not None:
+                enhanced_system_prompt = re.sub(
+                    r"<memory>.*?</memory>",
+                    f"<memory>\n{memory_str}\n</memory>",
+                    enhanced_system_prompt,
+                    flags=re.DOTALL,
+                )
+            dialogue.append({"role": "system", "content": enhanced_system_prompt})
+
+        # 添加用户和助手的对话
+        for m in self.dialogue:
+            if m.role != "system":  # 跳过原始的系统消息
+                self.getMessages(m, dialogue)
+
+        return dialogue

+ 85 - 0
main/server/core/utils/modules_initialize.py

@@ -0,0 +1,85 @@
+from typing import Dict, Any
+from config.logger import setup_logging
+from core.utils import tts, llm, intent, memory, vad, asr
+
+TAG = __name__
+logger = setup_logging()
+
+# *** 需要修改的点 ***
+def initialize_modules(
+        logger,
+        config: Dict[str, Any],
+        init_vad = False,
+        init_asr = False,
+        init_llm = False,
+        init_tts = False,
+        init_memory = False,
+        init_intent = False,
+) -> Dict[str, Any]:
+    """
+    初始化所有模块组件
+
+    Args:
+        config: 配置字典
+
+    Returns:
+        Dict[str, Any]: 初始化后的模块字典
+    """
+    modules = {}
+    
+    if init_tts:
+        select_tts_module = config["selected_modules"]["TTS"]
+        modules["tts"] = initialize_tts(config)
+        logger.bind(tag=TAG).info(f"初始化组件: tts成功{select_tts_module}")
+
+    # 初始化LLM模块
+    if init_llm:
+        select_llm_module = config["selected_modules"]["LLM"]
+        llm_type = (
+            select_llm_module
+            if "type" not in config["LLM"][select_llm_module]
+            else config["LLM"][select_llm_module]["type"]
+        )
+        modules["llm"] = llm.create_instance(
+            llm_type,
+            config["LLM"][select_llm_module],
+        )
+        logger.bind(tag=TAG).info(f"初始化组件: llm成功{select_llm_module}")
+
+    # 初始化Intent模块
+    if init_intent:
+        select_intent_module = config["selected_modules"]["Intent"]
+        intent_type = (
+            select_intent_module
+            if "type" not in config["Intent"][select_intent_module]
+            else config["Intent"][select_intent_module]["type"]
+        )
+        modules["intent"] = intent.create_instance(
+            intent_type,
+            config["Intent"][select_intent_module],
+        )
+        logger.bind(tag=TAG).info(f"初始化组件: intent成功{select_intent_module}")
+
+    # 初始化Memory模块
+    if init_memory:
+        select_memory_module = config["selected_modules"]["Memory"]
+        memory_type = (
+            select_memory_module
+            if "type" not in config["Memory"][select_memory_module]
+            else config["Memory"][select_memory_module]["type"]
+        )
+        modules["memory"] = memory.create_instance(
+            memory_type,
+            config["Memory"][select_memory_module],
+            config.get("summaryMemory", None),
+        )
+        logger.bind(tag=TAG).info(f"初始化组件: memory成功{select_memory_module}")
+        
+    # 初始化VAD模块
+    if init_vad:
+        select_vad_module = config["selected_modules"]["VAD"]
+        vad_type = (
+            select_vad_module
+            if "type" not in config["VAD"][select_vad_module]
+            else config["VAD"][select_vad_module]["type"]
+        )

+ 134 - 0
main/server/core/utils/voiceprint_provider.py

@@ -0,0 +1,134 @@
+import asyncio
+import json
+import time
+import aiohttp
+from urllib.parse import urlparse, parse_qs
+from typing import Optional, Dict
+from config.logger import setup_logging
+
+TAG = __name__
+logger = setup_logging()
+
+
+class VoiceprintProvider:
+    """声纹识别服务提供者"""
+    
+    def __init__(self, config: dict):
+        self.original_url = config.get("url", "")
+        self.speakers = config.get("speakers", [])
+        self.speaker_map = self._parse_speakers()
+        
+        # 解析API地址和密钥
+        self.api_url = None
+        self.api_key = None
+        self.speaker_ids = []
+        
+        if not self.original_url:
+            logger.bind(tag=TAG).warning("声纹识别URL未配置,声纹识别将被禁用")
+            self.enabled = False
+        else:
+            # 解析URL和key
+            parsed_url = urlparse(self.original_url)
+            base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+            
+            # 从查询参数中提取key
+            query_params = parse_qs(parsed_url.query)
+            self.api_key = query_params.get('key', [''])[0]
+            
+            if not self.api_key:
+                logger.bind(tag=TAG).error("URL中未找到key参数,声纹识别将被禁用")
+                self.enabled = False
+            else:
+                # 构造identify接口地址
+                self.api_url = f"{base_url}/voiceprint/identify"
+                
+                # 提取speaker_ids
+                for speaker_str in self.speakers:
+                    try:
+                        parts = speaker_str.split(",", 2)
+                        if len(parts) >= 1:
+                            speaker_id = parts[0].strip()
+                            self.speaker_ids.append(speaker_id)
+                    except Exception:
+                        continue
+                
+                # 检查是否有有效的说话人配置
+                if not self.speaker_ids:
+                    logger.bind(tag=TAG).warning("未配置有效的说话人,声纹识别将被禁用")
+                    self.enabled = False
+                else:
+                    self.enabled = True
+                    logger.bind(tag=TAG).info(f"声纹识别已配置: API={self.api_url}, 说话人={len(self.speaker_ids)}个")
+    
+    def _parse_speakers(self) -> Dict[str, Dict[str, str]]:
+        """解析说话人配置"""
+        speaker_map = {}
+        for speaker_str in self.speakers:
+            try:
+                parts = speaker_str.split(",", 2)
+                if len(parts) >= 3:
+                    speaker_id, name, description = parts[0].strip(), parts[1].strip(), parts[2].strip()
+                    speaker_map[speaker_id] = {
+                        "name": name,
+                        "description": description
+                    }
+            except Exception as e:
+                logger.bind(tag=TAG).warning(f"解析说话人配置失败: {speaker_str}, 错误: {e}")
+        return speaker_map
+    
+    async def identify_speaker(self, audio_data: bytes, session_id: str) -> Optional[str]:
+        """识别说话人"""
+        if not self.enabled or not self.api_url or not self.api_key:
+            logger.bind(tag=TAG).debug("声纹识别功能已禁用或未配置,跳过识别")
+            return None
+            
+        try:
+            api_start_time = time.monotonic()
+            
+            # 准备请求头
+            headers = {
+                'Authorization': f'Bearer {self.api_key}',
+                'Accept': 'application/json'
+            }
+            
+            # 准备multipart/form-data数据
+            data = aiohttp.FormData()
+            data.add_field('speaker_ids', ','.join(self.speaker_ids))
+            data.add_field('file', audio_data, filename='audio.wav', content_type='audio/wav')
+            
+            timeout = aiohttp.ClientTimeout(total=10)
+            
+            # 网络请求
+            async with aiohttp.ClientSession(timeout=timeout) as session:
+                async with session.post(self.api_url, headers=headers, data=data) as response:
+                    
+                    if response.status == 200:
+                        result = await response.json()
+                        speaker_id = result.get("speaker_id")
+                        score = result.get("score", 0)
+                        total_elapsed_time = time.monotonic() - api_start_time
+                        
+                        logger.bind(tag=TAG).info(f"声纹识别耗时: {total_elapsed_time:.3f}s")
+                        
+                        # 置信度检查
+                        if score < 0.5:
+                            logger.bind(tag=TAG).warning(f"声纹识别置信度较低: {score:.3f}")
+                        
+                        if speaker_id and speaker_id in self.speaker_map:
+                            result_name = self.speaker_map[speaker_id]["name"]
+                            return result_name
+                        else:
+                            logger.bind(tag=TAG).warning(f"未识别的说话人ID: {speaker_id}")
+                            return "未知说话人"
+                    else:
+                        logger.bind(tag=TAG).error(f"声纹识别API错误: HTTP {response.status}")
+                        return None
+                        
+        except asyncio.TimeoutError:
+            elapsed = time.monotonic() - api_start_time
+            logger.bind(tag=TAG).error(f"声纹识别超时: {elapsed:.3f}s")
+            return None
+        except Exception as e:
+            elapsed = time.monotonic() - api_start_time
+            logger.bind(tag=TAG).error(f"声纹识别失败: {e}")
+            return None

+ 44 - 0
main/server/core/websocket_server.py

@@ -0,0 +1,44 @@
+import asyncio
+import websockets
+from config.logger import setup_logging
+from core.utils.modules_initialize import initialize_modules
+from config.config_loader import get_config_from_api
+from core.connnection import ConnectionHandler
+
+TAG = __name__
+
+
+class WebSocketServer:
+    def __init__(self, config:dict):
+        self.config = config
+        self.logger = setup_logging()
+        self.config_lock = asyncio.Lock()
+        modules = initialize_modules(
+            self.logger,
+            self.config,
+            "VAD" in self.config["selected_modules"],
+            "ASR" in self.config["selected_modules"],
+            "LLM" in self.config["selected_modules"],
+            "TTS" in self.config["selected_modules"],
+            "Memory" in self.config["selected_modules"],
+            "Intent" in self.config["selected_modules"],
+        )
+        self.vad = modules["vad"] if "vad" in modules else None
+        self.asr = modules["asr"] if "asr" in modules else None
+        self.llm = modules["llm"] if "llm" in modules else None
+        self.intent = modules["intent"] if "intent" in modules else None
+        self.memory = modules["memory"] if "memory" in modules else None
+
+        self.active_connections = set()
+
+    async def start(self):
+        server_config = self.config["server"]
+        host = server_config.get("ip", "0.0.0.0")
+        port = server_config.get("port", 8080)
+
+        async with websockets.serve(self._handle_connection, host, port, process_request=self._process_request):
+            await asyncio.Future()
+
+    async def _handle_connection(self,websocket):
+        """处理新连接, 每次创建独立的ConnectionHandler"""
+        handler = ConnectionHandler(websocket, self.config, self.logger)