import torch
import torch.nn as nn
from torchvision import models
import numpy as np
import lightning as L
from torchvision.transforms import Compose, RandomApply, RandomHorizontalFlip, RandomRotation, RandomVerticalFlip, Normalize
from metrics import get_metrics
[docs]class ImageEncoder(torch.nn.Module):
[docs] def __init__(self, backbone, use_pretrained, latent_dim):
super(ImageEncoder, self).__init__()
self.backbone = backbone
self.use_pretrained = use_pretrained
self.latent_dim = latent_dim
# Initialize the model using various options
self.encoder, self.input_size = self.initialize_model()
def initialize_model(self):
# Initialize these variables which will be set in this if statement. Each of these
# variables is model specific.
model_ft = None
model_weights = 'IMAGENET1K_V1' if self.use_pretrained else None
input_size = 0
if self.backbone == "resnet": ##
""" Resnet18 acc@1 (on ImageNet-1K): 69.758
"""
model_ft = models.resnet18(weights=model_weights) #Get model
num_ftrs = model_ft.fc.in_features #Get in features of the fc layer (final layer)
model_ft.fc = nn.Linear(num_ftrs, self.latent_dim) #Keep in features, but modify out features for self.latent_dim
input_size = 224 #Set input size of each image
elif self.backbone == "resnet50":
""" Resnet50 acc@1 (on ImageNet-1K): 76.13
"""
model_ft = models.resnet50(weights=model_weights)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "ConvNeXt":
""" ConvNeXt tiny acc@1 (on ImageNet-1K): 82.52
"""
model_ft = models.convnext_tiny(weights=model_weights)
num_ftrs = model_ft.classifier[2].in_features
model_ft.classifier[2] = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "EfficientNetV2":
""" EfficientNetV2 small acc@1 (on ImageNet-1K): 84.228
"""
model_ft = models.efficientnet_v2_s(weights=model_weights)
num_ftrs = model_ft.classifier[1].in_features
model_ft.classifier[1] = nn.Linear(num_ftrs, self.latent_dim)
input_size = 384
elif self.backbone == "InceptionV3":
""" InceptionV3 acc@1 (on ImageNet-1K): 77.294
"""
model_ft = models.inception_v3(weights=model_weights)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, self.latent_dim)
input_size = 299
elif self.backbone == "MaxVit":
""" MaxVit acc@1 (on ImageNet-1K): 83.7
"""
model_ft = models.maxvit_t(weights=model_weights)
num_ftrs = model_ft.classifier[5].in_features
model_ft.classifier[5] = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "MobileNetV3":
""" MobileNet V3 acc@1 (on ImageNet-1K): 67.668
"""
model_ft = models.mobilenet_v3_small(weights=model_weights)
num_ftrs = model_ft.classifier[3].in_features
model_ft.classifier[3] = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "ResNetXt":
""" ResNeXt-50 32x4d acc@1 (on ImageNet-1K): 77.618
"""
model_ft = models.resnext50_32x4d(weights=model_weights)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "ShuffleNetV2":
""" ShuffleNetV2 acc@1 (on ImageNet-1K): 60.552
"""
model_ft = models.shufflenet_v2_x0_5(weights=model_weights)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "ViT":
""" Vision Transformer acc@1 (on ImageNet-1K): 81.072
"""
model_ft = models.vit_b_16(weights=model_weights)
num_ftrs = model_ft.heads.head.in_features
model_ft.heads.head = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "WideResNet":
""" Wide ResNet acc@1 (on ImageNet-1K): 78.468
"""
model_ft = models.wide_resnet50_2(weights=model_weights)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "densenet":
""" Densenet acc@1 (on ImageNet-1K): 74.434
"""
model_ft = models.densenet121(weights=model_weights)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "swin":
""" Swin Transformer tiny acc@1 (on ImageNet-1K): 81.474
"""
model_ft = models.swin_t(weights=model_weights)
num_ftrs = model_ft.head.in_features
model_ft.head = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
else:
print("Invalid model name, exiting...")
exit()
return model_ft, input_size
def forward(self, tissue_tiles):
latent_space = self.encoder(tissue_tiles)
return latent_space
class ImageBackbone(L.LightningModule):
def __init__(self, args, latent_dim):
super(ImageBackbone, self).__init__()
# Define normal hyperparameters
self.save_hyperparameters()
self.args = args
self.backbone = args.img_backbone
self.use_pretrained = args.img_use_pretrained
self.latent_dim = latent_dim
# Define image transformations
self.train_transforms = Compose([Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
RandomHorizontalFlip(p=0.5),
RandomVerticalFlip(p=0.5),
RandomApply([RandomRotation((90, 90))], p=0.5)])
if args.average_test:
self.test_transforms = Compose([Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), EightSymmetry()])
else:
self.test_transforms = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# Define loss criterion
self.criterion = torch.nn.MSELoss()
# Initialize the model using various options
self.encoder, self.input_size = self.initialize_model()
# Define outputs of the validation, test and train step
self.validation_step_outputs = []
self.training_step_outputs = []
self.test_step_outputs = []
# Auxiliary variables to log best metrics
self.best_metrics = None
min_max_metric_dict = {'PCC-Gene': 'max', 'PCC-Patch': 'max', 'MSE': 'min', 'MAE': 'min', 'R2-Gene': 'max', 'R2-Patch': 'max', 'Global': 'max'}
self.metric_objective = min_max_metric_dict[self.args.optim_metric]
def initialize_model(self):
# Initialize these variables which will be set in this if statement. Each of these
# variables is model specific.
model_ft = None
model_weights = 'IMAGENET1K_V1' if self.use_pretrained else None
input_size = 0
if self.backbone == "resnet": ##
""" Resnet18 acc@1 (on ImageNet-1K): 69.758
"""
model_ft = models.resnet18(weights=model_weights) #Get model
num_ftrs = model_ft.fc.in_features #Get in features of the fc layer (final layer)
model_ft.fc = nn.Linear(num_ftrs, self.latent_dim) #Keep in features, but modify out features for self.latent_dim
input_size = 224 #Set input size of each image
elif self.backbone == "ConvNeXt":
""" ConvNeXt tiny acc@1 (on ImageNet-1K): 82.52
"""
model_ft = models.convnext_tiny(weights=model_weights)
num_ftrs = model_ft.classifier[2].in_features
model_ft.classifier[2] = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "EfficientNetV2":
""" EfficientNetV2 small acc@1 (on ImageNet-1K): 84.228
"""
model_ft = models.efficientnet_v2_s(weights=model_weights)
num_ftrs = model_ft.classifier[1].in_features
model_ft.classifier[1] = nn.Linear(num_ftrs, self.latent_dim)
input_size = 384
elif self.backbone == "InceptionV3":
""" InceptionV3 acc@1 (on ImageNet-1K): 77.294
"""
model_ft = models.inception_v3(weights=model_weights)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, self.latent_dim)
input_size = 299
elif self.backbone == "MaxVit":
""" MaxVit acc@1 (on ImageNet-1K): 83.7
"""
model_ft = models.maxvit_t(weights=model_weights)
num_ftrs = model_ft.classifier[5].in_features
model_ft.classifier[5] = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "MobileNetV3":
""" MobileNet V3 acc@1 (on ImageNet-1K): 67.668
"""
model_ft = models.mobilenet_v3_small(weights=model_weights)
num_ftrs = model_ft.classifier[3].in_features
model_ft.classifier[3] = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "ResNetXt":
""" ResNeXt-50 32x4d acc@1 (on ImageNet-1K): 77.618
"""
model_ft = models.resnext50_32x4d(weights=model_weights)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "ShuffleNetV2":
""" ShuffleNetV2 acc@1 (on ImageNet-1K): 60.552
"""
model_ft = models.shufflenet_v2_x0_5(weights=model_weights)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "ViT":
""" Vision Transformer acc@1 (on ImageNet-1K): 81.072
"""
model_ft = models.vit_b_16(weights=model_weights)
num_ftrs = model_ft.heads.head.in_features
model_ft.heads.head = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "WideResNet":
""" Wide ResNet acc@1 (on ImageNet-1K): 78.468
"""
model_ft = models.wide_resnet50_2(weights=model_weights)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "densenet":
""" Densenet acc@1 (on ImageNet-1K): 74.434
"""
model_ft = models.densenet121(weights=model_weights)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
elif self.backbone == "swin":
""" Swin Transformer tiny acc@1 (on ImageNet-1K): 81.474
"""
model_ft = models.swin_t(weights=model_weights)
num_ftrs = model_ft.head.in_features
model_ft.head = nn.Linear(num_ftrs, self.latent_dim)
input_size = 224
# elif self.backbone == "densenet121-kather100k":
# """ Densenet F1 (Kather-100k): 0.993
# """
# histo_model = PatchPredictor(pretrained_model="densenet121-kather100k", batch_size=self.args.batch_size).model
# num_ftrs = histo_model.classifier.in_features
# histo_model.classifier = nn.Linear(num_ftrs, self.latent_dim)
# histo_state_dict = histo_model.state_dict()
# # Replace some keys that are modified by the TIA toolbox
# histo_state_dict = {k.replace('feat_extract', 'features'): v for k, v in histo_state_dict.items()}
# model_ft = models.densenet121()
# model_ft.classifier = nn.Linear(num_ftrs, self.latent_dim)
# model_ft.load_state_dict(histo_state_dict)
# # NOTE This is for linear probing
# # # Freeze all layers except the last one
# # for param in model_ft.parameters():
# # param.requires_grad = False
# # for param in model_ft.classifier.parameters():
# # param.requires_grad = True
# input_size = 224
# else:
# print("Invalid model name, exiting...")
# exit()
return model_ft, input_size
def forward(self, tissue_tiles):
latent_space = self.encoder(tissue_tiles)
return latent_space
def find_batch_patch_key(self, batch):
# Find the key of dataset.obsm that contains the patches
patch_key = [k for k in batch.obsm.keys() if 'patches' in k]
# Assert that there is only one key
assert len(patch_key) == 1, 'There should be only one key with patches in data.obsm'
patch_key = patch_key[0]
return patch_key
def pred_outputs_from_batch(self, batch):
# Get the patch key in the batch
patch_key = self.find_batch_patch_key(batch)
# Get (and reshape) images from batch
tissue_tiles = batch.obsm[patch_key]
w = round(np.sqrt(tissue_tiles.shape[1]/3))
tissue_tiles = tissue_tiles.reshape((tissue_tiles.shape[0], w, w, -1))
# Permute dimensions to be in correct order for normalization
tissue_tiles = tissue_tiles.permute(0,3,1,2).contiguous()
# Make transformations in tissue tiles
tissue_tiles = tissue_tiles/255.
# Transform tiles
tissue_tiles = self.test_transforms(tissue_tiles)
# Get groundtruth of expression
expression_gt = batch.X
# Get output of the model
# If tissue tiles is tuple then we will compute outputs of the 8 symmetries and then average them for prediction
if isinstance(tissue_tiles, tuple):
pred_list = [self.forward(tissue_rot) for tissue_rot in tissue_tiles]
pred_stack = torch.stack(pred_list)
expression_pred = pred_stack.mean(dim=0)
# If tissue tiles is not tuple then a single prediction is done with the original image
else:
expression_pred = self.forward(tissue_tiles)
# Handle delta vs absolute prediction with means
# If the adata object has a used mean attribute then we will use it to unnormalize the data
general_adata = batch.adatas[0]
if 'used_mean' in general_adata.var.keys():
means = general_adata.var['used_mean'].values
# Pass means to torch tensor in the same device as the model
means = torch.tensor(means, device=expression_gt.device)
# Unnormalize data and predictions
expression_gt = expression_gt+means
expression_pred = expression_pred+means
# Get boolean mask
mask = torch.Tensor(batch.layers['mask']).to(expression_gt.device).bool()
return expression_pred, expression_gt, mask
def training_step(self, batch):
# Get the patch key in the batch
# FIXME: Automate this with if self.glob_step == 0: or something simmilar
patch_key = self.find_batch_patch_key(batch)
# Get (and reshape) images from batch
tissue_tiles = batch.obsm[patch_key]
w = round(np.sqrt(tissue_tiles.shape[1]/3))
tissue_tiles = tissue_tiles.reshape((tissue_tiles.shape[0], w, w, -1))
# Permute dimensions to be in correct order for normalization
tissue_tiles = tissue_tiles.permute(0,3,1,2).contiguous()
# Make transformations in tissue tiles
tissue_tiles = tissue_tiles/255.
# Transform tiles
tissue_tiles = self.train_transforms(tissue_tiles)
# Get groundtruth of expression
expression_gt = batch.X
# Get boolean mask
mask = torch.Tensor(batch.layers['mask']).to(expression_gt.device).bool()
# Get output of the model
expression_pred = self.forward(tissue_tiles)
# Compute expression MSE loss (handle case to ignore zeros)
if self.args.robust_loss == True:
real_gt, real_pred = expression_gt[mask], expression_pred[mask]
loss = self.criterion(real_gt, real_pred)
else:
loss = self.criterion(expression_gt, expression_pred)
train_log_dict = {'train_loss': loss}
self.log_dict(train_log_dict, on_step=True)
# Append train step outputs
self.training_step_outputs.append((expression_pred, expression_gt, mask))
return loss
def on_train_epoch_end(self):
# Unpack the list of tuples
glob_expression_pred, glob_expression_gt, glob_mask = zip(*self.training_step_outputs)
# Concatenate outputs along the sample dimension
glob_expression_pred, glob_expression_gt, glob_mask = torch.cat(glob_expression_pred), torch.cat(glob_expression_gt), torch.cat(glob_mask)
# Get metrics and log
metrics = get_metrics(glob_expression_gt, glob_expression_pred, glob_mask)
# Put train prefix in metric dict
metrics = {f'train_{k}': v for k, v in metrics.items()}
self.log_dict(metrics, on_epoch=True)
# Free memory
self.training_step_outputs.clear()
def validation_step(self, batch):
# Get the outputs from the batch with generalistic function
expression_pred, expression_gt, mask = self.pred_outputs_from_batch(batch)
# Append validation step outputs
self.validation_step_outputs.append((expression_pred, expression_gt, mask))
return expression_pred, expression_gt, mask
def on_validation_epoch_end(self):
if self.trainer.sanity_checking:
# Free memory
self.validation_step_outputs.clear()
return
else:
# Unpack the list of tuples
glob_expression_pred, glob_expression_gt, glob_mask = zip(*self.validation_step_outputs)
# Concatenate outputs along the sample dimension
glob_expression_pred, glob_expression_gt, glob_mask = torch.cat(glob_expression_pred), torch.cat(glob_expression_gt), torch.cat(glob_mask)
# Get metrics and log
metrics = get_metrics(glob_expression_gt, glob_expression_pred, glob_mask)
# Put val prefix in metric dict
metrics = {f'val_{k}': v for k, v in metrics.items()}
# Auxiliar metric dict with a changed name to facilitate things. aux_metrics is not necesarily representing best metrics.
aux_metrics = {f'best_{k}': v for k, v in metrics.items()}
# Log best metrics
if self.best_metrics is None:
self.best_metrics = aux_metrics
else:
# Define metric name
metric_name = f'best_val_{self.args.optim_metric}'
# Determine if we got a new best model (robust to minimization or maximization of any metric)
got_best_min = (self.metric_objective == 'min') and (aux_metrics[metric_name] < self.best_metrics[metric_name])
got_best_max = (self.metric_objective == 'max') and (aux_metrics[metric_name] > self.best_metrics[metric_name])
# If we got a new best model, save it and log the metrics in wandb
if got_best_min or got_best_max:
self.best_metrics = aux_metrics
# Log metrics and best metrics in each validation step
self.log_dict({**metrics, **self.best_metrics})
# Free memory
self.validation_step_outputs.clear()
def test_step(self, batch):
# Get the outputs from the batch with generalistic function
expression_pred, expression_gt, mask = self.pred_outputs_from_batch(batch)
# Append validation step outputs
self.test_step_outputs.append((expression_pred, expression_gt, mask))
return expression_pred, expression_gt, mask
def on_test_epoch_end(self):
# Unpack the list of tuples
glob_expression_pred, glob_expression_gt, glob_mask = zip(*self.test_step_outputs)
# Concatenate outputs along the sample dimension
glob_expression_pred, glob_expression_gt, glob_mask = torch.cat(glob_expression_pred), torch.cat(glob_expression_gt), torch.cat(glob_mask)
# Get metrics and log
metrics = get_metrics(glob_expression_gt, glob_expression_pred, glob_mask)
# Put test prefix in metric dict
metrics = {f'test_{k}': v for k, v in metrics.items()}
self.log_dict(metrics, on_epoch=True)
# Free memory
self.test_step_outputs.clear()
def configure_optimizers(self):
try:
optimizer = getattr(torch.optim, self.args.optimizer)(self.parameters(), lr=self.args.lr, momentum=self.args.momentum)
except:
optimizer = getattr(torch.optim, self.args.optimizer)(self.parameters(), lr=self.args.lr)
return optimizer
class EightSymmetry(object):
"""Returns a tuple of the eight symmetries resulting from rotation and reflection.
This behaves similarly to TenCrop.
This transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your Dataset returns. See below for an example of how to deal with this.
Example:
transform = Compose([
EightSymmetry(), # this is a tuple of PIL Images
Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
])
"""
# This class function was taken fron the original ST-Net repository at:
# https://github.com/bryanhe/ST-Net/blob/43022c1cb7de1540d5a74ea2338a12c82491c5ad/stnet/transforms/eight_symmetry.py#L3
def __call__(self, img):
identity = lambda x: x
ans = []
for i in [identity, RandomHorizontalFlip(1)]:
for j in [identity, RandomVerticalFlip(1)]:
for k in [identity, RandomRotation((90, 90))]:
ans.append(i(j(k(img))))
return tuple(ans)
def __repr__(self):
return self.__class__.__name__ + "()"