shimmer_ssd.logging

  1import io
  2from abc import ABC, abstractmethod
  3from collections.abc import Mapping, Sequence
  4from typing import Any, Generic, Literal, TypeVar, cast
  5
  6import lightning.pytorch as pl
  7import matplotlib
  8import matplotlib.pyplot as plt
  9import numpy as np
 10import torch
 11from lightning.pytorch.loggers import Logger, TensorBoardLogger
 12from lightning.pytorch.loggers.wandb import WandbLogger
 13from matplotlib import gridspec
 14from matplotlib.figure import Figure
 15from PIL import Image
 16from shimmer.modules.global_workspace import GlobalWorkspaceBase, GWPredictionsBase
 17from simple_shapes_dataset import (
 18    UnnormalizeAttributes,
 19    tensor_to_attribute,
 20)
 21from simple_shapes_dataset.cli import generate_image
 22from tokenizers.implementations import ByteLevelBPETokenizer
 23from torchvision.transforms.functional import to_tensor
 24from torchvision.utils import make_grid
 25
 26from shimmer_ssd import LOGGER
 27from shimmer_ssd.modules.domains.text import GRUTextDomainModule, Text2Attr
 28from shimmer_ssd.modules.domains.visual import VisualLatentDomainModule
 29
 30matplotlib.use("Agg")
 31
 32_T = TypeVar("_T")
 33
 34
 35def log_image(
 36    logger: Logger,
 37    key: str,
 38    image: torch.Tensor | Image.Image,
 39    tensorboard_step: int | None = None,
 40):
 41    if isinstance(logger, WandbLogger):
 42        logger.log_image(key, [image])
 43    elif isinstance(logger, TensorBoardLogger):
 44        torch_image = to_tensor(image) if isinstance(image, Image.Image) else image
 45        logger.experiment.add_image(key, torch_image, tensorboard_step)
 46    else:
 47        LOGGER.warning(
 48            "[Sample Logger] Only logging to tensorboard or wandb is supported"
 49        )
 50        return
 51
 52
 53def log_text(
 54    logger: Logger,
 55    key: str,
 56    columns: list[str],
 57    data: list[list[str]],
 58    tensorboard_step: int | None = None,
 59):
 60    if isinstance(logger, WandbLogger):
 61        logger.log_text(key, columns, data)
 62    elif isinstance(logger, TensorBoardLogger):
 63        text = ", ".join(columns) + "\n"
 64        text += "\n".join([", ".join(d) for d in data])
 65        logger.experiment.add_text(key, text, tensorboard_step)
 66    else:
 67        LOGGER.warning(
 68            "[Sample Logger] Only logging to tensorboard or wandb is supported"
 69        )
 70        return
 71
 72
 73class LogSamplesCallback(Generic[_T], ABC, pl.Callback):
 74    def __init__(
 75        self,
 76        reference_samples: _T,
 77        log_key: str,
 78        mode: Literal["train", "val", "test"],
 79        every_n_epochs: int | None = 1,
 80    ) -> None:
 81        super().__init__()
 82        self.reference_samples = reference_samples
 83        self.every_n_epochs = every_n_epochs
 84        self.log_key = log_key
 85        self.mode = mode
 86        self._global_step = 0
 87
 88    def get_step(self) -> int:
 89        self._global_step += 1
 90        return self._global_step - 1
 91
 92    def to(self, samples: _T, device: torch.device) -> _T:
 93        raise NotImplementedError
 94
 95    def setup(
 96        self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
 97    ) -> None:
 98        if stage != "fit":
 99            return
