Source code for arctic_ai.preprocessing

"""Preprocess 
==========
Contains functions for preprocessing images."""
import os, tqdm
import numpy as np, pandas as pd
from itertools import product
from scipy.ndimage.morphology import binary_fill_holes as fill_holes
from pathpretrain.utils import load_image, generate_tissue_mask
from scipy.sparse.csgraph import connected_components
from sklearn.neighbors import radius_neighbors_graph
from sklearn.cluster import SpectralClustering
from shapely.geometry import Point, MultiPoint
import cv2
import dask
from dask.diagnostics import ProgressBar
import alphashape
import warnings
# import pysnooper

[docs]def preprocess(basename="163_A1a", threshold=0.05, patch_size=256, ext='.npy', secondary_patch_size=0, alpha=1024**-1, no_break=False, df_section_pieces_file='', dirname=".", image_mask_compression=1., use_section=False ): """ Preprocesses image and generates patches for training or testing. Parameters ---------- basename: str The base filename of the image to be preprocessed. threshold: float, optional The maximum fraction of a patch that can be blank. Default is 0.05. patch_size: int, optional The size of the patches to be generated. Default is 256. ext: str, optional The file extension of the input image. Default is ".npy". no_break: bool, optional If True, the function will not break large images into multiple smaller ones. Default is False. df_section_pieces_file: str, optional The filename of the file containing metadata about image patches. Default is "section_pieces.pkl". image_mask_compression: float, optional The degree of compression applied to the image mask. Default is 8. dirname: str, optional The directory where input and output files are stored. Default is ".". """ assert secondary_patch_size==0 write_images=False os.makedirs(os.path.join(dirname,"masks"),exist_ok=True) os.makedirs(os.path.join(dirname,"patches"),exist_ok=True) os.makedirs(os.path.join(dirname,"images"),exist_ok=True) os.makedirs(os.path.join(dirname,"metadata"),exist_ok=True) image=os.path.join(dirname,"inputs",f"{basename}{ext}") basename=os.path.basename(image).replace(ext,'') image=load_image(image)#np.load(image) img_shape=image.shape[:-1] df_section_pieces=None if not (df_section_pieces_file and os.path.exists(df_section_pieces_file)) else pd.read_pickle(df_section_pieces_file).reset_index().drop_duplicates().set_index("index") masks=dict() masks['tumor_map']=generate_tissue_mask(image, compression=10, otsu=False, threshold=240, connectivity=8, kernel=5, min_object_size=100000, return_convex_hull=False, keep_holes=False, max_hole_size=6000, gray_before_close=True, blur_size=51) x_max,y_max=masks['tumor_map'].shape if no_break: masks['macro_map']=fill_holes(masks['tumor_map']) patch_info=dict() patches=dict() include_patches=dict() patch_info['orig']=pd.DataFrame([[basename,x,y,patch_size,"0"] for x,y in tqdm.tqdm(list(product(range(0,x_max-patch_size,patch_size),range(0,y_max-patch_size,patch_size))))],columns=['ID','x','y','patch_size','annotation']) for k in (masks if no_break else ['tumor_map']): patch_info[k]=patch_info['orig'].copy() include_patches[k]=np.stack([masks[k][x:x+patch_size,y:y+patch_size] for x,y in tqdm.tqdm(patch_info[k][['x','y']].values.tolist())]).mean((1,2))>=threshold patch_info[k]=patch_info[k][include_patches[k]] if no_break: patches[k]=np.stack([image[x:x+patch_size,y:y+patch_size] for x,y in tqdm.tqdm(patch_info[k][['x','y']].values.tolist())]) np.save(os.path.join(dirname,"masks",f"{basename}_{k}.npy"),masks[k]) np.save(os.path.join(dirname,"patches",f"{basename}_{k}.npy"),patches[include_patches]) patch_info[k].to_pickle(os.path.join(dirname,"patches",f"{basename}_{k}.pkl")) if no_break: return None if not no_break: if df_section_pieces is not None: n_pieces=int(np.prod(df_section_pieces.loc[basename.replace("_ASAP","")])) G=radius_neighbors_graph(patch_info['tumor_map'][['x','y']], radius=512*np.sqrt(2)) patch_info['tumor_map']['piece_ID']=connected_components(G)[1] if df_section_pieces is None: n_pieces=int(patch_info['tumor_map']['piece_ID'].max()+1) patch_info['tumor_map']['piece_ID']=patch_info['tumor_map']['piece_ID'].max()-patch_info['tumor_map']['piece_ID'] patch_info['tumor_map']=patch_info['tumor_map'][patch_info['tumor_map']['piece_ID'].isin(patch_info['tumor_map']['piece_ID'].value_counts().index[:n_pieces].values)] patch_info['tumor_map']['piece_ID']=patch_info['tumor_map']['piece_ID'].map({v:k for k,v in enumerate(sorted(patch_info['tumor_map']['piece_ID'].unique()))}) if df_section_pieces is not None: assert df_section_pieces.loc[basename.replace("_ASAP","")]['Pieces']<=2 G=radius_neighbors_graph(patch_info['tumor_map'][['x','y']], radius=4096*np.sqrt(2)) patch_info['tumor_map']['piece_ID']=patch_info['tumor_map']['piece_ID'].map(dict(zip(patch_info['tumor_map'].groupby("piece_ID")['x'].mean().sort_values(ascending=False).index,range(patch_info['tumor_map']['piece_ID'].max()+1)))) patch_info['tumor_map']['section_ID']=connected_components(G)[1] complete=patch_info['tumor_map'][['section_ID','piece_ID']].groupby("section_ID")['piece_ID'].nunique()==df_section_pieces.loc[basename.replace("_ASAP","")]['Pieces'] patch_info['tumor_map']['complete']=patch_info['tumor_map']['section_ID'].isin(complete[complete].index) while patch_info['tumor_map']['piece_ID'].max()+1<n_pieces: split_pieces=patch_info['tumor_map']['piece_ID'].value_counts().index for split_piece in split_pieces: if patch_info['tumor_map'].loc[patch_info['tumor_map']['piece_ID']==split_piece,'complete'].sum()==0: patch_info['tumor_map'].loc[patch_info['tumor_map']['piece_ID']==split_piece,'complete']=True # TODO: this can break break G=radius_neighbors_graph(patch_info['tumor_map'].query(f"piece_ID=={split_piece}")[['x','y']], radius=patch_size*np.sqrt(2)) cl=SpectralClustering(n_clusters=2,affinity="precomputed",assign_labels="discretize",eigen_solver="amg",n_components=2).fit_predict(G) patch_info['tumor_map'].loc[patch_info['tumor_map']['piece_ID']==split_piece,'piece_ID']=cl+patch_info['tumor_map']['piece_ID'].max()+1 patch_info['tumor_map']['piece_ID']=patch_info['tumor_map']['piece_ID'].map(dict(zip(patch_info['tumor_map'].groupby("piece_ID")['x'].mean().sort_values(ascending=False).index,range(patch_info['tumor_map']['piece_ID'].max()+1)))) patch_info['tumor_map']['section_ID']=patch_info['tumor_map']['piece_ID']//df_section_pieces.loc[basename.replace("_ASAP","")]['Pieces'] assert patch_info['tumor_map']['piece_ID'].max()+1==n_pieces else: G=radius_neighbors_graph(patch_info['tumor_map'][['x','y']], radius=4096*np.sqrt(2)) patch_info['tumor_map']['section_ID']=connected_components(G)[1] patch_info['tumor_map']['section_ID']=patch_info['tumor_map']['section_ID'].max()-patch_info['tumor_map']['section_ID'] n_sections=patch_info['tumor_map']['section_ID'].max()+1 n_pieces_per_section=(patch_info['tumor_map']['piece_ID'].max()+1)/n_sections pts=MultiPoint(patch_info['orig'][['x','y']].values) patch_info_new=[] for ID in patch_info['tumor_map']['piece_ID'].unique(): tmp_points=patch_info['tumor_map'][['x','y']][patch_info['tumor_map']['piece_ID']==ID].values alpha_shape = alphashape.alphashape(tmp_points,alpha=alpha) tmp_points=MultiPoint(tmp_points) xy=dict() xy['macro']=pts.intersection(alpha_shape).difference(alpha_shape.exterior.buffer(256)) xy['macro_tumor']=xy['macro'].intersection(tmp_points.buffer(64)) xy['macro_no_tumor']=xy['macro'].difference(tmp_points.buffer(64)) xy['tumor_no_macro']=tmp_points.difference(xy['macro'].buffer(64)) for k in list(xy.keys()): if isinstance(xy[k],Point) or len(xy[k].geoms)==0: del xy[k] del xy['macro'] xy={k:pd.DataFrame(np.array([(int(p.x),int(p.y)) for p in xy[k]]),columns=['x','y']) for k in xy} for k in xy: xy[k]['basename']=basename xy[k]['section_ID']=ID//n_pieces_per_section xy[k]['piece_ID']=ID xy[k]['patch_size']=patch_size xy[k]['Type']=k xy=pd.concat(list(xy.values()),axis=0) xy['tumor_map']=xy['Type'].isin(['macro_tumor','tumor_no_macro']) xy['macro_map']=xy['Type'].isin(['macro_tumor','macro_no_tumor']) patch_info_new.append(xy) patch_info=pd.concat(patch_info_new,axis=0) for coord in ['x','y']: patch_info[f'{coord}_orig']=patch_info[coord] xy_bounds={} write_files=[] new_basenames=[] for ID in patch_info['section_ID' if use_section else 'piece_ID'].unique(): new_basename=f"{basename}_{ID}" new_basenames.append(new_basename) include_patches=(patch_info['section_ID' if use_section else 'piece_ID']==ID).values patch_info_ID=patch_info[include_patches] (xmin,ymin),(xmax,ymax)=patch_info_ID[['x','y']].min(0).values,(patch_info_ID[['x','y']].max(0).values+patch_size) im=image[xmin:xmax,ymin:ymax] msk=masks['tumor_map'][xmin:xmax,ymin:ymax] # TODO: can break, need target tissue section patch_info_ID.loc[:,['x','y']]-=patch_info_ID[['x','y']].min(0) patches_ID=np.stack([im[x:x+patch_size,y:y+patch_size] for x,y in tqdm.tqdm(patch_info_ID[['x','y']].values.tolist())]) patch_info_ID.reset_index(drop=True).to_pickle(os.path.join(dirname,"patches",f"{new_basename}.pkl")) write_files.append(dask.delayed(np.save)(os.path.join(dirname,"patches",f"{new_basename}.npy"),patches_ID)) write_files.append(dask.delayed(np.save)(os.path.join(dirname,"masks",f"{new_basename}.npy"),cv2.resize(msk.astype(np.uint8),None,fx=1/image_mask_compression,fy=1/image_mask_compression,interpolation=cv2.INTER_NEAREST)>0 if image_mask_compression>1 else msk)) if write_images: if image_mask_compression>1: write_files.append(dask.delayed(np.save)(os.path.join(dirname,"images",f"{new_basename}.npy"),cv2.resize(im,None,fx=1/image_mask_compression,fy=1/image_mask_compression,interpolation=cv2.INTER_CUBIC))) else: write_files.append(dask.delayed(np.save)(os.path.join(dirname,"images",f"{new_basename}.npy"),im)) xy_bounds[ID]=((xmin,ymin),(xmax,ymax)) pd.to_pickle(xy_bounds,os.path.join(dirname,"masks",f"{basename}.pkl")) with ProgressBar(): dask.compute(write_files,scheduler='threading') pd.to_pickle(new_basenames,os.path.join(dirname,"metadata",f"{basename}.pkl")) return None
def preprocess_old(basename="163_A1a", threshold=0.05, patch_size=256, ext='.npy', secondary_patch_size=0): warnings.warn( "Old preprocessing is deprecated", DeprecationWarning ) raise RuntimeError os.makedirs("masks",exist_ok=True) os.makedirs("patches",exist_ok=True) image=f"inputs/{basename}{ext}" basename=os.path.basename(image).replace(ext,'') image=load_image(image)#np.load(image) img_shape=image.shape[:-1] masks=dict() masks['tumor_map']=generate_tissue_mask(image, compression=10, otsu=False, threshold=240, connectivity=8, kernel=5, min_object_size=100000, return_convex_hull=False, keep_holes=False, max_hole_size=6000, gray_before_close=True, blur_size=51) x_max,y_max=masks['tumor_map'].shape masks['macro_map']=fill_holes(masks['tumor_map']) patch_info=dict() for k in masks: patch_info[k]=pd.DataFrame([[basename,x,y,patch_size,"0"] for x,y in tqdm.tqdm(list(product(range(0,x_max-patch_size,patch_size),range(0,y_max-patch_size,patch_size))))],columns=['ID','x','y','patch_size','annotation']) patches=np.stack([image[x:x+patch_size,y:y+patch_size] for x,y in tqdm.tqdm(patch_info[k][['x','y']].values.tolist())]) include_patches=np.stack([masks[k][x:x+patch_size,y:y+patch_size] for x,y in tqdm.tqdm(patch_info[k][['x','y']].values.tolist())]).mean((1,2))>=threshold np.save(f"masks/{basename}_{k}.npy",masks[k]) np.save(f"patches/{basename}_{k}.npy",patches[include_patches]) patch_info[k].iloc[include_patches].to_pickle(f"patches/{basename}_{k}.pkl") if secondary_patch_size: patch_info=dict() for k in ['tumor_map']: patch_info[k]=pd.DataFrame([[basename,x,y,secondary_patch_size,"0"] for x,y in tqdm.tqdm(list(product(range(0,x_max-secondary_patch_size,secondary_patch_size),range(0,y_max-secondary_patch_size,secondary_patch_size))))],columns=['ID','x','y','patch_size','annotation']) patches=np.stack([image[x:x+secondary_patch_size,y:y+secondary_patch_size] for x,y in tqdm.tqdm(patch_info[k][['x','y']].values.tolist())]) include_patches=np.stack([masks[k][x:x+secondary_patch_size,y:y+secondary_patch_size] for x,y in tqdm.tqdm(patch_info[k][['x','y']].values.tolist())]).mean((1,2))>=threshold np.save(f"patches/{basename}_{k}_{secondary_patch_size}.npy",patches[include_patches]) patch_info[k].iloc[include_patches].to_pickle(f"patches/{basename}_{k}_{secondary_patch_size}.pkl") return img_shape