1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692
|
# ML Collections
ML Collections is a library of Python Collections designed for ML use cases.
[](https://ml-collections.readthedocs.io/en/latest/?badge=latest)
[](https://badge.fury.io/py/ml-collections)
[](https://github.com/google/ml_collections/actions/workflows/pytest_and_autopublish.yml)
## ConfigDict
The two classes called `ConfigDict` and `FrozenConfigDict` are "dict-like" data
structures with dot access to nested elements. Together, they are supposed to be
used as a main way of expressing configurations of experiments and models.
This document describes example usage of `ConfigDict`, `FrozenConfigDict`,
`FieldReference`.
### Features
* Dot-based access to fields.
* Locking mechanism to prevent spelling mistakes.
* Lazy computation.
* FrozenConfigDict() class which is immutable and hashable.
* Type safety.
* "Did you mean" functionality.
* Human readable printing (with valid references and cycles), using valid YAML
format.
* Fields can be passed as keyword arguments using the `**` operator.
* There is one exception to the strong type-safety of the ConfigDict: `int`
values can be passed in to fields of type `float`. In such a case, the value
is type-converted to a `float` before being stored. (Back in the day of
Python 2, there was a similar exception to allow both `str` and `unicode`
values in string fields.)
### Basic Usage
```python
from ml_collections import config_dict
cfg = config_dict.ConfigDict()
cfg.float_field = 12.6
cfg.integer_field = 123
cfg.another_integer_field = 234
cfg.nested = config_dict.ConfigDict()
cfg.nested.string_field = 'tom'
print(cfg.integer_field) # Prints 123.
print(cfg['integer_field']) # Prints 123 as well.
try:
cfg.integer_field = 'tom' # Raises TypeError as this field is an integer.
except TypeError as e:
print(e)
cfg.float_field = 12 # Works: `Int` types can be assigned to `Float`.
cfg.nested.string_field = u'bob' # `String` fields can store Unicode strings.
print(cfg)
```
### FrozenConfigDict
A `FrozenConfigDict`is an immutable, hashable type of `ConfigDict`:
```python
from ml_collections import config_dict
initial_dictionary = {
'int': 1,
'list': [1, 2],
'tuple': (1, 2, 3),
'set': {1, 2, 3, 4},
'dict_tuple_list': {'tuple_list': ([1, 2], 3)}
}
cfg = config_dict.ConfigDict(initial_dictionary)
frozen_dict = config_dict.FrozenConfigDict(initial_dictionary)
print(frozen_dict.tuple) # Prints tuple (1, 2, 3)
print(frozen_dict.list) # Prints tuple (1, 2)
print(frozen_dict.set) # Prints frozenset {1, 2, 3, 4}
print(frozen_dict.dict_tuple_list.tuple_list[0]) # Prints tuple (1, 2)
frozen_cfg = config_dict.FrozenConfigDict(cfg)
print(frozen_cfg == frozen_dict) # True
print(hash(frozen_cfg) == hash(frozen_dict)) # True
try:
frozen_dict.int = 2 # Raises TypeError as FrozenConfigDict is immutable.
except AttributeError as e:
print(e)
# Converting between `FrozenConfigDict` and `ConfigDict`:
thawed_frozen_cfg = config_dict.ConfigDict(frozen_dict)
print(thawed_frozen_cfg == cfg) # True
frozen_cfg_to_cfg = frozen_dict.as_configdict()
print(frozen_cfg_to_cfg == cfg) # True
```
### FieldReferences and placeholders
A `FieldReference` is useful for having multiple fields use the same value. It
can also be used for [lazy computation](#lazy-computation).
You can use `placeholder()` as a shortcut to create a `FieldReference` (field)
with a `None` default value. This is useful if a program uses optional
configuration fields.
```python
from ml_collections import config_dict
placeholder = config_dict.FieldReference(0)
cfg = config_dict.ConfigDict()
cfg.placeholder = placeholder
cfg.optional = config_dict.placeholder(int)
cfg.nested = config_dict.ConfigDict()
cfg.nested.placeholder = placeholder
try:
cfg.optional = 'tom' # Raises Type error as this field is an integer.
except TypeError as e:
print(e)
cfg.optional = 1555 # Works fine.
cfg.placeholder = 1 # Changes the value of both placeholder and
# nested.placeholder fields.
print(cfg)
```
Note that the indirection provided by `FieldReference`s will be lost if accessed
through a `ConfigDict`.
```python
from ml_collections import config_dict
placeholder = config_dict.FieldReference(0)
cfg.field1 = placeholder
cfg.field2 = placeholder # This field will be tied to cfg.field1.
cfg.field3 = cfg.field1 # This will just be an int field initialized to 0.
```
### Lazy computation
Using a `FieldReference` in a standard operation (addition, subtraction,
multiplication, etc...) will return another `FieldReference` that points to the
original's value. You can use `FieldReference.get()` to execute the operations
and get the reference's computed value, and `FieldReference.set()` to change the
original reference's value.
```python
from ml_collections import config_dict
ref = config_dict.FieldReference(1)
print(ref.get()) # Prints 1
add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten
add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer
print(add_ten) # Prints 11
print(add_ten_lazy.get()) # Prints 11 because ref's value is 1
# Addition is lazily computed for FieldReferences so changing ref will change
# the value that is used to compute add_ten.
ref.set(5)
print(add_ten) # Prints 11
print(add_ten_lazy.get()) # Prints 15 because ref's value is 5
```
If a `FieldReference` has `None` as its original value, or any operation has an
argument of `None`, then the lazy computation will evaluate to `None`.
We can also use fields in a `ConfigDict` in lazy computation. In this case a
field will only be lazily evaluated if `ConfigDict.get_ref()` is used to get it.
```python
from ml_collections import config_dict
config = config_dict.ConfigDict()
config.reference_field = config_dict.FieldReference(1)
config.integer_field = 2
config.float_field = 2.5
# No lazy evaluatuations because we didn't use get_ref()
config.no_lazy = config.integer_field * config.float_field
# This will lazily evaluate ONLY config.integer_field
config.lazy_integer = config.get_ref('integer_field') * config.float_field
# This will lazily evaluate ONLY config.float_field
config.lazy_float = config.integer_field * config.get_ref('float_field')
# This will lazily evaluate BOTH config.integer_field and config.float_Field
config.lazy_both = (config.get_ref('integer_field') *
config.get_ref('float_field'))
config.integer_field = 3
print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value
print(config.lazy_integer) # Prints 7.5
config.float_field = 3.5
print(config.lazy_float) # Prints 7.0
print(config.lazy_both) # Prints 10.5
```
#### Changing lazily computed values
Lazily computed values in a ConfigDict can be overridden in the same way as
regular values. The reference to the `FieldReference` used for the lazy
computation will be lost and all computations downstream in the reference graph
will use the new value.
```python
from ml_collections import config_dict
config = config_dict.ConfigDict()
config.reference = 1
config.reference_0 = config.get_ref('reference') + 10
config.reference_1 = config.get_ref('reference') + 20
config.reference_1_0 = config.get_ref('reference_1') + 100
print(config.reference) # Prints 1.
print(config.reference_0) # Prints 11.
print(config.reference_1) # Prints 21.
print(config.reference_1_0) # Prints 121.
config.reference_1 = 30
print(config.reference) # Prints 1 (unchanged).
print(config.reference_0) # Prints 11 (unchanged).
print(config.reference_1) # Prints 30.
print(config.reference_1_0) # Prints 130.
```
#### Cycles
You cannot create cycles using references. Fortunately
[the only way](#changing-lazily-computed-values) to create a cycle is by
assigning a computed field to one that *is not* the result of computation. This
is forbidden:
```python
from ml_collections import config_dict
config = config_dict.ConfigDict()
config.integer_field = 1
config.bigger_integer_field = config.get_ref('integer_field') + 10
try:
# Raises a MutabilityError because setting config.integer_field would
# cause a cycle.
config.integer_field = config.get_ref('bigger_integer_field') + 2
except config_dict.MutabilityError as e:
print(e)
```
#### One-way references
One gotcha with `get_ref` is that it creates a bi-directional dependency when no operations are performed on the value.
```python
from ml_collections import config_dict
config = config_dict.ConfigDict()
config.reference = 1
config.reference_0 = config.get_ref('reference')
config.reference_0 = 2
print(config.reference) # Prints 2.
print(config.reference_0) # Prints 2.
```
This can be avoided by using `get_oneway_ref` instead of `get_ref`.
```python
from ml_collections import config_dict
config = config_dict.ConfigDict()
config.reference = 1
config.reference_0 = config.get_oneway_ref('reference')
config.reference_0 = 2
print(config.reference) # Prints 1.
print(config.reference_0) # Prints 2.
```
### Advanced usage
Here are some more advanced examples showing lazy computation with different
operators and data types.
```python
from ml_collections import config_dict
config = config_dict.ConfigDict()
config.float_field = 12.6
config.integer_field = 123
config.list_field = [0, 1, 2]
config.float_multiply_field = config.get_ref('float_field') * 3
print(config.float_multiply_field) # Prints 37.8
config.float_field = 10.0
print(config.float_multiply_field) # Prints 30.0
config.longer_list_field = config.get_ref('list_field') + [3, 4, 5]
print(config.longer_list_field) # Prints [0, 1, 2, 3, 4, 5]
config.list_field = [-1]
print(config.longer_list_field) # Prints [-1, 3, 4, 5]
# Both operands can be references
config.ref_subtraction = (
config.get_ref('float_field') - config.get_ref('integer_field'))
print(config.ref_subtraction) # Prints -113.0
config.integer_field = 10
print(config.ref_subtraction) # Prints 0.0
```
### Equality checking
You can use `==` and `.eq_as_configdict()` to check equality among `ConfigDict`
and `FrozenConfigDict` objects.
```python
from ml_collections import config_dict
dict_1 = {'list': [1, 2]}
dict_2 = {'list': (1, 2)}
cfg_1 = config_dict.ConfigDict(dict_1)
frozen_cfg_1 = config_dict.FrozenConfigDict(dict_1)
frozen_cfg_2 = config_dict.FrozenConfigDict(dict_2)
# True because FrozenConfigDict converts lists to tuples
print(frozen_cfg_1.items() == frozen_cfg_2.items())
# False because == distinguishes the underlying difference
print(frozen_cfg_1 == frozen_cfg_2)
# False because == distinguishes these types
print(frozen_cfg_1 == cfg_1)
# But eq_as_configdict() treats both as ConfigDict, so these are True:
print(frozen_cfg_1.eq_as_configdict(cfg_1))
print(cfg_1.eq_as_configdict(frozen_cfg_1))
```
### Equality checking with lazy computation
Equality checks see if the computed values are the same. Equality is satisfied
if two sets of computations are different as long as they result in the same
value.
```python
from ml_collections import config_dict
cfg_1 = config_dict.ConfigDict()
cfg_1.a = 1
cfg_1.b = cfg_1.get_ref('a') + 2
cfg_2 = config_dict.ConfigDict()
cfg_2.a = 1
cfg_2.b = cfg_2.get_ref('a') * 3
# True because all computed values are the same
print(cfg_1 == cfg_2)
```
### Locking and copying
Here is an example with `lock()` and `deepcopy()`:
```python
import copy
from ml_collections import config_dict
cfg = config_dict.ConfigDict()
cfg.integer_field = 123
# Locking prohibits the addition and deletion of new fields but allows
# modification of existing values.
cfg.lock()
try:
cfg.intagar_field = 124 # Modifies the wrong field
except AttributeError as e: # Raises AttributeError and suggests valid field.
print(e)
with cfg.unlocked():
cfg.intagar_field = 1555 # Works fine.
# Get a copy of the config dict.
new_cfg = copy.deepcopy(cfg)
new_cfg.integer_field = -123 # Works fine.
print(cfg)
print(new_cfg)
```
Output:
```
'Key "intagar_field" does not exist and cannot be added since the config is locked. Other fields present: "{\'integer_field\': 123}"\nDid you mean "integer_field" instead of "intagar_field"?'
intagar_field: 1555
integer_field: 123
intagar_field: 1555
integer_field: -123
```
### Dictionary attributes and initialization
```python
from ml_collections import config_dict
referenced_dict = {'inner_float': 3.14}
d = {
'referenced_dict_1': referenced_dict,
'referenced_dict_2': referenced_dict,
'list_containing_dict': [{'key': 'value'}],
}
# We can initialize on a dictionary
cfg = config_dict.ConfigDict(d)
# Reference structure is preserved
print(id(cfg.referenced_dict_1) == id(cfg.referenced_dict_2)) # True
# And the dict attributes have been converted to ConfigDict
print(type(cfg.referenced_dict_1)) # ConfigDict
# However, the initialization does not look inside of lists, so dicts inside
# lists are not converted to ConfigDict
print(type(cfg.list_containing_dict[0])) # dict
```
### More Examples
For more examples, take a look at
[`ml_collections/config_dict/examples/`](https://github.com/google/ml_collections/tree/master/ml_collections/config_dict/examples)
For examples and gotchas specifically about initializing a ConfigDict, see
[`ml_collections/config_dict/examples/config_dict_initialization.py`](https://github.com/google/ml_collections/blob/master/ml_collections/config_dict/examples/config_dict_initialization.py).
## Config Flags
This library adds flag definitions to `absl.flags` to handle config files. It
does not wrap `absl.flags` so if using any standard flag definitions alongside
config file flags, users must also import `absl.flags`.
Currently, this module adds two new flag types, namely `DEFINE_config_file`
which accepts a path to a Python file that generates a configuration, and
`DEFINE_config_dict` which accepts a configuration directly. Configurations are
dict-like structures (see [ConfigDict](#configdict)) whose nested elements
can be overridden using special command-line flags. See the examples below
for more details.
### Usage
Use `ml_collections.config_flags` alongside `absl.flags`. For
example:
`script.py`:
```python
from absl import app
from absl import flags
from ml_collections import config_flags
_CONFIG = config_flags.DEFINE_config_file('my_config')
_MY_FLAG = flags.DEFINE_integer('my_flag', None)
def main(_):
print(_CONFIG.value)
print(_MY_FLAG.value)
if __name__ == '__main__':
app.run(main)
```
`config.py`:
```python
# Note that this is a valid Python script.
# get_config() can return an arbitrary dict-like object. However, it is advised
# to use ml_collections.config_dict.ConfigDict.
# See ml_collections/config_dict/examples/config_dict_basic.py
from ml_collections import config_dict
def get_config():
config = config_dict.ConfigDict()
config.field1 = 1
config.field2 = 'tom'
config.nested = config_dict.ConfigDict()
config.nested.field = 2.23
config.tuple = (1, 2, 3)
return config
```
Warning: If you are using a pickle-based distributed programming framework such
as [Launchpad](https://github.com/deepmind/launchpad#readme), be aware of
limitations on the structure of this script that are [described below]
(#config_files_and_pickling).
Now, after running:
```bash
python script.py --my_config=config.py \
--my_config.field1=8 \
--my_config.nested.field=2.1 \
--my_config.tuple='(1, 2, (1, 2))'
```
we get:
```
field1: 8
field2: tom
nested:
field: 2.1
tuple: !!python/tuple
- 1
- 2
- !!python/tuple
- 1
- 2
```
Usage of `DEFINE_config_dict` is similar to `DEFINE_config_file`, the main
difference is the configuration is defined in `script.py` instead of in a
separate file.
`script.py`:
```python
from absl import app
from ml_collections import config_dict
from ml_collections import config_flags
config = config_dict.ConfigDict()
config.field1 = 1
config.field2 = 'tom'
config.nested = config_dict.ConfigDict()
config.nested.field = 2.23
config.tuple = (1, 2, 3)
_CONFIG = config_flags.DEFINE_config_dict('my_config', config)
def main(_):
print(_CONFIG.value)
if __name__ == '__main__':
app.run()
```
`config_file` flags are compatible with the command-line flag syntax. All the
following options are supported for non-boolean values in configurations:
* `-(-)config.field=value`
* `-(-)config.field value`
Options for boolean values are slightly different:
* `-(-)config.boolean_field`: set boolean value to True.
* `-(-)noconfig.boolean_field`: set boolean value to False.
* `-(-)config.boolean_field=value`: `value` is `true`, `false`, `True` or
`False`.
Note that `-(-)config.boolean_field value` is not supported.
### Parameterising the get_config() function
It's sometimes useful to be able to pass parameters into `get_config`, and
change what is returned based on this configuration. One example is if you are
grid searching over parameters which have a different hierarchical structure -
the flag needs to be present in the resulting ConfigDict. It would be possible
to include the union of all possible leaf values in your ConfigDict,
but this produces a confusing config result as you have to remember which
parameters will actually have an effect and which won't.
A better system is to pass some configuration, indicating which structure of
ConfigDict should be returned. An example is the following config file:
```python
from ml_collections import config_dict
def get_config(config_string):
possible_structures = {
'linear': config_dict.ConfigDict({
'model_constructor': 'snt.Linear',
'model_config': config_dict.ConfigDict({
'output_size': 42,
}),
'lstm': config_dict.ConfigDict({
'model_constructor': 'snt.LSTM',
'model_config': config_dict.ConfigDict({
'hidden_size': 108,
})
})
}
return possible_structures[config_string]
```
The value of `config_string` will be anything that is to the right of the first
colon in the config file path, if one exists. If no colon exists, no value is
passed to `get_config` (producing a TypeError if `get_config` expects a value).
The above example can be run like:
```bash
python script.py -- --config=path_to_config.py:linear \
--config.model_config.output_size=256
```
or like:
```bash
python script.py -- --config=path_to_config.py:lstm \
--config.model_config.hidden_size=512
```
### Additional features
* Loads any valid python script which defines `get_config()` function
returning any python object.
* Automatic locking of the loaded object, if the loaded object defines a
callable `.lock()` method.
* Supports command-line overriding of arbitrarily nested values in dict-like
objects (with key/attribute based getters/setters) of the following types:
* `int`
* `float`
* `bool`
* `str`
* `tuple` (but **not** `list`)
* `enum.Enum`
* Overriding is type safe.
* Overriding of a `tuple` can be done by passing in the `tuple` value as a
string (see the example in the [Usage](#usage) section).
* The overriding `tuple` object can be of a different length and have
different item types than the original. Nested tuples are also supported.
### Config Files and Pickling {#config_files_and_pickling}
This is likely to be troublesome:
```python {.bad}
@dataclasses.dataclass
class MyRecord:
num_balloons: int
color: str
def get_config():
return MyRecord(num_balloons=99, color='red')
```
This is not:
```python {.good}
def get_config():
@dataclasses.dataclass
class MyRecord:
num_balloons: int
color: str
return MyRecord(num_balloons=99, color='red')
```
#### Explanation
A config file is a Python module but it is not imported through Python's usual
module-importing mechanism.
Meanwhile, serialization libraries such as [`cloudpickle`](
https://github.com/cloudpipe/cloudpickle#readme) (which is used by [Launchpad](
https://github.com/deepmind/launchpad#readme)) and [Apache Beam](
https://beam.apache.org/) expect to be able to pickle an object without also
pickling every type to which it refers, on the assumption that types defined
at module scope can later be reconstructed simply by re-importing the modules
in which they are defined.
That assumption does not hold for a type that is defined at module scope in a
config file, because the config file can't be imported the usual way. The
symptom of this will be an `ImportError` when unpickling an object.
The treatment is to move types from module scope into `get_config()` so that
they will be serialized along with the values that have those types.
## Authors
* Sergio Gómez Colmenarejo - sergomez@google.com
* Wojciech Marian Czarnecki - lejlot@google.com
* Nicholas Watters
* Mohit Reddy - mohitreddy@google.com
|