Skip to content

Commit f6bfaee

Browse files
Merge pull request #7 from jeremymanning/main
Add flexible device support: CPU, MPS, and unlimited GPU scaling
2 parents e2e16f9 + 40920c2 commit f6bfaee

File tree

3 files changed

+147
-46
lines changed

3 files changed

+147
-46
lines changed

code/generate_figures.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,26 @@
2121
from llm_stylometry.cli_utils import safe_print, format_header, is_windows
2222

2323

24-
def train_models():
24+
def train_models(max_gpus=None):
2525
"""Train all models from scratch."""
2626
safe_print("\n" + "=" * 60)
2727
safe_print("Training Models from Scratch")
2828
safe_print("=" * 60)
2929
warning = "[WARNING]" if is_windows() else "⚠️"
30+
# Check device availability
31+
import torch
32+
device_info = ""
33+
if torch.cuda.is_available():
34+
gpu_count = torch.cuda.device_count()
35+
device_info = f"CUDA GPUs available: {gpu_count}"
36+
elif torch.backends.mps.is_available():
37+
device_info = "Apple Metal Performance Shaders (MPS) available"
38+
else:
39+
device_info = "CPU only (training will be slow)"
40+
3041
safe_print(f"\n{warning} Warning: This will train 80 models (8 authors × 10 seeds)")
31-
safe_print(" This requires a CUDA GPU and will take several hours.")
42+
safe_print(f" Device: {device_info}")
43+
safe_print(" Training time depends on hardware (hours on GPU, days on CPU)")
3244

3345
response = input("\nProceed with training? [y/N]: ")
3446
if response.lower() != 'y':
@@ -66,6 +78,10 @@ def train_models():
6678
env['NO_MULTIPROCESSING'] = '1'
6779
# Set PyTorch memory management for better GPU memory usage
6880
env['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
81+
# Pass through max GPUs limit if specified
82+
if max_gpus:
83+
env['MAX_GPUS'] = str(max_gpus)
84+
safe_print(f"Limiting to {max_gpus} GPU(s)")
6985
# Run without capturing output so we can see progress
7086
result = subprocess.run([sys.executable, 'code/main.py'], env=env, check=False)
7187
if result.returncode != 0:
@@ -182,6 +198,13 @@ def main():
182198
help='List available figures'
183199
)
184200

201+
parser.add_argument(
202+
'--max-gpus', '-g',
203+
type=int,
204+
help='Maximum number of GPUs to use for training (default: all available)',
205+
default=None
206+
)
207+
185208
args = parser.parse_args()
186209

187210
if args.list:
@@ -199,7 +222,7 @@ def main():
199222

200223
# Train models if requested
201224
if args.train:
202-
if not train_models():
225+
if not train_models(max_gpus=args.max_gpus):
203226
return 1
204227
# Update data path to use newly generated results
205228
args.data = 'data/model_results.pkl'

code/main.py

Lines changed: 93 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,20 @@ def tqdm(iterable, *args, **kwargs):
3636
logging.basicConfig(level=logging.INFO)
3737
logger = logging.getLogger(__name__)
3838

39-
if not torch.cuda.is_available():
40-
raise Exception("No GPU available")
39+
# Detect available devices
40+
def get_device_info():
41+
"""Detect and return device configuration."""
42+
if torch.cuda.is_available():
43+
device_count = torch.cuda.device_count()
44+
return "cuda", device_count
45+
elif torch.backends.mps.is_available():
46+
# Apple Metal Performance Shaders (MPS) backend
47+
return "mps", 1
48+
else:
49+
return "cpu", 1
50+
51+
device_type, device_count = get_device_info()
52+
logger.info(f"Device type: {device_type}, Count: {device_count}")
4153

4254
experiments = []
4355
for seed in range(10):
@@ -51,16 +63,26 @@ def tqdm(iterable, *args, **kwargs):
5163
)
5264

5365

54-
def run_experiment(exp: Experiment, gpu_queue):
66+
def run_experiment(exp: Experiment, device_queue, device_type="cuda"):
5567
try:
5668
logging.basicConfig(level=logging.INFO)
5769
logger = logging.getLogger(__name__)
5870

59-
# Get an available GPU id
60-
gpu_id = gpu_queue.get()
71+
# Get an available device id
72+
device_id = device_queue.get() if device_queue else 0
6173
logger.info(f"Starting experiment: {exp.name}")
62-
torch.cuda.set_device(gpu_id)
63-
device = torch.device("cuda", index=gpu_id)
74+
75+
# Set up device based on type
76+
if device_type == "cuda":
77+
torch.cuda.set_device(device_id)
78+
device = torch.device("cuda", index=device_id)
79+
device_label = f"GPU {device_id}"
80+
elif device_type == "mps":
81+
device = torch.device("mps")
82+
device_label = "MPS"
83+
else:
84+
device = torch.device("cpu")
85+
device_label = "CPU"
6486

6587
# Initialize tokenizer directly using get_tokenizer
6688
tokenizer = get_tokenizer(exp.tokenizer_name)
@@ -82,7 +104,7 @@ def run_experiment(exp: Experiment, gpu_queue):
82104
excluded_train_path=exp.excluded_train_path,
83105
)
84106
logger.info(
85-
f"[GPU {gpu_id}] Number of training batches: {len(train_dataloader)}"
107+
f"[{device_label}] Number of training batches: {len(train_dataloader)}"
86108
)
87109

88110
# Set up eval dataloaders
@@ -130,7 +152,7 @@ def run_experiment(exp: Experiment, gpu_queue):
130152
start_epoch = 0
131153

132154
logger.info(
133-
f"[GPU {gpu_id}] Total number of non-embedding parameters: {count_non_embedding_params(model)}"
155+
f"[{device_label}] Total number of non-embedding parameters: {count_non_embedding_params(model)}"
134156
)
135157

136158
# Initial evaluation (epochs_complete = 0)
@@ -151,15 +173,16 @@ def run_experiment(exp: Experiment, gpu_queue):
151173
train_author=exp.train_author,
152174
)
153175

154-
# Set up mixed precision training for memory efficiency
155-
scaler = torch.amp.GradScaler('cuda')
176+
# Set up mixed precision training if supported
177+
use_amp = device_type == "cuda"
178+
scaler = torch.amp.GradScaler('cuda') if use_amp else None
156179

157180
# Enable gradient checkpointing to save memory (if supported)
158181
try:
159182
model.gradient_checkpointing_enable()
160-
logger.info(f"[GPU {gpu_id}] Gradient checkpointing enabled for memory efficiency")
183+
logger.info(f"[{device_label}] Gradient checkpointing enabled for memory efficiency")
161184
except AttributeError:
162-
logger.info(f"[GPU {gpu_id}] Model does not support gradient checkpointing")
185+
logger.info(f"[{device_label}] Model does not support gradient checkpointing")
163186

164187
# Training loop
165188
for epoch in tqdm(range(start_epoch, max_epochs)):
@@ -171,16 +194,24 @@ def run_experiment(exp: Experiment, gpu_queue):
171194

172195
input_ids = batch["input_ids"].to(device)
173196

174-
# Forward pass with mixed precision
175-
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
197+
# Forward pass with or without mixed precision
198+
if use_amp:
199+
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
200+
outputs = model(input_ids=input_ids, labels=input_ids)
201+
loss = outputs.loss
202+
else:
176203
outputs = model(input_ids=input_ids, labels=input_ids)
177204
loss = outputs.loss
178205

179-
# Backward pass with scaled gradients
206+
# Backward pass with or without mixed precision
180207
optimizer.zero_grad()
181-
scaler.scale(loss).backward()
182-
scaler.step(optimizer)
183-
scaler.update()
208+
if use_amp:
209+
scaler.scale(loss).backward()
210+
scaler.step(optimizer)
211+
scaler.update()
212+
else:
213+
loss.backward()
214+
optimizer.step()
184215

185216
# Accumulate training loss
186217
total_train_loss += loss.item()
@@ -230,11 +261,12 @@ def run_experiment(exp: Experiment, gpu_queue):
230261
train_author=exp.train_author,
231262
)
232263

233-
# Force memory cleanup between evaluations
234-
torch.cuda.empty_cache()
264+
# Force memory cleanup between evaluations (CUDA only)
265+
if device_type == "cuda":
266+
torch.cuda.empty_cache()
235267

236268
# Build log message for console output
237-
log_message = f"[GPU {gpu_id}] Epoch {epochs_completed}/{max_epochs}: training loss = {train_loss:.4f}"
269+
log_message = f"[{device_label}] Epoch {epochs_completed}/{max_epochs}: training loss = {train_loss:.4f}"
238270
for name, loss in eval_losses.items():
239271
log_message += f", {name}: {loss:.4f}"
240272
logger.info(log_message)
@@ -249,13 +281,14 @@ def run_experiment(exp: Experiment, gpu_queue):
249281
# Early stopping after completing epoch (retain logs and checkpoints)
250282
if train_loss <= stop_train_loss and min_epochs <= epochs_completed:
251283
logger.info(
252-
f"[GPU {gpu_id}] Training loss {train_loss:.4f} below threshold {stop_train_loss}. Stopping training."
284+
f"[{device_label}] Training loss {train_loss:.4f} below threshold {stop_train_loss}. Stopping training."
253285
)
254286
break
255-
logger.info(f"[GPU {gpu_id}] Training complete for {modelname}")
287+
logger.info(f"[{device_label}] Training complete for {modelname}")
256288

257289
# Return the GPU id to the queue
258-
gpu_queue.put(gpu_id)
290+
if device_queue:
291+
device_queue.put(device_id)
259292
except Exception:
260293
logger.exception(f"Error in experiment {exp.name}")
261294
raise
@@ -265,16 +298,29 @@ def run_experiment(exp: Experiment, gpu_queue):
265298
# Check if we should run sequentially (for subprocess compatibility)
266299
USE_MULTIPROCESSING = os.environ.get('NO_MULTIPROCESSING', '0') != '1'
267300

268-
device_count = torch.cuda.device_count()
269-
gpu_count = min(device_count, 4)
270-
print(f"Using {gpu_count} GPUs out of {device_count} available")
301+
# Use already detected device configuration
302+
if device_type == "cuda":
303+
# Check for MAX_GPUS environment variable to optionally limit GPU usage
304+
max_gpus = int(os.environ.get('MAX_GPUS', '0')) or device_count
305+
gpu_count = min(device_count, max_gpus)
306+
if gpu_count < device_count:
307+
print(f"Using {gpu_count} GPUs (limited by MAX_GPUS) out of {device_count} available")
308+
else:
309+
print(f"Using all {gpu_count} available GPUs")
310+
elif device_type == "mps":
311+
gpu_count = 1
312+
print("Using Apple Metal Performance Shaders (MPS)")
313+
else:
314+
gpu_count = 1
315+
print("Using CPU for training (this will be slow)")
271316

272-
if USE_MULTIPROCESSING:
317+
if USE_MULTIPROCESSING and device_type == "cuda" and gpu_count > 1:
318+
# Only use multiprocessing for multiple CUDA GPUs
273319
mp.set_start_method("spawn", force=True)
274320
manager = mp.Manager()
275-
gpu_queue = manager.Queue()
321+
device_queue = manager.Queue()
276322
for gpu in range(gpu_count):
277-
gpu_queue.put(gpu)
323+
device_queue.put(gpu)
278324

279325
pool = mp.Pool(processes=gpu_count)
280326
logger = logging.getLogger(__name__)
@@ -286,22 +332,27 @@ def error_callback(e):
286332

287333
for exp in experiments:
288334
pool.apply_async(
289-
run_experiment, (exp, gpu_queue), error_callback=error_callback
335+
run_experiment, (exp, device_queue, device_type), error_callback=error_callback
290336
)
291337
pool.close()
292338
pool.join()
293339
else:
294-
# Sequential mode for subprocess compatibility
340+
# Sequential mode for subprocess compatibility or single device
295341
print("Running in sequential mode (multiprocessing disabled)")
296-
import queue
297-
gpu_queue = queue.Queue()
298-
for gpu in range(gpu_count):
299-
gpu_queue.put(gpu)
342+
if device_type == "cuda" and gpu_count > 1:
343+
# Multiple GPUs but running sequentially
344+
import queue
345+
device_queue = queue.Queue()
346+
for gpu in range(gpu_count):
347+
device_queue.put(gpu)
348+
else:
349+
# Single device or non-CUDA
350+
device_queue = None
300351

