shimmer_ssd.modules.domains.text

  1from collections.abc import Mapping, Sequence
  2from typing import Any
  3
  4import torch
  5import torch.nn.functional as F
  6from shimmer import LossOutput
  7from shimmer.modules.domain import DomainModule
  8from shimmer.modules.global_workspace import SchedulerArgs
  9from shimmer.modules.vae import (
 10    VAE,
 11    VAEDecoder,
 12    VAEEncoder,
 13    gaussian_nll,
 14    kl_divergence_loss,
 15    reparameterize,
 16)
 17from simple_shapes_dataset.text import composer
 18from simple_shapes_dataset.text.utils import inspect_all_choices
 19from torch import nn
 20from torch.optim.adamw import AdamW
 21from torch.optim.lr_scheduler import OneCycleLR
 22
 23
 24class Encoder(VAEEncoder):
 25    def __init__(
 26        self,
 27        in_dim: int,
 28        hidden_dim: int,
 29        out_dim: int,
 30    ):
 31        super().__init__()
 32
 33        self.in_dim = in_dim
 34        self.hidden_dim = hidden_dim
 35        self.out_dim = out_dim
 36
 37        self.encoder = nn.Sequential(
 38            nn.Linear(self.in_dim, hidden_dim),
 39            nn.ReLU(),
 40            nn.Linear(hidden_dim, hidden_dim),
 41            nn.ReLU(),
 42            nn.Linear(hidden_dim, out_dim),
 43            nn.ReLU(),
 44        )
 45
 46        self.q_mean = nn.Linear(self.out_dim, self.out_dim)
 47        self.q_logvar = nn.Linear(self.out_dim, self.out_dim)
 48
 49    def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
 50        out = torch.cat(list(x), dim=-1)
 51        out = self.encoder(out)
 52        return self.q_mean(out), self.q_logvar(out)
 53
 54
 55class Decoder(VAEDecoder):
 56    def __init__(
 57        self,
 58        in_dim: int,
 59        hidden_dim: int,
 60        out_dim: int,
 61    ):
 62        super().__init__()
 63
 64        self.in_dim = in_dim
 65        self.hidden_dim = hidden_dim
 66        self.out_dim = out_dim
 67
 68        self.decoder = nn.Sequential(
 69            nn.Linear(self.in_dim, self.hidden_dim),
 70            nn.ReLU(),
 71            nn.Linear(self.hidden_dim, self.hidden_dim),
 72            nn.ReLU(),
 73            nn.Linear(self.hidden_dim, self.out_dim),
 74            nn.Tanh(),
 75        )
 76
 77    def forward(self, z: torch.Tensor) -> list[torch.Tensor]:  # type: ignore
 78        return [self.decoder(z)]
 79
 80
 81class TextDomainModule(DomainModule):
 82    in_dim = 768
 83
 84    def __init__(
 85        self,
 86        latent_dim: int,
 87        hidden_dim: int,
 88        beta: float = 1,
 89        optim_lr: float = 1e-3,
 90        optim_weight_decay: float = 0,
 91        scheduler_args: SchedulerArgs | None = None,
 92    ):
 93        super().__init__(latent_dim)
 94        self.save_hyperparameters()
 95
 96        self.hidden_dim = hidden_dim
 97
 98        vae_encoder = Encoder(self.in_dim, self.hidden_dim, self.latent_dim)
 99        vae_decoder = Decoder(self.latent_dim, self.hidden_dim, self.in_dim)
