Source code for arctic_ai.gnn_prediction

"""
GNN 
==========

Graph neural network inference for tumor and completeness assessment.
"""
import os, torch, pickle, numpy as np, pandas as pd, torch.nn as nn
from torch_geometric.data import DataLoader as TG_DataLoader
from torch_geometric.utils import to_dense_batch, to_dense_adj, dense_to_sparse, dropout_adj, to_networkx
from torch_geometric.nn import GATConv
import torch.nn.functional as F

[docs]class GCNNet(torch.nn.Module): def __init__(self, inp_dim, out_dim, hidden_topology=[32,64,128,128], p=0.5, p2=0.1, drop_each=True): super(GCNNet, self).__init__() self.out_dim=out_dim self.convs = nn.ModuleList([GATConv(inp_dim, hidden_topology[0])]+[GATConv(hidden_topology[i],hidden_topology[i+1]) for i in range(len(hidden_topology[:-1]))]) self.drop_edge = lambda edge_index: dropout_adj(edge_index,p=p2)[0] self.dropout = nn.Dropout(p) self.fc = nn.Linear(hidden_topology[-1], out_dim) self.drop_each=drop_each
[docs] def forward(self, x, edge_index, edge_attr=None): for conv in self.convs: if self.drop_each and self.training: edge_index=self.drop_edge(edge_index) x = F.relu(conv(x, edge_index, edge_attr)) if self.training: x = self.dropout(x) x = self.fc(x) return x
[docs]class GCNFeatures(torch.nn.Module): def __init__(self, gcn, bayes=False, p=0.05, p2=0.1): super(GCNFeatures, self).__init__() self.gcn=gcn self.drop_each=bayes self.gcn.drop_edge = lambda edge_index: dropout_adj(edge_index,p=p2)[0] self.gcn.dropout = nn.Dropout(p)
[docs] def forward(self, x, edge_index, edge_attr=None): for i,conv in enumerate(self.gcn.convs): if self.drop_each: edge_index=self.gcn.drop_edge(edge_index) x = conv(x, edge_index, edge_attr) if i+1<len(self.gcn.convs): x=F.relu(x) if self.drop_each: x = self.gcn.dropout(x) y = self.gcn.fc(F.relu(x))#F.softmax() return x,y
def fix_state_dict(state_dict): # https://github.com/pyg-team/pytorch_geometric/issues/3139 new_state_dict={} for k in state_dict: if '.att_' in k or '.lin_' in k: new_state_dict[k.replace("_l","_src").replace("_r","_dst")]=state_dict[k] else: new_state_dict[k]=state_dict[k] return new_state_dict
[docs]def predict(basename="163_A1a", analysis_type="tumor", gpu_id=0, dirname="."): """ Run GNN prediction on patches. Parameters ---------- basename : str Base name of the slide. analysis_type : str Type of analysis to run. Must be "tumor" or "macro". gpu_id : int, optional ID of the GPU to use. Default is 0. dirname : str, optional Directory to save results to. Default is current directory. Returns ------- None """ os.makedirs(os.path.join(dirname,"gnn_results"),exist_ok=True) hidden_topology=dict(tumor=[32,64,64],macro=[32,64,64])#[32]*3 num_classes=dict(macro=4,tumor=3) if gpu_id>=0: torch.cuda.set_device(gpu_id) dataset=pickle.load(open(os.path.join(dirname,'graph_datasets',f"{basename}_{analysis_type}_map.pkl"),'rb')) model=GCNNet(dataset[0].x.shape[1],num_classes[analysis_type],hidden_topology=hidden_topology[analysis_type],p=0.,p2=0.) model.load_state_dict(fix_state_dict(torch.load(os.path.join(dirname,"models",f"{analysis_type}_map_gnn.pth"),map_location=f"cuda:{gpu_id}" if gpu_id>=0 else "cpu"))) if torch.cuda.is_available(): model=model.cuda() dataloader=TG_DataLoader(dataset,shuffle=False,batch_size=1) model.eval() feature_extractor=GCNFeatures(model,bayes=False) if torch.cuda.is_available(): feature_extractor=feature_extractor.cuda() graphs=[] for i,data in enumerate(dataloader): with torch.no_grad(): graph = to_networkx(data).to_undirected() model.train(False) x=data.x edge_index=data.edge_index if torch.cuda.is_available(): x=x.cuda() edge_index=edge_index.cuda() xy=data.pos.numpy() preds=feature_extractor(x,edge_index) z,y_pred=preds[0].detach().cpu().numpy(),preds[1].detach().cpu().numpy() graphs.append(dict(G=graph,xy=xy,z=z,y_pred=y_pred,slide=data.id,component=data.component)) torch.save(graphs,os.path.join(dirname,"gnn_results",f"{basename}_{analysis_type}_map.pkl"))