1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
|
"""Synthesizes the cpp wrapper code and builds dynamic Python extension."""
import os
import platform
import re
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension
import torch as th
def generate_pybind_wrapper(path, headers, has_cuda):
s = "#include \"torch/extension.h\"\n\n"
if has_cuda:
s += "#include \"HalidePyTorchCudaHelpers.h\"\n"
s += "#include \"HalidePyTorchHelpers.h\"\n"
for h in headers:
s += "#include \"{}\"\n".format(os.path.splitext(h)[0]+".pytorch.h")
s += "\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n"
for h in headers:
name = os.path.splitext(h)[0]
s += " m.def(\"{}\", &{}_th_, \"PyTorch wrapper of the Halide pipeline {}\");\n".format(
name, name, name)
s += "}\n"
with open(path, 'w') as fid:
fid.write(s)
if __name__ == "__main__":
# This is where the generate Halide ops headers live. We also generate the .cpp
# wrapper in this directory
build_dir = os.getenv("BIN")
if build_dir is None or not os.path.exists(build_dir):
raise ValueError("Bin directory {} is invalid".format(build_dir))
# Path to a distribution of Halide
halide_dir = os.getenv("HALIDE_DISTRIB_PATH")
if halide_dir is None or not os.path.exists(halide_dir):
raise ValueError("Halide directory {} is invalid".format(halide_dir))
has_cuda = os.getenv("HAS_CUDA")
if has_cuda is None or has_cuda == "0":
has_cuda = False
else:
has_cuda = True
include_dirs = [build_dir, os.path.join(halide_dir, "include")]
# Note that recent versions of PyTorch (at least 1.7.1) requires C++14
# in order to compile extensions
compile_args = ["-std=c++14", "-g"]
if platform.system() == "Darwin": # on osx libstdc++ causes trouble
compile_args += ["-stdlib=libc++"]
re_cc = re.compile(r".*\.pytorch\.h")
hl_srcs = [f for f in os.listdir(build_dir) if re_cc.match(f)]
ext_name = "halide_ops"
hl_libs = [] # Halide op libraries to link to
hl_headers = [] # Halide op headers to include in the wrapper
for f in hl_srcs:
# Add all Halide generated torch wrapper
hl_src = os.path.join(build_dir, f)
# Add all Halide-generated libraries
hl_lib = hl_src.split(".")[0] + ".a"
hl_libs.append(hl_lib)
hl_header = hl_src.split(".")[0] + ".h"
hl_headers.append(os.path.basename(hl_header))
# C++ wrapper code that includes so that we get all the Halide ops in a
# single python extension
wrapper_path = os.path.join(build_dir, "pybind_wrapper.cpp")
sources = [wrapper_path]
if has_cuda:
print("Generating CUDA wrapper")
generate_pybind_wrapper(wrapper_path, hl_headers, True)
from torch.utils.cpp_extension import CUDAExtension
extension = CUDAExtension(ext_name, sources,
include_dirs=include_dirs,
extra_objects=hl_libs,
libraries=["cuda"], # Halide ops need the full cuda lib, not just the RT library
extra_compile_args=compile_args)
else:
print("Generating CPU wrapper")
generate_pybind_wrapper(wrapper_path, hl_headers, False)
from torch.utils.cpp_extension import CppExtension
extension = CppExtension(ext_name, sources,
include_dirs=include_dirs,
extra_objects=hl_libs,
extra_compile_args=compile_args)
# Build the Python extension module
setup(name=ext_name,
verbose=True,
url="",
author_email="your@email.com",
author="Some Author",
version="0.0.0",
ext_modules=[extension],
cmdclass={"build_ext": BuildExtension}
)
|