Source code for spared.metrics.metrics

import numpy as np
import torch
import warnings
from typing import Union, Tuple
import warnings
from scipy.stats import pearsonr
from sklearn.metrics import r2_score
from time import time

warnings.filterwarnings(action='ignore', category=UserWarning)

def pearsonr_cols(gt_mat: torch.Tensor, pred_mat: torch.Tensor, mask: torch.Tensor) -> Tuple[float, list]:
    """
    This function receives 2 matrices of shapes (n_observations, n_variables) and computes the average Pearson correlation.
    To do that, it takes the i-th column of each matrix and computes the Pearson correlation between them.
    It finally returns the average of all the Pearson correlations computed.

    Args:
        gt_mat (torch.Tensor): Ground truth matrix of shape (n_observations, n_variables).
        pred_mat (torch.Tensor): Predicted matrix of shape (n_observations, n_variables).
        mask (torch.Tensor): Boolean mask with False in positions that must be ignored in metric computation (n_observations, n_variables).
    
    Returns:
        mean_pcc (float): Mean Pearson correlation computed by averaging the Pearson correlation for each patch.
        detalied_pcc (list): List of pcc for each one of the columns
    """
    masked_gt_mat = torch.masked.masked_tensor(gt_mat, mask=mask)
    masked_gt_mean = masked_gt_mat.mean(dim=0, keepdim=True)

    masked_pred_mat = torch.masked.masked_tensor(pred_mat, mask=mask)
    masked_pred_mean = masked_pred_mat.mean(dim=0, keepdim=True)

    # Construct matrices with only masked means
    masked_gt_mean = masked_gt_mean.to_tensor(float('nan')).repeat(gt_mat.shape[0],1)
    masked_pred_mean = masked_pred_mean.to_tensor(float('nan')).repeat(pred_mat.shape[0],1)

    # Find if there are any columns completely masked
    nan_columns = torch.isnan(masked_gt_mean).all(dim=0)

    # Modify mask==False entries of gt_mat and pred_mat to the masked mean. 
    # NOTE: This replace will make the computation of the metric efficient without taking into account the discarded values of the mask
    gt_mat = torch.where(mask==True, gt_mat, masked_gt_mean)
    pred_mat = torch.where(mask==True, pred_mat, masked_pred_mean)

    # Center both matrices by subtracting the mean of each column
    centered_gt_mat = gt_mat - masked_gt_mean
    centered_pred_mat = pred_mat - masked_pred_mean

    # Remove columns that are completely masked
    centered_gt_mat = centered_gt_mat[:, ~nan_columns]
    centered_pred_mat = centered_pred_mat[:, ~nan_columns]

    # Compute pearson correlation with cosine similarity
    pcc = torch.nn.functional.cosine_similarity(centered_gt_mat, centered_pred_mat, dim=0)

    # Compute mean pearson correlation (the nan mean is to ensure metric computation even when a complete patch is masked)
    mean_pcc = pcc.nanmean().item()
    # Get the list of pccs
    detailed_pcc = pcc.tolist()
    
    return mean_pcc, detailed_pcc

def pearsonr_gene(gt_mat: torch.Tensor, pred_mat: torch.Tensor, mask: torch.Tensor) -> Tuple[float, list]:
    """
    This function uses pearsonr_cols to compute the Pearson correlation between the ground truth and predicted matrices along
    the gene dimension. It is computing the correlation between the true and predicted values for each gene and returning the average of all.

    Args:
        gt_mat (torch.Tensor): Ground truth matrix of shape (n_samples, n_genes).
        pred_mat (torch.Tensor): Predicted matrix of shape (n_samples, n_genes).
        mask (torch.Tensor): Boolean mask with False in positions that must be ignored in metric computation (n_samples, n_genes).

    Returns:
        mean_pcc (float): Mean Pearson correlation computed by averaging the Pearson correlation for each gene.
        detalied_pcc (list): List of pcc for each one of the genes
    """

    mean_pcc, detalied_pcc = pearsonr_cols(gt_mat=gt_mat, pred_mat=pred_mat, mask=mask)

    return mean_pcc, detalied_pcc

