shimmer_ssd.config

  1import pprint
  2import warnings
  3from collections.abc import Mapping, Sequence
  4from enum import Enum
  5from pathlib import Path
  6from typing import Any, Literal
  7
  8from cfg_tools import ParsedModel, load_config_files
  9from cfg_tools.utils import validate_and_fill_missing
 10from pydantic import (
 11    BaseModel,
 12    Field,
 13    GetCoreSchemaHandler,
 14    field_serializer,
 15    field_validator,
 16)
 17from pydantic_core import core_schema
 18from shimmer import __version__
 19from shimmer.version import __version__ as shimmer_version
 20from simple_shapes_dataset import DomainType
 21from typing_extensions import TypedDict
 22
 23from shimmer_ssd import PROJECT_DIR
 24
 25
 26class DomainModuleVariant(Enum):
 27    """
 28    This is used to select a particular DomainModule.
 29    Each domain can have different variants of domain modules.
 30
 31    For example "attr" will load the default attribute module that uses a VAE,
 32    whereas "attr_legacy" will load a domain module that directly passes attributes
 33    to the GW.
 34    """
 35
 36    # Attribute modules
 37    # -----------------
 38    # Attribute module using a VAE to encode the attribute vector
 39    attr = (DomainType.attr, "default")
 40    # This is the module used in Devillers et al. paper. There is no VAE and the
 41    # attributes are used directly as the unimodal latent representations
 42    attr_legacy = (DomainType.attr, "legacy")
 43    # Same as "attr" but adds an unpaired attributes (information not available in the
 44    # other domains).
 45    attr_unpaired = (DomainType.attr, "unpaired")
 46
 47    # Visual modules
 48    # --------------
 49    # Visual VAE
 50    v = (DomainType.v, "default")
 51    # Same as "v", but uses pre-saved latent VAE representation for faster training.
 52    # This skips the image loading and encoding and only loads latent representation.
 53    # The downside is that you cannot access the default image, but you can reconstruct
 54    # it with "decode_images".
 55    v_latents = (DomainType.v_latents, "default")
 56    # Same as "v_latents" but adds an unpaired value (radom information not available
 57    # in the other domains).
 58    v_latents_unpaired = (DomainType.v_latents, "unpaired")
 59
 60    # Text modules
 61    # ------------
 62    # Text domain.
 63    t = (DomainType.t, "default")
 64    t_attr = (DomainType.t, "t2attr")
 65
 66    def __init__(self, kind: DomainType, model_variant: str) -> None:
 67        """
 68        The two elements of the tuple are put in the the `kind` and the `model_variant`
 69        properties.
 70        """
 71        self.kind = kind
 72        self.model_variant = model_variant
 73
 74    @classmethod
 75    def __get_pydantic_core_schema__(
 76        cls, _source_type: Any, _handler: GetCoreSchemaHandler
 77    ) -> core_schema.CoreSchema:
 78        """
 79        Define how this type is validated and serialized by pydantic.
 80        It can take a str related to the enum keys.
 81        It should be serialized as the the enum key name.
 82        """
 83
 84        def validate_from_str(v: str) -> DomainModuleVariant:
 85            """
 86            Use names instead of values to select enums
 87            """
 88            assert v in DomainModuleVariant.__members__, (
 89                f"Domain type `{v}` is not a member "
 90                f"of {list(DomainModuleVariant.__members__.keys())}"
 91            )
 92            return DomainModuleVariant[v]
 93
 94        from_str_schema = core_schema.no_info_plain_validator_function(
 95            validate_from_str
 96        )
 97
 98        def serialize_domain_variant(v: DomainModuleVariant) -> str:
 99            return v.name
