shimmer_ssd.modules.vae
1from collections.abc import Callable, Sequence 2from typing import Any 3 4import matplotlib.pyplot as plt 5import numpy as np 6import torch 7from matplotlib.figure import Figure 8from matplotlib.gridspec import GridSpec 9from PIL.Image import Image 10from shimmer.modules.vae import VAE, VAEDecoder, VAEEncoder 11from torch import nn 12 13 14class RAEEncoder(VAEEncoder): 15 def __init__( 16 self, 17 num_channels: int, 18 ae_dim: int = 1028, 19 z_dim: int = 16, 20 kernel_dim: int = 4, 21 padding: int = 1, 22 use_batchnorm: bool = True, 23 ): 24 super().__init__() 25 26 self.dims = [ 27 ae_dim // (2**i) for i in reversed(range(4)) 28 ] # 1 2 4 8 # 32 64 128 256 29 30 self.kernel_dim = kernel_dim 31 self.padding = padding 32 self.dims[-1] = ae_dim 33 self.use_batchnorm = use_batchnorm 34 35 self.out_dim = self.dims[3] * 2 * 2 36 self.z_dim = z_dim 37 38 self.layers = nn.Sequential( 39 nn.Conv2d( 40 num_channels, 41 self.dims[0], 42 kernel_size=self.kernel_dim, 43 stride=2, 44 padding=self.padding, 45 bias=not self.use_batchnorm, 46 ), 47 nn.BatchNorm2d(self.dims[0]) if self.use_batchnorm else nn.Identity(), 48 nn.ReLU(), 49 nn.Conv2d( 50 self.dims[0], 51 self.dims[1], 52 kernel_size=self.kernel_dim, 53 stride=2, 54 padding=self.padding, 55 bias=not self.use_batchnorm, 56 ), 57 nn.BatchNorm2d(self.dims[1]) if self.use_batchnorm else nn.Identity(), 58 nn.ReLU(), 59 nn.Conv2d( 60 self.dims[1], 61 self.dims[2], 62 kernel_size=self.kernel_dim, 63 stride=2, 64 padding=self.padding, 65 bias=not self.use_batchnorm, 66 ), 67 nn.BatchNorm2d(self.dims[2]) if self.use_batchnorm else nn.Identity(), 68 nn.ReLU(), 69 nn.Conv2d( 70 self.dims[2], 71 self.dims[3], 72 kernel_size=self.kernel_dim, 73 stride=2, 74 padding=self.padding, 75 bias=not self.use_batchnorm, 76 ), 77 nn.BatchNorm2d(self.dims[3]) if self.use_batchnorm else nn.Identity(), 78 nn.ReLU(), 79 ) 80 81 self.q_mean = nn.Linear(self.out_dim, self.z_dim) 82 self.q_logvar = nn.Linear(self.out_dim, self.z_dim) 83 84 def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 85 out = self.layers(x).view(x.size(0), -1) 86 out = out.view(out.size(0), -1) 87 88 return self.q_mean(out), self.q_logvar(out) 89 90 91class RAEDecoder(VAEDecoder): 92 def __init__( 93 self, 94 num_channels: int, 95 z_dim: int, 96 ae_dim: int = 1028, 97 kernel_dim: int = 4, 98 padding: int = 1, 99 use_batchnorm: bool = True, 100 ): 101 super().__init__() 102 103 self.num_channels = num_channels 104 self.dims = [ae_dim // (2**i) for i in reversed(range(3))] 105 self.dims[-1] = ae_dim 106 107 self.kernel_dim = kernel_dim 108 self.padding = padding 109 self.use_batchnorm = use_batchnorm 110 111 self.layers = nn.Sequential( 112 nn.ConvTranspose2d( 113 z_dim, 114 self.dims[2], 115 kernel_size=8, 116 stride=1, 117 bias=not self.use_batchnorm, 118 ), 119 nn.BatchNorm2d(self.dims[2]) if self.use_batchnorm else nn.Identity(), 120 nn.ReLU(), 121 nn.ConvTranspose2d( 122 self.dims[2], 123 self.dims[1], 124 kernel_size=self.kernel_dim, 125 stride=2, 126 padding=self.padding, 127 bias=not self.use_batchnorm, 128 ), 129 nn.BatchNorm2d(self.dims[1]) if self.use_batchnorm else nn.Identity(), 130 nn.ReLU(), 131 nn.ConvTranspose2d( 132 self.dims[1], 133 self.dims[0], 134 kernel_size=self.kernel_dim, 135 stride=2, 136 padding=self.padding, 137 bias=not self.use_batchnorm, 138 ), 139 nn.BatchNorm2d(self.dims[0]) if self.use_batchnorm else nn.Identity(), 140 nn.ReLU(), 141 ) 142 143 self.out_layer = nn.Sequential( 144 nn.ZeroPad2d((0, 1, 0, 1)), 145 nn.Conv2d( 146 self.dims[0], 147 self.num_channels, 148 kernel_size=self.kernel_dim, 149 stride=1, 150 padding=self.padding, 151 ), 152 nn.Sigmoid(), 153 ) 154 155 def forward(self, z: torch.Tensor) -> torch.Tensor: # type: ignore 156 return self.out_layer(self.layers(z[:, :, None, None])) 157 158 159def dim_exploration_figure( 160 vae: VAE, 161 z_size: int, 162 device: torch.device, 163 ax_from_tensors: Callable[[Any, int, int], Image], 164 num_samples: int = 5, 165 range_start: int = -6, 166 range_end: int = 6, 167 image_size: int = 32, 168 plot_dims: Sequence[int] | None = None, 169 fig_dim: int = 5, 170) -> Figure: 171 possible_dims = plot_dims or np.arange(z_size) 172 173 fig_size = (len(possible_dims) - 1) * fig_dim 174 175 fig = plt.figure(constrained_layout=True, figsize=(fig_size, fig_size), dpi=1) 176 177 gs = GridSpec(len(possible_dims), len(possible_dims), figure=fig) 178 done_dims: list[set[int]] = [] 179 180 for i, dim_i in enumerate(possible_dims): 181 for j, dim_j in enumerate(possible_dims): 182 if dim_i == dim_j or {dim_i, dim_j} in done_dims: 183 continue 184 185 done_dims.append({dim_i, dim_j}) 186 187 ax = fig.add_subplot(gs[j, i]) 188 189 z = ( 190 torch.zeros(z_size) 191 .unsqueeze(0) 192 .unsqueeze(0) 193 .expand(num_samples, num_samples, -1) 194 .to(device) 195 ) 196 197 for p in range(num_samples): 198 step = range_start + (range_end - range_start) * float(p) / float( 199 num_samples - 1 200 ) 201 z[p, :, dim_i] = step 202 for q in range(num_samples): 203 step = range_start + (range_end - range_start) * float(q) / float( 204 num_samples - 1 205 ) 206 z[:, q, dim_j] = step 207 208 decoded_x = vae.decoder(z.reshape(-1, z_size)) 209 210 img_grid = ax_from_tensors(decoded_x, image_size, num_samples) 211 212 ax.imshow(img_grid) 213 ax.set_xlabel(f"dim {dim_j}") 214 ax.set_ylabel(f"dim {dim_i}") 215 ax.set_xticks(image_size * np.arange(num_samples) + image_size // 2) 216 ax.set_xticklabels( 217 list( 218 map( 219 lambda x: f"{x:.1f}", 220 np.linspace(range_start, range_end, num_samples), 221 ) 222 ) 223 ) 224 ax.set_yticks(image_size * np.arange(num_samples) + image_size // 2) 225 ax.set_yticklabels( 226 list( 227 map( 228 lambda x: f"{x:.1f}", 229 np.linspace(range_start, range_end, num_samples), 230 ) 231 ) 232 ) 233 234 return fig
class
RAEEncoder(shimmer.modules.vae.VAEEncoder):
15class RAEEncoder(VAEEncoder): 16 def __init__( 17 self, 18 num_channels: int, 19 ae_dim: int = 1028, 20 z_dim: int = 16, 21 kernel_dim: int = 4, 22 padding: int = 1, 23 use_batchnorm: bool = True, 24 ): 25 super().__init__() 26 27 self.dims = [ 28 ae_dim // (2**i) for i in reversed(range(4)) 29 ] # 1 2 4 8 # 32 64 128 256 30 31 self.kernel_dim = kernel_dim 32 self.padding = padding 33 self.dims[-1] = ae_dim 34 self.use_batchnorm = use_batchnorm 35 36 self.out_dim = self.dims[3] * 2 * 2 37 self.z_dim = z_dim 38 39 self.layers = nn.Sequential( 40 nn.Conv2d( 41 num_channels, 42 self.dims[0], 43 kernel_size=self.kernel_dim, 44 stride=2, 45 padding=self.padding, 46 bias=not self.use_batchnorm, 47 ), 48 nn.BatchNorm2d(self.dims[0]) if self.use_batchnorm else nn.Identity(), 49 nn.ReLU(), 50 nn.Conv2d( 51 self.dims[0], 52 self.dims[1], 53 kernel_size=self.kernel_dim, 54 stride=2, 55 padding=self.padding, 56 bias=not self.use_batchnorm, 57 ), 58 nn.BatchNorm2d(self.dims[1]) if self.use_batchnorm else nn.Identity(), 59 nn.ReLU(), 60 nn.Conv2d( 61 self.dims[1], 62 self.dims[2], 63 kernel_size=self.kernel_dim, 64 stride=2, 65 padding=self.padding, 66 bias=not self.use_batchnorm, 67 ), 68 nn.BatchNorm2d(self.dims[2]) if self.use_batchnorm else nn.Identity(), 69 nn.ReLU(), 70 nn.Conv2d( 71 self.dims[2], 72 self.dims[3], 73 kernel_size=self.kernel_dim, 74 stride=2, 75 padding=self.padding, 76 bias=not self.use_batchnorm, 77 ), 78 nn.BatchNorm2d(self.dims[3]) if self.use_batchnorm else nn.Identity(), 79 nn.ReLU(), 80 ) 81 82 self.q_mean = nn.Linear(self.out_dim, self.z_dim) 83 self.q_logvar = nn.Linear(self.out_dim, self.z_dim) 84 85 def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 86 out = self.layers(x).view(x.size(0), -1) 87 out = out.view(out.size(0), -1) 88 89 return self.q_mean(out), self.q_logvar(out)
Base class for a VAE encoder.
RAEEncoder( num_channels: int, ae_dim: int = 1028, z_dim: int = 16, kernel_dim: int = 4, padding: int = 1, use_batchnorm: bool = True)
16 def __init__( 17 self, 18 num_channels: int, 19 ae_dim: int = 1028, 20 z_dim: int = 16, 21 kernel_dim: int = 4, 22 padding: int = 1, 23 use_batchnorm: bool = True, 24 ): 25 super().__init__() 26 27 self.dims = [ 28 ae_dim // (2**i) for i in reversed(range(4)) 29 ] # 1 2 4 8 # 32 64 128 256 30 31 self.kernel_dim = kernel_dim 32 self.padding = padding 33 self.dims[-1] = ae_dim 34 self.use_batchnorm = use_batchnorm 35 36 self.out_dim = self.dims[3] * 2 * 2 37 self.z_dim = z_dim 38 39 self.layers = nn.Sequential( 40 nn.Conv2d( 41 num_channels, 42 self.dims[0], 43 kernel_size=self.kernel_dim, 44 stride=2, 45 padding=self.padding, 46 bias=not self.use_batchnorm, 47 ), 48 nn.BatchNorm2d(self.dims[0]) if self.use_batchnorm else nn.Identity(), 49 nn.ReLU(), 50 nn.Conv2d( 51 self.dims[0], 52 self.dims[1], 53 kernel_size=self.kernel_dim, 54 stride=2, 55 padding=self.padding, 56 bias=not self.use_batchnorm, 57 ), 58 nn.BatchNorm2d(self.dims[1]) if self.use_batchnorm else nn.Identity(), 59 nn.ReLU(), 60 nn.Conv2d( 61 self.dims[1], 62 self.dims[2], 63 kernel_size=self.kernel_dim, 64 stride=2, 65 padding=self.padding, 66 bias=not self.use_batchnorm, 67 ), 68 nn.BatchNorm2d(self.dims[2]) if self.use_batchnorm else nn.Identity(), 69 nn.ReLU(), 70 nn.Conv2d( 71 self.dims[2], 72 self.dims[3], 73 kernel_size=self.kernel_dim, 74 stride=2, 75 padding=self.padding, 76 bias=not self.use_batchnorm, 77 ), 78 nn.BatchNorm2d(self.dims[3]) if self.use_batchnorm else nn.Identity(), 79 nn.ReLU(), 80 ) 81 82 self.q_mean = nn.Linear(self.out_dim, self.z_dim) 83 self.q_logvar = nn.Linear(self.out_dim, self.z_dim)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
def
forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
85 def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 86 out = self.layers(x).view(x.size(0), -1) 87 out = out.view(out.size(0), -1) 88 89 return self.q_mean(out), self.q_logvar(out)
Encode representation with VAE.
Arguments:
- x (
Any
): Some input value
Returns:
tuple[torch.Tensor, torch.Tensor]
: the mean and log variance
class
RAEDecoder(shimmer.modules.vae.VAEDecoder):
92class RAEDecoder(VAEDecoder): 93 def __init__( 94 self, 95 num_channels: int, 96 z_dim: int, 97 ae_dim: int = 1028, 98 kernel_dim: int = 4, 99 padding: int = 1, 100 use_batchnorm: bool = True, 101 ): 102 super().__init__() 103 104 self.num_channels = num_channels 105 self.dims = [ae_dim // (2**i) for i in reversed(range(3))] 106 self.dims[-1] = ae_dim 107 108 self.kernel_dim = kernel_dim 109 self.padding = padding 110 self.use_batchnorm = use_batchnorm 111 112 self.layers = nn.Sequential( 113 nn.ConvTranspose2d( 114 z_dim, 115 self.dims[2], 116 kernel_size=8, 117 stride=1, 118 bias=not self.use_batchnorm, 119 ), 120 nn.BatchNorm2d(self.dims[2]) if self.use_batchnorm else nn.Identity(), 121 nn.ReLU(), 122 nn.ConvTranspose2d( 123 self.dims[2], 124 self.dims[1], 125 kernel_size=self.kernel_dim, 126 stride=2, 127 padding=self.padding, 128 bias=not self.use_batchnorm, 129 ), 130 nn.BatchNorm2d(self.dims[1]) if self.use_batchnorm else nn.Identity(), 131 nn.ReLU(), 132 nn.ConvTranspose2d( 133 self.dims[1], 134 self.dims[0], 135 kernel_size=self.kernel_dim, 136 stride=2, 137 padding=self.padding, 138 bias=not self.use_batchnorm, 139 ), 140 nn.BatchNorm2d(self.dims[0]) if self.use_batchnorm else nn.Identity(), 141 nn.ReLU(), 142 ) 143 144 self.out_layer = nn.Sequential( 145 nn.ZeroPad2d((0, 1, 0, 1)), 146 nn.Conv2d( 147 self.dims[0], 148 self.num_channels, 149 kernel_size=self.kernel_dim, 150 stride=1, 151 padding=self.padding, 152 ), 153 nn.Sigmoid(), 154 ) 155 156 def forward(self, z: torch.Tensor) -> torch.Tensor: # type: ignore 157 return self.out_layer(self.layers(z[:, :, None, None]))
Base class for a VAE decoder.
RAEDecoder( num_channels: int, z_dim: int, ae_dim: int = 1028, kernel_dim: int = 4, padding: int = 1, use_batchnorm: bool = True)
93 def __init__( 94 self, 95 num_channels: int, 96 z_dim: int, 97 ae_dim: int = 1028, 98 kernel_dim: int = 4, 99 padding: int = 1, 100 use_batchnorm: bool = True, 101 ): 102 super().__init__() 103 104 self.num_channels = num_channels 105 self.dims = [ae_dim // (2**i) for i in reversed(range(3))] 106 self.dims[-1] = ae_dim 107 108 self.kernel_dim = kernel_dim 109 self.padding = padding 110 self.use_batchnorm = use_batchnorm 111 112 self.layers = nn.Sequential( 113 nn.ConvTranspose2d( 114 z_dim, 115 self.dims[2], 116 kernel_size=8, 117 stride=1, 118 bias=not self.use_batchnorm, 119 ), 120 nn.BatchNorm2d(self.dims[2]) if self.use_batchnorm else nn.Identity(), 121 nn.ReLU(), 122 nn.ConvTranspose2d( 123 self.dims[2], 124 self.dims[1], 125 kernel_size=self.kernel_dim, 126 stride=2, 127 padding=self.padding, 128 bias=not self.use_batchnorm, 129 ), 130 nn.BatchNorm2d(self.dims[1]) if self.use_batchnorm else nn.Identity(), 131 nn.ReLU(), 132 nn.ConvTranspose2d( 133 self.dims[1], 134 self.dims[0], 135 kernel_size=self.kernel_dim, 136 stride=2, 137 padding=self.padding, 138 bias=not self.use_batchnorm, 139 ), 140 nn.BatchNorm2d(self.dims[0]) if self.use_batchnorm else nn.Identity(), 141 nn.ReLU(), 142 ) 143 144 self.out_layer = nn.Sequential( 145 nn.ZeroPad2d((0, 1, 0, 1)), 146 nn.Conv2d( 147 self.dims[0], 148 self.num_channels, 149 kernel_size=self.kernel_dim, 150 stride=1, 151 padding=self.padding, 152 ), 153 nn.Sigmoid(), 154 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
def
forward(self, z: torch.Tensor) -> torch.Tensor:
156 def forward(self, z: torch.Tensor) -> torch.Tensor: # type: ignore 157 return self.out_layer(self.layers(z[:, :, None, None]))
Decode representation with VAE
Arguments:
- x (
torch.Tensor
): VAE latent representation representation
Returns:
Any
: the reconstructed input
def
dim_exploration_figure( vae: shimmer.modules.vae.VAE, z_size: int, device: torch.device, ax_from_tensors: Callable[[typing.Any, int, int], PIL.Image.Image], num_samples: int = 5, range_start: int = -6, range_end: int = 6, image_size: int = 32, plot_dims: Sequence[int] | None = None, fig_dim: int = 5) -> matplotlib.figure.Figure:
160def dim_exploration_figure( 161 vae: VAE, 162 z_size: int, 163 device: torch.device, 164 ax_from_tensors: Callable[[Any, int, int], Image], 165 num_samples: int = 5, 166 range_start: int = -6, 167 range_end: int = 6, 168 image_size: int = 32, 169 plot_dims: Sequence[int] | None = None, 170 fig_dim: int = 5, 171) -> Figure: 172 possible_dims = plot_dims or np.arange(z_size) 173 174 fig_size = (len(possible_dims) - 1) * fig_dim 175 176 fig = plt.figure(constrained_layout=True, figsize=(fig_size, fig_size), dpi=1) 177 178 gs = GridSpec(len(possible_dims), len(possible_dims), figure=fig) 179 done_dims: list[set[int]] = [] 180 181 for i, dim_i in enumerate(possible_dims): 182 for j, dim_j in enumerate(possible_dims): 183 if dim_i == dim_j or {dim_i, dim_j} in done_dims: 184 continue 185 186 done_dims.append({dim_i, dim_j}) 187 188 ax = fig.add_subplot(gs[j, i]) 189 190 z = ( 191 torch.zeros(z_size) 192 .unsqueeze(0) 193 .unsqueeze(0) 194 .expand(num_samples, num_samples, -1) 195 .to(device) 196 ) 197 198 for p in range(num_samples): 199 step = range_start + (range_end - range_start) * float(p) / float( 200 num_samples - 1 201 ) 202 z[p, :, dim_i] = step 203 for q in range(num_samples): 204 step = range_start + (range_end - range_start) * float(q) / float( 205 num_samples - 1 206 ) 207 z[:, q, dim_j] = step 208 209 decoded_x = vae.decoder(z.reshape(-1, z_size)) 210 211 img_grid = ax_from_tensors(decoded_x, image_size, num_samples) 212 213 ax.imshow(img_grid) 214 ax.set_xlabel(f"dim {dim_j}") 215 ax.set_ylabel(f"dim {dim_i}") 216 ax.set_xticks(image_size * np.arange(num_samples) + image_size // 2) 217 ax.set_xticklabels( 218 list( 219 map( 220 lambda x: f"{x:.1f}", 221 np.linspace(range_start, range_end, num_samples), 222 ) 223 ) 224 ) 225 ax.set_yticks(image_size * np.arange(num_samples) + image_size // 2) 226 ax.set_yticklabels( 227 list( 228 map( 229 lambda x: f"{x:.1f}", 230 np.linspace(range_start, range_end, num_samples), 231 ) 232 ) 233 ) 234 235 return fig