-
Notifications
You must be signed in to change notification settings - Fork 0
Return pca results in AnnData format #132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,7 @@ | |
| tl.diff_exp_alphaquant | ||
| tl.pca | ||
| tl.diff_exp_ebayes | ||
| tl.extract_pca_anndata | ||
|
|
||
| ``` | ||
|
|
||
|
|
||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,13 +24,13 @@ | |
| from alphatools.pl import defaults | ||
| from alphatools.pl.colors import BaseColors, BasePalettes, _get_colors_from_cmap, get_color_mapping | ||
| from alphatools.pl.figure import create_figure, label_axes | ||
| from alphatools.pl.plot_data_handling import ( | ||
| from alphatools.pp.data import data_column_to_array | ||
| from alphatools.tl.plot_data_handling import ( | ||
| extract_pca_anndata, | ||
| prepare_pca_1d_loadings_data_to_plot, | ||
| prepare_pca_2d_loadings_data_to_plot, | ||
| prepare_pca_data_to_plot, | ||
| prepare_scree_data_to_plot, | ||
| ) | ||
| from alphatools.pp.data import data_column_to_array | ||
|
|
||
| # logging configuration | ||
| logging.basicConfig(level=logging.INFO) | ||
|
|
@@ -1122,14 +1122,13 @@ def rank_median_plot( | |
| def plot_pca( | ||
| cls, | ||
| data: ad.AnnData, | ||
| x_column: int = 1, | ||
| y_column: int = 2, | ||
| pc_x: int = 1, | ||
| pc_y: int = 2, | ||
| color: str = "blue", | ||
| color_map_column: str | None = None, | ||
| color_column: str | None = None, | ||
| dim_space: str = "obs", | ||
| embeddings_name: str | None = None, | ||
| # TODO: the below argument is an antipattern resulting from this function doing multiple things. In the future, this should be replaced by a pca-plotting adapter so that pca_plot is no longer needed and scatter can be used instead, followed by label_plot, etc. | ||
| label: bool = False, # noqa: FBT001, FBT002 | ||
| label_column: str | None = None, | ||
| ax: plt.Axes | None = None, | ||
|
|
@@ -1180,34 +1179,25 @@ def plot_pca( | |
| """ | ||
| scatter_kwargs = scatter_kwargs or {} | ||
|
|
||
| pca_coor_df = prepare_pca_data_to_plot( | ||
| data, x_column, y_column, dim_space, embeddings_name, color_map_column, label_column, label=label | ||
| pca_anndata = extract_pca_anndata( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| data, dim_space=dim_space, embeddings_name=embeddings_name, expression_columns=color_map_column | ||
| ) | ||
|
|
||
| # Check if the variance layer exists in uns | ||
| variance_key = f"variance_pca_{dim_space}" if embeddings_name is None else embeddings_name | ||
|
|
||
| if variance_key not in data.uns: | ||
| raise ValueError( | ||
| f"PCA metadata layer '{variance_key}' not found in AnnData object. " | ||
| f"Found layers: {list(data.uns.keys())}" | ||
| ) | ||
|
|
||
| # get the explained variance ratio for the dimensions (for axis labels) | ||
| var_dim1 = data.uns[variance_key]["variance_ratio"][x_column - 1] | ||
| var_dim1 = pca_anndata.var["variance_ratio"][str(pc_x)] | ||
| var_dim1 = round(var_dim1 * 100, 2) | ||
| var_dim2 = data.uns[variance_key]["variance_ratio"][y_column - 1] | ||
| var_dim2 = pca_anndata.var["variance_ratio"][str(pc_y)] | ||
| var_dim2 = round(var_dim2 * 100, 2) | ||
|
|
||
| # add color column | ||
| if color_map_column is not None: | ||
| color_values = data_column_to_array(data, color_map_column) | ||
| pca_coor_df[color_map_column] = color_values | ||
| # check pc_x and pc_y are valid | ||
| n_pcs = pca_anndata.shape[1] | ||
| if pc_x < 1 or pc_x > n_pcs or pc_y < 1 or pc_y > n_pcs: | ||
| raise ValueError(f"pc_x and pc_y are out of bounds, must be between 1 and {n_pcs}") | ||
|
|
||
| cls.scatter( | ||
| data=pca_coor_df, | ||
| x_column="dim1", | ||
| y_column="dim2", | ||
| data=pca_anndata, | ||
| x_column=str(pc_x), | ||
| y_column=str(pc_y), | ||
| color=color, | ||
| color_column=color_column, | ||
| color_map_column=color_map_column, | ||
|
|
@@ -1225,11 +1215,16 @@ def plot_pca( | |
| labels = data.obs.index if label_column is None else data_column_to_array(data, label_column) | ||
| else: # dim_space == "var" | ||
| labels = data.var.index if label_column is None else data_column_to_array(data, label_column) | ||
|
|
||
| label_plot(ax=ax, x_values=pca_coor_df["dim1"], y_values=pca_coor_df["dim2"], labels=labels, x_anchors=None) | ||
| label_plot( | ||
| ax=ax, | ||
| x_values=pca_anndata.X[:, pc_x - 1], | ||
| y_values=pca_anndata.X[:, pc_y - 1], | ||
| labels=labels, | ||
| x_anchors=None, | ||
| ) | ||
|
|
||
| # set axislabels | ||
| label_axes(ax, xlabel=f"PC{x_column} ({var_dim1}%)", ylabel=f"PC{y_column} ({var_dim2}%)") | ||
| label_axes(ax, xlabel=f"PC{pc_x} ({var_dim1}%)", ylabel=f"PC{pc_y} ({var_dim2}%)") | ||
|
|
||
| @classmethod | ||
| def scree_plot( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,3 @@ | ||
| # TODO: move this submodule to tl module | ||
| """ | ||
| Auxiliary functions for handling data and formatting for PCA plot input. | ||
|
|
||
|
|
@@ -13,8 +12,6 @@ | |
| import numpy as np | ||
| import pandas as pd | ||
|
|
||
| from alphatools.pp.data import data_column_to_array | ||
|
|
||
| # logging configuration | ||
| logging.basicConfig(level=logging.INFO) | ||
|
|
||
|
|
@@ -49,8 +46,7 @@ def _validate_adata_and_dim_space(data, dim_space: str) -> None: # noqa: ANN001 | |
| def _validate_pca_plot_input( | ||
| data: ad.AnnData, | ||
| pca_embeddings_layer_name: str, | ||
| pc_x: int, | ||
| pc_y: int, | ||
| pca_var_key: str, | ||
| dim_space: str, | ||
| ) -> None: | ||
| """ | ||
|
|
@@ -62,10 +58,8 @@ def _validate_pca_plot_input( | |
| AnnData object to be validated. | ||
| pca_embeddings_layer_name: | ||
| Name of the PCA layer to be checked. | ||
| pc_x: | ||
| First PCA dimension to be validated (1-indexed, i.e. the first PC is 1, not 0). | ||
| pc_y: | ||
| Second PCA dimension to be validated (1-indexed, i.e. the first PC is 1, not 0). | ||
| pca_var_key: | ||
| Name of the PCA variance metadata layer to be checked, stored in `data.uns`. | ||
| dim_space: | ||
| The dimension space used in PCA. Can be either "obs" or "var". | ||
| """ | ||
|
|
@@ -82,10 +76,11 @@ def _validate_pca_plot_input( | |
| f"Found layers: {available_layers}" | ||
| ) | ||
|
|
||
| # Check PC dimensions | ||
| n_pcs = getattr(data, pca_coors_attr)[pca_embeddings_layer_name].shape[1] | ||
| if not (1 <= pc_x <= n_pcs) or not (1 <= pc_y <= n_pcs): | ||
| raise ValueError(f"pc_x and pc_y must be between 1 and {n_pcs} (inclusive). Got {pc_x=}, {pc_y=}") | ||
| # Check if the variance layer exists in uns | ||
| if pca_var_key not in data.uns: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename to |
||
| raise ValueError( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. was there a reason to move the |
||
| f"PCA metadata layer '{pca_var_key}' not found in AnnData object. Found layers: {list(data.uns.keys())}" | ||
| ) | ||
|
|
||
|
|
||
| def _validate_scree_plot_input( | ||
|
|
@@ -175,74 +170,78 @@ def _validate_pca_loadings_plot_inputs( | |
| ## Functions to prepare data frames for plotting using the scatter method | ||
|
|
||
|
|
||
| def prepare_pca_data_to_plot( | ||
| def extract_pca_anndata( | ||
| data: ad.AnnData, | ||
| pc_x: int = 1, | ||
| pc_y: int = 2, | ||
| dim_space: str = "obs", | ||
| embeddings_name: str | None = None, | ||
| color_map_column: str | None = None, | ||
| label_column: str | None = None, | ||
| *, | ||
| label: bool = False, | ||
| ) -> pd.DataFrame: | ||
| expression_columns: list[str] | None = None, | ||
| ) -> ad.AnnData: | ||
| """ | ||
| Fetched PCA data required from PCA plotting from AnnData object (as returned by `pca` function). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : ad.AnnData | ||
| AnnData object containing PCA results. | ||
| pc_x : int | ||
| First principal component (1-indexed). | ||
| pc_y : int | ||
| Second principal component (1-indexed). | ||
| dim_space : str | ||
| Either "obs" or "var" for observation or variable embeddings space. | ||
| embeddings_name : str | None | ||
| Custom embeddings name or None for default. | ||
| color_map_column : str | None | ||
| Column for color mapping. | ||
| label_column : str | None | ||
| Column for labeling points. | ||
| label : bool | ||
| Whether labels are requested. | ||
| expression_columns : str | None | ||
| a list of `var_names` (if `dim_space = obs`) or `obs_names` (if `dim_space = var`). | ||
|
|
||
| Returns | ||
| ------- | ||
| pd.DataFrame | ||
| DataFrame with PCA coordinates,color data and labels if requested. | ||
| ad.AnnData | ||
| AnnData object containing PCA coordinates, color mapping, and labels. | ||
| """ | ||
| # Generate the correct key names based on dim_space and embeddings_name | ||
| pca_coors_key = f"X_pca_{dim_space}" if embeddings_name is None else embeddings_name | ||
| pca_var_key = f"variance_pca_{dim_space}" if embeddings_name is None else embeddings_name | ||
|
|
||
| # Input checks | ||
| _validate_pca_plot_input(data, pca_coors_key, pc_x, pc_y, dim_space) | ||
|
|
||
| # Create the dataframe for plotting | ||
| dim1_z = pc_x - 1 # to account for 0 indexing | ||
| dim2_z = pc_y - 1 # to account for 0 indexing | ||
| _validate_pca_plot_input(data, pca_coors_key, pca_var_key, dim_space) | ||
|
|
||
| # Get PCA coordinates from the correct attribute | ||
| pca_coordinates = data.obsm[pca_coors_key] if dim_space == "obs" else data.varm[pca_coors_key] | ||
| obs_df = data.obs if dim_space == "obs" else data.var | ||
| var_df = pd.DataFrame(data.uns[pca_var_key]) | ||
|
|
||
| pca_coor_df = pd.DataFrame(pca_coordinates[:, [dim1_z, dim2_z]], columns=["dim1", "dim2"]) | ||
| pca_anndata = ad.AnnData(X=pca_coordinates) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| # Add color column if specified | ||
| if color_map_column is not None: | ||
| color_values = data_column_to_array(data, color_map_column) | ||
| pca_coor_df[color_map_column] = color_values | ||
| pca_anndata.obs = obs_df | ||
| pca_anndata.var = var_df | ||
|
|
||
| # Prepare labels if requested | ||
| labels = None | ||
| if label: | ||
| if dim_space == "obs": | ||
| labels = data.obs.index if label_column is None else data_column_to_array(data, label_column) | ||
| else: # dim_space == "var" | ||
| labels = data.var.index if label_column is None else data_column_to_array(data, label_column) | ||
| pca_coor_df["labels"] = labels | ||
|
|
||
| return pca_coor_df | ||
| if expression_columns is not None: | ||
| expression_columns = ( | ||
| expression_columns if isinstance(expression_columns, (list, tuple, np.ndarray)) else [expression_columns] | ||
| ) | ||
| # Subset the data to only include the specified expression columns | ||
| if dim_space == "obs" and len(data.var_names.intersection(expression_columns)) > 0: | ||
| expr_cols = list(set(expression_columns) & set(data.var_names)) | ||
| expr_data = pd.DataFrame(data[:, expr_cols].X) | ||
| expr_data = expr_data.apply(pd.to_numeric, errors="coerce") | ||
| # set colnames and indices of the data | ||
| expr_data = expr_data.set_axis(expr_cols, axis=1) | ||
| expr_data.index = data.obs_names | ||
| pca_anndata.obs = pca_anndata.obs.join(expr_data) | ||
|
|
||
| elif dim_space == "var" and len(list(set(expression_columns) & set(data.obs_names))) > 0: | ||
| expr_rows = list(set(expression_columns) & set(data.obs_names)) | ||
| expr_data = pd.DataFrame(data[:, expr_cols].X.T) | ||
| expr_data = expr_data.apply(pd.to_numeric, errors="coerce") | ||
|
|
||
| expr_data = expr_data.set_axis(expr_rows, axis=1) | ||
| expr_data.index = data.var_names | ||
| pca_anndata.obs = pca_anndata.obs.join(expr_data) | ||
| else: | ||
| print("here") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. accidental commit? |
||
| logging.warning( | ||
| f"No matching expression columns found in the data.X for the specified PCA dim_space '{dim_space}' and {expression_columns}." | ||
| ) | ||
|
|
||
| pca_anndata.var_names = [str(i + 1) for i in range(pca_anndata.X.shape[1])] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just "1"? not "pc_1"? |
||
| return pca_anndata | ||
|
|
||
|
|
||
| def prepare_scree_data_to_plot( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the TODO is no longer valid?