def pearsonr_patch(gt_mat: torch.Tensor, pred_mat: torch.Tensor, mask: torch.Tensor) -> Tuple[float, list]:
    """
    This function uses pearsonr_cols to compute the Pearson correlation between the ground truth and predicted matrices along
    the patch dimension. It is computing the correlation the between true and predicted values for each patch and returning the average of all.

    Args:
        gt_mat (torch.Tensor): Ground truth matrix of shape (n_samples, n_genes).
        pred_mat (torch.Tensor): Predicted matrix of shape (n_samples, n_genes).
        mask (torch.Tensor): Boolean mask with False in positions that must be ignored in metric computation (n_samples, n_genes).

    Returns:
        mean_pcc (float): Mean Pearson correlation computed by averaging the Pearson correlation for each patch.
        detalied_pcc (list): List of pcc for each one of the patches
    """
    # Transpose matrices and apply pearsonr_torch_cols 
    mean_pcc, detalied_pcc = pearsonr_cols(gt_mat=gt_mat.T, pred_mat=pred_mat.T, mask=mask.T)

    return mean_pcc, detalied_pcc

def r2_score_cols(gt_mat: torch.Tensor, pred_mat: torch.Tensor, mask: torch.Tensor) -> Tuple[float, list]:
    """
    This function receives 2 matrices of shapes (n_observations, n_variables) and computes the average R2 score.
    To do that, it takes the i-th column of each matrix and computes the R2 score between them.
    It finally returns the average of all the R2 scores computed.

    Args:
        gt_mat (torch.Tensor): Ground truth matrix of shape (n_observations, n_variables).
        pred_mat (torch.Tensor): Predicted matrix of shape (n_observations, n_variables).
        mask (torch.Tensor): Boolean mask with False in positions that must be ignored in metric computation (n_observations, n_variables).

    Returns:
        mean_r2_score (float): Mean R2 score computed by averaging the R2 score for each column in the matrices.
        detalied_r2_score (list): List of r2 scores for each one of the columns
    """
    # Pass input matrices to masked tensors
    gt_mat = torch.masked.masked_tensor(gt_mat, mask=mask)
    pred_mat = torch.masked.masked_tensor(pred_mat, mask=mask)

    # Remove columns with a single value without masking (these columns make R2 go to infinity)
    single_value_columns =  mask.sum(axis=0)==1
    gt_mat = gt_mat[:, ~single_value_columns]
    pred_mat = pred_mat[:, ~single_value_columns]

    # Compute the column means of the ground truth
    gt_col_means = gt_mat.mean(dim=0).to_tensor(float('nan'))

    # Find if there are any columns completely masked
    nan_columns = torch.isnan(gt_col_means)
    
    # Compute the total sum of squares
    total_sum_squares = torch.sum(torch.square(gt_mat - gt_col_means), dim=0).to_tensor(float('nan'))

    # Compute the residual sum of squares
    residual_sum_squares = torch.sum(torch.square(gt_mat - pred_mat), dim=0).to_tensor(float('nan'))

    # Remove columns that are completely masked
    total_sum_squares = total_sum_squares[~nan_columns]
    residual_sum_squares = residual_sum_squares[~nan_columns]

    # Compute the R2 score for each column
    r2_scores = 1. - (residual_sum_squares / total_sum_squares)

    # Compute the mean R2 score
    mean_r2_score = r2_scores.mean().item()
    detalied_r2_score = r2_scores.tolist()

    return mean_r2_score, detalied_r2_score

def r2_score_gene(gt_mat: torch.Tensor, pred_mat: torch.Tensor, mask: torch.Tensor) -> Tuple[float, list]:
    """
    This function uses r2_score_cols to compute the R2 score between the ground truth and predicted matrices along
    the gene dimension. It is computing the R2 score between the true and predicted values for each gene and returning the average of all.

    Args:
        gt_mat (torch.Tensor): Ground truth matrix of shape (n_samples, n_genes).
        pred_mat (torch.Tensor): Predicted matrix of shape (n_samples, n_genes).
        mask (torch.Tensor): Boolean mask with False in positions that must be ignored in metric computation (n_samples, n_genes).

    Returns:
        mean_r2_score (float): Mean R2 score computed by averaging the R2 score for each gene.
        detalied_r2_score (list): List of r2 scores for each one of the genes
    """

    mean_r2_score, detalied_r2_score = r2_score_cols(gt_mat=gt_mat, pred_mat=pred_mat, mask=mask)

    return mean_r2_score, detalied_r2_score

