cv2_numpy.hpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. #ifndef CV2_NUMPY_HPP
  2. #define CV2_NUMPY_HPP
  3. #include "cv2.hpp"
  4. #include "opencv2/core.hpp"
  5. class NumpyAllocator : public cv::MatAllocator
  6. {
  7. public:
  8. NumpyAllocator() { stdAllocator = cv::Mat::getStdAllocator(); }
  9. ~NumpyAllocator() {}
  10. cv::UMatData* allocate(PyObject* o, int dims, const int* sizes, int type, size_t* step) const;
  11. cv::UMatData* allocate(int dims0, const int* sizes, int type, void* data, size_t* step, cv::AccessFlag flags, cv::UMatUsageFlags usageFlags) const CV_OVERRIDE;
  12. bool allocate(cv::UMatData* u, cv::AccessFlag accessFlags, cv::UMatUsageFlags usageFlags) const CV_OVERRIDE;
  13. void deallocate(cv::UMatData* u) const CV_OVERRIDE;
  14. const cv::MatAllocator* stdAllocator;
  15. };
  16. extern NumpyAllocator g_numpyAllocator;
  17. //======================================================================================================================
  18. // HACK(?): function from cv2_util.hpp
  19. extern int failmsg(const char *fmt, ...);
  20. namespace {
  21. template<class T>
  22. NPY_TYPES asNumpyType()
  23. {
  24. return NPY_OBJECT;
  25. }
  26. template<>
  27. NPY_TYPES asNumpyType<bool>()
  28. {
  29. return NPY_BOOL;
  30. }
  31. #define CV_GENERATE_INTEGRAL_TYPE_NPY_CONVERSION(src, dst) \
  32. template<> \
  33. NPY_TYPES asNumpyType<src>() \
  34. { \
  35. return NPY_##dst; \
  36. } \
  37. template<> \
  38. NPY_TYPES asNumpyType<u##src>() \
  39. { \
  40. return NPY_U##dst; \
  41. }
  42. CV_GENERATE_INTEGRAL_TYPE_NPY_CONVERSION(int8_t, INT8);
  43. CV_GENERATE_INTEGRAL_TYPE_NPY_CONVERSION(int16_t, INT16);
  44. CV_GENERATE_INTEGRAL_TYPE_NPY_CONVERSION(int32_t, INT32);
  45. CV_GENERATE_INTEGRAL_TYPE_NPY_CONVERSION(int64_t, INT64);
  46. #undef CV_GENERATE_INTEGRAL_TYPE_NPY_CONVERSION
  47. template<>
  48. NPY_TYPES asNumpyType<float>()
  49. {
  50. return NPY_FLOAT;
  51. }
  52. template<>
  53. NPY_TYPES asNumpyType<double>()
  54. {
  55. return NPY_DOUBLE;
  56. }
  57. template <class T>
  58. PyArray_Descr* getNumpyTypeDescriptor()
  59. {
  60. return PyArray_DescrFromType(asNumpyType<T>());
  61. }
  62. template <>
  63. PyArray_Descr* getNumpyTypeDescriptor<size_t>()
  64. {
  65. #if SIZE_MAX == ULONG_MAX
  66. return PyArray_DescrFromType(NPY_ULONG);
  67. #elif SIZE_MAX == ULLONG_MAX
  68. return PyArray_DescrFromType(NPY_ULONGLONG);
  69. #else
  70. return PyArray_DescrFromType(NPY_UINT);
  71. #endif
  72. }
  73. template <class T, class U>
  74. bool isRepresentable(U value) {
  75. return (std::numeric_limits<T>::min() <= value) && (value <= std::numeric_limits<T>::max());
  76. }
  77. template<class T>
  78. bool canBeSafelyCasted(PyObject* obj, PyArray_Descr* to)
  79. {
  80. return PyArray_CanCastTo(PyArray_DescrFromScalar(obj), to) != 0;
  81. }
  82. template<>
  83. bool canBeSafelyCasted<size_t>(PyObject* obj, PyArray_Descr* to)
  84. {
  85. PyArray_Descr* from = PyArray_DescrFromScalar(obj);
  86. if (PyArray_CanCastTo(from, to))
  87. {
  88. return true;
  89. }
  90. else
  91. {
  92. // False negative scenarios:
  93. // - Signed input is positive so it can be safely cast to unsigned output
  94. // - Input has wider limits but value is representable within output limits
  95. // - All the above
  96. if (PyDataType_ISSIGNED(from))
  97. {
  98. int64_t input = 0;
  99. PyArray_CastScalarToCtype(obj, &input, getNumpyTypeDescriptor<int64_t>());
  100. return (input >= 0) && isRepresentable<size_t>(static_cast<uint64_t>(input));
  101. }
  102. else
  103. {
  104. uint64_t input = 0;
  105. PyArray_CastScalarToCtype(obj, &input, getNumpyTypeDescriptor<uint64_t>());
  106. return isRepresentable<size_t>(input);
  107. }
  108. return false;
  109. }
  110. }
  111. template<class T>
  112. bool parseNumpyScalar(PyObject* obj, T& value)
  113. {
  114. if (PyArray_CheckScalar(obj))
  115. {
  116. // According to the numpy documentation:
  117. // There are 21 statically-defined PyArray_Descr objects for the built-in data-types
  118. // So descriptor pointer is not owning.
  119. PyArray_Descr* to = getNumpyTypeDescriptor<T>();
  120. if (canBeSafelyCasted<T>(obj, to))
  121. {
  122. PyArray_CastScalarToCtype(obj, &value, to);
  123. return true;
  124. }
  125. }
  126. return false;
  127. }
  128. struct SafeSeqItem
  129. {
  130. PyObject * item;
  131. SafeSeqItem(PyObject *obj, size_t idx) { item = PySequence_GetItem(obj, idx); }
  132. ~SafeSeqItem() { Py_XDECREF(item); }
  133. private:
  134. SafeSeqItem(const SafeSeqItem&); // = delete
  135. SafeSeqItem& operator=(const SafeSeqItem&); // = delete
  136. };
  137. template <class T>
  138. class RefWrapper
  139. {
  140. public:
  141. RefWrapper(T& item) : item_(item) {}
  142. T& get() CV_NOEXCEPT { return item_; }
  143. private:
  144. T& item_;
  145. };
  146. // In order to support this conversion on 3.x branch - use custom reference_wrapper
  147. // and C-style array instead of std::array<T, N>
  148. template <class T, std::size_t N>
  149. bool parseSequence(PyObject* obj, RefWrapper<T> (&value)[N], const ArgInfo& info)
  150. {
  151. if (!obj || obj == Py_None)
  152. {
  153. return true;
  154. }
  155. if (!PySequence_Check(obj))
  156. {
  157. failmsg("Can't parse '%s'. Input argument doesn't provide sequence "
  158. "protocol", info.name);
  159. return false;
  160. }
  161. const std::size_t sequenceSize = PySequence_Size(obj);
  162. if (sequenceSize != N)
  163. {
  164. failmsg("Can't parse '%s'. Expected sequence length %lu, got %lu",
  165. info.name, N, sequenceSize);
  166. return false;
  167. }
  168. for (std::size_t i = 0; i < N; ++i)
  169. {
  170. SafeSeqItem seqItem(obj, i);
  171. if (!pyopencv_to(seqItem.item, value[i].get(), info))
  172. {
  173. failmsg("Can't parse '%s'. Sequence item with index %lu has a "
  174. "wrong type", info.name, i);
  175. return false;
  176. }
  177. }
  178. return true;
  179. }
  180. } // namespace
  181. #endif // CV2_NUMPY_HPP