File: test_nanobind_adapter.cpp

package info (click to toggle)
gridtools 2.3.9-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 29,480 kB
  • sloc: cpp: 228,792; python: 17,561; javascript: 9,164; ansic: 4,101; sh: 850; makefile: 231; f90: 201
file content (88 lines) | stat: -rw-r--r-- 3,328 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
/*
 * GridTools
 *
 * Copyright (c) 2014-2023, ETH Zurich
 * All rights reserved.
 *
 * Please, refer to the LICENSE file in the root directory.
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include <gridtools/storage/adapter/nanobind_adapter.hpp>

#include <Python.h>
#include <array>
#include <gridtools/common/integral_constant.hpp>
#include <gridtools/sid/concept.hpp>

#include <gtest/gtest.h>

namespace nb = nanobind;

class python_init_fixture : public ::testing::Test {
  protected:
    void SetUp() override { Py_Initialize(); }
    void TearDown() override { Py_FinalizeEx(); }
};

TEST_F(python_init_fixture, NanobindAdapterDataDynStrides) {
    const auto data = reinterpret_cast<void *>(0xDEADBEEF);
    constexpr int ndim = 2;
    constexpr std::array<std::size_t, ndim> shape = {3, 4};
    constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
    nb::ndarray<int, nb::shape<-1, -1>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

    const auto sid = gridtools::nanobind::as_sid(ndarray);
    const auto s_origin = sid_get_origin(sid);
    const auto s_strides = sid_get_strides(sid);
    const auto s_ptr = s_origin();

    EXPECT_EQ(s_ptr, data);
    EXPECT_EQ(strides[0], gridtools::get<0>(s_strides));
    EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST_F(python_init_fixture, NanobindAdapterReadOnly) {
    const auto data = reinterpret_cast<void *>(0xDEADBEEF);
    constexpr int ndim = 2;
    constexpr std::array<std::size_t, ndim> shape = {3, 4};
    constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
    nb::ndarray<int, nb::shape<-1, -1>, nb::ro> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

    const auto sid = gridtools::nanobind::as_sid(ndarray);
    using element_t = gridtools::sid::element_type<decltype(sid)>;
    static_assert(std::is_same_v<element_t, int const>);

    const auto s_origin = sid_get_origin(sid);
    const auto s_strides = sid_get_strides(sid);
    const auto s_ptr = s_origin();

    EXPECT_EQ(s_ptr, data);
    EXPECT_EQ(strides[0], gridtools::get<0>(s_strides));
    EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST_F(python_init_fixture, NanobindAdapterStaticStridesMatch) {
    const auto data = reinterpret_cast<void *>(0xDEADBEEF);
    constexpr int ndim = 2;
    constexpr std::array<std::size_t, ndim> shape = {3, 4};
    constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
    nb::ndarray<int, nb::shape<-1, -1>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

    const auto sid = gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<1, -1>{});
    const auto s_strides = sid_get_strides(sid);

    EXPECT_EQ(strides[0], gridtools::get<0>(s_strides).value);
    EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST_F(python_init_fixture, NanobindAdapterStaticStridesMismatch) {
    const auto data = reinterpret_cast<void *>(0xDEADBEEF);
    constexpr int ndim = 2;
    constexpr std::array<std::size_t, ndim> shape = {3, 4};
    constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
    nb::ndarray<int, nb::shape<-1, -1>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

    EXPECT_THROW(
        gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<2, -1>{}), std::invalid_argument);
}