diff --git a/README.md b/README.md index 85e26d08..6dfb9568 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,10 @@ from vggt.utils.load_fn import load_and_preprocess_images device = "cuda" if torch.cuda.is_available() else "cpu" # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) -dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 +if device == "cuda": + dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 +else: + dtype = torch.float32 # Initialize the model and load the pretrained weights. # This will automatically download the model weights the first time it's run, which may take a while. diff --git a/demo_gradio.py b/demo_gradio.py index 9b83acfe..73fc32bc 100644 --- a/demo_gradio.py +++ b/demo_gradio.py @@ -49,8 +49,6 @@ def run_model(target_dir, model) -> dict: # Device check device = "cuda" if torch.cuda.is_available() else "cpu" - if not torch.cuda.is_available(): - raise ValueError("CUDA is not available. Check your environment.") # Move model to device model = model.to(device) @@ -68,10 +66,16 @@ def run_model(target_dir, model) -> dict: # Run inference print("Running inference...") - dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + if device == "cuda": + dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + else: + dtype = torch.float32 with torch.no_grad(): - with torch.cuda.amp.autocast(dtype=dtype): + if device == "cuda": + with torch.cuda.amp.autocast(dtype=dtype): + predictions = model(images) + else: predictions = model(images) # Convert pose encoding to extrinsic and intrinsic matrices diff --git a/demo_viser.py b/demo_viser.py index c52726b6..0d75fc78 100644 --- a/demo_viser.py +++ b/demo_viser.py @@ -374,7 +374,10 @@ def main(): print(f"Preprocessed images shape: {images.shape}") print("Running inference...") - dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + if torch.cuda.is_available(): + dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + else: + dtype = torch.float32 with torch.no_grad(): with torch.cuda.amp.autocast(dtype=dtype):