shimmer.modules.global_workspace

  1from collections.abc import Iterable, Mapping
  2from pathlib import Path
  3from typing import Any, Generic, TypedDict, TypeVar, cast
  4
  5import torch
  6from lightning.pytorch import LightningModule
  7from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
  8from torch.nn import Module, ModuleDict
  9from torch.optim.lr_scheduler import OneCycleLR
 10
 11from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType
 12from shimmer.modules.domain import DomainModule
 13from shimmer.modules.gw_module import (
 14    GWModule,
 15    GWModuleBase,
 16    GWModuleBayesian,
 17)
 18from shimmer.modules.losses import (
 19    BroadcastLossCoefs,
 20    GWLosses,
 21    GWLosses2Domains,
 22    GWLossesBase,
 23    GWLossesBayesian,
 24    LossCoefs,
 25)
 26from shimmer.modules.selection import (
 27    FixedSharedSelection,
 28    RandomSelection,
 29    SelectionBase,
 30    SingleDomainSelection,
 31)
 32from shimmer.modules.utils import batch_cycles, batch_demi_cycles, batch_translations
 33from shimmer.types import (
 34    LatentsDomainGroupsDT,
 35    LatentsDomainGroupsT,
 36    ModelModeT,
 37    RawDomainGroupsDT,
 38    RawDomainGroupsT,
 39    RawDomainGroupT,
 40)
 41from shimmer.utils import groups_batch_size
 42
 43
 44class SchedulerArgs(TypedDict, total=False):
 45    """TypedDict of arguments passed to the OneCycle scheduler"""
 46
 47    max_lr: float
 48    """Maximum learning rate"""
 49
 50    total_steps: int
 51    """Total number of steps"""
 52
 53
 54class GWPredictionsBase(TypedDict):
 55    """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`"""
 56
 57    states: dict[str, torch.Tensor]
 58    """
 59    GW state representation from domain groups with only one domain.
 60    The key represent the domain's name.
 61    """
 62
 63
 64_T_gw_mod = TypeVar("_T_gw_mod", bound=GWModuleBase)
 65_T_selection_mod = TypeVar("_T_selection_mod", bound=SelectionBase)
 66_T_loss_mod = TypeVar("_T_loss_mod", bound=GWLossesBase)
 67
 68
 69class GlobalWorkspaceBase(
 70    Generic[_T_gw_mod, _T_selection_mod, _T_loss_mod], LightningModule
 71):
 72    """
 73    Global Workspace Lightning Module.
 74
 75    This is the base class to build the Global Workspace.
 76    """
 77
 78    def __init__(
 79        self,
 80        gw_mod: _T_gw_mod,
 81        selection_mod: _T_selection_mod,
 82        loss_mod: _T_loss_mod,
 83        optim_lr: float = 1e-3,
 84        optim_weight_decay: float = 0.0,
 85        scheduler_args: SchedulerArgs | None = None,
 86    ) -> None:
 87        """
 88        Initializes a GW
 89
 90        Args:
 91            gw_mod (`GWModuleBase`): the GWModule
 92            selection_mod (`SelectionBase`): selection module
 93            loss_mod (`GWLossesBase`): module to compute the GW losses.
 94            optim_lr (`float`): learning rate
 95            optim_weight_decay (`float`): weight decay
 96            scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define
 97                scheduler parameters.
 98        """
 99        super().__init__()