100        device = trainer.strategy.root_device
101        self.reference_samples = self.to(self.reference_samples, device)
102        for logger in trainer.loggers:
103            self.log_samples(logger, self.reference_samples, "reference")
104
105    def on_callback(
106        self,
107        loggers: Sequence[Logger],
108        pl_module: pl.LightningModule,
109    ) -> None:
110        if not len(loggers):
111            LOGGER.debug("[LOGGER] No logger found.")
112            return
113
114        samples = self.to(self.reference_samples, pl_module.device)
115
116        with torch.no_grad():
117            pl_module.eval()
118            generated_samples = pl_module(samples)
119            pl_module.train()
120
121        for logger in loggers:
122            self.log_samples(logger, generated_samples, "prediction")
123
124    def on_train_epoch_end(
125        self, trainer: pl.Trainer, pl_module: pl.LightningModule
126    ) -> None:
127        if self.mode != "train":
128            return
129
130        if (
131            self.every_n_epochs is None
132            or trainer.current_epoch % self.every_n_epochs != 0
133        ):
134            LOGGER.debug("[LOGGER] on_train_epoch_end")
135            return
136
137        LOGGER.debug("[LOGGER] on_train_epoch_end called")
138        return self.on_callback(trainer.loggers, pl_module)
139
140    def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
141        if self.mode == "test":
142            return
143
144        return self.on_callback(trainer.loggers, pl_module)
145
146    def on_validation_epoch_end(
147        self, trainer: pl.Trainer, pl_module: pl.LightningModule
148    ) -> None:
149        if self.mode != "val":
150            return
151
152        if (
153            self.every_n_epochs is None
154            or trainer.current_epoch % self.every_n_epochs != 0
155        ):
156            return
157
158        return self.on_callback(trainer.loggers, pl_module)
159
160    def on_test_epoch_end(
161        self, trainer: pl.Trainer, pl_module: pl.LightningModule
162    ) -> None:
163        if self.mode != "test":
164            return
165
166        return self.on_callback(trainer.loggers, pl_module)
167
168    @abstractmethod
169    def log_samples(self, logger: Logger, samples: _T, mode: str) -> None: ...
170
171
172def get_pil_image(figure: Figure) -> Image.Image:
173    buf = io.BytesIO()
174    figure.savefig(buf)
175    buf.seek(0)
176    return Image.open(buf)
177
178
179def get_attribute_figure_grid(
180    categories: np.ndarray,
181    locations: np.ndarray,
182    sizes: np.ndarray,
183    rotations: np.ndarray,
184    colors: np.ndarray,
185    image_size: int,
186    ncols: int = 8,
187    padding: float = 2,
188) -> Image.Image:
189    reminder = 1 if categories.shape[0] % ncols else 0
190    nrows = categories.shape[0] // ncols + reminder
191
192    width = ncols * (image_size + padding) + padding
193    height = nrows * (image_size + padding) + padding
194    dpi = 1
195
196    figure = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi, facecolor="white")
197    gs = gridspec.GridSpec(
198        nrows,
199        ncols,
200        wspace=padding / image_size,
201        hspace=padding / image_size,
202        left=padding / width,
203        right=1 - padding / width,
204        bottom=padding / height,
205        top=1 - padding / height,
206    )
207    for i in range(nrows):
208        for j in range(ncols):
209            k = i * ncols + j
210            if k >= categories.shape[0]:
211                break
212            ax = plt.subplot(gs[i, j])
213            generate_image(
214                ax,
215                categories[k],
216                locations[k],
217                sizes[k],
218                rotations[k],
219                colors[k],
220                image_size,
221            )
222            ax.set_facecolor("black")
223    image = get_pil_image(figure)
224    plt.close(figure)
225    return image
226
227
228def attribute_image_grid(
229    samples: Sequence[torch.Tensor],
230    image_size: int,
231    ncols: int,
232) -> Image.Image:
233    unnormalizer = UnnormalizeAttributes(image_size=image_size)
234    attributes = unnormalizer(tensor_to_attribute(samples))
235
236    categories = attributes.category.detach().cpu().numpy()
237    locations = torch.stack([attributes.x, attributes.y], dim=1).detach().cpu().numpy()
238    colors = (
239        (
240            torch.stack(
241                [
242                    attributes.color_r,
243                    attributes.color_g,
244                    attributes.color_b,
245                ],
246                dim=1,
247            )
248        )
249        .cpu()
250        .numpy()
251    )
252    sizes = attributes.size.detach().cpu().numpy()
253    rotations = attributes.rotation.detach().cpu().numpy()
254
255    return get_attribute_figure_grid(
256        categories,
257        locations,
258        sizes,
259        rotations,
260        colors,
261        image_size,
262        ncols,
263        padding=2,
264    )
265
266
267class LogAttributesCallback(LogSamplesCallback[Sequence[torch.Tensor]]):
268    def __init__(
269        self,
270        reference_samples: Sequence[torch.Tensor],
271        log_key: str,
272        mode: Literal["train", "val", "test"],
273        image_size: int,
274        every_n_epochs: int | None = 1,
275        ncols: int = 8,
276    ) -> None:
277        super().__init__(reference_samples, log_key, mode, every_n_epochs)
278        self.image_size = image_size
279        self.ncols = ncols
280
281    def to(
282        self, samples: Sequence[torch.Tensor], device: torch.device
283    ) -> list[torch.Tensor]:
284        return [x.to(device) for x in samples]
285
286    def log_samples(
287        self, logger: Logger, samples: Sequence[torch.Tensor], mode: str
288    ) -> None:
289        image = attribute_image_grid(
290            samples, image_size=self.image_size, ncols=self.ncols
291        )
292        log_image(logger, f"{self.log_key}_{mode}", image, self.get_step())
293
294
295class LogTextCallback(LogSamplesCallback[Mapping[str, torch.Tensor]]):
296    def __init__(
297        self,
298        reference_samples: Mapping[str, torch.Tensor],
299        log_key: str,
300        mode: Literal["train", "val", "test"],
301        image_size: int,
302        vocab: str,
303        merges: str,
304        every_n_epochs: int | None = 1,
305        ncols: int = 8,
306    ) -> None:
307        super().__init__(reference_samples, log_key, mode, every_n_epochs)
308        self.image_size = image_size
309        self.ncols = ncols
310        self.tokenizer = ByteLevelBPETokenizer(vocab, merges)
311
312    def to(
313        self, samples: Mapping[str, torch.Tensor], device: torch.device
314    ) -> dict[str, torch.Tensor]:
315        return {x: samples[x].to(device) for x in samples}
316
317    def setup(
318        self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
319    ) -> None:
320        if stage != "fit":
321            return
322        assert isinstance(pl_module, GRUTextDomainModule)
323        device = trainer.strategy.root_device
324        self.reference_samples = self.to(self.reference_samples, device)
325        for logger in trainer.loggers:
326            self.log_samples(logger, self.reference_samples, "reference")
327
328    def on_callback(
329        self,
330        loggers: Sequence[Logger],
331        pl_module: pl.LightningModule,
332    ) -> None:
333        assert isinstance(pl_module, GRUTextDomainModule)
334
335        samples = self.to(self.reference_samples, pl_module.device)
336
337        if not len(loggers):
338            LOGGER.debug("[LOGGER] No logger found.")
339            return
340
341        with torch.no_grad():
342            pl_module.eval()
343            generated_samples = pl_module(samples)
344            pl_module.train()
345
346        for logger in loggers:
347            self.log_samples(logger, generated_samples, "prediction")
348
349    def log_samples(
350        self, logger: Logger, samples: Mapping[str, torch.Tensor], mode: str
351    ) -> None:
352        if not isinstance(logger, WandbLogger):
353            LOGGER.warning("Only logging to wandb is supported")
354            return
355
356        assert self.tokenizer is not None
357        text = self.tokenizer.decode_batch(
358            samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True
359        )
360        text = [[t.replace("<pad>", "")] for t in text]
361        log_text(logger, f"{self.log_key}_{mode}_str", ["text"], text, self.get_step())
362
363
364class LogVisualCallback(LogSamplesCallback[torch.Tensor]):
365    def __init__(
366        self,
367        reference_samples: torch.Tensor,
368        log_key: str,
369        mode: Literal["train", "val", "test"],
370        every_n_epochs: int | None = 1,
371        ncols: int = 8,
372    ) -> None:
373        super().__init__(reference_samples, log_key, mode, every_n_epochs)
374        self.ncols = ncols
375
376    def to(self, samples: torch.Tensor, device: torch.device) -> torch.Tensor:
377        return samples.to(device)
378
379    def log_samples(self, logger: Logger, samples: torch.Tensor, mode: str) -> None:
380        images = make_grid(samples, nrow=self.ncols, pad_value=1)
381        log_image(logger, f"{self.log_key}_{mode}", images)
382
383
384class LogText2AttrCallback(
385    LogSamplesCallback[
386        Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]]
387    ]
388):
389    def __init__(
390        self,
391        reference_samples: Mapping[
392            str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]
393        ],
394        log_key: str,
395        mode: Literal["train", "val", "test"],
396        every_n_epochs: int | None = 1,
397        image_size: int = 32,
398        ncols: int = 8,
399        vocab: str | None = None,
400        merges: str | None = None,
401    ) -> None:
402        super().__init__(reference_samples, log_key, mode, every_n_epochs)
403        self.image_size = image_size
404        self.ncols = ncols
405        self.tokenizer = ByteLevelBPETokenizer(vocab, merges)
406        self.reference_samples = reference_samples
407
408    def to(
409        self,
410        samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]],
411        device: torch.device,
412    ) -> dict[str, dict[str, torch.Tensor] | list[torch.Tensor]]:
413        latents: dict[str, dict[str, torch.Tensor] | list[torch.Tensor]] = {}
414        for domain_name, domain in samples.items():
415            if isinstance(domain, dict):
416                latents[domain_name] = {k: x.to(device) for k, x in domain.items()}
417            elif isinstance(domain, list):
418                latents[domain_name] = [x.to(device) for x in domain]
419        return latents
420
421    def setup(
422        self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
423    ) -> None:
424        if stage != "fit":
425            return
426        assert isinstance(pl_module, Text2Attr)
427        device = trainer.strategy.root_device
428        self.reference_samples = self.to(self.reference_samples, device)
429        for logger in trainer.loggers:
430            self.log_samples(logger, self.reference_samples, "reference")
431
432    def on_callback(
433        self,
434        loggers: Sequence[Logger],
435        pl_module: pl.LightningModule,
436    ) -> None:
437        assert isinstance(pl_module, Text2Attr)
438
439        samples = self.to(self.reference_samples, pl_module.device)
440
441        if not len(loggers):
442            LOGGER.debug("[LOGGER] No logger found.")
443            return
444
445        with torch.no_grad():
446            pl_module.eval()
447            generated_samples = pl_module(samples["t"])
448            pl_module.train()
449
450        for logger in loggers:
451            self.log_samples(logger, generated_samples, "prediction")
452
453    def log_samples(
454        self,
455        logger: Logger,
456        samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]],
457        mode: str,
458    ) -> None:
459        for domain_name, domain in samples.items():
460            if domain_name == "t":
461                assert self.tokenizer is not None
462                assert isinstance(domain, dict)
463                text = self.tokenizer.decode_batch(
464                    domain["tokens"].detach().cpu().tolist(), skip_special_tokens=True
465                )
466                text = [[t.replace("<pad>", "")] for t in text]
467                log_text(
468                    logger,
469                    f"{self.log_key}_{mode}_str",
470                    ["text"],
471                    text,
472                    self.get_step(),
473                )
474            elif domain_name == "attr":
475                assert isinstance(domain, list)
476                image = attribute_image_grid(
477                    domain,
478                    image_size=self.image_size,
479                    ncols=self.ncols,
480                )
481                log_image(logger, f"{self.log_key}_{mode}", image, self.get_step())
482
483
484def batch_to_device(
485    samples: Mapping[
486        frozenset[str],
487        Mapping[str, Any],
488    ],
489    device: torch.device,
490) -> dict[frozenset[str], dict[str, Any]]:
491    out: dict[frozenset[str], dict[str, Any]] = {}
492    for domain_names, domains in samples.items():
493        latents: dict[str, Any] = {}
494        for domain_name, domain in domains.items():
495            if isinstance(domain, torch.Tensor):
496                latents[domain_name] = domain.to(device)
497            elif (
498                isinstance(domain, Mapping)
499                and len(domain)
500                and isinstance(next(iter(domain.values())), torch.Tensor)
501            ):
502                latents[domain_name] = {k: x.to(device) for k, x in domain.items()}
503            elif (
504                isinstance(domain, Sequence)
505                and len(domain)
506                and isinstance(domain[0], torch.Tensor)
507            ):
508                latents[domain_name] = [x.to(device) for x in domain]
509            else:
510                latents[domain_name] = domain
511        out[domain_names] = latents
512    return out
513
514
515class LogGWImagesCallback(pl.Callback):
516    def __init__(
517        self,
518        reference_samples: Mapping[frozenset[str], Mapping[str, Any]],
519        log_key: str,
520        mode: Literal["train", "val", "test"],
521        every_n_epochs: int | None = 1,
522        image_size: int = 32,
523        ncols: int = 8,
524        filter: Sequence[str] | None = None,
525        vocab: str | None = None,
526        merges: str | None = None,
527    ) -> None:
528        super().__init__()
529        self.mode = mode
530        self.reference_samples = reference_samples
531        self.every_n_epochs = every_n_epochs
532        self.log_key = log_key
533        self.image_size = image_size
534        self.ncols = ncols
535        self.filter = filter
536        self.tokenizer = None
537        if vocab is not None and merges is not None:
538            self.tokenizer = ByteLevelBPETokenizer(vocab, merges)
539        self._global_step = 0
540
541    def get_step(self):
542        self._global_step += 1
543        return self._global_step - 1
544
545    def to(
546        self,
547        samples: Mapping[
548            frozenset[str],
549            Mapping[
550                str, torch.Tensor | Sequence[torch.Tensor] | Mapping[str, torch.Tensor]
551            ],
552        ],
553        device: torch.device,
554    ) -> dict[
555        frozenset[str],
556        dict[str, torch.Tensor | list[torch.Tensor] | dict[Any, torch.Tensor]],
557    ]:
558        return batch_to_device(samples, device)
559
560    def setup(
561        self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
562    ) -> None:
563        if stage != "fit":
564            return
565        assert isinstance(pl_module, GlobalWorkspaceBase)
566        device = trainer.strategy.root_device
567        self.reference_samples = self.to(self.reference_samples, device)
568
569        for domain_names, domains in self.reference_samples.items():
570            for domain_name, domain_tensor in domains.items():
571                for logger in trainer.loggers:
572                    self.log_samples(
573                        logger,
574                        pl_module,
575                        domain_tensor,
576                        domain_name,
577                        f"ref_{'-'.join(domain_names)}_{domain_name}",
578                    )
579
580    def on_callback(
581        self,
582        loggers: Sequence[Logger],
583        pl_module: GlobalWorkspaceBase,
584    ) -> None:
585        if not (len(loggers)):
586            return
587
588        with torch.no_grad():
589            latent_groups = pl_module.encode_domains(self.reference_samples)
590            predictions = cast(GWPredictionsBase, pl_module(latent_groups))
591
592            for logger in loggers:
593                for domains, preds in predictions["broadcasts"].items():
594                    domain_from = ",".join(domains)
595                    for domain, pred in preds.items():
596                        log_name = f"pred_trans_{domain_from}_to_{domain}"
597                        if self.filter is not None and log_name not in self.filter:
598                            continue
599                        self.log_samples(
600                            logger,
601                            pl_module,
602                            pl_module.decode_domain(pred, domain),
603                            domain,
604                            log_name,
605                        )
606                for domains, preds in predictions["cycles"].items():
607                    domain_from = ",".join(domains)
608                    for domain, pred in preds.items():
609                        log_name = f"pred_cycle_{domain_from}_to_{domain}"
610                        if self.filter is not None and log_name not in self.filter:
611                            continue
612                        self.log_samples(
613                            logger,
614                            pl_module,
615                            pl_module.decode_domain(pred, domain),
616                            domain,
617                            log_name,
618                        )
619
620    def on_train_epoch_end(
621        self,
622        trainer: pl.Trainer,
623        pl_module: pl.LightningModule,
624    ) -> None:
625        if self.mode != "train":
626            return
627
628        if not isinstance(pl_module, GlobalWorkspaceBase):
629            return
630
631        if (
632            self.every_n_epochs is None
633            or trainer.current_epoch % self.every_n_epochs != 0
634        ):
635            return
636
637        return self.on_callback(trainer.loggers, pl_module)
638
639    def on_validation_epoch_end(
640        self,
641        trainer: pl.Trainer,
642        pl_module: pl.LightningModule,
643    ) -> None:
644        if self.mode != "val":
645            return
646
647        if not isinstance(pl_module, GlobalWorkspaceBase):
648            return
649
650        if (
651            self.every_n_epochs is None
652            or trainer.current_epoch % self.every_n_epochs != 0
653        ):
654            return
655
656        return self.on_callback(trainer.loggers, pl_module)
657
658    def on_test_epoch_end(
659        self,
660        trainer: pl.Trainer,
661        pl_module: pl.LightningModule,
662    ) -> None:
663        if self.mode != "test":
664            return
665
666        if not isinstance(pl_module, GlobalWorkspaceBase):
667            return
668
669        return self.on_callback(trainer.loggers, pl_module)
670
671    def on_train_end(
672        self,
673        trainer: pl.Trainer,
674        pl_module: pl.LightningModule,
675    ) -> None:
676        if self.mode == "test":
677            return
678
679        if not isinstance(pl_module, GlobalWorkspaceBase):
680            return
681
682        return self.on_callback(trainer.loggers, pl_module)
683
684    def log_samples(
685        self,
686        logger: Logger,
687        pl_module: GlobalWorkspaceBase,
688        samples: Any,
689        domain: str,
690        mode: str,
691    ) -> None:
692        match domain:
693            case "v":
694                self.log_visual_samples(logger, samples, mode)
695            case "v_latents":
696                assert "v_latents" in pl_module.domain_mods
697
698                module = cast(
699                    VisualLatentDomainModule,
700                    pl_module.domain_mods["v_latents"],
701                )
702                self.log_visual_samples(logger, module.decode_images(samples), mode)
703            case "attr":
704                self.log_attribute_samples(logger, samples, mode)
705            case "t":
706                self.log_text_samples(logger, samples, mode)
707                if "attr" in samples:
708                    self.log_attribute_samples(logger, samples["attr"], mode + "_attr")
709
710    def log_visual_samples(
711        self,
712        logger: Logger,
713        samples: Any,
714        mode: str,
715    ) -> None:
716        images = make_grid(samples, nrow=self.ncols, pad_value=1)
717        log_image(logger, f"{self.log_key}/{mode}", images, self.get_step())
718
719    def log_attribute_samples(
720        self,
721        logger: Logger,
722        samples: Any,
723        mode: str,
724    ) -> None:
725        image = attribute_image_grid(
726            samples,
727            image_size=self.image_size,
728            ncols=self.ncols,
729        )
730        log_image(logger, f"{self.log_key}/{mode}", image, self.get_step())
731
732    def log_text_samples(
733        self,
734        logger: Logger,
735        samples: Any,
736        mode: str,
737    ) -> None:
738        assert self.tokenizer is not None
739        text = self.tokenizer.decode_batch(
740            samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True
741        )
742        text = [[t.replace("<pad>", "")] for t in text]
743        log_text(logger, f"{self.log_key}/{mode}", ["text"], text, self.get_step())
def log_image( logger: lightning.pytorch.loggers.logger.Logger, key: str, image: torch.Tensor | PIL.Image.Image, tensorboard_step: int | None = None):
36def log_image(
37    logger: Logger,
38    key: str,
39    image: torch.Tensor | Image.Image,
40    tensorboard_step: int | None = None,
41):
42    if isinstance(logger, WandbLogger):
43        logger.log_image(key, [image])
44    elif isinstance(logger, TensorBoardLogger):
45        torch_image = to_tensor(image) if isinstance(image, Image.Image) else image
46        logger.experiment.add_image(key, torch_image, tensorboard_step)
47    else:
48        LOGGER.warning(
49            "[Sample Logger] Only logging to tensorboard or wandb is supported"
50        )
51        return
def log_text( logger: lightning.pytorch.loggers.logger.Logger, key: str, columns: list[str], data: list[list[str]], tensorboard_step: int | None = None):
54def log_text(
55    logger: Logger,
56    key: str,
57    columns: list[str],
58    data: list[list[str]],
59    tensorboard_step: int | None = None,
60):
61    if isinstance(logger, WandbLogger):
62        logger.log_text(key, columns, data)
63    elif isinstance(logger, TensorBoardLogger):
64        text = ", ".join(columns) + "\n"
65        text += "\n".join([", ".join(d) for d in data])
66        logger.experiment.add_text(key, text, tensorboard_step)
67    else:
68        LOGGER.warning(
69            "[Sample Logger] Only logging to tensorboard or wandb is supported"
70        )
71        return
class LogSamplesCallback(typing.Generic[~_T], abc.ABC, lightning.pytorch.callbacks.callback.Callback):
 74class LogSamplesCallback(Generic[_T], ABC, pl.Callback):
 75    def __init__(
 76        self,
 77        reference_samples: _T,
 78        log_key: str,
 79        mode: Literal["train", "val", "test"],
 80        every_n_epochs: int | None = 1,
 81    ) -> None:
 82        super().__init__()
 83        self.reference_samples = reference_samples
 84        self.every_n_epochs = every_n_epochs
 85        self.log_key = log_key
 86        self.mode = mode
 87        self._global_step = 0
 88
 89    def get_step(self) -> int:
 90        self._global_step += 1
 91        return self._global_step - 1
 92
 93    def to(self, samples: _T, device: torch.device) -> _T:
 94        raise NotImplementedError
 95
 96    def setup(
 97        self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
 98    ) -> None:
 99        if stage != "fit":
