Browse Source

fix :新增任务类型 方向识别、车辆+方向识别

zzf 4 months ago
parent
commit
7c0c3b9e4c

+ 21 - 3
AIT/api/record.go

@@ -26,9 +26,12 @@ func GetRecord() *RecordService {
 // @Summary 新增测试
 // @Description 新增测试
 // @Tags 测试记录
-// @Param recordFile  formData  file true "文件"
+// @Param file  formData  file true "文件"
 // @Param userName  formData  string true "用户名"
 // @Param fileType  formData  string true "文件类型"
+// @Param positiveDir formData string false "方向"
+// @Param negativaDir formData string false "方向"
+// @Param taskType formData string true "car_classification :车辆识别 ; car_direction:方向; class_direc =car_classification+car_direction;"
 // @Accept  json
 // @Produce  json
 // @Success 200 {string} string	"ok"
@@ -48,11 +51,24 @@ func (f *RecordService) SaveRecord(c *gin.Context) {
 		response.Failed(errmsg.ParamInvalid+",用户名为空", c)
 		return
 	}
-	fileType := c.PostForm("fileType")
 
+	fileType := c.PostForm("fileType")
+	positiveDir := c.PostForm("positiveDir")
+	negativaDir := c.PostForm("negativaDir")
+	taskType := c.PostForm("taskType")
 	message["msg_type"] = fileType
 	message["Sender"] = "server"
 	message["Recipient"] = "client"
+	message["positive_dir"] = positiveDir
+	message["negativa_dir"] = negativaDir
+	message["task_type"] = taskType
+	if (taskType == "car_direction" || taskType == "class_direc") && (len(positiveDir) == 0 || len(negativaDir) == 0) {
+		response.Failed(errmsg.ParamInvalid+",方向为空", c)
+		return
+	}
+	//task_type 类型:car_classification :车辆识别  car_direction方向 class_direc =car_classification+car_direction
+	//positive_dir
+	//negativa_dir
 	//添加记录
 	var record models.Record
 	//文件格式限制
@@ -90,7 +106,9 @@ func (f *RecordService) SaveRecord(c *gin.Context) {
 	record.FileType = fileType
 	record.Id = common.GenerateUUID()
 	record.FileUrl = fileURL
-
+	record.TaskType = taskType
+	record.PositiveDir = positiveDir
+	record.NegativaDir = negativaDir
 	message["recordId"] = record.Id
 	//invoke AI
 

+ 20 - 1
AIT/docs/docs.go

@@ -74,7 +74,7 @@ const docTemplate = `{
                     {
                         "type": "file",
                         "description": "文件",
-                        "name": "recordFile",
+                        "name": "file",
                         "in": "formData",
                         "required": true
                     },
@@ -91,6 +91,25 @@ const docTemplate = `{
                         "name": "fileType",
                         "in": "formData",
                         "required": true
+                    },
+                    {
+                        "type": "string",
+                        "description": "方向",
+                        "name": "positiveDir",
+                        "in": "formData"
+                    },
+                    {
+                        "type": "string",
+                        "description": "方向",
+                        "name": "negativaDir",
+                        "in": "formData"
+                    },
+                    {
+                        "type": "string",
+                        "description": "car_classification :车辆识别 ; car_direction:方向; class_direc =car_classification+car_direction;",
+                        "name": "taskType",
+                        "in": "formData",
+                        "required": true
                     }
                 ],
                 "responses": {

+ 20 - 1
AIT/docs/swagger.json

@@ -65,7 +65,7 @@
                     {
                         "type": "file",
                         "description": "文件",
-                        "name": "recordFile",
+                        "name": "file",
                         "in": "formData",
                         "required": true
                     },
@@ -82,6 +82,25 @@
                         "name": "fileType",
                         "in": "formData",
                         "required": true
+                    },
+                    {
+                        "type": "string",
+                        "description": "方向",
+                        "name": "positiveDir",
+                        "in": "formData"
+                    },
+                    {
+                        "type": "string",
+                        "description": "方向",
+                        "name": "negativaDir",
+                        "in": "formData"
+                    },
+                    {
+                        "type": "string",
+                        "description": "car_classification :车辆识别 ; car_direction:方向; class_direc =car_classification+car_direction;",
+                        "name": "taskType",
+                        "in": "formData",
+                        "required": true
                     }
                 ],
                 "responses": {

+ 14 - 1
AIT/docs/swagger.yaml

@@ -36,7 +36,7 @@ paths:
       parameters:
       - description: 文件
         in: formData
-        name: recordFile
+        name: file
         required: true
         type: file
       - description: 用户名
@@ -49,6 +49,19 @@ paths:
         name: fileType
         required: true
         type: string
+      - description: 方向
+        in: formData
+        name: positiveDir
+        type: string
+      - description: 方向
+        in: formData
+        name: negativaDir
+        type: string
+      - description: car_classification :车辆识别 ; car_direction:方向; class_direc =car_classification+car_direction;
+        in: formData
+        name: taskType
+        required: true
+        type: string
       produces:
       - application/json
       responses:

+ 1 - 1
AIT/go.mod

@@ -3,11 +3,11 @@ module AIT
 go 1.18
 
 require (
-	github.com/codyguo/godaemon v0.0.0-20200413142854-c36b39fdd071
 	github.com/gin-gonic/gin v1.9.1
 	github.com/golang-jwt/jwt v3.2.2+incompatible
 	github.com/google/uuid v1.1.2
 	github.com/gorilla/websocket v1.5.0
+	github.com/shopspring/decimal v1.3.1
 	github.com/spf13/viper v1.16.0
 	github.com/swaggo/files v1.0.1
 	github.com/swaggo/gin-swagger v1.6.0

+ 2 - 2
AIT/go.sum

@@ -58,8 +58,6 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk
 github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
 github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
 github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
-github.com/codyguo/godaemon v0.0.0-20200413142854-c36b39fdd071 h1:tbnIzvu8FEN9+zlJWxoNfPHeSYeAeAPSAQO7Q5ayQUE=
-github.com/codyguo/godaemon v0.0.0-20200413142854-c36b39fdd071/go.mod h1:RDz1idHRmqQc5EUSvX8/YJjOvGno6wqfCvt8ZOjDii8=
 github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@@ -222,6 +220,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
 github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
 github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
 github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
+github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
+github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
 github.com/spf13/afero v1.9.5 h1:stMpOSZFs//0Lv29HduCmli3GUfpFoF3Y1Q/aXj/wVM=
 github.com/spf13/afero v1.9.5/go.mod h1:UBogFpq8E9Hx+xc5CNTTEpTnuHVmXDwZcZcE1eb/UhQ=
 github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA=

+ 1 - 1
AIT/initialize/run.go

@@ -1,7 +1,7 @@
 package initialize
 
 func Run() {
-	//SwagInit()
+	SwagInit()
 	LoadConfig()
 	Mysql()
 	go Manager.Start()

+ 1 - 1
AIT/initialize/swag.go

@@ -7,7 +7,7 @@ import (
 	"os/exec"
 )
 
-// SwagInit 初始化swag
+// SwagInit 初始化swag http://localhost:8000/swagger/index.html
 func SwagInit() {
 	cmd := exec.Command("swag", "init")
 	fmt.Println("Cmd", cmd.Args)

+ 17 - 7
AIT/initialize/ws.go

@@ -28,11 +28,12 @@ type Client struct {
 
 // Message is return msg
 type Message struct {
-	Sender    string `json:"sender,omitempty"`
-	Recipient string `json:"recipient,omitempty"`
-	Content   string `json:"content,omitempty"`
-	RecordId  string `json:"recordId,omitempty"`
-	Status    bool   `json:"status,omitempty"`
+	Sender           string `json:"sender,omitempty"`
+	Recipient        string `json:"recipient,omitempty"`
+	CarContent       string `json:"car_content,omitempty"`
+	DirectionContent string `json:"direction_content,omitempty"`
+	RecordId         string `json:"recordId,omitempty"`
+	Status           bool   `json:"status,omitempty"`
 }
 
 // Manager define a ws server manager
@@ -85,9 +86,18 @@ func (manager *ClientManager) Start() {
 
 						content := ""
 						if MessageStruct.Status {
-							content = errmsg.TestResultSuccess + MessageStruct.Content
+							if len(MessageStruct.DirectionContent) > 0 {
+								content = errmsg.TestResultSuccess + MessageStruct.CarContent + ";" + MessageStruct.DirectionContent
+							} else {
+								content = errmsg.TestResultSuccess + MessageStruct.CarContent
+							}
+
 						} else {
-							content = errmsg.TestResultFailed + MessageStruct.Content
+							if len(MessageStruct.DirectionContent) > 0 {
+								content = errmsg.TestResultSuccess + MessageStruct.CarContent + ";" + MessageStruct.DirectionContent
+							} else {
+								content = errmsg.TestResultSuccess + MessageStruct.CarContent
+							}
 						}
 						recordService.UpdateRecordById(MessageStruct.RecordId, content)
 					}

+ 17 - 1
AIT/main.go

@@ -1,6 +1,9 @@
 package main
 
-import "AIT/initialize"
+import (
+	"AIT/initialize"
+	"os"
+)
 
 //@title AIT-go AI测试
 //@version 1.0
@@ -10,4 +13,17 @@ import "AIT/initialize"
 //@contact.url https://www.cnblogs.com/wormworm/
 func main() {
 	initialize.Run()
+	//var f1 = float64(5000)
+	//var f2 = 6757.94
+	//decimalValue := decimal.NewFromFloat(f1).Add(decimal.NewFromFloat(f2))
+	//fmt.Println(decimalValue.Float64())
+	path := "./logs/2023-11-30/"
+	err := os.MkdirAll(path, os.ModePerm)
+	if err != nil {
+		panic(err)
+		return
+	}
+	file, _ := os.OpenFile(path+"a.log", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
+	defer file.Close()
+	file.WriteString("lksjdlsdklas")
 }

+ 12 - 9
AIT/models/record.go

@@ -9,15 +9,18 @@ type RecordCreateParam struct {
 }
 
 type Record struct {
-	Id         string `form:"id"`
-	UserName   string `gorm:"user_name"`
-	FileType   string `gorm:"file_type"`
-	FilePath   string `gorm:"file_path"`
-	FileName   string `gorm:"file_name"`
-	FileUrl    string `gorm:"file_url"`
-	CreateTime string `gorm:"create_time"`
-	Result     string `gorm:"result"`
-	UpdateTime string `gorm:"update_time"`
+	Id          string `form:"id"`
+	UserName    string `gorm:"user_name"`
+	FileType    string `gorm:"file_type"`
+	FilePath    string `gorm:"file_path"`
+	FileName    string `gorm:"file_name"`
+	FileUrl     string `gorm:"file_url"`
+	CreateTime  string `gorm:"create_time"`
+	TaskType    string `gorm:"task_type"`
+	PositiveDir string `gorm:"positive_dir"`
+	NegativaDir string `gorm:"negativa_dir"`
+	Result      string `gorm:"result"`
+	UpdateTime  string `gorm:"update_time"`
 }
 
 type RecordListParam struct {

+ 7 - 1
classify-ai-backend/algo/car_cf.py

@@ -27,4 +27,10 @@ def car_classification(file_path):
         status = True
     else:
         status = False
-    return (status, car_name)
+    direction = 'null'
+    return (status, car_name,direction)
+
+if __name__ == "__main__":
+    file_path = 'D:\\Desktop\小组\\声音项目_wx\\数据\\1\\2.txt'
+    result = car_classification(file_path)
+    print('-----------done----------',result)

+ 67 - 34
classify-ai-backend/main.py

@@ -7,19 +7,35 @@ import config
 from loguru import logger
 import json
 from algo.car_cf import car_classification
+from algo.car_direction import car_dircetion
+from algo.class_direction import class_direction
+
 
 logger.add(config.log_path, rotation="50 MB", encoding='utf-8')
 
-CONNECT_TYPE = 'connect'
-ERROR_TYPE = 'error'
-CLOSE_TYPE = 'close'
 XLS_TYPE = 'xls'
 XLSX_TYPE = 'xlsx'
-TXT_TYPE = 'txt'
+TXT_TYPE ='txt'
 WAV_TYPE = 'wav'
-JSON_TYPE = 'json'
+JSON_TYPE ='json'
 NPY_TYPE = 'npy'
 
+handlers = {
+XLS_TYPE : 'xls',
+XLSX_TYPE : 'xlsx',
+TXT_TYPE : 'txt',
+WAV_TYPE : 'wav',
+JSON_TYPE :'json',
+NPY_TYPE : 'npy'
+}
+
+
+CLOSE_TYPE  = 'close'
+ERROR_TYPE = 'error'
+CONNECT_TYPE = 'connect'
+CAR_CLASS = 'car_classification'
+CAR_DIREC = 'car_direction'
+CLASS_DIREC = 'class_direc'
 
 class MessageOperation:
 
@@ -30,39 +46,54 @@ class MessageOperation:
         try:
             ask = json.loads(msg)
             self.msg_type = ask['msg_type']
-            self.msg_type = ask['msg_type']
+            self.positive_dir = ask['positive_dir']
+            self.negativa_dir = ask['negativa_dir']
+            self.task_type = ask['task_type']
             self.record_id = ask['recordId']
             self.msg_content = ask['content']
-            if self.msg_type == XLS_TYPE:
-                self.response_status, self.response_content = car_classification(
-                    self.msg_content)
-            elif self.msg_type == XLSX_TYPE:
-                self.response_status, self.response_content = car_classification(
-                    self.msg_content)
-            elif self.msg_type == TXT_TYPE:
-                self.response_status, self.response_content = car_classification(
-                    self.msg_content)
-            elif self.msg_type == WAV_TYPE:
-                self.response_status, self.response_content = car_classification(
-                    self.msg_content)
-            elif self.msg_type == JSON_TYPE:
-                self.response_status, self.response_content = car_classification(
-                    self.msg_content)
-            elif self.msg_type == NPY_TYPE:
-                self.response_status, self.response_content = car_classification(
-                    self.msg_content)
-            elif self.msg_type == CLOSE_TYPE:
-                self.response_status = True
-                self.response_content = "close success"
-            else:
-                self.response_status = False
-                self.response_content = "无法识别"
-                # self.record_id = "null"
+            if self.task_type == CAR_CLASS:
+                if self.msg_type in handlers:
+                    self.response_status,self.response_car_content,self.response_direction_content = car_classification(
+                        self.msg_content)
+                elif self.msg_type == CLOSE_TYPE:
+                    self.response_status = True
+                    self.response_car_content = "close success"
+                    self.response_direction_content = 'close success'
+                else:
+                    self.response_status = False
+                    self.response_car_content = "无法识别"
+                    self.response_direction_content = '无法识别'
+                    # self.record_id = "null"
+            elif self.task_type == CAR_DIREC:
+                if self.msg_type in handlers:
+                    self.response_status, self.response_direction_content, self.response_car_content = car_dircetion(
+                        self.msg_content,self.positive_dir,self.negativa_dir)
+                elif self.msg_type == CLOSE_TYPE:
+                    self.response_status = True
+                    self.response_direction_content = "close success"
+                    self.response_car_content = "close success"
+                else:
+                    self.response_status = False
+                    self.response_direction_content = "无法识别"
+                    self.response_car_content = "无法识别"
+            elif self.task_type == CLASS_DIREC:
+                if self.msg_type in handlers:
+                    self.response_status, self.response_car_content,self.response_direction_content = class_direction(
+                        self.msg_content,self.positive_dir,self.negativa_dir)
+                elif self.msg_type == CLOSE_TYPE:
+                    self.response_status = True
+                    self.response_car_content = "close success"
+                    # self.response_direction_content = "close success"
+                else:
+                    self.response_status = False
+                    self.response_car_content = "无法识别"
+                    # self.response_direction_content = "无法识别"
+                
         except Exception as e:
             self.msg_type = ERROR_TYPE
             self.response_status = False
             self.response_content = str(e)
-            # self.record_id = "null"
+
 
     def pack(self) -> str:
         return json.dumps({
@@ -71,7 +102,8 @@ class MessageOperation:
             'recordId': self.record_id,
             'status': self.response_status,
             'msg_type': self.msg_type,
-            'content': self.response_content,
+            'car_content': self.response_car_content,
+            'direction_content': self.response_direction_content,
             'time': str(datetime.datetime.now())
         })
 
@@ -88,7 +120,8 @@ async def main():
             'status': True,
             'msg_type': CONNECT_TYPE,
             'pid': pid,
-            'content': '',
+            'car_content': '',
+            'direction_content': '',
             'time': str(datetime.datetime.now())
         })
         logger.debug(f">>> {greeting}")