ref_reduce_arg.impl.hpp 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. // This file is part of OpenCV project.
  2. // It is subject to the license terms in the LICENSE file found in the top-level directory
  3. // of this distribution and at http://opencv.org/license.html.
  4. #ifndef OPENCV_TEST_REF_REDUCE_ARG_HPP
  5. #define OPENCV_TEST_REF_REDUCE_ARG_HPP
  6. #include "opencv2/core/detail/dispatch_helper.impl.hpp"
  7. #include <algorithm>
  8. #include <numeric>
  9. namespace cvtest {
  10. template <class Cmp, typename T>
  11. struct reduceMinMaxImpl
  12. {
  13. void operator()(const cv::Mat& src, cv::Mat& dst, const int axis) const
  14. {
  15. Cmp cmp;
  16. std::vector<int> sizes(src.dims);
  17. std::copy(src.size.p, src.size.p + src.dims, sizes.begin());
  18. std::vector<cv::Range> idx(sizes.size(), cv::Range(0, 1));
  19. idx[axis] = cv::Range::all();
  20. const int n = std::accumulate(begin(sizes), end(sizes), 1, std::multiplies<int>());
  21. const std::vector<int> newShape{1, src.size[axis]};
  22. for (int i = 0; i < n ; ++i)
  23. {
  24. cv::Mat sub = src(idx);
  25. auto begin = sub.begin<T>();
  26. auto it = std::min_element(begin, sub.end<T>(), cmp);
  27. *dst(idx).ptr<int32_t>() = static_cast<int32_t>(std::distance(begin, it));
  28. for (int j = static_cast<int>(idx.size()) - 1; j >= 0; --j)
  29. {
  30. if (j == axis)
  31. {
  32. continue;
  33. }
  34. const int old_s = idx[j].start;
  35. const int new_s = (old_s + 1) % sizes[j];
  36. if (new_s > old_s)
  37. {
  38. idx[j] = cv::Range(new_s, new_s + 1);
  39. break;
  40. }
  41. idx[j] = cv::Range(0, 1);
  42. }
  43. }
  44. }
  45. };
  46. template<template<class> class Cmp>
  47. struct MinMaxReducer{
  48. template <typename T>
  49. using Impl = reduceMinMaxImpl<Cmp<T>, T>;
  50. static void reduce(const Mat& src, Mat& dst, int axis)
  51. {
  52. axis = (axis + src.dims) % src.dims;
  53. CV_Assert(src.channels() == 1 && axis >= 0 && axis < src.dims);
  54. std::vector<int> sizes(src.dims);
  55. std::copy(src.size.p, src.size.p + src.dims, sizes.begin());
  56. sizes[axis] = 1;
  57. dst.create(sizes, CV_32SC1); // indices
  58. dst.setTo(cv::Scalar::all(0));
  59. cv::detail::depthDispatch<Impl>(src.depth(), src, dst, axis);
  60. }
  61. };
  62. }
  63. #endif //OPENCV_TEST_REF_REDUCE_ARG_HPP