2
0

2 Коммиты 897ad177b3 ... 0f2c50f0b0

Автор SHA1 Сообщение Дата
  Zzcoded 0f2c50f0b0 8-1-2连接 2 дней назад
  Zzcoded 5b71d1f66a 8-1服务器框架 2 дней назад

+ 2 - 2
main/server/app.py

@@ -6,8 +6,8 @@ from aioconsole import ainput
 from config.settings import load_config
 from config.logger import setup_logging
 from core.utils.util import check_ffmpeg_installed, get_local_ip, validate_mcp_endpoint
-from core.websocket.server import WebSocketServer
-from core.simple_http.server import SimpleHttpServer
+from core.websocket_server import WebSocketServer
+from core.http_server import SimpleHttpServer
 
 TAG = __name__
 logger = setup_logging()

+ 67 - 0
main/server/core/auth.py

@@ -0,0 +1,67 @@
+from config.logger import setup_logging
+
+TAG = __name__
+logger = setup_logging()
+
+class AuthenticationError(Exception):
+    pass
+
+
+class AuthMiddleware:
+    def __init__(self,config):
+        self.config = config
+        self.auth_config = config["server"].get("auth",{})
+
+        # 构建token查询表
+        self.tokens = {
+            item["token"]: item["name"]
+            for item in self.auth_config.get("tokens",[])
+        }
+        # 设备白名单
+        self.allowed_devices = set(
+            self.auth_config.get("allowed_devices",[])
+        )
+
+    async def authenticate(self, headers: dict):
+        """
+        验证请求头中的token
+
+        Args:
+            headers (dict): 请求头
+
+        Returns:
+            str: 用户名
+        """
+        if not self.auth_config.get("enabled",False):
+            return True
+        
+        # 检查设备是否在白名单中
+        device_id = headers.get("device_id","")
+        if self.allowed_devices and device_id not in self.allowed_devices:
+            return True
+        
+        # 验证Authorization header
+        auth_header = headers.get("Authorization","")
+        if not auth_header.startswith("Bearer "):
+            logger.bind(tag=TAG).error("Missing or invalid Authorization header")
+            raise AuthenticationError("Missing or invalid Authorization header")
+        
+        token = auth_header.split(" ")[1]
+        if token not in self.tokens:
+            logger.bind(tag=TAG).error(f"Invalid token: {token}")
+            raise AuthenticationError(f"Invalid token: {token}")
+        
+        logger.bind(tag=TAG).info(f"Authentication successful - Device: {device_id}, Token: {self.tokens[token]}")
+        return True
+    
+    def get_token_name(self, token: str) -> str:
+        """
+        获取token对应的名称
+
+        Args:
+            token (str): 令牌
+
+        Returns:
+            str: 令牌名称
+        """
+        return self.tokens.get(token, "Unknown")

+ 142 - 0
main/server/core/connnection.py

