shimmer.modules.selection
1from abc import ABC, abstractmethod 2from collections.abc import Iterable 3 4import torch 5import torch.nn as nn 6 7from shimmer.types import LatentsDomainGroupT 8from shimmer.utils import group_batch_size, group_device 9 10 11class SelectionBase(torch.nn.Module, ABC): 12 """ 13 This is the base class for the selection mechanism. 14 The selection mechanisms handles the "competition" between modules and *selects* 15 fusion coefficients for the domains. 16 """ 17 18 def update_gw_state(self, gw_state: torch.Tensor) -> None: 19 """ 20 Update the internal copy of the previous GW state. 21 By default, this is not implemented and will raise an error if used. 22 23 :note.. 24 This is not defined as an abstractmethod as some selection method may 25 not need it. 26 27 Args: 28 gw_state (`torch.Tensor`): the previous GW state 29 """ 30 pass 31 32 @abstractmethod 33 def forward( 34 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 35 ) -> dict[str, torch.Tensor]: 36 """ 37 Forward pass of the selection method. 38 39 Args: 40 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 41 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 42 representation. 43 44 Returns: 45 `dict[str, torch.Tensor]`: for each domain in the group, the fusion 46 coefficient for each item in the batch. 47 48 Example: 49 >>> SomeSelectionImplementation().forward( 50 ... {"v": torch.randn(3, 4), "t": torch.randn(3, 8)} 51 ... ) 52 {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])} 53 """ 54 ... 55 56 # This is just for proper auto-completion... 57 def __call__( 58 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 59 ) -> dict[str, torch.Tensor]: 60 return super().__call__(domains, encodings_pre_fusion) 61 62 63class SingleDomainSelection(SelectionBase): 64 """ 65 This selection mechanism handles groups that can have multiple domains, but always 66 return a selection of 1 domain from the group with a uniform distribution. 67 68 For example, if the group has 2 domains, there is a 50% chance of selecting each 69 domain. 70 """ 71 72 def forward( 73 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 74 ) -> dict[str, torch.Tensor]: 75 """ 76 Forward pass of the module. 77 78 Args: 79 domains (`LatentsDomainGroupT`): input unimodal latent representations 80 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 81 representation. 82 83 Returns: 84 `dict[str, torch.Tensor]`: whether the domain is selected for each input 85 in the batch. 86 """ 87 selection: dict[str, torch.Tensor] = {} 88 bs = group_batch_size(domains) 89 choice = torch.randint(len(domains), size=(bs,), device=group_device(domains)) 90 for k, domain in enumerate(domains.keys()): 91 selection[domain] = (choice == k).to(torch.float32) 92 return selection 93 94 95class FixedSharedSelection(SelectionBase): 96 """ 97 This selection mechanism is deterministic and always shares the weights equally 98 between domains. 99 100 For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each... 101 """ 102 103 def forward( 104 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 105 ) -> dict[str, torch.Tensor]: 106 """ 107 Forward pass of the module. 108 109 Args: 110 domains (`LatentsDomainGroupT`): input unimodal latent representations 111 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 112 representation. 113 114 Returns: 115 `dict[str, torch.Tensor]`: whether the domain is selected for each input 116 in the batch. 117 """ 118 selection: dict[str, torch.Tensor] = {} 119 bs = group_batch_size(domains) 120 coef = torch.full((bs,), 1.0 / len(domains), device=group_device(domains)) 121 for domain in domains: 122 selection[domain] = coef.clone() 123 return selection 124 125 126def _calculate_attention_dict( 127 domains: LatentsDomainGroupT, 128 keys: dict[str, torch.Tensor], 129 query: torch.Tensor, 130) -> dict[str, torch.Tensor]: 131 """ 132 Args: 133 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 134 keys (`dict[str, torch.Tensor]`): The keys for each domain. 135 query (`torch.Tensor`): The query tensor. 136 137 Returns: 138 `dict[str, torch.Tensor]`: The attention scores for each domain. 139 """ 140 dot_products = { 141 domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze() 142 for domain, key in keys.items() 143 } 144 145 dot_products_tensor = torch.stack(list(dot_products.values()), dim=1) 146 147 attention_scores = torch.softmax(dot_products_tensor, dim=1) 148 149 attention_dict = { 150 domain: attention_scores[:, i] for i, domain in enumerate(domains) 151 } 152 return attention_dict 153 154 155class RandomSelection(SelectionBase): 156 """ 157 Modified random attention to only utilize uniform-softmax scores across modalities. 158 This version omits the binary scaling factors and focuses on generating attention 159 coefficients using a uniform distribution followed by a domain-wise softmax. 160 """ 161 162 def __init__(self, temperature: float): 163 """ 164 Args: 165 temperature (`float`): Temperature of the softmax applied to uniform 166 scaling factors. 167 """ 168 super().__init__() 169 self.temperature = temperature 170 171 def forward( 172 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 173 ) -> dict[str, torch.Tensor]: 174 """ 175 Generate uniform-then-domain-wise-softmaxed samples for each domain. 176 177 Args: 178 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 179 This is not used in the function directly but determines the structure 180 of the returned attention coefficients. 181 182 Returns: 183 `dict[str, torch.Tensor]`: For each domain in the group, the fusion 184 coefficient for each item in the batch, based solely on 185 uniform-softmax scores. 186 """ 187 num_domains = len(domains) 188 batch_size = group_batch_size(domains) 189 device = group_device(domains) 190 191 # Generate uniform scores 192 uniform_scores = torch.rand(batch_size, num_domains, device=device) 193 194 # Apply softmax across domains with temperature scaling 195 softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1) 196 # Create attention dictionary for each domain 197 attention_dict = { 198 domain: softmax_scores[:, i] for i, domain in enumerate(domains) 199 } 200 201 return attention_dict 202 203 204class DynamicQueryAttention(SelectionBase): 205 """ 206 Key-Query attention with a dynamic gw vector. 207 The query is updated based on the scaled gw vector. 208 """ 209 210 def __init__( 211 self, 212 head_size: int, 213 domain_dim: int, 214 domain_names: Iterable[str], 215 n_steps: int = 1, 216 ): 217 """ 218 Args: 219 head_size (`int`) : dimension of the key and query vectors. 220 domain_dim (`int`) : dimension of the input dims (assumed to be the same 221 for now) 222 domain_names (`Iterable[str]`) : list of input domains 223 n_steps (`int`) : number of steps to update the query vector 224 """ 225 super().__init__() 226 self.head_size = head_size 227 self.query_layer = nn.Linear(domain_dim, head_size) 228 self.key_layers = nn.ModuleDict( 229 {domain: nn.Linear(domain_dim, head_size) for domain in domain_names} 230 ) 231 self.n_steps = n_steps 232 self.step_limit = n_steps # Default step limit is n_steps 233 # Start with a random gw state 234 self.register_buffer("initial_gw_state", torch.rand(domain_dim)) 235 236 def set_step_limit(self, step_limit: int): 237 """ 238 Sets the step limit for the dynamic attention update loop. 239 240 Args: 241 step_limit (`int`): Maximum number of steps to run the loop. 242 """ 243 if step_limit > self.n_steps: 244 raise ValueError( 245 f"Step limit cannot exceed the maximum n_steps ({self.n_steps})." 246 ) 247 self.step_limit = step_limit 248 249 def fuse_weighted_encodings( 250 self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor] 251 ) -> torch.Tensor: 252 """ 253 Fuse the weighted encodings using the attention scores. 254 255 Args: 256 encodings (`LatentsDomainGroupT`): Unimodal latent representation 257 attention_dict (`dict[str, torch.Tensor]`): The attention scores for each 258 domain in the group. 259 260 Returns: 261 `torch.Tensor`: The fused tensor. 262 """ 263 # Apply attention scores to the encodings 264 weighted_encodings = {} 265 for key in attention_dict: 266 if key in encodings: 267 # Perform element-wise multiplication 268 weighted_encodings[key] = ( 269 attention_dict[key].unsqueeze(1) * encodings[key] 270 ) 271 272 # Stack the tensors along a new dimension (dimension 0) 273 stacked_tensors = torch.stack(list(weighted_encodings.values())) 274 275 # Apply fusion by summing along the newly created dimension 276 summed_tensor = torch.sum(stacked_tensors, dim=0) 277 return summed_tensor 278 279 def forward( 280 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 281 ) -> dict[str, torch.Tensor]: 282 """ 283 Compute keys and queries, match them with dot product and softmax. 284 Does this twice, once with the static query and once with a dynamic query. 285 286 Args: 287 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 288 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 289 representation. 290 291 Returns: 292 `dict[str, torch.Tensor]`: the attention scores for each domain in the 293 group. 294 """ 295 296 keys = { 297 domain: self.key_layers[domain](encoding) 298 for domain, encoding in domains.items() 299 } 300 301 batch_size = group_batch_size(domains) 302 303 # Retrieve random query 304 query = self.query_layer(self.initial_gw_state.expand(batch_size, -1)) 305 306 # Calculate the attention scores 307 attention_dict = _calculate_attention_dict(domains, keys, query) 308 309 if self.n_steps > 0: 310 # Update the query based on the static attention scores 311 for _ in range(min(self.step_limit, self.n_steps)): 312 # Apply the attention scores to the encodings 313 summed_tensor = self.fuse_weighted_encodings( 314 encodings_pre_fusion, attention_dict 315 ) 316 317 # Retrieve query (now it is dependent on the new gw state) 318 query = self.query_layer(summed_tensor) 319 320 # Calculate the attention scores again 321 attention_dict = _calculate_attention_dict(domains, keys, query) 322 323 return attention_dict
12class SelectionBase(torch.nn.Module, ABC): 13 """ 14 This is the base class for the selection mechanism. 15 The selection mechanisms handles the "competition" between modules and *selects* 16 fusion coefficients for the domains. 17 """ 18 19 def update_gw_state(self, gw_state: torch.Tensor) -> None: 20 """ 21 Update the internal copy of the previous GW state. 22 By default, this is not implemented and will raise an error if used. 23 24 :note.. 25 This is not defined as an abstractmethod as some selection method may 26 not need it. 27 28 Args: 29 gw_state (`torch.Tensor`): the previous GW state 30 """ 31 pass 32 33 @abstractmethod 34 def forward( 35 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 36 ) -> dict[str, torch.Tensor]: 37 """ 38 Forward pass of the selection method. 39 40 Args: 41 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 42 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 43 representation. 44 45 Returns: 46 `dict[str, torch.Tensor]`: for each domain in the group, the fusion 47 coefficient for each item in the batch. 48 49 Example: 50 >>> SomeSelectionImplementation().forward( 51 ... {"v": torch.randn(3, 4), "t": torch.randn(3, 8)} 52 ... ) 53 {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])} 54 """ 55 ... 56 57 # This is just for proper auto-completion... 58 def __call__( 59 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 60 ) -> dict[str, torch.Tensor]: 61 return super().__call__(domains, encodings_pre_fusion)
This is the base class for the selection mechanism. The selection mechanisms handles the "competition" between modules and selects fusion coefficients for the domains.
19 def update_gw_state(self, gw_state: torch.Tensor) -> None: 20 """ 21 Update the internal copy of the previous GW state. 22 By default, this is not implemented and will raise an error if used. 23 24 :note.. 25 This is not defined as an abstractmethod as some selection method may 26 not need it. 27 28 Args: 29 gw_state (`torch.Tensor`): the previous GW state 30 """ 31 pass
Update the internal copy of the previous GW state. By default, this is not implemented and will raise an error if used.
:note.. This is not defined as an abstractmethod as some selection method may not need it.
Arguments:
- gw_state (
torch.Tensor
): the previous GW state
33 @abstractmethod 34 def forward( 35 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 36 ) -> dict[str, torch.Tensor]: 37 """ 38 Forward pass of the selection method. 39 40 Args: 41 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 42 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 43 representation. 44 45 Returns: 46 `dict[str, torch.Tensor]`: for each domain in the group, the fusion 47 coefficient for each item in the batch. 48 49 Example: 50 >>> SomeSelectionImplementation().forward( 51 ... {"v": torch.randn(3, 4), "t": torch.randn(3, 8)} 52 ... ) 53 {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])} 54 """ 55 ...
Forward pass of the selection method.
Arguments:
- domains (
LatentsDomainGroupT
): Group of unimodal latent representations. - encodings_pre_fusion (
LatentsDomainGroupT
): pre-fusion domain latent representation.
Returns:
dict[str, torch.Tensor]
: for each domain in the group, the fusion coefficient for each item in the batch.
Example:
>>> SomeSelectionImplementation().forward( ... {"v": torch.randn(3, 4), "t": torch.randn(3, 8)} ... ) {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
64class SingleDomainSelection(SelectionBase): 65 """ 66 This selection mechanism handles groups that can have multiple domains, but always 67 return a selection of 1 domain from the group with a uniform distribution. 68 69 For example, if the group has 2 domains, there is a 50% chance of selecting each 70 domain. 71 """ 72 73 def forward( 74 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 75 ) -> dict[str, torch.Tensor]: 76 """ 77 Forward pass of the module. 78 79 Args: 80 domains (`LatentsDomainGroupT`): input unimodal latent representations 81 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 82 representation. 83 84 Returns: 85 `dict[str, torch.Tensor]`: whether the domain is selected for each input 86 in the batch. 87 """ 88 selection: dict[str, torch.Tensor] = {} 89 bs = group_batch_size(domains) 90 choice = torch.randint(len(domains), size=(bs,), device=group_device(domains)) 91 for k, domain in enumerate(domains.keys()): 92 selection[domain] = (choice == k).to(torch.float32) 93 return selection
This selection mechanism handles groups that can have multiple domains, but always return a selection of 1 domain from the group with a uniform distribution.
For example, if the group has 2 domains, there is a 50% chance of selecting each domain.
73 def forward( 74 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 75 ) -> dict[str, torch.Tensor]: 76 """ 77 Forward pass of the module. 78 79 Args: 80 domains (`LatentsDomainGroupT`): input unimodal latent representations 81 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 82 representation. 83 84 Returns: 85 `dict[str, torch.Tensor]`: whether the domain is selected for each input 86 in the batch. 87 """ 88 selection: dict[str, torch.Tensor] = {} 89 bs = group_batch_size(domains) 90 choice = torch.randint(len(domains), size=(bs,), device=group_device(domains)) 91 for k, domain in enumerate(domains.keys()): 92 selection[domain] = (choice == k).to(torch.float32) 93 return selection
Forward pass of the module.
Arguments:
- domains (
LatentsDomainGroupT
): input unimodal latent representations - encodings_pre_fusion (
LatentsDomainGroupT
): pre-fusion domain latent representation.
Returns:
dict[str, torch.Tensor]
: whether the domain is selected for each input in the batch.
Inherited Members
156class RandomSelection(SelectionBase): 157 """ 158 Modified random attention to only utilize uniform-softmax scores across modalities. 159 This version omits the binary scaling factors and focuses on generating attention 160 coefficients using a uniform distribution followed by a domain-wise softmax. 161 """ 162 163 def __init__(self, temperature: float): 164 """ 165 Args: 166 temperature (`float`): Temperature of the softmax applied to uniform 167 scaling factors. 168 """ 169 super().__init__() 170 self.temperature = temperature 171 172 def forward( 173 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 174 ) -> dict[str, torch.Tensor]: 175 """ 176 Generate uniform-then-domain-wise-softmaxed samples for each domain. 177 178 Args: 179 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 180 This is not used in the function directly but determines the structure 181 of the returned attention coefficients. 182 183 Returns: 184 `dict[str, torch.Tensor]`: For each domain in the group, the fusion 185 coefficient for each item in the batch, based solely on 186 uniform-softmax scores. 187 """ 188 num_domains = len(domains) 189 batch_size = group_batch_size(domains) 190 device = group_device(domains) 191 192 # Generate uniform scores 193 uniform_scores = torch.rand(batch_size, num_domains, device=device) 194 195 # Apply softmax across domains with temperature scaling 196 softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1) 197 # Create attention dictionary for each domain 198 attention_dict = { 199 domain: softmax_scores[:, i] for i, domain in enumerate(domains) 200 } 201 202 return attention_dict
Modified random attention to only utilize uniform-softmax scores across modalities. This version omits the binary scaling factors and focuses on generating attention coefficients using a uniform distribution followed by a domain-wise softmax.
163 def __init__(self, temperature: float): 164 """ 165 Args: 166 temperature (`float`): Temperature of the softmax applied to uniform 167 scaling factors. 168 """ 169 super().__init__() 170 self.temperature = temperature
Arguments:
- temperature (
float
): Temperature of the softmax applied to uniform scaling factors.
172 def forward( 173 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 174 ) -> dict[str, torch.Tensor]: 175 """ 176 Generate uniform-then-domain-wise-softmaxed samples for each domain. 177 178 Args: 179 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 180 This is not used in the function directly but determines the structure 181 of the returned attention coefficients. 182 183 Returns: 184 `dict[str, torch.Tensor]`: For each domain in the group, the fusion 185 coefficient for each item in the batch, based solely on 186 uniform-softmax scores. 187 """ 188 num_domains = len(domains) 189 batch_size = group_batch_size(domains) 190 device = group_device(domains) 191 192 # Generate uniform scores 193 uniform_scores = torch.rand(batch_size, num_domains, device=device) 194 195 # Apply softmax across domains with temperature scaling 196 softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1) 197 # Create attention dictionary for each domain 198 attention_dict = { 199 domain: softmax_scores[:, i] for i, domain in enumerate(domains) 200 } 201 202 return attention_dict
Generate uniform-then-domain-wise-softmaxed samples for each domain.
Arguments:
- domains (
LatentsDomainGroupT
): Group of unimodal latent representations. This is not used in the function directly but determines the structure of the returned attention coefficients.
Returns:
dict[str, torch.Tensor]
: For each domain in the group, the fusion coefficient for each item in the batch, based solely on uniform-softmax scores.
Inherited Members
205class DynamicQueryAttention(SelectionBase): 206 """ 207 Key-Query attention with a dynamic gw vector. 208 The query is updated based on the scaled gw vector. 209 """ 210 211 def __init__( 212 self, 213 head_size: int, 214 domain_dim: int, 215 domain_names: Iterable[str], 216 n_steps: int = 1, 217 ): 218 """ 219 Args: 220 head_size (`int`) : dimension of the key and query vectors. 221 domain_dim (`int`) : dimension of the input dims (assumed to be the same 222 for now) 223 domain_names (`Iterable[str]`) : list of input domains 224 n_steps (`int`) : number of steps to update the query vector 225 """ 226 super().__init__() 227 self.head_size = head_size 228 self.query_layer = nn.Linear(domain_dim, head_size) 229 self.key_layers = nn.ModuleDict( 230 {domain: nn.Linear(domain_dim, head_size) for domain in domain_names} 231 ) 232 self.n_steps = n_steps 233 self.step_limit = n_steps # Default step limit is n_steps 234 # Start with a random gw state 235 self.register_buffer("initial_gw_state", torch.rand(domain_dim)) 236 237 def set_step_limit(self, step_limit: int): 238 """ 239 Sets the step limit for the dynamic attention update loop. 240 241 Args: 242 step_limit (`int`): Maximum number of steps to run the loop. 243 """ 244 if step_limit > self.n_steps: 245 raise ValueError( 246 f"Step limit cannot exceed the maximum n_steps ({self.n_steps})." 247 ) 248 self.step_limit = step_limit 249 250 def fuse_weighted_encodings( 251 self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor] 252 ) -> torch.Tensor: 253 """ 254 Fuse the weighted encodings using the attention scores. 255 256 Args: 257 encodings (`LatentsDomainGroupT`): Unimodal latent representation 258 attention_dict (`dict[str, torch.Tensor]`): The attention scores for each 259 domain in the group. 260 261 Returns: 262 `torch.Tensor`: The fused tensor. 263 """ 264 # Apply attention scores to the encodings 265 weighted_encodings = {} 266 for key in attention_dict: 267 if key in encodings: 268 # Perform element-wise multiplication 269 weighted_encodings[key] = ( 270 attention_dict[key].unsqueeze(1) * encodings[key] 271 ) 272 273 # Stack the tensors along a new dimension (dimension 0) 274 stacked_tensors = torch.stack(list(weighted_encodings.values())) 275 276 # Apply fusion by summing along the newly created dimension 277 summed_tensor = torch.sum(stacked_tensors, dim=0) 278 return summed_tensor 279 280 def forward( 281 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 282 ) -> dict[str, torch.Tensor]: 283 """ 284 Compute keys and queries, match them with dot product and softmax. 285 Does this twice, once with the static query and once with a dynamic query. 286 287 Args: 288 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 289 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 290 representation. 291 292 Returns: 293 `dict[str, torch.Tensor]`: the attention scores for each domain in the 294 group. 295 """ 296 297 keys = { 298 domain: self.key_layers[domain](encoding) 299 for domain, encoding in domains.items() 300 } 301 302 batch_size = group_batch_size(domains) 303 304 # Retrieve random query 305 query = self.query_layer(self.initial_gw_state.expand(batch_size, -1)) 306 307 # Calculate the attention scores 308 attention_dict = _calculate_attention_dict(domains, keys, query) 309 310 if self.n_steps > 0: 311 # Update the query based on the static attention scores 312 for _ in range(min(self.step_limit, self.n_steps)): 313 # Apply the attention scores to the encodings 314 summed_tensor = self.fuse_weighted_encodings( 315 encodings_pre_fusion, attention_dict 316 ) 317 318 # Retrieve query (now it is dependent on the new gw state) 319 query = self.query_layer(summed_tensor) 320 321 # Calculate the attention scores again 322 attention_dict = _calculate_attention_dict(domains, keys, query) 323 324 return attention_dict
Key-Query attention with a dynamic gw vector. The query is updated based on the scaled gw vector.
211 def __init__( 212 self, 213 head_size: int, 214 domain_dim: int, 215 domain_names: Iterable[str], 216 n_steps: int = 1, 217 ): 218 """ 219 Args: 220 head_size (`int`) : dimension of the key and query vectors. 221 domain_dim (`int`) : dimension of the input dims (assumed to be the same 222 for now) 223 domain_names (`Iterable[str]`) : list of input domains 224 n_steps (`int`) : number of steps to update the query vector 225 """ 226 super().__init__() 227 self.head_size = head_size 228 self.query_layer = nn.Linear(domain_dim, head_size) 229 self.key_layers = nn.ModuleDict( 230 {domain: nn.Linear(domain_dim, head_size) for domain in domain_names} 231 ) 232 self.n_steps = n_steps 233 self.step_limit = n_steps # Default step limit is n_steps 234 # Start with a random gw state 235 self.register_buffer("initial_gw_state", torch.rand(domain_dim))
Arguments:
- head_size (
int
) : dimension of the key and query vectors. - domain_dim (
int
) : dimension of the input dims (assumed to be the same for now) - domain_names (
Iterable[str]
) : list of input domains - n_steps (
int
) : number of steps to update the query vector
237 def set_step_limit(self, step_limit: int): 238 """ 239 Sets the step limit for the dynamic attention update loop. 240 241 Args: 242 step_limit (`int`): Maximum number of steps to run the loop. 243 """ 244 if step_limit > self.n_steps: 245 raise ValueError( 246 f"Step limit cannot exceed the maximum n_steps ({self.n_steps})." 247 ) 248 self.step_limit = step_limit
Sets the step limit for the dynamic attention update loop.
Arguments:
- step_limit (
int
): Maximum number of steps to run the loop.
250 def fuse_weighted_encodings( 251 self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor] 252 ) -> torch.Tensor: 253 """ 254 Fuse the weighted encodings using the attention scores. 255 256 Args: 257 encodings (`LatentsDomainGroupT`): Unimodal latent representation 258 attention_dict (`dict[str, torch.Tensor]`): The attention scores for each 259 domain in the group. 260 261 Returns: 262 `torch.Tensor`: The fused tensor. 263 """ 264 # Apply attention scores to the encodings 265 weighted_encodings = {} 266 for key in attention_dict: 267 if key in encodings: 268 # Perform element-wise multiplication 269 weighted_encodings[key] = ( 270 attention_dict[key].unsqueeze(1) * encodings[key] 271 ) 272 273 # Stack the tensors along a new dimension (dimension 0) 274 stacked_tensors = torch.stack(list(weighted_encodings.values())) 275 276 # Apply fusion by summing along the newly created dimension 277 summed_tensor = torch.sum(stacked_tensors, dim=0) 278 return summed_tensor
Fuse the weighted encodings using the attention scores.
Arguments:
- encodings (
LatentsDomainGroupT
): Unimodal latent representation - attention_dict (
dict[str, torch.Tensor]
): The attention scores for each domain in the group.
Returns:
torch.Tensor
: The fused tensor.
280 def forward( 281 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 282 ) -> dict[str, torch.Tensor]: 283 """ 284 Compute keys and queries, match them with dot product and softmax. 285 Does this twice, once with the static query and once with a dynamic query. 286 287 Args: 288 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 289 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 290 representation. 291 292 Returns: 293 `dict[str, torch.Tensor]`: the attention scores for each domain in the 294 group. 295 """ 296 297 keys = { 298 domain: self.key_layers[domain](encoding) 299 for domain, encoding in domains.items() 300 } 301 302 batch_size = group_batch_size(domains) 303 304 # Retrieve random query 305 query = self.query_layer(self.initial_gw_state.expand(batch_size, -1)) 306 307 # Calculate the attention scores 308 attention_dict = _calculate_attention_dict(domains, keys, query) 309 310 if self.n_steps > 0: 311 # Update the query based on the static attention scores 312 for _ in range(min(self.step_limit, self.n_steps)): 313 # Apply the attention scores to the encodings 314 summed_tensor = self.fuse_weighted_encodings( 315 encodings_pre_fusion, attention_dict 316 ) 317 318 # Retrieve query (now it is dependent on the new gw state) 319 query = self.query_layer(summed_tensor) 320 321 # Calculate the attention scores again 322 attention_dict = _calculate_attention_dict(domains, keys, query) 323 324 return attention_dict
Compute keys and queries, match them with dot product and softmax. Does this twice, once with the static query and once with a dynamic query.
Arguments:
- domains (
LatentsDomainGroupT
): Group of unimodal latent representations. - encodings_pre_fusion (
LatentsDomainGroupT
): pre-fusion domain latent representation.
Returns:
dict[str, torch.Tensor]
: the attention scores for each domain in the group.