Source code for spared.graph_operations.graph_operations

import anndata as ad
import torch
from positional_encodings.torch_encodings import PositionalEncoding2D
from tqdm import tqdm
from torch_geometric.data import Data as geo_Data
import numpy as np
import pathlib
import squidpy as sq
from torch_geometric.utils import from_scipy_sparse_matrix
from typing import Tuple
import sys
from typing import Tuple

# Path a spared 
SPARED_PATH = pathlib.Path(__file__).resolve().parent.parent

# Agregar el directorio padre al sys.path para los imports
sys.path.append(str(SPARED_PATH))
# Import im_encoder.py file
from filtering import filtering
# Remove the path from sys.path
sys.path.remove(str(SPARED_PATH))

### Graph building functions:
[docs]def get_graphs_one_slide(adata: ad.AnnData, n_hops: int, layer: str, hex_geometry: bool) -> Tuple[dict,int]: """ Get neighbor graphs for a single slide. This function receives an AnnData object with a single slide and for each node computes the graph in an n_hops radius in a pytorch geometric format. The AnnData object must have both embeddings and predictions in the adata.obsm attribute. It returns a dictionary where the patch names are the keys and a pytorch geometric graph for each one as values. NOTE: The first node of every graph is the center. Args: adata (ad.AnnData): The AnnData object with the slide data. n_hops (int): The number of hops to compute the graph. layer (str): The layer of the graph to predict. Will be added as y to the graph. hex_geometry (bool): Whether the slide has hexagonal geometry or not. Returns: Tuple(dict,int) dict: A dictionary where the patch names are the keys and pytorch geometric graph for each one as values. The first node of every graph is the center. int: Max column or row difference between the center and the neighbors. Used for positional encoding. """ # Compute spatial_neighbors if hex_geometry: sq.gr.spatial_neighbors(adata, coord_type='generic', n_neighs=6) # Hexagonal visium case else: sq.gr.spatial_neighbors(adata, coord_type='grid', n_neighs=8) # Grid STNet dataset case # Get the adjacency matrix adj_matrix = adata.obsp['spatial_connectivities'] # Define power matrix power_matrix = adj_matrix.copy() # Define the output matrix output_matrix = adj_matrix.copy() # Iterate through the hops for i in range(n_hops-1): # Compute the next hop power_matrix = power_matrix * adj_matrix # Add the next hop to the output matrix output_matrix = output_matrix + power_matrix # Zero out the diagonal output_matrix.setdiag(0) # Threshold the matrix to 0 and 1 output_matrix = output_matrix.astype(bool).astype(int) # Define dict from index to obs name index_to_obs = {i: obs for i, obs in enumerate(adata.obs.index.values)} # Define neighbors dicts (one with names and one with indexes) neighbors_dict_index = {} neighbors_dict_names = {} matrices_dict = {} # Iterate through the rows of the output matrix for i in range(output_matrix.shape[0]): # Get the non-zero elements of the row non_zero_elements = output_matrix[i].nonzero()[1] # Get the names of the neighbors non_zero_names = [index_to_obs[index] for index in non_zero_elements] # Add the neighbors to the neighbors dicts. NOTE: the first index is the query obs neighbors_dict_index[i] = [i] + list(non_zero_elements) neighbors_dict_names[index_to_obs[i]] = np.array([index_to_obs[i]] + non_zero_names) # Subset the matrix to the non-zero elements and store it in the matrices dict matrices_dict[index_to_obs[i]] = output_matrix[neighbors_dict_index[i], :][:, neighbors_dict_index[i]] ### Get pytorch geometric graphs ### layers_dict = {key: torch.from_numpy(adata.layers[key]).type(torch.float32) for key in adata.layers.keys()} # Get global layers pos = torch.from_numpy(adata.obs[['array_row', 'array_col']].values) # Get global positions # Get embeddings and predictions keys emb_key_list = [k for k in adata.obsm.keys() if 'embeddings' in k] pred_key_list = [k for k in adata.obsm.keys() if 'predictions' in k] assert len(emb_key_list) == 1, 'There are more than 1 or no embedding keys in adata.obsm' assert len(pred_key_list) == 1, 'There are more than 1 or no prediction keys in adata.obsm' emb_key, pred_key = emb_key_list[0], pred_key_list[0] # If embeddings and predictions are present in obsm, get them embeddings = torch.from_numpy(adata.obsm[emb_key]).type(torch.float32) predictions = torch.from_numpy(adata.obsm[pred_key]).type(torch.float32) # If layer contains delta then add a used_mean attribute to the graph used_mean = torch.from_numpy(adata.var[f'{layer}_avg_exp'.replace('deltas', 'log1p')].values).type(torch.float32) if 'deltas' in layer else None # Define the empty graph dict graph_dict = {} max_abs_d_pos=-1 # Cycle over each obs for i in tqdm(range(len(neighbors_dict_index)), leave=False, position=1): central_node_name = index_to_obs[i] # Get the name of the central node curr_nodes_idx = torch.tensor(neighbors_dict_index[i]) # Get the indexes of the nodes in the graph curr_adj_matrix = matrices_dict[central_node_name] # Get the adjacency matrix of the graph (precomputed) curr_edge_index, _ = from_scipy_sparse_matrix(curr_adj_matrix) # Get the edge index and edge attribute of the graph curr_layers = {key: layers_dict[key][curr_nodes_idx] for key in layers_dict.keys()} # Get the layers of the graph filtered by the nodes curr_pos = pos[curr_nodes_idx] # Get the positions of the nodes in the graph curr_d_pos = curr_pos - curr_pos[0] # Get the relative positions of the nodes in the graph # Define the graph graph_dict[central_node_name] = geo_Data( y=curr_layers[layer], edge_index=curr_edge_index, pos=curr_pos, d_pos=curr_d_pos, embeddings=embeddings[curr_nodes_idx], predictions=predictions[curr_nodes_idx] if predictions is not None else None, used_mean=used_mean if used_mean is not None else None, num_nodes=len(curr_nodes_idx), mask=layers_dict['mask'][curr_nodes_idx] ) max_curr_d_pos=curr_d_pos.abs().max() if max_curr_d_pos>max_abs_d_pos: max_abs_d_pos=max_curr_d_pos #cast as int max_abs_d_pos=int(max_abs_d_pos) # Return the graph dict return graph_dict, max_abs_d_pos
[docs]def get_sin_cos_positional_embeddings(graph_dict: dict, max_d_pos: int) -> dict: """ Get positional encodings for a neighbor graph. This function adds a transformer-like positional encodings to each graph in a graph dict. It adds the positional encodings under the attribute 'positional_embeddings' for each graph. Args: graph_dict (dict): A dictionary where the patch names are the keys and a pytorch geometric graphs for each one are values. max_d_pos (int): Max absolute value in the relative position matrix. Returns: dict: The input graph dict with the information of positional encodings for each graph. """ graph_dict_keys = list(graph_dict.keys()) embedding_dim = graph_dict[graph_dict_keys[0]].embeddings.shape[1] # Define the positional encoding model p_encoding_model= PositionalEncoding2D(embedding_dim) # Define the empty grid with size (batch_size, x, y, channels) grid_size = torch.zeros([1, 2*max_d_pos+1, 2*max_d_pos+1, embedding_dim]) # Obtain the embeddings for each position positional_look_up_table = p_encoding_model(grid_size) for key, value in graph_dict.items(): d_pos = value.d_pos grid_pos = d_pos + max_d_pos graph_dict[key].positional_embeddings = positional_look_up_table[0,grid_pos[:,0],grid_pos[:,1],:] return graph_dict
[docs]def get_graphs(adata: ad.AnnData, n_hops: int, layer: str, hex_geometry: bool=True) -> dict: """ Get graphs for all the slides in a dataset. This function wraps the get_graphs_one_slide function to get the graphs for all the slides in the dataset. After computing the graph dicts for each slide it concatenates them into a single dictionary which is then used to compute the positional embeddings for each graph. For details see get_graphs_one_slide and get_sin_cos_positional_embeddings functions. Args: adata (ad.AnnData): The AnnData object used to build the graphs. n_hops (int): The number of hops to compute each graph. layer (str): The layer of the graph to predict. Will be added as y to the graph. hex_geometry (bool): Whether the graph is hexagonal or not. Only true for visium datasets. Defaults to True. Returns: dict: A dictionary where the spots' names are the keys and pytorch geometric graphs are values. """ print('Computing graphs...') # Get unique slide ids unique_ids = adata.obs['slide_id'].unique() # Global dictionary to store the graphs (pytorch geometric graphs) graph_dict = {} max_global_d_pos=-1 # Iterate through slides for slide in tqdm(unique_ids, leave=True, position=0): curr_adata = filtering.get_slide_from_collection(adata, slide) curr_graph_dict, max_curr_d_pos = get_graphs_one_slide(curr_adata, n_hops, layer, hex_geometry) # Join the current dictionary to the global dictionary graph_dict = {**graph_dict, **curr_graph_dict} if max_curr_d_pos>max_global_d_pos: max_global_d_pos=max_curr_d_pos graph_dict = get_sin_cos_positional_embeddings(graph_dict, max_global_d_pos) # Return the graph dict return graph_dict