tensor_name = tensor_names.get(tensor)
if tensor_name is None:
continue
+ mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
mapping[key] = (tensor, tensor_name)
for bid in range(n_blocks):
if tensor_name is None:
continue
tensor_name = tensor_name.format(bid = bid)
+ mapping[tensor_name] = (tensor, tensor_name)
for key in keys:
key = key.format(bid = bid)
mapping[key] = (tensor, tensor_name)
- def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> tuple[MODEL_TENSOR, str] | None:
+ def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key)
if result is not None:
return result
return (result[0], result[1] + suffix)
return None
- def get_name(self, key: str, try_suffixes: Sequence[str]) -> str | None:
+ def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None:
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None:
return None
return result[1]
- def get_type(self, key: str, try_suffixes: Sequence[str]) -> MODEL_TENSOR | None:
+ def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None:
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None:
return None