shimmer_ssd.modules.domains.visual
1from collections.abc import Mapping 2from typing import Any 3 4import torch 5from shimmer import LossOutput 6from shimmer.modules.domain import DomainModule 7from shimmer.modules.vae import VAE, gaussian_nll, kl_divergence_loss 8from torch.nn.functional import mse_loss 9from torch.optim.lr_scheduler import OneCycleLR 10 11from shimmer_ssd import LOGGER 12from shimmer_ssd.modules.vae import RAEDecoder, RAEEncoder 13 14 15class VisualDomainModule(DomainModule): 16 def __init__( 17 self, 18 num_channels: int, 19 latent_dim: int, 20 ae_dim: int, 21 beta: float = 1, 22 optim_lr: float = 1e-3, 23 optim_weight_decay: float = 0, 24 scheduler_args: Mapping[str, Any] | None = None, 25 ): 26 """ 27 Visual domain module. This defines shimmer's `DomainModule` for the vision 28 side with a VAE. 29 30 Args: 31 num_channels (`int`): number of input channels (for RGB image, use 3) 32 latent_dim (`int`): latent dimension of the vision domain 33 ae_dim (`int`): internal auto-encoder dimension of the VAE 34 beta (`float`): beta value if beta-VAE. (Defaults to 1.0) 35 optim_lr (`float`): training learning rate 36 optim_weight_decay (`float`): training weight decay 37 scheduler_args (`Mapping[str, Any] | None`): Args for the scheduler. 38 """ 39 40 super().__init__(latent_dim) 41 self.save_hyperparameters() 42 43 vae_encoder = RAEEncoder(num_channels, ae_dim, latent_dim, use_batchnorm=True) 44 vae_decoder = RAEDecoder(num_channels, latent_dim, ae_dim) 45 self.vae = VAE(vae_encoder, vae_decoder, beta) 46 self.optim_lr = optim_lr 47 self.optim_weight_decay = optim_weight_decay 48 self.scheduler_args: dict[str, Any] = { 49 "max_lr": optim_lr, 50 "total_steps": 1, 51 } 52 self.scheduler_args.update(scheduler_args or {}) 53 54 def compute_loss( 55 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 56 ) -> LossOutput: 57 return LossOutput(mse_loss(pred, target, reduction="mean")) 58 59 def encode(self, x: torch.Tensor) -> torch.Tensor: 60 return self.vae.encode(x) 61 62 def decode(self, z: torch.Tensor) -> torch.Tensor: 63 out = self.vae.decode(z) 64 return out 65 66 def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore 67 return self.decode(self.encode(x)) 68 69 def generic_step( 70 self, 71 x: torch.Tensor, 72 mode: str = "train", 73 ) -> torch.Tensor: 74 (mean, logvar), reconstruction = self.vae(x) 75 76 reconstruction_loss = gaussian_nll(reconstruction, torch.tensor(0), x).sum() 77 78 kl_loss = kl_divergence_loss(mean, logvar) 79 total_loss = reconstruction_loss + self.vae.beta * kl_loss 80 81 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 82 self.log(f"{mode}/kl_loss", kl_loss) 83 self.log(f"{mode}/loss", total_loss) 84 return total_loss 85 86 def validation_step( # type: ignore 87 self, 88 batch: Mapping[str, torch.Tensor], 89 batch_idx: int, 90 ) -> torch.Tensor: 91 x = batch["v"] 92 return self.generic_step(x, "val") 93 94 def training_step( # type: ignore 95 self, 96 batch: Mapping[frozenset[str], Mapping[str, torch.Tensor]], 97 batch_idx: int, 98 ) -> torch.Tensor: 99 x = batch[frozenset(["v"])]["v"] 100 return self.generic_step(x, "train") 101 102 def configure_optimizers( # type: ignore 103 self, 104 ) -> dict[str, Any]: 105 optimizer = torch.optim.AdamW( 106 self.parameters(), 107 lr=self.optim_lr, 108 weight_decay=self.optim_weight_decay, 109 ) 110 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 111 112 return { 113 "optimizer": optimizer, 114 "lr_scheduler": { 115 "scheduler": lr_scheduler, 116 "interval": "step", 117 }, 118 } 119 120 121class VisualLatentDomainModule(DomainModule): 122 def __init__(self, visual_module: VisualDomainModule): 123 super().__init__(visual_module.latent_dim) 124 self.visual_module = visual_module 125 126 def encode(self, x: torch.Tensor) -> torch.Tensor: 127 return x 128 129 def decode(self, z: torch.Tensor) -> torch.Tensor: 130 return z 131 132 def compute_loss( 133 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 134 ) -> LossOutput: 135 return LossOutput(mse_loss(pred, target, reduction="mean")) 136 137 def decode_images(self, z: torch.Tensor) -> torch.Tensor: 138 LOGGER.debug(f"VisualLatentDomainModule.decode_images: z.shape = {z.size()}") 139 return self.visual_module.decode(z) 140 141 142class VisualLatentDomainWithUnpairedModule(DomainModule): 143 def __init__(self, visual_module: VisualDomainModule, coef_unpaired: float = 0.5): 144 super().__init__(visual_module.latent_dim + 1) 145 146 if coef_unpaired < 0 or coef_unpaired > 1: 147 raise ValueError("coef_unpaired should be in [0, 1]") 148 149 self.visual_module = visual_module 150 self.paired_dim = self.visual_module.latent_dim 151 self.coef_unpaired = coef_unpaired 152 153 def encode(self, x: torch.Tensor) -> torch.Tensor: 154 return x 155 156 def decode(self, z: torch.Tensor) -> torch.Tensor: 157 return z 158 159 def compute_loss( 160 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 161 ) -> LossOutput: 162 paired_loss = mse_loss(pred[:, : self.paired_dim], target[:, : self.paired_dim]) 163 unpaired_loss = mse_loss( 164 pred[:, self.paired_dim :], target[:, self.paired_dim :] 165 ) 166 total_loss = ( 167 self.coef_unpaired * unpaired_loss + (1 - self.coef_unpaired) * paired_loss 168 ) 169 return LossOutput( 170 loss=total_loss, 171 metrics={ 172 "unpaired": unpaired_loss, 173 "paired": paired_loss, 174 }, 175 ) 176 177 def decode_images(self, z: torch.Tensor) -> torch.Tensor: 178 LOGGER.debug(f"VisualLatentDomainModule.decode_images: z.shape = {z.size()}") 179 return self.visual_module.decode(z[:, :-1])
16class VisualDomainModule(DomainModule): 17 def __init__( 18 self, 19 num_channels: int, 20 latent_dim: int, 21 ae_dim: int, 22 beta: float = 1, 23 optim_lr: float = 1e-3, 24 optim_weight_decay: float = 0, 25 scheduler_args: Mapping[str, Any] | None = None, 26 ): 27 """ 28 Visual domain module. This defines shimmer's `DomainModule` for the vision 29 side with a VAE. 30 31 Args: 32 num_channels (`int`): number of input channels (for RGB image, use 3) 33 latent_dim (`int`): latent dimension of the vision domain 34 ae_dim (`int`): internal auto-encoder dimension of the VAE 35 beta (`float`): beta value if beta-VAE. (Defaults to 1.0) 36 optim_lr (`float`): training learning rate 37 optim_weight_decay (`float`): training weight decay 38 scheduler_args (`Mapping[str, Any] | None`): Args for the scheduler. 39 """ 40 41 super().__init__(latent_dim) 42 self.save_hyperparameters() 43 44 vae_encoder = RAEEncoder(num_channels, ae_dim, latent_dim, use_batchnorm=True) 45 vae_decoder = RAEDecoder(num_channels, latent_dim, ae_dim) 46 self.vae = VAE(vae_encoder, vae_decoder, beta) 47 self.optim_lr = optim_lr 48 self.optim_weight_decay = optim_weight_decay 49 self.scheduler_args: dict[str, Any] = { 50 "max_lr": optim_lr, 51 "total_steps": 1, 52 } 53 self.scheduler_args.update(scheduler_args or {}) 54 55 def compute_loss( 56 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 57 ) -> LossOutput: 58 return LossOutput(mse_loss(pred, target, reduction="mean")) 59 60 def encode(self, x: torch.Tensor) -> torch.Tensor: 61 return self.vae.encode(x) 62 63 def decode(self, z: torch.Tensor) -> torch.Tensor: 64 out = self.vae.decode(z) 65 return out 66 67 def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore 68 return self.decode(self.encode(x)) 69 70 def generic_step( 71 self, 72 x: torch.Tensor, 73 mode: str = "train", 74 ) -> torch.Tensor: 75 (mean, logvar), reconstruction = self.vae(x) 76 77 reconstruction_loss = gaussian_nll(reconstruction, torch.tensor(0), x).sum() 78 79 kl_loss = kl_divergence_loss(mean, logvar) 80 total_loss = reconstruction_loss + self.vae.beta * kl_loss 81 82 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 83 self.log(f"{mode}/kl_loss", kl_loss) 84 self.log(f"{mode}/loss", total_loss) 85 return total_loss 86 87 def validation_step( # type: ignore 88 self, 89 batch: Mapping[str, torch.Tensor], 90 batch_idx: int, 91 ) -> torch.Tensor: 92 x = batch["v"] 93 return self.generic_step(x, "val") 94 95 def training_step( # type: ignore 96 self, 97 batch: Mapping[frozenset[str], Mapping[str, torch.Tensor]], 98 batch_idx: int, 99 ) -> torch.Tensor: 100 x = batch[frozenset(["v"])]["v"] 101 return self.generic_step(x, "train") 102 103 def configure_optimizers( # type: ignore 104 self, 105 ) -> dict[str, Any]: 106 optimizer = torch.optim.AdamW( 107 self.parameters(), 108 lr=self.optim_lr, 109 weight_decay=self.optim_weight_decay, 110 ) 111 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 112 113 return { 114 "optimizer": optimizer, 115 "lr_scheduler": { 116 "scheduler": lr_scheduler, 117 "interval": "step", 118 }, 119 }
Base class for a DomainModule that defines domain specific modules of the GW.
17 def __init__( 18 self, 19 num_channels: int, 20 latent_dim: int, 21 ae_dim: int, 22 beta: float = 1, 23 optim_lr: float = 1e-3, 24 optim_weight_decay: float = 0, 25 scheduler_args: Mapping[str, Any] | None = None, 26 ): 27 """ 28 Visual domain module. This defines shimmer's `DomainModule` for the vision 29 side with a VAE. 30 31 Args: 32 num_channels (`int`): number of input channels (for RGB image, use 3) 33 latent_dim (`int`): latent dimension of the vision domain 34 ae_dim (`int`): internal auto-encoder dimension of the VAE 35 beta (`float`): beta value if beta-VAE. (Defaults to 1.0) 36 optim_lr (`float`): training learning rate 37 optim_weight_decay (`float`): training weight decay 38 scheduler_args (`Mapping[str, Any] | None`): Args for the scheduler. 39 """ 40 41 super().__init__(latent_dim) 42 self.save_hyperparameters() 43 44 vae_encoder = RAEEncoder(num_channels, ae_dim, latent_dim, use_batchnorm=True) 45 vae_decoder = RAEDecoder(num_channels, latent_dim, ae_dim) 46 self.vae = VAE(vae_encoder, vae_decoder, beta) 47 self.optim_lr = optim_lr 48 self.optim_weight_decay = optim_weight_decay 49 self.scheduler_args: dict[str, Any] = { 50 "max_lr": optim_lr, 51 "total_steps": 1, 52 } 53 self.scheduler_args.update(scheduler_args or {})
Visual domain module. This defines shimmer's DomainModule
for the vision
side with a VAE.
Arguments:
- num_channels (
int
): number of input channels (for RGB image, use 3) - latent_dim (
int
): latent dimension of the vision domain - ae_dim (
int
): internal auto-encoder dimension of the VAE - beta (
float
): beta value if beta-VAE. (Defaults to 1.0) - optim_lr (
float
): training learning rate - optim_weight_decay (
float
): training weight decay - scheduler_args (
Mapping[str, Any] | None
): Args for the scheduler.
55 def compute_loss( 56 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 57 ) -> LossOutput: 58 return LossOutput(mse_loss(pred, target, reduction="mean"))
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor - raw_target (
Any
): raw data from the input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.
Encode the domain data into a unimodal representation.
Arguments:
- x (
Any
): data of the domain.
Returns:
torch.Tensor
: a unimodal representation.
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor
): unimodal representation of the domain.
Returns:
Any
: the original domain data.
67 def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore 68 return self.decode(self.encode(x))
Same as torch.nn.Module.forward()
.
Arguments:
- *args: Whatever you decide to pass into the forward method.
- **kwargs: Keyword arguments are also possible.
Return:
Your model's output
70 def generic_step( 71 self, 72 x: torch.Tensor, 73 mode: str = "train", 74 ) -> torch.Tensor: 75 (mean, logvar), reconstruction = self.vae(x) 76 77 reconstruction_loss = gaussian_nll(reconstruction, torch.tensor(0), x).sum() 78 79 kl_loss = kl_divergence_loss(mean, logvar) 80 total_loss = reconstruction_loss + self.vae.beta * kl_loss 81 82 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 83 self.log(f"{mode}/kl_loss", kl_loss) 84 self.log(f"{mode}/loss", total_loss) 85 return total_loss
87 def validation_step( # type: ignore 88 self, 89 batch: Mapping[str, torch.Tensor], 90 batch_idx: int, 91 ) -> torch.Tensor: 92 x = batch["v"] 93 return self.generic_step(x, "val")
Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
.None
- Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples::
# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# log the outputs!
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders, validation_step()
will have an additional argument. We recommend
setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
# dataloader_idx tells you which dataset this is.
...
Note:
If you don't need to validate you don't need to implement this method.
Note:
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
95 def training_step( # type: ignore 96 self, 97 batch: Mapping[frozenset[str], Mapping[str, torch.Tensor]], 98 batch_idx: int, 99 ) -> torch.Tensor: 100 x = batch[frozenset(["v"])]["v"] 101 return self.generic_step(x, "train")
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary which can include any keys, but must include the key'loss'
in the case of automatic optimization.None
- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example::
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
def __init__(self):
super().__init__()
self.automatic_optimization = False
# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
opt1, opt2 = self.optimizers()
# do training_step with encoder
...
opt1.step()
# do training_step with decoder
...
opt2.step()
Note:
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
103 def configure_optimizers( # type: ignore 104 self, 105 ) -> dict[str, Any]: 106 optimizer = torch.optim.AdamW( 107 self.parameters(), 108 lr=self.optim_lr, 109 weight_decay=self.optim_weight_decay, 110 ) 111 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 112 113 return { 114 "optimizer": optimizer, 115 "lr_scheduler": { 116 "scheduler": lr_scheduler, 117 "interval": "step", 118 }, 119 }
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
Return:
Any of these 6 options.
- Single optimizer.
- List or Tuple of optimizers.
- Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).- Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.- None - Fit will run without any optimizer.
The lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration.
The default configuration is shown below.
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
When there are schedulers in which the .step()
method is conditioned on a value, such as the
torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that the
lr_scheduler_config
contains the keyword "monitor"
set to the metric name that the scheduler
should be conditioned on.
.. testcode::
# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
optimizer = Adam(...)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(optimizer, ...),
"monitor": "metric_to_track",
"frequency": "indicates how often the metric is updated",
# If "monitor" references validation metrics, then "frequency" should be set to a
# multiple of "trainer.check_val_every_n_epoch".
},
}
# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return (
{
"optimizer": optimizer1,
"lr_scheduler": {
"scheduler": scheduler1,
"monitor": "metric_to_track",
},
},
{"optimizer": optimizer2, "lr_scheduler": scheduler2},
)
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in your ~lightning.pytorch.core.LightningModule
.
Note:
Some things to know:
- Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.- If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default "epoch") in the scheduler configuration, Lightning will call the scheduler's.step()
method automatically in case of automatic optimization.- If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.- If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.- If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
- If you need to control how often the optimizer steps, override the
optimizer_step()
hook.
122class VisualLatentDomainModule(DomainModule): 123 def __init__(self, visual_module: VisualDomainModule): 124 super().__init__(visual_module.latent_dim) 125 self.visual_module = visual_module 126 127 def encode(self, x: torch.Tensor) -> torch.Tensor: 128 return x 129 130 def decode(self, z: torch.Tensor) -> torch.Tensor: 131 return z 132 133 def compute_loss( 134 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 135 ) -> LossOutput: 136 return LossOutput(mse_loss(pred, target, reduction="mean")) 137 138 def decode_images(self, z: torch.Tensor) -> torch.Tensor: 139 LOGGER.debug(f"VisualLatentDomainModule.decode_images: z.shape = {z.size()}") 140 return self.visual_module.decode(z)
Base class for a DomainModule that defines domain specific modules of the GW.
123 def __init__(self, visual_module: VisualDomainModule): 124 super().__init__(visual_module.latent_dim) 125 self.visual_module = visual_module
Initializes a DomainModule.
Arguments:
- latent_dim (
int
): latent dimension of the unimodal module
Encode the domain data into a unimodal representation.
Arguments:
- x (
Any
): data of the domain.
Returns:
torch.Tensor
: a unimodal representation.
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor
): unimodal representation of the domain.
Returns:
Any
: the original domain data.
133 def compute_loss( 134 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 135 ) -> LossOutput: 136 return LossOutput(mse_loss(pred, target, reduction="mean"))
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor - raw_target (
Any
): raw data from the input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.
143class VisualLatentDomainWithUnpairedModule(DomainModule): 144 def __init__(self, visual_module: VisualDomainModule, coef_unpaired: float = 0.5): 145 super().__init__(visual_module.latent_dim + 1) 146 147 if coef_unpaired < 0 or coef_unpaired > 1: 148 raise ValueError("coef_unpaired should be in [0, 1]") 149 150 self.visual_module = visual_module 151 self.paired_dim = self.visual_module.latent_dim 152 self.coef_unpaired = coef_unpaired 153 154 def encode(self, x: torch.Tensor) -> torch.Tensor: 155 return x 156 157 def decode(self, z: torch.Tensor) -> torch.Tensor: 158 return z 159 160 def compute_loss( 161 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 162 ) -> LossOutput: 163 paired_loss = mse_loss(pred[:, : self.paired_dim], target[:, : self.paired_dim]) 164 unpaired_loss = mse_loss( 165 pred[:, self.paired_dim :], target[:, self.paired_dim :] 166 ) 167 total_loss = ( 168 self.coef_unpaired * unpaired_loss + (1 - self.coef_unpaired) * paired_loss 169 ) 170 return LossOutput( 171 loss=total_loss, 172 metrics={ 173 "unpaired": unpaired_loss, 174 "paired": paired_loss, 175 }, 176 ) 177 178 def decode_images(self, z: torch.Tensor) -> torch.Tensor: 179 LOGGER.debug(f"VisualLatentDomainModule.decode_images: z.shape = {z.size()}") 180 return self.visual_module.decode(z[:, :-1])
Base class for a DomainModule that defines domain specific modules of the GW.
144 def __init__(self, visual_module: VisualDomainModule, coef_unpaired: float = 0.5): 145 super().__init__(visual_module.latent_dim + 1) 146 147 if coef_unpaired < 0 or coef_unpaired > 1: 148 raise ValueError("coef_unpaired should be in [0, 1]") 149 150 self.visual_module = visual_module 151 self.paired_dim = self.visual_module.latent_dim 152 self.coef_unpaired = coef_unpaired
Initializes a DomainModule.
Arguments:
- latent_dim (
int
): latent dimension of the unimodal module
Encode the domain data into a unimodal representation.
Arguments:
- x (
Any
): data of the domain.
Returns:
torch.Tensor
: a unimodal representation.
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor
): unimodal representation of the domain.
Returns:
Any
: the original domain data.
160 def compute_loss( 161 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 162 ) -> LossOutput: 163 paired_loss = mse_loss(pred[:, : self.paired_dim], target[:, : self.paired_dim]) 164 unpaired_loss = mse_loss( 165 pred[:, self.paired_dim :], target[:, self.paired_dim :] 166 ) 167 total_loss = ( 168 self.coef_unpaired * unpaired_loss + (1 - self.coef_unpaired) * paired_loss 169 ) 170 return LossOutput( 171 loss=total_loss, 172 metrics={ 173 "unpaired": unpaired_loss, 174 "paired": paired_loss, 175 }, 176 )
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor - raw_target (
Any
): raw data from the input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.