shimmer_ssd.config

  1import pprint
  2import warnings
  3from collections.abc import Mapping, Sequence
  4from enum import Enum
  5from pathlib import Path
  6from typing import Any, Literal, Self
  7
  8from cfg_tools import ParsedModel, load_config_files
  9from cfg_tools.utils import validate_and_fill_missing
 10from pydantic import (
 11    BaseModel,
 12    Field,
 13    GetCoreSchemaHandler,
 14    field_serializer,
 15    field_validator,
 16    model_validator,
 17)
 18from pydantic_core import core_schema
 19from shimmer import __version__
 20from shimmer.version import __version__ as shimmer_version
 21from simple_shapes_dataset import DomainType
 22from typing_extensions import TypedDict
 23
 24from shimmer_ssd import PROJECT_DIR
 25
 26
 27class DomainModuleVariant(Enum):
 28    """
 29    This is used to select a particular DomainModule.
 30    Each domain can have different variants of domain modules.
 31
 32    For example "attr" will load the default attribute module that uses a VAE,
 33    whereas "attr_legacy" will load a domain module that directly passes attributes
 34    to the GW.
 35    """
 36
 37    # Attribute modules
 38    # -----------------
 39    # Attribute module using a VAE to encode the attribute vector
 40    attr = (DomainType.attr, "default")
 41    # This is the module used in Devillers et al. paper. There is no VAE and the
 42    # attributes are used directly as the unimodal latent representations
 43    attr_legacy = (DomainType.attr, "legacy")
 44    # Same as "attr" but adds an unpaired attributes (information not available in the
 45    # other domains).
 46    attr_unpaired = (DomainType.attr, "unpaired")
 47
 48    # Visual modules
 49    # --------------
 50    # Visual VAE
 51    v = (DomainType.v, "default")
 52    # Same as "v", but uses pre-saved latent VAE representation for faster training.
 53    # This skips the image loading and encoding and only loads latent representation.
 54    # The downside is that you cannot access the default image, but you can reconstruct
 55    # it with "decode_images".
 56    v_latents = (DomainType.v_latents, "default")
 57    # Same as "v_latents" but adds an unpaired value (radom information not available
 58    # in the other domains).
 59    v_latents_unpaired = (DomainType.v_latents, "unpaired")
 60
 61    # Text modules
 62    # ------------
 63    # Text domain.
 64    t = (DomainType.t, "default")
 65    t_attr = (DomainType.t, "t2attr")
 66
 67    def __init__(self, kind: DomainType, model_variant: str) -> None:
 68        """
 69        The two elements of the tuple are put in the the `kind` and the `model_variant`
 70        properties.
 71        """
 72        self.kind = kind
 73        self.model_variant = model_variant
 74
 75    @classmethod
 76    def __get_pydantic_core_schema__(
 77        cls, _source_type: Any, _handler: GetCoreSchemaHandler
 78    ) -> core_schema.CoreSchema:
 79        """
 80        Define how this type is validated and serialized by pydantic.
 81        It can take a str related to the enum keys.
 82        It should be serialized as the the enum key name.
 83        """
 84
 85        def validate_from_str(v: str) -> DomainModuleVariant:
 86            """
 87            Use names instead of values to select enums
 88            """
 89            assert v in DomainModuleVariant.__members__, (
 90                f"Domain type `{v}` is not a member "
 91                f"of {list(DomainModuleVariant.__members__.keys())}"
 92            )
 93            return DomainModuleVariant[v]
 94
 95        from_str_schema = core_schema.no_info_plain_validator_function(
 96            validate_from_str
 97        )
 98
 99        def serialize_domain_variant(v: DomainModuleVariant) -> str:
