#!/usr/bin/env
# coding:utf-8
"""
Created on 2019/7/29 11:10
base Info
"""
__author__ = 'xx'
__version__ = '1.0'
from GNN_Implement.data_loader import PPI
from sklearn import metrics
from random import choice
import numpy as np
import torch
from itertools import chain
from GNN_Implement.model_component.utils.adj_mat import adj_list_to_n_order_adj_list
from GNN_Implement.model_component.utils.add_self_loop import list_add_self_loops_
class NOrderInductiveLearningTest(object):
def __init__(self, model_, dataset, order=2):
self.model_ = model_
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.dataset = dataset
# self.dataset = PPI()
assert order >= 1
self.adj_order = order
self.dataset.to_device(self.device, select_att=['node_label', 'node_ft'])
self.node_ft_size = self.dataset.get_node_ft_size()
# self.node_num = self.dataset.get_node_size()
self.label_num = self.dataset.get_label_num()
self.get_n_order_edge_list()
self.built_model()
# 图太多 放到CPU中
def get_n_order_edge_list(self, device='cpu'):
print('bulit n order edge list')
self.train_n_order_edge_list = [[graph.edge_list] for graph in self.dataset.train_graphs]
self.valid_n_order_edge_list = [[graph.edge_list] for graph in self.dataset.valid_graphs]
self.test_n_order_edge_list = [[graph.edge_list] for graph in self.dataset.test_graphs]
for one_graph_list in chain(self.train_n_order_edge_list,
self.valid_n_order_edge_list,
self.test_n_order_edge_list):
for order_idx in range(1, self.adj_order):
high_order_adj_list = adj_list_to_n_order_adj_list(one_graph_list[0], order_idx, device=device, fliter_path_num=9)
one_graph_list.append(high_order_adj_list)
print('one_graph_list[0] size = ', one_graph_list[0].size(), 'high_order_adj_list size = ', high_order_adj_list.size())
# 加入 self--loop
# print(one_graph_list[0].size())
one_graph_list[0] = list_add_self_loops_(one_graph_list[0], device='cpu')
# print(one_graph_list[0].size())
# print(one_graph_list[0].device, one_graph_list[0].dtype, one_graph_list[1].device, one_graph_list[1].dtype)
def built_model(self):
self.model = self.model_(
self.node_ft_size,
self.label_num
).to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.005)
self.epochs = 1000
self.test_acc_list = []
self.loss_op = torch.nn.BCEWithLogitsLoss()
def train(self):
self.model.train()
total_loss = 0
train_f1 = 0
for _ in range(len(self.dataset.train_graphs)):
one_train_graph = choice(self.dataset.train_graphs)
idx = self.dataset.train_graphs.index(one_train_graph)
self.optimizer.zero_grad()
# 构建高阶邻居矩阵
device_n_order_edge_list = []
for edge_list in self.train_n_order_edge_list[idx]:
device_n_order_edge_list.append(edge_list.to(self.device))
# print(edge_list.size(), type(edge_list), edge_list.device)
pred = self.model(one_train_graph.node_ft , device_n_order_edge_list).float()
label = one_train_graph.node_label.float()
loss = self.loss_op(pred, label)
pred = (pred > 0).float().cpu()
label = label.float().cpu()
micro_f1 = metrics.f1_score(label, pred, average='micro')
# print('node_num = ', one_train_graph.node_num, '1-pos = ', one_train_graph.node_label.sum(), 'micro_f1 = ', micro_f1)
train_f1 += micro_f1
total_loss += loss.item()
loss.backward()
self.optimizer.step()
return total_loss / len(self.dataset.train_graphs), train_f1 / len(self.dataset.train_graphs)
def metrics_model(self, graph_list, n_order_edge_lists):
self.model.eval()
total_micro_f1 = 0
total_micro_p = 0
total_micro_r = 0
right_num = 0
all_num = 0
for idx in range(len(graph_list)):
# cpu -> gpu
device_n_order_edge_list = []
for edge_list in n_order_edge_lists[idx]:
device_n_order_edge_list.append(edge_list.to(self.device))
# print(edge_list.size(), type(edge_list))
with torch.no_grad():
out = self.model(graph_list[idx].node_ft , device_n_order_edge_list)
pred = (out > 0).float().cpu()
label = graph_list[idx].node_label.float().cpu()
micro_f1 = metrics.f1_score(label, pred, average='micro')
total_micro_p += metrics.precision_score(label, pred, average='micro')
total_micro_r += metrics.recall_score(label, pred, average='micro')
# print('node_num = ', one_graph.node_num, '1-pos = ', one_graph.node_label.sum(), 'micro_f1 = ', micro_f1)
total_micro_f1 += metrics.f1_score(label, pred, average='micro')
# return float(right_num) / all_num
return total_micro_f1 / len(graph_list), total_micro_p / len(graph_list), total_micro_r / len(graph_list),
def start(self, display=True):
print('start train')
for epoch in range(1, self.epochs+1):
loss, train_f1 = self.train()
# train_f1 = self.metrics_model(self.dataset.train_graphs)
# valid_f1 = self.metrics_model(self.dataset.valid_graphs)
test_f1, test_p, test_r = self.metrics_model(self.dataset.test_graphs, self.test_n_order_edge_list)
print('epoch = ', epoch, 'loss = ', loss, 'train_f1 = ', train_f1, 'test_f1 = ', test_f1, 'test_p = ', test_p, 'test_r = ', test_r)
# if display is True:
# print('Epoch: {:02d}, Loss: {:.4f}, train_f1: {:.4f}, train_f1: {:.4f}, train_f1: {:.4f}'.format(
# epoch, loss, valid_f1, valid_f1, test_f1))
if __name__ == '__main__':
from GNN_Implement.model.modified_model.ppi_GCN2 import GCN2
from GNN_Implement.model.modified_model.ppi_GAT2_V1 import GAT2
# demo_test = InductiveLearningTest(GAT, PPI())
# demo_test = NOrderInductiveLearningTest(GCN2, PPI())
# demo_test = InductiveLearningTest(GCN2, PPI())
from GNN_Implement.model.modified_model.ppi_GCN2_V2 import GCN2
torch.cuda.set_device(2)
demo_test = NOrderInductiveLearningTest(GCN2, PPI())
# demo_test = NOrderInductiveLearningTest(GAT2, PPI())
# print(demo_test.dataset.train_graphs[0].node_label.device)
demo_test.start()
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
基于高阶邻居的图神经网络.zip (155个子文件)
ind.pubmed.allx 7.23MB
ind.citeseer.allx 581KB
ind.cora.allx 251KB
ind.pubmed.ally 219KB
ind.citeseer.ally 54KB
ind.cora.ally 47KB
Coradat 15.04MB
Cora.dat 420B
ind.pubmed.graph 461KB
trans.pubmed.graph 460KB
trans.citeseer.graph 61KB
ind.citeseer.graph 61KB
trans.cora.graph 58KB
ind.cora.graph 58KB
ind.pubmed.test.index 6KB
ind.citeseer.test.index 5KB
ind.cora.test.index 5KB
degree.ipynb 3KB
degree-checkpoint.ipynb 3KB
train_graph.json 59.69MB
ppi-G.json 30.22MB
ppi-class_map.json 20.3MB
valid_graph.json 9.15MB
test_graph.json 7.42MB
ppi-id_map.json 868KB
train_labels.npy 41.46MB
ppi-feats.npy 21.72MB
train_feats.npy 17.13MB
valid_labels.npy 6.01MB
test_labels.npy 5.1MB
valid_feats.npy 2.49MB
test_feats.npy 2.11MB
train_graph_id.npy 351KB
valid_graph_id.npy 51KB
test_graph_id.npy 43KB
demo.png 237KB
n_order_inductive_learning_test.py 7KB
HAN_two_order_mat.py 6KB
n_order_inductive_learning_test_input_mat.py 6KB
inductive_learning_test_input_mat.py 6KB
adj_mat.py 6KB
HAN_recur_mat.py 5KB
classical_citation.py 5KB
classical_citation_porcess_file.py 5KB
HAN.py 5KB
HAN_recur.py 5KB
hgcn_model_test.py 5KB
n_order_model_test.py 5KB
inductive_learning_test.py 5KB
ppi_GAT_V2.py 4KB
ppi_GAT2_V1.py 4KB
ppi_data_process_file.py 4KB
data.py 4KB
snap_data_process_file.py 4KB
GAT2_combine.py 4KB
n_order_GAT.py 4KB
gat_conv_input_list.py 4KB
gat_conv_input_mat.py 4KB
HGCN.py 3KB
ppi_GCN2.py 3KB
MLP_test.py 3KB
GCN2_combine.py 3KB
bit_gat_conv_input_list.py 3KB
ppi_GCN2_V2.py 3KB
ppi_HGCN.py 3KB
HGCN.py 3KB
ppi_GCN2_V2_mat.py 3KB
graph_test_base.py 3KB
sample_gat.py 3KB
statistical_scirpt.py 3KB
gcn_conv_input_list.py 3KB
one_graph_dataloader.py 3KB
igcn_conv_input_mat.py 3KB
gat_recurrence.py 3KB
GCN_two_order.py 3KB
ratio_model_test.py 3KB
baseline_model_test.py 3KB
base_process.py 3KB
bit_gat_conv_input_mat.py 2KB
ppi_GCN.py 2KB
ggnn_conv_input_list.py 2KB
aggregation_attn_conv.py 2KB
n_order_aggr_conv.py 2KB
two_order_GAT.py 2KB
add_self_loop.py 2KB
classical_gat.py 2KB
torch_dataset.py 2KB
BitGAT.py 2KB
GGNN.py 2KB
degree_display.py 2KB
GAT.py 2KB
ppi_IGCN.py 2KB
Cora.py 2KB
GAT_recur.py 2KB
process_tools.py 2KB
IGCN.py 2KB
two_order_aggr_conv_pro.py 2KB
two_order_aggr_conv.py 2KB
Alias_method.py 2KB
base_dataset.py 2KB
共 155 条
- 1
- 2
资源评论
博士僧小星
- 粉丝: 2161
- 资源: 5942
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 插入排序的Java实现方法InsertSort
- web前端实现四级联动效果
- 暴风电视刷机数据 X45 T45FUM屏V450HJ1-Q01机编60000AM2200 2300屏参30163802 强制升级
- 计算机Golang语言图书管理系统设计开发教程
- linux常用命令大全 Linux 系统中,如何查看文件内容?
- 作业作业作业作业作业作业作业作业作业作业作业作业作业作业
- Tina-Linux-系统软件-开发指南.pdf
- 精选电力集团信息化数字化转型总体规划数据治理大数据应用支撑平台建设可编辑PPT及word参考资料(4份).zip
- ssm+mysql的基于智慧医疗预约挂号管理系统(源码+lw+ppt)
- Ingress-Controller高可用方案及多租户场景
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功