File: setup.py

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (113 lines) | stat: -rw-r--r-- 3,791 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Synthesizes the cpp wrapper code and builds dynamic Python extension."""

import os
import platform
import re
from setuptools import setup

from torch.utils.cpp_extension import BuildExtension


def generate_pybind_wrapper(path, headers, has_cuda):
    s = '#include "torch/extension.h"\n\n'
    if has_cuda:
        s += '#include "HalidePyTorchCudaHelpers.h"\n'
    s += '#include "HalidePyTorchHelpers.h"\n'
    for h in headers:
        s += '#include "{}"\n'.format(os.path.splitext(h)[0] + ".pytorch.h")

    s += "\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n"
    for h in headers:
        name = os.path.splitext(h)[0]
        s += f'  m.def("{name}", &{name}_th_, "PyTorch wrapper of the Halide pipeline {name}");\n'
    s += "}\n"
    with open(path, "w") as fid:
        fid.write(s)


if __name__ == "__main__":
    # This is where the generate Halide ops headers live. We also generate the .cpp
    # wrapper in this directory
    build_dir = os.getenv("BIN")
    if build_dir is None or not os.path.exists(build_dir):
        raise ValueError(f"Bin directory {build_dir} is invalid")

    # Path to a distribution of Halide
    halide_dir = os.getenv("HALIDE_DISTRIB_PATH")
    if halide_dir is None or not os.path.exists(halide_dir):
        raise ValueError(f"Halide directory {halide_dir} is invalid")

    has_cuda = os.getenv("HAS_CUDA")
    if has_cuda is None or has_cuda == "0":
        has_cuda = False
    else:
        has_cuda = True

    include_dirs = [build_dir, os.path.join(halide_dir, "include")]
    # Note that recent versions of PyTorch (at least 1.7.1) requires C++14
    # in order to compile extensions
    compile_args = ["-std=c++14", "-g"]
    if platform.system() == "Darwin":  # on osx libstdc++ causes trouble
        compile_args += ["-stdlib=libc++"]

    re_cc = re.compile(r".*\.pytorch\.h")
    hl_srcs = [f for f in os.listdir(build_dir) if re_cc.match(f)]

    ext_name = "halide_ops"
    hl_libs = []  # Halide op libraries to link to
    hl_headers = []  # Halide op headers to include in the wrapper
    for f in hl_srcs:
        # Add all Halide generated torch wrapper
        hl_src = os.path.join(build_dir, f)

        # Add all Halide-generated libraries
        hl_lib = hl_src.split(".")[0] + ".a"
        hl_libs.append(hl_lib)

        hl_header = hl_src.split(".")[0] + ".h"
        hl_headers.append(os.path.basename(hl_header))

    # C++ wrapper code that includes so that we get all the Halide ops in a
    # single python extension
    wrapper_path = os.path.join(build_dir, "pybind_wrapper.cpp")
    sources = [wrapper_path]

    if has_cuda:
        print("Generating CUDA wrapper")
        generate_pybind_wrapper(wrapper_path, hl_headers, True)
        from torch.utils.cpp_extension import CUDAExtension

        extension = CUDAExtension(
            ext_name,
            sources,
            include_dirs=include_dirs,
            extra_objects=hl_libs,
            libraries=[
                "cuda"
            ],  # Halide ops need the full cuda lib, not just the RT library
            extra_compile_args=compile_args,
        )
    else:
        print("Generating CPU wrapper")
        generate_pybind_wrapper(wrapper_path, hl_headers, False)
        from torch.utils.cpp_extension import CppExtension

        extension = CppExtension(
            ext_name,
            sources,
            include_dirs=include_dirs,
            extra_objects=hl_libs,
            extra_compile_args=compile_args,
        )

    # Build the Python extension module
    setup(
        name=ext_name,
        verbose=True,
        url="",
        author_email="your@email.com",
        author="Some Author",
        version="0.0.0",
        ext_modules=[extension],
        cmdclass={"build_ext": BuildExtension},
    )