shimmer_ssd.modules.domains.text
1from collections.abc import Mapping, Sequence 2from typing import Any 3 4import torch 5import torch.nn.functional as F 6from shimmer import LossOutput 7from shimmer.modules.domain import DomainModule 8from shimmer.modules.global_workspace import SchedulerArgs 9from shimmer.modules.vae import ( 10 VAE, 11 VAEDecoder, 12 VAEEncoder, 13 gaussian_nll, 14 kl_divergence_loss, 15 reparameterize, 16) 17from simple_shapes_dataset.text import composer 18from simple_shapes_dataset.text.utils import inspect_all_choices 19from torch import nn 20from torch.optim.adamw import AdamW 21from torch.optim.lr_scheduler import OneCycleLR 22 23 24class Encoder(VAEEncoder): 25 def __init__( 26 self, 27 in_dim: int, 28 hidden_dim: int, 29 out_dim: int, 30 ): 31 super().__init__() 32 33 self.in_dim = in_dim 34 self.hidden_dim = hidden_dim 35 self.out_dim = out_dim 36 37 self.encoder = nn.Sequential( 38 nn.Linear(self.in_dim, hidden_dim), 39 nn.ReLU(), 40 nn.Linear(hidden_dim, hidden_dim), 41 nn.ReLU(), 42 nn.Linear(hidden_dim, out_dim), 43 nn.ReLU(), 44 ) 45 46 self.q_mean = nn.Linear(self.out_dim, self.out_dim) 47 self.q_logvar = nn.Linear(self.out_dim, self.out_dim) 48 49 def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: 50 out = torch.cat(list(x), dim=-1) 51 out = self.encoder(out) 52 return self.q_mean(out), self.q_logvar(out) 53 54 55class Decoder(VAEDecoder): 56 def __init__( 57 self, 58 in_dim: int, 59 hidden_dim: int, 60 out_dim: int, 61 ): 62 super().__init__() 63 64 self.in_dim = in_dim 65 self.hidden_dim = hidden_dim 66 self.out_dim = out_dim 67 68 self.decoder = nn.Sequential( 69 nn.Linear(self.in_dim, self.hidden_dim), 70 nn.ReLU(), 71 nn.Linear(self.hidden_dim, self.hidden_dim), 72 nn.ReLU(), 73 nn.Linear(self.hidden_dim, self.out_dim), 74 nn.Tanh(), 75 ) 76 77 def forward(self, z: torch.Tensor) -> list[torch.Tensor]: # type: ignore 78 return [self.decoder(z)] 79 80 81class TextDomainModule(DomainModule): 82 in_dim = 768 83 84 def __init__( 85 self, 86 latent_dim: int, 87 hidden_dim: int, 88 beta: float = 1, 89 optim_lr: float = 1e-3, 90 optim_weight_decay: float = 0, 91 scheduler_args: SchedulerArgs | None = None, 92 ): 93 super().__init__(latent_dim) 94 self.save_hyperparameters() 95 96 self.hidden_dim = hidden_dim 97 98 vae_encoder = Encoder(self.in_dim, self.hidden_dim, self.latent_dim) 99 vae_decoder = Decoder(self.latent_dim, self.hidden_dim, self.in_dim) 100 self.vae = VAE(vae_encoder, vae_decoder, beta) 101 102 self.attribute_cls = nn.Sequential( 103 nn.Linear(self.latent_dim, self.hidden_dim), 104 nn.ReLU(), 105 nn.Linear(self.hidden_dim, self.hidden_dim), 106 nn.ReLU(), 107 ) 108 self.attribute_cls_cat = nn.Linear(self.hidden_dim, 3) 109 self.attribute_cls_attr = nn.Sequential( 110 nn.Linear(self.hidden_dim, 8), nn.Tanh() 111 ) 112 113 self.composer_grammar_options = inspect_all_choices(composer) 114 115 self.grammar_cls = nn.Sequential( 116 nn.Linear(self.latent_dim, self.hidden_dim), 117 nn.ReLU(), 118 nn.Linear(self.hidden_dim, self.hidden_dim), 119 nn.ReLU(), 120 ) 121 self.grammar_heads = nn.ModuleDict( 122 { 123 name: nn.Linear(self.hidden_dim, n_outputs) 124 for name, n_outputs in self.composer_grammar_options.items() 125 } 126 ) 127 128 self.optim_lr = optim_lr 129 self.optim_weight_decay = optim_weight_decay 130 131 self.scheduler_args = SchedulerArgs( 132 max_lr=optim_lr, 133 total_steps=1, 134 ) 135 self.scheduler_args.update(scheduler_args or {}) 136 137 def compute_loss( 138 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 139 ) -> LossOutput: 140 return LossOutput(F.mse_loss(pred, target, reduction="mean")) 141 142 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 143 return self.vae.encode((x["bert"],)) 144 145 def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 146 text: dict[str, torch.Tensor] = {"bert": self.vae.decode(z)[0]} 147 attr_pred_cat, attr_pred_attr = self.predict_attr(z) 148 text["cls"] = attr_pred_cat 149 text["attr"] = attr_pred_attr 150 text["unpaired"] = torch.zeros_like(z[:, -1]) 151 text.update(self.predict_grammar(z)) 152 return text 153 154 def predict_attr(self, mean: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 155 attr_pred = self.attribute_cls(mean) 156 attr_pred_cat = self.attribute_cls_cat(attr_pred) 157 attr_pred_attr = self.attribute_cls_attr(attr_pred) 158 return attr_pred_cat, attr_pred_attr 159 160 def predict_grammar(self, mean: torch.Tensor) -> dict[str, torch.Tensor]: 161 grammar_pred = self.grammar_cls(mean) 162 return {name: head(grammar_pred) for name, head in self.grammar_heads.items()} 163 164 def grammar_losses(self, mean: torch.Tensor, targets) -> dict[str, torch.Tensor]: 165 grammar_pred = self.predict_grammar(mean) 166 return { 167 f"{name}_ce": F.cross_entropy( 168 pred, targets[name][:, 0].long(), reduction="sum" 169 ) 170 for name, pred in grammar_pred.items() 171 } 172 173 def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: 174 return self.decode(self.encode(x)) 175 176 def generic_step( 177 self, 178 x: Mapping[str, torch.Tensor], 179 mode: str = "train", 180 ) -> torch.Tensor: 181 (mean, logvar), reconstruction = self.vae((x["bert"],)) 182 183 reconstruction_loss = gaussian_nll( 184 reconstruction[0], torch.tensor(0), x["bert"] 185 ).sum() 186 187 kl_loss = kl_divergence_loss(mean, logvar) 188 189 attr_pred_cat, attr_pred_attr = self.predict_attr(mean) 190 191 loss_attr_cat = F.cross_entropy( 192 attr_pred_cat, x["cls"].argmax(dim=1), reduction="sum" 193 ) 194 loss_attr = F.mse_loss(attr_pred_attr, x["attr"], reduction="sum") 195 grammar_targets = {name: x[name] for name in self.composer_grammar_options} 196 grammar_losses = self.grammar_losses(mean, grammar_targets) 197 198 total_loss = ( 199 reconstruction_loss + self.vae.beta * kl_loss + loss_attr_cat + loss_attr 200 ) 201 202 for grammar_loss_name, grammar_loss in grammar_losses.items(): 203 total_loss += grammar_loss 204 self.log(f"{mode}/{grammar_loss_name}", grammar_loss) 205 206 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 207 self.log(f"{mode}/kl_loss", kl_loss) 208 self.log(f"{mode}/attr_category", loss_attr_cat) 209 self.log(f"{mode}/attr_attr", loss_attr) 210 self.log(f"{mode}/loss", total_loss) 211 return total_loss 212 213 def validation_step( # type: ignore 214 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 215 ) -> torch.Tensor: 216 x = batch["t"] 217 return self.generic_step(x, "val") 218 219 def training_step( # type: ignore 220 self, 221 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 222 _, 223 ) -> torch.Tensor: 224 x = batch[frozenset(["t"])]["t"] 225 return self.generic_step(x, "train") 226 227 def configure_optimizers( # type: ignore 228 self, 229 ) -> dict[str, Any]: 230 optimizer = torch.optim.AdamW( 231 self.parameters(), 232 lr=self.optim_lr, 233 weight_decay=self.optim_weight_decay, 234 ) 235 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 236 237 return { 238 "optimizer": optimizer, 239 "lr_scheduler": { 240 "scheduler": lr_scheduler, 241 "interval": "step", 242 }, 243 } 244 245 246class GRUEncoder(nn.Module): 247 def __init__( 248 self, 249 in_dim: int, 250 hidden_dim: int, 251 out_dim: int, 252 ): 253 super().__init__() 254 255 self.in_dim = in_dim 256 self.hidden_dim = hidden_dim 257 self.out_dim = out_dim 258 259 def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 260 out = torch.cat(list(x), dim=-1) 261 out = self.encoder(out) 262 return out 263 264 265class GRUTextDomainModule(DomainModule): 266 in_dim = 768 267 268 def __init__( 269 self, 270 latent_dim: int, 271 hidden_dim: int, 272 vocab_size: int, 273 seq_length: int, 274 optim_lr: float = 1e-3, 275 optim_weight_decay: float = 0, 276 scheduler_args: SchedulerArgs | None = None, 277 padding_token: int = 0, 278 reconstruction_coef: float = 0.5, 279 kl_coef: float = 0.05, 280 ): 281 super().__init__(latent_dim) 282 self.is_frozen = False 283 self.save_hyperparameters() 284 285 self.hidden_dim = hidden_dim 286 self.latent_dim = latent_dim 287 self.vocab_size = vocab_size 288 self.seq_length = seq_length 289 290 self._padding_token = padding_token 291 292 self.reconstruction_coef = reconstruction_coef 293 self.kl_coef = kl_coef 294 295 self.projector = nn.Sequential( 296 nn.Linear(self.in_dim, self.hidden_dim), 297 nn.ReLU(), 298 nn.Linear(self.hidden_dim, self.hidden_dim), 299 nn.ReLU(), 300 nn.Linear(self.hidden_dim, self.latent_dim), 301 nn.ReLU(), 302 ) 303 self.q_mean = nn.Linear(self.latent_dim, self.latent_dim) 304 self.q_logvar = nn.Linear(self.latent_dim, self.latent_dim) 305 306 self.projector_dec = nn.Sequential( 307 nn.Linear(self.latent_dim, self.hidden_dim), 308 nn.ReLU(), 309 nn.Linear(self.hidden_dim, self.hidden_dim), 310 nn.ReLU(), 311 nn.Linear(self.hidden_dim, self.in_dim), 312 ) 313 314 self.embeddings = nn.Embedding(vocab_size, self.latent_dim) 315 316 self.num_layers = 2 317 self.decoder = nn.GRU( 318 self.latent_dim, 319 self.hidden_dim, 320 num_layers=self.num_layers, 321 batch_first=True, 322 ) 323 self.text_head = nn.Linear(self.hidden_dim, self.vocab_size) 324 325 self.optim_lr = optim_lr 326 self.optim_weight_decay = optim_weight_decay 327 328 self.scheduler_args = SchedulerArgs( 329 max_lr=optim_lr, 330 total_steps=1, 331 ) 332 self.scheduler_args.update(scheduler_args or {}) 333 334 def compute_loss( 335 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 336 ) -> LossOutput: 337 mse_loss = torch.nn.functional.mse_loss(pred, target, reduction="mean") 338 339 return LossOutput(mse_loss, {"mse_loss": mse_loss}) 340 341 def compute_domain_loss(self, domain: Any) -> LossOutput: 342 z = self.encode(domain) 343 loss, acc = self.text_token_loss(z, domain) 344 return LossOutput(loss, {"acc": acc}) 345 346 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 347 z = self.projector(x["bert"]) 348 return self.q_mean(z) 349 350 def encode_dist( 351 self, x: Mapping[str, torch.Tensor] 352 ) -> tuple[torch.Tensor, torch.Tensor]: 353 z = self.projector(x["bert"]) 354 return self.q_mean(z), self.q_logvar(z) 355 356 def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 357 context = z.unsqueeze(1) 358 hidden = None 359 outputs = [] 360 361 for _ in range(self.seq_length): 362 out, hidden = self.decoder(context, hidden) 363 token_dist = self.text_head(out) 364 tokens = torch.argmax(token_dist, dim=-1) 365 outputs.append(token_dist) 366 367 context = self.embeddings(tokens) 368 369 token_dists = torch.cat(outputs, dim=1) 370 tokens = torch.argmax(token_dists, dim=-1) 371 return {"token_dist": token_dists, "tokens": tokens} 372 373 def decode_one(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 374 out, _ = self.decoder(z) 375 tokens_dist = self.text_head(out) 376 tokens = torch.argmax(tokens_dist, -1) 377 return {"token_dist": tokens_dist, "tokens": tokens} 378 379 def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: 380 z, _ = self.encode_dist(x) 381 return self.decode(z) 382 383 def text_token_loss( 384 self, z: torch.Tensor, target: Mapping[str, torch.Tensor] 385 ) -> tuple[torch.Tensor, torch.Tensor]: 386 context = z.unsqueeze(1) 387 real_tokens = self.embeddings(target["tokens"][:, :-1]) 388 seq = torch.cat([context, real_tokens], dim=1) 389 390 out = self.decode_one(seq) 391 392 loss = F.cross_entropy(out["token_dist"].transpose(1, 2), target["tokens"]) 393 padding_mask = target["tokens"] != self._padding_token 394 395 padded_out = out["tokens"][padding_mask] 396 padded_target = target["tokens"][padding_mask] 397 398 acc = (padded_out == padded_target).sum() / padded_out.size(0) 399 return loss, acc 400 401 def generic_step( 402 self, x: Mapping[str, torch.Tensor], mode: str = "train" 403 ) -> torch.Tensor: 404 mean, logvar = self.encode_dist(x) 405 z = reparameterize(mean, logvar) 406 noise = torch.randn_like(z) * 0.5 407 z_text = z + noise 408 loss, acc = self.text_token_loss(z_text, x) 409 410 x_hat = self.projector_dec(z) 411 reconstruction_loss = gaussian_nll(x_hat, torch.tensor(0), x["bert"]).sum() 412 kl_loss = kl_divergence_loss(mean, logvar) 413 414 total_loss = ( 415 loss 416 + self.reconstruction_coef * reconstruction_loss 417 + self.kl_coef * kl_loss 418 ) 419 420 self.log(f"{mode}/kl_loss", kl_loss) 421 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 422 self.log(f"{mode}/total_loss", total_loss) 423 self.log(f"{mode}/loss", loss) 424 self.log(f"{mode}/acc", acc) 425 return total_loss # loss 426 427 def validation_step( # type: ignore 428 self, 429 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 430 _, 431 ) -> torch.Tensor: 432 x = batch[frozenset(["t"])]["t"] 433 return self.generic_step(x, "val") 434 435 def training_step( # type: ignore 436 self, 437 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 438 _, 439 ) -> torch.Tensor: 440 x = batch[frozenset(["t"])]["t"] 441 return self.generic_step(x, "train") 442 443 def configure_optimizers( # type: ignore 444 self, 445 ) -> dict[str, Any]: 446 optimizer = AdamW( 447 self.parameters(), 448 lr=self.optim_lr, 449 weight_decay=self.optim_weight_decay, 450 ) 451 452 return {"optimizer": optimizer} 453 454 455class Text2Attr(DomainModule): 456 def __init__( 457 self, 458 latent_dim: int, 459 hidden_dim: int, 460 text_model: GRUTextDomainModule, 461 optim_lr: float = 1e-3, 462 optim_weight_decay: float = 0, 463 scheduler_args: SchedulerArgs | None = None, 464 ) -> None: 465 super().__init__(latent_dim) 466 self.save_hyperparameters(ignore=["text_model"]) 467 468 self.text_model = text_model 469 470 self.pred_attr = nn.Sequential( 471 nn.Linear(latent_dim, hidden_dim), 472 nn.ReLU(), 473 nn.Linear(hidden_dim, hidden_dim), 474 nn.ReLU(), 475 nn.Linear(hidden_dim, hidden_dim), 476 nn.ReLU(), 477 nn.Linear(hidden_dim, 8), 478 nn.Tanh(), 479 ) 480 self.pred_cat = nn.Sequential( 481 nn.Linear(latent_dim, hidden_dim), 482 nn.ReLU(), 483 nn.Linear(hidden_dim, hidden_dim), 484 nn.ReLU(), 485 nn.Linear(hidden_dim, hidden_dim), 486 nn.ReLU(), 487 nn.Linear(hidden_dim, 3), 488 nn.Softmax(dim=1), 489 ) 490 491 self.optim_lr = optim_lr 492 self.optim_weight_decay = optim_weight_decay 493 494 self.scheduler_args = SchedulerArgs( 495 max_lr=optim_lr, 496 total_steps=1, 497 ) 498 self.scheduler_args.update(scheduler_args or {}) 499 500 def compute_loss( 501 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 502 ) -> LossOutput: 503 pred_attr = self.pred_attr(pred) 504 pred_cat = self.pred_cat(pred) 505 target_attr = self.pred_attr(target) 506 target_cat = self.pred_cat(target) 507 loss_attr = F.mse_loss(pred_attr, target_attr) 508 loss_cat = F.cross_entropy(pred_cat, target_cat.argmax(dim=1)) 509 loss = loss_attr + loss_cat 510 return LossOutput(loss, {"attr": loss_attr, "cat": loss_cat}) 511 512 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 513 return self.text_model.encode(x) 514 515 def decode(self, z: torch.Tensor) -> dict[str, Any]: 516 out: dict[str, Any] = {} 517 out.update(self.text_model.decode(z)) 518 pred_attr = self.pred_attr(z) 519 pred_cat = self.pred_cat(z) 520 out.update({"attr": [pred_cat, pred_attr, pred_attr]}) 521 return out 522 523 def forward(self, x: Mapping[str, Any]) -> dict[str, list[torch.Tensor]]: 524 return self.decode(self.encode(x)) 525 526 def generic_step( 527 self, 528 x: Mapping[str, Any], 529 mode: str = "train", 530 ) -> torch.Tensor: 531 text: Mapping[str, torch.Tensor] = x["t"] 532 attr: Sequence[torch.Tensor] = x["attr"] 533 text_l = self.text_model.encode(text) 534 pred_attr = self.pred_attr(text_l) 535 pred_cat = self.pred_cat(text_l) 536 cats = torch.argmax(attr[0], dim=1) 537 loss_cat = F.cross_entropy(pred_cat, cats) 538 loss_attr = F.mse_loss(pred_attr, attr[1]) 539 total_loss = loss_cat + loss_attr 540 pred_cats = pred_cat.argmax(dim=1) 541 acc = (cats == pred_cats).sum() / cats.size(0) 542 543 self.log(f"{mode}/loss_cat", loss_cat) 544 self.log(f"{mode}/loss_attr", loss_attr) 545 self.log(f"{mode}/acc_cat", acc) 546 self.log(f"{mode}/loss", total_loss) 547 548 return total_loss 549 550 def validation_step( # type: ignore 551 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 552 ) -> torch.Tensor: 553 return self.generic_step(batch, "val") 554 555 def training_step( # type: ignore 556 self, 557 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 558 _, 559 ) -> torch.Tensor: 560 data = batch[frozenset(["t", "attr"])] 561 return self.generic_step(data, "train") 562 563 def configure_optimizers( # type: ignore 564 self, 565 ) -> dict[str, Any]: 566 optimizer = AdamW( 567 self.parameters(), 568 lr=self.optim_lr, 569 weight_decay=self.optim_weight_decay, 570 ) 571 572 return {"optimizer": optimizer}
25class Encoder(VAEEncoder): 26 def __init__( 27 self, 28 in_dim: int, 29 hidden_dim: int, 30 out_dim: int, 31 ): 32 super().__init__() 33 34 self.in_dim = in_dim 35 self.hidden_dim = hidden_dim 36 self.out_dim = out_dim 37 38 self.encoder = nn.Sequential( 39 nn.Linear(self.in_dim, hidden_dim), 40 nn.ReLU(), 41 nn.Linear(hidden_dim, hidden_dim), 42 nn.ReLU(), 43 nn.Linear(hidden_dim, out_dim), 44 nn.ReLU(), 45 ) 46 47 self.q_mean = nn.Linear(self.out_dim, self.out_dim) 48 self.q_logvar = nn.Linear(self.out_dim, self.out_dim) 49 50 def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: 51 out = torch.cat(list(x), dim=-1) 52 out = self.encoder(out) 53 return self.q_mean(out), self.q_logvar(out)
Base class for a VAE encoder.
26 def __init__( 27 self, 28 in_dim: int, 29 hidden_dim: int, 30 out_dim: int, 31 ): 32 super().__init__() 33 34 self.in_dim = in_dim 35 self.hidden_dim = hidden_dim 36 self.out_dim = out_dim 37 38 self.encoder = nn.Sequential( 39 nn.Linear(self.in_dim, hidden_dim), 40 nn.ReLU(), 41 nn.Linear(hidden_dim, hidden_dim), 42 nn.ReLU(), 43 nn.Linear(hidden_dim, out_dim), 44 nn.ReLU(), 45 ) 46 47 self.q_mean = nn.Linear(self.out_dim, self.out_dim) 48 self.q_logvar = nn.Linear(self.out_dim, self.out_dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
50 def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: 51 out = torch.cat(list(x), dim=-1) 52 out = self.encoder(out) 53 return self.q_mean(out), self.q_logvar(out)
Encode representation with VAE.
Arguments:
- x (
Any
): Some input value
Returns:
tuple[torch.Tensor, torch.Tensor]
: the mean and log variance
56class Decoder(VAEDecoder): 57 def __init__( 58 self, 59 in_dim: int, 60 hidden_dim: int, 61 out_dim: int, 62 ): 63 super().__init__() 64 65 self.in_dim = in_dim 66 self.hidden_dim = hidden_dim 67 self.out_dim = out_dim 68 69 self.decoder = nn.Sequential( 70 nn.Linear(self.in_dim, self.hidden_dim), 71 nn.ReLU(), 72 nn.Linear(self.hidden_dim, self.hidden_dim), 73 nn.ReLU(), 74 nn.Linear(self.hidden_dim, self.out_dim), 75 nn.Tanh(), 76 ) 77 78 def forward(self, z: torch.Tensor) -> list[torch.Tensor]: # type: ignore 79 return [self.decoder(z)]
Base class for a VAE decoder.
57 def __init__( 58 self, 59 in_dim: int, 60 hidden_dim: int, 61 out_dim: int, 62 ): 63 super().__init__() 64 65 self.in_dim = in_dim 66 self.hidden_dim = hidden_dim 67 self.out_dim = out_dim 68 69 self.decoder = nn.Sequential( 70 nn.Linear(self.in_dim, self.hidden_dim), 71 nn.ReLU(), 72 nn.Linear(self.hidden_dim, self.hidden_dim), 73 nn.ReLU(), 74 nn.Linear(self.hidden_dim, self.out_dim), 75 nn.Tanh(), 76 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
78 def forward(self, z: torch.Tensor) -> list[torch.Tensor]: # type: ignore 79 return [self.decoder(z)]
Decode representation with VAE
Arguments:
- x (
torch.Tensor
): VAE latent representation representation
Returns:
Any
: the reconstructed input
82class TextDomainModule(DomainModule): 83 in_dim = 768 84 85 def __init__( 86 self, 87 latent_dim: int, 88 hidden_dim: int, 89 beta: float = 1, 90 optim_lr: float = 1e-3, 91 optim_weight_decay: float = 0, 92 scheduler_args: SchedulerArgs | None = None, 93 ): 94 super().__init__(latent_dim) 95 self.save_hyperparameters() 96 97 self.hidden_dim = hidden_dim 98 99 vae_encoder = Encoder(self.in_dim, self.hidden_dim, self.latent_dim) 100 vae_decoder = Decoder(self.latent_dim, self.hidden_dim, self.in_dim) 101 self.vae = VAE(vae_encoder, vae_decoder, beta) 102 103 self.attribute_cls = nn.Sequential( 104 nn.Linear(self.latent_dim, self.hidden_dim), 105 nn.ReLU(), 106 nn.Linear(self.hidden_dim, self.hidden_dim), 107 nn.ReLU(), 108 ) 109 self.attribute_cls_cat = nn.Linear(self.hidden_dim, 3) 110 self.attribute_cls_attr = nn.Sequential( 111 nn.Linear(self.hidden_dim, 8), nn.Tanh() 112 ) 113 114 self.composer_grammar_options = inspect_all_choices(composer) 115 116 self.grammar_cls = nn.Sequential( 117 nn.Linear(self.latent_dim, self.hidden_dim), 118 nn.ReLU(), 119 nn.Linear(self.hidden_dim, self.hidden_dim), 120 nn.ReLU(), 121 ) 122 self.grammar_heads = nn.ModuleDict( 123 { 124 name: nn.Linear(self.hidden_dim, n_outputs) 125 for name, n_outputs in self.composer_grammar_options.items() 126 } 127 ) 128 129 self.optim_lr = optim_lr 130 self.optim_weight_decay = optim_weight_decay 131 132 self.scheduler_args = SchedulerArgs( 133 max_lr=optim_lr, 134 total_steps=1, 135 ) 136 self.scheduler_args.update(scheduler_args or {}) 137 138 def compute_loss( 139 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 140 ) -> LossOutput: 141 return LossOutput(F.mse_loss(pred, target, reduction="mean")) 142 143 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 144 return self.vae.encode((x["bert"],)) 145 146 def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 147 text: dict[str, torch.Tensor] = {"bert": self.vae.decode(z)[0]} 148 attr_pred_cat, attr_pred_attr = self.predict_attr(z) 149 text["cls"] = attr_pred_cat 150 text["attr"] = attr_pred_attr 151 text["unpaired"] = torch.zeros_like(z[:, -1]) 152 text.update(self.predict_grammar(z)) 153 return text 154 155 def predict_attr(self, mean: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 156 attr_pred = self.attribute_cls(mean) 157 attr_pred_cat = self.attribute_cls_cat(attr_pred) 158 attr_pred_attr = self.attribute_cls_attr(attr_pred) 159 return attr_pred_cat, attr_pred_attr 160 161 def predict_grammar(self, mean: torch.Tensor) -> dict[str, torch.Tensor]: 162 grammar_pred = self.grammar_cls(mean) 163 return {name: head(grammar_pred) for name, head in self.grammar_heads.items()} 164 165 def grammar_losses(self, mean: torch.Tensor, targets) -> dict[str, torch.Tensor]: 166 grammar_pred = self.predict_grammar(mean) 167 return { 168 f"{name}_ce": F.cross_entropy( 169 pred, targets[name][:, 0].long(), reduction="sum" 170 ) 171 for name, pred in grammar_pred.items() 172 } 173 174 def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: 175 return self.decode(self.encode(x)) 176 177 def generic_step( 178 self, 179 x: Mapping[str, torch.Tensor], 180 mode: str = "train", 181 ) -> torch.Tensor: 182 (mean, logvar), reconstruction = self.vae((x["bert"],)) 183 184 reconstruction_loss = gaussian_nll( 185 reconstruction[0], torch.tensor(0), x["bert"] 186 ).sum() 187 188 kl_loss = kl_divergence_loss(mean, logvar) 189 190 attr_pred_cat, attr_pred_attr = self.predict_attr(mean) 191 192 loss_attr_cat = F.cross_entropy( 193 attr_pred_cat, x["cls"].argmax(dim=1), reduction="sum" 194 ) 195 loss_attr = F.mse_loss(attr_pred_attr, x["attr"], reduction="sum") 196 grammar_targets = {name: x[name] for name in self.composer_grammar_options} 197 grammar_losses = self.grammar_losses(mean, grammar_targets) 198 199 total_loss = ( 200 reconstruction_loss + self.vae.beta * kl_loss + loss_attr_cat + loss_attr 201 ) 202 203 for grammar_loss_name, grammar_loss in grammar_losses.items(): 204 total_loss += grammar_loss 205 self.log(f"{mode}/{grammar_loss_name}", grammar_loss) 206 207 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 208 self.log(f"{mode}/kl_loss", kl_loss) 209 self.log(f"{mode}/attr_category", loss_attr_cat) 210 self.log(f"{mode}/attr_attr", loss_attr) 211 self.log(f"{mode}/loss", total_loss) 212 return total_loss 213 214 def validation_step( # type: ignore 215 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 216 ) -> torch.Tensor: 217 x = batch["t"] 218 return self.generic_step(x, "val") 219 220 def training_step( # type: ignore 221 self, 222 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 223 _, 224 ) -> torch.Tensor: 225 x = batch[frozenset(["t"])]["t"] 226 return self.generic_step(x, "train") 227 228 def configure_optimizers( # type: ignore 229 self, 230 ) -> dict[str, Any]: 231 optimizer = torch.optim.AdamW( 232 self.parameters(), 233 lr=self.optim_lr, 234 weight_decay=self.optim_weight_decay, 235 ) 236 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 237 238 return { 239 "optimizer": optimizer, 240 "lr_scheduler": { 241 "scheduler": lr_scheduler, 242 "interval": "step", 243 }, 244 }
Base class for a DomainModule that defines domain specific modules of the GW.
85 def __init__( 86 self, 87 latent_dim: int, 88 hidden_dim: int, 89 beta: float = 1, 90 optim_lr: float = 1e-3, 91 optim_weight_decay: float = 0, 92 scheduler_args: SchedulerArgs | None = None, 93 ): 94 super().__init__(latent_dim) 95 self.save_hyperparameters() 96 97 self.hidden_dim = hidden_dim 98 99 vae_encoder = Encoder(self.in_dim, self.hidden_dim, self.latent_dim) 100 vae_decoder = Decoder(self.latent_dim, self.hidden_dim, self.in_dim) 101 self.vae = VAE(vae_encoder, vae_decoder, beta) 102 103 self.attribute_cls = nn.Sequential( 104 nn.Linear(self.latent_dim, self.hidden_dim), 105 nn.ReLU(), 106 nn.Linear(self.hidden_dim, self.hidden_dim), 107 nn.ReLU(), 108 ) 109 self.attribute_cls_cat = nn.Linear(self.hidden_dim, 3) 110 self.attribute_cls_attr = nn.Sequential( 111 nn.Linear(self.hidden_dim, 8), nn.Tanh() 112 ) 113 114 self.composer_grammar_options = inspect_all_choices(composer) 115 116 self.grammar_cls = nn.Sequential( 117 nn.Linear(self.latent_dim, self.hidden_dim), 118 nn.ReLU(), 119 nn.Linear(self.hidden_dim, self.hidden_dim), 120 nn.ReLU(), 121 ) 122 self.grammar_heads = nn.ModuleDict( 123 { 124 name: nn.Linear(self.hidden_dim, n_outputs) 125 for name, n_outputs in self.composer_grammar_options.items() 126 } 127 ) 128 129 self.optim_lr = optim_lr 130 self.optim_weight_decay = optim_weight_decay 131 132 self.scheduler_args = SchedulerArgs( 133 max_lr=optim_lr, 134 total_steps=1, 135 ) 136 self.scheduler_args.update(scheduler_args or {})
Initializes a DomainModule.
Arguments:
- latent_dim (
int
): latent dimension of the unimodal module
138 def compute_loss( 139 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 140 ) -> LossOutput: 141 return LossOutput(F.mse_loss(pred, target, reduction="mean"))
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor - raw_target (
Any
): raw data from the input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.
143 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 144 return self.vae.encode((x["bert"],))
Encode the domain data into a unimodal representation.
Arguments:
- x (
Any
): data of the domain.
Returns:
torch.Tensor
: a unimodal representation.
146 def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 147 text: dict[str, torch.Tensor] = {"bert": self.vae.decode(z)[0]} 148 attr_pred_cat, attr_pred_attr = self.predict_attr(z) 149 text["cls"] = attr_pred_cat 150 text["attr"] = attr_pred_attr 151 text["unpaired"] = torch.zeros_like(z[:, -1]) 152 text.update(self.predict_grammar(z)) 153 return text
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor
): unimodal representation of the domain.
Returns:
Any
: the original domain data.
174 def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: 175 return self.decode(self.encode(x))
Same as torch.nn.Module.forward()
.
Arguments:
- *args: Whatever you decide to pass into the forward method.
- **kwargs: Keyword arguments are also possible.
Return:
Your model's output
177 def generic_step( 178 self, 179 x: Mapping[str, torch.Tensor], 180 mode: str = "train", 181 ) -> torch.Tensor: 182 (mean, logvar), reconstruction = self.vae((x["bert"],)) 183 184 reconstruction_loss = gaussian_nll( 185 reconstruction[0], torch.tensor(0), x["bert"] 186 ).sum() 187 188 kl_loss = kl_divergence_loss(mean, logvar) 189 190 attr_pred_cat, attr_pred_attr = self.predict_attr(mean) 191 192 loss_attr_cat = F.cross_entropy( 193 attr_pred_cat, x["cls"].argmax(dim=1), reduction="sum" 194 ) 195 loss_attr = F.mse_loss(attr_pred_attr, x["attr"], reduction="sum") 196 grammar_targets = {name: x[name] for name in self.composer_grammar_options} 197 grammar_losses = self.grammar_losses(mean, grammar_targets) 198 199 total_loss = ( 200 reconstruction_loss + self.vae.beta * kl_loss + loss_attr_cat + loss_attr 201 ) 202 203 for grammar_loss_name, grammar_loss in grammar_losses.items(): 204 total_loss += grammar_loss 205 self.log(f"{mode}/{grammar_loss_name}", grammar_loss) 206 207 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 208 self.log(f"{mode}/kl_loss", kl_loss) 209 self.log(f"{mode}/attr_category", loss_attr_cat) 210 self.log(f"{mode}/attr_attr", loss_attr) 211 self.log(f"{mode}/loss", total_loss) 212 return total_loss
214 def validation_step( # type: ignore 215 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 216 ) -> torch.Tensor: 217 x = batch["t"] 218 return self.generic_step(x, "val")
Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
.None
- Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples::
# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# log the outputs!
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders, validation_step()
will have an additional argument. We recommend
setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
# dataloader_idx tells you which dataset this is.
...
Note:
If you don't need to validate you don't need to implement this method.
Note:
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
220 def training_step( # type: ignore 221 self, 222 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 223 _, 224 ) -> torch.Tensor: 225 x = batch[frozenset(["t"])]["t"] 226 return self.generic_step(x, "train")
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary which can include any keys, but must include the key'loss'
in the case of automatic optimization.None
- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example::
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
def __init__(self):
super().__init__()
self.automatic_optimization = False
# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
opt1, opt2 = self.optimizers()
# do training_step with encoder
...
opt1.step()
# do training_step with decoder
...
opt2.step()
Note:
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
228 def configure_optimizers( # type: ignore 229 self, 230 ) -> dict[str, Any]: 231 optimizer = torch.optim.AdamW( 232 self.parameters(), 233 lr=self.optim_lr, 234 weight_decay=self.optim_weight_decay, 235 ) 236 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 237 238 return { 239 "optimizer": optimizer, 240 "lr_scheduler": { 241 "scheduler": lr_scheduler, 242 "interval": "step", 243 }, 244 }
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
Return:
Any of these 6 options.
- Single optimizer.
- List or Tuple of optimizers.
- Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).- Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.- None - Fit will run without any optimizer.
The lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration.
The default configuration is shown below.
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
When there are schedulers in which the .step()
method is conditioned on a value, such as the
torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that the
lr_scheduler_config
contains the keyword "monitor"
set to the metric name that the scheduler
should be conditioned on.
.. testcode::
# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
optimizer = Adam(...)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(optimizer, ...),
"monitor": "metric_to_track",
"frequency": "indicates how often the metric is updated",
# If "monitor" references validation metrics, then "frequency" should be set to a
# multiple of "trainer.check_val_every_n_epoch".
},
}
# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return (
{
"optimizer": optimizer1,
"lr_scheduler": {
"scheduler": scheduler1,
"monitor": "metric_to_track",
},
},
{"optimizer": optimizer2, "lr_scheduler": scheduler2},
)
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in your ~lightning.pytorch.core.LightningModule
.
Note:
Some things to know:
- Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.- If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default "epoch") in the scheduler configuration, Lightning will call the scheduler's.step()
method automatically in case of automatic optimization.- If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.- If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.- If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
- If you need to control how often the optimizer steps, override the
optimizer_step()
hook.
247class GRUEncoder(nn.Module): 248 def __init__( 249 self, 250 in_dim: int, 251 hidden_dim: int, 252 out_dim: int, 253 ): 254 super().__init__() 255 256 self.in_dim = in_dim 257 self.hidden_dim = hidden_dim 258 self.out_dim = out_dim 259 260 def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 261 out = torch.cat(list(x), dim=-1) 262 out = self.encoder(out) 263 return out
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
248 def __init__( 249 self, 250 in_dim: int, 251 hidden_dim: int, 252 out_dim: int, 253 ): 254 super().__init__() 255 256 self.in_dim = in_dim 257 self.hidden_dim = hidden_dim 258 self.out_dim = out_dim
Initialize internal Module state, shared by both nn.Module and ScriptModule.
260 def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 261 out = torch.cat(list(x), dim=-1) 262 out = self.encoder(out) 263 return out
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
266class GRUTextDomainModule(DomainModule): 267 in_dim = 768 268 269 def __init__( 270 self, 271 latent_dim: int, 272 hidden_dim: int, 273 vocab_size: int, 274 seq_length: int, 275 optim_lr: float = 1e-3, 276 optim_weight_decay: float = 0, 277 scheduler_args: SchedulerArgs | None = None, 278 padding_token: int = 0, 279 reconstruction_coef: float = 0.5, 280 kl_coef: float = 0.05, 281 ): 282 super().__init__(latent_dim) 283 self.is_frozen = False 284 self.save_hyperparameters() 285 286 self.hidden_dim = hidden_dim 287 self.latent_dim = latent_dim 288 self.vocab_size = vocab_size 289 self.seq_length = seq_length 290 291 self._padding_token = padding_token 292 293 self.reconstruction_coef = reconstruction_coef 294 self.kl_coef = kl_coef 295 296 self.projector = nn.Sequential( 297 nn.Linear(self.in_dim, self.hidden_dim), 298 nn.ReLU(), 299 nn.Linear(self.hidden_dim, self.hidden_dim), 300 nn.ReLU(), 301 nn.Linear(self.hidden_dim, self.latent_dim), 302 nn.ReLU(), 303 ) 304 self.q_mean = nn.Linear(self.latent_dim, self.latent_dim) 305 self.q_logvar = nn.Linear(self.latent_dim, self.latent_dim) 306 307 self.projector_dec = nn.Sequential( 308 nn.Linear(self.latent_dim, self.hidden_dim), 309 nn.ReLU(), 310 nn.Linear(self.hidden_dim, self.hidden_dim), 311 nn.ReLU(), 312 nn.Linear(self.hidden_dim, self.in_dim), 313 ) 314 315 self.embeddings = nn.Embedding(vocab_size, self.latent_dim) 316 317 self.num_layers = 2 318 self.decoder = nn.GRU( 319 self.latent_dim, 320 self.hidden_dim, 321 num_layers=self.num_layers, 322 batch_first=True, 323 ) 324 self.text_head = nn.Linear(self.hidden_dim, self.vocab_size) 325 326 self.optim_lr = optim_lr 327 self.optim_weight_decay = optim_weight_decay 328 329 self.scheduler_args = SchedulerArgs( 330 max_lr=optim_lr, 331 total_steps=1, 332 ) 333 self.scheduler_args.update(scheduler_args or {}) 334 335 def compute_loss( 336 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 337 ) -> LossOutput: 338 mse_loss = torch.nn.functional.mse_loss(pred, target, reduction="mean") 339 340 return LossOutput(mse_loss, {"mse_loss": mse_loss}) 341 342 def compute_domain_loss(self, domain: Any) -> LossOutput: 343 z = self.encode(domain) 344 loss, acc = self.text_token_loss(z, domain) 345 return LossOutput(loss, {"acc": acc}) 346 347 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 348 z = self.projector(x["bert"]) 349 return self.q_mean(z) 350 351 def encode_dist( 352 self, x: Mapping[str, torch.Tensor] 353 ) -> tuple[torch.Tensor, torch.Tensor]: 354 z = self.projector(x["bert"]) 355 return self.q_mean(z), self.q_logvar(z) 356 357 def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 358 context = z.unsqueeze(1) 359 hidden = None 360 outputs = [] 361 362 for _ in range(self.seq_length): 363 out, hidden = self.decoder(context, hidden) 364 token_dist = self.text_head(out) 365 tokens = torch.argmax(token_dist, dim=-1) 366 outputs.append(token_dist) 367 368 context = self.embeddings(tokens) 369 370 token_dists = torch.cat(outputs, dim=1) 371 tokens = torch.argmax(token_dists, dim=-1) 372 return {"token_dist": token_dists, "tokens": tokens} 373 374 def decode_one(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 375 out, _ = self.decoder(z) 376 tokens_dist = self.text_head(out) 377 tokens = torch.argmax(tokens_dist, -1) 378 return {"token_dist": tokens_dist, "tokens": tokens} 379 380 def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: 381 z, _ = self.encode_dist(x) 382 return self.decode(z) 383 384 def text_token_loss( 385 self, z: torch.Tensor, target: Mapping[str, torch.Tensor] 386 ) -> tuple[torch.Tensor, torch.Tensor]: 387 context = z.unsqueeze(1) 388 real_tokens = self.embeddings(target["tokens"][:, :-1]) 389 seq = torch.cat([context, real_tokens], dim=1) 390 391 out = self.decode_one(seq) 392 393 loss = F.cross_entropy(out["token_dist"].transpose(1, 2), target["tokens"]) 394 padding_mask = target["tokens"] != self._padding_token 395 396 padded_out = out["tokens"][padding_mask] 397 padded_target = target["tokens"][padding_mask] 398 399 acc = (padded_out == padded_target).sum() / padded_out.size(0) 400 return loss, acc 401 402 def generic_step( 403 self, x: Mapping[str, torch.Tensor], mode: str = "train" 404 ) -> torch.Tensor: 405 mean, logvar = self.encode_dist(x) 406 z = reparameterize(mean, logvar) 407 noise = torch.randn_like(z) * 0.5 408 z_text = z + noise 409 loss, acc = self.text_token_loss(z_text, x) 410 411 x_hat = self.projector_dec(z) 412 reconstruction_loss = gaussian_nll(x_hat, torch.tensor(0), x["bert"]).sum() 413 kl_loss = kl_divergence_loss(mean, logvar) 414 415 total_loss = ( 416 loss 417 + self.reconstruction_coef * reconstruction_loss 418 + self.kl_coef * kl_loss 419 ) 420 421 self.log(f"{mode}/kl_loss", kl_loss) 422 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 423 self.log(f"{mode}/total_loss", total_loss) 424 self.log(f"{mode}/loss", loss) 425 self.log(f"{mode}/acc", acc) 426 return total_loss # loss 427 428 def validation_step( # type: ignore 429 self, 430 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 431 _, 432 ) -> torch.Tensor: 433 x = batch[frozenset(["t"])]["t"] 434 return self.generic_step(x, "val") 435 436 def training_step( # type: ignore 437 self, 438 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 439 _, 440 ) -> torch.Tensor: 441 x = batch[frozenset(["t"])]["t"] 442 return self.generic_step(x, "train") 443 444 def configure_optimizers( # type: ignore 445 self, 446 ) -> dict[str, Any]: 447 optimizer = AdamW( 448 self.parameters(), 449 lr=self.optim_lr, 450 weight_decay=self.optim_weight_decay, 451 ) 452 453 return {"optimizer": optimizer}
Base class for a DomainModule that defines domain specific modules of the GW.
269 def __init__( 270 self, 271 latent_dim: int, 272 hidden_dim: int, 273 vocab_size: int, 274 seq_length: int, 275 optim_lr: float = 1e-3, 276 optim_weight_decay: float = 0, 277 scheduler_args: SchedulerArgs | None = None, 278 padding_token: int = 0, 279 reconstruction_coef: float = 0.5, 280 kl_coef: float = 0.05, 281 ): 282 super().__init__(latent_dim) 283 self.is_frozen = False 284 self.save_hyperparameters() 285 286 self.hidden_dim = hidden_dim 287 self.latent_dim = latent_dim 288 self.vocab_size = vocab_size 289 self.seq_length = seq_length 290 291 self._padding_token = padding_token 292 293 self.reconstruction_coef = reconstruction_coef 294 self.kl_coef = kl_coef 295 296 self.projector = nn.Sequential( 297 nn.Linear(self.in_dim, self.hidden_dim), 298 nn.ReLU(), 299 nn.Linear(self.hidden_dim, self.hidden_dim), 300 nn.ReLU(), 301 nn.Linear(self.hidden_dim, self.latent_dim), 302 nn.ReLU(), 303 ) 304 self.q_mean = nn.Linear(self.latent_dim, self.latent_dim) 305 self.q_logvar = nn.Linear(self.latent_dim, self.latent_dim) 306 307 self.projector_dec = nn.Sequential( 308 nn.Linear(self.latent_dim, self.hidden_dim), 309 nn.ReLU(), 310 nn.Linear(self.hidden_dim, self.hidden_dim), 311 nn.ReLU(), 312 nn.Linear(self.hidden_dim, self.in_dim), 313 ) 314 315 self.embeddings = nn.Embedding(vocab_size, self.latent_dim) 316 317 self.num_layers = 2 318 self.decoder = nn.GRU( 319 self.latent_dim, 320 self.hidden_dim, 321 num_layers=self.num_layers, 322 batch_first=True, 323 ) 324 self.text_head = nn.Linear(self.hidden_dim, self.vocab_size) 325 326 self.optim_lr = optim_lr 327 self.optim_weight_decay = optim_weight_decay 328 329 self.scheduler_args = SchedulerArgs( 330 max_lr=optim_lr, 331 total_steps=1, 332 ) 333 self.scheduler_args.update(scheduler_args or {})
Initializes a DomainModule.
Arguments:
- latent_dim (
int
): latent dimension of the unimodal module
335 def compute_loss( 336 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 337 ) -> LossOutput: 338 mse_loss = torch.nn.functional.mse_loss(pred, target, reduction="mean") 339 340 return LossOutput(mse_loss, {"mse_loss": mse_loss})
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor - raw_target (
Any
): raw data from the input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.
342 def compute_domain_loss(self, domain: Any) -> LossOutput: 343 z = self.encode(domain) 344 loss, acc = self.text_token_loss(z, domain) 345 return LossOutput(loss, {"acc": acc})
Compute the unimodal domain loss.
Arguments:
- domain (
Any
): domain input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.
347 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 348 z = self.projector(x["bert"]) 349 return self.q_mean(z)
Encode the domain data into a unimodal representation.
Arguments:
- x (
Any
): data of the domain.
Returns:
torch.Tensor
: a unimodal representation.
357 def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 358 context = z.unsqueeze(1) 359 hidden = None 360 outputs = [] 361 362 for _ in range(self.seq_length): 363 out, hidden = self.decoder(context, hidden) 364 token_dist = self.text_head(out) 365 tokens = torch.argmax(token_dist, dim=-1) 366 outputs.append(token_dist) 367 368 context = self.embeddings(tokens) 369 370 token_dists = torch.cat(outputs, dim=1) 371 tokens = torch.argmax(token_dists, dim=-1) 372 return {"token_dist": token_dists, "tokens": tokens}
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor
): unimodal representation of the domain.
Returns:
Any
: the original domain data.
380 def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: 381 z, _ = self.encode_dist(x) 382 return self.decode(z)
Same as torch.nn.Module.forward()
.
Arguments:
- *args: Whatever you decide to pass into the forward method.
- **kwargs: Keyword arguments are also possible.
Return:
Your model's output
384 def text_token_loss( 385 self, z: torch.Tensor, target: Mapping[str, torch.Tensor] 386 ) -> tuple[torch.Tensor, torch.Tensor]: 387 context = z.unsqueeze(1) 388 real_tokens = self.embeddings(target["tokens"][:, :-1]) 389 seq = torch.cat([context, real_tokens], dim=1) 390 391 out = self.decode_one(seq) 392 393 loss = F.cross_entropy(out["token_dist"].transpose(1, 2), target["tokens"]) 394 padding_mask = target["tokens"] != self._padding_token 395 396 padded_out = out["tokens"][padding_mask] 397 padded_target = target["tokens"][padding_mask] 398 399 acc = (padded_out == padded_target).sum() / padded_out.size(0) 400 return loss, acc
402 def generic_step( 403 self, x: Mapping[str, torch.Tensor], mode: str = "train" 404 ) -> torch.Tensor: 405 mean, logvar = self.encode_dist(x) 406 z = reparameterize(mean, logvar) 407 noise = torch.randn_like(z) * 0.5 408 z_text = z + noise 409 loss, acc = self.text_token_loss(z_text, x) 410 411 x_hat = self.projector_dec(z) 412 reconstruction_loss = gaussian_nll(x_hat, torch.tensor(0), x["bert"]).sum() 413 kl_loss = kl_divergence_loss(mean, logvar) 414 415 total_loss = ( 416 loss 417 + self.reconstruction_coef * reconstruction_loss 418 + self.kl_coef * kl_loss 419 ) 420 421 self.log(f"{mode}/kl_loss", kl_loss) 422 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 423 self.log(f"{mode}/total_loss", total_loss) 424 self.log(f"{mode}/loss", loss) 425 self.log(f"{mode}/acc", acc) 426 return total_loss # loss
428 def validation_step( # type: ignore 429 self, 430 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 431 _, 432 ) -> torch.Tensor: 433 x = batch[frozenset(["t"])]["t"] 434 return self.generic_step(x, "val")
Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
.None
- Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples::
# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# log the outputs!
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders, validation_step()
will have an additional argument. We recommend
setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
# dataloader_idx tells you which dataset this is.
...
Note:
If you don't need to validate you don't need to implement this method.
Note:
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
436 def training_step( # type: ignore 437 self, 438 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 439 _, 440 ) -> torch.Tensor: 441 x = batch[frozenset(["t"])]["t"] 442 return self.generic_step(x, "train")
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary which can include any keys, but must include the key'loss'
in the case of automatic optimization.None
- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example::
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
def __init__(self):
super().__init__()
self.automatic_optimization = False
# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
opt1, opt2 = self.optimizers()
# do training_step with encoder
...
opt1.step()
# do training_step with decoder
...
opt2.step()
Note:
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
444 def configure_optimizers( # type: ignore 445 self, 446 ) -> dict[str, Any]: 447 optimizer = AdamW( 448 self.parameters(), 449 lr=self.optim_lr, 450 weight_decay=self.optim_weight_decay, 451 ) 452 453 return {"optimizer": optimizer}
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
Return:
Any of these 6 options.
- Single optimizer.
- List or Tuple of optimizers.
- Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).- Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.- None - Fit will run without any optimizer.
The lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration.
The default configuration is shown below.
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
When there are schedulers in which the .step()
method is conditioned on a value, such as the
torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that the
lr_scheduler_config
contains the keyword "monitor"
set to the metric name that the scheduler
should be conditioned on.
.. testcode::
# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
optimizer = Adam(...)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(optimizer, ...),
"monitor": "metric_to_track",
"frequency": "indicates how often the metric is updated",
# If "monitor" references validation metrics, then "frequency" should be set to a
# multiple of "trainer.check_val_every_n_epoch".
},
}
# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return (
{
"optimizer": optimizer1,
"lr_scheduler": {
"scheduler": scheduler1,
"monitor": "metric_to_track",
},
},
{"optimizer": optimizer2, "lr_scheduler": scheduler2},
)
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in your ~lightning.pytorch.core.LightningModule
.
Note:
Some things to know:
- Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.- If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default "epoch") in the scheduler configuration, Lightning will call the scheduler's.step()
method automatically in case of automatic optimization.- If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.- If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.- If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
- If you need to control how often the optimizer steps, override the
optimizer_step()
hook.
456class Text2Attr(DomainModule): 457 def __init__( 458 self, 459 latent_dim: int, 460 hidden_dim: int, 461 text_model: GRUTextDomainModule, 462 optim_lr: float = 1e-3, 463 optim_weight_decay: float = 0, 464 scheduler_args: SchedulerArgs | None = None, 465 ) -> None: 466 super().__init__(latent_dim) 467 self.save_hyperparameters(ignore=["text_model"]) 468 469 self.text_model = text_model 470 471 self.pred_attr = nn.Sequential( 472 nn.Linear(latent_dim, hidden_dim), 473 nn.ReLU(), 474 nn.Linear(hidden_dim, hidden_dim), 475 nn.ReLU(), 476 nn.Linear(hidden_dim, hidden_dim), 477 nn.ReLU(), 478 nn.Linear(hidden_dim, 8), 479 nn.Tanh(), 480 ) 481 self.pred_cat = nn.Sequential( 482 nn.Linear(latent_dim, hidden_dim), 483 nn.ReLU(), 484 nn.Linear(hidden_dim, hidden_dim), 485 nn.ReLU(), 486 nn.Linear(hidden_dim, hidden_dim), 487 nn.ReLU(), 488 nn.Linear(hidden_dim, 3), 489 nn.Softmax(dim=1), 490 ) 491 492 self.optim_lr = optim_lr 493 self.optim_weight_decay = optim_weight_decay 494 495 self.scheduler_args = SchedulerArgs( 496 max_lr=optim_lr, 497 total_steps=1, 498 ) 499 self.scheduler_args.update(scheduler_args or {}) 500 501 def compute_loss( 502 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 503 ) -> LossOutput: 504 pred_attr = self.pred_attr(pred) 505 pred_cat = self.pred_cat(pred) 506 target_attr = self.pred_attr(target) 507 target_cat = self.pred_cat(target) 508 loss_attr = F.mse_loss(pred_attr, target_attr) 509 loss_cat = F.cross_entropy(pred_cat, target_cat.argmax(dim=1)) 510 loss = loss_attr + loss_cat 511 return LossOutput(loss, {"attr": loss_attr, "cat": loss_cat}) 512 513 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 514 return self.text_model.encode(x) 515 516 def decode(self, z: torch.Tensor) -> dict[str, Any]: 517 out: dict[str, Any] = {} 518 out.update(self.text_model.decode(z)) 519 pred_attr = self.pred_attr(z) 520 pred_cat = self.pred_cat(z) 521 out.update({"attr": [pred_cat, pred_attr, pred_attr]}) 522 return out 523 524 def forward(self, x: Mapping[str, Any]) -> dict[str, list[torch.Tensor]]: 525 return self.decode(self.encode(x)) 526 527 def generic_step( 528 self, 529 x: Mapping[str, Any], 530 mode: str = "train", 531 ) -> torch.Tensor: 532 text: Mapping[str, torch.Tensor] = x["t"] 533 attr: Sequence[torch.Tensor] = x["attr"] 534 text_l = self.text_model.encode(text) 535 pred_attr = self.pred_attr(text_l) 536 pred_cat = self.pred_cat(text_l) 537 cats = torch.argmax(attr[0], dim=1) 538 loss_cat = F.cross_entropy(pred_cat, cats) 539 loss_attr = F.mse_loss(pred_attr, attr[1]) 540 total_loss = loss_cat + loss_attr 541 pred_cats = pred_cat.argmax(dim=1) 542 acc = (cats == pred_cats).sum() / cats.size(0) 543 544 self.log(f"{mode}/loss_cat", loss_cat) 545 self.log(f"{mode}/loss_attr", loss_attr) 546 self.log(f"{mode}/acc_cat", acc) 547 self.log(f"{mode}/loss", total_loss) 548 549 return total_loss 550 551 def validation_step( # type: ignore 552 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 553 ) -> torch.Tensor: 554 return self.generic_step(batch, "val") 555 556 def training_step( # type: ignore 557 self, 558 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 559 _, 560 ) -> torch.Tensor: 561 data = batch[frozenset(["t", "attr"])] 562 return self.generic_step(data, "train") 563 564 def configure_optimizers( # type: ignore 565 self, 566 ) -> dict[str, Any]: 567 optimizer = AdamW( 568 self.parameters(), 569 lr=self.optim_lr, 570 weight_decay=self.optim_weight_decay, 571 ) 572 573 return {"optimizer": optimizer}
Base class for a DomainModule that defines domain specific modules of the GW.
457 def __init__( 458 self, 459 latent_dim: int, 460 hidden_dim: int, 461 text_model: GRUTextDomainModule, 462 optim_lr: float = 1e-3, 463 optim_weight_decay: float = 0, 464 scheduler_args: SchedulerArgs | None = None, 465 ) -> None: 466 super().__init__(latent_dim) 467 self.save_hyperparameters(ignore=["text_model"]) 468 469 self.text_model = text_model 470 471 self.pred_attr = nn.Sequential( 472 nn.Linear(latent_dim, hidden_dim), 473 nn.ReLU(), 474 nn.Linear(hidden_dim, hidden_dim), 475 nn.ReLU(), 476 nn.Linear(hidden_dim, hidden_dim), 477 nn.ReLU(), 478 nn.Linear(hidden_dim, 8), 479 nn.Tanh(), 480 ) 481 self.pred_cat = nn.Sequential( 482 nn.Linear(latent_dim, hidden_dim), 483 nn.ReLU(), 484 nn.Linear(hidden_dim, hidden_dim), 485 nn.ReLU(), 486 nn.Linear(hidden_dim, hidden_dim), 487 nn.ReLU(), 488 nn.Linear(hidden_dim, 3), 489 nn.Softmax(dim=1), 490 ) 491 492 self.optim_lr = optim_lr 493 self.optim_weight_decay = optim_weight_decay 494 495 self.scheduler_args = SchedulerArgs( 496 max_lr=optim_lr, 497 total_steps=1, 498 ) 499 self.scheduler_args.update(scheduler_args or {})
Initializes a DomainModule.
Arguments:
- latent_dim (
int
): latent dimension of the unimodal module
501 def compute_loss( 502 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 503 ) -> LossOutput: 504 pred_attr = self.pred_attr(pred) 505 pred_cat = self.pred_cat(pred) 506 target_attr = self.pred_attr(target) 507 target_cat = self.pred_cat(target) 508 loss_attr = F.mse_loss(pred_attr, target_attr) 509 loss_cat = F.cross_entropy(pred_cat, target_cat.argmax(dim=1)) 510 loss = loss_attr + loss_cat 511 return LossOutput(loss, {"attr": loss_attr, "cat": loss_cat})
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor - raw_target (
Any
): raw data from the input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.
513 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 514 return self.text_model.encode(x)
Encode the domain data into a unimodal representation.
Arguments:
- x (
Any
): data of the domain.
Returns:
torch.Tensor
: a unimodal representation.
516 def decode(self, z: torch.Tensor) -> dict[str, Any]: 517 out: dict[str, Any] = {} 518 out.update(self.text_model.decode(z)) 519 pred_attr = self.pred_attr(z) 520 pred_cat = self.pred_cat(z) 521 out.update({"attr": [pred_cat, pred_attr, pred_attr]}) 522 return out
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor
): unimodal representation of the domain.
Returns:
Any
: the original domain data.
524 def forward(self, x: Mapping[str, Any]) -> dict[str, list[torch.Tensor]]: 525 return self.decode(self.encode(x))
Same as torch.nn.Module.forward()
.
Arguments:
- *args: Whatever you decide to pass into the forward method.
- **kwargs: Keyword arguments are also possible.
Return:
Your model's output
527 def generic_step( 528 self, 529 x: Mapping[str, Any], 530 mode: str = "train", 531 ) -> torch.Tensor: 532 text: Mapping[str, torch.Tensor] = x["t"] 533 attr: Sequence[torch.Tensor] = x["attr"] 534 text_l = self.text_model.encode(text) 535 pred_attr = self.pred_attr(text_l) 536 pred_cat = self.pred_cat(text_l) 537 cats = torch.argmax(attr[0], dim=1) 538 loss_cat = F.cross_entropy(pred_cat, cats) 539 loss_attr = F.mse_loss(pred_attr, attr[1]) 540 total_loss = loss_cat + loss_attr 541 pred_cats = pred_cat.argmax(dim=1) 542 acc = (cats == pred_cats).sum() / cats.size(0) 543 544 self.log(f"{mode}/loss_cat", loss_cat) 545 self.log(f"{mode}/loss_attr", loss_attr) 546 self.log(f"{mode}/acc_cat", acc) 547 self.log(f"{mode}/loss", total_loss) 548 549 return total_loss
551 def validation_step( # type: ignore 552 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 553 ) -> torch.Tensor: 554 return self.generic_step(batch, "val")
Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
.None
- Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples::
# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# log the outputs!
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders, validation_step()
will have an additional argument. We recommend
setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
# dataloader_idx tells you which dataset this is.
...
Note:
If you don't need to validate you don't need to implement this method.
Note:
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
556 def training_step( # type: ignore 557 self, 558 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 559 _, 560 ) -> torch.Tensor: 561 data = batch[frozenset(["t", "attr"])] 562 return self.generic_step(data, "train")
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary which can include any keys, but must include the key'loss'
in the case of automatic optimization.None
- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example::
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
def __init__(self):
super().__init__()
self.automatic_optimization = False
# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
opt1, opt2 = self.optimizers()
# do training_step with encoder
...
opt1.step()
# do training_step with decoder
...
opt2.step()
Note:
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
564 def configure_optimizers( # type: ignore 565 self, 566 ) -> dict[str, Any]: 567 optimizer = AdamW( 568 self.parameters(), 569 lr=self.optim_lr, 570 weight_decay=self.optim_weight_decay, 571 ) 572 573 return {"optimizer": optimizer}
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
Return:
Any of these 6 options.
- Single optimizer.
- List or Tuple of optimizers.
- Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).- Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.- None - Fit will run without any optimizer.
The lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration.
The default configuration is shown below.
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
When there are schedulers in which the .step()
method is conditioned on a value, such as the
torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that the
lr_scheduler_config
contains the keyword "monitor"
set to the metric name that the scheduler
should be conditioned on.
.. testcode::
# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
optimizer = Adam(...)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(optimizer, ...),
"monitor": "metric_to_track",
"frequency": "indicates how often the metric is updated",
# If "monitor" references validation metrics, then "frequency" should be set to a
# multiple of "trainer.check_val_every_n_epoch".
},
}
# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return (
{
"optimizer": optimizer1,
"lr_scheduler": {
"scheduler": scheduler1,
"monitor": "metric_to_track",
},
},
{"optimizer": optimizer2, "lr_scheduler": scheduler2},
)
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in your ~lightning.pytorch.core.LightningModule
.
Note:
Some things to know:
- Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.- If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default "epoch") in the scheduler configuration, Lightning will call the scheduler's.step()
method automatically in case of automatic optimization.- If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.- If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.- If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
- If you need to control how often the optimizer steps, override the
optimizer_step()
hook.