1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
|
import argparse
import hashlib
import json
import logging
import os
import platform
import stat
import subprocess
import sys
import textwrap
import urllib.error
import urllib.request
from pathlib import Path
# String representing the host platform (e.g. Linux, Darwin).
HOST_PLATFORM = platform.system()
# PyTorch directory root
try:
result = subprocess.run(
["git", "rev-parse", "--show-toplevel"],
stdout=subprocess.PIPE,
check=True,
)
PYTORCH_ROOT = result.stdout.decode("utf-8").strip()
except subprocess.CalledProcessError:
# If git is not installed, compute repo root as 3 folders up from this file
path_ = os.path.abspath(__file__)
for _ in range(4):
path_ = os.path.dirname(path_)
PYTORCH_ROOT = path_
DRY_RUN = False
def compute_file_sha256(path: str) -> str:
"""Compute the SHA256 hash of a file and return it as a hex string."""
# If the file doesn't exist, return an empty string.
if not os.path.exists(path):
return ""
hash = hashlib.sha256()
# Open the file in binary mode and hash it.
with open(path, "rb") as f:
for b in f:
hash.update(b)
# Return the hash as a hexadecimal string.
return hash.hexdigest()
def report_download_progress(
chunk_number: int, chunk_size: int, file_size: int
) -> None:
"""
Pretty printer for file download progress.
"""
if file_size != -1:
percent = min(1, (chunk_number * chunk_size) / file_size)
bar = "#" * int(64 * percent)
sys.stdout.write("\r0% |{:<64}| {}%".format(bar, int(percent * 100)))
def check(binary_path: Path, reference_hash: str) -> bool:
"""Check whether the binary exists and is the right one.
If there is hash difference, delete the actual binary.
"""
if not binary_path.exists():
logging.info(f"{binary_path} does not exist.")
return False
existing_binary_hash = compute_file_sha256(str(binary_path))
if existing_binary_hash == reference_hash:
return True
logging.warning(
textwrap.dedent(
f"""\
Found binary hash does not match reference!
Found hash: {existing_binary_hash}
Reference hash: {reference_hash}
Deleting {binary_path} just to be safe.
"""
)
)
if DRY_RUN:
logging.critical(
"In dry run mode, so not actually deleting the binary. But consider deleting it ASAP!"
)
return False
try:
binary_path.unlink()
except OSError as e:
logging.critical(f"Failed to delete binary: {e}")
logging.critical(
"Delete this binary as soon as possible and do not execute it!"
)
return False
def download(
name: str,
output_dir: str,
url: str,
reference_bin_hash: str,
) -> bool:
"""
Download a platform-appropriate binary if one doesn't already exist at the expected location and verifies
that it is the right binary by checking its SHA256 hash against the expected hash.
"""
# First check if we need to do anything
binary_path = Path(output_dir, name)
if check(binary_path, reference_bin_hash):
logging.info(f"Correct binary already exists at {binary_path}. Exiting.")
return True
# Create the output folder
binary_path.parent.mkdir(parents=True, exist_ok=True)
# Download the binary
logging.info(f"Downloading {url} to {binary_path}")
if DRY_RUN:
logging.info("Exiting as there is nothing left to do in dry run mode")
return True
urllib.request.urlretrieve(
url,
binary_path,
reporthook=report_download_progress if sys.stdout.isatty() else None,
)
logging.info(f"Downloaded {name} successfully.")
# Check the downloaded binary
if not check(binary_path, reference_bin_hash):
logging.critical(f"Downloaded binary {name} failed its hash check")
return False
# Ensure that exeuctable bits are set
mode = os.stat(binary_path).st_mode
mode |= stat.S_IXUSR
os.chmod(binary_path, mode)
logging.info(f"Using {name} located at {binary_path}")
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="downloads and checks binaries from s3",
)
parser.add_argument(
"--config-json",
required=True,
help="Path to config json that describes where to find binaries and hashes",
)
parser.add_argument(
"--linter",
required=True,
help="Which linter to initialize from the config json",
)
parser.add_argument(
"--output-dir",
required=True,
help="place to put the binary",
)
parser.add_argument(
"--output-name",
required=True,
help="name of binary",
)
parser.add_argument(
"--dry-run",
default=False,
help="do not download, just print what would be done",
)
args = parser.parse_args()
if args.dry_run == "0":
DRY_RUN = False
else:
DRY_RUN = True
logging.basicConfig(
format="[DRY_RUN] %(levelname)s: %(message)s"
if DRY_RUN
else "%(levelname)s: %(message)s",
level=logging.INFO,
stream=sys.stderr,
)
config = json.load(open(args.config_json))
config = config[args.linter]
# If the host platform is not in platform_to_hash, it is unsupported.
if HOST_PLATFORM not in config:
logging.error(f"Unsupported platform: {HOST_PLATFORM}")
exit(1)
url = config[HOST_PLATFORM]["download_url"]
hash = config[HOST_PLATFORM]["hash"]
ok = download(args.output_name, args.output_dir, url, hash)
if not ok:
logging.critical(f"Unable to initialize {args.linter}")
sys.exit(1)
|