shimmer_ssd.modules.domains.pretrained
1from collections.abc import Mapping, Sequence 2from pathlib import Path 3 4from shimmer import DomainModule, GWDecoder, GWEncoder 5from torch.nn import Linear, Module 6 7from shimmer_ssd import PROJECT_DIR 8from shimmer_ssd.ckpt_migrations import ( 9 migrate_model, 10) 11from shimmer_ssd.config import DomainModuleVariant, LoadedDomainConfig 12from shimmer_ssd.errors import ConfigurationError 13from shimmer_ssd.modules.domains.attribute import ( 14 AttributeDomainModule, 15 AttributeLegacyDomainModule, 16 AttributeWithUnpairedDomainModule, 17) 18from shimmer_ssd.modules.domains.text import GRUTextDomainModule, Text2Attr 19from shimmer_ssd.modules.domains.visual import ( 20 VisualDomainModule, 21 VisualLatentDomainModule, 22 VisualLatentDomainWithUnpairedModule, 23) 24 25 26def load_pretrained_module(domain: LoadedDomainConfig) -> DomainModule: 27 domain_checkpoint = Path(domain.checkpoint_path) 28 module: DomainModule 29 match domain.domain_type: 30 case DomainModuleVariant.v: 31 migrate_model( 32 domain_checkpoint, 33 PROJECT_DIR / "shimmer_ssd" / "migrations" / "visual_mod", 34 ) 35 module = VisualDomainModule.load_from_checkpoint( 36 domain_checkpoint, **domain.args 37 ) 38 39 case DomainModuleVariant.v_latents: 40 migrate_model( 41 domain_checkpoint, 42 PROJECT_DIR / "shimmer_ssd" / "migrations" / "visual_mod", 43 ) 44 v_module = VisualDomainModule.load_from_checkpoint( 45 domain_checkpoint, **domain.args 46 ) 47 module = VisualLatentDomainModule(v_module) 48 49 case DomainModuleVariant.v_latents_unpaired: 50 migrate_model( 51 domain_checkpoint, 52 PROJECT_DIR / "shimmer_ssd" / "migrations" / "visual_mod", 53 ) 54 v_module = VisualDomainModule.load_from_checkpoint( 55 domain_checkpoint, **domain.args 56 ) 57 module = VisualLatentDomainWithUnpairedModule(v_module) 58 59 case DomainModuleVariant.attr: 60 migrate_model( 61 domain_checkpoint, 62 PROJECT_DIR / "shimmer_ssd" / "migrations" / "attr_mod", 63 ) 64 module = AttributeDomainModule.load_from_checkpoint( 65 domain_checkpoint, **domain.args 66 ) 67 68 case DomainModuleVariant.attr_unpaired: 69 migrate_model( 70 domain_checkpoint, 71 PROJECT_DIR / "shimmer_ssd" / "migrations" / "attr_mod", 72 ) 73 module = AttributeWithUnpairedDomainModule.load_from_checkpoint( 74 domain_checkpoint, **domain.args 75 ) 76 77 case DomainModuleVariant.attr_legacy: 78 module = AttributeLegacyDomainModule() 79 80 case DomainModuleVariant.t: 81 module = GRUTextDomainModule.load_from_checkpoint( 82 domain_checkpoint, **domain.args, strict=False 83 ) 84 # Freezes the projector 85 # module.embeddings.requires_grad_(False) 86 # module.projector.requires_grad_(False) 87 88 case DomainModuleVariant.t_attr: 89 assert ( 90 "text_model_path" in domain.args 91 ), 'add "text_model_path" to the domain\'s args.' 92 text_model = GRUTextDomainModule.load_from_checkpoint( 93 domain.args["text_model_path"], 94 **domain.args.get("t_args", {}), 95 ) 96 module = Text2Attr.load_from_checkpoint( 97 domain_checkpoint, 98 text_model=text_model, 99 **domain.args.get("model_args", {}), 100 ) 101 102 case _: 103 raise ConfigurationError(f"Unknown domain type {domain.domain_type.name}") 104 return module 105 106 107def get_from_dict_or_val( 108 val: int | Mapping[DomainModuleVariant, int], key: DomainModuleVariant, log: str 109) -> int: 110 """ 111 If val is int, return val, otherwise return val[key] 112 """ 113 if isinstance(val, int): 114 return val 115 116 assert key in val, f"{key} should be defined in {log}." 117 return val[key] 118 119 120def load_pretrained_domain( 121 domain: LoadedDomainConfig, 122 workspace_dim: int, 123 encoders_hidden_dim: int | Mapping[DomainModuleVariant, int], 124 encoders_n_layers: int | Mapping[DomainModuleVariant, int], 125 decoders_hidden_dim: int | Mapping[DomainModuleVariant, int], 126 decoders_n_layers: int | Mapping[DomainModuleVariant, int], 127 is_linear: bool = False, 128 bias: bool = False, 129) -> tuple[DomainModule, Module, Module]: 130 module = load_pretrained_module(domain) 131 encoder_hidden_dim = get_from_dict_or_val( 132 encoders_hidden_dim, domain.domain_type, "global_workspace.encoders.hidden_dim" 133 ) 134 decoder_hidden_dim = get_from_dict_or_val( 135 decoders_hidden_dim, domain.domain_type, "global_workspace.decoders.hidden_dim" 136 ) 137 138 encoder_n_layers = get_from_dict_or_val( 139 encoders_n_layers, domain.domain_type, "global_workspace.encoder.n_layers" 140 ) 141 142 decoder_n_layers = get_from_dict_or_val( 143 decoders_n_layers, domain.domain_type, "global_workspace.decoders.n_layers" 144 ) 145 146 gw_encoder: Module 147 gw_decoder: Module 148 if is_linear: 149 gw_encoder = Linear(module.latent_dim, workspace_dim, bias=bias) 150 gw_decoder = Linear(workspace_dim, module.latent_dim, bias=bias) 151 else: 152 gw_encoder = GWEncoder( 153 module.latent_dim, encoder_hidden_dim, workspace_dim, encoder_n_layers 154 ) 155 gw_decoder = GWDecoder( 156 workspace_dim, decoder_hidden_dim, module.latent_dim, decoder_n_layers 157 ) 158 159 return module, gw_encoder, gw_decoder 160 161 162def load_pretrained_domains( 163 domains: Sequence[LoadedDomainConfig], 164 workspace_dim: int, 165 encoders_hidden_dim: int | Mapping[DomainModuleVariant, int], 166 encoders_n_layers: int | Mapping[DomainModuleVariant, int], 167 decoders_hidden_dim: int | Mapping[DomainModuleVariant, int], 168 decoders_n_layers: int | Mapping[DomainModuleVariant, int], 169 is_linear: bool = False, 170 bias: bool = False, 171) -> tuple[dict[str, DomainModule], dict[str, Module], dict[str, Module]]: 172 modules: dict[str, DomainModule] = {} 173 gw_encoders: dict[str, Module] = {} 174 gw_decoders: dict[str, Module] = {} 175 for domain in domains: 176 if domain.domain_type.kind.value.kind in modules: 177 raise ConfigurationError("Cannot load multiple domains of the same kind.") 178 model, encoder, decoder = load_pretrained_domain( 179 domain, 180 workspace_dim, 181 encoders_hidden_dim, 182 encoders_n_layers, 183 decoders_hidden_dim, 184 decoders_n_layers, 185 is_linear, 186 bias, 187 ) 188 modules[domain.domain_type.kind.value.kind] = model 189 gw_encoders[domain.domain_type.kind.value.kind] = encoder 190 gw_decoders[domain.domain_type.kind.value.kind] = decoder 191 return modules, gw_encoders, gw_decoders
def
load_pretrained_module( domain: shimmer_ssd.config.LoadedDomainConfig) -> shimmer.modules.domain.DomainModule:
27def load_pretrained_module(domain: LoadedDomainConfig) -> DomainModule: 28 domain_checkpoint = Path(domain.checkpoint_path) 29 module: DomainModule 30 match domain.domain_type: 31 case DomainModuleVariant.v: 32 migrate_model( 33 domain_checkpoint, 34 PROJECT_DIR / "shimmer_ssd" / "migrations" / "visual_mod", 35 ) 36 module = VisualDomainModule.load_from_checkpoint( 37 domain_checkpoint, **domain.args 38 ) 39 40 case DomainModuleVariant.v_latents: 41 migrate_model( 42 domain_checkpoint, 43 PROJECT_DIR / "shimmer_ssd" / "migrations" / "visual_mod", 44 ) 45 v_module = VisualDomainModule.load_from_checkpoint( 46 domain_checkpoint, **domain.args 47 ) 48 module = VisualLatentDomainModule(v_module) 49 50 case DomainModuleVariant.v_latents_unpaired: 51 migrate_model( 52 domain_checkpoint, 53 PROJECT_DIR / "shimmer_ssd" / "migrations" / "visual_mod", 54 ) 55 v_module = VisualDomainModule.load_from_checkpoint( 56 domain_checkpoint, **domain.args 57 ) 58 module = VisualLatentDomainWithUnpairedModule(v_module) 59 60 case DomainModuleVariant.attr: 61 migrate_model( 62 domain_checkpoint, 63 PROJECT_DIR / "shimmer_ssd" / "migrations" / "attr_mod", 64 ) 65 module = AttributeDomainModule.load_from_checkpoint( 66 domain_checkpoint, **domain.args 67 ) 68 69 case DomainModuleVariant.attr_unpaired: 70 migrate_model( 71 domain_checkpoint, 72 PROJECT_DIR / "shimmer_ssd" / "migrations" / "attr_mod", 73 ) 74 module = AttributeWithUnpairedDomainModule.load_from_checkpoint( 75 domain_checkpoint, **domain.args 76 ) 77 78 case DomainModuleVariant.attr_legacy: 79 module = AttributeLegacyDomainModule() 80 81 case DomainModuleVariant.t: 82 module = GRUTextDomainModule.load_from_checkpoint( 83 domain_checkpoint, **domain.args, strict=False 84 ) 85 # Freezes the projector 86 # module.embeddings.requires_grad_(False) 87 # module.projector.requires_grad_(False) 88 89 case DomainModuleVariant.t_attr: 90 assert ( 91 "text_model_path" in domain.args 92 ), 'add "text_model_path" to the domain\'s args.' 93 text_model = GRUTextDomainModule.load_from_checkpoint( 94 domain.args["text_model_path"], 95 **domain.args.get("t_args", {}), 96 ) 97 module = Text2Attr.load_from_checkpoint( 98 domain_checkpoint, 99 text_model=text_model, 100 **domain.args.get("model_args", {}), 101 ) 102 103 case _: 104 raise ConfigurationError(f"Unknown domain type {domain.domain_type.name}") 105 return module
def
get_from_dict_or_val( val: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], key: shimmer_ssd.config.DomainModuleVariant, log: str) -> int:
108def get_from_dict_or_val( 109 val: int | Mapping[DomainModuleVariant, int], key: DomainModuleVariant, log: str 110) -> int: 111 """ 112 If val is int, return val, otherwise return val[key] 113 """ 114 if isinstance(val, int): 115 return val 116 117 assert key in val, f"{key} should be defined in {log}." 118 return val[key]
If val is int, return val, otherwise return val[key]
def
load_pretrained_domain( domain: shimmer_ssd.config.LoadedDomainConfig, workspace_dim: int, encoders_hidden_dim: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], encoders_n_layers: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], decoders_hidden_dim: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], decoders_n_layers: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], is_linear: bool = False, bias: bool = False) -> tuple[shimmer.modules.domain.DomainModule, torch.nn.modules.module.Module, torch.nn.modules.module.Module]:
121def load_pretrained_domain( 122 domain: LoadedDomainConfig, 123 workspace_dim: int, 124 encoders_hidden_dim: int | Mapping[DomainModuleVariant, int], 125 encoders_n_layers: int | Mapping[DomainModuleVariant, int], 126 decoders_hidden_dim: int | Mapping[DomainModuleVariant, int], 127 decoders_n_layers: int | Mapping[DomainModuleVariant, int], 128 is_linear: bool = False, 129 bias: bool = False, 130) -> tuple[DomainModule, Module, Module]: 131 module = load_pretrained_module(domain) 132 encoder_hidden_dim = get_from_dict_or_val( 133 encoders_hidden_dim, domain.domain_type, "global_workspace.encoders.hidden_dim" 134 ) 135 decoder_hidden_dim = get_from_dict_or_val( 136 decoders_hidden_dim, domain.domain_type, "global_workspace.decoders.hidden_dim" 137 ) 138 139 encoder_n_layers = get_from_dict_or_val( 140 encoders_n_layers, domain.domain_type, "global_workspace.encoder.n_layers" 141 ) 142 143 decoder_n_layers = get_from_dict_or_val( 144 decoders_n_layers, domain.domain_type, "global_workspace.decoders.n_layers" 145 ) 146 147 gw_encoder: Module 148 gw_decoder: Module 149 if is_linear: 150 gw_encoder = Linear(module.latent_dim, workspace_dim, bias=bias) 151 gw_decoder = Linear(workspace_dim, module.latent_dim, bias=bias) 152 else: 153 gw_encoder = GWEncoder( 154 module.latent_dim, encoder_hidden_dim, workspace_dim, encoder_n_layers 155 ) 156 gw_decoder = GWDecoder( 157 workspace_dim, decoder_hidden_dim, module.latent_dim, decoder_n_layers 158 ) 159 160 return module, gw_encoder, gw_decoder
def
load_pretrained_domains( domains: Sequence[shimmer_ssd.config.LoadedDomainConfig], workspace_dim: int, encoders_hidden_dim: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], encoders_n_layers: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], decoders_hidden_dim: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], decoders_n_layers: int | Mapping[shimmer_ssd.config.DomainModuleVariant, int], is_linear: bool = False, bias: bool = False) -> tuple[dict[str, shimmer.modules.domain.DomainModule], dict[str, torch.nn.modules.module.Module], dict[str, torch.nn.modules.module.Module]]:
163def load_pretrained_domains( 164 domains: Sequence[LoadedDomainConfig], 165 workspace_dim: int, 166 encoders_hidden_dim: int | Mapping[DomainModuleVariant, int], 167 encoders_n_layers: int | Mapping[DomainModuleVariant, int], 168 decoders_hidden_dim: int | Mapping[DomainModuleVariant, int], 169 decoders_n_layers: int | Mapping[DomainModuleVariant, int], 170 is_linear: bool = False, 171 bias: bool = False, 172) -> tuple[dict[str, DomainModule], dict[str, Module], dict[str, Module]]: 173 modules: dict[str, DomainModule] = {} 174 gw_encoders: dict[str, Module] = {} 175 gw_decoders: dict[str, Module] = {} 176 for domain in domains: 177 if domain.domain_type.kind.value.kind in modules: 178 raise ConfigurationError("Cannot load multiple domains of the same kind.") 179 model, encoder, decoder = load_pretrained_domain( 180 domain, 181 workspace_dim, 182 encoders_hidden_dim, 183 encoders_n_layers, 184 decoders_hidden_dim, 185 decoders_n_layers, 186 is_linear, 187 bias, 188 ) 189 modules[domain.domain_type.kind.value.kind] = model 190 gw_encoders[domain.domain_type.kind.value.kind] = encoder 191 gw_decoders[domain.domain_type.kind.value.kind] = decoder 192 return modules, gw_encoders, gw_decoders