action_recognition.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import os
  2. import numpy as np
  3. import cv2 as cv
  4. import argparse
  5. from common import findFile
  6. parser = argparse.ArgumentParser(description='Use this script to run action recognition using 3D ResNet34',
  7. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  8. parser.add_argument('--input', '-i', help='Path to input video file. Skip this argument to capture frames from a camera.')
  9. parser.add_argument('--model', required=True, help='Path to model.')
  10. parser.add_argument('--classes', default=findFile('action_recongnition_kinetics.txt'), help='Path to classes list.')
  11. # To get net download original repository https://github.com/kenshohara/video-classification-3d-cnn-pytorch
  12. # For correct ONNX export modify file: video-classification-3d-cnn-pytorch/models/resnet.py
  13. # change
  14. # - def downsample_basic_block(x, planes, stride):
  15. # - out = F.avg_pool3d(x, kernel_size=1, stride=stride)
  16. # - zero_pads = torch.Tensor(out.size(0), planes - out.size(1),
  17. # - out.size(2), out.size(3),
  18. # - out.size(4)).zero_()
  19. # - if isinstance(out.data, torch.cuda.FloatTensor):
  20. # - zero_pads = zero_pads.cuda()
  21. # -
  22. # - out = Variable(torch.cat([out.data, zero_pads], dim=1))
  23. # - return out
  24. # To
  25. # + def downsample_basic_block(x, planes, stride):
  26. # + out = F.avg_pool3d(x, kernel_size=1, stride=stride)
  27. # + out = F.pad(out, (0, 0, 0, 0, 0, 0, 0, int(planes - out.size(1)), 0, 0), "constant", 0)
  28. # + return out
  29. # To ONNX export use torch.onnx.export(model, inputs, model_name)
  30. def get_class_names(path):
  31. class_names = []
  32. with open(path) as f:
  33. for row in f:
  34. class_names.append(row[:-1])
  35. return class_names
  36. def classify_video(video_path, net_path):
  37. SAMPLE_DURATION = 16
  38. SAMPLE_SIZE = 112
  39. mean = (114.7748, 107.7354, 99.4750)
  40. class_names = get_class_names(args.classes)
  41. net = cv.dnn.readNet(net_path)
  42. net.setPreferableBackend(cv.dnn.DNN_BACKEND_INFERENCE_ENGINE)
  43. net.setPreferableTarget(cv.dnn.DNN_TARGET_CPU)
  44. winName = 'Deep learning image classification in OpenCV'
  45. cv.namedWindow(winName, cv.WINDOW_AUTOSIZE)
  46. cap = cv.VideoCapture(video_path)
  47. while cv.waitKey(1) < 0:
  48. frames = []
  49. for _ in range(SAMPLE_DURATION):
  50. hasFrame, frame = cap.read()
  51. if not hasFrame:
  52. exit(0)
  53. frames.append(frame)
  54. inputs = cv.dnn.blobFromImages(frames, 1, (SAMPLE_SIZE, SAMPLE_SIZE), mean, True, crop=True)
  55. inputs = np.transpose(inputs, (1, 0, 2, 3))
  56. inputs = np.expand_dims(inputs, axis=0)
  57. net.setInput(inputs)
  58. outputs = net.forward()
  59. class_pred = np.argmax(outputs)
  60. label = class_names[class_pred]
  61. for frame in frames:
  62. labelSize, baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 1)
  63. cv.rectangle(frame, (0, 10 - labelSize[1]),
  64. (labelSize[0], 10 + baseLine), (255, 255, 255), cv.FILLED)
  65. cv.putText(frame, label, (0, 10), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0))
  66. cv.imshow(winName, frame)
  67. if cv.waitKey(1) & 0xFF == ord('q'):
  68. break
  69. if __name__ == "__main__":
  70. args, _ = parser.parse_known_args()
  71. classify_video(args.input if args.input else 0, args.model)