File: numpyvector.hh

package info (click to toggle)
dune-common 2.10.0-6
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 5,824 kB
  • sloc: cpp: 52,256; python: 3,979; sh: 1,658; makefile: 17
file content (155 lines) | stat: -rw-r--r-- 3,822 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
// SPDX-FileCopyrightInfo: Copyright © DUNE Project contributors, see file LICENSE.md in module root
// SPDX-License-Identifier: LicenseRef-GPL-2.0-only-with-DUNE-exception
#ifndef DUNE_PYTHON_COMMON_NUMPYVECTOR_HH
#define DUNE_PYTHON_COMMON_NUMPYVECTOR_HH

#include <dune/common/exceptions.hh>
#include <dune/common/densevector.hh>
#include <dune/common/ftraits.hh>

#include <dune/python/pybind11/numpy.h>
#include <dune/python/pybind11/pybind11.h>
#include <dune/python/pybind11/stl.h>

namespace Dune
{

  namespace Python
  {

    // Internal Forward Declarations
    // -----------------------------

    template< class T >
    class NumPyVector;

  } // namespace Python



  // DenseMatVecTraits for NumPyVector
  // ---------------------------------

  template< class T >
  struct DenseMatVecTraits< Python::NumPyVector< T > >
  {
    typedef Python::NumPyVector< T > derived_type;
    typedef pybind11::array_t< T > container_type;
    typedef T value_type;
    typedef std::size_t size_type;
  };



  // FieldTraits for NumPyVector
  // ---------------------------

  template< class T >
  struct FieldTraits< Python::NumPyVector< T > >
  {
    typedef typename FieldTraits< T >::field_type field_type;
    typedef typename FieldTraits< T >::real_type real_type;
  };


  namespace Python
  {

    template< class T >
    class NumPyVector
      : public DenseVector< NumPyVector< T > >
    {
      typedef NumPyVector< T > This;
      typedef DenseVector< NumPyVector< T > > Base;

    public:
      typedef typename Base::size_type size_type;
      typedef typename Base::value_type value_type;

      explicit NumPyVector ( size_type size )
        : array_( pybind11::buffer_info( nullptr, sizeof( T ),
                  pybind11::format_descriptor< T >::value, 1, { size }, { sizeof( T ) } )
                ),
          dataPtr_( static_cast< value_type * >( array_.request(true).ptr ) ),
          size_(size)
      {}

      NumPyVector ( pybind11::buffer buf )
        : array_( buf ),
          dataPtr_( nullptr ),
          size_( 0 )
      {
        pybind11::buffer_info info = buf.request();
        if (info.ndim != 1)
          DUNE_THROW( InvalidStateException, "NumPyVector can only be created from one-dimensional array" );
        size_ = info.shape[0];

        dataPtr_ = static_cast< value_type * >( array_.request(true).ptr );
      }

      NumPyVector ( const This &other ) = delete;
      NumPyVector ( This &&other ) = delete;

      ~NumPyVector() {}

      This &operator= ( const This &other ) = delete;
      This &operator= ( This &&other ) = delete;

      operator pybind11::array_t< T > () const { return array_; }

      const value_type &operator[] ( size_type index ) const
      {
        return data()[ index ];
      }
      value_type &operator[] ( size_type index )
      {
        return data()[ index ];
      }
      value_type &vec_access ( size_type index )
      {
        return data()[ index ];
      }
      const value_type &vec_access ( size_type index ) const
      {
        return data()[ index ];
      }

      inline const value_type *data () const
      {
        assert( dataPtr_ );
        return dataPtr_;
      }
      inline value_type *data ()
      {
        assert( dataPtr_ );
        return dataPtr_;
      }
      pybind11::array_t< T > &coefficients()
      {
        return array_;
      }
      pybind11::array_t< T > &coefficients() const
      {
        return array_;
      }

      size_type size () const
      {
        return size_;
      }
      size_type vec_size () const
      {
        return size_;
      }

    protected:
      pybind11::array_t< T > array_;
      value_type* dataPtr_;
      size_type size_;
    };

  } // namespace Python

} // namespace Dune

#endif // #ifndef DUNE_PYTHON_COMMON_NUMPYVECTOR_HH