import glob
import anndata as ad
import os
os.environ['USE_PYGEOS'] = '0' # To supress a warning from geopandas
from PIL import Image
import warnings
import shutil
import torch.nn as nn
import torch
from torch.utils.data import Dataset
import numpy as np
import json
from torchvision.transforms import Normalize
from typing import Tuple
import torch
import argparse
import pathlib
import sys
#El path a spared es ahora diferente
SPARED_PATH = pathlib.Path(__file__).resolve().parent.parent
#Agregar el directorio padre al sys.path para los imports
sys.path.append(str(SPARED_PATH))
#TODO: AJUSTAR UNA VEZ ESTEN CREADOS LOS NUEVOS ARCHIVOS
# Import visualization and processing function
from filtering import filtering
from layer_operations import layer_operations
from dataloaders import dataloaders
from plotting import plotting
# Import all reader classes
from readers.AbaloReader import AbaloReader
from readers.BatiukReader import BatiukReader
from readers.EricksonReader import EricksonReader
from readers.FanReader import FanReader
from readers.MirzazadehReader import MirzazadehReader
from readers.ParigiReader import ParigiReader
from readers.VicariReader import VicariReader
from readers.VillacampaReader import VillacampaReader
from readers.VisiumReader import VisiumReader
#Remover el directorio padre al sys.path
sys.path.append(str(SPARED_PATH))
# Remove the max limit of pixels in a figure
Image.MAX_IMAGE_PIXELS = None
# Get the path of the spared database
DATA_PATH = os.getcwd()
# Set warnings to ignore
warnings.filterwarnings("ignore", message="No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored")
warnings.filterwarnings("ignore", message="Variable names are not unique. To make them unique, call `.var_names_make_unique`.")
warnings.filterwarnings("ignore", message="The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.")
warnings.filterwarnings("ignore", message="Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.")
# FIXME: Fix this warning FutureWarning: is_categorical_dtype is deprecated and will be removed in a future version. Use isinstance(dtype, CategoricalDtype) instead
# This is a problem between anndata and pandas
warnings.filterwarnings("ignore", category=FutureWarning, module='anndata')
warnings.filterwarnings("ignore", category=FutureWarning, module='squidpy')
warnings.filterwarnings("ignore", category=FutureWarning, module='scanpy')
# TODO: Think of implementing optional random subsampling of the dataset
class SpatialDataset():
def __init__(self,
dataset: str = 'V1_Breast_Cancer_Block_A',
param_dict: dict = {
'cell_min_counts': 1000,
'cell_max_counts': 100000,
'gene_min_counts': 1e3,
'gene_max_counts': 1e6,
'min_exp_frac': 0.8,
'min_glob_exp_frac': 0.8,
'real_data_percentage': 0.7,
'top_moran_genes': 256,
'wildcard_genes': 'None',
'combat_key': 'slide_id',
'random_samples': -1,
'plotting_slides': 'None',
'plotting_genes': 'None',
},
patch_scale: float = 1.0,
patch_size: int = 224,
force_compute: bool = False,
visualize: bool = True,
):
"""
This is a spatial data class that contains all the information about the dataset. It will call a reader class depending on the type
of dataset (by now only visium and STNet are supported). The reader class will download the data and read it into an AnnData collection
object. Then the dataset class will filter, process and plot quality control graphs for the dataset. The processed dataset will be stored
for rapid access in the future.
Args:
dataset (str, optional): An string encoding the dataset type. Defaults to 'V1_Breast_Cancer_Block_A'.
param_dict (dict, optional): Dictionary that contains filtering and processing parameters.
Detailed information about each key can be found in the parser definition over utils.py.
Defaults to {
'cell_min_counts': 1000,
'cell_max_counts': 100000,
'gene_min_counts': 1e3,
'gene_max_counts': 1e6,
'min_exp_frac': 0.8,
'min_glob_exp_frac': 0.8,
'real_data_percentage': 0.7,
'top_moran_genes': 256,
'wildcard_genes': 'None',
'combat_key': 'slide_id',
'random_samples': -1,
'plotting_slides': 'None',
'plotting_genes': 'None',
}.
patch_scale (float, optional): The scale of the patches to take into account. If bigger than 1, then the patches will be bigger than the original image. Defaults to 1.0.
patch_size (int, optional): The pixel size of the patches. Defaults to 224.
force_compute (bool, optional): Whether to force the processing computation or not. Defaults to False.
"""
# We define the variables for the SpatialDataset class
self.dataset = dataset
self.param_dict = param_dict
self.patch_scale = patch_scale
self.patch_size = patch_size
self.force_compute = force_compute
self.visualize = visualize
self.hex_geometry = False if self.dataset == 'stnet_dataset' else True # FIXME: Be careful with this attribute if we want to include more technologies
# We initialize the reader class (Both visium or stnet readers can be returned here)
self.reader_class = self.initialize_reader()
# We get the dict of split names
self.split_names = self.reader_class.split_names
# We get the dataset download path
self.download_path = self.reader_class.download_path
# Get the dataset path
reader_path = self.reader_class.dataset_path
split_reader = reader_path.split("processed_data")[1][1:]
self.dataset_path = os.path.join(DATA_PATH, "processed_data", split_reader)
# We load or compute the processed adata with patches.
self.adata = self.load_or_compute_adata()
#villacampa_lung_organoid
def initialize_reader(self):
"""
This function uses the parameters of the class to initialize the appropiate reader class
(Visium or STNet) and returns the reader class.
"""
if 'vicari' in self.dataset:
reader_class = VicariReader(
dataset=self.dataset,
param_dict=self.param_dict,
patch_scale=self.patch_scale,
patch_size=self.patch_size,
force_compute=self.force_compute
)
elif 'villacampa' in self.dataset:
reader_class = VillacampaReader(
dataset=self.dataset,
param_dict=self.param_dict,
patch_scale=self.patch_scale,
patch_size=self.patch_size,
force_compute=self.force_compute
)
elif 'mirzazadeh' in self.dataset:
reader_class = MirzazadehReader(
dataset=self.dataset,
param_dict=self.param_dict,
patch_scale=self.patch_scale,
patch_size=self.patch_size,
force_compute=self.force_compute
)
elif 'abalo' in self.dataset:
reader_class = AbaloReader(
dataset=self.dataset,
param_dict=self.param_dict,
patch_scale=self.patch_scale,
patch_size=self.patch_size,
force_compute=self.force_compute
)
elif 'erickson' in self.dataset:
reader_class = EricksonReader(
dataset=self.dataset,
param_dict=self.param_dict,
patch_scale=self.patch_scale,
patch_size=self.patch_size,
force_compute=self.force_compute
)
elif 'batiuk' in self.dataset:
reader_class = BatiukReader(
dataset=self.dataset,
param_dict=self.param_dict,
patch_scale=self.patch_scale,
patch_size=self.patch_size,
force_compute=self.force_compute
)
elif 'parigi' in self.dataset:
reader_class = ParigiReader(
dataset=self.dataset,
param_dict=self.param_dict,
patch_scale=self.patch_scale,
patch_size=self.patch_size,
force_compute=self.force_compute
)
elif 'fan' in self.dataset:
reader_class = FanReader(
dataset=self.dataset,
param_dict=self.param_dict,
patch_scale=self.patch_scale,
patch_size=self.patch_size,
force_compute=self.force_compute
)
else:
reader_class = VisiumReader(
dataset=self.dataset,
param_dict=self.param_dict,
patch_scale=self.patch_scale,
patch_size=self.patch_size,
force_compute=self.force_compute
)
return reader_class
# TODO: Update the docstring of this function (regarding process_dataset)
#def load_or_compute_adata(self) -> ad.AnnData:
def load_or_compute_adata(self) -> Tuple[ad.AnnData, ad.AnnData]:
"""
This function does the main data pipeline. It will first check if the processed data exists in the dataset_path. If it does not exist,
then it will compute it and save it. If it does exist, then it will load it and return it. If it is in the compute mode, then it will
also save quality control plots.
Returns:
ad.AnnData: The processed AnnData object ready to be used for training.
"""
curr_dict = self.__dict__.copy()
# Delete some keys from dictionary in order to just leave the class parameters
curr_dict.pop('reader_class', None)
curr_dict.pop('force_compute', None)
curr_dict.pop('visualize', None)
curr_dict.pop('download_path', None)
# If processed data does not exist, then compute and save it
if (not os.path.exists(os.path.join(self.dataset_path, f'adata.h5ad'))) or self.force_compute:
print('Computing main adata file from downloaded raw data...')
collection_raw = self.reader_class.get_adata_collection()
collection_filtered = filtering.filter_dataset(adata=collection_raw, param_dict=self.param_dict)
# Process data
collection_processed = layer_operations.process_dataset(
adata=collection_filtered, param_dict=self.param_dict)
# Save the processed data
os.makedirs(self.dataset_path, exist_ok=True)
collection_raw.write(os.path.join(self.dataset_path, f'adata_raw.h5ad'))
collection_processed.write(os.path.join(self.dataset_path, f'adata.h5ad'))
# Save parameters
with open(os.path.join(self.dataset_path, 'parameters.json'), 'w') as f:
json.dump(curr_dict, f, sort_keys=True, indent=4)
if self.visualize:
# QC plotting
plotting.plot_tests(self.patch_size, self.dataset, self.split_names, self.param_dict, self.dataset_path, collection_processed, collection_raw)
# Copy figures folder into public database
os.makedirs(os.path.join(DATA_PATH, 'PublicDatabase', self.dataset), exist_ok=True)
if os.path.exists(os.path.join(DATA_PATH, 'PublicDatabase', self.dataset, 'qc_plots')):
shutil.rmtree(os.path.join(DATA_PATH, 'PublicDatabase', self.dataset, 'qc_plots'))
shutil.copytree(os.path.join(self.dataset_path, 'qc_plots'), os.path.join(DATA_PATH, 'PublicDatabase', self.dataset, 'qc_plots'), dirs_exist_ok=True)
# Create README for dataset
if not os.path.exists(os.path.join(DATA_PATH, 'PublicDatabase', self.dataset, 'README.md')):
shutil.copy(os.path.join(SPARED_PATH, 'PublicDatabase', 'README_template.md'), os.path.join(DATA_PATH, 'PublicDatabase', self.dataset, 'README.md'))
else:
# Load processed adata
print(f'Loading main adata file from disk ({os.path.join(self.dataset_path, f"adata.h5ad")})...')
# If the file already exists, load it
collection_raw = ad.read_h5ad(os.path.join(self.dataset_path, f'adata_raw.h5ad'))
collection_processed = ad.read_h5ad(os.path.join(self.dataset_path, f'adata.h5ad'))
print('The loaded adata object looks like this:')
print(collection_processed)
# QC plotting if visualize is set to True
if self.visualize:
collection_raw = ad.read_h5ad(os.path.join(self.dataset_path, f'adata_raw.h5ad'))
# QC plotting
plotting.plot_tests(self.patch_size, self.dataset, self.split_names, self.param_dict, self.dataset_path, collection_processed, collection_raw)
# Copy figures folder into public database
os.makedirs(os.path.join(DATA_PATH, 'PublicDatabase', self.dataset), exist_ok=True)
if os.path.exists(os.path.join(DATA_PATH, 'PublicDatabase', self.dataset, 'qc_plots')):
shutil.rmtree(os.path.join(DATA_PATH, 'PublicDatabase', self.dataset, 'qc_plots'))
shutil.copytree(os.path.join(self.dataset_path, 'qc_plots'), os.path.join(DATA_PATH, 'PublicDatabase', self.dataset, 'qc_plots'), dirs_exist_ok=True)
# Create README for dataset
if not os.path.exists(os.path.join(DATA_PATH, 'PublicDatabase', self.dataset, 'README.md')):
shutil.copy(os.path.join(SPARED_PATH, 'PublicDatabase', 'README_template.md'), os.path.join(DATA_PATH, 'PublicDatabase', self.dataset, 'README.md'))
return collection_processed
class HisToGeneDataset(Dataset):
def __init__(self, adata, set_str):
self.set = set_str
if self.set == None:
self.adata = adata
else:
self.adata = adata[adata.obs.split == self.set]
self.idx_2_slide = {idx: slide for idx, slide in enumerate(self.adata.obs.slide_id.unique())}
#Perform transformations
self.transforms = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
tissue_tiles = self.adata.obsm['patches_scale_1.0']
# Pass to torch tensor
tissue_tiles = torch.from_numpy(tissue_tiles)
w = round(np.sqrt(tissue_tiles.shape[1]/3))
tissue_tiles = tissue_tiles.reshape((tissue_tiles.shape[0], w, w, -1))
# Permute dimensions to be in correct order for normalization
tissue_tiles = tissue_tiles.permute(0,3,1,2).contiguous()
# Make transformations in tissue tiles
tissue_tiles = tissue_tiles/255.
# Transform tiles
tissue_tiles = self.transforms(tissue_tiles)
# Flatten tiles
# self.adata.obsm['patches_scale_1.0_transformed'] = tissue_tiles.view(tissue_tiles.shape[0], -1)
self.adata.obsm['patches_scale_1.0_transformed_numpy'] = tissue_tiles.view(tissue_tiles.shape[0], -1).numpy()
# Define mask layer
self.adata.layers['mask'] = self.adata.layers['tpm'] != 0
def __len__(self):
return len(self.idx_2_slide)
def __getitem__(self, idx):
# Get the slide from the index
slide = self.idx_2_slide[idx]
# Get the adata of the slide
adata_slide = self.adata[self.adata.obs.slide_id == slide]
# Get the patches
patch = torch.from_numpy(adata_slide.obsm['patches_scale_1.0_transformed_numpy'])
# Get the coordinates
coor = torch.from_numpy(adata_slide.obs[['array_row', 'array_col']].values)
# Get the expression
exp = torch.from_numpy(adata_slide.X.toarray())
# Get the mask
mask = torch.from_numpy(adata_slide.layers['mask'])
return patch, coor, exp, mask
[docs]def get_dataset(dataset_name: str, visualize: bool = True) -> SpatialDataset:
"""Get a dataset from name.
This function receives the name of a dataset and retrieves a the correspondent ``SpatialDataset``. This function will retrieve configuration files predefined for
each dataset in SpaRED. The name of the dataset should be one of the following:
- ``10xgenomic_human_brain``
- ``10xgenomic_human_breast_cancer``
- ``10xgenomic_mouse_brain_coronal``
- ``10xgenomic_mouse_brain_sagittal_anterior``
- ``10xgenomic_mouse_brain_sagittal_posterior``
- ``abalo_human_squamous_cell_carcinoma``
- ``erickson_human_prostate_cancer_p1``
- ``erickson_human_prostate_cancer_p2``
- ``fan_mouse_brain_coronal``
- ``fan_mouse_olfatory_bulb``
- ``mirzazadeh_human_colon_p1``
- ``mirzazadeh_human_colon_p2``
- ``mirzazadeh_human_pediatric_brain_tumor_p1``
- ``mirzazadeh_human_pediatric_brain_tumor_p2``
- ``mirzazadeh_human_prostate_cancer``
- ``mirzazadeh_human_small_intestine``
- ``mirzazadeh_mouse_bone``
- ``mirzazadeh_mouse_brain_p1``
- ``mirzazadeh_mouse_brain_p2``
- ``mirzazadeh_mouse_brain``
- ``parigi_mouse_intestine``
- ``vicari_human_striatium``
- ``vicari_mouse_brain``
- ``villacampa_kidney_organoid``
- ``villacampa_lung_organoid``
- ``villacampa_mouse_brain``
After the first load of each dataset, the data will be stored and the next time the function is called it will load the data from the stored file much faster.
Args:
dataset_name (str): The name of the dataset.
visualize (bool, optional): Whether to visualize the dataset or not. Can significantly increase run time of the command. Defaults to ``True``.
Returns:
SpatialDataset: The specified dataset in a ``SpatialDataset`` object.
"""
# Get the name of the config based on the dataset
config = os.path.join(SPARED_PATH, 'configs', f'{dataset_name}.json')
# Load the config
with open(config, 'r') as f:
config_dict = json.load(f)
# Assign auxiliary variables for dataset
patch_scale = config_dict['patch_scale']
patch_size = config_dict['patch_size']
force_compute = config_dict['force_compute']
# Refine config dict into a param dict
[config_dict.pop(k) for k in ['patch_scale', 'patch_size', 'force_compute', 'dataset', 'n_hops', 'prediction_layer']]
# Declare the spatial dataset
dataset = SpatialDataset(
dataset=dataset_name,
param_dict=config_dict,
patch_scale=patch_scale,
patch_size=patch_size,
force_compute=force_compute,
visualize=visualize,
)
return dataset
# Test code only for debugging
if __name__ == "__main__":
# Auxiliary function to use booleans in parser
str2bool = lambda x: (str(x).lower() == 'true')
# Define a simple parser and add an argument for the config file
parser = argparse.ArgumentParser(description='Test code for datasets.')
parser.add_argument('--config', type=str, default=os.path.join(SPARED_PATH, 'configs', '10xgenomic_human_breast_cancer.json'), help='Path to the config file.')
parser.add_argument('--prepare_datasets', type=str2bool, default=False, help='If True then it processes all datasets.')
parser.add_argument('--graphs_ie_paths', type=str, default='None', help='Path to the folders with optimal image encoder models to get graphs from. E.g. os.path.join("optimal_models", "spared_vit_backbone_c_d_deltas"')
args = parser.parse_args()
# Define dataset list
dataset_list = ['10xgenomic_human_brain', '10xgenomic_human_breast_cancer',
'10xgenomic_mouse_brain_coronal', '10xgenomic_mouse_brain_sagittal_anterior',
'10xgenomic_mouse_brain_sagittal_posterior', 'abalo_human_squamous_cell_carcinoma',
'erickson_human_prostate_cancer_p1', 'erickson_human_prostate_cancer_p2',
'fan_mouse_brain_coronal', 'fan_mouse_olfatory_bulb',
'mirzazadeh_human_colon_p1', 'mirzazadeh_human_colon_p2',
'mirzazadeh_human_pediatric_brain_tumor_p1', 'mirzazadeh_human_pediatric_brain_tumor_p2',
'mirzazadeh_human_prostate_cancer', 'mirzazadeh_human_small_intestine',
'mirzazadeh_mouse_bone', 'mirzazadeh_mouse_brain_p1',
'mirzazadeh_mouse_brain_p2', 'mirzazadeh_mouse_brain',
'parigi_mouse_intestine', 'vicari_human_striatium',
'vicari_mouse_brain', 'villacampa_kidney_organoid',
'villacampa_lung_organoid', 'villacampa_mouse_brain']
# If prepare datasets then run the dataset pipeline for all available datasets
if args.prepare_datasets == True:
# Define complete config files list
config_list = [os.path.join(SPARED_PATH, 'configs', f'{dset}.json') for dset in dataset_list]
# Iterate over config files
for curr_config_path in config_list:
# Load the config file
with open(curr_config_path, 'r') as f:
config = json.load(f)
# Define param dict
param_dict = {
'cell_min_counts': config['cell_min_counts'],
'cell_max_counts': config['cell_max_counts'],
'gene_min_counts': config['gene_min_counts'],
'gene_max_counts': config['gene_max_counts'],
'min_exp_frac': config['min_exp_frac'],
'min_glob_exp_frac': config['min_glob_exp_frac'],
'real_data_percentage': config['real_data_percentage'],
'top_moran_genes': config['top_moran_genes'],
'wildcard_genes': config['wildcard_genes'],
'combat_key': config['combat_key'],
'random_samples': config['random_samples'],
'plotting_slides': config['plotting_slides'],
'plotting_genes': config['plotting_genes'],
}
# Process the dataset and store it as adata
test_dataset = SpatialDataset(
dataset = config['dataset'],
param_dict = param_dict,
patch_scale = config['patch_scale'],
patch_size = config['patch_size'],
force_compute = config['force_compute']
)
elif args.graphs_ie_paths != 'None':
# iterate over datasets
for dset in dataset_list:
# Get the dataset
dataset = get_dataset(dset)
# Get model path
model_path = glob.glob(os.path.join(args.graphs_ie_paths, f'{dset}', '**', '*.ckpt'), recursive=True)[0]
# Get layer
layer_dict = {
'spared_vit_backbone_c_d_deltas': 'c_d_deltas',
'spared_vit_backbone_c_t_deltas': 'c_t_deltas',
'spared_vit_backbone_noisy_d': 'noisy_d',
}
layer = layer_dict[os.path.basename(os.path.normpath(args.graphs_ie_paths))]
# Get the graphs
train_dl, val_dl, test_dl = dataloaders.get_graph_dataloaders(
adata=dataset.adata, dataset_path=dataset.dataset_path, layer=layer, n_hops=3, backbone='ViT', model_path=model_path, batch_size=256, shuffle=False,
hex_geometry=dataset.hex_geometry, patch_size=dataset.patch_size, patch_scale=dataset.patch_scale)
# If prepare datasets and get graphs are false then only process the single dataset specified by the config arg
else:
# Load the config file
with open(args.config, 'r') as f:
config = json.load(f)
# Define param dict
param_dict = {
'cell_min_counts': config['cell_min_counts'],
'cell_max_counts': config['cell_max_counts'],
'gene_min_counts': config['gene_min_counts'],
'gene_max_counts': config['gene_max_counts'],
'min_exp_frac': config['min_exp_frac'],
'min_glob_exp_frac': config['min_glob_exp_frac'],
'real_data_percentage': config['real_data_percentage'],
'top_moran_genes': config['top_moran_genes'],
'wildcard_genes': config['wildcard_genes'],
'combat_key': config['combat_key'],
'random_samples': config['random_samples'],
'plotting_slides': config['plotting_slides'],
'plotting_genes': config['plotting_genes'],
}
# Process the dataset and store it as adata
test_dataset = SpatialDataset(
dataset = config['dataset'],
param_dict = param_dict,
patch_scale = config['patch_scale'],
patch_size = config['patch_size'],
force_compute = config['force_compute']
)