Skip to content

Pytorch3d #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions models/deepmapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn as nn
from .networks import LocNetRegKITTI, MLP
from utils import transform_to_global_KITTI, compose_pose_diff, euler_pose_to_quaternion,quaternion_to_euler_pose, qmul_torch
from utils import transform_to_global_KITTI, compose_pose_diff, euler_pose_to_quaternion,quaternion_to_euler_pose, qmul_torch, matrix_to_rotation_6d, rotation_6d_to_matrix, euler_pose_to_6d_pose

def get_M_net_inputs_labels(occupied_points, unoccupited_points):
"""
Expand Down Expand Up @@ -52,6 +52,8 @@ def __init__(self, n_points, loss_fn, rotation_representation='quaternion', n_sa
self.rotation = rotation_representation
if self.rotation == 'quaternion':
self.loc_net = LocNetRegKITTI(n_points=n_points, out_dims=7) # <x,y,z,theta>
elif self.rotation == '6d':
self.loc_net = LocNetRegKITTI(n_points=n_points, out_dims=9)
else:
self.loc_net = LocNetRegKITTI(n_points=n_points, out_dims=6) # <x,y,z,theta>
self.occup_net = MLP(dim)
Expand All @@ -60,32 +62,50 @@ def __init__(self, n_points, loss_fn, rotation_representation='quaternion', n_sa


def forward(self, obs_local, sensor_pose, valid_points=None, pairwise_pose=None):
# obs_local: <BxGxNx3>
# obs_local: <GxNx3>
# sensor_pose: <Gx4>
G = obs_local.shape[0]
self.obs_local = obs_local
if self.rotation == 'quaternion':
sensor_pose = euler_pose_to_quaternion(sensor_pose)
self.obs_initial = transform_to_global_KITTI(
sensor_pose, self.obs_local, rotation_representation=self.rotation)
# sensor_pose: <Gx7>
elif self.rotation == '6d':
sensor_pose = euler_pose_to_6d_pose(sensor_pose)
# sensor_pose: <Gx9>

self.obs_initial = transform_to_global_KITTI(sensor_pose, self.obs_local, rotation_representation=self.rotation)
# obs_initial: <GxNx3>
self.l_net_out = self.loc_net(self.obs_initial)
# l_net_out: <Gx9>
print(self.l_net_out.shape)
if self.rotation == 'quaternion':
original_shape = list(sensor_pose.shape)
xyz = self.l_net_out[:,:3]+ sensor_pose[:,:3]
wxyz = qmul_torch(self.l_net_out[:,3:], sensor_pose[:,3:])
self.pose_est = torch.cat((xyz, wxyz), dim=1).view(original_shape)
elif self.rotation == 'euler_angle':
self.pose_est = self.l_net_out + sensor_pose
elif self.rotation == '6d':
original_shape = list(sensor_pose.shape)
xyz = self.l_net_out[:, :3] + sensor_pose[:, :3]
l_net_6d = rotation_6d_to_matrix(self.l_net_out[:, 3:])
sensor_6d = rotation_6d_to_matrix(sensor_pose[:, 3:])
rotation_6d = torch.matmul(l_net_6d, sensor_6d)
rotation_6d = matrix_to_rotation_6d(rotation_6d)
self.pose_est = torch.cat((xyz, rotation_6d), dim=1).view(original_shape)
# l_net_out[:, -1] = 0
# self.pose_est = cat_pose_KITTI(sensor_pose, self.loc_net(self.obs_initial))
# self.bs = obs_local.shape[0]
# self.obs_local = self.obs_local.reshape(self.bs,-1,3)
self.obs_global_est = transform_to_global_KITTI(
self.pose_est, self.obs_local, rotation_representation=self.rotation)
self.obs_global_est = transform_to_global_KITTI(self.pose_est, self.obs_local, rotation_representation=self.rotation)

if self.training:
self.valid_points = valid_points
if self.rotation == 'quaternion':
pairwise_pose = euler_pose_to_quaternion(pairwise_pose)
elif self.rotation == '6d':
pairwise_pose = euler_pose_to_6d_pose(pairwise_pose)

if self.loss_fn.__name__ == "pose":
self.t_src, self.t_dst, self.r_src, self.r_dst = compose_pose_diff(self.pose_est, pairwise_pose, rotation_representation=self.rotation)
else:
Expand Down
92 changes: 52 additions & 40 deletions script/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import time
import argparse
import functools
print = functools.partial(print,flush=True)

print = functools.partial(print, flush=True)

import numpy as np
import torch
Expand All @@ -22,20 +23,20 @@
torch.manual_seed(42)

parser = argparse.ArgumentParser()
parser.add_argument('--name',type=str,default='test',help='experiment name')
parser.add_argument('-e','--n_epochs',type=int,default=1000,help='number of epochs')
parser.add_argument('-l','--loss',type=str,default='bce_ch',help='loss function')
parser.add_argument('-n','--n_samples',type=int,default=35,help='number of sampled unoccupied points along rays')
parser.add_argument('-v','--voxel_size',type=float,default=1,help='size of downsampling voxel grid')
parser.add_argument('--lr',type=float,default=1e-4,help='learning rate')
parser.add_argument('--name', type=str, default='test', help='experiment name')
parser.add_argument('-e', '--n_epochs', type=int, default=1000, help='number of epochs')
parser.add_argument('-l', '--loss', type=str, default='bce_ch', help='loss function')
parser.add_argument('-n', '--n_samples', type=int, default=35, help='number of sampled unoccupied points along rays')
parser.add_argument('-v', '--voxel_size', type=float, default=1, help='size of downsampling voxel grid')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--dataset', type=str, default="KITTI", help="Type of dataset to use")
parser.add_argument('-d','--data_dir',type=str,default='../data/ActiveVisionDataset/',help='dataset path')
parser.add_argument('-t','--traj',type=str,default='2011_09_30_drive_0018_sync_full',help='trajectory file folder')
parser.add_argument('-m','--model', type=str, default=None,help='pretrained model name')
parser.add_argument('-i','--init', type=str, default=None, help='path to initial pose')
parser.add_argument('-d', '--data_dir', type=str, default='../data/ActiveVisionDataset/', help='dataset path')
parser.add_argument('-t', '--traj', type=str, default='2011_09_30_drive_0018_sync_full', help='trajectory file folder')
parser.add_argument('-m', '--model', type=str, default=None, help='pretrained model name')
parser.add_argument('-i', '--init', type=str, default=None, help='path to initial pose')
parser.add_argument('-p', '--pairwise', type=str, default=None, help='path to pairwise pose')
parser.add_argument('--log_interval',type=int,default=10,help='logging interval of saving results')
parser.add_argument('--group_size',type=int,default=8,help='group size')
parser.add_argument('--log_interval', type=int, default=10, help='logging interval of saving results')
parser.add_argument('--group_size', type=int, default=8, help='group size')
parser.add_argument('--resume', action='store_true',
help='If present, restore checkpoint and resume training')
parser.add_argument('--alpha', type=float, default=0.1, help='weight for chamfer loss')
Expand All @@ -46,52 +47,64 @@

opt = parser.parse_args()

checkpoint_dir = os.path.join('../results/'+opt.dataset,opt.name)
checkpoint_dir = os.path.join('../results/' + opt.dataset, opt.name)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
if not os.path.exists(os.path.join(checkpoint_dir, "pose_ests")):
os.makedirs(os.path.join(checkpoint_dir, "pose_ests"))
utils.save_opt(checkpoint_dir,opt)
utils.save_opt(checkpoint_dir, opt)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# opt.init: INIT=$DATA_DIR/$TRAJ/prior/init_pose.npy
# init_pose.npy should be an Nx6 numpy array, where N is the number of frames.
# Each row is the initial pose of a frame represented by x, y, z, row, pitch, yaw.
# 把initial pose转换成tensor
init_pose_np = np.load(opt.init).astype("float32")
init_pose = torch.from_numpy(init_pose_np)
pairwise_pose = np.load(opt.pairwise).astype("float32")

print('loading dataset')
if opt.dataset == "KITTI":
train_dataset = Kitti(opt.data_dir, opt.traj, opt.voxel_size, init_pose=init_pose, group_size=opt.group_size, pairwise_pose=pairwise_pose)
train_dataset = Kitti(opt.data_dir, opt.traj, opt.voxel_size, init_pose=init_pose, group_size=opt.group_size,
pairwise_pose=pairwise_pose)
eval_dataset = KittiEval(train_dataset)
# eval_dataset大致与train_dataset相同,但是不包含gt_pose

elif opt.dataset == "NCLT" or "Nebula":
train_dataset = Nclt(opt.data_dir, opt.traj, opt.voxel_size, init_pose=init_pose, group_size=opt.group_size, pairwise_pose=pairwise_pose)
train_dataset = Nclt(opt.data_dir, opt.traj, opt.voxel_size, init_pose=init_pose, group_size=opt.group_size,
pairwise_pose=pairwise_pose)
eval_dataset = NcltEval(train_dataset)
else:
assert 0, "Unsupported dataset"


train_loader = DataLoader(train_dataset, batch_size=None, num_workers=4, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=64, num_workers=4)
loss_fn = eval('loss.'+opt.loss)
# loss function is bce_ch_eu
loss_fn = eval('loss.' + opt.loss)

if opt.rotation not in ['quaternion','euler_angle']:
if opt.rotation not in ['quaternion', 'euler_angle', '6d']:
print("Unsupported rotation representation")
assert()
assert ()

print('creating model')
model = DeepMapping2(n_points=train_dataset.n_points, loss_fn=loss_fn,
n_samples=opt.n_samples, alpha=opt.alpha, beta=opt.beta, rotation_representation=opt.rotation).to(device)
n_samples=opt.n_samples, alpha=opt.alpha, beta=opt.beta, rotation_representation=opt.rotation).to(
device)

if opt.optimizer == "Adam":
optimizer = optim.Adam(model.parameters(),lr=opt.lr)
optimizer = optim.Adam(model.parameters(), lr=opt.lr)
elif opt.optimizer == "SGD":
optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9)
else:
print("Unsupported optimizer")
assert()
assert ()

scaler = torch.cuda.amp.GradScaler()

if opt.model is not None:
utils.load_checkpoint(opt.model,model,optimizer)
utils.load_checkpoint(opt.model, model, optimizer)

if opt.resume:
resume_filename = os.path.join(checkpoint_dir, "model_best.pth")
Expand All @@ -112,9 +125,8 @@
ch_loss = 0
eu_loss = 0
model.train()

time_start = time.time()
for index,(obs, valid_pt, init_global_pose, pairwise_pose) in enumerate(train_loader):
for index, (obs, valid_pt, init_global_pose, pairwise_pose) in enumerate(train_loader):
obs = obs.to(device)
valid_pt = valid_pt.to(device)
init_global_pose = init_global_pose.to(device)
Expand All @@ -141,11 +153,11 @@
ch_loss += ch
if loss == "bce_ch_eu" or loss == "pose":
eu_loss += eu

time_end = time.time()
# print(model.parameters().grad)
print("Training time: {:.2f}s".format(time_end - time_start))
training_loss_epoch = training_loss/len(train_loader)
training_loss_epoch = training_loss / len(train_loader)
bce_epoch = bce_loss / len(train_loader)
ch_epoch = ch_loss / len(train_loader)
eu_epoch = eu_loss / len(train_loader)
Expand All @@ -154,12 +166,12 @@
ch_losses.append(ch_epoch)
eu_losses.append(eu_epoch)

print('[{}/{}], training loss: {:.4f}'.format(epoch+1,opt.n_epochs,training_loss_epoch))
print('[{}/{}], training loss: {:.4f}'.format(epoch + 1, opt.n_epochs, training_loss_epoch))
obs_global_est_np = []
pose_est_np = []
with torch.no_grad():
model.eval()
for index,(obs, init_global_pose) in enumerate(eval_loader):
for index, (obs, init_global_pose) in enumerate(eval_loader):
obs = obs.to(device)
init_global_pose = init_global_pose.to(device)
model(obs, init_global_pose)
Expand All @@ -168,13 +180,13 @@
pose_est = model.pose_est
obs_global_est_np.append(obs_global_est.cpu().detach().numpy())
pose_est_np.append(pose_est.cpu().detach().numpy())

pose_est_np = np.concatenate(pose_est_np)

save_name = os.path.join(checkpoint_dir, "pose_ests", str(epoch+1))
np.save(save_name,pose_est_np)
save_name = os.path.join(checkpoint_dir, "pose_ests", str(epoch + 1))
np.save(save_name, pose_est_np)

utils.plot_global_pose(checkpoint_dir, opt.dataset, epoch+1, rotation_representation=opt.rotation)
utils.plot_global_pose(checkpoint_dir, opt.dataset, epoch + 1, rotation_representation=opt.rotation)

try:
trans_ate, rot_ate = utils.compute_ate(pose_est_np, train_dataset.gt_pose, rotation_representation=opt.rotation)
Expand All @@ -193,17 +205,17 @@
if training_loss_epoch < best_loss:
print("lowest loss:", training_loss_epoch)
best_loss = training_loss_epoch

# Visulize global point clouds
obs_global_est_np = np.concatenate(obs_global_est_np)
save_name = os.path.join(checkpoint_dir,'obs_global_est.npy')
np.save(save_name,obs_global_est_np)
save_name = os.path.join(checkpoint_dir, 'obs_global_est.npy')
np.save(save_name, obs_global_est_np)

# Save checkpoint
save_name = os.path.join(checkpoint_dir,'model_best.pth')
utils.save_checkpoint(save_name,model,optimizer,epoch)
save_name = os.path.join(checkpoint_dir, 'model_best.pth')
utils.save_checkpoint(save_name, model, optimizer, epoch)

print()

training_losses = np.array(training_losses)
np.save(os.path.join(checkpoint_dir, "loss.npy"), training_losses)
np.save(os.path.join(checkpoint_dir, "loss.npy"), training_losses)
32 changes: 32 additions & 0 deletions utils/geometry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ def euler_pose_to_quaternion(euler_pose):
quaternion_pose = torch.cat((xyz, quaternion), dim=1)
return quaternion_pose

def euler_pose_to_6d_pose(euler_pose):
"""
convert euler angles pose to 6d pose.
:param euler_pose: <Bx6> <x, y, z, row, pitch, yaw>
:return 6d_pose: <Bx9> <x, y, z, r00, r01, r02, r10, r11, r12>
"""
xyz = euler_pose[:, :3]
e = euler_pose[:,3:]
assert e.shape[-1] == 3

# Convert euler angles to rotation matrix
rotation_matrix = euler_angles_to_matrix(e, convention="XYZ")

# Convert rotation matrix to 6D pose representation
six_d = matrix_to_rotation_6d(rotation_matrix)

# Concatenate xyz and six_d along dimension 1
six_d_pose = torch.cat((xyz, six_d), dim=1)

return six_d_pose

def transform_to_global_KITTI(pose, obs_local, rotation_representation):
"""
transform obs local coordinate to global corrdinate frame
Expand All @@ -92,6 +113,9 @@ def transform_to_global_KITTI(pose, obs_local, rotation_representation):
elif rotation_representation == "quaternion":
quat = pose[:, 3:]
rotation_matrix = quaternion_to_matrix(quat)
elif rotation_representation == "6d":
sixd = pose[:, 3:]
rotation_matrix = rotation_6d_to_matrix(sixd)
obs_global = torch.bmm(obs_local, rotation_matrix.transpose(1, 2))
# obs_global[:, :, 0] = obs_global[:, :, 0] + pose[:, [0]]
# obs_global[:, :, 1] = obs_global[:, :, 1] + pose[:, [1]]
Expand Down Expand Up @@ -123,6 +147,9 @@ def compose_pose_diff(pose_est, pairwise, rotation_representation):
elif rotation_representation == "quaternion":
rotation_est = quaternion_to_matrix(rpy_est)
rotation_pairwise = quaternion_to_matrix(rpy_pairwise)
elif rotation_representation == "6d":
rotation_est = rotation_6d_to_matrix(rpy_est)
rotation_pairwise = rotation_6d_to_matrix(rpy_pairwise)
r_dst = torch.bmm(rotation_est, rotation_pairwise)
# rpy = matrix_to_euler_angles(rotation, convention="XYZ")
# dst = torch.concat((xyz, rpy), dim=1)
Expand Down Expand Up @@ -223,6 +250,11 @@ def compute_ate(output, target, rotation_representation):
q = output[:,3:]
output_quat = q[:, [1, 2, 3, 0]]
rpy = Rot.from_quat(output_quat).as_euler("XYZ")
elif rotation_representation == "6d":
r = torch.tensor(output[:,3:])
output_r = rotation_6d_to_matrix(r)
rpy = Rot.from_matrix(output_r.numpy()).as_euler("XYZ")