100            return v.name
101
102        return core_schema.json_or_python_schema(
103            json_schema=from_str_schema,
104            python_schema=core_schema.union_schema(
105                [
106                    core_schema.is_instance_schema(DomainModuleVariant),
107                    from_str_schema,
108                ]
109            ),
110            serialization=core_schema.plain_serializer_function_ser_schema(
111                serialize_domain_variant
112            ),
113        )
114
115
116class Logging(BaseModel):
117    # List of medias (images/text) that will be logged.
118    # The list is defined in individual train config (e.g. train_v.yaml, train_gw.yaml)
119    filter_images: Sequence[str] | None = None
120    log_train_medias_every_n_epochs: int | None = 10
121    log_val_medias_every_n_epochs: int | None = 10
122
123
124class Slurm(BaseModel):
125    """
126    Slurm config for https://github.com/bdvllrs/auto-sbatch
127    """
128
129    script: str
130    run_workdir: str
131    python_env: str
132    command: str
133
134    pre_modules: Sequence[str] = []
135    run_modules: Sequence[str] = []
136    args: Mapping[str, Any] = {}
137
138    grid_search: Mapping[str, Sequence[Any]] | None = None
139    grid_search_exclude: Sequence[Mapping[str, Any]] | None = None
140
141
142class Optim(BaseModel):
143    """
144    Optimizer config
145    """
146
147    # learning rate (max learning rate when using with the default OneCycle
148    # learning rate scheduler)
149    # see https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html#torch.optim.lr_scheduler.OneCycleLR)
150    # for information about the other config values
151    lr: float = 5e-3
152    # TODO: remove as redundant with lr
153    max_lr: float = 5e-3
154    start_lr: float = 5e-4
155    end_lr: float = 5e-4
156    pct_start: float = 0.2
157    weight_decay: float = 1e-5
158
159
160class Training(BaseModel):
161    """
162    Training related config.
163    As these config depend on what you are training,
164    they are defined in the script related yaml files (e.g. `train_v.yaml`,
165    `train_gw.yaml`).
166    """
167
168    batch_size: int = 2056
169    num_workers: int = 16
170    devices: int = 1  # num of devices (gpus) to use
171    accelerator: str = "gpu"
172
173    # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#fast-dev-run
174    fast_dev_run: bool = False
175    # number of training steps
176    max_steps: int = 200_000
177    enable_progress_bar: bool = True
178
179    # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision
180    # you may want to set to "16-mixed" if your gpu allows mixed precision
181    precision: Any = 32
182    # see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch-set-float32-matmul-precision
183    # you may want to decrease to "medium"
184    float32_matmul_precision: str = "highest"
185
186    # Optimizer config
187    optim: Optim = Optim()
188
189
190class ExploreVAE(BaseModel):
191    # the VAE checkpoint to use
192    checkpoint: str
193    num_samples: int = 9
194    range_start: int = -3
195    range_end: int = 3
196    wandb_name: str | None = None
197
198
199class ExploreGW(BaseModel):
200    checkpoint: str  # The GW checkpoint to use
201    domain: str
202    num_samples: int = 9
203    range_start: int = -3
204    range_end: int = 3
205    wandb_name: str | None = None
206
207
208class Visualization(BaseModel):
209    # when exploring a vae
210    explore_vae: ExploreVAE
211    # when exploring the GW
212    explore_gw: ExploreGW
213
214
215class WanDB(BaseModel):
216    enabled: bool = False
217    # where to save the logs
218    save_dir: str = ""
219    # the "entity/project" values of your wandb project
220    project: str = ""
221    entity: str = ""
222    # see https://docs.wandb.ai/ref/python/init/
223    reinit: bool = False
224
225
226class Dataset(BaseModel):
227    """
228    Simple shapes dataset infos
229    """
230
231    # Path to the dataset obtainable on https://github.com/ruflab/simple-shapes-dataset
232    path: Path
233
234
235class VisualModule(BaseModel):
236    """
237    Config for the visual domain module
238    """
239
240    num_channels: int = 3  # num channels of image to configure the VAE
241    ae_dim: int = 256  # AE hidden dim
242    latent_dim: int = 8  # latent dim of the VAE
243    beta: float = 0.1  # beta for betaVAE
244
245    # Whether the model is color blind.
246    # It adds a transform on the dataset that averages all color channels.
247    # NOTE: this only works when using the "v" domain and not "v_latents" where
248    # visual latent representations are already extracted. In that case, use a different
249    # presaved-path in `domain_args`.
250    color_blind: bool = False
251
252
253class AttributeModule(BaseModel):
254    """
255    Config for the attribute domain module
256    """
257
258    latent_dim: int = 10  # latent dim of the VAE
259    hidden_dim: int = 64  # hidden dim of the AE (encoders and decoders)
260    beta: float = 0.05  # for betaVAE
261    coef_categories: float = 1
262    coef_attributes: float = 1
263
264    # Whether to remove rotation information from the attribute.
265    nullify_rotation: bool = False
266
267
268class TextModule(BaseModel):
269    """
270    Config for the text domain module
271    """
272
273    # which file to load for text
274    # The file generated with `shapesd create` is "latent"
275    # If you use an older version (like if you downloaded directly the dataset) it
276    # should be "bert-base-uncased"
277    latent_filename: str = "latent"
278    # path to the vocab file
279    vocab_path: str = str((PROJECT_DIR / "tokenizer/vocab.json").resolve())
280    # path to the merge path
281    merges_path: str = str((PROJECT_DIR / "tokenizer/merges.txt").resolve())
282    # max sequence length of text sequence
283    seq_length: int = 64
284    vocab_size: int = 822
285
286    # VAE configuration
287    latent_dim: int = 64
288    hidden_dim: int = 256
289    beta: float = 0.1
290
291
292class DomainModules(BaseModel):
293    visual: VisualModule = VisualModule()
294    attribute: AttributeModule = AttributeModule()
295    text: TextModule = TextModule()
296
297
298class EncodersConfig(BaseModel):
299    """
300    Encoder architecture config
301    """
302
303    hidden_dim: int | Mapping[DomainModuleVariant, int] = 32
304
305    # The model will have an extra linear before and after the n_layers
306    # Hence the total will be `2 + n_layers`
307    n_layers: int | Mapping[DomainModuleVariant, int] = 3
308
309
310class LoadedDomainConfig(BaseModel):
311    """
312    Domain params for the active domains of the run
313    """
314
315    # path to the pretrained module
316    checkpoint_path: Path
317    # domain to select
318    domain_type: DomainModuleVariant
319    # domain module specific arguments
320    args: Mapping[str, Any] = {}
321
322
323class DomainProportion(BaseModel):
324    """
325    Deprecated, DomainProportionT will be used instead.
326    """
327
328    # proportion for some domains
329    # should be in [0, 1]
330    proportion: float
331    # list of domains the proportion is associated to
332    # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data
333    domains: Sequence[str]
334
335
336class DomainProportionT(TypedDict):
337    """This replaces `DomainProportion` in future config."""
338
339    # proportion for some domains
340    # should be in [0, 1]
341    proportion: float
342    # list of domains the proportion is associated to
343    # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data
344    domains: Sequence[str]
345
346
347class GlobalWorkspace(BaseModel):
348    # latent dim of the GW
349    latent_dim: int = 12
350    # whether to use the fusion GW
351    use_fusion_model: bool = False
352    # softmax temp for the Softmax distruted selection used by the fusion model
353    selection_temperature: float = 0.2
354    # whether to learn the logit scale of the contrastive loss like in the clip paper
355    learn_logit_scale: bool = False
356    # whether to use the VSEPP (https://github.com/fartashf/vsepp) contrastive loss
357    # with associated params
358    vsepp_contrastive_loss: bool = False
359    vsepp_margin: float = 0.2
360    vsepp_measure: Literal["cosine", "order"] = "cosine"
361    vsepp_max_violation: bool = True
362    # whether to use linear encoders and decoders for the GW
363    linear_domains: bool = False
364    # whether to use bias when using linear encoders and decoders
365    linear_domains_use_bias: bool = False
366    # encoder architecture config
367    encoders: EncodersConfig = EncodersConfig()
368    # decoder architecture config
369    decoders: EncodersConfig = EncodersConfig()
370    # coefs of each loss. The total loss is computed using the given values and coefs
371    # you can select any available loss generated by the loss functions
372    loss_coefficients: Mapping[str, float] = {
373        "cycles": 1.0,
374        "demi_cycles": 1.0,
375        "translations": 1.0,
376        "contrastives": 0.01,
377        "fused": 1.0,
378    }
379    # checkpoint of the GW for downstream, visualization tasks, or migrations
380    checkpoint: Path | None = None
381    # deprecated, use Config.domain_data_args instead
382    domain_args: Mapping[str, Mapping[str, Any]] | None = Field(
383        default=None, deprecated="Use `config.domain_data_args` instead."
384    )
385    # deprecated, use Config.domains instead
386    domains: Sequence[LoadedDomainConfig] | None = Field(
387        default=None, deprecated="Use `config.domains` instead."
388    )
389    # deprecated, use Config.domain_proportions instead
390    domain_proportions: Sequence[DomainProportion] | None = Field(
391        default=None, deprecated="Use `config.domain_proportions` instead."
392    )
393
394
395class ShimmerConfigInfo(BaseModel):
396    """
397    Some shimmer related config
398    """
399
400    # version of shimmer used
401    version: str = shimmer_version
402    # whether started in debug mode
403    debug: bool = False
404    # params that were passed through CLI
405    cli: Any = {}
406
407
408class Config(ParsedModel):
409    seed: int = 0  # training seed
410    ood_seed: int | None = None  # Out of distribution seed
411    default_root_dir: Path = (
412        Path("./checkpoints")  # Path where to save and load logs and checkpoints
413    )
414    # Dataset information
415    dataset: Dataset
416    # Training config
417    training: Training = Training()
418    # Wandb configuration
419    wandb: WanDB = WanDB()
420    # Logging configuration
421    logging: Logging = Logging()
422    # Add a title to your wandb run
423    title: str | None = Field(None, alias="t")
424    # Add a description to your run
425    desc: str | None = Field(None, alias="d")
426    # proportion of each domain in the dataset relative to the size of the dataset
427    domain_proportions: Mapping[frozenset[str], float] = {}
428    # Config of the different domain modules
429    domain_modules: DomainModules = DomainModules()
430    # Domain params for the active domains of the run
431    domains: Sequence[LoadedDomainConfig] = []
432    # data related args used by the dataloader
433    domain_data_args: Mapping[str, Mapping[str, Any]] = {}
434    # Global workspace configuration
435    global_workspace: GlobalWorkspace = GlobalWorkspace()
436    # Config used during visualization
437    visualization: Visualization | None = None
438    # Slurm config when startig on a cluster
439    slurm: Slurm | None = None
440    __shimmer__: ShimmerConfigInfo = ShimmerConfigInfo()
441
442    @field_validator("domain_proportions", mode="before")
443    @classmethod
444    def domain_proportion_validator(
445        cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float]
446    ) -> Mapping[frozenset[str], float]:
447        """
448        Replace the format:
449        ```
450        - domains: ["v"]
451          proportion: 1.0
452        ```
453        in the yaml file into a Mapping[frozenset[str], float]
454        """
455        if isinstance(value, Mapping):
456            return value
457        else:
458            return {frozenset(item["domains"]): item["proportion"] for item in value}
459
460    @field_serializer("domain_proportions")
461    def serialize_domain_proportions(
462        self, domain_proportions: Mapping[frozenset[str], float], _info
463    ) -> list[DomainProportionT]:
464        return [
465            {"domains": list(domains), "proportion": proportion}
466            for domains, proportion in domain_proportions.items()
467        ]
468
469    @model_validator(mode="after")
470    def check_selected_domains_have_non_null_proportion(self) -> Self:
471        for domain in self.domains:
472            domain_base_name = domain.domain_type.kind.value.base
473            group = frozenset([domain_base_name])
474            if self.domain_proportions.get(group, 0) <= 0:
475                raise ValueError(
476                    "Selected domains in `domains` should have a non-zero "
477                    "proportion in `domain_proportions` "
478                    f"but '{domain_base_name}' is not part of `domain_proportions`."
479                )
480        return self
481
482
483def use_deprecated_vals(config: Config) -> Config:
484    # use deprecated values
485    if config.global_workspace.domain_args is not None:
486        config.domain_data_args = config.global_workspace.domain_args
487        warnings.warn(
488            "Deprecated `config.global_workspace.domain_args`, "
489            "use `config.domain_data_args` instead",
490            DeprecationWarning,
491            stacklevel=2,
492        )
493    if config.global_workspace.domains is not None:
494        config.domains = config.global_workspace.domains
495        warnings.warn(
496            "Deprecated `config.global_workspace.domains`, "
497            "use `config.domains` instead",
498            DeprecationWarning,
499            stacklevel=2,
500        )
501    if config.global_workspace.domain_proportions is not None:
502        config.domain_proportions = {
503            frozenset(item.domains): item.proportion
504            for item in config.global_workspace.domain_proportions
505        }
506
507        warnings.warn(
508            "Deprecated `config.global_workspace.domain_proportions`, "
509            "use `config.domain_proportions` instead",
510            DeprecationWarning,
511            stacklevel=2,
512        )
513    return config
514
515
516def load_config(
517    path: str | Path,
518    load_files: list[str] | None = None,
519    use_cli: bool = True,
520    debug_mode: bool = False,
521    argv: list[str] | None = None,
522    log_config: bool = False,
523) -> Config:
524    path = Path(path)
525    conf_files = []
526    if load_files is not None:
527        conf_files.extend(load_files)
528    if (path / "main.yaml").exists():
529        conf_files.append("main.yaml")
530    if (path / "local.yaml").exists():
531        conf_files.append("local.yaml")
532
533    if debug_mode and (path / "debug.yaml").exists():
534        conf_files.append("debug.yaml")
535
536    config_dict, cli_config = load_config_files(path, conf_files, use_cli, argv)
537
538    config_dict.update(
539        {
540            "__shimmer__": {
541                "version": __version__,
542                "debug": debug_mode,
543                "cli": cli_config,
544            }
545        }
546    )
547
548    conf = use_deprecated_vals(
549        validate_and_fill_missing(config_dict, Config, path, "local.yaml")
550    )
551    if log_config:
552        print("Loaded config:")
553        pprint.pp(dict(conf))
554    return conf
class DomainModuleVariant(enum.Enum):
 28class DomainModuleVariant(Enum):
 29    """
 30    This is used to select a particular DomainModule.
 31    Each domain can have different variants of domain modules.
 32
 33    For example "attr" will load the default attribute module that uses a VAE,
 34    whereas "attr_legacy" will load a domain module that directly passes attributes
 35    to the GW.
 36    """
 37
 38    # Attribute modules
 39    # -----------------
 40    # Attribute module using a VAE to encode the attribute vector
 41    attr = (DomainType.attr, "default")
 42    # This is the module used in Devillers et al. paper. There is no VAE and the
 43    # attributes are used directly as the unimodal latent representations
 44    attr_legacy = (DomainType.attr, "legacy")
 45    # Same as "attr" but adds an unpaired attributes (information not available in the
 46    # other domains).
 47    attr_unpaired = (DomainType.attr, "unpaired")
 48
 49    # Visual modules
 50    # --------------
 51    # Visual VAE
 52    v = (DomainType.v, "default")
 53    # Same as "v", but uses pre-saved latent VAE representation for faster training.
 54    # This skips the image loading and encoding and only loads latent representation.
 55    # The downside is that you cannot access the default image, but you can reconstruct
 56    # it with "decode_images".
 57    v_latents = (DomainType.v_latents, "default")
 58    # Same as "v_latents" but adds an unpaired value (radom information not available
 59    # in the other domains).
 60    v_latents_unpaired = (DomainType.v_latents, "unpaired")
 61
 62    # Text modules
 63    # ------------
 64    # Text domain.
 65    t = (DomainType.t, "default")
 66    t_attr = (DomainType.t, "t2attr")
 67
 68    def __init__(self, kind: DomainType, model_variant: str) -> None:
 69        """
 70        The two elements of the tuple are put in the the `kind` and the `model_variant`
 71        properties.
 72        """
 73        self.kind = kind
 74        self.model_variant = model_variant
 75
 76    @classmethod
 77    def __get_pydantic_core_schema__(
 78        cls, _source_type: Any, _handler: GetCoreSchemaHandler
 79    ) -> core_schema.CoreSchema:
 80        """
 81        Define how this type is validated and serialized by pydantic.
 82        It can take a str related to the enum keys.
 83        It should be serialized as the the enum key name.
 84        """
 85
 86        def validate_from_str(v: str) -> DomainModuleVariant:
 87            """
 88            Use names instead of values to select enums
 89            """
 90            assert v in DomainModuleVariant.__members__, (
 91                f"Domain type `{v}` is not a member "
 92                f"of {list(DomainModuleVariant.__members__.keys())}"
 93            )
 94            return DomainModuleVariant[v]
 95
 96        from_str_schema = core_schema.no_info_plain_validator_function(
 97            validate_from_str
 98        )
 99