100
101        return core_schema.json_or_python_schema(
102            json_schema=from_str_schema,
103            python_schema=core_schema.union_schema(
104                [
105                    core_schema.is_instance_schema(DomainModuleVariant),
106                    from_str_schema,
107                ]
108            ),
109            serialization=core_schema.plain_serializer_function_ser_schema(
110                serialize_domain_variant
111            ),
112        )
113
114
115class Logging(BaseModel):
116    # List of medias (images/text) that will be logged.
117    # The list is defined in individual train config (e.g. train_v.yaml, train_gw.yaml)
118    filter_images: Sequence[str] | None = None
119    log_train_medias_every_n_epochs: int | None = 10
120    log_val_medias_every_n_epochs: int | None = 10
121
122
123class Slurm(BaseModel):
124    """
125    Slurm config for https://github.com/bdvllrs/auto-sbatch
126    """
127
128    script: str
129    run_workdir: str
130    python_env: str
131    command: str
132
133    pre_modules: Sequence[str] = []
134    run_modules: Sequence[str] = []
135    args: Mapping[str, Any] = {}
136
137    grid_search: Mapping[str, Sequence[Any]] | None = None
138    grid_search_exclude: Sequence[Mapping[str, Any]] | None = None
139
140
141class Optim(BaseModel):
142    """
143    Optimizer config
144    """
145
146    # learning rate (max learning rate when using with the default OneCycle
147    # learning rate scheduler)
148    # see https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html#torch.optim.lr_scheduler.OneCycleLR)
149    # for information about the other config values
150    lr: float = 5e-3
151    # TODO: remove as redundant with lr
152    max_lr: float = 5e-3
153    start_lr: float = 5e-4
154    end_lr: float = 5e-4
155    pct_start: float = 0.2
156    weight_decay: float = 1e-5
157
158
159class Training(BaseModel):
160    """
161    Training related config.
162    As these config depend on what you are training,
163    they are defined in the script related yaml files (e.g. `train_v.yaml`,
164    `train_gw.yaml`).
165    """
166
167    batch_size: int = 2056
168    num_workers: int = 16
169    devices: int = 1  # num of devices (gpus) to use
170    accelerator: str = "gpu"
171
172    # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#fast-dev-run
173    fast_dev_run: bool = False
174    # number of training steps
175    max_steps: int = 200_000
176    enable_progress_bar: bool = True
177
178    # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision
179    # you may want to set to "16-mixed" if your gpu allows mixed precision
180    precision: Any = 32
181    # see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch-set-float32-matmul-precision
182    # you may want to decrease to "medium"
183    float32_matmul_precision: str = "highest"
184
185    # Optimizer config
186    optim: Optim = Optim()
187
188
189class ExploreVAE(BaseModel):
190    # the VAE checkpoint to use
191    checkpoint: str
192    num_samples: int = 9
193    range_start: int = -3
194    range_end: int = 3
195    wandb_name: str | None = None
196
197
198class ExploreGW(BaseModel):
199    checkpoint: str  # The GW checkpoint to use
200    domain: str
201    num_samples: int = 9
202    range_start: int = -3
203    range_end: int = 3
204    wandb_name: str | None = None
205
206
207class Visualization(BaseModel):
208    # when exploring a vae
209    explore_vae: ExploreVAE
210    # when exploring the GW
211    explore_gw: ExploreGW
212
213
214class WanDB(BaseModel):
215    enabled: bool = False
216    # where to save the logs
217    save_dir: str = ""
218    # the "entity/project" values of your wandb project
219    project: str = ""
220    entity: str = ""
221    # see https://docs.wandb.ai/ref/python/init/
222    reinit: bool = False
223
224
225class Dataset(BaseModel):
226    """
227    Simple shapes dataset infos
228    """
229
230    # Path to the dataset obtainable on https://github.com/ruflab/simple-shapes-dataset
231    path: Path
232
233
234class VisualModule(BaseModel):
235    """
236    Config for the visual domain module
237    """
238
239    num_channels: int = 3  # num channels of image to configure the VAE
240    ae_dim: int = 256  # AE hidden dim
241    latent_dim: int = 8  # latent dim of the VAE
242    beta: float = 0.1  # beta for betaVAE
243
244    # Whether the model is color blind.
245    # It adds a transform on the dataset that averages all color channels.
246    # NOTE: this only works when using the "v" domain and not "v_latents" where
247    # visual latent representations are already extracted. In that case, use a different
248    # presaved-path in `domain_args`.
249    color_blind: bool = False
250
251
252class AttributeModule(BaseModel):
253    """
254    Config for the attribute domain module
255    """
256
257    latent_dim: int = 10  # latent dim of the VAE
258    hidden_dim: int = 64  # hidden dim of the AE (encoders and decoders)
259    beta: float = 0.05  # for betaVAE
260    coef_categories: float = 1
261    coef_attributes: float = 1
262
263    # Whether to remove rotation information from the attribute.
264    nullify_rotation: bool = False
265
266
267class TextModule(BaseModel):
268    """
269    Config for the text domain module
270    """
271
272    # which file to load for text
273    # The file generated with `shapesd create` is "latent"
274    # If you use an older version (like if you downloaded directly the dataset) it
275    # should be "bert-base-uncased"
276    latent_filename: str = "latent"
277    # path to the vocab file
278    vocab_path: str = str((PROJECT_DIR / "tokenizer/vocab.json").resolve())
279    # path to the merge path
280    merges_path: str = str((PROJECT_DIR / "tokenizer/merges.txt").resolve())
281    # max sequence length of text sequence
282    seq_length: int = 64
283    vocab_size: int = 822
284
285    # VAE configuration
286    latent_dim: int = 64
287    hidden_dim: int = 256
288    beta: float = 0.1
289    reconstruction_coef: float = 1.0
290    kl_coef: float = 0.1
291
292
293class DomainModules(BaseModel):
294    visual: VisualModule = VisualModule()
295    attribute: AttributeModule = AttributeModule()
296    text: TextModule = TextModule()
297
298
299class EncodersConfig(BaseModel):
300    """
301    Encoder architecture config
302    """
303
304    hidden_dim: int | Mapping[DomainModuleVariant, int] = 32
305
306    # The model will have an extra linear before and after the n_layers
307    # Hence the total will be `2 + n_layers`
308    n_layers: int | Mapping[DomainModuleVariant, int] = 3
309
310
311class LoadedDomainConfig(BaseModel):
312    """
313    Domain params for the active domains of the run
314    """
315
316    # path to the pretrained module
317    checkpoint_path: Path
318    # domain to select
319    domain_type: DomainModuleVariant
320    # domain module specific arguments
321    args: Mapping[str, Any] = {}
322
323
324class DomainProportion(BaseModel):
325    """
326    Deprecated, DomainProportionT will be used instead.
327    """
328
329    # proportion for some domains
330    # should be in [0, 1]
331    proportion: float
332    # list of domains the proportion is associated to
333    # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data
334    domains: Sequence[str]
335
336
337class DomainProportionT(TypedDict):
338    """This replaces `DomainProportion` in future config."""
339
340    # proportion for some domains
341    # should be in [0, 1]
342    proportion: float
343    # list of domains the proportion is associated to
344    # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data
345    domains: Sequence[str]
346
347
348class GlobalWorkspace(BaseModel):
349    # latent dim of the GW
350    latent_dim: int = 12
351    # whether to use the fusion GW
352    use_fusion_model: bool = False
353    # softmax temp for the Softmax distruted selection used by the fusion model
354    selection_temperature: float = 0.2
355    # whether to learn the logit scale of the contrastive loss like in the clip paper
356    learn_logit_scale: bool = False
357    # whether to use the VSEPP (https://github.com/fartashf/vsepp) contrastive loss
358    # with associated params
359    vsepp_contrastive_loss: bool = False
360    vsepp_margin: float = 0.2
361    vsepp_measure: Literal["cosine", "order"] = "cosine"
362    vsepp_max_violation: bool = True
363    # whether to use linear encoders and decoders for the GW
364    linear_domains: bool = False
365    # whether to use bias when using linear encoders and decoders
366    linear_domains_use_bias: bool = False
367    # encoder architecture config
368    encoders: EncodersConfig = EncodersConfig()
369    # decoder architecture config
370    decoders: EncodersConfig = EncodersConfig()
371    # coefs of each loss. The total loss is computed using the given values and coefs
372    # you can select any available loss generated by the loss functions
373    loss_coefficients: Mapping[str, float] = {
374        "cycles": 1.0,
375        "demi_cycles": 1.0,
376        "translations": 1.0,
377        "contrastives": 0.01,
378        "fused": 1.0,
379    }
380    # checkpoint of the GW for downstream, visualization tasks, or migrations
381    checkpoint: Path | None = None
382    # deprecated, use Config.domain_data_args instead
383    domain_args: Mapping[str, Mapping[str, Any]] | None = Field(
384        default=None, deprecated="Use `config.domain_data_args` instead."
385    )
386    # deprecated, use Config.domains instead
387    domains: Sequence[LoadedDomainConfig] | None = Field(
388        default=None, deprecated="Use `config.domains` instead."
389    )
390    # deprecated, use Config.domain_proportions instead
391    domain_proportions: Sequence[DomainProportion] | None = Field(
392        default=None, deprecated="Use `config.domain_proportions` instead."
393    )
394
395
396class ShimmerConfigInfo(BaseModel):
397    """
398    Some shimmer related config
399    """
400
401    # version of shimmer used
402    version: str = shimmer_version
403    # whether started in debug mode
404    debug: bool = False
405    # params that were passed through CLI
406    cli: Any = {}
407
408
409class Config(ParsedModel):
410    seed: int = 0  # training seed
411    ood_seed: int | None = None  # Out of distribution seed
412    default_root_dir: Path = (
413        Path("./checkpoints")  # Path where to save and load logs and checkpoints
414    )
415    # Dataset information
416    dataset: Dataset
417    # Training config
418    training: Training = Training()
419    # Wandb configuration
420    wandb: WanDB = WanDB()
421    # Logging configuration
422    logging: Logging = Logging()
423    # Add a title to your wandb run
424    title: str | None = Field(None, alias="t")
425    # Add a description to your run
426    desc: str | None = Field(None, alias="d")
427    # proportion of each domain in the dataset relative to the size of the dataset
428    domain_proportions: Mapping[frozenset[str], float] = {}
429    # Config of the different domain modules
430    domain_modules: DomainModules = DomainModules()
431    # Domain params for the active domains of the run
432    domains: Sequence[LoadedDomainConfig] = []
433    # data related args used by the dataloader
434    domain_data_args: Mapping[str, Mapping[str, Any]] = {}
435    # Global workspace configuration
436    global_workspace: GlobalWorkspace = GlobalWorkspace()
437    # Config used during visualization
438    visualization: Visualization | None = None
439    # Slurm config when startig on a cluster
440    slurm: Slurm | None = None
441    __shimmer__: ShimmerConfigInfo = ShimmerConfigInfo()
442
443    @field_validator("domain_proportions", mode="before")
444    @classmethod
445    def domain_proportion_validator(
446        cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float]
447    ) -> Mapping[frozenset[str], float]:
448        """
449        Replace the format:
450        ```
451        - domains: ["v"]
452          proportion: 1.0
453        ```
454        in the yaml file into a Mapping[frozenset[str], float]
455        """
456        if isinstance(value, Mapping):
457            return value
458        else:
459            return {frozenset(item["domains"]): item["proportion"] for item in value}
460
461    @field_serializer("domain_proportions")
462    def serialize_domain_proportions(
463        self, domain_proportions: Mapping[frozenset[str], float], _info
464    ) -> list[DomainProportionT]:
465        return [
466            {"domains": list(domains), "proportion": proportion}
467            for domains, proportion in domain_proportions.items()
468        ]
469
470
471def use_deprecated_vals(config: Config) -> Config:
472    # use deprecated values
473    if config.global_workspace.domain_args is not None:
474        config.domain_data_args = config.global_workspace.domain_args
475        warnings.warn(
476            "Deprecated `config.global_workspace.domain_args`, "
477            "use `config.domain_data_args` instead",
478            DeprecationWarning,
479            stacklevel=2,
480        )
481    if config.global_workspace.domains is not None:
482        config.domains = config.global_workspace.domains
483        warnings.warn(
484            "Deprecated `config.global_workspace.domains`, "
485            "use `config.domains` instead",
486            DeprecationWarning,
487            stacklevel=2,
488        )
489    if config.global_workspace.domain_proportions is not None:
490        config.domain_proportions = {
491            frozenset(item.domains): item.proportion
492            for item in config.global_workspace.domain_proportions
493        }
494
495        warnings.warn(
496            "Deprecated `config.global_workspace.domain_proportions`, "
497            "use `config.domain_proportions` instead",
498            DeprecationWarning,
499            stacklevel=2,
500        )
501    return config
502
503
504def load_config(
505    path: str | Path,
506    load_files: list[str] | None = None,
507    use_cli: bool = True,
508    debug_mode: bool = False,
509    argv: list[str] | None = None,
510    log_config: bool = False,
511) -> Config:
512    path = Path(path)
513    conf_files = []
514    if load_files is not None:
515        conf_files.extend(load_files)
516    if (path / "main.yaml").exists():
517        conf_files.append("main.yaml")
518    if (path / "local.yaml").exists():
519        conf_files.append("local.yaml")
520
521    if debug_mode and (path / "debug.yaml").exists():
522        conf_files.append("debug.yaml")
523
524    config_dict, cli_config = load_config_files(path, conf_files, use_cli, argv)
525
526    config_dict.update(
527        {
528            "__shimmer__": {
529                "version": __version__,
530                "debug": debug_mode,
531                "cli": cli_config,
532            }
533        }
534    )
535
536    conf = use_deprecated_vals(
537        validate_and_fill_missing(config_dict, Config, path, "local.yaml")
538    )
539    if log_config:
540        print("Loaded config:")
541        pprint.pp(dict(conf))
542    return conf
class DomainModuleVariant(enum.Enum):
 27class DomainModuleVariant(Enum):
 28    """
 29    This is used to select a particular DomainModule.
 30    Each domain can have different variants of domain modules.
 31
 32    For example "attr" will load the default attribute module that uses a VAE,
 33    whereas "attr_legacy" will load a domain module that directly passes attributes
 34    to the GW.
 35    """
 36
 37    # Attribute modules
 38    # -----------------
 39    # Attribute module using a VAE to encode the attribute vector
 40    attr = (DomainType.attr, "default")
 41    # This is the module used in Devillers et al. paper. There is no VAE and the
 42    # attributes are used directly as the unimodal latent representations
 43    attr_legacy = (DomainType.attr, "legacy")
 44    # Same as "attr" but adds an unpaired attributes (information not available in the
 45    # other domains).
 46    attr_unpaired = (DomainType.attr, "unpaired")
 47
 48    # Visual modules
 49    # --------------
 50    # Visual VAE
 51    v = (DomainType.v, "default")
 52    # Same as "v", but uses pre-saved latent VAE representation for faster training.
 53    # This skips the image loading and encoding and only loads latent representation.
 54    # The downside is that you cannot access the default image, but you can reconstruct
 55    # it with "decode_images".
 56    v_latents = (DomainType.v_latents, "default")
 57    # Same as "v_latents" but adds an unpaired value (radom information not available
 58    # in the other domains).
 59    v_latents_unpaired = (DomainType.v_latents, "unpaired")
 60
 61    # Text modules
 62    # ------------
 63    # Text domain.
 64    t = (DomainType.t, "default")
 65    t_attr = (DomainType.t, "t2attr")
 66
 67    def __init__(self, kind: DomainType, model_variant: str) -> None:
 68        """
 69        The two elements of the tuple are put in the the `kind` and the `model_variant`
 70        properties.
 71        """
 72        self.kind = kind
 73        self.model_variant = model_variant
 74
 75    @classmethod
 76    def __get_pydantic_core_schema__(
 77        cls, _source_type: Any, _handler: GetCoreSchemaHandler
 78    ) -> core_schema.CoreSchema:
 79        """
 80        Define how this type is validated and serialized by pydantic.
 81        It can take a str related to the enum keys.
 82        It should be serialized as the the enum key name.
 83        """
 84
 85        def validate_from_str(v: str) -> DomainModuleVariant:
 86            """
 87            Use names instead of values to select enums
 88            """
 89            assert v in DomainModuleVariant.__members__, (
 90                f"Domain type `{v}` is not a member "
 91                f"of {list(DomainModuleVariant.__members__.keys())}"
 92            )
 93            return DomainModuleVariant[v]
 94
 95        from_str_schema = core_schema.no_info_plain_validator_function(
 96            validate_from_str
 97        )
 98
 99        def serialize_domain_variant(v: DomainModuleVariant) -> str:
100            return v.name
101
102        return core_schema.json_or_python_schema(
103            json_schema=from_str_schema,
104            python_schema=core_schema.union_schema(
105                [
106                    core_schema.is_instance_schema(DomainModuleVariant),
107                    from_str_schema,
108                ]
109            ),
110            serialization=core_schema.plain_serializer_function_ser_schema(
111                serialize_domain_variant
112            ),
113        )

This is used to select a particular DomainModule. Each domain can have different variants of domain modules.

For example "attr" will load the default attribute module that uses a VAE, whereas "attr_legacy" will load a domain module that directly passes attributes to the GW.

DomainModuleVariant(kind: simple_shapes_dataset.domain.DomainType, model_variant: str)
67    def __init__(self, kind: DomainType, model_variant: str) -> None:
68        """
69        The two elements of the tuple are put in the the `kind` and the `model_variant`
70        properties.
71        """
72        self.kind = kind
73        self.model_variant = model_variant

The two elements of the tuple are put in the the kind and the model_variant properties.

attr = <DomainModuleVariant.attr: (<DomainType.attr: DomainDesc(base='attr', kind='attr')>, 'default')>
attr_legacy = <DomainModuleVariant.attr_legacy: (<DomainType.attr: DomainDesc(base='attr', kind='attr')>, 'legacy')>
attr_unpaired = <DomainModuleVariant.attr_unpaired: (<DomainType.attr: DomainDesc(base='attr', kind='attr')>, 'unpaired')>
v = <DomainModuleVariant.v: (<DomainType.v: DomainDesc(base='v', kind='v')>, 'default')>
v_latents = <DomainModuleVariant.v_latents: (<DomainType.v_latents: DomainDesc(base='v', kind='v_latents')>, 'default')>
v_latents_unpaired = <DomainModuleVariant.v_latents_unpaired: (<DomainType.v_latents: DomainDesc(base='v', kind='v_latents')>, 'unpaired')>
t = <DomainModuleVariant.t: (<DomainType.t: DomainDesc(base='t', kind='t')>, 'default')>
t_attr = <DomainModuleVariant.t_attr: (<DomainType.t: DomainDesc(base='t', kind='t')>, 't2attr')>
kind
model_variant
class Logging(pydantic.main.BaseModel):
116class Logging(BaseModel):
117    # List of medias (images/text) that will be logged.
118    # The list is defined in individual train config (e.g. train_v.yaml, train_gw.yaml)
119    filter_images: Sequence[str] | None = None
120    log_train_medias_every_n_epochs: int | None = 10
121    log_val_medias_every_n_epochs: int | None = 10

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
filter_images: Sequence[str] | None
log_train_medias_every_n_epochs: int | None
log_val_medias_every_n_epochs: int | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Slurm(pydantic.main.BaseModel):
124class Slurm(BaseModel):
125    """
126    Slurm config for https://github.com/bdvllrs/auto-sbatch
127    """
128
129    script: str
130    run_workdir: str
131    python_env: str
132    command: str
133
134    pre_modules: Sequence[str] = []
135    run_modules: Sequence[str] = []
136    args: Mapping[str, Any] = {}
137
138    grid_search: Mapping[str, Sequence[Any]] | None = None
139    grid_search_exclude: Sequence[Mapping[str, Any]] | None = None
script: str
run_workdir: str
python_env: str
command: str
pre_modules: Sequence[str]
run_modules: Sequence[str]
args: Mapping[str, typing.Any]
grid_search_exclude: Sequence[Mapping[str, typing.Any]] | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Optim(pydantic.main.BaseModel):
142class Optim(BaseModel):
143    """
144    Optimizer config
145    """
146
147    # learning rate (max learning rate when using with the default OneCycle
148    # learning rate scheduler)
149    # see https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html#torch.optim.lr_scheduler.OneCycleLR)
150    # for information about the other config values
151    lr: float = 5e-3
152    # TODO: remove as redundant with lr
153    max_lr: float = 5e-3
154    start_lr: float = 5e-4
155    end_lr: float = 5e-4
156    pct_start: float = 0.2
157    weight_decay: float = 1e-5

