Skip to content

Commit 21f7902

Browse files
Merge pull request #180 from cnbeining/wandb
[WIP] feat: Add generic Wandb integration
2 parents 07cb039 + 5bf9d33 commit 21f7902

File tree

6 files changed

+50
-26
lines changed

6 files changed

+50
-26
lines changed

examples/int4_finetuning/LLaMA_lora_int4.ipynb

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@
5757
"source": [
5858
"from xturing.datasets.instruction_dataset import InstructionDataset\n",
5959
"from xturing.models import BaseModel\n",
60+
"from pytorch_lightning.loggers import WandbLogger\n",
61+
"\n",
62+
"# Initializes WandB integration \n",
63+
"wandb_logger = WandbLogger()\n",
6064
"\n",
6165
"instruction_dataset = InstructionDataset(\"../llama/alpaca_data\")\n",
6266
"# Initializes the model\n",
@@ -65,46 +69,46 @@
6569
},
6670
{
6771
"cell_type": "markdown",
68-
"source": [
69-
"## 3. Start the finetuning"
70-
],
7172
"metadata": {
7273
"collapsed": false
73-
}
74+
},
75+
"source": [
76+
"## 3. Start the finetuning"
77+
]
7478
},
7579
{
7680
"cell_type": "code",
7781
"execution_count": null,
82+
"metadata": {
83+
"collapsed": false
84+
},
7885
"outputs": [],
7986
"source": [
8087
"# Finetuned the model\n",
81-
"model.finetune(dataset=instruction_dataset)"
82-
],
83-
"metadata": {
84-
"collapsed": false
85-
}
88+
"model.finetune(dataset=instruction_dataset, logger=wandb_logger)"
89+
]
8690
},
8791
{
8892
"cell_type": "markdown",
89-
"source": [
90-
"## 4. Generate an output text with the fine-tuned model"
91-
],
9293
"metadata": {
9394
"collapsed": false
94-
}
95+
},
96+
"source": [
97+
"## 4. Generate an output text with the fine-tuned model"
98+
]
9599
},
96100
{
97101
"cell_type": "code",
98102
"execution_count": null,
103+
"metadata": {
104+
"collapsed": false
105+
},
99106
"outputs": [],
100107
"source": [
101108
"# Once the model has been finetuned, you can start doing inferences\n",
102109
"output = model.generate(texts=[\"Why LLM models are becoming so important?\"])\n",
103110
"print(\"Generated output by the model: {}\".format(output))"
104-
],
105-
"metadata": {
106-
"collapsed": false
107-
}
111+
]
108112
}
109113
],
110114
"metadata": {

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ dependencies = [
5959
"pydantic >= 1.10.0",
6060
"rouge-score >= 0.1.2",
6161
"accelerate",
62+
"wandb",
6263
]
6364

6465
[project.scripts]

src/xturing/models/causal.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from pathlib import Path
3-
from typing import Iterable, List, Optional, Union
3+
from typing import Iterable, List, Optional, Union, Type
44

55
import torch
66
from torch.utils.data import DataLoader
@@ -18,6 +18,7 @@
1818
from xturing.trainers.base import BaseTrainer
1919
from xturing.trainers.lightning_trainer import LightningTrainer
2020
from xturing.utils.logging import configure_logger
21+
from pytorch_lightning.loggers import Logger
2122

2223
logger = configure_logger(__name__)
2324

@@ -63,7 +64,8 @@ def _make_collate_fn(self, dataset: Union[TextDataset, InstructionDataset]):
6364
dataset.meta,
6465
)
6566

66-
def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
67+
def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset],
68+
logger: Union[Logger, Iterable[Logger], bool] = True):
6769
return BaseTrainer.create(
6870
LightningTrainer.config_name,
6971
self.engine,
@@ -73,14 +75,16 @@ def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
7375
int(self.finetuning_args.batch_size),
7476
float(self.finetuning_args.learning_rate),
7577
self.finetuning_args.optimizer_name,
78+
logger=logger,
7679
)
7780

78-
def finetune(self, dataset: Union[TextDataset, InstructionDataset]):
81+
def finetune(self, dataset: Union[TextDataset, InstructionDataset],
82+
logger: Union[Logger, Iterable[Logger], bool] = True):
7983
assert dataset.config_name in [
8084
"text_dataset",
8185
"instruction_dataset",
8286
], "Please make sure the dataset_type is text_dataset or instruction_dataset"
83-
trainer = self._make_trainer(dataset)
87+
trainer = self._make_trainer(dataset, logger)
8488
trainer.fit()
8589