100        def serialize_domain_variant(v: DomainModuleVariant) -> str:
101            return v.name
102
103        return core_schema.json_or_python_schema(
104            json_schema=from_str_schema,
105            python_schema=core_schema.union_schema(
106                [
107                    core_schema.is_instance_schema(DomainModuleVariant),
108                    from_str_schema,
109                ]
110            ),
111            serialization=core_schema.plain_serializer_function_ser_schema(
112                serialize_domain_variant
113            ),
114        )

This is used to select a particular DomainModule. Each domain can have different variants of domain modules.

For example "attr" will load the default attribute module that uses a VAE, whereas "attr_legacy" will load a domain module that directly passes attributes to the GW.

DomainModuleVariant(kind: simple_shapes_dataset.domain.DomainType, model_variant: str)
68    def __init__(self, kind: DomainType, model_variant: str) -> None:
69        """
70        The two elements of the tuple are put in the the `kind` and the `model_variant`
71        properties.
72        """
73        self.kind = kind
74        self.model_variant = model_variant

The two elements of the tuple are put in the the kind and the model_variant properties.

attr = <DomainModuleVariant.attr: (<DomainType.attr: DomainDesc(base='attr', kind='attr')>, 'default')>
attr_legacy = <DomainModuleVariant.attr_legacy: (<DomainType.attr: DomainDesc(base='attr', kind='attr')>, 'legacy')>
attr_unpaired = <DomainModuleVariant.attr_unpaired: (<DomainType.attr: DomainDesc(base='attr', kind='attr')>, 'unpaired')>
v = <DomainModuleVariant.v: (<DomainType.v: DomainDesc(base='v', kind='v')>, 'default')>
v_latents = <DomainModuleVariant.v_latents: (<DomainType.v_latents: DomainDesc(base='v', kind='v_latents')>, 'default')>
v_latents_unpaired = <DomainModuleVariant.v_latents_unpaired: (<DomainType.v_latents: DomainDesc(base='v', kind='v_latents')>, 'unpaired')>
t = <DomainModuleVariant.t: (<DomainType.t: DomainDesc(base='t', kind='t')>, 'default')>
t_attr = <DomainModuleVariant.t_attr: (<DomainType.t: DomainDesc(base='t', kind='t')>, 't2attr')>
kind
model_variant
class Logging(pydantic.main.BaseModel):
117class Logging(BaseModel):
118    # List of medias (images/text) that will be logged.
119    # The list is defined in individual train config (e.g. train_v.yaml, train_gw.yaml)
120    filter_images: Sequence[str] | None = None
121    log_train_medias_every_n_epochs: int | None = 10
122    log_val_medias_every_n_epochs: int | None = 10

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
filter_images: Sequence[str] | None
log_train_medias_every_n_epochs: int | None
log_val_medias_every_n_epochs: int | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Slurm(pydantic.main.BaseModel):
125class Slurm(BaseModel):
126    """
127    Slurm config for https://github.com/bdvllrs/auto-sbatch
128    """
129
130    script: str
131    run_workdir: str
132    python_env: str
133    command: str
134
135    pre_modules: Sequence[str] = []
136    run_modules: Sequence[str] = []
137    args: Mapping[str, Any] = {}
138
139    grid_search: Mapping[str, Sequence[Any]] | None = None
140    grid_search_exclude: Sequence[Mapping[str, Any]] | None = None
script: str
run_workdir: str
python_env: str
command: str
pre_modules: Sequence[str]
run_modules: Sequence[str]
args: Mapping[str, typing.Any]
grid_search_exclude: Sequence[Mapping[str, typing.Any]] | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Optim(pydantic.main.BaseModel):
143class Optim(BaseModel):
144    """
145    Optimizer config
146    """
147
148    # learning rate (max learning rate when using with the default OneCycle
149    # learning rate scheduler)
150    # see https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html#torch.optim.lr_scheduler.OneCycleLR)
151    # for information about the other config values
152    lr: float = 5e-3
153    # TODO: remove as redundant with lr
154    max_lr: float = 5e-3
155    start_lr: float = 5e-4
156    end_lr: float = 5e-4
157    pct_start: float = 0.2
158    weight_decay: float = 1e-5

