import scanpy as sc
import anndata as ad
import os
os.environ['USE_PYGEOS'] = '0' # To supress a warning from geopandas
import squidpy as sq
import pandas as pd
from tqdm import tqdm
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import matplotlib as mpl
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.colors as colors
from time import time
import pathlib
import sys
#Path a spared
SPARED_PATH = pathlib.Path(__file__).resolve().parent.parent
#Agregar el directorio padre al sys.path para los imports
sys.path.append(str(SPARED_PATH))
#import requiere files
from gene_features import gene_features
from filtering import filtering
from spot_features import spot_features
[docs]def plot_all_slides(dataset: str, processed_adata: ad.AnnData, path: str) -> None:
""" Plot all the whole slide images
This function takes a slide collection and plot all the whole slide images in a square aspect ratio.
Args:
dataset: Name of the dataset
processed_adata (ad.AnnData): Processed and filtered data ready to use by the model.
path (str): Path to save the plot.
"""
# Get unique slide ids
unique_ids = sorted(processed_adata.obs['slide_id'].unique())
# Get number of rows and columns for the number of slides
n = int(np.ceil(np.sqrt(len(unique_ids))))
m = int(np.ceil(float(len(unique_ids))/float(n)))
# Define figure
fig, ax = plt.subplots(nrows=m, ncols=n)
fig.set_size_inches(7, (7*m/n)+1)
# Flatten axes list
ax = ax.flatten()
for i, curr_id in enumerate(unique_ids):
curr_img = processed_adata.uns['spatial'][curr_id]['images']['lowres']
ax[i].imshow(curr_img)
ax[i].set_title(curr_id, fontsize='large')
# Remove axis for all the figure
[axis.axis('off') for axis in ax]
fig.suptitle(f'All Histology Images from {dataset}', fontsize='x-large')
fig.tight_layout()
fig.savefig(path, dpi=300)
plt.close()
[docs]def plot_exp_frac(param_dict: dict, dataset: str, raw_adata: ad.AnnData, path: str) -> None:
""" Plot heatmap of expression fraction
This function plots a heatmap of the expression fraction and global expression fraction for the complete collection of slides.
Args:
raw_adata (ad.AnnData): An unfiltered and unprocessed (in raw counts) slide collection.
path (str): Path to save the plot.
"""
# Find indexes of cells with total_counts outside the range [cell_min_counts, cell_max_counts]
sample_counts = np.squeeze(np.asarray(raw_adata.X.sum(axis=1)))
bool_valid_samples = (sample_counts > param_dict['cell_min_counts']) & (sample_counts < param_dict['cell_max_counts'])
valid_samples = raw_adata.obs_names[bool_valid_samples]
# Subset the raw_adata to keep only the valid samples
raw_adata = raw_adata[valid_samples, :].copy()
# Compute the min expression fraction for each gene across all the slides
raw_adata = gene_features.get_exp_frac(raw_adata)
# Compute the global expression fraction for each gene
raw_adata = gene_features.get_glob_exp_frac(raw_adata)
# Histogram matrix
hist_mat, edge_exp_frac, edge_glob_exp_frac = np.histogram2d(raw_adata.var['exp_frac'], raw_adata.var['glob_exp_frac'], range=[[0,1],[0,1]], bins=20)
# Define dataframe
index_str = [f'{int(100*per)}%' for per in edge_exp_frac[1:]]
col_str = [f'{int(100*per)}%' for per in edge_glob_exp_frac[1:]]
hist_df = pd.DataFrame(hist_mat.astype(int), index=index_str, columns=col_str)
# Plot params
scale = 3
fig_size = (50, 40)
tit_size = 80
lab_size = 40
# Define colormap
d_colors = ["white", "darkcyan"]
cmap1 = LinearSegmentedColormap.from_list("mycmap", d_colors)
# Plot global expression fraction dataframe
plt.figure(figsize=fig_size)
sns.set_theme(font_scale=scale)
ax = sns.heatmap(hist_df, annot=True, linewidths=.5, fmt='g', cmap=cmap1, linecolor='k', norm=colors.LogNorm(vmin=0.9, vmax=10000))
# Define figure styling
plt.suptitle(f'Expression Fraction {dataset}', fontsize=tit_size)
plt.yticks(rotation=0)
plt.xticks(rotation=90)
ax.tick_params(labelsize=lab_size)
plt.xlabel("Global Expression Fraction", fontsize=tit_size)
plt.ylabel("Expression Fraction", fontsize=tit_size)
# Define color bar configs
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=lab_size)
cbar.ax.set_ylabel('Number of Genes', fontsize=tit_size)
plt.tight_layout()
# Save figure
plt.savefig(path, dpi=300)
plt.close()
# Return font scale to normal and matplotlib defaults to not mess the other figures
sns.set_theme(font_scale=1.0)
mpl.rcParams.update(mpl.rcParamsDefault)
[docs]def plot_histograms(processed_adata: ad.AnnData, raw_adata: ad.AnnData, path: str) -> None:
""" Plot filtering histograms
This function plots a figure that analyses the effect of the filtering over the data.
The first row corresponds to the raw data (which has patches and excludes constant genes) and the second row
plots the filtered and processed data. Histograms of total:
1. Counts per cell
2. Cells with expression
3. Total counts per gene
4. Moran I statistics (only in processed data)
are generated. The plot is saved in the specified path.
Cell filtering histograms are in red, gene filtering histograms are in blue and autocorrelation filtering histograms are in green.
Args:
processed_adata (ad.AnnData): Processed and filtered data ready to use by the model.
raw_adata (ad.AnnData): Loaded data from .h5ad file that is not filtered but has patch information.
path (str): Path to save histogram plot.
"""
# Compute qc metrics for raw and processed data in order to have total counts updated
sc.pp.calculate_qc_metrics(raw_adata, inplace=True, log1p=False, percent_top=None)
sc.pp.calculate_qc_metrics(processed_adata, inplace=True, log1p=False, percent_top=None, layer='counts')
# Compute the expression fraction of the raw_adata
raw_adata = gene_features.get_exp_frac(raw_adata)
# Create figures
fig, ax = plt.subplots(nrows=2, ncols=5)
fig.set_size_inches(18.75, 5)
bin_num = 50
# Plot histogram of the number of counts that each cell has
raw_adata.obs['total_counts'].hist(ax=ax[0,0], bins=bin_num, grid=False, color='k')
processed_adata.obs['total_counts'].hist(ax=ax[1,0], bins=bin_num, grid=False, color='darkred')
# Plot histogram of the expression fraction of each gene
raw_adata.var['exp_frac'].plot(kind='hist', ax=ax[0,1], bins=bin_num, grid=False, color='k', logy=True)
processed_adata.var['exp_frac'].plot(kind = 'hist', ax=ax[1,1], bins=bin_num, grid=False, color='darkcyan', logy=True)
# Plot histogram of the number of cells that express a given gene
raw_adata.var['n_cells_by_counts'].plot(kind='hist', ax=ax[0,2], bins=bin_num, grid=False, color='k', logy=True)
processed_adata.var['n_cells_by_counts'].plot(kind = 'hist', ax=ax[1,2], bins=bin_num, grid=False, color='darkcyan', logy=True)
# Plot histogram of the number of total counts per gene
raw_adata.var['total_counts'].plot(kind='hist', ax=ax[0,3], bins=bin_num, grid=False, color='k', logy=True)
processed_adata.var['total_counts'].plot(kind = 'hist', ax=ax[1,3], bins=bin_num, grid=False, color='darkcyan', logy=True)
# Plot histogram of the MoranI statistic per gene
# raw_adata.var['moranI'].plot(kind='hist', ax=ax[0,4], bins=bin_num, grid=False, color='k', logy=True)
processed_adata.var['d_log1p_moran'].plot(kind = 'hist', ax=ax[1,4], bins=bin_num, grid=False, color='darkgreen', logy=True)
# Lists to format axes
tit_list = ['Raw: Total counts', 'Raw: Expression fraction', 'Raw: Cells with expression', 'Raw: Total gene counts', 'Raw: MoranI statistic',
'Processed: Total counts', 'Processed: Expression fraction', 'Processed: Cells with expression', 'Processed: Total gene counts', 'Processed: MoranI statistic']
x_lab_list = ['Total counts', 'Expression fraction', 'Cells with expression', 'Total counts', 'MoranI statistic']*2
y_lab_list = ['# of cells', '# of genes', '# of genes', '# of genes', '# of genes']*2
# Format axes
for i, axis in enumerate(ax.flatten()):
# Not show moran in raw data because it has no sense to compute it
if i == 4:
# Delete frame
axis.axis('off')
continue
axis.set_title(tit_list[i])
axis.set_xlabel(x_lab_list[i])
axis.set_ylabel(y_lab_list[i])
axis.spines[['right', 'top']].set_visible(False)
# Shared x axes between plots
ax[1,0].sharex(ax[0,0])
ax[1,1].sharex(ax[0,1])
ax[1,2].sharex(ax[0,2])
ax[1,3].sharex(ax[0,3])
ax[1,4].sharex(ax[0,4])
# Shared y axes between
ax[1,0].sharey(ax[0,0])
fig.tight_layout()
fig.savefig(path, dpi=300)
plt.close()
[docs]def plot_random_patches(dataset: str, processed_adata: ad.AnnData, path: str, patch_size: int = 224) -> None:
""" Plot random set of patches
This function gets 16 flat random patches (with the specified dims) from the processed adata objects. It
reshapes them to a bidimensional form and shows them. The plot is saved to the specified path.
Args:
patch_size: Patch size (default 224)
dataset: Name of the dataset
processed_adata (ad.AnnData): Processed and filtered data ready to use by the model.
path (str): Path to save the image.
"""
# Verify that the patch scale exists and only exists once
obsm_keys = list(processed_adata.obsm.keys())
patch_scale_key = [key for key in obsm_keys if "patches_scale" in key]
assert len(patch_scale_key) == 1, "patches_scale key either does not exist or exists more than once in keys_list."
# Get the patch scale
patch_scale = patch_scale_key[0].split('_')[-1]
# Get the flat patches from the dataset
flat_patches = processed_adata.obsm[f'patches_scale_{patch_scale}']
# Reshape the patches for them to have image form
patches = flat_patches.reshape((-1, patch_size, patch_size, 3))
# Choose 16 random patches
chosen = np.random.randint(low=0, high=patches.shape[0], size=16)
# Get plotting patches
plotting_patches = patches[chosen, :, :, :]
# Declare image
im, ax = plt.subplots(nrows=4, ncols=4)
# Cycle over each random patch
for i, ax in enumerate(ax.reshape(-1)):
ax.imshow(plotting_patches[i, :, :, :])
ax.axis('off')
# Set figure formatting
im.suptitle(f'Random patches from {dataset}')
plt.tight_layout()
im.savefig(path, dpi=300)
plt.close()
[docs]def visualize_moran_filtering(param_dict: dict, processed_adata: ad.AnnData, from_layer: str, path: str, split_names:dict, top: bool = True) -> None:
""" Plot the most or least auto-correlated genes
This function visualizes the spatial expression of the 4 most and least auto-correlated genes in processed_adata.
The title of each subplot shows the value of the moran I statistic for a given gene. The plot is saved to the specified
path. This plot uses the slide list in string format in param_dict['plotting_slides'] to plot these specific observations.
If no list is provided (param_dict['plotting_slides']=='None'), 4 random slides are chosen.
Args:
param_dict: Dictionary with dataset parameters
processed_adata (ad.AnnData): Processed and filtered data ready to use by the model
from_layer (str): Layer of the adata object to use for plotting
path (str): Path to save the generated image
split_names (dict): dictionary containing split names
top (bool, optional): If True, the top 4 most auto-correlated genes are visualized. If False, the top 4 least
auto-correlated genes are visualized. Defaults to True
"""
plotting_key = from_layer
# Refine plotting slides string to assure they are in the dataset
param_dict['plotting_slides'] = refine_plotting_slides_str(split_names, processed_adata, param_dict['plotting_slides'])
# Get the slides to visualize in adata format
s_adata_list = filtering.get_slides_adata(processed_adata, param_dict['plotting_slides'])
# Get te top 4 most or least auto-correlated genes in processed data depending on the value of top
# NOTE: The selection of genes is done in the complete collection of slides, not in the specified slides
moran_key = 'd_log1p_moran'
if top:
selected_table = processed_adata.var.nlargest(4, columns=moran_key)
else:
selected_table = processed_adata.var.nsmallest(4, columns=moran_key)
# Declare figure
fig, ax = plt.subplots(nrows=4, ncols=len(s_adata_list))
fig.set_size_inches(4 * len(s_adata_list) , 13)
# Cycle over slides
for i in range(len(selected_table)):
# Get min and max of the selected gene in the slides
gene_min = min([dat[:, selected_table.index[i]].layers[plotting_key].min() for dat in s_adata_list])
gene_max = max([dat[:, selected_table.index[i]].layers[plotting_key].max() for dat in s_adata_list])
# Define color normalization
norm = matplotlib.colors.Normalize(vmin=gene_min, vmax=gene_max)
for j in range(len(s_adata_list)):
# Define bool to only plot the colorbar in the last column
cbar = True if j==(len(s_adata_list)-1) else False
# Plot selected genes in the specified slides
sq.pl.spatial_scatter(s_adata_list[j], color=[selected_table.index[i]], layer= plotting_key, ax=ax[i,j], cmap='jet', norm=norm, colorbar=cbar)
# Set slide name
if i==0:
ax[i,j].set_title(f'{param_dict["plotting_slides"].split(",")[j]}', fontsize=15)
else:
ax[i,j].set_title('')
# Set gene name and moran I value
if j==0:
ax[i,j].set_ylabel(f'{selected_table.index[i]}: $I = {selected_table[moran_key].iloc[i].round(3)}$', fontsize=13)
else:
ax[i,j].set_ylabel('')
# Format figure
for axis in ax.flatten():
axis.set_xlabel('')
# Turn off all spines
axis.spines['top'].set_visible(False)
axis.spines['right'].set_visible(False)
axis.spines['bottom'].set_visible(False)
axis.spines['left'].set_visible(False)
# Define title
tit_str = 'most (top)' if top else 'least (bottom)'
fig.suptitle(f'Top 4 {tit_str} auto-correlated genes in processed data', fontsize=20)
fig.tight_layout()
# Save plot
fig.savefig(path, dpi=300)
plt.close()
[docs]def visualize_gene_expression(param_dict: dict, processed_adata: ad.AnnData, from_layer: str, path: str, split_names:dict) -> None:
""" Plot specific gene expression
This function selects the genes specified in param_dict['plotting_genes'] and param_dict['plotting_slides']
to plot gene expression for the specified genes in the specified slides. If either of them is 'None', then the method
chooses randomly (4 genes or 4 slides in the stnet_dataset or 2 slides in visium datasets). The data is plotted from
the .layers[from_layer] expression matrix
Args:
param_dict: Dictionary of dataset parameters
processed_adata (ad.AnnData): The processed adata with the filtered patient collection
from_layer (str): The key to the layer of the data to plot
path (str): Path to save the image
split_names (dict): dictionary containing split names
"""
# Refine plotting slides string to assure they are in the dataset
param_dict['plotting_slides'] = refine_plotting_slides_str(split_names, processed_adata, param_dict['plotting_slides'])
# Get the slides to visualize in adata format
s_adata_list = filtering.get_slides_adata(processed_adata, param_dict['plotting_slides'])
# Define gene list
gene_list = param_dict['plotting_genes'].split(',')
# Try to get the specified genes otherwise choose randomly
try:
gene_table = processed_adata[:, gene_list].var
except:
print('Could not find all the specified plotting genes, choosing randomly')
gene_list = np.random.choice(processed_adata.var_names, size=4, replace=False)
gene_table = processed_adata[:, gene_list].var
# Declare figure
fig, ax = plt.subplots(nrows=4, ncols=len(s_adata_list))
fig.set_size_inches(4 * len(s_adata_list) , 13)
# Cycle over slides
for i in range(len(gene_table)):
# Get min and max of the selected gene in the slides
gene_min = min([dat[:, gene_table.index[i]].layers[from_layer].min() for dat in s_adata_list])
gene_max = max([dat[:, gene_table.index[i]].layers[from_layer].max() for dat in s_adata_list])
# Define color normalization
norm = matplotlib.colors.Normalize(vmin=gene_min, vmax=gene_max)
for j in range(len(s_adata_list)):
# Define bool to only plot the colorbar in the last column
cbar = True if j==(len(s_adata_list)-1) else False
# Plot selected genes in the specified slides
sq.pl.spatial_scatter(s_adata_list[j], layer=from_layer, color=[gene_table.index[i]], ax=ax[i,j], cmap='jet', norm=norm, colorbar=cbar)
# Set slide name
if i==0:
ax[i,j].set_title(f'{param_dict["plotting_slides"].split(",")[j]}', fontsize=15)
else:
ax[i,j].set_title('')
# Set gene name with moran I value
if j==0:
moran_key = 'd_log1p_moran'
ax[i,j].set_ylabel(f'{gene_table.index[i]}: $I = {gene_table[moran_key].iloc[i].round(3)}$', fontsize=13)
else:
ax[i,j].set_ylabel('')
# Format figure
for axis in ax.flatten():
axis.set_xlabel('')
# Turn off all spines
axis.spines['top'].set_visible(False)
axis.spines['right'].set_visible(False)
axis.spines['bottom'].set_visible(False)
axis.spines['left'].set_visible(False)
fig.suptitle('Gene expression in processed data', fontsize=20)
fig.tight_layout()
# Save plot
fig.savefig(path, dpi=300)
plt.close()
[docs]def plot_clusters(dataset: str, param_dict: dict, processed_adata: ad.AnnData, from_layer: str, path: str, split_names:dict) -> None:
""" Plot clusters spatially
This function generates a plot that visualizes Leiden clusters spatially in the slides in param_dict['plotting_slides'].
The slides can be specified in param_dict['plotting_slides'] or chosen randomly.
It plots:
1. The spatial distribution of the Leiden clusters in the slides.
2. UMAP embeddings of each slide colored by Leiden clusters.
3. General UMAP embedding of the complete dataset colored by Leiden clusters and the batch correction key.
4. PCA embeddings of the complete dataset colored by the batch correction key.
Args:
dataset: Name of the dataset
param_dict: Dictionary of dataset parameters
processed_adata (ad.AnnData): Processed and filtered data ready to use by the model.
from_layer (str): The key in adata.layers where the expression matrix is stored.
path (str): Path to save the image
split_names (dict): dictionary containing split names
"""
# Update the adata object with the embeddings and clusters
updated_adata = spot_features.compute_dim_red(processed_adata, from_layer)
# Refine plotting slides string to assure they are in the dataset
param_dict['plotting_slides'] = refine_plotting_slides_str(split_names, processed_adata, param_dict['plotting_slides'])
# Get the slides to visualize in adata format
s_adata_list = filtering.get_slides_adata(updated_adata, param_dict['plotting_slides'])
# Define dictionary from cluster to color
clusters = updated_adata.obs['cluster'].unique()
# Sort clusters
clusters = np.sort([int(cl) for cl in clusters])
clusters = [str(cl) for cl in clusters]
# Define color palette
colors = sns.color_palette('hls', len(clusters))
palette = dict(zip(clusters, colors))
gray_palette = dict(zip(clusters, ['gray']*len(clusters)))
# Declare figure
fig = plt.figure(layout="constrained")
gs0 = fig.add_gridspec(1, 2)
gs00 = gs0[0].subgridspec(4, 2)
gs01 = gs0[1].subgridspec(3, 1)
fig.set_size_inches(15,14)
# Cycle over slides
for i in range(len(s_adata_list)):
curr_clusters = s_adata_list[i].obs['cluster'].unique()
# Sort clusters
curr_clusters = np.sort([int(cl) for cl in curr_clusters])
curr_clusters = [str(cl) for cl in curr_clusters]
# # Define color palette
spatial_colors = matplotlib.colors.ListedColormap([palette[x] for x in curr_clusters])
# Get ax for spatial plot and UMAP plot
spatial_ax = fig.add_subplot(gs00[i, 0])
umap_ax = fig.add_subplot(gs00[i, 1])
# Plot cluster colors in spatial space
sq.pl.spatial_scatter(s_adata_list[i], color=['cluster'], ax=spatial_ax, palette=spatial_colors)
spatial_ax.get_legend().remove()
spatial_ax.set_title('Spatial', fontsize=18)
spatial_ax.set_ylabel(f'{param_dict["plotting_slides"].split(",")[i]}', fontsize=12)
spatial_ax.set_xlabel('')
# Turn off all spines
spatial_ax.spines['top'].set_visible(False)
spatial_ax.spines['right'].set_visible(False)
spatial_ax.spines['bottom'].set_visible(False)
spatial_ax.spines['left'].set_visible(False)
# Plot cluster colors in UMAP space for slide and all collection
sc.pl.umap(updated_adata, layer=from_layer, color=['cluster'], ax=umap_ax, frameon=False, palette=gray_palette, s=10, cmap=None, alpha=0.2)
umap_ax.get_legend().remove()
sc.pl.umap(s_adata_list[i], layer=from_layer, color=['cluster'], ax=umap_ax, frameon=False, palette=palette, s=10, cmap=None)
umap_ax.get_legend().remove()
umap_ax.set_title('UMAP', fontsize=18)
# Get axes for leiden clusters, patient and cancer types
leiden_ax = fig.add_subplot(gs01[0])
patient_ax = fig.add_subplot(gs01[1])
pca_ax = fig.add_subplot(gs01[2])
# Plot leiden clusters in UMAP space
sc.pl.umap(updated_adata, color=['cluster'], ax=leiden_ax, frameon=False, palette=palette, s=10, cmap=None)
leiden_ax.get_legend().set_title('Leiden Clusters')
leiden_ax.get_legend().get_title().set_fontsize(15)
leiden_ax.set_title('UMAP & Leiden Clusters', fontsize=18)
# Plot batch_key in UMAP space
sc.pl.umap(updated_adata, color=[param_dict['combat_key']], ax=patient_ax, frameon=False, palette='tab20', s=10, cmap=None)
patient_ax.get_legend().set_title(param_dict['combat_key'])
patient_ax.get_legend().get_title().set_fontsize(15)
patient_ax.set_title(f"UMAP & {param_dict['combat_key']}", fontsize=18)
# Plot cancer types in UMAP space
sc.pl.pca(updated_adata, color=[param_dict['combat_key']], ax=pca_ax, frameon=False, palette='tab20', s=10, cmap=None)
pca_ax.get_legend().set_title(param_dict['combat_key'])
pca_ax.get_legend().get_title().set_fontsize(15)
pca_ax.set_title(f'PCA & {param_dict["combat_key"]}', fontsize=18)
# Format figure and save
fig.suptitle(f'Cluster visualization for {dataset} in layer {from_layer}', fontsize=25)
# fig.tight_layout()
fig.savefig(path, dpi=300)
plt.close(fig)
[docs]def plot_mean_std(dataset: str, processed_adata: ad.AnnData, raw_adata: ad.AnnData, path: str) -> None:
""" Plot mean and std of all genes
This function plots a scatter of mean and standard deviation of genes present in raw_adata (black) and all the layers with non-zero
mean in processed_adata. It is used to see the effect of filtering and processing in the genes. The plot is saved to the specified path.
Args:
dataset: Name of the dataset
processed_adata (ad.AnnData): Processed and filtered data ready to use by the model.
raw_adata (ad.AnnData): Data loaded data from .h5ad file that is not filtered but has patch information.
path (str): Path to save the image.
"""
# Copy raw data to auxiliary data
aux_raw_adata = raw_adata.copy()
# Normalize and log transform aux_raw_adata
sc.pp.normalize_total(aux_raw_adata, inplace=True)
sc.pp.log1p(aux_raw_adata)
# Get means and stds from raw data
raw_mean = aux_raw_adata.to_df().mean(axis=0)
raw_std = aux_raw_adata.to_df().std(axis=0)
# Define list of layers to plot
layers = ['log1p', 'd_log1p', 'c_log1p', 'c_d_log1p']
plt.figure()
plt.scatter(raw_mean, raw_std, s=1, c='k', label='Raw data')
for layer in layers:
# Get means and stds from processed data
pro_mean = processed_adata.to_df(layer=layer).mean(axis=0)
pro_std = processed_adata.to_df(layer=layer).std(axis=0)
plt.scatter(pro_mean, pro_std, s=1, label=f'{layer} data')
plt.xlabel('Mean $Log(x+1)$')
plt.ylabel('Std $Log(x+1)$')
plt.legend(loc='best')
plt.title(f'Mean Std plot {dataset}')
plt.gca().spines[['right', 'top']].set_visible(False)
plt.tight_layout()
plt.savefig(path, dpi=300)
plt.close()
[docs]def plot_data_distribution_stats(dataset: str, processed_adata: ad.AnnData, path:str) -> None:
""" Plot dataset's general stats
This function plots a pie chart and bar plots of the distribution of spots and slides in the dataset split.
Args:
dataset: Name of the dataset
processed_adata (ad.AnnData): Processed and filtered data ready to use by the model.
path (str): Path to save the image.
"""
patients = processed_adata.obs['patient'].unique()
slides = processed_adata.obs['slide_id'].unique()
quant_spots = [processed_adata[processed_adata.obs['split'] == split].shape[0] for split in ['train', 'val', 'test']]
metadata = f"Patients: {len(patients)}\nSlides: {len(slides)}\nSpots: {processed_adata.shape[0]}"
labels_pie = ['Train', 'Valid', 'Test']
if quant_spots[-1] == 0:
quant_spots.pop()
labels_pie.pop()
# Create figures
fig, ax = plt.subplots(nrows=1, ncols=3)
fig.set_size_inches(18.75, 6.5)
# Format axes
ax[0].pie(quant_spots, labels = labels_pie,
wedgeprops = {"linewidth": 1, "edgecolor": "white"},
autopct = lambda x: '{:.1f}%\n{:.0f}'.format(x, x*processed_adata.shape[0]/100),
colors = sns.color_palette('Set2'),
textprops = {"fontsize": 15})
spot_counts = processed_adata.obs.groupby(['patient', 'split'], observed=False)['unique_id'].nunique().unstack(fill_value=0)
spot_counts = spot_counts.reindex(columns=['train','val','test'])
slide_counts = processed_adata.obs.groupby(['patient', 'split'], observed=False)['slide_id'].nunique().unstack(fill_value=0)
slide_counts = slide_counts.reindex(columns=['train','val','test'])
spot_counts.plot.bar(stacked=True, rot=0, color=sns.color_palette('Set2'), ax=ax[1], legend=False, fontsize=13)
slide_counts.plot.bar(stacked=True, rot=0, color=sns.color_palette('Set2'), ax=ax[2], legend=False, fontsize=13)
ax[1].spines[['right', 'top']].set_visible(False)
ax[2].spines[['right', 'top']].set_visible(False)
ax[0].set_title('Spot distribution per split', fontsize = 17)
ax[1].set_title('Spots per patient', fontsize = 17)
ax[2].set_title('Slides per patient', fontsize = 17)
ax[1].set_ylabel('# spots', fontsize = 15)
ax[2].set_ylabel('# slides', fontsize = 15)
ax[1].set_xlabel('Patient', fontsize = 15)
ax[2].set_xlabel('Patient', fontsize = 15)
fig.suptitle(f"Data distribution for {dataset}", fontsize = 20)
plt.figtext(0.02, 0.08, metadata, fontsize=15, wrap=True, bbox ={'facecolor':'whitesmoke', 'alpha':0.3, 'pad':5})
fig.tight_layout()
fig.savefig(path, dpi=300)
plt.close()
[docs]def plot_mean_std_partitions(dataset: str, processed_adata: ad.AnnData, from_layer: str, path: str) -> None:
""" Plot mean and std of genes by data split
This function plots a scatter of mean and standard deviation of genes present in processed_adata drawing with a different color different data
splits (train/val/test). This is all done for the specified layer in the from_layer parameter. This function is used to see how tractable is
the task. The plot is saved to the specified path.
Args:
dataset: Name of the dataset
processed_adata (ad.AnnData): Processed and filtered data ready to use by the model.
from_layer (str): The key in adata.layers where the expression matrix is stored.
path (str): Path to save the image.
"""
# Copy processed adata to avoid problems
aux_processed_adata = processed_adata.copy()
plt.figure()
for curr_split in aux_processed_adata.obs['split'].unique():
# Get means and stds from processed data
curr_mean = aux_processed_adata[aux_processed_adata.obs.split==curr_split, :].to_df(layer=from_layer).mean(axis=0)
curr_std = aux_processed_adata[aux_processed_adata.obs.split==curr_split, :].to_df(layer=from_layer).std(axis=0)
plt.scatter(curr_mean, curr_std, s=1, label=f'{curr_split} data')
plt.xlabel('Mean $Log(x+1)$')
plt.ylabel('Std $Log(x+1)$')
plt.legend(loc='best')
plt.title(f'Mean Std plot {dataset}')
plt.gca().spines[['right', 'top']].set_visible(False)
plt.tight_layout()
plt.savefig(path, dpi=300)
plt.close()
#TODO: revisar para eventualmente eliminar (se usa en plot_test)
def refine_plotting_slides_str(split_names: dict, collection: ad.AnnData, slide_list: str) -> str:
""" Assure plotting slides are on the dataset.
This function refines the plotting slides string to assure all slides are on the dataset. It works in the following way:
1. If all slides are in the dataset it does nothing and returns the same slide_list parameter.
2. If any slide is missing in the dataset or slide_list=='None' then it does one of 2 things:
a. If the dataset has 4 or less slides all slides are set as plotting slides
b. If the dataset has more than 4 slides it iterates over splits (train/val/test) and choses a single slide at a time without replacement.
Args:
split_names: dictionary containing split names
collection (ad.AnnData): Processed and filtered data ready to use by the model.
slide_list (str): String with a list of slides separated by commas.
Returns:
str: Refined version of the slide_list string.
"""
# Get bool value indicating if all slides in the string are on the dataset
plot_slides_in_dataset = all([sl in collection.obs.slide_id.unique() for sl in slide_list.split(',')])
# Decide if a refinement must be done
if (not plot_slides_in_dataset) or (slide_list=='None'):
# Check if there are less than 4 slides. If so, all are plotting slides
if len(collection.obs.slide_id.unique()) <= 4:
slide_list = ','.join(collection.obs.slide_id.unique())
print(f'Plotting slides were None or missing in the dataset. And there are 4 or less slides. Setting all slides as plotting slides: {slide_list}')
# If more than 4 slides, iterate over splits and chose randomly without replacement
else:
plotting_slide_list = []
# Get list of unique splits
split_list = collection.obs.split.unique()
# Get a copy of the dictionary of splits to slides
split2slide_list = split_names.copy()
# Set counter to 0
count = 0
# Iterate until we get 4 slides
while len(plotting_slide_list) < 4:
# Get current split ant current slide list
curr_split = split_list[count % len(split_list)]
curr_slide_list = split2slide_list[curr_split]
# If the current slide list has slides choose one and delete it from the list
if len(curr_slide_list) > 0:
curr_slide = np.random.choice(curr_slide_list, 1)[0]
plotting_slide_list.append(curr_slide)
split2slide_list[curr_split].remove(curr_slide)
# Update counter
count+=1
# Update dataset parameter
slide_list = ','.join(plotting_slide_list)
print(f'Plotting slides were None or missing in the dataset. And there are more than 4 slides. Setting slides internally from all splits: {slide_list}')
return slide_list
[docs]def plot_tests(patch_size: int, dataset: str, split_names: dict, param_dict: dict, folder_path: str, processed_adata: ad.AnnData, raw_adata: ad.AnnData)->None:
""" Plot all quality control plots
This function calls all the plotting functions in the class to create 6 quality control plots to check if the processing step of
the dataset is performed correctly. The results are saved in dataset_logs folder and indexed by date and time. A dictionary
in json format with all the dataset parameters is saved in the same log folder for reproducibility. Finally, a txt with the names of the
genes used in processed adata is also saved in the folder.
"""
### Define function to get an adata list of plotting slides
print('Started quality control plotting')
start = time()
# Define directory path to save data
save_path = os.path.join(folder_path, 'qc_plots')
os.makedirs(save_path, exist_ok=True)
# Define interest layers
relevant_layers = ['log1p', 'd_log1p', 'c_d_log1p']
complete_layers = ['counts', 'tpm', 'log1p', 'd_log1p', 'c_d_log1p']
# Assure that the plotting genes are in the data and if not, set random plotting genes
if not all([gene in processed_adata.var_names for gene in param_dict['plotting_genes'].split(',')]):
param_dict['plotting_genes'] = ','.join(np.random.choice(processed_adata.var_names, 4))
print(f'Plotting genes not in data. Setting random plotting genes: {param_dict["plotting_genes"]}')
# Refine plotting slides string to assure they are in the dataset
param_dict['plotting_slides'] = refine_plotting_slides_str(split_names, processed_adata, param_dict['plotting_slides'])
# Plot partitions mean vs std scatter
print('Started partitions mean vs std scatter plotting')
os.makedirs(os.path.join(save_path, 'mean_vs_std_partitions'), exist_ok=True)
for lay in tqdm(relevant_layers):
plot_mean_std_partitions(dataset, processed_adata, from_layer=lay, path=os.path.join(save_path, 'mean_vs_std_partitions', f'{lay}.png'))
# Plot all slides in collection
print('Started all slides plotting')
plot_all_slides(dataset, processed_adata, os.path.join(save_path, 'all_slides.png'))
# Make plot of filtering histograms
print('Started filtering histograms plotting')
plot_histograms(processed_adata, raw_adata, os.path.join(save_path, 'filtering_histograms.png'))
# Make plot of random patches
print('Started random patches plotting')
plot_random_patches(dataset, processed_adata, os.path.join(save_path, 'random_patches.png'), patch_size)
# Create save paths fot top and bottom moran genes
os.makedirs(os.path.join(save_path, 'top_moran_genes'), exist_ok=True)
os.makedirs(os.path.join(save_path, 'bottom_moran_genes'), exist_ok=True)
print('Started moran filtering plotting')
# Plot moran filtering
for lay in tqdm(relevant_layers):
# Make plot of 4 most moran genes and 4 less moran genes (in the chosen slides)
visualize_moran_filtering(param_dict, processed_adata, from_layer=lay, path = os.path.join(save_path, 'top_moran_genes', f'{lay}.png'), split_names=split_names, top = True)
visualize_moran_filtering(param_dict, processed_adata, from_layer=lay, path = os.path.join(save_path, 'bottom_moran_genes', f'{lay}.png'), split_names=split_names, top = False)
# Create save paths for cluster plots
os.makedirs(os.path.join(save_path, 'cluster_plots'), exist_ok=True)
print('Started cluster plotting')
# Plot cluster graphs
for lay in tqdm(relevant_layers):
plot_clusters(dataset, param_dict, processed_adata, from_layer=lay, path=os.path.join(save_path, 'cluster_plots', f'{lay}.png'), split_names=split_names,)
# Define expression layers
os.makedirs(os.path.join(save_path, 'expression_plots'), exist_ok=True)
print('Started gene expression plotting')
# Plot of gene expression in the chosen slides for the 4 chosen genes
for lay in tqdm(complete_layers):
visualize_gene_expression(param_dict, processed_adata, from_layer=lay, path=os.path.join(save_path,'expression_plots', f'{lay}.png'), split_names=split_names,)
# Make plot of mean vs std per gene must be programmed manually.
print('Started mean vs std plotting')
plot_mean_std(dataset, processed_adata, raw_adata, os.path.join(save_path, 'mean_std_scatter.png'))
# Make plot of data distribution statistics.
print('Started data distribution statistics plotting')
plot_data_distribution_stats(dataset, processed_adata, os.path.join(save_path, 'splits_stats.png'))
# Plot expression fraction 2D histogram
print('Started expression fraction plotting')
plot_exp_frac(param_dict, dataset, raw_adata, os.path.join(save_path, 'exp_frac.png'))
# Print the time that took to plot quality control
end = time()
print(f'Quality control plotting took {round(end-start, 2)}s')
print(f'Images saved in {save_path}')