shimmer.modules.domain

  1from dataclasses import dataclass, field
  2from typing import Any
  3
  4import lightning.pytorch as pl
  5import torch
  6
  7
  8@dataclass
  9class LossOutput:
 10    """
 11    This is a python dataclass use as a returned value for losses.
 12    It keeps track of what is used for training (`loss`) and what is used
 13    only for logging (`metrics`).
 14    """
 15
 16    loss: torch.Tensor
 17    """Loss used during training."""
 18
 19    metrics: dict[str, torch.Tensor] = field(default_factory=dict)
 20    """Some additional metrics to log (not used during training)."""
 21
 22    def __post_init__(self):
 23        if "loss" in self.metrics:
 24            raise ValueError("'loss' cannot be a key of metrics.")
 25
 26    @property
 27    def all(self) -> dict[str, torch.Tensor]:
 28        """
 29        Returns a dict with all metrics and loss with "loss" key.
 30        """
 31        return {**self.metrics, "loss": self.loss}
 32
 33
 34class DomainModule(pl.LightningModule):
 35    """
 36    Base class for a DomainModule that defines domain specific modules of the GW.
 37    """
 38
 39    def __init__(
 40        self,
 41        latent_dim: int,
 42    ) -> None:
 43        """
 44        Initializes a DomainModule.
 45
 46        Args:
 47            latent_dim (`int`): latent dimension of the unimodal module
 48        """
 49        super().__init__()
 50
 51        self.latent_dim = latent_dim
 52        """The latent dimension of the module."""
 53
 54    def encode(self, x: Any) -> torch.Tensor:
 55        """
 56        Encode the domain data into a unimodal representation.
 57
 58        Args:
 59            x (`Any`): data of the domain.
 60        Returns:
 61            `torch.Tensor`: a unimodal representation.
 62        """
 63        raise NotImplementedError
 64
 65    def decode(self, z: torch.Tensor) -> Any:
 66        """
 67        Decode data from unimodal representation back to the domain data.
 68
 69        Args:
 70            z (`torch.Tensor`): unimodal representation of the domain.
 71        Returns:
 72            `Any`: the original domain data.
 73        """
 74        raise NotImplementedError
 75
 76    def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
 77        """
 78        Generic loss computation  the modality.
 79
 80        Args:
 81            pred (`torch.Tensor`): prediction of the model
 82            target (`torch.Tensor`): target tensor
 83        Results:
 84            `LossOutput`: LossOuput with training loss and additional metrics.
 85        """
 86        raise NotImplementedError
 87
 88    def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
 89        """
 90        Computes the loss for a demi-cycle. Override if the demi-cycle loss is
 91        different that the generic loss.
 92
 93        Args:
 94            pred (`torch.Tensor`): prediction of the model
 95            target (`torch.Tensor`): target tensor
 96        Results:
 97            `LossOutput`: LossOuput with training loss and additional metrics.
 98        """
 99        return self.compute_loss(pred, target)
100
101    def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
102        """
103        Computes the loss for a cycle. Override if the cycle loss is
104        different that the generic loss.
105
106        Args:
107            pred (`torch.Tensor`): prediction of the model
108            target (`torch.Tensor`): target tensor
109        Results:
110            `LossOutput`: LossOuput with training loss and additional metrics.
111        """
112        return self.compute_loss(pred, target)
113
114    def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
115        """
116        Computes the loss for a translation. Override if the translation loss is
117        different that the generic loss.
118
119        Args:
120            pred (`torch.Tensor`): prediction of the model
121            target (`torch.Tensor`): target tensor
122        Results:
123            `LossOutput`: LossOuput with training loss and additional metrics.
124        """
125        return self.compute_loss(pred, target)
126
127    def compute_broadcast_loss(
128        self, pred: torch.Tensor, target: torch.Tensor
129    ) -> LossOutput:
130        """
131        Computes the loss for a broadcast (fusion). Override if the broadcast loss is
132        different that the generic loss.
133
134        Args:
135            pred (`torch.Tensor`): prediction of the model
136            target (`torch.Tensor`): target tensor
137        Results:
138            `LossOutput`: LossOuput with training loss and additional metrics.
139        """
140        return self.compute_loss(pred, target)
@dataclass
class LossOutput:
 9@dataclass