Optimizer config

lr: float
max_lr: float
start_lr: float
end_lr: float
pct_start: float
weight_decay: float
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Training(pydantic.main.BaseModel):
161class Training(BaseModel):
162    """
163    Training related config.
164    As these config depend on what you are training,
165    they are defined in the script related yaml files (e.g. `train_v.yaml`,
166    `train_gw.yaml`).
167    """
168
169    batch_size: int = 2056
170    num_workers: int = 16
171    devices: int = 1  # num of devices (gpus) to use
172    accelerator: str = "gpu"
173
174    # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#fast-dev-run
175    fast_dev_run: bool = False
176    # number of training steps
177    max_steps: int = 200_000
178    enable_progress_bar: bool = True
179
180    # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#precision
181    # you may want to set to "16-mixed" if your gpu allows mixed precision
182    precision: Any = 32
183    # see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch-set-float32-matmul-precision
184    # you may want to decrease to "medium"
185    float32_matmul_precision: str = "highest"
186
187    # Optimizer config
188    optim: Optim = Optim()

Training related config. As these config depend on what you are training, they are defined in the script related yaml files (e.g. train_v.yaml, train_gw.yaml).

batch_size: int
num_workers: int
devices: int
accelerator: str
fast_dev_run: bool
max_steps: int
enable_progress_bar: bool
precision: Any
float32_matmul_precision: str
optim: Optim
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ExploreVAE(pydantic.main.BaseModel):
191class ExploreVAE(BaseModel):
192    # the VAE checkpoint to use
193    checkpoint: str
194    num_samples: int = 9
195    range_start: int = -3
196    range_end: int = 3
197    wandb_name: str | None = None

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
checkpoint: str
num_samples: int
range_start: int
range_end: int
wandb_name: str | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ExploreGW(pydantic.main.BaseModel):
200class ExploreGW(BaseModel):
201    checkpoint: str  # The GW checkpoint to use
202    domain: str
203    num_samples: int = 9
204    range_start: int = -3
205    range_end: int = 3
206    wandb_name: str | None = None

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
checkpoint: str
domain: str
num_samples: int
range_start: int
range_end: int
wandb_name: str | None
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Visualization(pydantic.main.BaseModel):
209class Visualization(BaseModel):
210    # when exploring a vae
211    explore_vae: ExploreVAE
212    # when exploring the GW
213    explore_gw: ExploreGW

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
explore_vae: ExploreVAE
explore_gw: ExploreGW
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class WanDB(pydantic.main.BaseModel):
216class WanDB(BaseModel):
217    enabled: bool = False
218    # where to save the logs
219    save_dir: str = ""
220    # the "entity/project" values of your wandb project
221    project: str = ""
222    entity: str = ""
223    # see https://docs.wandb.ai/ref/python/init/
224    reinit: bool = False

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
enabled: bool
save_dir: str
project: str
entity: str
reinit: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Dataset(pydantic.main.BaseModel):
227class Dataset(BaseModel):
228    """
229    Simple shapes dataset infos
230    """
231
232    # Path to the dataset obtainable on https://github.com/ruflab/simple-shapes-dataset
233    path: Path

