1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
|
import contextlib
import csv
import dbm
import defusedxml.ElementTree as ET
import importlib
import json
import mimetypes
import pathlib
import yaml
registered_loaders = []
class LoaderError(Exception):
pass
class LoaderEntry:
def __init__(self, loader, name, match_source):
self.loader = loader
self.name = name
self.match_source = match_source
def loader_for_source(source, default=None):
"Return the loader for the named source."
for e in registered_loaders:
if e.match_source is not None and e.match_source(source):
return e.loader
return default
def loader_by_name(name, default=None):
"Return the loader registered with the given name."
for e in registered_loaders:
if e.match_source is not None and e.name == name:
return e.loader
return default
def mimetype_loader(name, mimetype):
"A data loader for the exact mimetype."
def check_mimetype(source):
guess = mimetypes.guess_type(source)[0]
if not guess:
return False
return guess == mimetype
return data_source_loader(name, check_mimetype)
def lenient_mimetype_loader(name, mimetype_fragment):
"A data loader for a mimetype containing the given substring."
def check_mimetype(source):
guess = mimetypes.guess_type(source)[0]
if not guess:
return False
return mimetype_fragment in guess
return data_source_loader(name, check_mimetype)
def file_extension_loader(name, extensions):
"A data loader for filenames ending with one of the given extensions."
def check_ext(filename):
return pathlib.Path(filename).suffix.lower() in set(
e.lower() for e in extensions)
return data_source_loader(name, check_ext)
def data_source_loader(name, match_source=None):
"""Add a named loader
Add a named data loader with an optional function for matching to
source names.
"""
def wrap(loader_func):
registered_loaders.append(LoaderEntry(loader_func, name, match_source))
return loader_func
return wrap
@data_source_loader("nodata")
@contextlib.contextmanager
def load_nodata(source, **options):
yield None
@file_extension_loader("csv", [".csv"])
@contextlib.contextmanager
def load_csv(source,
absolute_resolved_path,
headers=False,
dialect=None,
encoding='utf-8-sig',
**options):
with open(absolute_resolved_path, 'r', newline='', encoding=encoding) as f:
if dialect == "auto":
sample = f.read(8192)
f.seek(0)
sniffer = csv.Sniffer()
dialect = sniffer.sniff(sample)
if headers:
if dialect is None:
r = csv.DictReader(f)
else:
r = csv.DictReader(f, dialect=dialect)
else:
if dialect is None:
r = csv.reader(f)
else:
r = csv.reader(f, dialect=dialect)
yield list(r)
@mimetype_loader("json", "application/json")
@contextlib.contextmanager
def load_json(source, absolute_resolved_path, encoding='utf-8-sig', **options):
with open(absolute_resolved_path, 'r', encoding=encoding) as f:
try:
yield json.load(f)
except json.decoder.JSONDecodeError as error:
raise LoaderError(str(error)) from error
@file_extension_loader("yaml", ['.yml', '.yaml'])
@contextlib.contextmanager
def load_yaml(source,
absolute_resolved_path,
encoding='utf-8-sig',
multiple_documents=False,
**options):
with open(absolute_resolved_path, 'r', encoding=encoding) as f:
try:
if multiple_documents:
yield list(
yaml.safe_load_all(f)
) # force loading all documents now so the file can be closed
else:
yield yaml.safe_load(f)
except yaml.error.MarkedYAMLError as error:
if error.context_mark.name == absolute_resolved_path:
error.context_mark.name = source
error.problem_mark.name = source
raise LoaderError(str(error)) from error
@lenient_mimetype_loader('xml', 'xml')
@contextlib.contextmanager
def load_xml(source, absolute_resolved_path, **options):
try:
yield ET.parse(absolute_resolved_path).getroot()
except ET.ParseError as error:
raise LoaderError(str(error)) from error
@file_extension_loader("dbm", ['.dbm'])
def load_dbm(source, absolute_resolved_path, **options):
try:
return dbm.open(absolute_resolved_path, "r")
except dbm.error[0] as error:
raise LoaderError(str(error)) from error
@data_source_loader("import-module")
@contextlib.contextmanager
def load_import_module(source, **options):
try:
yield importlib.import_module(source)
except ModuleNotFoundError as error:
raise LoaderError(str(error)) from error
|