Optimizer config

lr: float
max_lr: float
start_lr: float
end_lr: float
pct_start: float
weight_decay: float
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Training(pydantic.main.BaseModel):
160class Training(BaseModel):
161    """
162    Training related config.
163    As these config depend on what you are training,
164    they are defined in the script related yaml files (e.g. `train_v.yaml`,
165    `train_gw.yaml`).
166    """
167
168    batch_size: int = 2056
169    num_workers: int = 16
170    devices: int = 1  # num of devices (gpus) to use
171    accelerator: str = "gpu"
172
173    # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#fast-dev-run
174    fast_dev_run: bool = False
175    # number of training steps
176    max_steps: int = 200_000
177    enable_progress_bar: bool = True
178
179    # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision
180    # you may want to set to "16-mixed" if your gpu allows mixed precision
181    precision: Any = 32
182    # see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch-set-float32-matmul-precision
183    # you may want to decrease to "medium"
184    float32_matmul_precision: str = "highest"
185
186    # Optimizer config
187    optim: Optim = Optim()

Training related config. As these config depend on what you are training, they are defined in the script related yaml files (e.g. train_v.yaml, train_gw.yaml).

batch_size: int
num_workers: int
devices: int
accelerator: str
fast_dev_run: bool
max_steps: int
enable_progress_bar: bool
precision: Any
float32_matmul_precision: str
optim: Optim
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ExploreVAE(pydantic.main.BaseModel):
190class ExploreVAE(BaseModel):
191    # the VAE checkpoint to use
192    checkpoint: str
193    num_samples: int = 9
194    range_start: int = -3
195    range_end: int = 3
196    wandb_name: str | None = None

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
checkpoint: str
num_samples: int
range_start: int
range_end: int
wandb_name: str | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ExploreGW(pydantic.main.BaseModel):
199class ExploreGW(BaseModel):
200    checkpoint: str  # The GW checkpoint to use
201    domain: str
202    num_samples: int = 9
203    range_start: int = -3
204    range_end: int = 3
205    wandb_name: str | None = None

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
checkpoint: str
domain: str
num_samples: int
range_start: int
range_end: int
wandb_name: str | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Visualization(pydantic.main.BaseModel):
208class Visualization(BaseModel):
209    # when exploring a vae
210    explore_vae: ExploreVAE
211    # when exploring the GW
212    explore_gw: ExploreGW

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
explore_vae: ExploreVAE
explore_gw: ExploreGW
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WanDB(pydantic.main.BaseModel):
215class WanDB(BaseModel):
216    enabled: bool = False
217    # where to save the logs
218    save_dir: str = ""
219    # the "entity/project" values of your wandb project
220    project: str = ""
221    entity: str = ""
222    # see https://docs.wandb.ai/ref/python/init/
223    reinit: bool = False

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
enabled: bool
save_dir: str
project: str
entity: str
reinit: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Dataset(pydantic.main.BaseModel):
226class Dataset(BaseModel):
227    """
228    Simple shapes dataset infos
229    """
230
231    # Path to the dataset obtainable on https://github.com/ruflab/simple-shapes-dataset
232    path: Path

