shimmer_ssd.modules.domains.attribute

  1from collections.abc import Mapping, Sequence
  2from typing import Any
  3
  4import torch
  5import torch.nn.functional as F
  6from shimmer import LossOutput
  7from shimmer.modules.domain import DomainModule
  8from shimmer.modules.global_workspace import SchedulerArgs
  9from shimmer.modules.vae import (
 10    VAE,
 11    VAEDecoder,
 12    VAEEncoder,
 13    gaussian_nll,
 14    kl_divergence_loss,
 15)
 16from torch import nn
 17from torch.optim.lr_scheduler import OneCycleLR
 18
 19
 20class Encoder(VAEEncoder):
 21    def __init__(
 22        self,
 23        hidden_dim: int,
 24        out_dim: int,
 25    ):
 26        super().__init__()
 27
 28        self.hidden_dim = hidden_dim
 29        self.out_dim = out_dim
 30
 31        self.encoder = nn.Sequential(
 32            nn.Linear(11, hidden_dim),
 33            nn.ReLU(),
 34            nn.Linear(hidden_dim, hidden_dim),
 35            nn.ReLU(),
 36            nn.Linear(hidden_dim, out_dim),
 37            nn.ReLU(),
 38        )
 39
 40        self.q_mean = nn.Linear(self.out_dim, self.out_dim)
 41        self.q_logvar = nn.Linear(self.out_dim, self.out_dim)
 42
 43    def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
 44        out = torch.cat(list(x), dim=-1)
 45        out = self.encoder(out)
 46        return self.q_mean(out), self.q_logvar(out)
 47
 48
 49class Decoder(VAEDecoder):
 50    def __init__(
 51        self,
 52        in_dim: int,
 53        hidden_dim: int,
 54    ):
 55        super().__init__()
 56
 57        self.in_dim = in_dim
 58        self.hidden_dim = hidden_dim
 59
 60        self.decoder = nn.Sequential(
 61            nn.Linear(self.in_dim, self.hidden_dim),
 62            nn.ReLU(),
 63            nn.Linear(self.hidden_dim, self.hidden_dim),
 64            nn.ReLU(),
 65        )
 66
 67        self.decoder_categories = nn.Sequential(
 68            nn.Linear(self.hidden_dim, 3),
 69        )
 70
 71        self.decoder_attributes = nn.Sequential(
 72            nn.Linear(self.hidden_dim, 8),
 73            nn.Tanh(),
 74        )
 75
 76    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
 77        out = self.decoder(x)
 78        return [self.decoder_categories(out), self.decoder_attributes(out)]
 79
 80
 81class AttributeDomainModule(DomainModule):
 82    in_dim = 11
 83
 84    def __init__(
 85        self,
 86        latent_dim: int,
 87        hidden_dim: int,
 88        beta: float = 1,
 89        coef_categories: float = 1,
 90        coef_attributes: float = 1,
 91        optim_lr: float = 1e-3,
 92        optim_weight_decay: float = 0,
 93        scheduler_args: SchedulerArgs | None = None,
 94    ):
 95        """
 96        Defines the Attribute domain module.
 97
 98        Args:
 99            latent_dim (`int`): the latent dimension of the module
100            hidden_dim (`int`): hidden dimension of the VAE encoders and decoders
101            beta (`float`): for beta-VAE
102            coef_categories (`float`): loss coefficient attributed to the category
103                (Defaults to 1.0)
104            coef_attributes (`float`): loss coefficient attributed to the rest of the
105                attributes (Defaults to 1.0)
106            optim_lr (`float`): learning rate for the optimizer
107            optim_weight_decay (`float`): weight decay for the optimizer
108            scheduler_args (`SchedulerArgs | None`): Scheduler arguments
109        """
110        super().__init__(latent_dim)
111        self.save_hyperparameters()
112
113        self.hidden_dim = hidden_dim
114        self.coef_categories = coef_categories
115        self.coef_attributes = coef_attributes
116
117        vae_encoder = Encoder(self.hidden_dim, self.latent_dim)
118        vae_decoder = Decoder(self.latent_dim, self.hidden_dim)
119        self.vae = VAE(vae_encoder, vae_decoder, beta)
120
121        self.optim_lr = optim_lr
122        self.optim_weight_decay = optim_weight_decay
123
124        self.scheduler_args = SchedulerArgs(
125            max_lr=optim_lr,
126            total_steps=1,
127        )
128        self.scheduler_args.update(scheduler_args or {})
129
130    def compute_loss(
131        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
132    ) -> LossOutput:
133        return LossOutput(F.mse_loss(pred, target, reduction="mean"))
134
135    def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
136        """
137        x must contain 2 items:
138        - the class
139        - the attributes
140        """
141        assert (
142            len(x) == 2
143        ), "x must only contain 2 items (use attr_unpaired to add an unpaired value)"
144        return self.vae.encode(x)
145
146    def decode(self, z: torch.Tensor) -> list[torch.Tensor]:
147        out = list(self.vae.decode(z))
148        if not isinstance(out, Sequence):
149            raise ValueError("The output of vae.decode should be a sequence.")
150        return out
151
152    def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:  # type: ignore
153        return self.decode(self.encode(x))
154
155    def generic_step(
156        self,
157        x: Sequence[torch.Tensor],
158        mode: str = "train",
159    ) -> torch.Tensor:
160        x_categories, x_attributes = x[0], x[1]
161
162        (mean, logvar), reconstruction = self.vae(x)
163        reconstruction_categories = reconstruction[0]
164        reconstruction_attributes = reconstruction[1]
165
166        reconstruction_loss_categories = F.cross_entropy(
167            reconstruction_categories,
168            x_categories.argmax(dim=1),
169            reduction="sum",
170        )
171        reconstruction_loss_attributes = gaussian_nll(
172            reconstruction_attributes, torch.tensor(0), x_attributes
173        ).sum()
174
175        reconstruction_loss = (
176            self.coef_categories * reconstruction_loss_categories
177            + self.coef_attributes * reconstruction_loss_attributes
178        )
179        kl_loss = kl_divergence_loss(mean, logvar)
180        total_loss = reconstruction_loss + self.vae.beta * kl_loss
181
182        self.log(
183            f"{mode}/reconstruction_loss_categories",
184            reconstruction_loss_categories,
185        )
186        self.log(
187            f"{mode}/reconstruction_loss_attributes",
188            reconstruction_loss_attributes,
189        )
190        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
191        self.log(f"{mode}/kl_loss", kl_loss)
192        self.log(f"{mode}/loss", total_loss)
193        return total_loss
194
195    def validation_step(  # type: ignore
196        self, batch: Mapping[str, Sequence[torch.Tensor]], _
197    ) -> torch.Tensor:
198        x = batch["attr"]
199        return self.generic_step(x, "val")
200
201    def training_step(  # type: ignore
202        self,
203        batch: Mapping[frozenset[str], Mapping[str, Sequence[torch.Tensor]]],
204        _,
205    ) -> torch.Tensor:
206        x = batch[frozenset(["attr"])]["attr"]
207        return self.generic_step(x, "train")
208
209    def configure_optimizers(  # type: ignore
210        self,
211    ) -> dict[str, Any]:
212        optimizer = torch.optim.AdamW(
213            self.parameters(),
214            lr=self.optim_lr,
215            weight_decay=self.optim_weight_decay,
216        )
217        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
218
219        return {
220            "optimizer": optimizer,
221            "lr_scheduler": {
222                "scheduler": lr_scheduler,
223                "interval": "step",
224            },
225        }
226
227
228class AttributeWithUnpairedDomainModule(DomainModule):
229    in_dim = 11
230
231    def __init__(
232        self,
233        latent_dim: int,
234        hidden_dim: int,
235        beta: float = 1,
236        coef_categories: float = 1,
237        coef_attributes: float = 1,
238        n_unpaired: int = 1,
239        optim_lr: float = 1e-3,
240        optim_weight_decay: float = 0,
241        scheduler_args: SchedulerArgs | None = None,
242        coef_unpaired: float = 0.5,
243    ):
244        super().__init__(latent_dim + n_unpaired)
245
246        if coef_categories < 0 or coef_categories > 1:
247            raise ValueError("coef_categories should be in [0, 1]")
248        if coef_attributes < 0 or coef_attributes > 1:
249            raise ValueError("coef_attributes should be in [0, 1]")
250        if coef_unpaired < 0 or coef_unpaired > 1:
251            raise ValueError("coef_unpaired should be in [0, 1]")
252
253        self.save_hyperparameters()
254        self.paired_dim = latent_dim
255        self.n_unpaired = n_unpaired
256        self.hidden_dim = hidden_dim
257        self.coef_categories = coef_categories
258        self.coef_attributes = coef_attributes
259        self.coef_unpaired = coef_unpaired
260
261        vae_encoder = Encoder(self.hidden_dim, self.latent_dim - self.n_unpaired)
262        vae_decoder = Decoder(self.latent_dim - self.n_unpaired, self.hidden_dim)
263        self.vae = VAE(vae_encoder, vae_decoder, beta)
264
265        self.optim_lr = optim_lr
266        self.optim_weight_decay = optim_weight_decay
267
268        self.scheduler_args = SchedulerArgs(
269            max_lr=optim_lr,
270            total_steps=1,
271        )
272        self.scheduler_args.update(scheduler_args or {})
273
274    def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
275        """
276        x must contains 3 items:
277        - the class
278        - the attributes
279        - the unpaired value
280        """
281        assert len(x) == 3, (
282            "x must have the unpaired value "
283            "(use `attr` instead of `attr_unpaired` otherwise)."
284        )
285        z = self.vae.encode(x[:-1])
286        return torch.cat([z, x[-1]], dim=-1)
287
288    def decode(self, z: torch.Tensor) -> list[torch.Tensor]:
289        paired = z[:, : self.paired_dim]
290        unpaired = z[:, self.paired_dim :]
291        out = list(self.vae.decode(paired))
292        out.append(unpaired)
293        return out
294
295    def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:  # type: ignore
296        return self.decode(self.encode(x))
297
298    def compute_loss(
299        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
300    ) -> LossOutput:
301        paired_loss = F.mse_loss(
302            pred[:, : self.paired_dim], target[:, : self.paired_dim]
303        )
304        unpaired_loss = F.mse_loss(
305            pred[:, self.paired_dim :], target[:, self.paired_dim :]
306        )
307        total_loss = unpaired_loss + paired_loss
308        return LossOutput(
309            loss=total_loss,
310            metrics={
311                "unpaired": unpaired_loss,
312                "paired": paired_loss,
313            },
314        )
315
316
317class AttributeLegacyDomainModule(DomainModule):
318    latent_dim = 11
319
320    def __init__(self):
321        super().__init__(self.latent_dim)
322        self.save_hyperparameters()
323
324    def compute_loss(
325        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
326    ) -> LossOutput:
327        pred_cat, pred_attr, _ = self.decode(pred)
328        target_cat, target_attr, _ = self.decode(target)
329
330        loss_attr = F.mse_loss(pred_attr, target_attr, reduction="mean")
331        loss_cat = F.nll_loss(pred_cat, torch.argmax(target_cat, 1))
332        loss = loss_attr + loss_cat
333
334        return LossOutput(loss, metrics={"loss_attr": loss_attr, "loss_cat": loss_cat})
335
336    def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
337        assert len(x) == 2, "This must have only 2 items."
338        return torch.cat(list(x), dim=-1)
339
340    def decode(self, z: torch.Tensor) -> list:
341        categories = z[:, :3]
342        attr = z[:, 3:11]
343        return [categories, attr]
344
345    def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:  # type: ignore
346        return self.decode(self.encode(x))
class Encoder(shimmer.modules.vae.VAEEncoder):
21class Encoder(VAEEncoder):
22    def __init__(
23        self,
24        hidden_dim: int,
25        out_dim: int,
26    ):
27        super().__init__()
28
29        self.hidden_dim = hidden_dim
30        self.out_dim = out_dim
31
32        self.encoder = nn.Sequential(
33            nn.Linear(11, hidden_dim),
34            nn.ReLU(),
35            nn.Linear(hidden_dim, hidden_dim),
36            nn.ReLU(),
37            nn.Linear(hidden_dim, out_dim),
38            nn.ReLU(),
39        )
40
41        self.q_mean = nn.Linear(self.out_dim, self.out_dim)
42        self.q_logvar = nn.Linear(self.out_dim, self.out_dim)
43
44    def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
45        out = torch.cat(list(x), dim=-1)
46        out = self.encoder(out)
47        return self.q_mean(out), self.q_logvar(out)

Base class for a VAE encoder.

Encoder(hidden_dim: int, out_dim: int)
22    def __init__(
23        self,
24        hidden_dim: int,
25        out_dim: int,
26    ):
27        super().__init__()
28
29        self.hidden_dim = hidden_dim
30        self.out_dim = out_dim
31
32        self.encoder = nn.Sequential(
33            nn.Linear(11, hidden_dim),
34            nn.ReLU(),
35            nn.Linear(hidden_dim, hidden_dim),
36            nn.ReLU(),
37            nn.Linear(hidden_dim, out_dim),
38            nn.ReLU(),
39        )
40
41        self.q_mean = nn.Linear(self.out_dim, self.out_dim)
42        self.q_logvar = nn.Linear(self.out_dim, self.out_dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

hidden_dim
out_dim
encoder
q_mean
q_logvar
def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
44    def forward(self, x: Sequence[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
45        out = torch.cat(list(x), dim=-1)
46        out = self.encoder(out)
47        return self.q_mean(out), self.q_logvar(out)

Encode representation with VAE.

Arguments:
  • x (Any): Some input value
Returns:

tuple[torch.Tensor, torch.Tensor]: the mean and log variance

class Decoder(shimmer.modules.vae.VAEDecoder):
50class Decoder(VAEDecoder):
51    def __init__(
52        self,
53        in_dim: int,
54        hidden_dim: int,
55    ):
56        super().__init__()
57
58        self.in_dim = in_dim
59        self.hidden_dim = hidden_dim
60
61        self.decoder = nn.Sequential(
62            nn.Linear(self.in_dim, self.hidden_dim),
63            nn.ReLU(),
64            nn.Linear(self.hidden_dim, self.hidden_dim),
65            nn.ReLU(),
66        )
67
68        self.decoder_categories = nn.Sequential(
69            nn.Linear(self.hidden_dim, 3),
70        )
71
72        self.decoder_attributes = nn.Sequential(
73            nn.Linear(self.hidden_dim, 8),
74            nn.Tanh(),
75        )
76
77    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
78        out = self.decoder(x)
79        return [self.decoder_categories(out), self.decoder_attributes(out)]

Base class for a VAE decoder.

Decoder(in_dim: int, hidden_dim: int)
51    def __init__(
52        self,
53        in_dim: int,
54        hidden_dim: int,
55    ):
56        super().__init__()
57
58        self.in_dim = in_dim
59        self.hidden_dim = hidden_dim
60
61        self.decoder = nn.Sequential(
62            nn.Linear(self.in_dim, self.hidden_dim),
63            nn.ReLU(),
64            nn.Linear(self.hidden_dim, self.hidden_dim),
65            nn.ReLU(),
66        )
67
68        self.decoder_categories = nn.Sequential(
69            nn.Linear(self.hidden_dim, 3),
70        )
71
72        self.decoder_attributes = nn.Sequential(
73            nn.Linear(self.hidden_dim, 8),
74            nn.Tanh(),
75        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_dim
hidden_dim
decoder
decoder_categories
decoder_attributes
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
77    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
78        out = self.decoder(x)
79        return [self.decoder_categories(out), self.decoder_attributes(out)]

Decode representation with VAE

Arguments:
  • x (torch.Tensor): VAE latent representation representation
Returns:

Any: the reconstructed input

class AttributeDomainModule(shimmer.modules.domain.DomainModule):
 82class AttributeDomainModule(DomainModule):
 83    in_dim = 11
 84
 85    def __init__(
 86        self,
 87        latent_dim: int,
 88        hidden_dim: int,
 89        beta: float = 1,
 90        coef_categories: float = 1,
 91        coef_attributes: float = 1,
 92        optim_lr: float = 1e-3,
 93        optim_weight_decay: float = 0,
 94        scheduler_args: SchedulerArgs | None = None,
 95    ):
 96        """
 97        Defines the Attribute domain module.
 98
 99        Args:
100            latent_dim (`int`): the latent dimension of the module
101            hidden_dim (`int`): hidden dimension of the VAE encoders and decoders
102            beta (`float`): for beta-VAE
103            coef_categories (`float`): loss coefficient attributed to the category
104                (Defaults to 1.0)
105            coef_attributes (`float`): loss coefficient attributed to the rest of the
106                attributes (Defaults to 1.0)
107            optim_lr (`float`): learning rate for the optimizer
108            optim_weight_decay (`float`): weight decay for the optimizer
109            scheduler_args (`SchedulerArgs | None`): Scheduler arguments
110        """
111        super().__init__(latent_dim)
112        self.save_hyperparameters()
113
114        self.hidden_dim = hidden_dim
115        self.coef_categories = coef_categories
116        self.coef_attributes = coef_attributes
117
118        vae_encoder = Encoder(self.hidden_dim, self.latent_dim)
119        vae_decoder = Decoder(self.latent_dim, self.hidden_dim)
120        self.vae = VAE(vae_encoder, vae_decoder, beta)
121
122        self.optim_lr = optim_lr
123        self.optim_weight_decay = optim_weight_decay
124
125        self.scheduler_args = SchedulerArgs(
126            max_lr=optim_lr,
127            total_steps=1,
128        )
129        self.scheduler_args.update(scheduler_args or {})
130
131    def compute_loss(
132        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
133    ) -> LossOutput:
134        return LossOutput(F.mse_loss(pred, target, reduction="mean"))
135
136    def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
137        """
138        x must contain 2 items:
139        - the class
140        - the attributes
141        """
142        assert (
143            len(x) == 2
144        ), "x must only contain 2 items (use attr_unpaired to add an unpaired value)"
145        return self.vae.encode(x)
146
147    def decode(self, z: torch.Tensor) -> list[torch.Tensor]:
148        out = list(self.vae.decode(z))
149        if not isinstance(out, Sequence):
150            raise ValueError("The output of vae.decode should be a sequence.")
151        return out
152
153    def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:  # type: ignore
154        return self.decode(self.encode(x))
155
156    def generic_step(
157        self,
158        x: Sequence[torch.Tensor],
159        mode: str = "train",
160    ) -> torch.Tensor:
161        x_categories, x_attributes = x[0], x[1]
162
163        (mean, logvar), reconstruction = self.vae(x)
164        reconstruction_categories = reconstruction[0]
165        reconstruction_attributes = reconstruction[1]
166
167        reconstruction_loss_categories = F.cross_entropy(
168            reconstruction_categories,
169            x_categories.argmax(dim=1),
170            reduction="sum",
171        )
172        reconstruction_loss_attributes = gaussian_nll(
173            reconstruction_attributes, torch.tensor(0), x_attributes
174        ).sum()
175
176        reconstruction_loss = (
177            self.coef_categories * reconstruction_loss_categories
178            + self.coef_attributes * reconstruction_loss_attributes
179        )
180        kl_loss = kl_divergence_loss(mean, logvar)
181        total_loss = reconstruction_loss + self.vae.beta * kl_loss
182
183        self.log(
184            f"{mode}/reconstruction_loss_categories",
185            reconstruction_loss_categories,
186        )
187        self.log(
188            f"{mode}/reconstruction_loss_attributes",
189            reconstruction_loss_attributes,
190        )
191        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
192        self.log(f"{mode}/kl_loss", kl_loss)
193        self.log(f"{mode}/loss", total_loss)
194        return total_loss
195
196    def validation_step(  # type: ignore
197        self, batch: Mapping[str, Sequence[torch.Tensor]], _
198    ) -> torch.Tensor:
199        x = batch["attr"]
200        return self.generic_step(x, "val")
201
202    def training_step(  # type: ignore
203        self,
204        batch: Mapping[frozenset[str], Mapping[str, Sequence[torch.Tensor]]],
205        _,
206    ) -> torch.Tensor:
207        x = batch[frozenset(["attr"])]["attr"]
208        return self.generic_step(x, "train")
209
210    def configure_optimizers(  # type: ignore
211        self,
212    ) -> dict[str, Any]:
213        optimizer = torch.optim.AdamW(
214            self.parameters(),
215            lr=self.optim_lr,
216            weight_decay=self.optim_weight_decay,
217        )
218        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
219
220        return {
221            "optimizer": optimizer,
222            "lr_scheduler": {
223                "scheduler": lr_scheduler,
224                "interval": "step",
225            },
226        }

Base class for a DomainModule that defines domain specific modules of the GW.

AttributeDomainModule( latent_dim: int, hidden_dim: int, beta: float = 1, coef_categories: float = 1, coef_attributes: float = 1, optim_lr: float = 0.001, optim_weight_decay: float = 0, scheduler_args: shimmer.modules.global_workspace.SchedulerArgs | None = None)
 85    def __init__(
 86        self,
 87        latent_dim: int,
 88        hidden_dim: int,
 89        beta: float = 1,
 90        coef_categories: float = 1,
 91        coef_attributes: float = 1,
 92        optim_lr: float = 1e-3,
 93        optim_weight_decay: float = 0,
 94        scheduler_args: SchedulerArgs | None = None,
 95    ):
 96        """
 97        Defines the Attribute domain module.
 98
 99        Args:
100            latent_dim (`int`): the latent dimension of the module
101            hidden_dim (`int`): hidden dimension of the VAE encoders and decoders
102            beta (`float`): for beta-VAE
103            coef_categories (`float`): loss coefficient attributed to the category
104                (Defaults to 1.0)
105            coef_attributes (`float`): loss coefficient attributed to the rest of the
106                attributes (Defaults to 1.0)
107            optim_lr (`float`): learning rate for the optimizer
108            optim_weight_decay (`float`): weight decay for the optimizer
109            scheduler_args (`SchedulerArgs | None`): Scheduler arguments
110        """
111        super().__init__(latent_dim)
112        self.save_hyperparameters()
113
114        self.hidden_dim = hidden_dim
115        self.coef_categories = coef_categories
116        self.coef_attributes = coef_attributes
117
118        vae_encoder = Encoder(self.hidden_dim, self.latent_dim)
119        vae_decoder = Decoder(self.latent_dim, self.hidden_dim)
120        self.vae = VAE(vae_encoder, vae_decoder, beta)
121
122        self.optim_lr = optim_lr
123        self.optim_weight_decay = optim_weight_decay
124
125        self.scheduler_args = SchedulerArgs(
126            max_lr=optim_lr,
127            total_steps=1,
128        )
129        self.scheduler_args.update(scheduler_args or {})

Defines the Attribute domain module.

Arguments:
  • latent_dim (int): the latent dimension of the module
  • hidden_dim (int): hidden dimension of the VAE encoders and decoders
  • beta (float): for beta-VAE
  • coef_categories (float): loss coefficient attributed to the category (Defaults to 1.0)
  • coef_attributes (float): loss coefficient attributed to the rest of the attributes (Defaults to 1.0)
  • optim_lr (float): learning rate for the optimizer
  • optim_weight_decay (float): weight decay for the optimizer
  • scheduler_args (SchedulerArgs | None): Scheduler arguments
in_dim = 11
hidden_dim
coef_categories
coef_attributes
vae
optim_lr
optim_weight_decay
scheduler_args
def compute_loss( self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> shimmer.modules.domain.LossOutput:
131    def compute_loss(
132        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
133    ) -> LossOutput:
134        return LossOutput(F.mse_loss(pred, target, reduction="mean"))

Generic loss computation the modality.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
  • raw_target (Any): raw data from the input
Results:

LossOutput | None: LossOuput with training loss and additional metrics. If None is returned, this loss will be ignored and will not participate in the total loss.

def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
136    def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
137        """
138        x must contain 2 items:
139        - the class
140        - the attributes
141        """
142        assert (
143            len(x) == 2
144        ), "x must only contain 2 items (use attr_unpaired to add an unpaired value)"
145        return self.vae.encode(x)

x must contain 2 items:

  • the class
  • the attributes
def decode(self, z: torch.Tensor) -> list[torch.Tensor]:
147    def decode(self, z: torch.Tensor) -> list[torch.Tensor]:
148        out = list(self.vae.decode(z))
149        if not isinstance(out, Sequence):
150            raise ValueError("The output of vae.decode should be a sequence.")
151        return out

Decode data from unimodal representation back to the domain data.

Arguments:
  • z (torch.Tensor): unimodal representation of the domain.
Returns:

Any: the original domain data.

def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:
153    def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:  # type: ignore
154        return self.decode(self.encode(x))

Same as torch.nn.Module.forward().

Arguments:
  • *args: Whatever you decide to pass into the forward method.
  • **kwargs: Keyword arguments are also possible.
Return:

Your model's output

def generic_step(self, x: Sequence[torch.Tensor], mode: str = 'train') -> torch.Tensor:
156    def generic_step(
157        self,
158        x: Sequence[torch.Tensor],
159        mode: str = "train",
160    ) -> torch.Tensor:
161        x_categories, x_attributes = x[0], x[1]
162
163        (mean, logvar), reconstruction = self.vae(x)
164        reconstruction_categories = reconstruction[0]
165        reconstruction_attributes = reconstruction[1]
166
167        reconstruction_loss_categories = F.cross_entropy(
168            reconstruction_categories,
169            x_categories.argmax(dim=1),
170            reduction="sum",
171        )
172        reconstruction_loss_attributes = gaussian_nll(
173            reconstruction_attributes, torch.tensor(0), x_attributes
174        ).sum()
175
176        reconstruction_loss = (
177            self.coef_categories * reconstruction_loss_categories
178            + self.coef_attributes * reconstruction_loss_attributes
179        )
180        kl_loss = kl_divergence_loss(mean, logvar)
181        total_loss = reconstruction_loss + self.vae.beta * kl_loss
182
183        self.log(
184            f"{mode}/reconstruction_loss_categories",
185            reconstruction_loss_categories,
186        )
187        self.log(
188            f"{mode}/reconstruction_loss_attributes",
189            reconstruction_loss_attributes,
190        )
191        self.log(f"{mode}/reconstruction_loss", reconstruction_loss)
192        self.log(f"{mode}/kl_loss", kl_loss)
193        self.log(f"{mode}/loss", total_loss)
194        return total_loss
def validation_step(self, batch: Mapping[str, Sequence[torch.Tensor]], _) -> torch.Tensor:
196    def validation_step(  # type: ignore
197        self, batch: Mapping[str, Sequence[torch.Tensor]], _
198    ) -> torch.Tensor:
199        x = batch["attr"]
200        return self.generic_step(x, "val")

Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary. Can include any keys, but must include the key 'loss'.
  • None - Skip to the next batch.
# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples::

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    ...
Note:

If you don't need to validate you don't need to implement this method.

Note:

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

def training_step( self, batch: Mapping[frozenset[str], Mapping[str, Sequence[torch.Tensor]]], _) -> torch.Tensor:
202    def training_step(  # type: ignore
203        self,
204        batch: Mapping[frozenset[str], Mapping[str, Sequence[torch.Tensor]]],
205        _,
206    ) -> torch.Tensor:
207        x = batch[frozenset(["attr"])]["attr"]
208        return self.generic_step(x, "train")

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Arguments:
  • batch: The output of your data iterable, normally a ~torch.utils.data.DataLoader.
  • batch_idx: The index of this batch.
  • dataloader_idx: The index of the dataloader that produced this batch. (only if multiple dataloaders used)
Return:
  • ~torch.Tensor - The loss tensor
  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.
  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example::

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()
Note:

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

def configure_optimizers(self) -> dict[str, typing.Any]:
210    def configure_optimizers(  # type: ignore
211        self,
212    ) -> dict[str, Any]:
213        optimizer = torch.optim.AdamW(
214            self.parameters(),
215            lr=self.optim_lr,
216            weight_decay=self.optim_weight_decay,
217        )
218        lr_scheduler = OneCycleLR(optimizer, **self.scheduler_args)
219
220        return {
221            "optimizer": optimizer,
222            "lr_scheduler": {
223                "scheduler": lr_scheduler,
224                "interval": "step",
225            },
226        }

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Return:

Any of these 6 options.

  • Single optimizer.
  • List or Tuple of optimizers.
  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).
  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.
  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

.. testcode::

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated",
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }


# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your ~lightning.pytorch.core.LightningModule.

Note:

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.
  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default "epoch") in the scheduler configuration, Lightning will call the scheduler's .step() method automatically in case of automatic optimization.
  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.
  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.
  • If you use multiple optimizers, you will have to switch to 'manual optimization' mode and step them yourself.
  • If you need to control how often the optimizer steps, override the optimizer_step() hook.
class AttributeWithUnpairedDomainModule(shimmer.modules.domain.DomainModule):
229class AttributeWithUnpairedDomainModule(DomainModule):
230    in_dim = 11
231
232    def __init__(
233        self,
234        latent_dim: int,
235        hidden_dim: int,
236        beta: float = 1,
237        coef_categories: float = 1,
238        coef_attributes: float = 1,
239        n_unpaired: int = 1,
240        optim_lr: float = 1e-3,
241        optim_weight_decay: float = 0,
242        scheduler_args: SchedulerArgs | None = None,
243        coef_unpaired: float = 0.5,
244    ):
245        super().__init__(latent_dim + n_unpaired)
246
247        if coef_categories < 0 or coef_categories > 1:
248            raise ValueError("coef_categories should be in [0, 1]")
249        if coef_attributes < 0 or coef_attributes > 1:
250            raise ValueError("coef_attributes should be in [0, 1]")
251        if coef_unpaired < 0 or coef_unpaired > 1:
252            raise ValueError("coef_unpaired should be in [0, 1]")
253
254        self.save_hyperparameters()
255        self.paired_dim = latent_dim
256        self.n_unpaired = n_unpaired
257        self.hidden_dim = hidden_dim
258        self.coef_categories = coef_categories
259        self.coef_attributes = coef_attributes
260        self.coef_unpaired = coef_unpaired
261
262        vae_encoder = Encoder(self.hidden_dim, self.latent_dim - self.n_unpaired)
263        vae_decoder = Decoder(self.latent_dim - self.n_unpaired, self.hidden_dim)
264        self.vae = VAE(vae_encoder, vae_decoder, beta)
265
266        self.optim_lr = optim_lr
267        self.optim_weight_decay = optim_weight_decay
268
269        self.scheduler_args = SchedulerArgs(
270            max_lr=optim_lr,
271            total_steps=1,
272        )
273        self.scheduler_args.update(scheduler_args or {})
274
275    def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
276        """
277        x must contains 3 items:
278        - the class
279        - the attributes
280        - the unpaired value
281        """
282        assert len(x) == 3, (
283            "x must have the unpaired value "
284            "(use `attr` instead of `attr_unpaired` otherwise)."
285        )
286        z = self.vae.encode(x[:-1])
287        return torch.cat([z, x[-1]], dim=-1)
288
289    def decode(self, z: torch.Tensor) -> list[torch.Tensor]:
290        paired = z[:, : self.paired_dim]
291        unpaired = z[:, self.paired_dim :]
292        out = list(self.vae.decode(paired))
293        out.append(unpaired)
294        return out
295
296    def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:  # type: ignore
297        return self.decode(self.encode(x))
298
299    def compute_loss(
300        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
301    ) -> LossOutput:
302        paired_loss = F.mse_loss(
303            pred[:, : self.paired_dim], target[:, : self.paired_dim]
304        )
305        unpaired_loss = F.mse_loss(
306            pred[:, self.paired_dim :], target[:, self.paired_dim :]
307        )
308        total_loss = unpaired_loss + paired_loss
309        return LossOutput(
310            loss=total_loss,
311            metrics={
312                "unpaired": unpaired_loss,
313                "paired": paired_loss,
314            },
315        )

Base class for a DomainModule that defines domain specific modules of the GW.

AttributeWithUnpairedDomainModule( latent_dim: int, hidden_dim: int, beta: float = 1, coef_categories: float = 1, coef_attributes: float = 1, n_unpaired: int = 1, optim_lr: float = 0.001, optim_weight_decay: float = 0, scheduler_args: shimmer.modules.global_workspace.SchedulerArgs | None = None, coef_unpaired: float = 0.5)
232    def __init__(
233        self,
234        latent_dim: int,
235        hidden_dim: int,
236        beta: float = 1,
237        coef_categories: float = 1,
238        coef_attributes: float = 1,
239        n_unpaired: int = 1,
240        optim_lr: float = 1e-3,
241        optim_weight_decay: float = 0,
242        scheduler_args: SchedulerArgs | None = None,
243        coef_unpaired: float = 0.5,
244    ):
245        super().__init__(latent_dim + n_unpaired)
246
247        if coef_categories < 0 or coef_categories > 1:
248            raise ValueError("coef_categories should be in [0, 1]")
249        if coef_attributes < 0 or coef_attributes > 1:
250            raise ValueError("coef_attributes should be in [0, 1]")
251        if coef_unpaired < 0 or coef_unpaired > 1:
252            raise ValueError("coef_unpaired should be in [0, 1]")
253
254        self.save_hyperparameters()
255        self.paired_dim = latent_dim
256        self.n_unpaired = n_unpaired
257        self.hidden_dim = hidden_dim
258        self.coef_categories = coef_categories
259        self.coef_attributes = coef_attributes
260        self.coef_unpaired = coef_unpaired
261
262        vae_encoder = Encoder(self.hidden_dim, self.latent_dim - self.n_unpaired)
263        vae_decoder = Decoder(self.latent_dim - self.n_unpaired, self.hidden_dim)
264        self.vae = VAE(vae_encoder, vae_decoder, beta)
265
266        self.optim_lr = optim_lr
267        self.optim_weight_decay = optim_weight_decay
268
269        self.scheduler_args = SchedulerArgs(
270            max_lr=optim_lr,
271            total_steps=1,
272        )
273        self.scheduler_args.update(scheduler_args or {})

Initializes a DomainModule.

Arguments:
  • latent_dim (int): latent dimension of the unimodal module
in_dim = 11
paired_dim
n_unpaired
hidden_dim
coef_categories
coef_attributes
coef_unpaired
vae
optim_lr
optim_weight_decay
scheduler_args
def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
275    def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
276        """
277        x must contains 3 items:
278        - the class
279        - the attributes
280        - the unpaired value
281        """
282        assert len(x) == 3, (
283            "x must have the unpaired value "
284            "(use `attr` instead of `attr_unpaired` otherwise)."
285        )
286        z = self.vae.encode(x[:-1])
287        return torch.cat([z, x[-1]], dim=-1)

x must contains 3 items:

  • the class
  • the attributes
  • the unpaired value
def decode(self, z: torch.Tensor) -> list[torch.Tensor]:
289    def decode(self, z: torch.Tensor) -> list[torch.Tensor]:
290        paired = z[:, : self.paired_dim]
291        unpaired = z[:, self.paired_dim :]
292        out = list(self.vae.decode(paired))
293        out.append(unpaired)
294        return out

Decode data from unimodal representation back to the domain data.

Arguments:
  • z (torch.Tensor): unimodal representation of the domain.
Returns:

Any: the original domain data.

def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:
296    def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:  # type: ignore
297        return self.decode(self.encode(x))

Same as torch.nn.Module.forward().

Arguments:
  • *args: Whatever you decide to pass into the forward method.
  • **kwargs: Keyword arguments are also possible.
Return:

Your model's output

def compute_loss( self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> shimmer.modules.domain.LossOutput:
299    def compute_loss(
300        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
301    ) -> LossOutput:
302        paired_loss = F.mse_loss(
303            pred[:, : self.paired_dim], target[:, : self.paired_dim]
304        )
305        unpaired_loss = F.mse_loss(
306            pred[:, self.paired_dim :], target[:, self.paired_dim :]
307        )
308        total_loss = unpaired_loss + paired_loss
309        return LossOutput(
310            loss=total_loss,
311            metrics={
312                "unpaired": unpaired_loss,
313                "paired": paired_loss,
314            },
315        )

Generic loss computation the modality.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
  • raw_target (Any): raw data from the input
Results:

LossOutput | None: LossOuput with training loss and additional metrics. If None is returned, this loss will be ignored and will not participate in the total loss.

class AttributeLegacyDomainModule(shimmer.modules.domain.DomainModule):
318class AttributeLegacyDomainModule(DomainModule):
319    latent_dim = 11
320
321    def __init__(self):
322        super().__init__(self.latent_dim)
323        self.save_hyperparameters()
324
325    def compute_loss(
326        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
327    ) -> LossOutput:
328        pred_cat, pred_attr, _ = self.decode(pred)
329        target_cat, target_attr, _ = self.decode(target)
330
331        loss_attr = F.mse_loss(pred_attr, target_attr, reduction="mean")
332        loss_cat = F.nll_loss(pred_cat, torch.argmax(target_cat, 1))
333        loss = loss_attr + loss_cat
334
335        return LossOutput(loss, metrics={"loss_attr": loss_attr, "loss_cat": loss_cat})
336
337    def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
338        assert len(x) == 2, "This must have only 2 items."
339        return torch.cat(list(x), dim=-1)
340
341    def decode(self, z: torch.Tensor) -> list:
342        categories = z[:, :3]
343        attr = z[:, 3:11]
344        return [categories, attr]
345
346    def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:  # type: ignore
347        return self.decode(self.encode(x))

Base class for a DomainModule that defines domain specific modules of the GW.

AttributeLegacyDomainModule()
321    def __init__(self):
322        super().__init__(self.latent_dim)
323        self.save_hyperparameters()

Initializes a DomainModule.

Arguments:
  • latent_dim (int): latent dimension of the unimodal module
latent_dim = 11

The latent dimension of the module.

def compute_loss( self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any) -> shimmer.modules.domain.LossOutput:
325    def compute_loss(
326        self, pred: torch.Tensor, target: torch.Tensor, raw_target: Any
327    ) -> LossOutput:
328        pred_cat, pred_attr, _ = self.decode(pred)
329        target_cat, target_attr, _ = self.decode(target)
330
331        loss_attr = F.mse_loss(pred_attr, target_attr, reduction="mean")
332        loss_cat = F.nll_loss(pred_cat, torch.argmax(target_cat, 1))
333        loss = loss_attr + loss_cat
334
335        return LossOutput(loss, metrics={"loss_attr": loss_attr, "loss_cat": loss_cat})

Generic loss computation the modality.

Arguments:
  • pred (torch.Tensor): prediction of the model
  • target (torch.Tensor): target tensor
  • raw_target (Any): raw data from the input
Results:

LossOutput | None: LossOuput with training loss and additional metrics. If None is returned, this loss will be ignored and will not participate in the total loss.

def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
337    def encode(self, x: Sequence[torch.Tensor]) -> torch.Tensor:
338        assert len(x) == 2, "This must have only 2 items."
339        return torch.cat(list(x), dim=-1)

Encode the domain data into a unimodal representation.

Arguments:
  • x (Any): data of the domain.
Returns:

torch.Tensor: a unimodal representation.

def decode(self, z: torch.Tensor) -> list:
341    def decode(self, z: torch.Tensor) -> list:
342        categories = z[:, :3]
343        attr = z[:, 3:11]
344        return [categories, attr]

Decode data from unimodal representation back to the domain data.

Arguments:
  • z (torch.Tensor): unimodal representation of the domain.
Returns:

Any: the original domain data.

def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:
346    def forward(self, x: Sequence[torch.Tensor]) -> list[torch.Tensor]:  # type: ignore
347        return self.decode(self.encode(x))

Same as torch.nn.Module.forward().

Arguments:
  • *args: Whatever you decide to pass into the forward method.
  • **kwargs: Keyword arguments are also possible.
Return:

Your model's output