100            return
101        device = trainer.strategy.root_device
102        self.reference_samples = self.to(self.reference_samples, device)
103        for logger in trainer.loggers:
104            self.log_samples(logger, self.reference_samples, "reference")
105
106    def on_callback(
107        self,
108        loggers: Sequence[Logger],
109        pl_module: pl.LightningModule,
110    ) -> None:
111        if not len(loggers):
112            LOGGER.debug("[LOGGER] No logger found.")
113            return
114
115        samples = self.to(self.reference_samples, pl_module.device)
116
117        with torch.no_grad():
118            pl_module.eval()
119            generated_samples = pl_module(samples)
120            pl_module.train()
121
122        for logger in loggers:
123            self.log_samples(logger, generated_samples, "prediction")
124
125    def on_train_epoch_end(
126        self, trainer: pl.Trainer, pl_module: pl.LightningModule
127    ) -> None:
128        if self.mode != "train":
129            return
130
131        if (
132            self.every_n_epochs is None
133            or trainer.current_epoch % self.every_n_epochs != 0
134        ):
135            LOGGER.debug("[LOGGER] on_train_epoch_end")
136            return
137
138        LOGGER.debug("[LOGGER] on_train_epoch_end called")
139        return self.on_callback(trainer.loggers, pl_module)
140
141    def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
142        if self.mode == "test":
143            return
144
145        return self.on_callback(trainer.loggers, pl_module)
146
147    def on_validation_epoch_end(
148        self, trainer: pl.Trainer, pl_module: pl.LightningModule
149    ) -> None:
150        if self.mode != "val":
151            return
152
153        if (
154            self.every_n_epochs is None
155            or trainer.current_epoch % self.every_n_epochs != 0
156        ):
157            return
158
159        return self.on_callback(trainer.loggers, pl_module)
160
161    def on_test_epoch_end(
162        self, trainer: pl.Trainer, pl_module: pl.LightningModule
163    ) -> None:
164        if self.mode != "test":
165            return
166
167        return self.on_callback(trainer.loggers, pl_module)
168
169    @abstractmethod
170    def log_samples(self, logger: Logger, samples: _T, mode: str) -> None: ...

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

reference_samples
every_n_epochs
log_key
mode
def get_step(self) -> int:
89    def get_step(self) -> int:
90        self._global_step += 1
91        return self._global_step - 1
def to(self, samples: ~_T, device: torch.device) -> ~_T:
93    def to(self, samples: _T, device: torch.device) -> _T:
94        raise NotImplementedError
def on_callback( self, loggers: Sequence[lightning.pytorch.loggers.logger.Logger], pl_module: lightning.pytorch.core.module.LightningModule) -> None:
106    def on_callback(
107        self,
108        loggers: Sequence[Logger],
109        pl_module: pl.LightningModule,
110    ) -> None:
111        if not len(loggers):
112            LOGGER.debug("[LOGGER] No logger found.")
113            return
114
115        samples = self.to(self.reference_samples, pl_module.device)
116
117        with torch.no_grad():
118            pl_module.eval()
119            generated_samples = pl_module(samples)
120            pl_module.train()
121
122        for logger in loggers:
123            self.log_samples(logger, generated_samples, "prediction")
@abstractmethod
def log_samples( self, logger: lightning.pytorch.loggers.logger.Logger, samples: ~_T, mode: str) -> None:
169    @abstractmethod
170    def log_samples(self, logger: Logger, samples: _T, mode: str) -> None: ...
def get_pil_image(figure: matplotlib.figure.Figure) -> PIL.Image.Image:
173def get_pil_image(figure: Figure) -> Image.Image:
174    buf = io.BytesIO()
175    figure.savefig(buf)
176    buf.seek(0)
177    return Image.open(buf)
def get_attribute_figure_grid( categories: numpy.ndarray, locations: numpy.ndarray, sizes: numpy.ndarray, rotations: numpy.ndarray, colors: numpy.ndarray, image_size: int, ncols: int = 8, padding: float = 2) -> PIL.Image.Image:
180def get_attribute_figure_grid(
181    categories: np.ndarray,
182    locations: np.ndarray,
183    sizes: np.ndarray,
184    rotations: np.ndarray,
185    colors: np.ndarray,
186    image_size: int,
187    ncols: int = 8,
188    padding: float = 2,
189) -> Image.Image:
190    reminder = 1 if categories.shape[0] % ncols else 0
191    nrows = categories.shape[0] // ncols + reminder
192
193    width = ncols * (image_size + padding) + padding
194    height = nrows * (image_size + padding) + padding
195    dpi = 1
196
197    figure = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi, facecolor="white")
198    gs = gridspec.GridSpec(
199        nrows,
200        ncols,
201        wspace=padding / image_size,
202        hspace=padding / image_size,
203        left=padding / width,
204        right=1 - padding / width,
205        bottom=padding / height,
206        top=1 - padding / height,
207    )
208    for i in range(nrows):
209        for j in range(ncols):
210            k = i * ncols + j
211            if k >= categories.shape[0]:
212                break
213            ax = plt.subplot(gs[i, j])
214            generate_image(
215                ax,
216                categories[k],
217                locations[k],
218                sizes[k],
219                rotations[k],
220                colors[k],
221                image_size,
222            )
223            ax.set_facecolor("black")
224    image = get_pil_image(figure)
225    plt.close(figure)
226    return image
def attribute_image_grid( samples: Sequence[torch.Tensor], image_size: int, ncols: int) -> PIL.Image.Image:
229def attribute_image_grid(
230    samples: Sequence[torch.Tensor],
231    image_size: int,
232    ncols: int,
233) -> Image.Image:
234    unnormalizer = UnnormalizeAttributes(image_size=image_size)
235    attributes = unnormalizer(tensor_to_attribute(samples))
236
237    categories = attributes.category.detach().cpu().numpy()
238    locations = torch.stack([attributes.x, attributes.y], dim=1).detach().cpu().numpy()
239    colors = (
240        (
241            torch.stack(
242                [
243                    attributes.color_r,
244                    attributes.color_g,
245                    attributes.color_b,
246                ],
247                dim=1,
248            )
249        )
250        .cpu()
251        .numpy()
252    )
253    sizes = attributes.size.detach().cpu().numpy()
254    rotations = attributes.rotation.detach().cpu().numpy()
255
256    return get_attribute_figure_grid(
257        categories,
258        locations,
259        sizes,
260        rotations,
261        colors,
262        image_size,
263        ncols,
264        padding=2,
265    )
268class LogAttributesCallback(LogSamplesCallback[Sequence[torch.Tensor]]):
269    def __init__(
270        self,
271        reference_samples: Sequence[torch.Tensor],
272        log_key: str,
273        mode: Literal["train", "val", "test"],
274        image_size: int,
275        every_n_epochs: int | None = 1,
276        ncols: int = 8,
277    ) -> None:
278        super().__init__(reference_samples, log_key, mode, every_n_epochs)
279        self.image_size = image_size
280        self.ncols = ncols
281
282    def to(
283        self, samples: Sequence[torch.Tensor], device: torch.device
284    ) -> list[torch.Tensor]:
285        return [x.to(device) for x in samples]
286
287    def log_samples(
288        self, logger: Logger, samples: Sequence[torch.Tensor], mode: str
289    ) -> None:
290        image = attribute_image_grid(
291            samples, image_size=self.image_size, ncols=self.ncols
292        )
293        log_image(logger, f"{self.log_key}_{mode}", image, self.get_step())

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

