shimmer.modules.selection
1from abc import ABC, abstractmethod 2from collections.abc import Iterable, Mapping 3from typing import cast 4 5import torch 6import torch.nn as nn 7 8from shimmer.types import LatentsDomainGroupT 9from shimmer.utils import group_batch_size, group_device 10 11 12class SelectionBase(torch.nn.Module, ABC): 13 """ 14 This is the base class for the selection mechanism. 15 The selection mechanisms handles the "competition" between modules and *selects* 16 fusion coefficients for the domains. 17 """ 18 19 def update_gw_state(self, gw_state: torch.Tensor) -> None: 20 """ 21 Update the internal copy of the previous GW state. 22 By default, this is not implemented and will raise an error if used. 23 24 :note.. 25 This is not defined as an abstractmethod as some selection method may 26 not need it. 27 28 Args: 29 gw_state (`torch.Tensor`): the previous GW state 30 """ 31 pass 32 33 @abstractmethod 34 def forward( 35 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 36 ) -> dict[str, torch.Tensor]: 37 """ 38 Forward pass of the selection method. 39 40 Args: 41 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 42 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 43 representation. 44 45 Returns: 46 `dict[str, torch.Tensor]`: for each domain in the group, the fusion 47 coefficient for each item in the batch. 48 49 Example: 50 >>> SomeSelectionImplementation().forward( 51 ... {"v": torch.randn(3, 4), "t": torch.randn(3, 8)} 52 ... ) 53 {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])} 54 """ 55 ... 56 57 # This is just for proper auto-completion... 58 def __call__( 59 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 60 ) -> dict[str, torch.Tensor]: 61 return super().__call__(domains, encodings_pre_fusion) 62 63 64class SingleDomainSelection(SelectionBase): 65 """ 66 This selection mechanism handles groups that can have multiple domains, but always 67 return a selection of 1 domain from the group with a uniform distribution. 68 69 For example, if the group has 2 domains, there is a 50% chance of selecting each 70 domain. 71 """ 72 73 def forward( 74 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 75 ) -> dict[str, torch.Tensor]: 76 """ 77 Forward pass of the module. 78 79 Args: 80 domains (`LatentsDomainGroupT`): input unimodal latent representations 81 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 82 representation. 83 84 Returns: 85 `dict[str, torch.Tensor]`: whether the domain is selected for each input 86 in the batch. 87 """ 88 selection: dict[str, torch.Tensor] = {} 89 bs = group_batch_size(domains) 90 choice = torch.randint(len(domains), size=(bs,), device=group_device(domains)) 91 for k, domain in enumerate(domains.keys()): 92 selection[domain] = (choice == k).to(torch.float32) 93 return selection 94 95 96class FixedSharedSelection(SelectionBase): 97 """ 98 This selection mechanism is deterministic and always shares the weights equally 99 between domains. 100 101 For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each... 102 """ 103 104 def forward( 105 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 106 ) -> dict[str, torch.Tensor]: 107 """ 108 Forward pass of the module. 109 110 Args: 111 domains (`LatentsDomainGroupT`): input unimodal latent representations 112 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 113 representation. 114 115 Returns: 116 `dict[str, torch.Tensor]`: whether the domain is selected for each input 117 in the batch. 118 """ 119 selection: dict[str, torch.Tensor] = {} 120 bs = group_batch_size(domains) 121 coef = torch.full((bs,), 1.0 / len(domains), device=group_device(domains)) 122 for domain in domains: 123 selection[domain] = coef.clone() 124 return selection 125 126 127def _calculate_attention_dict( 128 domains: LatentsDomainGroupT, 129 keys: dict[str, torch.Tensor], 130 query: torch.Tensor, 131) -> dict[str, torch.Tensor]: 132 """ 133 Args: 134 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 135 keys (`dict[str, torch.Tensor]`): The keys for each domain. 136 query (`torch.Tensor`): The query tensor. 137 138 Returns: 139 `dict[str, torch.Tensor]`: The attention scores for each domain. 140 """ 141 dot_products = { 142 domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze() 143 for domain, key in keys.items() 144 } 145 146 dot_products_tensor = torch.stack(list(dot_products.values()), dim=1) 147 148 attention_scores = torch.softmax(dot_products_tensor, dim=1) 149 150 attention_dict = { 151 domain: attention_scores[:, i] for i, domain in enumerate(domains) 152 } 153 return attention_dict 154 155 156class LearnedAttention(SelectionBase): 157 """ 158 Content-based single-step attention over GW latents with configurable toggles. 159 160 Design: 161 - Query is the mean of available GW latents (content-q0 seed) 162 - Single-step dot-product attention over domains (no refinement loop) 163 - Optional per-domain keys 164 165 Toggles: 166 - per_domain_keys: use per-domain key projections instead of a shared one 167 - stopgrad: detach GW latents before computing keys/query 168 - key_on_prefusion: compute keys on pre-fusion GW latents (True) or raw domains 169 - domain_dims: required when key_on_prefusion=False to size per-domain key layers 170 """ 171 172 def __init__( 173 self, 174 gw_dim: int, 175 domain_names: Iterable[str], 176 head_size: int = 64, 177 per_domain_keys: bool = False, 178 stopgrad: bool = True, 179 key_on_prefusion: bool = True, 180 domain_dims: Mapping[str, int] | None = None, 181 ): 182 super().__init__() 183 self.gw_dim = int(gw_dim) 184 self.head_size = int(head_size) 185 self.domain_names = list(domain_names) 186 187 # Toggles 188 self.per_domain_keys = bool(per_domain_keys) 189 self.stopgrad = bool(stopgrad) 190 self.key_on_prefusion = bool(key_on_prefusion) 191 self.domain_dims = dict(domain_dims) if domain_dims is not None else None 192 193 # Projections 194 self.query_layer = nn.Linear(self.gw_dim, self.head_size) 195 self.per_key_layers: nn.ModuleDict | None 196 self.shared_key_layer: nn.Linear | None 197 if self.key_on_prefusion: 198 if self.per_domain_keys: 199 self.per_key_layers = nn.ModuleDict( 200 { 201 d: nn.Linear(self.gw_dim, self.head_size) 202 for d in self.domain_names 203 } 204 ) 205 self.shared_key_layer = None 206 else: 207 self.shared_key_layer = nn.Linear(self.gw_dim, self.head_size) 208 self.per_key_layers = None 209 else: 210 if not self.per_domain_keys: 211 raise ValueError( 212 "key_on_prefusion=False requires per_domain_keys=True because " 213 "domain latent dimensions can differ." 214 ) 215 if self.domain_dims is None: 216 raise ValueError( 217 "key_on_prefusion=False requires domain_dims for key projections." 218 ) 219 missing_dims = [d for d in self.domain_names if d not in self.domain_dims] 220 if missing_dims: 221 raise ValueError( 222 f"Missing domain_dims for: {', '.join(sorted(missing_dims))}" 223 ) 224 self.per_key_layers = nn.ModuleDict( 225 { 226 d: nn.Linear(self.domain_dims[d], self.head_size) 227 for d in self.domain_names 228 } 229 ) 230 self.shared_key_layer = None 231 232 @staticmethod 233 def _calc_attention( 234 keys: dict[str, torch.Tensor], 235 query: torch.Tensor, 236 order: Iterable[str], 237 ) -> dict[str, torch.Tensor]: 238 """ 239 Compute attention over domains. 240 241 Args: 242 keys: mapping of domain -> key tensor (B, H) 243 query: query tensor (B, H) 244 order: iterable of domain names to fix output ordering 245 246 Returns: 247 dict[str, torch.Tensor]: per-domain attention scores that sum to 1. 248 """ 249 names = [d for d in order if d in keys] 250 if not names: 251 raise ValueError("LearnedAttention: no keys provided.") 252 253 logits = torch.stack( 254 [(keys[d] * query).sum(dim=1) for d in names], dim=1 255 ) # (B, D) 256 257 probs = torch.softmax(logits, dim=1) 258 259 return {d: probs[:, i] for i, d in enumerate(names)} 260 261 def forward( 262 self, 263 domains: LatentsDomainGroupT, 264 encodings_pre_fusion: LatentsDomainGroupT | None = None, 265 ) -> dict[str, torch.Tensor]: 266 """ 267 Args: 268 domains: mapping from domain name to GW latent (B, gw_dim) 269 encodings_pre_fusion: pre-fusion encodings (used when key_on_prefusion) 270 271 Returns: 272 dict[str, torch.Tensor]: per-domain attention weights. 273 """ 274 domain_latents: Mapping[str, torch.Tensor] = domains 275 276 present = [d for d in self.domain_names if d in domain_latents] 277 if not present: 278 raise ValueError( 279 "LearnedAttention: no known domains present in gw_latents." 280 ) 281 282 if self.key_on_prefusion: 283 if encodings_pre_fusion is None: 284 raise ValueError( 285 "key_on_prefusion=True requires encodings_pre_fusion inputs." 286 ) 287 key_source = encodings_pre_fusion 288 else: 289 key_source = domain_latents 290 291 missing_keys = [d for d in present if d not in key_source] 292 if missing_keys: 293 raise ValueError( 294 f"Missing key latents for: {', '.join(sorted(missing_keys))}" 295 ) 296 297 if encodings_pre_fusion is None: 298 query_source = domain_latents 299 else: 300 query_source = encodings_pre_fusion 301 302 missing_query = [d for d in present if d not in query_source] 303 if missing_query: 304 raise ValueError( 305 f"Missing query latents for: {', '.join(sorted(missing_query))}" 306 ) 307 308 if self.stopgrad: 309 key_latents = {d: key_source[d].detach() for d in present} 310 query_latents = {d: query_source[d].detach() for d in present} 311 else: 312 key_latents = {d: key_source[d] for d in present} 313 query_latents = {d: query_source[d] for d in present} 314 315 if self.per_domain_keys: 316 if self.per_key_layers is None: 317 raise RuntimeError( 318 "per_domain_keys=True but per-domain key layers are missing." 319 ) 320 keys = { 321 d: cast(nn.Linear, self.per_key_layers[d])(key_latents[d]) 322 for d in present 323 } 324 else: 325 if self.shared_key_layer is None: 326 raise RuntimeError( 327 "per_domain_keys=False but shared key layer is missing." 328 ) 329 proj = self.shared_key_layer 330 keys = {d: proj(key_latents[d]) for d in present} 331 332 stacked = torch.stack([query_latents[d] for d in present], dim=0) # (D, B, F) 333 query = self.query_layer(stacked.mean(0)) # (B, H) 334 335 return self._calc_attention( 336 keys=keys, 337 query=query, 338 order=self.domain_names, 339 ) 340 341 342class RandomSelection(SelectionBase): 343 """ 344 Modified random attention to only utilize uniform-softmax scores across modalities. 345 This version omits the binary scaling factors and focuses on generating attention 346 coefficients using a uniform distribution followed by a domain-wise softmax. 347 """ 348 349 def __init__(self, temperature: float): 350 """ 351 Args: 352 temperature (`float`): Temperature of the softmax applied to uniform 353 scaling factors. 354 """ 355 super().__init__() 356 self.temperature = temperature 357 358 def forward( 359 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 360 ) -> dict[str, torch.Tensor]: 361 """ 362 Generate uniform-then-domain-wise-softmaxed samples for each domain. 363 364 Args: 365 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 366 This is not used in the function directly but determines the structure 367 of the returned attention coefficients. 368 369 Returns: 370 `dict[str, torch.Tensor]`: For each domain in the group, the fusion 371 coefficient for each item in the batch, based solely on 372 uniform-softmax scores. 373 """ 374 num_domains = len(domains) 375 batch_size = group_batch_size(domains) 376 device = group_device(domains) 377 378 # Generate uniform scores 379 uniform_scores = torch.rand(batch_size, num_domains, device=device) 380 381 # Apply softmax across domains with temperature scaling 382 softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1) 383 # Create attention dictionary for each domain 384 attention_dict = { 385 domain: softmax_scores[:, i] for i, domain in enumerate(domains) 386 } 387 388 return attention_dict 389 390 391class DynamicQueryAttention(SelectionBase): 392 """ 393 Key-Query attention with a dynamic gw vector. 394 The query is updated based on the scaled gw vector. 395 """ 396 397 def __init__( 398 self, 399 head_size: int, 400 domain_dim: int, 401 domain_names: Iterable[str], 402 n_steps: int = 1, 403 ): 404 """ 405 Args: 406 head_size (`int`) : dimension of the key and query vectors. 407 domain_dim (`int`) : dimension of the input dims (assumed to be the same 408 for now) 409 domain_names (`Iterable[str]`) : list of input domains 410 n_steps (`int`) : number of steps to update the query vector 411 """ 412 super().__init__() 413 self.head_size = head_size 414 self.query_layer = nn.Linear(domain_dim, head_size) 415 self.key_layers = nn.ModuleDict( 416 {domain: nn.Linear(domain_dim, head_size) for domain in domain_names} 417 ) 418 self.n_steps = n_steps 419 self.step_limit = n_steps # Default step limit is n_steps 420 # Start with a random gw state 421 self.register_buffer("initial_gw_state", torch.rand(domain_dim)) 422 423 def set_step_limit(self, step_limit: int): 424 """ 425 Sets the step limit for the dynamic attention update loop. 426 427 Args: 428 step_limit (`int`): Maximum number of steps to run the loop. 429 """ 430 if step_limit > self.n_steps: 431 raise ValueError( 432 f"Step limit cannot exceed the maximum n_steps ({self.n_steps})." 433 ) 434 self.step_limit = step_limit 435 436 def fuse_weighted_encodings( 437 self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor] 438 ) -> torch.Tensor: 439 """ 440 Fuse the weighted encodings using the attention scores. 441 442 Args: 443 encodings (`LatentsDomainGroupT`): Unimodal latent representation 444 attention_dict (`dict[str, torch.Tensor]`): The attention scores for each 445 domain in the group. 446 447 Returns: 448 `torch.Tensor`: The fused tensor. 449 """ 450 # Apply attention scores to the encodings 451 weighted_encodings = {} 452 for key in attention_dict: 453 if key in encodings: 454 # Perform element-wise multiplication 455 weighted_encodings[key] = ( 456 attention_dict[key].unsqueeze(1) * encodings[key] 457 ) 458 459 # Stack the tensors along a new dimension (dimension 0) 460 stacked_tensors = torch.stack(list(weighted_encodings.values())) 461 462 # Apply fusion by summing along the newly created dimension 463 summed_tensor = torch.sum(stacked_tensors, dim=0) 464 return summed_tensor 465 466 def forward( 467 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 468 ) -> dict[str, torch.Tensor]: 469 """ 470 Compute keys and queries, match them with dot product and softmax. 471 Does this twice, once with the static query and once with a dynamic query. 472 473 Args: 474 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 475 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 476 representation. 477 478 Returns: 479 `dict[str, torch.Tensor]`: the attention scores for each domain in the 480 group. 481 """ 482 483 keys = { 484 domain: self.key_layers[domain](encoding) 485 for domain, encoding in domains.items() 486 } 487 488 batch_size = group_batch_size(domains) 489 490 # Retrieve random query 491 query = self.query_layer(self.initial_gw_state.expand(batch_size, -1)) 492 493 # Calculate the attention scores 494 attention_dict = _calculate_attention_dict(domains, keys, query) 495 496 if self.n_steps > 0: 497 # Update the query based on the static attention scores 498 for _ in range(min(self.step_limit, self.n_steps)): 499 # Apply the attention scores to the encodings 500 summed_tensor = self.fuse_weighted_encodings( 501 encodings_pre_fusion, attention_dict 502 ) 503 504 # Retrieve query (now it is dependent on the new gw state) 505 query = self.query_layer(summed_tensor) 506 507 # Calculate the attention scores again 508 attention_dict = _calculate_attention_dict(domains, keys, query) 509 510 return attention_dict
13class SelectionBase(torch.nn.Module, ABC): 14 """ 15 This is the base class for the selection mechanism. 16 The selection mechanisms handles the "competition" between modules and *selects* 17 fusion coefficients for the domains. 18 """ 19 20 def update_gw_state(self, gw_state: torch.Tensor) -> None: 21 """ 22 Update the internal copy of the previous GW state. 23 By default, this is not implemented and will raise an error if used. 24 25 :note.. 26 This is not defined as an abstractmethod as some selection method may 27 not need it. 28 29 Args: 30 gw_state (`torch.Tensor`): the previous GW state 31 """ 32 pass 33 34 @abstractmethod 35 def forward( 36 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 37 ) -> dict[str, torch.Tensor]: 38 """ 39 Forward pass of the selection method. 40 41 Args: 42 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 43 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 44 representation. 45 46 Returns: 47 `dict[str, torch.Tensor]`: for each domain in the group, the fusion 48 coefficient for each item in the batch. 49 50 Example: 51 >>> SomeSelectionImplementation().forward( 52 ... {"v": torch.randn(3, 4), "t": torch.randn(3, 8)} 53 ... ) 54 {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])} 55 """ 56 ... 57 58 # This is just for proper auto-completion... 59 def __call__( 60 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 61 ) -> dict[str, torch.Tensor]: 62 return super().__call__(domains, encodings_pre_fusion)
This is the base class for the selection mechanism. The selection mechanisms handles the "competition" between modules and selects fusion coefficients for the domains.
20 def update_gw_state(self, gw_state: torch.Tensor) -> None: 21 """ 22 Update the internal copy of the previous GW state. 23 By default, this is not implemented and will raise an error if used. 24 25 :note.. 26 This is not defined as an abstractmethod as some selection method may 27 not need it. 28 29 Args: 30 gw_state (`torch.Tensor`): the previous GW state 31 """ 32 pass
Update the internal copy of the previous GW state. By default, this is not implemented and will raise an error if used.
:note.. This is not defined as an abstractmethod as some selection method may not need it.
Arguments:
- gw_state (
torch.Tensor): the previous GW state
34 @abstractmethod 35 def forward( 36 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 37 ) -> dict[str, torch.Tensor]: 38 """ 39 Forward pass of the selection method. 40 41 Args: 42 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 43 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 44 representation. 45 46 Returns: 47 `dict[str, torch.Tensor]`: for each domain in the group, the fusion 48 coefficient for each item in the batch. 49 50 Example: 51 >>> SomeSelectionImplementation().forward( 52 ... {"v": torch.randn(3, 4), "t": torch.randn(3, 8)} 53 ... ) 54 {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])} 55 """ 56 ...
Forward pass of the selection method.
Arguments:
- domains (
LatentsDomainGroupT): Group of unimodal latent representations. - encodings_pre_fusion (
LatentsDomainGroupT): pre-fusion domain latent representation.
Returns:
dict[str, torch.Tensor]: for each domain in the group, the fusion coefficient for each item in the batch.
Example:
>>> SomeSelectionImplementation().forward( ... {"v": torch.randn(3, 4), "t": torch.randn(3, 8)} ... ) {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
65class SingleDomainSelection(SelectionBase): 66 """ 67 This selection mechanism handles groups that can have multiple domains, but always 68 return a selection of 1 domain from the group with a uniform distribution. 69 70 For example, if the group has 2 domains, there is a 50% chance of selecting each 71 domain. 72 """ 73 74 def forward( 75 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 76 ) -> dict[str, torch.Tensor]: 77 """ 78 Forward pass of the module. 79 80 Args: 81 domains (`LatentsDomainGroupT`): input unimodal latent representations 82 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 83 representation. 84 85 Returns: 86 `dict[str, torch.Tensor]`: whether the domain is selected for each input 87 in the batch. 88 """ 89 selection: dict[str, torch.Tensor] = {} 90 bs = group_batch_size(domains) 91 choice = torch.randint(len(domains), size=(bs,), device=group_device(domains)) 92 for k, domain in enumerate(domains.keys()): 93 selection[domain] = (choice == k).to(torch.float32) 94 return selection
This selection mechanism handles groups that can have multiple domains, but always return a selection of 1 domain from the group with a uniform distribution.
For example, if the group has 2 domains, there is a 50% chance of selecting each domain.
74 def forward( 75 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 76 ) -> dict[str, torch.Tensor]: 77 """ 78 Forward pass of the module. 79 80 Args: 81 domains (`LatentsDomainGroupT`): input unimodal latent representations 82 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 83 representation. 84 85 Returns: 86 `dict[str, torch.Tensor]`: whether the domain is selected for each input 87 in the batch. 88 """ 89 selection: dict[str, torch.Tensor] = {} 90 bs = group_batch_size(domains) 91 choice = torch.randint(len(domains), size=(bs,), device=group_device(domains)) 92 for k, domain in enumerate(domains.keys()): 93 selection[domain] = (choice == k).to(torch.float32) 94 return selection
Forward pass of the module.
Arguments:
- domains (
LatentsDomainGroupT): input unimodal latent representations - encodings_pre_fusion (
LatentsDomainGroupT): pre-fusion domain latent representation.
Returns:
dict[str, torch.Tensor]: whether the domain is selected for each input in the batch.
Inherited Members
157class LearnedAttention(SelectionBase): 158 """ 159 Content-based single-step attention over GW latents with configurable toggles. 160 161 Design: 162 - Query is the mean of available GW latents (content-q0 seed) 163 - Single-step dot-product attention over domains (no refinement loop) 164 - Optional per-domain keys 165 166 Toggles: 167 - per_domain_keys: use per-domain key projections instead of a shared one 168 - stopgrad: detach GW latents before computing keys/query 169 - key_on_prefusion: compute keys on pre-fusion GW latents (True) or raw domains 170 - domain_dims: required when key_on_prefusion=False to size per-domain key layers 171 """ 172 173 def __init__( 174 self, 175 gw_dim: int, 176 domain_names: Iterable[str], 177 head_size: int = 64, 178 per_domain_keys: bool = False, 179 stopgrad: bool = True, 180 key_on_prefusion: bool = True, 181 domain_dims: Mapping[str, int] | None = None, 182 ): 183 super().__init__() 184 self.gw_dim = int(gw_dim) 185 self.head_size = int(head_size) 186 self.domain_names = list(domain_names) 187 188 # Toggles 189 self.per_domain_keys = bool(per_domain_keys) 190 self.stopgrad = bool(stopgrad) 191 self.key_on_prefusion = bool(key_on_prefusion) 192 self.domain_dims = dict(domain_dims) if domain_dims is not None else None 193 194 # Projections 195 self.query_layer = nn.Linear(self.gw_dim, self.head_size) 196 self.per_key_layers: nn.ModuleDict | None 197 self.shared_key_layer: nn.Linear | None 198 if self.key_on_prefusion: 199 if self.per_domain_keys: 200 self.per_key_layers = nn.ModuleDict( 201 { 202 d: nn.Linear(self.gw_dim, self.head_size) 203 for d in self.domain_names 204 } 205 ) 206 self.shared_key_layer = None 207 else: 208 self.shared_key_layer = nn.Linear(self.gw_dim, self.head_size) 209 self.per_key_layers = None 210 else: 211 if not self.per_domain_keys: 212 raise ValueError( 213 "key_on_prefusion=False requires per_domain_keys=True because " 214 "domain latent dimensions can differ." 215 ) 216 if self.domain_dims is None: 217 raise ValueError( 218 "key_on_prefusion=False requires domain_dims for key projections." 219 ) 220 missing_dims = [d for d in self.domain_names if d not in self.domain_dims] 221 if missing_dims: 222 raise ValueError( 223 f"Missing domain_dims for: {', '.join(sorted(missing_dims))}" 224 ) 225 self.per_key_layers = nn.ModuleDict( 226 { 227 d: nn.Linear(self.domain_dims[d], self.head_size) 228 for d in self.domain_names 229 } 230 ) 231 self.shared_key_layer = None 232 233 @staticmethod 234 def _calc_attention( 235 keys: dict[str, torch.Tensor], 236 query: torch.Tensor, 237 order: Iterable[str], 238 ) -> dict[str, torch.Tensor]: 239 """ 240 Compute attention over domains. 241 242 Args: 243 keys: mapping of domain -> key tensor (B, H) 244 query: query tensor (B, H) 245 order: iterable of domain names to fix output ordering 246 247 Returns: 248 dict[str, torch.Tensor]: per-domain attention scores that sum to 1. 249 """ 250 names = [d for d in order if d in keys] 251 if not names: 252 raise ValueError("LearnedAttention: no keys provided.") 253 254 logits = torch.stack( 255 [(keys[d] * query).sum(dim=1) for d in names], dim=1 256 ) # (B, D) 257 258 probs = torch.softmax(logits, dim=1) 259 260 return {d: probs[:, i] for i, d in enumerate(names)} 261 262 def forward( 263 self, 264 domains: LatentsDomainGroupT, 265 encodings_pre_fusion: LatentsDomainGroupT | None = None, 266 ) -> dict[str, torch.Tensor]: 267 """ 268 Args: 269 domains: mapping from domain name to GW latent (B, gw_dim) 270 encodings_pre_fusion: pre-fusion encodings (used when key_on_prefusion) 271 272 Returns: 273 dict[str, torch.Tensor]: per-domain attention weights. 274 """ 275 domain_latents: Mapping[str, torch.Tensor] = domains 276 277 present = [d for d in self.domain_names if d in domain_latents] 278 if not present: 279 raise ValueError( 280 "LearnedAttention: no known domains present in gw_latents." 281 ) 282 283 if self.key_on_prefusion: 284 if encodings_pre_fusion is None: 285 raise ValueError( 286 "key_on_prefusion=True requires encodings_pre_fusion inputs." 287 ) 288 key_source = encodings_pre_fusion 289 else: 290 key_source = domain_latents 291 292 missing_keys = [d for d in present if d not in key_source] 293 if missing_keys: 294 raise ValueError( 295 f"Missing key latents for: {', '.join(sorted(missing_keys))}" 296 ) 297 298 if encodings_pre_fusion is None: 299 query_source = domain_latents 300 else: 301 query_source = encodings_pre_fusion 302 303 missing_query = [d for d in present if d not in query_source] 304 if missing_query: 305 raise ValueError( 306 f"Missing query latents for: {', '.join(sorted(missing_query))}" 307 ) 308 309 if self.stopgrad: 310 key_latents = {d: key_source[d].detach() for d in present} 311 query_latents = {d: query_source[d].detach() for d in present} 312 else: 313 key_latents = {d: key_source[d] for d in present} 314 query_latents = {d: query_source[d] for d in present} 315 316 if self.per_domain_keys: 317 if self.per_key_layers is None: 318 raise RuntimeError( 319 "per_domain_keys=True but per-domain key layers are missing." 320 ) 321 keys = { 322 d: cast(nn.Linear, self.per_key_layers[d])(key_latents[d]) 323 for d in present 324 } 325 else: 326 if self.shared_key_layer is None: 327 raise RuntimeError( 328 "per_domain_keys=False but shared key layer is missing." 329 ) 330 proj = self.shared_key_layer 331 keys = {d: proj(key_latents[d]) for d in present} 332 333 stacked = torch.stack([query_latents[d] for d in present], dim=0) # (D, B, F) 334 query = self.query_layer(stacked.mean(0)) # (B, H) 335 336 return self._calc_attention( 337 keys=keys, 338 query=query, 339 order=self.domain_names, 340 )
Content-based single-step attention over GW latents with configurable toggles.
Design:
- Query is the mean of available GW latents (content-q0 seed)
- Single-step dot-product attention over domains (no refinement loop)
- Optional per-domain keys
Toggles:
- per_domain_keys: use per-domain key projections instead of a shared one
- stopgrad: detach GW latents before computing keys/query
- key_on_prefusion: compute keys on pre-fusion GW latents (True) or raw domains
- domain_dims: required when key_on_prefusion=False to size per-domain key layers
173 def __init__( 174 self, 175 gw_dim: int, 176 domain_names: Iterable[str], 177 head_size: int = 64, 178 per_domain_keys: bool = False, 179 stopgrad: bool = True, 180 key_on_prefusion: bool = True, 181 domain_dims: Mapping[str, int] | None = None, 182 ): 183 super().__init__() 184 self.gw_dim = int(gw_dim) 185 self.head_size = int(head_size) 186 self.domain_names = list(domain_names) 187 188 # Toggles 189 self.per_domain_keys = bool(per_domain_keys) 190 self.stopgrad = bool(stopgrad) 191 self.key_on_prefusion = bool(key_on_prefusion) 192 self.domain_dims = dict(domain_dims) if domain_dims is not None else None 193 194 # Projections 195 self.query_layer = nn.Linear(self.gw_dim, self.head_size) 196 self.per_key_layers: nn.ModuleDict | None 197 self.shared_key_layer: nn.Linear | None 198 if self.key_on_prefusion: 199 if self.per_domain_keys: 200 self.per_key_layers = nn.ModuleDict( 201 { 202 d: nn.Linear(self.gw_dim, self.head_size) 203 for d in self.domain_names 204 } 205 ) 206 self.shared_key_layer = None 207 else: 208 self.shared_key_layer = nn.Linear(self.gw_dim, self.head_size) 209 self.per_key_layers = None 210 else: 211 if not self.per_domain_keys: 212 raise ValueError( 213 "key_on_prefusion=False requires per_domain_keys=True because " 214 "domain latent dimensions can differ." 215 ) 216 if self.domain_dims is None: 217 raise ValueError( 218 "key_on_prefusion=False requires domain_dims for key projections." 219 ) 220 missing_dims = [d for d in self.domain_names if d not in self.domain_dims] 221 if missing_dims: 222 raise ValueError( 223 f"Missing domain_dims for: {', '.join(sorted(missing_dims))}" 224 ) 225 self.per_key_layers = nn.ModuleDict( 226 { 227 d: nn.Linear(self.domain_dims[d], self.head_size) 228 for d in self.domain_names 229 } 230 ) 231 self.shared_key_layer = None
Initialize internal Module state, shared by both nn.Module and ScriptModule.
262 def forward( 263 self, 264 domains: LatentsDomainGroupT, 265 encodings_pre_fusion: LatentsDomainGroupT | None = None, 266 ) -> dict[str, torch.Tensor]: 267 """ 268 Args: 269 domains: mapping from domain name to GW latent (B, gw_dim) 270 encodings_pre_fusion: pre-fusion encodings (used when key_on_prefusion) 271 272 Returns: 273 dict[str, torch.Tensor]: per-domain attention weights. 274 """ 275 domain_latents: Mapping[str, torch.Tensor] = domains 276 277 present = [d for d in self.domain_names if d in domain_latents] 278 if not present: 279 raise ValueError( 280 "LearnedAttention: no known domains present in gw_latents." 281 ) 282 283 if self.key_on_prefusion: 284 if encodings_pre_fusion is None: 285 raise ValueError( 286 "key_on_prefusion=True requires encodings_pre_fusion inputs." 287 ) 288 key_source = encodings_pre_fusion 289 else: 290 key_source = domain_latents 291 292 missing_keys = [d for d in present if d not in key_source] 293 if missing_keys: 294 raise ValueError( 295 f"Missing key latents for: {', '.join(sorted(missing_keys))}" 296 ) 297 298 if encodings_pre_fusion is None: 299 query_source = domain_latents 300 else: 301 query_source = encodings_pre_fusion 302 303 missing_query = [d for d in present if d not in query_source] 304 if missing_query: 305 raise ValueError( 306 f"Missing query latents for: {', '.join(sorted(missing_query))}" 307 ) 308 309 if self.stopgrad: 310 key_latents = {d: key_source[d].detach() for d in present} 311 query_latents = {d: query_source[d].detach() for d in present} 312 else: 313 key_latents = {d: key_source[d] for d in present} 314 query_latents = {d: query_source[d] for d in present} 315 316 if self.per_domain_keys: 317 if self.per_key_layers is None: 318 raise RuntimeError( 319 "per_domain_keys=True but per-domain key layers are missing." 320 ) 321 keys = { 322 d: cast(nn.Linear, self.per_key_layers[d])(key_latents[d]) 323 for d in present 324 } 325 else: 326 if self.shared_key_layer is None: 327 raise RuntimeError( 328 "per_domain_keys=False but shared key layer is missing." 329 ) 330 proj = self.shared_key_layer 331 keys = {d: proj(key_latents[d]) for d in present} 332 333 stacked = torch.stack([query_latents[d] for d in present], dim=0) # (D, B, F) 334 query = self.query_layer(stacked.mean(0)) # (B, H) 335 336 return self._calc_attention( 337 keys=keys, 338 query=query, 339 order=self.domain_names, 340 )
Arguments:
- domains: mapping from domain name to GW latent (B, gw_dim)
- encodings_pre_fusion: pre-fusion encodings (used when key_on_prefusion)
Returns:
dict[str, torch.Tensor]: per-domain attention weights.
Inherited Members
343class RandomSelection(SelectionBase): 344 """ 345 Modified random attention to only utilize uniform-softmax scores across modalities. 346 This version omits the binary scaling factors and focuses on generating attention 347 coefficients using a uniform distribution followed by a domain-wise softmax. 348 """ 349 350 def __init__(self, temperature: float): 351 """ 352 Args: 353 temperature (`float`): Temperature of the softmax applied to uniform 354 scaling factors. 355 """ 356 super().__init__() 357 self.temperature = temperature 358 359 def forward( 360 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 361 ) -> dict[str, torch.Tensor]: 362 """ 363 Generate uniform-then-domain-wise-softmaxed samples for each domain. 364 365 Args: 366 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 367 This is not used in the function directly but determines the structure 368 of the returned attention coefficients. 369 370 Returns: 371 `dict[str, torch.Tensor]`: For each domain in the group, the fusion 372 coefficient for each item in the batch, based solely on 373 uniform-softmax scores. 374 """ 375 num_domains = len(domains) 376 batch_size = group_batch_size(domains) 377 device = group_device(domains) 378 379 # Generate uniform scores 380 uniform_scores = torch.rand(batch_size, num_domains, device=device) 381 382 # Apply softmax across domains with temperature scaling 383 softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1) 384 # Create attention dictionary for each domain 385 attention_dict = { 386 domain: softmax_scores[:, i] for i, domain in enumerate(domains) 387 } 388 389 return attention_dict
Modified random attention to only utilize uniform-softmax scores across modalities. This version omits the binary scaling factors and focuses on generating attention coefficients using a uniform distribution followed by a domain-wise softmax.
350 def __init__(self, temperature: float): 351 """ 352 Args: 353 temperature (`float`): Temperature of the softmax applied to uniform 354 scaling factors. 355 """ 356 super().__init__() 357 self.temperature = temperature
Arguments:
- temperature (
float): Temperature of the softmax applied to uniform scaling factors.
359 def forward( 360 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 361 ) -> dict[str, torch.Tensor]: 362 """ 363 Generate uniform-then-domain-wise-softmaxed samples for each domain. 364 365 Args: 366 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 367 This is not used in the function directly but determines the structure 368 of the returned attention coefficients. 369 370 Returns: 371 `dict[str, torch.Tensor]`: For each domain in the group, the fusion 372 coefficient for each item in the batch, based solely on 373 uniform-softmax scores. 374 """ 375 num_domains = len(domains) 376 batch_size = group_batch_size(domains) 377 device = group_device(domains) 378 379 # Generate uniform scores 380 uniform_scores = torch.rand(batch_size, num_domains, device=device) 381 382 # Apply softmax across domains with temperature scaling 383 softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1) 384 # Create attention dictionary for each domain 385 attention_dict = { 386 domain: softmax_scores[:, i] for i, domain in enumerate(domains) 387 } 388 389 return attention_dict
Generate uniform-then-domain-wise-softmaxed samples for each domain.
Arguments:
- domains (
LatentsDomainGroupT): Group of unimodal latent representations. This is not used in the function directly but determines the structure of the returned attention coefficients.
Returns:
dict[str, torch.Tensor]: For each domain in the group, the fusion coefficient for each item in the batch, based solely on uniform-softmax scores.
Inherited Members
392class DynamicQueryAttention(SelectionBase): 393 """ 394 Key-Query attention with a dynamic gw vector. 395 The query is updated based on the scaled gw vector. 396 """ 397 398 def __init__( 399 self, 400 head_size: int, 401 domain_dim: int, 402 domain_names: Iterable[str], 403 n_steps: int = 1, 404 ): 405 """ 406 Args: 407 head_size (`int`) : dimension of the key and query vectors. 408 domain_dim (`int`) : dimension of the input dims (assumed to be the same 409 for now) 410 domain_names (`Iterable[str]`) : list of input domains 411 n_steps (`int`) : number of steps to update the query vector 412 """ 413 super().__init__() 414 self.head_size = head_size 415 self.query_layer = nn.Linear(domain_dim, head_size) 416 self.key_layers = nn.ModuleDict( 417 {domain: nn.Linear(domain_dim, head_size) for domain in domain_names} 418 ) 419 self.n_steps = n_steps 420 self.step_limit = n_steps # Default step limit is n_steps 421 # Start with a random gw state 422 self.register_buffer("initial_gw_state", torch.rand(domain_dim)) 423 424 def set_step_limit(self, step_limit: int): 425 """ 426 Sets the step limit for the dynamic attention update loop. 427 428 Args: 429 step_limit (`int`): Maximum number of steps to run the loop. 430 """ 431 if step_limit > self.n_steps: 432 raise ValueError( 433 f"Step limit cannot exceed the maximum n_steps ({self.n_steps})." 434 ) 435 self.step_limit = step_limit 436 437 def fuse_weighted_encodings( 438 self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor] 439 ) -> torch.Tensor: 440 """ 441 Fuse the weighted encodings using the attention scores. 442 443 Args: 444 encodings (`LatentsDomainGroupT`): Unimodal latent representation 445 attention_dict (`dict[str, torch.Tensor]`): The attention scores for each 446 domain in the group. 447 448 Returns: 449 `torch.Tensor`: The fused tensor. 450 """ 451 # Apply attention scores to the encodings 452 weighted_encodings = {} 453 for key in attention_dict: 454 if key in encodings: 455 # Perform element-wise multiplication 456 weighted_encodings[key] = ( 457 attention_dict[key].unsqueeze(1) * encodings[key] 458 ) 459 460 # Stack the tensors along a new dimension (dimension 0) 461 stacked_tensors = torch.stack(list(weighted_encodings.values())) 462 463 # Apply fusion by summing along the newly created dimension 464 summed_tensor = torch.sum(stacked_tensors, dim=0) 465 return summed_tensor 466 467 def forward( 468 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 469 ) -> dict[str, torch.Tensor]: 470 """ 471 Compute keys and queries, match them with dot product and softmax. 472 Does this twice, once with the static query and once with a dynamic query. 473 474 Args: 475 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 476 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 477 representation. 478 479 Returns: 480 `dict[str, torch.Tensor]`: the attention scores for each domain in the 481 group. 482 """ 483 484 keys = { 485 domain: self.key_layers[domain](encoding) 486 for domain, encoding in domains.items() 487 } 488 489 batch_size = group_batch_size(domains) 490 491 # Retrieve random query 492 query = self.query_layer(self.initial_gw_state.expand(batch_size, -1)) 493 494 # Calculate the attention scores 495 attention_dict = _calculate_attention_dict(domains, keys, query) 496 497 if self.n_steps > 0: 498 # Update the query based on the static attention scores 499 for _ in range(min(self.step_limit, self.n_steps)): 500 # Apply the attention scores to the encodings 501 summed_tensor = self.fuse_weighted_encodings( 502 encodings_pre_fusion, attention_dict 503 ) 504 505 # Retrieve query (now it is dependent on the new gw state) 506 query = self.query_layer(summed_tensor) 507 508 # Calculate the attention scores again 509 attention_dict = _calculate_attention_dict(domains, keys, query) 510 511 return attention_dict
Key-Query attention with a dynamic gw vector. The query is updated based on the scaled gw vector.
398 def __init__( 399 self, 400 head_size: int, 401 domain_dim: int, 402 domain_names: Iterable[str], 403 n_steps: int = 1, 404 ): 405 """ 406 Args: 407 head_size (`int`) : dimension of the key and query vectors. 408 domain_dim (`int`) : dimension of the input dims (assumed to be the same 409 for now) 410 domain_names (`Iterable[str]`) : list of input domains 411 n_steps (`int`) : number of steps to update the query vector 412 """ 413 super().__init__() 414 self.head_size = head_size 415 self.query_layer = nn.Linear(domain_dim, head_size) 416 self.key_layers = nn.ModuleDict( 417 {domain: nn.Linear(domain_dim, head_size) for domain in domain_names} 418 ) 419 self.n_steps = n_steps 420 self.step_limit = n_steps # Default step limit is n_steps 421 # Start with a random gw state 422 self.register_buffer("initial_gw_state", torch.rand(domain_dim))
Arguments:
- head_size (
int) : dimension of the key and query vectors. - domain_dim (
int) : dimension of the input dims (assumed to be the same for now) - domain_names (
Iterable[str]) : list of input domains - n_steps (
int) : number of steps to update the query vector
424 def set_step_limit(self, step_limit: int): 425 """ 426 Sets the step limit for the dynamic attention update loop. 427 428 Args: 429 step_limit (`int`): Maximum number of steps to run the loop. 430 """ 431 if step_limit > self.n_steps: 432 raise ValueError( 433 f"Step limit cannot exceed the maximum n_steps ({self.n_steps})." 434 ) 435 self.step_limit = step_limit
Sets the step limit for the dynamic attention update loop.
Arguments:
- step_limit (
int): Maximum number of steps to run the loop.
437 def fuse_weighted_encodings( 438 self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor] 439 ) -> torch.Tensor: 440 """ 441 Fuse the weighted encodings using the attention scores. 442 443 Args: 444 encodings (`LatentsDomainGroupT`): Unimodal latent representation 445 attention_dict (`dict[str, torch.Tensor]`): The attention scores for each 446 domain in the group. 447 448 Returns: 449 `torch.Tensor`: The fused tensor. 450 """ 451 # Apply attention scores to the encodings 452 weighted_encodings = {} 453 for key in attention_dict: 454 if key in encodings: 455 # Perform element-wise multiplication 456 weighted_encodings[key] = ( 457 attention_dict[key].unsqueeze(1) * encodings[key] 458 ) 459 460 # Stack the tensors along a new dimension (dimension 0) 461 stacked_tensors = torch.stack(list(weighted_encodings.values())) 462 463 # Apply fusion by summing along the newly created dimension 464 summed_tensor = torch.sum(stacked_tensors, dim=0) 465 return summed_tensor
Fuse the weighted encodings using the attention scores.
Arguments:
- encodings (
LatentsDomainGroupT): Unimodal latent representation - attention_dict (
dict[str, torch.Tensor]): The attention scores for each domain in the group.
Returns:
torch.Tensor: The fused tensor.
467 def forward( 468 self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT 469 ) -> dict[str, torch.Tensor]: 470 """ 471 Compute keys and queries, match them with dot product and softmax. 472 Does this twice, once with the static query and once with a dynamic query. 473 474 Args: 475 domains (`LatentsDomainGroupT`): Group of unimodal latent representations. 476 encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent 477 representation. 478 479 Returns: 480 `dict[str, torch.Tensor]`: the attention scores for each domain in the 481 group. 482 """ 483 484 keys = { 485 domain: self.key_layers[domain](encoding) 486 for domain, encoding in domains.items() 487 } 488 489 batch_size = group_batch_size(domains) 490 491 # Retrieve random query 492 query = self.query_layer(self.initial_gw_state.expand(batch_size, -1)) 493 494 # Calculate the attention scores 495 attention_dict = _calculate_attention_dict(domains, keys, query) 496 497 if self.n_steps > 0: 498 # Update the query based on the static attention scores 499 for _ in range(min(self.step_limit, self.n_steps)): 500 # Apply the attention scores to the encodings 501 summed_tensor = self.fuse_weighted_encodings( 502 encodings_pre_fusion, attention_dict 503 ) 504 505 # Retrieve query (now it is dependent on the new gw state) 506 query = self.query_layer(summed_tensor) 507 508 # Calculate the attention scores again 509 attention_dict = _calculate_attention_dict(domains, keys, query) 510 511 return attention_dict
Compute keys and queries, match them with dot product and softmax. Does this twice, once with the static query and once with a dynamic query.
Arguments:
- domains (
LatentsDomainGroupT): Group of unimodal latent representations. - encodings_pre_fusion (
LatentsDomainGroupT): pre-fusion domain latent representation.
Returns:
dict[str, torch.Tensor]: the attention scores for each domain in the group.