@@ -0,0 +1,142 @@
+import os
+import sys
+import copy
+import json
+import uuid
+import time
+import queue
+import asyncio
+import threading
+import traceback
+import subprocess
+import websockets
+from config.logger import setup_logging
+from core.auth import AuthMiddleware, AuthenticationError
+from concurrent.futures import ThreadPoolExecutor
+from core.utils.voiceprint_provider import VoiceprintProvider
+from collections import deque
+from core.utils.dialogue import Dialogue
+from core.utils.prompt_manager import PromptManager
+
+
+
+TAG = __name__
+
+auto_import_modules("plugins_func.functions")
+
+class TTSSException(RuntimeError):
+    pass
+
+
+class ConnnectionHandler:
+    def __init_(
+        self,
+        config: Dict[str, Any],
+        _vad,
+        _asr,
+        _llm,
+        _memory,
+        _intent.
+        server=None,
+    ):
+        self.common_config = config
+        self.config = copy.deepcopy(config)
+        self.session_id = str(uuid.uuid4())
+        self.logger = setup_logging()
+        self.server = server # 保存server实例的引用
+
+        self.auth = AuthMiddleware(config)
+        self.need_bind = False
+        self.bind_code = None
+        self.read_config_from_api = self.config.get("read_config_from_api",False)
+
+        self.websocket = None
+        self.headers = None
+        self.device_id = None
+        self.client_ip = None
+        self.prompt = None
+        self.welcome_msg = None
+        self.max_output_size = 0
+        self.chat_history_conf = 0
+        self.audio_format = "opus"
+
+        # 客户端状态相关
+        self.client_abort = False
+        self.client_is_speaking = False
+        self.client_listen_mode = "auto"
+
+        # 线程任务相关
+        self.loop = asyncio.get_event_loop()
+        self.stop_event = threading.Event()
+        self.executor = ThreadPoolExecutor(max_workers=5)
+
+        # 添加上报线程池
+        self.report_queue = queue.Queue()
+        self.report_thread = None
+        # 未来可以通过修改此处,调节asr的上报和tts的上报,目前默认都开启
+        self.report_asr_enable = self.read_config_from_api
+        self.report_tts_enable = self.read_config_from_api
+
+        # 依赖的组件
+        self.vad = None
+        self.asr = None
+        self.tts = None
+        self._asr = _asr
+        self._vad = _vad
+        self.llm = _llm
+        self.memory = _memory
+        self.intent = _intent
+
+        # 为每个连接单独管理声纹识别
+        self.voiceprint_provider = None
+
+        # vad相关变量
+        self.client_audio_buffer = bytearray()
+        self.client_have_voice = False
+        self.last_activity_time = 0.0 # 统一的活动时间戳(毫秒)
+        self.client_voice_stop = False
+        self.client_voice_window = deque(maxlen=5)
+        self.last_is_voice = False
+
+         # asr相关变量
+        # 因为实际部署时可能会用到公共的本地ASR,不能把变量暴露给公共ASR
+        # 所以涉及到ASR的变量,需要在这里定义,属于connection的私有变量
+        self.asr_audio = []
+        self.asr_audio_queue = queue.Queue()
+
+        # llm相关变量
+        self.llm_finish_task = True
+        self.dialogue = Dialogue()
+
+        # tts相关变量
+        self.sentence_id = None
+        self.tts_MessageText = ""
+
+        # iot相关变量
+        self.iot_descriptors = {}
+        self.func_handler = None
+
+        self.cmd_exit = self.config["exit_commands"]
+        self.max_cmd_length = 0
+        for cmd in self.cmd_exit:
+            if len(cmd) > self.max_cmd_length:
+                self.max_cmd_length = len(cmd)
+
+        # 是否在聊天结束后关闭连接
+        self.close_after_chat = False
+        self.load_function_plugins = False
+        self.intent_type = "nointent"
+
+        self.timeout_seconds = (
+            int(self.config.get("close_connection_no_voice_time", 120)) + 60
+        ) # 在第一道语音结束后的60秒后,如果用户没有说话,则关闭连接
+        self.timeout_task = None
+
+        # {"mcp": true} 表示启用MCP功能
+        self.features = None
+
+        # 初始化提示词管理器
+        self.prompt_manager = PromptManager(config, self.logger)
+
+
+

+ 74 - 0
main/server/core/http_server.py

