shimmer_ssd.modules.domains.text
1from collections.abc import Mapping, Sequence 2from typing import Any 3 4import torch 5import torch.nn.functional as F 6from shimmer import LossOutput 7from shimmer.modules.domain import DomainModule 8from shimmer.modules.global_workspace import SchedulerArgs 9from shimmer.modules.vae import ( 10 VAE, 11 VAEDecoder, 12 VAEEncoder, 13 gaussian_nll, 14 kl_divergence_loss, 15) 16from simple_shapes_dataset.text import composer 17from simple_shapes_dataset.text.utils import inspect_all_choices 18from torch import nn 19from torch.optim.adamw import AdamW 20from torch.optim.lr_scheduler import OneCycleLR 21 22 23class Encoder(VAEEncoder): 24 def __init__( 25 self, 26 in_dim: int, 27 hidden_dim: int, 28 out_dim: int, 29 ): 30 super().__init__() 31 32 self.in_dim = in_dim 33 self.hidden_dim = hidden_dim 34 self.out_dim = out_dim 35 36 self.encoder = nn.Sequential( 37 nn.Linear(self.in_dim, hidden_dim), 38 nn.ReLU(), 39 nn.Linear(hidden_dim, hidden_dim), 40 nn.ReLU(), 41 nn.Linear(hidden_dim, out_dim), 42 nn.ReLU(), 43 ) 44 45 self.q_mean = nn.Linear(self.out_dim, self.out_dim) 46 self.q_logvar = nn.Linear(self.out_dim, self.out_dim) 47 48 def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: 49 out = torch.cat(list(x), dim=-1) 50 out = self.encoder(out) 51 return self.q_mean(out), self.q_logvar(out) 52 53 54class Decoder(VAEDecoder): 55 def __init__( 56 self, 57 in_dim: int, 58 hidden_dim: int, 59 out_dim: int, 60 ): 61 super().__init__() 62 63 self.in_dim = in_dim 64 self.hidden_dim = hidden_dim 65 self.out_dim = out_dim 66 67 self.decoder = nn.Sequential( 68 nn.Linear(self.in_dim, self.hidden_dim), 69 nn.ReLU(), 70 nn.Linear(self.hidden_dim, self.hidden_dim), 71 nn.ReLU(), 72 nn.Linear(self.hidden_dim, self.out_dim), 73 nn.Tanh(), 74 ) 75 76 def forward(self, z: torch.Tensor) -> list[torch.Tensor]: # type: ignore 77 return [self.decoder(z)] 78 79 80class TextDomainModule(DomainModule): 81 in_dim = 768 82 83 def __init__( 84 self, 85 latent_dim: int, 86 hidden_dim: int, 87 beta: float = 1, 88 optim_lr: float = 1e-3, 89 optim_weight_decay: float = 0, 90 scheduler_args: SchedulerArgs | None = None, 91 ): 92 super().__init__(latent_dim) 93 self.save_hyperparameters() 94 95 self.hidden_dim = hidden_dim 96 97 vae_encoder = Encoder(self.in_dim, self.hidden_dim, self.latent_dim) 98 vae_decoder = Decoder(self.latent_dim, self.hidden_dim, self.in_dim) 99 self.vae = VAE(vae_encoder, vae_decoder, beta) 100 101 self.attribute_cls = nn.Sequential( 102 nn.Linear(self.latent_dim, self.hidden_dim), 103 nn.ReLU(), 104 nn.Linear(self.hidden_dim, self.hidden_dim), 105 nn.ReLU(), 106 ) 107 self.attribute_cls_cat = nn.Linear(self.hidden_dim, 3) 108 self.attribute_cls_attr = nn.Sequential( 109 nn.Linear(self.hidden_dim, 8), nn.Tanh() 110 ) 111 112 self.composer_grammar_options = inspect_all_choices(composer) 113 114 self.grammar_cls = nn.Sequential( 115 nn.Linear(self.latent_dim, self.hidden_dim), 116 nn.ReLU(), 117 nn.Linear(self.hidden_dim, self.hidden_dim), 118 nn.ReLU(), 119 ) 120 self.grammar_heads = nn.ModuleDict( 121 { 122 name: nn.Linear(self.hidden_dim, n_outputs) 123 for name, n_outputs in self.composer_grammar_options.items() 124 } 125 ) 126 127 self.optim_lr = optim_lr 128 self.optim_weight_decay = optim_weight_decay 129 130 self.scheduler_args = SchedulerArgs( 131 max_lr=optim_lr, 132 total_steps=1, 133 ) 134 self.scheduler_args.update(scheduler_args or {}) 135 136 def compute_loss( 137 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 138 ) -> LossOutput: 139 return LossOutput(F.mse_loss(pred, target, reduction="mean")) 140 141 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 142 return self.vae.encode((x["bert"],)) 143 144 def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 145 text: dict[str, torch.Tensor] = {"bert": self.vae.decode(z)[0]} 146 attr_pred_cat, attr_pred_attr = self.predict_attr(z) 147 text["cls"] = attr_pred_cat 148 text["attr"] = attr_pred_attr 149 text["unpaired"] = torch.zeros_like(z[:, -1]) 150 text.update(self.predict_grammar(z)) 151 return text 152 153 def predict_attr(self, mean: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 154 attr_pred = self.attribute_cls(mean) 155 attr_pred_cat = self.attribute_cls_cat(attr_pred) 156 attr_pred_attr = self.attribute_cls_attr(attr_pred) 157 return attr_pred_cat, attr_pred_attr 158 159 def predict_grammar(self, mean: torch.Tensor) -> dict[str, torch.Tensor]: 160 grammar_pred = self.grammar_cls(mean) 161 return {name: head(grammar_pred) for name, head in self.grammar_heads.items()} 162 163 def grammar_losses(self, mean: torch.Tensor, targets) -> dict[str, torch.Tensor]: 164 grammar_pred = self.predict_grammar(mean) 165 return { 166 f"{name}_ce": F.cross_entropy( 167 pred, targets[name][:, 0].long(), reduction="sum" 168 ) 169 for name, pred in grammar_pred.items() 170 } 171 172 def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: 173 return self.decode(self.encode(x)) 174 175 def generic_step( 176 self, 177 x: Mapping[str, torch.Tensor], 178 mode: str = "train", 179 ) -> torch.Tensor: 180 (mean, logvar), reconstruction = self.vae((x["bert"],)) 181 182 reconstruction_loss = gaussian_nll( 183 reconstruction[0], torch.tensor(0), x["bert"] 184 ).sum() 185 186 kl_loss = kl_divergence_loss(mean, logvar) 187 188 attr_pred_cat, attr_pred_attr = self.predict_attr(mean) 189 190 loss_attr_cat = F.cross_entropy( 191 attr_pred_cat, x["cls"].argmax(dim=1), reduction="sum" 192 ) 193 loss_attr = F.mse_loss(attr_pred_attr, x["attr"], reduction="sum") 194 grammar_targets = {name: x[name] for name in self.composer_grammar_options} 195 grammar_losses = self.grammar_losses(mean, grammar_targets) 196 197 total_loss = ( 198 reconstruction_loss + self.vae.beta * kl_loss + loss_attr_cat + loss_attr 199 ) 200 201 for grammar_loss_name, grammar_loss in grammar_losses.items(): 202 total_loss += grammar_loss 203 self.log(f"{mode}/{grammar_loss_name}", grammar_loss) 204 205 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 206 self.log(f"{mode}/kl_loss", kl_loss) 207 self.log(f"{mode}/attr_category", loss_attr_cat) 208 self.log(f"{mode}/attr_attr", loss_attr) 209 self.log(f"{mode}/loss", total_loss) 210 return total_loss 211 212 def validation_step( # type: ignore 213 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 214 ) -> torch.Tensor: 215 x = batch["t"] 216 return self.generic_step(x, "val") 217 218 def training_step( # type: ignore 219 self, 220 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 221 _, 222 ) -> torch.Tensor: 223 x = batch[frozenset(["t"])]["t"] 224 return self.generic_step(x, "train") 225 226 def configure_optimizers( # type: ignore 227 self, 228 ) -> dict[str, Any]: 229 optimizer = torch.optim.AdamW( 230 self.parameters(), 231 lr=self.optim_lr, 232 weight_decay=self.optim_weight_decay, 233 ) 234 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 235 236 return { 237 "optimizer": optimizer, 238 "lr_scheduler": { 239 "scheduler": lr_scheduler, 240 "interval": "step", 241 }, 242 } 243 244 245class GRUEncoder(nn.Module): 246 def __init__( 247 self, 248 in_dim: int, 249 hidden_dim: int, 250 out_dim: int, 251 ): 252 super().__init__() 253 254 self.in_dim = in_dim 255 self.hidden_dim = hidden_dim 256 self.out_dim = out_dim 257 258 def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 259 out = torch.cat(list(x), dim=-1) 260 out = self.encoder(out) 261 return out 262 263 264class GRUTextDomainModule(DomainModule): 265 in_dim = 768 266 267 def __init__( 268 self, 269 latent_dim: int, 270 hidden_dim: int, 271 vocab_size: int, 272 seq_length: int, 273 optim_lr: float = 1e-3, 274 optim_weight_decay: float = 0, 275 scheduler_args: SchedulerArgs | None = None, 276 padding_token: int = 0, 277 ): 278 super().__init__(latent_dim) 279 self.is_frozen = False 280 self.save_hyperparameters() 281 282 self.hidden_dim = hidden_dim 283 self.latent_dim = latent_dim 284 self.vocab_size = vocab_size 285 self.seq_length = seq_length 286 287 self._padding_token = padding_token 288 289 self.projector = nn.Sequential( 290 nn.Linear(self.in_dim, self.hidden_dim), 291 nn.ReLU(), 292 nn.Linear(self.hidden_dim, self.latent_dim), 293 nn.Tanh(), 294 ) 295 296 self.embeddings = nn.Embedding(vocab_size, self.latent_dim) 297 298 self.num_layers = 2 299 self.decoder = nn.GRU( 300 self.latent_dim, 301 self.hidden_dim, 302 num_layers=self.num_layers, 303 batch_first=True, 304 ) 305 self.text_head = nn.Linear(self.hidden_dim, self.vocab_size) 306 307 self.optim_lr = optim_lr 308 self.optim_weight_decay = optim_weight_decay 309 310 self.scheduler_args = SchedulerArgs( 311 max_lr=optim_lr, 312 total_steps=1, 313 ) 314 self.scheduler_args.update(scheduler_args or {}) 315 316 def compute_loss( 317 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 318 ) -> LossOutput: 319 text_token_loss, acc = self.text_token_loss(pred, raw_target) 320 321 return LossOutput( 322 2 * text_token_loss, 323 {"loss_tokens": text_token_loss, "pred_t_acc": acc}, 324 ) 325 326 def compute_domain_loss(self, domain: Any) -> LossOutput: 327 z = self.encode(domain) 328 loss, acc = self.text_token_loss(z, domain) 329 return LossOutput(loss, {"acc": acc}) 330 331 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 332 return self.projector(x["bert"]) 333 334 def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 335 context = z.unsqueeze(1) 336 pad_tokens = self.embeddings( 337 torch.zeros( 338 z.size(0), self.seq_length - 1, dtype=torch.long, device=z.device 339 ) 340 ) 341 seq = torch.cat([context, pad_tokens], dim=1) 342 for k in range(0, self.seq_length - 1): 343 out = self.decode_one(seq) 344 seq[:, k + 1] = self.embeddings(out["tokens"][:, k]) 345 return self.decode_one(seq) 346 347 def decode_one(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 348 out, _ = self.decoder(z) 349 tokens_dist = self.text_head(out) 350 tokens = torch.argmax(tokens_dist, -1) 351 return {"token_dist": tokens_dist, "tokens": tokens} 352 353 def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: 354 return self.decode(self.encode(x)) 355 356 def text_token_loss( 357 self, z: torch.Tensor, target: Mapping[str, torch.Tensor] 358 ) -> tuple[torch.Tensor, torch.Tensor]: 359 context = z.unsqueeze(1) 360 real_tokens = self.embeddings(target["tokens"][:, :-1]) 361 seq = torch.cat([context, real_tokens], dim=1) 362 363 out = self.decode_one(seq) 364 loss = F.cross_entropy(out["token_dist"].transpose(1, 2), target["tokens"]) 365 padding_mask = target["tokens"] != self._padding_token 366 367 padded_out = out["tokens"][padding_mask] 368 padded_target = target["tokens"][padding_mask] 369 370 acc = (padded_out == padded_target).sum() / padded_out.size(0) 371 return loss, acc 372 373 def generic_step( 374 self, x: Mapping[str, torch.Tensor], mode: str = "train" 375 ) -> torch.Tensor: 376 z = self.encode(x) 377 loss, acc = self.text_token_loss(z, x) 378 self.log(f"{mode}/loss", loss) 379 self.log(f"{mode}/acc", acc) 380 return loss 381 382 def validation_step( # type: ignore 383 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 384 ) -> torch.Tensor: 385 x = batch["t"] 386 return self.generic_step(x, "val") 387 388 def training_step( # type: ignore 389 self, 390 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 391 _, 392 ) -> torch.Tensor: 393 x = batch[frozenset(["t"])]["t"] 394 return self.generic_step(x, "train") 395 396 def configure_optimizers( # type: ignore 397 self, 398 ) -> dict[str, Any]: 399 optimizer = AdamW( 400 self.parameters(), 401 lr=self.optim_lr, 402 weight_decay=self.optim_weight_decay, 403 ) 404 405 return {"optimizer": optimizer} 406 407 408class Text2Attr(DomainModule): 409 def __init__( 410 self, 411 latent_dim: int, 412 hidden_dim: int, 413 text_model: GRUTextDomainModule, 414 optim_lr: float = 1e-3, 415 optim_weight_decay: float = 0, 416 scheduler_args: SchedulerArgs | None = None, 417 ) -> None: 418 super().__init__(latent_dim) 419 self.save_hyperparameters(ignore=["text_model"]) 420 421 self.text_model = text_model 422 423 self.pred_attr = nn.Sequential( 424 nn.Linear(latent_dim, hidden_dim), 425 nn.ReLU(), 426 nn.Linear(hidden_dim, hidden_dim), 427 nn.ReLU(), 428 nn.Linear(hidden_dim, hidden_dim), 429 nn.ReLU(), 430 nn.Linear(hidden_dim, 8), 431 nn.Tanh(), 432 ) 433 self.pred_cat = nn.Sequential( 434 nn.Linear(latent_dim, hidden_dim), 435 nn.ReLU(), 436 nn.Linear(hidden_dim, hidden_dim), 437 nn.ReLU(), 438 nn.Linear(hidden_dim, hidden_dim), 439 nn.ReLU(), 440 nn.Linear(hidden_dim, 3), 441 nn.Softmax(dim=1), 442 ) 443 444 self.optim_lr = optim_lr 445 self.optim_weight_decay = optim_weight_decay 446 447 self.scheduler_args = SchedulerArgs( 448 max_lr=optim_lr, 449 total_steps=1, 450 ) 451 self.scheduler_args.update(scheduler_args or {}) 452 453 def compute_loss( 454 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 455 ) -> LossOutput: 456 pred_attr = self.pred_attr(pred) 457 pred_cat = self.pred_cat(pred) 458 target_attr = self.pred_attr(target) 459 target_cat = self.pred_cat(target) 460 loss_attr = F.mse_loss(pred_attr, target_attr) 461 loss_cat = F.cross_entropy(pred_cat, target_cat.argmax(dim=1)) 462 loss = loss_attr + loss_cat 463 return LossOutput(loss, {"attr": loss_attr, "cat": loss_cat}) 464 465 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 466 return self.text_model.encode(x) 467 468 def decode(self, z: torch.Tensor) -> dict[str, Any]: 469 out: dict[str, Any] = {} 470 out.update(self.text_model.decode(z)) 471 pred_attr = self.pred_attr(z) 472 pred_cat = self.pred_cat(z) 473 out.update({"attr": [pred_cat, pred_attr, pred_attr]}) 474 return out 475 476 def forward(self, x: Mapping[str, Any]) -> dict[str, list[torch.Tensor]]: 477 return self.decode(self.encode(x)) 478 479 def generic_step( 480 self, 481 x: Mapping[str, Any], 482 mode: str = "train", 483 ) -> torch.Tensor: 484 text: Mapping[str, torch.Tensor] = x["t"] 485 attr: Sequence[torch.Tensor] = x["attr"] 486 text_l = self.text_model.encode(text) 487 pred_attr = self.pred_attr(text_l) 488 pred_cat = self.pred_cat(text_l) 489 cats = torch.argmax(attr[0], dim=1) 490 loss_cat = F.cross_entropy(pred_cat, cats) 491 loss_attr = F.mse_loss(pred_attr, attr[1]) 492 total_loss = loss_cat + loss_attr 493 pred_cats = pred_cat.argmax(dim=1) 494 acc = (cats == pred_cats).sum() / cats.size(0) 495 496 self.log(f"{mode}/loss_cat", loss_cat) 497 self.log(f"{mode}/loss_attr", loss_attr) 498 self.log(f"{mode}/acc_cat", acc) 499 self.log(f"{mode}/loss", total_loss) 500 501 return total_loss 502 503 def validation_step( # type: ignore 504 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 505 ) -> torch.Tensor: 506 return self.generic_step(batch, "val") 507 508 def training_step( # type: ignore 509 self, 510 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 511 _, 512 ) -> torch.Tensor: 513 data = batch[frozenset(["t", "attr"])] 514 return self.generic_step(data, "train") 515 516 def configure_optimizers( # type: ignore 517 self, 518 ) -> dict[str, Any]: 519 optimizer = AdamW( 520 self.parameters(), 521 lr=self.optim_lr, 522 weight_decay=self.optim_weight_decay, 523 ) 524 525 return {"optimizer": optimizer}
24class Encoder(VAEEncoder): 25 def __init__( 26 self, 27 in_dim: int, 28 hidden_dim: int, 29 out_dim: int, 30 ): 31 super().__init__() 32 33 self.in_dim = in_dim 34 self.hidden_dim = hidden_dim 35 self.out_dim = out_dim 36 37 self.encoder = nn.Sequential( 38 nn.Linear(self.in_dim, hidden_dim), 39 nn.ReLU(), 40 nn.Linear(hidden_dim, hidden_dim), 41 nn.ReLU(), 42 nn.Linear(hidden_dim, out_dim), 43 nn.ReLU(), 44 ) 45 46 self.q_mean = nn.Linear(self.out_dim, self.out_dim) 47 self.q_logvar = nn.Linear(self.out_dim, self.out_dim) 48 49 def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: 50 out = torch.cat(list(x), dim=-1) 51 out = self.encoder(out) 52 return self.q_mean(out), self.q_logvar(out)
Base class for a VAE encoder.
25 def __init__( 26 self, 27 in_dim: int, 28 hidden_dim: int, 29 out_dim: int, 30 ): 31 super().__init__() 32 33 self.in_dim = in_dim 34 self.hidden_dim = hidden_dim 35 self.out_dim = out_dim 36 37 self.encoder = nn.Sequential( 38 nn.Linear(self.in_dim, hidden_dim), 39 nn.ReLU(), 40 nn.Linear(hidden_dim, hidden_dim), 41 nn.ReLU(), 42 nn.Linear(hidden_dim, out_dim), 43 nn.ReLU(), 44 ) 45 46 self.q_mean = nn.Linear(self.out_dim, self.out_dim) 47 self.q_logvar = nn.Linear(self.out_dim, self.out_dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
49 def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: 50 out = torch.cat(list(x), dim=-1) 51 out = self.encoder(out) 52 return self.q_mean(out), self.q_logvar(out)
Encode representation with VAE.
Arguments:
- x (
Any
): Some input value
Returns:
tuple[torch.Tensor, torch.Tensor]
: the mean and log variance
55class Decoder(VAEDecoder): 56 def __init__( 57 self, 58 in_dim: int, 59 hidden_dim: int, 60 out_dim: int, 61 ): 62 super().__init__() 63 64 self.in_dim = in_dim 65 self.hidden_dim = hidden_dim 66 self.out_dim = out_dim 67 68 self.decoder = nn.Sequential( 69 nn.Linear(self.in_dim, self.hidden_dim), 70 nn.ReLU(), 71 nn.Linear(self.hidden_dim, self.hidden_dim), 72 nn.ReLU(), 73 nn.Linear(self.hidden_dim, self.out_dim), 74 nn.Tanh(), 75 ) 76 77 def forward(self, z: torch.Tensor) -> list[torch.Tensor]: # type: ignore 78 return [self.decoder(z)]
Base class for a VAE decoder.
56 def __init__( 57 self, 58 in_dim: int, 59 hidden_dim: int, 60 out_dim: int, 61 ): 62 super().__init__() 63 64 self.in_dim = in_dim 65 self.hidden_dim = hidden_dim 66 self.out_dim = out_dim 67 68 self.decoder = nn.Sequential( 69 nn.Linear(self.in_dim, self.hidden_dim), 70 nn.ReLU(), 71 nn.Linear(self.hidden_dim, self.hidden_dim), 72 nn.ReLU(), 73 nn.Linear(self.hidden_dim, self.out_dim), 74 nn.Tanh(), 75 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
77 def forward(self, z: torch.Tensor) -> list[torch.Tensor]: # type: ignore 78 return [self.decoder(z)]
Decode representation with VAE
Arguments:
- x (
torch.Tensor
): VAE latent representation representation
Returns:
Any
: the reconstructed input
81class TextDomainModule(DomainModule): 82 in_dim = 768 83 84 def __init__( 85 self, 86 latent_dim: int, 87 hidden_dim: int, 88 beta: float = 1, 89 optim_lr: float = 1e-3, 90 optim_weight_decay: float = 0, 91 scheduler_args: SchedulerArgs | None = None, 92 ): 93 super().__init__(latent_dim) 94 self.save_hyperparameters() 95 96 self.hidden_dim = hidden_dim 97 98 vae_encoder = Encoder(self.in_dim, self.hidden_dim, self.latent_dim) 99 vae_decoder = Decoder(self.latent_dim, self.hidden_dim, self.in_dim) 100 self.vae = VAE(vae_encoder, vae_decoder, beta) 101 102 self.attribute_cls = nn.Sequential( 103 nn.Linear(self.latent_dim, self.hidden_dim), 104 nn.ReLU(), 105 nn.Linear(self.hidden_dim, self.hidden_dim), 106 nn.ReLU(), 107 ) 108 self.attribute_cls_cat = nn.Linear(self.hidden_dim, 3) 109 self.attribute_cls_attr = nn.Sequential( 110 nn.Linear(self.hidden_dim, 8), nn.Tanh() 111 ) 112 113 self.composer_grammar_options = inspect_all_choices(composer) 114 115 self.grammar_cls = nn.Sequential( 116 nn.Linear(self.latent_dim, self.hidden_dim), 117 nn.ReLU(), 118 nn.Linear(self.hidden_dim, self.hidden_dim), 119 nn.ReLU(), 120 ) 121 self.grammar_heads = nn.ModuleDict( 122 { 123 name: nn.Linear(self.hidden_dim, n_outputs) 124 for name, n_outputs in self.composer_grammar_options.items() 125 } 126 ) 127 128 self.optim_lr = optim_lr 129 self.optim_weight_decay = optim_weight_decay 130 131 self.scheduler_args = SchedulerArgs( 132 max_lr=optim_lr, 133 total_steps=1, 134 ) 135 self.scheduler_args.update(scheduler_args or {}) 136 137 def compute_loss( 138 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 139 ) -> LossOutput: 140 return LossOutput(F.mse_loss(pred, target, reduction="mean")) 141 142 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 143 return self.vae.encode((x["bert"],)) 144 145 def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 146 text: dict[str, torch.Tensor] = {"bert": self.vae.decode(z)[0]} 147 attr_pred_cat, attr_pred_attr = self.predict_attr(z) 148 text["cls"] = attr_pred_cat 149 text["attr"] = attr_pred_attr 150 text["unpaired"] = torch.zeros_like(z[:, -1]) 151 text.update(self.predict_grammar(z)) 152 return text 153 154 def predict_attr(self, mean: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 155 attr_pred = self.attribute_cls(mean) 156 attr_pred_cat = self.attribute_cls_cat(attr_pred) 157 attr_pred_attr = self.attribute_cls_attr(attr_pred) 158 return attr_pred_cat, attr_pred_attr 159 160 def predict_grammar(self, mean: torch.Tensor) -> dict[str, torch.Tensor]: 161 grammar_pred = self.grammar_cls(mean) 162 return {name: head(grammar_pred) for name, head in self.grammar_heads.items()} 163 164 def grammar_losses(self, mean: torch.Tensor, targets) -> dict[str, torch.Tensor]: 165 grammar_pred = self.predict_grammar(mean) 166 return { 167 f"{name}_ce": F.cross_entropy( 168 pred, targets[name][:, 0].long(), reduction="sum" 169 ) 170 for name, pred in grammar_pred.items() 171 } 172 173 def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: 174 return self.decode(self.encode(x)) 175 176 def generic_step( 177 self, 178 x: Mapping[str, torch.Tensor], 179 mode: str = "train", 180 ) -> torch.Tensor: 181 (mean, logvar), reconstruction = self.vae((x["bert"],)) 182 183 reconstruction_loss = gaussian_nll( 184 reconstruction[0], torch.tensor(0), x["bert"] 185 ).sum() 186 187 kl_loss = kl_divergence_loss(mean, logvar) 188 189 attr_pred_cat, attr_pred_attr = self.predict_attr(mean) 190 191 loss_attr_cat = F.cross_entropy( 192 attr_pred_cat, x["cls"].argmax(dim=1), reduction="sum" 193 ) 194 loss_attr = F.mse_loss(attr_pred_attr, x["attr"], reduction="sum") 195 grammar_targets = {name: x[name] for name in self.composer_grammar_options} 196 grammar_losses = self.grammar_losses(mean, grammar_targets) 197 198 total_loss = ( 199 reconstruction_loss + self.vae.beta * kl_loss + loss_attr_cat + loss_attr 200 ) 201 202 for grammar_loss_name, grammar_loss in grammar_losses.items(): 203 total_loss += grammar_loss 204 self.log(f"{mode}/{grammar_loss_name}", grammar_loss) 205 206 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 207 self.log(f"{mode}/kl_loss", kl_loss) 208 self.log(f"{mode}/attr_category", loss_attr_cat) 209 self.log(f"{mode}/attr_attr", loss_attr) 210 self.log(f"{mode}/loss", total_loss) 211 return total_loss 212 213 def validation_step( # type: ignore 214 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 215 ) -> torch.Tensor: 216 x = batch["t"] 217 return self.generic_step(x, "val") 218 219 def training_step( # type: ignore 220 self, 221 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 222 _, 223 ) -> torch.Tensor: 224 x = batch[frozenset(["t"])]["t"] 225 return self.generic_step(x, "train") 226 227 def configure_optimizers( # type: ignore 228 self, 229 ) -> dict[str, Any]: 230 optimizer = torch.optim.AdamW( 231 self.parameters(), 232 lr=self.optim_lr, 233 weight_decay=self.optim_weight_decay, 234 ) 235 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 236 237 return { 238 "optimizer": optimizer, 239 "lr_scheduler": { 240 "scheduler": lr_scheduler, 241 "interval": "step", 242 }, 243 }
Base class for a DomainModule that defines domain specific modules of the GW.
84 def __init__( 85 self, 86 latent_dim: int, 87 hidden_dim: int, 88 beta: float = 1, 89 optim_lr: float = 1e-3, 90 optim_weight_decay: float = 0, 91 scheduler_args: SchedulerArgs | None = None, 92 ): 93 super().__init__(latent_dim) 94 self.save_hyperparameters() 95 96 self.hidden_dim = hidden_dim 97 98 vae_encoder = Encoder(self.in_dim, self.hidden_dim, self.latent_dim) 99 vae_decoder = Decoder(self.latent_dim, self.hidden_dim, self.in_dim) 100 self.vae = VAE(vae_encoder, vae_decoder, beta) 101 102 self.attribute_cls = nn.Sequential( 103 nn.Linear(self.latent_dim, self.hidden_dim), 104 nn.ReLU(), 105 nn.Linear(self.hidden_dim, self.hidden_dim), 106 nn.ReLU(), 107 ) 108 self.attribute_cls_cat = nn.Linear(self.hidden_dim, 3) 109 self.attribute_cls_attr = nn.Sequential( 110 nn.Linear(self.hidden_dim, 8), nn.Tanh() 111 ) 112 113 self.composer_grammar_options = inspect_all_choices(composer) 114 115 self.grammar_cls = nn.Sequential( 116 nn.Linear(self.latent_dim, self.hidden_dim), 117 nn.ReLU(), 118 nn.Linear(self.hidden_dim, self.hidden_dim), 119 nn.ReLU(), 120 ) 121 self.grammar_heads = nn.ModuleDict( 122 { 123 name: nn.Linear(self.hidden_dim, n_outputs) 124 for name, n_outputs in self.composer_grammar_options.items() 125 } 126 ) 127 128 self.optim_lr = optim_lr 129 self.optim_weight_decay = optim_weight_decay 130 131 self.scheduler_args = SchedulerArgs( 132 max_lr=optim_lr, 133 total_steps=1, 134 ) 135 self.scheduler_args.update(scheduler_args or {})
Initializes a DomainModule.
Arguments:
- latent_dim (
int
): latent dimension of the unimodal module
137 def compute_loss( 138 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 139 ) -> LossOutput: 140 return LossOutput(F.mse_loss(pred, target, reduction="mean"))
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor - raw_target (
Any
): raw data from the input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.
142 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 143 return self.vae.encode((x["bert"],))
Encode the domain data into a unimodal representation.
Arguments:
- x (
Any
): data of the domain.
Returns:
torch.Tensor
: a unimodal representation.
145 def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 146 text: dict[str, torch.Tensor] = {"bert": self.vae.decode(z)[0]} 147 attr_pred_cat, attr_pred_attr = self.predict_attr(z) 148 text["cls"] = attr_pred_cat 149 text["attr"] = attr_pred_attr 150 text["unpaired"] = torch.zeros_like(z[:, -1]) 151 text.update(self.predict_grammar(z)) 152 return text
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor
): unimodal representation of the domain.
Returns:
Any
: the original domain data.
173 def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: 174 return self.decode(self.encode(x))
Same as torch.nn.Module.forward()
.
Arguments:
- *args: Whatever you decide to pass into the forward method.
- **kwargs: Keyword arguments are also possible.
Return:
Your model's output
176 def generic_step( 177 self, 178 x: Mapping[str, torch.Tensor], 179 mode: str = "train", 180 ) -> torch.Tensor: 181 (mean, logvar), reconstruction = self.vae((x["bert"],)) 182 183 reconstruction_loss = gaussian_nll( 184 reconstruction[0], torch.tensor(0), x["bert"] 185 ).sum() 186 187 kl_loss = kl_divergence_loss(mean, logvar) 188 189 attr_pred_cat, attr_pred_attr = self.predict_attr(mean) 190 191 loss_attr_cat = F.cross_entropy( 192 attr_pred_cat, x["cls"].argmax(dim=1), reduction="sum" 193 ) 194 loss_attr = F.mse_loss(attr_pred_attr, x["attr"], reduction="sum") 195 grammar_targets = {name: x[name] for name in self.composer_grammar_options} 196 grammar_losses = self.grammar_losses(mean, grammar_targets) 197 198 total_loss = ( 199 reconstruction_loss + self.vae.beta * kl_loss + loss_attr_cat + loss_attr 200 ) 201 202 for grammar_loss_name, grammar_loss in grammar_losses.items(): 203 total_loss += grammar_loss 204 self.log(f"{mode}/{grammar_loss_name}", grammar_loss) 205 206 self.log(f"{mode}/reconstruction_loss", reconstruction_loss) 207 self.log(f"{mode}/kl_loss", kl_loss) 208 self.log(f"{mode}/attr_category", loss_attr_cat) 209 self.log(f"{mode}/attr_attr", loss_attr) 210 self.log(f"{mode}/loss", total_loss) 211 return total_loss
213 def validation_step( # type: ignore 214 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 215 ) -> torch.Tensor: 216 x = batch["t"] 217 return self.generic_step(x, "val")
Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
.None
- Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples::
# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# log the outputs!
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders, validation_step()
will have an additional argument. We recommend
setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
# dataloader_idx tells you which dataset this is.
...
Note:
If you don't need to validate you don't need to implement this method.
Note:
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
219 def training_step( # type: ignore 220 self, 221 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 222 _, 223 ) -> torch.Tensor: 224 x = batch[frozenset(["t"])]["t"] 225 return self.generic_step(x, "train")
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary which can include any keys, but must include the key'loss'
in the case of automatic optimization.None
- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example::
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
def __init__(self):
super().__init__()
self.automatic_optimization = False
# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
opt1, opt2 = self.optimizers()
# do training_step with encoder
...
opt1.step()
# do training_step with decoder
...
opt2.step()
Note:
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
227 def configure_optimizers( # type: ignore 228 self, 229 ) -> dict[str, Any]: 230 optimizer = torch.optim.AdamW( 231 self.parameters(), 232 lr=self.optim_lr, 233 weight_decay=self.optim_weight_decay, 234 ) 235 lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args) 236 237 return { 238 "optimizer": optimizer, 239 "lr_scheduler": { 240 "scheduler": lr_scheduler, 241 "interval": "step", 242 }, 243 }
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
Return:
Any of these 6 options.
- Single optimizer.
- List or Tuple of optimizers.
- Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).- Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.- None - Fit will run without any optimizer.
The lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration.
The default configuration is shown below.
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
When there are schedulers in which the .step()
method is conditioned on a value, such as the
torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that the
lr_scheduler_config
contains the keyword "monitor"
set to the metric name that the scheduler
should be conditioned on.
.. testcode::
# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
optimizer = Adam(...)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(optimizer, ...),
"monitor": "metric_to_track",
"frequency": "indicates how often the metric is updated",
# If "monitor" references validation metrics, then "frequency" should be set to a
# multiple of "trainer.check_val_every_n_epoch".
},
}
# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return (
{
"optimizer": optimizer1,
"lr_scheduler": {
"scheduler": scheduler1,
"monitor": "metric_to_track",
},
},
{"optimizer": optimizer2, "lr_scheduler": scheduler2},
)
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in your ~lightning.pytorch.core.LightningModule
.
Note:
Some things to know:
- Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.- If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default "epoch") in the scheduler configuration, Lightning will call the scheduler's.step()
method automatically in case of automatic optimization.- If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.- If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.- If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
- If you need to control how often the optimizer steps, override the
optimizer_step()
hook.
246class GRUEncoder(nn.Module): 247 def __init__( 248 self, 249 in_dim: int, 250 hidden_dim: int, 251 out_dim: int, 252 ): 253 super().__init__() 254 255 self.in_dim = in_dim 256 self.hidden_dim = hidden_dim 257 self.out_dim = out_dim 258 259 def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 260 out = torch.cat(list(x), dim=-1) 261 out = self.encoder(out) 262 return out
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
247 def __init__( 248 self, 249 in_dim: int, 250 hidden_dim: int, 251 out_dim: int, 252 ): 253 super().__init__() 254 255 self.in_dim = in_dim 256 self.hidden_dim = hidden_dim 257 self.out_dim = out_dim
Initialize internal Module state, shared by both nn.Module and ScriptModule.
259 def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor: 260 out = torch.cat(list(x), dim=-1) 261 out = self.encoder(out) 262 return out
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
265class GRUTextDomainModule(DomainModule): 266 in_dim = 768 267 268 def __init__( 269 self, 270 latent_dim: int, 271 hidden_dim: int, 272 vocab_size: int, 273 seq_length: int, 274 optim_lr: float = 1e-3, 275 optim_weight_decay: float = 0, 276 scheduler_args: SchedulerArgs | None = None, 277 padding_token: int = 0, 278 ): 279 super().__init__(latent_dim) 280 self.is_frozen = False 281 self.save_hyperparameters() 282 283 self.hidden_dim = hidden_dim 284 self.latent_dim = latent_dim 285 self.vocab_size = vocab_size 286 self.seq_length = seq_length 287 288 self._padding_token = padding_token 289 290 self.projector = nn.Sequential( 291 nn.Linear(self.in_dim, self.hidden_dim), 292 nn.ReLU(), 293 nn.Linear(self.hidden_dim, self.latent_dim), 294 nn.Tanh(), 295 ) 296 297 self.embeddings = nn.Embedding(vocab_size, self.latent_dim) 298 299 self.num_layers = 2 300 self.decoder = nn.GRU( 301 self.latent_dim, 302 self.hidden_dim, 303 num_layers=self.num_layers, 304 batch_first=True, 305 ) 306 self.text_head = nn.Linear(self.hidden_dim, self.vocab_size) 307 308 self.optim_lr = optim_lr 309 self.optim_weight_decay = optim_weight_decay 310 311 self.scheduler_args = SchedulerArgs( 312 max_lr=optim_lr, 313 total_steps=1, 314 ) 315 self.scheduler_args.update(scheduler_args or {}) 316 317 def compute_loss( 318 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 319 ) -> LossOutput: 320 text_token_loss, acc = self.text_token_loss(pred, raw_target) 321 322 return LossOutput( 323 2 * text_token_loss, 324 {"loss_tokens": text_token_loss, "pred_t_acc": acc}, 325 ) 326 327 def compute_domain_loss(self, domain: Any) -> LossOutput: 328 z = self.encode(domain) 329 loss, acc = self.text_token_loss(z, domain) 330 return LossOutput(loss, {"acc": acc}) 331 332 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 333 return self.projector(x["bert"]) 334 335 def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 336 context = z.unsqueeze(1) 337 pad_tokens = self.embeddings( 338 torch.zeros( 339 z.size(0), self.seq_length - 1, dtype=torch.long, device=z.device 340 ) 341 ) 342 seq = torch.cat([context, pad_tokens], dim=1) 343 for k in range(0, self.seq_length - 1): 344 out = self.decode_one(seq) 345 seq[:, k + 1] = self.embeddings(out["tokens"][:, k]) 346 return self.decode_one(seq) 347 348 def decode_one(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 349 out, _ = self.decoder(z) 350 tokens_dist = self.text_head(out) 351 tokens = torch.argmax(tokens_dist, -1) 352 return {"token_dist": tokens_dist, "tokens": tokens} 353 354 def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: 355 return self.decode(self.encode(x)) 356 357 def text_token_loss( 358 self, z: torch.Tensor, target: Mapping[str, torch.Tensor] 359 ) -> tuple[torch.Tensor, torch.Tensor]: 360 context = z.unsqueeze(1) 361 real_tokens = self.embeddings(target["tokens"][:, :-1]) 362 seq = torch.cat([context, real_tokens], dim=1) 363 364 out = self.decode_one(seq) 365 loss = F.cross_entropy(out["token_dist"].transpose(1, 2), target["tokens"]) 366 padding_mask = target["tokens"] != self._padding_token 367 368 padded_out = out["tokens"][padding_mask] 369 padded_target = target["tokens"][padding_mask] 370 371 acc = (padded_out == padded_target).sum() / padded_out.size(0) 372 return loss, acc 373 374 def generic_step( 375 self, x: Mapping[str, torch.Tensor], mode: str = "train" 376 ) -> torch.Tensor: 377 z = self.encode(x) 378 loss, acc = self.text_token_loss(z, x) 379 self.log(f"{mode}/loss", loss) 380 self.log(f"{mode}/acc", acc) 381 return loss 382 383 def validation_step( # type: ignore 384 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 385 ) -> torch.Tensor: 386 x = batch["t"] 387 return self.generic_step(x, "val") 388 389 def training_step( # type: ignore 390 self, 391 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 392 _, 393 ) -> torch.Tensor: 394 x = batch[frozenset(["t"])]["t"] 395 return self.generic_step(x, "train") 396 397 def configure_optimizers( # type: ignore 398 self, 399 ) -> dict[str, Any]: 400 optimizer = AdamW( 401 self.parameters(), 402 lr=self.optim_lr, 403 weight_decay=self.optim_weight_decay, 404 ) 405 406 return {"optimizer": optimizer}
Base class for a DomainModule that defines domain specific modules of the GW.
268 def __init__( 269 self, 270 latent_dim: int, 271 hidden_dim: int, 272 vocab_size: int, 273 seq_length: int, 274 optim_lr: float = 1e-3, 275 optim_weight_decay: float = 0, 276 scheduler_args: SchedulerArgs | None = None, 277 padding_token: int = 0, 278 ): 279 super().__init__(latent_dim) 280 self.is_frozen = False 281 self.save_hyperparameters() 282 283 self.hidden_dim = hidden_dim 284 self.latent_dim = latent_dim 285 self.vocab_size = vocab_size 286 self.seq_length = seq_length 287 288 self._padding_token = padding_token 289 290 self.projector = nn.Sequential( 291 nn.Linear(self.in_dim, self.hidden_dim), 292 nn.ReLU(), 293 nn.Linear(self.hidden_dim, self.latent_dim), 294 nn.Tanh(), 295 ) 296 297 self.embeddings = nn.Embedding(vocab_size, self.latent_dim) 298 299 self.num_layers = 2 300 self.decoder = nn.GRU( 301 self.latent_dim, 302 self.hidden_dim, 303 num_layers=self.num_layers, 304 batch_first=True, 305 ) 306 self.text_head = nn.Linear(self.hidden_dim, self.vocab_size) 307 308 self.optim_lr = optim_lr 309 self.optim_weight_decay = optim_weight_decay 310 311 self.scheduler_args = SchedulerArgs( 312 max_lr=optim_lr, 313 total_steps=1, 314 ) 315 self.scheduler_args.update(scheduler_args or {})
Initializes a DomainModule.
Arguments:
- latent_dim (
int
): latent dimension of the unimodal module
317 def compute_loss( 318 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 319 ) -> LossOutput: 320 text_token_loss, acc = self.text_token_loss(pred, raw_target) 321 322 return LossOutput( 323 2 * text_token_loss, 324 {"loss_tokens": text_token_loss, "pred_t_acc": acc}, 325 )
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor - raw_target (
Any
): raw data from the input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.
327 def compute_domain_loss(self, domain: Any) -> LossOutput: 328 z = self.encode(domain) 329 loss, acc = self.text_token_loss(z, domain) 330 return LossOutput(loss, {"acc": acc})
Compute the unimodal domain loss.
Arguments:
- domain (
Any
): domain input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.
332 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 333 return self.projector(x["bert"])
Encode the domain data into a unimodal representation.
Arguments:
- x (
Any
): data of the domain.
Returns:
torch.Tensor
: a unimodal representation.
335 def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]: 336 context = z.unsqueeze(1) 337 pad_tokens = self.embeddings( 338 torch.zeros( 339 z.size(0), self.seq_length - 1, dtype=torch.long, device=z.device 340 ) 341 ) 342 seq = torch.cat([context, pad_tokens], dim=1) 343 for k in range(0, self.seq_length - 1): 344 out = self.decode_one(seq) 345 seq[:, k + 1] = self.embeddings(out["tokens"][:, k]) 346 return self.decode_one(seq)
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor
): unimodal representation of the domain.
Returns:
Any
: the original domain data.
354 def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]: 355 return self.decode(self.encode(x))
Same as torch.nn.Module.forward()
.
Arguments:
- *args: Whatever you decide to pass into the forward method.
- **kwargs: Keyword arguments are also possible.
Return:
Your model's output
357 def text_token_loss( 358 self, z: torch.Tensor, target: Mapping[str, torch.Tensor] 359 ) -> tuple[torch.Tensor, torch.Tensor]: 360 context = z.unsqueeze(1) 361 real_tokens = self.embeddings(target["tokens"][:, :-1]) 362 seq = torch.cat([context, real_tokens], dim=1) 363 364 out = self.decode_one(seq) 365 loss = F.cross_entropy(out["token_dist"].transpose(1, 2), target["tokens"]) 366 padding_mask = target["tokens"] != self._padding_token 367 368 padded_out = out["tokens"][padding_mask] 369 padded_target = target["tokens"][padding_mask] 370 371 acc = (padded_out == padded_target).sum() / padded_out.size(0) 372 return loss, acc
383 def validation_step( # type: ignore 384 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 385 ) -> torch.Tensor: 386 x = batch["t"] 387 return self.generic_step(x, "val")
Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
.None
- Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples::
# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# log the outputs!
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders, validation_step()
will have an additional argument. We recommend
setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
# dataloader_idx tells you which dataset this is.
...
Note:
If you don't need to validate you don't need to implement this method.
Note:
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
389 def training_step( # type: ignore 390 self, 391 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 392 _, 393 ) -> torch.Tensor: 394 x = batch[frozenset(["t"])]["t"] 395 return self.generic_step(x, "train")
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary which can include any keys, but must include the key'loss'
in the case of automatic optimization.None
- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example::
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
def __init__(self):
super().__init__()
self.automatic_optimization = False
# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
opt1, opt2 = self.optimizers()
# do training_step with encoder
...
opt1.step()
# do training_step with decoder
...
opt2.step()
Note:
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
397 def configure_optimizers( # type: ignore 398 self, 399 ) -> dict[str, Any]: 400 optimizer = AdamW( 401 self.parameters(), 402 lr=self.optim_lr, 403 weight_decay=self.optim_weight_decay, 404 ) 405 406 return {"optimizer": optimizer}
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
Return:
Any of these 6 options.
- Single optimizer.
- List or Tuple of optimizers.
- Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).- Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.- None - Fit will run without any optimizer.
The lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration.
The default configuration is shown below.
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
When there are schedulers in which the .step()
method is conditioned on a value, such as the
torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that the
lr_scheduler_config
contains the keyword "monitor"
set to the metric name that the scheduler
should be conditioned on.
.. testcode::
# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
optimizer = Adam(...)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(optimizer, ...),
"monitor": "metric_to_track",
"frequency": "indicates how often the metric is updated",
# If "monitor" references validation metrics, then "frequency" should be set to a
# multiple of "trainer.check_val_every_n_epoch".
},
}
# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return (
{
"optimizer": optimizer1,
"lr_scheduler": {
"scheduler": scheduler1,
"monitor": "metric_to_track",
},
},
{"optimizer": optimizer2, "lr_scheduler": scheduler2},
)
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in your ~lightning.pytorch.core.LightningModule
.
Note:
Some things to know:
- Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.- If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default "epoch") in the scheduler configuration, Lightning will call the scheduler's.step()
method automatically in case of automatic optimization.- If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.- If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.- If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
- If you need to control how often the optimizer steps, override the
optimizer_step()
hook.
409class Text2Attr(DomainModule): 410 def __init__( 411 self, 412 latent_dim: int, 413 hidden_dim: int, 414 text_model: GRUTextDomainModule, 415 optim_lr: float = 1e-3, 416 optim_weight_decay: float = 0, 417 scheduler_args: SchedulerArgs | None = None, 418 ) -> None: 419 super().__init__(latent_dim) 420 self.save_hyperparameters(ignore=["text_model"]) 421 422 self.text_model = text_model 423 424 self.pred_attr = nn.Sequential( 425 nn.Linear(latent_dim, hidden_dim), 426 nn.ReLU(), 427 nn.Linear(hidden_dim, hidden_dim), 428 nn.ReLU(), 429 nn.Linear(hidden_dim, hidden_dim), 430 nn.ReLU(), 431 nn.Linear(hidden_dim, 8), 432 nn.Tanh(), 433 ) 434 self.pred_cat = nn.Sequential( 435 nn.Linear(latent_dim, hidden_dim), 436 nn.ReLU(), 437 nn.Linear(hidden_dim, hidden_dim), 438 nn.ReLU(), 439 nn.Linear(hidden_dim, hidden_dim), 440 nn.ReLU(), 441 nn.Linear(hidden_dim, 3), 442 nn.Softmax(dim=1), 443 ) 444 445 self.optim_lr = optim_lr 446 self.optim_weight_decay = optim_weight_decay 447 448 self.scheduler_args = SchedulerArgs( 449 max_lr=optim_lr, 450 total_steps=1, 451 ) 452 self.scheduler_args.update(scheduler_args or {}) 453 454 def compute_loss( 455 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 456 ) -> LossOutput: 457 pred_attr = self.pred_attr(pred) 458 pred_cat = self.pred_cat(pred) 459 target_attr = self.pred_attr(target) 460 target_cat = self.pred_cat(target) 461 loss_attr = F.mse_loss(pred_attr, target_attr) 462 loss_cat = F.cross_entropy(pred_cat, target_cat.argmax(dim=1)) 463 loss = loss_attr + loss_cat 464 return LossOutput(loss, {"attr": loss_attr, "cat": loss_cat}) 465 466 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 467 return self.text_model.encode(x) 468 469 def decode(self, z: torch.Tensor) -> dict[str, Any]: 470 out: dict[str, Any] = {} 471 out.update(self.text_model.decode(z)) 472 pred_attr = self.pred_attr(z) 473 pred_cat = self.pred_cat(z) 474 out.update({"attr": [pred_cat, pred_attr, pred_attr]}) 475 return out 476 477 def forward(self, x: Mapping[str, Any]) -> dict[str, list[torch.Tensor]]: 478 return self.decode(self.encode(x)) 479 480 def generic_step( 481 self, 482 x: Mapping[str, Any], 483 mode: str = "train", 484 ) -> torch.Tensor: 485 text: Mapping[str, torch.Tensor] = x["t"] 486 attr: Sequence[torch.Tensor] = x["attr"] 487 text_l = self.text_model.encode(text) 488 pred_attr = self.pred_attr(text_l) 489 pred_cat = self.pred_cat(text_l) 490 cats = torch.argmax(attr[0], dim=1) 491 loss_cat = F.cross_entropy(pred_cat, cats) 492 loss_attr = F.mse_loss(pred_attr, attr[1]) 493 total_loss = loss_cat + loss_attr 494 pred_cats = pred_cat.argmax(dim=1) 495 acc = (cats == pred_cats).sum() / cats.size(0) 496 497 self.log(f"{mode}/loss_cat", loss_cat) 498 self.log(f"{mode}/loss_attr", loss_attr) 499 self.log(f"{mode}/acc_cat", acc) 500 self.log(f"{mode}/loss", total_loss) 501 502 return total_loss 503 504 def validation_step( # type: ignore 505 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 506 ) -> torch.Tensor: 507 return self.generic_step(batch, "val") 508 509 def training_step( # type: ignore 510 self, 511 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 512 _, 513 ) -> torch.Tensor: 514 data = batch[frozenset(["t", "attr"])] 515 return self.generic_step(data, "train") 516 517 def configure_optimizers( # type: ignore 518 self, 519 ) -> dict[str, Any]: 520 optimizer = AdamW( 521 self.parameters(), 522 lr=self.optim_lr, 523 weight_decay=self.optim_weight_decay, 524 ) 525 526 return {"optimizer": optimizer}
Base class for a DomainModule that defines domain specific modules of the GW.
410 def __init__( 411 self, 412 latent_dim: int, 413 hidden_dim: int, 414 text_model: GRUTextDomainModule, 415 optim_lr: float = 1e-3, 416 optim_weight_decay: float = 0, 417 scheduler_args: SchedulerArgs | None = None, 418 ) -> None: 419 super().__init__(latent_dim) 420 self.save_hyperparameters(ignore=["text_model"]) 421 422 self.text_model = text_model 423 424 self.pred_attr = nn.Sequential( 425 nn.Linear(latent_dim, hidden_dim), 426 nn.ReLU(), 427 nn.Linear(hidden_dim, hidden_dim), 428 nn.ReLU(), 429 nn.Linear(hidden_dim, hidden_dim), 430 nn.ReLU(), 431 nn.Linear(hidden_dim, 8), 432 nn.Tanh(), 433 ) 434 self.pred_cat = nn.Sequential( 435 nn.Linear(latent_dim, hidden_dim), 436 nn.ReLU(), 437 nn.Linear(hidden_dim, hidden_dim), 438 nn.ReLU(), 439 nn.Linear(hidden_dim, hidden_dim), 440 nn.ReLU(), 441 nn.Linear(hidden_dim, 3), 442 nn.Softmax(dim=1), 443 ) 444 445 self.optim_lr = optim_lr 446 self.optim_weight_decay = optim_weight_decay 447 448 self.scheduler_args = SchedulerArgs( 449 max_lr=optim_lr, 450 total_steps=1, 451 ) 452 self.scheduler_args.update(scheduler_args or {})
Initializes a DomainModule.
Arguments:
- latent_dim (
int
): latent dimension of the unimodal module
454 def compute_loss( 455 self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any 456 ) -> LossOutput: 457 pred_attr = self.pred_attr(pred) 458 pred_cat = self.pred_cat(pred) 459 target_attr = self.pred_attr(target) 460 target_cat = self.pred_cat(target) 461 loss_attr = F.mse_loss(pred_attr, target_attr) 462 loss_cat = F.cross_entropy(pred_cat, target_cat.argmax(dim=1)) 463 loss = loss_attr + loss_cat 464 return LossOutput(loss, {"attr": loss_attr, "cat": loss_cat})
Generic loss computation the modality.
Arguments:
- pred (
torch.Tensor
): prediction of the model - target (
torch.Tensor
): target tensor - raw_target (
Any
): raw data from the input
Results:
LossOutput | None
: LossOuput with training loss and additional metrics. IfNone
is returned, this loss will be ignored and will not participate in the total loss.
466 def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: 467 return self.text_model.encode(x)
Encode the domain data into a unimodal representation.
Arguments:
- x (
Any
): data of the domain.
Returns:
torch.Tensor
: a unimodal representation.
469 def decode(self, z: torch.Tensor) -> dict[str, Any]: 470 out: dict[str, Any] = {} 471 out.update(self.text_model.decode(z)) 472 pred_attr = self.pred_attr(z) 473 pred_cat = self.pred_cat(z) 474 out.update({"attr": [pred_cat, pred_attr, pred_attr]}) 475 return out
Decode data from unimodal representation back to the domain data.
Arguments:
- z (
torch.Tensor
): unimodal representation of the domain.
Returns:
Any
: the original domain data.
477 def forward(self, x: Mapping[str, Any]) -> dict[str, list[torch.Tensor]]: 478 return self.decode(self.encode(x))
Same as torch.nn.Module.forward()
.
Arguments:
- *args: Whatever you decide to pass into the forward method.
- **kwargs: Keyword arguments are also possible.
Return:
Your model's output
480 def generic_step( 481 self, 482 x: Mapping[str, Any], 483 mode: str = "train", 484 ) -> torch.Tensor: 485 text: Mapping[str, torch.Tensor] = x["t"] 486 attr: Sequence[torch.Tensor] = x["attr"] 487 text_l = self.text_model.encode(text) 488 pred_attr = self.pred_attr(text_l) 489 pred_cat = self.pred_cat(text_l) 490 cats = torch.argmax(attr[0], dim=1) 491 loss_cat = F.cross_entropy(pred_cat, cats) 492 loss_attr = F.mse_loss(pred_attr, attr[1]) 493 total_loss = loss_cat + loss_attr 494 pred_cats = pred_cat.argmax(dim=1) 495 acc = (cats == pred_cats).sum() / cats.size(0) 496 497 self.log(f"{mode}/loss_cat", loss_cat) 498 self.log(f"{mode}/loss_attr", loss_attr) 499 self.log(f"{mode}/acc_cat", acc) 500 self.log(f"{mode}/loss", total_loss) 501 502 return total_loss
504 def validation_step( # type: ignore 505 self, batch: Mapping[str, Mapping[str, torch.Tensor]], _ 506 ) -> torch.Tensor: 507 return self.generic_step(batch, "val")
Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary. Can include any keys, but must include the key'loss'
.None
- Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples::
# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
x, y = batch
# implement your own
out = self(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# log the outputs!
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders, validation_step()
will have an additional argument. We recommend
setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.
# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
# dataloader_idx tells you which dataset this is.
...
Note:
If you don't need to validate you don't need to implement this method.
Note:
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
509 def training_step( # type: ignore 510 self, 511 batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], 512 _, 513 ) -> torch.Tensor: 514 data = batch[frozenset(["t", "attr"])] 515 return self.generic_step(data, "train")
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
Arguments:
- batch: The output of your data iterable, normally a
~torch.utils.data.DataLoader
. - batch_idx: The index of this batch.
- dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
~torch.Tensor
- The loss tensordict
- A dictionary which can include any keys, but must include the key'loss'
in the case of automatic optimization.None
- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example::
def training_step(self, batch, batch_idx):
x, y, z = batch
out = self.encoder(x)
loss = self.loss(out, x)
return loss
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
def __init__(self):
super().__init__()
self.automatic_optimization = False
# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
opt1, opt2 = self.optimizers()
# do training_step with encoder
...
opt1.step()
# do training_step with decoder
...
opt2.step()
Note:
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
517 def configure_optimizers( # type: ignore 518 self, 519 ) -> dict[str, Any]: 520 optimizer = AdamW( 521 self.parameters(), 522 lr=self.optim_lr, 523 weight_decay=self.optim_weight_decay, 524 ) 525 526 return {"optimizer": optimizer}
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
Return:
Any of these 6 options.
- Single optimizer.
- List or Tuple of optimizers.
- Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).- Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.- None - Fit will run without any optimizer.
The lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration.
The default configuration is shown below.
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
When there are schedulers in which the .step()
method is conditioned on a value, such as the
torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that the
lr_scheduler_config
contains the keyword "monitor"
set to the metric name that the scheduler
should be conditioned on.
.. testcode::
# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
optimizer = Adam(...)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": ReduceLROnPlateau(optimizer, ...),
"monitor": "metric_to_track",
"frequency": "indicates how often the metric is updated",
# If "monitor" references validation metrics, then "frequency" should be set to a
# multiple of "trainer.check_val_every_n_epoch".
},
}
# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
optimizer1 = Adam(...)
optimizer2 = SGD(...)
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
scheduler2 = LambdaLR(optimizer2, ...)
return (
{
"optimizer": optimizer1,
"lr_scheduler": {
"scheduler": scheduler1,
"monitor": "metric_to_track",
},
},
{"optimizer": optimizer2, "lr_scheduler": scheduler2},
)
Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in your ~lightning.pytorch.core.LightningModule
.
Note:
Some things to know:
- Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.- If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default "epoch") in the scheduler configuration, Lightning will call the scheduler's.step()
method automatically in case of automatic optimization.- If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.- If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.- If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
- If you need to control how often the optimizer steps, override the
optimizer_step()
hook.