Simple shapes dataset infos

path: pathlib.Path
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class VisualModule(pydantic.main.BaseModel):
236class VisualModule(BaseModel):
237    """
238    Config for the visual domain module
239    """
240
241    num_channels: int = 3  # num channels of image to configure the VAE
242    ae_dim: int = 256  # AE hidden dim
243    latent_dim: int = 8  # latent dim of the VAE
244    beta: float = 0.1  # beta for betaVAE
245
246    # Whether the model is color blind.
247    # It adds a transform on the dataset that averages all color channels.
248    # NOTE: this only works when using the "v" domain and not "v_latents" where
249    # visual latent representations are already extracted. In that case, use a different
250    # presaved-path in `domain_args`.
251    color_blind: bool = False

Config for the visual domain module

num_channels: int
ae_dim: int
latent_dim: int
beta: float
color_blind: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class AttributeModule(pydantic.main.BaseModel):
254class AttributeModule(BaseModel):
255    """
256    Config for the attribute domain module
257    """
258
259    latent_dim: int = 10  # latent dim of the VAE
260    hidden_dim: int = 64  # hidden dim of the AE (encoders and decoders)
261    beta: float = 0.05  # for betaVAE
262    coef_categories: float = 1
263    coef_attributes: float = 1
264
265    # Whether to remove rotation information from the attribute.
266    nullify_rotation: bool = False

Config for the attribute domain module

latent_dim: int
hidden_dim: int
beta: float
coef_categories: float
coef_attributes: float
nullify_rotation: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class TextModule(pydantic.main.BaseModel):
269class TextModule(BaseModel):
270    """
271    Config for the text domain module
272    """
273
274    # which file to load for text
275    # The file generated with `shapesd create` is "latent"
276    # If you use an older version (like if you downloaded directly the dataset) it
277    # should be "bert-base-uncased"
278    latent_filename: str = "latent"
279    # path to the vocab file
280    vocab_path: str = str((PROJECT_DIR / "tokenizer/vocab.json").resolve())
281    # path to the merge path
282    merges_path: str = str((PROJECT_DIR / "tokenizer/merges.txt").resolve())
283    # max sequence length of text sequence
284    seq_length: int = 64
285    vocab_size: int = 822
286
287    # VAE configuration
288    latent_dim: int = 64
289    hidden_dim: int = 256
290    beta: float = 0.1

Config for the text domain module

latent_filename: str
vocab_path: str
merges_path: str
seq_length: int
vocab_size: int
latent_dim: int
hidden_dim: int
beta: float
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class DomainModules(pydantic.main.BaseModel):
293class DomainModules(BaseModel):
294    visual: VisualModule = VisualModule()
295    attribute: AttributeModule = AttributeModule()
296    text: TextModule = TextModule()

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
visual: VisualModule
attribute: AttributeModule
text: TextModule
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class EncodersConfig(pydantic.main.BaseModel):
299class EncodersConfig(BaseModel):
300    """
301    Encoder architecture config
302    """
303
304    hidden_dim: int | Mapping[DomainModuleVariant, int] = 32
305
306    # The model will have an extra linear before and after the n_layers
307    # Hence the total will be `2 + n_layers`
308    n_layers: int | Mapping[DomainModuleVariant, int] = 3

Encoder architecture config

hidden_dim: int | Mapping[DomainModuleVariant, int]
n_layers: int | Mapping[DomainModuleVariant, int]
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class LoadedDomainConfig(pydantic.main.BaseModel):
311class LoadedDomainConfig(BaseModel):
312    """
313    Domain params for the active domains of the run
314    """
315
316    # path to the pretrained module
317    checkpoint_path: Path
318    # domain to select
319    domain_type: DomainModuleVariant
320    # domain module specific arguments
321    args: Mapping[str, Any] = {}

Domain params for the active domains of the run

checkpoint_path: pathlib.Path
domain_type: DomainModuleVariant
args: Mapping[str, typing.Any]
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class DomainProportion(pydantic.main.BaseModel):
324class DomainProportion(BaseModel):
325    """
326    Deprecated, DomainProportionT will be used instead.
327    """
328
329    # proportion for some domains
330    # should be in [0, 1]
331    proportion: float
332    # list of domains the proportion is associated to
333    # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data
334    domains: Sequence[str]

Deprecated, DomainProportionT will be used instead.

proportion: float
domains: Sequence[str]
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class DomainProportionT(typing_extensions.TypedDict):
337class DomainProportionT(TypedDict):
338    """This replaces `DomainProportion` in future config."""
339
340    # proportion for some domains
341    # should be in [0, 1]
342    proportion: float
343    # list of domains the proportion is associated to
344    # e.g. if domains: ["v", "t"], then it gives the prop of paired v, t data
345    domains: Sequence[str]

This replaces DomainProportion in future config.