Simple shapes dataset infos

path: pathlib.Path
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class VisualModule(pydantic.main.BaseModel):
235class VisualModule(BaseModel):
236    """
237    Config for the visual domain module
238    """
239
240    num_channels: int = 3  # num channels of image to configure the VAE
241    ae_dim: int = 256  # AE hidden dim
242    latent_dim: int = 8  # latent dim of the VAE
243    beta: float = 0.1  # beta for betaVAE
244
245    # Whether the model is color blind.
246    # It adds a transform on the dataset that averages all color channels.
247    # NOTE: this only works when using the "v" domain and not "v_latents" where
248    # visual latent representations are already extracted. In that case, use a different
249    # presaved-path in `domain_args`.
250    color_blind: bool = False

Config for the visual domain module

num_channels: int
ae_dim: int
latent_dim: int
beta: float
color_blind: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class AttributeModule(pydantic.main.BaseModel):
253class AttributeModule(BaseModel):
254    """
255    Config for the attribute domain module
256    """
257
258    latent_dim: int = 10  # latent dim of the VAE
259    hidden_dim: int = 64  # hidden dim of the AE (encoders and decoders)
260    beta: float = 0.05  # for betaVAE
261    coef_categories: float = 1
262    coef_attributes: float = 1
263
264    # Whether to remove rotation information from the attribute.
265    nullify_rotation: bool = False

