shimmer.modules.domain      
                
                        
                        
                        1from dataclasses import dataclass, field 2from typing import Any 3 4import lightning.pytorch as pl 5import torch 6 7 8@dataclass 9class LossOutput: 10 """ 11 This is a python dataclass use as a returned value for losses. 12 It keeps track of what is used for training (`loss`) and what is used 13 only for logging (`metrics`). 14 """ 15 16 loss: torch.Tensor 17 """Loss used during training.""" 18 19 metrics: dict[str, torch.Tensor] = field(default_factory=dict) 20 """Some additional metrics to log (not used during training).""" 21 22 def __post_init__(self): 23 if "loss" in self.metrics: 24 raise ValueError("'loss' cannot be a key of metrics.") 25 26 @property 27 def all(self) -> dict[str, torch.Tensor]: 28 """ 29 Returns a dict with all metrics and loss with "loss" key. 30 """ 31 return {**self.metrics, "loss": self.loss} 32 33 34class DomainModule(pl.LightningModule): 35 """ 36 Base class for a DomainModule that defines domain specific modules of the GW. 37 """ 38 39 def __init__( 40 self, 41 latent_dim: int, 42 ) -> None: 43 """ 44 Initializes a DomainModule. 45 46 Args: 47 latent_dim (`int`): latent dimension of the unimodal module 48 """ 49 super().__init__() 50 51 self.latent_dim = latent_dim 52 """The latent dimension of the module.""" 53 54 def encode(self, x: Any) -> torch.Tensor: 55 """ 56 Encode the domain data into a unimodal representation. 57 58 Args: 59 x (`Any`): data of the domain. 60 Returns: 61 `torch.Tensor`: a unimodal representation. 62 """ 63 raise NotImplementedError 64 65 def decode(self, z: torch.Tensor) -> Any: 66 """ 67 Decode data from unimodal representation back to the domain data. 68 69 Args: 70 z (`torch.Tensor`): unimodal representation of the domain. 71 Returns: 72 `Any`: the original domain data. 73 """ 74 raise NotImplementedError 75 76 def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 77 """ 78 Generic loss computation the modality. 79 80 Args: 81 pred (`torch.Tensor`): prediction of the model 82 target (`torch.Tensor`): target tensor 83 Results: 84 `LossOutput`: LossOuput with training loss and additional metrics. 85 """ 86 raise NotImplementedError 87 88 def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 89 """ 90 Computes the loss for a demi-cycle. Override if the demi-cycle loss is 91 different that the generic loss. 92 93 Args: 94 pred (`torch.Tensor`): prediction of the model 95 target (`torch.Tensor`): target tensor 96 Results: 97 `LossOutput`: LossOuput with training loss and additional metrics. 98 """ 99 return self.compute_loss(pred, target) 100 101 def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 102 """ 103 Computes the loss for a cycle. Override if the cycle loss is 104 different that the generic loss. 105 106 Args: 107 pred (`torch.Tensor`): prediction of the model 108 target (`torch.Tensor`): target tensor 109 Results: 110 `LossOutput`: LossOuput with training loss and additional metrics. 111 """ 112 return self.compute_loss(pred, target) 113 114 def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 115 """ 116 Computes the loss for a translation. Override if the translation loss is 117 different that the generic loss. 118 119 Args: 120 pred (`torch.Tensor`): prediction of the model 121 target (`torch.Tensor`): target tensor 122 Results: 123 `LossOutput`: LossOuput with training loss and additional metrics. 124 """ 125 return self.compute_loss(pred, target) 126 127 def compute_broadcast_loss( 128 self, pred: torch.Tensor, target: torch.Tensor 129 ) -> LossOutput: 130 """ 131 Computes the loss for a broadcast (fusion). Override if the broadcast loss is 132 different that the generic loss. 133 134 Args: 135 pred (`torch.Tensor`): prediction of the model 136 target (`torch.Tensor`): target tensor 137 Results: 138 `LossOutput`: LossOuput with training loss and additional metrics. 139 """ 140 return self.compute_loss(pred, target)
9@dataclass 10class LossOutput: 11 """ 12 This is a python dataclass use as a returned value for losses. 13 It keeps track of what is used for training (`loss`) and what is used 14 only for logging (`metrics`). 15 """ 16 17 loss: torch.Tensor 18 """Loss used during training.""" 19 20 metrics: dict[str, torch.Tensor] = field(default_factory=dict) 21 """Some additional metrics to log (not used during training).""" 22 23 def __post_init__(self): 24 if "loss" in self.metrics: 25 raise ValueError("'loss' cannot be a key of metrics.") 26 27 @property 28 def all(self) -> dict[str, torch.Tensor]: 29 """ 30 Returns a dict with all metrics and loss with "loss" key. 31 """ 32 return {**self.metrics, "loss": self.loss}
35class DomainModule(pl.LightningModule): 36 """ 37 Base class for a DomainModule that defines domain specific modules of the GW. 38 """ 39 40 def __init__( 41 self, 42 latent_dim: int, 43 ) -> None: 44 """ 45 Initializes a DomainModule. 46 47 Args: 48 latent_dim (`int`): latent dimension of the unimodal module 49 """ 50 super().__init__() 51 52 self.latent_dim = latent_dim 53 """The latent dimension of the module.""" 54 55 def encode(self, x: Any) -> torch.Tensor: 56 """ 57 Encode the domain data into a unimodal representation. 58 59 Args: 60 x (`Any`): data of the domain. 61 Returns: 62 `torch.Tensor`: a unimodal representation. 63 """ 64 raise NotImplementedError 65 66 def decode(self, z: torch.Tensor) -> Any: 67 """ 68 Decode data from unimodal representation back to the domain data. 69 70 Args: 71 z (`torch.Tensor`): unimodal representation of the domain. 72 Returns: 73 `Any`: the original domain data. 74 """ 75 raise NotImplementedError 76 77 def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 78 """ 79 Generic loss computation the modality. 80 81 Args: 82 pred (`torch.Tensor`): prediction of the model 83 target (`torch.Tensor`): target tensor 84 Results: 85 `LossOutput`: LossOuput with training loss and additional metrics. 86 """ 87 raise NotImplementedError 88 89 def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 90 """ 91 Computes the loss for a demi-cycle. Override if the demi-cycle loss is 92 different that the generic loss. 93 94 Args: 95 pred (`torch.Tensor`): prediction of the model 96 target (`torch.Tensor`): target tensor 97 Results: 98 `LossOutput`: LossOuput with training loss and additional metrics. 99 """ 100 return self.compute_loss(pred, target) 101 102 def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 103 """ 104 Computes the loss for a cycle. Override if the cycle loss is 105 different that the generic loss. 106 107 Args: 108 pred (`torch.Tensor`): prediction of the model 109 target (`torch.Tensor`): target tensor 110 Results: 111 `LossOutput`: LossOuput with training loss and additional metrics. 112 """ 113 return self.compute_loss(pred, target) 114 115 def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 116 """ 117 Computes the loss for a translation. Override if the translation loss is 118 different that the generic loss. 119 120 Args: 121 pred (`torch.Tensor`): prediction of the model 122 target (`torch.Tensor`): target tensor 123 Results: 124 `LossOutput`: LossOuput with training loss and additional metrics. 125 """ 126 return self.compute_loss(pred, target) 127 128 def compute_broadcast_loss( 129 self, pred: torch.Tensor, target: torch.Tensor 130 ) -> LossOutput: 131 """ 132 Computes the loss for a broadcast (fusion). Override if the broadcast loss is 133 different that the generic loss. 134 135 Args: 136 pred (`torch.Tensor`): prediction of the model 137 target (`torch.Tensor`): target tensor 138 Results: 139 `LossOutput`: LossOuput with training loss and additional metrics. 140 """ 141 return self.compute_loss(pred, target)
Base class for a DomainModule that defines domain specific modules of the GW.
40 def __init__( 41 self, 42 latent_dim: int, 43 ) -> None: 44 """ 45 Initializes a DomainModule. 46 47 Args: 48 latent_dim (`int`): latent dimension of the unimodal module 49 """ 50 super().__init__() 51 52 self.latent_dim = latent_dim 53 """The latent dimension of the module."""
Initializes a DomainModule.
Arguments:
- latent_dim (
int): latent dimension of the unimodal module 
55 def encode(self, x: Any) -> torch.Tensor: 56 """ 57 Encode the domain data into a unimodal representation. 58 59 Args: 60 x (`Any`): data of the domain. 61 Returns: 62 `torch.Tensor`: a unimodal representation. 63 """ 64 raise NotImplementedError
Encode the domain data into a unimodal representation.
Arguments:
- x (
Any): data of the domain. 
Returns:
torch.Tensor: a unimodal representation.
66 def decode(self, z: torch.Tensor) -> Any: 67 """ 68 Decode data from unimodal representation back to the domain data. 69 70 Args: 71 z (`torch.Tensor`): unimodal representation of the domain. 72 Returns: 73 `Any`: the original domain data. 74 """ 75 raise NotImplementedError
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor): unimodal representation of the domain. 
Returns:
Any: the original domain data.
77 def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 78 """ 79 Generic loss computation the modality. 80 81 Args: 82 pred (`torch.Tensor`): prediction of the model 83 target (`torch.Tensor`): target tensor 84 Results: 85 `LossOutput`: LossOuput with training loss and additional metrics. 86 """ 87 raise NotImplementedError
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor): prediction of the model - target (
torch.Tensor): target tensor 
Results:
LossOutput: LossOuput with training loss and additional metrics.
89 def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 90 """ 91 Computes the loss for a demi-cycle. Override if the demi-cycle loss is 92 different that the generic loss. 93 94 Args: 95 pred (`torch.Tensor`): prediction of the model 96 target (`torch.Tensor`): target tensor 97 Results: 98 `LossOutput`: LossOuput with training loss and additional metrics. 99 """ 100 return self.compute_loss(pred, target)
Computes the loss for a demi-cycle. Override if the demi-cycle loss is different that the generic loss.
Arguments:
- pred (
torch.Tensor): prediction of the model - target (
torch.Tensor): target tensor 
Results:
LossOutput: LossOuput with training loss and additional metrics.
102 def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 103 """ 104 Computes the loss for a cycle. Override if the cycle loss is 105 different that the generic loss. 106 107 Args: 108 pred (`torch.Tensor`): prediction of the model 109 target (`torch.Tensor`): target tensor 110 Results: 111 `LossOutput`: LossOuput with training loss and additional metrics. 112 """ 113 return self.compute_loss(pred, target)
Computes the loss for a cycle. Override if the cycle loss is different that the generic loss.
Arguments:
- pred (
torch.Tensor): prediction of the model - target (
torch.Tensor): target tensor 
Results:
LossOutput: LossOuput with training loss and additional metrics.
115 def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 116 """ 117 Computes the loss for a translation. Override if the translation loss is 118 different that the generic loss. 119 120 Args: 121 pred (`torch.Tensor`): prediction of the model 122 target (`torch.Tensor`): target tensor 123 Results: 124 `LossOutput`: LossOuput with training loss and additional metrics. 125 """ 126 return self.compute_loss(pred, target)
Computes the loss for a translation. Override if the translation loss is different that the generic loss.
Arguments:
- pred (
torch.Tensor): prediction of the model - target (
torch.Tensor): target tensor 
Results:
LossOutput: LossOuput with training loss and additional metrics.
128 def compute_broadcast_loss( 129 self, pred: torch.Tensor, target: torch.Tensor 130 ) -> LossOutput: 131 """ 132 Computes the loss for a broadcast (fusion). Override if the broadcast loss is 133 different that the generic loss. 134 135 Args: 136 pred (`torch.Tensor`): prediction of the model 137 target (`torch.Tensor`): target tensor 138 Results: 139 `LossOutput`: LossOuput with training loss and additional metrics. 140 """ 141 return self.compute_loss(pred, target)
Computes the loss for a broadcast (fusion). Override if the broadcast loss is different that the generic loss.
Arguments:
- pred (
torch.Tensor): prediction of the model - target (
torch.Tensor): target tensor 
Results:
LossOutput: LossOuput with training loss and additional metrics.
Inherited Members
- lightning.pytorch.core.module.LightningModule
 - CHECKPOINT_HYPER_PARAMS_KEY
 - CHECKPOINT_HYPER_PARAMS_NAME
 - CHECKPOINT_HYPER_PARAMS_TYPE
 - optimizers
 - lr_schedulers
 - trainer
 - fabric
 - example_input_array
 - current_epoch
 - global_step
 - global_rank
 - local_rank
 - on_gpu
 - automatic_optimization
 - strict_loading
 - logger
 - loggers
 - log
 - log_dict
 - all_gather
 - forward
 - training_step
 - validation_step
 - test_step
 - predict_step
 - configure_callbacks
 - configure_optimizers
 - manual_backward
 - backward
 - toggle_optimizer
 - untoggle_optimizer
 - clip_gradients
 - configure_gradient_clipping
 - lr_scheduler_step
 - optimizer_step
 - optimizer_zero_grad
 - freeze
 - unfreeze
 - to_onnx
 - to_torchscript
 - load_from_checkpoint
 - lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
 - dtype
 - device
 - to
 - cuda
 - cpu
 - type
 - float
 - double
 - half
 - lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
 - save_hyperparameters
 - hparams
 - hparams_initial
 - lightning.pytorch.core.hooks.ModelHooks
 - on_fit_start
 - on_fit_end
 - on_train_start
 - on_train_end
 - on_validation_start
 - on_validation_end
 - on_test_start
 - on_test_end
 - on_predict_start
 - on_predict_end
 - on_train_batch_start
 - on_train_batch_end
 - on_validation_batch_start
 - on_validation_batch_end
 - on_test_batch_start
 - on_test_batch_end
 - on_predict_batch_start
 - on_predict_batch_end
 - on_validation_model_zero_grad
 - on_validation_model_eval
 - on_validation_model_train
 - on_test_model_eval
 - on_test_model_train
 - on_predict_model_eval
 - on_train_epoch_start
 - on_train_epoch_end
 - on_validation_epoch_start
 - on_validation_epoch_end
 - on_test_epoch_start
 - on_test_epoch_end
 - on_predict_epoch_start
 - on_predict_epoch_end
 - on_before_zero_grad
 - on_before_backward
 - on_after_backward
 - on_before_optimizer_step
 - configure_sharded_model
 - configure_model
 - lightning.pytorch.core.hooks.DataHooks
 - prepare_data_per_node
 - allow_zero_length_dataloader_with_multiple_devices
 - prepare_data
 - setup
 - teardown
 - train_dataloader
 - test_dataloader
 - val_dataloader
 - predict_dataloader
 - transfer_batch_to_device
 - on_before_batch_transfer
 - on_after_batch_transfer
 - lightning.pytorch.core.hooks.CheckpointHooks
 - on_load_checkpoint
 - on_save_checkpoint
 - torch.nn.modules.module.Module
 - dump_patches
 - training
 - call_super_init
 - register_buffer
 - register_parameter
 - add_module
 - register_module
 - get_submodule
 - get_parameter
 - get_buffer
 - get_extra_state
 - set_extra_state
 - apply
 - ipu
 - xpu
 - bfloat16
 - to_empty
 - register_full_backward_pre_hook
 - register_backward_hook
 - register_full_backward_hook
 - register_forward_pre_hook
 - register_forward_hook
 - register_state_dict_pre_hook
 - state_dict
 - register_load_state_dict_post_hook
 - load_state_dict
 - parameters
 - named_parameters
 - buffers
 - named_buffers
 - children
 - named_children
 - modules
 - named_modules
 - train
 - eval
 - requires_grad_
 - zero_grad
 - extra_repr
 - compile