shimmer.modules.global_workspace
1from collections.abc import Iterable, Mapping 2from pathlib import Path 3from typing import Any, Generic, TypedDict, TypeVar, cast 4 5import torch 6from lightning.pytorch import LightningModule 7from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig 8from torch.nn import Module, ModuleDict 9from torch.optim.lr_scheduler import OneCycleLR 10 11from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType 12from shimmer.modules.domain import DomainModule 13from shimmer.modules.gw_module import ( 14 GWModule, 15 GWModuleBase, 16 GWModuleBayesian, 17) 18from shimmer.modules.losses import ( 19 BroadcastLossCoefs, 20 GWLosses, 21 GWLosses2Domains, 22 GWLossesBase, 23 GWLossesBayesian, 24 LossCoefs, 25) 26from shimmer.modules.selection import ( 27 FixedSharedSelection, 28 RandomSelection, 29 SelectionBase, 30 SingleDomainSelection, 31) 32from shimmer.modules.utils import batch_cycles, batch_demi_cycles, batch_translations 33from shimmer.types import ( 34 LatentsDomainGroupsDT, 35 LatentsDomainGroupsT, 36 ModelModeT, 37 RawDomainGroupsDT, 38 RawDomainGroupsT, 39 RawDomainGroupT, 40) 41from shimmer.utils import groups_batch_size 42 43 44class SchedulerArgs(TypedDict, total=False): 45 """TypedDict of arguments passed to the OneCycle scheduler""" 46 47 max_lr: float 48 """Maximum learning rate""" 49 50 total_steps: int 51 """Total number of steps""" 52 53 54class GWPredictionsBase(TypedDict): 55 """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`""" 56 57 states: dict[str, torch.Tensor] 58 """ 59 GW state representation from domain groups with only one domain. 60 The key represent the domain's name. 61 """ 62 63 64_T_gw_mod = TypeVar("_T_gw_mod", bound=GWModuleBase) 65_T_selection_mod = TypeVar("_T_selection_mod", bound=SelectionBase) 66_T_loss_mod = TypeVar("_T_loss_mod", bound=GWLossesBase) 67 68 69class GlobalWorkspaceBase( 70 Generic[_T_gw_mod, _T_selection_mod, _T_loss_mod], LightningModule 71): 72 """ 73 Global Workspace Lightning Module. 74 75 This is the base class to build the Global Workspace. 76 """ 77 78 def __init__( 79 self, 80 gw_mod: _T_gw_mod, 81 selection_mod: _T_selection_mod, 82 loss_mod: _T_loss_mod, 83 optim_lr: float = 1e-3, 84 optim_weight_decay: float = 0.0, 85 scheduler_args: SchedulerArgs | None = None, 86 ) -> None: 87 """ 88 Initializes a GW 89 90 Args: 91 gw_mod (`GWModuleBase`): the GWModule 92 selection_mod (`SelectionBase`): selection module 93 loss_mod (`GWLossesBase`): module to compute the GW losses. 94 optim_lr (`float`): learning rate 95 optim_weight_decay (`float`): weight decay 96 scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define 97 scheduler parameters. 98 """ 99 super().__init__() 100 self.save_hyperparameters( 101 ignore=[ 102 "gw_mod", 103 "selection_mod", 104 "domain_mods", 105 "loss_mod", 106 "domain_descriptions", 107 "contrastive_loss", 108 "cont_loss_bayesian", 109 "gw_encoders", 110 "gw_decoders", 111 ] 112 ) 113 114 self.gw_mod = gw_mod 115 """ a `GWModuleBase` implementation.""" 116 117 self.selection_mod = selection_mod 118 """A `SelectionBase` implementation.""" 119 120 self.loss_mod = loss_mod 121 """The module that computes losses of the GW""" 122 123 self.optim_lr = optim_lr 124 self.optim_weight_decay = optim_weight_decay 125 self.scheduler_args = SchedulerArgs(max_lr=optim_lr, total_steps=1) 126 if scheduler_args is not None: 127 self.scheduler_args.update(scheduler_args) 128 129 @property 130 def domain_mods(self) -> Mapping[str, DomainModule]: 131 return self.gw_mod.domain_mods 132 133 @property 134 def workspace_dim(self) -> int: 135 """Dimension of the GW.""" 136 return self.gw_mod.workspace_dim 137 138 def encode_and_fuse( 139 self, x: LatentsDomainGroupsT, selection_module: SelectionBase 140 ) -> dict[frozenset[str], torch.Tensor]: 141 """ 142 Encode a group of latent representations into the GW representation. 143 144 Args: 145 x (`LatentsDomainGroupsT`): the input domain representations. 146 selection_scores (`Mapping[str, torch.Tensor]`): 147 148 Returns: 149 `dict[frozenset[str], torch.Tensor]`: the GW representations. 150 """ 151 return { 152 domains: self.gw_mod.encode_and_fuse(latents, selection_module) 153 for domains, latents in x.items() 154 } 155 156 def encode(self, x: LatentsDomainGroupsT) -> LatentsDomainGroupsDT: 157 """ 158 Encode a group of latent representations into the pre-fusion GW representation. 159 160 Args: 161 x (`LatentsDomainGroupsT`): the input domain representations. 162 163 Returns: 164 `LatensDomainGroupsDT`: the GW representations. 165 """ 166 return {domains: self.gw_mod.encode(latents) for domains, latents in x.items()} 167 168 def fuse( 169 self, 170 x: LatentsDomainGroupsT, 171 selection_scores: Mapping[frozenset[str], Mapping[str, torch.Tensor]], 172 ) -> dict[frozenset[str], torch.Tensor]: 173 """ 174 Fuses a group of latent representations into the GW representation. 175 176 Args: 177 x (`LatentsDomainGroupsT`): the pre-fusion latent representations 178 selection_scores (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`): 179 selection scores for each group 180 181 Returns: 182 `dict[frozenset[str], torch.Tensor]`: GW representation of each group 183 """ 184 return { 185 domains: self.gw_mod.fuse(latents, selection_scores[domains]) 186 for domains, latents in x.items() 187 } 188 189 def decode( 190 self, 191 z: Mapping[frozenset[str], torch.Tensor], 192 domains: Iterable[str] | None = None, 193 ) -> LatentsDomainGroupsDT: 194 """ 195 Decode the group GW representation into given `domains`. 196 197 Args: 198 z (`torch.Tensor`): the GW representation. 199 domains (`Iterable[str]`): iterable of domains to decode. 200 201 Returns: 202 `dict[str, torch.Tensor]`: the decoded unimodal representations. 203 """ 204 return { 205 domain_names: self.gw_mod.decode(gw_rep, domains) 206 for domain_names, gw_rep in z.items() 207 } 208 209 def forward( # type: ignore 210 self, 211 latent_domains: LatentsDomainGroupsT, 212 ) -> GWPredictionsBase: 213 """ 214 Computes demi-cycles, cycles, and translations. 215 216 Args: 217 latent_domains (`LatentsT`): Groups of domains for the computation. 218 219 Returns: 220 `GWPredictionsBase`: the predictions on the batch. 221 """ 222 223 return GWPredictionsBase(states=self.batch_gw_states(latent_domains)) 224 225 def batch_gw_states( 226 self, latent_domains: LatentsDomainGroupsT 227 ) -> dict[str, torch.Tensor]: 228 """ 229 Comptues GW states of a batch of groups of domains. 230 231 Args: 232 latent_domains (`LatentsT`): the batch of groups of domains 233 234 Returns: 235 `dict[str, torch.Tensor]`: states for each domain. 236 """ 237 predictions: dict[str, torch.Tensor] = {} 238 for domains, latents in latent_domains.items(): 239 if len(domains) > 1: 240 continue 241 domain_name = list(domains)[0] 242 z = self.gw_mod.encode_and_fuse( 243 latents, selection_module=self.selection_mod 244 ) 245 predictions[domain_name] = z 246 return predictions 247 248 def encode_domain(self, domain: Any, name: str) -> torch.Tensor: 249 """ 250 Encodes a domain from the domain data into the unimodal representation. 251 252 This is a convenient proxy for the `DomainModule.encode` method and is 253 equivalent to: 254 ```python 255 self.domain_mods[name].encode(domain) 256 ``` 257 258 Args: 259 domain (`Any`): the domain data 260 name (`str`): domain name to encode 261 262 Returns: 263 `torch.Tensor`: the domain's unimodal representation. 264 """ 265 return self.domain_mods[name].encode(domain) 266 267 def encode_domains(self, batch: RawDomainGroupsT) -> LatentsDomainGroupsDT: 268 """ 269 Encode all domains in the batch. 270 271 Args: 272 batch (`RawDomainGroupsT`): the batch of 273 domain groups with raw unimodal data to encode into groups of latent 274 representations. 275 276 Returns: 277 `LatentsDomainGroupsDT`: the domains' unimodal representations. 278 """ 279 return { 280 domains: { 281 name: self.domain_mods[name].encode(domain) 282 for name, domain in data.items() 283 } 284 for domains, data in batch.items() 285 } 286 287 def decode_domain(self, domain: torch.Tensor, name: str) -> Any: 288 """ 289 Decodes a domain from the unimodal representation into the domain data. 290 291 This is a convenient proxy for the `DomainModule.encode` method and is 292 equivalent to: 293 ```python 294 self.domain_mods[name].decode(domain) 295 ``` 296 297 Args: 298 domain (`torch.Tensor`): the domain data 299 name (`str`): domain name to encode 300 301 Returns: 302 `Any`: the domain's raw data. 303 """ 304 return self.domain_mods[name].decode(domain) 305 306 def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroupsDT: 307 """ 308 Decodes all domains in the batch. 309 310 Args: 311 batch (`LatentsDomainGroupsT`): the batch of 312 domain groups with unimodal latent representation to decode into 313 groups of raw data. 314 315 Returns: 316 `LatentsDomainGroupsDT`: the domains' raw data. 317 """ 318 return { 319 domains: { 320 name: self.domain_mods[name].decode(domain) 321 for name, domain in latents.items() 322 } 323 for domains, latents in latents_domain.items() 324 } 325 326 def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor: 327 """ 328 The generic step used in `training_step`, `validation_step` and 329 `test_step`. 330 331 Args: 332 batch (`RawDomainGroupsT`): the batch of groups of raw unimodal data. 333 mode (`ModelModeT`): 334 335 Returns: 336 `torch.Tensor`: the loss to train on. 337 """ 338 domain_latents = self.encode_domains(batch) 339 batch_size = groups_batch_size(domain_latents) 340 341 loss_output = self.loss_mod.step(domain_latents, mode) 342 343 for name, metric in loss_output.all.items(): 344 self.log( 345 f"{mode}/{name}", 346 metric, 347 batch_size=batch_size, 348 add_dataloader_idx=False, 349 ) 350 351 return loss_output.loss 352 353 def validation_step( # type: ignore 354 self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0 355 ) -> torch.Tensor: 356 """Validation step used by lightning""" 357 358 batch = {frozenset(data.keys()): data} 359 for domain in data: 360 batch[frozenset([domain])] = {domain: data[domain]} 361 if dataloader_idx == 0: 362 return self.generic_step(batch, mode="val") 363 return self.generic_step(batch, mode="val/ood") 364 365 def test_step( # type: ignore 366 self, data: Mapping[str, Any], batch_idx: int, dataloader_idx: int = 0 367 ) -> torch.Tensor: 368 """Test step used by lightning""" 369 370 batch = {frozenset(data.keys()): data} 371 for domain in data: 372 batch[frozenset([domain])] = {domain: data[domain]} 373 if dataloader_idx == 0: 374 return self.generic_step(batch, mode="test") 375 return self.generic_step(batch, mode="test/ood") 376 377 def training_step( # type: ignore 378 self, batch: Mapping[frozenset[str], Mapping[str, Any]], batch_idx: int 379 ) -> torch.Tensor: 380 """Training step used by lightning""" 381 382 return self.generic_step(batch, mode="train") 383 384 def predict_step( # type: ignore 385 self, data: Mapping[str, Any], batch_idx: int 386 ) -> GWPredictionsBase: 387 """Predict step used by lightning""" 388 389 batch = {frozenset(data.keys()): data} 390 for domain in data: 391 batch[frozenset([domain])] = {domain: data[domain]} 392 393 domain_latents = self.encode_domains(batch) 394 return self.forward(domain_latents) 395 396 def configure_optimizers(self) -> OptimizerLRSchedulerConfig: 397 """ 398 Configure models optimizers. 399 400 Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate 401 scheduler. 402 """ 403 404 optimizer = torch.optim.AdamW( 405 self.parameters(), 406 lr=self.optim_lr, 407 weight_decay=self.optim_weight_decay, 408 ) 409 410 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 411 412 return { 413 "optimizer": optimizer, 414 "lr_scheduler": { 415 "scheduler": lr_scheduler, 416 "interval": "step", 417 }, 418 } 419 420 421def freeze_domain_modules( 422 domain_mods: Mapping[str, DomainModule], 423) -> dict[str, DomainModule]: 424 """ 425 Freezes weights and set to eval mode the domain modules. 426 427 .. note:: 428 The output is casted as `dict[str, DomainModule]` type for better 429 auto-completion, but is actually a torch `ModuleDict`. 430 431 Args: 432 domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze 433 434 Returns: 435 `ModuleDict`: frozen modules. 436 """ 437 438 for mod in domain_mods.values(): 439 mod.freeze() 440 # Cast for better auto-completion at the expense of ModuleDict 441 return cast(dict[str, DomainModule], ModuleDict(domain_mods)) 442 443 444class GWPredictions(GWPredictionsBase): 445 """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`""" 446 447 demi_cycles: dict[str, torch.Tensor] 448 """ 449 Demi-cycle predictions of the model for each domain. Only computed on domain 450 groups with only one domain. 451 """ 452 453 cycles: dict[tuple[str, str], torch.Tensor] 454 """ 455 Cycle predictions of the model from one domain through another one. 456 Only computed on domain groups with more than one domain. 457 The keys are tuple with start domain and intermediary domain. 458 """ 459 460 translations: dict[tuple[str, str], torch.Tensor] 461 """ 462 Translation predictions of the model from one domain through another one. 463 464 Only computed on domain groups with more than one domain. 465 The keys are tuples with start domain and target domain. 466 """ 467 468 469class GlobalWorkspace2Domains( 470 GlobalWorkspaceBase[GWModule, SingleDomainSelection, GWLosses2Domains] 471): 472 """ 473 A simple 2-domains max flavor of GlobalWorkspaceBase. 474 475 This is used to simplify a Global Workspace instanciation and only overrides the 476 `__init__` method. 477 """ 478 479 def __init__( 480 self, 481 domain_mods: Mapping[str, DomainModule], 482 gw_encoders: Mapping[str, Module], 483 gw_decoders: Mapping[str, Module], 484 workspace_dim: int, 485 loss_coefs: LossCoefs, 486 optim_lr: float = 1e-3, 487 optim_weight_decay: float = 0.0, 488 scheduler_args: SchedulerArgs | None = None, 489 learn_logit_scale: bool = False, 490 contrastive_loss: ContrastiveLossType | None = None, 491 ) -> None: 492 """ 493 Initializes a Global Workspace 494 495 Args: 496 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains 497 connected to the GW. Keys are domain names, values are the 498 `DomainModule`. 499 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 500 name to a `torch.nn.Module` class which role is to encode a 501 unimodal latent representations into a GW representation (pre fusion). 502 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 503 name to a `torch.nn.Module` class which role is to decode a 504 GW representation into a unimodal latent representations. 505 workspace_dim (`int`): dimension of the GW. 506 loss_coefs (`LossCoefs`): loss coefficients 507 optim_lr (`float`): learning rate 508 optim_weight_decay (`float`): weight decay 509 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments 510 learn_logit_scale (`bool`): whether to learn the contrastive learning 511 contrastive loss when using the default contrastive loss. 512 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss 513 function used for alignment. `learn_logit_scale` will not affect custom 514 contrastive losses. 515 """ 516 domain_mods = freeze_domain_modules(domain_mods) 517 518 gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) 519 if contrastive_loss is None: 520 contrastive_loss = ContrastiveLoss( 521 torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale 522 ) 523 selection_mod = SingleDomainSelection() 524 loss_mod = GWLosses2Domains( 525 gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss 526 ) 527 528 super().__init__( 529 gw_mod, 530 selection_mod, 531 loss_mod, 532 optim_lr, 533 optim_weight_decay, 534 scheduler_args, 535 ) 536 537 def forward( # type: ignore 538 self, 539 latent_domains: LatentsDomainGroupsT, 540 ) -> GWPredictions: 541 """ 542 Computes demi-cycles, cycles, and translations. 543 544 Args: 545 latent_domains (`LatentsT`): Groups of domains for the computation. 546 547 Returns: 548 `GWPredictions`: the predictions on the batch. 549 """ 550 return GWPredictions( 551 demi_cycles=batch_demi_cycles( 552 self.gw_mod, self.selection_mod, latent_domains 553 ), 554 cycles=batch_cycles( 555 self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() 556 ), 557 translations=batch_translations( 558 self.gw_mod, self.selection_mod, latent_domains 559 ), 560 **super().forward(latent_domains), 561 ) 562 563 564class GlobalWorkspace(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]): 565 """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase. 566 567 This is used to simplify a Global Workspace instanciation and only overrides the 568 `__init__` method. 569 """ 570 571 def __init__( 572 self, 573 domain_mods: Mapping[str, DomainModule], 574 gw_encoders: Mapping[str, Module], 575 gw_decoders: Mapping[str, Module], 576 workspace_dim: int, 577 loss_coefs: BroadcastLossCoefs, 578 selection_temperature: float = 0.2, 579 optim_lr: float = 1e-3, 580 optim_weight_decay: float = 0.0, 581 scheduler_args: SchedulerArgs | None = None, 582 learn_logit_scale: bool = False, 583 contrastive_loss: ContrastiveLossType | None = None, 584 ) -> None: 585 """ 586 Initializes a Global Workspace 587 588 Args: 589 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains 590 connected to the GW. Keys are domain names, values are the 591 `DomainModule`. 592 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 593 name to a `torch.nn.Module` class which role is to encode a 594 unimodal latent representations into a GW representation (pre fusion). 595 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 596 name to a `torch.nn.Module` class which role is to decode a 597 GW representation into a unimodal latent representations. 598 workspace_dim (`int`): dimension of the GW. 599 loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses. 600 selection_temperature (`float`): temperature value for the RandomSelection 601 module. 602 optim_lr (`float`): learning rate 603 optim_weight_decay (`float`): weight decay 604 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments 605 learn_logit_scale (`bool`): whether to learn the contrastive learning 606 contrastive loss when using the default contrastive loss. 607 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss 608 function used for alignment. `learn_logit_scale` will not affect custom 609 contrastive losses. 610 """ 611 domain_mods = freeze_domain_modules(domain_mods) 612 gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) 613 614 if contrastive_loss is None: 615 contrastive_loss = ContrastiveLoss( 616 torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale 617 ) 618 619 selection_mod = RandomSelection(selection_temperature) 620 loss_mod = GWLosses( 621 gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss 622 ) 623 624 super().__init__( 625 gw_mod, 626 selection_mod, 627 loss_mod, 628 optim_lr, 629 optim_weight_decay, 630 scheduler_args, 631 ) 632 633 def forward( # type: ignore 634 self, 635 latent_domains: LatentsDomainGroupsT, 636 ) -> GWPredictions: 637 """ 638 Computes demi-cycles, cycles, and translations. 639 640 Args: 641 latent_domains (`LatentsT`): Groups of domains for the computation. 642 643 Returns: 644 `GWPredictions`: the predictions on the batch. 645 """ 646 return GWPredictions( 647 demi_cycles=batch_demi_cycles( 648 self.gw_mod, self.selection_mod, latent_domains 649 ), 650 cycles=batch_cycles( 651 self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() 652 ), 653 translations=batch_translations( 654 self.gw_mod, self.selection_mod, latent_domains 655 ), 656 # TODO: add other combinations 657 **super().forward(latent_domains), 658 ) 659 660 661class GlobalWorkspaceBayesian( 662 GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian] 663): 664 """ 665 A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty 666 prediction. 667 668 This is used to simplify a Global Workspace instanciation and only overrides the 669 `__init__` method. 670 """ 671 672 def __init__( 673 self, 674 domain_mods: Mapping[str, DomainModule], 675 gw_encoders: Mapping[str, Module], 676 gw_decoders: Mapping[str, Module], 677 workspace_dim: int, 678 loss_coefs: BroadcastLossCoefs, 679 sensitivity_selection: float = 1, 680 sensitivity_precision: float = 1, 681 optim_lr: float = 1e-3, 682 optim_weight_decay: float = 0.0, 683 scheduler_args: SchedulerArgs | None = None, 684 learn_logit_scale: bool = False, 685 use_normalized_constrastive: bool = True, 686 contrastive_loss: ContrastiveLossType | None = None, 687 precision_softmax_temp: float = 0.01, 688 ) -> None: 689 """ 690 Initializes a Global Workspace 691 692 Args: 693 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains 694 connected to the GW. Keys are domain names, values are the 695 `DomainModule`. 696 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 697 name to a `torch.nn.Module` class which role is to encode a 698 unimodal latent representations into a GW representation (pre fusion). 699 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 700 name to a `torch.nn.Module` class which role is to decode a 701 GW representation into a unimodal latent representations. 702 workspace_dim (`int`): dimension of the GW. 703 loss_coefs (`LossCoefs`): loss coefficients 704 sensitivity_selection (`float`): sensivity coef $c'_1$ 705 sensitivity_precision (`float`): sensitivity coef $c'_2$ 706 optim_lr (`float`): learning rate 707 optim_weight_decay (`float`): weight decay 708 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments 709 learn_logit_scale (`bool`): whether to learn the contrastive learning 710 contrastive loss when using the default contrastive loss. 711 use_normalized_constrastive (`bool`): whether to use the normalized cont 712 loss by the precision coefs 713 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss 714 function used for alignment. `learn_logit_scale` will not affect custom 715 contrastive losses. 716 precision_softmax_temp (`float`): temperature to use in softmax of 717 precision 718 """ 719 domain_mods = freeze_domain_modules(domain_mods) 720 721 gw_mod = GWModuleBayesian( 722 domain_mods, 723 workspace_dim, 724 gw_encoders, 725 gw_decoders, 726 sensitivity_selection, 727 sensitivity_precision, 728 precision_softmax_temp, 729 ) 730 731 selection_mod = FixedSharedSelection() 732 733 contrastive_loss = ContrastiveLoss( 734 torch.tensor([1]).log(), "mean", learn_logit_scale 735 ) 736 737 loss_mod = GWLossesBayesian( 738 gw_mod, 739 selection_mod, 740 domain_mods, 741 loss_coefs, 742 contrastive_loss, 743 use_normalized_constrastive, 744 ) 745 746 super().__init__( 747 gw_mod, 748 selection_mod, 749 loss_mod, 750 optim_lr, 751 optim_weight_decay, 752 scheduler_args, 753 ) 754 755 def forward( # type: ignore 756 self, 757 latent_domains: LatentsDomainGroupsT, 758 ) -> GWPredictions: 759 """ 760 Computes demi-cycles, cycles, and translations. 761 762 Args: 763 latent_domains (`LatentsT`): Groups of domains for the computation. 764 765 Returns: 766 `GWPredictions`: the predictions on the batch. 767 """ 768 return GWPredictions( 769 demi_cycles=batch_demi_cycles( 770 self.gw_mod, self.selection_mod, latent_domains 771 ), 772 cycles=batch_cycles( 773 self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() 774 ), 775 translations=batch_translations( 776 self.gw_mod, self.selection_mod, latent_domains 777 ), 778 **super().forward(latent_domains), 779 ) 780 781 782def pretrained_global_workspace( 783 checkpoint_path: str | Path, 784 domain_mods: Mapping[str, DomainModule], 785 gw_encoders: Mapping[str, Module], 786 gw_decoders: Mapping[str, Module], 787 workspace_dim: int, 788 loss_coefs: LossCoefs, 789 contrastive_fn: ContrastiveLossType, 790 **kwargs, 791) -> GlobalWorkspace2Domains: 792 """ 793 Load a `GlobalWorkspace` flavor of `GlobalWorkspaceBase` from a checkpoint. 794 795 Args: 796 checkpoint_path (`str | Path`): path to checkpoint 797 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains 798 connected to the GW. Keys are domain names, values are the 799 `DomainModule`. 800 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 801 name to a `torch.nn.Module` class which role is to encode a 802 unimodal latent representations into a GW representation (pre fusion). 803 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 804 name to a `torch.nn.Module` class which role is to decode a 805 GW representation into a unimodal latent representations. 806 workspace_dim (`int`): dimension of the GW. 807 loss_coefs (`LossCoefs`): loss coefficients 808 contrastive_loss (`ContrastiveLossType`): a contrastive loss 809 function used for alignment. `learn_logit_scale` will not affect custom 810 contrastive losses. 811 **kwargs: additional arguments to pass to 812 `GlobalWorkspace.load_from_checkpoint`. 813 814 Returns: 815 `GlobalWorkspace`: the pretrained `GlobalWorkspace`. 816 817 Raises: 818 `TypeError`: if loaded type is not `GlobalWorkspace`. 819 """ 820 domain_mods = freeze_domain_modules(domain_mods) 821 gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) 822 selection_mod = SingleDomainSelection() 823 loss_mod = GWLosses2Domains( 824 gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_fn 825 ) 826 827 gw = GlobalWorkspace2Domains.load_from_checkpoint( 828 checkpoint_path, 829 gw_mod=gw_mod, 830 selection_mid=selection_mod, 831 loss_coefs=loss_coefs, 832 loss_mod=loss_mod, 833 **kwargs, 834 ) 835 if not isinstance(gw, GlobalWorkspace2Domains): 836 raise TypeError("model should be of type GlobalWorkspace") 837 return gw
45class SchedulerArgs(TypedDict, total=False): 46 """TypedDict of arguments passed to the OneCycle scheduler""" 47 48 max_lr: float 49 """Maximum learning rate""" 50 51 total_steps: int 52 """Total number of steps"""
TypedDict of arguments passed to the OneCycle scheduler
55class GWPredictionsBase(TypedDict): 56 """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`""" 57 58 states: dict[str, torch.Tensor] 59 """ 60 GW state representation from domain groups with only one domain. 61 The key represent the domain's name. 62 """
TypedDict of the output given when calling GlobalWorkspaceBase.predict
70class GlobalWorkspaceBase( 71 Generic[_T_gw_mod, _T_selection_mod, _T_loss_mod], LightningModule 72): 73 """ 74 Global Workspace Lightning Module. 75 76 This is the base class to build the Global Workspace. 77 """ 78 79 def __init__( 80 self, 81 gw_mod: _T_gw_mod, 82 selection_mod: _T_selection_mod, 83 loss_mod: _T_loss_mod, 84 optim_lr: float = 1e-3, 85 optim_weight_decay: float = 0.0, 86 scheduler_args: SchedulerArgs | None = None, 87 ) -> None: 88 """ 89 Initializes a GW 90 91 Args: 92 gw_mod (`GWModuleBase`): the GWModule 93 selection_mod (`SelectionBase`): selection module 94 loss_mod (`GWLossesBase`): module to compute the GW losses. 95 optim_lr (`float`): learning rate 96 optim_weight_decay (`float`): weight decay 97 scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define 98 scheduler parameters. 99 """ 100 super().__init__() 101 self.save_hyperparameters( 102 ignore=[ 103 "gw_mod", 104 "selection_mod", 105 "domain_mods", 106 "loss_mod", 107 "domain_descriptions", 108 "contrastive_loss", 109 "cont_loss_bayesian", 110 "gw_encoders", 111 "gw_decoders", 112 ] 113 ) 114 115 self.gw_mod = gw_mod 116 """ a `GWModuleBase` implementation.""" 117 118 self.selection_mod = selection_mod 119 """A `SelectionBase` implementation.""" 120 121 self.loss_mod = loss_mod 122 """The module that computes losses of the GW""" 123 124 self.optim_lr = optim_lr 125 self.optim_weight_decay = optim_weight_decay 126 self.scheduler_args = SchedulerArgs(max_lr=optim_lr, total_steps=1) 127 if scheduler_args is not None: 128 self.scheduler_args.update(scheduler_args) 129 130 @property 131 def domain_mods(self) -> Mapping[str, DomainModule]: 132 return self.gw_mod.domain_mods 133 134 @property 135 def workspace_dim(self) -> int: 136 """Dimension of the GW.""" 137 return self.gw_mod.workspace_dim 138 139 def encode_and_fuse( 140 self, x: LatentsDomainGroupsT, selection_module: SelectionBase 141 ) -> dict[frozenset[str], torch.Tensor]: 142 """ 143 Encode a group of latent representations into the GW representation. 144 145 Args: 146 x (`LatentsDomainGroupsT`): the input domain representations. 147 selection_scores (`Mapping[str, torch.Tensor]`): 148 149 Returns: 150 `dict[frozenset[str], torch.Tensor]`: the GW representations. 151 """ 152 return { 153 domains: self.gw_mod.encode_and_fuse(latents, selection_module) 154 for domains, latents in x.items() 155 } 156 157 def encode(self, x: LatentsDomainGroupsT) -> LatentsDomainGroupsDT: 158 """ 159 Encode a group of latent representations into the pre-fusion GW representation. 160 161 Args: 162 x (`LatentsDomainGroupsT`): the input domain representations. 163 164 Returns: 165 `LatensDomainGroupsDT`: the GW representations. 166 """ 167 return {domains: self.gw_mod.encode(latents) for domains, latents in x.items()} 168 169 def fuse( 170 self, 171 x: LatentsDomainGroupsT, 172 selection_scores: Mapping[frozenset[str], Mapping[str, torch.Tensor]], 173 ) -> dict[frozenset[str], torch.Tensor]: 174 """ 175 Fuses a group of latent representations into the GW representation. 176 177 Args: 178 x (`LatentsDomainGroupsT`): the pre-fusion latent representations 179 selection_scores (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`): 180 selection scores for each group 181 182 Returns: 183 `dict[frozenset[str], torch.Tensor]`: GW representation of each group 184 """ 185 return { 186 domains: self.gw_mod.fuse(latents, selection_scores[domains]) 187 for domains, latents in x.items() 188 } 189 190 def decode( 191 self, 192 z: Mapping[frozenset[str], torch.Tensor], 193 domains: Iterable[str] | None = None, 194 ) -> LatentsDomainGroupsDT: 195 """ 196 Decode the group GW representation into given `domains`. 197 198 Args: 199 z (`torch.Tensor`): the GW representation. 200 domains (`Iterable[str]`): iterable of domains to decode. 201 202 Returns: 203 `dict[str, torch.Tensor]`: the decoded unimodal representations. 204 """ 205 return { 206 domain_names: self.gw_mod.decode(gw_rep, domains) 207 for domain_names, gw_rep in z.items() 208 } 209 210 def forward( # type: ignore 211 self, 212 latent_domains: LatentsDomainGroupsT, 213 ) -> GWPredictionsBase: 214 """ 215 Computes demi-cycles, cycles, and translations. 216 217 Args: 218 latent_domains (`LatentsT`): Groups of domains for the computation. 219 220 Returns: 221 `GWPredictionsBase`: the predictions on the batch. 222 """ 223 224 return GWPredictionsBase(states=self.batch_gw_states(latent_domains)) 225 226 def batch_gw_states( 227 self, latent_domains: LatentsDomainGroupsT 228 ) -> dict[str, torch.Tensor]: 229 """ 230 Comptues GW states of a batch of groups of domains. 231 232 Args: 233 latent_domains (`LatentsT`): the batch of groups of domains 234 235 Returns: 236 `dict[str, torch.Tensor]`: states for each domain. 237 """ 238 predictions: dict[str, torch.Tensor] = {} 239 for domains, latents in latent_domains.items(): 240 if len(domains) > 1: 241 continue 242 domain_name = list(domains)[0] 243 z = self.gw_mod.encode_and_fuse( 244 latents, selection_module=self.selection_mod 245 ) 246 predictions[domain_name] = z 247 return predictions 248 249 def encode_domain(self, domain: Any, name: str) -> torch.Tensor: 250 """ 251 Encodes a domain from the domain data into the unimodal representation. 252 253 This is a convenient proxy for the `DomainModule.encode` method and is 254 equivalent to: 255 ```python 256 self.domain_mods[name].encode(domain) 257 ``` 258 259 Args: 260 domain (`Any`): the domain data 261 name (`str`): domain name to encode 262 263 Returns: 264 `torch.Tensor`: the domain's unimodal representation. 265 """ 266 return self.domain_mods[name].encode(domain) 267 268 def encode_domains(self, batch: RawDomainGroupsT) -> LatentsDomainGroupsDT: 269 """ 270 Encode all domains in the batch. 271 272 Args: 273 batch (`RawDomainGroupsT`): the batch of 274 domain groups with raw unimodal data to encode into groups of latent 275 representations. 276 277 Returns: 278 `LatentsDomainGroupsDT`: the domains' unimodal representations. 279 """ 280 return { 281 domains: { 282 name: self.domain_mods[name].encode(domain) 283 for name, domain in data.items() 284 } 285 for domains, data in batch.items() 286 } 287 288 def decode_domain(self, domain: torch.Tensor, name: str) -> Any: 289 """ 290 Decodes a domain from the unimodal representation into the domain data. 291 292 This is a convenient proxy for the `DomainModule.encode` method and is 293 equivalent to: 294 ```python 295 self.domain_mods[name].decode(domain) 296 ``` 297 298 Args: 299 domain (`torch.Tensor`): the domain data 300 name (`str`): domain name to encode 301 302 Returns: 303 `Any`: the domain's raw data. 304 """ 305 return self.domain_mods[name].decode(domain) 306 307 def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroupsDT: 308 """ 309 Decodes all domains in the batch. 310 311 Args: 312 batch (`LatentsDomainGroupsT`): the batch of 313 domain groups with unimodal latent representation to decode into 314 groups of raw data. 315 316 Returns: 317 `LatentsDomainGroupsDT`: the domains' raw data. 318 """ 319 return { 320 domains: { 321 name: self.domain_mods[name].decode(domain) 322 for name, domain in latents.items() 323 } 324 for domains, latents in latents_domain.items() 325 } 326 327 def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor: 328 """ 329 The generic step used in `training_step`, `validation_step` and 330 `test_step`. 331 332 Args: 333 batch (`RawDomainGroupsT`): the batch of groups of raw unimodal data. 334 mode (`ModelModeT`): 335 336 Returns: 337 `torch.Tensor`: the loss to train on. 338 """ 339 domain_latents = self.encode_domains(batch) 340 batch_size = groups_batch_size(domain_latents) 341 342 loss_output = self.loss_mod.step(domain_latents, mode) 343 344 for name, metric in loss_output.all.items(): 345 self.log( 346 f"{mode}/{name}", 347 metric, 348 batch_size=batch_size, 349 add_dataloader_idx=False, 350 ) 351 352 return loss_output.loss 353 354 def validation_step( # type: ignore 355 self, data: RawDomainGroupT, batch_idx: int, dataloader_idx: int = 0 356 ) -> torch.Tensor: 357 """Validation step used by lightning""" 358 359 batch = {frozenset(data.keys()): data} 360 for domain in data: 361 batch[frozenset([domain])] = {domain: data[domain]} 362 if dataloader_idx == 0: 363 return self.generic_step(batch, mode="val") 364 return self.generic_step(batch, mode="val/ood") 365 366 def test_step( # type: ignore 367 self, data: Mapping[str, Any], batch_idx: int, dataloader_idx: int = 0 368 ) -> torch.Tensor: 369 """Test step used by lightning""" 370 371 batch = {frozenset(data.keys()): data} 372 for domain in data: 373 batch[frozenset([domain])] = {domain: data[domain]} 374 if dataloader_idx == 0: 375 return self.generic_step(batch, mode="test") 376 return self.generic_step(batch, mode="test/ood") 377 378 def training_step( # type: ignore 379 self, batch: Mapping[frozenset[str], Mapping[str, Any]], batch_idx: int 380 ) -> torch.Tensor: 381 """Training step used by lightning""" 382 383 return self.generic_step(batch, mode="train") 384 385 def predict_step( # type: ignore 386 self, data: Mapping[str, Any], batch_idx: int 387 ) -> GWPredictionsBase: 388 """Predict step used by lightning""" 389 390 batch = {frozenset(data.keys()): data} 391 for domain in data: 392 batch[frozenset([domain])] = {domain: data[domain]} 393 394 domain_latents = self.encode_domains(batch) 395 return self.forward(domain_latents) 396 397 def configure_optimizers(self) -> OptimizerLRSchedulerConfig: 398 """ 399 Configure models optimizers. 400 401 Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate 402 scheduler. 403 """ 404 405 optimizer = torch.optim.AdamW( 406 self.parameters(), 407 lr=self.optim_lr, 408 weight_decay=self.optim_weight_decay, 409 ) 410 411 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 412 413 return { 414 "optimizer": optimizer, 415 "lr_scheduler": { 416 "scheduler": lr_scheduler, 417 "interval": "step", 418 }, 419 }
Global Workspace Lightning Module.
This is the base class to build the Global Workspace.
134 @property 135 def workspace_dim(self) -> int: 136 """Dimension of the GW.""" 137 return self.gw_mod.workspace_dim
Dimension of the GW.
139 def encode_and_fuse( 140 self, x: LatentsDomainGroupsT, selection_module: SelectionBase 141 ) -> dict[frozenset[str], torch.Tensor]: 142 """ 143 Encode a group of latent representations into the GW representation. 144 145 Args: 146 x (`LatentsDomainGroupsT`): the input domain representations. 147 selection_scores (`Mapping[str, torch.Tensor]`): 148 149 Returns: 150 `dict[frozenset[str], torch.Tensor]`: the GW representations. 151 """ 152 return { 153 domains: self.gw_mod.encode_and_fuse(latents, selection_module) 154 for domains, latents in x.items() 155 }
Encode a group of latent representations into the GW representation.
Arguments:
- x (
LatentsDomainGroupsT
): the input domain representations. - selection_scores (
Mapping[str, torch.Tensor]
):
Returns:
dict[frozenset[str], torch.Tensor]
: the GW representations.
157 def encode(self, x: LatentsDomainGroupsT) -> LatentsDomainGroupsDT: 158 """ 159 Encode a group of latent representations into the pre-fusion GW representation. 160 161 Args: 162 x (`LatentsDomainGroupsT`): the input domain representations. 163 164 Returns: 165 `LatensDomainGroupsDT`: the GW representations. 166 """ 167 return {domains: self.gw_mod.encode(latents) for domains, latents in x.items()}
Encode a group of latent representations into the pre-fusion GW representation.
Arguments:
- x (
LatentsDomainGroupsT
): the input domain representations.
Returns:
LatensDomainGroupsDT
: the GW representations.
169 def fuse( 170 self, 171 x: LatentsDomainGroupsT, 172 selection_scores: Mapping[frozenset[str], Mapping[str, torch.Tensor]], 173 ) -> dict[frozenset[str], torch.Tensor]: 174 """ 175 Fuses a group of latent representations into the GW representation. 176 177 Args: 178 x (`LatentsDomainGroupsT`): the pre-fusion latent representations 179 selection_scores (`Mapping[frozenset[str], Mapping[str, torch.Tensor]]`): 180 selection scores for each group 181 182 Returns: 183 `dict[frozenset[str], torch.Tensor]`: GW representation of each group 184 """ 185 return { 186 domains: self.gw_mod.fuse(latents, selection_scores[domains]) 187 for domains, latents in x.items() 188 }
Fuses a group of latent representations into the GW representation.
Arguments:
- x (
LatentsDomainGroupsT
): the pre-fusion latent representations - selection_scores (
Mapping[frozenset[str], Mapping[str, torch.Tensor]]
): selection scores for each group
Returns:
dict[frozenset[str], torch.Tensor]
: GW representation of each group
190 def decode( 191 self, 192 z: Mapping[frozenset[str], torch.Tensor], 193 domains: Iterable[str] | None = None, 194 ) -> LatentsDomainGroupsDT: 195 """ 196 Decode the group GW representation into given `domains`. 197 198 Args: 199 z (`torch.Tensor`): the GW representation. 200 domains (`Iterable[str]`): iterable of domains to decode. 201 202 Returns: 203 `dict[str, torch.Tensor]`: the decoded unimodal representations. 204 """ 205 return { 206 domain_names: self.gw_mod.decode(gw_rep, domains) 207 for domain_names, gw_rep in z.items() 208 }
Decode the group GW representation into given domains
.
Arguments:
- z (
torch.Tensor
): the GW representation. - domains (
Iterable[str]
): iterable of domains to decode.
Returns:
dict[str, torch.Tensor]
: the decoded unimodal representations.
226 def batch_gw_states( 227 self, latent_domains: LatentsDomainGroupsT 228 ) -> dict[str, torch.Tensor]: 229 """ 230 Comptues GW states of a batch of groups of domains. 231 232 Args: 233 latent_domains (`LatentsT`): the batch of groups of domains 234 235 Returns: 236 `dict[str, torch.Tensor]`: states for each domain. 237 """ 238 predictions: dict[str, torch.Tensor] = {} 239 for domains, latents in latent_domains.items(): 240 if len(domains) > 1: 241 continue 242 domain_name = list(domains)[0] 243 z = self.gw_mod.encode_and_fuse( 244 latents, selection_module=self.selection_mod 245 ) 246 predictions[domain_name] = z 247 return predictions
Comptues GW states of a batch of groups of domains.
Arguments:
- latent_domains (
LatentsT
): the batch of groups of domains
Returns:
dict[str, torch.Tensor]
: states for each domain.
249 def encode_domain(self, domain: Any, name: str) -> torch.Tensor: 250 """ 251 Encodes a domain from the domain data into the unimodal representation. 252 253 This is a convenient proxy for the `DomainModule.encode` method and is 254 equivalent to: 255 ```python 256 self.domain_mods[name].encode(domain) 257 ``` 258 259 Args: 260 domain (`Any`): the domain data 261 name (`str`): domain name to encode 262 263 Returns: 264 `torch.Tensor`: the domain's unimodal representation. 265 """ 266 return self.domain_mods[name].encode(domain)
Encodes a domain from the domain data into the unimodal representation.
This is a convenient proxy for the DomainModule.encode
method and is
equivalent to:
self.domain_mods[name].encode(domain)
Arguments:
- domain (
Any
): the domain data - name (
str
): domain name to encode
Returns:
torch.Tensor
: the domain's unimodal representation.
268 def encode_domains(self, batch: RawDomainGroupsT) -> LatentsDomainGroupsDT: 269 """ 270 Encode all domains in the batch. 271 272 Args: 273 batch (`RawDomainGroupsT`): the batch of 274 domain groups with raw unimodal data to encode into groups of latent 275 representations. 276 277 Returns: 278 `LatentsDomainGroupsDT`: the domains' unimodal representations. 279 """ 280 return { 281 domains: { 282 name: self.domain_mods[name].encode(domain) 283 for name, domain in data.items() 284 } 285 for domains, data in batch.items() 286 }
Encode all domains in the batch.
Arguments:
- batch (
RawDomainGroupsT
): the batch of domain groups with raw unimodal data to encode into groups of latent representations.
Returns:
LatentsDomainGroupsDT
: the domains' unimodal representations.
288 def decode_domain(self, domain: torch.Tensor, name: str) -> Any: 289 """ 290 Decodes a domain from the unimodal representation into the domain data. 291 292 This is a convenient proxy for the `DomainModule.encode` method and is 293 equivalent to: 294 ```python 295 self.domain_mods[name].decode(domain) 296 ``` 297 298 Args: 299 domain (`torch.Tensor`): the domain data 300 name (`str`): domain name to encode 301 302 Returns: 303 `Any`: the domain's raw data. 304 """ 305 return self.domain_mods[name].decode(domain)
Decodes a domain from the unimodal representation into the domain data.
This is a convenient proxy for the DomainModule.encode
method and is
equivalent to:
self.domain_mods[name].decode(domain)
Arguments:
- domain (
torch.Tensor
): the domain data - name (
str
): domain name to encode
Returns:
Any
: the domain's raw data.
307 def decode_domains(self, latents_domain: LatentsDomainGroupsT) -> RawDomainGroupsDT: 308 """ 309 Decodes all domains in the batch. 310 311 Args: 312 batch (`LatentsDomainGroupsT`): the batch of 313 domain groups with unimodal latent representation to decode into 314 groups of raw data. 315 316 Returns: 317 `LatentsDomainGroupsDT`: the domains' raw data. 318 """ 319 return { 320 domains: { 321 name: self.domain_mods[name].decode(domain) 322 for name, domain in latents.items() 323 } 324 for domains, latents in latents_domain.items() 325 }
Decodes all domains in the batch.
Arguments:
- batch (
LatentsDomainGroupsT
): the batch of domain groups with unimodal latent representation to decode into groups of raw data.
Returns:
LatentsDomainGroupsDT
: the domains' raw data.
327 def generic_step(self, batch: RawDomainGroupsT, mode: ModelModeT) -> torch.Tensor: 328 """ 329 The generic step used in `training_step`, `validation_step` and 330 `test_step`. 331 332 Args: 333 batch (`RawDomainGroupsT`): the batch of groups of raw unimodal data. 334 mode (`ModelModeT`): 335 336 Returns: 337 `torch.Tensor`: the loss to train on. 338 """ 339 domain_latents = self.encode_domains(batch) 340 batch_size = groups_batch_size(domain_latents) 341 342 loss_output = self.loss_mod.step(domain_latents, mode) 343 344 for name, metric in loss_output.all.items(): 345 self.log( 346 f"{mode}/{name}", 347 metric, 348 batch_size=batch_size, 349 add_dataloader_idx=False, 350 ) 351 352 return loss_output.loss
The generic step used in training_step
, validation_step
and
test_step
.
Arguments:
- batch (
RawDomainGroupsT
): the batch of groups of raw unimodal data. - mode (
ModelModeT
):
Returns:
torch.Tensor
: the loss to train on.
Inherited Members
- lightning.pytorch.core.module.LightningModule
- LightningModule
- CHECKPOINT_HYPER_PARAMS_KEY
- CHECKPOINT_HYPER_PARAMS_NAME
- CHECKPOINT_HYPER_PARAMS_TYPE
- optimizers
- lr_schedulers
- trainer
- fabric
- example_input_array
- current_epoch
- global_step
- global_rank
- local_rank
- on_gpu
- automatic_optimization
- strict_loading
- logger
- loggers
- log
- log_dict
- all_gather
- forward
- training_step
- validation_step
- test_step
- predict_step
- configure_callbacks
- configure_optimizers
- manual_backward
- backward
- toggle_optimizer
- untoggle_optimizer
- clip_gradients
- configure_gradient_clipping
- lr_scheduler_step
- optimizer_step
- optimizer_zero_grad
- freeze
- unfreeze
- to_onnx
- to_torchscript
- load_from_checkpoint
- lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
- dtype
- device
- to
- cuda
- cpu
- type
- float
- double
- half
- lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
- save_hyperparameters
- hparams
- hparams_initial
- lightning.pytorch.core.hooks.ModelHooks
- on_fit_start
- on_fit_end
- on_train_start
- on_train_end
- on_validation_start
- on_validation_end
- on_test_start
- on_test_end
- on_predict_start
- on_predict_end
- on_train_batch_start
- on_train_batch_end
- on_validation_batch_start
- on_validation_batch_end
- on_test_batch_start
- on_test_batch_end
- on_predict_batch_start
- on_predict_batch_end
- on_validation_model_zero_grad
- on_validation_model_eval
- on_validation_model_train
- on_test_model_eval
- on_test_model_train
- on_predict_model_eval
- on_train_epoch_start
- on_train_epoch_end
- on_validation_epoch_start
- on_validation_epoch_end
- on_test_epoch_start
- on_test_epoch_end
- on_predict_epoch_start
- on_predict_epoch_end
- on_before_zero_grad
- on_before_backward
- on_after_backward
- on_before_optimizer_step
- configure_sharded_model
- configure_model
- lightning.pytorch.core.hooks.DataHooks
- prepare_data_per_node
- allow_zero_length_dataloader_with_multiple_devices
- prepare_data
- setup
- teardown
- train_dataloader
- test_dataloader
- val_dataloader
- predict_dataloader
- transfer_batch_to_device
- on_before_batch_transfer
- on_after_batch_transfer
- lightning.pytorch.core.hooks.CheckpointHooks
- on_load_checkpoint
- on_save_checkpoint
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- ipu
- xpu
- bfloat16
- to_empty
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
422def freeze_domain_modules( 423 domain_mods: Mapping[str, DomainModule], 424) -> dict[str, DomainModule]: 425 """ 426 Freezes weights and set to eval mode the domain modules. 427 428 .. note:: 429 The output is casted as `dict[str, DomainModule]` type for better 430 auto-completion, but is actually a torch `ModuleDict`. 431 432 Args: 433 domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze 434 435 Returns: 436 `ModuleDict`: frozen modules. 437 """ 438 439 for mod in domain_mods.values(): 440 mod.freeze() 441 # Cast for better auto-completion at the expense of ModuleDict 442 return cast(dict[str, DomainModule], ModuleDict(domain_mods))
Freezes weights and set to eval mode the domain modules.
The output is casted as dict[str, DomainModule]
type for better
auto-completion, but is actually a torch ModuleDict
.
Arguments:
- domain_mods (
Mapping[str, DomainModule]
): mapping of domain modules to freeze
Returns:
ModuleDict
: frozen modules.
445class GWPredictions(GWPredictionsBase): 446 """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`""" 447 448 demi_cycles: dict[str, torch.Tensor] 449 """ 450 Demi-cycle predictions of the model for each domain. Only computed on domain 451 groups with only one domain. 452 """ 453 454 cycles: dict[tuple[str, str], torch.Tensor] 455 """ 456 Cycle predictions of the model from one domain through another one. 457 Only computed on domain groups with more than one domain. 458 The keys are tuple with start domain and intermediary domain. 459 """ 460 461 translations: dict[tuple[str, str], torch.Tensor] 462 """ 463 Translation predictions of the model from one domain through another one. 464 465 Only computed on domain groups with more than one domain. 466 The keys are tuples with start domain and target domain. 467 """
TypedDict of the output given when calling GlobalWorkspaceBase.predict
Demi-cycle predictions of the model for each domain. Only computed on domain groups with only one domain.
Cycle predictions of the model from one domain through another one. Only computed on domain groups with more than one domain. The keys are tuple with start domain and intermediary domain.
Translation predictions of the model from one domain through another one.
Only computed on domain groups with more than one domain. The keys are tuples with start domain and target domain.
Inherited Members
- builtins.dict
- get
- setdefault
- pop
- popitem
- keys
- items
- values
- update
- fromkeys
- clear
- copy
470class GlobalWorkspace2Domains( 471 GlobalWorkspaceBase[GWModule, SingleDomainSelection, GWLosses2Domains] 472): 473 """ 474 A simple 2-domains max flavor of GlobalWorkspaceBase. 475 476 This is used to simplify a Global Workspace instanciation and only overrides the 477 `__init__` method. 478 """ 479 480 def __init__( 481 self, 482 domain_mods: Mapping[str, DomainModule], 483 gw_encoders: Mapping[str, Module], 484 gw_decoders: Mapping[str, Module], 485 workspace_dim: int, 486 loss_coefs: LossCoefs, 487 optim_lr: float = 1e-3, 488 optim_weight_decay: float = 0.0, 489 scheduler_args: SchedulerArgs | None = None, 490 learn_logit_scale: bool = False, 491 contrastive_loss: ContrastiveLossType | None = None, 492 ) -> None: 493 """ 494 Initializes a Global Workspace 495 496 Args: 497 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains 498 connected to the GW. Keys are domain names, values are the 499 `DomainModule`. 500 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 501 name to a `torch.nn.Module` class which role is to encode a 502 unimodal latent representations into a GW representation (pre fusion). 503 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 504 name to a `torch.nn.Module` class which role is to decode a 505 GW representation into a unimodal latent representations. 506 workspace_dim (`int`): dimension of the GW. 507 loss_coefs (`LossCoefs`): loss coefficients 508 optim_lr (`float`): learning rate 509 optim_weight_decay (`float`): weight decay 510 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments 511 learn_logit_scale (`bool`): whether to learn the contrastive learning 512 contrastive loss when using the default contrastive loss. 513 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss 514 function used for alignment. `learn_logit_scale` will not affect custom 515 contrastive losses. 516 """ 517 domain_mods = freeze_domain_modules(domain_mods) 518 519 gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) 520 if contrastive_loss is None: 521 contrastive_loss = ContrastiveLoss( 522 torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale 523 ) 524 selection_mod = SingleDomainSelection() 525 loss_mod = GWLosses2Domains( 526 gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss 527 ) 528 529 super().__init__( 530 gw_mod, 531 selection_mod, 532 loss_mod, 533 optim_lr, 534 optim_weight_decay, 535 scheduler_args, 536 ) 537 538 def forward( # type: ignore 539 self, 540 latent_domains: LatentsDomainGroupsT, 541 ) -> GWPredictions: 542 """ 543 Computes demi-cycles, cycles, and translations. 544 545 Args: 546 latent_domains (`LatentsT`): Groups of domains for the computation. 547 548 Returns: 549 `GWPredictions`: the predictions on the batch. 550 """ 551 return GWPredictions( 552 demi_cycles=batch_demi_cycles( 553 self.gw_mod, self.selection_mod, latent_domains 554 ), 555 cycles=batch_cycles( 556 self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() 557 ), 558 translations=batch_translations( 559 self.gw_mod, self.selection_mod, latent_domains 560 ), 561 **super().forward(latent_domains), 562 )
A simple 2-domains max flavor of GlobalWorkspaceBase.
This is used to simplify a Global Workspace instanciation and only overrides the
__init__
method.
480 def __init__( 481 self, 482 domain_mods: Mapping[str, DomainModule], 483 gw_encoders: Mapping[str, Module], 484 gw_decoders: Mapping[str, Module], 485 workspace_dim: int, 486 loss_coefs: LossCoefs, 487 optim_lr: float = 1e-3, 488 optim_weight_decay: float = 0.0, 489 scheduler_args: SchedulerArgs | None = None, 490 learn_logit_scale: bool = False, 491 contrastive_loss: ContrastiveLossType | None = None, 492 ) -> None: 493 """ 494 Initializes a Global Workspace 495 496 Args: 497 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains 498 connected to the GW. Keys are domain names, values are the 499 `DomainModule`. 500 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 501 name to a `torch.nn.Module` class which role is to encode a 502 unimodal latent representations into a GW representation (pre fusion). 503 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 504 name to a `torch.nn.Module` class which role is to decode a 505 GW representation into a unimodal latent representations. 506 workspace_dim (`int`): dimension of the GW. 507 loss_coefs (`LossCoefs`): loss coefficients 508 optim_lr (`float`): learning rate 509 optim_weight_decay (`float`): weight decay 510 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments 511 learn_logit_scale (`bool`): whether to learn the contrastive learning 512 contrastive loss when using the default contrastive loss. 513 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss 514 function used for alignment. `learn_logit_scale` will not affect custom 515 contrastive losses. 516 """ 517 domain_mods = freeze_domain_modules(domain_mods) 518 519 gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) 520 if contrastive_loss is None: 521 contrastive_loss = ContrastiveLoss( 522 torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale 523 ) 524 selection_mod = SingleDomainSelection() 525 loss_mod = GWLosses2Domains( 526 gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss 527 ) 528 529 super().__init__( 530 gw_mod, 531 selection_mod, 532 loss_mod, 533 optim_lr, 534 optim_weight_decay, 535 scheduler_args, 536 )
Initializes a Global Workspace
Arguments:
- domain_mods (
Mapping[str, DomainModule]
): mapping of the domains connected to the GW. Keys are domain names, values are theDomainModule
. - gw_encoders (
Mapping[str, torch.nn.Module]
): mapping for each domain name to atorch.nn.Module
class which role is to encode a unimodal latent representations into a GW representation (pre fusion). - gw_decoders (
Mapping[str, torch.nn.Module]
): mapping for each domain name to atorch.nn.Module
class which role is to decode a GW representation into a unimodal latent representations. - workspace_dim (
int
): dimension of the GW. - loss_coefs (
LossCoefs
): loss coefficients - optim_lr (
float
): learning rate - optim_weight_decay (
float
): weight decay - scheduler_args (
SchedulerArgs | None
): optimization scheduler's arguments - learn_logit_scale (
bool
): whether to learn the contrastive learning contrastive loss when using the default contrastive loss. - contrastive_loss (
ContrastiveLossType | None
): a contrastive loss function used for alignment.learn_logit_scale
will not affect custom contrastive losses.
538 def forward( # type: ignore 539 self, 540 latent_domains: LatentsDomainGroupsT, 541 ) -> GWPredictions: 542 """ 543 Computes demi-cycles, cycles, and translations. 544 545 Args: 546 latent_domains (`LatentsT`): Groups of domains for the computation. 547 548 Returns: 549 `GWPredictions`: the predictions on the batch. 550 """ 551 return GWPredictions( 552 demi_cycles=batch_demi_cycles( 553 self.gw_mod, self.selection_mod, latent_domains 554 ), 555 cycles=batch_cycles( 556 self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() 557 ), 558 translations=batch_translations( 559 self.gw_mod, self.selection_mod, latent_domains 560 ), 561 **super().forward(latent_domains), 562 )
Computes demi-cycles, cycles, and translations.
Arguments:
- latent_domains (
LatentsT
): Groups of domains for the computation.
Returns:
GWPredictions
: the predictions on the batch.
Inherited Members
- GlobalWorkspaceBase
- gw_mod
- selection_mod
- loss_mod
- optim_lr
- optim_weight_decay
- scheduler_args
- domain_mods
- workspace_dim
- encode_and_fuse
- encode
- fuse
- decode
- batch_gw_states
- encode_domain
- encode_domains
- decode_domain
- decode_domains
- generic_step
- validation_step
- test_step
- training_step
- predict_step
- configure_optimizers
- lightning.pytorch.core.module.LightningModule
- CHECKPOINT_HYPER_PARAMS_KEY
- CHECKPOINT_HYPER_PARAMS_NAME
- CHECKPOINT_HYPER_PARAMS_TYPE
- optimizers
- lr_schedulers
- trainer
- fabric
- example_input_array
- current_epoch
- global_step
- global_rank
- local_rank
- on_gpu
- automatic_optimization
- strict_loading
- logger
- loggers
- log
- log_dict
- all_gather
- configure_callbacks
- manual_backward
- backward
- toggle_optimizer
- untoggle_optimizer
- clip_gradients
- configure_gradient_clipping
- lr_scheduler_step
- optimizer_step
- optimizer_zero_grad
- freeze
- unfreeze
- to_onnx
- to_torchscript
- load_from_checkpoint
- lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
- dtype
- device
- to
- cuda
- cpu
- type
- float
- double
- half
- lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
- save_hyperparameters
- hparams
- hparams_initial
- lightning.pytorch.core.hooks.ModelHooks
- on_fit_start
- on_fit_end
- on_train_start
- on_train_end
- on_validation_start
- on_validation_end
- on_test_start
- on_test_end
- on_predict_start
- on_predict_end
- on_train_batch_start
- on_train_batch_end
- on_validation_batch_start
- on_validation_batch_end
- on_test_batch_start
- on_test_batch_end
- on_predict_batch_start
- on_predict_batch_end
- on_validation_model_zero_grad
- on_validation_model_eval
- on_validation_model_train
- on_test_model_eval
- on_test_model_train
- on_predict_model_eval
- on_train_epoch_start
- on_train_epoch_end
- on_validation_epoch_start
- on_validation_epoch_end
- on_test_epoch_start
- on_test_epoch_end
- on_predict_epoch_start
- on_predict_epoch_end
- on_before_zero_grad
- on_before_backward
- on_after_backward
- on_before_optimizer_step
- configure_sharded_model
- configure_model
- lightning.pytorch.core.hooks.DataHooks
- prepare_data_per_node
- allow_zero_length_dataloader_with_multiple_devices
- prepare_data
- setup
- teardown
- train_dataloader
- test_dataloader
- val_dataloader
- predict_dataloader
- transfer_batch_to_device
- on_before_batch_transfer
- on_after_batch_transfer
- lightning.pytorch.core.hooks.CheckpointHooks
- on_load_checkpoint
- on_save_checkpoint
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- ipu
- xpu
- bfloat16
- to_empty
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
565class GlobalWorkspace(GlobalWorkspaceBase[GWModule, RandomSelection, GWLosses]): 566 """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase. 567 568 This is used to simplify a Global Workspace instanciation and only overrides the 569 `__init__` method. 570 """ 571 572 def __init__( 573 self, 574 domain_mods: Mapping[str, DomainModule], 575 gw_encoders: Mapping[str, Module], 576 gw_decoders: Mapping[str, Module], 577 workspace_dim: int, 578 loss_coefs: BroadcastLossCoefs, 579 selection_temperature: float = 0.2, 580 optim_lr: float = 1e-3, 581 optim_weight_decay: float = 0.0, 582 scheduler_args: SchedulerArgs | None = None, 583 learn_logit_scale: bool = False, 584 contrastive_loss: ContrastiveLossType | None = None, 585 ) -> None: 586 """ 587 Initializes a Global Workspace 588 589 Args: 590 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains 591 connected to the GW. Keys are domain names, values are the 592 `DomainModule`. 593 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 594 name to a `torch.nn.Module` class which role is to encode a 595 unimodal latent representations into a GW representation (pre fusion). 596 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 597 name to a `torch.nn.Module` class which role is to decode a 598 GW representation into a unimodal latent representations. 599 workspace_dim (`int`): dimension of the GW. 600 loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses. 601 selection_temperature (`float`): temperature value for the RandomSelection 602 module. 603 optim_lr (`float`): learning rate 604 optim_weight_decay (`float`): weight decay 605 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments 606 learn_logit_scale (`bool`): whether to learn the contrastive learning 607 contrastive loss when using the default contrastive loss. 608 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss 609 function used for alignment. `learn_logit_scale` will not affect custom 610 contrastive losses. 611 """ 612 domain_mods = freeze_domain_modules(domain_mods) 613 gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) 614 615 if contrastive_loss is None: 616 contrastive_loss = ContrastiveLoss( 617 torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale 618 ) 619 620 selection_mod = RandomSelection(selection_temperature) 621 loss_mod = GWLosses( 622 gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss 623 ) 624 625 super().__init__( 626 gw_mod, 627 selection_mod, 628 loss_mod, 629 optim_lr, 630 optim_weight_decay, 631 scheduler_args, 632 ) 633 634 def forward( # type: ignore 635 self, 636 latent_domains: LatentsDomainGroupsT, 637 ) -> GWPredictions: 638 """ 639 Computes demi-cycles, cycles, and translations. 640 641 Args: 642 latent_domains (`LatentsT`): Groups of domains for the computation. 643 644 Returns: 645 `GWPredictions`: the predictions on the batch. 646 """ 647 return GWPredictions( 648 demi_cycles=batch_demi_cycles( 649 self.gw_mod, self.selection_mod, latent_domains 650 ), 651 cycles=batch_cycles( 652 self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() 653 ), 654 translations=batch_translations( 655 self.gw_mod, self.selection_mod, latent_domains 656 ), 657 # TODO: add other combinations 658 **super().forward(latent_domains), 659 )
The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
This is used to simplify a Global Workspace instanciation and only overrides the
__init__
method.
572 def __init__( 573 self, 574 domain_mods: Mapping[str, DomainModule], 575 gw_encoders: Mapping[str, Module], 576 gw_decoders: Mapping[str, Module], 577 workspace_dim: int, 578 loss_coefs: BroadcastLossCoefs, 579 selection_temperature: float = 0.2, 580 optim_lr: float = 1e-3, 581 optim_weight_decay: float = 0.0, 582 scheduler_args: SchedulerArgs | None = None, 583 learn_logit_scale: bool = False, 584 contrastive_loss: ContrastiveLossType | None = None, 585 ) -> None: 586 """ 587 Initializes a Global Workspace 588 589 Args: 590 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains 591 connected to the GW. Keys are domain names, values are the 592 `DomainModule`. 593 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 594 name to a `torch.nn.Module` class which role is to encode a 595 unimodal latent representations into a GW representation (pre fusion). 596 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 597 name to a `torch.nn.Module` class which role is to decode a 598 GW representation into a unimodal latent representations. 599 workspace_dim (`int`): dimension of the GW. 600 loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses. 601 selection_temperature (`float`): temperature value for the RandomSelection 602 module. 603 optim_lr (`float`): learning rate 604 optim_weight_decay (`float`): weight decay 605 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments 606 learn_logit_scale (`bool`): whether to learn the contrastive learning 607 contrastive loss when using the default contrastive loss. 608 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss 609 function used for alignment. `learn_logit_scale` will not affect custom 610 contrastive losses. 611 """ 612 domain_mods = freeze_domain_modules(domain_mods) 613 gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) 614 615 if contrastive_loss is None: 616 contrastive_loss = ContrastiveLoss( 617 torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale 618 ) 619 620 selection_mod = RandomSelection(selection_temperature) 621 loss_mod = GWLosses( 622 gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss 623 ) 624 625 super().__init__( 626 gw_mod, 627 selection_mod, 628 loss_mod, 629 optim_lr, 630 optim_weight_decay, 631 scheduler_args, 632 )
Initializes a Global Workspace
Arguments:
- domain_mods (
Mapping[str, DomainModule]
): mapping of the domains connected to the GW. Keys are domain names, values are theDomainModule
. - gw_encoders (
Mapping[str, torch.nn.Module]
): mapping for each domain name to atorch.nn.Module
class which role is to encode a unimodal latent representations into a GW representation (pre fusion). - gw_decoders (
Mapping[str, torch.nn.Module]
): mapping for each domain name to atorch.nn.Module
class which role is to decode a GW representation into a unimodal latent representations. - workspace_dim (
int
): dimension of the GW. - loss_coefs (
BroadcastLossCoefs
): loss coefs for the losses. - selection_temperature (
float
): temperature value for the RandomSelection module. - optim_lr (
float
): learning rate - optim_weight_decay (
float
): weight decay - scheduler_args (
SchedulerArgs | None
): optimization scheduler's arguments - learn_logit_scale (
bool
): whether to learn the contrastive learning contrastive loss when using the default contrastive loss. - contrastive_loss (
ContrastiveLossType | None
): a contrastive loss function used for alignment.learn_logit_scale
will not affect custom contrastive losses.
634 def forward( # type: ignore 635 self, 636 latent_domains: LatentsDomainGroupsT, 637 ) -> GWPredictions: 638 """ 639 Computes demi-cycles, cycles, and translations. 640 641 Args: 642 latent_domains (`LatentsT`): Groups of domains for the computation. 643 644 Returns: 645 `GWPredictions`: the predictions on the batch. 646 """ 647 return GWPredictions( 648 demi_cycles=batch_demi_cycles( 649 self.gw_mod, self.selection_mod, latent_domains 650 ), 651 cycles=batch_cycles( 652 self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() 653 ), 654 translations=batch_translations( 655 self.gw_mod, self.selection_mod, latent_domains 656 ), 657 # TODO: add other combinations 658 **super().forward(latent_domains), 659 )
Computes demi-cycles, cycles, and translations.
Arguments:
- latent_domains (
LatentsT
): Groups of domains for the computation.
Returns:
GWPredictions
: the predictions on the batch.
Inherited Members
- GlobalWorkspaceBase
- gw_mod
- selection_mod
- loss_mod
- optim_lr
- optim_weight_decay
- scheduler_args
- domain_mods
- workspace_dim
- encode_and_fuse
- encode
- fuse
- decode
- batch_gw_states
- encode_domain
- encode_domains
- decode_domain
- decode_domains
- generic_step
- validation_step
- test_step
- training_step
- predict_step
- configure_optimizers
- lightning.pytorch.core.module.LightningModule
- CHECKPOINT_HYPER_PARAMS_KEY
- CHECKPOINT_HYPER_PARAMS_NAME
- CHECKPOINT_HYPER_PARAMS_TYPE
- optimizers
- lr_schedulers
- trainer
- fabric
- example_input_array
- current_epoch
- global_step
- global_rank
- local_rank
- on_gpu
- automatic_optimization
- strict_loading
- logger
- loggers
- log
- log_dict
- all_gather
- configure_callbacks
- manual_backward
- backward
- toggle_optimizer
- untoggle_optimizer
- clip_gradients
- configure_gradient_clipping
- lr_scheduler_step
- optimizer_step
- optimizer_zero_grad
- freeze
- unfreeze
- to_onnx
- to_torchscript
- load_from_checkpoint
- lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
- dtype
- device
- to
- cuda
- cpu
- type
- float
- double
- half
- lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
- save_hyperparameters
- hparams
- hparams_initial
- lightning.pytorch.core.hooks.ModelHooks
- on_fit_start
- on_fit_end
- on_train_start
- on_train_end
- on_validation_start
- on_validation_end
- on_test_start
- on_test_end
- on_predict_start
- on_predict_end
- on_train_batch_start
- on_train_batch_end
- on_validation_batch_start
- on_validation_batch_end
- on_test_batch_start
- on_test_batch_end
- on_predict_batch_start
- on_predict_batch_end
- on_validation_model_zero_grad
- on_validation_model_eval
- on_validation_model_train
- on_test_model_eval
- on_test_model_train
- on_predict_model_eval
- on_train_epoch_start
- on_train_epoch_end
- on_validation_epoch_start
- on_validation_epoch_end
- on_test_epoch_start
- on_test_epoch_end
- on_predict_epoch_start
- on_predict_epoch_end
- on_before_zero_grad
- on_before_backward
- on_after_backward
- on_before_optimizer_step
- configure_sharded_model
- configure_model
- lightning.pytorch.core.hooks.DataHooks
- prepare_data_per_node
- allow_zero_length_dataloader_with_multiple_devices
- prepare_data
- setup
- teardown
- train_dataloader
- test_dataloader
- val_dataloader
- predict_dataloader
- transfer_batch_to_device
- on_before_batch_transfer
- on_after_batch_transfer
- lightning.pytorch.core.hooks.CheckpointHooks
- on_load_checkpoint
- on_save_checkpoint
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- ipu
- xpu
- bfloat16
- to_empty
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
662class GlobalWorkspaceBayesian( 663 GlobalWorkspaceBase[GWModuleBayesian, FixedSharedSelection, GWLossesBayesian] 664): 665 """ 666 A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty 667 prediction. 668 669 This is used to simplify a Global Workspace instanciation and only overrides the 670 `__init__` method. 671 """ 672 673 def __init__( 674 self, 675 domain_mods: Mapping[str, DomainModule], 676 gw_encoders: Mapping[str, Module], 677 gw_decoders: Mapping[str, Module], 678 workspace_dim: int, 679 loss_coefs: BroadcastLossCoefs, 680 sensitivity_selection: float = 1, 681 sensitivity_precision: float = 1, 682 optim_lr: float = 1e-3, 683 optim_weight_decay: float = 0.0, 684 scheduler_args: SchedulerArgs | None = None, 685 learn_logit_scale: bool = False, 686 use_normalized_constrastive: bool = True, 687 contrastive_loss: ContrastiveLossType | None = None, 688 precision_softmax_temp: float = 0.01, 689 ) -> None: 690 """ 691 Initializes a Global Workspace 692 693 Args: 694 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains 695 connected to the GW. Keys are domain names, values are the 696 `DomainModule`. 697 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 698 name to a `torch.nn.Module` class which role is to encode a 699 unimodal latent representations into a GW representation (pre fusion). 700 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 701 name to a `torch.nn.Module` class which role is to decode a 702 GW representation into a unimodal latent representations. 703 workspace_dim (`int`): dimension of the GW. 704 loss_coefs (`LossCoefs`): loss coefficients 705 sensitivity_selection (`float`): sensivity coef $c'_1$ 706 sensitivity_precision (`float`): sensitivity coef $c'_2$ 707 optim_lr (`float`): learning rate 708 optim_weight_decay (`float`): weight decay 709 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments 710 learn_logit_scale (`bool`): whether to learn the contrastive learning 711 contrastive loss when using the default contrastive loss. 712 use_normalized_constrastive (`bool`): whether to use the normalized cont 713 loss by the precision coefs 714 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss 715 function used for alignment. `learn_logit_scale` will not affect custom 716 contrastive losses. 717 precision_softmax_temp (`float`): temperature to use in softmax of 718 precision 719 """ 720 domain_mods = freeze_domain_modules(domain_mods) 721 722 gw_mod = GWModuleBayesian( 723 domain_mods, 724 workspace_dim, 725 gw_encoders, 726 gw_decoders, 727 sensitivity_selection, 728 sensitivity_precision, 729 precision_softmax_temp, 730 ) 731 732 selection_mod = FixedSharedSelection() 733 734 contrastive_loss = ContrastiveLoss( 735 torch.tensor([1]).log(), "mean", learn_logit_scale 736 ) 737 738 loss_mod = GWLossesBayesian( 739 gw_mod, 740 selection_mod, 741 domain_mods, 742 loss_coefs, 743 contrastive_loss, 744 use_normalized_constrastive, 745 ) 746 747 super().__init__( 748 gw_mod, 749 selection_mod, 750 loss_mod, 751 optim_lr, 752 optim_weight_decay, 753 scheduler_args, 754 ) 755 756 def forward( # type: ignore 757 self, 758 latent_domains: LatentsDomainGroupsT, 759 ) -> GWPredictions: 760 """ 761 Computes demi-cycles, cycles, and translations. 762 763 Args: 764 latent_domains (`LatentsT`): Groups of domains for the computation. 765 766 Returns: 767 `GWPredictions`: the predictions on the batch. 768 """ 769 return GWPredictions( 770 demi_cycles=batch_demi_cycles( 771 self.gw_mod, self.selection_mod, latent_domains 772 ), 773 cycles=batch_cycles( 774 self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() 775 ), 776 translations=batch_translations( 777 self.gw_mod, self.selection_mod, latent_domains 778 ), 779 **super().forward(latent_domains), 780 )
A simple 2-domains max GlobalWorkspaceBase with a Bayesian base uncertainty prediction.
This is used to simplify a Global Workspace instanciation and only overrides the
__init__
method.
673 def __init__( 674 self, 675 domain_mods: Mapping[str, DomainModule], 676 gw_encoders: Mapping[str, Module], 677 gw_decoders: Mapping[str, Module], 678 workspace_dim: int, 679 loss_coefs: BroadcastLossCoefs, 680 sensitivity_selection: float = 1, 681 sensitivity_precision: float = 1, 682 optim_lr: float = 1e-3, 683 optim_weight_decay: float = 0.0, 684 scheduler_args: SchedulerArgs | None = None, 685 learn_logit_scale: bool = False, 686 use_normalized_constrastive: bool = True, 687 contrastive_loss: ContrastiveLossType | None = None, 688 precision_softmax_temp: float = 0.01, 689 ) -> None: 690 """ 691 Initializes a Global Workspace 692 693 Args: 694 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains 695 connected to the GW. Keys are domain names, values are the 696 `DomainModule`. 697 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 698 name to a `torch.nn.Module` class which role is to encode a 699 unimodal latent representations into a GW representation (pre fusion). 700 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 701 name to a `torch.nn.Module` class which role is to decode a 702 GW representation into a unimodal latent representations. 703 workspace_dim (`int`): dimension of the GW. 704 loss_coefs (`LossCoefs`): loss coefficients 705 sensitivity_selection (`float`): sensivity coef $c'_1$ 706 sensitivity_precision (`float`): sensitivity coef $c'_2$ 707 optim_lr (`float`): learning rate 708 optim_weight_decay (`float`): weight decay 709 scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments 710 learn_logit_scale (`bool`): whether to learn the contrastive learning 711 contrastive loss when using the default contrastive loss. 712 use_normalized_constrastive (`bool`): whether to use the normalized cont 713 loss by the precision coefs 714 contrastive_loss (`ContrastiveLossType | None`): a contrastive loss 715 function used for alignment. `learn_logit_scale` will not affect custom 716 contrastive losses. 717 precision_softmax_temp (`float`): temperature to use in softmax of 718 precision 719 """ 720 domain_mods = freeze_domain_modules(domain_mods) 721 722 gw_mod = GWModuleBayesian( 723 domain_mods, 724 workspace_dim, 725 gw_encoders, 726 gw_decoders, 727 sensitivity_selection, 728 sensitivity_precision, 729 precision_softmax_temp, 730 ) 731 732 selection_mod = FixedSharedSelection() 733 734 contrastive_loss = ContrastiveLoss( 735 torch.tensor([1]).log(), "mean", learn_logit_scale 736 ) 737 738 loss_mod = GWLossesBayesian( 739 gw_mod, 740 selection_mod, 741 domain_mods, 742 loss_coefs, 743 contrastive_loss, 744 use_normalized_constrastive, 745 ) 746 747 super().__init__( 748 gw_mod, 749 selection_mod, 750 loss_mod, 751 optim_lr, 752 optim_weight_decay, 753 scheduler_args, 754 )
Initializes a Global Workspace
Arguments:
- domain_mods (
Mapping[str, DomainModule]
): mapping of the domains connected to the GW. Keys are domain names, values are theDomainModule
. - gw_encoders (
Mapping[str, torch.nn.Module]
): mapping for each domain name to atorch.nn.Module
class which role is to encode a unimodal latent representations into a GW representation (pre fusion). - gw_decoders (
Mapping[str, torch.nn.Module]
): mapping for each domain name to atorch.nn.Module
class which role is to decode a GW representation into a unimodal latent representations. - workspace_dim (
int
): dimension of the GW. - loss_coefs (
LossCoefs
): loss coefficients - sensitivity_selection (
float
): sensivity coef $c'_1$ - sensitivity_precision (
float
): sensitivity coef $c'_2$ - optim_lr (
float
): learning rate - optim_weight_decay (
float
): weight decay - scheduler_args (
SchedulerArgs | None
): optimization scheduler's arguments - learn_logit_scale (
bool
): whether to learn the contrastive learning contrastive loss when using the default contrastive loss. - use_normalized_constrastive (
bool
): whether to use the normalized cont loss by the precision coefs - contrastive_loss (
ContrastiveLossType | None
): a contrastive loss function used for alignment.learn_logit_scale
will not affect custom contrastive losses. - precision_softmax_temp (
float
): temperature to use in softmax of precision
756 def forward( # type: ignore 757 self, 758 latent_domains: LatentsDomainGroupsT, 759 ) -> GWPredictions: 760 """ 761 Computes demi-cycles, cycles, and translations. 762 763 Args: 764 latent_domains (`LatentsT`): Groups of domains for the computation. 765 766 Returns: 767 `GWPredictions`: the predictions on the batch. 768 """ 769 return GWPredictions( 770 demi_cycles=batch_demi_cycles( 771 self.gw_mod, self.selection_mod, latent_domains 772 ), 773 cycles=batch_cycles( 774 self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys() 775 ), 776 translations=batch_translations( 777 self.gw_mod, self.selection_mod, latent_domains 778 ), 779 **super().forward(latent_domains), 780 )
Computes demi-cycles, cycles, and translations.
Arguments:
- latent_domains (
LatentsT
): Groups of domains for the computation.
Returns:
GWPredictions
: the predictions on the batch.
Inherited Members
- GlobalWorkspaceBase
- gw_mod
- selection_mod
- loss_mod
- optim_lr
- optim_weight_decay
- scheduler_args
- domain_mods
- workspace_dim
- encode_and_fuse
- encode
- fuse
- decode
- batch_gw_states
- encode_domain
- encode_domains
- decode_domain
- decode_domains
- generic_step
- validation_step
- test_step
- training_step
- predict_step
- configure_optimizers
- lightning.pytorch.core.module.LightningModule
- CHECKPOINT_HYPER_PARAMS_KEY
- CHECKPOINT_HYPER_PARAMS_NAME
- CHECKPOINT_HYPER_PARAMS_TYPE
- optimizers
- lr_schedulers
- trainer
- fabric
- example_input_array
- current_epoch
- global_step
- global_rank
- local_rank
- on_gpu
- automatic_optimization
- strict_loading
- logger
- loggers
- log
- log_dict
- all_gather
- configure_callbacks
- manual_backward
- backward
- toggle_optimizer
- untoggle_optimizer
- clip_gradients
- configure_gradient_clipping
- lr_scheduler_step
- optimizer_step
- optimizer_zero_grad
- freeze
- unfreeze
- to_onnx
- to_torchscript
- load_from_checkpoint
- lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
- dtype
- device
- to
- cuda
- cpu
- type
- float
- double
- half
- lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
- save_hyperparameters
- hparams
- hparams_initial
- lightning.pytorch.core.hooks.ModelHooks
- on_fit_start
- on_fit_end
- on_train_start
- on_train_end
- on_validation_start
- on_validation_end
- on_test_start
- on_test_end
- on_predict_start
- on_predict_end
- on_train_batch_start
- on_train_batch_end
- on_validation_batch_start
- on_validation_batch_end
- on_test_batch_start
- on_test_batch_end
- on_predict_batch_start
- on_predict_batch_end
- on_validation_model_zero_grad
- on_validation_model_eval
- on_validation_model_train
- on_test_model_eval
- on_test_model_train
- on_predict_model_eval
- on_train_epoch_start
- on_train_epoch_end
- on_validation_epoch_start
- on_validation_epoch_end
- on_test_epoch_start
- on_test_epoch_end
- on_predict_epoch_start
- on_predict_epoch_end
- on_before_zero_grad
- on_before_backward
- on_after_backward
- on_before_optimizer_step
- configure_sharded_model
- configure_model
- lightning.pytorch.core.hooks.DataHooks
- prepare_data_per_node
- allow_zero_length_dataloader_with_multiple_devices
- prepare_data
- setup
- teardown
- train_dataloader
- test_dataloader
- val_dataloader
- predict_dataloader
- transfer_batch_to_device
- on_before_batch_transfer
- on_after_batch_transfer
- lightning.pytorch.core.hooks.CheckpointHooks
- on_load_checkpoint
- on_save_checkpoint
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- ipu
- xpu
- bfloat16
- to_empty
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
783def pretrained_global_workspace( 784 checkpoint_path: str | Path, 785 domain_mods: Mapping[str, DomainModule], 786 gw_encoders: Mapping[str, Module], 787 gw_decoders: Mapping[str, Module], 788 workspace_dim: int, 789 loss_coefs: LossCoefs, 790 contrastive_fn: ContrastiveLossType, 791 **kwargs, 792) -> GlobalWorkspace2Domains: 793 """ 794 Load a `GlobalWorkspace` flavor of `GlobalWorkspaceBase` from a checkpoint. 795 796 Args: 797 checkpoint_path (`str | Path`): path to checkpoint 798 domain_mods (`Mapping[str, DomainModule]`): mapping of the domains 799 connected to the GW. Keys are domain names, values are the 800 `DomainModule`. 801 gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 802 name to a `torch.nn.Module` class which role is to encode a 803 unimodal latent representations into a GW representation (pre fusion). 804 gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain 805 name to a `torch.nn.Module` class which role is to decode a 806 GW representation into a unimodal latent representations. 807 workspace_dim (`int`): dimension of the GW. 808 loss_coefs (`LossCoefs`): loss coefficients 809 contrastive_loss (`ContrastiveLossType`): a contrastive loss 810 function used for alignment. `learn_logit_scale` will not affect custom 811 contrastive losses. 812 **kwargs: additional arguments to pass to 813 `GlobalWorkspace.load_from_checkpoint`. 814 815 Returns: 816 `GlobalWorkspace`: the pretrained `GlobalWorkspace`. 817 818 Raises: 819 `TypeError`: if loaded type is not `GlobalWorkspace`. 820 """ 821 domain_mods = freeze_domain_modules(domain_mods) 822 gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders) 823 selection_mod = SingleDomainSelection() 824 loss_mod = GWLosses2Domains( 825 gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_fn 826 ) 827 828 gw = GlobalWorkspace2Domains.load_from_checkpoint( 829 checkpoint_path, 830 gw_mod=gw_mod, 831 selection_mid=selection_mod, 832 loss_coefs=loss_coefs, 833 loss_mod=loss_mod, 834 **kwargs, 835 ) 836 if not isinstance(gw, GlobalWorkspace2Domains): 837 raise TypeError("model should be of type GlobalWorkspace") 838 return gw
Load a GlobalWorkspace
flavor of GlobalWorkspaceBase
from a checkpoint.
Arguments:
- checkpoint_path (
str | Path
): path to checkpoint - domain_mods (
Mapping[str, DomainModule]
): mapping of the domains connected to the GW. Keys are domain names, values are theDomainModule
. - gw_encoders (
Mapping[str, torch.nn.Module]
): mapping for each domain name to atorch.nn.Module
class which role is to encode a unimodal latent representations into a GW representation (pre fusion). - gw_decoders (
Mapping[str, torch.nn.Module]
): mapping for each domain name to atorch.nn.Module
class which role is to decode a GW representation into a unimodal latent representations. - workspace_dim (
int
): dimension of the GW. - loss_coefs (
LossCoefs
): loss coefficients - contrastive_loss (
ContrastiveLossType
): a contrastive loss function used for alignment.learn_logit_scale
will not affect custom contrastive losses. - **kwargs: additional arguments to pass to
GlobalWorkspace.load_from_checkpoint
.
Returns:
GlobalWorkspace
: the pretrainedGlobalWorkspace
.
Raises:
TypeError
: if loaded type is notGlobalWorkspace
.