100        self.vae = VAE(vae_encoder, vae_decoder, beta)
101
102        self.attribute_cls = nn.Sequential(
103            nn.Linear(self.latent_dim, self.hidden_dim),
104            nn.ReLU(),
105            nn.Linear(self.hidden_dim, self.hidden_dim),
106            nn.ReLU(),
107        )
108        self.attribute_cls_cat = nn.Linear(self.hidden_dim, 3)
109        self.attribute_cls_attr = nn.Sequential(
110            nn.Linear(self.hidden_dim, 8), nn.Tanh()
111        )
112
113        self.composer_grammar_options = inspect_all_choices(composer)
114
115        self.grammar_cls = nn.Sequential(
116            nn.Linear(self.latent_dim, self.hidden_dim),
117            nn.ReLU(),
118            nn.Linear(self.hidden_dim, self.hidden_dim),
119            nn.ReLU(),
120        )
121        self.grammar_heads = nn.ModuleDict(
122            {
123                name: nn.Linear(self.hidden_dim, n_outputs)
124                for name, n_outputs in self.composer_grammar_options.items()
125            }
126        )
127
128        self.optim_lr = optim_lr
129        self.optim_weight_decay = optim_weight_decay
130
131        self.scheduler_args = SchedulerArgs(
132            max_lr=optim_lr,
133            total_steps=1,
134        )
135        self.scheduler_args.update(scheduler_args or {})
136
137    def compute_loss(
138        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
139    ) -> LossOutput:
140        return LossOutput(F.mse_loss(pred, target, reduction="mean"))
141
142    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
143        return self.vae.encode((x["bert"],))
144
145    def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
146        text: dict[str, torch.Tensor] = {"bert": self.vae.decode(z)[0]}
147        attr_pred_cat, attr_pred_attr = self.predict_attr(z)
148        text["cls"] = attr_pred_cat
149        text["attr"] = attr_pred_attr
150        text["unpaired"] = torch.zeros_like(z[:, -1])
151        text.update(self.predict_grammar(z))
152        return text
153
154    def predict_attr(self, mean: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
155        attr_pred = self.attribute_cls(mean)
156        attr_pred_cat = self.attribute_cls_cat(attr_pred)
157        attr_pred_attr = self.attribute_cls_attr(attr_pred)
158        return attr_pred_cat, attr_pred_attr
159
160    def predict_grammar(self, mean: torch.Tensor) -> dict[str, torch.Tensor]:
161        grammar_pred = self.grammar_cls(mean)
162        return {name: head(grammar_pred) for name, head in self.grammar_heads.items()}
163
164    def grammar_losses(self, mean: torch.Tensor, targets) -> dict[str, torch.Tensor]:
165        grammar_pred = self.predict_grammar(mean)
166        return {
167            f"{name}_ce": F.cross_entropy(
168                pred, targets[name][:, 0].long(), reduction="sum"
169            )
170            for name, pred in grammar_pred.items()
171        }
172
173    def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
174        return self.decode(self.encode(x))
175
176    def generic_step(
177        self,
178        x: Mapping[str, torch.Tensor],
179        mode: str = "train",
180    ) -> torch.Tensor:
181        (mean, logvar), reconstruction = self.vae((x["bert"],))
182
183        reconstruction_loss = gaussian_nll(
184            reconstruction[0], torch.tensor(0), x["bert"]
185        ).sum()
186
187        kl_loss = kl_divergence_loss(mean, logvar)
188
189        attr_pred_cat, attr_pred_attr = self.predict_attr(mean)
190
191        loss_attr_cat = F.cross_entropy(
192            attr_pred_cat, x["cls"].argmax(dim=1), reduction="sum"
193        )
194        loss_attr = F.mse_loss(attr_pred_attr, x["attr"], reduction="sum")
195        grammar_targets = {name: x[name] for name in self.composer_grammar_options}
196        grammar_losses = self.grammar_losses(mean, grammar_targets)
197
198        total_loss = (
199            reconstruction_loss + self.vae.beta * kl_loss + loss_attr_cat + loss_attr
200        )
201
202        for grammar_loss_name, grammar_loss in grammar_losses.items():
203            total_loss += grammar_loss
204            self.log(f"{mode}/{grammar_loss_name}", grammar_loss)
205
206        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
207        self.log(f"{mode}/kl_loss", kl_loss)
208        self.log(f"{mode}/attr_category", loss_attr_cat)
209        self.log(f"{mode}/attr_attr", loss_attr)
210        self.log(f"{mode}/loss", total_loss)
211        return total_loss
212
213    def validation_step(  # type: ignore
214        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
215    ) -> torch.Tensor:
216        x = batch["t"]
217        return self.generic_step(x, "val")
218
219    def training_step(  # type: ignore
220        self,
221        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
222        _,
223    ) -> torch.Tensor:
224        x = batch[frozenset(["t"])]["t"]
225        return self.generic_step(x, "train")
226
227    def configure_optimizers(  # type: ignore
228        self,
229    ) -> dict[str, Any]:
230        optimizer = torch.optim.AdamW(
231            self.parameters(),
232            lr=self.optim_lr,
233            weight_decay=self.optim_weight_decay,
234        )
235        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
236
237        return {
238            "optimizer": optimizer,
239            "lr_scheduler": {
240                "scheduler": lr_scheduler,
241                "interval": "step",
242            },
243        }
244
245
246class GRUEncoder(nn.Module):
247    def __init__(
248        self,
249        in_dim: int,
250        hidden_dim: int,
251        out_dim: int,
252    ):
253        super().__init__()
254
255        self.in_dim = in_dim
256        self.hidden_dim = hidden_dim
257        self.out_dim = out_dim
258
259    def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
260        out = torch.cat(list(x), dim=-1)
261        out = self.encoder(out)
262        return out
263
264
265class GRUTextDomainModule(DomainModule):
266    in_dim = 768
267
268    def __init__(
269        self,
270        latent_dim: int,
271        hidden_dim: int,
272        vocab_size: int,
273        seq_length: int,
274        optim_lr: float = 1e-3,
275        optim_weight_decay: float = 0,
276        scheduler_args: SchedulerArgs | None = None,
277        padding_token: int = 0,
278        reconstruction_coef: float = 0.5,
279        kl_coef: float = 0.05,
280    ):
281        super().__init__(latent_dim)
282        self.is_frozen = False
283        self.save_hyperparameters()
284
285        self.hidden_dim = hidden_dim
286        self.latent_dim = latent_dim
287        self.vocab_size = vocab_size
288        self.seq_length = seq_length
289
290        self._padding_token = padding_token
291
292        self.reconstruction_coef = reconstruction_coef
293        self.kl_coef = kl_coef
294
295        self.projector = nn.Sequential(
296            nn.Linear(self.in_dim, self.hidden_dim),
297            nn.ReLU(),
298            nn.Linear(self.hidden_dim, self.hidden_dim),
299            nn.ReLU(),
300            nn.Linear(self.hidden_dim, self.latent_dim),
301            nn.ReLU(),
302        )
303        self.q_mean = nn.Linear(self.latent_dim, self.latent_dim)
304        self.q_logvar = nn.Linear(self.latent_dim, self.latent_dim)
305
306        self.projector_dec = nn.Sequential(
307            nn.Linear(self.latent_dim, self.hidden_dim),
308            nn.ReLU(),
309            nn.Linear(self.hidden_dim, self.hidden_dim),
310            nn.ReLU(),
311            nn.Linear(self.hidden_dim, self.in_dim),
312        )
313
314        self.embeddings = nn.Embedding(vocab_size, self.latent_dim)
315
316        self.num_layers = 2
317        self.decoder = nn.GRU(
318            self.latent_dim,
319            self.hidden_dim,
320            num_layers=self.num_layers,
321            batch_first=True,
322        )
323        self.text_head = nn.Linear(self.hidden_dim, self.vocab_size)
324
325        self.optim_lr = optim_lr
326        self.optim_weight_decay = optim_weight_decay
327
328        self.scheduler_args = SchedulerArgs(
329            max_lr=optim_lr,
330            total_steps=1,
331        )
332        self.scheduler_args.update(scheduler_args or {})
333
334    def compute_loss(
335        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
336    ) -> LossOutput:
337        mse_loss = torch.nn.functional.mse_loss(pred, target, reduction="mean")
338
339        return LossOutput(mse_loss, {"mse_loss": mse_loss})
340
341    def compute_domain_loss(self, domain: Any) -> LossOutput:
342        z = self.encode(domain)
343        loss, acc = self.text_token_loss(z, domain)
344        return LossOutput(loss, {"acc": acc})
345
346    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
347        z = self.projector(x["bert"])
348        return self.q_mean(z)
349
350    def encode_dist(
351        self, x: Mapping[str, torch.Tensor]
352    ) -> tuple[torch.Tensor, torch.Tensor]:
353        z = self.projector(x["bert"])
354        return self.q_mean(z), self.q_logvar(z)
355
356    def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
357        context = z.unsqueeze(1)
358        hidden = None
359        outputs = []
360
361        for _ in range(self.seq_length):
362            out, hidden = self.decoder(context, hidden)
363            token_dist = self.text_head(out)
364            tokens = torch.argmax(token_dist, dim=-1)
365            outputs.append(token_dist)
366
367            context = self.embeddings(tokens)
368
369        token_dists = torch.cat(outputs, dim=1)
370        tokens = torch.argmax(token_dists, dim=-1)
371        return {"token_dist": token_dists, "tokens": tokens}
372
373    def decode_one(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
374        out, _ = self.decoder(z)
375        tokens_dist = self.text_head(out)
376        tokens = torch.argmax(tokens_dist, -1)
377        return {"token_dist": tokens_dist, "tokens": tokens}
378
379    def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
380        z, _ = self.encode_dist(x)
381        return self.decode(z)
382
383    def text_token_loss(
384        self, z: torch.Tensor, target: Mapping[str, torch.Tensor]
385    ) -> tuple[torch.Tensor, torch.Tensor]:
386        context = z.unsqueeze(1)
387        real_tokens = self.embeddings(target["tokens"][:, :-1])
388        seq = torch.cat([context, real_tokens], dim=1)
389
390        out = self.decode_one(seq)
391
392        loss = F.cross_entropy(out["token_dist"].transpose(1, 2), target["tokens"])
393        padding_mask = target["tokens"] != self._padding_token
394
395        padded_out = out["tokens"][padding_mask]
396        padded_target = target["tokens"][padding_mask]
397
398        acc = (padded_out == padded_target).sum() / padded_out.size(0)
399        return loss, acc
400
401    def generic_step(
402        self, x: Mapping[str, torch.Tensor], mode: str = "train"
403    ) -> torch.Tensor:
404        mean, logvar = self.encode_dist(x)
405        z = reparameterize(mean, logvar)
406        noise = torch.randn_like(z) * 0.5
407        z_text = z + noise
408        loss, acc = self.text_token_loss(z_text, x)
409
410        x_hat = self.projector_dec(z)
411        reconstruction_loss = gaussian_nll(x_hat, torch.tensor(0), x["bert"]).sum()
412        kl_loss = kl_divergence_loss(mean, logvar)
413
414        total_loss = (
415            loss
416            + self.reconstruction_coef * reconstruction_loss
417            + self.kl_coef * kl_loss
418        )
419
420        self.log(f"{mode}/kl_loss", kl_loss)
421        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
422        self.log(f"{mode}/total_loss", total_loss)
423        self.log(f"{mode}/loss", loss)
424        self.log(f"{mode}/acc", acc)
425        return total_loss  # loss
426
427    def validation_step(  # type: ignore
428        self,
429        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
430        _,
431    ) -> torch.Tensor:
432        x = batch[frozenset(["t"])]["t"]
433        return self.generic_step(x, "val")
434
435    def training_step(  # type: ignore
436        self,
437        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
438        _,
439    ) -> torch.Tensor:
440        x = batch[frozenset(["t"])]["t"]
441        return self.generic_step(x, "train")
442
443    def configure_optimizers(  # type: ignore
444        self,
445    ) -> dict[str, Any]:
446        optimizer = AdamW(
447            self.parameters(),
448            lr=self.optim_lr,
449            weight_decay=self.optim_weight_decay,
450        )
451
452        return {"optimizer": optimizer}
453
454
455class Text2Attr(DomainModule):
456    def __init__(
457        self,
458        latent_dim: int,
459        hidden_dim: int,
460        text_model: GRUTextDomainModule,
461        optim_lr: float = 1e-3,
462        optim_weight_decay: float = 0,
463        scheduler_args: SchedulerArgs | None = None,
464    ) -> None:
465        super().__init__(latent_dim)
466        self.save_hyperparameters(ignore=["text_model"])
467
468        self.text_model = text_model
469
470        self.pred_attr = nn.Sequential(
471            nn.Linear(latent_dim, hidden_dim),
472            nn.ReLU(),
473            nn.Linear(hidden_dim, hidden_dim),
474            nn.ReLU(),
475            nn.Linear(hidden_dim, hidden_dim),
476            nn.ReLU(),
477            nn.Linear(hidden_dim, 8),
478            nn.Tanh(),
479        )
480        self.pred_cat = nn.Sequential(
481            nn.Linear(latent_dim, hidden_dim),
482            nn.ReLU(),
483            nn.Linear(hidden_dim, hidden_dim),
484            nn.ReLU(),
485            nn.Linear(hidden_dim, hidden_dim),
486            nn.ReLU(),
487            nn.Linear(hidden_dim, 3),
488            nn.Softmax(dim=1),
489        )
490
491        self.optim_lr = optim_lr
492        self.optim_weight_decay = optim_weight_decay
493
494        self.scheduler_args = SchedulerArgs(
495            max_lr=optim_lr,
496            total_steps=1,
497        )
498        self.scheduler_args.update(scheduler_args or {})
499
500    def compute_loss(
501        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
502    ) -> LossOutput:
503        pred_attr = self.pred_attr(pred)
504        pred_cat = self.pred_cat(pred)
505        target_attr = self.pred_attr(target)
506        target_cat = self.pred_cat(target)
507        loss_attr = F.mse_loss(pred_attr, target_attr)
508        loss_cat = F.cross_entropy(pred_cat, target_cat.argmax(dim=1))
509        loss = loss_attr + loss_cat
510        return LossOutput(loss, {"attr": loss_attr, "cat": loss_cat})
511
512    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
513        return self.text_model.encode(x)
514
515    def decode(self, z: torch.Tensor) -> dict[str, Any]:
516        out: dict[str, Any] = {}
517        out.update(self.text_model.decode(z))
518        pred_attr = self.pred_attr(z)
519        pred_cat = self.pred_cat(z)
520        out.update({"attr": [pred_cat, pred_attr, pred_attr]})
521        return out
522
523    def forward(self, x: Mapping[str, Any]) -> dict[str, list[torch.Tensor]]:
524        return self.decode(self.encode(x))
525
526    def generic_step(
527        self,
528        x: Mapping[str, Any],
529        mode: str = "train",
530    ) -> torch.Tensor:
531        text: Mapping[str, torch.Tensor] = x["t"]
532        attr: Sequence[torch.Tensor] = x["attr"]
533        text_l = self.text_model.encode(text)
534        pred_attr = self.pred_attr(text_l)
535        pred_cat = self.pred_cat(text_l)
536        cats = torch.argmax(attr[0], dim=1)
537        loss_cat = F.cross_entropy(pred_cat, cats)
538        loss_attr = F.mse_loss(pred_attr, attr[1])
539        total_loss = loss_cat + loss_attr
540        pred_cats = pred_cat.argmax(dim=1)
541        acc = (cats == pred_cats).sum() / cats.size(0)
542
543        self.log(f"{mode}/loss_cat", loss_cat)
544        self.log(f"{mode}/loss_attr", loss_attr)
545        self.log(f"{mode}/acc_cat", acc)
546        self.log(f"{mode}/loss", total_loss)
547
548        return total_loss
549
550    def validation_step(  # type: ignore
551        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
552    ) -> torch.Tensor:
553        return self.generic_step(batch, "val")
554
555    def training_step(  # type: ignore
556        self,
557        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
558        _,
559    ) -> torch.Tensor:
560        data = batch[frozenset(["t", "attr"])]
561        return self.generic_step(data, "train")
562
563    def configure_optimizers(  # type: ignore
564        self,
565    ) -> dict[str, Any]:
566        optimizer = AdamW(
567            self.parameters(),
568            lr=self.optim_lr,
569            weight_decay=self.optim_weight_decay,
570        )
571
572        return {"optimizer": optimizer}
class Encoder(shimmer.modules.vae.VAEEncoder):
25class Encoder(VAEEncoder):
26    def __init__(
27        self,
28        in_dim: int,
29        hidden_dim: int,
30        out_dim: int,
31    ):
32        super().__init__()
33
34        self.in_dim = in_dim
35        self.hidden_dim = hidden_dim
36        self.out_dim = out_dim
37
38        self.encoder = nn.Sequential(
39            nn.Linear(self.in_dim, hidden_dim),
40            nn.ReLU(),
41            nn.Linear(hidden_dim, hidden_dim),
42            nn.ReLU(),
43            nn.Linear(hidden_dim, out_dim),
44            nn.ReLU(),
45        )
46
47        self.q_mean = nn.Linear(self.out_dim, self.out_dim)
48        self.q_logvar = nn.Linear(self.out_dim, self.out_dim)
49
50    def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
51        out = torch.cat(list(x), dim=-1)
52        out = self.encoder(out)
53        return self.q_mean(out), self.q_logvar(out)

Base class for a VAE encoder.

Encoder(in_dim: int, hidden_dim: int, out_dim: int)
26    def __init__(
27        self,
28        in_dim: int,
29        hidden_dim: int,
30        out_dim: int,
31    ):
32        super().__init__()
33
34        self.in_dim = in_dim
35        self.hidden_dim = hidden_dim
36        self.out_dim = out_dim
37
38        self.encoder = nn.Sequential(
39            nn.Linear(self.in_dim, hidden_dim),
40            nn.ReLU(),
41            nn.Linear(hidden_dim, hidden_dim),
42            nn.ReLU(),
43            nn.Linear(hidden_dim, out_dim),
44            nn.ReLU(),
45        )
46
47        self.q_mean = nn.Linear(self.out_dim, self.out_dim)
48        self.q_logvar = nn.Linear(self.out_dim, self.out_dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_dim
hidden_dim
out_dim
encoder
q_mean
q_logvar
def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
50    def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
51        out = torch.cat(list(x), dim=-1)
52        out = self.encoder(out)
53        return self.q_mean(out), self.q_logvar(out)

Encode representation with VAE.

Arguments:
  • x (Any): Some input value
Returns:

tuple[torch.Tensor, torch.Tensor]: the mean and log variance

class Decoder(shimmer.modules.vae.VAEDecoder):
56class Decoder(VAEDecoder):
57    def __init__(
58        self,
59        in_dim: int,
60        hidden_dim: int,
61        out_dim: int,
62    ):
63        super().__init__()
64
65        self.in_dim = in_dim
66        self.hidden_dim = hidden_dim
67        self.out_dim = out_dim
68
69        self.decoder = nn.Sequential(
70            nn.Linear(self.in_dim, self.hidden_dim),
71            nn.ReLU(),
72            nn.Linear(self.hidden_dim, self.hidden_dim),
73            nn.ReLU(),
74            nn.Linear(self.hidden_dim, self.out_dim),
75            nn.Tanh(),
76        )
77
78    def forward(self, z: torch.Tensor) -> list[torch.Tensor]:  # type: ignore
79        return [self.decoder(z)]

Base class for a VAE decoder.

Decoder(in_dim: int, hidden_dim: int, out_dim: int)
57    def __init__(
58        self,
59        in_dim: int,
60        hidden_dim: int,
61        out_dim: int,
62    ):
63        super().__init__()
64
65        self.in_dim = in_dim
66        self.hidden_dim = hidden_dim
67        self.out_dim = out_dim
68
69        self.decoder = nn.Sequential(
70            nn.Linear(self.in_dim, self.hidden_dim),
71            nn.ReLU(),
72            nn.Linear(self.hidden_dim, self.hidden_dim),
73            nn.ReLU(),
74            nn.Linear(self.hidden_dim, self.out_dim),
75            nn.Tanh(),
76        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_dim
hidden_dim
out_dim
decoder
def forward(self, z: torch.Tensor) -> list[torch.Tensor]:
78    def forward(self, z: torch.Tensor) -> list[torch.Tensor]:  # type: ignore
79        return [self.decoder(z)]

Decode representation with VAE

Arguments:
  • x (torch.Tensor): VAE latent representation representation
Returns:

Any: the reconstructed input

class TextDomainModule(shimmer.modules.domain.DomainModule):
 82class TextDomainModule(DomainModule):
 83    in_dim = 768
 84
 85    def __init__(
 86        self,
 87        latent_dim: int,
 88        hidden_dim: int,
 89        beta: float = 1,
 90        optim_lr: float = 1e-3,
 91        optim_weight_decay: float = 0,
 92        scheduler_args: SchedulerArgs | None = None,
 93    ):
 94        super().__init__(latent_dim)
 95        self.save_hyperparameters()
 96
 97        self.hidden_dim = hidden_dim
 98
 99        vae_encoder = Encoder(self.in_dim, self.hidden_dim, self.latent_dim)
100        vae_decoder = Decoder(self.latent_dim, self.hidden_dim, self.in_dim)
101        self.vae = VAE(vae_encoder, vae_decoder, beta)
102
103        self.attribute_cls = nn.Sequential(
104            nn.Linear(self.latent_dim, self.hidden_dim),
105            nn.ReLU(),
106            nn.Linear(self.hidden_dim, self.hidden_dim),
107            nn.ReLU(),
108        )
109        self.attribute_cls_cat = nn.Linear(self.hidden_dim, 3)
110        self.attribute_cls_attr = nn.Sequential(
111            nn.Linear(self.hidden_dim, 8), nn.Tanh()
112        )
113
114        self.composer_grammar_options = inspect_all_choices(composer)
115
116        self.grammar_cls = nn.Sequential(
117            nn.Linear(self.latent_dim, self.hidden_dim),
118            nn.ReLU(),
119            nn.Linear(self.hidden_dim, self.hidden_dim),
120            nn.ReLU(),
121        )
122        self.grammar_heads = nn.ModuleDict(
123            {
124                name: nn.Linear(self.hidden_dim, n_outputs)
125                for name, n_outputs in self.composer_grammar_options.items()
126            }
127        )
128
129        self.optim_lr = optim_lr
130        self.optim_weight_decay = optim_weight_decay
131
132        self.scheduler_args = SchedulerArgs(
133            max_lr=optim_lr,
134            total_steps=1,
135        )
136        self.scheduler_args.update(scheduler_args or {})
137
138    def compute_loss(
139        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
140    ) -> LossOutput:
141        return LossOutput(F.mse_loss(pred, target, reduction="mean"))
142
143    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
144        return self.vae.encode((x["bert"],))
145
146    def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
147        text: dict[str, torch.Tensor] = {"bert": self.vae.decode(z)[0]}
148        attr_pred_cat, attr_pred_attr = self.predict_attr(z)
149        text["cls"] = attr_pred_cat
150        text["attr"] = attr_pred_attr
151        text["unpaired"] = torch.zeros_like(z[:, -1])
152        text.update(self.predict_grammar(z))
153        return text
154
155    def predict_attr(self, mean: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
156        attr_pred = self.attribute_cls(mean)
157        attr_pred_cat = self.attribute_cls_cat(attr_pred)
158        attr_pred_attr = self.attribute_cls_attr(attr_pred)
159        return attr_pred_cat, attr_pred_attr
160
161    def predict_grammar(self, mean: torch.Tensor) -> dict[str, torch.Tensor]:
162        grammar_pred = self.grammar_cls(mean)
163        return {name: head(grammar_pred) for name, head in self.grammar_heads.items()}
164
165    def grammar_losses(self, mean: torch.Tensor, targets) -> dict[str, torch.Tensor]:
166        grammar_pred = self.predict_grammar(mean)
167        return {
168            f"{name}_ce": F.cross_entropy(
169                pred, targets[name][:, 0].long(), reduction="sum"
170            )
171            for name, pred in grammar_pred.items()
172        }
173
174    def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
175        return self.decode(self.encode(x))
176
177    def generic_step(
178        self,
179        x: Mapping[str, torch.Tensor],
180        mode: str = "train",
181    ) -> torch.Tensor:
182        (mean, logvar), reconstruction = self.vae((x["bert"],))
183
184        reconstruction_loss = gaussian_nll(
185            reconstruction[0], torch.tensor(0), x["bert"]
186        ).sum()
187
188        kl_loss = kl_divergence_loss(mean, logvar)
189
190        attr_pred_cat, attr_pred_attr = self.predict_attr(mean)
191
192        loss_attr_cat = F.cross_entropy(
193            attr_pred_cat, x["cls"].argmax(dim=1), reduction="sum"
194        )
195        loss_attr = F.mse_loss(attr_pred_attr, x["attr"], reduction="sum")
196        grammar_targets = {name: x[name] for name in self.composer_grammar_options}
197        grammar_losses = self.grammar_losses(mean, grammar_targets)
198
199        total_loss = (
200            reconstruction_loss + self.vae.beta * kl_loss + loss_attr_cat + loss_attr
201        )
202
203        for grammar_loss_name, grammar_loss in grammar_losses.items():
204            total_loss += grammar_loss
205            self.log(f"{mode}/{grammar_loss_name}", grammar_loss)
206
207        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
208        self.log(f"{mode}/kl_loss", kl_loss)
209        self.log(f"{mode}/attr_category", loss_attr_cat)
210        self.log(f"{mode}/attr_attr", loss_attr)
211        self.log(f"{mode}/loss", total_loss)
212        return total_loss
213
214    def validation_step(  # type: ignore
215        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
216    ) -> torch.Tensor:
217        x = batch["t"]
218        return self.generic_step(x, "val")
219
220    def training_step(  # type: ignore
221        self,
222        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
223        _,
224    ) -> torch.Tensor:
225        x = batch[frozenset(["t"])]["t"]
226        return self.generic_step(x, "train")
227
228    def configure_optimizers(  # type: ignore
229        self,
230    ) -> dict[str, Any]:
231        optimizer = torch.optim.AdamW(
232            self.parameters(),
233            lr=self.optim_lr,
234            weight_decay=self.optim_weight_decay,
235        )
236        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
237
238        return {
239            "optimizer": optimizer,
240            "lr_scheduler": {
241                "scheduler": lr_scheduler,
242                "interval": "step",
243            },
244        }

Base class for a DomainModule that defines domain specific modules of the GW.

TextDomainModule( latent_dim: int, hidden_dim: int, beta: float = 1, optim_lr: float = 0.001, optim_weight_decay: float = 0, scheduler_args: shimmer.modules.global_workspace.SchedulerArgs | None = None)
 85    def __init__(
 86        self,
 87        latent_dim: int,
 88        hidden_dim: int,
 89        beta: float = 1,
 90        optim_lr: float = 1e-3,
 91        optim_weight_decay: float = 0,
 92        scheduler_args: SchedulerArgs | None = None,
 93    ):
 94        super().__init__(latent_dim)
 95        self.save_hyperparameters()
 96
 97        self.hidden_dim = hidden_dim
 98
 99        vae_encoder = Encoder(self.in_dim, self.hidden_dim, self.latent_dim)
100        vae_decoder = Decoder(self.latent_dim, self.hidden_dim, self.in_dim)
101        self.vae = VAE(vae_encoder, vae_decoder, beta)
102
103        self.attribute_cls = nn.Sequential(
104            nn.Linear(self.latent_dim, self.hidden_dim),
105            nn.ReLU(),
106            nn.Linear(self.hidden_dim, self.hidden_dim),
107            nn.ReLU(),
108        )
109        self.attribute_cls_cat = nn.Linear(self.hidden_dim, 3)
110        self.attribute_cls_attr = nn.Sequential(
111            nn.Linear(self.hidden_dim, 8), nn.Tanh()
112        )
113
114        self.composer_grammar_options = inspect_all_choices(composer)
115
116        self.grammar_cls = nn.Sequential(
117            nn.Linear(self.latent_dim, self.hidden_dim),
118            nn.ReLU(),
119            nn.Linear(self.hidden_dim, self.hidden_dim),
120            nn.ReLU(),
121        )
122        self.grammar_heads = nn.ModuleDict(
123            {
124                name: nn.Linear(self.hidden_dim, n_outputs)
125                for name, n_outputs in self.composer_grammar_options.items()
126            }
127        )
128
129        self.optim_lr = optim_lr
130        self.optim_weight_decay = optim_weight_decay
131
132        self.scheduler_args = SchedulerArgs(
133            max_lr=optim_lr,
134            total_steps=1,
135        )
136        self.scheduler_args.update(scheduler_args or {})

Initializes a DomainModule.

Arguments:
  • latent_dim (int): latent dimension of the unimodal module
in_dim = 768
hidden_dim
vae
attribute_cls
attribute_cls_cat
attribute_cls_attr
composer_grammar_options
grammar_cls
grammar_heads
optim_lr
optim_weight_decay
scheduler_args
def compute_loss( self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> shimmer.modules.domain.LossOutput:
138    def compute_loss(
139        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
140    ) -> LossOutput:
141        return LossOutput(F.mse_loss(pred, target, reduction="mean"))

Generic loss computation the modality.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
  • raw_target (Any): raw data from the input
Results:

LossOutput | None: LossOuput with training loss and additional metrics. If None is returned, this loss will be ignored and will not participate in the total loss.

def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
143    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
144        return self.vae.encode((x["bert"],))

Encode the domain data into a unimodal representation.

Arguments:
  • x (Any): data of the domain.
Returns:

torch.Tensor: a unimodal representation.

def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
146    def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
147        text: dict[str, torch.Tensor] = {"bert": self.vae.decode(z)[0]}
148        attr_pred_cat, attr_pred_attr = self.predict_attr(z)
149        text["cls"] = attr_pred_cat
150        text["attr"] = attr_pred_attr
151        text["unpaired"] = torch.zeros_like(z[:, -1])
152        text.update(self.predict_grammar(z))
153        return text

Decode data from unimodal representation back to the domain data.

Arguments:
  • z (torch.Tensor): unimodal representation of the domain.
Returns:

Any: the original domain data.

def predict_attr(self, mean: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
155    def predict_attr(self, mean: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
156        attr_pred = self.attribute_cls(mean)
157        attr_pred_cat = self.attribute_cls_cat(attr_pred)
158        attr_pred_attr = self.attribute_cls_attr(attr_pred)
159        return attr_pred_cat, attr_pred_attr
def predict_grammar(self, mean: torch.Tensor) -> dict[str, torch.Tensor]:
161    def predict_grammar(self, mean: torch.Tensor) -> dict[str, torch.Tensor]:
162        grammar_pred = self.grammar_cls(mean)
163        return {name: head(grammar_pred) for name, head in self.grammar_heads.items()}
def grammar_losses(self, mean: torch.Tensor, targets) -> dict[str, torch.Tensor]:
165    def grammar_losses(self, mean: torch.Tensor, targets) -> dict[str, torch.Tensor]:
166        grammar_pred = self.predict_grammar(mean)
167        return {
168            f"{name}_ce": F.cross_entropy(
169                pred, targets[name][:, 0].long(), reduction="sum"
170            )
171            for name, pred in grammar_pred.items()
172        }
def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
174    def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
175        return self.decode(self.encode(x))

Same as torch.nn.Module.forward().

Arguments:
  • *args: Whatever you decide to pass into the forward method.
  • **kwargs: Keyword arguments are also possible.
Return:

Your model's output

def generic_step(self, x: Mapping[str, torch.Tensor], mode: str = 'train') -> torch.Tensor:
177    def generic_step(
178        self,
179        x: Mapping[str, torch.Tensor],
180        mode: str = "train",
181    ) -> torch.Tensor:
182        (mean, logvar), reconstruction = self.vae((x["bert"],))
183
184        reconstruction_loss = gaussian_nll(
185            reconstruction[0], torch.tensor(0), x["bert"]
186        ).sum()
187
188        kl_loss = kl_divergence_loss(mean, logvar)
189
190        attr_pred_cat, attr_pred_attr = self.predict_attr(mean)
191
192        loss_attr_cat = F.cross_entropy(
193            attr_pred_cat, x["cls"].argmax(dim=1), reduction="sum"
194        )
195        loss_attr = F.mse_loss(attr_pred_attr, x["attr"], reduction="sum")
196        grammar_targets = {name: x[name] for name in self.composer_grammar_options}
197        grammar_losses = self.grammar_losses(mean, grammar_targets)
198
199        total_loss = (
200            reconstruction_loss + self.vae.beta * kl_loss + loss_attr_cat + loss_attr
201        )
202
203        for grammar_loss_name, grammar_loss in grammar_losses.items():
204            total_loss += grammar_loss
205            self.log(f"{mode}/{grammar_loss_name}", grammar_loss)
206
207        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
208        self.log(f"{mode}/kl_loss", kl_loss)
209        self.log(f"{mode}/attr_category", loss_attr_cat)
210        self.log(f"{mode}/attr_attr", loss_attr)
211        self.log(f"{mode}/loss", total_loss)
212        return total_loss
def validation_step(self, batch: Mapping[str, Mapping[str, torch.Tensor]], _) -> torch.Tensor:
214    def validation_step(  # type: ignore
215        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
216    ) -> torch.Tensor:
217        x = batch["t"]
218        return self.generic_step(x, "val")

Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary. Can include any keys, but must include the key 'loss'.
  • None - Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples::

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...
Note:

If you don't need to validate you don't need to implement this method.

Note:

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

def training_step( self, batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], _) -> torch.Tensor:
220    def training_step(  # type: ignore
221        self,
222        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
223        _,
224    ) -> torch.Tensor:
225        x = batch[frozenset(["t"])]["t"]
226        return self.generic_step(x, "train")

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.
  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example::

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()
Note:

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

def configure_optimizers(self) -> dict[str, typing.Any]:
228    def configure_optimizers(  # type: ignore
229        self,
230    ) -> dict[str, Any]:
231        optimizer = torch.optim.AdamW(
232            self.parameters(),
233            lr=self.optim_lr,
234            weight_decay=self.optim_weight_decay,
235        )
236        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
237
238        return {
239            "optimizer": optimizer,
240            "lr_scheduler": {
241                "scheduler": lr_scheduler,
242                "interval": "step",
243            },
244        }

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Return:

Any of these 6 options.

  • Single optimizer.
  • List or Tuple of optimizers.
  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).
  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.
  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

.. testcode::

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated",
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your ~lightning.pytorch.core.LightningModule.

Note:

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.
  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default "epoch") in the scheduler configuration, Lightning will call the scheduler's .step() method automatically in case of automatic optimization.
  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.
  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.
  • If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
  • If you need to control how often the optimizer steps, override the optimizer_step() hook.
class GRUEncoder(torch.nn.modules.module.Module):
247class GRUEncoder(nn.Module):
248    def __init__(
249        self,
250        in_dim: int,
251        hidden_dim: int,
252        out_dim: int,
253    ):
254        super().__init__()
255
256        self.in_dim = in_dim
257        self.hidden_dim = hidden_dim
258        self.out_dim = out_dim
259
260    def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
261        out = torch.cat(list(x), dim=-1)
262        out = self.encoder(out)
263        return out

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

GRUEncoder(in_dim: int, hidden_dim: int, out_dim: int)
248    def __init__(
249        self,
250        in_dim: int,
251        hidden_dim: int,
252        out_dim: int,
253    ):
254        super().__init__()
255
256        self.in_dim = in_dim
257        self.hidden_dim = hidden_dim
258        self.out_dim = out_dim

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_dim
hidden_dim
out_dim
def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
260    def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
261        out = torch.cat(list(x), dim=-1)
262        out = self.encoder(out)
263        return out

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class GRUTextDomainModule(shimmer.modules.domain.DomainModule):
266class GRUTextDomainModule(DomainModule):
267    in_dim = 768
268
269    def __init__(
270        self,
271        latent_dim: int,
272        hidden_dim: int,
273        vocab_size: int,
274        seq_length: int,
275        optim_lr: float = 1e-3,
276        optim_weight_decay: float = 0,
277        scheduler_args: SchedulerArgs | None = None,
278        padding_token: int = 0,
279        reconstruction_coef: float = 0.5,
280        kl_coef: float = 0.05,
281    ):
282        super().__init__(latent_dim)
283        self.is_frozen = False
284        self.save_hyperparameters()
285
286        self.hidden_dim = hidden_dim
287        self.latent_dim = latent_dim
288        self.vocab_size = vocab_size
289        self.seq_length = seq_length
290
291        self._padding_token = padding_token
292
293        self.reconstruction_coef = reconstruction_coef
294        self.kl_coef = kl_coef
295
296        self.projector = nn.Sequential(
297            nn.Linear(self.in_dim, self.hidden_dim),
298            nn.ReLU(),
299            nn.Linear(self.hidden_dim, self.hidden_dim),
300            nn.ReLU(),
301            nn.Linear(self.hidden_dim, self.latent_dim),
302            nn.ReLU(),
303        )
304        self.q_mean = nn.Linear(self.latent_dim, self.latent_dim)
305        self.q_logvar = nn.Linear(self.latent_dim, self.latent_dim)
306
307        self.projector_dec = nn.Sequential(
308            nn.Linear(self.latent_dim, self.hidden_dim),
309            nn.ReLU(),
310            nn.Linear(self.hidden_dim, self.hidden_dim),
311            nn.ReLU(),
312            nn.Linear(self.hidden_dim, self.in_dim),
313        )
314
315        self.embeddings = nn.Embedding(vocab_size, self.latent_dim)
316
317        self.num_layers = 2
318        self.decoder = nn.GRU(
319            self.latent_dim,
320            self.hidden_dim,
321            num_layers=self.num_layers,
322            batch_first=True,
323        )
324        self.text_head = nn.Linear(self.hidden_dim, self.vocab_size)
325
326        self.optim_lr = optim_lr
327        self.optim_weight_decay = optim_weight_decay
328
329        self.scheduler_args = SchedulerArgs(
330            max_lr=optim_lr,
331            total_steps=1,
332        )
333        self.scheduler_args.update(scheduler_args or {})
334
335    def compute_loss(
336        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
337    ) -> LossOutput:
338        mse_loss = torch.nn.functional.mse_loss(pred, target, reduction="mean")
339
340        return LossOutput(mse_loss, {"mse_loss": mse_loss})
341
342    def compute_domain_loss(self, domain: Any) -> LossOutput:
343        z = self.encode(domain)
344        loss, acc = self.text_token_loss(z, domain)
345        return LossOutput(loss, {"acc": acc})
346
347    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
348        z = self.projector(x["bert"])
349        return self.q_mean(z)
350
351    def encode_dist(
352        self, x: Mapping[str, torch.Tensor]
353    ) -> tuple[torch.Tensor, torch.Tensor]:
354        z = self.projector(x["bert"])
355        return self.q_mean(z), self.q_logvar(z)
356
357    def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
358        context = z.unsqueeze(1)
359        hidden = None
360        outputs = []
361
362        for _ in range(self.seq_length):
363            out, hidden = self.decoder(context, hidden)
364            token_dist = self.text_head(out)
365            tokens = torch.argmax(token_dist, dim=-1)
366            outputs.append(token_dist)
367
368            context = self.embeddings(tokens)
369
370        token_dists = torch.cat(outputs, dim=1)
371        tokens = torch.argmax(token_dists, dim=-1)
372        return {"token_dist": token_dists, "tokens": tokens}
373
374    def decode_one(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
375        out, _ = self.decoder(z)
376        tokens_dist = self.text_head(out)
377        tokens = torch.argmax(tokens_dist, -1)
378        return {"token_dist": tokens_dist, "tokens": tokens}
379
380    def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
381        z, _ = self.encode_dist(x)
382        return self.decode(z)
383
384    def text_token_loss(
385        self, z: torch.Tensor, target: Mapping[str, torch.Tensor]
386    ) -> tuple[torch.Tensor, torch.Tensor]:
387        context = z.unsqueeze(1)
388        real_tokens = self.embeddings(target["tokens"][:, :-1])
389        seq = torch.cat([context, real_tokens], dim=1)
390
391        out = self.decode_one(seq)
392
393        loss = F.cross_entropy(out["token_dist"].transpose(1, 2), target["tokens"])
394        padding_mask = target["tokens"] != self._padding_token
395
396        padded_out = out["tokens"][padding_mask]
397        padded_target = target["tokens"][padding_mask]
398
399        acc = (padded_out == padded_target).sum() / padded_out.size(0)
400        return loss, acc
401
402    def generic_step(
403        self, x: Mapping[str, torch.Tensor], mode: str = "train"
404    ) -> torch.Tensor:
405        mean, logvar = self.encode_dist(x)
406        z = reparameterize(mean, logvar)
407        noise = torch.randn_like(z) * 0.5
408        z_text = z + noise
409        loss, acc = self.text_token_loss(z_text, x)
410
411        x_hat = self.projector_dec(z)
412        reconstruction_loss = gaussian_nll(x_hat, torch.tensor(0), x["bert"]).sum()
413        kl_loss = kl_divergence_loss(mean, logvar)
414
415        total_loss = (
416            loss
417            + self.reconstruction_coef * reconstruction_loss
418            + self.kl_coef * kl_loss
419        )
420
421        self.log(f"{mode}/kl_loss", kl_loss)
422        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
423        self.log(f"{mode}/total_loss", total_loss)
424        self.log(f"{mode}/loss", loss)
425        self.log(f"{mode}/acc", acc)
426        return total_loss  # loss
427
428    def validation_step(  # type: ignore
429        self,
430        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
431        _,
432    ) -> torch.Tensor:
433        x = batch[frozenset(["t"])]["t"]
434        return self.generic_step(x, "val")
435
436    def training_step(  # type: ignore
437        self,
438        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
439        _,
440    ) -> torch.Tensor:
441        x = batch[frozenset(["t"])]["t"]
442        return self.generic_step(x, "train")
443
444    def configure_optimizers(  # type: ignore
445        self,
446    ) -> dict[str, Any]:
447        optimizer = AdamW(
448            self.parameters(),
449            lr=self.optim_lr,
450            weight_decay=self.optim_weight_decay,
451        )
452
453        return {"optimizer": optimizer}

Base class for a DomainModule that defines domain specific modules of the GW.

GRUTextDomainModule( latent_dim: int, hidden_dim: int, vocab_size: int, seq_length: int, optim_lr: float = 0.001, optim_weight_decay: float = 0, scheduler_args: shimmer.modules.global_workspace.SchedulerArgs | None = None, padding_token: int = 0, reconstruction_coef: float = 0.5, kl_coef: float = 0.05)
269    def __init__(
270        self,
271        latent_dim: int,
272        hidden_dim: int,
273        vocab_size: int,
274        seq_length: int,
275        optim_lr: float = 1e-3,
276        optim_weight_decay: float = 0,
277        scheduler_args: SchedulerArgs | None = None,
278        padding_token: int = 0,
279        reconstruction_coef: float = 0.5,
280        kl_coef: float = 0.05,
281    ):
282        super().__init__(latent_dim)
283        self.is_frozen = False
284        self.save_hyperparameters()
285
286        self.hidden_dim = hidden_dim
287        self.latent_dim = latent_dim
288        self.vocab_size = vocab_size
289        self.seq_length = seq_length
290
291        self._padding_token = padding_token
292
293        self.reconstruction_coef = reconstruction_coef
294        self.kl_coef = kl_coef
295
296        self.projector = nn.Sequential(
297            nn.Linear(self.in_dim, self.hidden_dim),
298            nn.ReLU(),
299            nn.Linear(self.hidden_dim, self.hidden_dim),
300            nn.ReLU(),
301            nn.Linear(self.hidden_dim, self.latent_dim),
302            nn.ReLU(),
303        )
304        self.q_mean = nn.Linear(self.latent_dim, self.latent_dim)
305        self.q_logvar = nn.Linear(self.latent_dim, self.latent_dim)
306
307        self.projector_dec = nn.Sequential(
308            nn.Linear(self.latent_dim, self.hidden_dim),
309            nn.ReLU(),
310            nn.Linear(self.hidden_dim, self.hidden_dim),
311            nn.ReLU(),
312            nn.Linear(self.hidden_dim, self.in_dim),
313        )
314
315        self.embeddings = nn.Embedding(vocab_size, self.latent_dim)
316
317        self.num_layers = 2
318        self.decoder = nn.GRU(
319            self.latent_dim,
320            self.hidden_dim,
321            num_layers=self.num_layers,
322            batch_first=True,
323        )
324        self.text_head = nn.Linear(self.hidden_dim, self.vocab_size)
325
326        self.optim_lr = optim_lr
327        self.optim_weight_decay = optim_weight_decay
328
329        self.scheduler_args = SchedulerArgs(
330            max_lr=optim_lr,
331            total_steps=1,
332        )
333        self.scheduler_args.update(scheduler_args or {})

Initializes a DomainModule.

Arguments:
  • latent_dim (int): latent dimension of the unimodal module
in_dim = 768
is_frozen

Whether the module is frozen. If None, it is frozen by default.

hidden_dim
latent_dim

The latent dimension of the module.

vocab_size
seq_length
reconstruction_coef
kl_coef
projector
q_mean
q_logvar
projector_dec
embeddings
num_layers
decoder
text_head
optim_lr
optim_weight_decay
scheduler_args
def compute_loss( self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> shimmer.modules.domain.LossOutput:
335    def compute_loss(
336        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
337    ) -> LossOutput:
338        mse_loss = torch.nn.functional.mse_loss(pred, target, reduction="mean")
339
340        return LossOutput(mse_loss, {"mse_loss": mse_loss})

Generic loss computation the modality.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
  • raw_target (Any): raw data from the input
Results:

LossOutput | None: LossOuput with training loss and additional metrics. If None is returned, this loss will be ignored and will not participate in the total loss.

def compute_domain_loss(self, domain: Any) -> shimmer.modules.domain.LossOutput:
342    def compute_domain_loss(self, domain: Any) -> LossOutput:
343        z = self.encode(domain)
344        loss, acc = self.text_token_loss(z, domain)
345        return LossOutput(loss, {"acc": acc})

Compute the unimodal domain loss.

Arguments:
  • domain (Any): domain input
Results:

LossOutput | None: LossOuput with training loss and additional metrics. If None is returned, this loss will be ignored and will not participate in the total loss.

def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
347    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
348        z = self.projector(x["bert"])
349        return self.q_mean(z)

Encode the domain data into a unimodal representation.

Arguments:
  • x (Any): data of the domain.
Returns:

torch.Tensor: a unimodal representation.

def encode_dist(self, x: Mapping[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
351    def encode_dist(
352        self, x: Mapping[str, torch.Tensor]
353    ) -> tuple[torch.Tensor, torch.Tensor]:
354        z = self.projector(x["bert"])
355        return self.q_mean(z), self.q_logvar(z)
def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
357    def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
358        context = z.unsqueeze(1)
359        hidden = None
360        outputs = []
361
362        for _ in range(self.seq_length):
363            out, hidden = self.decoder(context, hidden)
364            token_dist = self.text_head(out)
365            tokens = torch.argmax(token_dist, dim=-1)
366            outputs.append(token_dist)
367
368            context = self.embeddings(tokens)
369
370        token_dists = torch.cat(outputs, dim=1)
371        tokens = torch.argmax(token_dists, dim=-1)
372        return {"token_dist": token_dists, "tokens": tokens}

Decode data from unimodal representation back to the domain data.

Arguments:
  • z (torch.Tensor): unimodal representation of the domain.
Returns:

Any: the original domain data.

def decode_one(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
374    def decode_one(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
375        out, _ = self.decoder(z)
376        tokens_dist = self.text_head(out)
377        tokens = torch.argmax(tokens_dist, -1)
378        return {"token_dist": tokens_dist, "tokens": tokens}
def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
380    def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
381        z, _ = self.encode_dist(x)
382        return self.decode(z)

Same as torch.nn.Module.forward().

Arguments:
  • *args: Whatever you decide to pass into the forward method.
  • **kwargs: Keyword arguments are also possible.
Return:

Your model's output

def text_token_loss( self, z: torch.Tensor, target: Mapping[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
384    def text_token_loss(
385        self, z: torch.Tensor, target: Mapping[str, torch.Tensor]
386    ) -> tuple[torch.Tensor, torch.Tensor]:
387        context = z.unsqueeze(1)
388        real_tokens = self.embeddings(target["tokens"][:, :-1])
389        seq = torch.cat([context, real_tokens], dim=1)
390
391        out = self.decode_one(seq)
392
393        loss = F.cross_entropy(out["token_dist"].transpose(1, 2), target["tokens"])
394        padding_mask = target["tokens"] != self._padding_token
395
396        padded_out = out["tokens"][padding_mask]
397        padded_target = target["tokens"][padding_mask]
398
399        acc = (padded_out == padded_target).sum() / padded_out.size(0)
400        return loss, acc
def generic_step(self, x: Mapping[str, torch.Tensor], mode: str = 'train') -> torch.Tensor:
402    def generic_step(
403        self, x: Mapping[str, torch.Tensor], mode: str = "train"
404    ) -> torch.Tensor:
405        mean, logvar = self.encode_dist(x)
406        z = reparameterize(mean, logvar)
407        noise = torch.randn_like(z) * 0.5
408        z_text = z + noise
409        loss, acc = self.text_token_loss(z_text, x)
410
411        x_hat = self.projector_dec(z)
412        reconstruction_loss = gaussian_nll(x_hat, torch.tensor(0), x["bert"]).sum()
413        kl_loss = kl_divergence_loss(mean, logvar)
414
415        total_loss = (
416            loss
417            + self.reconstruction_coef * reconstruction_loss
418            + self.kl_coef * kl_loss
419        )
420
421        self.log(f"{mode}/kl_loss", kl_loss)
422        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
423        self.log(f"{mode}/total_loss", total_loss)
424        self.log(f"{mode}/loss", loss)
425        self.log(f"{mode}/acc", acc)
426        return total_loss  # loss
def validation_step( self, batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], _) -> torch.Tensor:
428    def validation_step(  # type: ignore
429        self,
430        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
431        _,
432    ) -> torch.Tensor:
433        x = batch[frozenset(["t"])]["t"]
434        return self.generic_step(x, "val")

Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary. Can include any keys, but must include the key 'loss'.
  • None - Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples::

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...
Note:

If you don't need to validate you don't need to implement this method.

Note:

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

def training_step( self, batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], _) -> torch.Tensor:
436    def training_step(  # type: ignore
437        self,
438        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
439        _,
440    ) -> torch.Tensor:
441        x = batch[frozenset(["t"])]["t"]
442        return self.generic_step(x, "train")

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.
  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example::

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()
Note:

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

def configure_optimizers(self) -> dict[str, typing.Any]:
444    def configure_optimizers(  # type: ignore
445        self,
446    ) -> dict[str, Any]:
447        optimizer = AdamW(
448            self.parameters(),
449            lr=self.optim_lr,
450            weight_decay=self.optim_weight_decay,
451        )
452
453        return {"optimizer": optimizer}

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Return:

Any of these 6 options.

  • Single optimizer.
  • List or Tuple of optimizers.
  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).
  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.
  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

.. testcode::

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated",
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your ~lightning.pytorch.core.LightningModule.

Note:

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.
  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default "epoch") in the scheduler configuration, Lightning will call the scheduler's .step() method automatically in case of automatic optimization.
  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.
  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.
  • If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
  • If you need to control how often the optimizer steps, override the optimizer_step() hook.
class Text2Attr(shimmer.modules.domain.DomainModule):
456class Text2Attr(DomainModule):
457    def __init__(
458        self,
459        latent_dim: int,
460        hidden_dim: int,
461        text_model: GRUTextDomainModule,
462        optim_lr: float = 1e-3,
463        optim_weight_decay: float = 0,
464        scheduler_args: SchedulerArgs | None = None,
465    ) -> None:
466        super().__init__(latent_dim)
467        self.save_hyperparameters(ignore=["text_model"])
468
469        self.text_model = text_model
470
471        self.pred_attr = nn.Sequential(
472            nn.Linear(latent_dim, hidden_dim),
473            nn.ReLU(),
474            nn.Linear(hidden_dim, hidden_dim),
475            nn.ReLU(),
476            nn.Linear(hidden_dim, hidden_dim),
477            nn.ReLU(),
478            nn.Linear(hidden_dim, 8),
479            nn.Tanh(),
480        )
481        self.pred_cat = nn.Sequential(
482            nn.Linear(latent_dim, hidden_dim),
483            nn.ReLU(),
484            nn.Linear(hidden_dim, hidden_dim),
485            nn.ReLU(),
486            nn.Linear(hidden_dim, hidden_dim),
487            nn.ReLU(),
488            nn.Linear(hidden_dim, 3),
489            nn.Softmax(dim=1),
490        )
491
492        self.optim_lr = optim_lr
493        self.optim_weight_decay = optim_weight_decay
494
495        self.scheduler_args = SchedulerArgs(
496            max_lr=optim_lr,
497            total_steps=1,
498        )
499        self.scheduler_args.update(scheduler_args or {})
500
501    def compute_loss(
502        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
503    ) -> LossOutput:
504        pred_attr = self.pred_attr(pred)
505        pred_cat = self.pred_cat(pred)
506        target_attr = self.pred_attr(target)
507        target_cat = self.pred_cat(target)
508        loss_attr = F.mse_loss(pred_attr, target_attr)
509        loss_cat = F.cross_entropy(pred_cat, target_cat.argmax(dim=1))
510        loss = loss_attr + loss_cat
511        return LossOutput(loss, {"attr": loss_attr, "cat": loss_cat})
512
513    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
514        return self.text_model.encode(x)
515
516    def decode(self, z: torch.Tensor) -> dict[str, Any]:
517        out: dict[str, Any] = {}
518        out.update(self.text_model.decode(z))
519        pred_attr = self.pred_attr(z)
520        pred_cat = self.pred_cat(z)
521        out.update({"attr": [pred_cat, pred_attr, pred_attr]})
522        return out
523
524    def forward(self, x: Mapping[str, Any]) -> dict[str, list[torch.Tensor]]:
525        return self.decode(self.encode(x))
526
527    def generic_step(
528        self,
529        x: Mapping[str, Any],
530        mode: str = "train",
531    ) -> torch.Tensor:
532        text: Mapping[str, torch.Tensor] = x["t"]
533        attr: Sequence[torch.Tensor] = x["attr"]
534        text_l = self.text_model.encode(text)
535        pred_attr = self.pred_attr(text_l)
536        pred_cat = self.pred_cat(text_l)
537        cats = torch.argmax(attr[0], dim=1)
538        loss_cat = F.cross_entropy(pred_cat, cats)
539        loss_attr = F.mse_loss(pred_attr, attr[1])
540        total_loss = loss_cat + loss_attr
541        pred_cats = pred_cat.argmax(dim=1)
542        acc = (cats == pred_cats).sum() / cats.size(0)
543
544        self.log(f"{mode}/loss_cat", loss_cat)
545        self.log(f"{mode}/loss_attr", loss_attr)
546        self.log(f"{mode}/acc_cat", acc)
547        self.log(f"{mode}/loss", total_loss)
548
549        return total_loss
550
551    def validation_step(  # type: ignore
552        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
553    ) -> torch.Tensor:
554        return self.generic_step(batch, "val")
555
556    def training_step(  # type: ignore
557        self,
558        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
559        _,
560    ) -> torch.Tensor:
561        data = batch[frozenset(["t", "attr"])]
562        return self.generic_step(data, "train")
563
564    def configure_optimizers(  # type: ignore
565        self,
566    ) -> dict[str, Any]:
567        optimizer = AdamW(
568            self.parameters(),
569            lr=self.optim_lr,
570            weight_decay=self.optim_weight_decay,
571        )
572
573        return {"optimizer": optimizer}

Base class for a DomainModule that defines domain specific modules of the GW.

Text2Attr( latent_dim: int, hidden_dim: int, text_model: GRUTextDomainModule, optim_lr: float = 0.001, optim_weight_decay: float = 0, scheduler_args: shimmer.modules.global_workspace.SchedulerArgs | None = None)
457    def __init__(
458        self,
459        latent_dim: int,
460        hidden_dim: int,
461        text_model: GRUTextDomainModule,
462        optim_lr: float = 1e-3,
463        optim_weight_decay: float = 0,
464        scheduler_args: SchedulerArgs | None = None,
465    ) -> None:
466        super().__init__(latent_dim)
467        self.save_hyperparameters(ignore=["text_model"])
468
469        self.text_model = text_model
470
471        self.pred_attr = nn.Sequential(
472            nn.Linear(latent_dim, hidden_dim),
473            nn.ReLU(),
474            nn.Linear(hidden_dim, hidden_dim),
475            nn.ReLU(),
476            nn.Linear(hidden_dim, hidden_dim),
477            nn.ReLU(),
478            nn.Linear(hidden_dim, 8),
479            nn.Tanh(),
480        )
481        self.pred_cat = nn.Sequential(
482            nn.Linear(latent_dim, hidden_dim),
483            nn.ReLU(),
484            nn.Linear(hidden_dim, hidden_dim),
485            nn.ReLU(),
486            nn.Linear(hidden_dim, hidden_dim),
487            nn.ReLU(),
488            nn.Linear(hidden_dim, 3),
489            nn.Softmax(dim=1),
490        )
491
492        self.optim_lr = optim_lr
493        self.optim_weight_decay = optim_weight_decay
494
495        self.scheduler_args = SchedulerArgs(
496            max_lr=optim_lr,
497            total_steps=1,
498        )
499        self.scheduler_args.update(scheduler_args or {})

Initializes a DomainModule.

Arguments:
  • latent_dim (int): latent dimension of the unimodal module
text_model
pred_attr
pred_cat
optim_lr
optim_weight_decay
scheduler_args
def compute_loss( self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> shimmer.modules.domain.LossOutput:
501    def compute_loss(
502        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
503    ) -> LossOutput:
504        pred_attr = self.pred_attr(pred)
505        pred_cat = self.pred_cat(pred)
506        target_attr = self.pred_attr(target)
507        target_cat = self.pred_cat(target)
508        loss_attr = F.mse_loss(pred_attr, target_attr)
509        loss_cat = F.cross_entropy(pred_cat, target_cat.argmax(dim=1))
510        loss = loss_attr + loss_cat
511        return LossOutput(loss, {"attr": loss_attr, "cat": loss_cat})

Generic loss computation the modality.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
  • raw_target (Any): raw data from the input
Results:

LossOutput | None: LossOuput with training loss and additional metrics. If None is returned, this loss will be ignored and will not participate in the total loss.

def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
513    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
514        return self.text_model.encode(x)

Encode the domain data into a unimodal representation.

Arguments:
  • x (Any): data of the domain.
Returns:

torch.Tensor: a unimodal representation.

def decode(self, z: torch.Tensor) -> dict[str, typing.Any]:
516    def decode(self, z: torch.Tensor) -> dict[str, Any]:
517        out: dict[str, Any] = {}
518        out.update(self.text_model.decode(z))
519        pred_attr = self.pred_attr(z)
520        pred_cat = self.pred_cat(z)
521        out.update({"attr": [pred_cat, pred_attr, pred_attr]})
522        return out

Decode data from unimodal representation back to the domain data.

Arguments:
  • z (torch.Tensor): unimodal representation of the domain.
Returns:

Any: the original domain data.

def forward(self, x: Mapping[str, typing.Any]) -> dict[str, list[torch.Tensor]]:
524    def forward(self, x: Mapping[str, Any]) -> dict[str, list[torch.Tensor]]:
525        return self.decode(self.encode(x))

Same as torch.nn.Module.forward().

Arguments:
  • *args: Whatever you decide to pass into the forward method.
  • **kwargs: Keyword arguments are also possible.
Return:

Your model's output

def generic_step(self, x: Mapping[str, typing.Any], mode: str = 'train') -> torch.Tensor:
527    def generic_step(
528        self,
529        x: Mapping[str, Any],
530        mode: str = "train",
531    ) -> torch.Tensor:
532        text: Mapping[str, torch.Tensor] = x["t"]
533        attr: Sequence[torch.Tensor] = x["attr"]
534        text_l = self.text_model.encode(text)
535        pred_attr = self.pred_attr(text_l)
536        pred_cat = self.pred_cat(text_l)
537        cats = torch.argmax(attr[0], dim=1)
538        loss_cat = F.cross_entropy(pred_cat, cats)
539        loss_attr = F.mse_loss(pred_attr, attr[1])
540        total_loss = loss_cat + loss_attr
541        pred_cats = pred_cat.argmax(dim=1)
542        acc = (cats == pred_cats).sum() / cats.size(0)
543
544        self.log(f"{mode}/loss_cat", loss_cat)
545        self.log(f"{mode}/loss_attr", loss_attr)
546        self.log(f"{mode}/acc_cat", acc)
547        self.log(f"{mode}/loss", total_loss)
548
549        return total_loss
def validation_step(self, batch: Mapping[str, Mapping[str, torch.Tensor]], _) -> torch.Tensor:
551    def validation_step(  # type: ignore
552        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
553    ) -> torch.Tensor:
554        return self.generic_step(batch, "val")

Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary. Can include any keys, but must include the key 'loss'.
  • None - Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples::

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...
Note:

If you don't need to validate you don't need to implement this method.

Note:

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

def training_step( self, batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], _) -> torch.Tensor:
556    def training_step(  # type: ignore
557        self,
558        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
559        _,
560    ) -> torch.Tensor:
561        data = batch[frozenset(["t", "attr"])]
562        return self.generic_step(data, "train")

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.
  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example::

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()
Note:

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

def configure_optimizers(self) -> dict[str, typing.Any]:
564    def configure_optimizers(  # type: ignore
565        self,
566    ) -> dict[str, Any]:
567        optimizer = AdamW(
568            self.parameters(),
569            lr=self.optim_lr,
570            weight_decay=self.optim_weight_decay,
571        )
572
573        return {"optimizer": optimizer}

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Return:

Any of these 6 options.

  • Single optimizer.
  • List or Tuple of optimizers.
  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).
  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.
  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

.. testcode::

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated",
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your ~lightning.pytorch.core.LightningModule.

Note:

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.
  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default "epoch") in the scheduler configuration, Lightning will call the scheduler's .step() method automatically in case of automatic optimization.
  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.
  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.
  • If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
  • If you need to control how often the optimizer steps, override the optimizer_step() hook.