1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
|
#!/usr/bin/env python3
"""Script to generate unannotated baseline stubs using stubgen.
Basic usage:
$ python3 scripts/create_baseline_stubs.py <project on PyPI>
Run with -h for more help.
"""
from __future__ import annotations
import argparse
import asyncio
import glob
import os.path
import re
import subprocess
import sys
import urllib.parse
from importlib.metadata import distribution
import aiohttp
import termcolor
PYRIGHT_CONFIG = "pyrightconfig.stricter.json"
def search_pip_freeze_output(project: str, output: str) -> tuple[str, str] | None:
# Look for lines such as "typed-ast==1.4.2". '-' matches '_' and
# '_' matches '-' in project name, so that "typed_ast" matches
# "typed-ast", and vice versa.
regex = "^(" + re.sub(r"[-_]", "[-_]", project) + ")==(.*)"
m = re.search(regex, output, flags=re.IGNORECASE | re.MULTILINE)
if not m:
return None
return m.group(1), m.group(2)
def get_installed_package_info(project: str) -> tuple[str, str] | None:
"""Find package information from pip freeze output.
Match project name somewhat fuzzily (case sensitive; '-' matches '_', and
vice versa).
Return (normalized project name, installed version) if successful.
"""
r = subprocess.run(["pip", "freeze"], capture_output=True, text=True, check=True)
return search_pip_freeze_output(project, r.stdout)
def run_stubgen(package: str, output: str) -> None:
print(f"Running stubgen: stubgen -o {output} -p {package}")
subprocess.run(["stubgen", "-o", output, "-p", package, "--export-less"], check=True)
def run_stubdefaulter(stub_dir: str) -> None:
print(f"Running stubdefaulter: stubdefaulter --packages {stub_dir}")
subprocess.run(["stubdefaulter", "--packages", stub_dir])
def run_black(stub_dir: str) -> None:
print(f"Running Black: black {stub_dir}")
subprocess.run(["pre-commit", "run", "black", "--files", *glob.iglob(f"{stub_dir}/**/*.pyi")])
def run_ruff(stub_dir: str) -> None:
print(f"Running Ruff: ruff check {stub_dir} --fix-only")
subprocess.run([sys.executable, "-m", "ruff", "check", stub_dir, "--fix-only"])
async def get_project_urls_from_pypi(project: str, session: aiohttp.ClientSession) -> dict[str, str]:
pypi_root = f"https://pypi.org/pypi/{urllib.parse.quote(project)}"
async with session.get(f"{pypi_root}/json") as response:
if response.status != 200:
return {}
j: dict[str, dict[str, dict[str, str]]]
j = await response.json()
return j["info"].get("project_urls") or {}
async def get_upstream_repo_url(project: str) -> str | None:
# aiohttp is overkill here, but it would also just be silly
# to have both requests and aiohttp in our requirements-tests.txt file.
async with aiohttp.ClientSession() as session:
project_urls = await get_project_urls_from_pypi(project, session)
if not project_urls:
return None
# Order the project URLs so that we put the ones
# that are most likely to point to the source code first
urls_to_check: list[str] = []
url_names_probably_pointing_to_source = ("Source", "Repository", "Homepage")
for url_name in url_names_probably_pointing_to_source:
if url := project_urls.get(url_name):
urls_to_check.append(url)
urls_to_check.extend(
url for url_name, url in project_urls.items() if url_name not in url_names_probably_pointing_to_source
)
for url in urls_to_check:
# Remove `www.`; replace `http://` with `https://`
url = re.sub(r"^(https?://)?(www\.)?", "https://", url)
netloc = urllib.parse.urlparse(url).netloc
if netloc in {"gitlab.com", "github.com", "bitbucket.org", "foss.heptapod.net"}:
# truncate to https://site.com/user/repo
upstream_repo_url = "/".join(url.split("/")[:5])
async with session.get(upstream_repo_url) as response:
if response.status == 200:
return upstream_repo_url
return None
def create_metadata(project: str, stub_dir: str, version: str) -> None:
"""Create a METADATA.toml file."""
match = re.match(r"[0-9]+.[0-9]+", version)
if match is None:
sys.exit(f"Error: Cannot parse version number: {version}")
filename = os.path.join(stub_dir, "METADATA.toml")
version = match.group(0)
if os.path.exists(filename):
return
metadata = f'version = "{version}.*"\n'
upstream_repo_url = asyncio.run(get_upstream_repo_url(project))
if upstream_repo_url is None:
warning = (
f"\nCould not find a URL pointing to the source code for {project!r}.\n"
f"Please add it as `upstream_repository` to `stubs/{project}/METADATA.toml`, if possible!\n"
)
print(termcolor.colored(warning, "red"))
else:
metadata += f'upstream_repository = "{upstream_repo_url}"\n'
print(f"Writing {filename}")
with open(filename, "w", encoding="UTF-8") as file:
file.write(metadata)
def add_pyright_exclusion(stub_dir: str) -> None:
"""Exclude stub_dir from strict pyright checks."""
with open(PYRIGHT_CONFIG, encoding="UTF-8") as f:
lines = f.readlines()
i = 0
while i < len(lines) and not lines[i].strip().startswith('"exclude": ['):
i += 1
assert i < len(lines), f"Error parsing {PYRIGHT_CONFIG}"
while not lines[i].strip().startswith("]"):
i += 1
end = i
# We assume that all third-party excludes must be at the end of the list.
# This helps with skipping special entries, such as "stubs/**/@tests/test_cases".
while lines[i - 1].strip().startswith('"stubs/'):
i -= 1
start = i
before_third_party_excludes = lines[:start]
third_party_excludes = lines[start:end]
after_third_party_excludes = lines[end:]
last_line = third_party_excludes[-1].rstrip()
if not last_line.endswith(","):
last_line += ","
third_party_excludes[-1] = last_line + "\n"
# Must use forward slash in the .json file
line_to_add = f' "{stub_dir}",\n'.replace("\\", "/")
if line_to_add in third_party_excludes:
print(f"{PYRIGHT_CONFIG} already up-to-date")
return
third_party_excludes.append(line_to_add)
third_party_excludes.sort(key=str.lower)
print(f"Updating {PYRIGHT_CONFIG}")
with open(PYRIGHT_CONFIG, "w", encoding="UTF-8") as f:
f.writelines(before_third_party_excludes)
f.writelines(third_party_excludes)
f.writelines(after_third_party_excludes)
def main() -> None:
parser = argparse.ArgumentParser(
description="""Generate baseline stubs automatically for an installed pip package
using stubgen. Also run Black and Ruff. If the name of
the project is different from the runtime Python package name, you may
need to use --package (example: --package yaml PyYAML)."""
)
parser.add_argument("project", help="name of PyPI project for which to generate stubs under stubs/")
parser.add_argument("--package", help="generate stubs for this Python package (default is autodetected)")
args = parser.parse_args()
project = args.project
package = args.package
if not re.match(r"[a-zA-Z0-9-_.]+$", project):
sys.exit(f"Invalid character in project name: {project!r}")
if not package:
package = project # default
# Try to find which packages are provided by the project
# Use default if that fails or if several packages are found
#
# The importlib.metadata module is used for projects whose name is different
# from the runtime Python package name (example: PyYAML/yaml)
dist = distribution(project).read_text("top_level.txt")
if dist is not None:
packages = [name for name in dist.split() if not name.startswith("_")]
if len(packages) == 1:
package = packages[0]
print(f'Using detected package "{package}" for project "{project}"', file=sys.stderr)
print("Suggestion: Try again with --package argument if that's not what you wanted", file=sys.stderr)
if not os.path.isdir("stubs") or not os.path.isdir("stdlib"):
sys.exit("Error: Current working directory must be the root of typeshed repository")
# Get normalized project name and version of installed package.
info = get_installed_package_info(project)
if info is None:
print(f'Error: "{project}" is not installed', file=sys.stderr)
print("", file=sys.stderr)
print(f'Suggestion: Run "python3 -m pip install {project}" and try again', file=sys.stderr)
sys.exit(1)
project, version = info
stub_dir = os.path.join("stubs", project)
package_dir = os.path.join(stub_dir, package)
if os.path.exists(package_dir):
sys.exit(f"Error: {package_dir} already exists (delete it first)")
run_stubgen(package, stub_dir)
run_stubdefaulter(stub_dir)
run_ruff(stub_dir)
run_black(stub_dir)
create_metadata(project, stub_dir, version)
# Since the generated stubs won't have many type annotations, we
# have to exclude them from strict pyright checks.
add_pyright_exclusion(stub_dir)
print("\nDone!\n\nSuggested next steps:")
print(f" 1. Manually review the generated stubs in {stub_dir}")
print(" 2. Optionally run tests and autofixes (see tests/README.md for details)")
print(" 3. Commit the changes on a new branch and create a typeshed PR (don't force-push!)")
if __name__ == "__main__":
main()
|