Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
62ce27f
update
bertranMiquel May 13, 2024
e8a35dc
First ring-based commit
bertranMiquel Jun 13, 2024
1385b4e
Merge remote-tracking branch 'upstream/main'
bertranMiquel Jun 13, 2024
82c6b25
Readme update
bertranMiquel Jun 13, 2024
4be4aa6
mol ring lifting correctation
bertranMiquel Jun 13, 2024
02ea2bf
Correct more errors
bertranMiquel Jun 13, 2024
a83f484
Fix lifting details
bertranMiquel Jun 13, 2024
cd936f7
Hope it is the last one
bertranMiquel Jun 13, 2024
431e246
Deleting other test since they are using other test dataset
bertranMiquel Jun 13, 2024
ae365b5
git ignore modification
bertranMiquel Jun 13, 2024
b201b54
minor modifications
bertranMiquel Jun 14, 2024
917799b
Update ring_lifting.py
bertranMiquel Jun 14, 2024
a7534a8
Ring lifting modifications
bertranMiquel Jun 15, 2024
b884a92
Update dependencies
bertranMiquel Jun 18, 2024
1da64be
Updates due to warnings
bertranMiquel Jun 20, 2024
2f35444
Reload tests
bertranMiquel Jun 20, 2024
9ab1d8a
Trail spaces
bertranMiquel Jun 20, 2024
59735bd
Changes in test manual data.
bertranMiquel Jun 20, 2024
2e30876
Remove space
bertranMiquel Jun 20, 2024
28e7b74
Refine test.
bertranMiquel Jun 20, 2024
704e2e1
Remove notebook
bertranMiquel Jun 20, 2024
a07528f
correct torch in test
bertranMiquel Jun 20, 2024
db73c5e
Adding attributes functions
bertranMiquel Jul 1, 2024
0bbc8cd
Adding attributes
bertranMiquel Jul 1, 2024
0783baf
Load UniProt data
bertranMiquel Jul 3, 2024
29608e8
Trying lifting stuff
bertranMiquel Jul 4, 2024
456dddc
Lifting try
bertranMiquel Jul 5, 2024
06b2084
All developed
bertranMiquel Jul 8, 2024
7a7d2df
Comments updated
bertranMiquel Jul 8, 2024
cf7ffb7
Update readme
bertranMiquel Jul 8, 2024
530a198
Remove random package
bertranMiquel Jul 8, 2024
e1f83a3
update size parameter
bertranMiquel Jul 8, 2024
44e083d
Change size parameter
bertranMiquel Jul 8, 2024
c406216
Change attributes to mass function name
bertranMiquel Jul 8, 2024
7c46cb9
Update import class
bertranMiquel Jul 8, 2024
9cc415a
Update import class
bertranMiquel Jul 8, 2024
d1f49f6
Add parameter to test
bertranMiquel Jul 8, 2024
0fbae67
Finish test data
bertranMiquel Jul 8, 2024
cf27d2f
Update test data reference
bertranMiquel Jul 9, 2024
27df0a8
ruff modifications
bertranMiquel Jul 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions configs/datasets/QM9.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
data_domain: graph
data_type: QM9
data_name: QM9
data_dir: datasets/${data_domain}/${data_type}
#data_split_dir: ${oc.env:PROJECT_ROOT}/datasets/data_splits/${data_name}

# Dataset parameters
num_features: 11
num_classes: 1
task: regression
loss_type: mse
monitor_metric: mae
task_level: graph

22 changes: 22 additions & 0 deletions configs/datasets/UniProt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
data_domain: graph
data_type: UniProt
data_name: UniProt
data_dir: datasets/${data_domain}/${data_type}
#data_split_dir: ${oc.env:PROJECT_ROOT}/datasets/data_splits/${data_name}

# Some parameters to do the query
query: "length:[95 TO 155]" # number of residues per protein
format: "tsv"
fields: "accession,length"
size: 100 # number of proteins to load

threshold: 6.0 # distance between proteins to create the initial graph

# Dataset parameters
num_features: 20
num_classes: 1
task: regression
loss_type: mse
monitor_metric: mae
task_level: graph

12 changes: 12 additions & 0 deletions configs/datasets/manual_prot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: graph
data_type: toy_dataset
data_name: manual_prot
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
num_features: 1
num_classes: 2
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: node
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transform_type: 'lifting'
transform_name: "HypergraphCloseLifting"
feature_lifting: ProjectionSum

distance: 6.0
227 changes: 227 additions & 0 deletions loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import os

import numpy as np
import rootutils
import torch_geometric
from omegaconf import DictConfig

# silent RDKit warnings
from rdkit import Chem, RDLogger

from modules.data.load.base import AbstractLoader
from modules.data.utils.concat2geometric_dataset import ConcatToGeometricDataset
from modules.data.utils.custom_dataset import CustomDataset
from modules.data.utils.utils import (
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
load_manual_graph,
load_simplicial_dataset,
)

