text_detection.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. '''
  2. Text detection model: https://github.com/argman/EAST
  3. Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
  4. CRNN Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
  5. How to convert from pb to onnx:
  6. Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
  7. More converted onnx text recognition models can be downloaded directly here:
  8. Download link: https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing
  9. And these models taken from here:https://github.com/clovaai/deep-text-recognition-benchmark
  10. import torch
  11. from models.crnn import CRNN
  12. model = CRNN(32, 1, 37, 256)
  13. model.load_state_dict(torch.load('crnn.pth'))
  14. dummy_input = torch.randn(1, 1, 32, 100)
  15. torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
  16. '''
  17. # Import required modules
  18. import numpy as np
  19. import cv2 as cv
  20. import math
  21. import argparse
  22. ############ Add argument parser for command line arguments ############
  23. parser = argparse.ArgumentParser(
  24. description="Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
  25. "EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)"
  26. "The OCR model can be obtained from converting the pretrained CRNN model to .onnx format from the github repository https://github.com/meijieru/crnn.pytorch"
  27. "Or you can download trained OCR model directly from https://drive.google.com/drive/folders/1cTbQ3nuZG-EKWak6emD_s8_hHXWz7lAr?usp=sharing")
  28. parser.add_argument('--input',
  29. help='Path to input image or video file. Skip this argument to capture frames from a camera.')
  30. parser.add_argument('--model', '-m', required=True,
  31. help='Path to a binary .pb file contains trained detector network.')
  32. parser.add_argument('--ocr', default="crnn.onnx",
  33. help="Path to a binary .pb or .onnx file contains trained recognition network", )
  34. parser.add_argument('--width', type=int, default=320,
  35. help='Preprocess input image by resizing to a specific width. It should be multiple by 32.')
  36. parser.add_argument('--height', type=int, default=320,
  37. help='Preprocess input image by resizing to a specific height. It should be multiple by 32.')
  38. parser.add_argument('--thr', type=float, default=0.5,
  39. help='Confidence threshold.')
  40. parser.add_argument('--nms', type=float, default=0.4,
  41. help='Non-maximum suppression threshold.')
  42. args = parser.parse_args()
  43. ############ Utility functions ############
  44. def fourPointsTransform(frame, vertices):
  45. vertices = np.asarray(vertices)
  46. outputSize = (100, 32)
  47. targetVertices = np.array([
  48. [0, outputSize[1] - 1],
  49. [0, 0],
  50. [outputSize[0] - 1, 0],
  51. [outputSize[0] - 1, outputSize[1] - 1]], dtype="float32")
  52. rotationMatrix = cv.getPerspectiveTransform(vertices, targetVertices)
  53. result = cv.warpPerspective(frame, rotationMatrix, outputSize)
  54. return result
  55. def decodeText(scores):
  56. text = ""
  57. alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
  58. for i in range(scores.shape[0]):
  59. c = np.argmax(scores[i][0])
  60. if c != 0:
  61. text += alphabet[c - 1]
  62. else:
  63. text += '-'
  64. # adjacent same letters as well as background text must be removed to get the final output
  65. char_list = []
  66. for i in range(len(text)):
  67. if text[i] != '-' and (not (i > 0 and text[i] == text[i - 1])):
  68. char_list.append(text[i])
  69. return ''.join(char_list)
  70. def decodeBoundingBoxes(scores, geometry, scoreThresh):
  71. detections = []
  72. confidences = []
  73. ############ CHECK DIMENSIONS AND SHAPES OF geometry AND scores ############
  74. assert len(scores.shape) == 4, "Incorrect dimensions of scores"
  75. assert len(geometry.shape) == 4, "Incorrect dimensions of geometry"
  76. assert scores.shape[0] == 1, "Invalid dimensions of scores"
  77. assert geometry.shape[0] == 1, "Invalid dimensions of geometry"
  78. assert scores.shape[1] == 1, "Invalid dimensions of scores"
  79. assert geometry.shape[1] == 5, "Invalid dimensions of geometry"
  80. assert scores.shape[2] == geometry.shape[2], "Invalid dimensions of scores and geometry"
  81. assert scores.shape[3] == geometry.shape[3], "Invalid dimensions of scores and geometry"
  82. height = scores.shape[2]
  83. width = scores.shape[3]
  84. for y in range(0, height):
  85. # Extract data from scores
  86. scoresData = scores[0][0][y]
  87. x0_data = geometry[0][0][y]
  88. x1_data = geometry[0][1][y]
  89. x2_data = geometry[0][2][y]
  90. x3_data = geometry[0][3][y]
  91. anglesData = geometry[0][4][y]
  92. for x in range(0, width):
  93. score = scoresData[x]
  94. # If score is lower than threshold score, move to next x
  95. if (score < scoreThresh):
  96. continue
  97. # Calculate offset
  98. offsetX = x * 4.0
  99. offsetY = y * 4.0
  100. angle = anglesData[x]
  101. # Calculate cos and sin of angle
  102. cosA = math.cos(angle)
  103. sinA = math.sin(angle)
  104. h = x0_data[x] + x2_data[x]
  105. w = x1_data[x] + x3_data[x]
  106. # Calculate offset
  107. offset = ([offsetX + cosA * x1_data[x] + sinA * x2_data[x], offsetY - sinA * x1_data[x] + cosA * x2_data[x]])
  108. # Find points for rectangle
  109. p1 = (-sinA * h + offset[0], -cosA * h + offset[1])
  110. p3 = (-cosA * w + offset[0], sinA * w + offset[1])
  111. center = (0.5 * (p1[0] + p3[0]), 0.5 * (p1[1] + p3[1]))
  112. detections.append((center, (w, h), -1 * angle * 180.0 / math.pi))
  113. confidences.append(float(score))
  114. # Return detections and confidences
  115. return [detections, confidences]
  116. def main():
  117. # Read and store arguments
  118. confThreshold = args.thr
  119. nmsThreshold = args.nms
  120. inpWidth = args.width
  121. inpHeight = args.height
  122. modelDetector = args.model
  123. modelRecognition = args.ocr
  124. # Load network
  125. detector = cv.dnn.readNet(modelDetector)
  126. recognizer = cv.dnn.readNet(modelRecognition)
  127. # Create a new named window
  128. kWinName = "EAST: An Efficient and Accurate Scene Text Detector"
  129. cv.namedWindow(kWinName, cv.WINDOW_NORMAL)
  130. outNames = []
  131. outNames.append("feature_fusion/Conv_7/Sigmoid")
  132. outNames.append("feature_fusion/concat_3")
  133. # Open a video file or an image file or a camera stream
  134. cap = cv.VideoCapture(args.input if args.input else 0)
  135. tickmeter = cv.TickMeter()
  136. while cv.waitKey(1) < 0:
  137. # Read frame
  138. hasFrame, frame = cap.read()
  139. if not hasFrame:
  140. cv.waitKey()
  141. break
  142. # Get frame height and width
  143. height_ = frame.shape[0]
  144. width_ = frame.shape[1]
  145. rW = width_ / float(inpWidth)
  146. rH = height_ / float(inpHeight)
  147. # Create a 4D blob from frame.
  148. blob = cv.dnn.blobFromImage(frame, 1.0, (inpWidth, inpHeight), (123.68, 116.78, 103.94), True, False)
  149. # Run the detection model
  150. detector.setInput(blob)
  151. tickmeter.start()
  152. outs = detector.forward(outNames)
  153. tickmeter.stop()
  154. # Get scores and geometry
  155. scores = outs[0]
  156. geometry = outs[1]
  157. [boxes, confidences] = decodeBoundingBoxes(scores, geometry, confThreshold)
  158. # Apply NMS
  159. indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold, nmsThreshold)
  160. for i in indices:
  161. # get 4 corners of the rotated rect
  162. vertices = cv.boxPoints(boxes[i[0]])
  163. # scale the bounding box coordinates based on the respective ratios
  164. for j in range(4):
  165. vertices[j][0] *= rW
  166. vertices[j][1] *= rH
  167. # get cropped image using perspective transform
  168. if modelRecognition:
  169. cropped = fourPointsTransform(frame, vertices)
  170. cropped = cv.cvtColor(cropped, cv.COLOR_BGR2GRAY)
  171. # Create a 4D blob from cropped image
  172. blob = cv.dnn.blobFromImage(cropped, size=(100, 32), mean=127.5, scalefactor=1 / 127.5)
  173. recognizer.setInput(blob)
  174. # Run the recognition model
  175. tickmeter.start()
  176. result = recognizer.forward()
  177. tickmeter.stop()
  178. # decode the result into text
  179. wordRecognized = decodeText(result)
  180. cv.putText(frame, wordRecognized, (int(vertices[1][0]), int(vertices[1][1])), cv.FONT_HERSHEY_SIMPLEX,
  181. 0.5, (255, 0, 0))
  182. for j in range(4):
  183. p1 = (int(vertices[j][0]), int(vertices[j][1]))
  184. p2 = (int(vertices[(j + 1) % 4][0]), int(vertices[(j + 1) % 4][1]))
  185. cv.line(frame, p1, p2, (0, 255, 0), 1)
  186. # Put efficiency information
  187. label = 'Inference time: %.2f ms' % (tickmeter.getTimeMilli())
  188. cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
  189. # Display the frame
  190. cv.imshow(kWinName, frame)
  191. tickmeter.reset()
  192. if __name__ == "__main__":
  193. main()