1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
|
"""
interneratchive.cli.cli_utils
"""
# Copyright (C) 2012-2024 Internet Archive
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations
import argparse
import json
import os
import signal
import sys
from collections import defaultdict
from collections.abc import Iterable
from typing import Mapping
from urllib.parse import parse_qsl
from internetarchive.utils import InvalidIdentifierException, validate_s3_identifier
def get_args_dict(args: list[str],
query_string: bool = False,
header: bool = False) -> dict:
args = args or []
if not isinstance(args, list):
args = [args]
metadata: dict[str, list | str] = defaultdict(list)
for md in args:
if query_string:
if (":" in md) and ("=" not in md):
md = md.replace(":", "=").replace(";", "&")
for key, value in parse_qsl(md):
assert value
metadata[key] = value
else:
key, value = md.split(":", 1)
assert value
if value not in metadata[key]:
metadata[key].append(value) # type: ignore
for key in metadata: # noqa: PLC0206
# Flatten single item lists.
if len(metadata[key]) <= 1:
metadata[key] = metadata[key][0]
return metadata
def convert_str_list_to_unicode(str_list: list[bytes]):
encoding = sys.getfilesystemencoding()
return [b.decode(encoding) for b in str_list]
def validate_identifier(identifier):
try:
validate_s3_identifier(identifier)
except InvalidIdentifierException as e:
raise argparse.ArgumentTypeError(str(e))
return identifier
def flatten_list(lst):
"""Flatten a list if it contains lists."""
result = []
for item in lst:
if isinstance(item, Iterable) and not isinstance(item, str):
result.extend(flatten_list(item)) # Recursively flatten
else:
result.append(item) # Just append the item if it's not a list
return result
class FlattenListAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
# Flatten the list of values (if nested)
flattened = flatten_list(values)
# Initialize the attribute if it doesn't exist yet
if getattr(namespace, self.dest, None) is None:
setattr(namespace, self.dest, [])
# Append the flattened list to the existing attribute
getattr(namespace, self.dest).extend(flattened)
class PostDataAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
current_value = getattr(namespace, self.dest, None)
# Split values into individual JSON objects (if needed) and parse them
all_values = []
for value in values:
try:
obj = json.loads(value)
all_values.append(obj)
except json.JSONDecodeError as e:
parser.error(f"Invalid JSON format for post data: {value}")
# If there is no current value (first argument), initialize it as an object or list
if current_value is None:
# If there's only one value, don't wrap it in a list
if len(all_values) == 1:
post_data = all_values[0]
else:
post_data = all_values
elif isinstance(current_value, list):
# If it's already a list, append the new values to it
post_data = current_value + all_values
else:
# If it's a single object (first argument), convert it into a list and append new data
post_data = [current_value] + all_values
# Set the final value back to the namespace
setattr(namespace, self.dest, post_data)
class QueryStringAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
# Initialize the destination as an empty dictionary if it doesn't exist
if getattr(namespace, self.dest, None) is None:
setattr(namespace, self.dest, {})
for sublist in values:
if "=" not in sublist and ":" in sublist:
sublist = sublist.replace(":", "=", 1)
key_value_pairs = parse_qsl(sublist)
if sublist and not key_value_pairs:
parser.error(f"{option_string} must be formatted as 'key=value' "
"or 'key:value'")
for key, value in key_value_pairs:
current_dict = getattr(namespace, self.dest)
if key in current_dict:
current_dict[key].append(value)
else:
current_dict[key] = [value]
current_dict = getattr(namespace, self.dest)
for key, value in current_dict.items():
if len(value) == 1:
current_dict[key] = value[0]
class MetadataAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
# Initialize the destination as an empty dictionary if it doesn't exist
if getattr(namespace, self.dest, None) is None:
setattr(namespace, self.dest, {})
for sublist in values:
if ":" not in sublist and "=" in sublist:
sublist = sublist.replace("=", ":", 1)
try:
key, value = sublist.split(":", 1)
except ValueError:
parser.error(f"{option_string} must be formatted as 'KEY:VALUE'")
current_dict = getattr(namespace, self.dest)
if key in current_dict:
if not isinstance(current_dict[key], list):
current_dict[key] = [current_dict[key]]
current_dict[key].append(value)
else:
current_dict[key] = value
def validate_dir_path(path):
"""
Check if the given path is a directory that exists.
Args:
path (str): The path to check.
Returns:
str: The validated directory path.
Raises:
argparse.ArgumentTypeError: If the path is not a valid directory.
"""
if os.path.isdir(path):
return path
else:
raise argparse.ArgumentTypeError(f"'{path}' is not a valid directory")
def exit_on_signal(sig, frame):
"""
Exit the program cleanly upon receiving a specified signal.
This function is designed to be used as a signal handler. When a signal
(such as SIGINT or SIGPIPE) is received, it exits the program with an
exit code of 128 plus the signal number. This convention helps to
distinguish between regular exit codes and those caused by signals.
"""
exit_code = 128 + sig
sys.exit(exit_code)
|