Config for the attribute domain module

latent_dim: int
hidden_dim: int
beta: float
coef_categories: float
coef_attributes: float
nullify_rotation: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TextModule(pydantic.main.BaseModel):
268class TextModule(BaseModel):
269    """
270    Config for the text domain module
271    """
272
273    # which file to load for text
274    # The file generated with `shapesd create` is "latent"
275    # If you use an older version (like if you downloaded directly the dataset) it
276    # should be "bert-base-uncased"
277    latent_filename: str = "latent"
278    # path to the vocab file
279    vocab_path: str = str((PROJECT_DIR / "tokenizer/vocab.json").resolve())
280    # path to the merge path
281    merges_path: str = str((PROJECT_DIR / "tokenizer/merges.txt").resolve())
282    # max sequence length of text sequence
283    seq_length: int = 64
284    vocab_size: int = 822
285
286    # VAE configuration
287    latent_dim: int = 64
288    hidden_dim: int = 256
289    beta: float = 0.1
290    reconstruction_coef: float = 1.0
291    kl_coef: float = 0.1

Config for the text domain module

latent_filename: str
vocab_path: str
merges_path: str
seq_length: int
vocab_size: int
latent_dim: int
hidden_dim: int
beta: float
reconstruction_coef: float
kl_coef: float
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class DomainModules(pydantic.main.BaseModel):
294class DomainModules(BaseModel):
295    visual: VisualModule = VisualModule()
296    attribute: AttributeModule = AttributeModule()
297    text: TextModule = TextModule()

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
visual: VisualModule
attribute: AttributeModule
text: TextModule
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class EncodersConfig(pydantic.main.BaseModel):
300class EncodersConfig(BaseModel):
301    """
302    Encoder architecture config
303    """
304
305    hidden_dim: int | Mapping[DomainModuleVariant, int] = 32
306
307    # The model will have an extra linear before and after the n_layers
308    # Hence the total will be `2 + n_layers`
309    n_layers: int | Mapping[DomainModuleVariant, int] = 3

Encoder architecture config

hidden_dim: int | Mapping[DomainModuleVariant, int]
n_layers: int | Mapping[DomainModuleVariant, int]
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class LoadedDomainConfig(pydantic.main.BaseModel):
312class LoadedDomainConfig(BaseModel):
313    """
314    Domain params for the active domains of the run
315    """
316
317    # path to the pretrained module
318    checkpoint_path: Path
319    # domain to select
320    domain_type: DomainModuleVariant
321    # domain module specific arguments
322    args: Mapping[str, Any] = {}

Domain params for the active domains of the run

checkpoint_path: pathlib.Path
domain_type: DomainModuleVariant
args: Mapping[str, typing.Any]
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class DomainProportion(pydantic.main.BaseModel):
325class DomainProportion(BaseModel):
326    """
327    Deprecated, DomainProportionT will be used instead.
328    """
329
330    # proportion for some domains
331    # should be in [0, 1]
332    proportion: float
333    # list of domains the proportion is associated to
334    # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data
335    domains: Sequence[str]

Deprecated, DomainProportionT will be used instead.

