Skip to content

Commit a869570

Browse files
tests(mm): attempt to fix windows issue w/ mm tests
1 parent 40601d1 commit a869570

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

invokeai/backend/model_manager/util/model_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
4343
"F64": torch.float64,
4444
}[info["dtype"]]
4545

46-
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
46+
checkpoint[key] = torch.empty(tuple(info["shape"]), dtype=dtype, device=device)
4747

4848
return checkpoint
4949

tests/backend/patches/lora_conversions/lora_state_dicts/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@
44
def keys_to_mock_state_dict(keys: dict[str, list[int]]) -> dict[str, torch.Tensor]:
55
state_dict: dict[str, torch.Tensor] = {}
66
for k, shape in keys.items():
7-
state_dict[k] = torch.empty(shape)
7+
state_dict[k] = torch.empty(tuple(shape))
88
return state_dict

tests/model_identification/stripped_model_on_disk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def dress(cls, v: Any):
6767
)
6868
case {"shape": shape, "dtype": dtype_str, "fakeTensor": True}:
6969
dtype = cls.STR_TO_DTYPE[dtype_str]
70-
return torch.empty(shape, dtype=dtype)
70+
return torch.empty(tuple(shape), dtype=dtype)
7171
case dict():
7272
return {k: cls.dress(v) for k, v in v.items()}
7373
case list() | tuple():

0 commit comments

Comments
 (0)