File: setup.py

package info (click to toggle)
ml-dtypes 0.5.4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,768 kB
  • sloc: ansic: 48,160; cpp: 26,737; python: 2,344; pascal: 514; makefile: 15
file content (77 lines) | stat: -rw-r--r-- 2,256 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright 2022 The ml_dtypes Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Setuptool-based build for ml_dtypes."""

import fnmatch
import platform
import numpy as np
from setuptools import Extension
from setuptools import setup
from setuptools.command.build_py import build_py as build_py_orig

if platform.system() == "Windows":
  COMPILE_ARGS = [
      "/std:c++17",
      "/DEIGEN_MPL2_ONLY",
      "/EHsc",
      "/bigobj",
  ]
else:
  COMPILE_ARGS = [
      "-std=c++17",
      "-DEIGEN_MPL2_ONLY",
      "-fvisibility=hidden",
      # -ftrapping-math is necessary because NumPy looks at floating point
      # exception state to determine whether to emit, e.g., invalid value
      # warnings. Without this setting, on Mac ARM we see spurious "invalid
      # value" warnings when running the tests.
      "-ftrapping-math",
  ]

exclude = ["third_party*"]


class build_py(build_py_orig):  # pylint: disable=invalid-name

  def find_package_modules(self, package, package_dir):
    modules = super().find_package_modules(package, package_dir)
    return [  # pylint: disable=g-complex-comprehension
        (pkg, mod, file)
        for (pkg, mod, file) in modules
        if not any(
            fnmatch.fnmatchcase(pkg + "." + mod, pat=pattern)
            for pattern in exclude
        )
    ]


setup(
    ext_modules=[
        Extension(
            "ml_dtypes._ml_dtypes_ext",
            [
                "ml_dtypes/_src/dtypes.cc",
                "ml_dtypes/_src/numpy.cc",
            ],
            include_dirs=[
                "third_party/eigen",
                ".",
                np.get_include(),
            ],
            extra_compile_args=COMPILE_ARGS,
        )
    ],
    cmdclass={"build_py": build_py},
)