shimmer.modules.selection

  1from abc import ABC, abstractmethod
  2from collections.abc import Iterable
  3
  4import torch
  5import torch.nn as nn
  6
  7from shimmer.types import LatentsDomainGroupT
  8from shimmer.utils import group_batch_size, group_device
  9
 10
 11class SelectionBase(torch.nn.Module, ABC):
 12    """
 13    This is the base class for the selection mechanism.
 14    The selection mechanisms handles the "competition" between modules and *selects*
 15    fusion coefficients for the domains.
 16    """
 17
 18    def update_gw_state(self, gw_state: torch.Tensor) -> None:
 19        """
 20        Update the internal copy of the previous GW state.
 21        By default, this is not implemented and will raise an error if used.
 22
 23        :note..
 24            This is not defined as an abstractmethod as some selection method may
 25            not need it.
 26
 27        Args:
 28            gw_state (`torch.Tensor`): the previous GW state
 29        """
 30        pass
 31
 32    @abstractmethod
 33    def forward(
 34        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
 35    ) -> dict[str, torch.Tensor]:
 36        """
 37        Forward pass of the selection method.
 38
 39        Args:
 40            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
 41            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
 42                representation.
 43
 44        Returns:
 45            `dict[str, torch.Tensor]`: for each domain in the group, the fusion
 46            coefficient for each item in the batch.
 47
 48        Example:
 49            >>> SomeSelectionImplementation().forward(
 50            ...     {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
 51            ... )
 52            {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
 53        """
 54        ...
 55
 56    # This is just for proper auto-completion...
 57    def __call__(
 58        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
 59    ) -> dict[str, torch.Tensor]:
 60        return super().__call__(domains, encodings_pre_fusion)
 61
 62
 63class SingleDomainSelection(SelectionBase):
 64    """
 65    This selection mechanism handles groups that can have multiple domains, but always
 66    return a selection of 1 domain from the group with a uniform distribution.
 67
 68    For example, if the group has 2 domains, there is a 50% chance of selecting each
 69    domain.
 70    """
 71
 72    def forward(
 73        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
 74    ) -> dict[str, torch.Tensor]:
 75        """
 76        Forward pass of the module.
 77
 78        Args:
 79            domains (`LatentsDomainGroupT`): input unimodal latent representations
 80            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
 81                representation.
 82
 83        Returns:
 84            `dict[str, torch.Tensor]`: whether the domain is selected for each input
 85            in the batch.
 86        """
 87        selection: dict[str, torch.Tensor] = {}
 88        bs = group_batch_size(domains)
 89        choice = torch.randint(len(domains), size=(bs,), device=group_device(domains))
 90        for k, domain in enumerate(domains.keys()):
 91            selection[domain] = (choice == k).to(torch.float32)
 92        return selection
 93
 94
 95class FixedSharedSelection(SelectionBase):
 96    """
 97    This selection mechanism is deterministic and always shares the weights equally
 98    between domains.
 99
100    For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...
101    """
102
103    def forward(
104        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
105    ) -> dict[str, torch.Tensor]:
106        """
107        Forward pass of the module.
108
109        Args:
110            domains (`LatentsDomainGroupT`): input unimodal latent representations
111            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
112                representation.
113
114        Returns:
115            `dict[str, torch.Tensor]`: whether the domain is selected for each input
116            in the batch.
117        """
118        selection: dict[str, torch.Tensor] = {}
119        bs = group_batch_size(domains)
120        coef = torch.full((bs,), 1.0 / len(domains), device=group_device(domains))
121        for domain in domains:
122            selection[domain] = coef.clone()
123        return selection
124
125
126def _calculate_attention_dict(
127    domains: LatentsDomainGroupT,
128    keys: dict[str, torch.Tensor],
129    query: torch.Tensor,
130) -> dict[str, torch.Tensor]:
131    """
132    Args:
133        domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
134        keys (`dict[str, torch.Tensor]`): The keys for each domain.
135        query (`torch.Tensor`): The query tensor.
136
137    Returns:
138        `dict[str, torch.Tensor]`: The attention scores for each domain.
139    """
140    dot_products = {
141        domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze()
142        for domain, key in keys.items()
143    }
144
145    dot_products_tensor = torch.stack(list(dot_products.values()), dim=1)
146
147    attention_scores = torch.softmax(dot_products_tensor, dim=1)
148
149    attention_dict = {
150        domain: attention_scores[:, i] for i, domain in enumerate(domains)
151    }
152    return attention_dict
153
154
155class RandomSelection(SelectionBase):
156    """
157    Modified random attention to only utilize uniform-softmax scores across modalities.
158    This version omits the binary scaling factors and focuses on generating attention
159    coefficients using a uniform distribution followed by a domain-wise softmax.
160    """
161
162    def __init__(self, temperature: float):
163        """
164        Args:
165            temperature (`float`): Temperature of the softmax applied to uniform
166                scaling factors.
167        """
168        super().__init__()
169        self.temperature = temperature
170
171    def forward(
172        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
173    ) -> dict[str, torch.Tensor]:
174        """
175        Generate uniform-then-domain-wise-softmaxed samples for each domain.
176
177        Args:
178            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
179                This is not used in the function directly but determines the structure
180                of the returned attention coefficients.
181
182        Returns:
183            `dict[str, torch.Tensor]`: For each domain in the group, the fusion
184            coefficient for each item in the batch, based solely on
185            uniform-softmax scores.
186        """
187        num_domains = len(domains)
188        batch_size = group_batch_size(domains)
189        device = group_device(domains)
190
191        # Generate uniform scores
192        uniform_scores = torch.rand(batch_size, num_domains, device=device)
193
194        # Apply softmax across domains with temperature scaling
195        softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1)
196        # Create attention dictionary for each domain
197        attention_dict = {
198            domain: softmax_scores[:, i] for i, domain in enumerate(domains)
199        }
200
201        return attention_dict
202
203
204class DynamicQueryAttention(SelectionBase):
205    """
206    Key-Query attention with a dynamic gw vector.
207    The query is updated based on the scaled gw vector.
208    """
209
210    def __init__(
211        self,
212        head_size: int,
213        domain_dim: int,
214        domain_names: Iterable[str],
215        n_steps: int = 1,
216    ):
217        """
218        Args:
219            head_size (`int`) : dimension of the key and query vectors.
220            domain_dim (`int`) : dimension of the input dims (assumed to be the same
221                for now)
222            domain_names  (`Iterable[str]`) : list of input domains
223            n_steps (`int`) : number of steps to update the query vector
224        """
225        super().__init__()
226        self.head_size = head_size
227        self.query_layer = nn.Linear(domain_dim, head_size)
228        self.key_layers = nn.ModuleDict(
229            {domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
230        )
231        self.n_steps = n_steps
232        self.step_limit = n_steps  # Default step limit is n_steps
233        # Start with a random gw state
234        self.register_buffer("initial_gw_state", torch.rand(domain_dim))
235
236    def set_step_limit(self, step_limit: int):
237        """
238        Sets the step limit for the dynamic attention update loop.
239
240        Args:
241            step_limit (`int`): Maximum number of steps to run the loop.
242        """
243        if step_limit > self.n_steps:
244            raise ValueError(
245                f"Step limit cannot exceed the maximum n_steps ({self.n_steps})."
246            )
247        self.step_limit = step_limit
248
249    def fuse_weighted_encodings(
250        self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
251    ) -> torch.Tensor:
252        """
253        Fuse the weighted encodings using the attention scores.
254
255        Args:
256            encodings (`LatentsDomainGroupT`): Unimodal latent representation
257            attention_dict (`dict[str, torch.Tensor]`): The attention scores for each
258                domain in the group.
259
260        Returns:
261            `torch.Tensor`: The fused tensor.
262        """
263        # Apply attention scores to the encodings
264        weighted_encodings = {}
265        for key in attention_dict:
266            if key in encodings:
267                # Perform element-wise multiplication
268                weighted_encodings[key] = (
269                    attention_dict[key].unsqueeze(1) * encodings[key]
270                )
271
272        # Stack the tensors along a new dimension (dimension 0)
273        stacked_tensors = torch.stack(list(weighted_encodings.values()))
274
275        # Apply fusion by summing along the newly created dimension
276        summed_tensor = torch.sum(stacked_tensors, dim=0)
277        return summed_tensor
278
279    def forward(
280        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
281    ) -> dict[str, torch.Tensor]:
282        """
283        Compute keys and queries, match them with dot product and softmax.
284        Does this twice, once with the static query and once with a dynamic query.
285
286        Args:
287            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
288            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
289                representation.
290
291        Returns:
292            `dict[str, torch.Tensor]`: the attention scores for each domain in the
293            group.
294        """
295
296        keys = {
297            domain: self.key_layers[domain](encoding)
298            for domain, encoding in domains.items()
299        }
300
301        batch_size = group_batch_size(domains)
302
303        # Retrieve random query
304        query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
305
306        # Calculate the attention scores
307        attention_dict = _calculate_attention_dict(domains, keys, query)
308
309        if self.n_steps > 0:
310            # Update the query based on the static attention scores
311            for _ in range(min(self.step_limit, self.n_steps)):
312                # Apply the attention scores to the encodings
313                summed_tensor = self.fuse_weighted_encodings(
314                    encodings_pre_fusion, attention_dict
315                )
316
317                # Retrieve query (now it is dependent on the new gw state)
318                query = self.query_layer(summed_tensor)
319
320                # Calculate the attention scores again
321                attention_dict = _calculate_attention_dict(domains, keys, query)
322
323        return attention_dict
class SelectionBase(torch.nn.modules.module.Module, abc.ABC):
12class SelectionBase(torch.nn.Module, ABC):
13    """
14    This is the base class for the selection mechanism.
15    The selection mechanisms handles the "competition" between modules and *selects*
16    fusion coefficients for the domains.
17    """
18
19    def update_gw_state(self, gw_state: torch.Tensor) -> None:
20        """
21        Update the internal copy of the previous GW state.
22        By default, this is not implemented and will raise an error if used.
23
24        :note..
25            This is not defined as an abstractmethod as some selection method may
26            not need it.
27
28        Args:
29            gw_state (`torch.Tensor`): the previous GW state
30        """
31        pass
32
33    @abstractmethod
34    def forward(
35        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
36    ) -> dict[str, torch.Tensor]:
37        """
38        Forward pass of the selection method.
39
40        Args:
41            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
42            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
43                representation.
44
45        Returns:
46            `dict[str, torch.Tensor]`: for each domain in the group, the fusion
47            coefficient for each item in the batch.
48
49        Example:
50            >>> SomeSelectionImplementation().forward(
51            ...     {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
52            ... )
53            {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
54        """
55        ...
56
57    # This is just for proper auto-completion...
58    def __call__(
59        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
60    ) -> dict[str, torch.Tensor]:
61        return super().__call__(domains, encodings_pre_fusion)

This is the base class for the selection mechanism. The selection mechanisms handles the "competition" between modules and selects fusion coefficients for the domains.

def update_gw_state(self, gw_state: torch.Tensor) -> None:
19    def update_gw_state(self, gw_state: torch.Tensor) -> None:
20        """
21        Update the internal copy of the previous GW state.
22        By default, this is not implemented and will raise an error if used.
23
24        :note..
25            This is not defined as an abstractmethod as some selection method may
26            not need it.
27
28        Args:
29            gw_state (`torch.Tensor`): the previous GW state
30        """
31        pass

Update the internal copy of the previous GW state. By default, this is not implemented and will raise an error if used.

:note.. This is not defined as an abstractmethod as some selection method may not need it.

Arguments:
  • gw_state (torch.Tensor): the previous GW state
@abstractmethod
def forward( self, domains: Mapping[str, torch.Tensor], encodings_pre_fusion: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
33    @abstractmethod
34    def forward(
35        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
36    ) -> dict[str, torch.Tensor]:
37        """
38        Forward pass of the selection method.
39
40        Args:
41            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
42            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
43                representation.
44
45        Returns:
46            `dict[str, torch.Tensor]`: for each domain in the group, the fusion
47            coefficient for each item in the batch.
48
49        Example:
50            >>> SomeSelectionImplementation().forward(
51            ...     {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
52            ... )
53            {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
54        """
55        ...

Forward pass of the selection method.

Arguments:
  • domains (LatentsDomainGroupT): Group of unimodal latent representations.
  • encodings_pre_fusion (LatentsDomainGroupT): pre-fusion domain latent representation.
Returns:

dict[str, torch.Tensor]: for each domain in the group, the fusion coefficient for each item in the batch.

Example:
>>> SomeSelectionImplementation().forward(
...     {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
... )
{"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
class SingleDomainSelection(SelectionBase):
64class SingleDomainSelection(SelectionBase):
65    """
66    This selection mechanism handles groups that can have multiple domains, but always
67    return a selection of 1 domain from the group with a uniform distribution.
68
69    For example, if the group has 2 domains, there is a 50% chance of selecting each
70    domain.
71    """
72
73    def forward(
74        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
75    ) -> dict[str, torch.Tensor]:
76        """
77        Forward pass of the module.
78
79        Args:
80            domains (`LatentsDomainGroupT`): input unimodal latent representations
81            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
82                representation.
83
84        Returns:
85            `dict[str, torch.Tensor]`: whether the domain is selected for each input
86            in the batch.
87        """
88        selection: dict[str, torch.Tensor] = {}
89        bs = group_batch_size(domains)
90        choice = torch.randint(len(domains), size=(bs,), device=group_device(domains))
91        for k, domain in enumerate(domains.keys()):
92            selection[domain] = (choice == k).to(torch.float32)
93        return selection

This selection mechanism handles groups that can have multiple domains, but always return a selection of 1 domain from the group with a uniform distribution.

For example, if the group has 2 domains, there is a 50% chance of selecting each domain.

def forward( self, domains: Mapping[str, torch.Tensor], encodings_pre_fusion: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
73    def forward(
74        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
75    ) -> dict[str, torch.Tensor]:
76        """
77        Forward pass of the module.
78
79        Args:
80            domains (`LatentsDomainGroupT`): input unimodal latent representations
81            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
82                representation.
83
84        Returns:
85            `dict[str, torch.Tensor]`: whether the domain is selected for each input
86            in the batch.
87        """
88        selection: dict[str, torch.Tensor] = {}
89        bs = group_batch_size(domains)
90        choice = torch.randint(len(domains), size=(bs,), device=group_device(domains))
91        for k, domain in enumerate(domains.keys()):
92            selection[domain] = (choice == k).to(torch.float32)
93        return selection

Forward pass of the module.

Arguments:
  • domains (LatentsDomainGroupT): input unimodal latent representations
  • encodings_pre_fusion (LatentsDomainGroupT): pre-fusion domain latent representation.
Returns:

dict[str, torch.Tensor]: whether the domain is selected for each input in the batch.

Inherited Members
SelectionBase
update_gw_state
class FixedSharedSelection(SelectionBase):
 96class FixedSharedSelection(SelectionBase):
 97    """
 98    This selection mechanism is deterministic and always shares the weights equally
 99    between domains.
100
101    For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...
102    """
103
104    def forward(
105        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
106    ) -> dict[str, torch.Tensor]:
107        """
108        Forward pass of the module.
109
110        Args:
111            domains (`LatentsDomainGroupT`): input unimodal latent representations
112            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
113                representation.
114
115        Returns:
116            `dict[str, torch.Tensor]`: whether the domain is selected for each input
117            in the batch.
118        """
119        selection: dict[str, torch.Tensor] = {}
120        bs = group_batch_size(domains)
121        coef = torch.full((bs,), 1.0 / len(domains), device=group_device(domains))
122        for domain in domains:
123            selection[domain] = coef.clone()
124        return selection

This selection mechanism is deterministic and always shares the weights equally between domains.

For example, if 2 domains, it gives 0.5 for each; 3 domains, 1/3 for each...

def forward( self, domains: Mapping[str, torch.Tensor], encodings_pre_fusion: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
104    def forward(
105        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
106    ) -> dict[str, torch.Tensor]:
107        """
108        Forward pass of the module.
109
110        Args:
111            domains (`LatentsDomainGroupT`): input unimodal latent representations
112            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
113                representation.
114
115        Returns:
116            `dict[str, torch.Tensor]`: whether the domain is selected for each input
117            in the batch.
118        """
119        selection: dict[str, torch.Tensor] = {}
120        bs = group_batch_size(domains)
121        coef = torch.full((bs,), 1.0 / len(domains), device=group_device(domains))
122        for domain in domains:
123            selection[domain] = coef.clone()
124        return selection

Forward pass of the module.

Arguments:
  • domains (LatentsDomainGroupT): input unimodal latent representations
  • encodings_pre_fusion (LatentsDomainGroupT): pre-fusion domain latent representation.
Returns:

dict[str, torch.Tensor]: whether the domain is selected for each input in the batch.

Inherited Members
SelectionBase
update_gw_state
class RandomSelection(SelectionBase):
156class RandomSelection(SelectionBase):
157    """
158    Modified random attention to only utilize uniform-softmax scores across modalities.
159    This version omits the binary scaling factors and focuses on generating attention
160    coefficients using a uniform distribution followed by a domain-wise softmax.
161    """
162
163    def __init__(self, temperature: float):
164        """
165        Args:
166            temperature (`float`): Temperature of the softmax applied to uniform
167                scaling factors.
168        """
169        super().__init__()
170        self.temperature = temperature
171
172    def forward(
173        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
174    ) -> dict[str, torch.Tensor]:
175        """
176        Generate uniform-then-domain-wise-softmaxed samples for each domain.
177
178        Args:
179            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
180                This is not used in the function directly but determines the structure
181                of the returned attention coefficients.
182
183        Returns:
184            `dict[str, torch.Tensor]`: For each domain in the group, the fusion
185            coefficient for each item in the batch, based solely on
186            uniform-softmax scores.
187        """
188        num_domains = len(domains)
189        batch_size = group_batch_size(domains)
190        device = group_device(domains)
191
192        # Generate uniform scores
193        uniform_scores = torch.rand(batch_size, num_domains, device=device)
194
195        # Apply softmax across domains with temperature scaling
196        softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1)
197        # Create attention dictionary for each domain
198        attention_dict = {
199            domain: softmax_scores[:, i] for i, domain in enumerate(domains)
200        }
201
202        return attention_dict

Modified random attention to only utilize uniform-softmax scores across modalities. This version omits the binary scaling factors and focuses on generating attention coefficients using a uniform distribution followed by a domain-wise softmax.

RandomSelection(temperature: float)
163    def __init__(self, temperature: float):
164        """
165        Args:
166            temperature (`float`): Temperature of the softmax applied to uniform
167                scaling factors.
168        """
169        super().__init__()
170        self.temperature = temperature
Arguments:
  • temperature (float): Temperature of the softmax applied to uniform scaling factors.
temperature
def forward( self, domains: Mapping[str, torch.Tensor], encodings_pre_fusion: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
172    def forward(
173        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
174    ) -> dict[str, torch.Tensor]:
175        """
176        Generate uniform-then-domain-wise-softmaxed samples for each domain.
177
178        Args:
179            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
180                This is not used in the function directly but determines the structure
181                of the returned attention coefficients.
182
183        Returns:
184            `dict[str, torch.Tensor]`: For each domain in the group, the fusion
185            coefficient for each item in the batch, based solely on
186            uniform-softmax scores.
187        """
188        num_domains = len(domains)
189        batch_size = group_batch_size(domains)
190        device = group_device(domains)
191
192        # Generate uniform scores
193        uniform_scores = torch.rand(batch_size, num_domains, device=device)
194
195        # Apply softmax across domains with temperature scaling
196        softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1)
197        # Create attention dictionary for each domain
198        attention_dict = {
199            domain: softmax_scores[:, i] for i, domain in enumerate(domains)
200        }
201
202        return attention_dict

Generate uniform-then-domain-wise-softmaxed samples for each domain.

Arguments:
  • domains (LatentsDomainGroupT): Group of unimodal latent representations. This is not used in the function directly but determines the structure of the returned attention coefficients.
Returns:

dict[str, torch.Tensor]: For each domain in the group, the fusion coefficient for each item in the batch, based solely on uniform-softmax scores.

Inherited Members
SelectionBase
update_gw_state
class DynamicQueryAttention(SelectionBase):
205class DynamicQueryAttention(SelectionBase):
206    """
207    Key-Query attention with a dynamic gw vector.
208    The query is updated based on the scaled gw vector.
209    """
210
211    def __init__(
212        self,
213        head_size: int,
214        domain_dim: int,
215        domain_names: Iterable[str],
216        n_steps: int = 1,
217    ):
218        """
219        Args:
220            head_size (`int`) : dimension of the key and query vectors.
221            domain_dim (`int`) : dimension of the input dims (assumed to be the same
222                for now)
223            domain_names  (`Iterable[str]`) : list of input domains
224            n_steps (`int`) : number of steps to update the query vector
225        """
226        super().__init__()
227        self.head_size = head_size
228        self.query_layer = nn.Linear(domain_dim, head_size)
229        self.key_layers = nn.ModuleDict(
230            {domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
231        )
232        self.n_steps = n_steps
233        self.step_limit = n_steps  # Default step limit is n_steps
234        # Start with a random gw state
235        self.register_buffer("initial_gw_state", torch.rand(domain_dim))
236
237    def set_step_limit(self, step_limit: int):
238        """
239        Sets the step limit for the dynamic attention update loop.
240
241        Args:
242            step_limit (`int`): Maximum number of steps to run the loop.
243        """
244        if step_limit > self.n_steps:
245            raise ValueError(
246                f"Step limit cannot exceed the maximum n_steps ({self.n_steps})."
247            )
248        self.step_limit = step_limit
249
250    def fuse_weighted_encodings(
251        self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
252    ) -> torch.Tensor:
253        """
254        Fuse the weighted encodings using the attention scores.
255
256        Args:
257            encodings (`LatentsDomainGroupT`): Unimodal latent representation
258            attention_dict (`dict[str, torch.Tensor]`): The attention scores for each
259                domain in the group.
260
261        Returns:
262            `torch.Tensor`: The fused tensor.
263        """
264        # Apply attention scores to the encodings
265        weighted_encodings = {}
266        for key in attention_dict:
267            if key in encodings:
268                # Perform element-wise multiplication
269                weighted_encodings[key] = (
270                    attention_dict[key].unsqueeze(1) * encodings[key]
271                )
272
273        # Stack the tensors along a new dimension (dimension 0)
274        stacked_tensors = torch.stack(list(weighted_encodings.values()))
275
276        # Apply fusion by summing along the newly created dimension
277        summed_tensor = torch.sum(stacked_tensors, dim=0)
278        return summed_tensor
279
280    def forward(
281        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
282    ) -> dict[str, torch.Tensor]:
283        """
284        Compute keys and queries, match them with dot product and softmax.
285        Does this twice, once with the static query and once with a dynamic query.
286
287        Args:
288            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
289            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
290                representation.
291
292        Returns:
293            `dict[str, torch.Tensor]`: the attention scores for each domain in the
294            group.
295        """
296
297        keys = {
298            domain: self.key_layers[domain](encoding)
299            for domain, encoding in domains.items()
300        }
301
302        batch_size = group_batch_size(domains)
303
304        # Retrieve random query
305        query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
306
307        # Calculate the attention scores
308        attention_dict = _calculate_attention_dict(domains, keys, query)
309
310        if self.n_steps > 0:
311            # Update the query based on the static attention scores
312            for _ in range(min(self.step_limit, self.n_steps)):
313                # Apply the attention scores to the encodings
314                summed_tensor = self.fuse_weighted_encodings(
315                    encodings_pre_fusion, attention_dict
316                )
317
318                # Retrieve query (now it is dependent on the new gw state)
319                query = self.query_layer(summed_tensor)
320
321                # Calculate the attention scores again
322                attention_dict = _calculate_attention_dict(domains, keys, query)
323
324        return attention_dict

Key-Query attention with a dynamic gw vector. The query is updated based on the scaled gw vector.

DynamicQueryAttention( head_size: int, domain_dim: int, domain_names: Iterable[str], n_steps: int = 1)
211    def __init__(
212        self,
213        head_size: int,
214        domain_dim: int,
215        domain_names: Iterable[str],
216        n_steps: int = 1,
217    ):
218        """
219        Args:
220            head_size (`int`) : dimension of the key and query vectors.
221            domain_dim (`int`) : dimension of the input dims (assumed to be the same
222                for now)
223            domain_names  (`Iterable[str]`) : list of input domains
224            n_steps (`int`) : number of steps to update the query vector
225        """
226        super().__init__()
227        self.head_size = head_size
228        self.query_layer = nn.Linear(domain_dim, head_size)
229        self.key_layers = nn.ModuleDict(
230            {domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
231        )
232        self.n_steps = n_steps
233        self.step_limit = n_steps  # Default step limit is n_steps
234        # Start with a random gw state
235        self.register_buffer("initial_gw_state", torch.rand(domain_dim))
Arguments:
  • head_size (int) : dimension of the key and query vectors.
  • domain_dim (int) : dimension of the input dims (assumed to be the same for now)
  • domain_names (Iterable[str]) : list of input domains
  • n_steps (int) : number of steps to update the query vector
head_size
query_layer
key_layers
n_steps
step_limit
def set_step_limit(self, step_limit: int):
237    def set_step_limit(self, step_limit: int):
238        """
239        Sets the step limit for the dynamic attention update loop.
240
241        Args:
242            step_limit (`int`): Maximum number of steps to run the loop.
243        """
244        if step_limit > self.n_steps:
245            raise ValueError(
246                f"Step limit cannot exceed the maximum n_steps ({self.n_steps})."
247            )
248        self.step_limit = step_limit

Sets the step limit for the dynamic attention update loop.

Arguments:
  • step_limit (int): Maximum number of steps to run the loop.
def fuse_weighted_encodings( self, encodings: Mapping[str, torch.Tensor], attention_dict: dict[str, torch.Tensor]) -> torch.Tensor:
250    def fuse_weighted_encodings(
251        self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
252    ) -> torch.Tensor:
253        """
254        Fuse the weighted encodings using the attention scores.
255
256        Args:
257            encodings (`LatentsDomainGroupT`): Unimodal latent representation
258            attention_dict (`dict[str, torch.Tensor]`): The attention scores for each
259                domain in the group.
260
261        Returns:
262            `torch.Tensor`: The fused tensor.
263        """
264        # Apply attention scores to the encodings
265        weighted_encodings = {}
266        for key in attention_dict:
267            if key in encodings:
268                # Perform element-wise multiplication
269                weighted_encodings[key] = (
270                    attention_dict[key].unsqueeze(1) * encodings[key]
271                )
272
273        # Stack the tensors along a new dimension (dimension 0)
274        stacked_tensors = torch.stack(list(weighted_encodings.values()))
275
276        # Apply fusion by summing along the newly created dimension
277        summed_tensor = torch.sum(stacked_tensors, dim=0)
278        return summed_tensor

Fuse the weighted encodings using the attention scores.

Arguments:
  • encodings (LatentsDomainGroupT): Unimodal latent representation
  • attention_dict (dict[str, torch.Tensor]): The attention scores for each domain in the group.
Returns:

torch.Tensor: The fused tensor.

def forward( self, domains: Mapping[str, torch.Tensor], encodings_pre_fusion: Mapping[str, torch.Tensor]) -> dict[str, torch.Tensor]:
280    def forward(
281        self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
282    ) -> dict[str, torch.Tensor]:
283        """
284        Compute keys and queries, match them with dot product and softmax.
285        Does this twice, once with the static query and once with a dynamic query.
286
287        Args:
288            domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
289            encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
290                representation.
291
292        Returns:
293            `dict[str, torch.Tensor]`: the attention scores for each domain in the
294            group.
295        """
296
297        keys = {
298            domain: self.key_layers[domain](encoding)
299            for domain, encoding in domains.items()
300        }
301
302        batch_size = group_batch_size(domains)
303
304        # Retrieve random query
305        query = self.query_layer(self.initial_gw_state.expand(batch_size, -1))
306
307        # Calculate the attention scores
308        attention_dict = _calculate_attention_dict(domains, keys, query)
309
310        if self.n_steps > 0:
311            # Update the query based on the static attention scores
312            for _ in range(min(self.step_limit, self.n_steps)):
313                # Apply the attention scores to the encodings
314                summed_tensor = self.fuse_weighted_encodings(
315                    encodings_pre_fusion, attention_dict
316                )
317
318                # Retrieve query (now it is dependent on the new gw state)
319                query = self.query_layer(summed_tensor)
320
321                # Calculate the attention scores again
322                attention_dict = _calculate_attention_dict(domains, keys, query)
323
324        return attention_dict

Compute keys and queries, match them with dot product and softmax. Does this twice, once with the static query and once with a dynamic query.

Arguments:
  • domains (LatentsDomainGroupT): Group of unimodal latent representations.
  • encodings_pre_fusion (LatentsDomainGroupT): pre-fusion domain latent representation.
Returns:

dict[str, torch.Tensor]: the attention scores for each domain in the group.

Inherited Members
SelectionBase
update_gw_state