Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ from vggt.utils.load_fn import load_and_preprocess_images

device = "cuda" if torch.cuda.is_available() else "cpu"
# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
if device == "cuda":
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
else:
dtype = torch.float32

# Initialize the model and load the pretrained weights.
# This will automatically download the model weights the first time it's run, which may take a while.
Expand Down
12 changes: 8 additions & 4 deletions demo_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def run_model(target_dir, model) -> dict:

# Device check
device = "cuda" if torch.cuda.is_available() else "cpu"
if not torch.cuda.is_available():
raise ValueError("CUDA is not available. Check your environment.")

# Move model to device
model = model.to(device)
Expand All @@ -68,10 +66,16 @@ def run_model(target_dir, model) -> dict:

# Run inference
print("Running inference...")
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
if device == "cuda":
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
else:
dtype = torch.float32

with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
if device == "cuda":
with torch.cuda.amp.autocast(dtype=dtype):
predictions = model(images)
else:
predictions = model(images)

# Convert pose encoding to extrinsic and intrinsic matrices
Expand Down
5 changes: 4 additions & 1 deletion demo_viser.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,10 @@ def main():
print(f"Preprocessed images shape: {images.shape}")

print("Running inference...")
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
if torch.cuda.is_available():
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
else:
dtype = torch.float32

with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
Expand Down