|
@@ -0,0 +1,159 @@
|
|
|
+import json
|
|
|
+import copy
|
|
|
+from aiohttp import web
|
|
|
+from config.logger import setup_logger
|
|
|
+from core.utils.auth import AuthToken
|
|
|
+from core.utils.util import is_valid_image
|
|
|
+from core.utils.vllm import create_llm_instance
|
|
|
+from config.config_loader import get_private_config_from_api
|
|
|
+from plugins_func.register import Action
|
|
|
+
|
|
|
+TAG = __name__
|
|
|
+
|
|
|
+MAX_FILE_SIZE = 5 * 1024 * 1024
|
|
|
+
|
|
|
+class VisionHandler:
|
|
|
+ def __init__(self, config:dict):
|
|
|
+ self.config = config
|
|
|
+ self.logger = setup_logger()
|
|
|
+ self.auth = AuthToken(config["server"]["auth_key"])
|
|
|
+
|
|
|
+ def _create_error_response(self,message:str) -> dict:
|
|
|
+ """
|
|
|
+ 创建错误响应
|
|
|
+ Args:
|
|
|
+ message: 错误信息
|
|
|
+ Returns:
|
|
|
+ dict: 错误响应
|
|
|
+ """
|
|
|
+ return {"success": False, "message": message}
|
|
|
+
|
|
|
+ def _verify_auth_token(self,request) -> Tuple[bool, Optional[str]]:
|
|
|
+ """
|
|
|
+ 验证认证token
|
|
|
+ Args:
|
|
|
+ request: 请求对象
|
|
|
+ Returns:
|
|
|
+ Tuple[bool, Optional[str]]: 验证结果和设备ID
|
|
|
+ """
|
|
|
+ auth_header = request.headers.get("Authorization","")
|
|
|
+ if not auth_header.startswith("Bearer "):
|
|
|
+ return False, None
|
|
|
+
|
|
|
+ token = auth_header[7:]
|
|
|
+ return self.auth.verify_token(token)
|
|
|
+
|
|
|
+ async def handle_post(self, request):
|
|
|
+ """
|
|
|
+ 处理MCP Vision Post请求
|
|
|
+ Args:
|
|
|
+ request: 请求对象
|
|
|
+ Returns:
|
|
|
+ web.Response: 响应对象
|
|
|
+ """
|
|
|
+ response = None # 初始化response对象
|
|
|
+ try:
|
|
|
+ # 验证token
|
|
|
+ is_valid, token_device_id = self._verify_auth_token(request)
|
|
|
+ if not is_valid:
|
|
|
+ response = web.Response(
|
|
|
+ text=json.dumps(
|
|
|
+ self._create_error_response("无效认证的token或token已过期")
|
|
|
+ ),
|
|
|
+ content_type="application/json",
|
|
|
+ status=401
|
|
|
+ )
|
|
|
+ return response
|
|
|
+
|
|
|
+ # 获取请求头信息
|
|
|
+ device_id = request.headers.get("Device-Id","")
|
|
|
+ client_id = request.headers.get("Client_Id","")
|
|
|
+ if device_id != token_device_id:
|
|
|
+ raise ValueError("设备ID与token不匹配")
|
|
|
+
|
|
|
+ reader = await request.multipart()
|
|
|
+
|
|
|
+ # 读取question字段
|
|
|
+ question_field = await reader.next()
|
|
|
+ if question_field is None:
|
|
|
+ raise ValueError("缺少question字段")
|
|
|
+ question = await question_field.text()
|
|
|
+ self.logger.bind(TAG).debug(f"Question:{question}")
|
|
|
+
|
|
|
+ # 读取图片文件
|
|
|
+ image_field = await reader.next()
|
|
|
+ if image_field is None:
|
|
|
+ raise ValueError("缺少图片文件")
|
|
|
+ image_data = await image_field.read()
|
|
|
+ if not image_data:
|
|
|
+ raise ValueError("图片文件为空")
|
|
|
+ if (len(image_data) > MAX_FILE_SIZE):
|
|
|
+ raise ValueError(f"图片文件大小超过限制,最大支持{MAX_FILE_SIZE/1024/1024}MB")
|
|
|
+ if not is_valid_image(image_data):
|
|
|
+ raise ValueError("图片文件格式不支持")
|
|
|
+ image_base64 = base64.b64encode(image_data).decode("utf-8")
|
|
|
+
|
|
|
+ current_config = copy.deepcopy(self.config)
|
|
|
+ read_config_from_api = current_config.get("read_config_from_api",False)
|
|
|
+ if read_config_from_api:
|
|
|
+ current_config = get_private_config_from_api(
|
|
|
+ current_config,
|
|
|
+ device_id,
|
|
|
+ client_id,
|
|
|
+ )
|
|
|
+ select_vllm_module = current_config["selected_module"].get("VLLM")
|
|
|
+ if not select_vllm_module:
|
|
|
+ raise ValueError("您还未配置视觉分析模型")
|
|
|
+
|
|
|
+ vllm_type = (
|
|
|
+ select_vllm_module
|
|
|
+ if "type" not in current_config["VLLM"][select_vllm_module]
|
|
|
+ else current_config["VLLM"][select_vllm_module]["type"]
|
|
|
+ )
|
|
|
+
|
|
|
+ if not vllm_type:
|
|
|
+ raise ValueError(f"无法找到VLLM模块对应的供应器{vllm_type}")
|
|
|
+
|
|
|
+ vllm = create_llm_instance(
|
|
|
+ vllm_type,
|
|
|
+ current_config["VLLM"][select_vllm_module]
|
|
|
+ )
|
|
|
+
|
|
|
+ result = vllm.response(question,image_base64)
|
|
|
+
|
|
|
+ return_json = {
|
|
|
+ "success": True,
|
|
|
+ "action": Action.RESPONSE.name,
|
|
|
+ "response": result,
|
|
|
+ }
|
|
|
+
|
|
|
+ response = web.Response(
|
|
|
+ text=json.dumps(return_json, separators=(",",":")),
|
|
|
+ content_type="application/json",
|
|
|
+ )
|
|
|
+ except ValueError as e:
|
|
|
+ self.logger.bind(tag=TAG).error(f"MCP Vision POST请求异常(e)")
|
|
|
+ return_json = self._create_error_response(str(e))
|
|
|
+ response = web.Response(
|
|
|
+ text=json.dumps(return_json, separators=(",",":")),
|
|
|
+ content_type="application/json",
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.bind(tag=TAG).error(f"MCP Vision POST请求异常(e)")
|
|
|
+ return_json = self._create_error_response(str(e))
|
|
|
+ response = web.Response(
|
|
|
+ text=json.dumps(return_json, separators=(",",":")),
|
|
|
+ content_type="application/json",
|
|
|
+ )
|
|
|
+ finally:
|
|
|
+ if response:
|
|
|
+ self._add_cors_headers(response)
|
|
|
+ return response
|
|
|
+
|
|
|
+ def _add_cors_headers(self, response):
|
|
|
+ """添加CORS头信息"""
|
|
|
+ response.headers["Access-Control-Allow-Headers"] = (
|
|
|
+ "client-id, content-type, device-id"
|
|
|
+ )
|
|
|
+ response.headers["Access-Control-Allow-Credentials"] = "true"
|
|
|
+ response.headers["Access-Control-Allow-Origin"] = "*"
|