Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dust3r/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .staticthings3d import StaticThings3D # noqa
from .waymo import Waymo # noqa
from .wildrgbd import WildRGBD # noqa

from .freiburgDataset import freiburgDataset

def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True):
import torch
Expand Down
124 changes: 124 additions & 0 deletions dust3r/datasets/freiburgDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import pickle
import torch
from torch.utils.data import Dataset
import numpy as np
import os.path as osp
from PIL import Image
import cv2

import glob

import sys
sys.path.append("/home/user/elwakeely1/dust3r")
from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
from dust3r.utils.image import resize_img,preprocess_ir_rgb
from dust3r.datasets.base.base_stereo_view_dataset import view_name
from dust3r.viz import SceneViz, auto_cam_size
from dust3r.utils.image import rgb

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


class freiburgDataset(BaseStereoViewDataset):
def __init__(self, *args, ROOT,method , **kwargs):
self.ROOT = ROOT
self.method = method
super().__init__(*args, **kwargs)
self.scenes= []
self.pairs =[]
self.frames =[]
self._load_data()
def load_train(self):
self.scene_files = sorted(glob.glob(osp.join(self.ROOT,self.split,self.method, "dataset_seq_*.npz")))
for scene_id, scene_file in enumerate(self.scene_files):
with np.load(scene_file, allow_pickle=True) as data:
frames = dict(data)
self.frames.append(frames)
self.scenes.append(scene_file)
for i in range(len(frames) - 1):
if f"{i+1}" in frames and f"{i}" in frames :
fm1_id = frames[f"{i}"].item()["img_number"]
fm2_id = frames[f"{i+1}"].item()["img_number"]
if fm2_id == fm1_id + 1:
self.pairs.append((scene_file,scene_id, fm1_id, fm2_id))

def load_test(self):
data = np.load(osp.join(self.ROOT,self.split, "dataset_test.npz"),allow_pickle=True)
self.frames = dict(data)


def _load_data(self):

if self.split == "Test":
self.load_test()
if self.split =="Train":
self.load_train()

def get_view_train(self,pair_idx, resolution, rng):
seq_path,seq, fm1, fm2 = self.pairs[pair_idx]
seq_frames = self.frames[seq]
views = []
for view_index in [fm1, fm2]:
data = seq_frames[f"{view_index}"].item()
IR_img_path = data["IR_aligned_path"]
ir_img = Image.open(str(IR_img_path))
rgb_path = data["RGB_path"]
rgb_image = Image.open(str(rgb_path))
rgb,ir_img = preprocess_ir_rgb(rgb_image,ir_img)
ir_img = resize_img(ir_img,size =224)
depthmap = data["Depth"]
intrinsics =np.float32( data["Camera_intrinsic"])
camera_pose = np.float32(data["camera_pose"])
views.append(dict(
img=ir_img,
depthmap=depthmap,
camera_pose=camera_pose,
camera_intrinsics=intrinsics,
dataset='freiburg',
label=str(seq_path),
instance=str(IR_img_path)))
return views

def get_view_test(self,pair_idx, resolution, rng):
views = []
for i in range(2):
data = self.frames[f"{pair_idx}"].item()
IR_img_path = data["IR_aligned_path"]
ir_img = Image.open(str(IR_img_path))
rgb_path = data["RGB_path"]
rgb_image = Image.open(str(rgb_path))
rgb,ir_img = preprocess_ir_rgb(rgb_image,ir_img)
ir_img = resize_img(ir_img,size =224)
depthmap = data["Depth"]
intrinsics =np.float32( data["Camera_intrinsic"])
views.append(dict(
img=ir_img,
depthmap=depthmap,
camera_intrinsics=intrinsics,
dataset='freiburg',
label=0,
instance=str(IR_img_path)))
return views


def _get_views(self, pair_idx, resolution, rng):
if self.split == "Test":
views = self.get_view_test(pair_idx, resolution, rng)
if self.split =="Train":
views = self.get_view_train(pair_idx, resolution, rng)
return views

def __len__(self):
"""Returns the number of samples in the dataset."""
if self.split == "Train":
return len(self.pairs)
elif self.split == "Test":
return len(self.frames)

if __name__ == "__main__":


train_ds = freiburgDataset(ROOT="/home/user/elwakeely1/DataParam",method="RANSAC", split = "Test",resolution=224, aug_crop=16)
views = train_ds[0]
print(views[0]["img"].shape)
Binary file added dust3r/datasets/output_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions dust3r/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, u
if symmetrize_batch:
view1, view2 = make_batch_symmetric(batch)

with torch.cuda.amp.autocast(enabled=bool(use_amp)):
with torch.amp.autocast('cuda',enabled=bool(use_amp)):
pred1, pred2 = model(view1, view2)

