'ssm_d': 'State space model skip connection',
'ssm_dt': 'State space model time step',
'ssm_out': 'State space model output projection',
- 'blk': 'Block'
+ 'blk': 'Block',
+ 'enc': 'Encoder',
+ 'dec': 'Decoder',
}
expanded_words = []
tensor_group_name = "base"
if tensor_components[0] == 'blk':
tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}"
+ elif tensor_components[0] in ['enc', 'dec'] and tensor_components[1] == 'blk':
+ tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}.{tensor_components[2]}"
+ elif tensor_components[0] in ['enc', 'dec']:
+ tensor_group_name = f"{tensor_components[0]}"
# Check if new Tensor Group
if tensor_group_name not in tensor_groups: