-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathwalk_encoder.py
More file actions
160 lines (138 loc) · 5.26 KB
/
Copy pathwalk_encoder.py
File metadata and controls
160 lines (138 loc) · 5.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
from torch import nn
from einops.layers.torch import Rearrange
from torch_scatter import scatter_mean, scatter_sum
class WalkEncoder(nn.Module):
def __init__(
self,
hidden_size,
sequence_layer_type='conv',
d_state=16,
d_conv=9,
expand=1,
num_heads=4,
mlp_ratio=1,
use_edge_proj=True,
use_encoder_norm=True,
proj_mlp_ratio=1,
dropout=0.,
walk_length=50,
use_positional_encoding=True,
pos_embed=False,
window_size=8,
bidirection=False,
layer_idx=None,
):
super().__init__()
self.hidden_size = hidden_size
self.use_positional_encoding= use_positional_encoding
walk_pe_dim = window_size * 2 - 1 if use_positional_encoding else 0
self.bidirection = bidirection
self.use_edge_proj = use_edge_proj
if use_edge_proj:
self.edge_proj = nn.Linear(hidden_size, hidden_size)
self.walk_pe_proj = nn.Linear(walk_pe_dim, hidden_size)
self.pos_embed = None
if pos_embed:
self.pos_embed = nn.Parameter(
torch.zeros(1, walk_length, hidden_size), requires_grad=False
)
self.pos_embed.data.copy_(
get_1d_sine_pos_embed(hidden_size, walk_length).unsqueeze(0)
)
# build sequence layer
self.norm = None
if use_encoder_norm:
self.norm = nn.LayerNorm(hidden_size, eps=1e-05)
self.seq_layer_backward = None
self.sequence_layer_type = sequence_layer_type
if sequence_layer_type == "conv":
self.seq_layer = nn.Sequential(
Rearrange("a b c -> a c b"),
nn.Conv1d(hidden_size, hidden_size, d_conv, groups=hidden_size, padding=d_conv // 2),
nn.ReLU(),
nn.Conv1d(hidden_size, hidden_size, 1, padding=0),
nn.ReLU(),
Rearrange("a b c -> a c b")
)
elif sequence_layer_type == 'mamba':
from mamba_ssm import Mamba
self.seq_layer = Mamba(
d_model=hidden_size,
d_state=d_state,
d_conv=d_conv,
expand=expand,
layer_idx=layer_idx,
)
if bidirection:
self.seq_layer_backward = Mamba(
d_model=hidden_size,
d_state=d_state,
d_conv=d_conv,
expand=expand,
layer_idx=layer_idx,
)
else:
raise NotImplementedError(
f"not supported sequence layer type: {sequence_layer_type}"
)
self.out_node_proj = nn.Sequential(
nn.Linear(hidden_size, hidden_size * proj_mlp_ratio),
nn.ReLU(),
nn.Linear(hidden_size * proj_mlp_ratio, hidden_size),
nn.Dropout(dropout),
)
if use_edge_proj:
self.out_edge_proj = nn.Sequential(
nn.Linear(hidden_size, hidden_size * proj_mlp_ratio),
nn.ReLU(),
nn.Linear(hidden_size * proj_mlp_ratio, hidden_size),
nn.Dropout(dropout),
)
self.edge_ln = nn.LayerNorm(hidden_size, eps=1e-05)
def reset_parameters(self):
def reset_module_parameters(module):
for layer in module.children():
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
reset_module_parameters(layer)
reset_module_parameters(self)
def forward(self, x, edge_index, walk_node_index, walk_edge_index, walk_pe, num_nodes, edge_attr=None):
walk_x = x[walk_node_index]
if self.use_edge_proj and edge_attr is not None:
walk_e = edge_attr[walk_edge_index]
walk_x = walk_x + self.edge_proj(walk_e)
del walk_e
if self.use_positional_encoding and walk_pe is not None:
walk_x = walk_x + self.walk_pe_proj(walk_pe)
if self.pos_embed is not None:
walk_x = walk_x + self.pos_embed
if self.norm is not None:
walk_x = self.norm(walk_x)
walk_x_forward = self.seq_layer(walk_x)
if self.seq_layer_backward is not None:
walk_x_backward = self.seq_layer_backward(walk_x.flip([1])).flip([1])
walk_x_forward = (walk_x_forward + walk_x_backward) * 0.5
del walk_x_backward
walk_x = walk_x_forward
node_agg = scatter_mean(
walk_x.reshape(-1, self.hidden_size),
walk_node_index.flatten(),
dim=0,
dim_size=num_nodes,
)
x = x + self.out_node_proj(node_agg)
if self.use_edge_proj:
del node_agg
edge_agg = scatter_mean(
walk_x[:, :-1].reshape(-1, self.hidden_size),
walk_edge_index[:, :-1].flatten(),
dim=0,
dim_size=edge_index.shape[-1],
)
if edge_attr is not None:
edge_attr = edge_attr + self.out_edge_proj(edge_agg)
else:
edge_attr = self.out_edge_proj(edge_agg)
edge_attr = self.edge_ln(edge_attr)
return x, edge_attr