shimmer_ssd.config
1import pprint 2import warnings 3from collections.abc import Mapping, Sequence 4from enum import Enum 5from pathlib import Path 6from typing import Any, Literal, Self 7 8from cfg_tools import ParsedModel, load_config_files 9from cfg_tools.utils import validate_and_fill_missing 10from pydantic import ( 11 BaseModel, 12 Field, 13 GetCoreSchemaHandler, 14 field_serializer, 15 field_validator, 16 model_validator, 17) 18from pydantic_core import core_schema 19from shimmer import __version__ 20from shimmer.version import __version__ as shimmer_version 21from simple_shapes_dataset import DomainType 22from typing_extensions import TypedDict 23 24from shimmer_ssd import PROJECT_DIR 25 26 27class DomainModuleVariant(Enum): 28 """ 29 This is used to select a particular DomainModule. 30 Each domain can have different variants of domain modules. 31 32 For example "attr" will load the default attribute module that uses a VAE, 33 whereas "attr_legacy" will load a domain module that directly passes attributes 34 to the GW. 35 """ 36 37 # Attribute modules 38 # ----------------- 39 # Attribute module using a VAE to encode the attribute vector 40 attr = (DomainType.attr, "default") 41 # This is the module used in Devillers et al. paper. There is no VAE and the 42 # attributes are used directly as the unimodal latent representations 43 attr_legacy = (DomainType.attr, "legacy") 44 # Same as "attr" but adds an unpaired attributes (information not available in the 45 # other domains). 46 attr_unpaired = (DomainType.attr, "unpaired") 47 48 # Visual modules 49 # -------------- 50 # Visual VAE 51 v = (DomainType.v, "default") 52 # Same as "v", but uses pre-saved latent VAE representation for faster training. 53 # This skips the image loading and encoding and only loads latent representation. 54 # The downside is that you cannot access the default image, but you can reconstruct 55 # it with "decode_images". 56 v_latents = (DomainType.v_latents, "default") 57 # Same as "v_latents" but adds an unpaired value (radom information not available 58 # in the other domains). 59 v_latents_unpaired = (DomainType.v_latents, "unpaired") 60 61 # Text modules 62 # ------------ 63 # Text domain. 64 t = (DomainType.t, "default") 65 t_attr = (DomainType.t, "t2attr") 66 67 def __init__(self, kind: DomainType, model_variant: str) -> None: 68 """ 69 The two elements of the tuple are put in the the `kind` and the `model_variant` 70 properties. 71 """ 72 self.kind = kind 73 self.model_variant = model_variant 74 75 @classmethod 76 def __get_pydantic_core_schema__( 77 cls, _source_type: Any, _handler: GetCoreSchemaHandler 78 ) -> core_schema.CoreSchema: 79 """ 80 Define how this type is validated and serialized by pydantic. 81 It can take a str related to the enum keys. 82 It should be serialized as the the enum key name. 83 """ 84 85 def validate_from_str(v: str) -> DomainModuleVariant: 86 """ 87 Use names instead of values to select enums 88 """ 89 assert v in DomainModuleVariant.__members__, ( 90 f"Domain type `{v}` is not a member " 91 f"of {list(DomainModuleVariant.__members__.keys())}" 92 ) 93 return DomainModuleVariant[v] 94 95 from_str_schema = core_schema.no_info_plain_validator_function( 96 validate_from_str 97 ) 98 99 def serialize_domain_variant(v: DomainModuleVariant) -> str: 100 return v.name 101 102 return core_schema.json_or_python_schema( 103 json_schema=from_str_schema, 104 python_schema=core_schema.union_schema( 105 [ 106 core_schema.is_instance_schema(DomainModuleVariant), 107 from_str_schema, 108 ] 109 ), 110 serialization=core_schema.plain_serializer_function_ser_schema( 111 serialize_domain_variant 112 ), 113 ) 114 115 116class Logging(BaseModel): 117 # List of medias (images/text) that will be logged. 118 # The list is defined in individual train config (e.g. train_v.yaml, train_gw.yaml) 119 filter_images: Sequence[str] | None = None 120 log_train_medias_every_n_epochs: int | None = 10 121 log_val_medias_every_n_epochs: int | None = 10 122 123 124class Slurm(BaseModel): 125 """ 126 Slurm config for https://github.com/bdvllrs/auto-sbatch 127 """ 128 129 script: str 130 run_workdir: str 131 python_env: str 132 command: str 133 134 pre_modules: Sequence[str] = [] 135 run_modules: Sequence[str] = [] 136 args: Mapping[str, Any] = {} 137 138 grid_search: Mapping[str, Sequence[Any]] | None = None 139 grid_search_exclude: Sequence[Mapping[str, Any]] | None = None 140 141 142class Optim(BaseModel): 143 """ 144 Optimizer config 145 """ 146 147 # learning rate (max learning rate when using with the default OneCycle 148 # learning rate scheduler) 149 # see https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html#torch.optim.lr_scheduler.OneCycleLR) 150 # for information about the other config values 151 lr: float = 5e-3 152 # TODO: remove as redundant with lr 153 max_lr: float = 5e-3 154 start_lr: float = 5e-4 155 end_lr: float = 5e-4 156 pct_start: float = 0.2 157 weight_decay: float = 1e-5 158 159 160class Training(BaseModel): 161 """ 162 Training related config. 163 As these config depend on what you are training, 164 they are defined in the script related yaml files (e.g. `train_v.yaml`, 165 `train_gw.yaml`). 166 """ 167 168 batch_size: int = 2056 169 num_workers: int = 16 170 devices: int = 1 # num of devices (gpus) to use 171 accelerator: str = "gpu" 172 173 # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#fast-dev-run 174 fast_dev_run: bool = False 175 # number of training steps 176 max_steps: int = 200_000 177 enable_progress_bar: bool = True 178 179 # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision 180 # you may want to set to "16-mixed" if your gpu allows mixed precision 181 precision: Any = 32 182 # see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch-set-float32-matmul-precision 183 # you may want to decrease to "medium" 184 float32_matmul_precision: str = "highest" 185 186 # Optimizer config 187 optim: Optim = Optim() 188 189 190class ExploreVAE(BaseModel): 191 # the VAE checkpoint to use 192 checkpoint: str 193 num_samples: int = 9 194 range_start: int = -3 195 range_end: int = 3 196 wandb_name: str | None = None 197 198 199class ExploreGW(BaseModel): 200 checkpoint: str # The GW checkpoint to use 201 domain: str 202 num_samples: int = 9 203 range_start: int = -3 204 range_end: int = 3 205 wandb_name: str | None = None 206 207 208class Visualization(BaseModel): 209 # when exploring a vae 210 explore_vae: ExploreVAE 211 # when exploring the GW 212 explore_gw: ExploreGW 213 214 215class WanDB(BaseModel): 216 enabled: bool = False 217 # where to save the logs 218 save_dir: str = "" 219 # the "entity/project" values of your wandb project 220 project: str = "" 221 entity: str = "" 222 # see https://docs.wandb.ai/ref/python/init/ 223 reinit: bool = False 224 225 226class Dataset(BaseModel): 227 """ 228 Simple shapes dataset infos 229 """ 230 231 # Path to the dataset obtainable on https://github.com/ruflab/simple-shapes-dataset 232 path: Path 233 234 235class VisualModule(BaseModel): 236 """ 237 Config for the visual domain module 238 """ 239 240 num_channels: int = 3 # num channels of image to configure the VAE 241 ae_dim: int = 256 # AE hidden dim 242 latent_dim: int = 8 # latent dim of the VAE 243 beta: float = 0.1 # beta for betaVAE 244 245 # Whether the model is color blind. 246 # It adds a transform on the dataset that averages all color channels. 247 # NOTE: this only works when using the "v" domain and not "v_latents" where 248 # visual latent representations are already extracted. In that case, use a different 249 # presaved-path in `domain_args`. 250 color_blind: bool = False 251 252 253class AttributeModule(BaseModel): 254 """ 255 Config for the attribute domain module 256 """ 257 258 latent_dim: int = 10 # latent dim of the VAE 259 hidden_dim: int = 64 # hidden dim of the AE (encoders and decoders) 260 beta: float = 0.05 # for betaVAE 261 coef_categories: float = 1 262 coef_attributes: float = 1 263 264 # Whether to remove rotation information from the attribute. 265 nullify_rotation: bool = False 266 267 268class TextModule(BaseModel): 269 """ 270 Config for the text domain module 271 """ 272 273 # which file to load for text 274 # The file generated with `shapesd create` is "latent" 275 # If you use an older version (like if you downloaded directly the dataset) it 276 # should be "bert-base-uncased" 277 latent_filename: str = "latent" 278 # path to the vocab file 279 vocab_path: str = str((PROJECT_DIR / "tokenizer/vocab.json").resolve()) 280 # path to the merge path 281 merges_path: str = str((PROJECT_DIR / "tokenizer/merges.txt").resolve()) 282 # max sequence length of text sequence 283 seq_length: int = 64 284 vocab_size: int = 822 285 286 # VAE configuration 287 latent_dim: int = 64 288 hidden_dim: int = 256 289 beta: float = 0.1 290 291 292class DomainModules(BaseModel): 293 visual: VisualModule = VisualModule() 294 attribute: AttributeModule = AttributeModule() 295 text: TextModule = TextModule() 296 297 298class EncodersConfig(BaseModel): 299 """ 300 Encoder architecture config 301 """ 302 303 hidden_dim: int | Mapping[DomainModuleVariant, int] = 32 304 305 # The model will have an extra linear before and after the n_layers 306 # Hence the total will be `2 + n_layers` 307 n_layers: int | Mapping[DomainModuleVariant, int] = 3 308 309 310class LoadedDomainConfig(BaseModel): 311 """ 312 Domain params for the active domains of the run 313 """ 314 315 # path to the pretrained module 316 checkpoint_path: Path 317 # domain to select 318 domain_type: DomainModuleVariant 319 # domain module specific arguments 320 args: Mapping[str, Any] = {} 321 322 323class DomainProportion(BaseModel): 324 """ 325 Deprecated, DomainProportionT will be used instead. 326 """ 327 328 # proportion for some domains 329 # should be in [0, 1] 330 proportion: float 331 # list of domains the proportion is associated to 332 # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data 333 domains: Sequence[str] 334 335 336class DomainProportionT(TypedDict): 337 """This replaces `DomainProportion` in future config.""" 338 339 # proportion for some domains 340 # should be in [0, 1] 341 proportion: float 342 # list of domains the proportion is associated to 343 # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data 344 domains: Sequence[str] 345 346 347class GlobalWorkspace(BaseModel): 348 # latent dim of the GW 349 latent_dim: int = 12 350 # whether to use the fusion GW 351 use_fusion_model: bool = False 352 # softmax temp for the Softmax distruted selection used by the fusion model 353 selection_temperature: float = 0.2 354 # whether to learn the logit scale of the contrastive loss like in the clip paper 355 learn_logit_scale: bool = False 356 # whether to use the VSEPP (https://github.com/fartashf/vsepp) contrastive loss 357 # with associated params 358 vsepp_contrastive_loss: bool = False 359 vsepp_margin: float = 0.2 360 vsepp_measure: Literal["cosine", "order"] = "cosine" 361 vsepp_max_violation: bool = True 362 # whether to use linear encoders and decoders for the GW 363 linear_domains: bool = False 364 # whether to use bias when using linear encoders and decoders 365 linear_domains_use_bias: bool = False 366 # encoder architecture config 367 encoders: EncodersConfig = EncodersConfig() 368 # decoder architecture config 369 decoders: EncodersConfig = EncodersConfig() 370 # coefs of each loss. The total loss is computed using the given values and coefs 371 # you can select any available loss generated by the loss functions 372 loss_coefficients: Mapping[str, float] = { 373 "cycles": 1.0, 374 "demi_cycles": 1.0, 375 "translations": 1.0, 376 "contrastives": 0.01, 377 "fused": 1.0, 378 } 379 # checkpoint of the GW for downstream, visualization tasks, or migrations 380 checkpoint: Path | None = None 381 # deprecated, use Config.domain_data_args instead 382 domain_args: Mapping[str, Mapping[str, Any]] | None = Field( 383 default=None, deprecated="Use `config.domain_data_args` instead." 384 ) 385 # deprecated, use Config.domains instead 386 domains: Sequence[LoadedDomainConfig] | None = Field( 387 default=None, deprecated="Use `config.domains` instead." 388 ) 389 # deprecated, use Config.domain_proportions instead 390 domain_proportions: Sequence[DomainProportion] | None = Field( 391 default=None, deprecated="Use `config.domain_proportions` instead." 392 ) 393 394 395class ShimmerConfigInfo(BaseModel): 396 """ 397 Some shimmer related config 398 """ 399 400 # version of shimmer used 401 version: str = shimmer_version 402 # whether started in debug mode 403 debug: bool = False 404 # params that were passed through CLI 405 cli: Any = {} 406 407 408class Config(ParsedModel): 409 seed: int = 0 # training seed 410 ood_seed: int | None = None # Out of distribution seed 411 default_root_dir: Path = ( 412 Path("./checkpoints") # Path where to save and load logs and checkpoints 413 ) 414 # Dataset information 415 dataset: Dataset 416 # Training config 417 training: Training = Training() 418 # Wandb configuration 419 wandb: WanDB = WanDB() 420 # Logging configuration 421 logging: Logging = Logging() 422 # Add a title to your wandb run 423 title: str | None = Field(None, alias="t") 424 # Add a description to your run 425 desc: str | None = Field(None, alias="d") 426 # proportion of each domain in the dataset relative to the size of the dataset 427 domain_proportions: Mapping[frozenset[str], float] = {} 428 # Config of the different domain modules 429 domain_modules: DomainModules = DomainModules() 430 # Domain params for the active domains of the run 431 domains: Sequence[LoadedDomainConfig] = [] 432 # data related args used by the dataloader 433 domain_data_args: Mapping[str, Mapping[str, Any]] = {} 434 # Global workspace configuration 435 global_workspace: GlobalWorkspace = GlobalWorkspace() 436 # Config used during visualization 437 visualization: Visualization | None = None 438 # Slurm config when startig on a cluster 439 slurm: Slurm | None = None 440 __shimmer__: ShimmerConfigInfo = ShimmerConfigInfo() 441 442 @field_validator("domain_proportions", mode="before") 443 @classmethod 444 def domain_proportion_validator( 445 cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float] 446 ) -> Mapping[frozenset[str], float]: 447 """ 448 Replace the format: 449 ``` 450 - domains: ["v"] 451 proportion: 1.0 452 ``` 453 in the yaml file into a Mapping[frozenset[str], float] 454 """ 455 if isinstance(value, Mapping): 456 return value 457 else: 458 return {frozenset(item["domains"]): item["proportion"] for item in value} 459 460 @field_serializer("domain_proportions") 461 def serialize_domain_proportions( 462 self, domain_proportions: Mapping[frozenset[str], float], _info 463 ) -> list[DomainProportionT]: 464 return [ 465 {"domains": list(domains), "proportion": proportion} 466 for domains, proportion in domain_proportions.items() 467 ] 468 469 @model_validator(mode="after") 470 def check_selected_domains_have_non_null_proportion(self) -> Self: 471 for domain in self.domains: 472 domain_base_name = domain.domain_type.kind.value.base 473 group = frozenset([domain_base_name]) 474 if self.domain_proportions.get(group, 0) <= 0: 475 raise ValueError( 476 "Selected domains in `domains` should have a non-zero " 477 "proportion in `domain_proportions` " 478 f"but '{domain_base_name}' is not part of `domain_proportions`." 479 ) 480 return self 481 482 483def use_deprecated_vals(config: Config) -> Config: 484 # use deprecated values 485 if config.global_workspace.domain_args is not None: 486 config.domain_data_args = config.global_workspace.domain_args 487 warnings.warn( 488 "Deprecated `config.global_workspace.domain_args`, " 489 "use `config.domain_data_args` instead", 490 DeprecationWarning, 491 stacklevel=2, 492 ) 493 if config.global_workspace.domains is not None: 494 config.domains = config.global_workspace.domains 495 warnings.warn( 496 "Deprecated `config.global_workspace.domains`, " 497 "use `config.domains` instead", 498 DeprecationWarning, 499 stacklevel=2, 500 ) 501 if config.global_workspace.domain_proportions is not None: 502 config.domain_proportions = { 503 frozenset(item.domains): item.proportion 504 for item in config.global_workspace.domain_proportions 505 } 506 507 warnings.warn( 508 "Deprecated `config.global_workspace.domain_proportions`, " 509 "use `config.domain_proportions` instead", 510 DeprecationWarning, 511 stacklevel=2, 512 ) 513 return config 514 515 516def load_config( 517 path: str | Path, 518 load_files: list[str] | None = None, 519 use_cli: bool = True, 520 debug_mode: bool = False, 521 argv: list[str] | None = None, 522 log_config: bool = False, 523) -> Config: 524 path = Path(path) 525 conf_files = [] 526 if load_files is not None: 527 conf_files.extend(load_files) 528 if (path / "main.yaml").exists(): 529 conf_files.append("main.yaml") 530 if (path / "local.yaml").exists(): 531 conf_files.append("local.yaml") 532 533 if debug_mode and (path / "debug.yaml").exists(): 534 conf_files.append("debug.yaml") 535 536 config_dict, cli_config = load_config_files(path, conf_files, use_cli, argv) 537 538 config_dict.update( 539 { 540 "__shimmer__": { 541 "version": __version__, 542 "debug": debug_mode, 543 "cli": cli_config, 544 } 545 } 546 ) 547 548 conf = use_deprecated_vals( 549 validate_and_fill_missing(config_dict, Config, path, "local.yaml") 550 ) 551 if log_config: 552 print("Loaded config:") 553 pprint.pp(dict(conf)) 554 return conf
28class DomainModuleVariant(Enum): 29 """ 30 This is used to select a particular DomainModule. 31 Each domain can have different variants of domain modules. 32 33 For example "attr" will load the default attribute module that uses a VAE, 34 whereas "attr_legacy" will load a domain module that directly passes attributes 35 to the GW. 36 """ 37 38 # Attribute modules 39 # ----------------- 40 # Attribute module using a VAE to encode the attribute vector 41 attr = (DomainType.attr, "default") 42 # This is the module used in Devillers et al. paper. There is no VAE and the 43 # attributes are used directly as the unimodal latent representations 44 attr_legacy = (DomainType.attr, "legacy") 45 # Same as "attr" but adds an unpaired attributes (information not available in the 46 # other domains). 47 attr_unpaired = (DomainType.attr, "unpaired") 48 49 # Visual modules 50 # -------------- 51 # Visual VAE 52 v = (DomainType.v, "default") 53 # Same as "v", but uses pre-saved latent VAE representation for faster training. 54 # This skips the image loading and encoding and only loads latent representation. 55 # The downside is that you cannot access the default image, but you can reconstruct 56 # it with "decode_images". 57 v_latents = (DomainType.v_latents, "default") 58 # Same as "v_latents" but adds an unpaired value (radom information not available 59 # in the other domains). 60 v_latents_unpaired = (DomainType.v_latents, "unpaired") 61 62 # Text modules 63 # ------------ 64 # Text domain. 65 t = (DomainType.t, "default") 66 t_attr = (DomainType.t, "t2attr") 67 68 def __init__(self, kind: DomainType, model_variant: str) -> None: 69 """ 70 The two elements of the tuple are put in the the `kind` and the `model_variant` 71 properties. 72 """ 73 self.kind = kind 74 self.model_variant = model_variant 75 76 @classmethod 77 def __get_pydantic_core_schema__( 78 cls, _source_type: Any, _handler: GetCoreSchemaHandler 79 ) -> core_schema.CoreSchema: 80 """ 81 Define how this type is validated and serialized by pydantic. 82 It can take a str related to the enum keys. 83 It should be serialized as the the enum key name. 84 """ 85 86 def validate_from_str(v: str) -> DomainModuleVariant: 87 """ 88 Use names instead of values to select enums 89 """ 90 assert v in DomainModuleVariant.__members__, ( 91 f"Domain type `{v}` is not a member " 92 f"of {list(DomainModuleVariant.__members__.keys())}" 93 ) 94 return DomainModuleVariant[v] 95 96 from_str_schema = core_schema.no_info_plain_validator_function( 97 validate_from_str 98 ) 99 100 def serialize_domain_variant(v: DomainModuleVariant) -> str: 101 return v.name 102 103 return core_schema.json_or_python_schema( 104 json_schema=from_str_schema, 105 python_schema=core_schema.union_schema( 106 [ 107 core_schema.is_instance_schema(DomainModuleVariant), 108 from_str_schema, 109 ] 110 ), 111 serialization=core_schema.plain_serializer_function_ser_schema( 112 serialize_domain_variant 113 ), 114 )
This is used to select a particular DomainModule. Each domain can have different variants of domain modules.
For example "attr" will load the default attribute module that uses a VAE, whereas "attr_legacy" will load a domain module that directly passes attributes to the GW.
68 def __init__(self, kind: DomainType, model_variant: str) -> None: 69 """ 70 The two elements of the tuple are put in the the `kind` and the `model_variant` 71 properties. 72 """ 73 self.kind = kind 74 self.model_variant = model_variant
The two elements of the tuple are put in the the kind
and the model_variant
properties.
117class Logging(BaseModel): 118 # List of medias (images/text) that will be logged. 119 # The list is defined in individual train config (e.g. train_v.yaml, train_gw.yaml) 120 filter_images: Sequence[str] | None = None 121 log_train_medias_every_n_epochs: int | None = 10 122 log_val_medias_every_n_epochs: int | None = 10
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
125class Slurm(BaseModel): 126 """ 127 Slurm config for https://github.com/bdvllrs/auto-sbatch 128 """ 129 130 script: str 131 run_workdir: str 132 python_env: str 133 command: str 134 135 pre_modules: Sequence[str] = [] 136 run_modules: Sequence[str] = [] 137 args: Mapping[str, Any] = {} 138 139 grid_search: Mapping[str, Sequence[Any]] | None = None 140 grid_search_exclude: Sequence[Mapping[str, Any]] | None = None
Slurm config for https://github.com/bdvllrs/auto-sbatch
143class Optim(BaseModel): 144 """ 145 Optimizer config 146 """ 147 148 # learning rate (max learning rate when using with the default OneCycle 149 # learning rate scheduler) 150 # see https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html#torch.optim.lr_scheduler.OneCycleLR) 151 # for information about the other config values 152 lr: float = 5e-3 153 # TODO: remove as redundant with lr 154 max_lr: float = 5e-3 155 start_lr: float = 5e-4 156 end_lr: float = 5e-4 157 pct_start: float = 0.2 158 weight_decay: float = 1e-5
Optimizer config
161class Training(BaseModel): 162 """ 163 Training related config. 164 As these config depend on what you are training, 165 they are defined in the script related yaml files (e.g. `train_v.yaml`, 166 `train_gw.yaml`). 167 """ 168 169 batch_size: int = 2056 170 num_workers: int = 16 171 devices: int = 1 # num of devices (gpus) to use 172 accelerator: str = "gpu" 173 174 # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#fast-dev-run 175 fast_dev_run: bool = False 176 # number of training steps 177 max_steps: int = 200_000 178 enable_progress_bar: bool = True 179 180 # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision 181 # you may want to set to "16-mixed" if your gpu allows mixed precision 182 precision: Any = 32 183 # see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch-set-float32-matmul-precision 184 # you may want to decrease to "medium" 185 float32_matmul_precision: str = "highest" 186 187 # Optimizer config 188 optim: Optim = Optim()
Training related config.
As these config depend on what you are training,
they are defined in the script related yaml files (e.g. train_v.yaml
,
train_gw.yaml
).
191class ExploreVAE(BaseModel): 192 # the VAE checkpoint to use 193 checkpoint: str 194 num_samples: int = 9 195 range_start: int = -3 196 range_end: int = 3 197 wandb_name: str | None = None
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
200class ExploreGW(BaseModel): 201 checkpoint: str # The GW checkpoint to use 202 domain: str 203 num_samples: int = 9 204 range_start: int = -3 205 range_end: int = 3 206 wandb_name: str | None = None
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
209class Visualization(BaseModel): 210 # when exploring a vae 211 explore_vae: ExploreVAE 212 # when exploring the GW 213 explore_gw: ExploreGW
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
216class WanDB(BaseModel): 217 enabled: bool = False 218 # where to save the logs 219 save_dir: str = "" 220 # the "entity/project" values of your wandb project 221 project: str = "" 222 entity: str = "" 223 # see https://docs.wandb.ai/ref/python/init/ 224 reinit: bool = False
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
227class Dataset(BaseModel): 228 """ 229 Simple shapes dataset infos 230 """ 231 232 # Path to the dataset obtainable on https://github.com/ruflab/simple-shapes-dataset 233 path: Path
Simple shapes dataset infos
236class VisualModule(BaseModel): 237 """ 238 Config for the visual domain module 239 """ 240 241 num_channels: int = 3 # num channels of image to configure the VAE 242 ae_dim: int = 256 # AE hidden dim 243 latent_dim: int = 8 # latent dim of the VAE 244 beta: float = 0.1 # beta for betaVAE 245 246 # Whether the model is color blind. 247 # It adds a transform on the dataset that averages all color channels. 248 # NOTE: this only works when using the "v" domain and not "v_latents" where 249 # visual latent representations are already extracted. In that case, use a different 250 # presaved-path in `domain_args`. 251 color_blind: bool = False
Config for the visual domain module
254class AttributeModule(BaseModel): 255 """ 256 Config for the attribute domain module 257 """ 258 259 latent_dim: int = 10 # latent dim of the VAE 260 hidden_dim: int = 64 # hidden dim of the AE (encoders and decoders) 261 beta: float = 0.05 # for betaVAE 262 coef_categories: float = 1 263 coef_attributes: float = 1 264 265 # Whether to remove rotation information from the attribute. 266 nullify_rotation: bool = False
Config for the attribute domain module
269class TextModule(BaseModel): 270 """ 271 Config for the text domain module 272 """ 273 274 # which file to load for text 275 # The file generated with `shapesd create` is "latent" 276 # If you use an older version (like if you downloaded directly the dataset) it 277 # should be "bert-base-uncased" 278 latent_filename: str = "latent" 279 # path to the vocab file 280 vocab_path: str = str((PROJECT_DIR / "tokenizer/vocab.json").resolve()) 281 # path to the merge path 282 merges_path: str = str((PROJECT_DIR / "tokenizer/merges.txt").resolve()) 283 # max sequence length of text sequence 284 seq_length: int = 64 285 vocab_size: int = 822 286 287 # VAE configuration 288 latent_dim: int = 64 289 hidden_dim: int = 256 290 beta: float = 0.1
Config for the text domain module
293class DomainModules(BaseModel): 294 visual: VisualModule = VisualModule() 295 attribute: AttributeModule = AttributeModule() 296 text: TextModule = TextModule()
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
299class EncodersConfig(BaseModel): 300 """ 301 Encoder architecture config 302 """ 303 304 hidden_dim: int | Mapping[DomainModuleVariant, int] = 32 305 306 # The model will have an extra linear before and after the n_layers 307 # Hence the total will be `2 + n_layers` 308 n_layers: int | Mapping[DomainModuleVariant, int] = 3
Encoder architecture config
311class LoadedDomainConfig(BaseModel): 312 """ 313 Domain params for the active domains of the run 314 """ 315 316 # path to the pretrained module 317 checkpoint_path: Path 318 # domain to select 319 domain_type: DomainModuleVariant 320 # domain module specific arguments 321 args: Mapping[str, Any] = {}
Domain params for the active domains of the run
324class DomainProportion(BaseModel): 325 """ 326 Deprecated, DomainProportionT will be used instead. 327 """ 328 329 # proportion for some domains 330 # should be in [0, 1] 331 proportion: float 332 # list of domains the proportion is associated to 333 # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data 334 domains: Sequence[str]
Deprecated, DomainProportionT will be used instead.
337class DomainProportionT(TypedDict): 338 """This replaces `DomainProportion` in future config.""" 339 340 # proportion for some domains 341 # should be in [0, 1] 342 proportion: float 343 # list of domains the proportion is associated to 344 # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data 345 domains: Sequence[str]
This replaces DomainProportion
in future config.
348class GlobalWorkspace(BaseModel): 349 # latent dim of the GW 350 latent_dim: int = 12 351 # whether to use the fusion GW 352 use_fusion_model: bool = False 353 # softmax temp for the Softmax distruted selection used by the fusion model 354 selection_temperature: float = 0.2 355 # whether to learn the logit scale of the contrastive loss like in the clip paper 356 learn_logit_scale: bool = False 357 # whether to use the VSEPP (https://github.com/fartashf/vsepp) contrastive loss 358 # with associated params 359 vsepp_contrastive_loss: bool = False 360 vsepp_margin: float = 0.2 361 vsepp_measure: Literal["cosine", "order"] = "cosine" 362 vsepp_max_violation: bool = True 363 # whether to use linear encoders and decoders for the GW 364 linear_domains: bool = False 365 # whether to use bias when using linear encoders and decoders 366 linear_domains_use_bias: bool = False 367 # encoder architecture config 368 encoders: EncodersConfig = EncodersConfig() 369 # decoder architecture config 370 decoders: EncodersConfig = EncodersConfig() 371 # coefs of each loss. The total loss is computed using the given values and coefs 372 # you can select any available loss generated by the loss functions 373 loss_coefficients: Mapping[str, float] = { 374 "cycles": 1.0, 375 "demi_cycles": 1.0, 376 "translations": 1.0, 377 "contrastives": 0.01, 378 "fused": 1.0, 379 } 380 # checkpoint of the GW for downstream, visualization tasks, or migrations 381 checkpoint: Path | None = None 382 # deprecated, use Config.domain_data_args instead 383 domain_args: Mapping[str, Mapping[str, Any]] | None = Field( 384 default=None, deprecated="Use `config.domain_data_args` instead." 385 ) 386 # deprecated, use Config.domains instead 387 domains: Sequence[LoadedDomainConfig] | None = Field( 388 default=None, deprecated="Use `config.domains` instead." 389 ) 390 # deprecated, use Config.domain_proportions instead 391 domain_proportions: Sequence[DomainProportion] | None = Field( 392 default=None, deprecated="Use `config.domain_proportions` instead." 393 )
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.
Attributes:
- msg: The deprecation message to be emitted.
- wrapped_property: The property instance if the deprecated field is a computed field, or
None
. - field_name: The name of the field being deprecated.
Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.
Attributes:
- msg: The deprecation message to be emitted.
- wrapped_property: The property instance if the deprecated field is a computed field, or
None
. - field_name: The name of the field being deprecated.
Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.
Attributes:
- msg: The deprecation message to be emitted.
- wrapped_property: The property instance if the deprecated field is a computed field, or
None
. - field_name: The name of the field being deprecated.
396class ShimmerConfigInfo(BaseModel): 397 """ 398 Some shimmer related config 399 """ 400 401 # version of shimmer used 402 version: str = shimmer_version 403 # whether started in debug mode 404 debug: bool = False 405 # params that were passed through CLI 406 cli: Any = {}
Some shimmer related config
409class Config(ParsedModel): 410 seed: int = 0 # training seed 411 ood_seed: int | None = None # Out of distribution seed 412 default_root_dir: Path = ( 413 Path("./checkpoints") # Path where to save and load logs and checkpoints 414 ) 415 # Dataset information 416 dataset: Dataset 417 # Training config 418 training: Training = Training() 419 # Wandb configuration 420 wandb: WanDB = WanDB() 421 # Logging configuration 422 logging: Logging = Logging() 423 # Add a title to your wandb run 424 title: str | None = Field(None, alias="t") 425 # Add a description to your run 426 desc: str | None = Field(None, alias="d") 427 # proportion of each domain in the dataset relative to the size of the dataset 428 domain_proportions: Mapping[frozenset[str], float] = {} 429 # Config of the different domain modules 430 domain_modules: DomainModules = DomainModules() 431 # Domain params for the active domains of the run 432 domains: Sequence[LoadedDomainConfig] = [] 433 # data related args used by the dataloader 434 domain_data_args: Mapping[str, Mapping[str, Any]] = {} 435 # Global workspace configuration 436 global_workspace: GlobalWorkspace = GlobalWorkspace() 437 # Config used during visualization 438 visualization: Visualization | None = None 439 # Slurm config when startig on a cluster 440 slurm: Slurm | None = None 441 __shimmer__: ShimmerConfigInfo = ShimmerConfigInfo() 442 443 @field_validator("domain_proportions", mode="before") 444 @classmethod 445 def domain_proportion_validator( 446 cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float] 447 ) -> Mapping[frozenset[str], float]: 448 """ 449 Replace the format: 450 ``` 451 - domains: ["v"] 452 proportion: 1.0 453 ``` 454 in the yaml file into a Mapping[frozenset[str], float] 455 """ 456 if isinstance(value, Mapping): 457 return value 458 else: 459 return {frozenset(item["domains"]): item["proportion"] for item in value} 460 461 @field_serializer("domain_proportions") 462 def serialize_domain_proportions( 463 self, domain_proportions: Mapping[frozenset[str], float], _info 464 ) -> list[DomainProportionT]: 465 return [ 466 {"domains": list(domains), "proportion": proportion} 467 for domains, proportion in domain_proportions.items() 468 ] 469 470 @model_validator(mode="after") 471 def check_selected_domains_have_non_null_proportion(self) -> Self: 472 for domain in self.domains: 473 domain_base_name = domain.domain_type.kind.value.base 474 group = frozenset([domain_base_name]) 475 if self.domain_proportions.get(group, 0) <= 0: 476 raise ValueError( 477 "Selected domains in `domains` should have a non-zero " 478 "proportion in `domain_proportions` " 479 f"but '{domain_base_name}' is not part of `domain_proportions`." 480 ) 481 return self
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
443 @field_validator("domain_proportions", mode="before") 444 @classmethod 445 def domain_proportion_validator( 446 cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float] 447 ) -> Mapping[frozenset[str], float]: 448 """ 449 Replace the format: 450 ``` 451 - domains: ["v"] 452 proportion: 1.0 453 ``` 454 in the yaml file into a Mapping[frozenset[str], float] 455 """ 456 if isinstance(value, Mapping): 457 return value 458 else: 459 return {frozenset(item["domains"]): item["proportion"] for item in value}
Replace the format:
- domains: ["v"]
proportion: 1.0
in the yaml file into a Mapping[frozenset[str], float]
461 @field_serializer("domain_proportions") 462 def serialize_domain_proportions( 463 self, domain_proportions: Mapping[frozenset[str], float], _info 464 ) -> list[DomainProportionT]: 465 return [ 466 {"domains": list(domains), "proportion": proportion} 467 for domains, proportion in domain_proportions.items() 468 ]
470 @model_validator(mode="after") 471 def check_selected_domains_have_non_null_proportion(self) -> Self: 472 for domain in self.domains: 473 domain_base_name = domain.domain_type.kind.value.base 474 group = frozenset([domain_base_name]) 475 if self.domain_proportions.get(group, 0) <= 0: 476 raise ValueError( 477 "Selected domains in `domains` should have a non-zero " 478 "proportion in `domain_proportions` " 479 f"but '{domain_base_name}' is not part of `domain_proportions`." 480 ) 481 return self
484def use_deprecated_vals(config: Config) -> Config: 485 # use deprecated values 486 if config.global_workspace.domain_args is not None: 487 config.domain_data_args = config.global_workspace.domain_args 488 warnings.warn( 489 "Deprecated `config.global_workspace.domain_args`, " 490 "use `config.domain_data_args` instead", 491 DeprecationWarning, 492 stacklevel=2, 493 ) 494 if config.global_workspace.domains is not None: 495 config.domains = config.global_workspace.domains 496 warnings.warn( 497 "Deprecated `config.global_workspace.domains`, " 498 "use `config.domains` instead", 499 DeprecationWarning, 500 stacklevel=2, 501 ) 502 if config.global_workspace.domain_proportions is not None: 503 config.domain_proportions = { 504 frozenset(item.domains): item.proportion 505 for item in config.global_workspace.domain_proportions 506 } 507 508 warnings.warn( 509 "Deprecated `config.global_workspace.domain_proportions`, " 510 "use `config.domain_proportions` instead", 511 DeprecationWarning, 512 stacklevel=2, 513 ) 514 return config
517def load_config( 518 path: str | Path, 519 load_files: list[str] | None = None, 520 use_cli: bool = True, 521 debug_mode: bool = False, 522 argv: list[str] | None = None, 523 log_config: bool = False, 524) -> Config: 525 path = Path(path) 526 conf_files = [] 527 if load_files is not None: 528 conf_files.extend(load_files) 529 if (path / "main.yaml").exists(): 530 conf_files.append("main.yaml") 531 if (path / "local.yaml").exists(): 532 conf_files.append("local.yaml") 533 534 if debug_mode and (path / "debug.yaml").exists(): 535 conf_files.append("debug.yaml") 536 537 config_dict, cli_config = load_config_files(path, conf_files, use_cli, argv) 538 539 config_dict.update( 540 { 541 "__shimmer__": { 542 "version": __version__, 543 "debug": debug_mode, 544 "cli": cli_config, 545 } 546 } 547 ) 548 549 conf = use_deprecated_vals( 550 validate_and_fill_missing(config_dict, Config, path, "local.yaml") 551 ) 552 if log_config: 553 print("Loaded config:") 554 pprint.pp(dict(conf)) 555 return conf