1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
|
## @package db_file_reader
# Module caffe2.python.db_file_reader
from caffe2.python import core, scope, workspace, _import_c_extension as C
from caffe2.python.dataio import Reader
from caffe2.python.dataset import Dataset
from caffe2.python.schema import from_column_list
import os
class DBFileReader(Reader):
default_name_suffix = 'db_file_reader'
"""Reader reads from a DB file.
Example usage:
db_file_reader = DBFileReader(db_path='/tmp/cache.db', db_type='LevelDB')
Args:
db_path: str.
db_type: str. DB type of file. A db_type is registed by
`REGISTER_CAFFE2_DB(<db_type>, <DB Class>)`.
name: str or None. Name of DBFileReader.
Optional name to prepend to blobs that will store the data.
Default to '<db_name>_<default_name_suffix>'.
batch_size: int.
How many examples are read for each time the read_net is run.
loop_over: bool.
If True given, will go through examples in random order endlessly.
field_names: List[str]. If the schema.field_names() should not in
alphabetic order, it must be specified.
Otherwise, schema will be automatically restored with
schema.field_names() sorted in alphabetic order.
"""
def __init__(
self,
db_path,
db_type,
name=None,
batch_size=100,
loop_over=False,
field_names=None,
):
assert db_path is not None, "db_path can't be None."
assert db_type in C.registered_dbs(), \
"db_type [{db_type}] is not available. \n" \
"Choose one of these: {registered_dbs}.".format(
db_type=db_type,
registered_dbs=C.registered_dbs(),
)
self.db_path = os.path.expanduser(db_path)
self.db_type = db_type
self.name = name or '{db_name}_{default_name_suffix}'.format(
db_name=self._extract_db_name_from_db_path(),
default_name_suffix=self.default_name_suffix,
)
self.batch_size = batch_size
self.loop_over = loop_over
# Before self._init_reader_schema(...),
# self.db_path and self.db_type are required to be set.
super(DBFileReader, self).__init__(self._init_reader_schema(field_names))
self.ds = Dataset(self._schema, self.name + '_dataset')
self.ds_reader = None
def _init_name(self, name):
return name or self._extract_db_name_from_db_path(
) + '_db_file_reader'
def _init_reader_schema(self, field_names=None):
"""Restore a reader schema from the DB file.
If `field_names` given, restore scheme according to it.
Overwise, loade blobs from the DB file into the workspace,
and restore schema from these blob names.
It is also assumed that:
1). Each field of the schema have corresponding blobs
stored in the DB file.
2). Each blob loaded from the DB file corresponds to
a field of the schema.
3). field_names in the original schema are in alphabetic order,
since blob names loaded to the workspace from the DB file
will be in alphabetic order.
Load a set of blobs from a DB file. From names of these blobs,
restore the DB file schema using `from_column_list(...)`.
Returns:
schema: schema.Struct. Used in Reader.__init__(...).
"""
if field_names:
return from_column_list(field_names)
if self.db_type == "log_file_db":
assert os.path.exists(self.db_path), \
'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
with core.NameScope(self.name):
# blob_prefix is for avoiding name conflict in workspace
blob_prefix = scope.CurrentNameScope()
workspace.RunOperatorOnce(
core.CreateOperator(
'Load',
[],
[],
absolute_path=True,
db=self.db_path,
db_type=self.db_type,
load_all=True,
add_prefix=blob_prefix,
)
)
col_names = [
blob_name[len(blob_prefix):] for blob_name in sorted(workspace.Blobs())
if blob_name.startswith(blob_prefix)
]
schema = from_column_list(col_names)
return schema
def setup_ex(self, init_net, finish_net):
"""From the Dataset, create a _DatasetReader and setup a init_net.
Make sure the _init_field_blobs_as_empty(...) is only called once.
Because the underlying NewRecord(...) creats blobs by calling
NextScopedBlob(...), so that references to previously-initiated
empty blobs will be lost, causing accessibility issue.
"""
if self.ds_reader:
self.ds_reader.setup_ex(init_net, finish_net)
else:
self._init_field_blobs_as_empty(init_net)
self._feed_field_blobs_from_db_file(init_net)
self.ds_reader = self.ds.random_reader(
init_net,
batch_size=self.batch_size,
loop_over=self.loop_over,
)
self.ds_reader.sort_and_shuffle(init_net)
self.ds_reader.computeoffset(init_net)
def read(self, read_net):
assert self.ds_reader, 'setup_ex must be called first'
return self.ds_reader.read(read_net)
def _init_field_blobs_as_empty(self, init_net):
"""Initialize dataset field blobs by creating an empty record"""
with core.NameScope(self.name):
self.ds.init_empty(init_net)
def _feed_field_blobs_from_db_file(self, net):
"""Load from the DB file at db_path and feed dataset field blobs"""
if self.db_type == "log_file_db":
assert os.path.exists(self.db_path), \
'db_path [{db_path}] does not exist'.format(db_path=self.db_path)
net.Load(
[],
self.ds.get_blobs(),
db=self.db_path,
db_type=self.db_type,
absolute_path=True,
source_blob_names=self.ds.field_names(),
)
def _extract_db_name_from_db_path(self):
"""Extract DB name from DB path
E.g. given self.db_path=`/tmp/sample.db`, or
self.db_path = `dper_test_data/cached_reader/sample.db`
it returns `sample`.
Returns:
db_name: str.
"""
return os.path.basename(self.db_path).rsplit('.', 1)[0]
|