@@ -0,0 +1,74 @@
+import asyncio
+from aiohttp import web
+from config.logger import setup_logging
+from core.api.ota_handler import OTAHandler 
+from core.api.vision_handler import VisionHandler
+
+TAG = __name__
+
+
+class SimpleHttpServer:
+    def __init__(self,config:dict):
+        self.config = config
+        self.logger = setup_logging()
+        self.ota_handler = OTAHandler(config)
+        self.vision_handler = VisionHandler(config)
+
+    def _get_websocket_url(self, local_ip: str, port: int) -> str:
+        """
+        获取WebSocket URL
+
+        Args:
+            local_ip (str): 本地IP地址
+            port (int): 端口号
+
+        Returns:
+            str: WebSocket URL
+        """
+        server_config = self.config["server"]
+        websocket_config = server_config.get("websocket",{})
+
+        # 需要修改的点
+        if websocket_config and "你" not in websocket_config:
+            return websocket_config
+        else:
+            return f"ws://{local_ip}:{websocket_config['port']}"
+        
+    async def start(self):
+        server_config = self.config["server"]
+        host = server_config.get("ip", "0.0.0.0")
+        port = server_config.get("port", 8083)
+
+        if port:
+            app = web.Application()
+            read_config_from_api = server_config.get("read_config_from_api",False)
+
+            if not read_config_from_api:
+                # 如果没有开启智控台, 只是单模块运行, 就需要再添加简单的OTA接口, 用于下发websocket接口
+                app.add_routes(
+                    [
+                        web.get("/xiaozhi/ota", self.ota_handler.handle_get),
+                        web.post("/xiaozhi/ota", self.ota_handler.handle_post),
+                        web.options("/xiaozhi/ota", self.ota_handler.handle_post),
+                    ]
+                )
+
+            # 添加路由
+            app.add_routes(
+                [
+                    web.get("/mcp/vision/explain", self.vision_handler.handle_get),
+                    web.post("/mcp/vision/explain", self.vision_handler.handle_post),
+                    web.options("/mcp/vision/explain", self.vision_handler.handle_post),
+                ]
+            )
+
+            # 运行服务
+            runner = web.AppRunner(app)
+            await runner.setup()
+            site = web.TCPSite(runner, host, port)
+            await site.start()
+            
+
+            # 保持服务运行
+            while True:
+                await asyncio.sleep(3600) # 每隔1小时检查一次

+ 121 - 0
main/server/core/utils/dialogue.py

@@ -0,0 +1,121 @@
+import uuid
+import re
+from typing import List, Dict
+from datetime import datetime
+
+class Message:
+    def __init__(
+        self,
+        role: str,
+        content: str = None,
+        uniq_id: str = None,
+        tool_calls = None,
+        tool_call_id = None,
+    ):
+        self.unid_id = uniq_id if uniq_id is not None else str(uuid.uuid4())
+        self.role = role
+        self.content = content
+        self.tool_calls = tool_calls
+        self.tool_call_id = tool_call_id
+
+
+
+class Dialogue:
+    def __init__(self):
+        self.dialogue: List[Message] = []
+        # 获取当前时间
+        self.current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+
+    def put(self, message: Message):
+        self.dialogue.append(message)
+
+    def getMessages(self, m, dialogue):
+        if m.tool_Calls is not None:
+            dialogue.append({"role": m.role, "tool_calls": m.tool_calls})
+        elif m.role == "tool":
+            dialogue.append(
+                {
+                    "role": m.role,
+                    "tool_call_id": str(uuid.uuid64() if m.tool_call_id is None else m.tool_call_id),
+                    "content": m.content,
+                }
+            )
+        else:
+            dialogue.append(
+                {
+                    "role": m.role,
+                    "content": m.content,
+                }
+            )
+
+    def get_llm_dialogue(self) -> List[Dict[str, str]]:
+        # 直接调用get_llm_dialogue_with_memory,传入None作为memory_str
+        # 这样确保说话人功能在所有调用路径下都生效
+        return self.get_llm_dialogue_with_memory(None, None)
+    
+    def update_system_message(self, new_content: str):
+        """
+        更新系统消息
+
+        Args:
+            new_content (str): 新的系统消息内容
+        """
+        system_message = next((msg for msg in self.dialogue if msg.role == "system"), None)
+        if system_message:
+            system_message.content = new_content
+        else:
+            self.put(Message(role="system", content=new_content))
+
+    def get_llm_dialogue_with_memory(
+        self, memory_str: str = None, voiceprint_config: dict = None
+    ) -> List[Dict[str, str]]:
+        # 构建对话
+        dialogue = []
+
+        # 添加系统提示和记忆
+        system_message = next(
+            (msg for msg in self.dialogue if msg.role == "system"), None
+        )
+
+        if system_message:
+            # 基础系统提示
+            enhanced_system_prompt = system_message.content
+
+            # 添加说话人个性化描述
+            try:
+                speakers = voiceprint_config.get("speakers", [])
+                if speakers:
+                    enhanced_system_prompt += "\n\n<speakers_info>"
+                    for speaker_str in speakers:
+                        try:
+                            parts = speaker_str.split(",", 2)
+                            if len(parts) >= 2:
+                                name = parts[1].strip()
+                                # 如果描述为空,则为""
+                                description = (
+                                    parts[2].strip() if len(parts) >= 3 else ""
+                                )
+                                enhanced_system_prompt += f"\n- {name}:{description}"
+                        except:
+                            pass
+                    enhanced_system_prompt += "\n\n</speakers_info>"
+            except:
+                # 配置读取失败时忽略错误,不影响其他功能
+                pass
+
+            # 使用正则表达式匹配 <memory> 标签,不管中间有什么内容
+            if memory_str is not None:
+                enhanced_system_prompt = re.sub(
+                    r"<memory>.*?</memory>",
+                    f"<memory>\n{memory_str}\n</memory>",
+                    enhanced_system_prompt,
+                    flags=re.DOTALL,
+                )
+            dialogue.append({"role": "system", "content": enhanced_system_prompt})
+
+        # 添加用户和助手的对话
+        for m in self.dialogue:
+            if m.role != "system":  # 跳过原始的系统消息
+                self.getMessages(m, dialogue)
+
+        return dialogue