301352
for i, exp in enumerate(experiments):
302353
print(f"Training model {i+1}/{len(experiments)}: {exp.name}")
303-
run_experiment(exp, gpu_queue)
304-
# Put GPU back in queue for next experiment
305-
if not gpu_queue.empty():
306-
gpu_id = gpu_queue.get()
307-
gpu_queue.put(gpu_id)
354+
run_experiment(exp, device_queue, device_type)
355+
# For multi-GPU sequential mode, rotate through GPUs
356+
if device_queue and not device_queue.empty():
357+
device_id = device_queue.get()
358+
device_queue.put(device_id)

run_llm_stylometry.sh

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ OPTIONS:
3535
-h, --help Show this help message
3636
-f, --figure FIGURE Generate specific figure (1a, 1b, 2a, 2b, 3, 4, 5)
3737
-t, --train Train models from scratch before generating figures
38+
-g, --max-gpus NUM Maximum number of GPUs to use for training (default: all)
3839
-d, --data PATH Path to model_results.pkl (default: data/model_results.pkl)
3940
-o, --output DIR Output directory for figures (default: paper/figs/source)
4041
-l, --list List available figures
@@ -48,7 +49,8 @@ EXAMPLES:
4849
$0 # Setup environment and generate all figures
4950
$0 -f 1a # Generate only Figure 1A
5051
$0 -f 4 # Generate only Figure 4 (MDS plot)
51-
$0 -t # Train models from scratch, then generate figures
52+
$0 -t # Train models from scratch using all GPUs
53+
$0 -t -g 2 # Train models using only 2 GPUs
5254
$0 -l # List available figures
5355
$0 --setup-only # Only setup the environment
5456
$0 --clean # Remove environment and reinstall from scratch
@@ -278,6 +280,7 @@ setup_environment() {
278280
# Parse command line arguments
279281
FIGURE=""
280282
TRAIN=false
283+
MAX_GPUS=""
281284
DATA_PATH="data/model_results.pkl"
282285
OUTPUT_DIR="paper/figs/source"
283286
LIST_FIGURES=false
@@ -301,6 +304,10 @@ while [[ $# -gt 0 ]]; do
301304
TRAIN=true
302305
shift
303306
;;
307+
-g|--max-gpus)
308+
MAX_GPUS="$2"
309+
shift 2
310+
;;
304311
-d|--data)
305312
DATA_PATH="$2"
306313
shift 2
@@ -381,6 +388,22 @@ if [ "$SETUP_ONLY" = true ]; then
381388
exit 0
382389
fi
383390

391+
# Detect available compute devices
392+
print_info "Detecting available compute devices..."
393+
DEVICE_INFO=$(python -c "
394+
import torch
395+
if torch.cuda.is_available():
396+
n = torch.cuda.device_count()
397+
names = [torch.cuda.get_device_name(i) for i in range(n)]
398+
print(f'CUDA GPUs: {n} device(s) - {names[0] if n > 0 else \"Unknown\"}')
399+
elif torch.backends.mps.is_available():
400+
print('Apple Metal Performance Shaders (MPS)')
401+
else:
402+
import multiprocessing
403+
print(f'CPU only ({multiprocessing.cpu_count()} cores)')
404+
" 2>/dev/null || echo "Could not detect device")
405+
print_info "Device: $DEVICE_INFO"
406+
384407
# Build the Python command
385408
PYTHON_CMD="python code/generate_figures.py"
386409

@@ -394,6 +417,10 @@ if [ "$TRAIN" = true ]; then
394417
PYTHON_CMD="$PYTHON_CMD --train"
395418
fi
396419

420+
if [ -n "$MAX_GPUS" ]; then
421+
PYTHON_CMD="$PYTHON_CMD --max-gpus $MAX_GPUS"
422+
fi
423+
397424
if [ "$DATA_PATH" != "data/model_results.pkl" ]; then
398425
PYTHON_CMD="$PYTHON_CMD --data $DATA_PATH"
399426
fi

0 commit comments

Comments
 (0)