1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
|
/*
* tensor_func.h
*
* Created on: Dec 30, 2013
* Author: evaleev
*/
#ifndef BTAS_TENSOR_FUNC_H_
#define BTAS_TENSOR_FUNC_H_
#include <btas/tensorview.h>
namespace btas {
// Helper template for TensorViewOf
template<typename _T>
using Nref = typename std::remove_reference<_T>::type;
// Maps Tensor -> TensorView,
// TensorView -> TensorView
// appropriately transferring constness of the storage, that is,
// if _T is const, uses const _T::storage_type, otherwise just _T::storage_type
template<typename _T>
using TensorViewOf = TensorView<typename Nref<_T>::value_type,
typename Nref<_T>::range_type,
typename std::conditional<std::is_const<Nref<_T>>::value,
const typename Nref<_T>::storage_type,
typename Nref<_T>::storage_type
>::type>;
template<typename _T,
typename _Permutation>
TensorViewOf<_T>
permute( _T&& t,
_Permutation p) {
return make_view( permute(t.range(), p), t.storage() );
}
template<typename _T,
typename _U>
TensorViewOf<_T>
permute( _T&& t,
std::initializer_list<_U> p) {
return make_view( permute(t.range(), p), t.storage() );
}
template <typename _T>
TensorViewOf<_T>
diag(_T&& T)
{
return make_view(diag(T.range()),T.storage());
}
template <typename _T, typename ArrayType>
TensorViewOf<_T>
tieIndex(_T&& T,
const ArrayType& inds)
{
return make_view(tieIndex(T.range(),inds),T.storage());
}
template <typename _T, typename... _args>
TensorViewOf<_T>
tieIndex(_T&& T,
size_t i0,
const _args&... rest)
{
const auto size = 1 + sizeof...(rest);
std::array<size_t,size> inds = { i0, static_cast<size_t>(rest)...};
return make_view(tieIndex(T.range(),inds),T.storage());
}
template <typename _T>
TensorViewOf<_T>
group(_T&& T,
size_t istart,
size_t iend)
{
return make_view(group(T.range(),istart,iend),T.storage());
}
template <typename _T>
TensorViewOf<_T>
flatten(_T&& T)
{
return make_view(flatten(T.range()),T.storage());
}
} // namespace btas
#endif /* BTAS_TENSOR_FUNC_H_ */
|