proportion: float
domains: Sequence[str]
class GlobalWorkspace(pydantic.main.BaseModel):
348class GlobalWorkspace(BaseModel):
349    # latent dim of the GW
350    latent_dim: int = 12
351    # whether to use the fusion GW
352    use_fusion_model: bool = False
353    # softmax temp for the Softmax distruted selection used by the fusion model
354    selection_temperature: float = 0.2
355    # whether to learn the logit scale of the contrastive loss like in the clip paper
356    learn_logit_scale: bool = False
357    # whether to use the VSEPP (https://github.com/fartashf/vsepp) contrastive loss
358    # with associated params
359    vsepp_contrastive_loss: bool = False
360    vsepp_margin: float = 0.2
361    vsepp_measure: Literal["cosine", "order"] = "cosine"
362    vsepp_max_violation: bool = True
363    # whether to use linear encoders and decoders for the GW
364    linear_domains: bool = False
365    # whether to use bias when using linear encoders and decoders
366    linear_domains_use_bias: bool = False
367    # encoder architecture config
368    encoders: EncodersConfig = EncodersConfig()
369    # decoder architecture config
370    decoders: EncodersConfig = EncodersConfig()
371    # coefs of each loss. The total loss is computed using the given values and coefs
372    # you can select any available loss generated by the loss functions
373    loss_coefficients: Mapping[str, float] = {
374        "cycles": 1.0,
375        "demi_cycles": 1.0,
376        "translations": 1.0,
377        "contrastives": 0.01,
378        "fused": 1.0,
379    }
380    # checkpoint of the GW for downstream, visualization tasks, or migrations
381    checkpoint: Path | None = None
382    # deprecated, use Config.domain_data_args instead
383    domain_args: Mapping[str, Mapping[str, Any]] | None = Field(
384        default=None, deprecated="Use `config.domain_data_args` instead."
385    )
386    # deprecated, use Config.domains instead
387    domains: Sequence[LoadedDomainConfig] | None = Field(
388        default=None, deprecated="Use `config.domains` instead."
389    )
390    # deprecated, use Config.domain_proportions instead
391    domain_proportions: Sequence[DomainProportion] | None = Field(
392        default=None, deprecated="Use `config.domain_proportions` instead."
393    )

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
latent_dim: int
use_fusion_model: bool
selection_temperature: float
learn_logit_scale: bool
vsepp_contrastive_loss: bool
vsepp_margin: float
vsepp_measure: Literal['cosine', 'order']
vsepp_max_violation: bool
linear_domains: bool
linear_domains_use_bias: bool
encoders: EncodersConfig
decoders: EncodersConfig
loss_coefficients: Mapping[str, float]
checkpoint: pathlib.Path | None
domain_args: Mapping[str, Mapping[str, typing.Any]] | None

Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.

Attributes:
  • msg: The deprecation message to be emitted.
  • wrapped_property: The property instance if the deprecated field is a computed field, or None.
  • field_name: The name of the field being deprecated.
domains: Sequence[LoadedDomainConfig] | None

Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.

Attributes:
  • msg: The deprecation message to be emitted.
  • wrapped_property: The property instance if the deprecated field is a computed field, or None.
  • field_name: The name of the field being deprecated.
domain_proportions: Sequence[DomainProportion] | None

Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.

Attributes:
  • msg: The deprecation message to be emitted.
  • wrapped_property: The property instance if the deprecated field is a computed field, or None.
  • field_name: The name of the field being deprecated.
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class ShimmerConfigInfo(pydantic.main.BaseModel):
396class ShimmerConfigInfo(BaseModel):
397    """
398    Some shimmer related config
399    """
400
401    # version of shimmer used
402    version: str = shimmer_version
403    # whether started in debug mode
404    debug: bool = False
405    # params that were passed through CLI
406    cli: Any = {}

Some shimmer related config

version: str
debug: bool
cli: Any
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Config(cfg_tools.data_parser.ParsedModel):
409class Config(ParsedModel):
410    seed: int = 0  # training seed
411    ood_seed: int | None = None  # Out of distribution seed
412    default_root_dir: Path = (
413        Path("./checkpoints")  # Path where to save and load logs and checkpoints
414    )
415    # Dataset information
416    dataset: Dataset
417    # Training config
418    training: Training = Training()
419    # Wandb configuration
420    wandb: WanDB = WanDB()
421    # Logging configuration
422    logging: Logging = Logging()
423    # Add a title to your wandb run
424    title: str | None = Field(None, alias="t")
425    # Add a description to your run
426    desc: str | None = Field(None, alias="d")
427    # proportion of each domain in the dataset relative to the size of the dataset
428    domain_proportions: Mapping[frozenset[str], float] = {}
429    # Config of the different domain modules
430    domain_modules: DomainModules = DomainModules()
431    # Domain params for the active domains of the run
432    domains: Sequence[LoadedDomainConfig] = []
433    # data related args used by the dataloader
434    domain_data_args: Mapping[str, Mapping[str, Any]] = {}
435    # Global workspace configuration
436    global_workspace: GlobalWorkspace = GlobalWorkspace()
437    # Config used during visualization
438    visualization: Visualization | None = None
439    # Slurm config when startig on a cluster
440    slurm: Slurm | None = None
441    __shimmer__: ShimmerConfigInfo = ShimmerConfigInfo()
442
443    @field_validator("domain_proportions", mode="before")
444    @classmethod
445    def domain_proportion_validator(
446        cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float]
447    ) -> Mapping[frozenset[str], float]:
448        """
449        Replace the format:
450        ```
451        - domains: ["v"]
452          proportion: 1.0
453        ```
454        in the yaml file into a Mapping[frozenset[str], float]
455        """
456        if isinstance(value, Mapping):
457            return value
458        else:
459            return {frozenset(item["domains"]): item["proportion"] for item in value}
460
461    @field_serializer("domain_proportions")
462    def serialize_domain_proportions(
463        self, domain_proportions: Mapping[frozenset[str], float], _info
464    ) -> list[DomainProportionT]:
465        return [
466            {"domains": list(domains), "proportion": proportion}
467            for domains, proportion in domain_proportions.items()
468        ]
469
470    @model_validator(mode="after")
471    def check_selected_domains_have_non_null_proportion(self) -> Self:
472        for domain in self.domains:
473            domain_base_name = domain.domain_type.kind.value.base
474            group = frozenset([domain_base_name])
475            if self.domain_proportions.get(group, 0) <= 0:
476                raise ValueError(
477                    "Selected domains in `domains` should have a non-zero "
478                    "proportion in `domain_proportions` "
479                    f"but '{domain_base_name}' is not part of `domain_proportions`."
480                )
481        return self

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes:
  • __class_vars__: The names of the class variables defined on the model.
  • __private_attributes__: Metadata about the private attributes of the model.
  • __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.
  • __pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
  • __pydantic_core_schema__: The core schema of the model.
  • __pydantic_custom_init__: Whether the model has a custom __init__ function.
  • __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces Model.__validators__ and Model.__root_validators__ from Pydantic V1.
  • __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
  • __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
  • __pydantic_post_init__: The name of the post-init method for the model, if defined.
  • __pydantic_root_model__: Whether the model is a [RootModel][pydantic.root_model.RootModel].
  • __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model.
  • __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model.
  • __pydantic_fields__: A dictionary of field names and their corresponding [FieldInfo][pydantic.fields.FieldInfo] objects.
  • __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [ComputedFieldInfo][pydantic.fields.ComputedFieldInfo] objects.
  • __pydantic_extra__: A dictionary containing extra values, if [extra][pydantic.config.ConfigDict.extra] is set to 'allow'.
  • __pydantic_fields_set__: The names of fields explicitly set during instantiation.
  • __pydantic_private__: Values of private attributes set on the model instance.