proportion: float
domains: Sequence[str]
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class DomainProportionT(typing_extensions.TypedDict):
338class DomainProportionT(TypedDict):
339    """This replaces `DomainProportion` in future config."""
340
341    # proportion for some domains
342    # should be in [0, 1]
343    proportion: float
344    # list of domains the proportion is associated to
345    # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data
346    domains: Sequence[str]

This replaces DomainProportion in future config.

proportion: float
domains: Sequence[str]
class GlobalWorkspace(pydantic.main.BaseModel):
349class GlobalWorkspace(BaseModel):
350    # latent dim of the GW
351    latent_dim: int = 12
352    # whether to use the fusion GW
353    use_fusion_model: bool = False
354    # softmax temp for the Softmax distruted selection used by the fusion model
355    selection_temperature: float = 0.2
356    # whether to learn the logit scale of the contrastive loss like in the clip paper
357    learn_logit_scale: bool = False
358    # whether to use the VSEPP (https://github.com/fartashf/vsepp) contrastive loss
359    # with associated params
360    vsepp_contrastive_loss: bool = False
361    vsepp_margin: float = 0.2
362    vsepp_measure: Literal["cosine", "order"] = "cosine"
363    vsepp_max_violation: bool = True
364    # whether to use linear encoders and decoders for the GW
365    linear_domains: bool = False
366    # whether to use bias when using linear encoders and decoders
367    linear_domains_use_bias: bool = False
368    # encoder architecture config
369    encoders: EncodersConfig = EncodersConfig()
370    # decoder architecture config
371    decoders: EncodersConfig = EncodersConfig()
372    # coefs of each loss. The total loss is computed using the given values and coefs
373    # you can select any available loss generated by the loss functions
374    loss_coefficients: Mapping[str, float] = {
375        "cycles": 1.0,
376        "demi_cycles": 1.0,
377        "translations": 1.0,
378        "contrastives": 0.01,
379        "fused": 1.0,
380    }
381    # checkpoint of the GW for downstream, visualization tasks, or migrations
382    checkpoint: Path | None = None
383    # deprecated, use Config.domain_data_args instead
384    domain_args: Mapping[str, Mapping[str, Any]] | None = Field(
385        default=None, deprecated="Use `config.domain_data_args` instead."
386    )
387    # deprecated, use Config.domains instead
388    domains: Sequence[LoadedDomainConfig] | None = Field(
389        default=None, deprecated="Use `config.domains` instead."
390    )
391    # deprecated, use Config.domain_proportions instead
392    domain_proportions: Sequence[DomainProportion] | None = Field(
393        default=None, deprecated="Use `config.domain_proportions` instead."
394    )

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
latent_dim: int
use_fusion_model: bool
selection_temperature: float
learn_logit_scale: bool
vsepp_contrastive_loss: bool
vsepp_margin: float
vsepp_measure: Literal['cosine', 'order']
vsepp_max_violation: bool
linear_domains: bool
linear_domains_use_bias: bool
encoders: EncodersConfig
decoders: EncodersConfig
loss_coefficients: Mapping[str, float]
checkpoint: pathlib.Path | None
domain_args: Mapping[str, Mapping[str, typing.Any]] | None

Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.

Attributes:
  • msg: The deprecation message to be emitted.
  • wrapped_property: The property instance if the deprecated field is a computed field, or None.
  • field_name: The name of the field being deprecated.
domains: Sequence[LoadedDomainConfig] | None

Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.

Attributes:
  • msg: The deprecation message to be emitted.
  • wrapped_property: The property instance if the deprecated field is a computed field, or None.
  • field_name: The name of the field being deprecated.
domain_proportions: Sequence[DomainProportion] | None

Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.

Attributes:
  • msg: The deprecation message to be emitted.
  • wrapped_property: The property instance if the deprecated field is a computed field, or None.
  • field_name: The name of the field being deprecated.
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ShimmerConfigInfo(pydantic.main.BaseModel):
397class ShimmerConfigInfo(BaseModel):
398    """
399    Some shimmer related config
400    """
401
402    # version of shimmer used
403    version: str = shimmer_version
404    # whether started in debug mode
405    debug: bool = False
406    # params that were passed through CLI
407    cli: Any = {}

Some shimmer related config