10class LossOutput:
11    """
12    This is a python dataclass use as a returned value for losses.
13    It keeps track of what is used for training (`loss`) and what is used
14    only for logging (`metrics`).
15    """
16
17    loss: torch.Tensor
18    """Loss used during training."""
19
20    metrics: dict[str, torch.Tensor] = field(default_factory=dict)
21    """Some additional metrics to log (not used during training)."""
22
23    def __post_init__(self):
24        if "loss" in self.metrics:
25            raise ValueError("'loss' cannot be a key of metrics.")
26
27    @property
28    def all(self) -> dict[str, torch.Tensor]:
29        """
30        Returns a dict with all metrics and loss with "loss" key.
31        """
32        return {**self.metrics, "loss": self.loss}

This is a python dataclass use as a returned value for losses. It keeps track of what is used for training (loss) and what is used only for logging (metrics).

LossOutput(loss: torch.Tensor, metrics: dict[str, torch.Tensor] = <factory>)
loss: torch.Tensor

Loss used during training.

metrics: dict[str, torch.Tensor]

Some additional metrics to log (not used during training).

all: dict[str, torch.Tensor]
27    @property
28    def all(self) -> dict[str, torch.Tensor]:
29        """
30        Returns a dict with all metrics and loss with "loss" key.
31        """
32        return {**self.metrics, "loss": self.loss}

Returns a dict with all metrics and loss with "loss" key.

class DomainModule(lightning.pytorch.core.module.LightningModule):
 35class DomainModule(pl.LightningModule):
 36    """
 37    Base class for a DomainModule that defines domain specific modules of the GW.
 38    """
 39
 40    def __init__(
 41        self,
 42        latent_dim: int,
 43    ) -> None:
 44        """
 45        Initializes a DomainModule.
 46
 47        Args:
 48            latent_dim (`int`): latent dimension of the unimodal module
 49        """
 50        super().__init__()
 51
 52        self.latent_dim = latent_dim
 53        """The latent dimension of the module."""
 54
 55    def encode(self, x: Any) -> torch.Tensor:
 56        """
 57        Encode the domain data into a unimodal representation.
 58
 59        Args:
 60            x (`Any`): data of the domain.
 61        Returns:
 62            `torch.Tensor`: a unimodal representation.
 63        """
 64        raise NotImplementedError
 65
 66    def decode(self, z: torch.Tensor) -> Any:
 67        """
 68        Decode data from unimodal representation back to the domain data.
 69
 70        Args:
 71            z (`torch.Tensor`): unimodal representation of the domain.
 72        Returns:
 73            `Any`: the original domain data.
 74        """
 75        raise NotImplementedError
 76
 77    def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
 78        """
 79        Generic loss computation  the modality.
 80
 81        Args:
 82            pred (`torch.Tensor`): prediction of the model
 83            target (`torch.Tensor`): target tensor
 84        Results:
 85            `LossOutput`: LossOuput with training loss and additional metrics.
 86        """
 87        raise NotImplementedError
 88
 89    def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
 90        """
 91        Computes the loss for a demi-cycle. Override if the demi-cycle loss is
 92        different that the generic loss.
 93
 94        Args:
 95            pred (`torch.Tensor`): prediction of the model
 96            target (`torch.Tensor`): target tensor
 97        Results:
 98            `LossOutput`: LossOuput with training loss and additional metrics.
 99        """
