File: segment_coo_cpu.h

package info (click to toggle)
pytorch-scatter 2.1.2-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,128 kB
  • sloc: python: 1,574; cpp: 1,379; sh: 58; makefile: 13
file content (11 lines) | stat: -rw-r--r-- 434 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
#pragma once

#include "../extensions.h"

std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
                torch::optional<torch::Tensor> optional_out,
                torch::optional<int64_t> dim_size, std::string reduce);

torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
                             torch::optional<torch::Tensor> optional_out);