RDLogger.DisableLog("rdApp.*")


class GraphLoader(AbstractLoader):
r"""Loader for graph datasets.

Parameters
----------
parameters : DictConfig
Configuration parameters.
"""

def __init__(self, parameters: DictConfig):
super().__init__(parameters)
self.parameters = parameters

def is_valid_smiles(self, smiles):
"""Check if a SMILES string is valid using RDKit."""
mol = Chem.MolFromSmiles(smiles)
return mol is not None

def filter_qm9_dataset(self, dataset):
"""Filter the QM9 dataset to remove invalid SMILES strings."""
return [data for data in dataset if self.is_valid_smiles(data.smiles)]

def load(self) -> torch_geometric.data.Dataset:
r"""Load graph dataset.

Parameters
----------
None

Returns
-------
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
# Define the path to the data directory
root_folder = rootutils.find_root()
root_data_dir = os.path.join(root_folder, self.parameters["data_dir"])

self.data_dir = os.path.join(root_data_dir, self.parameters["data_name"])
if (
self.parameters.data_name.lower() in ["cora", "citeseer", "pubmed"]
and self.parameters.data_type == "cocitation"
):
dataset = torch_geometric.datasets.Planetoid(
root=root_data_dir,
name=self.parameters["data_name"],
)

elif self.parameters.data_name in [
"MUTAG",
"ENZYMES",
"PROTEINS",
"COLLAB",
"IMDB-BINARY",
"IMDB-MULTI",
"REDDIT-BINARY",
"NCI1",
"NCI109",
]:
dataset = torch_geometric.datasets.TUDataset(
root=root_data_dir,
name=self.parameters["data_name"],
use_node_attr=False,
)

elif self.parameters.data_name in ["ZINC", "AQSOL"]:
datasets = []
for split in ["train", "val", "test"]:
if self.parameters.data_name == "ZINC":
datasets.append(
torch_geometric.datasets.ZINC(
root=root_data_dir,
subset=True,
split=split,
)
)
elif self.parameters.data_name == "AQSOL":
datasets.append(
torch_geometric.datasets.AQSOL(
root=root_data_dir,
split=split,
)
)
# The splits are predefined
# Extract and prepare split_idx
split_idx = {"train": np.arange(len(datasets[0]))}
split_idx["valid"] = np.arange(
len(datasets[0]), len(datasets[0]) + len(datasets[1])
)
split_idx["test"] = np.arange(
len(datasets[0]) + len(datasets[1]),
len(datasets[0]) + len(datasets[1]) + len(datasets[2]),
)
# Join dataset to process it
dataset = datasets[0] + datasets[1] + datasets[2]
dataset = ConcatToGeometricDataset(dataset)

elif self.parameters.data_name == "QM9":
dataset = torch_geometric.datasets.QM9(root=root_data_dir)
# Filter the QM9 dataset to remove invalid SMILES strings
valid_dataset = self.filter_qm9_dataset(dataset)
# dataset = ConcatToGeometricDataset(valid_dataset)
dataset = CustomDataset(valid_dataset, self.data_dir)

elif self.parameters.data_name in ["manual"]:
data = load_manual_graph()
dataset = CustomDataset([data], self.data_dir)

else:
raise NotImplementedError(
f"Dataset {self.parameters.data_name} not implemented"
)

return dataset


class CellComplexLoader(AbstractLoader):
r"""Loader for cell complex datasets.

Parameters
----------
parameters : DictConfig
Configuration parameters.
"""

def __init__(self, parameters: DictConfig):
super().__init__(parameters)
self.parameters = parameters

def load(
self,
) -> torch_geometric.data.Dataset:
r"""Load cell complex dataset.

Parameters
----------
None

Returns
-------
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
return load_cell_complex_dataset(self.parameters)


class SimplicialLoader(AbstractLoader):
r"""Loader for simplicial datasets.

Parameters
----------
parameters : DictConfig
Configuration parameters.
"""

def __init__(self, parameters: DictConfig):
super().__init__(parameters)
self.parameters = parameters

def load(
self,
) -> torch_geometric.data.Dataset:
r"""Load simplicial dataset.

Parameters
----------
None

Returns
-------
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
return load_simplicial_dataset(self.parameters)


class HypergraphLoader(AbstractLoader):
r"""Loader for hypergraph datasets.

Parameters
----------
parameters : DictConfig
Configuration parameters.
"""

def __init__(self, parameters: DictConfig):
super().__init__(parameters)
self.parameters = parameters

def load(
self,
) -> torch_geometric.data.Dataset:
r"""Load hypergraph dataset.

Parameters
----------
None

Returns
-------
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
return load_hypergraph_pickle_dataset(self.parameters)
Loading