100        return self.compute_loss(pred, target)
101
102    def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
103        """
104        Computes the loss for a cycle. Override if the cycle loss is
105        different that the generic loss.
106
107        Args:
108            pred (`torch.Tensor`): prediction of the model
109            target (`torch.Tensor`): target tensor
110        Results:
111            `LossOutput`: LossOuput with training loss and additional metrics.
112        """
113        return self.compute_loss(pred, target)
114
115    def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
116        """
117        Computes the loss for a translation. Override if the translation loss is
118        different that the generic loss.
119
120        Args:
121            pred (`torch.Tensor`): prediction of the model
122            target (`torch.Tensor`): target tensor
123        Results:
124            `LossOutput`: LossOuput with training loss and additional metrics.
125        """
126        return self.compute_loss(pred, target)
127
128    def compute_broadcast_loss(
129        self, pred: torch.Tensor, target: torch.Tensor
130    ) -> LossOutput:
131        """
132        Computes the loss for a broadcast (fusion). Override if the broadcast loss is
133        different that the generic loss.
134
135        Args:
136            pred (`torch.Tensor`): prediction of the model
137            target (`torch.Tensor`): target tensor
138        Results:
139            `LossOutput`: LossOuput with training loss and additional metrics.
140        """
141        return self.compute_loss(pred, target)

Base class for a DomainModule that defines domain specific modules of the GW.

DomainModule(latent_dim: int)
40    def __init__(
41        self,
42        latent_dim: int,
43    ) -> None:
44        """
45        Initializes a DomainModule.
46
47        Args:
48            latent_dim (`int`): latent dimension of the unimodal module
49        """
50        super().__init__()
51
52        self.latent_dim = latent_dim
53        """The latent dimension of the module."""

Initializes a DomainModule.

Arguments:
  • latent_dim (int): latent dimension of the unimodal module
latent_dim

The latent dimension of the module.

def encode(self, x: Any) -> torch.Tensor:
55    def encode(self, x: Any) -> torch.Tensor:
56        """
57        Encode the domain data into a unimodal representation.
58
59        Args:
60            x (`Any`): data of the domain.
61        Returns:
62            `torch.Tensor`: a unimodal representation.
63        """
64        raise NotImplementedError

Encode the domain data into a unimodal representation.

Arguments:
  • x (Any): data of the domain.
Returns:

torch.Tensor: a unimodal representation.

def decode(self, z: torch.Tensor) -> Any:
66    def decode(self, z: torch.Tensor) -> Any:
67        """
68        Decode data from unimodal representation back to the domain data.
69
70        Args:
71            z (`torch.Tensor`): unimodal representation of the domain.
72        Returns:
73            `Any`: the original domain data.
74        """
75        raise NotImplementedError

Decode data from unimodal representation back to the domain data.

Arguments:
  • z (torch.Tensor): unimodal representation of the domain.
Returns:

Any: the original domain data.

