File: setup.py

package info (click to toggle)
halide 14.0.0-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 49,124 kB
  • sloc: cpp: 238,722; makefile: 4,303; python: 4,047; java: 1,575; sh: 1,384; pascal: 211; xml: 165; javascript: 43; ansic: 34
file content (102 lines) | stat: -rw-r--r-- 3,931 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Synthesizes the cpp wrapper code and builds dynamic Python extension."""
import os
import platform
import re
from setuptools import setup, find_packages

from torch.utils.cpp_extension import BuildExtension
import torch as th


def generate_pybind_wrapper(path, headers, has_cuda):
    s = "#include \"torch/extension.h\"\n\n"
    if has_cuda:
        s += "#include \"HalidePyTorchCudaHelpers.h\"\n"
    s += "#include \"HalidePyTorchHelpers.h\"\n"
    for h in headers:
        s += "#include \"{}\"\n".format(os.path.splitext(h)[0]+".pytorch.h")

    s += "\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n"
    for h in headers:
        name = os.path.splitext(h)[0]
        s += "  m.def(\"{}\", &{}_th_, \"PyTorch wrapper of the Halide pipeline {}\");\n".format(
          name, name, name)
    s += "}\n"
    with open(path, 'w') as fid:
        fid.write(s)

if __name__ == "__main__":
    # This is where the generate Halide ops headers live. We also generate the .cpp
    # wrapper in this directory
    build_dir = os.getenv("BIN")
    if build_dir is None or not os.path.exists(build_dir):
        raise ValueError("Bin directory {} is invalid".format(build_dir))

    # Path to a distribution of Halide
    halide_dir = os.getenv("HALIDE_DISTRIB_PATH")
    if halide_dir is None or not os.path.exists(halide_dir):
        raise ValueError("Halide directory {} is invalid".format(halide_dir))

    has_cuda = os.getenv("HAS_CUDA")
    if has_cuda is None or has_cuda == "0":
        has_cuda = False
    else:
        has_cuda = True

    include_dirs = [build_dir, os.path.join(halide_dir, "include")]
    # Note that recent versions of PyTorch (at least 1.7.1) requires C++14
    # in order to compile extensions
    compile_args = ["-std=c++14", "-g"]
    if platform.system() == "Darwin":  # on osx libstdc++ causes trouble
        compile_args += ["-stdlib=libc++"]

    re_cc = re.compile(r".*\.pytorch\.h")
    hl_srcs = [f for f in os.listdir(build_dir) if re_cc.match(f)]

    ext_name = "halide_ops"
    hl_libs = []  # Halide op libraries to link to
    hl_headers = []  # Halide op headers to include in the wrapper
    for f in hl_srcs:
        # Add all Halide generated torch wrapper
        hl_src = os.path.join(build_dir, f)

        # Add all Halide-generated libraries
        hl_lib = hl_src.split(".")[0] + ".a"
        hl_libs.append(hl_lib)

        hl_header = hl_src.split(".")[0] + ".h"
        hl_headers.append(os.path.basename(hl_header))

    # C++ wrapper code that includes so that we get all the Halide ops in a
    # single python extension
    wrapper_path = os.path.join(build_dir, "pybind_wrapper.cpp")
    sources = [wrapper_path]

    if has_cuda:
        print("Generating CUDA wrapper")
        generate_pybind_wrapper(wrapper_path, hl_headers, True)
        from torch.utils.cpp_extension import CUDAExtension
        extension = CUDAExtension(ext_name, sources,
                                  include_dirs=include_dirs,
                                  extra_objects=hl_libs,
                                  libraries=["cuda"],  # Halide ops need the full cuda lib, not just the RT library
                                  extra_compile_args=compile_args)
    else:
        print("Generating CPU wrapper")
        generate_pybind_wrapper(wrapper_path, hl_headers, False)
        from torch.utils.cpp_extension import CppExtension
        extension = CppExtension(ext_name, sources,
                                 include_dirs=include_dirs,
                                 extra_objects=hl_libs,
                                 extra_compile_args=compile_args)

    # Build the Python extension module
    setup(name=ext_name,
          verbose=True,
          url="",
          author_email="your@email.com",
          author="Some Author",
          version="0.0.0",
          ext_modules=[extension],
          cmdclass={"build_ext": BuildExtension}
          )