瀏覽代碼

主程序流程

Zzcoded 5 天之前
父節點
當前提交
897ad177b3
共有 5 個文件被更改,包括 316 次插入0 次删除
  1. 136 0
      main/server/app.py
  2. 22 0
      main/server/core/handle/abortHandle.py
  3. 89 0
      main/server/core/utils/util.py
  4. 32 0
      main/server/core/utils/vllm.py
  5. 37 0
      main/server/requirements.txt

+ 136 - 0
main/server/app.py

@@ -0,0 +1,136 @@
+import sys
+import uuid
+import signal
+import asyncio
+from aioconsole import ainput
+from config.settings import load_config
+from config.logger import setup_logging
+from core.utils.util import check_ffmpeg_installed, get_local_ip, validate_mcp_endpoint
+from core.websocket.server import WebSocketServer
+from core.simple_http.server import SimpleHttpServer
+
+TAG = __name__
+logger = setup_logging()
+
+async def wait_for_exit() -> None:
+    """
+    阻塞直到收到 Ctrl-C / SIGINT / SIGTERM 信号
+    - Unix: 使用 add_signal_handler 注册信号处理
+    - Windows: 使用 signal.pause() 等待信号
+    """
+    loop = asyncio.get_running_loop()
+    stop_event = asyncio.Event()
+
+    if sys.platform != "win32":  # Unix / macOS
+        for sig in (signal.SIGINT, signal.SIGTERM):
+            loop.add_signal_handler(sig, stop_event.set)
+        await stop_event.wait()
+    else:
+        # Windows:await一个永远pending的fut,
+        # 让 KeyboardInterrupt 冒泡到 asyncio.run,以此消除遗留普通线程导致进程退出阻塞的问题
+        try:
+            await asyncio.Future()
+        except KeyboardInterrupt:  # Ctrl‑C
+            pass
+
+
+async def monitor_stdin():
+    """监控标准输入,消费回车键"""
+    while True:
+        await ainput() # 异步等待输入, 消费回车
+    
+async def main():
+    check_ffmpeg_installed()
+    config = load_config()
+
+    # 默认使用manager-api的secret作为auth_key
+    # 如果secret为空,则生成随机密钥
+    # auth_key用于jwt认证,比如视觉分析接口的jwt认证
+    auth_key = config.get("manager-api",{}).get("secret","")
+    if not auth_key or len(auth_key)  == 0 or "你" in auth_key:
+        auth_key = str(uuid.uuid4().hex)
+    config["server"]["auth_key"] = auth_key
+
+    # 添加 stdin 监控任务
+    stdin_task = asyncio.create_task(monitor_stdin())
+
+    # 启动WebSocket 服务器
+    ws_server = WebSocketServer(config)
+    ws_task = asyncio.create_task(ws_server.start())
+
+    # 启动simple http 服务器
+    ota_server = SimpleHttpServer(config)
+    ota_task = asyncio.create_task(ota_server.start())
+
+    read_config_from_api = config.get("read_config_from_api",False)
+    port = int(config.get("server",{}).get("port",8003))
+    if not read_config_from_api:
+        logger.bind(tag=TAG).info(
+            "OTA接口是\t\thttp://{}:{}/xiaozhi/ota/",
+            get_local_ip(),
+            port,
+        )
+    logger.bind(tag=TAG).info(
+        "视觉分析接口是\thttp://{}:{}/mcp/vision/explain",
+        get_local_ip(),
+        port,
+    )
+    mcp_endpoint = config.get("mcp_endpoint",None)
+    if mcp_endpoint is not None and "你" not in mcp_endpoint:
+        if validate_mcp_endpoint(mcp_endpoint):
+            logger.bind(tag=TAG).info("MCP端点是\t\t{}",mcp_endpoint)
+            # 将mcp计入点地址转为调用点
+            mcp_endpoint = mcp_endpoint.replace("/mcp/","/mcp/call/")
+            config["mcp_endpoint"] = mcp_endpoint
+        else:
+            logger.bind(tag=TAG).error("MCP端点无效,请检查配置")
+            config["mcp_endpoint"] = "你的接入点 websocket地址"
+        
+    # 获取websocket配置
+    websocket_port = 8000
+    sever_config = config.get("server",{})
+    if isinstance(sever_config,dict):
+        websocket_port = int(sever_config.get("websocket_port",8000))
+
+    logger.bind(tag=TAG).info(
+        "WebSocket地址是\t\tws://{}:{}",
+        get_local_ip(),
+        websocket_port,
+    )
+
+    logger.bind(tag=TAG).info(
+        "=======上面的地址是websocket协议地址,请勿用浏览器访问======="
+    )
+    logger.bind(tag=TAG).info(
+        "如想测试websocket请用谷歌浏览器打开test目录下的test_page.html"
+    )
+    logger.bind(tag=TAG).info(
+        "=============================================================\n"
+    )
+
+    try:
+        await wait_for_exit()  # 阻塞直到收到退出信号
+    except asyncio.CancelledError:
+        print("收到退出信号,准备退出")
+    finally:
+        # 取消所有任务
+        stdin_task.cancel()
+        ws_task.cancel()
+        if ota_task:
+            ota_task.cancel()
+
+        # 等待所有任务完成
+        await asyncio.wait(
+            [stdin_task, ws_task, ota_task] if ota_task else [stdin_task, ws_task],
+            timeout = 3,
+            return_when = asyncio.ALL_COMPLETED,
+        )
+        print("服务器已关闭,程序退出。")
+
+
+
+if __name__ == "__main__":
+    try:
+        asyncio.run(main())
+    except KeyboardInterrupt:
+        print("手动中断,程序终止。")

+ 22 - 0
main/server/core/handle/abortHandle.py

@@ -0,0 +1,22 @@
+import json
+
+TAG = __name__
+
+async def handleAbortMessage(conn):
+    """
+    处理终止消息
+    Args:
+        conn: 连接对象
+    Returns:
+        None
+    """
+    conn.logger.bind(TAG).info("收到终止消息,准备关闭连接")
+    # 设置为打断状态,会自动打断llm,tts任务
+    conn.client_abort = True
+    conn.clear_queues()
+    # 打断客户端说话状态
+    await conn.websocket.send(
+        json.dumps("type": "tts", "state": "stop","session_id": conn.session_id)
+    )
+    conn.clearSpeakStatus()
+    conn.logger.bind(TAG).info("连接已关闭")

+ 89 - 0
main/server/core/utils/util.py

@@ -0,0 +1,89 @@
+import json
+import socket
+import subprocess
+
+
+TAG = __name__
+
+def is_valid_image(file_data:bytes) -> bool:
+    """
+    验证图片是否有效
+    Args:
+        file_data: 文件数据
+    Returns:
+        bool: 是否有效
+    """
+    # 常见图片格式的魔数(文件头)
+    image_signatures = {
+        b"\xff\xd8\xff": "JPEG",
+        b"\x89PNG\r\n\x1a\n": "PNG",
+        b"GIF87a": "GIF",
+        b"GIF89a": "GIF",
+        b"BM": "BMP",
+        b"II*\x00": "TIFF",
+        b"MM\x00*": "TIFF",
+        b"RIFF": "WEBP",
+    }
+    
+    for signatures in image_signatures:
+        if file_data.startswith(signatures):
+            return True
+        
+    return False
+
+
+def check_ffmpeg_installed():
+    """
+    检查ffmpeg是否安装
+    """
+    ffmpeg_installed = False
+    try:
+        # 执行ffmpeg -version,并捕获输出
+        result = subprocess.run(
+            ["ffmpeg", "-version"],
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            text=True,
+            check=True, # 如果返回码非零抛出异常
+        )
+        # 检查输出中是否包含版本信息(可选)
+        output = result.stdout + result.stderr
+        if "ffmpeg version" in output.lower():
+            ffmpeg_installed = True
+        return ffmpeg_installed
+    except subprocess.CalledProcessError:
+        ffmpeg_installed = False
+    if not ffmpeg_installed:
+        raise RuntimeError("ffmpeg未安装,请先安装ffmpeg")
+    
+def validate_mcp_endpoint(mcp_endpoint:str) -> bool:
+    """
+    验证mcp端点是否有效
+    Args:
+        endpoint: mcp端点
+    Returns:
+        bool: 是否有效
+    """
+    if not mcp_endpoint.startswith("ws"):
+        return False
+    if "key" in mcp_endpoint.lower() or "call" in mcp_endpoint.lower():
+        return False
+    if "/mcp/" not in mcp_endpoint:
+        return False
+    
+    return True
+
+def get_local_ip() -> str:
+    """
+    获取本地IP地址
+    Returns:
+        str: 本地IP地址
+    """
+    try:
+        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) # 创建IPV4 UDP socket
+        s.connect(("8.8.8.8", 80))
+        local_ip = s.getsockname()[0]
+        s.close()
+        return local_ip
+    except Exception as e:
+        return "127.0.0.1"

+ 32 - 0
main/server/core/utils/vllm.py

@@ -0,0 +1,32 @@
+import os
+import sys
+
+# 增添项目根目录到python路径
+current_dir = os.path.dirname(os.path.abspath(__file__))
+project_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
+sys.path.insert(0, project_root)
+
+from config.logger import setup_logging
+import importlib
+logger = setup_logging()
+
+
+def create_llm_instance(class_name:str, *args, **kwargs):
+    """
+    创建LLM实例
+    Args:
+        class_name: 类名
+        *args: 参数
+        **kwargs: 关键字参数
+    Returns:
+        object: LLM实例
+    """
+    if os.path.exists(os.path.join("core","providers","vllm",f"{class_name}.py")):
+        lib_name = f"core.providers.vllm.{class_name}"
+        if lib_name not in sys.modules:
+            sys.modules[lib_name] = importlib.import_module(f"{lib_name}")
+        return sys.modules[lib_name].VLLMProvider(*args, **kwargs)
+    
+    raise ValueError(f"无法找到VLLM模块{class_name}")
+
+

+ 37 - 0
main/server/requirements.txt

@@ -0,0 +1,37 @@
+pyyml==0.0.2
+torch==2.2.2
+silero_vad==5.1.2
+websockets==14.2
+opuslib_next==1.1.2
+numpy==1.26.4
+pydub==0.25.1
+funasr==1.2.3
+torchaudio==2.2.2
+openai==1.61.0
+google-generativeai==0.8.4
+edge_tts==7.0.0
+httpx==0.27.2
+aiohttp==3.9.3
+aiohttp_cors==0.7.0
+ormsgpack==1.7.0
+ruamel.yaml==0.18.10
+loguru==0.7.3
+requests==2.32.3
+cozepy==0.12.0
+mem0ai==0.1.62
+bs4==0.0.2
+modelscope==1.23.2
+sherpa_onnx==1.12.4
+mcp==1.8.1
+cnlunar==0.2.0
+PySocks==1.7.1
+dashscope==1.23.1
+baidu-aip==4.16.13
+chardet==5.2.0
+aioconsole==0.8.1
+markitdown==0.1.1
+mcp-proxy==0.8.0
+PyJWT==2.8.0
+psutil==7.0.0
+portalocker==2.10.1
+Jinja2==3.1.6