Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
tl.diff_exp_alphaquant
tl.pca
tl.diff_exp_ebayes
tl.extract_pca_anndata

```

Expand Down
229 changes: 222 additions & 7 deletions docs/notebooks/04_basic_PCA_workflow.ipynb

Large diffs are not rendered by default.

53 changes: 24 additions & 29 deletions src/alphatools/pl/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
from alphatools.pl import defaults
from alphatools.pl.colors import BaseColors, BasePalettes, _get_colors_from_cmap, get_color_mapping
from alphatools.pl.figure import create_figure, label_axes
from alphatools.pl.plot_data_handling import (
from alphatools.pp.data import data_column_to_array
from alphatools.tl.plot_data_handling import (
extract_pca_anndata,
prepare_pca_1d_loadings_data_to_plot,
prepare_pca_2d_loadings_data_to_plot,
prepare_pca_data_to_plot,
prepare_scree_data_to_plot,
)
from alphatools.pp.data import data_column_to_array

# logging configuration
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -1122,14 +1122,13 @@ def rank_median_plot(
def plot_pca(
cls,
data: ad.AnnData,
x_column: int = 1,
y_column: int = 2,
pc_x: int = 1,
pc_y: int = 2,
color: str = "blue",
color_map_column: str | None = None,
color_column: str | None = None,
dim_space: str = "obs",
embeddings_name: str | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the TODO is no longer valid?

# TODO: the below argument is an antipattern resulting from this function doing multiple things. In the future, this should be replaced by a pca-plotting adapter so that pca_plot is no longer needed and scatter can be used instead, followed by label_plot, etc.
label: bool = False, # noqa: FBT001, FBT002
label_column: str | None = None,
ax: plt.Axes | None = None,
Expand Down Expand Up @@ -1180,34 +1179,25 @@ def plot_pca(
"""
scatter_kwargs = scatter_kwargs or {}

pca_coor_df = prepare_pca_data_to_plot(
data, x_column, y_column, dim_space, embeddings_name, color_map_column, label_column, label=label
pca_anndata = extract_pca_anndata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adata_pca? for consistency..

data, dim_space=dim_space, embeddings_name=embeddings_name, expression_columns=color_map_column
)

# Check if the variance layer exists in uns
variance_key = f"variance_pca_{dim_space}" if embeddings_name is None else embeddings_name

if variance_key not in data.uns:
raise ValueError(
f"PCA metadata layer '{variance_key}' not found in AnnData object. "
f"Found layers: {list(data.uns.keys())}"
)

# get the explained variance ratio for the dimensions (for axis labels)
var_dim1 = data.uns[variance_key]["variance_ratio"][x_column - 1]
var_dim1 = pca_anndata.var["variance_ratio"][str(pc_x)]
var_dim1 = round(var_dim1 * 100, 2)
var_dim2 = data.uns[variance_key]["variance_ratio"][y_column - 1]
var_dim2 = pca_anndata.var["variance_ratio"][str(pc_y)]
var_dim2 = round(var_dim2 * 100, 2)

# add color column
if color_map_column is not None:
color_values = data_column_to_array(data, color_map_column)
pca_coor_df[color_map_column] = color_values
# check pc_x and pc_y are valid
n_pcs = pca_anndata.shape[1]
if pc_x < 1 or pc_x > n_pcs or pc_y < 1 or pc_y > n_pcs:
raise ValueError(f"pc_x and pc_y are out of bounds, must be between 1 and {n_pcs}")

cls.scatter(
data=pca_coor_df,
x_column="dim1",
y_column="dim2",
data=pca_anndata,
x_column=str(pc_x),
y_column=str(pc_y),
color=color,
color_column=color_column,
color_map_column=color_map_column,
Expand All @@ -1225,11 +1215,16 @@ def plot_pca(
labels = data.obs.index if label_column is None else data_column_to_array(data, label_column)
else: # dim_space == "var"
labels = data.var.index if label_column is None else data_column_to_array(data, label_column)

label_plot(ax=ax, x_values=pca_coor_df["dim1"], y_values=pca_coor_df["dim2"], labels=labels, x_anchors=None)
label_plot(
ax=ax,
x_values=pca_anndata.X[:, pc_x - 1],
y_values=pca_anndata.X[:, pc_y - 1],
labels=labels,
x_anchors=None,
)

# set axislabels
label_axes(ax, xlabel=f"PC{x_column} ({var_dim1}%)", ylabel=f"PC{y_column} ({var_dim2}%)")
label_axes(ax, xlabel=f"PC{pc_x} ({var_dim1}%)", ylabel=f"PC{pc_y} ({var_dim2}%)")

@classmethod
def scree_plot(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# TODO: move this submodule to tl module
"""
Auxiliary functions for handling data and formatting for PCA plot input.

Expand All @@ -13,8 +12,6 @@
import numpy as np
import pandas as pd

from alphatools.pp.data import data_column_to_array

# logging configuration
logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -49,8 +46,7 @@ def _validate_adata_and_dim_space(data, dim_space: str) -> None: # noqa: ANN001
def _validate_pca_plot_input(
data: ad.AnnData,
pca_embeddings_layer_name: str,
pc_x: int,
pc_y: int,
pca_var_key: str,
dim_space: str,
) -> None:
"""
Expand All @@ -62,10 +58,8 @@ def _validate_pca_plot_input(
AnnData object to be validated.
pca_embeddings_layer_name:
Name of the PCA layer to be checked.
pc_x:
First PCA dimension to be validated (1-indexed, i.e. the first PC is 1, not 0).
pc_y:
Second PCA dimension to be validated (1-indexed, i.e. the first PC is 1, not 0).
pca_var_key:
Name of the PCA variance metadata layer to be checked, stored in `data.uns`.
dim_space:
The dimension space used in PCA. Can be either "obs" or "var".
"""
Expand All @@ -82,10 +76,11 @@ def _validate_pca_plot_input(
f"Found layers: {available_layers}"
)

# Check PC dimensions
n_pcs = getattr(data, pca_coors_attr)[pca_embeddings_layer_name].shape[1]
if not (1 <= pc_x <= n_pcs) or not (1 <= pc_y <= n_pcs):
raise ValueError(f"pc_x and pc_y must be between 1 and {n_pcs} (inclusive). Got {pc_x=}, {pc_y=}")
# Check if the variance layer exists in uns
if pca_var_key not in data.uns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to adata?

raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was there a reason to move the if not (1 <= pc_x <= n_pcs) check?

f"PCA metadata layer '{pca_var_key}' not found in AnnData object. Found layers: {list(data.uns.keys())}"
)


def _validate_scree_plot_input(
Expand Down Expand Up @@ -175,74 +170,78 @@ def _validate_pca_loadings_plot_inputs(
## Functions to prepare data frames for plotting using the scatter method


def prepare_pca_data_to_plot(
def extract_pca_anndata(
data: ad.AnnData,
pc_x: int = 1,
pc_y: int = 2,
dim_space: str = "obs",
embeddings_name: str | None = None,
color_map_column: str | None = None,
label_column: str | None = None,
*,
label: bool = False,
) -> pd.DataFrame:
expression_columns: list[str] | None = None,
) -> ad.AnnData:
"""
Fetched PCA data required from PCA plotting from AnnData object (as returned by `pca` function).

Parameters
----------
data : ad.AnnData
AnnData object containing PCA results.
pc_x : int
First principal component (1-indexed).
pc_y : int
Second principal component (1-indexed).
dim_space : str
Either "obs" or "var" for observation or variable embeddings space.
embeddings_name : str | None
Custom embeddings name or None for default.
color_map_column : str | None
Column for color mapping.
label_column : str | None
Column for labeling points.
label : bool
Whether labels are requested.
expression_columns : str | None
a list of `var_names` (if `dim_space = obs`) or `obs_names` (if `dim_space = var`).

Returns
-------
pd.DataFrame
DataFrame with PCA coordinates,color data and labels if requested.
ad.AnnData
AnnData object containing PCA coordinates, color mapping, and labels.
"""
# Generate the correct key names based on dim_space and embeddings_name
pca_coors_key = f"X_pca_{dim_space}" if embeddings_name is None else embeddings_name
pca_var_key = f"variance_pca_{dim_space}" if embeddings_name is None else embeddings_name

# Input checks
_validate_pca_plot_input(data, pca_coors_key, pc_x, pc_y, dim_space)

# Create the dataframe for plotting
dim1_z = pc_x - 1 # to account for 0 indexing
dim2_z = pc_y - 1 # to account for 0 indexing
_validate_pca_plot_input(data, pca_coors_key, pca_var_key, dim_space)

# Get PCA coordinates from the correct attribute
pca_coordinates = data.obsm[pca_coors_key] if dim_space == "obs" else data.varm[pca_coors_key]
obs_df = data.obs if dim_space == "obs" else data.var
var_df = pd.DataFrame(data.uns[pca_var_key])

pca_coor_df = pd.DataFrame(pca_coordinates[:, [dim1_z, dim2_z]], columns=["dim1", "dim2"])
pca_anndata = ad.AnnData(X=pca_coordinates)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adata_pca?


# Add color column if specified
if color_map_column is not None:
color_values = data_column_to_array(data, color_map_column)
pca_coor_df[color_map_column] = color_values
pca_anndata.obs = obs_df
pca_anndata.var = var_df

# Prepare labels if requested
labels = None
if label:
if dim_space == "obs":
labels = data.obs.index if label_column is None else data_column_to_array(data, label_column)
else: # dim_space == "var"
labels = data.var.index if label_column is None else data_column_to_array(data, label_column)
pca_coor_df["labels"] = labels

return pca_coor_df
if expression_columns is not None:
expression_columns = (
expression_columns if isinstance(expression_columns, (list, tuple, np.ndarray)) else [expression_columns]
)
# Subset the data to only include the specified expression columns
if dim_space == "obs" and len(data.var_names.intersection(expression_columns)) > 0:
expr_cols = list(set(expression_columns) & set(data.var_names))
expr_data = pd.DataFrame(data[:, expr_cols].X)
expr_data = expr_data.apply(pd.to_numeric, errors="coerce")
# set colnames and indices of the data
expr_data = expr_data.set_axis(expr_cols, axis=1)
expr_data.index = data.obs_names
pca_anndata.obs = pca_anndata.obs.join(expr_data)

elif dim_space == "var" and len(list(set(expression_columns) & set(data.obs_names))) > 0:
expr_rows = list(set(expression_columns) & set(data.obs_names))
expr_data = pd.DataFrame(data[:, expr_cols].X.T)
expr_data = expr_data.apply(pd.to_numeric, errors="coerce")

expr_data = expr_data.set_axis(expr_rows, axis=1)
expr_data.index = data.var_names
pca_anndata.obs = pca_anndata.obs.join(expr_data)
else:
print("here")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accidental commit?

logging.warning(
f"No matching expression columns found in the data.X for the specified PCA dim_space '{dim_space}' and {expression_columns}."
)

pca_anndata.var_names = [str(i + 1) for i in range(pca_anndata.X.shape[1])]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just "1"? not "pc_1"?

return pca_anndata


def prepare_scree_data_to_plot(
Expand Down
Loading