+ 85 - 0
main/server/core/utils/modules_initialize.py

@@ -0,0 +1,85 @@
+from typing import Dict, Any
+from config.logger import setup_logging
+from core.utils import tts, llm, intent, memory, vad, asr
+
+TAG = __name__
+logger = setup_logging()
+
+# *** 需要修改的点 ***
+def initialize_modules(
+        logger,
+        config: Dict[str, Any],
+        init_vad = False,
+        init_asr = False,
+        init_llm = False,
+        init_tts = False,
+        init_memory = False,
+        init_intent = False,
+) -> Dict[str, Any]:
+    """
+    初始化所有模块组件
+
+    Args:
+        config: 配置字典
+
+    Returns:
+        Dict[str, Any]: 初始化后的模块字典
+    """
+    modules = {}
+    
+    if init_tts:
+        select_tts_module = config["selected_modules"]["TTS"]
+        modules["tts"] = initialize_tts(config)
+        logger.bind(tag=TAG).info(f"初始化组件: tts成功{select_tts_module}")
+
+    # 初始化LLM模块
+    if init_llm:
+        select_llm_module = config["selected_modules"]["LLM"]
+        llm_type = (
+            select_llm_module
+            if "type" not in config["LLM"][select_llm_module]
+            else config["LLM"][select_llm_module]["type"]
+        )
+        modules["llm"] = llm.create_instance(
+            llm_type,
+            config["LLM"][select_llm_module],
+        )
+        logger.bind(tag=TAG).info(f"初始化组件: llm成功{select_llm_module}")
+
+    # 初始化Intent模块
+    if init_intent:
+        select_intent_module = config["selected_modules"]["Intent"]
+        intent_type = (
+            select_intent_module
+            if "type" not in config["Intent"][select_intent_module]
+            else config["Intent"][select_intent_module]["type"]
+        )
+        modules["intent"] = intent.create_instance(
+            intent_type,
+            config["Intent"][select_intent_module],
+        )
+        logger.bind(tag=TAG).info(f"初始化组件: intent成功{select_intent_module}")
+
+    # 初始化Memory模块
+    if init_memory:
+        select_memory_module = config["selected_modules"]["Memory"]
+        memory_type = (
+            select_memory_module
+            if "type" not in config["Memory"][select_memory_module]
+            else config["Memory"][select_memory_module]["type"]
+        )
+        modules["memory"] = memory.create_instance(
+            memory_type,
+            config["Memory"][select_memory_module],
+            config.get("summaryMemory", None),
+        )
+        logger.bind(tag=TAG).info(f"初始化组件: memory成功{select_memory_module}")
+        
+    # 初始化VAD模块
+    if init_vad:
+        select_vad_module = config["selected_modules"]["VAD"]
+        vad_type = (
+            select_vad_module
+            if "type" not in config["VAD"][select_vad_module]
+            else config["VAD"][select_vad_module]["type"]
+        )

+ 63 - 0
main/server/core/utils/prompt_manager.py

@@ -0,0 +1,63 @@
+"""
+系统提示词管理器模块,
+负责管理和更新系统提示词,包括快速初始化和异步增强功能
+"""
+
+import os
+import cnlunar
+from typing import Dict, Any
+from config.logger import setup_logging
+
+TAG = __name__
+
+WEEKDAY_MAP = {
+    "Monday": "星期一",
+    "Tuesday": "星期二",
+    "Wednesday": "星期三",
+    "Thursday": "星期四",
+    "Friday": "星期五",
+    "Saturday": "星期六",
+    "Sunday": "星期日",
+}
+
+EMOJI_List = [
+    "😶",
+    "🙂",
+    "😆",
+    "😂",
+    "😔",
+    "😠",
+    "😭",
+    "😍",
+    "😳",
+    "😲",
+    "😱",
+    "🤔",
+    "😉",
+    "😎",
+    "😌",
+    "🤤",
+    "😘",
+    "😏",
+    "😴",
+    "😜",
+    "🙄",
+]
+
+class PromptManager:
+    """系统提示词管理器,负责管理和更新系统提示词"""
+    def __init__(self, config: Dict[str, Any], logger):
+        self.config = config
+        self.logger = logger or setup_logging()
+        self.base_prompt_template = None
+        self.last_update_time = 0
+
+        # 导入全局缓存管理器
+        from core.utils.cathe.manager import cache_manager, CacheType
+
+        self.cache_manager = cache_manager
+        self.CacheType = CacheType
+        
+        self._load_base_template()
+
+    def 

+ 134 - 0
main/server/core/utils/voiceprint_provider.py

@@ -0,0 +1,134 @@
+import asyncio
+import json
+import time
+import aiohttp
+from urllib.parse import urlparse, parse_qs
+from typing import Optional, Dict
+from config.logger import setup_logging
+
+TAG = __name__
+logger = setup_logging()
+
+
+class VoiceprintProvider:
+    """声纹识别服务提供者"""
+    
+    def __init__(self, config: dict):
+        self.original_url = config.get("url", "")
+        self.speakers = config.get("speakers", [])
+        self.speaker_map = self._parse_speakers()
+        
+        # 解析API地址和密钥
+        self.api_url = None
+        self.api_key = None
+        self.speaker_ids = []
+        
+        if not self.original_url:
+            logger.bind(tag=TAG).warning("声纹识别URL未配置,声纹识别将被禁用")
+            self.enabled = False
+        else:
+            # 解析URL和key
+            parsed_url = urlparse(self.original_url)
+            base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+            
+            # 从查询参数中提取key
+            query_params = parse_qs(parsed_url.query)
+            self.api_key = query_params.get('key', [''])[0]
+            
+            if not self.api_key:
+                logger.bind(tag=TAG).error("URL中未找到key参数,声纹识别将被禁用")
+                self.enabled = False
+            else:
+                # 构造identify接口地址
+                self.api_url = f"{base_url}/voiceprint/identify"
+                
+                # 提取speaker_ids
+                for speaker_str in self.speakers:
+                    try:
+                        parts = speaker_str.split(",", 2)
+                        if len(parts) >= 1:
+                            speaker_id = parts[0].strip()
+                            self.speaker_ids.append(speaker_id)
+                    except Exception:
+                        continue
+                
+                # 检查是否有有效的说话人配置
+                if not self.speaker_ids:
+                    logger.bind(tag=TAG).warning("未配置有效的说话人,声纹识别将被禁用")
+                    self.enabled = False
+                else:
+                    self.enabled = True
+                    logger.bind(tag=TAG).info(f"声纹识别已配置: API={self.api_url}, 说话人={len(self.speaker_ids)}个")
+    
+    def _parse_speakers(self) -> Dict[str, Dict[str, str]]:
+        """解析说话人配置"""
+        speaker_map = {}
+        for speaker_str in self.speakers:
+            try:
+                parts = speaker_str.split(",", 2)
+                if len(parts) >= 3:
+                    speaker_id, name, description = parts[0].strip(), parts[1].strip(), parts[2].strip()
+                    speaker_map[speaker_id] = {
+                        "name": name,
+                        "description": description
+                    }
+            except Exception as e:
+                logger.bind(tag=TAG).warning(f"解析说话人配置失败: {speaker_str}, 错误: {e}")
+        return speaker_map
+    
+    async def identify_speaker(self, audio_data: bytes, session_id: str) -> Optional[str]:
+        """识别说话人"""
+        if not self.enabled or not self.api_url or not self.api_key:
+            logger.bind(tag=TAG).debug("声纹识别功能已禁用或未配置,跳过识别")
+            return None
+            
+        try:
+            api_start_time = time.monotonic()
+            
+            # 准备请求头
+            headers = {
+                'Authorization': f'Bearer {self.api_key}',
+                'Accept': 'application/json'
+            }
+            
+            # 准备multipart/form-data数据
+            data = aiohttp.FormData()
+            data.add_field('speaker_ids', ','.join(self.speaker_ids))
+            data.add_field('file', audio_data, filename='audio.wav', content_type='audio/wav')
+            
+            timeout = aiohttp.ClientTimeout(total=10)
+            
+            # 网络请求
+            async with aiohttp.ClientSession(timeout=timeout) as session:
+                async with session.post(self.api_url, headers=headers, data=data) as response:
+                    
+                    if response.status == 200:
+                        result = await response.json()
+                        speaker_id = result.get("speaker_id")
+                        score = result.get("score", 0)
+                        total_elapsed_time = time.monotonic() - api_start_time
+                        
+                        logger.bind(tag=TAG).info(f"声纹识别耗时: {total_elapsed_time:.3f}s")
+                        
+                        # 置信度检查
+                        if score < 0.5:
+                            logger.bind(tag=TAG).warning(f"声纹识别置信度较低: {score:.3f}")
+                        
+                        if speaker_id and speaker_id in self.speaker_map:
+                            result_name = self.speaker_map[speaker_id]["name"]
+                            return result_name
+                        else:
+                            logger.bind(tag=TAG).warning(f"未识别的说话人ID: {speaker_id}")
+                            return "未知说话人"
+                    else:
+                        logger.bind(tag=TAG).error(f"声纹识别API错误: HTTP {response.status}")
+                        return None
+                        
+        except asyncio.TimeoutError:
+            elapsed = time.monotonic() - api_start_time
+            logger.bind(tag=TAG).error(f"声纹识别超时: {elapsed:.3f}s")
+            return None
+        except Exception as e:
+            elapsed = time.monotonic() - api_start_time
+            logger.bind(tag=TAG).error(f"声纹识别失败: {e}")
+            return None

