import os
import copy
import torch
import deepsnap
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
from sklearn.metrics import f1_score
from deepsnap.hetero_gnn import forward_op
from deepsnap.hetero_graph import HeteroGraph
from torch_sparse import SparseTensor, matmul
class HeteroGNNConv(pyg_nn.MessagePassing):
def __init__(self, in_channels_src, in_channels_dst, out_channels):
super(HeteroGNNConv, self).__init__(aggr="mean")
self.in_channels_src = in_channels_src #这里对应两组输入src,dst 输出为一组 对应上文中的两组类型的特征输入进
self.in_channels_dst = in_channels_dst
self.out_channels = out_channels
# To simplify implementation, please initialize both self.lin_dst
# and self.lin_src out_features to out_channels
############# Your code here #############
## (~3 lines of code)
## Note:
## 1. Initialize the 3 linear layers.
## 2. Think through the connection between the mathematical
## definition of the update rule and torch linear layers!
self.lin_dst = nn.Linear(in_channels_dst, out_channels) # W_d^{(l)[m]}
self.lin_src = nn.Linear(in_channels_src, out_channels) # W_s^{(l)[m]}
self.lin_update = nn.Linear(out_channels * 2, out_channels) # W^{(l)[m]} #未涉及到卷积,只有线性层
##########################################
def forward(
self,
node_feature_src,
node_feature_dst,
edge_index,
size=None
):
############# Your code here #############
## (~1 line of code)
## Note:
## 1. Unlike Colabs 3 and 4, we just need to call self.propagate with
## proper/custom arguments.
return self.propagate(edge_index, size=size, #消息传递
node_feature_src=node_feature_src,
node_feature_dst=node_feature_dst)#, res_n_id=res_n_id
def message_and_aggregate(self, edge_index, node_feature_src):
out=matmul(edge_index,node_feature_src,reduce=self.aggr)
############# Your code here #############
## (~1 line of code)
## Note:
## 1. Different from what we implemented in Colabs 3 and 4, we use message_and_aggregate
## to combine the previously seperate message and aggregate functions.
## The benefit is that we can avoid materializing x_i and x_j
## to make the implementation more efficient.
## 2. To implement efficiently, refer to PyG documentation for message_and_aggregate
## and sparse-matrix multiplication:
## https://pytorch-geometric.readthedocs.io/en/latest/notes/sparse_tensor.html
## 3. Here edge_index is torch_sparse SparseTensor. Although interesting, you
## do not need to deeply understand SparseTensor represenations!
## 4. Conceptually, think through how the message passing and aggregation
## expressed mathematically can be expressed through matrix multiplication.
##########################################
return out
def update(self, aggr_out, node_feature_dst):
############# Your code here #############
## (~4 lines of code)
## Note:
## 1. The update function is called after message_and_aggregate
## 2. Think through the one-one connection between the mathematical update
## rule and the 3 linear layers defined in the constructor.
aggr_out = self.lin_src(aggr_out)
node_feature_dst = self.lin_dst(node_feature_dst)
concat_features = torch.cat((node_feature_dst, aggr_out), dim=-1)
# 维度-1在这里就是维度1
aggr_out = self.lin_update(concat_features)
return aggr_out
class HeteroGNNWrapperConv(deepsnap.hetero_gnn.HeteroConv):
def __init__(self, convs, args, aggr="mean"):
super(HeteroGNNWrapperConv, self).__init__(convs, None)
self.aggr = aggr
# Map the index and message type
self.mapping = {}
# A numpy array that stores the final attention probability
self.alpha = None
self.attn_proj = None
if self.aggr == "attn":
############# Your code here #############
## (~1 line of code)
## Note:
## 1. Initialize self.attn_proj, where self.attn_proj should include
## two linear layers. Note, make sure you understand
## which part of the equation self.attn_proj captures.
## 2. You should use nn.Sequential for self.attn_proj
## 3. nn.Linear and nn.Tanh are useful.
## 4. You can model a weight vector (rather than matrix) by using:
## nn.Linear(some_size, 1, bias=False).
## 5. The first linear layer should have out_features as args['attn_size']
## 6. You can assume we only have one "head" for the attention.
## 7. We recommend you to implement the mean aggregation first. After
## the mean aggregation works well in the training, then you can
## implement this part.
# if self.aggr == "attn":
self.attn_proj = nn.Sequential(
nn.Linear(args['hidden_size'], args['attn_size']),
nn.Tanh(), #https://pytorch.org/docs/stable/generated/torch.nn.Tanh.html#torch.nn.Tanh
nn.Linear(args['attn_size'], 1, bias=False),
)
##########################################
def reset_parameters(self):
super(HeteroConvWrapper, self).reset_parameters()
if self.aggr == "attn":
for layer in self.attn_proj.children():
layer.reset_parameters()
def forward(self, node_features, edge_indices):
message_type_emb = {}
for message_key, message_type in edge_indices.items():
src_type, edge_type, dst_type = message_key
node_feature_src = node_features[src_type]
node_feature_dst = node_features[dst_type]
edge_index = edge_indices[message_key]
message_type_emb[message_key] = (
self.convs[message_key]( #HeteroGNNConv(),{('paper', 'author', 'paper'): HeteroGNNConv(), ('paper', 'subject', 'paper'): HeteroGNNConv()}
node_feature_src,
node_feature_dst,
edge_index,
)
)
node_emb = {dst: [] for _, _, dst in message_type_emb.keys()}
mapping = {}
for (src, edge_type, dst), item in message_type_emb.items():
mapping[len(node_emb[dst])] = (src, edge_type, dst)
node_emb[dst].append(item)
self.mapping = mapping
for node_type, embs in node_emb.items():
if len(embs) == 1:
node_emb[node_type] = embs[0]
else:
node_emb[node_type] = self.aggregate(embs)
return node_emb
def aggregate(self, xs):
# TODO: Implement this function that aggregates all message type results.
# Here, xs is a list of tensors (embeddings) with respect to message
# type aggregation results.
if self.aggr == "mean":
## Note:
## 1. Explore the function parameter `xs`!
if self.aggr == "mean":
x = torch.stack(xs, dim=-1) #torch.Size([3025, 64, 2]) 在维度上连接(concatenate)若干个张量。(这些张量形状相同)。
return x.mean(dim=-1)
elif self.aggr == "attn":
N = xs[0].shape[0] # Number of nodes for that node type
M = len(xs) # Number of message types for that node type