shimmer_ssd.modules.domains.text

  1from collections.abc import Mapping, Sequence
  2from typing import Any
  3
  4import torch
  5import torch.nn.functional as F
  6from shimmer import LossOutput
  7from shimmer.modules.domain import DomainModule
  8from shimmer.modules.global_workspace import SchedulerArgs
  9from shimmer.modules.vae import (
 10    VAE,
 11    VAEDecoder,
 12    VAEEncoder,
 13    gaussian_nll,
 14    kl_divergence_loss,
 15)
 16from simple_shapes_dataset.text import composer
 17from simple_shapes_dataset.text.utils import inspect_all_choices
 18from torch import nn
 19from torch.optim.adamw import AdamW
 20from torch.optim.lr_scheduler import OneCycleLR
 21
 22
 23class Encoder(VAEEncoder):
 24    def __init__(
 25        self,
 26        in_dim: int,
 27        hidden_dim: int,
 28        out_dim: int,
 29    ):
 30        super().__init__()
 31
 32        self.in_dim = in_dim
 33        self.hidden_dim = hidden_dim
 34        self.out_dim = out_dim
 35
 36        self.encoder = nn.Sequential(
 37            nn.Linear(self.in_dim, hidden_dim),
 38            nn.ReLU(),
 39            nn.Linear(hidden_dim, hidden_dim),
 40            nn.ReLU(),
 41            nn.Linear(hidden_dim, out_dim),
 42            nn.ReLU(),
 43        )
 44
 45        self.q_mean = nn.Linear(self.out_dim, self.out_dim)
 46        self.q_logvar = nn.Linear(self.out_dim, self.out_dim)
 47
 48    def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
 49        out = torch.cat(list(x), dim=-1)
 50        out = self.encoder(out)
 51        return self.q_mean(out), self.q_logvar(out)
 52
 53
 54class Decoder(VAEDecoder):
 55    def __init__(
 56        self,
 57        in_dim: int,
 58        hidden_dim: int,
 59        out_dim: int,
 60    ):
 61        super().__init__()
 62
 63        self.in_dim = in_dim
 64        self.hidden_dim = hidden_dim
 65        self.out_dim = out_dim
 66
 67        self.decoder = nn.Sequential(
 68            nn.Linear(self.in_dim, self.hidden_dim),
 69            nn.ReLU(),
 70            nn.Linear(self.hidden_dim, self.hidden_dim),
 71            nn.ReLU(),
 72            nn.Linear(self.hidden_dim, self.out_dim),
 73            nn.Tanh(),
 74        )
 75
 76    def forward(self, z: torch.Tensor) -> list[torch.Tensor]:  # type: ignore
 77        return [self.decoder(z)]
 78
 79
 80class TextDomainModule(DomainModule):
 81    in_dim = 768
 82
 83    def __init__(
 84        self,
 85        latent_dim: int,
 86        hidden_dim: int,
 87        beta: float = 1,
 88        optim_lr: float = 1e-3,
 89        optim_weight_decay: float = 0,
 90        scheduler_args: SchedulerArgs | None = None,
 91    ):
 92        super().__init__(latent_dim)
 93        self.save_hyperparameters()
 94
 95        self.hidden_dim = hidden_dim
 96
 97        vae_encoder = Encoder(self.in_dim, self.hidden_dim, self.latent_dim)
 98        vae_decoder = Decoder(self.latent_dim, self.hidden_dim, self.in_dim)
 99        self.vae = VAE(vae_encoder, vae_decoder, beta)
