shimmer_ssd.modules.domains.pretrained

  1from collections.abc import Mapping, Sequence
  2from pathlib import Path
  3
  4from shimmer import DomainModule, GWDecoder, GWEncoder
  5from torch.nn import Linear, Module
  6
  7from shimmer_ssd import PROJECT_DIR
  8from shimmer_ssd.ckpt_migrations import (
  9    migrate_model,
 10)
 11from shimmer_ssd.config import DomainModuleVariant, LoadedDomainConfig
 12from shimmer_ssd.errors import ConfigurationError
 13from shimmer_ssd.modules.domains.attribute import (
 14    AttributeDomainModule,
 15    AttributeLegacyDomainModule,
 16    AttributeWithUnpairedDomainModule,
 17)
 18from shimmer_ssd.modules.domains.text import GRUTextDomainModule, Text2Attr
 19from shimmer_ssd.modules.domains.visual import (
 20    VisualDomainModule,
 21    VisualLatentDomainModule,
 22    VisualLatentDomainWithUnpairedModule,
 23)
 24
 25
 26def load_pretrained_module(domain: LoadedDomainConfig) -> DomainModule:
 27    domain_checkpoint = Path(domain.checkpoint_path)
 28    module: DomainModule
 29    match domain.domain_type:
 30        case DomainModuleVariant.v:
 31            migrate_model(
 32                domain_checkpoint,
 33                PROJECT_DIR / "shimmer_ssd" / "migrations" / "visual_mod",
 34            )
 35            module = VisualDomainModule.load_from_checkpoint(
 36                domain_checkpoint, **domain.args
 37            )
 38
 39        case DomainModuleVariant.v_latents:
 40            migrate_model(
 41                domain_checkpoint,
 42                PROJECT_DIR / "shimmer_ssd" / "migrations" / "visual_mod",
 43            )
 44            v_module = VisualDomainModule.load_from_checkpoint(
 45                domain_checkpoint, **domain.args
 46            )
 47            module = VisualLatentDomainModule(v_module)
 48
 49        case DomainModuleVariant.v_latents_unpaired:
 50            migrate_model(
 51                domain_checkpoint,
 52                PROJECT_DIR / "shimmer_ssd" / "migrations" / "visual_mod",
 53            )
 54            v_module = VisualDomainModule.load_from_checkpoint(
 55                domain_checkpoint, **domain.args
 56            )
 57            module = VisualLatentDomainWithUnpairedModule(v_module)
 58
 59        case DomainModuleVariant.attr:
 60            migrate_model(
 61                domain_checkpoint,
 62                PROJECT_DIR / "shimmer_ssd" / "migrations" / "attr_mod",
 63            )
 64            module = AttributeDomainModule.load_from_checkpoint(
 65                domain_checkpoint, **domain.args
 66            )
 67
 68        case DomainModuleVariant.attr_unpaired:
 69            migrate_model(
 70                domain_checkpoint,
 71                PROJECT_DIR / "shimmer_ssd" / "migrations" / "attr_mod",
 72            )
 73            module = AttributeWithUnpairedDomainModule.load_from_checkpoint(
 74                domain_checkpoint, **domain.args
 75            )
 76
 77        case DomainModuleVariant.attr_legacy:
 78            module = AttributeLegacyDomainModule()
 79
 80        case DomainModuleVariant.t:
 81            module = GRUTextDomainModule.load_from_checkpoint(
 82                domain_checkpoint, **domain.args, strict=False
 83            )
 84            # Freezes the projector
 85            # module.embeddings.requires_grad_(False)
 86            # module.projector.requires_grad_(False)
 87
 88        case DomainModuleVariant.t_attr:
 89            assert (
 90                "text_model_path" in domain.args
 91            ), 'add "text_model_path" to the domain\'s args.'
 92            text_model = GRUTextDomainModule.load_from_checkpoint(
 93                domain.args["text_model_path"],
 94                **domain.args.get("t_args", {}),
 95            )
 96            module = Text2Attr.load_from_checkpoint(
 97                domain_checkpoint,
 98                text_model=text_model,
 99                **domain.args.get("model_args", {}),
100            )
101
102        case _:
103            raise ConfigurationError(f"Unknown domain type {domain.domain_type.name}")
104    return module
105
106
107def get_from_dict_or_val(
108    val: int | Mapping[DomainModuleVariant, int], key: DomainModuleVariant, log: str
109) -> int:
110    """
111    If val is int, return val, otherwise return val[key]
112    """
113    if isinstance(val, int):
114        return val
115
116    assert key in val, f"{key} should be defined in {log}."
117    return val[key]
118
119
120def load_pretrained_domain(
121    domain: LoadedDomainConfig,
122    workspace_dim: int,
123    encoders_hidden_dim: int | Mapping[DomainModuleVariant, int],
124    encoders_n_layers: int | Mapping[DomainModuleVariant, int],
125    decoders_hidden_dim: int | Mapping[DomainModuleVariant, int],
126    decoders_n_layers: int | Mapping[DomainModuleVariant, int],
127    is_linear: bool = False,
128    bias: bool = False,
129) -> tuple[DomainModule, Module, Module]:
130    module = load_pretrained_module(domain)
131    encoder_hidden_dim = get_from_dict_or_val(
132        encoders_hidden_dim, domain.domain_type, "global_workspace.encoders.hidden_dim"
133    )
134    decoder_hidden_dim = get_from_dict_or_val(
135        decoders_hidden_dim, domain.domain_type, "global_workspace.decoders.hidden_dim"
136    )
137
138    encoder_n_layers = get_from_dict_or_val(
139        encoders_n_layers, domain.domain_type, "global_workspace.encoder.n_layers"
140    )
141
142    decoder_n_layers = get_from_dict_or_val(
143        decoders_n_layers, domain.domain_type, "global_workspace.decoders.n_layers"
144    )
145
146    gw_encoder: Module
147    gw_decoder: Module
148    if is_linear:
149        gw_encoder = Linear(module.latent_dim, workspace_dim, bias=bias)
150        gw_decoder = Linear(workspace_dim, module.latent_dim, bias=bias)
151    else:
152        gw_encoder = GWEncoder(
153            module.latent_dim, encoder_hidden_dim, workspace_dim, encoder_n_layers
154        )
155        gw_decoder = GWDecoder(
156            workspace_dim, decoder_hidden_dim, module.latent_dim, decoder_n_layers
157        )
158
159    return module, gw_encoder, gw_decoder
160
161
162def load_pretrained_domains(
163    domains: Sequence[LoadedDomainConfig],
164    workspace_dim: int,
165    encoders_hidden_dim: int | Mapping[DomainModuleVariant, int],
166    encoders_n_layers: int | Mapping[DomainModuleVariant, int],
167    decoders_hidden_dim: int | Mapping[DomainModuleVariant, int],
168    decoders_n_layers: int | Mapping[DomainModuleVariant, int],
169    is_linear: bool = False,
170    bias: bool = False,
171) -> tuple[dict[str, DomainModule], dict[str, Module], dict[str, Module]]:
172    modules: dict[str, DomainModule] = {}
173    gw_encoders: dict[str, Module] = {}
174    gw_decoders: dict[str, Module] = {}
175    for domain in domains:
176        if domain.domain_type.kind.value.kind in modules:
177            raise ConfigurationError("Cannot load multiple domains of the same kind.")
178        model, encoder, decoder = load_pretrained_domain(
179            domain,
180            workspace_dim,
181            encoders_hidden_dim,
182            encoders_n_layers,
183            decoders_hidden_dim,
184            decoders_n_layers,
185            is_linear,
186            bias,
187        )
188        modules[domain.domain_type.kind.value.kind] = model
189        gw_encoders[domain.domain_type.kind.value.kind] = encoder
190        gw_decoders[domain.domain_type.kind.value.kind] = decoder
191    return modules, gw_encoders, gw_decoders
def load_pretrained_module( domain: shimmer_ssd.config.LoadedDomainConfig) -> shimmer.modules.domain.DomainModule:
 27def load_pretrained_module(domain: LoadedDomainConfig) -> DomainModule:
 28    domain_checkpoint = Path(domain.checkpoint_path)
 29    module: DomainModule
 30    match domain.domain_type:
 31        case DomainModuleVariant.v:
 32            migrate_model(
 33                domain_checkpoint,
 34                PROJECT_DIR / "shimmer_ssd" / "migrations" / "visual_mod",
 35            )
 36            module = VisualDomainModule.load_from_checkpoint(
 37                domain_checkpoint, **domain.args
 38            )
 39
 40        case DomainModuleVariant.v_latents:
 41            migrate_model(
 42                domain_checkpoint,
 43                PROJECT_DIR / "shimmer_ssd" / "migrations" / "visual_mod",
 44            )
 45            v_module = VisualDomainModule.load_from_checkpoint(
 46                domain_checkpoint, **domain.args
 47            )
 48            module = VisualLatentDomainModule(v_module)
 49
 50        case DomainModuleVariant.v_latents_unpaired:
 51            migrate_model(
 52                domain_checkpoint,
 53                PROJECT_DIR / "shimmer_ssd" / "migrations" / "visual_mod",
 54            )
 55            v_module = VisualDomainModule.load_from_checkpoint(
 56                domain_checkpoint, **domain.args
 57            )
 58            module = VisualLatentDomainWithUnpairedModule(v_module)
 59
 60        case DomainModuleVariant.attr:
 61            migrate_model(
 62                domain_checkpoint,
 63                PROJECT_DIR / "shimmer_ssd" / "migrations" / "attr_mod",
 64            )
 65            module = AttributeDomainModule.load_from_checkpoint(
 66                domain_checkpoint, **domain.args
 67            )
 68
 69        case DomainModuleVariant.attr_unpaired:
 70            migrate_model(
 71                domain_checkpoint,
 72                PROJECT_DIR / "shimmer_ssd" / "migrations" / "attr_mod",
 73            )
 74            module = AttributeWithUnpairedDomainModule.load_from_checkpoint(
 75                domain_checkpoint, **domain.args
 76            )
 77
 78        case DomainModuleVariant.attr_legacy:
 79            module = AttributeLegacyDomainModule()
 80
 81        case DomainModuleVariant.t:
 82            module = GRUTextDomainModule.load_from_checkpoint(
 83                domain_checkpoint, **domain.args, strict=False
 84            )
 85            # Freezes the projector
 86            # module.embeddings.requires_grad_(False)
 87            # module.projector.requires_grad_(False)
 88
 89        case DomainModuleVariant.t_attr:
 90            assert (
 91                "text_model_path" in domain.args
 92            ), 'add "text_model_path" to the domain\'s args.'
 93            text_model = GRUTextDomainModule.load_from_checkpoint(
 94                domain.args["text_model_path"],
 95                **domain.args.get("t_args", {}),
 96            )
 97            module = Text2Attr.load_from_checkpoint(
 98                domain_checkpoint,
 99                text_model=text_model,
100                **domain.args.get("model_args", {}),
101            )
102
103        case _:
104            raise ConfigurationError(f"Unknown domain type {domain.domain_type.name}")
105    return module
def get_from_dict_or_val( val: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], key: shimmer_ssd.config.DomainModuleVariant, log: str) -> int:
108def get_from_dict_or_val(
109    val: int | Mapping[DomainModuleVariant, int], key: DomainModuleVariant, log: str
110) -> int:
111    """
112    If val is int, return val, otherwise return val[key]
113    """
114    if isinstance(val, int):
115        return val
116
117    assert key in val, f"{key} should be defined in {log}."
118    return val[key]

If val is int, return val, otherwise return val[key]

def load_pretrained_domain( domain: shimmer_ssd.config.LoadedDomainConfig, workspace_dim: int, encoders_hidden_dim: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], encoders_n_layers: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], decoders_hidden_dim: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], decoders_n_layers: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], is_linear: bool = False, bias: bool = False) -> tuple[shimmer.modules.domain.DomainModule, torch.nn.modules.module.Module, torch.nn.modules.module.Module]:
121def load_pretrained_domain(
122    domain: LoadedDomainConfig,
123    workspace_dim: int,
124    encoders_hidden_dim: int | Mapping[DomainModuleVariant, int],
125    encoders_n_layers: int | Mapping[DomainModuleVariant, int],
126    decoders_hidden_dim: int | Mapping[DomainModuleVariant, int],
127    decoders_n_layers: int | Mapping[DomainModuleVariant, int],
128    is_linear: bool = False,
129    bias: bool = False,
130) -> tuple[DomainModule, Module, Module]:
131    module = load_pretrained_module(domain)
132    encoder_hidden_dim = get_from_dict_or_val(
133        encoders_hidden_dim, domain.domain_type, "global_workspace.encoders.hidden_dim"
134    )
135    decoder_hidden_dim = get_from_dict_or_val(
136        decoders_hidden_dim, domain.domain_type, "global_workspace.decoders.hidden_dim"
137    )
138
139    encoder_n_layers = get_from_dict_or_val(
140        encoders_n_layers, domain.domain_type, "global_workspace.encoder.n_layers"
141    )
142
143    decoder_n_layers = get_from_dict_or_val(
144        decoders_n_layers, domain.domain_type, "global_workspace.decoders.n_layers"
145    )
146
147    gw_encoder: Module
148    gw_decoder: Module
149    if is_linear:
150        gw_encoder = Linear(module.latent_dim, workspace_dim, bias=bias)
151        gw_decoder = Linear(workspace_dim, module.latent_dim, bias=bias)
152    else:
153        gw_encoder = GWEncoder(
154            module.latent_dim, encoder_hidden_dim, workspace_dim, encoder_n_layers
155        )
156        gw_decoder = GWDecoder(
157            workspace_dim, decoder_hidden_dim, module.latent_dim, decoder_n_layers
158        )
159
160    return module, gw_encoder, gw_decoder
def load_pretrained_domains( domains: Sequence[shimmer_ssd.config.LoadedDomainConfig], workspace_dim: int, encoders_hidden_dim: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], encoders_n_layers: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], decoders_hidden_dim: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], decoders_n_layers: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], is_linear: bool = False, bias: bool = False) -> tuple[dict[str, shimmer.modules.domain.DomainModule], dict[str, torch.nn.modules.module.Module], dict[str, torch.nn.modules.module.Module]]:
163def load_pretrained_domains(
164    domains: Sequence[LoadedDomainConfig],
165    workspace_dim: int,
166    encoders_hidden_dim: int | Mapping[DomainModuleVariant, int],
167    encoders_n_layers: int | Mapping[DomainModuleVariant, int],
168    decoders_hidden_dim: int | Mapping[DomainModuleVariant, int],
169    decoders_n_layers: int | Mapping[DomainModuleVariant, int],
170    is_linear: bool = False,
171    bias: bool = False,
172) -> tuple[dict[str, DomainModule], dict[str, Module], dict[str, Module]]:
173    modules: dict[str, DomainModule] = {}
174    gw_encoders: dict[str, Module] = {}
175    gw_decoders: dict[str, Module] = {}
176    for domain in domains:
177        if domain.domain_type.kind.value.kind in modules:
178            raise ConfigurationError("Cannot load multiple domains of the same kind.")
179        model, encoder, decoder = load_pretrained_domain(
180            domain,
181            workspace_dim,
182            encoders_hidden_dim,
183            encoders_n_layers,
184            decoders_hidden_dim,
185            decoders_n_layers,
186            is_linear,
187            bias,
188        )
189        modules[domain.domain_type.kind.value.kind] = model
190        gw_encoders[domain.domain_type.kind.value.kind] = encoder
191        gw_decoders[domain.domain_type.kind.value.kind] = decoder
192    return modules, gw_encoders, gw_decoders