100        self.save_hyperparameters(
101            ignore=[
102                "gw_mod",
103                "selection_mod",
104                "domain_mods",
105                "loss_mod",
106                "domain_descriptions",
107                "contrastive_loss",
108                "cont_loss_bayesian",
109                "gw_encoders",
110                "gw_decoders",
111            ]
112        )
113
114        self.gw_mod = gw_mod
115        """ a `GWModuleBase` implementation."""
116
117        self.selection_mod = selection_mod
118        """A `SelectionBase` implementation."""
119
120        self.loss_mod = loss_mod
121        """The module that computes losses of the GW"""
122
123        self.optim_lr = optim_lr
124        self.optim_weight_decay = optim_weight_decay
125        self.scheduler_args = SchedulerArgs(max_lr=optim_lr, total_steps=1)
126        if scheduler_args is not None:
127            self.scheduler_args.update(scheduler_args)
128
129    @property
130    def domain_mods(self) -> Mapping[str, DomainModule]:
131        return self.gw_mod.domain_mods
132
133    @property
134    def workspace_dim(self) -> int:
135        """Dimension of the GW."""
136        return self.gw_mod.workspace_dim
137
138    def encode_and_fuse(
139        self, x: LatentsDomainGroupsT, selection_module: SelectionBase
140    ) -> dict[frozenset[str], torch.Tensor]:
141        """
142        Encode a group of latent representations into the GW representation.
143
144        Args:
145            x (`LatentsDomainGroupsT`): the input domain representations.
146            selection_scores (`Mapping[str, torch.Tensor]`):
147
148        Returns:
149            `dict[frozenset[str], torch.Tensor]`: the GW representations.
150        """
151        return {
152            domains: self.gw_mod.encode_and_fuse(latents, selection_module)
153            for domains, latents in x.items()
154        }
155
156    def encode(self, x: LatentsDomainGroupsT) -> LatentsDomainGroupsDT:
157        """
158        Encode a group of latent representations into the pre-fusion GW representation.
159
160        Args:
161            x (`LatentsDomainGroupsT`): the input domain representations.
162
163        Returns:
164            `LatensDomainGroupsDT`: the GW representations.
165        """
166        return {domains: self.gw_mod.encode(latents) for domains, latents in x.items()}
167
168    def fuse(
169        self,
170        x: LatentsDomainGroupsT,
171        selection_scores: Mapping[frozenset[str], Mapping[str, torch.Tensor]],
172    ) -> dict[frozenset[str], torch.Tensor]:
173        """
174        Fuses a group of latent representations into the GW representation.
175
176        Args:
177            x (`LatentsDomainGroupsT`): the pre-fusion latent representations
178            selection_scores (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`):
179                selection scores for each group
180
181        Returns:
182            `dict[frozenset[str], torch.Tensor]`: GW representation of each group
183        """
184        return {
185            domains: self.gw_mod.fuse(latents, selection_scores[domains])
186            for domains, latents in x.items()
187        }
188
189    def decode(
190        self,
191        z: Mapping[frozenset[str], torch.Tensor],
192        domains: Iterable[str] | None = None,
193    ) -> LatentsDomainGroupsDT:
194        """
195        Decode the group GW representation into given `domains`.
196
197        Args:
198            z (`torch.Tensor`): the GW representation.
199            domains (`Iterable[str]`): iterable of domains to decode.
200
201        Returns:
202            `dict[str, torch.Tensor]`: the decoded unimodal representations.
203        """
204        return {
205            domain_names: self.gw_mod.decode(gw_rep, domains)
206            for domain_names, gw_rep in z.items()
207        }
208
209    def forward(  # type: ignore
210        self,
211        latent_domains: LatentsDomainGroupsT,
212    ) -> GWPredictionsBase:
213        """
214        Computes demi-cycles, cycles, and translations.
215
216        Args:
217            latent_domains (`LatentsT`): Groups of domains for the computation.
218
219        Returns:
220            `GWPredictionsBase`: the predictions on the batch.
221        """
222
223        return GWPredictionsBase(states=self.batch_gw_states(latent_domains))
224
225    def batch_gw_states(
226        self, latent_domains: LatentsDomainGroupsT
227    ) -> dict[str, torch.Tensor]:
228        """
229        Comptues GW states of a batch of groups of domains.
230
231        Args:
232            latent_domains (`LatentsT`): the batch of groups of domains
233
234        Returns:
235            `dict[str, torch.Tensor]`: states for each domain.
236        """
237        predictions: dict[str, torch.Tensor] = {}
238        for domains, latents in latent_domains.items():
239            if len(domains) > 1:
240                continue
241            domain_name = list(domains)[0]
242            z = self.gw_mod.encode_and_fuse(
243                latents, selection_module=self.selection_mod
244            )
245            predictions[domain_name] = z
246        return predictions
247
248    def encode_domain(self, domain: Any, name: str) -> torch.Tensor:
249        """
250        Encodes a domain from the domain data into the unimodal representation.
251
252        This is a convenient proxy for the `DomainModule.encode` method and is
253        equivalent to:
254        ```python
255        self.domain_mods[name].encode(domain)
256        ```
257
258        Args:
259            domain (`Any`): the domain data
260            name (`str`): domain name to encode
261
262        Returns:
263            `torch.Tensor`: the domain's unimodal representation.
264        """
265        return self.domain_mods[name].encode(domain)
266
267    def encode_domains(self, batch: RawDomainGroupsT) -> LatentsDomainGroupsDT:
268        """
269        Encode all domains in the batch.
270
271        Args:
272            batch (`RawDomainGroupsT`): the batch of
273                domain groups with raw unimodal data to encode into groups of latent
274                representations.
275
276        Returns:
277            `LatentsDomainGroupsDT`: the domains' unimodal representations.
278        """
279        return {
280            domains: {
281                name: self.domain_mods[name].encode(domain)
282                for name, domain in data.items()
283            }
284            for domains, data in batch.items()
285        }
286
287    def decode_domain(self, domain: torch.Tensor, name: str) -> Any:
288        """
289        Decodes a domain from the unimodal representation into the domain data.
290
291        This is a convenient proxy for the `DomainModule.encode` method and is
292        equivalent to:
293        ```python
294        self.domain_mods[name].decode(domain)
295        ```
296
297        Args:
298            domain (`torch.Tensor`): the domain data
299            name (`str`): domain name to encode
300
301        Returns:
302            `Any`: the domain's raw data.
303        """
304        return self.domain_mods[name].decode(domain)
305
306    def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroupsDT:
307        """
308        Decodes all domains in the batch.
309
310        Args:
311            batch (`LatentsDomainGroupsT`): the batch of
312                domain groups with unimodal latent representation to decode into
313                groups of raw data.
314
315        Returns:
316            `LatentsDomainGroupsDT`: the domains' raw data.
317        """
318        return {
319            domains: {
320                name: self.domain_mods[name].decode(domain)
321                for name, domain in latents.items()
322            }
323            for domains, latents in latents_domain.items()
324        }
325
326    def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor:
327        """
328        The generic step used in `training_step`, `validation_step` and
329        `test_step`.
330
331        Args:
332            batch (`RawDomainGroupsT`): the batch of groups of raw unimodal data.
333            mode (`ModelModeT`):
334
335        Returns:
336            `torch.Tensor`: the loss to train on.
337        """
338        domain_latents = self.encode_domains(batch)
339        batch_size = groups_batch_size(domain_latents)
340
341        loss_output = self.loss_mod.step(domain_latents, mode)
342
343        for name, metric in loss_output.all.items():
344            self.log(
345                f"{mode}/{name}",
346                metric,
347                batch_size=batch_size,
348                add_dataloader_idx=False,
349            )
350
351        return loss_output.loss
352
353    def validation_step(  # type: ignore
354        self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0
355    ) -> torch.Tensor:
356        """Validation step used by lightning"""
357
358        batch = {frozenset(data.keys()): data}
359        for domain in data:
360            batch[frozenset([domain])] = {domain: data[domain]}
361        if dataloader_idx == 0:
362            return self.generic_step(batch, mode="val")
363        return self.generic_step(batch, mode="val/ood")
364
365    def test_step(  # type: ignore
366        self, data: Mapping[str, Any], batch_idx: int, dataloader_idx: int = 0
367    ) -> torch.Tensor:
368        """Test step used by lightning"""
369
370        batch = {frozenset(data.keys()): data}
371        for domain in data:
372            batch[frozenset([domain])] = {domain: data[domain]}
373        if dataloader_idx == 0:
374            return self.generic_step(batch, mode="test")
375        return self.generic_step(batch, mode="test/ood")
376
377    def training_step(  # type: ignore
378        self, batch: Mapping[frozenset[str], Mapping[str, Any]], batch_idx: int
379    ) -> torch.Tensor:
380        """Training step used by lightning"""
381
382        return self.generic_step(batch, mode="train")
383
384    def predict_step(  # type: ignore
385        self, data: Mapping[str, Any], batch_idx: int
386    ) -> GWPredictionsBase:
387        """Predict step used by lightning"""
388
389        batch = {frozenset(data.keys()): data}
390        for domain in data:
391            batch[frozenset([domain])] = {domain: data[domain]}
392
393        domain_latents = self.encode_domains(batch)
394        return self.forward(domain_latents)
395
396    def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
397        """
398        Configure models optimizers.
399
400        Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate
401        scheduler.
402        """
403
404        optimizer = torch.optim.AdamW(
405            self.parameters(),
406            lr=self.optim_lr,
407            weight_decay=self.optim_weight_decay,
408        )
409
410        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
411
412        return {
413            "optimizer": optimizer,
414            "lr_scheduler": {
415                "scheduler": lr_scheduler,
416                "interval": "step",
417            },
418        }
419
420
421def freeze_domain_modules(
422    domain_mods: Mapping[str, DomainModule],
423) -> dict[str, DomainModule]:
424    """
425    Freezes weights and set to eval mode the domain modules.
426
427    .. note::
428        The output is casted as `dict[str, DomainModule]` type for better
429        auto-completion, but is actually a torch `ModuleDict`.
430
431    Args:
432        domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze
433
434    Returns:
435        `ModuleDict`: frozen modules.
436    """
437
438    for mod in domain_mods.values():
439        mod.freeze()
440    # Cast for better auto-completion at the expense of ModuleDict
441    return cast(dict[str, DomainModule], ModuleDict(domain_mods))
442
443
444class GWPredictions(GWPredictionsBase):
445    """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`"""
446
447    demi_cycles: dict[str, torch.Tensor]
448    """
449    Demi-cycle predictions of the model for each domain. Only computed on domain
450    groups with only one domain.
451    """
452
453    cycles: dict[tuple[str, str], torch.Tensor]
454    """
455    Cycle predictions of the model from one domain through another one.
456    Only computed on domain groups with more than one domain.
457    The keys are tuple with start domain and intermediary domain.
458    """
459
460    translations: dict[tuple[str, str], torch.Tensor]
461    """
462    Translation predictions of the model from one domain through another one.
463
464    Only computed on domain groups with more than one domain.
465    The keys are tuples with start domain and target domain.
466    """
467
468
469class GlobalWorkspace2Domains(
470    GlobalWorkspaceBase[GWModule, SingleDomainSelection, GWLosses2Domains]
471):
472    """
473    A simple 2-domains max flavor of GlobalWorkspaceBase.
474
475    This is used to simplify a Global Workspace instanciation and only overrides the
476    `__init__` method.
477    """
478
479    def __init__(
480        self,
481        domain_mods: Mapping[str, DomainModule],
482        gw_encoders: Mapping[str, Module],
483        gw_decoders: Mapping[str, Module],
484        workspace_dim: int,
485        loss_coefs: LossCoefs,
486        optim_lr: float = 1e-3,
487        optim_weight_decay: float = 0.0,
488        scheduler_args: SchedulerArgs | None = None,
489        learn_logit_scale: bool = False,
490        contrastive_loss: ContrastiveLossType | None = None,
491    ) -> None:
492        """
493        Initializes a Global Workspace
494
495        Args:
496            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
497                connected to the GW. Keys are domain names, values are the
498                `DomainModule`.
499            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
500                name to a `torch.nn.Module` class which role is to encode a
501                unimodal latent representations into a GW representation (pre fusion).
502            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
503                name to a `torch.nn.Module` class which role is to decode a
504                GW representation into a unimodal latent representations.
505            workspace_dim (`int`): dimension of the GW.
506            loss_coefs (`LossCoefs`): loss coefficients
507            optim_lr (`float`): learning rate
508            optim_weight_decay (`float`): weight decay
509            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
510            learn_logit_scale (`bool`): whether to learn the contrastive learning
511                contrastive loss when using the default contrastive loss.
512            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
513                function used for alignment. `learn_logit_scale` will not affect custom
514                contrastive losses.
515        """
516        domain_mods = freeze_domain_modules(domain_mods)
517
518        gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
519        if contrastive_loss is None:
520            contrastive_loss = ContrastiveLoss(
521                torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
522            )
523        selection_mod = SingleDomainSelection()
524        loss_mod = GWLosses2Domains(
525            gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
526        )
527
528        super().__init__(
529            gw_mod,
530            selection_mod,
531            loss_mod,
532            optim_lr,
533            optim_weight_decay,
534            scheduler_args,
535        )
536
537    def forward(  # type: ignore
538        self,
539        latent_domains: LatentsDomainGroupsT,
540    ) -> GWPredictions:
541        """
542        Computes demi-cycles, cycles, and translations.
543
544        Args:
545            latent_domains (`LatentsT`): Groups of domains for the computation.
546
547        Returns:
548            `GWPredictions`: the predictions on the batch.
549        """
550        return GWPredictions(
551            demi_cycles=batch_demi_cycles(
552                self.gw_mod, self.selection_mod, latent_domains
553            ),
554            cycles=batch_cycles(
555                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
556            ),
557            translations=batch_translations(
558                self.gw_mod, self.selection_mod, latent_domains
559            ),
560            **super().forward(latent_domains),
561        )
562
563
564class GlobalWorkspace(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]):
565    """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
566
567    This is used to simplify a Global Workspace instanciation and only overrides the
568    `__init__` method.
569    """
570
571    def __init__(
572        self,
573        domain_mods: Mapping[str, DomainModule],
574        gw_encoders: Mapping[str, Module],
575        gw_decoders: Mapping[str, Module],
576        workspace_dim: int,
577        loss_coefs: BroadcastLossCoefs,
578        selection_temperature: float = 0.2,
579        optim_lr: float = 1e-3,
580        optim_weight_decay: float = 0.0,
581        scheduler_args: SchedulerArgs | None = None,
582        learn_logit_scale: bool = False,
583        contrastive_loss: ContrastiveLossType | None = None,
584    ) -> None:
585        """
586        Initializes a Global Workspace
587
588        Args:
589            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
590                connected to the GW. Keys are domain names, values are the
591                `DomainModule`.
592            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
593                name to a `torch.nn.Module` class which role is to encode a
594                unimodal latent representations into a GW representation (pre fusion).
595            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
596                name to a `torch.nn.Module` class which role is to decode a
597                GW representation into a unimodal latent representations.
598            workspace_dim (`int`): dimension of the GW.
599            loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
600            selection_temperature (`float`): temperature value for the RandomSelection
601                module.
602            optim_lr (`float`): learning rate
603            optim_weight_decay (`float`): weight decay
604            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
605            learn_logit_scale (`bool`): whether to learn the contrastive learning
606                contrastive loss when using the default contrastive loss.
607            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
608                function used for alignment. `learn_logit_scale` will not affect custom
609                contrastive losses.
610        """
611        domain_mods = freeze_domain_modules(domain_mods)
612        gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
613
614        if contrastive_loss is None:
615            contrastive_loss = ContrastiveLoss(
616                torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
617            )
618
619        selection_mod = RandomSelection(selection_temperature)
620        loss_mod = GWLosses(
621            gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
622        )
623
624        super().__init__(
625            gw_mod,
626            selection_mod,
627            loss_mod,
628            optim_lr,
629            optim_weight_decay,
630            scheduler_args,
631        )
632
633    def forward(  # type: ignore
634        self,
635        latent_domains: LatentsDomainGroupsT,
636    ) -> GWPredictions:
637        """
638        Computes demi-cycles, cycles, and translations.
639
640        Args:
641            latent_domains (`LatentsT`): Groups of domains for the computation.
642
643        Returns:
644            `GWPredictions`: the predictions on the batch.
645        """
646        return GWPredictions(
647            demi_cycles=batch_demi_cycles(
648                self.gw_mod, self.selection_mod, latent_domains
649            ),
650            cycles=batch_cycles(
651                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
652            ),
653            translations=batch_translations(
654                self.gw_mod, self.selection_mod, latent_domains
655            ),
656            # TODO: add other combinations
657            **super().forward(latent_domains),
658        )
659
660
661class GlobalWorkspaceBayesian(
662    GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian]
663):
664    """
665    A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty
666    prediction.
667
668    This is used to simplify a Global Workspace instanciation and only overrides the
669    `__init__` method.
670    """
671
672    def __init__(
673        self,
674        domain_mods: Mapping[str, DomainModule],
675        gw_encoders: Mapping[str, Module],
676        gw_decoders: Mapping[str, Module],
677        workspace_dim: int,
678        loss_coefs: BroadcastLossCoefs,
679        sensitivity_selection: float = 1,
680        sensitivity_precision: float = 1,
681        optim_lr: float = 1e-3,
682        optim_weight_decay: float = 0.0,
683        scheduler_args: SchedulerArgs | None = None,
684        learn_logit_scale: bool = False,
685        use_normalized_constrastive: bool = True,
686        contrastive_loss: ContrastiveLossType | None = None,
687        precision_softmax_temp: float = 0.01,
688    ) -> None:
689        """
690        Initializes a Global Workspace
691
692        Args:
693            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
694                connected to the GW. Keys are domain names, values are the
695                `DomainModule`.
696            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
697                name to a `torch.nn.Module` class which role is to encode a
698                unimodal latent representations into a GW representation (pre fusion).
699            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
700                name to a `torch.nn.Module` class which role is to decode a
701                GW representation into a unimodal latent representations.
702            workspace_dim (`int`): dimension of the GW.
703            loss_coefs (`LossCoefs`): loss coefficients
704            sensitivity_selection (`float`): sensivity coef $c'_1$
705            sensitivity_precision (`float`): sensitivity coef $c'_2$
706            optim_lr (`float`): learning rate
707            optim_weight_decay (`float`): weight decay
708            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
709            learn_logit_scale (`bool`): whether to learn the contrastive learning
710                contrastive loss when using the default contrastive loss.
711            use_normalized_constrastive (`bool`): whether to use the normalized cont
712                loss by the precision coefs
713            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
714                function used for alignment. `learn_logit_scale` will not affect custom
715                contrastive losses.
716            precision_softmax_temp (`float`): temperature to use in softmax of
717                precision
718        """
719        domain_mods = freeze_domain_modules(domain_mods)
720
721        gw_mod = GWModuleBayesian(
722            domain_mods,
723            workspace_dim,
724            gw_encoders,
725            gw_decoders,
726            sensitivity_selection,
727            sensitivity_precision,
728            precision_softmax_temp,
729        )
730
731        selection_mod = FixedSharedSelection()
732
733        contrastive_loss = ContrastiveLoss(
734            torch.tensor([1]).log(), "mean", learn_logit_scale
735        )
736
737        loss_mod = GWLossesBayesian(
738            gw_mod,
739            selection_mod,
740            domain_mods,
741            loss_coefs,
742            contrastive_loss,
743            use_normalized_constrastive,
744        )
745
746        super().__init__(
747            gw_mod,
748            selection_mod,
749            loss_mod,
750            optim_lr,
751            optim_weight_decay,
752            scheduler_args,
753        )
754
755    def forward(  # type: ignore
756        self,
757        latent_domains: LatentsDomainGroupsT,
758    ) -> GWPredictions:
759        """
760        Computes demi-cycles, cycles, and translations.
761
762        Args:
763            latent_domains (`LatentsT`): Groups of domains for the computation.
764
765        Returns:
766            `GWPredictions`: the predictions on the batch.
767        """
768        return GWPredictions(
769            demi_cycles=batch_demi_cycles(
770                self.gw_mod, self.selection_mod, latent_domains
771            ),
772            cycles=batch_cycles(
773                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
774            ),
775            translations=batch_translations(
776                self.gw_mod, self.selection_mod, latent_domains
777            ),
778            **super().forward(latent_domains),
779        )
780
781
782def pretrained_global_workspace(
783    checkpoint_path: str | Path,
784    domain_mods: Mapping[str, DomainModule],
785    gw_encoders: Mapping[str, Module],
786    gw_decoders: Mapping[str, Module],
787    workspace_dim: int,
788    loss_coefs: LossCoefs,
789    contrastive_fn: ContrastiveLossType,
790    **kwargs,
791) -> GlobalWorkspace2Domains:
792    """
793    Load a `GlobalWorkspace` flavor of `GlobalWorkspaceBase` from a checkpoint.
794
795    Args:
796        checkpoint_path (`str | Path`): path to checkpoint
797        domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
798            connected to the GW. Keys are domain names, values are the
799            `DomainModule`.
800        gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
801            name to a `torch.nn.Module` class which role is to encode a
802            unimodal latent representations into a GW representation (pre fusion).
803        gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
804            name to a `torch.nn.Module` class which role is to decode a
805            GW representation into a unimodal latent representations.
806        workspace_dim (`int`): dimension of the GW.
807        loss_coefs (`LossCoefs`): loss coefficients
808        contrastive_loss (`ContrastiveLossType`): a contrastive loss
809            function used for alignment. `learn_logit_scale` will not affect custom
810            contrastive losses.
811        **kwargs: additional arguments to pass to
812            `GlobalWorkspace.load_from_checkpoint`.
813
814    Returns:
815        `GlobalWorkspace`: the pretrained `GlobalWorkspace`.
816
817    Raises:
818        `TypeError`: if loaded type is not `GlobalWorkspace`.
819    """
820    domain_mods = freeze_domain_modules(domain_mods)
821    gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
822    selection_mod = SingleDomainSelection()
823    loss_mod = GWLosses2Domains(
824        gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_fn
825    )
826
827    gw = GlobalWorkspace2Domains.load_from_checkpoint(
828        checkpoint_path,
829        gw_mod=gw_mod,
830        selection_mid=selection_mod,
831        loss_coefs=loss_coefs,
832        loss_mod=loss_mod,
833        **kwargs,
834    )
835    if not isinstance(gw, GlobalWorkspace2Domains):
836        raise TypeError("model should be of type GlobalWorkspace")
837    return gw
class SchedulerArgs(typing.TypedDict):
45class SchedulerArgs(TypedDict, total=False):
46    """TypedDict of arguments passed to the OneCycle scheduler"""
47
48    max_lr: float
49    """Maximum learning rate"""
50
51    total_steps: int
52    """Total number of steps"""

