File: cuda_dlink_extension.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (21 lines) | stat: -rw-r--r-- 748 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include <torch/extension.h>

// Declare the function from cuda_dlink_extension.cu.
void add_cuda(const float* a, const float* b, float* output, int size);

at::Tensor add(at::Tensor a, at::Tensor b) {
  TORCH_CHECK(a.device().is_cuda(), "a is a cuda tensor");
  TORCH_CHECK(b.device().is_cuda(), "b is a cuda tensor");
  TORCH_CHECK(a.dtype() == at::kFloat, "a is a float tensor");
  TORCH_CHECK(b.dtype() == at::kFloat, "b is a float tensor");
  TORCH_CHECK(a.sizes() == b.sizes(), "a and b should have same size");

  at::Tensor output = at::empty_like(a);
  add_cuda(a.data_ptr<float>(), b.data_ptr<float>(), output.data_ptr<float>(), a.numel());

  return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("add", &add, "a + b");
}