+ 44 - 0
main/server/core/websocket_server.py

@@ -0,0 +1,44 @@
+import asyncio
+import websockets
+from config.logger import setup_logging
+from core.utils.modules_initialize import initialize_modules
+from config.config_loader import get_config_from_api
+from core.connnection import ConnectionHandler
+
+TAG = __name__
+
+
+class WebSocketServer:
+    def __init__(self, config:dict):
+        self.config = config
+        self.logger = setup_logging()
+        self.config_lock = asyncio.Lock()
+        modules = initialize_modules(
+            self.logger,
+            self.config,
+            "VAD" in self.config["selected_modules"],
+            "ASR" in self.config["selected_modules"],
+            "LLM" in self.config["selected_modules"],
+            "TTS" in self.config["selected_modules"],
+            "Memory" in self.config["selected_modules"],
+            "Intent" in self.config["selected_modules"],
+        )
+        self.vad = modules["vad"] if "vad" in modules else None
+        self.asr = modules["asr"] if "asr" in modules else None
+        self.llm = modules["llm"] if "llm" in modules else None
+        self.intent = modules["intent"] if "intent" in modules else None
+        self.memory = modules["memory"] if "memory" in modules else None
+
+        self.active_connections = set()
+
+    async def start(self):
+        server_config = self.config["server"]
+        host = server_config.get("ip", "0.0.0.0")
+        port = server_config.get("port", 8080)
+
+        async with websockets.serve(self._handle_connection, host, port, process_request=self._process_request):
+            await asyncio.Future()
+
+    async def _handle_connection(self,websocket):
+        """处理新连接, 每次创建独立的ConnectionHandler"""
+        handler = ConnectionHandler(websocket, self.config, self.logger)