TypedDict of arguments passed to the OneCycle scheduler

max_lr: float

Maximum learning rate

total_steps: int

Total number of steps

class GWPredictionsBase(typing.TypedDict):
55class GWPredictionsBase(TypedDict):
56    """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`"""
57
58    states: dict[str, torch.Tensor]
59    """
60    GW state representation from domain groups with only one domain.
61    The key represent the domain's name.
62    """

TypedDict of the output given when calling GlobalWorkspaceBase.predict

states: dict[str, torch.Tensor]

GW state representation from domain groups with only one domain. The key represent the domain's name.

class GlobalWorkspaceBase(typing.Generic[~_T_gw_mod, ~_T_selection_mod, ~_T_loss_mod], lightning.pytorch.core.module.LightningModule):
 70class GlobalWorkspaceBase(
 71    Generic[_T_gw_mod, _T_selection_mod, _T_loss_mod], LightningModule
 72):
 73    """
 74    Global Workspace Lightning Module.
 75
 76    This is the base class to build the Global Workspace.
 77    """
 78
 79    def __init__(
 80        self,
 81        gw_mod: _T_gw_mod,
 82        selection_mod: _T_selection_mod,
 83        loss_mod: _T_loss_mod,
 84        optim_lr: float = 1e-3,
 85        optim_weight_decay: float = 0.0,
 86        scheduler_args: SchedulerArgs | None = None,
 87    ) -> None:
 88        """
 89        Initializes a GW
 90
 91        Args:
 92            gw_mod (`GWModuleBase`): the GWModule
 93            selection_mod (`SelectionBase`): selection module
 94            loss_mod (`GWLossesBase`): module to compute the GW losses.
 95            optim_lr (`float`): learning rate
 96            optim_weight_decay (`float`): weight decay
 97            scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define
 98                scheduler parameters.
 99        """
100        super().__init__()
101        self.save_hyperparameters(
102            ignore=[
103                "gw_mod",
104                "selection_mod",
105                "domain_mods",
106                "loss_mod",
107                "domain_descriptions",
108                "contrastive_loss",
109                "cont_loss_bayesian",
110                "gw_encoders",
111                "gw_decoders",
112            ]
113        )
114
115        self.gw_mod = gw_mod
116        """ a `GWModuleBase` implementation."""
117
118        self.selection_mod = selection_mod
119        """A `SelectionBase` implementation."""
120
121        self.loss_mod = loss_mod
122        """The module that computes losses of the GW"""
123
124        self.optim_lr = optim_lr
125        self.optim_weight_decay = optim_weight_decay
126        self.scheduler_args = SchedulerArgs(max_lr=optim_lr, total_steps=1)
127        if scheduler_args is not None:
128            self.scheduler_args.update(scheduler_args)
129
130    @property
131    def domain_mods(self) -> Mapping[str, DomainModule]:
132        return self.gw_mod.domain_mods
133
134    @property
135    def workspace_dim(self) -> int:
136        """Dimension of the GW."""
137        return self.gw_mod.workspace_dim
138
139    def encode_and_fuse(
140        self, x: LatentsDomainGroupsT, selection_module: SelectionBase
141    ) -> dict[frozenset[str], torch.Tensor]:
142        """
143        Encode a group of latent representations into the GW representation.
144
145        Args:
146            x (`LatentsDomainGroupsT`): the input domain representations.
147            selection_scores (`Mapping[str, torch.Tensor]`):
148
149        Returns:
150            `dict[frozenset[str], torch.Tensor]`: the GW representations.
151        """
152        return {
153            domains: self.gw_mod.encode_and_fuse(latents, selection_module)
154            for domains, latents in x.items()
155        }
156
157    def encode(self, x: LatentsDomainGroupsT) -> LatentsDomainGroupsDT:
158        """
159        Encode a group of latent representations into the pre-fusion GW representation.
160
161        Args:
162            x (`LatentsDomainGroupsT`): the input domain representations.
163
164        Returns:
165            `LatensDomainGroupsDT`: the GW representations.
166        """
167        return {domains: self.gw_mod.encode(latents) for domains, latents in x.items()}
168
169    def fuse(
170        self,
171        x: LatentsDomainGroupsT,
172        selection_scores: Mapping[frozenset[str], Mapping[str, torch.Tensor]],
173    ) -> dict[frozenset[str], torch.Tensor]:
174        """
175        Fuses a group of latent representations into the GW representation.
176
177        Args:
178            x (`LatentsDomainGroupsT`): the pre-fusion latent representations
179            selection_scores (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`):
180                selection scores for each group
181
182        Returns:
183            `dict[frozenset[str], torch.Tensor]`: GW representation of each group
184        """
185        return {
186            domains: self.gw_mod.fuse(latents, selection_scores[domains])
187            for domains, latents in x.items()
188        }
189
190    def decode(
191        self,
192        z: Mapping[frozenset[str], torch.Tensor],
193        domains: Iterable[str] | None = None,
194    ) -> LatentsDomainGroupsDT:
195        """
196        Decode the group GW representation into given `domains`.
197
198        Args:
199            z (`torch.Tensor`): the GW representation.
200            domains (`Iterable[str]`): iterable of domains to decode.
201
202        Returns:
203            `dict[str, torch.Tensor]`: the decoded unimodal representations.
204        """
205        return {
206            domain_names: self.gw_mod.decode(gw_rep, domains)
207            for domain_names, gw_rep in z.items()
208        }
209
210    def forward(  # type: ignore
211        self,
212        latent_domains: LatentsDomainGroupsT,
213    ) -> GWPredictionsBase:
214        """
215        Computes demi-cycles, cycles, and translations.
216
217        Args:
218            latent_domains (`LatentsT`): Groups of domains for the computation.
219
220        Returns:
221            `GWPredictionsBase`: the predictions on the batch.
222        """
223
224        return GWPredictionsBase(states=self.batch_gw_states(latent_domains))
225
226    def batch_gw_states(
227        self, latent_domains: LatentsDomainGroupsT
228    ) -> dict[str, torch.Tensor]:
229        """
230        Comptues GW states of a batch of groups of domains.
231
232        Args:
233            latent_domains (`LatentsT`): the batch of groups of domains
234
235        Returns:
236            `dict[str, torch.Tensor]`: states for each domain.
237        """
238        predictions: dict[str, torch.Tensor] = {}
239        for domains, latents in latent_domains.items():
240            if len(domains) > 1:
241                continue
242            domain_name = list(domains)[0]
243            z = self.gw_mod.encode_and_fuse(
244                latents, selection_module=self.selection_mod
245            )
246            predictions[domain_name] = z
247        return predictions
248
249    def encode_domain(self, domain: Any, name: str) -> torch.Tensor:
250        """
251        Encodes a domain from the domain data into the unimodal representation.
252
253        This is a convenient proxy for the `DomainModule.encode` method and is
254        equivalent to:
255        ```python
256        self.domain_mods[name].encode(domain)
257        ```
258
259        Args:
260            domain (`Any`): the domain data
261            name (`str`): domain name to encode
262
263        Returns:
264            `torch.Tensor`: the domain's unimodal representation.
265        """
266        return self.domain_mods[name].encode(domain)
267
268    def encode_domains(self, batch: RawDomainGroupsT) -> LatentsDomainGroupsDT:
269        """
270        Encode all domains in the batch.
271
272        Args:
273            batch (`RawDomainGroupsT`): the batch of
274                domain groups with raw unimodal data to encode into groups of latent
275                representations.
276
277        Returns:
278            `LatentsDomainGroupsDT`: the domains' unimodal representations.
279        """
280        return {
281            domains: {
282                name: self.domain_mods[name].encode(domain)
283                for name, domain in data.items()
284            }
285            for domains, data in batch.items()
286        }
287
288    def decode_domain(self, domain: torch.Tensor, name: str) -> Any:
289        """
290        Decodes a domain from the unimodal representation into the domain data.
291
292        This is a convenient proxy for the `DomainModule.encode` method and is
293        equivalent to:
294        ```python
295        self.domain_mods[name].decode(domain)
296        ```
297
298        Args:
299            domain (`torch.Tensor`): the domain data
300            name (`str`): domain name to encode
301
302        Returns:
303            `Any`: the domain's raw data.
304        """
305        return self.domain_mods[name].decode(domain)
306
307    def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroupsDT:
308        """
309        Decodes all domains in the batch.
310
311        Args:
312            batch (`LatentsDomainGroupsT`): the batch of
313                domain groups with unimodal latent representation to decode into
314                groups of raw data.
315
316        Returns:
317            `LatentsDomainGroupsDT`: the domains' raw data.
318        """
319        return {
320            domains: {
321                name: self.domain_mods[name].decode(domain)
322                for name, domain in latents.items()
323            }
324            for domains, latents in latents_domain.items()
325        }
326
327    def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor:
328        """
329        The generic step used in `training_step`, `validation_step` and
330        `test_step`.
331
332        Args:
333            batch (`RawDomainGroupsT`): the batch of groups of raw unimodal data.
334            mode (`ModelModeT`):
335
336        Returns:
337            `torch.Tensor`: the loss to train on.
338        """
339        domain_latents = self.encode_domains(batch)
340        batch_size = groups_batch_size(domain_latents)
341
342        loss_output = self.loss_mod.step(domain_latents, mode)
343
344        for name, metric in loss_output.all.items():
345            self.log(
346                f"{mode}/{name}",
347                metric,
348                batch_size=batch_size,
349                add_dataloader_idx=False,
350            )
351
352        return loss_output.loss
353
354    def validation_step(  # type: ignore
355        self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0
356    ) -> torch.Tensor:
357        """Validation step used by lightning"""
358
359        batch = {frozenset(data.keys()): data}
360        for domain in data:
361            batch[frozenset([domain])] = {domain: data[domain]}
362        if dataloader_idx == 0:
363            return self.generic_step(batch, mode="val")
364        return self.generic_step(batch, mode="val/ood")
365
366    def test_step(  # type: ignore
367        self, data: Mapping[str, Any], batch_idx: int, dataloader_idx: int = 0
368    ) -> torch.Tensor:
369        """Test step used by lightning"""
370
371        batch = {frozenset(data.keys()): data}
372        for domain in data:
373            batch[frozenset([domain])] = {domain: data[domain]}
374        if dataloader_idx == 0:
375            return self.generic_step(batch, mode="test")
376        return self.generic_step(batch, mode="test/ood")
377
378    def training_step(  # type: ignore
379        self, batch: Mapping[frozenset[str], Mapping[str, Any]], batch_idx: int
380    ) -> torch.Tensor:
381        """Training step used by lightning"""
382
383        return self.generic_step(batch, mode="train")
384
385    def predict_step(  # type: ignore
386        self, data: Mapping[str, Any], batch_idx: int
387    ) -> GWPredictionsBase:
388        """Predict step used by lightning"""
389
390        batch = {frozenset(data.keys()): data}
391        for domain in data:
392            batch[frozenset([domain])] = {domain: data[domain]}
393
394        domain_latents = self.encode_domains(batch)
395        return self.forward(domain_latents)
396
397    def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
398        """
399        Configure models optimizers.
400
401        Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate
402        scheduler.
403        """
404
405        optimizer = torch.optim.AdamW(
406            self.parameters(),
407            lr=self.optim_lr,
408            weight_decay=self.optim_weight_decay,
409        )
410
411        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
412
413        return {
414            "optimizer": optimizer,
415            "lr_scheduler": {
416                "scheduler": lr_scheduler,
417                "interval": "step",
418            },
419        }

Global Workspace Lightning Module.

This is the base class to build the Global Workspace.

gw_mod

a GWModuleBase implementation.

selection_mod

A SelectionBase implementation.

loss_mod

The module that computes losses of the GW

optim_lr
optim_weight_decay
scheduler_args
domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule]
130    @property
131    def domain_mods(self) -> Mapping[str, DomainModule]:
132        return self.gw_mod.domain_mods
workspace_dim: int
134    @property
135    def workspace_dim(self) -> int:
136        """Dimension of the GW."""
137        return self.gw_mod.workspace_dim

Dimension of the GW.

def encode_and_fuse( self, x: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]], selection_module: shimmer.modules.selection.SelectionBase) -> dict[frozenset[str], torch.Tensor]:
139    def encode_and_fuse(
140        self, x: LatentsDomainGroupsT, selection_module: SelectionBase
141    ) -> dict[frozenset[str], torch.Tensor]:
142        """
143        Encode a group of latent representations into the GW representation.
144
145        Args:
146            x (`LatentsDomainGroupsT`): the input domain representations.
147            selection_scores (`Mapping[str, torch.Tensor]`):
148
149        Returns:
150            `dict[frozenset[str], torch.Tensor]`: the GW representations.
151        """
152        return {
153            domains: self.gw_mod.encode_and_fuse(latents, selection_module)
154            for domains, latents in x.items()
155        }

Encode a group of latent representations into the GW representation.

Arguments:
  • x (LatentsDomainGroupsT): the input domain representations.
  • selection_scores (Mapping[str, torch.Tensor]):
Returns:

dict[frozenset[str], torch.Tensor]: the GW representations.

def encode( self, x: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[frozenset[str], dict[str, torch.Tensor]]:
157    def encode(self, x: LatentsDomainGroupsT) -> LatentsDomainGroupsDT:
158        """
159        Encode a group of latent representations into the pre-fusion GW representation.
160
161        Args:
162            x (`LatentsDomainGroupsT`): the input domain representations.
163
164        Returns:
165            `LatensDomainGroupsDT`: the GW representations.
166        """
167        return {domains: self.gw_mod.encode(latents) for domains, latents in x.items()}

Encode a group of latent representations into the pre-fusion GW representation.

Arguments:
  • x (LatentsDomainGroupsT): the input domain representations.
Returns:

LatensDomainGroupsDT: the GW representations.

def fuse( self, x: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]], selection_scores: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[frozenset[str], torch.Tensor]:
169    def fuse(
170        self,
171        x: LatentsDomainGroupsT,
172        selection_scores: Mapping[frozenset[str], Mapping[str, torch.Tensor]],
173    ) -> dict[frozenset[str], torch.Tensor]:
174        """
175        Fuses a group of latent representations into the GW representation.
176
177        Args:
178            x (`LatentsDomainGroupsT`): the pre-fusion latent representations
179            selection_scores (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`):
180                selection scores for each group
181
182        Returns:
183            `dict[frozenset[str], torch.Tensor]`: GW representation of each group
184        """
185        return {
186            domains: self.gw_mod.fuse(latents, selection_scores[domains])
187            for domains, latents in x.items()
188        }

Fuses a group of latent representations into the GW representation.

Arguments:
  • x (LatentsDomainGroupsT): the pre-fusion latent representations
  • selection_scores (Mapping[frozenset[str], Mapping[str, torch.Tensor]]): selection scores for each group
Returns:

dict[frozenset[str], torch.Tensor]: GW representation of each group

def decode( self, z: collections.abc.Mapping[frozenset[str], torch.Tensor], domains: collections.abc.Iterable[str] | None = None) -> dict[frozenset[str], dict[str, torch.Tensor]]:
190    def decode(
191        self,
192        z: Mapping[frozenset[str], torch.Tensor],
193        domains: Iterable[str] | None = None,
194    ) -> LatentsDomainGroupsDT:
195        """
196        Decode the group GW representation into given `domains`.
197
198        Args:
199            z (`torch.Tensor`): the GW representation.
200            domains (`Iterable[str]`): iterable of domains to decode.
201
202        Returns:
203            `dict[str, torch.Tensor]`: the decoded unimodal representations.
204        """
205        return {
206            domain_names: self.gw_mod.decode(gw_rep, domains)
207            for domain_names, gw_rep in z.items()
208        }

Decode the group GW representation into given domains.

Arguments:
  • z (torch.Tensor): the GW representation.
  • domains (Iterable[str]): iterable of domains to decode.
Returns:

dict[str, torch.Tensor]: the decoded unimodal representations.

def batch_gw_states( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
226    def batch_gw_states(
227        self, latent_domains: LatentsDomainGroupsT
228    ) -> dict[str, torch.Tensor]:
229        """
230        Comptues GW states of a batch of groups of domains.
231
232        Args:
233            latent_domains (`LatentsT`): the batch of groups of domains
234
235        Returns:
236            `dict[str, torch.Tensor]`: states for each domain.
237        """
238        predictions: dict[str, torch.Tensor] = {}
239        for domains, latents in latent_domains.items():
240            if len(domains) > 1:
241                continue
242            domain_name = list(domains)[0]
243            z = self.gw_mod.encode_and_fuse(
244                latents, selection_module=self.selection_mod
245            )
246            predictions[domain_name] = z
247        return predictions

Comptues GW states of a batch of groups of domains.

Arguments:
  • latent_domains (LatentsT): the batch of groups of domains
Returns:

dict[str, torch.Tensor]: states for each domain.

def encode_domain(self, domain: Any, name: str) -> torch.Tensor:
249    def encode_domain(self, domain: Any, name: str) -> torch.Tensor:
250        """
251        Encodes a domain from the domain data into the unimodal representation.
252
253        This is a convenient proxy for the `DomainModule.encode` method and is
254        equivalent to:
255        ```python
256        self.domain_mods[name].encode(domain)
257        ```
258
259        Args:
260            domain (`Any`): the domain data
261            name (`str`): domain name to encode
262
263        Returns:
264            `torch.Tensor`: the domain's unimodal representation.
265        """
266        return self.domain_mods[name].encode(domain)

Encodes a domain from the domain data into the unimodal representation.

This is a convenient proxy for the DomainModule.encode method and is equivalent to:

self.domain_mods[name].encode(domain)
Arguments:
  • domain (Any): the domain data
  • name (str): domain name to encode
Returns:

torch.Tensor: the domain's unimodal representation.

def encode_domains( self, batch: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, typing.Any]]) -> dict[frozenset[str], dict[str, torch.Tensor]]:
268    def encode_domains(self, batch: RawDomainGroupsT) -> LatentsDomainGroupsDT:
269        """
270        Encode all domains in the batch.
271
272        Args:
273            batch (`RawDomainGroupsT`): the batch of
274                domain groups with raw unimodal data to encode into groups of latent
275                representations.
276
277        Returns:
278            `LatentsDomainGroupsDT`: the domains' unimodal representations.
279        """
280        return {
281            domains: {
282                name: self.domain_mods[name].encode(domain)
283                for name, domain in data.items()
284            }
285            for domains, data in batch.items()
286        }