version: str
debug: bool
cli: Any
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Config(cfg_tools.data_parser.ParsedModel):
410class Config(ParsedModel):
411    seed: int = 0  # training seed
412    ood_seed: int | None = None  # Out of distribution seed
413    default_root_dir: Path = (
414        Path("./checkpoints")  # Path where to save and load logs and checkpoints
415    )
416    # Dataset information
417    dataset: Dataset
418    # Training config
419    training: Training = Training()
420    # Wandb configuration
421    wandb: WanDB = WanDB()
422    # Logging configuration
423    logging: Logging = Logging()
424    # Add a title to your wandb run
425    title: str | None = Field(None, alias="t")
426    # Add a description to your run
427    desc: str | None = Field(None, alias="d")
428    # proportion of each domain in the dataset relative to the size of the dataset
429    domain_proportions: Mapping[frozenset[str], float] = {}
430    # Config of the different domain modules
431    domain_modules: DomainModules = DomainModules()
432    # Domain params for the active domains of the run
433    domains: Sequence[LoadedDomainConfig] = []
434    # data related args used by the dataloader
435    domain_data_args: Mapping[str, Mapping[str, Any]] = {}
436    # Global workspace configuration
437    global_workspace: GlobalWorkspace = GlobalWorkspace()
438    # Config used during visualization
439    visualization: Visualization | None = None
440    # Slurm config when startig on a cluster
441    slurm: Slurm | None = None
442    __shimmer__: ShimmerConfigInfo = ShimmerConfigInfo()
443
444    @field_validator("domain_proportions", mode="before")
445    @classmethod
446    def domain_proportion_validator(
447        cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float]
448    ) -> Mapping[frozenset[str], float]:
449        """
450        Replace the format:
451        ```
452        - domains: ["v"]
453          proportion: 1.0
454        ```
455        in the yaml file into a Mapping[frozenset[str], float]
456        """
457        if isinstance(value, Mapping):
458            return value
459        else:
460            return {frozenset(item["domains"]): item["proportion"] for item in value}
461
462    @field_serializer("domain_proportions")
463    def serialize_domain_proportions(
464        self, domain_proportions: Mapping[frozenset[str], float], _info
465    ) -> list[DomainProportionT]:
466        return [
467            {"domains": list(domains), "proportion": proportion}
468            for domains, proportion in domain_proportions.items()
469        ]

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
seed: int
ood_seed: int | None
default_root_dir: pathlib.Path
dataset: Dataset
training: Training
wandb: WanDB
logging: Logging
title: str | None
desc: str | None
domain_proportions: Mapping[frozenset[str], float]
domain_modules: DomainModules
domains: Sequence[LoadedDomainConfig]
domain_data_args: Mapping[str, Mapping[str, typing.Any]]
global_workspace: GlobalWorkspace
visualization: Visualization | None
slurm: Slurm | None
@field_validator('domain_proportions', mode='before')
@classmethod
def domain_proportion_validator( cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float]) -> Mapping[frozenset[str], float]:
444    @field_validator("domain_proportions", mode="before")
445    @classmethod
446    def domain_proportion_validator(
447        cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float]
448    ) -> Mapping[frozenset[str], float]:
449        """
450        Replace the format:
451        ```
452        - domains: ["v"]
453          proportion: 1.0
454        ```
455        in the yaml file into a Mapping[frozenset[str], float]
456        """
457        if isinstance(value, Mapping):
458            return value
459        else:
460            return {frozenset(item["domains"]): item["proportion"] for item in value}

Replace the format:

- domains: ["v"]
  proportion: 1.0

in the yaml file into a Mapping[frozenset[str], float]

@field_serializer('domain_proportions')
def serialize_domain_proportions( self, domain_proportions: Mapping[frozenset[str], float], _info) -> list[DomainProportionT]:
462    @field_serializer("domain_proportions")
463    def serialize_domain_proportions(
464        self, domain_proportions: Mapping[frozenset[str], float], _info
465    ) -> list[DomainProportionT]:
466        return [
467            {"domains": list(domains), "proportion": proportion}
468            for domains, proportion in domain_proportions.items()
469        ]
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def use_deprecated_vals(config: Config) -> Config:
472def use_deprecated_vals(config: Config) -> Config:
473    # use deprecated values
474    if config.global_workspace.domain_args is not None:
475        config.domain_data_args = config.global_workspace.domain_args
476        warnings.warn(
477            "Deprecated `config.global_workspace.domain_args`, "
478            "use `config.domain_data_args` instead",
479            DeprecationWarning,
480            stacklevel=2,
481        )
482    if config.global_workspace.domains is not None:
483        config.domains = config.global_workspace.domains
484        warnings.warn(
485            "Deprecated `config.global_workspace.domains`, "
486            "use `config.domains` instead",
487            DeprecationWarning,
488            stacklevel=2,
489        )
490    if config.global_workspace.domain_proportions is not None:
491        config.domain_proportions = {
492            frozenset(item.domains): item.proportion
493            for item in config.global_workspace.domain_proportions
494        }
495
496        warnings.warn(
497            "Deprecated `config.global_workspace.domain_proportions`, "
498            "use `config.domain_proportions` instead",
499            DeprecationWarning,
500            stacklevel=2,
501        )
502    return config
def load_config( path: str | pathlib.Path, load_files: list[str] | None = None, use_cli: bool = True, debug_mode: bool = False, argv: list[str] | None = None, log_config: bool = False) -> Config:
505def load_config(
506    path: str | Path,
507    load_files: list[str] | None = None,
508    use_cli: bool = True,
509    debug_mode: bool = False,
510    argv: list[str] | None = None,
511    log_config: bool = False,
512) -> Config:
513    path = Path(path)
514    conf_files = []
515    if load_files is not None:
516        conf_files.extend(load_files)
517    if (path / "main.yaml").exists():
518        conf_files.append("main.yaml")
519    if (path / "local.yaml").exists():
520        conf_files.append("local.yaml")
521
522    if debug_mode and (path / "debug.yaml").exists():
523        conf_files.append("debug.yaml")
524
525    config_dict, cli_config = load_config_files(path, conf_files, use_cli, argv)
526
527    config_dict.update(
528        {
529            "__shimmer__": {
530                "version": __version__,
531                "debug": debug_mode,
532                "cli": cli_config,
533            }
534        }
535    )
536
537    conf = use_deprecated_vals(
538        validate_and_fill_missing(config_dict, Config, path, "local.yaml")
539    )
540    if log_config:
541        print("Loaded config:")
542        pprint.pp(dict(conf))
543    return conf