Source code for spared.denoising.denoising
import anndata as ad
from tqdm import tqdm
import numpy as np
import sys
import pathlib
from datetime import datetime
import torch
import os
os.environ['USE_PYGEOS'] = '0' # To supress a warning from geopandas
import json
from lightning.pytorch import seed_everything
from torch.utils.data import DataLoader
import warnings
# Get the path of the spared database
SPARED_PATH = pathlib.Path(__file__).resolve().parent.parent
# Agregar el directorio padre al sys.path para los imports
sys.path.append(str(SPARED_PATH))
# Import im_encoder.py file
from spot_features import spot_features
from layer_operations import layer_operations
from datasets import datasets
from spackle.utils import *
from spackle.model import GeneImputationModel
from spackle.dataset import ImputationDataset
from spackle.main import train_spackle
# Remove the path from sys.path
sys.path.remove(str(SPARED_PATH))
#clean noise with medians
# TODO: Think in making this function also add the binary mask layer
[docs]def median_cleaner(collection: ad.AnnData, from_layer: str, to_layer: str, n_hops: int, hex_geometry: bool) -> ad.AnnData:
"""Remove noise with adaptive median filter.
Function that cleans noise (missing data) with the modified adaptive median initially proposed by `SEPAL <https://doi.org/10.48550/arXiv.2309.01036>`_
filter for each slide in an AnnData collection. Windows to compute the medians are defined by topological distances (hops) in the neighbors graph defined
by the ``hex_geometry`` parameter with a maximum window size of ``n_hops``. The adaptive median filter denoises each gene independently. In other words
gene A has no influence on the denoising of gene B. The data will be taken from ``adata.layers[from_layer]`` and the results will be stored in
``adata.layers[to_layer]``.
Args:
collection (ad.AnnData): The AnnData collection to process.
from_layer (str): The layer to compute the adaptive median filter from. Where to clean the noise from.
to_layer (str): The layer to store the results of the adaptive median filter. Where to store the cleaned data.
n_hops (int): The maximum number of concentric rings in the neighbors graph to take into account to compute the median. Analogous to the maximum window size.
hex_geometry (bool): ``True`` if the graph has hexagonal spatial geometry (Visium technology). If ``False``, then the graph is a grid.
Returns:
ad.AnnData: New AnnData collection with the results of the adaptive median filter stored in ``adata.layers[to_layer]``.
"""
### Define cleaning function for single slide:
def adaptive_median_filter_pepper(adata: ad.AnnData, from_layer: str, to_layer: str, n_hops: int, hex_geometry: bool) -> ad.AnnData:
"""
This function computes a modified adaptive median filter for pairs (obs, gene) with a zero value (peper noise) in the layer 'from_layer' and
stores the result in the layer 'to_layer'. The max window size is a neighborhood of n_hops defined by the conectivity (hexagonal or grid).
This means the number of concentric rings in a graph to take into account to compute the median.
The adaptive median filter denoises each gene independently. In other words gene A has no influence on the denoising of gene B.
Args:
adata (ad.AnnData): The AnnData object to process. Importantly it is only from a single slide. Can not be a collection of slides.
from_layer (str): The layer to compute the adaptive median filter from.
to_layer (str): The layer to store the results of the adaptive median filter.
n_hops (int): The maximum number of concentric rings in the graph to take into account to compute the median. Analogous to the max window size.
hex_geometry (bool): Whether the graph is hexagonal or not. If True, then the graph is hexagonal. If False, then the graph is a grid. Only
true for visium datasets.
Returns:
ad.AnnData: The AnnData object with the results of the adaptive median filter stored in the layer 'to_layer'.
"""
# Define original expression matrix
original_exp = adata.layers[from_layer]
medians = np.zeros((adata.n_obs, n_hops, adata.n_vars))
# Iterate over the hops
for i in range(1, n_hops+1):
# Get dictionary of neighbors for a given number of hops
curr_neighbors_dict = spot_features.get_spatial_neighbors(adata, i, hex_geometry)
# Iterate over observations
for j in range(adata.n_obs):
# Get the list of indexes of the neighbors of the j'th observation
neighbors_idx = curr_neighbors_dict[j]
# Get the expression matrix of the neighbors
neighbor_exp = original_exp[neighbors_idx, :]
# Get the median of the expression matrix
median = np.median(neighbor_exp, axis=0)
# Store the median in the medians matrix
medians[j, i-1, :] = median
# Also robustly compute the median of the non-zero values for each gene
general_medians = np.apply_along_axis(lambda v: np.median(v[np.nonzero(v)]), 0, original_exp)
general_medians[np.isnan(general_medians)] = 0.0 # Correct for possible nans
# Define corrected expression matrix
corrected_exp = np.zeros_like(original_exp)
### Now that all the possible medians are computed. We code for each observation:
# Note: i indexes over observations, j indexes over genes
for i in range(adata.n_obs):
for j in range(adata.n_vars):
# Get real expression value
z_xy = original_exp[i, j]
# Only apply adaptive median filter if real expression is zero
if z_xy != 0:
corrected_exp[i,j] = z_xy
continue
else:
# Definie initial stage and window size
current_stage = 'A'
k = 0
while True:
# Stage A:
if current_stage == 'A':
# Get median value
z_med = medians[i, k, j]
# If median is not zero then go to stage B
if z_med != 0:
current_stage = 'B'
continue
# If median is zero, then increase window and repeat stage A
else:
k += 1
if k < n_hops:
current_stage = 'A'
continue
# If we have the biggest window size, then return the median
else:
# NOTE: Big modification to the median filter here. Be careful
corrected_exp[i,j] = general_medians[j]
break
# Stage B:
elif current_stage == 'B':
# Get window median
z_med = medians[i, k, j]
# If real expression is not peper then return it
if z_xy != 0:
corrected_exp[i,j] = z_xy
break
# If real expression is peper, then return the median
else:
corrected_exp[i,j] = z_med
break
# Add corrected expression to adata
adata.layers[to_layer] = corrected_exp
return adata
# Print message
print('Applying adaptive median filter to collection...')
# Get the unique slides
slides = np.unique(collection.obs['slide_id'])
# Define the corrected adata list
corrected_adata_list = []
# Iterate over the slides
for slide in tqdm(slides):
# Get the adata of the slide
adata = collection[collection.obs['slide_id'] == slide].copy()
# Apply adaptive median filter
adata = adaptive_median_filter_pepper(adata, from_layer, to_layer, n_hops, hex_geometry)
# Append to the corrected adata list
corrected_adata_list.append(adata)
# Concatenate the corrected adata list
corrected_collection = ad.concat(corrected_adata_list, join='inner', merge='same')
# Restore the uns attribute
corrected_collection.uns = collection.uns
return corrected_collection
#Replicate SpaCKLE's results
def spackle_cleaner_experiment(adata: ad.AnnData, dataset: str, from_layer: str, device, lr = 1e-3, train = True, load_ckpt_path = "", optimizer = "Adam", max_steps = 1000) -> ad.AnnData:
# TODO: [PC] add in the documentation that the adata must have data splits in adata.obs['split'] and the values should be 'train', 'val', and (optional) 'test'
# TODO: [PC] For documentation:
# "This function's purpose is solely to reproduce the results presented in SpaCKLE's paper"
# load_ckpt_path example: /home/pcardenasg/spared_imputation/imput_results/vicari_mouse_brain/2024-02-28-07-02-31/epoch=101-step=9370.ckpt {should end with the ckpt file and the ckpts file must be inside a directory that also contains script_params.json}
# Get parser and parse arguments
parser = get_main_parser()
args = parser.parse_args()
args_dict = vars(args)
# Get datetime
run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
# Set manual seeds and get cuda
seed_everything(42)
# TODO: [PC] allow the use of an already-trained model?
if train:
# Create directory where the newly trained model will be saved #save_path = os.path.join('imput_results', dataset, "best_model") # TODO: [PC] group opinion: ¿should we set data naming with date like in our works?
save_path = os.path.join('imput_results', dataset, run_date)
os.makedirs(save_path, exist_ok=True)
# Save script arguments in json file
with open(os.path.join(save_path, 'script_params.json'), 'w') as f:
json.dump(args_dict, f, indent=4)
print(f"Training a new SpaCKLE model. The script arguments and best checkpoints will be saved in {save_path}")
else:
assert os.path.exists(load_ckpt_path), "load_ckpts_path not found. Please use train = True if you do not have the checkpoints of a trained SpaCKLE model and its corresponding script_params.json file."
save_path = os.path.dirname(load_ckpt_path)
# FIXME: [PC] decidir qué elementos comparar y recordar que al subir los pesos (i.e Drive) subirlos con su json de params correspondiente
with open(os.path.join(save_path, 'script_params.json'), 'r') as f:
saved_script_params = json.load(f)
# Check that the parameters of the loaded model agree with the current inference process
#if (saved_script_params['prediction_layer'] != args_dict['prediction_layer']) or (saved_script_params['prediction_layer'] != args_dict['prediction_layer']):
# warnings.warn("Saved model's parameters differ from those of the current argparse.")
print(f"Model from {load_ckpt_path} will be loaded and tested. No new training will be undergone.")
# Train new SpaCKLE model
train_spackle(
adata=adata,
device=device,
save_path=save_path,
prediction_layer=from_layer,
lr=lr,
train=train,
get_performance=True,
load_ckpt_path=load_ckpt_path,
optimizer=optimizer,
max_steps=max_steps,
args=args)
#clean noise con spackle
[docs]def spackle_cleaner(adata: ad.AnnData, dataset: str, from_layer: str, to_layer: str, device, lr = 1e-3, train = True, get_performance_metrics = True, load_ckpt_path = "", optimizer = "Adam", max_steps = 1000) -> ad.AnnData:
# TODO: [PC] add in the documentation that the adata must have data splits in adata.obs['split'] and the values should be 'train', 'val', and (optional) 'test'
# TODO: [PC] For documentation:
# "This function's purpose is solely to reproduce the results presented in SpaCKLE's paper"
# load_ckpt_path example: /home/pcardenasg/spared_imputation/imput_results/vicari_mouse_brain/2024-02-28-07-02-31/epoch=101-step=9370.ckpt {should end with the ckpt file and the ckpts file must be inside a directory that also contains script_params.json}
# Get parser and parse arguments
parser = get_main_parser()
args = parser.parse_args()
args_dict = vars(args)
# Get datetime
run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
# Set manual seeds and get cuda
seed_everything(42)
# TODO: [PC] allow the use of an already-trained model?
if train:
# Código para entrenar modelo (train_splackle()) y retornar ruta a mejores pesos del entrenamiento
# Create directory where the newly trained model will be saved #save_path = os.path.join('imput_results', dataset, "best_model") # TODO: [PC] group opinion: ¿should we set data naming with date like in our works?
save_path = os.path.join('imput_results', dataset, run_date)
os.makedirs(save_path, exist_ok=True)
# Save script arguments in json file
with open(os.path.join(save_path, 'script_params.json'), 'w') as f:
json.dump(args_dict, f, indent=4)
# Train new SpaCKLE model
train_spackle(
adata=adata,
device=device,
save_path=save_path,
prediction_layer=from_layer,
lr=lr,
train=train,
get_performance=get_performance_metrics,
load_ckpt_path=load_ckpt_path,
optimizer=optimizer,
max_steps=max_steps,
args=args)
load_ckpt_path = glob.glob(os.path.join(save_path, '*.ckpt'))[0]
else:
assert os.path.exists(load_ckpt_path), "load_ckpts_path not found. Please use train = True if you do not have the checkpoints of a trained SpaCKLE model and its corresponding script_params.json file."
save_path = os.path.dirname(load_ckpt_path)
# FIXME: [PC] decidir qué elementos comparar y recordar que al subir los pesos (i.e Drive) subirlos con su json de params correspondiente
with open(os.path.join(save_path, 'script_params.json'), 'r') as f:
saved_script_params = json.load(f)
# Check that the parameters of the loaded model agree with the current inference process
#if (saved_script_params['prediction_layer'] != args_dict['prediction_layer']) or (saved_script_params['prediction_layer'] != args_dict['prediction_layer']):
# warnings.warn("Saved model's parameters differ from those of the current argparse.")
if saved_script_params['transformer_dim'] != adata.n_vars:
warnings.warn("The architecture of the model you want to load may not be compatible with the shape of the data.")
# Declare model
vis_features_dim = 0
model = GeneImputationModel(
args=args,
data_input_size=adata.n_vars,
lr=lr,
optimizer=optimizer,
vis_features_dim=vis_features_dim
).to(device)
# Load best checkpoints
state_dict = torch.load(load_ckpt_path)
state_dict = state_dict['state_dict']
model.load_state_dict(state_dict)
model.to(device)
model.eval()
print(f"Finished loading model with weights from {load_ckpt_path}")
# Prepare data and dataloader
data = ImputationDataset(adata, args, 'complete', from_layer)
dataloader = DataLoader(
data,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True,
drop_last=False,
num_workers=args.num_workers)
# Get gene imputations for missing values of randomly masked elements trhoughout the entire dataset
all_exps = []
all_masks = []
exp_with_imputation = []
print("----"*30)
print(f"Completing missing values in adata with the SpaCKLE model from {load_ckpt_path}")
with torch.no_grad():
for batch in tqdm(dataloader):
del batch['split_name']
# Extract batch variables
batch = {k: v.to(device) for k, v in batch.items()}
expression_gt = batch['exp_matrix_gt']
mask = batch['real_missing']
# Remove median imputations from gene expression matrix
input_genes = expression_gt.clone()
input_genes[~mask] = 0
# Get predictions
prediction = model.forward(input_genes)
# Imput predicted gene expression only in missing data for 'main spot' in the neighborhood
imputed_exp = torch.where(mask[:,0,:], expression_gt[:,0,:], prediction[:,0,:])
all_exps.append(expression_gt[:,0,:])
all_masks.append(batch['real_missing'][:,0,:])
exp_with_imputation.append(imputed_exp)
# Concatenate output tensors into complete data expression matrix
all_exps = torch.cat(all_exps)
all_masks = torch.cat(all_masks)
exp_with_imputation = torch.cat(exp_with_imputation)
# Add imputed data to adata
adata.layers[to_layer] = np.asarray(exp_with_imputation.cpu().double())
# Return the adata with cleaned layer and the path to the ckpts used to complete the missing values.
return adata, load_ckpt_path