LogAttributesCallback( reference_samples: Sequence[torch.Tensor], log_key: str, mode: Literal['train', 'val', 'test'], image_size: int, every_n_epochs: int | None = 1, ncols: int = 8)
269    def __init__(
270        self,
271        reference_samples: Sequence[torch.Tensor],
272        log_key: str,
273        mode: Literal["train", "val", "test"],
274        image_size: int,
275        every_n_epochs: int | None = 1,
276        ncols: int = 8,
277    ) -> None:
278        super().__init__(reference_samples, log_key, mode, every_n_epochs)
279        self.image_size = image_size
280        self.ncols = ncols
image_size
ncols
def to( self, samples: Sequence[torch.Tensor], device: torch.device) -> list[torch.Tensor]:
282    def to(
283        self, samples: Sequence[torch.Tensor], device: torch.device
284    ) -> list[torch.Tensor]:
285        return [x.to(device) for x in samples]
def log_samples( self, logger: lightning.pytorch.loggers.logger.Logger, samples: Sequence[torch.Tensor], mode: str) -> None:
287    def log_samples(
288        self, logger: Logger, samples: Sequence[torch.Tensor], mode: str
289    ) -> None:
290        image = attribute_image_grid(
291            samples, image_size=self.image_size, ncols=self.ncols
292        )
293        log_image(logger, f"{self.log_key}_{mode}", image, self.get_step())
296class LogTextCallback(LogSamplesCallback[Mapping[str, torch.Tensor]]):
297    def __init__(
298        self,
299        reference_samples: Mapping[str, torch.Tensor],
300        log_key: str,
301        mode: Literal["train", "val", "test"],
302        image_size: int,
303        vocab: str,
304        merges: str,
305        every_n_epochs: int | None = 1,
306        ncols: int = 8,
307    ) -> None:
308        super().__init__(reference_samples, log_key, mode, every_n_epochs)
309        self.image_size = image_size
310        self.ncols = ncols
311        self.tokenizer = ByteLevelBPETokenizer(vocab, merges)
312
313    def to(
314        self, samples: Mapping[str, torch.Tensor], device: torch.device
315    ) -> dict[str, torch.Tensor]:
316        return {x: samples[x].to(device) for x in samples}
317
318    def setup(
319        self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
320    ) -> None:
321        if stage != "fit":
322            return
323        assert isinstance(pl_module, GRUTextDomainModule)
324        device = trainer.strategy.root_device
325        self.reference_samples = self.to(self.reference_samples, device)
326        for logger in trainer.loggers:
327            self.log_samples(logger, self.reference_samples, "reference")
328
329    def on_callback(
330        self,
331        loggers: Sequence[Logger],
332        pl_module: pl.LightningModule,
333    ) -> None:
334        assert isinstance(pl_module, GRUTextDomainModule)
335
336        samples = self.to(self.reference_samples, pl_module.device)
337
338        if not len(loggers):
339            LOGGER.debug("[LOGGER] No logger found.")
340            return
341
342        with torch.no_grad():
343            pl_module.eval()
344            generated_samples = pl_module(samples)
345            pl_module.train()
346
347        for logger in loggers:
348            self.log_samples(logger, generated_samples, "prediction")
349
350    def log_samples(
351        self, logger: Logger, samples: Mapping[str, torch.Tensor], mode: str
352    ) -> None:
353        if not isinstance(logger, WandbLogger):
354            LOGGER.warning("Only logging to wandb is supported")
355            return
356
357        assert self.tokenizer is not None
358        text = self.tokenizer.decode_batch(
359            samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True
360        )
361        text = [[t.replace("<pad>", "")] for t in text]
362        log_text(logger, f"{self.log_key}_{mode}_str", ["text"], text, self.get_step())

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

LogTextCallback( reference_samples: Mapping[str, torch.Tensor], log_key: str, mode: Literal['train', 'val', 'test'], image_size: int, vocab: str, merges: str, every_n_epochs: int | None = 1, ncols: int = 8)
297    def __init__(
298        self,
299        reference_samples: Mapping[str, torch.Tensor],
300        log_key: str,
301        mode: Literal["train", "val", "test"],
302        image_size: int,
303        vocab: str,
304        merges: str,
305        every_n_epochs: int | None = 1,
306        ncols: int = 8,
307    ) -> None:
308        super().__init__(reference_samples, log_key, mode, every_n_epochs)
309        self.image_size = image_size
310        self.ncols = ncols
311        self.tokenizer = ByteLevelBPETokenizer(vocab, merges)
image_size
ncols
tokenizer
def to( self, samples: Mapping[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]:
313    def to(
314        self, samples: Mapping[str, torch.Tensor], device: torch.device
315    ) -> dict[str, torch.Tensor]:
316        return {x: samples[x].to(device) for x in samples}
def setup( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: lightning.pytorch.core.module.LightningModule, stage: str) -> None:
318    def setup(
319        self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
320    ) -> None:
321        if stage != "fit":
322            return
323        assert isinstance(pl_module, GRUTextDomainModule)
324        device = trainer.strategy.root_device
325        self.reference_samples = self.to(self.reference_samples, device)
326        for logger in trainer.loggers:
327            self.log_samples(logger, self.reference_samples, "reference")

Called when fit, validate, test, predict, or tune begins.

def on_callback( self, loggers: Sequence[lightning.pytorch.loggers.logger.Logger], pl_module: lightning.pytorch.core.module.LightningModule) -> None:
329    def on_callback(
330        self,
331        loggers: Sequence[Logger],
332        pl_module: pl.LightningModule,
333    ) -> None:
334        assert isinstance(pl_module, GRUTextDomainModule)
335
336        samples = self.to(self.reference_samples, pl_module.device)
337
338        if not len(loggers):
339            LOGGER.debug("[LOGGER] No logger found.")
340            return
341
342        with torch.no_grad():
343            pl_module.eval()
344            generated_samples = pl_module(samples)
345            pl_module.train()
346
347        for logger in loggers:
348            self.log_samples(logger, generated_samples, "prediction")
def log_samples( self, logger: lightning.pytorch.loggers.logger.Logger, samples: Mapping[str, torch.Tensor], mode: str) -> None:
350    def log_samples(
351        self, logger: Logger, samples: Mapping[str, torch.Tensor], mode: str
352    ) -> None:
353        if not isinstance(logger, WandbLogger):
354            LOGGER.warning("Only logging to wandb is supported")
355            return
356
357        assert self.tokenizer is not None
358        text = self.tokenizer.decode_batch(
359            samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True
360        )
361        text = [[t.replace("<pad>", "")] for t in text]
362        log_text(logger, f"{self.log_key}_{mode}_str", ["text"], text, self.get_step())
class LogVisualCallback(shimmer_ssd.logging.LogSamplesCallback[torch.Tensor]):
365class LogVisualCallback(LogSamplesCallback[torch.Tensor]):
366    def __init__(
367        self,
368        reference_samples: torch.Tensor,
369        log_key: str,
370        mode: Literal["train", "val", "test"],
371        every_n_epochs: int | None = 1,
372        ncols: int = 8,
373    ) -> None:
374        super().__init__(reference_samples, log_key, mode, every_n_epochs)
375        self.ncols = ncols
376
377    def to(self, samples: torch.Tensor, device: torch.device) -> torch.Tensor:
378        return samples.to(device)
379
380    def log_samples(self, logger: Logger, samples: torch.Tensor, mode: str) -> None:
381        images = make_grid(samples, nrow=self.ncols, pad_value=1)
382        log_image(logger, f"{self.log_key}_{mode}", images)

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

LogVisualCallback( reference_samples: torch.Tensor, log_key: str, mode: Literal['train', 'val', 'test'], every_n_epochs: int | None = 1, ncols: int = 8)
366    def __init__(
367        self,
368        reference_samples: torch.Tensor,
369        log_key: str,
370        mode: Literal["train", "val", "test"],
371        every_n_epochs: int | None = 1,
372        ncols: int = 8,
373    ) -> None:
374        super().__init__(reference_samples, log_key, mode, every_n_epochs)
375        self.ncols = ncols
ncols
def to(self, samples: torch.Tensor, device: torch.device) -> torch.Tensor:
377    def to(self, samples: torch.Tensor, device: torch.device) -> torch.Tensor:
378        return samples.to(device)
def log_samples( self, logger: lightning.pytorch.loggers.logger.Logger, samples: torch.Tensor, mode: str) -> None:
380    def log_samples(self, logger: Logger, samples: torch.Tensor, mode: str) -> None:
381        images = make_grid(samples, nrow=self.ncols, pad_value=1)
382        log_image(logger, f"{self.log_key}_{mode}", images)
385class LogText2AttrCallback(
386    LogSamplesCallback[
387        Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]]
388    ]
389):
390    def __init__(
391        self,
392        reference_samples: Mapping[
393            str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]
394        ],
395        log_key: str,
396        mode: Literal["train", "val", "test"],
397        every_n_epochs: int | None = 1,
398        image_size: int = 32,
399        ncols: int = 8,
400        vocab: str | None = None,
401        merges: str | None = None,
402    ) -> None:
403        super().__init__(reference_samples, log_key, mode, every_n_epochs)
404        self.image_size = image_size
405        self.ncols = ncols
406        self.tokenizer = ByteLevelBPETokenizer(vocab, merges)
407        self.reference_samples = reference_samples
408
409    def to(
410        self,
411        samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]],
412        device: torch.device,
413    ) -> dict[str, dict[str, torch.Tensor] | list[torch.Tensor]]:
414        latents: dict[str, dict[str, torch.Tensor] | list[torch.Tensor]] = {}
415        for domain_name, domain in samples.items():
416            if isinstance(domain, dict):
417                latents[domain_name] = {k: x.to(device) for k, x in domain.items()}
418            elif isinstance(domain, list):
419                latents[domain_name] = [x.to(device) for x in domain]
420        return latents
421
422    def setup(
423        self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
424    ) -> None:
425        if stage != "fit":
426            return
427        assert isinstance(pl_module, Text2Attr)
428        device = trainer.strategy.root_device
429        self.reference_samples = self.to(self.reference_samples, device)
430        for logger in trainer.loggers:
431            self.log_samples(logger, self.reference_samples, "reference")
432
433    def on_callback(
434        self,
435        loggers: Sequence[Logger],
436        pl_module: pl.LightningModule,
437    ) -> None:
438        assert isinstance(pl_module, Text2Attr)
439
440        samples = self.to(self.reference_samples, pl_module.device)
441
442        if not len(loggers):
443            LOGGER.debug("[LOGGER] No logger found.")
444            return
445
446        with torch.no_grad():
447            pl_module.eval()
448            generated_samples = pl_module(samples["t"])
449            pl_module.train()
450
451        for logger in loggers:
452            self.log_samples(logger, generated_samples, "prediction")
453
454    def log_samples(
455        self,
456        logger: Logger,
457        samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]],
458        mode: str,
459    ) -> None:
460        for domain_name, domain in samples.items():
461            if domain_name == "t":
462                assert self.tokenizer is not None
463                assert isinstance(domain, dict)
464                text = self.tokenizer.decode_batch(
465                    domain["tokens"].detach().cpu().tolist(), skip_special_tokens=True
466                )
467                text = [[t.replace("<pad>", "")] for t in text]
468                log_text(
469                    logger,
470                    f"{self.log_key}_{mode}_str",
471                    ["text"],
472                    text,
473                    self.get_step(),
474                )
475            elif domain_name == "attr":
476                assert isinstance(domain, list)
477                image = attribute_image_grid(
478                    domain,
479                    image_size=self.image_size,
480                    ncols=self.ncols,
481                )
482                log_image(logger, f"{self.log_key}_{mode}", image, self.get_step())

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

