Skip to content

Commit c895213

Browse files
Merge pull request #191 from stochasticai/dev
dev
2 parents 21f7902 + 34f68c8 commit c895213

File tree

5 files changed

+8
-7
lines changed

5 files changed

+8
-7
lines changed

src/xturing/datasets/instruction_dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def __init__(
5959
else:
6060
path = Path(path)
6161
assert Path(path).exists(), "path does not exist"
62-
6362
if path.is_dir():
6463
self.data = load_from_disk(str(path))
6564
elif path.suffix == ".jsonl":
@@ -123,7 +122,7 @@ def __getitem__(self, idx):
123122
return self.data["train"][idx]
124123

125124
def save(self, path):
126-
return self.data.save_to_disk(path)
125+
return self.data["train"].save_to_disk(path)
127126

128127
@classmethod
129128
def generate_dataset(

src/xturing/engines/causal.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ def __init__(
175175

176176
self.loss_fct = CrossEntropyLoss()
177177

178+
def set_from_state_dict(self, state_dict, strict=False):
179+
self.model.load_state_dict(state_dict, strict=strict)
180+
178181
def save(self, saving_path: Union[str, Path]):
179182
# Save HF config file
180183
self.model.config.save_pretrained(str(saving_path))

src/xturing/engines/llama_engine.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,6 @@ def noop(*args, **kwargs):
161161
state_dict = torch.load(
162162
weights_path / Path("pytorch_model.bin"), map_location="cpu"
163163
)
164-
new_state_dict = {}
165-
for key, value in state_dict.items():
166-
new_state_dict[key[6:]] = value
167-
model.load_state_dict(new_state_dict, strict=False)
168164

169165
if warmup_autotune:
170166
autotune_warmup(model)
@@ -192,3 +188,5 @@ def noop(*args, **kwargs):
192188
torch.nn.init.kaiming_uniform_ = saved_kaiming_uniform_
193189
torch.nn.init.uniform_ = saved_uniform_
194190
torch.nn.init.normal_ = saved_normal_
191+
192+
self.set_from_state_dict(state_dict)

src/xturing/utils/hub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def bar_progress(current, total, width=80):
4444
sys.stdout.flush()
4545

4646
try:
47-
wget.download(url, str(zip_filename), bar=bar_progress)
47+
wget.download(url, str(zip_filename))
4848

4949
with ZipFile(zip_filename, "r") as zip_ref:
5050
zip_ref.extractall(path=model_dir)

tests/xturing/datasets/test_instruction_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@ def test_features_dataset():
4242
"target": "second text",
4343
"instruction": "second instruction",
4444
}
45+
dataset.save(".")

0 commit comments

Comments
 (0)