shimmer_ssd.logging
1import io 2from abc import ABC, abstractmethod 3from collections.abc import Mapping, Sequence 4from typing import Any, Generic, Literal, TypeVar, cast 5 6import lightning.pytorch as pl 7import matplotlib 8import matplotlib.pyplot as plt 9import numpy as np 10import torch 11from lightning.pytorch.loggers import Logger, TensorBoardLogger 12from lightning.pytorch.loggers.wandb import WandbLogger 13from matplotlib import gridspec 14from matplotlib.figure import Figure 15from PIL import Image 16from shimmer.modules.global_workspace import GlobalWorkspaceBase, GWPredictionsBase 17from simple_shapes_dataset import ( 18 UnnormalizeAttributes, 19 tensor_to_attribute, 20) 21from simple_shapes_dataset.cli import generate_image 22from tokenizers.implementations import ByteLevelBPETokenizer 23from torchvision.transforms.functional import to_tensor 24from torchvision.utils import make_grid 25 26from shimmer_ssd import LOGGER 27from shimmer_ssd.modules.domains.text import GRUTextDomainModule, Text2Attr 28from shimmer_ssd.modules.domains.visual import VisualLatentDomainModule 29 30matplotlib.use("Agg") 31 32_T = TypeVar("_T") 33 34 35def log_image( 36 logger: Logger, 37 key: str, 38 image: torch.Tensor | Image.Image, 39 tensorboard_step: int | None = None, 40): 41 if isinstance(logger, WandbLogger): 42 logger.log_image(key, [image]) 43 elif isinstance(logger, TensorBoardLogger): 44 torch_image = to_tensor(image) if isinstance(image, Image.Image) else image 45 logger.experiment.add_image(key, torch_image, tensorboard_step) 46 else: 47 LOGGER.warning( 48 "[Sample Logger] Only logging to tensorboard or wandb is supported" 49 ) 50 return 51 52 53def log_text( 54 logger: Logger, 55 key: str, 56 columns: list[str], 57 data: list[list[str]], 58 tensorboard_step: int | None = None, 59): 60 if isinstance(logger, WandbLogger): 61 logger.log_text(key, columns, data) 62 elif isinstance(logger, TensorBoardLogger): 63 text = ", ".join(columns) + "\n" 64 text += "\n".join([", ".join(d) for d in data]) 65 logger.experiment.add_text(key, text, tensorboard_step) 66 else: 67 LOGGER.warning( 68 "[Sample Logger] Only logging to tensorboard or wandb is supported" 69 ) 70 return 71 72 73class LogSamplesCallback(Generic[_T], ABC, pl.Callback): 74 def __init__( 75 self, 76 reference_samples: _T, 77 log_key: str, 78 mode: Literal["train", "val", "test"], 79 every_n_epochs: int | None = 1, 80 ) -> None: 81 super().__init__() 82 self.reference_samples = reference_samples 83 self.every_n_epochs = every_n_epochs 84 self.log_key = log_key 85 self.mode = mode 86 self._global_step = 0 87 88 def get_step(self) -> int: 89 self._global_step += 1 90 return self._global_step - 1 91 92 def to(self, samples: _T, device: torch.device) -> _T: 93 raise NotImplementedError 94 95 def setup( 96 self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str 97 ) -> None: 98 if stage != "fit": 99 return 100 device = trainer.strategy.root_device 101 self.reference_samples = self.to(self.reference_samples, device) 102 for logger in trainer.loggers: 103 self.log_samples(logger, self.reference_samples, "reference") 104 105 def on_callback( 106 self, 107 loggers: Sequence[Logger], 108 pl_module: pl.LightningModule, 109 ) -> None: 110 if not len(loggers): 111 LOGGER.debug("[LOGGER] No logger found.") 112 return 113 114 samples = self.to(self.reference_samples, pl_module.device) 115 116 with torch.no_grad(): 117 pl_module.eval() 118 generated_samples = pl_module(samples) 119 pl_module.train() 120 121 for logger in loggers: 122 self.log_samples(logger, generated_samples, "prediction") 123 124 def on_train_epoch_end( 125 self, trainer: pl.Trainer, pl_module: pl.LightningModule 126 ) -> None: 127 if self.mode != "train": 128 return 129 130 if ( 131 self.every_n_epochs is None 132 or trainer.current_epoch % self.every_n_epochs != 0 133 ): 134 LOGGER.debug("[LOGGER] on_train_epoch_end") 135 return 136 137 LOGGER.debug("[LOGGER] on_train_epoch_end called") 138 return self.on_callback(trainer.loggers, pl_module) 139 140 def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 141 if self.mode == "test": 142 return 143 144 return self.on_callback(trainer.loggers, pl_module) 145 146 def on_validation_epoch_end( 147 self, trainer: pl.Trainer, pl_module: pl.LightningModule 148 ) -> None: 149 if self.mode != "val": 150 return 151 152 if ( 153 self.every_n_epochs is None 154 or trainer.current_epoch % self.every_n_epochs != 0 155 ): 156 return 157 158 return self.on_callback(trainer.loggers, pl_module) 159 160 def on_test_epoch_end( 161 self, trainer: pl.Trainer, pl_module: pl.LightningModule 162 ) -> None: 163 if self.mode != "test": 164 return 165 166 return self.on_callback(trainer.loggers, pl_module) 167 168 @abstractmethod 169 def log_samples(self, logger: Logger, samples: _T, mode: str) -> None: ... 170 171 172def get_pil_image(figure: Figure) -> Image.Image: 173 buf = io.BytesIO() 174 figure.savefig(buf) 175 buf.seek(0) 176 return Image.open(buf) 177 178 179def get_attribute_figure_grid( 180 categories: np.ndarray, 181 locations: np.ndarray, 182 sizes: np.ndarray, 183 rotations: np.ndarray, 184 colors: np.ndarray, 185 image_size: int, 186 ncols: int = 8, 187 padding: float = 2, 188) -> Image.Image: 189 reminder = 1 if categories.shape[0] % ncols else 0 190 nrows = categories.shape[0] // ncols + reminder 191 192 width = ncols * (image_size + padding) + padding 193 height = nrows * (image_size + padding) + padding 194 dpi = 1 195 196 figure = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi, facecolor="white") 197 gs = gridspec.GridSpec( 198 nrows, 199 ncols, 200 wspace=padding / image_size, 201 hspace=padding / image_size, 202 left=padding / width, 203 right=1 - padding / width, 204 bottom=padding / height, 205 top=1 - padding / height, 206 ) 207 for i in range(nrows): 208 for j in range(ncols): 209 k = i * ncols + j 210 if k >= categories.shape[0]: 211 break 212 ax = plt.subplot(gs[i, j]) 213 generate_image( 214 ax, 215 categories[k], 216 locations[k], 217 sizes[k], 218 rotations[k], 219 colors[k], 220 image_size, 221 ) 222 ax.set_facecolor("black") 223 image = get_pil_image(figure) 224 plt.close(figure) 225 return image 226 227 228def attribute_image_grid( 229 samples: Sequence[torch.Tensor], 230 image_size: int, 231 ncols: int, 232) -> Image.Image: 233 unnormalizer = UnnormalizeAttributes(image_size=image_size) 234 attributes = unnormalizer(tensor_to_attribute(samples)) 235 236 categories = attributes.category.detach().cpu().numpy() 237 locations = torch.stack([attributes.x, attributes.y], dim=1).detach().cpu().numpy() 238 colors = ( 239 ( 240 torch.stack( 241 [ 242 attributes.color_r, 243 attributes.color_g, 244 attributes.color_b, 245 ], 246 dim=1, 247 ) 248 ) 249 .cpu() 250 .numpy() 251 ) 252 sizes = attributes.size.detach().cpu().numpy() 253 rotations = attributes.rotation.detach().cpu().numpy() 254 255 return get_attribute_figure_grid( 256 categories, 257 locations, 258 sizes, 259 rotations, 260 colors, 261 image_size, 262 ncols, 263 padding=2, 264 ) 265 266 267class LogAttributesCallback(LogSamplesCallback[Sequence[torch.Tensor]]): 268 def __init__( 269 self, 270 reference_samples: Sequence[torch.Tensor], 271 log_key: str, 272 mode: Literal["train", "val", "test"], 273 image_size: int, 274 every_n_epochs: int | None = 1, 275 ncols: int = 8, 276 ) -> None: 277 super().__init__(reference_samples, log_key, mode, every_n_epochs) 278 self.image_size = image_size 279 self.ncols = ncols 280 281 def to( 282 self, samples: Sequence[torch.Tensor], device: torch.device 283 ) -> list[torch.Tensor]: 284 return [x.to(device) for x in samples] 285 286 def log_samples( 287 self, logger: Logger, samples: Sequence[torch.Tensor], mode: str 288 ) -> None: 289 image = attribute_image_grid( 290 samples, image_size=self.image_size, ncols=self.ncols 291 ) 292 log_image(logger, f"{self.log_key}_{mode}", image, self.get_step()) 293 294 295class LogTextCallback(LogSamplesCallback[Mapping[str, torch.Tensor]]): 296 def __init__( 297 self, 298 reference_samples: Mapping[str, torch.Tensor], 299 log_key: str, 300 mode: Literal["train", "val", "test"], 301 image_size: int, 302 vocab: str, 303 merges: str, 304 every_n_epochs: int | None = 1, 305 ncols: int = 8, 306 ) -> None: 307 super().__init__(reference_samples, log_key, mode, every_n_epochs) 308 self.image_size = image_size 309 self.ncols = ncols 310 self.tokenizer = ByteLevelBPETokenizer(vocab, merges) 311 312 def to( 313 self, samples: Mapping[str, torch.Tensor], device: torch.device 314 ) -> dict[str, torch.Tensor]: 315 return {x: samples[x].to(device) for x in samples} 316 317 def setup( 318 self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str 319 ) -> None: 320 if stage != "fit": 321 return 322 assert isinstance(pl_module, GRUTextDomainModule) 323 device = trainer.strategy.root_device 324 self.reference_samples = self.to(self.reference_samples, device) 325 for logger in trainer.loggers: 326 self.log_samples(logger, self.reference_samples, "reference") 327 328 def on_callback( 329 self, 330 loggers: Sequence[Logger], 331 pl_module: pl.LightningModule, 332 ) -> None: 333 assert isinstance(pl_module, GRUTextDomainModule) 334 335 samples = self.to(self.reference_samples, pl_module.device) 336 337 if not len(loggers): 338 LOGGER.debug("[LOGGER] No logger found.") 339 return 340 341 with torch.no_grad(): 342 pl_module.eval() 343 generated_samples = pl_module(samples) 344 pl_module.train() 345 346 for logger in loggers: 347 self.log_samples(logger, generated_samples, "prediction") 348 349 def log_samples( 350 self, logger: Logger, samples: Mapping[str, torch.Tensor], mode: str 351 ) -> None: 352 if not isinstance(logger, WandbLogger): 353 LOGGER.warning("Only logging to wandb is supported") 354 return 355 356 assert self.tokenizer is not None 357 text = self.tokenizer.decode_batch( 358 samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True 359 ) 360 text = [[t.replace("<pad>", "")] for t in text] 361 log_text(logger, f"{self.log_key}_{mode}_str", ["text"], text, self.get_step()) 362 363 364class LogVisualCallback(LogSamplesCallback[torch.Tensor]): 365 def __init__( 366 self, 367 reference_samples: torch.Tensor, 368 log_key: str, 369 mode: Literal["train", "val", "test"], 370 every_n_epochs: int | None = 1, 371 ncols: int = 8, 372 ) -> None: 373 super().__init__(reference_samples, log_key, mode, every_n_epochs) 374 self.ncols = ncols 375 376 def to(self, samples: torch.Tensor, device: torch.device) -> torch.Tensor: 377 return samples.to(device) 378 379 def log_samples(self, logger: Logger, samples: torch.Tensor, mode: str) -> None: 380 images = make_grid(samples, nrow=self.ncols, pad_value=1) 381 log_image(logger, f"{self.log_key}_{mode}", images) 382 383 384class LogText2AttrCallback( 385 LogSamplesCallback[ 386 Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]] 387 ] 388): 389 def __init__( 390 self, 391 reference_samples: Mapping[ 392 str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor] 393 ], 394 log_key: str, 395 mode: Literal["train", "val", "test"], 396 every_n_epochs: int | None = 1, 397 image_size: int = 32, 398 ncols: int = 8, 399 vocab: str | None = None, 400 merges: str | None = None, 401 ) -> None: 402 super().__init__(reference_samples, log_key, mode, every_n_epochs) 403 self.image_size = image_size 404 self.ncols = ncols 405 self.tokenizer = ByteLevelBPETokenizer(vocab, merges) 406 self.reference_samples = reference_samples 407 408 def to( 409 self, 410 samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]], 411 device: torch.device, 412 ) -> dict[str, dict[str, torch.Tensor] | list[torch.Tensor]]: 413 latents: dict[str, dict[str, torch.Tensor] | list[torch.Tensor]] = {} 414 for domain_name, domain in samples.items(): 415 if isinstance(domain, dict): 416 latents[domain_name] = {k: x.to(device) for k, x in domain.items()} 417 elif isinstance(domain, list): 418 latents[domain_name] = [x.to(device) for x in domain] 419 return latents 420 421 def setup( 422 self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str 423 ) -> None: 424 if stage != "fit": 425 return 426 assert isinstance(pl_module, Text2Attr) 427 device = trainer.strategy.root_device 428 self.reference_samples = self.to(self.reference_samples, device) 429 for logger in trainer.loggers: 430 self.log_samples(logger, self.reference_samples, "reference") 431 432 def on_callback( 433 self, 434 loggers: Sequence[Logger], 435 pl_module: pl.LightningModule, 436 ) -> None: 437 assert isinstance(pl_module, Text2Attr) 438 439 samples = self.to(self.reference_samples, pl_module.device) 440 441 if not len(loggers): 442 LOGGER.debug("[LOGGER] No logger found.") 443 return 444 445 with torch.no_grad(): 446 pl_module.eval() 447 generated_samples = pl_module(samples["t"]) 448 pl_module.train() 449 450 for logger in loggers: 451 self.log_samples(logger, generated_samples, "prediction") 452 453 def log_samples( 454 self, 455 logger: Logger, 456 samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]], 457 mode: str, 458 ) -> None: 459 for domain_name, domain in samples.items(): 460 if domain_name == "t": 461 assert self.tokenizer is not None 462 assert isinstance(domain, dict) 463 text = self.tokenizer.decode_batch( 464 domain["tokens"].detach().cpu().tolist(), skip_special_tokens=True 465 ) 466 text = [[t.replace("<pad>", "")] for t in text] 467 log_text( 468 logger, 469 f"{self.log_key}_{mode}_str", 470 ["text"], 471 text, 472 self.get_step(), 473 ) 474 elif domain_name == "attr": 475 assert isinstance(domain, list) 476 image = attribute_image_grid( 477 domain, 478 image_size=self.image_size, 479 ncols=self.ncols, 480 ) 481 log_image(logger, f"{self.log_key}_{mode}", image, self.get_step()) 482 483 484def batch_to_device( 485 samples: Mapping[ 486 frozenset[str], 487 Mapping[str, Any], 488 ], 489 device: torch.device, 490) -> dict[frozenset[str], dict[str, Any]]: 491 out: dict[frozenset[str], dict[str, Any]] = {} 492 for domain_names, domains in samples.items(): 493 latents: dict[str, Any] = {} 494 for domain_name, domain in domains.items(): 495 if isinstance(domain, torch.Tensor): 496 latents[domain_name] = domain.to(device) 497 elif ( 498 isinstance(domain, Mapping) 499 and len(domain) 500 and isinstance(next(iter(domain.values())), torch.Tensor) 501 ): 502 latents[domain_name] = {k: x.to(device) for k, x in domain.items()} 503 elif ( 504 isinstance(domain, Sequence) 505 and len(domain) 506 and isinstance(domain[0], torch.Tensor) 507 ): 508 latents[domain_name] = [x.to(device) for x in domain] 509 else: 510 latents[domain_name] = domain 511 out[domain_names] = latents 512 return out 513 514 515class LogGWImagesCallback(pl.Callback): 516 def __init__( 517 self, 518 reference_samples: Mapping[frozenset[str], Mapping[str, Any]], 519 log_key: str, 520 mode: Literal["train", "val", "test"], 521 every_n_epochs: int | None = 1, 522 image_size: int = 32, 523 ncols: int = 8, 524 filter: Sequence[str] | None = None, 525 vocab: str | None = None, 526 merges: str | None = None, 527 ) -> None: 528 super().__init__() 529 self.mode = mode 530 self.reference_samples = reference_samples 531 self.every_n_epochs = every_n_epochs 532 self.log_key = log_key 533 self.image_size = image_size 534 self.ncols = ncols 535 self.filter = filter 536 self.tokenizer = None 537 if vocab is not None and merges is not None: 538 self.tokenizer = ByteLevelBPETokenizer(vocab, merges) 539 self._global_step = 0 540 541 def get_step(self): 542 self._global_step += 1 543 return self._global_step - 1 544 545 def to( 546 self, 547 samples: Mapping[ 548 frozenset[str], 549 Mapping[ 550 str, torch.Tensor | Sequence[torch.Tensor] | Mapping[str, torch.Tensor] 551 ], 552 ], 553 device: torch.device, 554 ) -> dict[ 555 frozenset[str], 556 dict[str, torch.Tensor | list[torch.Tensor] | dict[Any, torch.Tensor]], 557 ]: 558 return batch_to_device(samples, device) 559 560 def setup( 561 self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str 562 ) -> None: 563 if stage != "fit": 564 return 565 assert isinstance(pl_module, GlobalWorkspaceBase) 566 device = trainer.strategy.root_device 567 self.reference_samples = self.to(self.reference_samples, device) 568 569 for domain_names, domains in self.reference_samples.items(): 570 for domain_name, domain_tensor in domains.items(): 571 for logger in trainer.loggers: 572 self.log_samples( 573 logger, 574 pl_module, 575 domain_tensor, 576 domain_name, 577 f"ref_{'-'.join(domain_names)}_{domain_name}", 578 ) 579 580 def on_callback( 581 self, 582 loggers: Sequence[Logger], 583 pl_module: GlobalWorkspaceBase, 584 ) -> None: 585 if not (len(loggers)): 586 return 587 588 with torch.no_grad(): 589 latent_groups = pl_module.encode_domains(self.reference_samples) 590 predictions = cast(GWPredictionsBase, pl_module(latent_groups)) 591 592 for logger in loggers: 593 for domains, preds in predictions["broadcasts"].items(): 594 domain_from = ",".join(domains) 595 for domain, pred in preds.items(): 596 log_name = f"pred_trans_{domain_from}_to_{domain}" 597 if self.filter is not None and log_name not in self.filter: 598 continue 599 self.log_samples( 600 logger, 601 pl_module, 602 pl_module.decode_domain(pred, domain), 603 domain, 604 log_name, 605 ) 606 for domains, preds in predictions["cycles"].items(): 607 domain_from = ",".join(domains) 608 for domain, pred in preds.items(): 609 log_name = f"pred_cycle_{domain_from}_to_{domain}" 610 if self.filter is not None and log_name not in self.filter: 611 continue 612 self.log_samples( 613 logger, 614 pl_module, 615 pl_module.decode_domain(pred, domain), 616 domain, 617 log_name, 618 ) 619 620 def on_train_epoch_end( 621 self, 622 trainer: pl.Trainer, 623 pl_module: pl.LightningModule, 624 ) -> None: 625 if self.mode != "train": 626 return 627 628 if not isinstance(pl_module, GlobalWorkspaceBase): 629 return 630 631 if ( 632 self.every_n_epochs is None 633 or trainer.current_epoch % self.every_n_epochs != 0 634 ): 635 return 636 637 return self.on_callback(trainer.loggers, pl_module) 638 639 def on_validation_epoch_end( 640 self, 641 trainer: pl.Trainer, 642 pl_module: pl.LightningModule, 643 ) -> None: 644 if self.mode != "val": 645 return 646 647 if not isinstance(pl_module, GlobalWorkspaceBase): 648 return 649 650 if ( 651 self.every_n_epochs is None 652 or trainer.current_epoch % self.every_n_epochs != 0 653 ): 654 return 655 656 return self.on_callback(trainer.loggers, pl_module) 657 658 def on_test_epoch_end( 659 self, 660 trainer: pl.Trainer, 661 pl_module: pl.LightningModule, 662 ) -> None: 663 if self.mode != "test": 664 return 665 666 if not isinstance(pl_module, GlobalWorkspaceBase): 667 return 668 669 return self.on_callback(trainer.loggers, pl_module) 670 671 def on_train_end( 672 self, 673 trainer: pl.Trainer, 674 pl_module: pl.LightningModule, 675 ) -> None: 676 if self.mode == "test": 677 return 678 679 if not isinstance(pl_module, GlobalWorkspaceBase): 680 return 681 682 return self.on_callback(trainer.loggers, pl_module) 683 684 def log_samples( 685 self, 686 logger: Logger, 687 pl_module: GlobalWorkspaceBase, 688 samples: Any, 689 domain: str, 690 mode: str, 691 ) -> None: 692 match domain: 693 case "v": 694 self.log_visual_samples(logger, samples, mode) 695 case "v_latents": 696 assert "v_latents" in pl_module.domain_mods 697 698 module = cast( 699 VisualLatentDomainModule, 700 pl_module.domain_mods["v_latents"], 701 ) 702 self.log_visual_samples(logger, module.decode_images(samples), mode) 703 case "attr": 704 self.log_attribute_samples(logger, samples, mode) 705 case "t": 706 self.log_text_samples(logger, samples, mode) 707 if "attr" in samples: 708 self.log_attribute_samples(logger, samples["attr"], mode + "_attr") 709 710 def log_visual_samples( 711 self, 712 logger: Logger, 713 samples: Any, 714 mode: str, 715 ) -> None: 716 images = make_grid(samples, nrow=self.ncols, pad_value=1) 717 log_image(logger, f"{self.log_key}/{mode}", images, self.get_step()) 718 719 def log_attribute_samples( 720 self, 721 logger: Logger, 722 samples: Any, 723 mode: str, 724 ) -> None: 725 image = attribute_image_grid( 726 samples, 727 image_size=self.image_size, 728 ncols=self.ncols, 729 ) 730 log_image(logger, f"{self.log_key}/{mode}", image, self.get_step()) 731 732 def log_text_samples( 733 self, 734 logger: Logger, 735 samples: Any, 736 mode: str, 737 ) -> None: 738 assert self.tokenizer is not None 739 text = self.tokenizer.decode_batch( 740 samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True 741 ) 742 text = [[t.replace("<pad>", "")] for t in text] 743 log_text(logger, f"{self.log_key}/{mode}", ["text"], text, self.get_step())
36def log_image( 37 logger: Logger, 38 key: str, 39 image: torch.Tensor | Image.Image, 40 tensorboard_step: int | None = None, 41): 42 if isinstance(logger, WandbLogger): 43 logger.log_image(key, [image]) 44 elif isinstance(logger, TensorBoardLogger): 45 torch_image = to_tensor(image) if isinstance(image, Image.Image) else image 46 logger.experiment.add_image(key, torch_image, tensorboard_step) 47 else: 48 LOGGER.warning( 49 "[Sample Logger] Only logging to tensorboard or wandb is supported" 50 ) 51 return
54def log_text( 55 logger: Logger, 56 key: str, 57 columns: list[str], 58 data: list[list[str]], 59 tensorboard_step: int | None = None, 60): 61 if isinstance(logger, WandbLogger): 62 logger.log_text(key, columns, data) 63 elif isinstance(logger, TensorBoardLogger): 64 text = ", ".join(columns) + "\n" 65 text += "\n".join([", ".join(d) for d in data]) 66 logger.experiment.add_text(key, text, tensorboard_step) 67 else: 68 LOGGER.warning( 69 "[Sample Logger] Only logging to tensorboard or wandb is supported" 70 ) 71 return
74class LogSamplesCallback(Generic[_T], ABC, pl.Callback): 75 def __init__( 76 self, 77 reference_samples: _T, 78 log_key: str, 79 mode: Literal["train", "val", "test"], 80 every_n_epochs: int | None = 1, 81 ) -> None: 82 super().__init__() 83 self.reference_samples = reference_samples 84 self.every_n_epochs = every_n_epochs 85 self.log_key = log_key 86 self.mode = mode 87 self._global_step = 0 88 89 def get_step(self) -> int: 90 self._global_step += 1 91 return self._global_step - 1 92 93 def to(self, samples: _T, device: torch.device) -> _T: 94 raise NotImplementedError 95 96 def setup( 97 self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str 98 ) -> None: 99 if stage != "fit": 100 return 101 device = trainer.strategy.root_device 102 self.reference_samples = self.to(self.reference_samples, device) 103 for logger in trainer.loggers: 104 self.log_samples(logger, self.reference_samples, "reference") 105 106 def on_callback( 107 self, 108 loggers: Sequence[Logger], 109 pl_module: pl.LightningModule, 110 ) -> None: 111 if not len(loggers): 112 LOGGER.debug("[LOGGER] No logger found.") 113 return 114 115 samples = self.to(self.reference_samples, pl_module.device) 116 117 with torch.no_grad(): 118 pl_module.eval() 119 generated_samples = pl_module(samples) 120 pl_module.train() 121 122 for logger in loggers: 123 self.log_samples(logger, generated_samples, "prediction") 124 125 def on_train_epoch_end( 126 self, trainer: pl.Trainer, pl_module: pl.LightningModule 127 ) -> None: 128 if self.mode != "train": 129 return 130 131 if ( 132 self.every_n_epochs is None 133 or trainer.current_epoch % self.every_n_epochs != 0 134 ): 135 LOGGER.debug("[LOGGER] on_train_epoch_end") 136 return 137 138 LOGGER.debug("[LOGGER] on_train_epoch_end called") 139 return self.on_callback(trainer.loggers, pl_module) 140 141 def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 142 if self.mode == "test": 143 return 144 145 return self.on_callback(trainer.loggers, pl_module) 146 147 def on_validation_epoch_end( 148 self, trainer: pl.Trainer, pl_module: pl.LightningModule 149 ) -> None: 150 if self.mode != "val": 151 return 152 153 if ( 154 self.every_n_epochs is None 155 or trainer.current_epoch % self.every_n_epochs != 0 156 ): 157 return 158 159 return self.on_callback(trainer.loggers, pl_module) 160 161 def on_test_epoch_end( 162 self, trainer: pl.Trainer, pl_module: pl.LightningModule 163 ) -> None: 164 if self.mode != "test": 165 return 166 167 return self.on_callback(trainer.loggers, pl_module) 168 169 @abstractmethod 170 def log_samples(self, logger: Logger, samples: _T, mode: str) -> None: ...
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
106 def on_callback( 107 self, 108 loggers: Sequence[Logger], 109 pl_module: pl.LightningModule, 110 ) -> None: 111 if not len(loggers): 112 LOGGER.debug("[LOGGER] No logger found.") 113 return 114 115 samples = self.to(self.reference_samples, pl_module.device) 116 117 with torch.no_grad(): 118 pl_module.eval() 119 generated_samples = pl_module(samples) 120 pl_module.train() 121 122 for logger in loggers: 123 self.log_samples(logger, generated_samples, "prediction")
180def get_attribute_figure_grid( 181 categories: np.ndarray, 182 locations: np.ndarray, 183 sizes: np.ndarray, 184 rotations: np.ndarray, 185 colors: np.ndarray, 186 image_size: int, 187 ncols: int = 8, 188 padding: float = 2, 189) -> Image.Image: 190 reminder = 1 if categories.shape[0] % ncols else 0 191 nrows = categories.shape[0] // ncols + reminder 192 193 width = ncols * (image_size + padding) + padding 194 height = nrows * (image_size + padding) + padding 195 dpi = 1 196 197 figure = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi, facecolor="white") 198 gs = gridspec.GridSpec( 199 nrows, 200 ncols, 201 wspace=padding / image_size, 202 hspace=padding / image_size, 203 left=padding / width, 204 right=1 - padding / width, 205 bottom=padding / height, 206 top=1 - padding / height, 207 ) 208 for i in range(nrows): 209 for j in range(ncols): 210 k = i * ncols + j 211 if k >= categories.shape[0]: 212 break 213 ax = plt.subplot(gs[i, j]) 214 generate_image( 215 ax, 216 categories[k], 217 locations[k], 218 sizes[k], 219 rotations[k], 220 colors[k], 221 image_size, 222 ) 223 ax.set_facecolor("black") 224 image = get_pil_image(figure) 225 plt.close(figure) 226 return image
229def attribute_image_grid( 230 samples: Sequence[torch.Tensor], 231 image_size: int, 232 ncols: int, 233) -> Image.Image: 234 unnormalizer = UnnormalizeAttributes(image_size=image_size) 235 attributes = unnormalizer(tensor_to_attribute(samples)) 236 237 categories = attributes.category.detach().cpu().numpy() 238 locations = torch.stack([attributes.x, attributes.y], dim=1).detach().cpu().numpy() 239 colors = ( 240 ( 241 torch.stack( 242 [ 243 attributes.color_r, 244 attributes.color_g, 245 attributes.color_b, 246 ], 247 dim=1, 248 ) 249 ) 250 .cpu() 251 .numpy() 252 ) 253 sizes = attributes.size.detach().cpu().numpy() 254 rotations = attributes.rotation.detach().cpu().numpy() 255 256 return get_attribute_figure_grid( 257 categories, 258 locations, 259 sizes, 260 rotations, 261 colors, 262 image_size, 263 ncols, 264 padding=2, 265 )
268class LogAttributesCallback(LogSamplesCallback[Sequence[torch.Tensor]]): 269 def __init__( 270 self, 271 reference_samples: Sequence[torch.Tensor], 272 log_key: str, 273 mode: Literal["train", "val", "test"], 274 image_size: int, 275 every_n_epochs: int | None = 1, 276 ncols: int = 8, 277 ) -> None: 278 super().__init__(reference_samples, log_key, mode, every_n_epochs) 279 self.image_size = image_size 280 self.ncols = ncols 281 282 def to( 283 self, samples: Sequence[torch.Tensor], device: torch.device 284 ) -> list[torch.Tensor]: 285 return [x.to(device) for x in samples] 286 287 def log_samples( 288 self, logger: Logger, samples: Sequence[torch.Tensor], mode: str 289 ) -> None: 290 image = attribute_image_grid( 291 samples, image_size=self.image_size, ncols=self.ncols 292 ) 293 log_image(logger, f"{self.log_key}_{mode}", image, self.get_step())
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
269 def __init__( 270 self, 271 reference_samples: Sequence[torch.Tensor], 272 log_key: str, 273 mode: Literal["train", "val", "test"], 274 image_size: int, 275 every_n_epochs: int | None = 1, 276 ncols: int = 8, 277 ) -> None: 278 super().__init__(reference_samples, log_key, mode, every_n_epochs) 279 self.image_size = image_size 280 self.ncols = ncols
296class LogTextCallback(LogSamplesCallback[Mapping[str, torch.Tensor]]): 297 def __init__( 298 self, 299 reference_samples: Mapping[str, torch.Tensor], 300 log_key: str, 301 mode: Literal["train", "val", "test"], 302 image_size: int, 303 vocab: str, 304 merges: str, 305 every_n_epochs: int | None = 1, 306 ncols: int = 8, 307 ) -> None: 308 super().__init__(reference_samples, log_key, mode, every_n_epochs) 309 self.image_size = image_size 310 self.ncols = ncols 311 self.tokenizer = ByteLevelBPETokenizer(vocab, merges) 312 313 def to( 314 self, samples: Mapping[str, torch.Tensor], device: torch.device 315 ) -> dict[str, torch.Tensor]: 316 return {x: samples[x].to(device) for x in samples} 317 318 def setup( 319 self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str 320 ) -> None: 321 if stage != "fit": 322 return 323 assert isinstance(pl_module, GRUTextDomainModule) 324 device = trainer.strategy.root_device 325 self.reference_samples = self.to(self.reference_samples, device) 326 for logger in trainer.loggers: 327 self.log_samples(logger, self.reference_samples, "reference") 328 329 def on_callback( 330 self, 331 loggers: Sequence[Logger], 332 pl_module: pl.LightningModule, 333 ) -> None: 334 assert isinstance(pl_module, GRUTextDomainModule) 335 336 samples = self.to(self.reference_samples, pl_module.device) 337 338 if not len(loggers): 339 LOGGER.debug("[LOGGER] No logger found.") 340 return 341 342 with torch.no_grad(): 343 pl_module.eval() 344 generated_samples = pl_module(samples) 345 pl_module.train() 346 347 for logger in loggers: 348 self.log_samples(logger, generated_samples, "prediction") 349 350 def log_samples( 351 self, logger: Logger, samples: Mapping[str, torch.Tensor], mode: str 352 ) -> None: 353 if not isinstance(logger, WandbLogger): 354 LOGGER.warning("Only logging to wandb is supported") 355 return 356 357 assert self.tokenizer is not None 358 text = self.tokenizer.decode_batch( 359 samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True 360 ) 361 text = [[t.replace("<pad>", "")] for t in text] 362 log_text(logger, f"{self.log_key}_{mode}_str", ["text"], text, self.get_step())
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
297 def __init__( 298 self, 299 reference_samples: Mapping[str, torch.Tensor], 300 log_key: str, 301 mode: Literal["train", "val", "test"], 302 image_size: int, 303 vocab: str, 304 merges: str, 305 every_n_epochs: int | None = 1, 306 ncols: int = 8, 307 ) -> None: 308 super().__init__(reference_samples, log_key, mode, every_n_epochs) 309 self.image_size = image_size 310 self.ncols = ncols 311 self.tokenizer = ByteLevelBPETokenizer(vocab, merges)
318 def setup( 319 self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str 320 ) -> None: 321 if stage != "fit": 322 return 323 assert isinstance(pl_module, GRUTextDomainModule) 324 device = trainer.strategy.root_device 325 self.reference_samples = self.to(self.reference_samples, device) 326 for logger in trainer.loggers: 327 self.log_samples(logger, self.reference_samples, "reference")
Called when fit, validate, test, predict, or tune begins.
329 def on_callback( 330 self, 331 loggers: Sequence[Logger], 332 pl_module: pl.LightningModule, 333 ) -> None: 334 assert isinstance(pl_module, GRUTextDomainModule) 335 336 samples = self.to(self.reference_samples, pl_module.device) 337 338 if not len(loggers): 339 LOGGER.debug("[LOGGER] No logger found.") 340 return 341 342 with torch.no_grad(): 343 pl_module.eval() 344 generated_samples = pl_module(samples) 345 pl_module.train() 346 347 for logger in loggers: 348 self.log_samples(logger, generated_samples, "prediction")
350 def log_samples( 351 self, logger: Logger, samples: Mapping[str, torch.Tensor], mode: str 352 ) -> None: 353 if not isinstance(logger, WandbLogger): 354 LOGGER.warning("Only logging to wandb is supported") 355 return 356 357 assert self.tokenizer is not None 358 text = self.tokenizer.decode_batch( 359 samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True 360 ) 361 text = [[t.replace("<pad>", "")] for t in text] 362 log_text(logger, f"{self.log_key}_{mode}_str", ["text"], text, self.get_step())
365class LogVisualCallback(LogSamplesCallback[torch.Tensor]): 366 def __init__( 367 self, 368 reference_samples: torch.Tensor, 369 log_key: str, 370 mode: Literal["train", "val", "test"], 371 every_n_epochs: int | None = 1, 372 ncols: int = 8, 373 ) -> None: 374 super().__init__(reference_samples, log_key, mode, every_n_epochs) 375 self.ncols = ncols 376 377 def to(self, samples: torch.Tensor, device: torch.device) -> torch.Tensor: 378 return samples.to(device) 379 380 def log_samples(self, logger: Logger, samples: torch.Tensor, mode: str) -> None: 381 images = make_grid(samples, nrow=self.ncols, pad_value=1) 382 log_image(logger, f"{self.log_key}_{mode}", images)
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
385class LogText2AttrCallback( 386 LogSamplesCallback[ 387 Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]] 388 ] 389): 390 def __init__( 391 self, 392 reference_samples: Mapping[ 393 str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor] 394 ], 395 log_key: str, 396 mode: Literal["train", "val", "test"], 397 every_n_epochs: int | None = 1, 398 image_size: int = 32, 399 ncols: int = 8, 400 vocab: str | None = None, 401 merges: str | None = None, 402 ) -> None: 403 super().__init__(reference_samples, log_key, mode, every_n_epochs) 404 self.image_size = image_size 405 self.ncols = ncols 406 self.tokenizer = ByteLevelBPETokenizer(vocab, merges) 407 self.reference_samples = reference_samples 408 409 def to( 410 self, 411 samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]], 412 device: torch.device, 413 ) -> dict[str, dict[str, torch.Tensor] | list[torch.Tensor]]: 414 latents: dict[str, dict[str, torch.Tensor] | list[torch.Tensor]] = {} 415 for domain_name, domain in samples.items(): 416 if isinstance(domain, dict): 417 latents[domain_name] = {k: x.to(device) for k, x in domain.items()} 418 elif isinstance(domain, list): 419 latents[domain_name] = [x.to(device) for x in domain] 420 return latents 421 422 def setup( 423 self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str 424 ) -> None: 425 if stage != "fit": 426 return 427 assert isinstance(pl_module, Text2Attr) 428 device = trainer.strategy.root_device 429 self.reference_samples = self.to(self.reference_samples, device) 430 for logger in trainer.loggers: 431 self.log_samples(logger, self.reference_samples, "reference") 432 433 def on_callback( 434 self, 435 loggers: Sequence[Logger], 436 pl_module: pl.LightningModule, 437 ) -> None: 438 assert isinstance(pl_module, Text2Attr) 439 440 samples = self.to(self.reference_samples, pl_module.device) 441 442 if not len(loggers): 443 LOGGER.debug("[LOGGER] No logger found.") 444 return 445 446 with torch.no_grad(): 447 pl_module.eval() 448 generated_samples = pl_module(samples["t"]) 449 pl_module.train() 450 451 for logger in loggers: 452 self.log_samples(logger, generated_samples, "prediction") 453 454 def log_samples( 455 self, 456 logger: Logger, 457 samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]], 458 mode: str, 459 ) -> None: 460 for domain_name, domain in samples.items(): 461 if domain_name == "t": 462 assert self.tokenizer is not None 463 assert isinstance(domain, dict) 464 text = self.tokenizer.decode_batch( 465 domain["tokens"].detach().cpu().tolist(), skip_special_tokens=True 466 ) 467 text = [[t.replace("<pad>", "")] for t in text] 468 log_text( 469 logger, 470 f"{self.log_key}_{mode}_str", 471 ["text"], 472 text, 473 self.get_step(), 474 ) 475 elif domain_name == "attr": 476 assert isinstance(domain, list) 477 image = attribute_image_grid( 478 domain, 479 image_size=self.image_size, 480 ncols=self.ncols, 481 ) 482 log_image(logger, f"{self.log_key}_{mode}", image, self.get_step())
Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows::
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
390 def __init__( 391 self, 392 reference_samples: Mapping[ 393 str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor] 394 ], 395 log_key: str, 396 mode: Literal["train", "val", "test"], 397 every_n_epochs: int | None = 1, 398 image_size: int = 32, 399 ncols: int = 8, 400 vocab: str | None = None, 401 merges: str | None = None, 402 ) -> None: 403 super().__init__(reference_samples, log_key, mode, every_n_epochs) 404 self.image_size = image_size 405 self.ncols = ncols 406 self.tokenizer = ByteLevelBPETokenizer(vocab, merges) 407 self.reference_samples = reference_samples
409 def to( 410 self, 411 samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]], 412 device: torch.device, 413 ) -> dict[str, dict[str, torch.Tensor] | list[torch.Tensor]]: 414 latents: dict[str, dict[str, torch.Tensor] | list[torch.Tensor]] = {} 415 for domain_name, domain in samples.items(): 416 if isinstance(domain, dict): 417 latents[domain_name] = {k: x.to(device) for k, x in domain.items()} 418 elif isinstance(domain, list): 419 latents[domain_name] = [x.to(device) for x in domain] 420 return latents
422 def setup( 423 self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str 424 ) -> None: 425 if stage != "fit": 426 return 427 assert isinstance(pl_module, Text2Attr) 428 device = trainer.strategy.root_device 429 self.reference_samples = self.to(self.reference_samples, device) 430 for logger in trainer.loggers: 431 self.log_samples(logger, self.reference_samples, "reference")
Called when fit, validate, test, predict, or tune begins.
433 def on_callback( 434 self, 435 loggers: Sequence[Logger], 436 pl_module: pl.LightningModule, 437 ) -> None: 438 assert isinstance(pl_module, Text2Attr) 439 440 samples = self.to(self.reference_samples, pl_module.device) 441 442 if not len(loggers): 443 LOGGER.debug("[LOGGER] No logger found.") 444 return 445 446 with torch.no_grad(): 447 pl_module.eval() 448 generated_samples = pl_module(samples["t"]) 449 pl_module.train() 450 451 for logger in loggers: 452 self.log_samples(logger, generated_samples, "prediction")
454 def log_samples( 455 self, 456 logger: Logger, 457 samples: Mapping[str, Mapping[str, torch.Tensor] | Sequence[torch.Tensor]], 458 mode: str, 459 ) -> None: 460 for domain_name, domain in samples.items(): 461 if domain_name == "t": 462 assert self.tokenizer is not None 463 assert isinstance(domain, dict) 464 text = self.tokenizer.decode_batch( 465 domain["tokens"].detach().cpu().tolist(), skip_special_tokens=True 466 ) 467 text = [[t.replace("<pad>", "")] for t in text] 468 log_text( 469 logger, 470 f"{self.log_key}_{mode}_str", 471 ["text"], 472 text, 473 self.get_step(), 474 ) 475 elif domain_name == "attr": 476 assert isinstance(domain, list) 477 image = attribute_image_grid( 478 domain, 479 image_size=self.image_size, 480 ncols=self.ncols, 481 ) 482 log_image(logger, f"{self.log_key}_{mode}", image, self.get_step())
485def batch_to_device( 486 samples: Mapping[ 487 frozenset[str], 488 Mapping[str, Any], 489 ], 490 device: torch.device, 491) -> dict[frozenset[str], dict[str, Any]]: 492 out: dict[frozenset[str], dict[str, Any]] = {} 493 for domain_names, domains in samples.items(): 494 latents: dict[str, Any] = {} 495 for domain_name, domain in domains.items(): 496 if isinstance(domain, torch.Tensor): 497 latents[domain_name] = domain.to(device) 498 elif ( 499 isinstance(domain, Mapping) 500 and len(domain) 501 and isinstance(next(iter(domain.values())), torch.Tensor) 502 ): 503 latents[domain_name] = {k: x.to(device) for k, x in domain.items()} 504 elif ( 505 isinstance(domain, Sequence) 506 and len(domain) 507 and isinstance(domain[0], torch.Tensor) 508 ): 509 latents[domain_name] = [x.to(device) for x in domain] 510 else: 511 latents[domain_name] = domain 512 out[domain_names] = latents 513 return out
516class LogGWImagesCallback(pl.Callback): 517 def __init__( 518 self, 519 reference_samples: Mapping[frozenset[str], Mapping[str, Any]], 520 log_key: str, 521 mode: Literal["train", "val", "test"], 522 every_n_epochs: int | None = 1, 523 image_size: int = 32, 524 ncols: int = 8, 525 filter: Sequence[str] | None = None, 526 vocab: str | None = None, 527 merges: str | None = None, 528 ) -> None: 529 super().__init__() 530 self.mode = mode 531 self.reference_samples = reference_samples 532 self.every_n_epochs = every_n_epochs 533 self.log_key = log_key 534 self.image_size = image_size 535 self.ncols = ncols 536 self.filter = filter 537 self.tokenizer = None 538 if vocab is not None and merges is not None: 539 self.tokenizer = ByteLevelBPETokenizer(vocab, merges) 540 self._global_step = 0 541 542 def get_step(self): 543 self._global_step += 1 544 return self._global_step - 1 545 546 def to( 547 self, 548 samples: Mapping[ 549 frozenset[str], 550 Mapping[ 551 str, torch.Tensor | Sequence[torch.Tensor] | Mapping[str, torch.Tensor] 552 ], 553 ], 554 device: torch.device, 555 ) -> dict[ 556 frozenset[str], 557 dict[str, torch.Tensor | list[torch.Tensor] | dict[Any, torch.Tensor]], 558 ]: 559 return batch_to_device(samples, device) 560 561 def setup( 562 self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str 563 ) -> None: 564 if stage != "fit": 565 return 566 assert isinstance(pl_module, GlobalWorkspaceBase) 567 device = trainer.strategy.root_device 568 self.reference_samples = self.to(self.reference_samples, device) 569 570 for domain_names, domains in self.reference_samples.items(): 571 for domain_name, domain_tensor in domains.items(): 572 for logger in trainer.loggers: 573 self.log_samples( 574 logger, 575 pl_module, 576 domain_tensor, 577 domain_name, 578 f"ref_{'-'.join(domain_names)}_{domain_name}", 579 ) 580 581 def on_callback( 582 self, 583 loggers: Sequence[Logger], 584 pl_module: GlobalWorkspaceBase, 585 ) -> None: 586 if not (len(loggers)): 587 return 588 589 with torch.no_grad(): 590 latent_groups = pl_module.encode_domains(self.reference_samples) 591 predictions = cast(GWPredictionsBase, pl_module(latent_groups)) 592 593 for logger in loggers: 594 for domains, preds in predictions["broadcasts"].items(): 595 domain_from = ",".join(domains) 596 for domain, pred in preds.items(): 597 log_name = f"pred_trans_{domain_from}_to_{domain}" 598 if self.filter is not None and log_name not in self.filter: 599 continue 600 self.log_samples( 601 logger, 602 pl_module, 603 pl_module.decode_domain(pred, domain), 604 domain, 605 log_name, 606 ) 607 for domains, preds in predictions["cycles"].items(): 608 domain_from = ",".join(domains) 609 for domain, pred in preds.items(): 610 log_name = f"pred_cycle_{domain_from}_to_{domain}" 611 if self.filter is not None and log_name not in self.filter: 612 continue 613 self.log_samples( 614 logger, 615 pl_module, 616 pl_module.decode_domain(pred, domain), 617 domain, 618 log_name, 619 ) 620 621 def on_train_epoch_end( 622 self, 623 trainer: pl.Trainer, 624 pl_module: pl.LightningModule, 625 ) -> None: 626 if self.mode != "train": 627 return 628 629 if not isinstance(pl_module, GlobalWorkspaceBase): 630 return 631 632 if ( 633 self.every_n_epochs is None 634 or trainer.current_epoch % self.every_n_epochs != 0 635 ): 636 return 637 638 return self.on_callback(trainer.loggers, pl_module) 639 640 def on_validation_epoch_end( 641 self, 642 trainer: pl.Trainer, 643 pl_module: pl.LightningModule, 644 ) -> None: 645 if self.mode != "val": 646 return 647 648 if not isinstance(pl_module, GlobalWorkspaceBase): 649 return 650 651 if ( 652 self.every_n_epochs is None 653 or trainer.current_epoch % self.every_n_epochs != 0 654 ): 655 return 656 657 return self.on_callback(trainer.loggers, pl_module) 658 659 def on_test_epoch_end( 660 self, 661 trainer: pl.Trainer, 662 pl_module: pl.LightningModule, 663 ) -> None: 664 if self.mode != "test": 665 return 666 667 if not isinstance(pl_module, GlobalWorkspaceBase): 668 return 669 670 return self.on_callback(trainer.loggers, pl_module) 671 672 def on_train_end( 673 self, 674 trainer: pl.Trainer, 675 pl_module: pl.LightningModule, 676 ) -> None: 677 if self.mode == "test": 678 return 679 680 if not isinstance(pl_module, GlobalWorkspaceBase): 681 return 682 683 return self.on_callback(trainer.loggers, pl_module) 684 685 def log_samples( 686 self, 687 logger: Logger, 688 pl_module: GlobalWorkspaceBase, 689 samples: Any, 690 domain: str, 691 mode: str, 692 ) -> None: 693 match domain: 694 case "v": 695 self.log_visual_samples(logger, samples, mode) 696 case "v_latents": 697 assert "v_latents" in pl_module.domain_mods 698 699 module = cast( 700 VisualLatentDomainModule, 701 pl_module.domain_mods["v_latents"], 702 ) 703 self.log_visual_samples(logger, module.decode_images(samples), mode) 704 case "attr": 705 self.log_attribute_samples(logger, samples, mode) 706 case "t": 707 self.log_text_samples(logger, samples, mode) 708 if "attr" in samples: 709 self.log_attribute_samples(logger, samples["attr"], mode + "_attr") 710 711 def log_visual_samples( 712 self, 713 logger: Logger, 714 samples: Any, 715 mode: str, 716 ) -> None: 717 images = make_grid(samples, nrow=self.ncols, pad_value=1) 718 log_image(logger, f"{self.log_key}/{mode}", images, self.get_step()) 719 720 def log_attribute_samples( 721 self, 722 logger: Logger, 723 samples: Any, 724 mode: str, 725 ) -> None: 726 image = attribute_image_grid( 727 samples, 728 image_size=self.image_size, 729 ncols=self.ncols, 730 ) 731 log_image(logger, f"{self.log_key}/{mode}", image, self.get_step()) 732 733 def log_text_samples( 734 self, 735 logger: Logger, 736 samples: Any, 737 mode: str, 738 ) -> None: 739 assert self.tokenizer is not None 740 text = self.tokenizer.decode_batch( 741 samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True 742 ) 743 text = [[t.replace("<pad>", "")] for t in text] 744 log_text(logger, f"{self.log_key}/{mode}", ["text"], text, self.get_step())
Abstract base class used to build new callbacks.
Subclass this class and override any of the relevant hooks
517 def __init__( 518 self, 519 reference_samples: Mapping[frozenset[str], Mapping[str, Any]], 520 log_key: str, 521 mode: Literal["train", "val", "test"], 522 every_n_epochs: int | None = 1, 523 image_size: int = 32, 524 ncols: int = 8, 525 filter: Sequence[str] | None = None, 526 vocab: str | None = None, 527 merges: str | None = None, 528 ) -> None: 529 super().__init__() 530 self.mode = mode 531 self.reference_samples = reference_samples 532 self.every_n_epochs = every_n_epochs 533 self.log_key = log_key 534 self.image_size = image_size 535 self.ncols = ncols 536 self.filter = filter 537 self.tokenizer = None 538 if vocab is not None and merges is not None: 539 self.tokenizer = ByteLevelBPETokenizer(vocab, merges) 540 self._global_step = 0
546 def to( 547 self, 548 samples: Mapping[ 549 frozenset[str], 550 Mapping[ 551 str, torch.Tensor | Sequence[torch.Tensor] | Mapping[str, torch.Tensor] 552 ], 553 ], 554 device: torch.device, 555 ) -> dict[ 556 frozenset[str], 557 dict[str, torch.Tensor | list[torch.Tensor] | dict[Any, torch.Tensor]], 558 ]: 559 return batch_to_device(samples, device)
561 def setup( 562 self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str 563 ) -> None: 564 if stage != "fit": 565 return 566 assert isinstance(pl_module, GlobalWorkspaceBase) 567 device = trainer.strategy.root_device 568 self.reference_samples = self.to(self.reference_samples, device) 569 570 for domain_names, domains in self.reference_samples.items(): 571 for domain_name, domain_tensor in domains.items(): 572 for logger in trainer.loggers: 573 self.log_samples( 574 logger, 575 pl_module, 576 domain_tensor, 577 domain_name, 578 f"ref_{'-'.join(domain_names)}_{domain_name}", 579 )
Called when fit, validate, test, predict, or tune begins.
581 def on_callback( 582 self, 583 loggers: Sequence[Logger], 584 pl_module: GlobalWorkspaceBase, 585 ) -> None: 586 if not (len(loggers)): 587 return 588 589 with torch.no_grad(): 590 latent_groups = pl_module.encode_domains(self.reference_samples) 591 predictions = cast(GWPredictionsBase, pl_module(latent_groups)) 592 593 for logger in loggers: 594 for domains, preds in predictions["broadcasts"].items(): 595 domain_from = ",".join(domains) 596 for domain, pred in preds.items(): 597 log_name = f"pred_trans_{domain_from}_to_{domain}" 598 if self.filter is not None and log_name not in self.filter: 599 continue 600 self.log_samples( 601 logger, 602 pl_module, 603 pl_module.decode_domain(pred, domain), 604 domain, 605 log_name, 606 ) 607 for domains, preds in predictions["cycles"].items(): 608 domain_from = ",".join(domains) 609 for domain, pred in preds.items(): 610 log_name = f"pred_cycle_{domain_from}_to_{domain}" 611 if self.filter is not None and log_name not in self.filter: 612 continue 613 self.log_samples( 614 logger, 615 pl_module, 616 pl_module.decode_domain(pred, domain), 617 domain, 618 log_name, 619 )
621 def on_train_epoch_end( 622 self, 623 trainer: pl.Trainer, 624 pl_module: pl.LightningModule, 625 ) -> None: 626 if self.mode != "train": 627 return 628 629 if not isinstance(pl_module, GlobalWorkspaceBase): 630 return 631 632 if ( 633 self.every_n_epochs is None 634 or trainer.current_epoch % self.every_n_epochs != 0 635 ): 636 return 637 638 return self.on_callback(trainer.loggers, pl_module)
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
lightning.pytorch.core.LightningModule
and access them in this hook:
class MyLightningModule(L.LightningModule):
def __init__(self):
super().__init__()
self.training_step_outputs = []
def training_step(self):
loss = ...
self.training_step_outputs.append(loss)
return loss
class MyCallback(L.Callback):
def on_train_epoch_end(self, trainer, pl_module):
# do something with all training_step outputs, for example:
epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
pl_module.log("training_epoch_mean", epoch_mean)
# free up the memory
pl_module.training_step_outputs.clear()
640 def on_validation_epoch_end( 641 self, 642 trainer: pl.Trainer, 643 pl_module: pl.LightningModule, 644 ) -> None: 645 if self.mode != "val": 646 return 647 648 if not isinstance(pl_module, GlobalWorkspaceBase): 649 return 650 651 if ( 652 self.every_n_epochs is None 653 or trainer.current_epoch % self.every_n_epochs != 0 654 ): 655 return 656 657 return self.on_callback(trainer.loggers, pl_module)
Called when the val epoch ends.
659 def on_test_epoch_end( 660 self, 661 trainer: pl.Trainer, 662 pl_module: pl.LightningModule, 663 ) -> None: 664 if self.mode != "test": 665 return 666 667 if not isinstance(pl_module, GlobalWorkspaceBase): 668 return 669 670 return self.on_callback(trainer.loggers, pl_module)
Called when the test epoch ends.
672 def on_train_end( 673 self, 674 trainer: pl.Trainer, 675 pl_module: pl.LightningModule, 676 ) -> None: 677 if self.mode == "test": 678 return 679 680 if not isinstance(pl_module, GlobalWorkspaceBase): 681 return 682 683 return self.on_callback(trainer.loggers, pl_module)
Called when the train ends.
685 def log_samples( 686 self, 687 logger: Logger, 688 pl_module: GlobalWorkspaceBase, 689 samples: Any, 690 domain: str, 691 mode: str, 692 ) -> None: 693 match domain: 694 case "v": 695 self.log_visual_samples(logger, samples, mode) 696 case "v_latents": 697 assert "v_latents" in pl_module.domain_mods 698 699 module = cast( 700 VisualLatentDomainModule, 701 pl_module.domain_mods["v_latents"], 702 ) 703 self.log_visual_samples(logger, module.decode_images(samples), mode) 704 case "attr": 705 self.log_attribute_samples(logger, samples, mode) 706 case "t": 707 self.log_text_samples(logger, samples, mode) 708 if "attr" in samples: 709 self.log_attribute_samples(logger, samples["attr"], mode + "_attr")
733 def log_text_samples( 734 self, 735 logger: Logger, 736 samples: Any, 737 mode: str, 738 ) -> None: 739 assert self.tokenizer is not None 740 text = self.tokenizer.decode_batch( 741 samples["tokens"].detach().cpu().tolist(), skip_special_tokens=True 742 ) 743 text = [[t.replace("<pad>", "")] for t in text] 744 log_text(logger, f"{self.log_key}/{mode}", ["text"], text, self.get_step())