perf_batchDistance.cpp 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #include "perf_precomp.hpp"
  2. namespace opencv_test
  3. {
  4. using namespace perf;
  5. CV_ENUM(NormType, NORM_L1, NORM_L2, NORM_L2SQR, NORM_HAMMING, NORM_HAMMING2)
  6. typedef tuple<NormType, MatType, bool> Norm_Destination_CrossCheck_t;
  7. typedef perf::TestBaseWithParam<Norm_Destination_CrossCheck_t> Norm_Destination_CrossCheck;
  8. typedef tuple<NormType, bool> Norm_CrossCheck_t;
  9. typedef perf::TestBaseWithParam<Norm_CrossCheck_t> Norm_CrossCheck;
  10. typedef tuple<MatType, bool> Source_CrossCheck_t;
  11. typedef perf::TestBaseWithParam<Source_CrossCheck_t> Source_CrossCheck;
  12. void generateData( Mat& query, Mat& train, const int sourceType );
  13. PERF_TEST_P(Norm_Destination_CrossCheck, batchDistance_8U,
  14. testing::Combine(testing::Values((int)NORM_L1, (int)NORM_L2SQR),
  15. testing::Values(CV_32S, CV_32F),
  16. testing::Bool()
  17. )
  18. )
  19. {
  20. NormType normType = get<0>(GetParam());
  21. int destinationType = get<1>(GetParam());
  22. bool isCrossCheck = get<2>(GetParam());
  23. int knn = isCrossCheck ? 1 : 0;
  24. Mat queryDescriptors;
  25. Mat trainDescriptors;
  26. Mat dist;
  27. Mat ndix;
  28. generateData(queryDescriptors, trainDescriptors, CV_8U);
  29. TEST_CYCLE()
  30. {
  31. batchDistance(queryDescriptors, trainDescriptors, dist, destinationType, (isCrossCheck) ? ndix : noArray(),
  32. normType, knn, Mat(), 0, isCrossCheck);
  33. }
  34. SANITY_CHECK(dist);
  35. if (isCrossCheck) SANITY_CHECK(ndix);
  36. }
  37. PERF_TEST_P(Norm_CrossCheck, batchDistance_Dest_32S,
  38. testing::Combine(testing::Values((int)NORM_HAMMING, (int)NORM_HAMMING2),
  39. testing::Bool()
  40. )
  41. )
  42. {
  43. NormType normType = get<0>(GetParam());
  44. bool isCrossCheck = get<1>(GetParam());
  45. int knn = isCrossCheck ? 1 : 0;
  46. Mat queryDescriptors;
  47. Mat trainDescriptors;
  48. Mat dist;
  49. Mat ndix;
  50. generateData(queryDescriptors, trainDescriptors, CV_8U);
  51. TEST_CYCLE()
  52. {
  53. batchDistance(queryDescriptors, trainDescriptors, dist, CV_32S, (isCrossCheck) ? ndix : noArray(),
  54. normType, knn, Mat(), 0, isCrossCheck);
  55. }
  56. SANITY_CHECK(dist);
  57. if (isCrossCheck) SANITY_CHECK(ndix);
  58. }
  59. PERF_TEST_P(Source_CrossCheck, batchDistance_L2,
  60. testing::Combine(testing::Values(CV_8U, CV_32F),
  61. testing::Bool()
  62. )
  63. )
  64. {
  65. int sourceType = get<0>(GetParam());
  66. bool isCrossCheck = get<1>(GetParam());
  67. int knn = isCrossCheck ? 1 : 0;
  68. Mat queryDescriptors;
  69. Mat trainDescriptors;
  70. Mat dist;
  71. Mat ndix;
  72. generateData(queryDescriptors, trainDescriptors, sourceType);
  73. declare.time(50);
  74. TEST_CYCLE()
  75. {
  76. batchDistance(queryDescriptors, trainDescriptors, dist, CV_32F, (isCrossCheck) ? ndix : noArray(),
  77. NORM_L2, knn, Mat(), 0, isCrossCheck);
  78. }
  79. SANITY_CHECK(dist);
  80. if (isCrossCheck) SANITY_CHECK(ndix);
  81. }
  82. PERF_TEST_P(Norm_CrossCheck, batchDistance_32F,
  83. testing::Combine(testing::Values((int)NORM_L1, (int)NORM_L2SQR),
  84. testing::Bool()
  85. )
  86. )
  87. {
  88. NormType normType = get<0>(GetParam());
  89. bool isCrossCheck = get<1>(GetParam());
  90. int knn = isCrossCheck ? 1 : 0;
  91. Mat queryDescriptors;
  92. Mat trainDescriptors;
  93. Mat dist;
  94. Mat ndix;
  95. generateData(queryDescriptors, trainDescriptors, CV_32F);
  96. declare.time(100);
  97. TEST_CYCLE()
  98. {
  99. batchDistance(queryDescriptors, trainDescriptors, dist, CV_32F, (isCrossCheck) ? ndix : noArray(),
  100. normType, knn, Mat(), 0, isCrossCheck);
  101. }
  102. SANITY_CHECK(dist, 1e-4);
  103. if (isCrossCheck) SANITY_CHECK(ndix);
  104. }
  105. void generateData( Mat& query, Mat& train, const int sourceType )
  106. {
  107. const int dim = 500;
  108. const int queryDescCount = 300; // must be even number because we split train data in some cases in two
  109. const int countFactor = 4; // do not change it
  110. RNG& rng = theRNG();
  111. // Generate query descriptors randomly.
  112. // Descriptor vector elements are integer values.
  113. Mat buf( queryDescCount, dim, CV_32SC1 );
  114. rng.fill( buf, RNG::UNIFORM, Scalar::all(0), Scalar(3) );
  115. buf.convertTo( query, sourceType );
  116. // Generate train descriptors as follows:
  117. // copy each query descriptor to train set countFactor times
  118. // and perturb some one element of the copied descriptors in
  119. // in ascending order. General boundaries of the perturbation
  120. // are (0.f, 1.f).
  121. train.create( query.rows*countFactor, query.cols, sourceType );
  122. float step = (sourceType == CV_8U ? 256.f : 1.f) / countFactor;
  123. for( int qIdx = 0; qIdx < query.rows; qIdx++ )
  124. {
  125. Mat queryDescriptor = query.row(qIdx);
  126. for( int c = 0; c < countFactor; c++ )
  127. {
  128. int tIdx = qIdx * countFactor + c;
  129. Mat trainDescriptor = train.row(tIdx);
  130. queryDescriptor.copyTo( trainDescriptor );
  131. int elem = rng(dim);
  132. float diff = rng.uniform( step*c, step*(c+1) );
  133. trainDescriptor.col(elem) += diff;
  134. }
  135. }
  136. }
  137. } // namespace