translate_blas.hpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
  2. // Copyright 2008-2016 National ICT Australia (NICTA)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // ------------------------------------------------------------------------
  15. #ifdef ARMA_USE_BLAS
  16. //! \namespace blas namespace for BLAS functions
  17. namespace blas
  18. {
  19. template<typename eT>
  20. inline
  21. void
  22. gemv(const char* transA, const blas_int* m, const blas_int* n, const eT* alpha, const eT* A, const blas_int* ldA, const eT* x, const blas_int* incx, const eT* beta, eT* y, const blas_int* incy)
  23. {
  24. arma_type_check((is_supported_blas_type<eT>::value == false));
  25. #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS)
  26. {
  27. if( is_float<eT>::value) { typedef float T; arma_fortran(arma_sgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy, 1); }
  28. else if( is_double<eT>::value) { typedef double T; arma_fortran(arma_dgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy, 1); }
  29. else if( is_cx_float<eT>::value) { typedef blas_cxf T; arma_fortran(arma_cgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy, 1); }
  30. else if(is_cx_double<eT>::value) { typedef blas_cxd T; arma_fortran(arma_zgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy, 1); }
  31. }
  32. #else
  33. {
  34. if( is_float<eT>::value) { typedef float T; arma_fortran(arma_sgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); }
  35. else if( is_double<eT>::value) { typedef double T; arma_fortran(arma_dgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); }
  36. else if( is_cx_float<eT>::value) { typedef blas_cxf T; arma_fortran(arma_cgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); }
  37. else if(is_cx_double<eT>::value) { typedef blas_cxd T; arma_fortran(arma_zgemv)(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy); }
  38. }
  39. #endif
  40. }
  41. template<typename eT>
  42. inline
  43. void
  44. gemm(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const eT* alpha, const eT* A, const blas_int* ldA, const eT* B, const blas_int* ldB, const eT* beta, eT* C, const blas_int* ldC)
  45. {
  46. arma_type_check((is_supported_blas_type<eT>::value == false));
  47. #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS)
  48. {
  49. if( is_float<eT>::value) { typedef float T; arma_fortran(arma_sgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC, 1, 1); }
  50. else if( is_double<eT>::value) { typedef double T; arma_fortran(arma_dgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC, 1, 1); }
  51. else if( is_cx_float<eT>::value) { typedef blas_cxf T; arma_fortran(arma_cgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC, 1, 1); }
  52. else if(is_cx_double<eT>::value) { typedef blas_cxd T; arma_fortran(arma_zgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC, 1, 1); }
  53. }
  54. #else
  55. {
  56. if( is_float<eT>::value) { typedef float T; arma_fortran(arma_sgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); }
  57. else if( is_double<eT>::value) { typedef double T; arma_fortran(arma_dgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); }
  58. else if( is_cx_float<eT>::value) { typedef blas_cxf T; arma_fortran(arma_cgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); }
  59. else if(is_cx_double<eT>::value) { typedef blas_cxd T; arma_fortran(arma_zgemm)(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC); }
  60. }
  61. #endif
  62. }
  63. template<typename eT>
  64. inline
  65. void
  66. syrk(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const eT* alpha, const eT* A, const blas_int* ldA, const eT* beta, eT* C, const blas_int* ldC)
  67. {
  68. arma_type_check((is_supported_blas_type<eT>::value == false));
  69. #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS)
  70. {
  71. if( is_float<eT>::value) { typedef float T; arma_fortran(arma_ssyrk)(uplo, transA, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)beta, (T*)C, ldC, 1, 1); }
  72. else if(is_double<eT>::value) { typedef double T; arma_fortran(arma_dsyrk)(uplo, transA, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)beta, (T*)C, ldC, 1, 1); }
  73. }
  74. #else
  75. {
  76. if( is_float<eT>::value) { typedef float T; arma_fortran(arma_ssyrk)(uplo, transA, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)beta, (T*)C, ldC); }
  77. else if(is_double<eT>::value) { typedef double T; arma_fortran(arma_dsyrk)(uplo, transA, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)beta, (T*)C, ldC); }
  78. }
  79. #endif
  80. }
  81. template<typename T>
  82. inline
  83. void
  84. herk(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const T* alpha, const std::complex<T>* A, const blas_int* ldA, const T* beta, std::complex<T>* C, const blas_int* ldC)
  85. {
  86. arma_type_check((is_supported_blas_type<T>::value == false));
  87. #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS)
  88. {
  89. if( is_float<T>::value) { typedef float TT; typedef blas_cxf cx_TT; arma_fortran(arma_cherk)(uplo, transA, n, k, (const TT*)alpha, (const cx_TT*)A, ldA, (const TT*)beta, (cx_TT*)C, ldC, 1, 1); }
  90. else if(is_double<T>::value) { typedef double TT; typedef blas_cxd cx_TT; arma_fortran(arma_zherk)(uplo, transA, n, k, (const TT*)alpha, (const cx_TT*)A, ldA, (const TT*)beta, (cx_TT*)C, ldC, 1, 1); }
  91. }
  92. #else
  93. {
  94. if( is_float<T>::value) { typedef float TT; typedef blas_cxf cx_TT; arma_fortran(arma_cherk)(uplo, transA, n, k, (const TT*)alpha, (const cx_TT*)A, ldA, (const TT*)beta, (cx_TT*)C, ldC); }
  95. else if(is_double<T>::value) { typedef double TT; typedef blas_cxd cx_TT; arma_fortran(arma_zherk)(uplo, transA, n, k, (const TT*)alpha, (const cx_TT*)A, ldA, (const TT*)beta, (cx_TT*)C, ldC); }
  96. }
  97. #endif
  98. }
  99. template<typename eT>
  100. inline
  101. eT
  102. dot(const uword n_elem, const eT* x, const eT* y)
  103. {
  104. arma_type_check((is_supported_blas_type<eT>::value == false));
  105. if(is_float<eT>::value)
  106. {
  107. #if defined(ARMA_BLAS_SDOT_BUG)
  108. {
  109. if(n_elem == 0) { return eT(0); }
  110. const char trans = 'T';
  111. const blas_int m = blas_int(n_elem);
  112. const blas_int n = 1;
  113. const blas_int inc = 1;
  114. const eT alpha = eT(1);
  115. const eT beta = eT(0);
  116. eT result[2]; // paranoia: using two elements instead of one
  117. blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0], &inc);
  118. return result[0];
  119. }
  120. #else
  121. {
  122. blas_int n = blas_int(n_elem);
  123. blas_int inc = 1;
  124. typedef float T;
  125. return eT( arma_fortran(arma_sdot)(&n, (const T*)x, &inc, (const T*)y, &inc) );
  126. }
  127. #endif
  128. }
  129. else
  130. if(is_double<eT>::value)
  131. {
  132. blas_int n = blas_int(n_elem);
  133. blas_int inc = 1;
  134. typedef double T;
  135. return eT( arma_fortran(arma_ddot)(&n, (const T*)x, &inc, (const T*)y, &inc) );
  136. }
  137. else
  138. if( (is_cx_float<eT>::value) || (is_cx_double<eT>::value) )
  139. {
  140. if(n_elem == 0) { return eT(0); }
  141. // using gemv() workaround due to compatibility issues with cdotu() and zdotu()
  142. const char trans = 'T';
  143. const blas_int m = blas_int(n_elem);
  144. const blas_int n = 1;
  145. const blas_int inc = 1;
  146. const eT alpha = eT(1);
  147. const eT beta = eT(0);
  148. eT result[2]; // paranoia: using two elements instead of one
  149. blas::gemv(&trans, &m, &n, &alpha, x, &m, y, &inc, &beta, &result[0], &inc);
  150. return result[0];
  151. }
  152. return eT(0);
  153. }
  154. template<typename eT>
  155. arma_inline
  156. eT
  157. asum(const uword n_elem, const eT* x)
  158. {
  159. arma_type_check((is_supported_blas_type<eT>::value == false));
  160. if(is_float<eT>::value)
  161. {
  162. blas_int n = blas_int(n_elem);
  163. blas_int inc = 1;
  164. typedef float T;
  165. return arma_fortran(arma_sasum)(&n, (const T*)x, &inc);
  166. }
  167. else
  168. if(is_double<eT>::value)
  169. {
  170. blas_int n = blas_int(n_elem);
  171. blas_int inc = 1;
  172. typedef double T;
  173. return arma_fortran(arma_dasum)(&n, (const T*)x, &inc);
  174. }
  175. return eT(0);
  176. }
  177. template<typename eT>
  178. arma_inline
  179. eT
  180. nrm2(const uword n_elem, const eT* x)
  181. {
  182. arma_type_check((is_supported_blas_type<eT>::value == false));
  183. if(is_float<eT>::value)
  184. {
  185. blas_int n = blas_int(n_elem);
  186. blas_int inc = 1;
  187. typedef float T;
  188. return arma_fortran(arma_snrm2)(&n, (const T*)x, &inc);
  189. }
  190. else
  191. if(is_double<eT>::value)
  192. {
  193. blas_int n = blas_int(n_elem);
  194. blas_int inc = 1;
  195. typedef double T;
  196. return arma_fortran(arma_dnrm2)(&n, (const T*)x, &inc);
  197. }
  198. return eT(0);
  199. }
  200. } // namespace blas
  201. #endif