# loss is supposed to be symmetric
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast('cuda',enabled=False):
loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None

result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
Expand Down
33 changes: 33 additions & 0 deletions dust3r/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False):

def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None):
# everything is normalized w.r.t. camera of view1
# print(gt1.keys())--dict_keys(['img', 'depthmap', 'camera_intrinsics', 'dataset', 'label', 'instance', 'idx', 'true_shape', 'camera_pose', 'pts3d', 'valid_mask', 'rng'])
# print(pred1.keys())-->dict_keys(['pts3d', 'conf'])

in_camera1 = inv(gt1['camera_pose'])
gt_pts1 = geotrf(in_camera1, gt1['pts3d']) # B,H,W,3
gt_pts2 = geotrf(in_camera1, gt2['pts3d']) # B,H,W,3
Expand Down Expand Up @@ -238,6 +241,36 @@ def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details)


class RMSE_loss(MultiLoss):
def __init__(self):
super().__init__()

def compute_loss(self, gt1, gt2, pred1, pred2, **kw):

# Compute the L2 loss for depth (between predicted and ground truth depth)
depth_pred = pred1['pts3d'][..., 2] # Extract the predicted depth (z-coordinate)
gt_depthmap = gt1["depthmap"].to(depth_pred.device)
# l1 = self.criterion(depth_pred,gt_depthmap)
depth_loss = torch.sqrt(torch.mean((depth_pred - gt_depthmap)**2))


# # Compute accuracy-based loss (Threshold-based accuracy)
# threshold = torch.max(depth_pred / gt_depthmap, gt_depthmap / depth_pred)
# acc_1_25 = torch.mean((threshold < 1.25).float()) # Accuracy with threshold 1.25
# acc_1_25_2 = torch.mean((threshold < 1.25**2).float()) # Accuracy with threshold 1.25^2



# # Track individual losses for debugging or analysis
details = {
"RMSE Loss": depth_loss.item(),
# "Accuracy < 1.25": acc_1_25.item(),
# "Accuracy < 1.25^2": acc_1_25_2.item(),
}

return depth_loss, details


class Regr3D_ShiftInv (Regr3D):
""" Same than Regr3D but invariant to depth shift.
"""
Expand Down
4 changes: 2 additions & 2 deletions dust3r/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def load_model(model_path, device, verbose=True):
if verbose:
print('... loading model from', model_path)
ckpt = torch.load(model_path, map_location='cpu')
ckpt = torch.load(model_path, map_location='cpu',weights_only=False)
args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
if 'landscape_only' not in args:
args = args[:-1] + ', landscape_only=False)'
Expand Down Expand Up @@ -202,7 +202,7 @@ def forward(self, view1, view2):
# combine all ref images into object-centric representation
dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2)

with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast('cuda',enabled=False):
res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1)
res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2)

Expand Down
14 changes: 7 additions & 7 deletions dust3r/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
def get_args_parser():
parser = argparse.ArgumentParser('DUST3R training', add_help=False)
# model and criterion
parser.add_argument('--model', default="AsymmetricCroCo3DStereo(patch_embed_cls='ManyAR_PatchEmbed')",
parser.add_argument('--model', default="AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)",
type=str, help="string containing the model to build")
parser.add_argument('--pretrained', default=None, help='path of a starting checkpoint')
parser.add_argument('--pretrained', default="checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth", help='path of a starting checkpoint')
parser.add_argument('--train_criterion', default="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)",
type=str, help="train criterion")
parser.add_argument('--test_criterion', default=None, type=str, help="test criterion")
Expand All @@ -52,19 +52,19 @@ def get_args_parser():

# training
parser.add_argument('--seed', default=0, type=int, help="Random seed")
parser.add_argument('--batch_size', default=64, type=int,
parser.add_argument('--batch_size', default=1, type=int,
help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus")
parser.add_argument('--accum_iter', default=1, type=int,
help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)")
parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler")

parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)")
parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)')
parser.add_argument('--lr', type=float, default=0.00001, metavar='LR', help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
parser.add_argument('--min_lr', type=float, default=1e-06, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR')
parser.add_argument('--warmup_epochs', type=int, default=1, metavar='N', help='epochs to warmup LR')

parser.add_argument('--amp', type=int, default=0,
choices=[0, 1], help="Use Automatic Mixed Precision for pretraining")
Expand Down Expand Up @@ -137,7 +137,7 @@ def train(args):

if args.pretrained and not args.resume:
print('Loading pretrained: ', args.pretrained)
ckpt = torch.load(args.pretrained, map_location=device)
ckpt = torch.load(args.pretrained, map_location=device, weights_only=False)
print(model.load_state_dict(ckpt['model'], strict=False))
del ckpt # in case it occupies memory

Expand Down
Loading