import logging
import random
import torch.nn.functional as F
from dataset import *
from gcn import GCN
import warnings
warnings.filterwarnings('ignore')
logger = logging.getLogger("Graph Classification")
logger.level = logging.INFO
format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
formatter = logging.Formatter(format_str)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
def train():
# load data
logger.info('Loading data...')
dataset = DatasetWrapper()
train_size = int(0.8 * len(dataset))
dim_feat = 1
dim_hidden = 16
gclasses = 2
batch_size = train_size
epochs = 100
lr = 0.1
# train and validation data indices
random_indices = list(range(len(dataset)))
random.shuffle(random_indices)
train_idx = random_indices[:train_size]
val_idx = random_indices[train_size:]
logger.info('Building model...')
model = GCN(dim_feat, dim_hidden, gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
logger.info('Training...')
model.train()
for epoch in range(epochs):
for i, indices in enumerate([train_idx[i:i + batch_size] for i in range(0, len(train_idx), batch_size)]):
data, labels = dataset[indices]
labels = torch.tensor(labels, dtype=torch.long)
pred = model(data)
loss = F.cross_entropy(pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
logger.info('Epoch: {}/{} | Batch: {}/{} | Loss: {}'.format(epoch + 1, epochs, i + 1, len(train_idx) // batch_size, loss.item()))
logger.info('Evaluating...')
model.eval()
num_correct = 0
num_tests = 0
with torch.no_grad():
for i, indices in enumerate([val_idx[i:i + batch_size] for i in range(0, len(val_idx), batch_size)]):
data, labels = dataset[indices]
labels = torch.tensor(labels, dtype=torch.long)
pred = model(data)
num_correct += (pred.argmax(1) == labels).sum().item()
num_tests += len(labels)
logger.info('Accuracy: {}'.format(num_correct / num_tests))
if __name__ == '__main__':
train()
GCN.zip
版权申诉
94 浏览量
2023-08-25
11:26:11
上传
评论
收藏 2.64MB ZIP 举报
sjx_alo
- 粉丝: 1w+
- 资源: 1216
最新资源
- (完整)数据库课程设计餐厅点餐说明书-21ab6d3c8beb172ded630b1c59eef8c75ebf952c.doc
- 2023-04-06-项目笔记 - 第一百五十四阶段 - 4.4.2.152全局变量的作用域-152 -2024.06.04
- 松哥解协议松哥解协议松哥解协议松哥解协议松哥解协议
- 618节日618节日618节日
- tensorflow-gpu-2.9.1-cp37-cp37m-win-amd64.whl
- tensorflow-gpu-2.9.0-cp37-cp37m-win-amd64.whl
- tensorflow-gpu-2.9.0-cp39-cp39-win-amd64.whl
- lcd daimalcd daima
- 电影领域-推荐算法-个性化内容-观影决策-电影推荐小程序.zip
- 电气控制PLC考试题库
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