mirror of
https://github.com/twitter/the-algorithm-ml.git
synced 2024-12-24 15:21:47 +01:00
67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
"""Base class for all config (forbids extra fields)."""
|
|
|
|
import collections
|
|
import functools
|
|
import yaml
|
|
|
|
import pydantic
|
|
|
|
|
|
class BaseConfig(pydantic.BaseModel):
|
|
"""Base class for all derived config classes.
|
|
|
|
This class provides some convenient functionality:
|
|
- Disallows extra fields when constructing an object. User error
|
|
should be reduced by exact arguments.
|
|
- "one_of" fields. A subclass can group optional fields and enforce
|
|
that only one of the fields be set. For example:
|
|
|
|
```
|
|
class ExampleConfig(BaseConfig):
|
|
x: int = Field(None, one_of="group_1")
|
|
y: int = Field(None, one_of="group_1")
|
|
|
|
ExampleConfig(x=1) # ok
|
|
ExampleConfig(y=1) # ok
|
|
ExampleConfig(x=1, y=1) # throws error
|
|
```
|
|
"""
|
|
|
|
class Config:
|
|
"""Forbids extras."""
|
|
|
|
extra = pydantic.Extra.forbid # noqa
|
|
|
|
@classmethod
|
|
@functools.lru_cache()
|
|
def _field_data_map(cls, field_data_name):
|
|
"""Create a map of fields with provided the field data."""
|
|
schema = cls.schema()
|
|
one_of = collections.defaultdict(list)
|
|
for field, fdata in schema["properties"].items():
|
|
if field_data_name in fdata:
|
|
one_of[fdata[field_data_name]].append(field)
|
|
return one_of
|
|
|
|
@pydantic.root_validator
|
|
def _one_of_check(cls, values):
|
|
"""Validate that all 'one of' fields are appear exactly once."""
|
|
one_of_map = cls._field_data_map("one_of")
|
|
for one_of, field_names in one_of_map.items():
|
|
if sum([values.get(n, None) is not None for n in field_names]) != 1:
|
|
raise ValueError(f"Exactly one of {','.join(field_names)} required.")
|
|
return values
|
|
|
|
@pydantic.root_validator
|
|
def _at_most_one_of_check(cls, values):
|
|
"""Validate that all 'at_most_one_of' fields appear at most once."""
|
|
at_most_one_of_map = cls._field_data_map("at_most_one_of")
|
|
for one_of, field_names in at_most_one_of_map.items():
|
|
if sum([values.get(n, None) is not None for n in field_names]) > 1:
|
|
raise ValueError(f"At most one of {','.join(field_names)} can be set.")
|
|
return values
|
|
|
|
def pretty_print(self) -> str:
|
|
"""Return a human legible (yaml) representation of the config useful for logging."""
|
|
return yaml.dump(self.dict())
|