shimmer_ssd.config
1import pprint 2import warnings 3from collections.abc import Mapping, Sequence 4from enum import Enum 5from pathlib import Path 6from typing import Any, Literal 7 8from cfg_tools import ParsedModel, load_config_files 9from cfg_tools.utils import validate_and_fill_missing 10from pydantic import ( 11 BaseModel, 12 Field, 13 GetCoreSchemaHandler, 14 field_serializer, 15 field_validator, 16) 17from pydantic_core import core_schema 18from shimmer import __version__ 19from shimmer.version import __version__ as shimmer_version 20from simple_shapes_dataset import DomainType 21from typing_extensions import TypedDict 22 23from shimmer_ssd import PROJECT_DIR 24 25 26class DomainModuleVariant(Enum): 27 """ 28 This is used to select a particular DomainModule. 29 Each domain can have different variants of domain modules. 30 31 For example "attr" will load the default attribute module that uses a VAE, 32 whereas "attr_legacy" will load a domain module that directly passes attributes 33 to the GW. 34 """ 35 36 # Attribute modules 37 # ----------------- 38 # Attribute module using a VAE to encode the attribute vector 39 attr = (DomainType.attr, "default") 40 # This is the module used in Devillers et al. paper. There is no VAE and the 41 # attributes are used directly as the unimodal latent representations 42 attr_legacy = (DomainType.attr, "legacy") 43 # Same as "attr" but adds an unpaired attributes (information not available in the 44 # other domains). 45 attr_unpaired = (DomainType.attr, "unpaired") 46 47 # Visual modules 48 # -------------- 49 # Visual VAE 50 v = (DomainType.v, "default") 51 # Same as "v", but uses pre-saved latent VAE representation for faster training. 52 # This skips the image loading and encoding and only loads latent representation. 53 # The downside is that you cannot access the default image, but you can reconstruct 54 # it with "decode_images". 55 v_latents = (DomainType.v_latents, "default") 56 # Same as "v_latents" but adds an unpaired value (radom information not available 57 # in the other domains). 58 v_latents_unpaired = (DomainType.v_latents, "unpaired") 59 60 # Text modules 61 # ------------ 62 # Text domain. 63 t = (DomainType.t, "default") 64 t_attr = (DomainType.t, "t2attr") 65 66 def __init__(self, kind: DomainType, model_variant: str) -> None: 67 """ 68 The two elements of the tuple are put in the the `kind` and the `model_variant` 69 properties. 70 """ 71 self.kind = kind 72 self.model_variant = model_variant 73 74 @classmethod 75 def __get_pydantic_core_schema__( 76 cls, _source_type: Any, _handler: GetCoreSchemaHandler 77 ) -> core_schema.CoreSchema: 78 """ 79 Define how this type is validated and serialized by pydantic. 80 It can take a str related to the enum keys. 81 It should be serialized as the the enum key name. 82 """ 83 84 def validate_from_str(v: str) -> DomainModuleVariant: 85 """ 86 Use names instead of values to select enums 87 """ 88 assert v in DomainModuleVariant.__members__, ( 89 f"Domain type `{v}` is not a member " 90 f"of {list(DomainModuleVariant.__members__.keys())}" 91 ) 92 return DomainModuleVariant[v] 93 94 from_str_schema = core_schema.no_info_plain_validator_function( 95 validate_from_str 96 ) 97 98 def serialize_domain_variant(v: DomainModuleVariant) -> str: 99 return v.name 100 101 return core_schema.json_or_python_schema( 102 json_schema=from_str_schema, 103 python_schema=core_schema.union_schema( 104 [ 105 core_schema.is_instance_schema(DomainModuleVariant), 106 from_str_schema, 107 ] 108 ), 109 serialization=core_schema.plain_serializer_function_ser_schema( 110 serialize_domain_variant 111 ), 112 ) 113 114 115class Logging(BaseModel): 116 # List of medias (images/text) that will be logged. 117 # The list is defined in individual train config (e.g. train_v.yaml, train_gw.yaml) 118 filter_images: Sequence[str] | None = None 119 log_train_medias_every_n_epochs: int | None = 10 120 log_val_medias_every_n_epochs: int | None = 10 121 122 123class Slurm(BaseModel): 124 """ 125 Slurm config for https://github.com/bdvllrs/auto-sbatch 126 """ 127 128 script: str 129 run_workdir: str 130 python_env: str 131 command: str 132 133 pre_modules: Sequence[str] = [] 134 run_modules: Sequence[str] = [] 135 args: Mapping[str, Any] = {} 136 137 grid_search: Mapping[str, Sequence[Any]] | None = None 138 grid_search_exclude: Sequence[Mapping[str, Any]] | None = None 139 140 141class Optim(BaseModel): 142 """ 143 Optimizer config 144 """ 145 146 # learning rate (max learning rate when using with the default OneCycle 147 # learning rate scheduler) 148 # see https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html#torch.optim.lr_scheduler.OneCycleLR) 149 # for information about the other config values 150 lr: float = 5e-3 151 # TODO: remove as redundant with lr 152 max_lr: float = 5e-3 153 start_lr: float = 5e-4 154 end_lr: float = 5e-4 155 pct_start: float = 0.2 156 weight_decay: float = 1e-5 157 158 159class Training(BaseModel): 160 """ 161 Training related config. 162 As these config depend on what you are training, 163 they are defined in the script related yaml files (e.g. `train_v.yaml`, 164 `train_gw.yaml`). 165 """ 166 167 batch_size: int = 2056 168 num_workers: int = 16 169 devices: int = 1 # num of devices (gpus) to use 170 accelerator: str = "gpu" 171 172 # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#fast-dev-run 173 fast_dev_run: bool = False 174 # number of training steps 175 max_steps: int = 200_000 176 enable_progress_bar: bool = True 177 178 # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision 179 # you may want to set to "16-mixed" if your gpu allows mixed precision 180 precision: Any = 32 181 # see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch-set-float32-matmul-precision 182 # you may want to decrease to "medium" 183 float32_matmul_precision: str = "highest" 184 185 # Optimizer config 186 optim: Optim = Optim() 187 188 189class ExploreVAE(BaseModel): 190 # the VAE checkpoint to use 191 checkpoint: str 192 num_samples: int = 9 193 range_start: int = -3 194 range_end: int = 3 195 wandb_name: str | None = None 196 197 198class ExploreGW(BaseModel): 199 checkpoint: str # The GW checkpoint to use 200 domain: str 201 num_samples: int = 9 202 range_start: int = -3 203 range_end: int = 3 204 wandb_name: str | None = None 205 206 207class Visualization(BaseModel): 208 # when exploring a vae 209 explore_vae: ExploreVAE 210 # when exploring the GW 211 explore_gw: ExploreGW 212 213 214class WanDB(BaseModel): 215 enabled: bool = False 216 # where to save the logs 217 save_dir: str = "" 218 # the "entity/project" values of your wandb project 219 project: str = "" 220 entity: str = "" 221 # see https://docs.wandb.ai/ref/python/init/ 222 reinit: bool = False 223 224 225class Dataset(BaseModel): 226 """ 227 Simple shapes dataset infos 228 """ 229 230 # Path to the dataset obtainable on https://github.com/ruflab/simple-shapes-dataset 231 path: Path 232 233 234class VisualModule(BaseModel): 235 """ 236 Config for the visual domain module 237 """ 238 239 num_channels: int = 3 # num channels of image to configure the VAE 240 ae_dim: int = 256 # AE hidden dim 241 latent_dim: int = 8 # latent dim of the VAE 242 beta: float = 0.1 # beta for betaVAE 243 244 # Whether the model is color blind. 245 # It adds a transform on the dataset that averages all color channels. 246 # NOTE: this only works when using the "v" domain and not "v_latents" where 247 # visual latent representations are already extracted. In that case, use a different 248 # presaved-path in `domain_args`. 249 color_blind: bool = False 250 251 252class AttributeModule(BaseModel): 253 """ 254 Config for the attribute domain module 255 """ 256 257 latent_dim: int = 10 # latent dim of the VAE 258 hidden_dim: int = 64 # hidden dim of the AE (encoders and decoders) 259 beta: float = 0.05 # for betaVAE 260 coef_categories: float = 1 261 coef_attributes: float = 1 262 263 # Whether to remove rotation information from the attribute. 264 nullify_rotation: bool = False 265 266 267class TextModule(BaseModel): 268 """ 269 Config for the text domain module 270 """ 271 272 # which file to load for text 273 # The file generated with `shapesd create` is "latent" 274 # If you use an older version (like if you downloaded directly the dataset) it 275 # should be "bert-base-uncased" 276 latent_filename: str = "latent" 277 # path to the vocab file 278 vocab_path: str = str((PROJECT_DIR / "tokenizer/vocab.json").resolve()) 279 # path to the merge path 280 merges_path: str = str((PROJECT_DIR / "tokenizer/merges.txt").resolve()) 281 # max sequence length of text sequence 282 seq_length: int = 64 283 vocab_size: int = 822 284 285 # VAE configuration 286 latent_dim: int = 64 287 hidden_dim: int = 256 288 beta: float = 0.1 289 reconstruction_coef: float = 1.0 290 kl_coef: float = 0.1 291 292 293class DomainModules(BaseModel): 294 visual: VisualModule = VisualModule() 295 attribute: AttributeModule = AttributeModule() 296 text: TextModule = TextModule() 297 298 299class EncodersConfig(BaseModel): 300 """ 301 Encoder architecture config 302 """ 303 304 hidden_dim: int | Mapping[DomainModuleVariant, int] = 32 305 306 # The model will have an extra linear before and after the n_layers 307 # Hence the total will be `2 + n_layers` 308 n_layers: int | Mapping[DomainModuleVariant, int] = 3 309 310 311class LoadedDomainConfig(BaseModel): 312 """ 313 Domain params for the active domains of the run 314 """ 315 316 # path to the pretrained module 317 checkpoint_path: Path 318 # domain to select 319 domain_type: DomainModuleVariant 320 # domain module specific arguments 321 args: Mapping[str, Any] = {} 322 323 324class DomainProportion(BaseModel): 325 """ 326 Deprecated, DomainProportionT will be used instead. 327 """ 328 329 # proportion for some domains 330 # should be in [0, 1] 331 proportion: float 332 # list of domains the proportion is associated to 333 # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data 334 domains: Sequence[str] 335 336 337class DomainProportionT(TypedDict): 338 """This replaces `DomainProportion` in future config.""" 339 340 # proportion for some domains 341 # should be in [0, 1] 342 proportion: float 343 # list of domains the proportion is associated to 344 # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data 345 domains: Sequence[str] 346 347 348class GlobalWorkspace(BaseModel): 349 # latent dim of the GW 350 latent_dim: int = 12 351 # whether to use the fusion GW 352 use_fusion_model: bool = False 353 # softmax temp for the Softmax distruted selection used by the fusion model 354 selection_temperature: float = 0.2 355 # whether to learn the logit scale of the contrastive loss like in the clip paper 356 learn_logit_scale: bool = False 357 # whether to use the VSEPP (https://github.com/fartashf/vsepp) contrastive loss 358 # with associated params 359 vsepp_contrastive_loss: bool = False 360 vsepp_margin: float = 0.2 361 vsepp_measure: Literal["cosine", "order"] = "cosine" 362 vsepp_max_violation: bool = True 363 # whether to use linear encoders and decoders for the GW 364 linear_domains: bool = False 365 # whether to use bias when using linear encoders and decoders 366 linear_domains_use_bias: bool = False 367 # encoder architecture config 368 encoders: EncodersConfig = EncodersConfig() 369 # decoder architecture config 370 decoders: EncodersConfig = EncodersConfig() 371 # coefs of each loss. The total loss is computed using the given values and coefs 372 # you can select any available loss generated by the loss functions 373 loss_coefficients: Mapping[str, float] = { 374 "cycles": 1.0, 375 "demi_cycles": 1.0, 376 "translations": 1.0, 377 "contrastives": 0.01, 378 "fused": 1.0, 379 } 380 # checkpoint of the GW for downstream, visualization tasks, or migrations 381 checkpoint: Path | None = None 382 # deprecated, use Config.domain_data_args instead 383 domain_args: Mapping[str, Mapping[str, Any]] | None = Field( 384 default=None, deprecated="Use `config.domain_data_args` instead." 385 ) 386 # deprecated, use Config.domains instead 387 domains: Sequence[LoadedDomainConfig] | None = Field( 388 default=None, deprecated="Use `config.domains` instead." 389 ) 390 # deprecated, use Config.domain_proportions instead 391 domain_proportions: Sequence[DomainProportion] | None = Field( 392 default=None, deprecated="Use `config.domain_proportions` instead." 393 ) 394 395 396class ShimmerConfigInfo(BaseModel): 397 """ 398 Some shimmer related config 399 """ 400 401 # version of shimmer used 402 version: str = shimmer_version 403 # whether started in debug mode 404 debug: bool = False 405 # params that were passed through CLI 406 cli: Any = {} 407 408 409class Config(ParsedModel): 410 seed: int = 0 # training seed 411 ood_seed: int | None = None # Out of distribution seed 412 default_root_dir: Path = ( 413 Path("./checkpoints") # Path where to save and load logs and checkpoints 414 ) 415 # Dataset information 416 dataset: Dataset 417 # Training config 418 training: Training = Training() 419 # Wandb configuration 420 wandb: WanDB = WanDB() 421 # Logging configuration 422 logging: Logging = Logging() 423 # Add a title to your wandb run 424 title: str | None = Field(None, alias="t") 425 # Add a description to your run 426 desc: str | None = Field(None, alias="d") 427 # proportion of each domain in the dataset relative to the size of the dataset 428 domain_proportions: Mapping[frozenset[str], float] = {} 429 # Config of the different domain modules 430 domain_modules: DomainModules = DomainModules() 431 # Domain params for the active domains of the run 432 domains: Sequence[LoadedDomainConfig] = [] 433 # data related args used by the dataloader 434 domain_data_args: Mapping[str, Mapping[str, Any]] = {} 435 # Global workspace configuration 436 global_workspace: GlobalWorkspace = GlobalWorkspace() 437 # Config used during visualization 438 visualization: Visualization | None = None 439 # Slurm config when startig on a cluster 440 slurm: Slurm | None = None 441 __shimmer__: ShimmerConfigInfo = ShimmerConfigInfo() 442 443 @field_validator("domain_proportions", mode="before") 444 @classmethod 445 def domain_proportion_validator( 446 cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float] 447 ) -> Mapping[frozenset[str], float]: 448 """ 449 Replace the format: 450 ``` 451 - domains: ["v"] 452 proportion: 1.0 453 ``` 454 in the yaml file into a Mapping[frozenset[str], float] 455 """ 456 if isinstance(value, Mapping): 457 return value 458 else: 459 return {frozenset(item["domains"]): item["proportion"] for item in value} 460 461 @field_serializer("domain_proportions") 462 def serialize_domain_proportions( 463 self, domain_proportions: Mapping[frozenset[str], float], _info 464 ) -> list[DomainProportionT]: 465 return [ 466 {"domains": list(domains), "proportion": proportion} 467 for domains, proportion in domain_proportions.items() 468 ] 469 470 471def use_deprecated_vals(config: Config) -> Config: 472 # use deprecated values 473 if config.global_workspace.domain_args is not None: 474 config.domain_data_args = config.global_workspace.domain_args 475 warnings.warn( 476 "Deprecated `config.global_workspace.domain_args`, " 477 "use `config.domain_data_args` instead", 478 DeprecationWarning, 479 stacklevel=2, 480 ) 481 if config.global_workspace.domains is not None: 482 config.domains = config.global_workspace.domains 483 warnings.warn( 484 "Deprecated `config.global_workspace.domains`, " 485 "use `config.domains` instead", 486 DeprecationWarning, 487 stacklevel=2, 488 ) 489 if config.global_workspace.domain_proportions is not None: 490 config.domain_proportions = { 491 frozenset(item.domains): item.proportion 492 for item in config.global_workspace.domain_proportions 493 } 494 495 warnings.warn( 496 "Deprecated `config.global_workspace.domain_proportions`, " 497 "use `config.domain_proportions` instead", 498 DeprecationWarning, 499 stacklevel=2, 500 ) 501 return config 502 503 504def load_config( 505 path: str | Path, 506 load_files: list[str] | None = None, 507 use_cli: bool = True, 508 debug_mode: bool = False, 509 argv: list[str] | None = None, 510 log_config: bool = False, 511) -> Config: 512 path = Path(path) 513 conf_files = [] 514 if load_files is not None: 515 conf_files.extend(load_files) 516 if (path / "main.yaml").exists(): 517 conf_files.append("main.yaml") 518 if (path / "local.yaml").exists(): 519 conf_files.append("local.yaml") 520 521 if debug_mode and (path / "debug.yaml").exists(): 522 conf_files.append("debug.yaml") 523 524 config_dict, cli_config = load_config_files(path, conf_files, use_cli, argv) 525 526 config_dict.update( 527 { 528 "__shimmer__": { 529 "version": __version__, 530 "debug": debug_mode, 531 "cli": cli_config, 532 } 533 } 534 ) 535 536 conf = use_deprecated_vals( 537 validate_and_fill_missing(config_dict, Config, path, "local.yaml") 538 ) 539 if log_config: 540 print("Loaded config:") 541 pprint.pp(dict(conf)) 542 return conf
27class DomainModuleVariant(Enum): 28 """ 29 This is used to select a particular DomainModule. 30 Each domain can have different variants of domain modules. 31 32 For example "attr" will load the default attribute module that uses a VAE, 33 whereas "attr_legacy" will load a domain module that directly passes attributes 34 to the GW. 35 """ 36 37 # Attribute modules 38 # ----------------- 39 # Attribute module using a VAE to encode the attribute vector 40 attr = (DomainType.attr, "default") 41 # This is the module used in Devillers et al. paper. There is no VAE and the 42 # attributes are used directly as the unimodal latent representations 43 attr_legacy = (DomainType.attr, "legacy") 44 # Same as "attr" but adds an unpaired attributes (information not available in the 45 # other domains). 46 attr_unpaired = (DomainType.attr, "unpaired") 47 48 # Visual modules 49 # -------------- 50 # Visual VAE 51 v = (DomainType.v, "default") 52 # Same as "v", but uses pre-saved latent VAE representation for faster training. 53 # This skips the image loading and encoding and only loads latent representation. 54 # The downside is that you cannot access the default image, but you can reconstruct 55 # it with "decode_images". 56 v_latents = (DomainType.v_latents, "default") 57 # Same as "v_latents" but adds an unpaired value (radom information not available 58 # in the other domains). 59 v_latents_unpaired = (DomainType.v_latents, "unpaired") 60 61 # Text modules 62 # ------------ 63 # Text domain. 64 t = (DomainType.t, "default") 65 t_attr = (DomainType.t, "t2attr") 66 67 def __init__(self, kind: DomainType, model_variant: str) -> None: 68 """ 69 The two elements of the tuple are put in the the `kind` and the `model_variant` 70 properties. 71 """ 72 self.kind = kind 73 self.model_variant = model_variant 74 75 @classmethod 76 def __get_pydantic_core_schema__( 77 cls, _source_type: Any, _handler: GetCoreSchemaHandler 78 ) -> core_schema.CoreSchema: 79 """ 80 Define how this type is validated and serialized by pydantic. 81 It can take a str related to the enum keys. 82 It should be serialized as the the enum key name. 83 """ 84 85 def validate_from_str(v: str) -> DomainModuleVariant: 86 """ 87 Use names instead of values to select enums 88 """ 89 assert v in DomainModuleVariant.__members__, ( 90 f"Domain type `{v}` is not a member " 91 f"of {list(DomainModuleVariant.__members__.keys())}" 92 ) 93 return DomainModuleVariant[v] 94 95 from_str_schema = core_schema.no_info_plain_validator_function( 96 validate_from_str 97 ) 98 99 def serialize_domain_variant(v: DomainModuleVariant) -> str: 100 return v.name 101 102 return core_schema.json_or_python_schema( 103 json_schema=from_str_schema, 104 python_schema=core_schema.union_schema( 105 [ 106 core_schema.is_instance_schema(DomainModuleVariant), 107 from_str_schema, 108 ] 109 ), 110 serialization=core_schema.plain_serializer_function_ser_schema( 111 serialize_domain_variant 112 ), 113 )
This is used to select a particular DomainModule. Each domain can have different variants of domain modules.
For example "attr" will load the default attribute module that uses a VAE, whereas "attr_legacy" will load a domain module that directly passes attributes to the GW.
67 def __init__(self, kind: DomainType, model_variant: str) -> None: 68 """ 69 The two elements of the tuple are put in the the `kind` and the `model_variant` 70 properties. 71 """ 72 self.kind = kind 73 self.model_variant = model_variant
The two elements of the tuple are put in the the kind
and the model_variant
properties.
116class Logging(BaseModel): 117 # List of medias (images/text) that will be logged. 118 # The list is defined in individual train config (e.g. train_v.yaml, train_gw.yaml) 119 filter_images: Sequence[str] | None = None 120 log_train_medias_every_n_epochs: int | None = 10 121 log_val_medias_every_n_epochs: int | None = 10
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
124class Slurm(BaseModel): 125 """ 126 Slurm config for https://github.com/bdvllrs/auto-sbatch 127 """ 128 129 script: str 130 run_workdir: str 131 python_env: str 132 command: str 133 134 pre_modules: Sequence[str] = [] 135 run_modules: Sequence[str] = [] 136 args: Mapping[str, Any] = {} 137 138 grid_search: Mapping[str, Sequence[Any]] | None = None 139 grid_search_exclude: Sequence[Mapping[str, Any]] | None = None
Slurm config for https://github.com/bdvllrs/auto-sbatch
142class Optim(BaseModel): 143 """ 144 Optimizer config 145 """ 146 147 # learning rate (max learning rate when using with the default OneCycle 148 # learning rate scheduler) 149 # see https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html#torch.optim.lr_scheduler.OneCycleLR) 150 # for information about the other config values 151 lr: float = 5e-3 152 # TODO: remove as redundant with lr 153 max_lr: float = 5e-3 154 start_lr: float = 5e-4 155 end_lr: float = 5e-4 156 pct_start: float = 0.2 157 weight_decay: float = 1e-5
Optimizer config
160class Training(BaseModel): 161 """ 162 Training related config. 163 As these config depend on what you are training, 164 they are defined in the script related yaml files (e.g. `train_v.yaml`, 165 `train_gw.yaml`). 166 """ 167 168 batch_size: int = 2056 169 num_workers: int = 16 170 devices: int = 1 # num of devices (gpus) to use 171 accelerator: str = "gpu" 172 173 # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#fast-dev-run 174 fast_dev_run: bool = False 175 # number of training steps 176 max_steps: int = 200_000 177 enable_progress_bar: bool = True 178 179 # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision 180 # you may want to set to "16-mixed" if your gpu allows mixed precision 181 precision: Any = 32 182 # see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch-set-float32-matmul-precision 183 # you may want to decrease to "medium" 184 float32_matmul_precision: str = "highest" 185 186 # Optimizer config 187 optim: Optim = Optim()
Training related config.
As these config depend on what you are training,
they are defined in the script related yaml files (e.g. train_v.yaml
,
train_gw.yaml
).
190class ExploreVAE(BaseModel): 191 # the VAE checkpoint to use 192 checkpoint: str 193 num_samples: int = 9 194 range_start: int = -3 195 range_end: int = 3 196 wandb_name: str | None = None
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
199class ExploreGW(BaseModel): 200 checkpoint: str # The GW checkpoint to use 201 domain: str 202 num_samples: int = 9 203 range_start: int = -3 204 range_end: int = 3 205 wandb_name: str | None = None
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
208class Visualization(BaseModel): 209 # when exploring a vae 210 explore_vae: ExploreVAE 211 # when exploring the GW 212 explore_gw: ExploreGW
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
215class WanDB(BaseModel): 216 enabled: bool = False 217 # where to save the logs 218 save_dir: str = "" 219 # the "entity/project" values of your wandb project 220 project: str = "" 221 entity: str = "" 222 # see https://docs.wandb.ai/ref/python/init/ 223 reinit: bool = False
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
226class Dataset(BaseModel): 227 """ 228 Simple shapes dataset infos 229 """ 230 231 # Path to the dataset obtainable on https://github.com/ruflab/simple-shapes-dataset 232 path: Path
Simple shapes dataset infos
235class VisualModule(BaseModel): 236 """ 237 Config for the visual domain module 238 """ 239 240 num_channels: int = 3 # num channels of image to configure the VAE 241 ae_dim: int = 256 # AE hidden dim 242 latent_dim: int = 8 # latent dim of the VAE 243 beta: float = 0.1 # beta for betaVAE 244 245 # Whether the model is color blind. 246 # It adds a transform on the dataset that averages all color channels. 247 # NOTE: this only works when using the "v" domain and not "v_latents" where 248 # visual latent representations are already extracted. In that case, use a different 249 # presaved-path in `domain_args`. 250 color_blind: bool = False
Config for the visual domain module
253class AttributeModule(BaseModel): 254 """ 255 Config for the attribute domain module 256 """ 257 258 latent_dim: int = 10 # latent dim of the VAE 259 hidden_dim: int = 64 # hidden dim of the AE (encoders and decoders) 260 beta: float = 0.05 # for betaVAE 261 coef_categories: float = 1 262 coef_attributes: float = 1 263 264 # Whether to remove rotation information from the attribute. 265 nullify_rotation: bool = False
Config for the attribute domain module
268class TextModule(BaseModel): 269 """ 270 Config for the text domain module 271 """ 272 273 # which file to load for text 274 # The file generated with `shapesd create` is "latent" 275 # If you use an older version (like if you downloaded directly the dataset) it 276 # should be "bert-base-uncased" 277 latent_filename: str = "latent" 278 # path to the vocab file 279 vocab_path: str = str((PROJECT_DIR / "tokenizer/vocab.json").resolve()) 280 # path to the merge path 281 merges_path: str = str((PROJECT_DIR / "tokenizer/merges.txt").resolve()) 282 # max sequence length of text sequence 283 seq_length: int = 64 284 vocab_size: int = 822 285 286 # VAE configuration 287 latent_dim: int = 64 288 hidden_dim: int = 256 289 beta: float = 0.1 290 reconstruction_coef: float = 1.0 291 kl_coef: float = 0.1
Config for the text domain module
294class DomainModules(BaseModel): 295 visual: VisualModule = VisualModule() 296 attribute: AttributeModule = AttributeModule() 297 text: TextModule = TextModule()
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
300class EncodersConfig(BaseModel): 301 """ 302 Encoder architecture config 303 """ 304 305 hidden_dim: int | Mapping[DomainModuleVariant, int] = 32 306 307 # The model will have an extra linear before and after the n_layers 308 # Hence the total will be `2 + n_layers` 309 n_layers: int | Mapping[DomainModuleVariant, int] = 3
Encoder architecture config
312class LoadedDomainConfig(BaseModel): 313 """ 314 Domain params for the active domains of the run 315 """ 316 317 # path to the pretrained module 318 checkpoint_path: Path 319 # domain to select 320 domain_type: DomainModuleVariant 321 # domain module specific arguments 322 args: Mapping[str, Any] = {}
Domain params for the active domains of the run
325class DomainProportion(BaseModel): 326 """ 327 Deprecated, DomainProportionT will be used instead. 328 """ 329 330 # proportion for some domains 331 # should be in [0, 1] 332 proportion: float 333 # list of domains the proportion is associated to 334 # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data 335 domains: Sequence[str]
Deprecated, DomainProportionT will be used instead.
338class DomainProportionT(TypedDict): 339 """This replaces `DomainProportion` in future config.""" 340 341 # proportion for some domains 342 # should be in [0, 1] 343 proportion: float 344 # list of domains the proportion is associated to 345 # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data 346 domains: Sequence[str]
This replaces DomainProportion
in future config.
349class GlobalWorkspace(BaseModel): 350 # latent dim of the GW 351 latent_dim: int = 12 352 # whether to use the fusion GW 353 use_fusion_model: bool = False 354 # softmax temp for the Softmax distruted selection used by the fusion model 355 selection_temperature: float = 0.2 356 # whether to learn the logit scale of the contrastive loss like in the clip paper 357 learn_logit_scale: bool = False 358 # whether to use the VSEPP (https://github.com/fartashf/vsepp) contrastive loss 359 # with associated params 360 vsepp_contrastive_loss: bool = False 361 vsepp_margin: float = 0.2 362 vsepp_measure: Literal["cosine", "order"] = "cosine" 363 vsepp_max_violation: bool = True 364 # whether to use linear encoders and decoders for the GW 365 linear_domains: bool = False 366 # whether to use bias when using linear encoders and decoders 367 linear_domains_use_bias: bool = False 368 # encoder architecture config 369 encoders: EncodersConfig = EncodersConfig() 370 # decoder architecture config 371 decoders: EncodersConfig = EncodersConfig() 372 # coefs of each loss. The total loss is computed using the given values and coefs 373 # you can select any available loss generated by the loss functions 374 loss_coefficients: Mapping[str, float] = { 375 "cycles": 1.0, 376 "demi_cycles": 1.0, 377 "translations": 1.0, 378 "contrastives": 0.01, 379 "fused": 1.0, 380 } 381 # checkpoint of the GW for downstream, visualization tasks, or migrations 382 checkpoint: Path | None = None 383 # deprecated, use Config.domain_data_args instead 384 domain_args: Mapping[str, Mapping[str, Any]] | None = Field( 385 default=None, deprecated="Use `config.domain_data_args` instead." 386 ) 387 # deprecated, use Config.domains instead 388 domains: Sequence[LoadedDomainConfig] | None = Field( 389 default=None, deprecated="Use `config.domains` instead." 390 ) 391 # deprecated, use Config.domain_proportions instead 392 domain_proportions: Sequence[DomainProportion] | None = Field( 393 default=None, deprecated="Use `config.domain_proportions` instead." 394 )
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.
Attributes:
- msg: The deprecation message to be emitted.
- wrapped_property: The property instance if the deprecated field is a computed field, or
None
. - field_name: The name of the field being deprecated.
Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.
Attributes:
- msg: The deprecation message to be emitted.
- wrapped_property: The property instance if the deprecated field is a computed field, or
None
. - field_name: The name of the field being deprecated.
Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.
Attributes:
- msg: The deprecation message to be emitted.
- wrapped_property: The property instance if the deprecated field is a computed field, or
None
. - field_name: The name of the field being deprecated.
397class ShimmerConfigInfo(BaseModel): 398 """ 399 Some shimmer related config 400 """ 401 402 # version of shimmer used 403 version: str = shimmer_version 404 # whether started in debug mode 405 debug: bool = False 406 # params that were passed through CLI 407 cli: Any = {}
Some shimmer related config
410class Config(ParsedModel): 411 seed: int = 0 # training seed 412 ood_seed: int | None = None # Out of distribution seed 413 default_root_dir: Path = ( 414 Path("./checkpoints") # Path where to save and load logs and checkpoints 415 ) 416 # Dataset information 417 dataset: Dataset 418 # Training config 419 training: Training = Training() 420 # Wandb configuration 421 wandb: WanDB = WanDB() 422 # Logging configuration 423 logging: Logging = Logging() 424 # Add a title to your wandb run 425 title: str | None = Field(None, alias="t") 426 # Add a description to your run 427 desc: str | None = Field(None, alias="d") 428 # proportion of each domain in the dataset relative to the size of the dataset 429 domain_proportions: Mapping[frozenset[str], float] = {} 430 # Config of the different domain modules 431 domain_modules: DomainModules = DomainModules() 432 # Domain params for the active domains of the run 433 domains: Sequence[LoadedDomainConfig] = [] 434 # data related args used by the dataloader 435 domain_data_args: Mapping[str, Mapping[str, Any]] = {} 436 # Global workspace configuration 437 global_workspace: GlobalWorkspace = GlobalWorkspace() 438 # Config used during visualization 439 visualization: Visualization | None = None 440 # Slurm config when startig on a cluster 441 slurm: Slurm | None = None 442 __shimmer__: ShimmerConfigInfo = ShimmerConfigInfo() 443 444 @field_validator("domain_proportions", mode="before") 445 @classmethod 446 def domain_proportion_validator( 447 cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float] 448 ) -> Mapping[frozenset[str], float]: 449 """ 450 Replace the format: 451 ``` 452 - domains: ["v"] 453 proportion: 1.0 454 ``` 455 in the yaml file into a Mapping[frozenset[str], float] 456 """ 457 if isinstance(value, Mapping): 458 return value 459 else: 460 return {frozenset(item["domains"]): item["proportion"] for item in value} 461 462 @field_serializer("domain_proportions") 463 def serialize_domain_proportions( 464 self, domain_proportions: Mapping[frozenset[str], float], _info 465 ) -> list[DomainProportionT]: 466 return [ 467 {"domains": list(domains), "proportion": proportion} 468 for domains, proportion in domain_proportions.items() 469 ]
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
- __class_vars__: The names of the class variables defined on the model.
- __private_attributes__: Metadata about the private attributes of the model.
- __signature__: The synthesized
__init__
[Signature
][inspect.Signature] of the model. - __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
- __pydantic_core_schema__: The core schema of the model.
- __pydantic_custom_init__: Whether the model has a custom
__init__
function. - __pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces
Model.__validators__
andModel.__root_validators__
from Pydantic V1. - __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
- __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
- __pydantic_post_init__: The name of the post-init method for the model, if defined.
- __pydantic_root_model__: Whether the model is a [
RootModel
][pydantic.root_model.RootModel]. - __pydantic_serializer__: The
pydantic-core
SchemaSerializer
used to dump instances of the model. - __pydantic_validator__: The
pydantic-core
SchemaValidator
used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [
FieldInfo
][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [
ComputedFieldInfo
][pydantic.fields.ComputedFieldInfo] objects. - __pydantic_extra__: A dictionary containing extra values, if [
extra
][pydantic.config.ConfigDict.extra] is set to'allow'
. - __pydantic_fields_set__: The names of fields explicitly set during instantiation.
- __pydantic_private__: Values of private attributes set on the model instance.
444 @field_validator("domain_proportions", mode="before") 445 @classmethod 446 def domain_proportion_validator( 447 cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float] 448 ) -> Mapping[frozenset[str], float]: 449 """ 450 Replace the format: 451 ``` 452 - domains: ["v"] 453 proportion: 1.0 454 ``` 455 in the yaml file into a Mapping[frozenset[str], float] 456 """ 457 if isinstance(value, Mapping): 458 return value 459 else: 460 return {frozenset(item["domains"]): item["proportion"] for item in value}
Replace the format:
- domains: ["v"]
proportion: 1.0
in the yaml file into a Mapping[frozenset[str], float]
462 @field_serializer("domain_proportions") 463 def serialize_domain_proportions( 464 self, domain_proportions: Mapping[frozenset[str], float], _info 465 ) -> list[DomainProportionT]: 466 return [ 467 {"domains": list(domains), "proportion": proportion} 468 for domains, proportion in domain_proportions.items() 469 ]
472def use_deprecated_vals(config: Config) -> Config: 473 # use deprecated values 474 if config.global_workspace.domain_args is not None: 475 config.domain_data_args = config.global_workspace.domain_args 476 warnings.warn( 477 "Deprecated `config.global_workspace.domain_args`, " 478 "use `config.domain_data_args` instead", 479 DeprecationWarning, 480 stacklevel=2, 481 ) 482 if config.global_workspace.domains is not None: 483 config.domains = config.global_workspace.domains 484 warnings.warn( 485 "Deprecated `config.global_workspace.domains`, " 486 "use `config.domains` instead", 487 DeprecationWarning, 488 stacklevel=2, 489 ) 490 if config.global_workspace.domain_proportions is not None: 491 config.domain_proportions = { 492 frozenset(item.domains): item.proportion 493 for item in config.global_workspace.domain_proportions 494 } 495 496 warnings.warn( 497 "Deprecated `config.global_workspace.domain_proportions`, " 498 "use `config.domain_proportions` instead", 499 DeprecationWarning, 500 stacklevel=2, 501 ) 502 return config
505def load_config( 506 path: str | Path, 507 load_files: list[str] | None = None, 508 use_cli: bool = True, 509 debug_mode: bool = False, 510 argv: list[str] | None = None, 511 log_config: bool = False, 512) -> Config: 513 path = Path(path) 514 conf_files = [] 515 if load_files is not None: 516 conf_files.extend(load_files) 517 if (path / "main.yaml").exists(): 518 conf_files.append("main.yaml") 519 if (path / "local.yaml").exists(): 520 conf_files.append("local.yaml") 521 522 if debug_mode and (path / "debug.yaml").exists(): 523 conf_files.append("debug.yaml") 524 525 config_dict, cli_config = load_config_files(path, conf_files, use_cli, argv) 526 527 config_dict.update( 528 { 529 "__shimmer__": { 530 "version": __version__, 531 "debug": debug_mode, 532 "cli": cli_config, 533 } 534 } 535 ) 536 537 conf = use_deprecated_vals( 538 validate_and_fill_missing(config_dict, Config, path, "local.yaml") 539 ) 540 if log_config: 541 print("Loaded config:") 542 pprint.pp(dict(conf)) 543 return conf