shimmer.modules.domain
1from dataclasses import dataclass, field 2from typing import Any 3 4import lightning.pytorch as pl 5import torch 6 7 8@dataclass 9class LossOutput: 10 """ 11 This is a python dataclass use as a returned value for losses. 12 It keeps track of what is used for training (`loss`) and what is used 13 only for logging (`metrics`). 14 """ 15 16 loss: torch.Tensor 17 """Loss used during training.""" 18 19 metrics: dict[str, torch.Tensor] = field(default_factory=dict) 20 """Some additional metrics to log (not used during training).""" 21 22 def __post_init__(self): 23 if "loss" in self.metrics: 24 raise ValueError("'loss' cannot be a key of metrics.") 25 26 @property 27 def all(self) -> dict[str, torch.Tensor]: 28 """ 29 Returns a dict with all metrics and loss with "loss" key. 30 """ 31 return {**self.metrics, "loss": self.loss} 32 33 34class DomainModule(pl.LightningModule): 35 """ 36 Base class for a DomainModule that defines domain specific modules of the GW. 37 """ 38 39 def __init__( 40 self, 41 latent_dim: int, 42 ) -> None: 43 """ 44 Initializes a DomainModule. 45 46 Args: 47 latent_dim (`int`): latent dimension of the unimodal module 48 """ 49 super().__init__() 50 51 self.latent_dim = latent_dim 52 """The latent dimension of the module.""" 53 54 def encode(self, x: Any) -> torch.Tensor: 55 """ 56 Encode the domain data into a unimodal representation. 57 58 Args: 59 x (`Any`): data of the domain. 60 Returns: 61 `torch.Tensor`: a unimodal representation. 62 """ 63 raise NotImplementedError 64 65 def decode(self, z: torch.Tensor) -> Any: 66 """ 67 Decode data from unimodal representation back to the domain data. 68 69 Args: 70 z (`torch.Tensor`): unimodal representation of the domain. 71 Returns: 72 `Any`: the original domain data. 73 """ 74 raise NotImplementedError 75 76 def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 77 """ 78 Generic loss computation the modality. 79 80 Args: 81 pred (`torch.Tensor`): prediction of the model 82 target (`torch.Tensor`): target tensor 83 Results: 84 `LossOutput`: LossOuput with training loss and additional metrics. 85 """ 86 raise NotImplementedError 87 88 def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 89 """ 90 Computes the loss for a demi-cycle. Override if the demi-cycle loss is 91 different that the generic loss. 92 93 Args: 94 pred (`torch.Tensor`): prediction of the model 95 target (`torch.Tensor`): target tensor 96 Results: 97 `LossOutput`: LossOuput with training loss and additional metrics. 98 """ 99 return self.compute_loss(pred, target) 100 101 def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 102 """ 103 Computes the loss for a cycle. Override if the cycle loss is 104 different that the generic loss. 105 106 Args: 107 pred (`torch.Tensor`): prediction of the model 108 target (`torch.Tensor`): target tensor 109 Results: 110 `LossOutput`: LossOuput with training loss and additional metrics. 111 """ 112 return self.compute_loss(pred, target) 113 114 def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 115 """ 116 Computes the loss for a translation. Override if the translation loss is 117 different that the generic loss. 118 119 Args: 120 pred (`torch.Tensor`): prediction of the model 121 target (`torch.Tensor`): target tensor 122 Results: 123 `LossOutput`: LossOuput with training loss and additional metrics. 124 """ 125 return self.compute_loss(pred, target) 126 127 def compute_broadcast_loss( 128 self, pred: torch.Tensor, target: torch.Tensor 129 ) -> LossOutput: 130 """ 131 Computes the loss for a broadcast (fusion). Override if the broadcast loss is 132 different that the generic loss. 133 134 Args: 135 pred (`torch.Tensor`): prediction of the model 136 target (`torch.Tensor`): target tensor 137 Results: 138 `LossOutput`: LossOuput with training loss and additional metrics. 139 """ 140 return self.compute_loss(pred, target)
9@dataclass 10class LossOutput: 11 """ 12 This is a python dataclass use as a returned value for losses. 13 It keeps track of what is used for training (`loss`) and what is used 14 only for logging (`metrics`). 15 """ 16 17 loss: torch.Tensor 18 """Loss used during training.""" 19 20 metrics: dict[str, torch.Tensor] = field(default_factory=dict) 21 """Some additional metrics to log (not used during training).""" 22 23 def __post_init__(self): 24 if "loss" in self.metrics: 25 raise ValueError("'loss' cannot be a key of metrics.") 26 27 @property 28 def all(self) -> dict[str, torch.Tensor]: 29 """ 30 Returns a dict with all metrics and loss with "loss" key. 31 """ 32 return {**self.metrics, "loss": self.loss}
35class DomainModule(pl.LightningModule): 36 """ 37 Base class for a DomainModule that defines domain specific modules of the GW. 38 """ 39 40 def __init__( 41 self, 42 latent_dim: int, 43 ) -> None: 44 """ 45 Initializes a DomainModule. 46 47 Args: 48 latent_dim (`int`): latent dimension of the unimodal module 49 """ 50 super().__init__() 51 52 self.latent_dim = latent_dim 53 """The latent dimension of the module.""" 54 55 def encode(self, x: Any) -> torch.Tensor: 56 """ 57 Encode the domain data into a unimodal representation. 58 59 Args: 60 x (`Any`): data of the domain. 61 Returns: 62 `torch.Tensor`: a unimodal representation. 63 """ 64 raise NotImplementedError 65 66 def decode(self, z: torch.Tensor) -> Any: 67 """ 68 Decode data from unimodal representation back to the domain data. 69 70 Args: 71 z (`torch.Tensor`): unimodal representation of the domain. 72 Returns: 73 `Any`: the original domain data. 74 """ 75 raise NotImplementedError 76 77 def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 78 """ 79 Generic loss computation the modality. 80 81 Args: 82 pred (`torch.Tensor`): prediction of the model 83 target (`torch.Tensor`): target tensor 84 Results: 85 `LossOutput`: LossOuput with training loss and additional metrics. 86 """ 87 raise NotImplementedError 88 89 def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 90 """ 91 Computes the loss for a demi-cycle. Override if the demi-cycle loss is 92 different that the generic loss. 93 94 Args: 95 pred (`torch.Tensor`): prediction of the model 96 target (`torch.Tensor`): target tensor 97 Results: 98 `LossOutput`: LossOuput with training loss and additional metrics. 99 """ 100 return self.compute_loss(pred, target) 101 102 def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 103 """ 104 Computes the loss for a cycle. Override if the cycle loss is 105 different that the generic loss. 106 107 Args: 108 pred (`torch.Tensor`): prediction of the model 109 target (`torch.Tensor`): target tensor 110 Results: 111 `LossOutput`: LossOuput with training loss and additional metrics. 112 """ 113 return self.compute_loss(pred, target) 114 115 def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 116 """ 117 Computes the loss for a translation. Override if the translation loss is 118 different that the generic loss. 119 120 Args: 121 pred (`torch.Tensor`): prediction of the model 122 target (`torch.Tensor`): target tensor 123 Results: 124 `LossOutput`: LossOuput with training loss and additional metrics. 125 """ 126 return self.compute_loss(pred, target) 127 128 def compute_broadcast_loss( 129 self, pred: torch.Tensor, target: torch.Tensor 130 ) -> LossOutput: 131 """ 132 Computes the loss for a broadcast (fusion). Override if the broadcast loss is 133 different that the generic loss. 134 135 Args: 136 pred (`torch.Tensor`): prediction of the model 137 target (`torch.Tensor`): target tensor 138 Results: 139 `LossOutput`: LossOuput with training loss and additional metrics. 140 """ 141 return self.compute_loss(pred, target)
Base class for a DomainModule that defines domain specific modules of the GW.
40 def __init__( 41 self, 42 latent_dim: int, 43 ) -> None: 44 """ 45 Initializes a DomainModule. 46 47 Args: 48 latent_dim (`int`): latent dimension of the unimodal module 49 """ 50 super().__init__() 51 52 self.latent_dim = latent_dim 53 """The latent dimension of the module."""
Initializes a DomainModule.
Arguments:
- latent_dim (
int
): latent dimension of the unimodal module
55 def encode(self, x: Any) -> torch.Tensor: 56 """ 57 Encode the domain data into a unimodal representation. 58 59 Args: 60 x (`Any`): data of the domain. 61 Returns: 62 `torch.Tensor`: a unimodal representation. 63 """ 64 raise NotImplementedError
Encode the domain data into a unimodal representation.
Arguments:
- x (
Any
): data of the domain.
Returns:
torch.Tensor
: a unimodal representation.
66 def decode(self, z: torch.Tensor) -> Any: 67 """ 68 Decode data from unimodal representation back to the domain data. 69 70 Args: 71 z (`torch.Tensor`): unimodal representation of the domain. 72 Returns: 73 `Any`: the original domain data. 74 """ 75 raise NotImplementedError
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor
): unimodal representation of the domain.
Returns:
Any
: the original domain data.
77 def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 78 """ 79 Generic loss computation the modality. 80 81 Args: 82 pred (`torch.Tensor`): prediction of the model 83 target (`torch.Tensor`): target tensor 84 Results: 85 `LossOutput`: LossOuput with training loss and additional metrics. 86 """ 87 raise NotImplementedError
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor
Results:
LossOutput
: LossOuput with training loss and additional metrics.
89 def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 90 """ 91 Computes the loss for a demi-cycle. Override if the demi-cycle loss is 92 different that the generic loss. 93 94 Args: 95 pred (`torch.Tensor`): prediction of the model 96 target (`torch.Tensor`): target tensor 97 Results: 98 `LossOutput`: LossOuput with training loss and additional metrics. 99 """ 100 return self.compute_loss(pred, target)
Computes the loss for a demi-cycle. Override if the demi-cycle loss is different that the generic loss.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor
Results:
LossOutput
: LossOuput with training loss and additional metrics.
102 def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 103 """ 104 Computes the loss for a cycle. Override if the cycle loss is 105 different that the generic loss. 106 107 Args: 108 pred (`torch.Tensor`): prediction of the model 109 target (`torch.Tensor`): target tensor 110 Results: 111 `LossOutput`: LossOuput with training loss and additional metrics. 112 """ 113 return self.compute_loss(pred, target)
Computes the loss for a cycle. Override if the cycle loss is different that the generic loss.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor
Results:
LossOutput
: LossOuput with training loss and additional metrics.
115 def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: 116 """ 117 Computes the loss for a translation. Override if the translation loss is 118 different that the generic loss. 119 120 Args: 121 pred (`torch.Tensor`): prediction of the model 122 target (`torch.Tensor`): target tensor 123 Results: 124 `LossOutput`: LossOuput with training loss and additional metrics. 125 """ 126 return self.compute_loss(pred, target)
Computes the loss for a translation. Override if the translation loss is different that the generic loss.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor
Results:
LossOutput
: LossOuput with training loss and additional metrics.
128 def compute_broadcast_loss( 129 self, pred: torch.Tensor, target: torch.Tensor 130 ) -> LossOutput: 131 """ 132 Computes the loss for a broadcast (fusion). Override if the broadcast loss is 133 different that the generic loss. 134 135 Args: 136 pred (`torch.Tensor`): prediction of the model 137 target (`torch.Tensor`): target tensor 138 Results: 139 `LossOutput`: LossOuput with training loss and additional metrics. 140 """ 141 return self.compute_loss(pred, target)
Computes the loss for a broadcast (fusion). Override if the broadcast loss is different that the generic loss.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor
Results:
LossOutput
: LossOuput with training loss and additional metrics.
Inherited Members
- lightning.pytorch.core.module.LightningModule
- CHECKPOINT_HYPER_PARAMS_KEY
- CHECKPOINT_HYPER_PARAMS_NAME
- CHECKPOINT_HYPER_PARAMS_TYPE
- optimizers
- lr_schedulers
- trainer
- fabric
- example_input_array
- current_epoch
- global_step
- global_rank
- local_rank
- on_gpu
- automatic_optimization
- strict_loading
- logger
- loggers
- log
- log_dict
- all_gather
- forward
- training_step
- validation_step
- test_step
- predict_step
- configure_callbacks
- configure_optimizers
- manual_backward
- backward
- toggle_optimizer
- untoggle_optimizer
- clip_gradients
- configure_gradient_clipping
- lr_scheduler_step
- optimizer_step
- optimizer_zero_grad
- freeze
- unfreeze
- to_onnx
- to_torchscript
- load_from_checkpoint
- lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
- dtype
- device
- to
- cuda
- cpu
- type
- float
- double
- half
- lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
- save_hyperparameters
- hparams
- hparams_initial
- lightning.pytorch.core.hooks.ModelHooks
- on_fit_start
- on_fit_end
- on_train_start
- on_train_end
- on_validation_start
- on_validation_end
- on_test_start
- on_test_end
- on_predict_start
- on_predict_end
- on_train_batch_start
- on_train_batch_end
- on_validation_batch_start
- on_validation_batch_end
- on_test_batch_start
- on_test_batch_end
- on_predict_batch_start
- on_predict_batch_end
- on_validation_model_zero_grad
- on_validation_model_eval
- on_validation_model_train
- on_test_model_eval
- on_test_model_train
- on_predict_model_eval
- on_train_epoch_start
- on_train_epoch_end
- on_validation_epoch_start
- on_validation_epoch_end
- on_test_epoch_start
- on_test_epoch_end
- on_predict_epoch_start
- on_predict_epoch_end
- on_before_zero_grad
- on_before_backward
- on_after_backward
- on_before_optimizer_step
- configure_sharded_model
- configure_model
- lightning.pytorch.core.hooks.DataHooks
- prepare_data_per_node
- allow_zero_length_dataloader_with_multiple_devices
- prepare_data
- setup
- teardown
- train_dataloader
- test_dataloader
- val_dataloader
- predict_dataloader
- transfer_batch_to_device
- on_before_batch_transfer
- on_after_batch_transfer
- lightning.pytorch.core.hooks.CheckpointHooks
- on_load_checkpoint
- on_save_checkpoint
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- ipu
- xpu
- bfloat16
- to_empty
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile