Source code for arctic_ai.detection_workflows.follicle_detection

"""
Follicle Detection  
==========
Contains functions for detecting follicles in images.
"""

from skimage.draw import disk
from PIL import Image
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.structures import BoxMode
import os
import pandas as pd
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torch
import tifffile
import tqdm
import time
from scipy.special import softmax
import dask
from dask.diagnostics import ProgressBar
from pathpretrain.utils import load_image
import fire

white_square=np.ones((1024,1024,3),dtype=np.uint8)*255

def check_update(new_square):
    white_square_=new_square
    if not np.all(np.array(new_square.shape)==np.array((1024,1024,3))):
        white_square_=white_square.copy()
        white_square_[:new_square.shape[0],:new_square.shape[1]]=new_square
    return white_square_

def update_gnn_res(out_fname,gnn_res,alpha,patches_new,disk_mask,y_pred,y_tumor,tumor_thres,model,batch_size,num_workers):
    X=TensorDataset(torch.FloatTensor(patches_new[...,::-1].copy().astype("float32")).permute(0,3,1,2))#/255
    dataloader=DataLoader(X,shuffle=False,batch_size=batch_size,num_workers=num_workers)
    preds=[]
    with torch.no_grad():
        for x, in dataloader:#tqdm.tqdm(,total=len(dataloader.dataset)//dataloader.batch_size):
            preds.append(np.stack([y_pred['panoptic_seg'][0].cpu().numpy() for y_pred in model([{"image":im, "height":1024, "width":1024} for im in x.cuda()])]))
        preds=np.concatenate(preds,0)
    y_tumor_new=[max(y_tumor_pred-np.mean([alpha[j-1]*(follicle_pred[disk_mask==j]>0).mean() for j in range(1,4)]),0) for y_tumor_pred, follicle_pred in zip(y_tumor[y_tumor>tumor_thres],preds)] 
    y_benign=y_pred[:,0].copy()
    y_benign[y_tumor>tumor_thres]=y_benign[y_tumor>tumor_thres]+y_pred[y_tumor>tumor_thres,2]-y_tumor_new#verify, add, update with inverse 
    y_tumor[y_tumor>tumor_thres]=y_tumor_new
    gnn_res['y_pred_orig']=y_pred.copy()
    gnn_res['y_pred']=y_pred.copy()
    gnn_res['y_pred'][:,0]=y_benign 
    gnn_res['y_pred'][:,2]=y_tumor
    torch.save([gnn_res],out_fname)
    return None

[docs]def predict_hair_follicles(tumor_thres=0.3, patch_size=256, basename="340_A1a_ASAP", alpha_scale=2., alpha=[1.,2.,3.], model_path="model_final.pth", model_dir="./output", detectron_threshold=0.55, dirname="../../bcc_test_set/", ext=".tif", batch_size=16, num_workers=1): """ Predict hair follicles in a given tumor image using a pre-trained GNN model. Args: - tumor_thres (float): Threshold value for the tumor prediction probability (default 0.3). - patch_size (int): Size of the image patches used for prediction (default 256). - basename (str): Name of the input image file without the extension (default "340_A1a_ASAP"). - alpha_scale (float): Scaling factor for the alpha values used in GNN, reduces the tumor prediction probability (default 2.0). - alpha (list): List of 3 alpha values used in GNN to reduce tumor probability in presence of follicles (default [1., 2., 3.]). - model_path (str): Path to the pre-trained GNN model file (default "model_final.pth"). - model_dir (str): Directory path to save the GNN model (default "./output"). - detectron_threshold (float): Threshold value for object detection (default 0.55). - dirname (str): Directory path to the input image file (default "../../bcc_test_set/"). - ext (str): Extension of the input image file (default ".tif"). - batch_size (int): Batch size used for prediction (default 16). - num_workers (int): Number of worker processes for data loading (default 1). Returns: - None. """ threshold=detectron_threshold base_model="COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml" n=1 os.makedirs(os.path.join(dirname,"gnn_follicle_results"),exist_ok=True) cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file(base_model)) cfg.MODEL.ROI_HEADS.NUM_CLASSES = n cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = n cfg.OUTPUT_DIR=model_dir cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, model_path) cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold predictor = DefaultPredictor(cfg) model=predictor.model alpha=np.array(alpha)/sum(alpha)*alpha_scale d_patch=int(patch_size/2) disk_mask=np.zeros((1024,1024)) xx,yy=disk((512,512),512) disk_mask[xx,yy]=1 xx,yy=disk((512,512),256) disk_mask[xx,yy]=2 xx,yy=disk((512,512),128) disk_mask[xx,yy]=3 img_=load_image(os.path.join(dirname,"inputs",f"{basename}{ext}")) gnn_res_={} for bn in pd.read_pickle(os.path.join(dirname,"metadata",f"{basename}.pkl")): patch_info=pd.read_pickle(os.path.join(dirname,"patches",f"{bn}.pkl")) patches=np.load(os.path.join(dirname,"patches",f"{bn}.npy"))[patch_info['tumor_map'].values] patch_info=patch_info[patch_info['tumor_map'].values] gnn_res=torch.load(os.path.join(dirname,"gnn_results",f"{bn}_tumor_map.pkl"))[0] y_pred=softmax(gnn_res['y_pred'],1).copy() y_tumor=y_pred[:,2].copy() patches_new=[check_update(img_[x+d_patch-512:x+d_patch+512,y+d_patch-512:y+d_patch+512]) for x,y in tqdm.tqdm(patch_info[['x_orig','y_orig']][y_tumor>tumor_thres].values.tolist())] if len(patches_new): patches_new=np.stack(patches_new) gnn_res_[bn]=dask.delayed(update_gnn_res)(os.path.join(dirname,"gnn_follicle_results",f"{bn}_tumor_map.pkl"),gnn_res,alpha,patches_new,disk_mask,y_pred,y_tumor,tumor_thres,model,batch_size,num_workers) else: gnn_res_[bn]=[gnn_res] with ProgressBar(): gnn_res_=dask.compute(gnn_res_,scheduler="threading")[0]
if __name__=="__main__": fire.Fire(predict_hair_follicles)