Encode all domains in the batch.

Arguments:
  • batch (RawDomainGroupsT): the batch of domain groups with raw unimodal data to encode into groups of latent representations.
Returns:

LatentsDomainGroupsDT: the domains' unimodal representations.

def decode_domain(self, domain: torch.Tensor, name: str) -> Any:
288    def decode_domain(self, domain: torch.Tensor, name: str) -> Any:
289        """
290        Decodes a domain from the unimodal representation into the domain data.
291
292        This is a convenient proxy for the `DomainModule.encode` method and is
293        equivalent to:
294        ```python
295        self.domain_mods[name].decode(domain)
296        ```
297
298        Args:
299            domain (`torch.Tensor`): the domain data
300            name (`str`): domain name to encode
301
302        Returns:
303            `Any`: the domain's raw data.
304        """
305        return self.domain_mods[name].decode(domain)

Decodes a domain from the unimodal representation into the domain data.

This is a convenient proxy for the DomainModule.encode method and is equivalent to:

self.domain_mods[name].decode(domain)
Arguments:
  • domain (torch.Tensor): the domain data
  • name (str): domain name to encode
Returns:

Any: the domain's raw data.

def decode_domains( self, latents_domain: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> dict[frozenset[str], dict[str, typing.Any]]:
307    def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroupsDT:
308        """
309        Decodes all domains in the batch.
310
311        Args:
312            batch (`LatentsDomainGroupsT`): the batch of
313                domain groups with unimodal latent representation to decode into
314                groups of raw data.
315
316        Returns:
317            `LatentsDomainGroupsDT`: the domains' raw data.
318        """
319        return {
320            domains: {
321                name: self.domain_mods[name].decode(domain)
322                for name, domain in latents.items()
323            }
324            for domains, latents in latents_domain.items()
325        }

Decodes all domains in the batch.

Arguments:
  • batch (LatentsDomainGroupsT): the batch of domain groups with unimodal latent representation to decode into groups of raw data.
Returns:

LatentsDomainGroupsDT: the domains' raw data.

def generic_step( self, batch: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, typing.Any]], mode: Literal['train', 'val', 'test', 'val/ood', 'test/ood']) -> torch.Tensor:
327    def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor:
328        """
329        The generic step used in `training_step`, `validation_step` and
330        `test_step`.
331
332        Args:
333            batch (`RawDomainGroupsT`): the batch of groups of raw unimodal data.
334            mode (`ModelModeT`):
335
336        Returns:
337            `torch.Tensor`: the loss to train on.
338        """
339        domain_latents = self.encode_domains(batch)
340        batch_size = groups_batch_size(domain_latents)
341
342        loss_output = self.loss_mod.step(domain_latents, mode)
343
344        for name, metric in loss_output.all.items():
345            self.log(
346                f"{mode}/{name}",
347                metric,
348                batch_size=batch_size,
349                add_dataloader_idx=False,
350            )
351
352        return loss_output.loss

The generic step used in training_step, validation_step and test_step.

Arguments:
  • batch (RawDomainGroupsT): the batch of groups of raw unimodal data.
  • mode (ModelModeT):
Returns:

torch.Tensor: the loss to train on.

Inherited Members
lightning.pytorch.core.module.LightningModule
LightningModule
CHECKPOINT_HYPER_PARAMS_KEY
CHECKPOINT_HYPER_PARAMS_NAME
CHECKPOINT_HYPER_PARAMS_TYPE
optimizers
lr_schedulers
trainer
fabric
example_input_array
current_epoch
global_step
global_rank
local_rank
on_gpu
automatic_optimization
strict_loading
logger
loggers
print
log
log_dict
all_gather
forward
training_step
validation_step
test_step
predict_step
configure_callbacks
configure_optimizers
manual_backward
backward
toggle_optimizer
untoggle_optimizer
clip_gradients
configure_gradient_clipping
lr_scheduler_step
optimizer_step
optimizer_zero_grad
freeze
unfreeze
to_onnx
to_torchscript
load_from_checkpoint
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
dtype
device
to
cuda
cpu
type
float
double
half
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
save_hyperparameters
hparams
hparams_initial
lightning.pytorch.core.hooks.ModelHooks
on_fit_start
on_fit_end
on_train_start
on_train_end
on_validation_start
on_validation_end
on_test_start
on_test_end
on_predict_start
on_predict_end
on_train_batch_start
on_train_batch_end
on_validation_batch_start
on_validation_batch_end
on_test_batch_start
on_test_batch_end
on_predict_batch_start
on_predict_batch_end
on_validation_model_zero_grad
on_validation_model_eval
on_validation_model_train
on_test_model_eval
on_test_model_train
on_predict_model_eval
on_train_epoch_start
on_train_epoch_end
on_validation_epoch_start
on_validation_epoch_end
on_test_epoch_start
on_test_epoch_end
on_predict_epoch_start
on_predict_epoch_end
on_before_zero_grad
on_before_backward
on_after_backward
on_before_optimizer_step
configure_sharded_model
configure_model
lightning.pytorch.core.hooks.DataHooks
prepare_data_per_node
allow_zero_length_dataloader_with_multiple_devices
prepare_data
setup
teardown
train_dataloader
test_dataloader
val_dataloader
predict_dataloader
transfer_batch_to_device
on_before_batch_transfer
on_after_batch_transfer
lightning.pytorch.core.hooks.CheckpointHooks
on_load_checkpoint
on_save_checkpoint
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
ipu
xpu
bfloat16
to_empty
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def freeze_domain_modules( domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule]) -> dict[str, shimmer.modules.domain.DomainModule]:
422def freeze_domain_modules(
423    domain_mods: Mapping[str, DomainModule],
424) -> dict[str, DomainModule]:
425    """
426    Freezes weights and set to eval mode the domain modules.
427
428    .. note::
429        The output is casted as `dict[str, DomainModule]` type for better
430        auto-completion, but is actually a torch `ModuleDict`.
431
432    Args:
433        domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze
434
435    Returns:
436        `ModuleDict`: frozen modules.
437    """
438
439    for mod in domain_mods.values():
440        mod.freeze()
441    # Cast for better auto-completion at the expense of ModuleDict
442    return cast(dict[str, DomainModule], ModuleDict(domain_mods))

Freezes weights and set to eval mode the domain modules.

The output is casted as dict[str, DomainModule] type for better auto-completion, but is actually a torch ModuleDict.

Arguments:
  • domain_mods (Mapping[str, DomainModule]): mapping of domain modules to freeze
Returns:

ModuleDict: frozen modules.

class GWPredictions(builtins.dict):
445class GWPredictions(GWPredictionsBase):
446    """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`"""
447
448    demi_cycles: dict[str, torch.Tensor]
449    """
450    Demi-cycle predictions of the model for each domain. Only computed on domain
451    groups with only one domain.
452    """
453
454    cycles: dict[tuple[str, str], torch.Tensor]
455    """
456    Cycle predictions of the model from one domain through another one.
457    Only computed on domain groups with more than one domain.
458    The keys are tuple with start domain and intermediary domain.
459    """
460
461    translations: dict[tuple[str, str], torch.Tensor]
462    """
463    Translation predictions of the model from one domain through another one.
464
465    Only computed on domain groups with more than one domain.
466    The keys are tuples with start domain and target domain.
467    """

TypedDict of the output given when calling GlobalWorkspaceBase.predict

demi_cycles: dict[str, torch.Tensor]

Demi-cycle predictions of the model for each domain. Only computed on domain groups with only one domain.

cycles: dict[tuple[str, str], torch.Tensor]

Cycle predictions of the model from one domain through another one. Only computed on domain groups with more than one domain. The keys are tuple with start domain and intermediary domain.

translations: dict[tuple[str, str], torch.Tensor]

Translation predictions of the model from one domain through another one.

Only computed on domain groups with more than one domain. The keys are tuples with start domain and target domain.

states: dict[str, torch.Tensor]
Inherited Members
builtins.dict
get
setdefault
pop
popitem
keys
items
values
update
fromkeys
clear
copy
470class GlobalWorkspace2Domains(
471    GlobalWorkspaceBase[GWModule, SingleDomainSelection, GWLosses2Domains]
472):
473    """
474    A simple 2-domains max flavor of GlobalWorkspaceBase.
475
476    This is used to simplify a Global Workspace instanciation and only overrides the
477    `__init__` method.
478    """
479
480    def __init__(
481        self,
482        domain_mods: Mapping[str, DomainModule],
483        gw_encoders: Mapping[str, Module],
484        gw_decoders: Mapping[str, Module],
485        workspace_dim: int,
486        loss_coefs: LossCoefs,
487        optim_lr: float = 1e-3,
488        optim_weight_decay: float = 0.0,
489        scheduler_args: SchedulerArgs | None = None,
490        learn_logit_scale: bool = False,
491        contrastive_loss: ContrastiveLossType | None = None,
492    ) -> None:
493        """
494        Initializes a Global Workspace
495
496        Args:
497            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
498                connected to the GW. Keys are domain names, values are the
499                `DomainModule`.
500            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
501                name to a `torch.nn.Module` class which role is to encode a
502                unimodal latent representations into a GW representation (pre fusion).
503            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
504                name to a `torch.nn.Module` class which role is to decode a
505                GW representation into a unimodal latent representations.
506            workspace_dim (`int`): dimension of the GW.
507            loss_coefs (`LossCoefs`): loss coefficients
508            optim_lr (`float`): learning rate
509            optim_weight_decay (`float`): weight decay
510            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
511            learn_logit_scale (`bool`): whether to learn the contrastive learning
512                contrastive loss when using the default contrastive loss.
513            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
514                function used for alignment. `learn_logit_scale` will not affect custom
515                contrastive losses.
516        """
517        domain_mods = freeze_domain_modules(domain_mods)
518
519        gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
520        if contrastive_loss is None:
521            contrastive_loss = ContrastiveLoss(
522                torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
523            )
524        selection_mod = SingleDomainSelection()
525        loss_mod = GWLosses2Domains(
526            gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
527        )
528
529        super().__init__(
530            gw_mod,
531            selection_mod,
532            loss_mod,
533            optim_lr,
534            optim_weight_decay,
535            scheduler_args,
536        )
537
538    def forward(  # type: ignore
539        self,
540        latent_domains: LatentsDomainGroupsT,
541    ) -> GWPredictions:
542        """
543        Computes demi-cycles, cycles, and translations.
544
545        Args:
546            latent_domains (`LatentsT`): Groups of domains for the computation.
547
548        Returns:
549            `GWPredictions`: the predictions on the batch.
550        """
551        return GWPredictions(
552            demi_cycles=batch_demi_cycles(
553                self.gw_mod, self.selection_mod, latent_domains
554            ),
555            cycles=batch_cycles(
556                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
557            ),
558            translations=batch_translations(
559                self.gw_mod, self.selection_mod, latent_domains
560            ),
561            **super().forward(latent_domains),
562        )

