123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- '''
- Text detection model: https://github.com/argman/EAST
- Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
- CRNN Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
- How to convert from pb to onnx:
- Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
- More converted onnx text recognition models can be downloaded directly here:
- Download link: https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing
- And these models taken from here:https://github.com/clovaai/deep-text-recognition-benchmark
- import torch
- from models.crnn import CRNN
- model = CRNN(32, 1, 37, 256)
- model.load_state_dict(torch.load('crnn.pth'))
- dummy_input = torch.randn(1, 1, 32, 100)
- torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
- '''
- # Import required modules
- import numpy as np
- import cv2 as cv
- import math
- import argparse
- ############ Add argument parser for command line arguments ############
- parser = argparse.ArgumentParser(
- description="Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
- "EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)"
- "The OCR model can be obtained from converting the pretrained CRNN model to .onnx format from the github repository https://github.com/meijieru/crnn.pytorch"
- "Or you can download trained OCR model directly from https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing")
- parser.add_argument('--input',
- help='Path to input image or video file. Skip this argument to capture frames from a camera.')
- parser.add_argument('--model', '-m', required=True,
- help='Path to a binary .pb file contains trained detector network.')
- parser.add_argument('--ocr', default="crnn.onnx",
- help="Path to a binary .pb or .onnx file contains trained recognition network", )
- parser.add_argument('--width', type=int, default=320,
- help='Preprocess input image by resizing to a specific width. It should be multiple by 32.')
- parser.add_argument('--height', type=int, default=320,
- help='Preprocess input image by resizing to a specific height. It should be multiple by 32.')
- parser.add_argument('--thr', type=float, default=0.5,
- help='Confidence threshold.')
- parser.add_argument('--nms', type=float, default=0.4,
- help='Non-maximum suppression threshold.')
- args = parser.parse_args()
- ############ Utility functions ############
- def fourPointsTransform(frame, vertices):
- vertices = np.asarray(vertices)
- outputSize = (100, 32)
- targetVertices = np.array([
- [0, outputSize[1] - 1],
- [0, 0],
- [outputSize[0] - 1, 0],
- [outputSize[0] - 1, outputSize[1] - 1]], dtype="float32")
- rotationMatrix = cv.getPerspectiveTransform(vertices, targetVertices)
- result = cv.warpPerspective(frame, rotationMatrix, outputSize)
- return result
- def decodeText(scores):
- text = ""
- alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
- for i in range(scores.shape[0]):
- c = np.argmax(scores[i][0])
- if c != 0:
- text += alphabet[c - 1]
- else:
- text += '-'
- # adjacent same letters as well as background text must be removed to get the final output
- char_list = []
- for i in range(len(text)):
- if text[i] != '-' and (not (i > 0 and text[i] == text[i - 1])):
- char_list.append(text[i])
- return ''.join(char_list)
- def decodeBoundingBoxes(scores, geometry, scoreThresh):
- detections = []
- confidences = []
- ############ CHECK DIMENSIONS AND SHAPES OF geometry AND scores ############
- assert len(scores.shape) == 4, "Incorrect dimensions of scores"
- assert len(geometry.shape) == 4, "Incorrect dimensions of geometry"
- assert scores.shape[0] == 1, "Invalid dimensions of scores"
- assert geometry.shape[0] == 1, "Invalid dimensions of geometry"
- assert scores.shape[1] == 1, "Invalid dimensions of scores"
- assert geometry.shape[1] == 5, "Invalid dimensions of geometry"
- assert scores.shape[2] == geometry.shape[2], "Invalid dimensions of scores and geometry"
- assert scores.shape[3] == geometry.shape[3], "Invalid dimensions of scores and geometry"
- height = scores.shape[2]
- width = scores.shape[3]
- for y in range(0, height):
- # Extract data from scores
- scoresData = scores[0][0][y]
- x0_data = geometry[0][0][y]
- x1_data = geometry[0][1][y]
- x2_data = geometry[0][2][y]
- x3_data = geometry[0][3][y]
- anglesData = geometry[0][4][y]
- for x in range(0, width):
- score = scoresData[x]
- # If score is lower than threshold score, move to next x
- if (score < scoreThresh):
- continue
- # Calculate offset
- offsetX = x * 4.0
- offsetY = y * 4.0
- angle = anglesData[x]
- # Calculate cos and sin of angle
- cosA = math.cos(angle)
- sinA = math.sin(angle)
- h = x0_data[x] + x2_data[x]
- w = x1_data[x] + x3_data[x]
- # Calculate offset
- offset = ([offsetX + cosA * x1_data[x] + sinA * x2_data[x], offsetY - sinA * x1_data[x] + cosA * x2_data[x]])
- # Find points for rectangle
- p1 = (-sinA * h + offset[0], -cosA * h + offset[1])
- p3 = (-cosA * w + offset[0], sinA * w + offset[1])
- center = (0.5 * (p1[0] + p3[0]), 0.5 * (p1[1] + p3[1]))
- detections.append((center, (w, h), -1 * angle * 180.0 / math.pi))
- confidences.append(float(score))
- # Return detections and confidences
- return [detections, confidences]
- def main():
- # Read and store arguments
- confThreshold = args.thr
- nmsThreshold = args.nms
- inpWidth = args.width
- inpHeight = args.height
- modelDetector = args.model
- modelRecognition = args.ocr
- # Load network
- detector = cv.dnn.readNet(modelDetector)
- recognizer = cv.dnn.readNet(modelRecognition)
- # Create a new named window
- kWinName = "EAST: An Efficient and Accurate Scene Text Detector"
- cv.namedWindow(kWinName, cv.WINDOW_NORMAL)
- outNames = []
- outNames.append("feature_fusion/Conv_7/Sigmoid")
- outNames.append("feature_fusion/concat_3")
- # Open a video file or an image file or a camera stream
- cap = cv.VideoCapture(args.input if args.input else 0)
- tickmeter = cv.TickMeter()
- while cv.waitKey(1) < 0:
- # Read frame
- hasFrame, frame = cap.read()
- if not hasFrame:
- cv.waitKey()
- break
- # Get frame height and width
- height_ = frame.shape[0]
- width_ = frame.shape[1]
- rW = width_ / float(inpWidth)
- rH = height_ / float(inpHeight)
- # Create a 4D blob from frame.
- blob = cv.dnn.blobFromImage(frame, 1.0, (inpWidth, inpHeight), (123.68, 116.78, 103.94), True, False)
- # Run the detection model
- detector.setInput(blob)
- tickmeter.start()
- outs = detector.forward(outNames)
- tickmeter.stop()
- # Get scores and geometry
- scores = outs[0]
- geometry = outs[1]
- [boxes, confidences] = decodeBoundingBoxes(scores, geometry, confThreshold)
- # Apply NMS
- indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold, nmsThreshold)
- for i in indices:
- # get 4 corners of the rotated rect
- vertices = cv.boxPoints(boxes[i[0]])
- # scale the bounding box coordinates based on the respective ratios
- for j in range(4):
- vertices[j][0] *= rW
- vertices[j][1] *= rH
- # get cropped image using perspective transform
- if modelRecognition:
- cropped = fourPointsTransform(frame, vertices)
- cropped = cv.cvtColor(cropped, cv.COLOR_BGR2GRAY)
- # Create a 4D blob from cropped image
- blob = cv.dnn.blobFromImage(cropped, size=(100, 32), mean=127.5, scalefactor=1 / 127.5)
- recognizer.setInput(blob)
- # Run the recognition model
- tickmeter.start()
- result = recognizer.forward()
- tickmeter.stop()
- # decode the result into text
- wordRecognized = decodeText(result)
- cv.putText(frame, wordRecognized, (int(vertices[1][0]), int(vertices[1][1])), cv.FONT_HERSHEY_SIMPLEX,
- 0.5, (255, 0, 0))
- for j in range(4):
- p1 = (int(vertices[j][0]), int(vertices[j][1]))
- p2 = (int(vertices[(j + 1) % 4][0]), int(vertices[(j + 1) % 4][1]))
- cv.line(frame, p1, p2, (0, 255, 0), 1)
- # Put efficiency information
- label = 'Inference time: %.2f ms' % (tickmeter.getTimeMilli())
- cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
- # Display the frame
- cv.imshow(kWinName, frame)
- tickmeter.reset()
- if __name__ == "__main__":
- main()
|