def compute_loss( self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
77    def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
78        """
79        Generic loss computation  the modality.
80
81        Args:
82            pred (`torch.Tensor`): prediction of the model
83            target (`torch.Tensor`): target tensor
84        Results:
85            `LossOutput`: LossOuput with training loss and additional metrics.
86        """
87        raise NotImplementedError

Generic loss computation the modality.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
Results:

LossOutput: LossOuput with training loss and additional metrics.

def compute_dcy_loss( self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
 89    def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
 90        """
 91        Computes the loss for a demi-cycle. Override if the demi-cycle loss is
 92        different that the generic loss.
 93
 94        Args:
 95            pred (`torch.Tensor`): prediction of the model
 96            target (`torch.Tensor`): target tensor
 97        Results:
 98            `LossOutput`: LossOuput with training loss and additional metrics.
 99        """
100        return self.compute_loss(pred, target)

Computes the loss for a demi-cycle. Override if the demi-cycle loss is different that the generic loss.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
Results:

LossOutput: LossOuput with training loss and additional metrics.

def compute_cy_loss( self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
102    def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
103        """
104        Computes the loss for a cycle. Override if the cycle loss is
105        different that the generic loss.
106
107        Args:
108            pred (`torch.Tensor`): prediction of the model
109            target (`torch.Tensor`): target tensor
110        Results:
111            `LossOutput`: LossOuput with training loss and additional metrics.
112        """
113        return self.compute_loss(pred, target)

Computes the loss for a cycle. Override if the cycle loss is different that the generic loss.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
Results:

LossOutput: LossOuput with training loss and additional metrics.

def compute_tr_loss( self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
115    def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
116        """
117        Computes the loss for a translation. Override if the translation loss is
118        different that the generic loss.
119
120        Args:
121            pred (`torch.Tensor`): prediction of the model
122            target (`torch.Tensor`): target tensor
123        Results:
124            `LossOutput`: LossOuput with training loss and additional metrics.
125        """
126        return self.compute_loss(pred, target)

Computes the loss for a translation. Override if the translation loss is different that the generic loss.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
Results:

LossOutput: LossOuput with training loss and additional metrics.

def compute_broadcast_loss( self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
128    def compute_broadcast_loss(
129        self, pred: torch.Tensor, target: torch.Tensor
130    ) -> LossOutput:
131        """
132        Computes the loss for a broadcast (fusion). Override if the broadcast loss is
133        different that the generic loss.
134
135        Args:
136            pred (`torch.Tensor`): prediction of the model
137            target (`torch.Tensor`): target tensor
138        Results:
139            `LossOutput`: LossOuput with training loss and additional metrics.
140        """
141        return self.compute_loss(pred, target)

Computes the loss for a broadcast (fusion). Override if the broadcast loss is different that the generic loss.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
Results:

LossOutput: LossOuput with training loss and additional metrics.

Inherited Members
lightning.pytorch.core.module.LightningModule
CHECKPOINT_HYPER_PARAMS_KEY
CHECKPOINT_HYPER_PARAMS_NAME
CHECKPOINT_HYPER_PARAMS_TYPE
optimizers
lr_schedulers
trainer
fabric
example_input_array
current_epoch
global_step
global_rank
local_rank
on_gpu
automatic_optimization
strict_loading
logger
loggers
print
log
log_dict
all_gather
forward
training_step
validation_step
test_step
predict_step
configure_callbacks
configure_optimizers
manual_backward
backward
toggle_optimizer
untoggle_optimizer
clip_gradients
configure_gradient_clipping
lr_scheduler_step
optimizer_step
optimizer_zero_grad
freeze
unfreeze
to_onnx
to_torchscript
load_from_checkpoint
lightning.fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
dtype
device
to
cuda
cpu
type
float
double
half
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin
save_hyperparameters
hparams
hparams_initial
lightning.pytorch.core.hooks.ModelHooks
on_fit_start
on_fit_end
on_train_start
on_train_end
on_validation_start
on_validation_end
on_test_start
on_test_end
on_predict_start
on_predict_end
on_train_batch_start
on_train_batch_end
on_validation_batch_start
on_validation_batch_end
on_test_batch_start
on_test_batch_end
on_predict_batch_start
on_predict_batch_end
on_validation_model_zero_grad
on_validation_model_eval
on_validation_model_train
on_test_model_eval
on_test_model_train
on_predict_model_eval
on_train_epoch_start
on_train_epoch_end
on_validation_epoch_start
on_validation_epoch_end
on_test_epoch_start
on_test_epoch_end
on_predict_epoch_start
on_predict_epoch_end
on_before_zero_grad
on_before_backward
on_after_backward
on_before_optimizer_step
configure_sharded_model
configure_model
lightning.pytorch.core.hooks.DataHooks
prepare_data_per_node
allow_zero_length_dataloader_with_multiple_devices
prepare_data
setup
teardown
train_dataloader
test_dataloader
val_dataloader
predict_dataloader
transfer_batch_to_device
on_before_batch_transfer
on_after_batch_transfer
lightning.pytorch.core.hooks.CheckpointHooks
on_load_checkpoint
on_save_checkpoint
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
ipu
xpu
bfloat16
to_empty
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile