1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
|
# -*- coding: utf-8 -*-
"""
DictReader and DictWriter.
This code is entirely copied from the Python csv module. The only exception is
that it uses the `reader` and `writer` classes from our package.
Author: Gertjan van den Burg
"""
import warnings
from collections import OrderedDict
from .read import reader
from .write import writer
class DictReader(object):
def __init__(
self,
f,
fieldnames=None,
restkey=None,
restval=None,
dialect="excel",
*args,
**kwds
):
self._fieldnames = fieldnames
self.restkey = restkey
self.restval = restval
self.reader = reader(f, dialect, *args, **kwds)
self.dialect = dialect
self.line_num = 0
def __iter__(self):
return self
@property
def fieldnames(self):
if self._fieldnames is None:
try:
self._fieldnames = next(self.reader)
except StopIteration:
pass
# Note: this was added because I don't think it's expected that Python
# simply drops information if there are duplicate headers. There is
# discussion on this issue in the Python bug tracker here:
# https://bugs.python.org/issue17537 (see linked thread therein). A
# warning is easy enough to suppress and should ensure that the user
# is at least aware of this behavior.
if not len(self._fieldnames) == len(set(self._fieldnames)):
warnings.warn(
"fieldnames are not unique, some columns will be dropped."
)
self.line_num = self.reader.line_num
return self._fieldnames
@fieldnames.setter
def fieldnames(self, value):
self._fieldnames = value
def __next__(self):
if self.line_num == 0:
self.fieldnames
row = next(self.reader)
self.line_num = self.reader.line_num
while row == []:
row = next(self.reader)
d = OrderedDict(zip(self.fieldnames, row))
lf = len(self.fieldnames)
lr = len(row)
if lf < lr:
d[self.restkey] = row[lf:]
elif lf > lr:
for key in self.fieldnames[lr:]:
d[key] = self.restval
return d
class DictWriter(object):
def __init__(
self,
f,
fieldnames,
restval="",
extrasaction="raise",
dialect="excel",
*args,
**kwds
):
self.fieldnames = fieldnames
self.restval = restval
if extrasaction.lower() not in ("raise", "ignore"):
raise ValueError(
"extrasaction (%s) must be 'raise' or 'ignore'" % extrasaction
)
self.extrasaction = extrasaction
self.writer = writer(f, dialect, *args, **kwds)
def writeheader(self):
header = dict(zip(self.fieldnames, self.fieldnames))
return self.writerow(header)
def _dict_to_list(self, rowdict):
if self.extrasaction == "raise":
wrong_fields = rowdict.keys() - self.fieldnames
if wrong_fields:
raise ValueError(
"dict contains fields not in fieldnames: "
+ ", ".join([repr(x) for x in wrong_fields])
)
return (rowdict.get(key, self.restval) for key in self.fieldnames)
def writerow(self, rowdict):
return self.writer.writerow(self._dict_to_list(rowdict))
def writerows(self, rowdicts):
return self.writer.writerows(map(self._dict_to_list, rowdicts))
|