Source code for arctic_ai.generate_graph

"""
Graph Data 
==========
Functions for graph dataset generation.
"""
import os, torch, numpy as np, pandas as pd
import pickle
import scipy.sparse as sps
from torch_geometric.utils import subgraph, add_remaining_self_loops
from torch_cluster import radius_graph
from collections import Counter
from torch_geometric.data import Data

[docs]def create_graph_data(basename="163_A1a", analysis_type="tumor", radius=256, min_component_size=600, no_component_break=False, dirname="."): """ Creates graph data for use in the GNN model for a given tissue slide. Parameters ---------- basename : str The basename of the tissue slide to create graph data for. analysis_type : str The type of analysis to perform. Can be "tumor" or "macro". radius : int The radius to use when creating the graph. min_component_size : int The minimum size a connected component must be to be included in the graph data. no_component_break : bool Whether to include all connected components in the graph data, or just the largest one. dirname : str The directory to save the graph data in. Returns ------- None """ os.makedirs(os.path.join(dirname,"graph_datasets"),exist_ok=True) embeddings=torch.load(os.path.join(dirname,f"cnn_embeddings/{basename}_{analysis_type}_map.pkl")) xy=torch.tensor(embeddings['patch_info'][['x','y']].values).float() if torch.cuda.is_available(): xy=xy.cuda() X=torch.tensor(embeddings['embeddings']) G=radius_graph(xy, r=radius*np.sqrt(2), batch=None, loop=True) G=G.detach().cpu() G=add_remaining_self_loops(G)[0] xy=xy.detach().cpu() datasets=[] edges=G.detach().cpu().numpy().astype(int) n_components,components=list(sps.csgraph.connected_components(sps.coo_matrix((np.ones_like(edges[0]),(edges[0],edges[1]))))) comp_count=Counter(components) components=torch.LongTensor(components) if not no_component_break: for i in range(n_components): if comp_count[i]>=min_component_size: G_new=subgraph(components==i,G,relabel_nodes=True)[0] xy_new=xy[components==i] X_new=X[components==i] np.random.seed(42) idx=np.arange(X_new.shape[0]) idx2=np.arange(X_new.shape[0]) np.random.shuffle(idx) train_idx,val_idx,test_idx=torch.tensor(np.isin(idx2,idx[:int(0.8*len(idx))])),torch.tensor(np.isin(idx2,idx[int(0.8*len(idx)):int(0.9*len(idx))])),torch.tensor(np.isin(idx2,idx[int(0.9*len(idx)):])) dataset=Data(x=X_new, edge_index=G_new, y_new=torch.ones(len(X_new)), edge_attr=None, pos=xy_new) dataset.train_mask=train_idx dataset.val_mask=val_idx dataset.test_mask=test_idx dataset.id=basename dataset.component=i datasets.append(dataset) else: dataset=Data(x=X, edge_index=G, y_new=torch.ones(len(X)), edge_attr=None, pos=xy) np.random.seed(42) idx=np.arange(X.shape[0]) idx2=np.arange(X.shape[0]) np.random.shuffle(idx) train_idx,val_idx,test_idx=torch.tensor(np.isin(idx2,idx[:int(0.8*len(idx))])),torch.tensor(np.isin(idx2,idx[int(0.8*len(idx)):int(0.9*len(idx))])),torch.tensor(np.isin(idx2,idx[int(0.9*len(idx)):])) dataset.train_mask=train_idx dataset.val_mask=val_idx dataset.test_mask=test_idx dataset.id=basename dataset.component=0 datasets.append(dataset) pickle.dump(datasets,open(os.path.join(dirname,'graph_datasets',f"{basename}_{analysis_type}_map.pkl"),'wb'))