1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
|
#!/usr/bin/env python3
import argparse
import os
import shutil
import subprocess
import sys
import tempfile
import urllib.request
from typing import cast, List, NoReturn, Optional
def parse_arguments() -> argparse.Namespace:
"""
Parses command-line arguments using argparse.
Returns:
argparse.Namespace: The parsed arguments containing the PR number, optional target directory, and strip count.
"""
parser = argparse.ArgumentParser(
description=(
"Download and apply a Pull Request (PR) patch from the PyTorch GitHub repository "
"to your local PyTorch installation.\n\n"
"Best Practice: Since this script involves hot-patching PyTorch, it's recommended to use "
"a disposable environment like a Docker container or a dedicated Python virtual environment (venv). "
"This ensures that if the patching fails, you can easily recover by resetting the environment."
),
epilog=(
"Example:\n"
" python nightly_hotpatch.py 12345\n"
" python nightly_hotpatch.py 12345 --directory /path/to/pytorch --strip 1\n\n"
"These commands will download the patch for PR #12345 and apply it to your local "
"PyTorch installation."
),
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"PR_NUMBER",
type=int,
help="The number of the Pull Request (PR) from the PyTorch GitHub repository to download and apply as a patch.",
)
parser.add_argument(
"--directory",
"-d",
type=str,
default=None,
help="Optional. Specify the target directory to apply the patch. "
"If not provided, the script will use the PyTorch installation path.",
)
parser.add_argument(
"--strip",
"-p",
type=int,
default=1,
help="Optional. Specify the strip count to remove leading directories from file paths in the patch. Default is 1.",
)
return parser.parse_args()
def get_pytorch_path() -> str:
"""
Retrieves the installation path of PyTorch in the current environment.
Returns:
str: The directory of the PyTorch installation.
Exits:
If PyTorch is not installed in the current Python environment, the script will exit.
"""
try:
import torch
torch_paths: List[str] = cast(List[str], torch.__path__)
torch_path: str = torch_paths[0]
parent_path: str = os.path.dirname(torch_path)
print(f"PyTorch is installed at: {torch_path}")
print(f"Parent directory for patching: {parent_path}")
return parent_path
except ImportError:
handle_import_error()
def handle_import_error() -> NoReturn:
"""
Handle the case where PyTorch is not installed and exit the program.
Exits:
NoReturn: This function will terminate the program.
"""
print("Error: PyTorch is not installed in the current Python environment.")
sys.exit(1)
def download_patch(pr_number: int, repo_url: str, download_dir: str) -> str:
"""
Downloads the patch file for a given PR from the specified GitHub repository.
Args:
pr_number (int): The pull request number.
repo_url (str): The URL of the repository where the PR is hosted.
download_dir (str): The directory to store the downloaded patch.
Returns:
str: The path to the downloaded patch file.
Exits:
If the download fails, the script will exit.
"""
patch_url = f"{repo_url}/pull/{pr_number}.diff"
patch_file = os.path.join(download_dir, f"pr-{pr_number}.patch")
print(f"Downloading PR #{pr_number} patch from {patch_url}...")
try:
with urllib.request.urlopen(patch_url) as response, open(
patch_file, "wb"
) as out_file:
shutil.copyfileobj(response, out_file)
if not os.path.isfile(patch_file):
print(f"Failed to download patch for PR #{pr_number}")
sys.exit(1)
print(f"Patch downloaded to {patch_file}")
return patch_file
except urllib.error.HTTPError as e:
print(f"HTTP Error: {e.code} when downloading patch for PR #{pr_number}")
sys.exit(1)
except Exception as e:
print(f"An error occurred while downloading the patch: {e}")
sys.exit(1)
def apply_patch(patch_file: str, target_dir: Optional[str], strip_count: int) -> None:
"""
Applies the downloaded patch to the specified directory using the given strip count.
Args:
patch_file (str): The path to the patch file.
target_dir (Optional[str]): The directory to apply the patch to. If None, uses PyTorch installation path.
strip_count (int): The number of leading directories to strip from file paths in the patch.
Exits:
If the patch command fails or the 'patch' utility is not available, the script will exit.
"""
if target_dir:
print(f"Applying patch in directory: {target_dir}")
else:
print("No target directory specified. Using PyTorch installation path.")
print(f"Applying patch with strip count: {strip_count}")
try:
# Construct the patch command with -d and -p options
patch_command = ["patch", f"-p{strip_count}", "-i", patch_file]
if target_dir:
patch_command.insert(
1, f"-d{target_dir}"
) # Insert -d option right after 'patch'
print(f"Running command: {' '.join(patch_command)}")
result = subprocess.run(patch_command, capture_output=True, text=True)
else:
patch_command.insert(1, f"-d{target_dir}")
print(f"Running command: {' '.join(patch_command)}")
result = subprocess.run(patch_command, capture_output=True, text=True)
# Check if the patch was applied successfully
if result.returncode != 0:
print("Failed to apply patch.")
print("Patch output:")
print(result.stdout)
print(result.stderr)
sys.exit(1)
else:
print("Patch applied successfully.")
except FileNotFoundError:
print("Error: The 'patch' utility is not installed or not found in PATH.")
sys.exit(1)
except Exception as e:
print(f"An error occurred while applying the patch: {e}")
sys.exit(1)
def main() -> None:
"""
Main function to orchestrate the patch download and application process.
Steps:
1. Parse command-line arguments to get the PR number, optional target directory, and strip count.
2. Retrieve the local PyTorch installation path or use the provided target directory.
3. Download the patch for the provided PR number.
4. Apply the patch to the specified directory with the given strip count.
"""
args = parse_arguments()
pr_number = args.PR_NUMBER
custom_target_dir = args.directory
strip_count = args.strip
if custom_target_dir:
if not os.path.isdir(custom_target_dir):
print(
f"Error: The specified target directory '{custom_target_dir}' does not exist."
)
sys.exit(1)
target_dir = custom_target_dir
print(f"Using custom target directory: {target_dir}")
else:
target_dir = get_pytorch_path()
repo_url = "https://github.com/pytorch/pytorch"
with tempfile.TemporaryDirectory() as tmpdirname:
patch_file = download_patch(pr_number, repo_url, tmpdirname)
apply_patch(patch_file, target_dir, strip_count)
if __name__ == "__main__":
main()
|