100
101        self.attribute_cls = nn.Sequential(
102            nn.Linear(self.latent_dim, self.hidden_dim),
103            nn.ReLU(),
104            nn.Linear(self.hidden_dim, self.hidden_dim),
105            nn.ReLU(),
106        )
107        self.attribute_cls_cat = nn.Linear(self.hidden_dim, 3)
108        self.attribute_cls_attr = nn.Sequential(
109            nn.Linear(self.hidden_dim, 8), nn.Tanh()
110        )
111
112        self.composer_grammar_options = inspect_all_choices(composer)
113
114        self.grammar_cls = nn.Sequential(
115            nn.Linear(self.latent_dim, self.hidden_dim),
116            nn.ReLU(),
117            nn.Linear(self.hidden_dim, self.hidden_dim),
118            nn.ReLU(),
119        )
120        self.grammar_heads = nn.ModuleDict(
121            {
122                name: nn.Linear(self.hidden_dim, n_outputs)
123                for name, n_outputs in self.composer_grammar_options.items()
124            }
125        )
126
127        self.optim_lr = optim_lr
128        self.optim_weight_decay = optim_weight_decay
129
130        self.scheduler_args = SchedulerArgs(
131            max_lr=optim_lr,
132            total_steps=1,
133        )
134        self.scheduler_args.update(scheduler_args or {})
135
136    def compute_loss(
137        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
138    ) -> LossOutput:
139        return LossOutput(F.mse_loss(pred, target, reduction="mean"))
140
141    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
142        return self.vae.encode((x["bert"],))
143
144    def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
145        text: dict[str, torch.Tensor] = {"bert": self.vae.decode(z)[0]}
146        attr_pred_cat, attr_pred_attr = self.predict_attr(z)
147        text["cls"] = attr_pred_cat
148        text["attr"] = attr_pred_attr
149        text["unpaired"] = torch.zeros_like(z[:, -1])
150        text.update(self.predict_grammar(z))
151        return text
152
153    def predict_attr(self, mean: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
154        attr_pred = self.attribute_cls(mean)
155        attr_pred_cat = self.attribute_cls_cat(attr_pred)
156        attr_pred_attr = self.attribute_cls_attr(attr_pred)
157        return attr_pred_cat, attr_pred_attr
158
159    def predict_grammar(self, mean: torch.Tensor) -> dict[str, torch.Tensor]:
160        grammar_pred = self.grammar_cls(mean)
161        return {name: head(grammar_pred) for name, head in self.grammar_heads.items()}
162
163    def grammar_losses(self, mean: torch.Tensor, targets) -> dict[str, torch.Tensor]:
164        grammar_pred = self.predict_grammar(mean)
165        return {
166            f"{name}_ce": F.cross_entropy(
167                pred, targets[name][:, 0].long(), reduction="sum"
168            )
169            for name, pred in grammar_pred.items()
170        }
171
172    def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
173        return self.decode(self.encode(x))
174
175    def generic_step(
176        self,
177        x: Mapping[str, torch.Tensor],
178        mode: str = "train",
179    ) -> torch.Tensor:
180        (mean, logvar), reconstruction = self.vae((x["bert"],))
181
182        reconstruction_loss = gaussian_nll(
183            reconstruction[0], torch.tensor(0), x["bert"]
184        ).sum()
185
186        kl_loss = kl_divergence_loss(mean, logvar)
187
188        attr_pred_cat, attr_pred_attr = self.predict_attr(mean)
189
190        loss_attr_cat = F.cross_entropy(
191            attr_pred_cat, x["cls"].argmax(dim=1), reduction="sum"
192        )
193        loss_attr = F.mse_loss(attr_pred_attr, x["attr"], reduction="sum")
194        grammar_targets = {name: x[name] for name in self.composer_grammar_options}
195        grammar_losses = self.grammar_losses(mean, grammar_targets)
196
197        total_loss = (
198            reconstruction_loss + self.vae.beta * kl_loss + loss_attr_cat + loss_attr
199        )
200
201        for grammar_loss_name, grammar_loss in grammar_losses.items():
202            total_loss += grammar_loss
203            self.log(f"{mode}/{grammar_loss_name}", grammar_loss)
204
205        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
206        self.log(f"{mode}/kl_loss", kl_loss)
207        self.log(f"{mode}/attr_category", loss_attr_cat)
208        self.log(f"{mode}/attr_attr", loss_attr)
209        self.log(f"{mode}/loss", total_loss)
210        return total_loss
211
212    def validation_step(  # type: ignore
213        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
214    ) -> torch.Tensor:
215        x = batch["t"]
216        return self.generic_step(x, "val")
217
218    def training_step(  # type: ignore
219        self,
220        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
221        _,
222    ) -> torch.Tensor:
223        x = batch[frozenset(["t"])]["t"]
224        return self.generic_step(x, "train")
225
226    def configure_optimizers(  # type: ignore
227        self,
228    ) -> dict[str, Any]:
229        optimizer = torch.optim.AdamW(
230            self.parameters(),
231            lr=self.optim_lr,
232            weight_decay=self.optim_weight_decay,
233        )
234        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
235
236        return {
237            "optimizer": optimizer,
238            "lr_scheduler": {
239                "scheduler": lr_scheduler,
240                "interval": "step",
241            },
242        }
243
244
245class GRUEncoder(nn.Module):
246    def __init__(
247        self,
248        in_dim: int,
249        hidden_dim: int,
250        out_dim: int,
251    ):
252        super().__init__()
253
254        self.in_dim = in_dim
255        self.hidden_dim = hidden_dim
256        self.out_dim = out_dim
257
258    def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
259        out = torch.cat(list(x), dim=-1)
260        out = self.encoder(out)
261        return out
262
263
264class GRUTextDomainModule(DomainModule):
265    in_dim = 768
266
267    def __init__(
268        self,
269        latent_dim: int,
270        hidden_dim: int,
271        vocab_size: int,
272        seq_length: int,
273        optim_lr: float = 1e-3,
274        optim_weight_decay: float = 0,
275        scheduler_args: SchedulerArgs | None = None,
276        padding_token: int = 0,
277    ):
278        super().__init__(latent_dim)
279        self.is_frozen = False
280        self.save_hyperparameters()
281
282        self.hidden_dim = hidden_dim
283        self.latent_dim = latent_dim
284        self.vocab_size = vocab_size
285        self.seq_length = seq_length
286
287        self._padding_token = padding_token
288
289        self.projector = nn.Sequential(
290            nn.Linear(self.in_dim, self.hidden_dim),
291            nn.ReLU(),
292            nn.Linear(self.hidden_dim, self.latent_dim),
293            nn.Tanh(),
294        )
295
296        self.embeddings = nn.Embedding(vocab_size, self.latent_dim)
297
298        self.num_layers = 2
299        self.decoder = nn.GRU(
300            self.latent_dim,
301            self.hidden_dim,
302            num_layers=self.num_layers,
303            batch_first=True,
304        )
305        self.text_head = nn.Linear(self.hidden_dim, self.vocab_size)
306
307        self.optim_lr = optim_lr
308        self.optim_weight_decay = optim_weight_decay
309
310        self.scheduler_args = SchedulerArgs(
311            max_lr=optim_lr,
312            total_steps=1,
313        )
314        self.scheduler_args.update(scheduler_args or {})
315
316    def compute_loss(
317        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
318    ) -> LossOutput:
319        text_token_loss, acc = self.text_token_loss(pred, raw_target)
320
321        return LossOutput(
322            2 * text_token_loss,
323            {"loss_tokens": text_token_loss, "pred_t_acc": acc},
324        )
325
326    def compute_domain_loss(self, domain: Any) -> LossOutput:
327        z = self.encode(domain)
328        loss, acc = self.text_token_loss(z, domain)
329        return LossOutput(loss, {"acc": acc})
330
331    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
332        return self.projector(x["bert"])
333
334    def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
335        context = z.unsqueeze(1)
336        pad_tokens = self.embeddings(
337            torch.zeros(
338                z.size(0), self.seq_length - 1, dtype=torch.long, device=z.device
339            )
340        )
341        seq = torch.cat([context, pad_tokens], dim=1)
342        for k in range(0, self.seq_length - 1):
343            out = self.decode_one(seq)
344            seq[:, k + 1] = self.embeddings(out["tokens"][:, k])
345        return self.decode_one(seq)
346
347    def decode_one(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
348        out, _ = self.decoder(z)
349        tokens_dist = self.text_head(out)
350        tokens = torch.argmax(tokens_dist, -1)
351        return {"token_dist": tokens_dist, "tokens": tokens}
352
353    def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
354        return self.decode(self.encode(x))
355
356    def text_token_loss(
357        self, z: torch.Tensor, target: Mapping[str, torch.Tensor]
358    ) -> tuple[torch.Tensor, torch.Tensor]:
359        context = z.unsqueeze(1)
360        real_tokens = self.embeddings(target["tokens"][:, :-1])
361        seq = torch.cat([context, real_tokens], dim=1)
362
363        out = self.decode_one(seq)
364        loss = F.cross_entropy(out["token_dist"].transpose(1, 2), target["tokens"])
365        padding_mask = target["tokens"] != self._padding_token
366
367        padded_out = out["tokens"][padding_mask]
368        padded_target = target["tokens"][padding_mask]
369
370        acc = (padded_out == padded_target).sum() / padded_out.size(0)
371        return loss, acc
372
373    def generic_step(
374        self, x: Mapping[str, torch.Tensor], mode: str = "train"
375    ) -> torch.Tensor:
376        z = self.encode(x)
377        loss, acc = self.text_token_loss(z, x)
378        self.log(f"{mode}/loss", loss)
379        self.log(f"{mode}/acc", acc)
380        return loss
381
382    def validation_step(  # type: ignore
383        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
384    ) -> torch.Tensor:
385        x = batch["t"]
386        return self.generic_step(x, "val")
387
388    def training_step(  # type: ignore
389        self,
390        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
391        _,
392    ) -> torch.Tensor:
393        x = batch[frozenset(["t"])]["t"]
394        return self.generic_step(x, "train")
395
396    def configure_optimizers(  # type: ignore
397        self,
398    ) -> dict[str, Any]:
399        optimizer = AdamW(
400            self.parameters(),
401            lr=self.optim_lr,
402            weight_decay=self.optim_weight_decay,
403        )
404
405        return {"optimizer": optimizer}
406
407
408class Text2Attr(DomainModule):
409    def __init__(
410        self,
411        latent_dim: int,
412        hidden_dim: int,
413        text_model: GRUTextDomainModule,
414        optim_lr: float = 1e-3,
415        optim_weight_decay: float = 0,
416        scheduler_args: SchedulerArgs | None = None,
417    ) -> None:
418        super().__init__(latent_dim)
419        self.save_hyperparameters(ignore=["text_model"])
420
421        self.text_model = text_model
422
423        self.pred_attr = nn.Sequential(
424            nn.Linear(latent_dim, hidden_dim),
425            nn.ReLU(),
426            nn.Linear(hidden_dim, hidden_dim),
427            nn.ReLU(),
428            nn.Linear(hidden_dim, hidden_dim),
429            nn.ReLU(),
430            nn.Linear(hidden_dim, 8),
431            nn.Tanh(),
432        )
433        self.pred_cat = nn.Sequential(
434            nn.Linear(latent_dim, hidden_dim),
435            nn.ReLU(),
436            nn.Linear(hidden_dim, hidden_dim),
437            nn.ReLU(),
438            nn.Linear(hidden_dim, hidden_dim),
439            nn.ReLU(),
440            nn.Linear(hidden_dim, 3),
441            nn.Softmax(dim=1),
442        )
443
444        self.optim_lr = optim_lr
445        self.optim_weight_decay = optim_weight_decay
446
447        self.scheduler_args = SchedulerArgs(
448            max_lr=optim_lr,
449            total_steps=1,
450        )
451        self.scheduler_args.update(scheduler_args or {})
452
453    def compute_loss(
454        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
455    ) -> LossOutput:
456        pred_attr = self.pred_attr(pred)
457        pred_cat = self.pred_cat(pred)
458        target_attr = self.pred_attr(target)
459        target_cat = self.pred_cat(target)
460        loss_attr = F.mse_loss(pred_attr, target_attr)
461        loss_cat = F.cross_entropy(pred_cat, target_cat.argmax(dim=1))
462        loss = loss_attr + loss_cat
463        return LossOutput(loss, {"attr": loss_attr, "cat": loss_cat})
464
465    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
466        return self.text_model.encode(x)
467
468    def decode(self, z: torch.Tensor) -> dict[str, Any]:
469        out: dict[str, Any] = {}
470        out.update(self.text_model.decode(z))
471        pred_attr = self.pred_attr(z)
472        pred_cat = self.pred_cat(z)
473        out.update({"attr": [pred_cat, pred_attr, pred_attr]})
474        return out
475
476    def forward(self, x: Mapping[str, Any]) -> dict[str, list[torch.Tensor]]:
477        return self.decode(self.encode(x))
478
479    def generic_step(
480        self,
481        x: Mapping[str, Any],
482        mode: str = "train",
483    ) -> torch.Tensor:
484        text: Mapping[str, torch.Tensor] = x["t"]
485        attr: Sequence[torch.Tensor] = x["attr"]
486        text_l = self.text_model.encode(text)
487        pred_attr = self.pred_attr(text_l)
488        pred_cat = self.pred_cat(text_l)
489        cats = torch.argmax(attr[0], dim=1)
490        loss_cat = F.cross_entropy(pred_cat, cats)
491        loss_attr = F.mse_loss(pred_attr, attr[1])
492        total_loss = loss_cat + loss_attr
493        pred_cats = pred_cat.argmax(dim=1)
494        acc = (cats == pred_cats).sum() / cats.size(0)
495
496        self.log(f"{mode}/loss_cat", loss_cat)
497        self.log(f"{mode}/loss_attr", loss_attr)
498        self.log(f"{mode}/acc_cat", acc)
499        self.log(f"{mode}/loss", total_loss)
500
501        return total_loss
502
503    def validation_step(  # type: ignore
504        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
505    ) -> torch.Tensor:
506        return self.generic_step(batch, "val")
507
508    def training_step(  # type: ignore
509        self,
510        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
511        _,
512    ) -> torch.Tensor:
513        data = batch[frozenset(["t", "attr"])]
514        return self.generic_step(data, "train")
515
516    def configure_optimizers(  # type: ignore
517        self,
518    ) -> dict[str, Any]:
519        optimizer = AdamW(
520            self.parameters(),
521            lr=self.optim_lr,
522            weight_decay=self.optim_weight_decay,
523        )
524
525        return {"optimizer": optimizer}
class Encoder(shimmer.modules.vae.VAEEncoder):
24class Encoder(VAEEncoder):
25    def __init__(
26        self,
27        in_dim: int,
28        hidden_dim: int,
29        out_dim: int,
30    ):
31        super().__init__()
32
33        self.in_dim = in_dim
34        self.hidden_dim = hidden_dim
35        self.out_dim = out_dim
36
37        self.encoder = nn.Sequential(
38            nn.Linear(self.in_dim, hidden_dim),
39            nn.ReLU(),
40            nn.Linear(hidden_dim, hidden_dim),
41            nn.ReLU(),
42            nn.Linear(hidden_dim, out_dim),
43            nn.ReLU(),
44        )
45
46        self.q_mean = nn.Linear(self.out_dim, self.out_dim)
47        self.q_logvar = nn.Linear(self.out_dim, self.out_dim)
48
49    def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
50        out = torch.cat(list(x), dim=-1)
51        out = self.encoder(out)
52        return self.q_mean(out), self.q_logvar(out)

Base class for a VAE encoder.

Encoder(in_dim: int, hidden_dim: int, out_dim: int)
25    def __init__(
26        self,
27        in_dim: int,
28        hidden_dim: int,
29        out_dim: int,
30    ):
31        super().__init__()
32
33        self.in_dim = in_dim
34        self.hidden_dim = hidden_dim
35        self.out_dim = out_dim
36
37        self.encoder = nn.Sequential(
38            nn.Linear(self.in_dim, hidden_dim),
39            nn.ReLU(),
40            nn.Linear(hidden_dim, hidden_dim),
41            nn.ReLU(),
42            nn.Linear(hidden_dim, out_dim),
43            nn.ReLU(),
44        )
45
46        self.q_mean = nn.Linear(self.out_dim, self.out_dim)
47        self.q_logvar = nn.Linear(self.out_dim, self.out_dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_dim
hidden_dim
out_dim
encoder
q_mean
q_logvar
def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
49    def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
50        out = torch.cat(list(x), dim=-1)
51        out = self.encoder(out)
52        return self.q_mean(out), self.q_logvar(out)

Encode representation with VAE.

Arguments:
  • x (Any): Some input value
Returns:

tuple[torch.Tensor, torch.Tensor]: the mean and log variance

class Decoder(shimmer.modules.vae.VAEDecoder):
55class Decoder(VAEDecoder):
56    def __init__(
57        self,
58        in_dim: int,
59        hidden_dim: int,
60        out_dim: int,
61    ):
62        super().__init__()
63
64        self.in_dim = in_dim
65        self.hidden_dim = hidden_dim
66        self.out_dim = out_dim
67
68        self.decoder = nn.Sequential(
69            nn.Linear(self.in_dim, self.hidden_dim),
70            nn.ReLU(),
71            nn.Linear(self.hidden_dim, self.hidden_dim),
72            nn.ReLU(),
73            nn.Linear(self.hidden_dim, self.out_dim),
74            nn.Tanh(),
75        )
76
77    def forward(self, z: torch.Tensor) -> list[torch.Tensor]:  # type: ignore
78        return [self.decoder(z)]

Base class for a VAE decoder.

Decoder(in_dim: int, hidden_dim: int, out_dim: int)
56    def __init__(
57        self,
58        in_dim: int,
59        hidden_dim: int,
60        out_dim: int,
61    ):
62        super().__init__()
63
64        self.in_dim = in_dim
65        self.hidden_dim = hidden_dim
66        self.out_dim = out_dim
67
68        self.decoder = nn.Sequential(
69            nn.Linear(self.in_dim, self.hidden_dim),
70            nn.ReLU(),
71            nn.Linear(self.hidden_dim, self.hidden_dim),
72            nn.ReLU(),
73            nn.Linear(self.hidden_dim, self.out_dim),
74            nn.Tanh(),
75        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_dim
hidden_dim
out_dim
decoder
def forward(self, z: torch.Tensor) -> list[torch.Tensor]:
77    def forward(self, z: torch.Tensor) -> list[torch.Tensor]:  # type: ignore
78        return [self.decoder(z)]

Decode representation with VAE

Arguments:
  • x (torch.Tensor): VAE latent representation representation
Returns:

Any: the reconstructed input

class TextDomainModule(shimmer.modules.domain.DomainModule):
 81class TextDomainModule(DomainModule):
 82    in_dim = 768
 83
 84    def __init__(
 85        self,
 86        latent_dim: int,
 87        hidden_dim: int,
 88        beta: float = 1,
 89        optim_lr: float = 1e-3,
 90        optim_weight_decay: float = 0,
 91        scheduler_args: SchedulerArgs | None = None,
 92    ):
 93        super().__init__(latent_dim)
 94        self.save_hyperparameters()
 95
 96        self.hidden_dim = hidden_dim
 97
 98        vae_encoder = Encoder(self.in_dim, self.hidden_dim, self.latent_dim)
 99        vae_decoder = Decoder(self.latent_dim, self.hidden_dim, self.in_dim)
100        self.vae = VAE(vae_encoder, vae_decoder, beta)
101
102        self.attribute_cls = nn.Sequential(
103            nn.Linear(self.latent_dim, self.hidden_dim),
104            nn.ReLU(),
105            nn.Linear(self.hidden_dim, self.hidden_dim),
106            nn.ReLU(),
107        )
108        self.attribute_cls_cat = nn.Linear(self.hidden_dim, 3)
109        self.attribute_cls_attr = nn.Sequential(
110            nn.Linear(self.hidden_dim, 8), nn.Tanh()
111        )
112
113        self.composer_grammar_options = inspect_all_choices(composer)
114
115        self.grammar_cls = nn.Sequential(
116            nn.Linear(self.latent_dim, self.hidden_dim),
117            nn.ReLU(),
118            nn.Linear(self.hidden_dim, self.hidden_dim),
119            nn.ReLU(),
120        )
121        self.grammar_heads = nn.ModuleDict(
122            {
123                name: nn.Linear(self.hidden_dim, n_outputs)
124                for name, n_outputs in self.composer_grammar_options.items()
125            }
126        )
127
128        self.optim_lr = optim_lr
129        self.optim_weight_decay = optim_weight_decay
130
131        self.scheduler_args = SchedulerArgs(
132            max_lr=optim_lr,
133            total_steps=1,
134        )
135        self.scheduler_args.update(scheduler_args or {})
136
137    def compute_loss(
138        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
139    ) -> LossOutput:
140        return LossOutput(F.mse_loss(pred, target, reduction="mean"))
141
142    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
143        return self.vae.encode((x["bert"],))
144
145    def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
146        text: dict[str, torch.Tensor] = {"bert": self.vae.decode(z)[0]}
147        attr_pred_cat, attr_pred_attr = self.predict_attr(z)
148        text["cls"] = attr_pred_cat
149        text["attr"] = attr_pred_attr
150        text["unpaired"] = torch.zeros_like(z[:, -1])
151        text.update(self.predict_grammar(z))
152        return text
153
154    def predict_attr(self, mean: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
155        attr_pred = self.attribute_cls(mean)
156        attr_pred_cat = self.attribute_cls_cat(attr_pred)
157        attr_pred_attr = self.attribute_cls_attr(attr_pred)
158        return attr_pred_cat, attr_pred_attr
159
160    def predict_grammar(self, mean: torch.Tensor) -> dict[str, torch.Tensor]:
161        grammar_pred = self.grammar_cls(mean)
162        return {name: head(grammar_pred) for name, head in self.grammar_heads.items()}
163
164    def grammar_losses(self, mean: torch.Tensor, targets) -> dict[str, torch.Tensor]:
165        grammar_pred = self.predict_grammar(mean)
166        return {
167            f"{name}_ce": F.cross_entropy(
168                pred, targets[name][:, 0].long(), reduction="sum"
169            )
170            for name, pred in grammar_pred.items()
171        }
172
173    def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
174        return self.decode(self.encode(x))
175
176    def generic_step(
177        self,
178        x: Mapping[str, torch.Tensor],
179        mode: str = "train",
180    ) -> torch.Tensor:
181        (mean, logvar), reconstruction = self.vae((x["bert"],))
182
183        reconstruction_loss = gaussian_nll(
184            reconstruction[0], torch.tensor(0), x["bert"]
185        ).sum()
186
187        kl_loss = kl_divergence_loss(mean, logvar)
188
189        attr_pred_cat, attr_pred_attr = self.predict_attr(mean)
190
191        loss_attr_cat = F.cross_entropy(
192            attr_pred_cat, x["cls"].argmax(dim=1), reduction="sum"
193        )
194        loss_attr = F.mse_loss(attr_pred_attr, x["attr"], reduction="sum")
195        grammar_targets = {name: x[name] for name in self.composer_grammar_options}
196        grammar_losses = self.grammar_losses(mean, grammar_targets)
197
198        total_loss = (
199            reconstruction_loss + self.vae.beta * kl_loss + loss_attr_cat + loss_attr
200        )
201
202        for grammar_loss_name, grammar_loss in grammar_losses.items():
203            total_loss += grammar_loss
204            self.log(f"{mode}/{grammar_loss_name}", grammar_loss)
205
206        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
207        self.log(f"{mode}/kl_loss", kl_loss)
208        self.log(f"{mode}/attr_category", loss_attr_cat)
209        self.log(f"{mode}/attr_attr", loss_attr)
210        self.log(f"{mode}/loss", total_loss)
211        return total_loss
212
213    def validation_step(  # type: ignore
214        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
215    ) -> torch.Tensor:
216        x = batch["t"]
217        return self.generic_step(x, "val")
218
219    def training_step(  # type: ignore
220        self,
221        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
222        _,
223    ) -> torch.Tensor:
224        x = batch[frozenset(["t"])]["t"]
225        return self.generic_step(x, "train")
226
227    def configure_optimizers(  # type: ignore
228        self,
229    ) -> dict[str, Any]:
230        optimizer = torch.optim.AdamW(
231            self.parameters(),
232            lr=self.optim_lr,
233            weight_decay=self.optim_weight_decay,
234        )
235        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
236
237        return {
238            "optimizer": optimizer,
239            "lr_scheduler": {
240                "scheduler": lr_scheduler,
241                "interval": "step",
242            },
243        }

Base class for a DomainModule that defines domain specific modules of the GW.

TextDomainModule( latent_dim: int, hidden_dim: int, beta: float = 1, optim_lr: float = 0.001, optim_weight_decay: float = 0, scheduler_args: shimmer.modules.global_workspace.SchedulerArgs | None = None)
 84    def __init__(
 85        self,
 86        latent_dim: int,
 87        hidden_dim: int,
 88        beta: float = 1,
 89        optim_lr: float = 1e-3,
 90        optim_weight_decay: float = 0,
 91        scheduler_args: SchedulerArgs | None = None,
 92    ):
 93        super().__init__(latent_dim)
 94        self.save_hyperparameters()
 95
 96        self.hidden_dim = hidden_dim
 97
 98        vae_encoder = Encoder(self.in_dim, self.hidden_dim, self.latent_dim)
 99        vae_decoder = Decoder(self.latent_dim, self.hidden_dim, self.in_dim)
100        self.vae = VAE(vae_encoder, vae_decoder, beta)
101
102        self.attribute_cls = nn.Sequential(
103            nn.Linear(self.latent_dim, self.hidden_dim),
104            nn.ReLU(),
105            nn.Linear(self.hidden_dim, self.hidden_dim),
106            nn.ReLU(),
107        )
108        self.attribute_cls_cat = nn.Linear(self.hidden_dim, 3)
109        self.attribute_cls_attr = nn.Sequential(
110            nn.Linear(self.hidden_dim, 8), nn.Tanh()
111        )
112
113        self.composer_grammar_options = inspect_all_choices(composer)
114
115        self.grammar_cls = nn.Sequential(
116            nn.Linear(self.latent_dim, self.hidden_dim),
117            nn.ReLU(),
118            nn.Linear(self.hidden_dim, self.hidden_dim),
119            nn.ReLU(),
120        )
121        self.grammar_heads = nn.ModuleDict(
122            {
123                name: nn.Linear(self.hidden_dim, n_outputs)
124                for name, n_outputs in self.composer_grammar_options.items()
125            }
126        )
127
128        self.optim_lr = optim_lr
129        self.optim_weight_decay = optim_weight_decay
130
131        self.scheduler_args = SchedulerArgs(
132            max_lr=optim_lr,
133            total_steps=1,
134        )
135        self.scheduler_args.update(scheduler_args or {})

Initializes a DomainModule.

Arguments:
  • latent_dim (int): latent dimension of the unimodal module
in_dim = 768
hidden_dim
vae
attribute_cls
attribute_cls_cat
attribute_cls_attr
composer_grammar_options
grammar_cls
grammar_heads
optim_lr
optim_weight_decay
scheduler_args
def compute_loss( self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> shimmer.modules.domain.LossOutput:
137    def compute_loss(
138        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
139    ) -> LossOutput:
140        return LossOutput(F.mse_loss(pred, target, reduction="mean"))

Generic loss computation the modality.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
  • raw_target (Any): raw data from the input
Results:

LossOutput | None: LossOuput with training loss and additional metrics. If None is returned, this loss will be ignored and will not participate in the total loss.

def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
142    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
143        return self.vae.encode((x["bert"],))

Encode the domain data into a unimodal representation.

Arguments:
  • x (Any): data of the domain.
Returns:

torch.Tensor: a unimodal representation.

def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
145    def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
146        text: dict[str, torch.Tensor] = {"bert": self.vae.decode(z)[0]}
147        attr_pred_cat, attr_pred_attr = self.predict_attr(z)
148        text["cls"] = attr_pred_cat
149        text["attr"] = attr_pred_attr
150        text["unpaired"] = torch.zeros_like(z[:, -1])
151        text.update(self.predict_grammar(z))
152        return text

Decode data from unimodal representation back to the domain data.

Arguments:
  • z (torch.Tensor): unimodal representation of the domain.
Returns:

Any: the original domain data.

def predict_attr(self, mean: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
154    def predict_attr(self, mean: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
155        attr_pred = self.attribute_cls(mean)
156        attr_pred_cat = self.attribute_cls_cat(attr_pred)
157        attr_pred_attr = self.attribute_cls_attr(attr_pred)
158        return attr_pred_cat, attr_pred_attr
def predict_grammar(self, mean: torch.Tensor) -> dict[str, torch.Tensor]:
160    def predict_grammar(self, mean: torch.Tensor) -> dict[str, torch.Tensor]:
161        grammar_pred = self.grammar_cls(mean)
162        return {name: head(grammar_pred) for name, head in self.grammar_heads.items()}
def grammar_losses(self, mean: torch.Tensor, targets) -> dict[str, torch.Tensor]:
164    def grammar_losses(self, mean: torch.Tensor, targets) -> dict[str, torch.Tensor]:
165        grammar_pred = self.predict_grammar(mean)
166        return {
167            f"{name}_ce": F.cross_entropy(
168                pred, targets[name][:, 0].long(), reduction="sum"
169            )
170            for name, pred in grammar_pred.items()
171        }
def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
173    def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
174        return self.decode(self.encode(x))

Same as torch.nn.Module.forward().

Arguments:
  • *args: Whatever you decide to pass into the forward method.
  • **kwargs: Keyword arguments are also possible.
Return:

Your model's output

def generic_step(self, x: Mapping[str, torch.Tensor], mode: str = 'train') -> torch.Tensor:
176    def generic_step(
177        self,
178        x: Mapping[str, torch.Tensor],
179        mode: str = "train",
180    ) -> torch.Tensor:
181        (mean, logvar), reconstruction = self.vae((x["bert"],))
182
183        reconstruction_loss = gaussian_nll(
184            reconstruction[0], torch.tensor(0), x["bert"]
185        ).sum()
186
187        kl_loss = kl_divergence_loss(mean, logvar)
188
189        attr_pred_cat, attr_pred_attr = self.predict_attr(mean)
190
191        loss_attr_cat = F.cross_entropy(
192            attr_pred_cat, x["cls"].argmax(dim=1), reduction="sum"
193        )
194        loss_attr = F.mse_loss(attr_pred_attr, x["attr"], reduction="sum")
195        grammar_targets = {name: x[name] for name in self.composer_grammar_options}
196        grammar_losses = self.grammar_losses(mean, grammar_targets)
197
198        total_loss = (
199            reconstruction_loss + self.vae.beta * kl_loss + loss_attr_cat + loss_attr
200        )
201
202        for grammar_loss_name, grammar_loss in grammar_losses.items():
203            total_loss += grammar_loss
204            self.log(f"{mode}/{grammar_loss_name}", grammar_loss)
205
206        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
207        self.log(f"{mode}/kl_loss", kl_loss)
208        self.log(f"{mode}/attr_category", loss_attr_cat)
209        self.log(f"{mode}/attr_attr", loss_attr)
210        self.log(f"{mode}/loss", total_loss)
211        return total_loss
def validation_step(self, batch: Mapping[str, Mapping[str, torch.Tensor]], _) -> torch.Tensor:
213    def validation_step(  # type: ignore
214        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
215    ) -> torch.Tensor:
216        x = batch["t"]
217        return self.generic_step(x, "val")

Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary. Can include any keys, but must include the key 'loss'.
  • None - Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples::

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...
Note:

If you don't need to validate you don't need to implement this method.

Note:

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

def training_step( self, batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], _) -> torch.Tensor:
219    def training_step(  # type: ignore
220        self,
221        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
222        _,
223    ) -> torch.Tensor:
224        x = batch[frozenset(["t"])]["t"]
225        return self.generic_step(x, "train")

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.
  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example::

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()
Note:

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

def configure_optimizers(self) -> dict[str, typing.Any]:
227    def configure_optimizers(  # type: ignore
228        self,
229    ) -> dict[str, Any]:
230        optimizer = torch.optim.AdamW(
231            self.parameters(),
232            lr=self.optim_lr,
233            weight_decay=self.optim_weight_decay,
234        )
235        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
236
237        return {
238            "optimizer": optimizer,
239            "lr_scheduler": {
240                "scheduler": lr_scheduler,
241                "interval": "step",
242            },
243        }

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Return:

Any of these 6 options.

  • Single optimizer.
  • List or Tuple of optimizers.
  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).
  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.
  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

.. testcode::

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated",
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your ~lightning.pytorch.core.LightningModule.

Note:

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.
  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default "epoch") in the scheduler configuration, Lightning will call the scheduler's .step() method automatically in case of automatic optimization.
  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.
  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.
  • If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
  • If you need to control how often the optimizer steps, override the optimizer_step() hook.
class GRUEncoder(torch.nn.modules.module.Module):
246class GRUEncoder(nn.Module):
247    def __init__(
248        self,
249        in_dim: int,
250        hidden_dim: int,
251        out_dim: int,
252    ):
253        super().__init__()
254
255        self.in_dim = in_dim
256        self.hidden_dim = hidden_dim
257        self.out_dim = out_dim
258
259    def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
260        out = torch.cat(list(x), dim=-1)
261        out = self.encoder(out)
262        return out

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

GRUEncoder(in_dim: int, hidden_dim: int, out_dim: int)
247    def __init__(
248        self,
249        in_dim: int,
250        hidden_dim: int,
251        out_dim: int,
252    ):
253        super().__init__()
254
255        self.in_dim = in_dim
256        self.hidden_dim = hidden_dim
257        self.out_dim = out_dim

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_dim
hidden_dim
out_dim
def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
259    def forward(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
260        out = torch.cat(list(x), dim=-1)
261        out = self.encoder(out)
262        return out

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class GRUTextDomainModule(shimmer.modules.domain.DomainModule):
265class GRUTextDomainModule(DomainModule):
266    in_dim = 768
267
268    def __init__(
269        self,
270        latent_dim: int,
271        hidden_dim: int,
272        vocab_size: int,
273        seq_length: int,
274        optim_lr: float = 1e-3,
275        optim_weight_decay: float = 0,
276        scheduler_args: SchedulerArgs | None = None,
277        padding_token: int = 0,
278    ):
279        super().__init__(latent_dim)
280        self.is_frozen = False
281        self.save_hyperparameters()
282
283        self.hidden_dim = hidden_dim
284        self.latent_dim = latent_dim
285        self.vocab_size = vocab_size
286        self.seq_length = seq_length
287
288        self._padding_token = padding_token
289
290        self.projector = nn.Sequential(
291            nn.Linear(self.in_dim, self.hidden_dim),
292            nn.ReLU(),
293            nn.Linear(self.hidden_dim, self.latent_dim),
294            nn.Tanh(),
295        )
296
297        self.embeddings = nn.Embedding(vocab_size, self.latent_dim)
298
299        self.num_layers = 2
300        self.decoder = nn.GRU(
301            self.latent_dim,
302            self.hidden_dim,
303            num_layers=self.num_layers,
304            batch_first=True,
305        )
306        self.text_head = nn.Linear(self.hidden_dim, self.vocab_size)
307
308        self.optim_lr = optim_lr
309        self.optim_weight_decay = optim_weight_decay
310
311        self.scheduler_args = SchedulerArgs(
312            max_lr=optim_lr,
313            total_steps=1,
314        )
315        self.scheduler_args.update(scheduler_args or {})
316
317    def compute_loss(
318        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
319    ) -> LossOutput:
320        text_token_loss, acc = self.text_token_loss(pred, raw_target)
321
322        return LossOutput(
323            2 * text_token_loss,
324            {"loss_tokens": text_token_loss, "pred_t_acc": acc},
325        )
326
327    def compute_domain_loss(self, domain: Any) -> LossOutput:
328        z = self.encode(domain)
329        loss, acc = self.text_token_loss(z, domain)
330        return LossOutput(loss, {"acc": acc})
331
332    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
333        return self.projector(x["bert"])
334
335    def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
336        context = z.unsqueeze(1)
337        pad_tokens = self.embeddings(
338            torch.zeros(
339                z.size(0), self.seq_length - 1, dtype=torch.long, device=z.device
340            )
341        )
342        seq = torch.cat([context, pad_tokens], dim=1)
343        for k in range(0, self.seq_length - 1):
344            out = self.decode_one(seq)
345            seq[:, k + 1] = self.embeddings(out["tokens"][:, k])
346        return self.decode_one(seq)
347
348    def decode_one(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
349        out, _ = self.decoder(z)
350        tokens_dist = self.text_head(out)
351        tokens = torch.argmax(tokens_dist, -1)
352        return {"token_dist": tokens_dist, "tokens": tokens}
353
354    def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
355        return self.decode(self.encode(x))
356
357    def text_token_loss(
358        self, z: torch.Tensor, target: Mapping[str, torch.Tensor]
359    ) -> tuple[torch.Tensor, torch.Tensor]:
360        context = z.unsqueeze(1)
361        real_tokens = self.embeddings(target["tokens"][:, :-1])
362        seq = torch.cat([context, real_tokens], dim=1)
363
364        out = self.decode_one(seq)
365        loss = F.cross_entropy(out["token_dist"].transpose(1, 2), target["tokens"])
366        padding_mask = target["tokens"] != self._padding_token
367
368        padded_out = out["tokens"][padding_mask]
369        padded_target = target["tokens"][padding_mask]
370
371        acc = (padded_out == padded_target).sum() / padded_out.size(0)
372        return loss, acc
373
374    def generic_step(
375        self, x: Mapping[str, torch.Tensor], mode: str = "train"
376    ) -> torch.Tensor:
377        z = self.encode(x)
378        loss, acc = self.text_token_loss(z, x)
379        self.log(f"{mode}/loss", loss)
380        self.log(f"{mode}/acc", acc)
381        return loss
382
383    def validation_step(  # type: ignore
384        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
385    ) -> torch.Tensor:
386        x = batch["t"]
387        return self.generic_step(x, "val")
388
389    def training_step(  # type: ignore
390        self,
391        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
392        _,
393    ) -> torch.Tensor:
394        x = batch[frozenset(["t"])]["t"]
395        return self.generic_step(x, "train")
396
397    def configure_optimizers(  # type: ignore
398        self,
399    ) -> dict[str, Any]:
400        optimizer = AdamW(
401            self.parameters(),
402            lr=self.optim_lr,
403            weight_decay=self.optim_weight_decay,
404        )
405
406        return {"optimizer": optimizer}

Base class for a DomainModule that defines domain specific modules of the GW.

GRUTextDomainModule( latent_dim: int, hidden_dim: int, vocab_size: int, seq_length: int, optim_lr: float = 0.001, optim_weight_decay: float = 0, scheduler_args: shimmer.modules.global_workspace.SchedulerArgs | None = None, padding_token: int = 0)
268    def __init__(
269        self,
270        latent_dim: int,
271        hidden_dim: int,
272        vocab_size: int,
273        seq_length: int,
274        optim_lr: float = 1e-3,
275        optim_weight_decay: float = 0,
276        scheduler_args: SchedulerArgs | None = None,
277        padding_token: int = 0,
278    ):
279        super().__init__(latent_dim)
280        self.is_frozen = False
281        self.save_hyperparameters()
282
283        self.hidden_dim = hidden_dim
284        self.latent_dim = latent_dim
285        self.vocab_size = vocab_size
286        self.seq_length = seq_length
287
288        self._padding_token = padding_token
289
290        self.projector = nn.Sequential(
291            nn.Linear(self.in_dim, self.hidden_dim),
292            nn.ReLU(),
293            nn.Linear(self.hidden_dim, self.latent_dim),
294            nn.Tanh(),
295        )
296
297        self.embeddings = nn.Embedding(vocab_size, self.latent_dim)
298
299        self.num_layers = 2
300        self.decoder = nn.GRU(
301            self.latent_dim,
302            self.hidden_dim,
303            num_layers=self.num_layers,
304            batch_first=True,
305        )
306        self.text_head = nn.Linear(self.hidden_dim, self.vocab_size)
307
308        self.optim_lr = optim_lr
309        self.optim_weight_decay = optim_weight_decay
310
311        self.scheduler_args = SchedulerArgs(
312            max_lr=optim_lr,
313            total_steps=1,
314        )
315        self.scheduler_args.update(scheduler_args or {})

Initializes a DomainModule.

Arguments:
  • latent_dim (int): latent dimension of the unimodal module
in_dim = 768
is_frozen

Whether the module is frozen. If None, it is frozen by default.

hidden_dim
latent_dim

The latent dimension of the module.

vocab_size
seq_length
projector
embeddings
num_layers
decoder
text_head
optim_lr
optim_weight_decay
scheduler_args
def compute_loss( self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> shimmer.modules.domain.LossOutput:
317    def compute_loss(
318        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
319    ) -> LossOutput:
320        text_token_loss, acc = self.text_token_loss(pred, raw_target)
321
322        return LossOutput(
323            2 * text_token_loss,
324            {"loss_tokens": text_token_loss, "pred_t_acc": acc},
325        )

Generic loss computation the modality.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
  • raw_target (Any): raw data from the input
Results:

LossOutput | None: LossOuput with training loss and additional metrics. If None is returned, this loss will be ignored and will not participate in the total loss.

def compute_domain_loss(self, domain: Any) -> shimmer.modules.domain.LossOutput:
327    def compute_domain_loss(self, domain: Any) -> LossOutput:
328        z = self.encode(domain)
329        loss, acc = self.text_token_loss(z, domain)
330        return LossOutput(loss, {"acc": acc})

Compute the unimodal domain loss.

Arguments:
  • domain (Any): domain input
Results:

LossOutput | None: LossOuput with training loss and additional metrics. If None is returned, this loss will be ignored and will not participate in the total loss.

def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
332    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
333        return self.projector(x["bert"])

Encode the domain data into a unimodal representation.

Arguments:
  • x (Any): data of the domain.
Returns:

torch.Tensor: a unimodal representation.

def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
335    def decode(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
336        context = z.unsqueeze(1)
337        pad_tokens = self.embeddings(
338            torch.zeros(
339                z.size(0), self.seq_length - 1, dtype=torch.long, device=z.device
340            )
341        )
342        seq = torch.cat([context, pad_tokens], dim=1)
343        for k in range(0, self.seq_length - 1):
344            out = self.decode_one(seq)
345            seq[:, k + 1] = self.embeddings(out["tokens"][:, k])
346        return self.decode_one(seq)

Decode data from unimodal representation back to the domain data.

Arguments:
  • z (torch.Tensor): unimodal representation of the domain.
Returns:

Any: the original domain data.

def decode_one(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
348    def decode_one(self, z: torch.Tensor) -> dict[str, torch.Tensor]:
349        out, _ = self.decoder(z)
350        tokens_dist = self.text_head(out)
351        tokens = torch.argmax(tokens_dist, -1)
352        return {"token_dist": tokens_dist, "tokens": tokens}
def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
354    def forward(self, x: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
355        return self.decode(self.encode(x))

Same as torch.nn.Module.forward().

Arguments:
  • *args: Whatever you decide to pass into the forward method.
  • **kwargs: Keyword arguments are also possible.
Return:

Your model's output

def text_token_loss( self, z: torch.Tensor, target: Mapping[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
357    def text_token_loss(
358        self, z: torch.Tensor, target: Mapping[str, torch.Tensor]
359    ) -> tuple[torch.Tensor, torch.Tensor]:
360        context = z.unsqueeze(1)
361        real_tokens = self.embeddings(target["tokens"][:, :-1])
362        seq = torch.cat([context, real_tokens], dim=1)
363
364        out = self.decode_one(seq)
365        loss = F.cross_entropy(out["token_dist"].transpose(1, 2), target["tokens"])
366        padding_mask = target["tokens"] != self._padding_token
367
368        padded_out = out["tokens"][padding_mask]
369        padded_target = target["tokens"][padding_mask]
370
371        acc = (padded_out == padded_target).sum() / padded_out.size(0)
372        return loss, acc
def generic_step(self, x: Mapping[str, torch.Tensor], mode: str = 'train') -> torch.Tensor:
374    def generic_step(
375        self, x: Mapping[str, torch.Tensor], mode: str = "train"
376    ) -> torch.Tensor:
377        z = self.encode(x)
378        loss, acc = self.text_token_loss(z, x)
379        self.log(f"{mode}/loss", loss)
380        self.log(f"{mode}/acc", acc)
381        return loss
def validation_step(self, batch: Mapping[str, Mapping[str, torch.Tensor]], _) -> torch.Tensor:
383    def validation_step(  # type: ignore
384        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
385    ) -> torch.Tensor:
386        x = batch["t"]
387        return self.generic_step(x, "val")

Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary. Can include any keys, but must include the key 'loss'.
  • None - Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples::

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...
Note:

If you don't need to validate you don't need to implement this method.

Note:

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

def training_step( self, batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], _) -> torch.Tensor:
389    def training_step(  # type: ignore
390        self,
391        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
392        _,
393    ) -> torch.Tensor:
394        x = batch[frozenset(["t"])]["t"]
395        return self.generic_step(x, "train")

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.
  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example::

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()
Note:

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

def configure_optimizers(self) -> dict[str, typing.Any]:
397    def configure_optimizers(  # type: ignore
398        self,
399    ) -> dict[str, Any]:
400        optimizer = AdamW(
401            self.parameters(),
402            lr=self.optim_lr,
403            weight_decay=self.optim_weight_decay,
404        )
405
406        return {"optimizer": optimizer}

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Return:

Any of these 6 options.

  • Single optimizer.
  • List or Tuple of optimizers.
  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).
  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.
  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

.. testcode::

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated",
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your ~lightning.pytorch.core.LightningModule.

Note:

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.
  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default "epoch") in the scheduler configuration, Lightning will call the scheduler's .step() method automatically in case of automatic optimization.
  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.
  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.
  • If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
  • If you need to control how often the optimizer steps, override the optimizer_step() hook.
class Text2Attr(shimmer.modules.domain.DomainModule):
409class Text2Attr(DomainModule):
410    def __init__(
411        self,
412        latent_dim: int,
413        hidden_dim: int,
414        text_model: GRUTextDomainModule,
415        optim_lr: float = 1e-3,
416        optim_weight_decay: float = 0,
417        scheduler_args: SchedulerArgs | None = None,
418    ) -> None:
419        super().__init__(latent_dim)
420        self.save_hyperparameters(ignore=["text_model"])
421
422        self.text_model = text_model
423
424        self.pred_attr = nn.Sequential(
425            nn.Linear(latent_dim, hidden_dim),
426            nn.ReLU(),
427            nn.Linear(hidden_dim, hidden_dim),
428            nn.ReLU(),
429            nn.Linear(hidden_dim, hidden_dim),
430            nn.ReLU(),
431            nn.Linear(hidden_dim, 8),
432            nn.Tanh(),
433        )
434        self.pred_cat = nn.Sequential(
435            nn.Linear(latent_dim, hidden_dim),
436            nn.ReLU(),
437            nn.Linear(hidden_dim, hidden_dim),
438            nn.ReLU(),
439            nn.Linear(hidden_dim, hidden_dim),
440            nn.ReLU(),
441            nn.Linear(hidden_dim, 3),
442            nn.Softmax(dim=1),
443        )
444
445        self.optim_lr = optim_lr
446        self.optim_weight_decay = optim_weight_decay
447
448        self.scheduler_args = SchedulerArgs(
449            max_lr=optim_lr,
450            total_steps=1,
451        )
452        self.scheduler_args.update(scheduler_args or {})
453
454    def compute_loss(
455        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
456    ) -> LossOutput:
457        pred_attr = self.pred_attr(pred)
458        pred_cat = self.pred_cat(pred)
459        target_attr = self.pred_attr(target)
460        target_cat = self.pred_cat(target)
461        loss_attr = F.mse_loss(pred_attr, target_attr)
462        loss_cat = F.cross_entropy(pred_cat, target_cat.argmax(dim=1))
463        loss = loss_attr + loss_cat
464        return LossOutput(loss, {"attr": loss_attr, "cat": loss_cat})
465
466    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
467        return self.text_model.encode(x)
468
469    def decode(self, z: torch.Tensor) -> dict[str, Any]:
470        out: dict[str, Any] = {}
471        out.update(self.text_model.decode(z))
472        pred_attr = self.pred_attr(z)
473        pred_cat = self.pred_cat(z)
474        out.update({"attr": [pred_cat, pred_attr, pred_attr]})
475        return out
476
477    def forward(self, x: Mapping[str, Any]) -> dict[str, list[torch.Tensor]]:
478        return self.decode(self.encode(x))
479
480    def generic_step(
481        self,
482        x: Mapping[str, Any],
483        mode: str = "train",
484    ) -> torch.Tensor:
485        text: Mapping[str, torch.Tensor] = x["t"]
486        attr: Sequence[torch.Tensor] = x["attr"]
487        text_l = self.text_model.encode(text)
488        pred_attr = self.pred_attr(text_l)
489        pred_cat = self.pred_cat(text_l)
490        cats = torch.argmax(attr[0], dim=1)
491        loss_cat = F.cross_entropy(pred_cat, cats)
492        loss_attr = F.mse_loss(pred_attr, attr[1])
493        total_loss = loss_cat + loss_attr
494        pred_cats = pred_cat.argmax(dim=1)
495        acc = (cats == pred_cats).sum() / cats.size(0)
496
497        self.log(f"{mode}/loss_cat", loss_cat)
498        self.log(f"{mode}/loss_attr", loss_attr)
499        self.log(f"{mode}/acc_cat", acc)
500        self.log(f"{mode}/loss", total_loss)
501
502        return total_loss
503
504    def validation_step(  # type: ignore
505        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
506    ) -> torch.Tensor:
507        return self.generic_step(batch, "val")
508
509    def training_step(  # type: ignore
510        self,
511        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
512        _,
513    ) -> torch.Tensor:
514        data = batch[frozenset(["t", "attr"])]
515        return self.generic_step(data, "train")
516
517    def configure_optimizers(  # type: ignore
518        self,
519    ) -> dict[str, Any]:
520        optimizer = AdamW(
521            self.parameters(),
522            lr=self.optim_lr,
523            weight_decay=self.optim_weight_decay,
524        )
525
526        return {"optimizer": optimizer}

Base class for a DomainModule that defines domain specific modules of the GW.

Text2Attr( latent_dim: int, hidden_dim: int, text_model: GRUTextDomainModule, optim_lr: float = 0.001, optim_weight_decay: float = 0, scheduler_args: shimmer.modules.global_workspace.SchedulerArgs | None = None)
410    def __init__(
411        self,
412        latent_dim: int,
413        hidden_dim: int,
414        text_model: GRUTextDomainModule,
415        optim_lr: float = 1e-3,
416        optim_weight_decay: float = 0,
417        scheduler_args: SchedulerArgs | None = None,
418    ) -> None:
419        super().__init__(latent_dim)
420        self.save_hyperparameters(ignore=["text_model"])
421
422        self.text_model = text_model
423
424        self.pred_attr = nn.Sequential(
425            nn.Linear(latent_dim, hidden_dim),
426            nn.ReLU(),
427            nn.Linear(hidden_dim, hidden_dim),
428            nn.ReLU(),
429            nn.Linear(hidden_dim, hidden_dim),
430            nn.ReLU(),
431            nn.Linear(hidden_dim, 8),
432            nn.Tanh(),
433        )
434        self.pred_cat = nn.Sequential(
435            nn.Linear(latent_dim, hidden_dim),
436            nn.ReLU(),
437            nn.Linear(hidden_dim, hidden_dim),
438            nn.ReLU(),
439            nn.Linear(hidden_dim, hidden_dim),
440            nn.ReLU(),
441            nn.Linear(hidden_dim, 3),
442            nn.Softmax(dim=1),
443        )
444
445        self.optim_lr = optim_lr
446        self.optim_weight_decay = optim_weight_decay
447
448        self.scheduler_args = SchedulerArgs(
449            max_lr=optim_lr,
450            total_steps=1,
451        )
452        self.scheduler_args.update(scheduler_args or {})

Initializes a DomainModule.

Arguments:
  • latent_dim (int): latent dimension of the unimodal module
text_model
pred_attr
pred_cat
optim_lr
optim_weight_decay
scheduler_args
def compute_loss( self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> shimmer.modules.domain.LossOutput:
454    def compute_loss(
455        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
456    ) -> LossOutput:
457        pred_attr = self.pred_attr(pred)
458        pred_cat = self.pred_cat(pred)
459        target_attr = self.pred_attr(target)
460        target_cat = self.pred_cat(target)
461        loss_attr = F.mse_loss(pred_attr, target_attr)
462        loss_cat = F.cross_entropy(pred_cat, target_cat.argmax(dim=1))
463        loss = loss_attr + loss_cat
464        return LossOutput(loss, {"attr": loss_attr, "cat": loss_cat})

Generic loss computation the modality.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
  • raw_target (Any): raw data from the input
Results:

LossOutput | None: LossOuput with training loss and additional metrics. If None is returned, this loss will be ignored and will not participate in the total loss.

def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
466    def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
467        return self.text_model.encode(x)

Encode the domain data into a unimodal representation.

Arguments:
  • x (Any): data of the domain.
Returns:

torch.Tensor: a unimodal representation.

def decode(self, z: torch.Tensor) -> dict[str, typing.Any]:
469    def decode(self, z: torch.Tensor) -> dict[str, Any]:
470        out: dict[str, Any] = {}
471        out.update(self.text_model.decode(z))
472        pred_attr = self.pred_attr(z)
473        pred_cat = self.pred_cat(z)
474        out.update({"attr": [pred_cat, pred_attr, pred_attr]})
475        return out

Decode data from unimodal representation back to the domain data.

Arguments:
  • z (torch.Tensor): unimodal representation of the domain.
Returns:

Any: the original domain data.

def forward(self, x: Mapping[str, typing.Any]) -> dict[str, list[torch.Tensor]]:
477    def forward(self, x: Mapping[str, Any]) -> dict[str, list[torch.Tensor]]:
478        return self.decode(self.encode(x))

Same as torch.nn.Module.forward().

Arguments:
  • *args: Whatever you decide to pass into the forward method.
  • **kwargs: Keyword arguments are also possible.
Return:

Your model's output

def generic_step(self, x: Mapping[str, typing.Any], mode: str = 'train') -> torch.Tensor:
480    def generic_step(
481        self,
482        x: Mapping[str, Any],
483        mode: str = "train",
484    ) -> torch.Tensor:
485        text: Mapping[str, torch.Tensor] = x["t"]
486        attr: Sequence[torch.Tensor] = x["attr"]
487        text_l = self.text_model.encode(text)
488        pred_attr = self.pred_attr(text_l)
489        pred_cat = self.pred_cat(text_l)
490        cats = torch.argmax(attr[0], dim=1)
491        loss_cat = F.cross_entropy(pred_cat, cats)
492        loss_attr = F.mse_loss(pred_attr, attr[1])
493        total_loss = loss_cat + loss_attr
494        pred_cats = pred_cat.argmax(dim=1)
495        acc = (cats == pred_cats).sum() / cats.size(0)
496
497        self.log(f"{mode}/loss_cat", loss_cat)
498        self.log(f"{mode}/loss_attr", loss_attr)
499        self.log(f"{mode}/acc_cat", acc)
500        self.log(f"{mode}/loss", total_loss)
501
502        return total_loss
def validation_step(self, batch: Mapping[str, Mapping[str, torch.Tensor]], _) -> torch.Tensor:
504    def validation_step(  # type: ignore
505        self, batch: Mapping[str, Mapping[str, torch.Tensor]], _
506    ) -> torch.Tensor:
507        return self.generic_step(batch, "val")

Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary. Can include any keys, but must include the key 'loss'.
  • None - Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples::

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...
Note:

If you don't need to validate you don't need to implement this method.

Note:

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

def training_step( self, batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]], _) -> torch.Tensor:
509    def training_step(  # type: ignore
510        self,
511        batch: Mapping[frozenset[str], Mapping[str, Mapping[str, torch.Tensor]]],
512        _,
513    ) -> torch.Tensor:
514        data = batch[frozenset(["t", "attr"])]
515        return self.generic_step(data, "train")

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.
  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example::

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()
Note:

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

def configure_optimizers(self) -> dict[str, typing.Any]:
517    def configure_optimizers(  # type: ignore
518        self,
519    ) -> dict[str, Any]:
520        optimizer = AdamW(
521            self.parameters(),
522            lr=self.optim_lr,
523            weight_decay=self.optim_weight_decay,
524        )
525
526        return {"optimizer": optimizer}

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Return:

Any of these 6 options.

  • Single optimizer.
  • List or Tuple of optimizers.
  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).
  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.
  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

.. testcode::

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated",
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your ~lightning.pytorch.core.LightningModule.

Note:

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.
  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default "epoch") in the scheduler configuration, Lightning will call the scheduler's .step() method automatically in case of automatic optimization.
  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.
  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.
  • If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
  • If you need to control how often the optimizer steps, override the optimizer_step() hook.