1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
|
# -*- coding: utf-8 -*-
"""
DictReader and DictWriter.
This code is entirely copied from the Python csv module. The only exception is
that it uses the `reader` and `writer` classes from our package.
Author: Gertjan van den Burg
"""
from __future__ import annotations
import warnings
from collections import OrderedDict
from collections.abc import Collection
from typing import TYPE_CHECKING
from typing import Any
from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import Literal
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import TypeVar
from typing import Union
from typing import cast
from clevercsv.read import reader
from clevercsv.write import writer
if TYPE_CHECKING:
from clevercsv._types import SupportsWrite
from clevercsv._types import _DialectLike
from clevercsv._types import _DictReadMapping
_T = TypeVar("_T")
class DictReader(
Generic[_T], Iterator["_DictReadMapping[Union[_T, Any], Union[str, Any]]"]
):
def __init__(
self,
f: Iterable[str],
fieldnames: Optional[Sequence[_T]] = None,
restkey: Optional[str] = None,
restval: Optional[str] = None,
dialect: "_DialectLike" = "excel",
*args: Any,
**kwds: Any,
) -> None:
self._fieldnames = fieldnames
self.restkey = restkey
self.restval = restval
self.reader: reader = reader(f, dialect, *args, **kwds)
self.dialect = dialect
self.line_num = 0
def __iter__(self) -> "DictReader[_T]":
return self
@property
def fieldnames(self) -> Sequence[_T]:
if self._fieldnames is None:
try:
fieldnames = next(self.reader)
self._fieldnames = [cast(_T, f) for f in fieldnames]
except StopIteration:
pass
assert self._fieldnames is not None
# Note: this was added because I don't think it's expected that Python
# simply drops information if there are duplicate headers. There is
# discussion on this issue in the Python bug tracker here:
# https://bugs.python.org/issue17537 (see linked thread therein). A
# warning is easy enough to suppress and should ensure that the user
# is at least aware of this behavior.
if not len(self._fieldnames) == len(set(self._fieldnames)):
warnings.warn(
"fieldnames are not unique, some columns will be dropped."
)
self.line_num = self.reader.line_num
return self._fieldnames
@fieldnames.setter
def fieldnames(self, value: Sequence[_T]) -> None:
self._fieldnames = value
def __next__(self) -> "_DictReadMapping[Union[_T, Any], Union[str, Any]]":
if self.line_num == 0:
self.fieldnames
row = next(self.reader)
self.line_num = self.reader.line_num
while row == []:
row = next(self.reader)
d: _DictReadMapping = OrderedDict(zip(self.fieldnames, row))
lf = len(self.fieldnames)
lr = len(row)
if lf < lr:
d[self.restkey] = row[lf:]
elif lf > lr:
for key in self.fieldnames[lr:]:
d[key] = self.restval
return d
class DictWriter(Generic[_T]):
def __init__(
self,
f: SupportsWrite[str],
fieldnames: Collection[_T],
restval: Optional[Any] = "",
extrasaction: Literal["raise", "ignore"] = "raise",
dialect: "_DialectLike" = "excel",
*args: Any,
**kwds: Any,
):
self.fieldnames = fieldnames
self.restval = restval
if extrasaction.lower() not in ("raise", "ignore"):
raise ValueError(
"extrasaction (%s) must be 'raise' or 'ignore'" % extrasaction
)
self.extrasaction = extrasaction
self.writer = writer(f, dialect, *args, **kwds)
def writeheader(self) -> Any:
header = dict(zip(self.fieldnames, self.fieldnames))
return self.writerow(header)
def _dict_to_list(self, rowdict: Mapping[_T, Any]) -> Iterator[Any]:
if self.extrasaction == "raise":
wrong_fields = rowdict.keys() - self.fieldnames
if wrong_fields:
raise ValueError(
"dict contains fields not in fieldnames: "
+ ", ".join([repr(x) for x in wrong_fields])
)
return (rowdict.get(key, self.restval) for key in self.fieldnames)
def writerow(self, rowdict: Mapping[_T, Any]) -> Any:
return self.writer.writerow(self._dict_to_list(rowdict))
def writerows(self, rowdicts: Iterable[Mapping[_T, Any]]) -> None:
return self.writer.writerows(map(self._dict_to_list, rowdicts))
|