Skip to content
This repository was archived by the owner on May 21, 2025. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions jraph/_src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,19 +547,26 @@ def _ApplyGCN(graph):
nodes = update_node_fn(nodes)
# Equivalent to jnp.sum(n_node), but jittable
total_num_nodes = tree.tree_leaves(nodes)[0].shape[0]

# Handle None senders and receivers by initializing empty arrays
if senders is None:
senders = jnp.array([], dtype=jnp.int32)
if receivers is None:
receivers = jnp.array([], dtype=jnp.int32)

if add_self_edges:
# We add self edges to the senders and receivers so that each node
# includes itself in aggregation.
# In principle, a `GraphsTuple` should partition by n_edge, but in
# this case it is not required since a GCN is agnostic to whether
# the `GraphsTuple` is a batch of graphs or a single large graph.
conv_receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)),
# We add self edges to the senders and receivers so that each node
# includes itself in aggregation.
# In principle, a `GraphsTuple` should partition by n_edge, but in
# this case it is not required since a GCN is agnostic to whether
# the `GraphsTuple` is a batch of graphs or a single large graph.
conv_receivers = jnp.concatenate((receivers, jnp.arange(total_num_nodes)),
axis=0)
conv_senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)),
conv_senders = jnp.concatenate((senders, jnp.arange(total_num_nodes)),
axis=0)
else:
conv_senders = senders
conv_receivers = receivers
conv_senders = senders
conv_receivers = receivers

# pylint: disable=g-long-lambda
if symmetric_normalization:
Expand Down