@@ -36,8 +36,20 @@ def tqdm(iterable, *args, **kwargs):
3636logging .basicConfig (level = logging .INFO )
3737logger = logging .getLogger (__name__ )
3838
39- if not torch .cuda .is_available ():
40- raise Exception ("No GPU available" )
39+ # Detect available devices
40+ def get_device_info ():
41+ """Detect and return device configuration."""
42+ if torch .cuda .is_available ():
43+ device_count = torch .cuda .device_count ()
44+ return "cuda" , device_count
45+ elif torch .backends .mps .is_available ():
46+ # Apple Metal Performance Shaders (MPS) backend
47+ return "mps" , 1
48+ else :
49+ return "cpu" , 1
50+
51+ device_type , device_count = get_device_info ()
52+ logger .info (f"Device type: { device_type } , Count: { device_count } " )
4153
4254experiments = []
4355for seed in range (10 ):
@@ -51,16 +63,26 @@ def tqdm(iterable, *args, **kwargs):
5163 )
5264
5365
54- def run_experiment (exp : Experiment , gpu_queue ):
66+ def run_experiment (exp : Experiment , device_queue , device_type = "cuda" ):
5567 try :
5668 logging .basicConfig (level = logging .INFO )
5769 logger = logging .getLogger (__name__ )
5870
59- # Get an available GPU id
60- gpu_id = gpu_queue .get ()
71+ # Get an available device id
72+ device_id = device_queue .get () if device_queue else 0
6173 logger .info (f"Starting experiment: { exp .name } " )
62- torch .cuda .set_device (gpu_id )
63- device = torch .device ("cuda" , index = gpu_id )
74+
75+ # Set up device based on type
76+ if device_type == "cuda" :
77+ torch .cuda .set_device (device_id )
78+ device = torch .device ("cuda" , index = device_id )
79+ device_label = f"GPU { device_id } "
80+ elif device_type == "mps" :
81+ device = torch .device ("mps" )
82+ device_label = "MPS"
83+ else :
84+ device = torch .device ("cpu" )
85+ device_label = "CPU"
6486
6587 # Initialize tokenizer directly using get_tokenizer
6688 tokenizer = get_tokenizer (exp .tokenizer_name )
@@ -82,7 +104,7 @@ def run_experiment(exp: Experiment, gpu_queue):
82104 excluded_train_path = exp .excluded_train_path ,
83105 )
84106 logger .info (
85- f"[GPU { gpu_id } ] Number of training batches: { len (train_dataloader )} "
107+ f"[{ device_label } ] Number of training batches: { len (train_dataloader )} "
86108 )
87109
88110 # Set up eval dataloaders
@@ -130,7 +152,7 @@ def run_experiment(exp: Experiment, gpu_queue):
130152 start_epoch = 0
131153
132154 logger .info (
133- f"[GPU { gpu_id } ] Total number of non-embedding parameters: { count_non_embedding_params (model )} "
155+ f"[{ device_label } ] Total number of non-embedding parameters: { count_non_embedding_params (model )} "
134156 )
135157
136158 # Initial evaluation (epochs_complete = 0)
@@ -151,15 +173,16 @@ def run_experiment(exp: Experiment, gpu_queue):
151173 train_author = exp .train_author ,
152174 )
153175
154- # Set up mixed precision training for memory efficiency
155- scaler = torch .amp .GradScaler ('cuda' )
176+ # Set up mixed precision training if supported
177+ use_amp = device_type == "cuda"
178+ scaler = torch .amp .GradScaler ('cuda' ) if use_amp else None
156179
157180 # Enable gradient checkpointing to save memory (if supported)
158181 try :
159182 model .gradient_checkpointing_enable ()
160- logger .info (f"[GPU { gpu_id } ] Gradient checkpointing enabled for memory efficiency" )
183+ logger .info (f"[{ device_label } ] Gradient checkpointing enabled for memory efficiency" )
161184 except AttributeError :
162- logger .info (f"[GPU { gpu_id } ] Model does not support gradient checkpointing" )
185+ logger .info (f"[{ device_label } ] Model does not support gradient checkpointing" )
163186
164187 # Training loop
165188 for epoch in tqdm (range (start_epoch , max_epochs )):
@@ -171,16 +194,24 @@ def run_experiment(exp: Experiment, gpu_queue):
171194
172195 input_ids = batch ["input_ids" ].to (device )
173196
174- # Forward pass with mixed precision
175- with torch .amp .autocast (device_type = 'cuda' , dtype = torch .float16 ):
197+ # Forward pass with or without mixed precision
198+ if use_amp :
199+ with torch .amp .autocast (device_type = 'cuda' , dtype = torch .float16 ):
200+ outputs = model (input_ids = input_ids , labels = input_ids )
201+ loss = outputs .loss
202+ else :
176203 outputs = model (input_ids = input_ids , labels = input_ids )
177204 loss = outputs .loss
178205
179- # Backward pass with scaled gradients
206+ # Backward pass with or without mixed precision
180207 optimizer .zero_grad ()
181- scaler .scale (loss ).backward ()
182- scaler .step (optimizer )
183- scaler .update ()
208+ if use_amp :
209+ scaler .scale (loss ).backward ()
210+ scaler .step (optimizer )
211+ scaler .update ()
212+ else :
213+ loss .backward ()
214+ optimizer .step ()
184215
185216 # Accumulate training loss
186217 total_train_loss += loss .item ()
@@ -230,11 +261,12 @@ def run_experiment(exp: Experiment, gpu_queue):
230261 train_author = exp .train_author ,
231262 )
232263
233- # Force memory cleanup between evaluations
234- torch .cuda .empty_cache ()
264+ # Force memory cleanup between evaluations (CUDA only)
265+ if device_type == "cuda" :
266+ torch .cuda .empty_cache ()
235267
236268 # Build log message for console output
237- log_message = f"[GPU { gpu_id } ] Epoch { epochs_completed } /{ max_epochs } : training loss = { train_loss :.4f} "
269+ log_message = f"[{ device_label } ] Epoch { epochs_completed } /{ max_epochs } : training loss = { train_loss :.4f} "
238270 for name , loss in eval_losses .items ():
239271 log_message += f", { name } : { loss :.4f} "
240272 logger .info (log_message )
@@ -249,13 +281,14 @@ def run_experiment(exp: Experiment, gpu_queue):
249281 # Early stopping after completing epoch (retain logs and checkpoints)
250282 if train_loss <= stop_train_loss and min_epochs <= epochs_completed :
251283 logger .info (
252- f"[GPU { gpu_id } ] Training loss { train_loss :.4f} below threshold { stop_train_loss } . Stopping training."
284+ f"[{ device_label } ] Training loss { train_loss :.4f} below threshold { stop_train_loss } . Stopping training."
253285 )
254286 break
255- logger .info (f"[GPU { gpu_id } ] Training complete for { modelname } " )
287+ logger .info (f"[{ device_label } ] Training complete for { modelname } " )
256288
257289 # Return the GPU id to the queue
258- gpu_queue .put (gpu_id )
290+ if device_queue :
291+ device_queue .put (device_id )
259292 except Exception :
260293 logger .exception (f"Error in experiment { exp .name } " )
261294 raise
@@ -265,16 +298,29 @@ def run_experiment(exp: Experiment, gpu_queue):
265298 # Check if we should run sequentially (for subprocess compatibility)
266299 USE_MULTIPROCESSING = os .environ .get ('NO_MULTIPROCESSING' , '0' ) != '1'
267300
268- device_count = torch .cuda .device_count ()
269- gpu_count = min (device_count , 4 )
270- print (f"Using { gpu_count } GPUs out of { device_count } available" )
301+ # Use already detected device configuration
302+ if device_type == "cuda" :
303+ # Check for MAX_GPUS environment variable to optionally limit GPU usage
304+ max_gpus = int (os .environ .get ('MAX_GPUS' , '0' )) or device_count
305+ gpu_count = min (device_count , max_gpus )
306+ if gpu_count < device_count :
307+ print (f"Using { gpu_count } GPUs (limited by MAX_GPUS) out of { device_count } available" )
308+ else :
309+ print (f"Using all { gpu_count } available GPUs" )
310+ elif device_type == "mps" :
311+ gpu_count = 1
312+ print ("Using Apple Metal Performance Shaders (MPS)" )
313+ else :
314+ gpu_count = 1
315+ print ("Using CPU for training (this will be slow)" )
271316
272- if USE_MULTIPROCESSING :
317+ if USE_MULTIPROCESSING and device_type == "cuda" and gpu_count > 1 :
318+ # Only use multiprocessing for multiple CUDA GPUs
273319 mp .set_start_method ("spawn" , force = True )
274320 manager = mp .Manager ()
275- gpu_queue = manager .Queue ()
321+ device_queue = manager .Queue ()
276322 for gpu in range (gpu_count ):
277- gpu_queue .put (gpu )
323+ device_queue .put (gpu )
278324
279325 pool = mp .Pool (processes = gpu_count )
280326 logger = logging .getLogger (__name__ )
@@ -286,22 +332,27 @@ def error_callback(e):
286332
287333 for exp in experiments :
288334 pool .apply_async (
289- run_experiment , (exp , gpu_queue ), error_callback = error_callback
335+ run_experiment , (exp , device_queue , device_type ), error_callback = error_callback
290336 )
291337 pool .close ()
292338 pool .join ()
293339 else :
294- # Sequential mode for subprocess compatibility
340+ # Sequential mode for subprocess compatibility or single device
295341 print ("Running in sequential mode (multiprocessing disabled)" )
296- import queue
297- gpu_queue = queue .Queue ()
298- for gpu in range (gpu_count ):
299- gpu_queue .put (gpu )
342+ if device_type == "cuda" and gpu_count > 1 :
343+ # Multiple GPUs but running sequentially
344+ import queue
345+ device_queue = queue .Queue ()
346+ for gpu in range (gpu_count ):
347+ device_queue .put (gpu )
348+ else :
349+ # Single device or non-CUDA
350+ device_queue = None
300351
301352 for i , exp in enumerate (experiments ):
302353 print (f"Training model { i + 1 } /{ len (experiments )} : { exp .name } " )
303- run_experiment (exp , gpu_queue )
304- # Put GPU back in queue for next experiment
305- if not gpu_queue .empty ():
306- gpu_id = gpu_queue .get ()
307- gpu_queue .put (gpu_id )
354+ run_experiment (exp , device_queue , device_type )
355+ # For multi- GPU sequential mode, rotate through GPUs
356+ if device_queue and not device_queue .empty ():
357+ device_id = device_queue .get ()
358+ device_queue .put (device_id )
0 commit comments