A simple 2-domains max flavor of GlobalWorkspaceBase.

This is used to simplify a Global Workspace instanciation and only overrides the __init__ method.

GlobalWorkspace2Domains( domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], gw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], gw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], workspace_dim: int, loss_coefs: shimmer.modules.losses.LossCoefs, optim_lr: float = 0.001, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, learn_logit_scale: bool = False, contrastive_loss: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput] | None = None)
480    def __init__(
481        self,
482        domain_mods: Mapping[str, DomainModule],
483        gw_encoders: Mapping[str, Module],
484        gw_decoders: Mapping[str, Module],
485        workspace_dim: int,
486        loss_coefs: LossCoefs,
487        optim_lr: float = 1e-3,
488        optim_weight_decay: float = 0.0,
489        scheduler_args: SchedulerArgs | None = None,
490        learn_logit_scale: bool = False,
491        contrastive_loss: ContrastiveLossType | None = None,
492    ) -> None:
493        """
494        Initializes a Global Workspace
495
496        Args:
497            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
498                connected to the GW. Keys are domain names, values are the
499                `DomainModule`.
500            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
501                name to a `torch.nn.Module` class which role is to encode a
502                unimodal latent representations into a GW representation (pre fusion).
503            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
504                name to a `torch.nn.Module` class which role is to decode a
505                GW representation into a unimodal latent representations.
506            workspace_dim (`int`): dimension of the GW.
507            loss_coefs (`LossCoefs`): loss coefficients
508            optim_lr (`float`): learning rate
509            optim_weight_decay (`float`): weight decay
510            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
511            learn_logit_scale (`bool`): whether to learn the contrastive learning
512                contrastive loss when using the default contrastive loss.
513            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
514                function used for alignment. `learn_logit_scale` will not affect custom
515                contrastive losses.
516        """
517        domain_mods = freeze_domain_modules(domain_mods)
518
519        gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
520        if contrastive_loss is None:
521            contrastive_loss = ContrastiveLoss(
522                torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
523            )
524        selection_mod = SingleDomainSelection()
525        loss_mod = GWLosses2Domains(
526            gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
527        )
528
529        super().__init__(
530            gw_mod,
531            selection_mod,
532            loss_mod,
533            optim_lr,
534            optim_weight_decay,
535            scheduler_args,
536        )

Initializes a Global Workspace

Arguments:
  • domain_mods (Mapping[str, DomainModule]): mapping of the domains connected to the GW. Keys are domain names, values are the DomainModule.
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain name to a torch.nn.Module class which role is to encode a unimodal latent representations into a GW representation (pre fusion).
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain name to a torch.nn.Module class which role is to decode a GW representation into a unimodal latent representations.
  • workspace_dim (int): dimension of the GW.
  • loss_coefs (LossCoefs): loss coefficients
  • optim_lr (float): learning rate
  • optim_weight_decay (float): weight decay
  • scheduler_args (SchedulerArgs | None): optimization scheduler's arguments
  • learn_logit_scale (bool): whether to learn the contrastive learning contrastive loss when using the default contrastive loss.
  • contrastive_loss (ContrastiveLossType | None): a contrastive loss function used for alignment. learn_logit_scale will not affect custom contrastive losses.
def forward( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> GWPredictions:
538    def forward(  # type: ignore
539        self,
540        latent_domains: LatentsDomainGroupsT,
541    ) -> GWPredictions:
542        """
543        Computes demi-cycles, cycles, and translations.
544
545        Args:
546            latent_domains (`LatentsT`): Groups of domains for the computation.
547
548        Returns:
549            `GWPredictions`: the predictions on the batch.
550        """
551        return GWPredictions(
552            demi_cycles=batch_demi_cycles(
553                self.gw_mod, self.selection_mod, latent_domains
554            ),
555            cycles=batch_cycles(
556                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
557            ),
558            translations=batch_translations(
559                self.gw_mod, self.selection_mod, latent_domains
560            ),
561            **super().forward(latent_domains),
562        )

Computes demi-cycles, cycles, and translations.

Arguments:
  • latent_domains (LatentsT): Groups of domains for the computation.
Returns:

GWPredictions: the predictions on the batch.

Inherited Members
GlobalWorkspaceBase
gw_mod
selection_mod
loss_mod
optim_lr
optim_weight_decay
scheduler_args
domain_mods
workspace_dim
encode_and_fuse
encode
fuse
decode
batch_gw_states
encode_domain
encode_domains
decode_domain
decode_domains
generic_step
validation_step
test_step
training_step
predict_step
configure_optimizers
lightning.pytorch.core.module.LightningModule
CHECKPOINT_HYPER_PARAMS_KEY
CHECKPOINT_HYPER_PARAMS_NAME
CHECKPOINT_HYPER_PARAMS_TYPE
optimizers
lr_schedulers
trainer
fabric
example_input_array
current_epoch
global_step
global_rank
local_rank
on_gpu
automatic_optimization
strict_loading
logger
loggers
print
log
log_dict
all_gather
configure_callbacks
manual_backward
backward
toggle_optimizer
untoggle_optimizer
clip_gradients
configure_gradient_clipping
lr_scheduler_step
optimizer_step
optimizer_zero_grad
freeze
unfreeze
to_onnx
to_torchscript
load_from_checkpoint
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
dtype
device
to
cuda
cpu
type
float
double
half
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
save_hyperparameters
hparams
hparams_initial
lightning.pytorch.core.hooks.ModelHooks
on_fit_start
on_fit_end
on_train_start
on_train_end
on_validation_start
on_validation_end
on_test_start
on_test_end
on_predict_start
on_predict_end
on_train_batch_start
on_train_batch_end
on_validation_batch_start
on_validation_batch_end
on_test_batch_start
on_test_batch_end
on_predict_batch_start
on_predict_batch_end
on_validation_model_zero_grad
on_validation_model_eval
on_validation_model_train
on_test_model_eval
on_test_model_train
on_predict_model_eval
on_train_epoch_start
on_train_epoch_end
on_validation_epoch_start
on_validation_epoch_end
on_test_epoch_start
on_test_epoch_end
on_predict_epoch_start
on_predict_epoch_end
on_before_zero_grad
on_before_backward
on_after_backward
on_before_optimizer_step
configure_sharded_model
configure_model
lightning.pytorch.core.hooks.DataHooks
prepare_data_per_node
allow_zero_length_dataloader_with_multiple_devices
prepare_data
setup
teardown
train_dataloader
test_dataloader
val_dataloader
predict_dataloader
transfer_batch_to_device
on_before_batch_transfer
on_after_batch_transfer
lightning.pytorch.core.hooks.CheckpointHooks
on_load_checkpoint
on_save_checkpoint
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
ipu
xpu
bfloat16
to_empty
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
565class GlobalWorkspace(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]):
566    """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
567
568    This is used to simplify a Global Workspace instanciation and only overrides the
569    `__init__` method.
570    """
571
572    def __init__(
573        self,
574        domain_mods: Mapping[str, DomainModule],
575        gw_encoders: Mapping[str, Module],
576        gw_decoders: Mapping[str, Module],
577        workspace_dim: int,
578        loss_coefs: BroadcastLossCoefs,
579        selection_temperature: float = 0.2,
580        optim_lr: float = 1e-3,
581        optim_weight_decay: float = 0.0,
582        scheduler_args: SchedulerArgs | None = None,
583        learn_logit_scale: bool = False,
584        contrastive_loss: ContrastiveLossType | None = None,
585    ) -> None:
586        """
587        Initializes a Global Workspace
588
589        Args:
590            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
591                connected to the GW. Keys are domain names, values are the
592                `DomainModule`.
593            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
594                name to a `torch.nn.Module` class which role is to encode a
595                unimodal latent representations into a GW representation (pre fusion).
596            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
597                name to a `torch.nn.Module` class which role is to decode a
598                GW representation into a unimodal latent representations.
599            workspace_dim (`int`): dimension of the GW.
600            loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
601            selection_temperature (`float`): temperature value for the RandomSelection
602                module.
603            optim_lr (`float`): learning rate
604            optim_weight_decay (`float`): weight decay
605            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
606            learn_logit_scale (`bool`): whether to learn the contrastive learning
607                contrastive loss when using the default contrastive loss.
608            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
609                function used for alignment. `learn_logit_scale` will not affect custom
610                contrastive losses.
611        """
612        domain_mods = freeze_domain_modules(domain_mods)
613        gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
614
615        if contrastive_loss is None:
616            contrastive_loss = ContrastiveLoss(
617                torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
618            )
619
620        selection_mod = RandomSelection(selection_temperature)
621        loss_mod = GWLosses(
622            gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
623        )
624
625        super().__init__(
626            gw_mod,
627            selection_mod,
628            loss_mod,
629            optim_lr,
630            optim_weight_decay,
631            scheduler_args,
632        )
633
634    def forward(  # type: ignore
635        self,
636        latent_domains: LatentsDomainGroupsT,
637    ) -> GWPredictions:
638        """
639        Computes demi-cycles, cycles, and translations.
640
641        Args:
642            latent_domains (`LatentsT`): Groups of domains for the computation.
643
644        Returns:
645            `GWPredictions`: the predictions on the batch.
646        """
647        return GWPredictions(
648            demi_cycles=batch_demi_cycles(
649                self.gw_mod, self.selection_mod, latent_domains
650            ),
651            cycles=batch_cycles(
652                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
653            ),
654            translations=batch_translations(
655                self.gw_mod, self.selection_mod, latent_domains
656            ),
657            # TODO: add other combinations
658            **super().forward(latent_domains),
659        )

The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.

This is used to simplify a Global Workspace instanciation and only overrides the __init__ method.

