瀏覽代碼

请求处理

Zzcoded 5 天之前
父節點
當前提交
3534cf3e32

+ 17 - 0
main/server/core/api/base_handler.py

@@ -0,0 +1,17 @@
+from aiohttp import web
+from config.logger import setup_logger
+
+class BaseHandler:
+    def __init__(self, config: dict):
+        self.config = config
+        self.logger = setup_logger()
+
+    def _add_cors_headers(self, response: web.Response):
+        # - "Access-Control-Allow-Headers" 指定了允许哪些自定义请求头(如clinet-id、content-type、device-id)可以被前端发送到服务器。
+        # - "Access-Control-Allow-Credentials" 设置为"true",表示允许浏览器携带cookie等凭证信息进行跨域请求。
+        # - "Accsss-Control-Allow-Origin"(注意拼写应为"Access-Control-Allow-Origin")设置为"*",表示允许所有域名的前端访问该API。
+        response.headers["Access-Control-Allow-Headers"] = (
+            "clinet-id, content-type, device-id"
+        )
+        response.headers["Access-Control-Allow-Credentials"] = "true"
+        response.headers["Accsss-Control-Allow-Origin"] = "*"

+ 101 - 0
main/server/core/api/ota_handler.py

@@ -0,0 +1,101 @@
+import json
+import time
+from aiohttp import web
+from core.utils.util import get_local_ip
+from core.api.base_handler import BaseHandler
+
+TAG = __name__
+
+
+class OTAHandler(BaseHandler):
+    def __init__(self,config:dict):
+        super().__init__(config)
+
+    def _get_websocket_url(self,local_ip:str,port:int) -> str:
+        """获取websocket的url
+        
+        Args:
+            local_ip: 本地ip
+            port: 端口
+        
+        Returns:
+            websocket的url
+        """
+        server_config = self.config.get("server",{})
+        websocket_config = server_config.get("websocket",{})
+        
+        if not websocket_config:
+            self.logger.error(f"{TAG} websocket_config is not found")
+            return ""
+        websocket_url = websocket_config.get("url","")
+
+
+        """
+        后续优化点
+        """
+        if "你的" not in websocket_config:
+            return websocket_config
+        else:
+            return f"ws://{local_ip}:{port}/xiaozhi/v1/"
+
+    async def handle_post(self, request):
+        """处理post请求
+        
+        Args:
+            request: 请求对象
+        
+        Returns:
+            response: 响应对象
+        """
+        try:
+            data = await request.text()
+            self.logger.bind(tag=TAG).debug(f"OTA请求方法:{request.method}")
+            self.logger.bind(tag=TAG).debug(f"OTA请求数据:{data}")
+            self.logger.bind(tag=TAG).debug(f"OTA请求头:{request.headers}")
+            
+            device_id = request.headers.get("device_id","")
+            if not device_id:
+                raise web.HTTPBadRequest(text="OTA请求设备ID为空")
+            self.logger.bind(tag=TAG).info(f"OTA请求设备ID:{device_id}")
+           
+            data_json = json.loads(data)
+
+            server_config = self.config.get("server",{})
+            port = int(server_config.get("port",8080))
+            local_ip = get_local_ip()
+
+            return_json = {
+                "server_time": {
+                    "timestamp": int(round(time.time() * 1000)),
+                    "timezone_offset": server_config.get("timezone_offset",8) *60,
+                },
+                "fireware": {
+                    "version": data_json["application"].get("version","1.0.0"),
+                    "url": "",
+                },
+                "websocket": {
+                    "url": self._get_websocket_url(local_ip,port),
+                },
+            }
+            response = web.Response(
+                text = json.dumps(return_json, separators=(",",":")),
+                content_type = "application/json",
+            )
+        except Exception as e:
+            return_json = {"success": False, "message": "request error, " + str(e)}
+            response = web.Response(
+                text = json.dumps(return_json, separators=(",",":")),
+                content_type = "application/json",
+            )
+        finally:
+            self._add_cors_headers(response)
+            return response
+
+
+            
+
+
+
+
+
+    

+ 159 - 0
main/server/core/api/vision_handler.py

@@ -0,0 +1,159 @@
+import json
+import copy
+from aiohttp import web
+from config.logger import setup_logger
+from core.utils.auth import AuthToken
+from core.utils.util import is_valid_image
+from core.utils.vllm import create_llm_instance
+from config.config_loader import get_private_config_from_api
+from plugins_func.register import Action
+
+TAG = __name__
+
+MAX_FILE_SIZE = 5 * 1024 * 1024
+
+class VisionHandler:
+    def __init__(self, config:dict):
+        self.config = config
+        self.logger = setup_logger()
+        self.auth = AuthToken(config["server"]["auth_key"])
+
+    def _create_error_response(self,message:str) -> dict:
+        """
+        创建错误响应
+        Args:
+            message: 错误信息
+        Returns:
+            dict: 错误响应
+        """
+        return {"success": False, "message": message}
+
+    def _verify_auth_token(self,request) -> Tuple[bool, Optional[str]]:
+        """
+        验证认证token
+        Args:
+            request: 请求对象
+        Returns:
+            Tuple[bool, Optional[str]]: 验证结果和设备ID
+        """
+        auth_header = request.headers.get("Authorization","")
+        if not auth_header.startswith("Bearer "):
+            return False, None
+        
+        token = auth_header[7:]
+        return self.auth.verify_token(token)
+    
+    async def handle_post(self, request):
+        """
+        处理MCP Vision Post请求
+        Args:
+            request: 请求对象
+        Returns:
+            web.Response: 响应对象
+        """
+        response = None # 初始化response对象
+        try:
+            # 验证token
+            is_valid, token_device_id = self._verify_auth_token(request)
+            if not is_valid:
+                response = web.Response(
+                    text=json.dumps(
+                        self._create_error_response("无效认证的token或token已过期")
+                    ),
+                    content_type="application/json",
+                    status=401
+                )
+                return response
+            
+            # 获取请求头信息
+            device_id = request.headers.get("Device-Id","")
+            client_id = request.headers.get("Client_Id","")
+            if device_id != token_device_id:
+                raise ValueError("设备ID与token不匹配")
+            
+            reader = await request.multipart()
+
+            # 读取question字段
+            question_field = await reader.next()
+            if question_field is None:
+                raise ValueError("缺少question字段")
+            question = await question_field.text()
+            self.logger.bind(TAG).debug(f"Question:{question}")
+
+            # 读取图片文件
+            image_field = await reader.next()
+            if image_field is None:
+                raise ValueError("缺少图片文件")
+            image_data = await image_field.read()
+            if not image_data:
+                raise ValueError("图片文件为空")
+            if (len(image_data) > MAX_FILE_SIZE):
+                raise ValueError(f"图片文件大小超过限制,最大支持{MAX_FILE_SIZE/1024/1024}MB")
+            if not is_valid_image(image_data):
+                raise ValueError("图片文件格式不支持")
+            image_base64 = base64.b64encode(image_data).decode("utf-8")
+
+            current_config = copy.deepcopy(self.config)
+            read_config_from_api = current_config.get("read_config_from_api",False)
+            if read_config_from_api:
+                current_config = get_private_config_from_api(
+                    current_config,
+                    device_id,
+                    client_id,
+                )
+            select_vllm_module = current_config["selected_module"].get("VLLM")
+            if not select_vllm_module:
+                raise ValueError("您还未配置视觉分析模型")
+
+            vllm_type = (
+                select_vllm_module
+                if "type" not in current_config["VLLM"][select_vllm_module]
+                else current_config["VLLM"][select_vllm_module]["type"]
+            )
+
+            if not vllm_type:
+                raise ValueError(f"无法找到VLLM模块对应的供应器{vllm_type}")
+            
+            vllm = create_llm_instance(
+                vllm_type,
+                current_config["VLLM"][select_vllm_module]
+            )
+
+            result = vllm.response(question,image_base64)
+
+            return_json = {
+                "success": True,
+                "action": Action.RESPONSE.name,
+                "response": result,
+            }
+
+            response = web.Response(
+                text=json.dumps(return_json, separators=(",",":")),
+                content_type="application/json",
+            )
+        except ValueError as e:
+            self.logger.bind(tag=TAG).error(f"MCP Vision POST请求异常(e)")
+            return_json = self._create_error_response(str(e))
+            response = web.Response(
+                text=json.dumps(return_json, separators=(",",":")),
+                content_type="application/json",
+            )
+        except Exception as e:
+            self.logger.bind(tag=TAG).error(f"MCP Vision POST请求异常(e)")
+            return_json = self._create_error_response(str(e))
+            response = web.Response(
+                text=json.dumps(return_json, separators=(",",":")),
+                content_type="application/json",
+            )
+        finally:
+            if response:
+                self._add_cors_headers(response)
+            return response
+    
+    def _add_cors_headers(self, response):
+        """添加CORS头信息"""
+        response.headers["Access-Control-Allow-Headers"] = (
+            "client-id, content-type, device-id"
+        )
+        response.headers["Access-Control-Allow-Credentials"] = "true"
+        response.headers["Access-Control-Allow-Origin"] = "*"

+ 129 - 0
main/server/core/utils/auth.py

@@ -0,0 +1,129 @@
+import jwt
+import time
+import json 
+import os
+from datetime import datetime, timedelta, timezone
+from typing import Optional
+import base64
+from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
+from cryptography.hazmat.primitives import padding
+from cryptography.hazmat.backends import default_backend
+TAG = __name__
+import base64
+
+class AuthToken:
+    def __init__(self, secret_key: str):
+        self.secret_key = secret_key.encode() # 将密钥转换为字节
+        self.encryption_key = self.derive_key()
+
+    def derive_key(self, length: int = 32) -> bytes:
+        """ 派生固定长度的密钥 """
+        from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
+        from cryptography.hazmat.primitives import hashes
+        
+        # 使用固定盐值(实际生产环境用随机盐)
+        salt = b"xiaozhi_secret_key"
+        kdf = PBKDF2HMAC(
+            algorithm=hashes.SHA256(),
+            length=length,
+            salt=salt,
+            iterations=100000,  # 迭代次数
+            backend=default_backend()   # 加密算法的后端实现,通常用default_backend()即可。
+        )
+        return kdf.derive(self.secret_key) # 派生密钥
+    
+    def _encrypt_payload(self, payload: dict) -> str:
+        """使用AES-GCM加密整个payload"""
+        payload_json = json.dumps(payload)
+
+        # 生成随机IV
+        iv = os.urandom(12)
+        # 创建加密器
+        cipher = Cipher(
+            algorithms.AES(self.encryption_key),
+            modes.GCM(iv),
+            backend=default_backend(),
+        )
+        encryptor = cipher.encryptor()
+
+        # 加密并生成标签
+        ciphertext = encryptor.update(payload_json.encode()) + encryptor.finalize()
+        tag = encryptor.tag
+
+        # 组合 IV + 密文 + 标签
+        encrypted_data = iv + ciphertext + tag
+        return base64.urlsafe_b64encode(encrypted_data).decode()
+    
+    def _decrypt_payload(self, encrypted_data: str) -> dict:
+        """解密AES-GCM加密的payload"""
+        # 解码base64
+        data = base64.urlsafe_b64decode(encrypted_data.encode())
+        # 拆分组件
+        iv = data[:12]
+        tag = data[-16:]
+        ciphertext = data[12:-16]
+
+        # 创建解密器
+        cipher = Cipher(
+            algorithms.AES(self.encryption_key),
+            modes.GCM(iv, tag),
+            backend=default_backend(),
+        )
+        decryptor = cipher.decryptor()
+
+        # 解密
+        plaintext = decryptor.update(ciphertext) + decryptor.finalize()
+        return json.loads(plaintext.decode())
+    
+    def generate_token(self, device_id: str) -> str:
+        """
+        生成JWT token
+        Args:
+            device_id: 设备ID
+        Returns:
+            str: JWT token
+        """
+        # 设置过期时间为1小时后
+        expire_time = datetime.now(timezone.utc) + timedelta(hours=1)
+        payload = {"device_id": device_id,"exp": expire_time.timestamp()}
+        # 加密payload
+        encrypted_payload = self._encrypt_payload(payload)
+        # 组合外层payload
+        outer_payload = {"data": encrypted_payload}
+        # 使用HS256算法生成JWT token
+        token = jwt.encode(outer_payload, self.secret_key, algorithm="HS256")
+        return token
+    
+    def verify_token(self, token: str) -> Optional[dict]:
+        """
+        验证JWT token
+        Args:
+            token: JWT token
+        Returns:
+            Optional[dict]: 解密后的payload
+        """
+        try:
+            # 先验证外层JWT (签名和过期时间)
+            outer_payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
+            
+            # 解密内层payload
+            inner_payload = self._decrypt_payload(outer_payload["data"])
+
+            if inner_payload["exp"] < time.time():
+                return False, None
+            
+            return True, inner_payload["device_id"]
+
+        except jwt.InvalidTokenError:
+            return False, None
+        except json.JSONDecodeError:
+            return False, None
+        except Exception as e:
+            print(f"AuthToken verify_token error: {e}")
+            return False, None
+
+
+
+
+
+

+ 57 - 0
main/server/plugins_func/register.py

@@ -0,0 +1,57 @@
+from config.logger import setup_logging
+from enum import Enum
+
+TAG = __name__
+
+logger = setup_logging()
+
+class ToolType(Enum):
+    NONE = (1, "调用完工具后,不做其他操作")
+    WAIT = (2, "调用工具,等待函数返回")
+    CHANGE_SYS_PROMPT = (3, "修改系统提示词,切换角色性格和职责")
+    SYSTEM_CTL = (4, "系统控制, 影响正常的对话流程,如退出,播放音乐等,需要传递conn参数")
+    INT_CTL = (5, "IOT设备控制,需要传递conn参数")
+    MCP_CLIENT = (6, "MCP客户端")
+
+    def __init__(self,code,message):
+        self.code = code
+        self.message = message
+
+
+class Action(Enum):
+    ERROR = (-1, "错误")
+    NOTFOUND = (0, "没有找到函数")
+    NONE = (1,"啥也不干")
+    RESPONSE = (2, "直接响应") 
+    REQLLM = (3, "调用函数后再请求llm生成回复")
+
+    def __init__(self, code, message):
+        self.code = code
+        self.message = message
+
+
+class ActionResponse:
+    def __init__(self, action: Action, result=None, response=None):
+        self.action = action
+        self.result = result
+        self.response = response
+    
+
+class FunctionItem:
+    def __init__(self, name, description, func ,type):
+        self.name = name
+        self.description = description
+        self.func = func
+        self.type = type
+
+
+
+
+# 后续等用到再开发
+class DeviceTypeRegister:
+    """
+    设备类型注册表,用于管理IOT设备类型及其函数
+    """
+    
+    def __init__(self):
+        self.type_functions = {}   # type_signature -> {func_name: FunctionItem}