the-algorithm-ml/core/config/base_config.py

67 lines
2.2 KiB
Python

"""Base class for all config (forbids extra fields)."""
import collections
import functools
import yaml
import pydantic
class BaseConfig(pydantic.BaseModel):
"""Base class for all derived config classes.
This class provides some convenient functionality:
- Disallows extra fields when constructing an object. User error
should be reduced by exact arguments.
- "one_of" fields. A subclass can group optional fields and enforce
that only one of the fields be set. For example:
```
class ExampleConfig(BaseConfig):
x: int = Field(None, one_of="group_1")
y: int = Field(None, one_of="group_1")
ExampleConfig(x=1) # ok
ExampleConfig(y=1) # ok
ExampleConfig(x=1, y=1) # throws error
```
"""
class Config:
"""Forbids extras."""
extra = pydantic.Extra.forbid # noqa
@classmethod
@functools.lru_cache()
def _field_data_map(cls, field_data_name):
"""Create a map of fields with provided the field data."""
schema = cls.schema()
one_of = collections.defaultdict(list)
for field, fdata in schema["properties"].items():
if field_data_name in fdata:
one_of[fdata[field_data_name]].append(field)
return one_of
@pydantic.root_validator
def _one_of_check(cls, values):
"""Validate that all 'one of' fields are appear exactly once."""
one_of_map = cls._field_data_map("one_of")
for one_of, field_names in one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) != 1:
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
return values
@pydantic.root_validator
def _at_most_one_of_check(cls, values):
"""Validate that all 'at_most_one_of' fields appear at most once."""
at_most_one_of_map = cls._field_data_map("at_most_one_of")
for one_of, field_names in at_most_one_of_map.items():
if sum([values.get(n, None) is not None for n in field_names]) > 1:
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
return values
def pretty_print(self) -> str:
"""Return a human legible (yaml) representation of the config useful for logging."""
return yaml.dump(self.dict())