GlobalWorkspace( domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], gw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], gw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], workspace_dim: int, loss_coefs: shimmer.modules.losses.BroadcastLossCoefs, selection_temperature: float = 0.2, optim_lr: float = 0.001, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, learn_logit_scale: bool = False, contrastive_loss: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput] | None = None)
572    def __init__(
573        self,
574        domain_mods: Mapping[str, DomainModule],
575        gw_encoders: Mapping[str, Module],
576        gw_decoders: Mapping[str, Module],
577        workspace_dim: int,
578        loss_coefs: BroadcastLossCoefs,
579        selection_temperature: float = 0.2,
580        optim_lr: float = 1e-3,
581        optim_weight_decay: float = 0.0,
582        scheduler_args: SchedulerArgs | None = None,
583        learn_logit_scale: bool = False,
584        contrastive_loss: ContrastiveLossType | None = None,
585    ) -> None:
586        """
587        Initializes a Global Workspace
588
589        Args:
590            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
591                connected to the GW. Keys are domain names, values are the
592                `DomainModule`.
593            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
594                name to a `torch.nn.Module` class which role is to encode a
595                unimodal latent representations into a GW representation (pre fusion).
596            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
597                name to a `torch.nn.Module` class which role is to decode a
598                GW representation into a unimodal latent representations.
599            workspace_dim (`int`): dimension of the GW.
600            loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
601            selection_temperature (`float`): temperature value for the RandomSelection
602                module.
603            optim_lr (`float`): learning rate
604            optim_weight_decay (`float`): weight decay
605            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
606            learn_logit_scale (`bool`): whether to learn the contrastive learning
607                contrastive loss when using the default contrastive loss.
608            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
609                function used for alignment. `learn_logit_scale` will not affect custom
610                contrastive losses.
611        """
612        domain_mods = freeze_domain_modules(domain_mods)
613        gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
614
615        if contrastive_loss is None:
616            contrastive_loss = ContrastiveLoss(
617                torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
618            )
619
620        selection_mod = RandomSelection(selection_temperature)
621        loss_mod = GWLosses(
622            gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
623        )
624
625        super().__init__(
626            gw_mod,
627            selection_mod,
628            loss_mod,
629            optim_lr,
630            optim_weight_decay,
631            scheduler_args,
632        )

Initializes a Global Workspace

Arguments:
  • domain_mods (Mapping[str, DomainModule]): mapping of the domains connected to the GW. Keys are domain names, values are the DomainModule.
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain name to a torch.nn.Module class which role is to encode a unimodal latent representations into a GW representation (pre fusion).
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain name to a torch.nn.Module class which role is to decode a GW representation into a unimodal latent representations.
  • workspace_dim (int): dimension of the GW.
  • loss_coefs (BroadcastLossCoefs): loss coefs for the losses.
  • selection_temperature (float): temperature value for the RandomSelection module.
  • optim_lr (float): learning rate
  • optim_weight_decay (float): weight decay
  • scheduler_args (SchedulerArgs | None): optimization scheduler's arguments
  • learn_logit_scale (bool): whether to learn the contrastive learning contrastive loss when using the default contrastive loss.
  • contrastive_loss (ContrastiveLossType | None): a contrastive loss function used for alignment. learn_logit_scale will not affect custom contrastive losses.
def forward( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> GWPredictions:
634    def forward(  # type: ignore
635        self,
636        latent_domains: LatentsDomainGroupsT,
637    ) -> GWPredictions:
638        """
639        Computes demi-cycles, cycles, and translations.
640
641        Args:
642            latent_domains (`LatentsT`): Groups of domains for the computation.
643
644        Returns:
645            `GWPredictions`: the predictions on the batch.
646        """
647        return GWPredictions(
648            demi_cycles=batch_demi_cycles(
649                self.gw_mod, self.selection_mod, latent_domains
650            ),
651            cycles=batch_cycles(
652                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
653            ),
654            translations=batch_translations(
655                self.gw_mod, self.selection_mod, latent_domains
656            ),
657            # TODO: add other combinations
658            **super().forward(latent_domains),
659        )

Computes demi-cycles, cycles, and translations.

Arguments:
  • latent_domains (LatentsT): Groups of domains for the computation.
Returns:

GWPredictions: the predictions on the batch.

Inherited Members
GlobalWorkspaceBase
gw_mod
selection_mod
loss_mod
optim_lr
optim_weight_decay
scheduler_args
domain_mods
workspace_dim
encode_and_fuse
encode
fuse
decode
batch_gw_states
encode_domain
encode_domains
decode_domain
decode_domains
generic_step
validation_step
test_step
training_step
predict_step
configure_optimizers
lightning.pytorch.core.module.LightningModule
CHECKPOINT_HYPER_PARAMS_KEY
CHECKPOINT_HYPER_PARAMS_NAME
CHECKPOINT_HYPER_PARAMS_TYPE
optimizers
lr_schedulers
trainer
fabric
example_input_array
current_epoch
global_step
global_rank
local_rank
on_gpu
automatic_optimization
strict_loading
logger
loggers
print
log
log_dict
all_gather
configure_callbacks
manual_backward
backward
toggle_optimizer
untoggle_optimizer
clip_gradients
configure_gradient_clipping
lr_scheduler_step
optimizer_step
optimizer_zero_grad
freeze
unfreeze
to_onnx
to_torchscript
load_from_checkpoint
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
dtype
device
to
cuda
cpu
type
float
double
half
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
save_hyperparameters
hparams
hparams_initial
lightning.pytorch.core.hooks.ModelHooks
on_fit_start
on_fit_end
on_train_start
on_train_end
on_validation_start
on_validation_end
on_test_start
on_test_end
on_predict_start
on_predict_end
on_train_batch_start
on_train_batch_end
on_validation_batch_start
on_validation_batch_end
on_test_batch_start
on_test_batch_end
on_predict_batch_start
on_predict_batch_end
on_validation_model_zero_grad
on_validation_model_eval
on_validation_model_train
on_test_model_eval
on_test_model_train
on_predict_model_eval
on_train_epoch_start
on_train_epoch_end
on_validation_epoch_start
on_validation_epoch_end
on_test_epoch_start
on_test_epoch_end
on_predict_epoch_start
on_predict_epoch_end
on_before_zero_grad
on_before_backward
on_after_backward
on_before_optimizer_step
configure_sharded_model
configure_model
lightning.pytorch.core.hooks.DataHooks
prepare_data_per_node
allow_zero_length_dataloader_with_multiple_devices
prepare_data
setup
teardown
train_dataloader
test_dataloader
val_dataloader
predict_dataloader
transfer_batch_to_device
on_before_batch_transfer
on_after_batch_transfer
lightning.pytorch.core.hooks.CheckpointHooks
on_load_checkpoint
on_save_checkpoint
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
ipu
xpu
bfloat16
to_empty
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
662class GlobalWorkspaceBayesian(
663    GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian]
664):
665    """
666    A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty
667    prediction.
668
669    This is used to simplify a Global Workspace instanciation and only overrides the
670    `__init__` method.
671    """
672
673    def __init__(
674        self,
675        domain_mods: Mapping[str, DomainModule],
676        gw_encoders: Mapping[str, Module],
677        gw_decoders: Mapping[str, Module],
678        workspace_dim: int,
679        loss_coefs: BroadcastLossCoefs,
680        sensitivity_selection: float = 1,
681        sensitivity_precision: float = 1,
682        optim_lr: float = 1e-3,
683        optim_weight_decay: float = 0.0,
684        scheduler_args: SchedulerArgs | None = None,
685        learn_logit_scale: bool = False,
686        use_normalized_constrastive: bool = True,
687        contrastive_loss: ContrastiveLossType | None = None,
688        precision_softmax_temp: float = 0.01,
689    ) -> None:
690        """
691        Initializes a Global Workspace
692
693        Args:
694            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
695                connected to the GW. Keys are domain names, values are the
696                `DomainModule`.
697            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
698                name to a `torch.nn.Module` class which role is to encode a
699                unimodal latent representations into a GW representation (pre fusion).
700            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
701                name to a `torch.nn.Module` class which role is to decode a
702                GW representation into a unimodal latent representations.
703            workspace_dim (`int`): dimension of the GW.
704            loss_coefs (`LossCoefs`): loss coefficients
705            sensitivity_selection (`float`): sensivity coef $c'_1$
706            sensitivity_precision (`float`): sensitivity coef $c'_2$
707            optim_lr (`float`): learning rate
708            optim_weight_decay (`float`): weight decay
709            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
710            learn_logit_scale (`bool`): whether to learn the contrastive learning
711                contrastive loss when using the default contrastive loss.
712            use_normalized_constrastive (`bool`): whether to use the normalized cont
713                loss by the precision coefs
714            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
715                function used for alignment. `learn_logit_scale` will not affect custom
716                contrastive losses.
717            precision_softmax_temp (`float`): temperature to use in softmax of
718                precision
719        """
720        domain_mods = freeze_domain_modules(domain_mods)
721
722        gw_mod = GWModuleBayesian(
723            domain_mods,
724            workspace_dim,
725            gw_encoders,
726            gw_decoders,
727            sensitivity_selection,
728            sensitivity_precision,
729            precision_softmax_temp,
730        )
731
732        selection_mod = FixedSharedSelection()
733
734        contrastive_loss = ContrastiveLoss(
735            torch.tensor([1]).log(), "mean", learn_logit_scale
736        )
737
738        loss_mod = GWLossesBayesian(
739            gw_mod,
740            selection_mod,
741            domain_mods,
742            loss_coefs,
743            contrastive_loss,
744            use_normalized_constrastive,
745        )
746
747        super().__init__(
748            gw_mod,
749            selection_mod,
750            loss_mod,
751            optim_lr,
752            optim_weight_decay,
753            scheduler_args,
754        )
755
756    def forward(  # type: ignore
757        self,
758        latent_domains: LatentsDomainGroupsT,
759    ) -> GWPredictions:
760        """
761        Computes demi-cycles, cycles, and translations.
762
763        Args:
764            latent_domains (`LatentsT`): Groups of domains for the computation.
765
766        Returns:
767            `GWPredictions`: the predictions on the batch.
768        """
769        return GWPredictions(
770            demi_cycles=batch_demi_cycles(
771                self.gw_mod, self.selection_mod, latent_domains
772            ),
773            cycles=batch_cycles(
774                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
775            ),
776            translations=batch_translations(
777                self.gw_mod, self.selection_mod, latent_domains
778            ),
779            **super().forward(latent_domains),
780        )

A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty prediction.

This is used to simplify a Global Workspace instanciation and only overrides the __init__ method.

