kalman.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #!/usr/bin/env python
  2. """
  3. Tracking of rotating point.
  4. Point moves in a circle and is characterized by a 1D state.
  5. state_k+1 = state_k + speed + process_noise N(0, 1e-5)
  6. The speed is constant.
  7. Both state and measurements vectors are 1D (a point angle),
  8. Measurement is the real state + gaussian noise N(0, 1e-1).
  9. The real and the measured points are connected with red line segment,
  10. the real and the estimated points are connected with yellow line segment,
  11. the real and the corrected estimated points are connected with green line segment.
  12. (if Kalman filter works correctly,
  13. the yellow segment should be shorter than the red one and
  14. the green segment should be shorter than the yellow one).
  15. Pressing any key (except ESC) will reset the tracking.
  16. Pressing ESC will stop the program.
  17. """
  18. # Python 2/3 compatibility
  19. import sys
  20. PY3 = sys.version_info[0] == 3
  21. if PY3:
  22. long = int
  23. import numpy as np
  24. import cv2 as cv
  25. from math import cos, sin, sqrt, pi
  26. def main():
  27. img_height = 500
  28. img_width = 500
  29. kalman = cv.KalmanFilter(2, 1, 0)
  30. code = long(-1)
  31. num_circle_steps = 12
  32. while True:
  33. img = np.zeros((img_height, img_width, 3), np.uint8)
  34. state = np.array([[0.0],[(2 * pi) / num_circle_steps]]) # start state
  35. kalman.transitionMatrix = np.array([[1., 1.], [0., 1.]]) # F. input
  36. kalman.measurementMatrix = 1. * np.eye(1, 2) # H. input
  37. kalman.processNoiseCov = 1e-5 * np.eye(2) # Q. input
  38. kalman.measurementNoiseCov = 1e-1 * np.ones((1, 1)) # R. input
  39. kalman.errorCovPost = 1. * np.eye(2, 2) # P._k|k KF state var
  40. kalman.statePost = 0.1 * np.random.randn(2, 1) # x^_k|k KF state var
  41. while True:
  42. def calc_point(angle):
  43. return (np.around(img_width / 2. + img_width / 3.0 * cos(angle), 0).astype(int),
  44. np.around(img_height / 2. - img_width / 3.0 * sin(angle), 1).astype(int))
  45. img = img * 1e-3
  46. state_angle = state[0, 0]
  47. state_pt = calc_point(state_angle)
  48. # advance Kalman filter to next timestep
  49. # updates statePre, statePost, errorCovPre, errorCovPost
  50. # k-> k+1, x'(k) = A*x(k)
  51. # P'(k) = temp1*At + Q
  52. prediction = kalman.predict()
  53. predict_pt = calc_point(prediction[0, 0]) # equivalent to calc_point(kalman.statePre[0,0])
  54. # generate measurement
  55. measurement = kalman.measurementNoiseCov * np.random.randn(1, 1)
  56. measurement = np.dot(kalman.measurementMatrix, state) + measurement
  57. measurement_angle = measurement[0, 0]
  58. measurement_pt = calc_point(measurement_angle)
  59. # correct the state estimates based on measurements
  60. # updates statePost & errorCovPost
  61. kalman.correct(measurement)
  62. improved_pt = calc_point(kalman.statePost[0, 0])
  63. # plot points
  64. cv.drawMarker(img, measurement_pt, (0, 0, 255), cv.MARKER_SQUARE, 5, 2)
  65. cv.drawMarker(img, predict_pt, (0, 255, 255), cv.MARKER_SQUARE, 5, 2)
  66. cv.drawMarker(img, improved_pt, (0, 255, 0), cv.MARKER_SQUARE, 5, 2)
  67. cv.drawMarker(img, state_pt, (255, 255, 255), cv.MARKER_STAR, 10, 1)
  68. # forecast one step
  69. cv.drawMarker(img, calc_point(np.dot(kalman.transitionMatrix, kalman.statePost)[0, 0]),
  70. (255, 255, 0), cv.MARKER_SQUARE, 12, 1)
  71. cv.line(img, state_pt, measurement_pt, (0, 0, 255), 1, cv.LINE_AA, 0) # red measurement error
  72. cv.line(img, state_pt, predict_pt, (0, 255, 255), 1, cv.LINE_AA, 0) # yellow pre-meas error
  73. cv.line(img, state_pt, improved_pt, (0, 255, 0), 1, cv.LINE_AA, 0) # green post-meas error
  74. # update the real process
  75. process_noise = sqrt(kalman.processNoiseCov[0, 0]) * np.random.randn(2, 1)
  76. state = np.dot(kalman.transitionMatrix, state) + process_noise # x_k+1 = F x_k + w_k
  77. cv.imshow("Kalman", img)
  78. code = cv.waitKey(1000)
  79. if code != -1:
  80. break
  81. if code in [27, ord('q'), ord('Q')]:
  82. break
  83. print('Done')
  84. if __name__ == '__main__':
  85. print(__doc__)
  86. main()
  87. cv.destroyAllWindows()