LogText2AttrCallback( reference_samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]], log_key: str, mode: Literal['train', 'val', 'test'], every_n_epochs: int | None = 1, image_size: int = 32, ncols: int = 8, vocab: str | None = None, merges: str | None = None)
390    def __init__(
391        self,
392        reference_samples: Mapping[
393            str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]
394        ],
395        log_key: str,
396        mode: Literal["train", "val", "test"],
397        every_n_epochs: int | None = 1,
398        image_size: int = 32,
399        ncols: int = 8,
400        vocab: str | None = None,
401        merges: str | None = None,
402    ) -> None:
403        super().__init__(reference_samples, log_key, mode, every_n_epochs)
404        self.image_size = image_size
405        self.ncols = ncols
406        self.tokenizer = ByteLevelBPETokenizer(vocab, merges)
407        self.reference_samples = reference_samples
image_size
ncols
tokenizer
reference_samples
def to( self, samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]], device: torch.device) -> dict[str, dict[str, torch.Tensor] | list[torch.Tensor]]:
409    def to(
410        self,
411        samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]],
412        device: torch.device,
413    ) -> dict[str, dict[str, torch.Tensor] | list[torch.Tensor]]:
414        latents: dict[str, dict[str, torch.Tensor] | list[torch.Tensor]] = {}
415        for domain_name, domain in samples.items():
416            if isinstance(domain, dict):
417                latents[domain_name] = {k: x.to(device) for k, x in domain.items()}
418            elif isinstance(domain, list):
419                latents[domain_name] = [x.to(device) for x in domain]
420        return latents
def setup( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: lightning.pytorch.core.module.LightningModule, stage: str) -> None:
422    def setup(
423        self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
424    ) -> None:
425        if stage != "fit":
426            return
427        assert isinstance(pl_module, Text2Attr)
428        device = trainer.strategy.root_device
429        self.reference_samples = self.to(self.reference_samples, device)
430        for logger in trainer.loggers:
431            self.log_samples(logger, self.reference_samples, "reference")

Called when fit, validate, test, predict, or tune begins.