seed: int
ood_seed: int | None
default_root_dir: pathlib.Path
dataset: Dataset
training: Training
wandb: WanDB
logging: Logging
title: str | None
desc: str | None
domain_proportions: Mapping[frozenset[str], float]
domain_modules: DomainModules
domains: Sequence[LoadedDomainConfig]
domain_data_args: Mapping[str, Mapping[str, typing.Any]]
global_workspace: GlobalWorkspace
visualization: Visualization | None
slurm: Slurm | None
@field_validator('domain_proportions', mode='before')
@classmethod
def domain_proportion_validator( cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float]) -> Mapping[frozenset[str], float]:
443    @field_validator("domain_proportions", mode="before")
444    @classmethod
445    def domain_proportion_validator(
446        cls, value: Sequence[DomainProportionT] | Mapping[frozenset[str], float]
447    ) -> Mapping[frozenset[str], float]:
448        """
449        Replace the format:
450        ```
451        - domains: ["v"]
452          proportion: 1.0
453        ```
454        in the yaml file into a Mapping[frozenset[str], float]
455        """
456        if isinstance(value, Mapping):
457            return value
458        else:
459            return {frozenset(item["domains"]): item["proportion"] for item in value}

Replace the format:

- domains: ["v"]
  proportion: 1.0

in the yaml file into a Mapping[frozenset[str], float]

@field_serializer('domain_proportions')
def serialize_domain_proportions( self, domain_proportions: Mapping[frozenset[str], float], _info) -> list[DomainProportionT]:
461    @field_serializer("domain_proportions")
462    def serialize_domain_proportions(
463        self, domain_proportions: Mapping[frozenset[str], float], _info
464    ) -> list[DomainProportionT]:
465        return [
466            {"domains": list(domains), "proportion": proportion}
467            for domains, proportion in domain_proportions.items()
468        ]
@model_validator(mode='after')
def check_selected_domains_have_non_null_proportion(self) -> Self:
470    @model_validator(mode="after")
471    def check_selected_domains_have_non_null_proportion(self) -> Self:
472        for domain in self.domains:
473            domain_base_name = domain.domain_type.kind.value.base
474            group = frozenset([domain_base_name])
475            if self.domain_proportions.get(group, 0) <= 0:
476                raise ValueError(
477                    "Selected domains in `domains` should have a non-zero "
478                    "proportion in `domain_proportions` "
479                    f"but '{domain_base_name}' is not part of `domain_proportions`."
480                )
481        return self
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def use_deprecated_vals(config: Config) -> Config:
484def use_deprecated_vals(config: Config) -> Config:
485    # use deprecated values
486    if config.global_workspace.domain_args is not None:
487        config.domain_data_args = config.global_workspace.domain_args
488        warnings.warn(
489            "Deprecated `config.global_workspace.domain_args`, "
490            "use `config.domain_data_args` instead",
491            DeprecationWarning,
492            stacklevel=2,
493        )
494    if config.global_workspace.domains is not None:
495        config.domains = config.global_workspace.domains
496        warnings.warn(
497            "Deprecated `config.global_workspace.domains`, "
498            "use `config.domains` instead",
499            DeprecationWarning,
500            stacklevel=2,
501        )
502    if config.global_workspace.domain_proportions is not None:
503        config.domain_proportions = {
504            frozenset(item.domains): item.proportion
505            for item in config.global_workspace.domain_proportions
506        }
507
508        warnings.warn(
509            "Deprecated `config.global_workspace.domain_proportions`, "
510            "use `config.domain_proportions` instead",
511            DeprecationWarning,
512            stacklevel=2,
513        )
514    return config
def load_config( path: str | pathlib.Path, load_files: list[str] | None = None, use_cli: bool = True, debug_mode: bool = False, argv: list[str] | None = None, log_config: bool = False) -> Config:
517def load_config(
518    path: str | Path,
519    load_files: list[str] | None = None,
520    use_cli: bool = True,
521    debug_mode: bool = False,
522    argv: list[str] | None = None,
523    log_config: bool = False,
524) -> Config:
525    path = Path(path)
526    conf_files = []
527    if load_files is not None:
528        conf_files.extend(load_files)
529    if (path / "main.yaml").exists():
530        conf_files.append("main.yaml")
531    if (path / "local.yaml").exists():
532        conf_files.append("local.yaml")
533
534    if debug_mode and (path / "debug.yaml").exists():
535        conf_files.append("debug.yaml")
536
537    config_dict, cli_config = load_config_files(path, conf_files, use_cli, argv)
538
539    config_dict.update(
540        {
541            "__shimmer__": {
542                "version": __version__,
543                "debug": debug_mode,
544                "cli": cli_config,
545            }
546        }
547    )
548
549    conf = use_deprecated_vals(
550        validate_and_fill_missing(config_dict, Config, path, "local.yaml")
551    )
552    if log_config:
553        print("Loaded config:")
554        pprint.pp(dict(conf))
555    return conf