shimmer_ssd.modules.vae

  1from collections.abc import Callable, Sequence
  2from typing import Any
  3
  4import matplotlib.pyplot as plt
  5import numpy as np
  6import torch
  7from matplotlib.figure import Figure
  8from matplotlib.gridspec import GridSpec
  9from PIL.Image import Image
 10from shimmer.modules.vae import VAE, VAEDecoder, VAEEncoder
 11from torch import nn
 12
 13
 14class RAEEncoder(VAEEncoder):
 15    def __init__(
 16        self,
 17        num_channels: int,
 18        ae_dim: int = 1028,
 19        z_dim: int = 16,
 20        kernel_dim: int = 4,
 21        padding: int = 1,
 22        use_batchnorm: bool = True,
 23    ):
 24        super().__init__()
 25
 26        self.dims = [
 27            ae_dim // (2**i) for i in reversed(range(4))
 28        ]  # 1 2 4 8 # 32 64 128 256
 29
 30        self.kernel_dim = kernel_dim
 31        self.padding = padding
 32        self.dims[-1] = ae_dim
 33        self.use_batchnorm = use_batchnorm
 34
 35        self.out_dim = self.dims[3] * 2 * 2
 36        self.z_dim = z_dim
 37
 38        self.layers = nn.Sequential(
 39            nn.Conv2d(
 40                num_channels,
 41                self.dims[0],
 42                kernel_size=self.kernel_dim,
 43                stride=2,
 44                padding=self.padding,
 45                bias=not self.use_batchnorm,
 46            ),
 47            nn.BatchNorm2d(self.dims[0]) if self.use_batchnorm else nn.Identity(),
 48            nn.ReLU(),
 49            nn.Conv2d(
 50                self.dims[0],
 51                self.dims[1],
 52                kernel_size=self.kernel_dim,
 53                stride=2,
 54                padding=self.padding,
 55                bias=not self.use_batchnorm,
 56            ),
 57            nn.BatchNorm2d(self.dims[1]) if self.use_batchnorm else nn.Identity(),
 58            nn.ReLU(),
 59            nn.Conv2d(
 60                self.dims[1],
 61                self.dims[2],
 62                kernel_size=self.kernel_dim,
 63                stride=2,
 64                padding=self.padding,
 65                bias=not self.use_batchnorm,
 66            ),
 67            nn.BatchNorm2d(self.dims[2]) if self.use_batchnorm else nn.Identity(),
 68            nn.ReLU(),
 69            nn.Conv2d(
 70                self.dims[2],
 71                self.dims[3],
 72                kernel_size=self.kernel_dim,
 73                stride=2,
 74                padding=self.padding,
 75                bias=not self.use_batchnorm,
 76            ),
 77            nn.BatchNorm2d(self.dims[3]) if self.use_batchnorm else nn.Identity(),
 78            nn.ReLU(),
 79        )
 80
 81        self.q_mean = nn.Linear(self.out_dim, self.z_dim)
 82        self.q_logvar = nn.Linear(self.out_dim, self.z_dim)
 83
 84    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
 85        out = self.layers(x).view(x.size(0), -1)
 86        out = out.view(out.size(0), -1)
 87
 88        return self.q_mean(out), self.q_logvar(out)
 89
 90
 91class RAEDecoder(VAEDecoder):
 92    def __init__(
 93        self,
 94        num_channels: int,
 95        z_dim: int,
 96        ae_dim: int = 1028,
 97        kernel_dim: int = 4,
 98        padding: int = 1,
 99        use_batchnorm: bool = True,
100    ):
101        super().__init__()
102
103        self.num_channels = num_channels
104        self.dims = [ae_dim // (2**i) for i in reversed(range(3))]
105        self.dims[-1] = ae_dim
106
107        self.kernel_dim = kernel_dim
108        self.padding = padding
109        self.use_batchnorm = use_batchnorm
110
111        self.layers = nn.Sequential(
112            nn.ConvTranspose2d(
113                z_dim,
114                self.dims[2],
115                kernel_size=8,
116                stride=1,
117                bias=not self.use_batchnorm,
118            ),
119            nn.BatchNorm2d(self.dims[2]) if self.use_batchnorm else nn.Identity(),
120            nn.ReLU(),
121            nn.ConvTranspose2d(
122                self.dims[2],
123                self.dims[1],
124                kernel_size=self.kernel_dim,
125                stride=2,
126                padding=self.padding,
127                bias=not self.use_batchnorm,
128            ),
129            nn.BatchNorm2d(self.dims[1]) if self.use_batchnorm else nn.Identity(),
130            nn.ReLU(),
131            nn.ConvTranspose2d(
132                self.dims[1],
133                self.dims[0],
134                kernel_size=self.kernel_dim,
135                stride=2,
136                padding=self.padding,
137                bias=not self.use_batchnorm,
138            ),
139            nn.BatchNorm2d(self.dims[0]) if self.use_batchnorm else nn.Identity(),
140            nn.ReLU(),
141        )
142
143        self.out_layer = nn.Sequential(
144            nn.ZeroPad2d((0, 1, 0, 1)),
145            nn.Conv2d(
146                self.dims[0],
147                self.num_channels,
148                kernel_size=self.kernel_dim,
149                stride=1,
150                padding=self.padding,
151            ),
152            nn.Sigmoid(),
153        )
154
155    def forward(self, z: torch.Tensor) -> torch.Tensor:  # type: ignore
156        return self.out_layer(self.layers(z[:, :, None, None]))
157
158
159def dim_exploration_figure(
160    vae: VAE,
161    z_size: int,
162    device: torch.device,
163    ax_from_tensors: Callable[[Any, int, int], Image],
164    num_samples: int = 5,
165    range_start: int = -6,
166    range_end: int = 6,
167    image_size: int = 32,
168    plot_dims: Sequence[int] | None = None,
169    fig_dim: int = 5,
170) -> Figure:
171    possible_dims = plot_dims or np.arange(z_size)
172
173    fig_size = (len(possible_dims) - 1) * fig_dim
174
175    fig = plt.figure(constrained_layout=True, figsize=(fig_size, fig_size), dpi=1)
176
177    gs = GridSpec(len(possible_dims), len(possible_dims), figure=fig)
178    done_dims: list[set[int]] = []
179
180    for i, dim_i in enumerate(possible_dims):
181        for j, dim_j in enumerate(possible_dims):
182            if dim_i == dim_j or {dim_i, dim_j} in done_dims:
183                continue
184
185            done_dims.append({dim_i, dim_j})
186
187            ax = fig.add_subplot(gs[j, i])
188
189            z = (
190                torch.zeros(z_size)
191                .unsqueeze(0)
192                .unsqueeze(0)
193                .expand(num_samples, num_samples, -1)
194                .to(device)
195            )
196
197            for p in range(num_samples):
198                step = range_start + (range_end - range_start) * float(p) / float(
199                    num_samples - 1
200                )
201                z[p, :, dim_i] = step
202            for q in range(num_samples):
203                step = range_start + (range_end - range_start) * float(q) / float(
204                    num_samples - 1
205                )
206                z[:, q, dim_j] = step
207
208            decoded_x = vae.decoder(z.reshape(-1, z_size))
209
210            img_grid = ax_from_tensors(decoded_x, image_size, num_samples)
211
212            ax.imshow(img_grid)
213            ax.set_xlabel(f"dim {dim_j}")
214            ax.set_ylabel(f"dim {dim_i}")
215            ax.set_xticks(image_size * np.arange(num_samples) + image_size // 2)
216            ax.set_xticklabels(
217                list(
218                    map(
219                        lambda x: f"{x:.1f}",
220                        np.linspace(range_start, range_end, num_samples),
221                    )
222                )
223            )
224            ax.set_yticks(image_size * np.arange(num_samples) + image_size // 2)
225            ax.set_yticklabels(
226                list(
227                    map(
228                        lambda x: f"{x:.1f}",
229                        np.linspace(range_start, range_end, num_samples),
230                    )
231                )
232            )
233
234    return fig
class RAEEncoder(shimmer.modules.vae.VAEEncoder):
15class RAEEncoder(VAEEncoder):
16    def __init__(
17        self,
18        num_channels: int,
19        ae_dim: int = 1028,
20        z_dim: int = 16,
21        kernel_dim: int = 4,
22        padding: int = 1,
23        use_batchnorm: bool = True,
24    ):
25        super().__init__()
26
27        self.dims = [
28            ae_dim // (2**i) for i in reversed(range(4))
29        ]  # 1 2 4 8 # 32 64 128 256
30
31        self.kernel_dim = kernel_dim
32        self.padding = padding
33        self.dims[-1] = ae_dim
34        self.use_batchnorm = use_batchnorm
35
36        self.out_dim = self.dims[3] * 2 * 2
37        self.z_dim = z_dim
38
39        self.layers = nn.Sequential(
40            nn.Conv2d(
41                num_channels,
42                self.dims[0],
43                kernel_size=self.kernel_dim,
44                stride=2,
45                padding=self.padding,
46                bias=not self.use_batchnorm,
47            ),
48            nn.BatchNorm2d(self.dims[0]) if self.use_batchnorm else nn.Identity(),
49            nn.ReLU(),
50            nn.Conv2d(
51                self.dims[0],
52                self.dims[1],
53                kernel_size=self.kernel_dim,
54                stride=2,
55                padding=self.padding,
56                bias=not self.use_batchnorm,
57            ),
58            nn.BatchNorm2d(self.dims[1]) if self.use_batchnorm else nn.Identity(),
59            nn.ReLU(),
60            nn.Conv2d(
61                self.dims[1],
62                self.dims[2],
63                kernel_size=self.kernel_dim,
64                stride=2,
65                padding=self.padding,
66                bias=not self.use_batchnorm,
67            ),
68            nn.BatchNorm2d(self.dims[2]) if self.use_batchnorm else nn.Identity(),
69            nn.ReLU(),
70            nn.Conv2d(
71                self.dims[2],
72                self.dims[3],
73                kernel_size=self.kernel_dim,
74                stride=2,
75                padding=self.padding,
76                bias=not self.use_batchnorm,
77            ),
78            nn.BatchNorm2d(self.dims[3]) if self.use_batchnorm else nn.Identity(),
79            nn.ReLU(),
80        )
81
82        self.q_mean = nn.Linear(self.out_dim, self.z_dim)
83        self.q_logvar = nn.Linear(self.out_dim, self.z_dim)
84
85    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
86        out = self.layers(x).view(x.size(0), -1)
87        out = out.view(out.size(0), -1)
88
89        return self.q_mean(out), self.q_logvar(out)

Base class for a VAE encoder.

RAEEncoder( num_channels: int, ae_dim: int = 1028, z_dim: int = 16, kernel_dim: int = 4, padding: int = 1, use_batchnorm: bool = True)
16    def __init__(
17        self,
18        num_channels: int,
19        ae_dim: int = 1028,
20        z_dim: int = 16,
21        kernel_dim: int = 4,
22        padding: int = 1,
23        use_batchnorm: bool = True,
24    ):
25        super().__init__()
26
27        self.dims = [
28            ae_dim // (2**i) for i in reversed(range(4))
29        ]  # 1 2 4 8 # 32 64 128 256
30
31        self.kernel_dim = kernel_dim
32        self.padding = padding
33        self.dims[-1] = ae_dim
34        self.use_batchnorm = use_batchnorm
35
36        self.out_dim = self.dims[3] * 2 * 2
37        self.z_dim = z_dim
38
39        self.layers = nn.Sequential(
40            nn.Conv2d(
41                num_channels,
42                self.dims[0],
43                kernel_size=self.kernel_dim,
44                stride=2,
45                padding=self.padding,
46                bias=not self.use_batchnorm,
47            ),
48            nn.BatchNorm2d(self.dims[0]) if self.use_batchnorm else nn.Identity(),
49            nn.ReLU(),
50            nn.Conv2d(
51                self.dims[0],
52                self.dims[1],
53                kernel_size=self.kernel_dim,
54                stride=2,
55                padding=self.padding,
56                bias=not self.use_batchnorm,
57            ),
58            nn.BatchNorm2d(self.dims[1]) if self.use_batchnorm else nn.Identity(),
59            nn.ReLU(),
60            nn.Conv2d(
61                self.dims[1],
62                self.dims[2],
63                kernel_size=self.kernel_dim,
64                stride=2,
65                padding=self.padding,
66                bias=not self.use_batchnorm,
67            ),
68            nn.BatchNorm2d(self.dims[2]) if self.use_batchnorm else nn.Identity(),
69            nn.ReLU(),
70            nn.Conv2d(
71                self.dims[2],
72                self.dims[3],
73                kernel_size=self.kernel_dim,
74                stride=2,
75                padding=self.padding,
76                bias=not self.use_batchnorm,
77            ),
78            nn.BatchNorm2d(self.dims[3]) if self.use_batchnorm else nn.Identity(),
79            nn.ReLU(),
80        )
81
82        self.q_mean = nn.Linear(self.out_dim, self.z_dim)
83        self.q_logvar = nn.Linear(self.out_dim, self.z_dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dims
kernel_dim
padding
use_batchnorm
out_dim
z_dim
layers
q_mean
q_logvar
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
85    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
86        out = self.layers(x).view(x.size(0), -1)
87        out = out.view(out.size(0), -1)
88
89        return self.q_mean(out), self.q_logvar(out)

Encode representation with VAE.

Arguments:
  • x (Any): Some input value
Returns:

tuple[torch.Tensor, torch.Tensor]: the mean and log variance

class RAEDecoder(shimmer.modules.vae.VAEDecoder):
 92class RAEDecoder(VAEDecoder):
 93    def __init__(
 94        self,
 95        num_channels: int,
 96        z_dim: int,
 97        ae_dim: int = 1028,
 98        kernel_dim: int = 4,
 99        padding: int = 1,
100        use_batchnorm: bool = True,
101    ):
102        super().__init__()
103
104        self.num_channels = num_channels
105        self.dims = [ae_dim // (2**i) for i in reversed(range(3))]
106        self.dims[-1] = ae_dim
107
108        self.kernel_dim = kernel_dim
109        self.padding = padding
110        self.use_batchnorm = use_batchnorm
111
112        self.layers = nn.Sequential(
113            nn.ConvTranspose2d(
114                z_dim,
115                self.dims[2],
116                kernel_size=8,
117                stride=1,
118                bias=not self.use_batchnorm,
119            ),
120            nn.BatchNorm2d(self.dims[2]) if self.use_batchnorm else nn.Identity(),
121            nn.ReLU(),
122            nn.ConvTranspose2d(
123                self.dims[2],
124                self.dims[1],
125                kernel_size=self.kernel_dim,
126                stride=2,
127                padding=self.padding,
128                bias=not self.use_batchnorm,
129            ),
130            nn.BatchNorm2d(self.dims[1]) if self.use_batchnorm else nn.Identity(),
131            nn.ReLU(),
132            nn.ConvTranspose2d(
133                self.dims[1],
134                self.dims[0],
135                kernel_size=self.kernel_dim,
136                stride=2,
137                padding=self.padding,
138                bias=not self.use_batchnorm,
139            ),
140            nn.BatchNorm2d(self.dims[0]) if self.use_batchnorm else nn.Identity(),
141            nn.ReLU(),
142        )
143
144        self.out_layer = nn.Sequential(
145            nn.ZeroPad2d((0, 1, 0, 1)),
146            nn.Conv2d(
147                self.dims[0],
148                self.num_channels,
149                kernel_size=self.kernel_dim,
150                stride=1,
151                padding=self.padding,
152            ),
153            nn.Sigmoid(),
154        )
155
156    def forward(self, z: torch.Tensor) -> torch.Tensor:  # type: ignore
157        return self.out_layer(self.layers(z[:, :, None, None]))

Base class for a VAE decoder.

RAEDecoder( num_channels: int, z_dim: int, ae_dim: int = 1028, kernel_dim: int = 4, padding: int = 1, use_batchnorm: bool = True)
 93    def __init__(
 94        self,
 95        num_channels: int,
 96        z_dim: int,
 97        ae_dim: int = 1028,
 98        kernel_dim: int = 4,
 99        padding: int = 1,
100        use_batchnorm: bool = True,
101    ):
102        super().__init__()
103
104        self.num_channels = num_channels
105        self.dims = [ae_dim // (2**i) for i in reversed(range(3))]
106        self.dims[-1] = ae_dim
107
108        self.kernel_dim = kernel_dim
109        self.padding = padding
110        self.use_batchnorm = use_batchnorm
111
112        self.layers = nn.Sequential(
113            nn.ConvTranspose2d(
114                z_dim,
115                self.dims[2],
116                kernel_size=8,
117                stride=1,
118                bias=not self.use_batchnorm,
119            ),
120            nn.BatchNorm2d(self.dims[2]) if self.use_batchnorm else nn.Identity(),
121            nn.ReLU(),
122            nn.ConvTranspose2d(
123                self.dims[2],
124                self.dims[1],
125                kernel_size=self.kernel_dim,
126                stride=2,
127                padding=self.padding,
128                bias=not self.use_batchnorm,
129            ),
130            nn.BatchNorm2d(self.dims[1]) if self.use_batchnorm else nn.Identity(),
131            nn.ReLU(),
132            nn.ConvTranspose2d(
133                self.dims[1],
134                self.dims[0],
135                kernel_size=self.kernel_dim,
136                stride=2,
137                padding=self.padding,
138                bias=not self.use_batchnorm,
139            ),
140            nn.BatchNorm2d(self.dims[0]) if self.use_batchnorm else nn.Identity(),
141            nn.ReLU(),
142        )
143
144        self.out_layer = nn.Sequential(
145            nn.ZeroPad2d((0, 1, 0, 1)),
146            nn.Conv2d(
147                self.dims[0],
148                self.num_channels,
149                kernel_size=self.kernel_dim,
150                stride=1,
151                padding=self.padding,
152            ),
153            nn.Sigmoid(),
154        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_channels
dims
kernel_dim
padding
use_batchnorm
layers
out_layer
def forward(self, z: torch.Tensor) -> torch.Tensor:
156    def forward(self, z: torch.Tensor) -> torch.Tensor:  # type: ignore
157        return self.out_layer(self.layers(z[:, :, None, None]))

Decode representation with VAE

Arguments:
  • x (torch.Tensor): VAE latent representation representation
Returns:

Any: the reconstructed input

def dim_exploration_figure( vae: shimmer.modules.vae.VAE, z_size: int, device: torch.device, ax_from_tensors: Callable[[typing.Any, int, int], PIL.Image.Image], num_samples: int = 5, range_start: int = -6, range_end: int = 6, image_size: int = 32, plot_dims: Sequence[int] | None = None, fig_dim: int = 5) -> matplotlib.figure.Figure:
160def dim_exploration_figure(
161    vae: VAE,
162    z_size: int,
163    device: torch.device,
164    ax_from_tensors: Callable[[Any, int, int], Image],
165    num_samples: int = 5,
166    range_start: int = -6,
167    range_end: int = 6,
168    image_size: int = 32,
169    plot_dims: Sequence[int] | None = None,
170    fig_dim: int = 5,
171) -> Figure:
172    possible_dims = plot_dims or np.arange(z_size)
173
174    fig_size = (len(possible_dims) - 1) * fig_dim
175
176    fig = plt.figure(constrained_layout=True, figsize=(fig_size, fig_size), dpi=1)
177
178    gs = GridSpec(len(possible_dims), len(possible_dims), figure=fig)
179    done_dims: list[set[int]] = []
180
181    for i, dim_i in enumerate(possible_dims):
182        for j, dim_j in enumerate(possible_dims):
183            if dim_i == dim_j or {dim_i, dim_j} in done_dims:
184                continue
185
186            done_dims.append({dim_i, dim_j})
187
188            ax = fig.add_subplot(gs[j, i])
189
190            z = (
191                torch.zeros(z_size)
192                .unsqueeze(0)
193                .unsqueeze(0)
194                .expand(num_samples, num_samples, -1)
195                .to(device)
196            )
197
198            for p in range(num_samples):
199                step = range_start + (range_end - range_start) * float(p) / float(
200                    num_samples - 1
201                )
202                z[p, :, dim_i] = step
203            for q in range(num_samples):
204                step = range_start + (range_end - range_start) * float(q) / float(
205                    num_samples - 1
206                )
207                z[:, q, dim_j] = step
208
209            decoded_x = vae.decoder(z.reshape(-1, z_size))
210
211            img_grid = ax_from_tensors(decoded_x, image_size, num_samples)
212
213            ax.imshow(img_grid)
214            ax.set_xlabel(f"dim {dim_j}")
215            ax.set_ylabel(f"dim {dim_i}")
216            ax.set_xticks(image_size * np.arange(num_samples) + image_size // 2)
217            ax.set_xticklabels(
218                list(
219                    map(
220                        lambda x: f"{x:.1f}",
221                        np.linspace(range_start, range_end, num_samples),
222                    )
223                )
224            )
225            ax.set_yticks(image_size * np.arange(num_samples) + image_size // 2)
226            ax.set_yticklabels(
227                list(
228                    map(
229                        lambda x: f"{x:.1f}",
230                        np.linspace(range_start, range_end, num_samples),
231                    )
232                )
233            )
234
235    return fig