def on_callback( self, loggers: Sequence[lightning.pytorch.loggers.logger.Logger], pl_module: lightning.pytorch.core.module.LightningModule) -> None:
433    def on_callback(
434        self,
435        loggers: Sequence[Logger],
436        pl_module: pl.LightningModule,
437    ) -> None:
438        assert isinstance(pl_module, Text2Attr)
439
440        samples = self.to(self.reference_samples, pl_module.device)
441
442        if not len(loggers):
443            LOGGER.debug("[LOGGER] No logger found.")
444            return
445
446        with torch.no_grad():
447            pl_module.eval()
448            generated_samples = pl_module(samples["t"])
449            pl_module.train()
450
451        for logger in loggers:
452            self.log_samples(logger, generated_samples, "prediction")
def log_samples( self, logger: lightning.pytorch.loggers.logger.Logger, samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]], mode: str) -> None:
454    def log_samples(
455        self,
456        logger: Logger,
457        samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]],
458        mode: str,
459    ) -> None:
460        for domain_name, domain in samples.items():
461            if domain_name == "t":
462                assert self.tokenizer is not None
463                assert isinstance(domain, dict)
464                text = self.tokenizer.decode_batch(
465                    domain["tokens"].detach().cpu().tolist(), skip_special_tokens=True
466                )
467                text = [[t.replace("<pad>", "")] for t in text]
468                log_text(
469                    logger,
470                    f"{self.log_key}_{mode}_str",
471                    ["text"],
472                    text,
473                    self.get_step(),
474                )
475            elif domain_name == "attr":
476                assert isinstance(domain, list)
477                image = attribute_image_grid(
478                    domain,
479                    image_size=self.image_size,
480                    ncols=self.ncols,
481                )
482                log_image(logger, f"{self.log_key}_{mode}", image, self.get_step())
def batch_to_device( samples: Mapping[frozenset[str], Mapping[str, typing.Any]], device: torch.device) -> dict[frozenset[str], dict[str, typing.Any]]:
485def batch_to_device(
486    samples: Mapping[
487        frozenset[str],
488        Mapping[str, Any],
489    ],
490    device: torch.device,
491) -> dict[frozenset[str], dict[str, Any]]:
492    out: dict[frozenset[str], dict[str, Any]] = {}
493    for domain_names, domains in samples.items():
494        latents: dict[str, Any] = {}
495        for domain_name, domain in domains.items():
496            if isinstance(domain, torch.Tensor):
497                latents[domain_name] = domain.to(device)
498            elif (
499                isinstance(domain, Mapping)
500                and len(domain)
501                and isinstance(next(iter(domain.values())), torch.Tensor)
502            ):
503                latents[domain_name] = {k: x.to(device) for k, x in domain.items()}
504            elif (
505                isinstance(domain, Sequence)
506                and len(domain)
507                and isinstance(domain[0], torch.Tensor)
508            ):
509                latents[domain_name] = [x.to(device) for x in domain]
510            else:
511                latents[domain_name] = domain
512        out[domain_names] = latents
513    return out
class LogGWImagesCallback(lightning.pytorch.callbacks.callback.Callback):
516class LogGWImagesCallback(pl.Callback):
517    def __init__(
518        self,
519        reference_samples: Mapping[frozenset[str], Mapping[str, Any]],
520        log_key: str,
521        mode: Literal["train", "val", "test"],
522        every_n_epochs: int | None = 1,
523        image_size: int = 32,
524        ncols: int = 8,
525        filter: Sequence[str] | None = None,
526        vocab: str | None = None,
527        merges: str | None = None,
528    ) -> None:
529        super().__init__()
530        self.mode = mode
531        self.reference_samples = reference_samples
532        self.every_n_epochs = every_n_epochs
533        self.log_key = log_key
534        self.image_size = image_size
535        self.ncols = ncols
536        self.filter = filter
537        self.tokenizer = None
538        if vocab is not None and merges is not None:
539            self.tokenizer = ByteLevelBPETokenizer(vocab, merges)
540        self._global_step = 0
541
542    def get_step(self):
543        self._global_step += 1
544        return self._global_step - 1
545
546    def to(
547        self,
548        samples: Mapping[
549            frozenset[str],
550            Mapping[
551                str, torch.Tensor | Sequence[torch.Tensor] | Mapping[str, torch.Tensor]
552            ],
553        ],
554        device: torch.device,
555    ) -> dict[
556        frozenset[str],
557        dict[str, torch.Tensor | list[torch.Tensor] | dict[Any, torch.Tensor]],
558    ]:
559        return batch_to_device(samples, device)
560
561    def setup(
562        self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
563    ) -> None:
564        if stage != "fit":
565            return
566        assert isinstance(pl_module, GlobalWorkspaceBase)
567        device = trainer.strategy.root_device
568        self.reference_samples = self.to(self.reference_samples, device)
569
570        for domain_names, domains in self.reference_samples.items():
571            for domain_name, domain_tensor in domains.items():
572                for logger in trainer.loggers:
573                    self.log_samples(
574                        logger,
575                        pl_module,
576                        domain_tensor,
577                        domain_name,
578                        f"ref_{'-'.join(domain_names)}_{domain_name}",
579                    )
580
581    def on_callback(
582        self,
583        loggers: Sequence[Logger],
584        pl_module: GlobalWorkspaceBase,
585    ) -> None:
586        if not (len(loggers)):
587            return
588
589        with torch.no_grad():
590            latent_groups = pl_module.encode_domains(self.reference_samples)
591            predictions = cast(GWPredictionsBase, pl_module(latent_groups))
592
593            for logger in loggers:
594                for domains, preds in predictions["broadcasts"].items():
595                    domain_from = ",".join(domains)
596                    for domain, pred in preds.items():
597                        log_name = f"pred_trans_{domain_from}_to_{domain}"
598                        if self.filter is not None and log_name not in self.filter:
599                            continue
600                        self.log_samples(
601                            logger,
602                            pl_module,
603                            pl_module.decode_domain(pred, domain),
604                            domain,
605                            log_name,
606                        )
607                for domains, preds in predictions["cycles"].items():
608                    domain_from = ",".join(domains)
609                    for domain, pred in preds.items():
610                        log_name = f"pred_cycle_{domain_from}_to_{domain}"
611                        if self.filter is not None and log_name not in self.filter:
612                            continue
613                        self.log_samples(
614                            logger,
615                            pl_module,
616                            pl_module.decode_domain(pred, domain),
617                            domain,
618                            log_name,
619                        )
620
621    def on_train_epoch_end(
622        self,
623        trainer: pl.Trainer,
624        pl_module: pl.LightningModule,
625    ) -> None:
626        if self.mode != "train":
627            return
628
629        if not isinstance(pl_module, GlobalWorkspaceBase):
630            return
631
632        if (
633            self.every_n_epochs is None
634            or trainer.current_epoch % self.every_n_epochs != 0
635        ):
636            return
637
638        return self.on_callback(trainer.loggers, pl_module)
639
640    def on_validation_epoch_end(
641        self,
642        trainer: pl.Trainer,
643        pl_module: pl.LightningModule,
644    ) -> None:
645        if self.mode != "val":
646            return
647
648        if not isinstance(pl_module, GlobalWorkspaceBase):
649            return
650
651        if (
652            self.every_n_epochs is None
653            or trainer.current_epoch % self.every_n_epochs != 0
654        ):
655            return
656
657        return self.on_callback(trainer.loggers, pl_module)
658
659    def on_test_epoch_end(
660        self,
661        trainer: pl.Trainer,
662        pl_module: pl.LightningModule,
663    ) -> None:
664        if self.mode != "test":
665            return
666
667        if not isinstance(pl_module, GlobalWorkspaceBase):
668            return
669
670        return self.on_callback(trainer.loggers, pl_module)
671
672    def on_train_end(
673        self,
674        trainer: pl.Trainer,
675        pl_module: pl.LightningModule,
676    ) -> None:
677        if self.mode == "test":
678            return
679
680        if not isinstance(pl_module, GlobalWorkspaceBase):
681            return
682
683        return self.on_callback(trainer.loggers, pl_module)
684
685    def log_samples(
686        self,
687        logger: Logger,
688        pl_module: GlobalWorkspaceBase,
689        samples: Any,
690        domain: str,
691        mode: str,
692    ) -> None:
693        match domain:
694            case "v":
695                self.log_visual_samples(logger, samples, mode)
696            case "v_latents":
697                assert "v_latents" in pl_module.domain_mods
698
699                module = cast(
700                    VisualLatentDomainModule,
701                    pl_module.domain_mods["v_latents"],
702                )
703                self.log_visual_samples(logger, module.decode_images(samples), mode)
704            case "attr":
705                self.log_attribute_samples(logger, samples, mode)
706            case "t":
707                self.log_text_samples(logger, samples, mode)
708                if "attr" in samples:
709                    self.log_attribute_samples(logger, samples["attr"], mode + "_attr")
710
711    def log_visual_samples(
712        self,
713        logger: Logger,
714        samples: Any,
715        mode: str,
716    ) -> None:
717        images = make_grid(samples, nrow=self.ncols, pad_value=1)
718        log_image(logger, f"{self.log_key}/{mode}", images, self.get_step())
719
720    def log_attribute_samples(
721        self,
722        logger: Logger,
723        samples: Any,
724        mode: str,
725    ) -> None:
726        image = attribute_image_grid(
727            samples,
728            image_size=self.image_size,
729            ncols=self.ncols,
730        )
731        log_image(logger, f"{self.log_key}/{mode}", image, self.get_step())
732
733    def log_text_samples(
734        self,
735        logger: Logger,
736        samples: Any,
737        mode: str,
738    ) -> None:
739        assert self.tokenizer is not None
740        text = self.tokenizer.decode_batch(
741            samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True
742        )
743        text = [[t.replace("<pad>", "")] for t in text]
744        log_text(logger, f"{self.log_key}/{mode}", ["text"], text, self.get_step())

Abstract base class used to build new callbacks.

Subclass this class and override any of the relevant hooks

LogGWImagesCallback( reference_samples: Mapping[frozenset[str], Mapping[str, typing.Any]], log_key: str, mode: Literal['train', 'val', 'test'], every_n_epochs: int | None = 1, image_size: int = 32, ncols: int = 8, filter: Sequence[str] | None = None, vocab: str | None = None, merges: str | None = None)
517    def __init__(
518        self,
519        reference_samples: Mapping[frozenset[str], Mapping[str, Any]],
520        log_key: str,
521        mode: Literal["train", "val", "test"],
522        every_n_epochs: int | None = 1,
523        image_size: int = 32,
524        ncols: int = 8,
525        filter: Sequence[str] | None = None,
526        vocab: str | None = None,
527        merges: str | None = None,
528    ) -> None:
529        super().__init__()
530        self.mode = mode
531        self.reference_samples = reference_samples
532        self.every_n_epochs = every_n_epochs
533        self.log_key = log_key
534        self.image_size = image_size
535        self.ncols = ncols
536        self.filter = filter
537        self.tokenizer = None
538        if vocab is not None and merges is not None:
539            self.tokenizer = ByteLevelBPETokenizer(vocab, merges)
540        self._global_step = 0
mode
reference_samples
every_n_epochs
log_key
image_size
ncols
filter
tokenizer
def get_step(self):
542    def get_step(self):
543        self._global_step += 1
544        return self._global_step - 1
def to( self, samples: Mapping[frozenset[str], Mapping[str, torch.Tensor | Sequence[torch.Tensor] | Mapping[str, torch.Tensor]]], device: torch.device) -> dict[frozenset[str], dict[str, torch.Tensor | list[torch.Tensor] | dict[typing.Any, torch.Tensor]]]:
546    def to(
547        self,
548        samples: Mapping[
549            frozenset[str],
550            Mapping[
551                str, torch.Tensor | Sequence[torch.Tensor] | Mapping[str, torch.Tensor]
552            ],
553        ],
554        device: torch.device,
555    ) -> dict[
556        frozenset[str],
557        dict[str, torch.Tensor | list[torch.Tensor] | dict[Any, torch.Tensor]],
558    ]:
559        return batch_to_device(samples, device)
def setup( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: lightning.pytorch.core.module.LightningModule, stage: str) -> None:
561    def setup(
562        self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str
563    ) -> None:
564        if stage != "fit":
565            return
566        assert isinstance(pl_module, GlobalWorkspaceBase)
567        device = trainer.strategy.root_device
568        self.reference_samples = self.to(self.reference_samples, device)
569
570        for domain_names, domains in self.reference_samples.items():
571            for domain_name, domain_tensor in domains.items():
572                for logger in trainer.loggers:
573                    self.log_samples(
574                        logger,
575                        pl_module,
576                        domain_tensor,
577                        domain_name,
578                        f"ref_{'-'.join(domain_names)}_{domain_name}",
579                    )

Called when fit, validate, test, predict, or tune begins.

def on_callback( self, loggers: Sequence[lightning.pytorch.loggers.logger.Logger], pl_module: shimmer.modules.global_workspace.GlobalWorkspaceBase) -> None:
581    def on_callback(
582        self,
583        loggers: Sequence[Logger],
584        pl_module: GlobalWorkspaceBase,
585    ) -> None:
586        if not (len(loggers)):
587            return
588
589        with torch.no_grad():
590            latent_groups = pl_module.encode_domains(self.reference_samples)
591            predictions = cast(GWPredictionsBase, pl_module(latent_groups))
592
593            for logger in loggers:
594                for domains, preds in predictions["broadcasts"].items():
595                    domain_from = ",".join(domains)
596                    for domain, pred in preds.items():
597                        log_name = f"pred_trans_{domain_from}_to_{domain}"
598                        if self.filter is not None and log_name not in self.filter:
599                            continue
600                        self.log_samples(
601                            logger,
602                            pl_module,
603                            pl_module.decode_domain(pred, domain),
604                            domain,
605                            log_name,
606                        )
607                for domains, preds in predictions["cycles"].items():
608                    domain_from = ",".join(domains)
609                    for domain, pred in preds.items():
610                        log_name = f"pred_cycle_{domain_from}_to_{domain}"
611                        if self.filter is not None and log_name not in self.filter:
612                            continue
613                        self.log_samples(
614                            logger,
615                            pl_module,
616                            pl_module.decode_domain(pred, domain),
617                            domain,
618                            log_name,
619                        )
def on_train_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: lightning.pytorch.core.module.LightningModule) -> None:
621    def on_train_epoch_end(
622        self,
623        trainer: pl.Trainer,
624        pl_module: pl.LightningModule,
625    ) -> None:
626        if self.mode != "train":
627            return
628
629        if not isinstance(pl_module, GlobalWorkspaceBase):
630            return
631
632        if (
633            self.every_n_epochs is None
634            or trainer.current_epoch % self.every_n_epochs != 0
635        ):
636            return
637
638        return self.on_callback(trainer.loggers, pl_module)

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
def on_validation_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: lightning.pytorch.core.module.LightningModule) -> None:
640    def on_validation_epoch_end(
641        self,
642        trainer: pl.Trainer,
643        pl_module: pl.LightningModule,
644    ) -> None:
645        if self.mode != "val":
646            return
647
648        if not isinstance(pl_module, GlobalWorkspaceBase):
649            return
650
651        if (
652            self.every_n_epochs is None
653            or trainer.current_epoch % self.every_n_epochs != 0
654        ):
655            return
656
657        return self.on_callback(trainer.loggers, pl_module)

Called when the val epoch ends.

def on_test_epoch_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: lightning.pytorch.core.module.LightningModule) -> None:
659    def on_test_epoch_end(
660        self,
661        trainer: pl.Trainer,
662        pl_module: pl.LightningModule,
663    ) -> None:
664        if self.mode != "test":
665            return
666
667        if not isinstance(pl_module, GlobalWorkspaceBase):
668            return
669
670        return self.on_callback(trainer.loggers, pl_module)

Called when the test epoch ends.

def on_train_end( self, trainer: lightning.pytorch.trainer.trainer.Trainer, pl_module: lightning.pytorch.core.module.LightningModule) -> None:
672    def on_train_end(
673        self,
674        trainer: pl.Trainer,
675        pl_module: pl.LightningModule,
676    ) -> None:
677        if self.mode == "test":
678            return
679
680        if not isinstance(pl_module, GlobalWorkspaceBase):
681            return
682
683        return self.on_callback(trainer.loggers, pl_module)

Called when the train ends.

def log_samples( self, logger: lightning.pytorch.loggers.logger.Logger, pl_module: shimmer.modules.global_workspace.GlobalWorkspaceBase, samples: Any, domain: str, mode: str) -> None:
685    def log_samples(
686        self,
687        logger: Logger,
688        pl_module: GlobalWorkspaceBase,
689        samples: Any,
690        domain: str,
691        mode: str,
692    ) -> None:
693        match domain:
694            case "v":
695                self.log_visual_samples(logger, samples, mode)
696            case "v_latents":
697                assert "v_latents" in pl_module.domain_mods
698
699                module = cast(
700                    VisualLatentDomainModule,
701                    pl_module.domain_mods["v_latents"],
702                )
703                self.log_visual_samples(logger, module.decode_images(samples), mode)
704            case "attr":
705                self.log_attribute_samples(logger, samples, mode)
706            case "t":
707                self.log_text_samples(logger, samples, mode)
708                if "attr" in samples:
709                    self.log_attribute_samples(logger, samples["attr"], mode + "_attr")
def log_visual_samples( self, logger: lightning.pytorch.loggers.logger.Logger, samples: Any, mode: str) -> None:
711    def log_visual_samples(
712        self,
713        logger: Logger,
714        samples: Any,
715        mode: str,
716    ) -> None:
717        images = make_grid(samples, nrow=self.ncols, pad_value=1)
718        log_image(logger, f"{self.log_key}/{mode}", images, self.get_step())
def log_attribute_samples( self, logger: lightning.pytorch.loggers.logger.Logger, samples: Any, mode: str) -> None:
720    def log_attribute_samples(
721        self,
722        logger: Logger,
723        samples: Any,
724        mode: str,
725    ) -> None:
726        image = attribute_image_grid(
727            samples,
728            image_size=self.image_size,
729            ncols=self.ncols,
730        )
731        log_image(logger, f"{self.log_key}/{mode}", image, self.get_step())
def log_text_samples( self, logger: lightning.pytorch.loggers.logger.Logger, samples: Any, mode: str) -> None:
733    def log_text_samples(
734        self,
735        logger: Logger,
736        samples: Any,
737        mode: str,
738    ) -> None:
739        assert self.tokenizer is not None
740        text = self.tokenizer.decode_batch(
741            samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True
742        )
743        text = [[t.replace("<pad>", "")] for t in text]
744        log_text(logger, f"{self.log_key}/{mode}", ["text"], text, self.get_step())