yaw_aligned = rpy[:, -1] + rotation[-1]
yaw_gt = target[:, -1]
while np.any(yaw_aligned > np.pi):
Expand Down
44 changes: 43 additions & 1 deletion utils/pytorch3d_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,46 @@ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
return o.reshape(quaternions.shape[:-1] + (3, 3))

def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
"""
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
using Gram--Schmidt orthogonalization per Section B of [1].
Args:
d6: 6D rotation representation, of size (*, 6)

Returns:
batch of rotation matrices of size (*, 3, 3)

[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""

a1, a2 = d6[..., :3], d6[..., 3:]
b1 = F.normalize(a1, dim=-1)
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
b2 = F.normalize(b2, dim=-1)
b3 = torch.cross(b1, b2, dim=-1)
return torch.stack((b1, b2, b3), dim=-2)


def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
"""
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
by dropping the last row. Note that 6D representation is not unique.
Args:
matrix: batch of rotation matrices of size (*, 3, 3)

Returns:
6D rotation representation, of size (*, 6)

[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
batch_dim = matrix.size()[:-2]
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
6 changes: 6 additions & 0 deletions utils/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def plot_global_pose(checkpoint_dir, dataset="kitti", epoch=None, mode=None, rot
q = q[:, [1, 2, 3, 0]]
rpy = Rot.from_quat(q).as_euler("XYZ")
location = np.concatenate((location[:,:3],rpy),axis=1)
elif rotation_representation == "6d":
rotation_6d = torch.tensor(location[:, 3:])
rotation_matrix = rotation_6d_to_matrix(rotation_6d)
rpy = Rot.from_matrix(rotation_matrix.numpy()).as_euler("XYZ")
location = np.concatenate((location[:, :3], rpy), axis=1)

t = np.arange(location.shape[0]) / location.shape[0]
# location[:, 0] = location[:, 0] - np.mean(location[:, 0])
# location[:, 1] = location[:, 1] - np.mean(location[:, 1])
Expand Down