GlobalWorkspaceBayesian( domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], gw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], gw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], workspace_dim: int, loss_coefs: shimmer.modules.losses.BroadcastLossCoefs, sensitivity_selection: float = 1, sensitivity_precision: float = 1, optim_lr: float = 0.001, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, learn_logit_scale: bool = False, use_normalized_constrastive: bool = True, contrastive_loss: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput] | None = None, precision_softmax_temp: float = 0.01)
673    def __init__(
674        self,
675        domain_mods: Mapping[str, DomainModule],
676        gw_encoders: Mapping[str, Module],
677        gw_decoders: Mapping[str, Module],
678        workspace_dim: int,
679        loss_coefs: BroadcastLossCoefs,
680        sensitivity_selection: float = 1,
681        sensitivity_precision: float = 1,
682        optim_lr: float = 1e-3,
683        optim_weight_decay: float = 0.0,
684        scheduler_args: SchedulerArgs | None = None,
685        learn_logit_scale: bool = False,
686        use_normalized_constrastive: bool = True,
687        contrastive_loss: ContrastiveLossType | None = None,
688        precision_softmax_temp: float = 0.01,
689    ) -> None:
690        """
691        Initializes a Global Workspace
692
693        Args:
694            domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
695                connected to the GW. Keys are domain names, values are the
696                `DomainModule`.
697            gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
698                name to a `torch.nn.Module` class which role is to encode a
699                unimodal latent representations into a GW representation (pre fusion).
700            gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
701                name to a `torch.nn.Module` class which role is to decode a
702                GW representation into a unimodal latent representations.
703            workspace_dim (`int`): dimension of the GW.
704            loss_coefs (`LossCoefs`): loss coefficients
705            sensitivity_selection (`float`): sensivity coef $c'_1$
706            sensitivity_precision (`float`): sensitivity coef $c'_2$
707            optim_lr (`float`): learning rate
708            optim_weight_decay (`float`): weight decay
709            scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
710            learn_logit_scale (`bool`): whether to learn the contrastive learning
711                contrastive loss when using the default contrastive loss.
712            use_normalized_constrastive (`bool`): whether to use the normalized cont
713                loss by the precision coefs
714            contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
715                function used for alignment. `learn_logit_scale` will not affect custom
716                contrastive losses.
717            precision_softmax_temp (`float`): temperature to use in softmax of
718                precision
719        """
720        domain_mods = freeze_domain_modules(domain_mods)
721
722        gw_mod = GWModuleBayesian(
723            domain_mods,
724            workspace_dim,
725            gw_encoders,
726            gw_decoders,
727            sensitivity_selection,
728            sensitivity_precision,
729            precision_softmax_temp,
730        )
731
732        selection_mod = FixedSharedSelection()
733
734        contrastive_loss = ContrastiveLoss(
735            torch.tensor([1]).log(), "mean", learn_logit_scale
736        )
737
738        loss_mod = GWLossesBayesian(
739            gw_mod,
740            selection_mod,
741            domain_mods,
742            loss_coefs,
743            contrastive_loss,
744            use_normalized_constrastive,
745        )
746
747        super().__init__(
748            gw_mod,
749            selection_mod,
750            loss_mod,
751            optim_lr,
752            optim_weight_decay,
753            scheduler_args,
754        )

Initializes a Global Workspace

Arguments:
  • domain_mods (Mapping[str, DomainModule]): mapping of the domains connected to the GW. Keys are domain names, values are the DomainModule.
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain name to a torch.nn.Module class which role is to encode a unimodal latent representations into a GW representation (pre fusion).
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain name to a torch.nn.Module class which role is to decode a GW representation into a unimodal latent representations.
  • workspace_dim (int): dimension of the GW.
  • loss_coefs (LossCoefs): loss coefficients
  • sensitivity_selection (float): sensivity coef $c'_1$
  • sensitivity_precision (float): sensitivity coef $c'_2$
  • optim_lr (float): learning rate
  • optim_weight_decay (float): weight decay
  • scheduler_args (SchedulerArgs | None): optimization scheduler's arguments
  • learn_logit_scale (bool): whether to learn the contrastive learning contrastive loss when using the default contrastive loss.
  • use_normalized_constrastive (bool): whether to use the normalized cont loss by the precision coefs
  • contrastive_loss (ContrastiveLossType | None): a contrastive loss function used for alignment. learn_logit_scale will not affect custom contrastive losses.
  • precision_softmax_temp (float): temperature to use in softmax of precision
def forward( self, latent_domains: collections.abc.Mapping[frozenset[str], collections.abc.Mapping[str, torch.Tensor]]) -> GWPredictions:
756    def forward(  # type: ignore
757        self,
758        latent_domains: LatentsDomainGroupsT,
759    ) -> GWPredictions:
760        """
761        Computes demi-cycles, cycles, and translations.
762
763        Args:
764            latent_domains (`LatentsT`): Groups of domains for the computation.
765
766        Returns:
767            `GWPredictions`: the predictions on the batch.
768        """
769        return GWPredictions(
770            demi_cycles=batch_demi_cycles(
771                self.gw_mod, self.selection_mod, latent_domains
772            ),
773            cycles=batch_cycles(
774                self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
775            ),
776            translations=batch_translations(
777                self.gw_mod, self.selection_mod, latent_domains
778            ),
779            **super().forward(latent_domains),
780        )

Computes demi-cycles, cycles, and translations.

Arguments:
  • latent_domains (LatentsT): Groups of domains for the computation.
Returns:

GWPredictions: the predictions on the batch.

Inherited Members
GlobalWorkspaceBase
gw_mod
selection_mod
loss_mod
optim_lr
optim_weight_decay
scheduler_args
domain_mods
workspace_dim
encode_and_fuse
encode
fuse
decode
batch_gw_states
encode_domain
encode_domains
decode_domain
decode_domains
generic_step
validation_step
test_step
training_step
predict_step
configure_optimizers
lightning.pytorch.core.module.LightningModule
CHECKPOINT_HYPER_PARAMS_KEY
CHECKPOINT_HYPER_PARAMS_NAME
CHECKPOINT_HYPER_PARAMS_TYPE
optimizers
lr_schedulers
trainer
fabric
example_input_array
current_epoch
global_step
global_rank
local_rank
on_gpu
automatic_optimization
strict_loading
logger
loggers
print
log
log_dict
all_gather
configure_callbacks
manual_backward
backward
toggle_optimizer
untoggle_optimizer
clip_gradients
configure_gradient_clipping
lr_scheduler_step
optimizer_step
optimizer_zero_grad
freeze
unfreeze
to_onnx
to_torchscript
load_from_checkpoint
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
dtype
device
to
cuda
cpu
type
float
double
half
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
save_hyperparameters
hparams
hparams_initial
lightning.pytorch.core.hooks.ModelHooks
on_fit_start
on_fit_end
on_train_start
on_train_end
on_validation_start
on_validation_end
on_test_start
on_test_end
on_predict_start
on_predict_end
on_train_batch_start
on_train_batch_end
on_validation_batch_start
on_validation_batch_end
on_test_batch_start
on_test_batch_end
on_predict_batch_start
on_predict_batch_end
on_validation_model_zero_grad
on_validation_model_eval
on_validation_model_train
on_test_model_eval
on_test_model_train
on_predict_model_eval
on_train_epoch_start
on_train_epoch_end
on_validation_epoch_start
on_validation_epoch_end
on_test_epoch_start
on_test_epoch_end
on_predict_epoch_start
on_predict_epoch_end
on_before_zero_grad
on_before_backward
on_after_backward
on_before_optimizer_step
configure_sharded_model
configure_model
lightning.pytorch.core.hooks.DataHooks
prepare_data_per_node
allow_zero_length_dataloader_with_multiple_devices
prepare_data
setup
teardown
train_dataloader
test_dataloader
val_dataloader
predict_dataloader
transfer_batch_to_device
on_before_batch_transfer
on_after_batch_transfer
lightning.pytorch.core.hooks.CheckpointHooks
on_load_checkpoint
on_save_checkpoint
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
ipu
xpu
bfloat16
to_empty
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def pretrained_global_workspace( checkpoint_path: str | pathlib.Path, domain_mods: collections.abc.Mapping[str, shimmer.modules.domain.DomainModule], gw_encoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], gw_decoders: collections.abc.Mapping[str, torch.nn.modules.module.Module], workspace_dim: int, loss_coefs: shimmer.modules.losses.LossCoefs, contrastive_fn: collections.abc.Callable[[torch.Tensor, torch.Tensor], shimmer.modules.domain.LossOutput], **kwargs) -> GlobalWorkspace2Domains:
783def pretrained_global_workspace(
784    checkpoint_path: str | Path,
785    domain_mods: Mapping[str, DomainModule],
786    gw_encoders: Mapping[str, Module],
787    gw_decoders: Mapping[str, Module],
788    workspace_dim: int,
789    loss_coefs: LossCoefs,
790    contrastive_fn: ContrastiveLossType,
791    **kwargs,
792) -> GlobalWorkspace2Domains:
793    """
794    Load a `GlobalWorkspace` flavor of `GlobalWorkspaceBase` from a checkpoint.
795
796    Args:
797        checkpoint_path (`str | Path`): path to checkpoint
798        domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
799            connected to the GW. Keys are domain names, values are the
800            `DomainModule`.
801        gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
802            name to a `torch.nn.Module` class which role is to encode a
803            unimodal latent representations into a GW representation (pre fusion).
804        gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
805            name to a `torch.nn.Module` class which role is to decode a
806            GW representation into a unimodal latent representations.
807        workspace_dim (`int`): dimension of the GW.
808        loss_coefs (`LossCoefs`): loss coefficients
809        contrastive_loss (`ContrastiveLossType`): a contrastive loss
810            function used for alignment. `learn_logit_scale` will not affect custom
811            contrastive losses.
812        **kwargs: additional arguments to pass to
813            `GlobalWorkspace.load_from_checkpoint`.
814
815    Returns:
816        `GlobalWorkspace`: the pretrained `GlobalWorkspace`.
817
818    Raises:
819        `TypeError`: if loaded type is not `GlobalWorkspace`.
820    """
821    domain_mods = freeze_domain_modules(domain_mods)
822    gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)
823    selection_mod = SingleDomainSelection()
824    loss_mod = GWLosses2Domains(
825        gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_fn
826    )
827
828    gw = GlobalWorkspace2Domains.load_from_checkpoint(
829        checkpoint_path,
830        gw_mod=gw_mod,
831        selection_mid=selection_mod,
832        loss_coefs=loss_coefs,
833        loss_mod=loss_mod,
834        **kwargs,
835    )
836    if not isinstance(gw, GlobalWorkspace2Domains):
837        raise TypeError("model should be of type GlobalWorkspace")
838    return gw

Load a GlobalWorkspace flavor of GlobalWorkspaceBase from a checkpoint.

Arguments:
  • checkpoint_path (str | Path): path to checkpoint
  • domain_mods (Mapping[str, DomainModule]): mapping of the domains connected to the GW. Keys are domain names, values are the DomainModule.
  • gw_encoders (Mapping[str, torch.nn.Module]): mapping for each domain name to a torch.nn.Module class which role is to encode a unimodal latent representations into a GW representation (pre fusion).
  • gw_decoders (Mapping[str, torch.nn.Module]): mapping for each domain name to a torch.nn.Module class which role is to decode a GW representation into a unimodal latent representations.
  • workspace_dim (int): dimension of the GW.
  • loss_coefs (LossCoefs): loss coefficients
  • contrastive_loss (ContrastiveLossType): a contrastive loss function used for alignment. learn_logit_scale will not affect custom contrastive losses.
  • **kwargs: additional arguments to pass to GlobalWorkspace.load_from_checkpoint.
Returns:

GlobalWorkspace: the pretrained GlobalWorkspace.

Raises: