Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions configs/datasets/annulus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: annulus
data_dir: datasets/${data_domain}/${data_type}/${data_name}

# Dataset parameters
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
num_features: 10
num_classes: 3
num_points: 1000
dim: 2
13 changes: 13 additions & 0 deletions configs/datasets/random_pointcloud.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: random_pointcloud
data_dir: datasets/${data_domain}/${data_type}/${data_name}

# Dataset parameters
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
num_features: 10
num_classes: 3
num_points: 50
dim: 2
11 changes: 11 additions & 0 deletions configs/datasets/stanford_bunny.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: stanford_bunny
data_dir: datasets/${data_domain}/${data_type}/${data_name}

# Dataset parameters
num_classes: 1
task: regression
loss_type: mse
monitor_metric: mae
task_level: graph
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
transform_type: 'lifting'
transform_name: "CoverLifting"
feature_lifting: ProjectionSum
34 changes: 34 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
load_manual_graph,
load_pointcloud_dataset,
load_simplicial_dataset,
)

Expand Down Expand Up @@ -204,3 +205,36 @@ def load(
torch_geometric.data.Dataset object containing the loaded data.
"""
return load_hypergraph_pickle_dataset(self.parameters)


class PointCloudLoader(AbstractLoader):
r"""Loader for point cloud datasets.
Parameters
----------
parameters : DictConfig
Configuration parameters.
"""

def __init__(self, parameters: DictConfig):
super().__init__(parameters)
self.parameters = parameters

def load(
self,
) -> torch_geometric.data.Dataset:
r"""Load point cloud dataset.
Parameters
----------
None
Returns
-------
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
# Define the path to the data directory
root_folder = rootutils.find_root()
self.data_dir = os.path.join(root_folder, self.parameters["data_dir"])

data = load_pointcloud_dataset(self.parameters)
print(data, data[0])
return load_pointcloud_dataset(self.parameters)
103 changes: 95 additions & 8 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
import networkx as nx
import numpy as np
import omegaconf
import rootutils
import toponetx.datasets.graph as graph
import torch
import torch_geometric
from gudhi.datasets.remote import fetch_bunny
from topomodelx.utils.sparse import from_sparse
from torch_geometric.data import Data
from torch_sparse import coalesce

from modules.data.utils.custom_dataset import CustomDataset


def get_complex_connectivity(complex, max_rank, signed=False):
r"""Gets the connectivity matrices for the complex.
Expand Down Expand Up @@ -50,16 +54,16 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity
Expand Down Expand Up @@ -372,6 +376,89 @@ def get_TUDataset_pyg(cfg):
return [data for data in dataset]


def load_pointcloud_dataset(cfg):
r"""Loads point cloud datasets.
Parameters
----------
cfg : DictConfig
Configuration parameters.
Returns
-------
torch_geometric.data.Data
Point cloud dataset.
"""
root_folder = rootutils.find_root()
data_dir = osp.join(root_folder, cfg["data_dir"])

if cfg["data_name"] == "random_pointcloud":
num_points, dim = cfg["num_points"], cfg["dim"]
pos = torch.rand((num_points, dim))
elif cfg["data_name"] == "annulus":
num_points, dim = cfg["num_points"], cfg["dim"]
pos = sample_annulus(dim, num_points)
elif cfg["data_name"] == "stanford_bunny":
pos = fetch_bunny(
file_path=osp.join(data_dir, "stanford_bunny.npy"),
accept_license=False,
)
pos = torch.tensor(pos)

return CustomDataset(
[
torch_geometric.data.Data(
pos=pos,
)
],
data_dir,
)


def sample_annulus(
dimension: int, num_points: int, inner_radius: float = 0.8, outer_radius: float = 1
):
"""Sample points from annulus of the given dimension.

Parameters
----------
dimension : int
Dimension
num_points : int
Size of sample
inner_radius : float, optional
Inner radius, by default 0.8
outer_radius : float, optional
Outer radius, by default 1

Returns
-------
torch.Tensor
Tensor of sampled points
"""
n = 0
P = np.array([[0.0] * dimension] * num_points)

rng = np.random.default_rng()
while n < num_points:
p = rng.uniform(-outer_radius, outer_radius, dimension)
if np.linalg.norm(p) > outer_radius or np.linalg.norm(p) < inner_radius:
continue
P[n] = p
n = n + 1
return torch.tensor(P)


def load_annulus():
"""Loads 2D annulus point cloud.

Returns
-------
torch_geometric.data.Data
Point cloud data.
"""
pos = sample_annulus(2, 1000)
return torch_geometric.data.Data(pos=pos)


def ensure_serializable(obj):
r"""Ensures that the object is serializable.

Expand Down
3 changes: 3 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.pointcloud2graph.cover_lifting import CoverLifting

TRANSFORMS = {
# Graph -> Hypergraph
Expand All @@ -23,6 +24,8 @@
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Point Cloud -> Graph
"CoverLifting": CoverLifting,
# Feature Liftings
"ProjectionSum": ProjectionSum,
# Data Manipulations
Expand Down
9 changes: 6 additions & 3 deletions modules/transforms/liftings/lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data:
initial_data = data.to_dict()
lifted_topology = self.lift_topology(data)
lifted_topology = self.feature_lifting(lifted_topology)

return torch_geometric.data.Data(**initial_data, **lifted_topology)


Expand Down Expand Up @@ -118,9 +119,11 @@ def _generate_graph_from_data(self, data: torch_geometric.data.Data) -> nx.Graph
# In case edge features are given, assign features to every edge
edge_index, edge_attr = (
data.edge_index,
data.edge_attr
if is_undirected(data.edge_index, data.edge_attr)
else to_undirected(data.edge_index, data.edge_attr),
(
data.edge_attr
if is_undirected(data.edge_index, data.edge_attr)
else to_undirected(data.edge_index, data.edge_attr)
),
)
edges = [
(i.item(), j.item(), dict(features=edge_attr[edge_idx], dim=1))
Expand Down
Loading