Source code for spared.dataloaders.dataloaders

import anndata as ad
from anndata.experimental.pytorch import AnnLoader
import torch
import os
import glob
import json
from time import time
from datetime import datetime
from torch_geometric.loader import DataLoader as geo_DataLoader
import numpy as np
import pathlib
from typing import Tuple
import sys
from typing import Tuple

# Path a spared
SPARED_PATH = pathlib.Path(__file__).resolve().parent.parent

# Agregar el directorio padre al sys.path para los imports
sys.path.append(str(SPARED_PATH))
# Import im_encoder.py file
from layer_operations import layer_operations
from spot_features import spot_features
from graph_operations import graph_operations
# Remove the path from sys.path
sys.path.remove(str(SPARED_PATH))


# TODO: Fix the internal fixme (DISCUSS AGAIN)
[docs]def get_pretrain_dataloaders(adata: ad.AnnData, layer: str = 'c_d_log1p', batch_size: int = 128, shuffle: bool = True, use_cuda: bool = False) -> Tuple[AnnLoader, AnnLoader, AnnLoader]: """ Get dataloaders for pretraining an image encoder. This function returns the dataloaders for training an image encoder. This means training a purely vision-based model on only the patches to predict the gene expression of the patches. Dataloaders are returned as a tuple, if there is no test set for the dataset, then the test dataloader is None. Args: adata (ad.AnnData): The AnnData object that will be processed. layer (str, optional): The layer to use for the pre-training. The adata.X will be set to that of 'layer'. Defaults to 'deltas'. batch_size (int, optional): The batch size of the loaders. Defaults to 128. shuffle (bool, optional): Whether to shuffle the data in the loaders. Defaults to True. use_cuda (bool, optional): True for using cuda in the loader. Defaults to False. Returns: Tuple[AnnLoader, AnnLoader, AnnLoader]: The train, validation and test dataloaders. If there is no test set, the test dataloader is None. """ # Get the sample indexes for the train, validation and test sets idx_train, idx_val, idx_test = adata.obs[adata.obs.split == 'train'].index, adata.obs[adata.obs.split == 'val'].index, adata.obs[adata.obs.split == 'test'].index ##### Addition to handle noisy training ##### # FIXME: Put this in a part of the complete processing pipeline instead of the dataloader function. # Handle noisy training # Add this function in procces_data function and automaticaaly generate noisy layers for this layers: # c_d_log1p, c_t_log1p, c_d_deltas, c_t_deltas # FIXME: This is generating the unwanted message "Using noisy_delta layer for training. This will probably yield bad results." in quickstart tutorial adata = layer_operations.add_noisy_layer(adata=adata, prediction_layer=layer) # Set the X of the adata to the layer casted to float32 adata.X = adata.layers[layer].astype(np.float32) imp_model_str = 'transformer model' if layer in ['c_t_log1p', 'c_t_deltas'] else 'median filter' # Print with the percentage of the dataset that was replaced imp_pct = 100 * (~adata.layers["mask"]).sum() / (adata.n_vars*adata.n_obs) print('Percentage of imputed observations with {}: {:5.3f}%'.format(imp_model_str, imp_pct)) # If the prediction layer is some form of deltas, add the used mean of the layer as a column in the var if 'deltas' in layer: # Add a var column of used means of the layer mean_key = f'{layer}_avg_exp'.replace('deltas', 'log1p') adata.var['used_mean'] = adata.var[mean_key] # Subset the global data handle also the possibility that there is no test set adata_train, adata_val = adata[idx_train, :], adata[idx_val, :] adata_test = adata[idx_test, :] if len(idx_test) > 0 else None # Declare dataloaders train_dataloader = AnnLoader(adata_train, batch_size=batch_size, shuffle=shuffle, use_cuda=use_cuda) val_dataloader = AnnLoader(adata_val, batch_size=batch_size, shuffle=shuffle, use_cuda=use_cuda) test_dataloader = AnnLoader(adata_test, batch_size=batch_size, shuffle=shuffle, use_cuda=use_cuda) if adata_test is not None else None return train_dataloader, val_dataloader, test_dataloader
# TODO: Fix the internal fixme (DEPENDS ON THE PREVIOUS DISCUSSION)
[docs]def get_graph_dataloaders(adata: ad.AnnData, dataset_path: str='', layer: str = 'c_t_log1p', n_hops: int = 2, backbone: str ='densenet', model_path: str = "None", batch_size: int = 128, shuffle: bool = True, hex_geometry: bool=True, patch_size: int=224) -> Tuple[geo_DataLoader, geo_DataLoader, geo_DataLoader]: """ Get dataloaders for the graphs of a dataset. This function performs all the pipeline to get graphs dataloaders for a dataset. It does the following steps: 1. Computes embeddings and predictions for the patches using the specified backbone and model. 2. Computes the graph dictionaries for the dataset using the embeddings and predictions. 3. Saves the graphs in the dataset_path folder. 4. Returns the train, validation and test dataloaders for the graphs. The function also checks if the graphs are already saved in the dataset_path folder. If they are, it loads them instead of recomputing them. In case the dataset has no test set, the test dataloader is set to None. Args: adata (ad.AnnData): The AnnData object to process. dataset_path (str, optional): The path to the dataset (where the graphs will be stored). Defaults to ''. layer (str, optional): Layer to predict. Defaults to 'c_t_log1p'. n_hops (int, optional): Number of hops to compute the graph. Defaults to 2. backbone (str, optional): Backbone model to use. Defaults to 'densenet'. model_path (str, optional): Path to the model to use. Defaults to "None". batch_size (int, optional): Batch size of the dataloaders. Defaults to 128. shuffle (bool, optional): Whether to shuffle the data in the dataloaders. Defaults to True. hex_geometry (bool, optional): Whether the graph is hexagonal or not. Defaults to True. patch_size (int, optional): Size of the patches. Defaults to 224. Returns: Tuple[geo_DataLoader, geo_DataLoader, geo_DataLoader]: _description_ """ # Get dictionary of parameters to get the graphs curr_graph_params = { 'n_hops': n_hops, 'layer': layer, 'backbone': backbone, 'model_path': model_path } # Create graph directory if it does not exist os.makedirs(os.path.join(dataset_path, 'graphs'), exist_ok=True) # Get the filenames of the parameters of all directories in the graph folder filenames = glob.glob(os.path.join(dataset_path, 'graphs', '**', 'graph_params.json' ), recursive=True) # Define boolean to check if the graphs are already saved found_graphs = False # Iterate over all the filenames and check if the parameters are the same for filename in filenames: with open(filename, 'r') as f: # Load the parameters of the dataset saved_params = json.load(f) # Check if the parameters are the same if saved_params == curr_graph_params: print(f'Graph data already saved in {filename}') found_graphs = True # Track the time and load the graphs start = time() train_graphs = torch.load(os.path.join(os.path.dirname(filename), 'train_graphs.pt')) val_graphs = torch.load(os.path.join(os.path.dirname(filename), 'val_graphs.pt')) test_graphs = torch.load(os.path.join(os.path.dirname(filename), 'test_graphs.pt')) if os.path.exists(os.path.join(os.path.dirname(filename), 'test_graphs.pt')) else None print(f'Loaded graphs in {time() - start:.2f} seconds.') break # If the graphs are not found, compute them if not found_graphs: # Print that we are computing the graphs print('Graphs not found in file, computing graphs...') # FIXME: Again this should be in the processing part and not in the dataloader adata = layer_operations.add_noisy_layer(adata=adata, prediction_layer=layer) # We compute the embeddings and predictions for the patches spot_features.compute_patches_embeddings(adata=adata, backbone=backbone, model_path=model_path, patch_size=patch_size) spot_features.compute_patches_predictions(adata=adata, backbone=backbone, model_path=model_path, patch_size=patch_size) # Get graph dicts general_graph_dict = graph_operations.get_graphs(adata=adata, n_hops=n_hops, layer=layer, hex_geometry=hex_geometry) # Get the train, validation and test indexes idx_train, idx_val, idx_test = adata.obs[adata.obs.split == 'train'].index, adata.obs[adata.obs.split == 'val'].index, adata.obs[adata.obs.split == 'test'].index # Get list of graphs train_graphs = [general_graph_dict[idx] for idx in idx_train] val_graphs = [general_graph_dict[idx] for idx in idx_val] test_graphs = [general_graph_dict[idx] for idx in idx_test] if len(idx_test) > 0 else None print('Saving graphs...') # Create graph directory if it does not exist with the current time graph_dir = os.path.join(dataset_path, 'graphs', datetime.now().strftime("%d-%m-%Y-%H-%M-%S")) os.makedirs(graph_dir, exist_ok=True) # Save the graph parameters with open(os.path.join(graph_dir, 'graph_params.json'), 'w') as f: # Write the json json.dump(curr_graph_params, f, indent=4) torch.save(train_graphs, os.path.join(graph_dir, 'train_graphs.pt')) torch.save(val_graphs, os.path.join(graph_dir, 'val_graphs.pt')) torch.save(test_graphs, os.path.join(graph_dir, 'test_graphs.pt')) if test_graphs is not None else None # Declare dataloaders train_dataloader = geo_DataLoader(train_graphs, batch_size=batch_size, shuffle=shuffle) val_dataloader = geo_DataLoader(val_graphs, batch_size=batch_size, shuffle=shuffle) test_dataloader = geo_DataLoader(test_graphs, batch_size=batch_size, shuffle=shuffle) if test_graphs is not None else None return train_dataloader, val_dataloader, test_dataloader