def r2_score_patch(gt_mat: torch.Tensor, pred_mat: torch.Tensor, mask: torch.Tensor) -> Tuple[float, list]:
    """
    This function uses r2_score_cols to compute the R2 score between the ground truth and predicted matrices along
    the patch dimension. It is computing the R2 score between the true and predicted values for each patch and returning the average of all.

    Args:
        gt_mat (torch.Tensor): Ground truth matrix of shape (n_samples, n_genes).
        pred_mat (torch.Tensor): Predicted matrix of shape (n_samples, n_genes).
        mask (torch.Tensor): Boolean mask with False in positions that must be ignored in metric computation (n_samples, n_genes).

    Returns:
        mean_r2_score (float): Mean R2 score computed by averaging the R2 score for each patch.
        detalied_r2_score (list): List of r2 scores for each one of the patches.
    """
    
    # Transpose matrices and apply r2_score_torch_cols
    mean_r2_score, detalied_r2_score = r2_score_cols(gt_mat=gt_mat.T, pred_mat=pred_mat.T, mask=mask.T)
    
    return mean_r2_score, detalied_r2_score

[docs]def get_pearsonr(gt_mat: torch.Tensor, pred_mat: torch.Tensor, mask: torch.Tensor, axis:int) -> Tuple[float, list]: """ This function receives 2 matrices of shapes (n_observations, n_variables) and computes the average Pearson correlation. To do that, it takes the i-th column of each matrix and computes the Pearson correlation between them. It finally returns the average of all the Pearson correlations computed. Args: gt_mat (torch.Tensor): Ground truth matrix of shape (n_observations, n_variables). pred_mat (torch.Tensor): Predicted matrix of shape (n_observations, n_variables). mask (torch.Tensor): Boolean mask with False in positions that must be ignored in metric computation (n_observations, n_variables). axis (int): wether to compute the pcc by columns (axis=0) ir by rows (axis=1) Returns: mean_pcc (float): Mean Pearson correlation computed by averaging the Pearson correlation for each patch. detalied_pcc (list): List of pcc for each one of the columns """ masked_gt_mat = torch.masked.masked_tensor(gt_mat, mask=mask) masked_gt_mean = masked_gt_mat.mean(dim=axis, keepdim=True) masked_pred_mat = torch.masked.masked_tensor(pred_mat, mask=mask) masked_pred_mean = masked_pred_mat.mean(dim=axis, keepdim=True) # Construct matrices with only masked means # By columns if axis == 0: masked_gt_mean = masked_gt_mean.to_tensor(float('nan')).repeat(gt_mat.shape[0],1) masked_pred_mean = masked_pred_mean.to_tensor(float('nan')).repeat(pred_mat.shape[0],1) # By rows elif axis == 1: masked_gt_mean = masked_gt_mean.to_tensor(float('nan')).repeat(1,gt_mat.shape[1]) masked_pred_mean = masked_pred_mean.to_tensor(float('nan')).repeat(1,pred_mat.shape[1]) # Find if there are any columns completely masked nan_axis = torch.isnan(masked_gt_mean).all(dim=axis) # Modify mask==False entries of gt_mat and pred_mat to the masked mean. # NOTE: This replace will make the computation of the metric efficient without taking into account the discarded values of the mask gt_mat = torch.where(mask==True, gt_mat, masked_gt_mean) pred_mat = torch.where(mask==True, pred_mat, masked_pred_mean) # Center both matrices by subtracting the mean of each column centered_gt_mat = gt_mat - masked_gt_mean centered_pred_mat = pred_mat - masked_pred_mean # Remove columns that are completely masked if axis == 0: centered_gt_mat = centered_gt_mat[:, ~nan_axis] centered_pred_mat = centered_pred_mat[:, ~nan_axis] # Remove rows that are completely masked if axis == 1: centered_gt_mat = centered_gt_mat[~nan_axis, :] centered_pred_mat = centered_pred_mat[~nan_axis, :] # Compute pearson correlation with cosine similarity pcc = torch.nn.functional.cosine_similarity(centered_gt_mat, centered_pred_mat, dim=axis) # Compute mean pearson correlation (the nan mean is to ensure metric computation even when a complete patch is masked) mean_pcc = pcc.nanmean().item() # Get the list of pccs detailed_pcc = pcc.tolist() return mean_pcc, detailed_pcc
[docs]def get_r2_score(gt_mat: torch.Tensor, pred_mat: torch.Tensor, mask: torch.Tensor, axis=int) -> Tuple[float, list]: """ This function receives 2 matrices of shapes (n_observations, n_variables) and computes the average R2 score. To do that, it takes the i-th column of each matrix and computes the R2 score between them. It finally returns the average of all the R2 scores computed. Args: gt_mat (torch.Tensor): Ground truth matrix of shape (n_observations, n_variables). pred_mat (torch.Tensor): Predicted matrix of shape (n_observations, n_variables). mask (torch.Tensor): Boolean mask with False in positions that must be ignored in metric computation (n_observations, n_variables). axis (int): wether to compute the pcc by columns (axis=0) ir by rows (axis=1) Returns: mean_r2_score (float): Mean R2 score computed by averaging the R2 score for each column in the matrices. detalied_r2_score (list): List of r2 scores for each one of the columns """ # Pass input matrices to masked tensors gt_mat = torch.masked.masked_tensor(gt_mat, mask=mask) pred_mat = torch.masked.masked_tensor(pred_mat, mask=mask) single_value_axis = mask.sum(axis=axis)==1 # Remove columns with a single value without masking (these columns make R2 go to infinity) if axis == 0: gt_mat = gt_mat[:, ~single_value_axis] pred_mat = pred_mat[:, ~single_value_axis] # Remove rows with a single value without masking (these columns make R2 go to infinity) elif axis == 1: gt_mat = gt_mat[~single_value_axis, :] pred_mat = pred_mat[~single_value_axis, :] # Compute the axis means of the ground truth gt_axis_means = gt_mat.mean(dim=axis, keepdim=True).to_tensor(float('nan')) # Find if there are any columns or rows completely masked nan_axis = torch.isnan(gt_axis_means).squeeze(dim=axis) # Compute the total sum of squares total_sum_squares = torch.sum(torch.square(gt_mat - gt_axis_means), dim=axis).to_tensor(float('nan')) # Compute the residual sum of squares residual_sum_squares = torch.sum(torch.square(gt_mat - pred_mat), dim=axis).to_tensor(float('nan')) # Remove rows or columns that are completely masked total_sum_squares = total_sum_squares[~nan_axis] residual_sum_squares = residual_sum_squares[~nan_axis] # Compute the R2 score for each row or column r2_scores = 1. - (residual_sum_squares / total_sum_squares) # Compute the mean R2 score mean_r2_score = r2_scores.mean().item() detalied_r2_score = r2_scores.tolist() return mean_r2_score, detalied_r2_score
[docs]def get_metrics(gt_mat: Union[np.array, torch.Tensor], pred_mat: Union[np.array, torch.Tensor], mask: Union[np.array, torch.Tensor], detailed: bool = False) -> dict: """ Get general regression metrics This function receives 2 matrices of shapes (n_samples, n_genes) and computes the following metrics: - Pearson correlation (gene-wise) [PCC-Gene] - Pearson correlation (patch-wise) [PCC-Patch] - r2 score (gene-wise) [R2-Gene] - r2 score (patch-wise) [R2-Patch] - Mean squared error [MSE] - Mean absolute error [MAE] - Global metric [Global] (Global = PCC-Gene + R2-Gene + PCC-Patch + R2-Patch - MAE - MSE) If detailed == True. Then the function returns these aditional keys (all of them are numpy arrays): - Individual pearson correlation for every gene [PPC-Gene-detailed] - Individual pearson correlation for every patch [PPC-Patch-detailed] - Individual r2 score for every gene [R2-Gene-detailed] - Individual r2 score for every patch [R2-Gene-detailed] - Individual MSE for every gene [detailed_mse_gene] - Individual MAE for every gene [detailed_mae_gene] - Individual average error for every gene [detailed_error_gene] - Flat concatenation of all errors in valid positions [detailed_errors] Args: gt_mat (Union[np.array, torch.Tensor]): Ground truth matrix of shape (n_samples, n_genes). pred_mat (Union[np.array, torch.Tensor]): Predicted matrix of shape (n_samples, n_genes). mask (Union[np.array, torch.Tensor]): Boolean mask with False in positions that must be ignored in metric computation (n_samples, n_genes). detailed (bool): If True, the dictionary also includes the detailed metrics. Returns: dict: Dictionary containing the metrics computed. The keys are: ['PCC-Gene', 'PCC-Patch', 'R2-Gene', 'R2-Patch', 'MSE', 'MAE', 'Global'] """ # Assert that all matrices have the same shape assert gt_mat.shape == pred_mat.shape, "gt_mat and pred_mat matrices must have the same shape." assert gt_mat.shape == mask.shape, "gt_mat and mask matrices must have the same shape." # If input are numpy arrays, convert them to torch tensors if isinstance(gt_mat, np.ndarray): gt_mat = torch.from_numpy(gt_mat) if isinstance(pred_mat, np.ndarray): pred_mat = torch.from_numpy(pred_mat) if isinstance(mask, np.ndarray): mask = torch.from_numpy(mask) # Get boolean indicating constant columns in predicted matrix # NOTE: A constant gene prediction will mess with the pearson correlation constant_cols = torch.all(pred_mat == pred_mat[[0],:], axis = 0) # Get boolean indicating if there are any constant columns any_constant_cols = torch.any(constant_cols) # Get boolean indicating constant rows in predicted matrix # NOTE: A constant patch prediction will mess with the pearson correlation constant_rows = torch.all(pred_mat == pred_mat[:,[0]], axis = 1) # Get boolean indicating if there are any constant rows any_constant_rows = torch.any(constant_rows) # If there are any constant columns, set the pcc_g and r2_g to None if any_constant_cols: pcc_g, detailed_pcc_g = np.nan, np.nan print(f"There are {constant_cols.sum().item()} constant columns in the predicted matrix. This means a gene is being predicted as constant. The Pearson correlation (gene-wise) will be set to NaN.") else: # Compute Pearson correlation (gene-wise) pcc_g, detailed_pcc_g = pearsonr_gene(gt_mat, pred_mat, mask=mask) # If there are any constant rows, set the pcc_p and r2_p to None if any_constant_rows: pcc_p, detailed_pcc_p = np.nan, np.nan print(f"There are {constant_rows.sum().item()} constant rows in the predicted matrix. This means a patch is being predicted as constant. The Pearson correlation (patch-wise) will be set to NaN.") else: # Compute Pearson correlation (patch-wise) pcc_p, detailed_pcc_p = pearsonr_patch(gt_mat, pred_mat, mask=mask) # Compute r2 score (gene-wise) r2_g, detailed_r2_g = r2_score_gene(gt_mat, pred_mat, mask=mask) # Compute r2 score (patch-wise) r2_p, detailed_r2_p = r2_score_patch(gt_mat, pred_mat, mask=mask) # Compute mean squared error mse = torch.nn.functional.mse_loss(gt_mat[mask], pred_mat[mask], reduction='mean').item() # Compute mean absolute error mae = torch.nn.functional.l1_loss(gt_mat[mask], pred_mat[mask], reduction='mean').item() # Compute detailed error metrics (only at gene level because patch level usualy gives empty patches) errors = pred_mat - gt_mat errors[~mask] = torch.nan sq_errors = torch.square(errors) detailed_mse_gene = sq_errors.nanmean(dim=0).tolist() detailed_mae_gene = torch.abs(errors).nanmean(dim=0).tolist() detailed_error_gene = errors.nanmean(dim=0).tolist() detailed_errors = errors[mask].tolist() # Create dictionary with the metrics computed metrics_dict = { 'PCC-Gene': pcc_g, 'PCC-Patch': pcc_p, 'R2-Gene': r2_g, 'R2-Patch': r2_p, 'MSE': mse, 'MAE': mae, 'Global': pcc_g + pcc_p + r2_g + r2_p - mse - mae } # If detailed metrics are desired then add to the metric dict the PCCs and R2s for every gene and patch if detailed==True: detailed_metrics_dict = { 'detailed_PCC-Gene': detailed_pcc_g, 'detailed_PCC-Patch': detailed_pcc_p, 'detailed_R2-Gene': detailed_r2_g, 'detailed_R2-Patch': detailed_r2_p, 'detailed_mse_gene': detailed_mse_gene, 'detailed_mae_gene': detailed_mae_gene, 'detailed_error_gene': detailed_error_gene, 'detailed_errors': detailed_errors, } # Update metric dict metrics_dict = {**metrics_dict, **detailed_metrics_dict} return metrics_dict
# Here we have some testing code if __name__=='__main__': # Set number of observations and genes (hypothetical) obs = 7777 genes = 256 imputed_fraction = 0.26 # This is the percentage of zeros in the mask # Henerate random matrices pred = torch.randn((obs,genes)) gt = torch.randn((obs,genes)) mask = torch.rand((obs,genes))>imputed_fraction # Compute metrics with the fast way (efficient implementation) print('Fast metrics'+'-'*40) start = time() test_metrics = get_metrics(gt, pred, mask=mask) print("Time taken: {:5.2f}s".format(time()-start)) for key, val in test_metrics.items(): print("{} = {:5.7f}".format(key, val)) # Compute metrics with the slow way (inefficient implementation but secure) print('Slow metrics'+'-'*40) start = time() slow_test_metrics = slow_get_metrics(gt, pred, mask=mask) print("Time taken: {:5.2f}s".format(time()-start)) for key, val in slow_test_metrics.items(): print("{} = {:5.7f}".format(key, val))