File: tensor_func.h

package info (click to toggle)
bagel 1.2.2-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 134,940 kB
  • sloc: cpp: 1,236,571; javascript: 15,383; python: 1,461; ansic: 674; makefile: 253; sh: 109
file content (92 lines) | stat: -rw-r--r-- 2,423 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
/*
 * tensor_func.h
 *
 *  Created on: Dec 30, 2013
 *      Author: evaleev
 */

#ifndef BTAS_TENSOR_FUNC_H_
#define BTAS_TENSOR_FUNC_H_

#include <btas/tensorview.h>

namespace btas {

  // Helper template for TensorViewOf
  template<typename _T>
  using Nref = typename std::remove_reference<_T>::type;

  // Maps Tensor     -> TensorView,
  //      TensorView -> TensorView
  // appropriately transferring constness of the storage, that is,
  // if _T is const, uses const _T::storage_type, otherwise just _T::storage_type
  template<typename _T>
  using TensorViewOf = TensorView<typename Nref<_T>::value_type,
                                  typename Nref<_T>::range_type,
                                  typename std::conditional<std::is_const<Nref<_T>>::value,
                                                            const typename Nref<_T>::storage_type,
                                                            typename Nref<_T>::storage_type
                                                           >::type>;

  template<typename _T,
           typename _Permutation>
  TensorViewOf<_T>
  permute( _T&& t,
           _Permutation p) {
      return make_view( permute(t.range(), p), t.storage() );
  }

  template<typename _T,
           typename _U>
  TensorViewOf<_T>
  permute( _T&& t,
           std::initializer_list<_U> p) {
      return make_view( permute(t.range(), p), t.storage() );
  }

  template <typename _T>
  TensorViewOf<_T>
  diag(_T&& T)
    {
    return make_view(diag(T.range()),T.storage());
    }

  template <typename _T, typename ArrayType>
  TensorViewOf<_T>
  tieIndex(_T&& T,
           const ArrayType& inds)
    {
    return make_view(tieIndex(T.range(),inds),T.storage());
    }

  template <typename _T, typename... _args>
  TensorViewOf<_T>
  tieIndex(_T&& T,
           size_t i0,
           const _args&... rest)
    {
    const auto size = 1 + sizeof...(rest);
    std::array<size_t,size> inds = { i0, static_cast<size_t>(rest)...};
    return make_view(tieIndex(T.range(),inds),T.storage());
    }

  template <typename _T>
  TensorViewOf<_T>
  group(_T&& T,
        size_t istart,
        size_t iend)
    {
    return make_view(group(T.range(),istart,iend),T.storage());
    }

  template <typename _T>
  TensorViewOf<_T>
  flatten(_T&& T)
    {
    return make_view(flatten(T.range()),T.storage());
    }

} // namespace btas


#endif /* BTAS_TENSOR_FUNC_H_ */