-
Notifications
You must be signed in to change notification settings - Fork 989
Description
问题在于 pointnet2_utils 模块中的 球形领域查询函数:query_ball_point,其存在严重 BUG!!
完整代码如下:
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) # ❌
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] # 🔴
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
其中 group_idx 存放的是索引,不包含任何距离信息,
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) # ❌
对于 group_idx 进行的排序,在pointnet++官方中使用的是距离排序,而这里:
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] # 🔴
由于 group_idx 存放的是索引,因此这里完全是针对索引进行的排序
由此导致的的后果是,在对临时替换为 越界索引 N 的数据,后期对其进行修正时,本应该用到 排在第一位,距离最近的邻居点进行替换,结果排在第一位的并非是距离最近的点,由此导致经过替换后,仍无法保证索引恢复合法状态