1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
|
#pragma once
#include "extensions.h"
namespace scatter {
SCATTER_API int64_t cuda_version() noexcept;
namespace detail {
SCATTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
} // namespace detail
} // namespace scatter
SCATTER_API torch::Tensor
scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API torch::Tensor
scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API torch::Tensor
scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API torch::Tensor
segment_sum_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API torch::Tensor
segment_mean_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_min_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_max_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API torch::Tensor
gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out);
SCATTER_API torch::Tensor
segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
SCATTER_API torch::Tensor
segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_min_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_max_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
SCATTER_API torch::Tensor
gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
|