8690
def evaluate(self, dataset: Union[TextDataset, InstructionDataset]):
@@ -188,7 +192,8 @@ class CausalLoraModel(CausalModel):
188192
def __init__(self, engine: str, weights_path: Optional[str] = None):
189193
super().__init__(engine, weights_path)
190194

191-
def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
195+
def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset],
196+
logger: Union[Logger, Iterable[Logger], bool] = True):
192197
return BaseTrainer.create(
193198
LightningTrainer.config_name,
194199
self.engine,
@@ -200,6 +205,7 @@ def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
200205
self.finetuning_args.optimizer_name,
201206
True,
202207
True,
208+
logger=logger,
203209
)
204210

205211

src/xturing/models/llama.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import List, Optional, Union
1+
from typing import Iterable, List, Optional, Union
2+
from pytorch_lightning.loggers import Logger
23

34
from xturing.engines.llama_engine import (
45
LLamaEngine,
@@ -50,7 +51,8 @@ def __init__(self, weights_path: Optional[str] = None):
5051
class LlamaLoraInt4(CausalLoraInt8Model):
5152
config_name: str = "llama_lora_int4"
5253

53-
def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
54+
def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset],
55+
logger: Union[Logger, Iterable[Logger], bool] = True):
5456
return BaseTrainer.create(
5557
LightningTrainer.config_name,
5658
self.engine,
@@ -63,6 +65,7 @@ def _make_trainer(self, dataset: Union[TextDataset, InstructionDataset]):
6365
True,
6466
True,
6567
lora_type=32,
68+
logger=logger,
6669
)
6770

6871
def __init__(self, weights_path: Optional[str] = None):

src/xturing/models/stable_diffusion.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@ class StableDiffusion:
1010
def __init__(self, weights_path: str):
1111
pass
1212

13-
def finetune(self, dataset: Text2ImageDataset):
13+
def finetune(self, dataset: Text2ImageDataset, logger=True):
14+
"""Finetune Stable Diffusion model on a given dataset.
15+
16+
Args:
17+
dataset (Text2ImageDataset): Dataset to finetune on.
18+
logger (bool, optional): To be setup with a Pytorch Lightning logger when implemented."""
1419
pass
1520

1621
def generate(

src/xturing/trainers/lightning_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import tempfile
44
import uuid
55
from pathlib import Path
6-
from typing import Optional, Union
6+
from typing import Iterable, Optional, Union, Type
77

88
import pytorch_lightning as pl
99
import torch
1010
from deepspeed.ops.adam import DeepSpeedCPUAdam
1111
from pytorch_lightning import callbacks
1212
from pytorch_lightning.trainer.trainer import Trainer
13+
from pytorch_lightning.loggers import Logger
1314

1415
from xturing.config import DEFAULT_DEVICE, IS_INTERACTIVE
1516
from xturing.datasets.base import BaseDataset
@@ -101,6 +102,7 @@ def __init__(
101102
use_deepspeed: bool = False,
102103
max_training_time_in_secs: Optional[int] = None,
103104
lora_type: int = 16,
105+
logger: Union[Logger, Iterable[Logger], bool] = True,
104106
):
105107
self.lightning_model = TuringLightningModule(
106108
model_engine=model_engine,
@@ -145,6 +147,7 @@ def __init__(
145147
callbacks=training_callbacks,
146148
enable_checkpointing=False,
147149
log_every_n_steps=50,
150+
logger=logger,
148151
)
149152
elif not use_lora and not use_deepspeed:
150153
self.trainer = Trainer(
@@ -154,6 +157,7 @@ def __init__(
154157
callbacks=training_callbacks,
155158
enable_checkpointing=True,
156159
log_every_n_steps=50,
160+
logger=logger,
157161
)
158162
else:
159163
training_callbacks = [
@@ -179,6 +183,7 @@ def __init__(
179183
callbacks=training_callbacks,
180184
enable_checkpointing=True,
181185
log_every_n_steps=50,
186+
logger=logger,
182187
)
183188

184189
def fit(self):

0 commit comments

Comments
 (0)