speech_recognition.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. import numpy as np
  2. import cv2 as cv
  3. import argparse
  4. import os
  5. '''
  6. You can download the converted onnx model from https://drive.google.com/drive/folders/1wLtxyao4ItAg8tt4Sb63zt6qXzhcQoR6?usp=sharing
  7. or convert the model yourself.
  8. You can get the original pre-trained Jasper model from NVIDIA : https://ngc.nvidia.com/catalog/models/nvidia:jasper_pyt_onnx_fp16_amp/files
  9. Download and unzip : `$ wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/jasper_pyt_onnx_fp16_amp/versions/20.10.0/zip -O jasper_pyt_onnx_fp16_amp_20.10.0.zip && unzip -o ./jasper_pyt_onnx_fp16_amp_20.10.0.zip && unzip -o ./jasper_pyt_onnx_fp16_amp.zip`
  10. you can get the script to convert the model here : https://gist.github.com/spazewalker/507f1529e19aea7e8417f6e935851a01
  11. You can convert the model using the following steps:
  12. 1. Import onnx and load the original model
  13. ```
  14. import onnx
  15. model = onnx.load("./jasper-onnx/1/model.onnx")
  16. ```
  17. 3. Change data type of input layer
  18. ```
  19. inp = model.graph.input[0]
  20. model.graph.input.remove(inp)
  21. inp.type.tensor_type.elem_type = 1
  22. model.graph.input.insert(0,inp)
  23. ```
  24. 4. Change the data type of output layer
  25. ```
  26. out = model.graph.output[0]
  27. model.graph.output.remove(out)
  28. out.type.tensor_type.elem_type = 1
  29. model.graph.output.insert(0,out)
  30. ```
  31. 5. Change the data type of every initializer and cast it's values from FP16 to FP32
  32. ```
  33. for i,init in enumerate(model.graph.initializer):
  34. model.graph.initializer.remove(init)
  35. init.data_type = 1
  36. init.raw_data = np.frombuffer(init.raw_data, count=np.product(init.dims), dtype=np.float16).astype(np.float32).tobytes()
  37. model.graph.initializer.insert(i,init)
  38. ```
  39. 6. Add an additional reshape node to handle the inconsistant input from python and c++ of openCV.
  40. see https://github.com/opencv/opencv/issues/19091
  41. Make & insert a new node with 'Reshape' operation & required initializer
  42. ```
  43. tensor = numpy_helper.from_array(np.array([0,64,-1]),name='shape_reshape')
  44. model.graph.initializer.insert(0,tensor)
  45. node = onnx.helper.make_node(op_type='Reshape',inputs=['input__0','shape_reshape'], outputs=['input_reshaped'], name='reshape__0')
  46. model.graph.node.insert(0,node)
  47. model.graph.node[1].input[0] = 'input_reshaped'
  48. ```
  49. 7. Finally save the model
  50. ```
  51. with open('jasper_dynamic_input_float.onnx','wb') as f:
  52. onnx.save_model(model,f)
  53. ```
  54. Original Repo : https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechRecognition/Jasper
  55. '''
  56. class FilterbankFeatures:
  57. def __init__(self,
  58. sample_rate=16000, window_size=0.02, window_stride=0.01,
  59. n_fft=512, preemph=0.97, n_filt=64, lowfreq=0,
  60. highfreq=None, log=True, dither=1e-5):
  61. '''
  62. Initializes pre-processing class. Default values are the values used by the Jasper
  63. architecture for pre-processing. For more details, refer to the paper here:
  64. https://arxiv.org/abs/1904.03288
  65. '''
  66. self.win_length = int(sample_rate * window_size) # frame size
  67. self.hop_length = int(sample_rate * window_stride) # stride
  68. self.n_fft = n_fft or 2 ** np.ceil(np.log2(self.win_length))
  69. self.log = log
  70. self.dither = dither
  71. self.n_filt = n_filt
  72. self.preemph = preemph
  73. highfreq = highfreq or sample_rate / 2
  74. self.window_tensor = np.hanning(self.win_length)
  75. self.filterbanks = self.mel(sample_rate, self.n_fft, n_mels=n_filt, fmin=lowfreq, fmax=highfreq)
  76. self.filterbanks.dtype=np.float32
  77. self.filterbanks = np.expand_dims(self.filterbanks,0)
  78. def normalize_batch(self, x, seq_len):
  79. '''
  80. Normalizes the features.
  81. '''
  82. x_mean = np.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype)
  83. x_std = np.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype)
  84. for i in range(x.shape[0]):
  85. x_mean[i, :] = np.mean(x[i, :, :seq_len[i]],axis=1)
  86. x_std[i, :] = np.std(x[i, :, :seq_len[i]],axis=1)
  87. # make sure x_std is not zero
  88. x_std += 1e-10
  89. return (x - np.expand_dims(x_mean,2)) / np.expand_dims(x_std,2)
  90. def calculate_features(self, x, seq_len):
  91. '''
  92. Calculates filterbank features.
  93. args:
  94. x : mono channel audio
  95. seq_len : length of the audio sample
  96. returns:
  97. x : filterbank features
  98. '''
  99. dtype = x.dtype
  100. seq_len = np.ceil(seq_len / self.hop_length)
  101. seq_len = np.array(seq_len,dtype=np.int32)
  102. # dither
  103. if self.dither > 0:
  104. x += self.dither * np.random.randn(*x.shape)
  105. # do preemphasis
  106. if self.preemph is not None:
  107. x = np.concatenate(
  108. (np.expand_dims(x[0],-1), x[1:] - self.preemph * x[:-1]), axis=0)
  109. # Short Time Fourier Transform
  110. x = self.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
  111. win_length=self.win_length,
  112. fft_window=self.window_tensor)
  113. # get power spectrum
  114. x = (x**2).sum(-1)
  115. # dot with filterbank energies
  116. x = np.matmul(np.array(self.filterbanks,dtype=x.dtype), x)
  117. # log features if required
  118. if self.log:
  119. x = np.log(x + 1e-20)
  120. # normalize if required
  121. x = self.normalize_batch(x, seq_len).astype(dtype)
  122. return x
  123. # Mel Frequency calculation
  124. def hz_to_mel(self, frequencies):
  125. '''
  126. Converts frequencies from hz to mel scale. Input can be a number or a vector.
  127. '''
  128. frequencies = np.asanyarray(frequencies)
  129. f_min = 0.0
  130. f_sp = 200.0 / 3
  131. mels = (frequencies - f_min) / f_sp
  132. # Fill in the log-scale part
  133. min_log_hz = 1000.0 # beginning of log region (Hz)
  134. min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
  135. logstep = np.log(6.4) / 27.0 # step size for log region
  136. if frequencies.ndim:
  137. # If we have array data, vectorize
  138. log_t = frequencies >= min_log_hz
  139. mels[log_t] = min_log_mel + np.log(frequencies[log_t] / min_log_hz) / logstep
  140. elif frequencies >= min_log_hz:
  141. # If we have scalar data, directly
  142. mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep
  143. return mels
  144. def mel_to_hz(self, mels):
  145. '''
  146. Converts frequencies from mel to hz scale. Input can be a number or a vector.
  147. '''
  148. mels = np.asanyarray(mels)
  149. # Fill in the linear scale
  150. f_min = 0.0
  151. f_sp = 200.0 / 3
  152. freqs = f_min + f_sp * mels
  153. # And now the nonlinear scale
  154. min_log_hz = 1000.0 # beginning of log region (Hz)
  155. min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
  156. logstep = np.log(6.4) / 27.0 # step size for log region
  157. if mels.ndim:
  158. # If we have vector data, vectorize
  159. log_t = mels >= min_log_mel
  160. freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
  161. elif mels >= min_log_mel:
  162. # If we have scalar data, check directly
  163. freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel))
  164. return freqs
  165. def mel_frequencies(self, n_mels=128, fmin=0.0, fmax=11025.0):
  166. '''
  167. Calculates n mel frequencies between 2 frequencies
  168. args:
  169. n_mels : number of bands
  170. fmin : min frequency
  171. fmax : max frequency
  172. returns:
  173. mels : vector of mel frequencies
  174. '''
  175. # 'Center freqs' of mel bands - uniformly spaced between limits
  176. min_mel = self.hz_to_mel(fmin)
  177. max_mel = self.hz_to_mel(fmax)
  178. mels = np.linspace(min_mel, max_mel, n_mels)
  179. return self.mel_to_hz(mels)
  180. def mel(self, sr, n_fft, n_mels=128, fmin=0.0, fmax=None, dtype=np.float32):
  181. '''
  182. Generates mel filterbank
  183. args:
  184. sr : Sampling rate
  185. n_fft : number of FFT components
  186. n_mels : number of Mel bands to generate
  187. fmin : lowest frequency (in Hz)
  188. fmax : highest frequency (in Hz). sr/2.0 if None
  189. dtype : the data type of the output basis.
  190. returns:
  191. mels : Mel transform matrix
  192. '''
  193. # default Max freq = half of sampling rate
  194. if fmax is None:
  195. fmax = float(sr) / 2
  196. # Initialize the weights
  197. n_mels = int(n_mels)
  198. weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
  199. # Center freqs of each FFT bin
  200. fftfreqs = np.linspace(0, float(sr) / 2, int(1 + n_fft // 2), endpoint=True)
  201. # 'Center freqs' of mel bands - uniformly spaced between limits
  202. mel_f = self.mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax)
  203. fdiff = np.diff(mel_f)
  204. ramps = np.subtract.outer(mel_f, fftfreqs)
  205. for i in range(n_mels):
  206. # lower and upper slopes for all bins
  207. lower = -ramps[i] / fdiff[i]
  208. upper = ramps[i + 2] / fdiff[i + 1]
  209. # .. then intersect them with each other and zero
  210. weights[i] = np.maximum(0, np.minimum(lower, upper))
  211. # Using Slaney-style mel which is scaled to be approx constant energy per channel
  212. enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
  213. weights *= enorm[:, np.newaxis]
  214. return weights
  215. # STFT preperation
  216. def pad_window_center(self, data, size, axis=-1, **kwargs):
  217. '''
  218. Centers the data and pads.
  219. args:
  220. data : Vector to be padded and centered
  221. size : Length to pad data
  222. axis : Axis along which to pad and center the data
  223. kwargs : arguments passed to np.pad
  224. return : centered and padded data
  225. '''
  226. kwargs.setdefault("mode", "constant")
  227. n = data.shape[axis]
  228. lpad = int((size - n) // 2)
  229. lengths = [(0, 0)] * data.ndim
  230. lengths[axis] = (lpad, int(size - n - lpad))
  231. if lpad < 0:
  232. raise Exception(
  233. ("Target size ({:d}) must be at least input size ({:d})").format(size, n)
  234. )
  235. return np.pad(data, lengths, **kwargs)
  236. def frame(self, x, frame_length, hop_length):
  237. '''
  238. Slices a data array into (overlapping) frames.
  239. args:
  240. x : array to frame
  241. frame_length : length of frame
  242. hop_length : Number of steps to advance between frames
  243. return : A framed view of `x`
  244. '''
  245. if x.shape[-1] < frame_length:
  246. raise Exception(
  247. "Input is too short (n={:d})"
  248. " for frame_length={:d}".format(x.shape[-1], frame_length)
  249. )
  250. x = np.asfortranarray(x)
  251. n_frames = 1 + (x.shape[-1] - frame_length) // hop_length
  252. strides = np.asarray(x.strides)
  253. new_stride = np.prod(strides[strides > 0] // x.itemsize) * x.itemsize
  254. shape = list(x.shape)[:-1] + [frame_length, n_frames]
  255. strides = list(strides) + [hop_length * new_stride]
  256. return np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
  257. def dtype_r2c(self, d, default=np.complex64):
  258. '''
  259. Find the complex numpy dtype corresponding to a real dtype.
  260. args:
  261. d : The real-valued dtype to convert to complex.
  262. default : The default complex target type, if `d` does not match a known dtype
  263. return : The complex dtype
  264. '''
  265. mapping = {
  266. np.dtype(np.float32): np.complex64,
  267. np.dtype(np.float64): np.complex128,
  268. }
  269. dt = np.dtype(d)
  270. if dt.kind == "c":
  271. return dt
  272. return np.dtype(mapping.get(dt, default))
  273. def stft(self, y, n_fft, hop_length=None, win_length=None, fft_window=None, pad_mode='reflect', return_complex=False):
  274. '''
  275. Short Time Fourier Transform. The STFT represents a signal in the time-frequency
  276. domain by computing discrete Fourier transforms (DFT) over short overlapping windows.
  277. args:
  278. y : input signal
  279. n_fft : length of the windowed signal after padding with zeros.
  280. hop_length : number of audio samples between adjacent STFT columns.
  281. win_length : Each frame of audio is windowed by window of length win_length and
  282. then padded with zeros to match n_fft
  283. fft_window : a vector or array of length `n_fft` having values computed by a
  284. window function
  285. pad_mode : mode while padding the singnal
  286. return_complex : returns array with complex data type if `True`
  287. return : Matrix of short-term Fourier transform coefficients.
  288. '''
  289. if win_length is None:
  290. win_length = n_fft
  291. if hop_length is None:
  292. hop_length = int(win_length // 4)
  293. if y.ndim!=1:
  294. raise Exception(f'Invalid input shape. Only Mono Channeled audio supported. Input must have shape (Audio,). Got {y.shape}')
  295. # Pad the window out to n_fft size
  296. fft_window = self.pad_window_center(fft_window, n_fft)
  297. # Reshape so that the window can be broadcast
  298. fft_window = fft_window.reshape((-1, 1))
  299. # Pad the time series so that frames are centered
  300. y = np.pad(y, int(n_fft // 2), mode=pad_mode)
  301. # Window the time series.
  302. y_frames = self.frame(y, frame_length=n_fft, hop_length=hop_length)
  303. # Convert data type to complex
  304. dtype = self.dtype_r2c(y.dtype)
  305. # Pre-allocate the STFT matrix
  306. stft_matrix = np.empty( (int(1 + n_fft // 2), y_frames.shape[-1]), dtype=dtype, order="F")
  307. stft_matrix = np.fft.rfft( fft_window * y_frames, axis=0)
  308. return stft_matrix if return_complex==True else np.stack((stft_matrix.real,stft_matrix.imag),axis=-1)
  309. class Decoder:
  310. '''
  311. Used for decoding the output of jasper model.
  312. '''
  313. def __init__(self):
  314. labels=[' ','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z',"'"]
  315. self.labels_map = {i: label for i,label in enumerate(labels)}
  316. self.blank_id = 28
  317. def decode(self,x):
  318. """
  319. Takes output of Jasper model and performs ctc decoding algorithm to
  320. remove duplicates and special symbol. Returns prediction
  321. """
  322. x = np.argmax(x,axis=-1)
  323. hypotheses = []
  324. prediction = x.tolist()
  325. # CTC decoding procedure
  326. decoded_prediction = []
  327. previous = self.blank_id
  328. for p in prediction:
  329. if (p != previous or previous == self.blank_id) and p != self.blank_id:
  330. decoded_prediction.append(p)
  331. previous = p
  332. hypothesis = ''.join([self.labels_map[c] for c in decoded_prediction])
  333. hypotheses.append(hypothesis)
  334. return hypotheses
  335. def predict(features, net, decoder):
  336. '''
  337. Passes the features through the Jasper model and decodes the output to english transcripts.
  338. args:
  339. features : input features, calculated using FilterbankFeatures class
  340. net : Jasper model dnn.net object
  341. decoder : Decoder object
  342. return : Predicted text
  343. '''
  344. # make prediction
  345. net.setInput(features)
  346. output = net.forward()
  347. # decode output to transcript
  348. prediction = decoder.decode(output.squeeze(0))
  349. return prediction[0]
  350. def readAudioFile(file, audioStream):
  351. cap = cv.VideoCapture(file)
  352. samplingRate = 16000
  353. params = np.asarray([cv.CAP_PROP_AUDIO_STREAM, audioStream,
  354. cv.CAP_PROP_VIDEO_STREAM, -1,
  355. cv.CAP_PROP_AUDIO_DATA_DEPTH, cv.CV_32F,
  356. cv.CAP_PROP_AUDIO_SAMPLES_PER_SECOND, samplingRate
  357. ])
  358. cap.open(file, cv.CAP_ANY, params)
  359. if cap.isOpened() is False:
  360. print("Error : Can't read audio file:", file, "with audioStream = ", audioStream)
  361. return
  362. audioBaseIndex = int (cap.get(cv.CAP_PROP_AUDIO_BASE_INDEX))
  363. inputAudio = []
  364. while(1):
  365. if (cap.grab()):
  366. frame = np.asarray([])
  367. frame = cap.retrieve(frame, audioBaseIndex)
  368. for i in range(len(frame[1][0])):
  369. inputAudio.append(frame[1][0][i])
  370. else:
  371. break
  372. inputAudio = np.asarray(inputAudio, dtype=np.float64)
  373. return inputAudio, samplingRate
  374. def readAudioMicrophone(microTime):
  375. cap = cv.VideoCapture()
  376. samplingRate = 16000
  377. params = np.asarray([cv.CAP_PROP_AUDIO_STREAM, 0,
  378. cv.CAP_PROP_VIDEO_STREAM, -1,
  379. cv.CAP_PROP_AUDIO_DATA_DEPTH, cv.CV_32F,
  380. cv.CAP_PROP_AUDIO_SAMPLES_PER_SECOND, samplingRate
  381. ])
  382. cap.open(0, cv.CAP_ANY, params)
  383. if cap.isOpened() is False:
  384. print("Error: Can't open microphone")
  385. print("Error: problems with audio reading, check input arguments")
  386. return
  387. audioBaseIndex = int(cap.get(cv.CAP_PROP_AUDIO_BASE_INDEX))
  388. cvTickFreq = cv.getTickFrequency()
  389. sysTimeCurr = cv.getTickCount()
  390. sysTimePrev = sysTimeCurr
  391. inputAudio = []
  392. while ((sysTimeCurr - sysTimePrev) / cvTickFreq < microTime):
  393. if (cap.grab()):
  394. frame = np.asarray([])
  395. frame = cap.retrieve(frame, audioBaseIndex)
  396. for i in range(len(frame[1][0])):
  397. inputAudio.append(frame[1][0][i])
  398. sysTimeCurr = cv.getTickCount()
  399. else:
  400. print("Error: Grab error")
  401. break
  402. inputAudio = np.asarray(inputAudio, dtype=np.float64)
  403. print("Number of samples: ", len(inputAudio))
  404. return inputAudio, samplingRate
  405. if __name__ == '__main__':
  406. # Computation backends supported by layers
  407. backends = (cv.dnn.DNN_BACKEND_DEFAULT, cv.dnn.DNN_BACKEND_INFERENCE_ENGINE, cv.dnn.DNN_BACKEND_OPENCV)
  408. # Target Devices for computation
  409. targets = (cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_OPENCL, cv.dnn.DNN_TARGET_OPENCL_FP16)
  410. parser = argparse.ArgumentParser(description='This script runs Jasper Speech recognition model',
  411. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  412. parser.add_argument('--input_type', type=str, required=True, help='file or microphone')
  413. parser.add_argument('--micro_time', type=int, default=15, help='Duration of microphone work in seconds. Must be more than 6 sec')
  414. parser.add_argument('--input_audio', type=str, help='Path to input audio file. OR Path to a txt file with relative path to multiple audio files in different lines')
  415. parser.add_argument('--audio_stream', type=int, default=0, help='CAP_PROP_AUDIO_STREAM value')
  416. parser.add_argument('--show_spectrogram', action='store_true', help='Whether to show a spectrogram of the input audio.')
  417. parser.add_argument('--model', type=str, default='jasper.onnx', help='Path to the onnx file of Jasper. default="jasper.onnx"')
  418. parser.add_argument('--output', type=str, help='Path to file where recognized audio transcript must be saved. Leave this to print on console.')
  419. parser.add_argument('--backend', choices=backends, default=cv.dnn.DNN_BACKEND_DEFAULT, type=int,
  420. help='Select a computation backend: '
  421. "%d: automatically (by default) "
  422. "%d: OpenVINO Inference Engine "
  423. "%d: OpenCV Implementation " % backends)
  424. parser.add_argument('--target', choices=targets, default=cv.dnn.DNN_TARGET_CPU, type=int,
  425. help='Select a target device: '
  426. "%d: CPU target (by default) "
  427. "%d: OpenCL "
  428. "%d: OpenCL FP16 " % targets)
  429. args, _ = parser.parse_known_args()
  430. if args.input_audio and not os.path.isfile(args.input_audio):
  431. raise OSError("Input audio file does not exist")
  432. if not os.path.isfile(args.model):
  433. raise OSError("Jasper model file does not exist")
  434. features = []
  435. if args.input_type == "file":
  436. if args.input_audio.endswith('.txt'):
  437. with open(args.input_audio) as f:
  438. content = f.readlines()
  439. content = [x.strip() for x in content]
  440. audio_file_paths = content
  441. for audio_file_path in audio_file_paths:
  442. if not os.path.isfile(audio_file_path):
  443. raise OSError("Audio file({audio_file_path}) does not exist")
  444. else:
  445. audio_file_paths = [args.input_audio]
  446. audio_file_paths = [os.path.abspath(x) for x in audio_file_paths]
  447. # Read audio Files
  448. for audio_file_path in audio_file_paths:
  449. audio = readAudioFile(audio_file_path, args.audio_stream)
  450. if audio is None:
  451. raise Exception(f"Can't read {args.input_audio}. Try a different format")
  452. features.append(audio[0])
  453. elif args.input_type == "microphone":
  454. # Read audio from microphone
  455. audio = readAudioMicrophone(args.micro_time)
  456. if audio is None:
  457. raise Exception(f"Can't open microphone. Try a different format")
  458. features.append(audio[0])
  459. else:
  460. raise Exception(f"input_type {args.input_type} doesn't exist. Please enter 'file' or 'microphone'")
  461. # Get Filterbank Features
  462. feature_extractor = FilterbankFeatures()
  463. for i in range(len(features)):
  464. X = features[i]
  465. seq_len = np.array([X.shape[0]], dtype=np.int32)
  466. features[i] = feature_extractor.calculate_features(x=X, seq_len=seq_len)
  467. # Load Network
  468. net = cv.dnn.readNetFromONNX(args.model)
  469. net.setPreferableBackend(args.backend)
  470. net.setPreferableTarget(args.target)
  471. # Show spectogram if required
  472. if args.show_spectrogram and not args.input_audio.endswith('.txt'):
  473. img = cv.normalize(src=features[0][0], dst=None, alpha=0, beta=255, norm_type=cv.NORM_MINMAX, dtype=cv.CV_8U)
  474. img = cv.applyColorMap(img, cv.COLORMAP_JET)
  475. cv.imshow('spectogram', img)
  476. cv.waitKey(0)
  477. # Initialize decoder
  478. decoder = Decoder()
  479. # Make prediction
  480. prediction = []
  481. print("Predicting...")
  482. for feature in features:
  483. print(f"\rAudio file {len(prediction)+1}/{len(features)}", end='')
  484. prediction.append(predict(feature, net, decoder))
  485. print("")
  486. # save transcript if required
  487. if args.output:
  488. with open(args.output,'w') as f:
  489. for pred in prediction:
  490. f.write(pred+'\n')
  491. print("Transcript was written to {}".format(args.output))
  492. else:
  493. print(prediction)
  494. cv.destroyAllWindows()