1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
// SPDX-FileCopyrightInfo: Copyright © DUNE Project contributors, see file LICENSE.md in module root
// SPDX-License-Identifier: LicenseRef-GPL-2.0-only-with-DUNE-exception
#ifndef DUNE_PYTHON_COMMON_NUMPYVECTOR_HH
#define DUNE_PYTHON_COMMON_NUMPYVECTOR_HH
#include <dune/common/exceptions.hh>
#include <dune/common/densevector.hh>
#include <dune/common/ftraits.hh>
#include <dune/python/pybind11/numpy.h>
#include <dune/python/pybind11/pybind11.h>
#include <dune/python/pybind11/stl.h>
namespace Dune
{
namespace Python
{
// Internal Forward Declarations
// -----------------------------
template< class T >
class NumPyVector;
} // namespace Python
// DenseMatVecTraits for NumPyVector
// ---------------------------------
template< class T >
struct DenseMatVecTraits< Python::NumPyVector< T > >
{
typedef Python::NumPyVector< T > derived_type;
typedef pybind11::array_t< T > container_type;
typedef T value_type;
typedef std::size_t size_type;
};
// FieldTraits for NumPyVector
// ---------------------------
template< class T >
struct FieldTraits< Python::NumPyVector< T > >
{
typedef typename FieldTraits< T >::field_type field_type;
typedef typename FieldTraits< T >::real_type real_type;
};
namespace Python
{
template< class T >
class NumPyVector
: public DenseVector< NumPyVector< T > >
{
typedef NumPyVector< T > This;
typedef DenseVector< NumPyVector< T > > Base;
public:
typedef typename Base::size_type size_type;
typedef typename Base::value_type value_type;
explicit NumPyVector ( size_type size )
: array_( pybind11::buffer_info( nullptr, sizeof( T ),
pybind11::format_descriptor< T >::value, 1, { size }, { sizeof( T ) } )
),
dataPtr_( static_cast< value_type * >( array_.request(true).ptr ) ),
size_(size)
{}
NumPyVector ( pybind11::buffer buf )
: array_( buf ),
dataPtr_( nullptr ),
size_( 0 )
{
pybind11::buffer_info info = buf.request();
if (info.ndim != 1)
DUNE_THROW( InvalidStateException, "NumPyVector can only be created from one-dimensional array" );
size_ = info.shape[0];
dataPtr_ = static_cast< value_type * >( array_.request(true).ptr );
}
NumPyVector ( const This &other ) = delete;
NumPyVector ( This &&other ) = delete;
~NumPyVector() {}
This &operator= ( const This &other ) = delete;
This &operator= ( This &&other ) = delete;
operator pybind11::array_t< T > () const { return array_; }
const value_type &operator[] ( size_type index ) const
{
return data()[ index ];
}
value_type &operator[] ( size_type index )
{
return data()[ index ];
}
value_type &vec_access ( size_type index )
{
return data()[ index ];
}
const value_type &vec_access ( size_type index ) const
{
return data()[ index ];
}
inline const value_type *data () const
{
assert( dataPtr_ );
return dataPtr_;
}
inline value_type *data ()
{
assert( dataPtr_ );
return dataPtr_;
}
pybind11::array_t< T > &coefficients()
{
return array_;
}
pybind11::array_t< T > &coefficients() const
{
return array_;
}
size_type size () const
{
return size_;
}
size_type vec_size () const
{
return size_;
}
protected:
pybind11::array_t< T > array_;
value_type* dataPtr_;
size_type size_;
};
} // namespace Python
} // namespace Dune
#endif // #ifndef DUNE_PYTHON_COMMON_NUMPYVECTOR_HH
|