shimmer_ssd.modules.domains.attribute
1from collections.abc import Mapping, Sequence 2from typing import Any 3 4import torch 5import torch.nn.functional as F 6from shimmer import LossOutput 7from shimmer.modules.domain import DomainModule 8from shimmer.modules.global_workspace import SchedulerArgs 9from shimmer.modules.vae import ( 10 VAE, 11 VAEDecoder, 12 VAEEncoder, 13 gaussian_nll, 14 kl_divergence_loss, 15) 16from torch import nn 17from torch.optim.lr_scheduler import OneCycleLR 18 19 20class Encoder(VAEEncoder): 21 def __init__( 22 self, 23 hidden_dim: int, 24 out_dim: int, 25 ): 26 super().__init__() 27 28 self.hidden_dim = hidden_dim 29 self.out_dim = out_dim 30 31 self.encoder = nn.Sequential( 32 nn.Linear(11, hidden_dim), 33 nn.ReLU(), 34 nn.Linear(hidden_dim, hidden_dim), 35 nn.ReLU(), 36 nn.Linear(hidden_dim, out_dim), 37 nn.ReLU(), 38 ) 39 40 self.q_mean = nn.Linear(self.out_dim, self.out_dim) 41 self.q_logvar = nn.Linear(self.out_dim, self.out_dim) 42 43 def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: 44 out = torch.cat(list(x), dim=-1) 45 out = self.encoder(out) 46 return self.q_mean(out), self.q_logvar(out) 47 48 49class Decoder(VAEDecoder): 50 def __init__( 51 self, 52 in_dim: int, 53 hidden_dim: int, 54 ): 55 super().__init__() 56 57 self.in_dim = in_dim 58 self.hidden_dim = hidden_dim 59 60 self.decoder = nn.Sequential( 61 nn.Linear(self.in_dim, self.hidden_dim), 62 nn.ReLU(), 63 nn.Linear(self.hidden_dim, self.hidden_dim), 64 nn.ReLU(), 65 ) 66 67 self.decoder_categories = nn.Sequential( 68 nn.Linear(self.hidden_dim, 3), 69 ) 70 71 self.decoder_attributes = nn.Sequential( 72 nn.Linear(self.hidden_dim, 8), 73 nn.Tanh(), 74 ) 75 76 def forward(self, x: torch.Tensor) -> list[torch.Tensor]: 77 out = self.decoder(x) 78 return [self.decoder_categories(out), self.decoder_attributes(out)] 79 80 81class AttributeDomainModule(DomainModule): 82 in_dim = 11 83 84 def __init__( 85 self, 86 latent_dim: int, 87 hidden_dim: int, 88 beta: float = 1, 89 coef_categories: float = 1, 90 coef_attributes: float = 1, 91 optim_lr: float = 1e-3, 92 optim_weight_decay: float = 0, 93 scheduler_args: SchedulerArgs | None = None, 94 ): 95 """ 96 Defines the Attribute domain module. 97 98 Args: 99 latent_dim (`int`): the latent dimension of the module 100 hidden_dim (`int`): hidden dimension of the VAE encoders and decoders 101 beta (`float`): for beta-VAE 102 coef_categories (`float`): loss coefficient attributed to the category 103 (Defaults to 1.0) 104 coef_attributes (`float`): loss coefficient attributed to the rest of the 105 attributes (Defaults to 1.0) 106 optim_lr (`float`): learning rate for the optimizer 107 optim_weight_decay (`float`): weight decay for the optimizer 108 scheduler_args (`SchedulerArgs | None`): Scheduler arguments 109 """ 110 super().__init__(latent_dim) 111 self.save_hyperparameters() 112 113 self.hidden_dim = hidden_dim 114 self.coef_categories = coef_categories 115 self.coef_attributes = coef_attributes 116 117 vae_encoder = Encoder(self.hidden_dim, self.latent_dim) 118 vae_decoder = Decoder(self.latent_dim, self.hidden_dim) 119 self.vae = VAE(vae_encoder, vae_decoder, beta) 120 121 self.optim_lr = optim_lr 122 self.optim_weight_decay = optim_weight_decay 123 124 self.scheduler_args = SchedulerArgs( 125 max_lr=optim_lr, 126 total_steps=1, 127 ) 128 self.scheduler_args.update(scheduler_args or {}) 129 130 def compute_loss( 131 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 132 ) -> LossOutput: 133 return LossOutput(F.mse_loss(pred, target, reduction="mean")) 134 135 def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 136 """ 137 x must contain 2 items: 138 - the class 139 - the attributes 140 """ 141 assert ( 142 len(x) == 2 143 ), "x must only contain 2 items (use attr_unpaired to add an unpaired value)" 144 return self.vae.encode(x) 145 146 def decode(self, z: torch.Tensor) -> list[torch.Tensor]: 147 out = list(self.vae.decode(z)) 148 if not isinstance(out, Sequence): 149 raise ValueError("The output of vae.decode should be a sequence.") 150 return out 151 152 def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore 153 return self.decode(self.encode(x)) 154 155 def generic_step( 156 self, 157 x: Sequence[torch.Tensor], 158 mode: str = "train", 159 ) -> torch.Tensor: 160 x_categories, x_attributes = x[0], x[1] 161 162 (mean, logvar), reconstruction = self.vae(x) 163 reconstruction_categories = reconstruction[0] 164 reconstruction_attributes = reconstruction[1] 165 166 reconstruction_loss_categories = F.cross_entropy( 167 reconstruction_categories, 168 x_categories.argmax(dim=1), 169 reduction="sum", 170 ) 171 reconstruction_loss_attributes = gaussian_nll( 172 reconstruction_attributes, torch.tensor(0), x_attributes 173 ).sum() 174 175 reconstruction_loss = ( 176 self.coef_categories * reconstruction_loss_categories 177 + self.coef_attributes * reconstruction_loss_attributes 178 ) 179 kl_loss = kl_divergence_loss(mean, logvar) 180 total_loss = reconstruction_loss + self.vae.beta * kl_loss 181 182 self.log( 183 f"{mode}/reconstruction_loss_categories", 184 reconstruction_loss_categories, 185 ) 186 self.log( 187 f"{mode}/reconstruction_loss_attributes", 188 reconstruction_loss_attributes, 189 ) 190 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 191 self.log(f"{mode}/kl_loss", kl_loss) 192 self.log(f"{mode}/loss", total_loss) 193 return total_loss 194 195 def validation_step( # type: ignore 196 self, batch: Mapping[str, Sequence[torch.Tensor]], _ 197 ) -> torch.Tensor: 198 x = batch["attr"] 199 return self.generic_step(x, "val") 200 201 def training_step( # type: ignore 202 self, 203 batch: Mapping[frozenset[str], Mapping[str, Sequence[torch.Tensor]]], 204 _, 205 ) -> torch.Tensor: 206 x = batch[frozenset(["attr"])]["attr"] 207 return self.generic_step(x, "train") 208 209 def configure_optimizers( # type: ignore 210 self, 211 ) -> dict[str, Any]: 212 optimizer = torch.optim.AdamW( 213 self.parameters(), 214 lr=self.optim_lr, 215 weight_decay=self.optim_weight_decay, 216 ) 217 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 218 219 return { 220 "optimizer": optimizer, 221 "lr_scheduler": { 222 "scheduler": lr_scheduler, 223 "interval": "step", 224 }, 225 } 226 227 228class AttributeWithUnpairedDomainModule(DomainModule): 229 in_dim = 11 230 231 def __init__( 232 self, 233 latent_dim: int, 234 hidden_dim: int, 235 beta: float = 1, 236 coef_categories: float = 1, 237 coef_attributes: float = 1, 238 n_unpaired: int = 1, 239 optim_lr: float = 1e-3, 240 optim_weight_decay: float = 0, 241 scheduler_args: SchedulerArgs | None = None, 242 coef_unpaired: float = 0.5, 243 ): 244 super().__init__(latent_dim + n_unpaired) 245 246 if coef_categories < 0 or coef_categories > 1: 247 raise ValueError("coef_categories should be in [0, 1]") 248 if coef_attributes < 0 or coef_attributes > 1: 249 raise ValueError("coef_attributes should be in [0, 1]") 250 if coef_unpaired < 0 or coef_unpaired > 1: 251 raise ValueError("coef_unpaired should be in [0, 1]") 252 253 self.save_hyperparameters() 254 self.paired_dim = latent_dim 255 self.n_unpaired = n_unpaired 256 self.hidden_dim = hidden_dim 257 self.coef_categories = coef_categories 258 self.coef_attributes = coef_attributes 259 self.coef_unpaired = coef_unpaired 260 261 vae_encoder = Encoder(self.hidden_dim, self.latent_dim - self.n_unpaired) 262 vae_decoder = Decoder(self.latent_dim - self.n_unpaired, self.hidden_dim) 263 self.vae = VAE(vae_encoder, vae_decoder, beta) 264 265 self.optim_lr = optim_lr 266 self.optim_weight_decay = optim_weight_decay 267 268 self.scheduler_args = SchedulerArgs( 269 max_lr=optim_lr, 270 total_steps=1, 271 ) 272 self.scheduler_args.update(scheduler_args or {}) 273 274 def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 275 """ 276 x must contains 3 items: 277 - the class 278 - the attributes 279 - the unpaired value 280 """ 281 assert len(x) == 3, ( 282 "x must have the unpaired value " 283 "(use `attr` instead of `attr_unpaired` otherwise)." 284 ) 285 z = self.vae.encode(x[:-1]) 286 return torch.cat([z, x[-1]], dim=-1) 287 288 def decode(self, z: torch.Tensor) -> list[torch.Tensor]: 289 paired = z[:, : self.paired_dim] 290 unpaired = z[:, self.paired_dim :] 291 out = list(self.vae.decode(paired)) 292 out.append(unpaired) 293 return out 294 295 def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore 296 return self.decode(self.encode(x)) 297 298 def compute_loss( 299 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 300 ) -> LossOutput: 301 paired_loss = F.mse_loss( 302 pred[:, : self.paired_dim], target[:, : self.paired_dim] 303 ) 304 unpaired_loss = F.mse_loss( 305 pred[:, self.paired_dim :], target[:, self.paired_dim :] 306 ) 307 total_loss = unpaired_loss + paired_loss 308 return LossOutput( 309 loss=total_loss, 310 metrics={ 311 "unpaired": unpaired_loss, 312 "paired": paired_loss, 313 }, 314 ) 315 316 317class AttributeLegacyDomainModule(DomainModule): 318 latent_dim = 11 319 320 def __init__(self): 321 super().__init__(self.latent_dim) 322 self.save_hyperparameters() 323 324 def compute_loss( 325 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 326 ) -> LossOutput: 327 pred_cat, pred_attr, _ = self.decode(pred) 328 target_cat, target_attr, _ = self.decode(target) 329 330 loss_attr = F.mse_loss(pred_attr, target_attr, reduction="mean") 331 loss_cat = F.nll_loss(pred_cat, torch.argmax(target_cat, 1)) 332 loss = loss_attr + loss_cat 333 334 return LossOutput(loss, metrics={"loss_attr": loss_attr, "loss_cat": loss_cat}) 335 336 def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 337 assert len(x) == 2, "This must have only 2 items." 338 return torch.cat(list(x), dim=-1) 339 340 def decode(self, z: torch.Tensor) -> list: 341 categories = z[:, :3] 342 attr = z[:, 3:11] 343 return [categories, attr] 344 345 def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore 346 return self.decode(self.encode(x))
21class Encoder(VAEEncoder): 22 def __init__( 23 self, 24 hidden_dim: int, 25 out_dim: int, 26 ): 27 super().__init__() 28 29 self.hidden_dim = hidden_dim 30 self.out_dim = out_dim 31 32 self.encoder = nn.Sequential( 33 nn.Linear(11, hidden_dim), 34 nn.ReLU(), 35 nn.Linear(hidden_dim, hidden_dim), 36 nn.ReLU(), 37 nn.Linear(hidden_dim, out_dim), 38 nn.ReLU(), 39 ) 40 41 self.q_mean = nn.Linear(self.out_dim, self.out_dim) 42 self.q_logvar = nn.Linear(self.out_dim, self.out_dim) 43 44 def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: 45 out = torch.cat(list(x), dim=-1) 46 out = self.encoder(out) 47 return self.q_mean(out), self.q_logvar(out)
Base class for a VAE encoder.
22 def __init__( 23 self, 24 hidden_dim: int, 25 out_dim: int, 26 ): 27 super().__init__() 28 29 self.hidden_dim = hidden_dim 30 self.out_dim = out_dim 31 32 self.encoder = nn.Sequential( 33 nn.Linear(11, hidden_dim), 34 nn.ReLU(), 35 nn.Linear(hidden_dim, hidden_dim), 36 nn.ReLU(), 37 nn.Linear(hidden_dim, out_dim), 38 nn.ReLU(), 39 ) 40 41 self.q_mean = nn.Linear(self.out_dim, self.out_dim) 42 self.q_logvar = nn.Linear(self.out_dim, self.out_dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
44 def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: 45 out = torch.cat(list(x), dim=-1) 46 out = self.encoder(out) 47 return self.q_mean(out), self.q_logvar(out)
Encode representation with VAE.
Arguments:
- x (
Any
): Some input value
Returns:
tuple[torch.Tensor, torch.Tensor]
: the mean and log variance
50class Decoder(VAEDecoder): 51 def __init__( 52 self, 53 in_dim: int, 54 hidden_dim: int, 55 ): 56 super().__init__() 57 58 self.in_dim = in_dim 59 self.hidden_dim = hidden_dim 60 61 self.decoder = nn.Sequential( 62 nn.Linear(self.in_dim, self.hidden_dim), 63 nn.ReLU(), 64 nn.Linear(self.hidden_dim, self.hidden_dim), 65 nn.ReLU(), 66 ) 67 68 self.decoder_categories = nn.Sequential( 69 nn.Linear(self.hidden_dim, 3), 70 ) 71 72 self.decoder_attributes = nn.Sequential( 73 nn.Linear(self.hidden_dim, 8), 74 nn.Tanh(), 75 ) 76 77 def forward(self, x: torch.Tensor) -> list[torch.Tensor]: 78 out = self.decoder(x) 79 return [self.decoder_categories(out), self.decoder_attributes(out)]
Base class for a VAE decoder.
51 def __init__( 52 self, 53 in_dim: int, 54 hidden_dim: int, 55 ): 56 super().__init__() 57 58 self.in_dim = in_dim 59 self.hidden_dim = hidden_dim 60 61 self.decoder = nn.Sequential( 62 nn.Linear(self.in_dim, self.hidden_dim), 63 nn.ReLU(), 64 nn.Linear(self.hidden_dim, self.hidden_dim), 65 nn.ReLU(), 66 ) 67 68 self.decoder_categories = nn.Sequential( 69 nn.Linear(self.hidden_dim, 3), 70 ) 71 72 self.decoder_attributes = nn.Sequential( 73 nn.Linear(self.hidden_dim, 8), 74 nn.Tanh(), 75 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
77 def forward(self, x: torch.Tensor) -> list[torch.Tensor]: 78 out = self.decoder(x) 79 return [self.decoder_categories(out), self.decoder_attributes(out)]
Decode representation with VAE
Arguments:
- x (
torch.Tensor
): VAE latent representation representation
Returns:
Any
: the reconstructed input
82class AttributeDomainModule(DomainModule): 83 in_dim = 11 84 85 def __init__( 86 self, 87 latent_dim: int, 88 hidden_dim: int, 89 beta: float = 1, 90 coef_categories: float = 1, 91 coef_attributes: float = 1, 92 optim_lr: float = 1e-3, 93 optim_weight_decay: float = 0, 94 scheduler_args: SchedulerArgs | None = None, 95 ): 96 """ 97 Defines the Attribute domain module. 98 99 Args: 100 latent_dim (`int`): the latent dimension of the module 101 hidden_dim (`int`): hidden dimension of the VAE encoders and decoders 102 beta (`float`): for beta-VAE 103 coef_categories (`float`): loss coefficient attributed to the category 104 (Defaults to 1.0) 105 coef_attributes (`float`): loss coefficient attributed to the rest of the 106 attributes (Defaults to 1.0) 107 optim_lr (`float`): learning rate for the optimizer 108 optim_weight_decay (`float`): weight decay for the optimizer 109 scheduler_args (`SchedulerArgs | None`): Scheduler arguments 110 """ 111 super().__init__(latent_dim) 112 self.save_hyperparameters() 113 114 self.hidden_dim = hidden_dim 115 self.coef_categories = coef_categories 116 self.coef_attributes = coef_attributes 117 118 vae_encoder = Encoder(self.hidden_dim, self.latent_dim) 119 vae_decoder = Decoder(self.latent_dim, self.hidden_dim) 120 self.vae = VAE(vae_encoder, vae_decoder, beta) 121 122 self.optim_lr = optim_lr 123 self.optim_weight_decay = optim_weight_decay 124 125 self.scheduler_args = SchedulerArgs( 126 max_lr=optim_lr, 127 total_steps=1, 128 ) 129 self.scheduler_args.update(scheduler_args or {}) 130 131 def compute_loss( 132 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 133 ) -> LossOutput: 134 return LossOutput(F.mse_loss(pred, target, reduction="mean")) 135 136 def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 137 """ 138 x must contain 2 items: 139 - the class 140 - the attributes 141 """ 142 assert ( 143 len(x) == 2 144 ), "x must only contain 2 items (use attr_unpaired to add an unpaired value)" 145 return self.vae.encode(x) 146 147 def decode(self, z: torch.Tensor) -> list[torch.Tensor]: 148 out = list(self.vae.decode(z)) 149 if not isinstance(out, Sequence): 150 raise ValueError("The output of vae.decode should be a sequence.") 151 return out 152 153 def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore 154 return self.decode(self.encode(x)) 155 156 def generic_step( 157 self, 158 x: Sequence[torch.Tensor], 159 mode: str = "train", 160 ) -> torch.Tensor: 161 x_categories, x_attributes = x[0], x[1] 162 163 (mean, logvar), reconstruction = self.vae(x) 164 reconstruction_categories = reconstruction[0] 165 reconstruction_attributes = reconstruction[1] 166 167 reconstruction_loss_categories = F.cross_entropy( 168 reconstruction_categories, 169 x_categories.argmax(dim=1), 170 reduction="sum", 171 ) 172 reconstruction_loss_attributes = gaussian_nll( 173 reconstruction_attributes, torch.tensor(0), x_attributes 174 ).sum() 175 176 reconstruction_loss = ( 177 self.coef_categories * reconstruction_loss_categories 178 + self.coef_attributes * reconstruction_loss_attributes 179 ) 180 kl_loss = kl_divergence_loss(mean, logvar) 181 total_loss = reconstruction_loss + self.vae.beta * kl_loss 182 183 self.log( 184 f"{mode}/reconstruction_loss_categories", 185 reconstruction_loss_categories, 186 ) 187 self.log( 188 f"{mode}/reconstruction_loss_attributes", 189 reconstruction_loss_attributes, 190 ) 191 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 192 self.log(f"{mode}/kl_loss", kl_loss) 193 self.log(f"{mode}/loss", total_loss) 194 return total_loss 195 196 def validation_step( # type: ignore 197 self, batch: Mapping[str, Sequence[torch.Tensor]], _ 198 ) -> torch.Tensor: 199 x = batch["attr"] 200 return self.generic_step(x, "val") 201 202 def training_step( # type: ignore 203 self, 204 batch: Mapping[frozenset[str], Mapping[str, Sequence[torch.Tensor]]], 205 _, 206 ) -> torch.Tensor: 207 x = batch[frozenset(["attr"])]["attr"] 208 return self.generic_step(x, "train") 209 210 def configure_optimizers( # type: ignore 211 self, 212 ) -> dict[str, Any]: 213 optimizer = torch.optim.AdamW( 214 self.parameters(), 215 lr=self.optim_lr, 216 weight_decay=self.optim_weight_decay, 217 ) 218 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 219 220 return { 221 "optimizer": optimizer, 222 "lr_scheduler": { 223 "scheduler": lr_scheduler, 224 "interval": "step", 225 }, 226 }
Base class for a DomainModule that defines domain specific modules of the GW.
85 def __init__( 86 self, 87 latent_dim: int, 88 hidden_dim: int, 89 beta: float = 1, 90 coef_categories: float = 1, 91 coef_attributes: float = 1, 92 optim_lr: float = 1e-3, 93 optim_weight_decay: float = 0, 94 scheduler_args: SchedulerArgs | None = None, 95 ): 96 """ 97 Defines the Attribute domain module. 98 99 Args: 100 latent_dim (`int`): the latent dimension of the module 101 hidden_dim (`int`): hidden dimension of the VAE encoders and decoders 102 beta (`float`): for beta-VAE 103 coef_categories (`float`): loss coefficient attributed to the category 104 (Defaults to 1.0) 105 coef_attributes (`float`): loss coefficient attributed to the rest of the 106 attributes (Defaults to 1.0) 107 optim_lr (`float`): learning rate for the optimizer 108 optim_weight_decay (`float`): weight decay for the optimizer 109 scheduler_args (`SchedulerArgs | None`): Scheduler arguments 110 """ 111 super().__init__(latent_dim) 112 self.save_hyperparameters() 113 114 self.hidden_dim = hidden_dim 115 self.coef_categories = coef_categories 116 self.coef_attributes = coef_attributes 117 118 vae_encoder = Encoder(self.hidden_dim, self.latent_dim) 119 vae_decoder = Decoder(self.latent_dim, self.hidden_dim) 120 self.vae = VAE(vae_encoder, vae_decoder, beta) 121 122 self.optim_lr = optim_lr 123 self.optim_weight_decay = optim_weight_decay 124 125 self.scheduler_args = SchedulerArgs( 126 max_lr=optim_lr, 127 total_steps=1, 128 ) 129 self.scheduler_args.update(scheduler_args or {})
Defines the Attribute domain module.
Arguments:
- latent_dim (
int
): the latent dimension of the module - hidden_dim (
int
): hidden dimension of the VAE encoders and decoders - beta (
float
): for beta-VAE - coef_categories (
float
): loss coefficient attributed to the category (Defaults to 1.0) - coef_attributes (
float
): loss coefficient attributed to the rest of the attributes (Defaults to 1.0) - optim_lr (
float
): learning rate for the optimizer - optim_weight_decay (
float
): weight decay for the optimizer - scheduler_args (
SchedulerArgs | None
): Scheduler arguments
131 def compute_loss( 132 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 133 ) -> LossOutput: 134 return LossOutput(F.mse_loss(pred, target, reduction="mean"))
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor - raw_target (
Any
): raw data from the input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.
136 def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 137 """ 138 x must contain 2 items: 139 - the class 140 - the attributes 141 """ 142 assert ( 143 len(x) == 2 144 ), "x must only contain 2 items (use attr_unpaired to add an unpaired value)" 145 return self.vae.encode(x)
x must contain 2 items:
- the class
- the attributes
147 def decode(self, z: torch.Tensor) -> list[torch.Tensor]: 148 out = list(self.vae.decode(z)) 149 if not isinstance(out, Sequence): 150 raise ValueError("The output of vae.decode should be a sequence.") 151 return out
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor
): unimodal representation of the domain.
Returns:
Any
: the original domain data.
153 def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore 154 return self.decode(self.encode(x))
Same as torch.nn.Module.forward()
.
Arguments:
- *args: Whatever you decide to pass into the forward method.
- **kwargs: Keyword arguments are also possible.
Return:
Your model's output
156 def generic_step( 157 self, 158 x: Sequence[torch.Tensor], 159 mode: str = "train", 160 ) -> torch.Tensor: 161 x_categories, x_attributes = x[0], x[1] 162 163 (mean, logvar), reconstruction = self.vae(x) 164 reconstruction_categories = reconstruction[0] 165 reconstruction_attributes = reconstruction[1] 166 167 reconstruction_loss_categories = F.cross_entropy( 168 reconstruction_categories, 169 x_categories.argmax(dim=1), 170 reduction="sum", 171 ) 172 reconstruction_loss_attributes = gaussian_nll( 173 reconstruction_attributes, torch.tensor(0), x_attributes 174 ).sum() 175 176 reconstruction_loss = ( 177 self.coef_categories * reconstruction_loss_categories 178 + self.coef_attributes * reconstruction_loss_attributes 179 ) 180 kl_loss = kl_divergence_loss(mean, logvar) 181 total_loss = reconstruction_loss + self.vae.beta * kl_loss 182 183 self.log( 184 f"{mode}/reconstruction_loss_categories", 185 reconstruction_loss_categories, 186 ) 187 self.log( 188 f"{mode}/reconstruction_loss_attributes", 189 reconstruction_loss_attributes, 190 ) 191 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 192 self.log(f"{mode}/kl_loss", kl_loss) 193 self.log(f"{mode}/loss", total_loss) 194 return total_loss
196 def validation_step( # type: ignore 197 self, batch: Mapping[str, Sequence[torch.Tensor]], _ 198 ) -> torch.Tensor: 199 x = batch["attr"] 200 return self.generic_step(x, "val")
Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
.None
- Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples::
# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# log the outputs!
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders, validation_step()
will have an additional argument. We recommend
setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
# dataloader_idx tells you which dataset this is.
...
Note:
If you don't need to validate you don't need to implement this method.
Note:
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
202 def training_step( # type: ignore 203 self, 204 batch: Mapping[frozenset[str], Mapping[str, Sequence[torch.Tensor]]], 205 _, 206 ) -> torch.Tensor: 207 x = batch[frozenset(["attr"])]["attr"] 208 return self.generic_step(x, "train")
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary which can include any keys, but must include the key'loss'
in the case of automatic optimization.None
- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example::
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
def __init__(self):
super().__init__()
self.automatic_optimization = False
# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
opt1, opt2 = self.optimizers()
# do training_step with encoder
...
opt1.step()
# do training_step with decoder
...
opt2.step()
Note:
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
210 def configure_optimizers( # type: ignore 211 self, 212 ) -> dict[str, Any]: 213 optimizer = torch.optim.AdamW( 214 self.parameters(), 215 lr=self.optim_lr, 216 weight_decay=self.optim_weight_decay, 217 ) 218 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 219 220 return { 221 "optimizer": optimizer, 222 "lr_scheduler": { 223 "scheduler": lr_scheduler, 224 "interval": "step", 225 }, 226 }
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
Return:
Any of these 6 options.
- Single optimizer.
- List or Tuple of optimizers.
- Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).- Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.- None - Fit will run without any optimizer.
The lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration.
The default configuration is shown below.
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
When there are schedulers in which the .step()
method is conditioned on a value, such as the
torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that the
lr_scheduler_config
contains the keyword "monitor"
set to the metric name that the scheduler
should be conditioned on.
.. testcode::
# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
optimizer = Adam(...)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(optimizer, ...),
"monitor": "metric_to_track",
"frequency": "indicates how often the metric is updated",
# If "monitor" references validation metrics, then "frequency" should be set to a
# multiple of "trainer.check_val_every_n_epoch".
},
}
# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return (
{
"optimizer": optimizer1,
"lr_scheduler": {
"scheduler": scheduler1,
"monitor": "metric_to_track",
},
},
{"optimizer": optimizer2, "lr_scheduler": scheduler2},
)
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in your ~lightning.pytorch.core.LightningModule
.
Note:
Some things to know:
- Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.- If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default "epoch") in the scheduler configuration, Lightning will call the scheduler's.step()
method automatically in case of automatic optimization.- If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.- If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.- If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
- If you need to control how often the optimizer steps, override the
optimizer_step()
hook.
229class AttributeWithUnpairedDomainModule(DomainModule): 230 in_dim = 11 231 232 def __init__( 233 self, 234 latent_dim: int, 235 hidden_dim: int, 236 beta: float = 1, 237 coef_categories: float = 1, 238 coef_attributes: float = 1, 239 n_unpaired: int = 1, 240 optim_lr: float = 1e-3, 241 optim_weight_decay: float = 0, 242 scheduler_args: SchedulerArgs | None = None, 243 coef_unpaired: float = 0.5, 244 ): 245 super().__init__(latent_dim + n_unpaired) 246 247 if coef_categories < 0 or coef_categories > 1: 248 raise ValueError("coef_categories should be in [0, 1]") 249 if coef_attributes < 0 or coef_attributes > 1: 250 raise ValueError("coef_attributes should be in [0, 1]") 251 if coef_unpaired < 0 or coef_unpaired > 1: 252 raise ValueError("coef_unpaired should be in [0, 1]") 253 254 self.save_hyperparameters() 255 self.paired_dim = latent_dim 256 self.n_unpaired = n_unpaired 257 self.hidden_dim = hidden_dim 258 self.coef_categories = coef_categories 259 self.coef_attributes = coef_attributes 260 self.coef_unpaired = coef_unpaired 261 262 vae_encoder = Encoder(self.hidden_dim, self.latent_dim - self.n_unpaired) 263 vae_decoder = Decoder(self.latent_dim - self.n_unpaired, self.hidden_dim) 264 self.vae = VAE(vae_encoder, vae_decoder, beta) 265 266 self.optim_lr = optim_lr 267 self.optim_weight_decay = optim_weight_decay 268 269 self.scheduler_args = SchedulerArgs( 270 max_lr=optim_lr, 271 total_steps=1, 272 ) 273 self.scheduler_args.update(scheduler_args or {}) 274 275 def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 276 """ 277 x must contains 3 items: 278 - the class 279 - the attributes 280 - the unpaired value 281 """ 282 assert len(x) == 3, ( 283 "x must have the unpaired value " 284 "(use `attr` instead of `attr_unpaired` otherwise)." 285 ) 286 z = self.vae.encode(x[:-1]) 287 return torch.cat([z, x[-1]], dim=-1) 288 289 def decode(self, z: torch.Tensor) -> list[torch.Tensor]: 290 paired = z[:, : self.paired_dim] 291 unpaired = z[:, self.paired_dim :] 292 out = list(self.vae.decode(paired)) 293 out.append(unpaired) 294 return out 295 296 def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore 297 return self.decode(self.encode(x)) 298 299 def compute_loss( 300 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 301 ) -> LossOutput: 302 paired_loss = F.mse_loss( 303 pred[:, : self.paired_dim], target[:, : self.paired_dim] 304 ) 305 unpaired_loss = F.mse_loss( 306 pred[:, self.paired_dim :], target[:, self.paired_dim :] 307 ) 308 total_loss = unpaired_loss + paired_loss 309 return LossOutput( 310 loss=total_loss, 311 metrics={ 312 "unpaired": unpaired_loss, 313 "paired": paired_loss, 314 }, 315 )
Base class for a DomainModule that defines domain specific modules of the GW.
232 def __init__( 233 self, 234 latent_dim: int, 235 hidden_dim: int, 236 beta: float = 1, 237 coef_categories: float = 1, 238 coef_attributes: float = 1, 239 n_unpaired: int = 1, 240 optim_lr: float = 1e-3, 241 optim_weight_decay: float = 0, 242 scheduler_args: SchedulerArgs | None = None, 243 coef_unpaired: float = 0.5, 244 ): 245 super().__init__(latent_dim + n_unpaired) 246 247 if coef_categories < 0 or coef_categories > 1: 248 raise ValueError("coef_categories should be in [0, 1]") 249 if coef_attributes < 0 or coef_attributes > 1: 250 raise ValueError("coef_attributes should be in [0, 1]") 251 if coef_unpaired < 0 or coef_unpaired > 1: 252 raise ValueError("coef_unpaired should be in [0, 1]") 253 254 self.save_hyperparameters() 255 self.paired_dim = latent_dim 256 self.n_unpaired = n_unpaired 257 self.hidden_dim = hidden_dim 258 self.coef_categories = coef_categories 259 self.coef_attributes = coef_attributes 260 self.coef_unpaired = coef_unpaired 261 262 vae_encoder = Encoder(self.hidden_dim, self.latent_dim - self.n_unpaired) 263 vae_decoder = Decoder(self.latent_dim - self.n_unpaired, self.hidden_dim) 264 self.vae = VAE(vae_encoder, vae_decoder, beta) 265 266 self.optim_lr = optim_lr 267 self.optim_weight_decay = optim_weight_decay 268 269 self.scheduler_args = SchedulerArgs( 270 max_lr=optim_lr, 271 total_steps=1, 272 ) 273 self.scheduler_args.update(scheduler_args or {})
Initializes a DomainModule.
Arguments:
- latent_dim (
int
): latent dimension of the unimodal module
275 def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 276 """ 277 x must contains 3 items: 278 - the class 279 - the attributes 280 - the unpaired value 281 """ 282 assert len(x) == 3, ( 283 "x must have the unpaired value " 284 "(use `attr` instead of `attr_unpaired` otherwise)." 285 ) 286 z = self.vae.encode(x[:-1]) 287 return torch.cat([z, x[-1]], dim=-1)
x must contains 3 items:
- the class
- the attributes
- the unpaired value
289 def decode(self, z: torch.Tensor) -> list[torch.Tensor]: 290 paired = z[:, : self.paired_dim] 291 unpaired = z[:, self.paired_dim :] 292 out = list(self.vae.decode(paired)) 293 out.append(unpaired) 294 return out
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor
): unimodal representation of the domain.
Returns:
Any
: the original domain data.
296 def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore 297 return self.decode(self.encode(x))
Same as torch.nn.Module.forward()
.
Arguments:
- *args: Whatever you decide to pass into the forward method.
- **kwargs: Keyword arguments are also possible.
Return:
Your model's output
299 def compute_loss( 300 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 301 ) -> LossOutput: 302 paired_loss = F.mse_loss( 303 pred[:, : self.paired_dim], target[:, : self.paired_dim] 304 ) 305 unpaired_loss = F.mse_loss( 306 pred[:, self.paired_dim :], target[:, self.paired_dim :] 307 ) 308 total_loss = unpaired_loss + paired_loss 309 return LossOutput( 310 loss=total_loss, 311 metrics={ 312 "unpaired": unpaired_loss, 313 "paired": paired_loss, 314 }, 315 )
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor - raw_target (
Any
): raw data from the input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.
318class AttributeLegacyDomainModule(DomainModule): 319 latent_dim = 11 320 321 def __init__(self): 322 super().__init__(self.latent_dim) 323 self.save_hyperparameters() 324 325 def compute_loss( 326 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 327 ) -> LossOutput: 328 pred_cat, pred_attr, _ = self.decode(pred) 329 target_cat, target_attr, _ = self.decode(target) 330 331 loss_attr = F.mse_loss(pred_attr, target_attr, reduction="mean") 332 loss_cat = F.nll_loss(pred_cat, torch.argmax(target_cat, 1)) 333 loss = loss_attr + loss_cat 334 335 return LossOutput(loss, metrics={"loss_attr": loss_attr, "loss_cat": loss_cat}) 336 337 def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 338 assert len(x) == 2, "This must have only 2 items." 339 return torch.cat(list(x), dim=-1) 340 341 def decode(self, z: torch.Tensor) -> list: 342 categories = z[:, :3] 343 attr = z[:, 3:11] 344 return [categories, attr] 345 346 def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore 347 return self.decode(self.encode(x))
Base class for a DomainModule that defines domain specific modules of the GW.
Initializes a DomainModule.
Arguments:
- latent_dim (
int
): latent dimension of the unimodal module
325 def compute_loss( 326 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 327 ) -> LossOutput: 328 pred_cat, pred_attr, _ = self.decode(pred) 329 target_cat, target_attr, _ = self.decode(target) 330 331 loss_attr = F.mse_loss(pred_attr, target_attr, reduction="mean") 332 loss_cat = F.nll_loss(pred_cat, torch.argmax(target_cat, 1)) 333 loss = loss_attr + loss_cat 334 335 return LossOutput(loss, metrics={"loss_attr": loss_attr, "loss_cat": loss_cat})
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor - raw_target (
Any
): raw data from the input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.
337 def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 338 assert len(x) == 2, "This must have only 2 items." 339 return torch.cat(list(x), dim=-1)
Encode the domain data into a unimodal representation.
Arguments:
- x (
Any
): data of the domain.
Returns:
torch.Tensor
: a unimodal representation.
341 def decode(self, z: torch.Tensor) -> list: 342 categories = z[:, :3] 343 attr = z[:, 3:11] 344 return [categories, attr]
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor
): unimodal representation of the domain.
Returns:
Any
: the original domain data.
346 def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]: # type: ignore 347 return self.decode(self.encode(x))
Same as torch.nn.Module.forward()
.
Arguments:
- *args: Whatever you decide to pass into the forward method.
- **kwargs: Keyword arguments are also possible.
Return:
Your model's output