shimmer.modules.selection

  1from abc import ABC, abstractmethod
  2from collections.abc import Iterable, Mapping
  3from typing import cast
  4
  5import torch
  6import torch.nn as nn
  7
  8from shimmer.types import LatentsDomainGroupT
  9from shimmer.utils import group_batch_size, group_device
 10
 11
 12class SelectionBase(torch.nn.Module, ABC):
 13    """
 14    This is the base class for the selection mechanism.
 15    The selection mechanisms handles the "competition" between modules and *selects*
 16    fusion coefficients for the domains.
 17    """
 18
 19    def update_gw_state(self, gw_state: torch.Tensor) -> None:
 20        """
 21        Update the internal copy of the previous GW state.
 22        By default, this is not implemented and will raise an error if used.
 23
 24        :note..
 25            This is not defined as an abstractmethod as some selection method may
 26            not need it.
 27
 28        Args:
 29            gw_state (`torch.Tensor`): the previous GW state
 30        """
 31        pass
 32
 33    @abstractmethod
 34    def forward(
 35        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
 36    ) -> dict[str, torch.Tensor]:
 37        """
 38        Forward pass of the selection method.
 39
 40        Args:
 41            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
 42            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
 43                representation.
 44
 45        Returns:
 46            `dict[str, torch.Tensor]`: for each domain in the group, the fusion
 47            coefficient for each item in the batch.
 48
 49        Example:
 50            >>> SomeSelectionImplementation().forward(
 51            ...     {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
 52            ... )
 53            {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
 54        """
 55        ...
 56
 57    # This is just for proper auto-completion...
 58    def __call__(
 59        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
 60    ) -> dict[str, torch.Tensor]:
 61        return super().__call__(domains, encodings_pre_fusion)
 62
 63
 64class SingleDomainSelection(SelectionBase):
 65    """
 66    This selection mechanism handles groups that can have multiple domains, but always
 67    return a selection of 1 domain from the group with a uniform distribution.
 68
 69    For example, if the group has 2 domains, there is a 50% chance of selecting each
 70    domain.
 71    """
 72
 73    def forward(
 74        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
 75    ) -> dict[str, torch.Tensor]:
 76        """
 77        Forward pass of the module.
 78
 79        Args:
 80            domains (`LatentsDomainGroupT`): input unimodal latent representations
 81            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
 82                representation.
 83
 84        Returns:
 85            `dict[str, torch.Tensor]`: whether the domain is selected for each input
 86            in the batch.
 87        """
 88        selection: dict[str, torch.Tensor] = {}
 89        bs = group_batch_size(domains)
 90        choice = torch.randint(len(domains), size=(bs,), device=group_device(domains))
 91        for k, domain in enumerate(domains.keys()):
 92            selection[domain] = (choice == k).to(torch.float32)
 93        return selection
 94
 95
 96class FixedSharedSelection(SelectionBase):
 97    """
 98    This selection mechanism is deterministic and always shares the weights equally
 99    between domains.
100
101    For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...
102    """
103
104    def forward(
105        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
106    ) -> dict[str, torch.Tensor]:
107        """
108        Forward pass of the module.
109
110        Args:
111            domains (`LatentsDomainGroupT`): input unimodal latent representations
112            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
113                representation.
114
115        Returns:
116            `dict[str, torch.Tensor]`: whether the domain is selected for each input
117            in the batch.
118        """
119        selection: dict[str, torch.Tensor] = {}
120        bs = group_batch_size(domains)
121        coef = torch.full((bs,), 1.0 / len(domains), device=group_device(domains))
122        for domain in domains:
123            selection[domain] = coef.clone()
124        return selection
125
126
127def _calculate_attention_dict(
128    domains: LatentsDomainGroupT,
129    keys: dict[str, torch.Tensor],
130    query: torch.Tensor,
131) -> dict[str, torch.Tensor]:
132    """
133    Args:
134        domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
135        keys (`dict[str, torch.Tensor]`): The keys for each domain.
136        query (`torch.Tensor`): The query tensor.
137
138    Returns:
139        `dict[str, torch.Tensor]`: The attention scores for each domain.
140    """
141    dot_products = {
142        domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze()
143        for domain, key in keys.items()
144    }
145
146    dot_products_tensor = torch.stack(list(dot_products.values()), dim=1)
147
148    attention_scores = torch.softmax(dot_products_tensor, dim=1)
149
150    attention_dict = {
151        domain: attention_scores[:, i] for i, domain in enumerate(domains)
152    }
153    return attention_dict
154
155
156class LearnedAttention(SelectionBase):
157    """
158    Content-based single-step attention over GW latents with configurable toggles.
159
160    Design:
161    - Query is the mean of available GW latents (content-q0 seed)
162    - Single-step dot-product attention over domains (no refinement loop)
163    - Optional per-domain keys
164
165    Toggles:
166    - per_domain_keys: use per-domain key projections instead of a shared one
167    - stopgrad: detach GW latents before computing keys/query
168    - key_on_prefusion: compute keys on pre-fusion GW latents (True) or raw domains
169    - domain_dims: required when key_on_prefusion=False to size per-domain key layers
170    """
171
172    def __init__(
173        self,
174        gw_dim: int,
175        domain_names: Iterable[str],
176        head_size: int = 64,
177        per_domain_keys: bool = False,
178        stopgrad: bool = True,
179        key_on_prefusion: bool = True,
180        domain_dims: Mapping[str, int] | None = None,
181    ):
182        super().__init__()
183        self.gw_dim = int(gw_dim)
184        self.head_size = int(head_size)
185        self.domain_names = list(domain_names)
186
187        # Toggles
188        self.per_domain_keys = bool(per_domain_keys)
189        self.stopgrad = bool(stopgrad)
190        self.key_on_prefusion = bool(key_on_prefusion)
191        self.domain_dims = dict(domain_dims) if domain_dims is not None else None
192
193        # Projections
194        self.query_layer = nn.Linear(self.gw_dim, self.head_size)
195        self.per_key_layers: nn.ModuleDict | None
196        self.shared_key_layer: nn.Linear | None
197        if self.key_on_prefusion:
198            if self.per_domain_keys:
199                self.per_key_layers = nn.ModuleDict(
200                    {
201                        d: nn.Linear(self.gw_dim, self.head_size)
202                        for d in self.domain_names
203                    }
204                )
205                self.shared_key_layer = None
206            else:
207                self.shared_key_layer = nn.Linear(self.gw_dim, self.head_size)
208                self.per_key_layers = None
209        else:
210            if not self.per_domain_keys:
211                raise ValueError(
212                    "key_on_prefusion=False requires per_domain_keys=True because "
213                    "domain latent dimensions can differ."
214                )
215            if self.domain_dims is None:
216                raise ValueError(
217                    "key_on_prefusion=False requires domain_dims for key projections."
218                )
219            missing_dims = [d for d in self.domain_names if d not in self.domain_dims]
220            if missing_dims:
221                raise ValueError(
222                    f"Missing domain_dims for: {', '.join(sorted(missing_dims))}"
223                )
224            self.per_key_layers = nn.ModuleDict(
225                {
226                    d: nn.Linear(self.domain_dims[d], self.head_size)
227                    for d in self.domain_names
228                }
229            )
230            self.shared_key_layer = None
231
232    @staticmethod
233    def _calc_attention(
234        keys: dict[str, torch.Tensor],
235        query: torch.Tensor,
236        order: Iterable[str],
237    ) -> dict[str, torch.Tensor]:
238        """
239        Compute attention over domains.
240
241        Args:
242            keys: mapping of domain -> key tensor (B, H)
243            query: query tensor (B, H)
244            order: iterable of domain names to fix output ordering
245
246        Returns:
247            dict[str, torch.Tensor]: per-domain attention scores that sum to 1.
248        """
249        names = [d for d in order if d in keys]
250        if not names:
251            raise ValueError("LearnedAttention: no keys provided.")
252
253        logits = torch.stack(
254            [(keys[d] * query).sum(dim=1) for d in names], dim=1
255        )  # (B, D)
256
257        probs = torch.softmax(logits, dim=1)
258
259        return {d: probs[:, i] for i, d in enumerate(names)}
260
261    def forward(
262        self,
263        domains: LatentsDomainGroupT,
264        encodings_pre_fusion: LatentsDomainGroupT | None = None,
265    ) -> dict[str, torch.Tensor]:
266        """
267        Args:
268            domains: mapping from domain name to GW latent (B, gw_dim)
269            encodings_pre_fusion: pre-fusion encodings (used when key_on_prefusion)
270
271        Returns:
272            dict[str, torch.Tensor]: per-domain attention weights.
273        """
274        domain_latents: Mapping[str, torch.Tensor] = domains
275
276        present = [d for d in self.domain_names if d in domain_latents]
277        if not present:
278            raise ValueError(
279                "LearnedAttention: no known domains present in gw_latents."
280            )
281
282        if self.key_on_prefusion:
283            if encodings_pre_fusion is None:
284                raise ValueError(
285                    "key_on_prefusion=True requires encodings_pre_fusion inputs."
286                )
287            key_source = encodings_pre_fusion
288        else:
289            key_source = domain_latents
290
291        missing_keys = [d for d in present if d not in key_source]
292        if missing_keys:
293            raise ValueError(
294                f"Missing key latents for: {', '.join(sorted(missing_keys))}"
295            )
296
297        if encodings_pre_fusion is None:
298            query_source = domain_latents
299        else:
300            query_source = encodings_pre_fusion
301
302        missing_query = [d for d in present if d not in query_source]
303        if missing_query:
304            raise ValueError(
305                f"Missing query latents for: {', '.join(sorted(missing_query))}"
306            )
307
308        if self.stopgrad:
309            key_latents = {d: key_source[d].detach() for d in present}
310            query_latents = {d: query_source[d].detach() for d in present}
311        else:
312            key_latents = {d: key_source[d] for d in present}
313            query_latents = {d: query_source[d] for d in present}
314
315        if self.per_domain_keys:
316            if self.per_key_layers is None:
317                raise RuntimeError(
318                    "per_domain_keys=True but per-domain key layers are missing."
319                )
320            keys = {
321                d: cast(nn.Linear, self.per_key_layers[d])(key_latents[d])
322                for d in present
323            }
324        else:
325            if self.shared_key_layer is None:
326                raise RuntimeError(
327                    "per_domain_keys=False but shared key layer is missing."
328                )
329            proj = self.shared_key_layer
330            keys = {d: proj(key_latents[d]) for d in present}
331
332        stacked = torch.stack([query_latents[d] for d in present], dim=0)  # (D, B, F)
333        query = self.query_layer(stacked.mean(0))  # (B, H)
334
335        return self._calc_attention(
336            keys=keys,
337            query=query,
338            order=self.domain_names,
339        )
340
341
342class RandomSelection(SelectionBase):
343    """
344    Modified random attention to only utilize uniform-softmax scores across modalities.
345    This version omits the binary scaling factors and focuses on generating attention
346    coefficients using a uniform distribution followed by a domain-wise softmax.
347    """
348
349    def __init__(self, temperature: float):
350        """
351        Args:
352            temperature (`float`): Temperature of the softmax applied to uniform
353                scaling factors.
354        """
355        super().__init__()
356        self.temperature = temperature
357
358    def forward(
359        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
360    ) -> dict[str, torch.Tensor]:
361        """
362        Generate uniform-then-domain-wise-softmaxed samples for each domain.
363
364        Args:
365            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
366                This is not used in the function directly but determines the structure
367                of the returned attention coefficients.
368
369        Returns:
370            `dict[str, torch.Tensor]`: For each domain in the group, the fusion
371            coefficient for each item in the batch, based solely on
372            uniform-softmax scores.
373        """
374        num_domains = len(domains)
375        batch_size = group_batch_size(domains)
376        device = group_device(domains)
377
378        # Generate uniform scores
379        uniform_scores = torch.rand(batch_size, num_domains, device=device)
380
381        # Apply softmax across domains with temperature scaling
382        softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1)
383        # Create attention dictionary for each domain
384        attention_dict = {
385            domain: softmax_scores[:, i] for i, domain in enumerate(domains)
386        }
387
388        return attention_dict
389
390
391class DynamicQueryAttention(SelectionBase):
392    """
393    Key-Query attention with a dynamic gw vector.
394    The query is updated based on the scaled gw vector.
395    """
396
397    def __init__(
398        self,
399        head_size: int,
400        domain_dim: int,
401        domain_names: Iterable[str],
402        n_steps: int = 1,
403    ):
404        """
405        Args:
406            head_size (`int`) : dimension of the key and query vectors.
407            domain_dim (`int`) : dimension of the input dims (assumed to be the same
408                for now)
409            domain_names  (`Iterable[str]`) : list of input domains
410            n_steps (`int`) : number of steps to update the query vector
411        """
412        super().__init__()
413        self.head_size = head_size
414        self.query_layer = nn.Linear(domain_dim, head_size)
415        self.key_layers = nn.ModuleDict(
416            {domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
417        )
418        self.n_steps = n_steps
419        self.step_limit = n_steps  # Default step limit is n_steps
420        # Start with a random gw state
421        self.register_buffer("initial_gw_state", torch.rand(domain_dim))
422
423    def set_step_limit(self, step_limit: int):
424        """
425        Sets the step limit for the dynamic attention update loop.
426
427        Args:
428            step_limit (`int`): Maximum number of steps to run the loop.
429        """
430        if step_limit > self.n_steps:
431            raise ValueError(
432                f"Step limit cannot exceed the maximum n_steps ({self.n_steps})."
433            )
434        self.step_limit = step_limit
435
436    def fuse_weighted_encodings(
437        self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
438    ) -> torch.Tensor:
439        """
440        Fuse the weighted encodings using the attention scores.
441
442        Args:
443            encodings (`LatentsDomainGroupT`): Unimodal latent representation
444            attention_dict (`dict[str, torch.Tensor]`): The attention scores for each
445                domain in the group.
446
447        Returns:
448            `torch.Tensor`: The fused tensor.
449        """
450        # Apply attention scores to the encodings
451        weighted_encodings = {}
452        for key in attention_dict:
453            if key in encodings:
454                # Perform element-wise multiplication
455                weighted_encodings[key] = (
456                    attention_dict[key].unsqueeze(1) * encodings[key]
457                )
458
459        # Stack the tensors along a new dimension (dimension 0)
460        stacked_tensors = torch.stack(list(weighted_encodings.values()))
461
462        # Apply fusion by summing along the newly created dimension
463        summed_tensor = torch.sum(stacked_tensors, dim=0)
464        return summed_tensor
465
466    def forward(
467        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
468    ) -> dict[str, torch.Tensor]:
469        """
470        Compute keys and queries, match them with dot product and softmax.
471        Does this twice, once with the static query and once with a dynamic query.
472
473        Args:
474            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
475            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
476                representation.
477
478        Returns:
479            `dict[str, torch.Tensor]`: the attention scores for each domain in the
480            group.
481        """
482
483        keys = {
484            domain: self.key_layers[domain](encoding)
485            for domain, encoding in domains.items()
486        }
487
488        batch_size = group_batch_size(domains)
489
490        # Retrieve random query
491        query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
492
493        # Calculate the attention scores
494        attention_dict = _calculate_attention_dict(domains, keys, query)
495
496        if self.n_steps > 0:
497            # Update the query based on the static attention scores
498            for _ in range(min(self.step_limit, self.n_steps)):
499                # Apply the attention scores to the encodings
500                summed_tensor = self.fuse_weighted_encodings(
501                    encodings_pre_fusion, attention_dict
502                )
503
504                # Retrieve query (now it is dependent on the new gw state)
505                query = self.query_layer(summed_tensor)
506
507                # Calculate the attention scores again
508                attention_dict = _calculate_attention_dict(domains, keys, query)
509
510        return attention_dict
class SelectionBase(torch.nn.modules.module.Module, abc.ABC):
13class SelectionBase(torch.nn.Module, ABC):
14    """
15    This is the base class for the selection mechanism.
16    The selection mechanisms handles the "competition" between modules and *selects*
17    fusion coefficients for the domains.
18    """
19
20    def update_gw_state(self, gw_state: torch.Tensor) -> None:
21        """
22        Update the internal copy of the previous GW state.
23        By default, this is not implemented and will raise an error if used.
24
25        :note..
26            This is not defined as an abstractmethod as some selection method may
27            not need it.
28
29        Args:
30            gw_state (`torch.Tensor`): the previous GW state
31        """
32        pass
33
34    @abstractmethod
35    def forward(
36        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
37    ) -> dict[str, torch.Tensor]:
38        """
39        Forward pass of the selection method.
40
41        Args:
42            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
43            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
44                representation.
45
46        Returns:
47            `dict[str, torch.Tensor]`: for each domain in the group, the fusion
48            coefficient for each item in the batch.
49
50        Example:
51            >>> SomeSelectionImplementation().forward(
52            ...     {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
53            ... )
54            {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
55        """
56        ...
57
58    # This is just for proper auto-completion...
59    def __call__(
60        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
61    ) -> dict[str, torch.Tensor]:
62        return super().__call__(domains, encodings_pre_fusion)

This is the base class for the selection mechanism. The selection mechanisms handles the "competition" between modules and selects fusion coefficients for the domains.

def update_gw_state(self, gw_state: torch.Tensor) -> None:
20    def update_gw_state(self, gw_state: torch.Tensor) -> None:
21        """
22        Update the internal copy of the previous GW state.
23        By default, this is not implemented and will raise an error if used.
24
25        :note..
26            This is not defined as an abstractmethod as some selection method may
27            not need it.
28
29        Args:
30            gw_state (`torch.Tensor`): the previous GW state
31        """
32        pass

Update the internal copy of the previous GW state. By default, this is not implemented and will raise an error if used.

:note.. This is not defined as an abstractmethod as some selection method may not need it.

Arguments:
  • gw_state (torch.Tensor): the previous GW state
@abstractmethod
def forward( self, domains: Mapping[str, torch.Tensor], encodings_pre_fusion: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
34    @abstractmethod
35    def forward(
36        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
37    ) -> dict[str, torch.Tensor]:
38        """
39        Forward pass of the selection method.
40
41        Args:
42            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
43            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
44                representation.
45
46        Returns:
47            `dict[str, torch.Tensor]`: for each domain in the group, the fusion
48            coefficient for each item in the batch.
49
50        Example:
51            >>> SomeSelectionImplementation().forward(
52            ...     {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
53            ... )
54            {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
55        """
56        ...

Forward pass of the selection method.

Arguments:
  • domains (LatentsDomainGroupT): Group of unimodal latent representations.
  • encodings_pre_fusion (LatentsDomainGroupT): pre-fusion domain latent representation.
Returns:

dict[str, torch.Tensor]: for each domain in the group, the fusion coefficient for each item in the batch.

Example:
>>> SomeSelectionImplementation().forward(
...     {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
... )
{"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
class SingleDomainSelection(SelectionBase):
65class SingleDomainSelection(SelectionBase):
66    """
67    This selection mechanism handles groups that can have multiple domains, but always
68    return a selection of 1 domain from the group with a uniform distribution.
69
70    For example, if the group has 2 domains, there is a 50% chance of selecting each
71    domain.
72    """
73
74    def forward(
75        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
76    ) -> dict[str, torch.Tensor]:
77        """
78        Forward pass of the module.
79
80        Args:
81            domains (`LatentsDomainGroupT`): input unimodal latent representations
82            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
83                representation.
84
85        Returns:
86            `dict[str, torch.Tensor]`: whether the domain is selected for each input
87            in the batch.
88        """
89        selection: dict[str, torch.Tensor] = {}
90        bs = group_batch_size(domains)
91        choice = torch.randint(len(domains), size=(bs,), device=group_device(domains))
92        for k, domain in enumerate(domains.keys()):
93            selection[domain] = (choice == k).to(torch.float32)
94        return selection

This selection mechanism handles groups that can have multiple domains, but always return a selection of 1 domain from the group with a uniform distribution.

For example, if the group has 2 domains, there is a 50% chance of selecting each domain.

def forward( self, domains: Mapping[str, torch.Tensor], encodings_pre_fusion: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
74    def forward(
75        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
76    ) -> dict[str, torch.Tensor]:
77        """
78        Forward pass of the module.
79
80        Args:
81            domains (`LatentsDomainGroupT`): input unimodal latent representations
82            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
83                representation.
84
85        Returns:
86            `dict[str, torch.Tensor]`: whether the domain is selected for each input
87            in the batch.
88        """
89        selection: dict[str, torch.Tensor] = {}
90        bs = group_batch_size(domains)
91        choice = torch.randint(len(domains), size=(bs,), device=group_device(domains))
92        for k, domain in enumerate(domains.keys()):
93            selection[domain] = (choice == k).to(torch.float32)
94        return selection

Forward pass of the module.

Arguments:
  • domains (LatentsDomainGroupT): input unimodal latent representations
  • encodings_pre_fusion (LatentsDomainGroupT): pre-fusion domain latent representation.
Returns:

dict[str, torch.Tensor]: whether the domain is selected for each input in the batch.

Inherited Members
SelectionBase
update_gw_state
class FixedSharedSelection(SelectionBase):
 97class FixedSharedSelection(SelectionBase):
 98    """
 99    This selection mechanism is deterministic and always shares the weights equally
100    between domains.
101
102    For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...
103    """
104
105    def forward(
106        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
107    ) -> dict[str, torch.Tensor]:
108        """
109        Forward pass of the module.
110
111        Args:
112            domains (`LatentsDomainGroupT`): input unimodal latent representations
113            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
114                representation.
115
116        Returns:
117            `dict[str, torch.Tensor]`: whether the domain is selected for each input
118            in the batch.
119        """
120        selection: dict[str, torch.Tensor] = {}
121        bs = group_batch_size(domains)
122        coef = torch.full((bs,), 1.0 / len(domains), device=group_device(domains))
123        for domain in domains:
124            selection[domain] = coef.clone()
125        return selection

This selection mechanism is deterministic and always shares the weights equally between domains.

For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...

def forward( self, domains: Mapping[str, torch.Tensor], encodings_pre_fusion: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
105    def forward(
106        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
107    ) -> dict[str, torch.Tensor]:
108        """
109        Forward pass of the module.
110
111        Args:
112            domains (`LatentsDomainGroupT`): input unimodal latent representations
113            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
114                representation.
115
116        Returns:
117            `dict[str, torch.Tensor]`: whether the domain is selected for each input
118            in the batch.
119        """
120        selection: dict[str, torch.Tensor] = {}
121        bs = group_batch_size(domains)
122        coef = torch.full((bs,), 1.0 / len(domains), device=group_device(domains))
123        for domain in domains:
124            selection[domain] = coef.clone()
125        return selection

Forward pass of the module.

Arguments:
  • domains (LatentsDomainGroupT): input unimodal latent representations
  • encodings_pre_fusion (LatentsDomainGroupT): pre-fusion domain latent representation.
Returns:

dict[str, torch.Tensor]: whether the domain is selected for each input in the batch.

Inherited Members
SelectionBase
update_gw_state
class LearnedAttention(SelectionBase):
157class LearnedAttention(SelectionBase):
158    """
159    Content-based single-step attention over GW latents with configurable toggles.
160
161    Design:
162    - Query is the mean of available GW latents (content-q0 seed)
163    - Single-step dot-product attention over domains (no refinement loop)
164    - Optional per-domain keys
165
166    Toggles:
167    - per_domain_keys: use per-domain key projections instead of a shared one
168    - stopgrad: detach GW latents before computing keys/query
169    - key_on_prefusion: compute keys on pre-fusion GW latents (True) or raw domains
170    - domain_dims: required when key_on_prefusion=False to size per-domain key layers
171    """
172
173    def __init__(
174        self,
175        gw_dim: int,
176        domain_names: Iterable[str],
177        head_size: int = 64,
178        per_domain_keys: bool = False,
179        stopgrad: bool = True,
180        key_on_prefusion: bool = True,
181        domain_dims: Mapping[str, int] | None = None,
182    ):
183        super().__init__()
184        self.gw_dim = int(gw_dim)
185        self.head_size = int(head_size)
186        self.domain_names = list(domain_names)
187
188        # Toggles
189        self.per_domain_keys = bool(per_domain_keys)
190        self.stopgrad = bool(stopgrad)
191        self.key_on_prefusion = bool(key_on_prefusion)
192        self.domain_dims = dict(domain_dims) if domain_dims is not None else None
193
194        # Projections
195        self.query_layer = nn.Linear(self.gw_dim, self.head_size)
196        self.per_key_layers: nn.ModuleDict | None
197        self.shared_key_layer: nn.Linear | None
198        if self.key_on_prefusion:
199            if self.per_domain_keys:
200                self.per_key_layers = nn.ModuleDict(
201                    {
202                        d: nn.Linear(self.gw_dim, self.head_size)
203                        for d in self.domain_names
204                    }
205                )
206                self.shared_key_layer = None
207            else:
208                self.shared_key_layer = nn.Linear(self.gw_dim, self.head_size)
209                self.per_key_layers = None
210        else:
211            if not self.per_domain_keys:
212                raise ValueError(
213                    "key_on_prefusion=False requires per_domain_keys=True because "
214                    "domain latent dimensions can differ."
215                )
216            if self.domain_dims is None:
217                raise ValueError(
218                    "key_on_prefusion=False requires domain_dims for key projections."
219                )
220            missing_dims = [d for d in self.domain_names if d not in self.domain_dims]
221            if missing_dims:
222                raise ValueError(
223                    f"Missing domain_dims for: {', '.join(sorted(missing_dims))}"
224                )
225            self.per_key_layers = nn.ModuleDict(
226                {
227                    d: nn.Linear(self.domain_dims[d], self.head_size)
228                    for d in self.domain_names
229                }
230            )
231            self.shared_key_layer = None
232
233    @staticmethod
234    def _calc_attention(
235        keys: dict[str, torch.Tensor],
236        query: torch.Tensor,
237        order: Iterable[str],
238    ) -> dict[str, torch.Tensor]:
239        """
240        Compute attention over domains.
241
242        Args:
243            keys: mapping of domain -> key tensor (B, H)
244            query: query tensor (B, H)
245            order: iterable of domain names to fix output ordering
246
247        Returns:
248            dict[str, torch.Tensor]: per-domain attention scores that sum to 1.
249        """
250        names = [d for d in order if d in keys]
251        if not names:
252            raise ValueError("LearnedAttention: no keys provided.")
253
254        logits = torch.stack(
255            [(keys[d] * query).sum(dim=1) for d in names], dim=1
256        )  # (B, D)
257
258        probs = torch.softmax(logits, dim=1)
259
260        return {d: probs[:, i] for i, d in enumerate(names)}
261
262    def forward(
263        self,
264        domains: LatentsDomainGroupT,
265        encodings_pre_fusion: LatentsDomainGroupT | None = None,
266    ) -> dict[str, torch.Tensor]:
267        """
268        Args:
269            domains: mapping from domain name to GW latent (B, gw_dim)
270            encodings_pre_fusion: pre-fusion encodings (used when key_on_prefusion)
271
272        Returns:
273            dict[str, torch.Tensor]: per-domain attention weights.
274        """
275        domain_latents: Mapping[str, torch.Tensor] = domains
276
277        present = [d for d in self.domain_names if d in domain_latents]
278        if not present:
279            raise ValueError(
280                "LearnedAttention: no known domains present in gw_latents."
281            )
282
283        if self.key_on_prefusion:
284            if encodings_pre_fusion is None:
285                raise ValueError(
286                    "key_on_prefusion=True requires encodings_pre_fusion inputs."
287                )
288            key_source = encodings_pre_fusion
289        else:
290            key_source = domain_latents
291
292        missing_keys = [d for d in present if d not in key_source]
293        if missing_keys:
294            raise ValueError(
295                f"Missing key latents for: {', '.join(sorted(missing_keys))}"
296            )
297
298        if encodings_pre_fusion is None:
299            query_source = domain_latents
300        else:
301            query_source = encodings_pre_fusion
302
303        missing_query = [d for d in present if d not in query_source]
304        if missing_query:
305            raise ValueError(
306                f"Missing query latents for: {', '.join(sorted(missing_query))}"
307            )
308
309        if self.stopgrad:
310            key_latents = {d: key_source[d].detach() for d in present}
311            query_latents = {d: query_source[d].detach() for d in present}
312        else:
313            key_latents = {d: key_source[d] for d in present}
314            query_latents = {d: query_source[d] for d in present}
315
316        if self.per_domain_keys:
317            if self.per_key_layers is None:
318                raise RuntimeError(
319                    "per_domain_keys=True but per-domain key layers are missing."
320                )
321            keys = {
322                d: cast(nn.Linear, self.per_key_layers[d])(key_latents[d])
323                for d in present
324            }
325        else:
326            if self.shared_key_layer is None:
327                raise RuntimeError(
328                    "per_domain_keys=False but shared key layer is missing."
329                )
330            proj = self.shared_key_layer
331            keys = {d: proj(key_latents[d]) for d in present}
332
333        stacked = torch.stack([query_latents[d] for d in present], dim=0)  # (D, B, F)
334        query = self.query_layer(stacked.mean(0))  # (B, H)
335
336        return self._calc_attention(
337            keys=keys,
338            query=query,
339            order=self.domain_names,
340        )

Content-based single-step attention over GW latents with configurable toggles.

Design:

  • Query is the mean of available GW latents (content-q0 seed)
  • Single-step dot-product attention over domains (no refinement loop)
  • Optional per-domain keys

Toggles:

  • per_domain_keys: use per-domain key projections instead of a shared one
  • stopgrad: detach GW latents before computing keys/query
  • key_on_prefusion: compute keys on pre-fusion GW latents (True) or raw domains
  • domain_dims: required when key_on_prefusion=False to size per-domain key layers
LearnedAttention( gw_dim: int, domain_names: Iterable[str], head_size: int = 64, per_domain_keys: bool = False, stopgrad: bool = True, key_on_prefusion: bool = True, domain_dims: Mapping[str, int] | None = None)
173    def __init__(
174        self,
175        gw_dim: int,
176        domain_names: Iterable[str],
177        head_size: int = 64,
178        per_domain_keys: bool = False,
179        stopgrad: bool = True,
180        key_on_prefusion: bool = True,
181        domain_dims: Mapping[str, int] | None = None,
182    ):
183        super().__init__()
184        self.gw_dim = int(gw_dim)
185        self.head_size = int(head_size)
186        self.domain_names = list(domain_names)
187
188        # Toggles
189        self.per_domain_keys = bool(per_domain_keys)
190        self.stopgrad = bool(stopgrad)
191        self.key_on_prefusion = bool(key_on_prefusion)
192        self.domain_dims = dict(domain_dims) if domain_dims is not None else None
193
194        # Projections
195        self.query_layer = nn.Linear(self.gw_dim, self.head_size)
196        self.per_key_layers: nn.ModuleDict | None
197        self.shared_key_layer: nn.Linear | None
198        if self.key_on_prefusion:
199            if self.per_domain_keys:
200                self.per_key_layers = nn.ModuleDict(
201                    {
202                        d: nn.Linear(self.gw_dim, self.head_size)
203                        for d in self.domain_names
204                    }
205                )
206                self.shared_key_layer = None
207            else:
208                self.shared_key_layer = nn.Linear(self.gw_dim, self.head_size)
209                self.per_key_layers = None
210        else:
211            if not self.per_domain_keys:
212                raise ValueError(
213                    "key_on_prefusion=False requires per_domain_keys=True because "
214                    "domain latent dimensions can differ."
215                )
216            if self.domain_dims is None:
217                raise ValueError(
218                    "key_on_prefusion=False requires domain_dims for key projections."
219                )
220            missing_dims = [d for d in self.domain_names if d not in self.domain_dims]
221            if missing_dims:
222                raise ValueError(
223                    f"Missing domain_dims for: {', '.join(sorted(missing_dims))}"
224                )
225            self.per_key_layers = nn.ModuleDict(
226                {
227                    d: nn.Linear(self.domain_dims[d], self.head_size)
228                    for d in self.domain_names
229                }
230            )
231            self.shared_key_layer = None

Initialize internal Module state, shared by both nn.Module and ScriptModule.

gw_dim
head_size
domain_names
per_domain_keys
stopgrad
key_on_prefusion
domain_dims
query_layer
per_key_layers: torch.nn.modules.container.ModuleDict | None
shared_key_layer: torch.nn.modules.linear.Linear | None
def forward( self, domains: Mapping[str, torch.Tensor], encodings_pre_fusion: Mapping[str, torch.Tensor] | None = None) -> dict[str, torch.Tensor]:
262    def forward(
263        self,
264        domains: LatentsDomainGroupT,
265        encodings_pre_fusion: LatentsDomainGroupT | None = None,
266    ) -> dict[str, torch.Tensor]:
267        """
268        Args:
269            domains: mapping from domain name to GW latent (B, gw_dim)
270            encodings_pre_fusion: pre-fusion encodings (used when key_on_prefusion)
271
272        Returns:
273            dict[str, torch.Tensor]: per-domain attention weights.
274        """
275        domain_latents: Mapping[str, torch.Tensor] = domains
276
277        present = [d for d in self.domain_names if d in domain_latents]
278        if not present:
279            raise ValueError(
280                "LearnedAttention: no known domains present in gw_latents."
281            )
282
283        if self.key_on_prefusion:
284            if encodings_pre_fusion is None:
285                raise ValueError(
286                    "key_on_prefusion=True requires encodings_pre_fusion inputs."
287                )
288            key_source = encodings_pre_fusion
289        else:
290            key_source = domain_latents
291
292        missing_keys = [d for d in present if d not in key_source]
293        if missing_keys:
294            raise ValueError(
295                f"Missing key latents for: {', '.join(sorted(missing_keys))}"
296            )
297
298        if encodings_pre_fusion is None:
299            query_source = domain_latents
300        else:
301            query_source = encodings_pre_fusion
302
303        missing_query = [d for d in present if d not in query_source]
304        if missing_query:
305            raise ValueError(
306                f"Missing query latents for: {', '.join(sorted(missing_query))}"
307            )
308
309        if self.stopgrad:
310            key_latents = {d: key_source[d].detach() for d in present}
311            query_latents = {d: query_source[d].detach() for d in present}
312        else:
313            key_latents = {d: key_source[d] for d in present}
314            query_latents = {d: query_source[d] for d in present}
315
316        if self.per_domain_keys:
317            if self.per_key_layers is None:
318                raise RuntimeError(
319                    "per_domain_keys=True but per-domain key layers are missing."
320                )
321            keys = {
322                d: cast(nn.Linear, self.per_key_layers[d])(key_latents[d])
323                for d in present
324            }
325        else:
326            if self.shared_key_layer is None:
327                raise RuntimeError(
328                    "per_domain_keys=False but shared key layer is missing."
329                )
330            proj = self.shared_key_layer
331            keys = {d: proj(key_latents[d]) for d in present}
332
333        stacked = torch.stack([query_latents[d] for d in present], dim=0)  # (D, B, F)
334        query = self.query_layer(stacked.mean(0))  # (B, H)
335
336        return self._calc_attention(
337            keys=keys,
338            query=query,
339            order=self.domain_names,
340        )
Arguments:
  • domains: mapping from domain name to GW latent (B, gw_dim)
  • encodings_pre_fusion: pre-fusion encodings (used when key_on_prefusion)
Returns:

dict[str, torch.Tensor]: per-domain attention weights.

Inherited Members
SelectionBase
update_gw_state
class RandomSelection(SelectionBase):
343class RandomSelection(SelectionBase):
344    """
345    Modified random attention to only utilize uniform-softmax scores across modalities.
346    This version omits the binary scaling factors and focuses on generating attention
347    coefficients using a uniform distribution followed by a domain-wise softmax.
348    """
349
350    def __init__(self, temperature: float):
351        """
352        Args:
353            temperature (`float`): Temperature of the softmax applied to uniform
354                scaling factors.
355        """
356        super().__init__()
357        self.temperature = temperature
358
359    def forward(
360        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
361    ) -> dict[str, torch.Tensor]:
362        """
363        Generate uniform-then-domain-wise-softmaxed samples for each domain.
364
365        Args:
366            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
367                This is not used in the function directly but determines the structure
368                of the returned attention coefficients.
369
370        Returns:
371            `dict[str, torch.Tensor]`: For each domain in the group, the fusion
372            coefficient for each item in the batch, based solely on
373            uniform-softmax scores.
374        """
375        num_domains = len(domains)
376        batch_size = group_batch_size(domains)
377        device = group_device(domains)
378
379        # Generate uniform scores
380        uniform_scores = torch.rand(batch_size, num_domains, device=device)
381
382        # Apply softmax across domains with temperature scaling
383        softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1)
384        # Create attention dictionary for each domain
385        attention_dict = {
386            domain: softmax_scores[:, i] for i, domain in enumerate(domains)
387        }
388
389        return attention_dict

Modified random attention to only utilize uniform-softmax scores across modalities. This version omits the binary scaling factors and focuses on generating attention coefficients using a uniform distribution followed by a domain-wise softmax.

RandomSelection(temperature: float)
350    def __init__(self, temperature: float):
351        """
352        Args:
353            temperature (`float`): Temperature of the softmax applied to uniform
354                scaling factors.
355        """
356        super().__init__()
357        self.temperature = temperature
Arguments:
  • temperature (float): Temperature of the softmax applied to uniform scaling factors.
temperature
def forward( self, domains: Mapping[str, torch.Tensor], encodings_pre_fusion: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
359    def forward(
360        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
361    ) -> dict[str, torch.Tensor]:
362        """
363        Generate uniform-then-domain-wise-softmaxed samples for each domain.
364
365        Args:
366            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
367                This is not used in the function directly but determines the structure
368                of the returned attention coefficients.
369
370        Returns:
371            `dict[str, torch.Tensor]`: For each domain in the group, the fusion
372            coefficient for each item in the batch, based solely on
373            uniform-softmax scores.
374        """
375        num_domains = len(domains)
376        batch_size = group_batch_size(domains)
377        device = group_device(domains)
378
379        # Generate uniform scores
380        uniform_scores = torch.rand(batch_size, num_domains, device=device)
381
382        # Apply softmax across domains with temperature scaling
383        softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1)
384        # Create attention dictionary for each domain
385        attention_dict = {
386            domain: softmax_scores[:, i] for i, domain in enumerate(domains)
387        }
388
389        return attention_dict

Generate uniform-then-domain-wise-softmaxed samples for each domain.

Arguments:
  • domains (LatentsDomainGroupT): Group of unimodal latent representations. This is not used in the function directly but determines the structure of the returned attention coefficients.
Returns:

dict[str, torch.Tensor]: For each domain in the group, the fusion coefficient for each item in the batch, based solely on uniform-softmax scores.

Inherited Members
SelectionBase
update_gw_state
class DynamicQueryAttention(SelectionBase):
392class DynamicQueryAttention(SelectionBase):
393    """
394    Key-Query attention with a dynamic gw vector.
395    The query is updated based on the scaled gw vector.
396    """
397
398    def __init__(
399        self,
400        head_size: int,
401        domain_dim: int,
402        domain_names: Iterable[str],
403        n_steps: int = 1,
404    ):
405        """
406        Args:
407            head_size (`int`) : dimension of the key and query vectors.
408            domain_dim (`int`) : dimension of the input dims (assumed to be the same
409                for now)
410            domain_names  (`Iterable[str]`) : list of input domains
411            n_steps (`int`) : number of steps to update the query vector
412        """
413        super().__init__()
414        self.head_size = head_size
415        self.query_layer = nn.Linear(domain_dim, head_size)
416        self.key_layers = nn.ModuleDict(
417            {domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
418        )
419        self.n_steps = n_steps
420        self.step_limit = n_steps  # Default step limit is n_steps
421        # Start with a random gw state
422        self.register_buffer("initial_gw_state", torch.rand(domain_dim))
423
424    def set_step_limit(self, step_limit: int):
425        """
426        Sets the step limit for the dynamic attention update loop.
427
428        Args:
429            step_limit (`int`): Maximum number of steps to run the loop.
430        """
431        if step_limit > self.n_steps:
432            raise ValueError(
433                f"Step limit cannot exceed the maximum n_steps ({self.n_steps})."
434            )
435        self.step_limit = step_limit
436
437    def fuse_weighted_encodings(
438        self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
439    ) -> torch.Tensor:
440        """
441        Fuse the weighted encodings using the attention scores.
442
443        Args:
444            encodings (`LatentsDomainGroupT`): Unimodal latent representation
445            attention_dict (`dict[str, torch.Tensor]`): The attention scores for each
446                domain in the group.
447
448        Returns:
449            `torch.Tensor`: The fused tensor.
450        """
451        # Apply attention scores to the encodings
452        weighted_encodings = {}
453        for key in attention_dict:
454            if key in encodings:
455                # Perform element-wise multiplication
456                weighted_encodings[key] = (
457                    attention_dict[key].unsqueeze(1) * encodings[key]
458                )
459
460        # Stack the tensors along a new dimension (dimension 0)
461        stacked_tensors = torch.stack(list(weighted_encodings.values()))
462
463        # Apply fusion by summing along the newly created dimension
464        summed_tensor = torch.sum(stacked_tensors, dim=0)
465        return summed_tensor
466
467    def forward(
468        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
469    ) -> dict[str, torch.Tensor]:
470        """
471        Compute keys and queries, match them with dot product and softmax.
472        Does this twice, once with the static query and once with a dynamic query.
473
474        Args:
475            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
476            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
477                representation.
478
479        Returns:
480            `dict[str, torch.Tensor]`: the attention scores for each domain in the
481            group.
482        """
483
484        keys = {
485            domain: self.key_layers[domain](encoding)
486            for domain, encoding in domains.items()
487        }
488
489        batch_size = group_batch_size(domains)
490
491        # Retrieve random query
492        query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
493
494        # Calculate the attention scores
495        attention_dict = _calculate_attention_dict(domains, keys, query)
496
497        if self.n_steps > 0:
498            # Update the query based on the static attention scores
499            for _ in range(min(self.step_limit, self.n_steps)):
500                # Apply the attention scores to the encodings
501                summed_tensor = self.fuse_weighted_encodings(
502                    encodings_pre_fusion, attention_dict
503                )
504
505                # Retrieve query (now it is dependent on the new gw state)
506                query = self.query_layer(summed_tensor)
507
508                # Calculate the attention scores again
509                attention_dict = _calculate_attention_dict(domains, keys, query)
510
511        return attention_dict

Key-Query attention with a dynamic gw vector. The query is updated based on the scaled gw vector.

DynamicQueryAttention( head_size: int, domain_dim: int, domain_names: Iterable[str], n_steps: int = 1)
398    def __init__(
399        self,
400        head_size: int,
401        domain_dim: int,
402        domain_names: Iterable[str],
403        n_steps: int = 1,
404    ):
405        """
406        Args:
407            head_size (`int`) : dimension of the key and query vectors.
408            domain_dim (`int`) : dimension of the input dims (assumed to be the same
409                for now)
410            domain_names  (`Iterable[str]`) : list of input domains
411            n_steps (`int`) : number of steps to update the query vector
412        """
413        super().__init__()
414        self.head_size = head_size
415        self.query_layer = nn.Linear(domain_dim, head_size)
416        self.key_layers = nn.ModuleDict(
417            {domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
418        )
419        self.n_steps = n_steps
420        self.step_limit = n_steps  # Default step limit is n_steps
421        # Start with a random gw state
422        self.register_buffer("initial_gw_state", torch.rand(domain_dim))
Arguments:
  • head_size (int) : dimension of the key and query vectors.
  • domain_dim (int) : dimension of the input dims (assumed to be the same for now)
  • domain_names (Iterable[str]) : list of input domains
  • n_steps (int) : number of steps to update the query vector
head_size
query_layer
key_layers
n_steps
step_limit
def set_step_limit(self, step_limit: int):
424    def set_step_limit(self, step_limit: int):
425        """
426        Sets the step limit for the dynamic attention update loop.
427
428        Args:
429            step_limit (`int`): Maximum number of steps to run the loop.
430        """
431        if step_limit > self.n_steps:
432            raise ValueError(
433                f"Step limit cannot exceed the maximum n_steps ({self.n_steps})."
434            )
435        self.step_limit = step_limit

Sets the step limit for the dynamic attention update loop.

Arguments:
  • step_limit (int): Maximum number of steps to run the loop.
def fuse_weighted_encodings( self, encodings: Mapping[str, torch.Tensor], attention_dict: dict[str, torch.Tensor]) -> torch.Tensor:
437    def fuse_weighted_encodings(
438        self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
439    ) -> torch.Tensor:
440        """
441        Fuse the weighted encodings using the attention scores.
442
443        Args:
444            encodings (`LatentsDomainGroupT`): Unimodal latent representation
445            attention_dict (`dict[str, torch.Tensor]`): The attention scores for each
446                domain in the group.
447
448        Returns:
449            `torch.Tensor`: The fused tensor.
450        """
451        # Apply attention scores to the encodings
452        weighted_encodings = {}
453        for key in attention_dict:
454            if key in encodings:
455                # Perform element-wise multiplication
456                weighted_encodings[key] = (
457                    attention_dict[key].unsqueeze(1) * encodings[key]
458                )
459
460        # Stack the tensors along a new dimension (dimension 0)
461        stacked_tensors = torch.stack(list(weighted_encodings.values()))
462
463        # Apply fusion by summing along the newly created dimension
464        summed_tensor = torch.sum(stacked_tensors, dim=0)
465        return summed_tensor

Fuse the weighted encodings using the attention scores.

Arguments:
  • encodings (LatentsDomainGroupT): Unimodal latent representation
  • attention_dict (dict[str, torch.Tensor]): The attention scores for each domain in the group.
Returns:

torch.Tensor: The fused tensor.

def forward( self, domains: Mapping[str, torch.Tensor], encodings_pre_fusion: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
467    def forward(
468        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
469    ) -> dict[str, torch.Tensor]:
470        """
471        Compute keys and queries, match them with dot product and softmax.
472        Does this twice, once with the static query and once with a dynamic query.
473
474        Args:
475            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
476            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
477                representation.
478
479        Returns:
480            `dict[str, torch.Tensor]`: the attention scores for each domain in the
481            group.
482        """
483
484        keys = {
485            domain: self.key_layers[domain](encoding)
486            for domain, encoding in domains.items()
487        }
488
489        batch_size = group_batch_size(domains)
490
491        # Retrieve random query
492        query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
493
494        # Calculate the attention scores
495        attention_dict = _calculate_attention_dict(domains, keys, query)
496
497        if self.n_steps > 0:
498            # Update the query based on the static attention scores
499            for _ in range(min(self.step_limit, self.n_steps)):
500                # Apply the attention scores to the encodings
501                summed_tensor = self.fuse_weighted_encodings(
502                    encodings_pre_fusion, attention_dict
503                )
504
505                # Retrieve query (now it is dependent on the new gw state)
506                query = self.query_layer(summed_tensor)
507
508                # Calculate the attention scores again
509                attention_dict = _calculate_attention_dict(domains, keys, query)
510
511        return attention_dict

Compute keys and queries, match them with dot product and softmax. Does this twice, once with the static query and once with a dynamic query.

Arguments:
  • domains (LatentsDomainGroupT): Group of unimodal latent representations.
  • encodings_pre_fusion (LatentsDomainGroupT): pre-fusion domain latent representation.
Returns:

dict[str, torch.Tensor]: the attention scores for each domain in the group.

Inherited Members
SelectionBase
update_gw_state