Source code for arctic_ai.detection_workflows.nuclei_gnn

"""Nuclei GNN
==========
Functions to support cell-graph neural network inference. Cell level CNN embedding code is being refactored and will be added.
"""
import sys, os
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pickle
import os,glob, pandas as pd
from sklearn.metrics import f1_score
import copy
from collections import Counter
import torch
import torch.nn.functional as F
import traceback
try:
    from torch import nn
    from torch_geometric.nn import GCNConv, GATConv, DeepGraphInfomax, SAGEConv
    from torch_geometric.nn import DenseGraphConv
    from torch_geometric.utils import to_dense_batch, to_dense_adj, dense_to_sparse
    from torch_geometric.nn import GINEConv
    from torch_geometric.utils import dropout_adj
    from torch_geometric.nn import APPNP
    from torch_cluster import knn_graph
    from torch_geometric.data import Data 
    from torch_geometric.utils import train_test_split_edges
    from torch_geometric.utils.convert import to_networkx
    from torch_geometric.data import InMemoryDataset,DataLoader
    from sklearn.model_selection import train_test_split
    from sklearn.utils.class_weight import compute_class_weight
except:
    print(traceback.format_exc())


def get_graph_dataset(basename, embedding_dir='gcn/nodes', k=8, radius=0):
    df=pickle.load(open(os.path.join(embedding_dir,f"{basename}.pkl"), "rb"))
    xy=torch.tensor(df[['X','Y']].values).cuda()
    
    
    X=torch.tensor(np.stack([x for x in df.Embedding.values]), dtype=torch.float32)
    y=torch.tensor(df.Label.values)
    # needs to be modified... build a graph for each individual patch
    if not radius:
        G=knn_graph(xy,k=k)
    else:
        G=radius_graph(xy, r=radius*np.sqrt(2), batch=None, loop=True)
        
    G=G.detach().cpu()
    G=torch_geometric.utils.add_remaining_self_loops(G)[0]
    xy=xy.detach().cpu()
    datasets=[]
    edges=G.detach().cpu().numpy().astype(int)
    n_components,components=list(sps.csgraph.connected_components(sps.coo_matrix((np.ones_like(edges[0]),(edges[0],edges[1])))))
    components=torch.LongTensor(components)
    for i in range(n_components):
        G_new=subgraph(components==i,G,relabel_nodes=True)[0]
        xy_new=xy[components==i]
        X_new=X[components==i]
        y_new=y[components==i]
        np.random.seed(42)
        idx=np.arange(X_new.shape[0])
        idx2=np.arange(X_new.shape[0])
        np.random.shuffle(idx)
        train_idx,val_idx,test_idx=torch.tensor(np.isin(idx2,idx[:int(0.8*len(idx))])),torch.tensor(np.isin(idx2,idx[int(0.8*len(idx)):int(0.9*len(idx))])),torch.tensor(np.isin(idx2,idx[int(0.9*len(idx)):]))
        dataset=Data(x=X_new, edge_index=G_new, edge_attr=None, y=y_new, pos=xy_new)
        dataset.train_mask=train_idx
        dataset.val_mask=val_idx
        dataset.test_mask=test_idx
        datasets.append(dataset)
    return datasets

EPS = 1e-15

[docs]class GCNNet(torch.nn.Module): def __init__(self, inp_dim, out_dim, hidden_topology=[32,64,128,128], p=0.5, p2=0.1, drop_each=True): super(GCNNet, self).__init__() self.out_dim=out_dim self.convs = nn.ModuleList([GATConv(inp_dim, hidden_topology[0])]+[GATConv(hidden_topology[i],hidden_topology[i+1]) for i in range(len(hidden_topology[:-1]))]) self.drop_edge = lambda edge_index: dropout_adj(edge_index,p=p2)[0] self.dropout = nn.Dropout(p) self.fc = nn.Linear(hidden_topology[-1], out_dim) self.drop_each=drop_each
[docs] def forward(self, x, edge_index, edge_attr=None): for conv in self.convs: if self.drop_each and self.training: edge_index=self.drop_edge(edge_index) x = conv(x, edge_index, edge_attr) x = F.relu(x) if self.training: x = self.dropout(x) x = self.fc(x) return x
[docs]class GCNFeatures(torch.nn.Module): def __init__(self, gcn, bayes=False): super(GCNFeatures, self).__init__() self.gcn=gcn self.drop_each=bayes
[docs] def forward(self, x, edge_index, edge_attr=None): for conv in self.gcn.convs: if self.drop_each: edge_index=self.gcn.drop_edge(edge_index) x = F.relu(conv(x, edge_index, edge_attr)) if self.drop_each: x = self.gcn.dropout(x) y = F.softmax(self.gcn.fc(x)) return x,y
def validate_model(graph_data="graph_data.pkl", use_weights=False, use_model=None, n_batches_backward=1, f1_metric='weighted', n_epochs=1500, out_dir='gcn/models', lr=1e-2, eta_min=1e-4, T_max=20, wd=0, hidden_topology=[32,64,128,128], p=0.5, p2=0.3, burnin=400, warmup=100, gpu_id=0, batch_size=1 ): print(gpu_id); torch.cuda.set_device(gpu_id) datasets=pickle.load(open(graph_data, "rb")) train_dataset = [dataset for key in datasets for dataset in datasets[key]['train']] val_dataset = [dataset for key in datasets for dataset in datasets[key]['val']] y_train=np.hstack([graph.y.numpy() for graph in train_dataset]) if use_weights: weights=compute_class_weight('balanced',classes=np.unique(y_train),y=y_train) else: weights=None print("number of classes", len(np.unique(y_train))) # load model model=GCNNet(train_dataset[0].x.shape[1],len(np.unique(y_train)),hidden_topology=hidden_topology,p=p,p2=p2) if use_model: model.load_state_dict(torch.load(use_model)) model=model.cuda() criterion=nn.CrossEntropyLoss(weight=torch.tensor(weights).float() if use_weights else None) criterion=criterion.cuda() Y_all = {} Y_Pred_all = {} pos_all = {} Y,Y_Pred=[],[] pos = [] wsi = [] for key in datasets: train_dataset = datasets[key]['train'] val_dataset = datasets[key]['val'] dataloaders={} dataloaders['train']=DataLoader(train_dataset,batch_size=batch_size,shuffle=True) dataloaders['warmup']=DataLoader(train_dataset,shuffle=False) train_loader=dataloaders['warmup'] for i,data in enumerate(dataloaders['train']): model.train(False) x=data.x.cuda() edge_index=data.edge_index.cuda() y=data.y.cuda() y_out=model(x,edge_index) #print(y_out, y) loss = criterion(y_out, y) y_prob=F.softmax(y_out, dim=1).detach().cpu().numpy() y_pred=y_prob.argmax(1).flatten() y_true=y.detach().cpu().numpy().flatten() Y_Pred.append(y_pred) Y.append(y_true) pos.append(data.pos) for i in range(len(y_pred)): wsi.append(key) del x, edge_index, loss, y_out if (len(val_dataset) > 0): dataloaders['val']=DataLoader(val_dataset,shuffle=True) for i,data in enumerate(dataloaders['val']): model.train(False) x=data.x.cuda() edge_index=data.edge_index.cuda() y=data.y.cuda() y_out=model(x,edge_index) loss = criterion(y_out, y) y_prob=F.softmax(y_out, dim=1).detach().cpu().numpy() y_pred=y_prob.argmax(1).flatten() y_true=y.detach().cpu().numpy().flatten() Y_Pred.append(y_pred) Y.append(y_true) pos.append(data.pos) for i in range(len(y_pred)): wsi.append(key) del x, edge_index, loss, y_out Y = np.hstack(Y) Y_Pred = np.hstack(Y_Pred) pos = np.vstack(pos) return Y, Y_Pred, pos, wsi