Source code for spared.spot_features.spot_features

import anndata as ad
import torch
import torchvision.models as tmodels
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import scanpy as sc
import squidpy as sq
import sys
import pathlib

# Path a spared 
SPARED_PATH = pathlib.Path(__file__).resolve().parent.parent
# Agregar el directorio padre al sys.path para los imports
sys.path.append(str(SPARED_PATH))
# Import models.py file
from models import models
# Remove the path from sys.path
sys.path.remove(str(SPARED_PATH))

### Patch processing functions
## Compute the patch embeddings
# TODO: Check if we can remove the patch_scale parameter
# TODO: Make all the backbones options to correspond with pytorch models.
[docs]def compute_patches_embeddings(adata: ad.AnnData, backbone: str ='densenet', model_path:str='None', patch_size: int = 224) -> None: """ Compute embeddings for patches. This function computes embeddings (last layer representations) for a given backbone model and adata object. It can optionally compute using a stored model in ``model_path`` or a pretrained model from `pytorch <https://pytorch.org/vision/stable/models.html>`_. The embeddings are stored in ``adata.obsm[f'embeddings_{backbone}']``. The patches must already be stored in a flattened format inside ``adata.obsm[f'patches_scale_{patch_scale}']`` and must be of shape ``(n_patches, patch_size*patch_size*3)``. The ``patch_scale`` key can be whatever you want as long as there is only one key with the word ``patches_scale`` in the ``obsm`` keys. Normally, the key is ``patches_scale_1.0``. The function only modifies the AnnData object in place. The patch information should be in ``int`` format from ``0`` to ``255``. All needed transformations are done inside the function. Args: adata (ad.AnnData): The AnnData object with the patches to process. backbone (str, optional): A string specifying the backbone model to use. Must be one of the following ``['resnet', 'resnet50', 'ConvNeXt', 'EfficientNetV2', 'InceptionV3', 'MaxVit', 'MobileNetV3', 'ResNetXt', 'ShuffleNetV2', 'ViT', 'WideResnet', 'densenet', 'swin']``. Defaults to ``'densenet'``. model_path (str, optional): The path to a stored model. If set to ``'None'``, then an ImageNet pretrained model is used. Defaults to ``'None'``. patch_size (int, optional): The size of the patches. Defaults to ``224``. Raises: ValueError: If the backbone is not supported. """ # Define a cuda device if available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Verify that the patch scale exists and only exists once obsm_keys = list(adata.obsm.keys()) patch_scale_key = [key for key in obsm_keys if "patches_scale" in key] assert len(patch_scale_key) == 1, "patches_scale key either does not exist or exists more than once in keys_list." # Get the patch scale patch_scale = patch_scale_key[0].split('_')[-1] # Define the model model = models.ImageEncoder(backbone=backbone, use_pretrained=True, latent_dim=adata.n_vars) if model_path != "None": saved_model = torch.load(model_path) # Check if state_dict is inside a nested dictionary if 'state_dict' in saved_model.keys(): saved_model = saved_model['state_dict'] model.load_state_dict(saved_model) # Define the weights for the model depending on the backbone if backbone == 'resnet': weights = tmodels.ResNet18_Weights.DEFAULT model.encoder.fc = nn.Identity() elif backbone == 'resnet50': weights = tmodels.ResNet50_Weights.DEFAULT model.encoder.fc = nn.Identity() elif backbone == 'ConvNeXt': weights = tmodels.ConvNeXt_Tiny_Weights.DEFAULT model.encoder.classifier[2] = nn.Identity() elif backbone == 'EfficientNetV2': weights = tmodels.EfficientNet_V2_S_Weights.DEFAULT model.encoder.classifier[1] = nn.Identity() elif backbone == 'InceptionV3': weights = tmodels.Inception_V3_Weights.DEFAULT model.encoder.fc = nn.Identity() elif backbone == "MaxVit": weights = tmodels.MaxVit_T_Weights.DEFAULT model.encoder.classifier[5] = nn.Identity() elif backbone == "MobileNetV3": weights = tmodels.MobileNet_V3_Small_Weights.DEFAULT model.encoder.classifier[3] = nn.Identity() elif backbone == "ResNetXt": weights = tmodels.ResNeXt50_32X4D_Weights.DEFAULT model.encoder.fc = nn.Identity() elif backbone == "ShuffleNetV2": weights = tmodels.ShuffleNet_V2_X0_5_Weights.DEFAULT model.encoder.fc = nn.Identity() elif backbone == "ViT": weights = tmodels.ViT_B_16_Weights.DEFAULT model.encoder.heads.head = nn.Identity() elif backbone == "WideResnet": weights = tmodels.Wide_ResNet50_2_Weights.DEFAULT model.encoder.fc = nn.Identity() elif backbone == 'densenet': weights = tmodels.DenseNet121_Weights.DEFAULT model.encoder.classifier = nn.Identity() elif backbone == 'swin': weights = tmodels.Swin_T_Weights.DEFAULT model.encoder.head = nn.Identity() else: raise ValueError(f'Backbone {backbone} not supported') # Pass model to device and put in eval mode model.to(device) model.eval() # Perform specific preprocessing for the model preprocess = weights.transforms() # Get the patches # patch_scale = 1.0 flat_patches = adata.obsm[f'patches_scale_{patch_scale}'] # Reshape all the patches to the original shape all_patches = flat_patches.reshape((-1, patch_size, patch_size, 3)) torch_patches = torch.from_numpy(all_patches).permute(0, 3, 1, 2).float() # Turn all patches to torch rescaled_patches = torch_patches / 255 # Rescale patches to [0, 1] processed_patches = preprocess(rescaled_patches) # Preprocess patches # Create a dataloader dataloader = DataLoader(processed_patches, batch_size=256, shuffle=False, num_workers=4) # Declare lists to store the embeddings outputs = [] with torch.no_grad(): desc = 'Getting embeddings' for batch in tqdm(dataloader, desc=desc): batch = batch.to(device) # Send batch to device batch_output = model(batch) # Get embeddings outputs.append(batch_output) # Append to list # Concatenate all embeddings outputs = torch.cat(outputs, dim=0) # Pass embeddings to cpu and add to data.obsm adata.obsm[f'embeddings_{backbone}'] = outputs.cpu().numpy()
## Compute the patch predictions # TODO: Check if we can remove the patch_scale parameter # TODO: Make all the backbones options to correspond with pytorch models.
[docs]def compute_patches_predictions(adata: ad.AnnData, backbone: str ='densenet', model_path:str="None", patch_size: int = 224) -> None: """ Compute predictions for patches. This function computes gene expression predictions for a given backbone model and adata object. It can optionally compute using a stored model in ``model_path`` or a pretrained model from `pytorch <https://pytorch.org/vision/stable/models.html>`_. The predictions are stored in ``adata.obsm[f'predictions_{backbone}']``. The patches must already be stored in a flattened format inside ``adata.obsm[f'patches_scale_{patch_scale}']`` and must be of shape ``(n_patches, patch_size*patch_size*3)``. The ``patch_scale`` key can be whatever you want as long as there is only one key with the word ``patches_scale`` in the ``obsm`` keys. Normally, the key is ``patches_scale_1.0``. The function only modifies the AnnData object in place. The patch information should be in ``int`` format from ``0`` to ``255``. All needed transformations are done inside the function. All models will be declared to have the same number of outputs as genes in the ``adata`` object (``adata.n_vars``). Please also note that if you try to predict with a model that has only been pretrained on ImageNet, the predictions will be random and not useful. So always try to use models pretrained in spatial transcriptomics datasets. Args: adata (ad.AnnData): The AnnData object with the patches to process. backbone (str, optional): A string specifying the backbone model to use. Must be one of the following ``['resnet', 'resnet50', 'ConvNeXt', 'EfficientNetV2', 'InceptionV3', 'MaxVit', 'MobileNetV3', 'ResNetXt', 'ShuffleNetV2', 'ViT', 'WideResnet', 'densenet', 'swin']``. Defaults to ``'densenet'``. model_path (str, optional): The path to a stored model. If set to ``'None'``, then an ImageNet pretrained model is used. Defaults to ``'None'``. patch_size (int, optional): The size of the patches. Defaults to ``224``. Raises: ValueError: If the backbone is not supported. """ # Define a cuda device if available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Verify that the patch scale exists and only exists once obsm_keys = list(adata.obsm.keys()) patch_scale_key = [key for key in obsm_keys if "patches_scale" in key] assert len(patch_scale_key) == 1, "patches_scale key either does not exist or exists more than once in keys_list." # Get the patch scale patch_scale = patch_scale_key[0].split('_')[-1] # Define the model model = models.ImageEncoder(backbone=backbone, use_pretrained=True, latent_dim=adata.n_vars) if model_path != "None": saved_model = torch.load(model_path) # Check if state_dict is inside a nested dictionary if 'state_dict' in saved_model.keys(): saved_model = saved_model['state_dict'] model.load_state_dict(saved_model) # Define the weights for the model depending on the backbone if backbone == 'resnet': weights = tmodels.ResNet18_Weights.DEFAULT elif backbone == 'resnet50': weights = tmodels.ResNet50_Weights.DEFAULT elif backbone == 'ConvNeXt': weights = tmodels.ConvNeXt_Tiny_Weights.DEFAULT elif backbone == 'EfficientNetV2': weights = tmodels.EfficientNet_V2_S_Weights.DEFAULT elif backbone == 'InceptionV3': weights = tmodels.Inception_V3_Weights.DEFAULT elif backbone == "MaxVit": weights = tmodels.MaxVit_T_Weights.DEFAULT elif backbone == "MobileNetV3": weights = tmodels.MobileNet_V3_Small_Weights.DEFAULT elif backbone == "ResNetXt": weights = tmodels.ResNeXt50_32X4D_Weights.DEFAULT elif backbone == "ShuffleNetV2": weights = tmodels.ShuffleNet_V2_X0_5_Weights.DEFAULT elif backbone == "ViT": weights = tmodels.ViT_B_16_Weights.DEFAULT elif backbone == "WideResnet": weights = tmodels.Wide_ResNet50_2_Weights.DEFAULT elif backbone == 'densenet': weights = tmodels.DenseNet121_Weights.DEFAULT elif backbone == 'swin': weights = tmodels.Swin_T_Weights.DEFAULT else: raise ValueError(f'Backbone {backbone} not supported') # Pass model to device and put in eval mode model.to(device) model.eval() # Perform specific preprocessing for the model preprocess = weights.transforms() # Get the patches # patch_scale = 1.0 flat_patches = adata.obsm[f'patches_scale_{patch_scale}'] # Reshape all the patches to the original shape all_patches = flat_patches.reshape((-1, patch_size, patch_size, 3)) torch_patches = torch.from_numpy(all_patches).permute(0, 3, 1, 2).float() # Turn all patches to torch rescaled_patches = torch_patches / 255 # Rescale patches to [0, 1] processed_patches = preprocess(rescaled_patches) # Preprocess patches # Create a dataloader dataloader = DataLoader(processed_patches, batch_size=256, shuffle=False, num_workers=4) # Declare lists to store the embeddings or predictions outputs = [] with torch.no_grad(): desc = 'Getting predictions' for batch in tqdm(dataloader, desc=desc): batch = batch.to(device) # Send batch to device batch_output = model(batch) # Get predictions outputs.append(batch_output) # Append to list # Concatenate all embeddings or predictions outputs = torch.cat(outputs, dim=0) # Pass predictions to cpu and add to data.obsm adata.obsm[f'predictions_{backbone}'] = outputs.cpu().numpy()
### Define function to get dimensionality reductions depending on the layer # TODO: Add in the documentation which keys and information is added by the function. # TODO: Add references to scanpy's functions inside documentation.
[docs]def compute_dim_red(adata: ad.AnnData, from_layer: str) -> ad.AnnData: """ Compute dimensionality reductions and clusters Simple wrapper around ``sc.pp.pca()``, ``sc.pp.neighbors()``, ``sc.tl.umap()`` and ``sc.tl.leiden()`` with default parameters to compute the embeddings and cluster the data. Everything will be computed using the expression matrix stored in ``adata.layers[from_layer]``. Args: adata (ad.AnnData): The AnnData object to transform. Must have expression values in ``adata.layers[from_layer]``. from_layer (str): The key in ``adata.layers`` where the expression matrix is stored. Returns: ad.AnnData: The transformed AnnData object with the dimensionality reductions and clusters. """ # Start the timer # start = time() # print(f'Computing embeddings and clusters using data of layer {from_layer}...') # Set the key layer as the main expression matrix adata_copy = adata.copy() adata_copy.X = adata_copy.layers[from_layer] # Compute the embeddings and clusters sc.pp.pca(adata_copy, random_state=42) sc.pp.neighbors(adata_copy, random_state=42) sc.tl.umap(adata_copy, random_state=42) sc.tl.leiden(adata_copy, key_added="cluster", random_state=42) # Restore the original expression matrix as counts layer adata_copy.X = adata_copy.layers['counts'] # Print the time it took to compute the embeddings and clusters # print(f'Embeddings and clusters computed in {time() - start:.2f} seconds') # Return the adapted AnnData object return adata_copy
### Define function to get spatial neighbors in an AnnData object # TODO: Shouldn't we make this function robust for a collection of slides?
[docs]def get_spatial_neighbors(adata: ad.AnnData, n_hops: int, hex_geometry: bool) -> dict: """ Compute neighbors dictionary for an AnnData object. This function computes a neighbors dictionary for an AnnData object. The neighbors are computed according to topological distances over a graph defined by the ``hex_geometry`` connectivity. The neighbors dictionary is a dictionary where the keys are the indexes of the observations and the values are lists of the indexes of the neighbors of each observation. The neighbors include the observation itself as first element and are found inside an ``n_hops`` neighborhood (vicinity) of the observation. Args: adata (ad.AnnData): The AnnData object to process. Importantly it is only from a single slide. Can not be a collection of slides. n_hops (int): The size of the neighborhood to take into account to compute the neighbors. hex_geometry (bool): Whether the graph is hexagonal or not. If ``True``, then the graph is hexagonal. If ``False``, then the graph is a grid. Only ``True`` for Visium datasets. Returns: dict: The neighbors dictionary. The keys are the indexes of the observations and the values are lists of the indexes of the neighbors of each observation. """ # Compute spatial_neighbors if hex_geometry: sq.gr.spatial_neighbors(adata, coord_type='generic', n_neighs=6) # Hexagonal visium case else: sq.gr.spatial_neighbors(adata, coord_type='grid', n_neighs=8) # Grid dataset case # Get the adjacency matrix adj_matrix = adata.obsp['spatial_connectivities'] # Define power matrix power_matrix = adj_matrix.copy() # Define the output matrix output_matrix = adj_matrix.copy() # Iterate through the hops for i in range(n_hops-1): # Compute the next hop power_matrix = power_matrix * adj_matrix # Add the next hop to the output matrix output_matrix = output_matrix + power_matrix # Zero out the diagonal output_matrix.setdiag(0) # Threshold the matrix to 0 and 1 output_matrix = output_matrix.astype(bool).astype(int) # Define neighbors dict neighbors_dict_index = {} # Iterate through the rows of the output matrix for i in range(output_matrix.shape[0]): # Get the non-zero elements of the row non_zero_elements = output_matrix[i].nonzero()[1] # Add the neighbors to the neighbors dicts. NOTE: the first index is the query obs neighbors_dict_index[i] = [i] + list(non_zero_elements) # Return the neighbors dict return neighbors_dict_index