File: process_commit.py

package info (click to toggle)
pytorch-vision 0.21.0-3
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 20,228 kB
  • sloc: python: 65,904; cpp: 11,406; ansic: 2,459; java: 550; sh: 265; xml: 79; objc: 56; makefile: 33
file content (81 lines) | stat: -rw-r--r-- 2,470 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
This script finds the merger responsible for labeling a PR by a commit SHA. It is used by the workflow in
'.github/workflows/pr-labels.yml'. If there exists no PR associated with the commit or the PR is properly labeled,
this script is a no-op.

Note: we ping the merger only, not the reviewers, as the reviewers can sometimes be external to torchvision
with no labeling responsibility, so we don't want to bother them.
"""

import sys
from typing import Any, Optional, Set, Tuple

import requests

# For a PR to be properly labeled it should have one primary label and one secondary label
PRIMARY_LABELS = {
    "new feature",
    "bug",
    "code quality",
    "enhancement",
    "bc-breaking",
    "deprecation",
    "other",
    "prototype",
}

SECONDARY_LABELS = {
    "dependency issue",
    "module: c++ frontend",
    "module: ci",
    "module: datasets",
    "module: documentation",
    "module: io",
    "module: models.quantization",
    "module: models",
    "module: onnx",
    "module: ops",
    "module: reference scripts",
    "module: rocm",
    "module: tests",
    "module: transforms",
    "module: utils",
    "module: video",
    "Perf",
    "Revert(ed)",
    "topic: build",
}


def query_torchvision(cmd: str, *, accept) -> Any:
    response = requests.get(f"https://api.github.com/repos/pytorch/vision/{cmd}", headers=dict(Accept=accept))
    return response.json()


def get_pr_number(commit_hash: str) -> Optional[int]:
    # See https://docs.github.com/en/rest/reference/repos#list-pull-requests-associated-with-a-commit
    data = query_torchvision(f"commits/{commit_hash}/pulls", accept="application/vnd.github.groot-preview+json")
    if not data:
        return None
    return data[0]["number"]


def get_pr_merger_and_labels(pr_number: int) -> Tuple[str, Set[str]]:
    # See https://docs.github.com/en/rest/reference/pulls#get-a-pull-request
    data = query_torchvision(f"pulls/{pr_number}", accept="application/vnd.github.v3+json")
    merger = data["merged_by"]["login"]
    labels = {label["name"] for label in data["labels"]}
    return merger, labels


if __name__ == "__main__":
    commit_hash = sys.argv[1]
    pr_number = get_pr_number(commit_hash)
    if not pr_number:
        sys.exit(0)

    merger, labels = get_pr_merger_and_labels(pr_number)
    is_properly_labeled = bool(PRIMARY_LABELS.intersection(labels) and SECONDARY_LABELS.intersection(labels))

    if not is_properly_labeled:
        print(f"@{merger}")