Skip to content

Commit 6f92bc3

Browse files
committed
add variadic_sample, multi_slice & add more instances for spmm/rspmm
1 parent fef62bc commit 6f92bc3

File tree

12 files changed

+152
-41
lines changed

12 files changed

+152
-41
lines changed

doc/source/api/layers.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ Variadic
202202

203203
.. autofunction:: variadic_log_softmax
204204

205+
.. autofunction:: variadic_softmax
206+
205207
.. autofunction:: variadic_sort
206208

207209
.. autofunction:: variadic_topk
@@ -210,6 +212,8 @@ Variadic
210212

211213
.. autofunction:: variadic_randperm
212214

215+
.. autofunction:: variadic_sample
216+
213217
Tensor Reduction
214218
^^^^^^^^^^^^^^^^
215219
.. autofunction:: masked_mean
@@ -222,6 +226,8 @@ Tensor Construction
222226

223227
.. autofunction:: one_hot
224228

229+
.. autofunction:: multi_slice
230+
225231
.. autofunction:: multi_slice_mask
226232

227233
Sampling

torchdrug/layers/conv.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -779,19 +779,3 @@ def forward(self, graph, input):
779779
def combine(self, input, update):
780780
output = input + update
781781
return output
782-
783-
if __name__ == "__main__":
784-
edge_list = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 0]]
785-
node_position = torch.randn(6, 3)
786-
graph = data.Graph(edge_list, num_node=6)
787-
with graph.node():
788-
graph.node_position = node_position
789-
layer = ContinuousFilterConv(16, 16)
790-
791-
for i in range(100):
792-
input = torch.randn(6, 16)
793-
794-
message = layer.message(graph, input)
795-
output1 = layer.aggregate(graph, message)
796-
output2 = layer.message_and_aggregate(graph, input)
797-
assert torch.allclose(output1, output2)
Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
from .functional import multinomial, masked_mean, mean_with_nan, shifted_softplus, multi_slice_mask, as_mask, \
2-
_size_to_index, _extend, variadic_log_softmax, variadic_softmax, variadic_sum, variadic_mean, variadic_max, \
3-
variadic_cross_entropy, variadic_sort, variadic_topk, variadic_arange, variadic_randperm, one_hot, \
4-
clipped_policy_gradient_objective, policy_gradient_objective
1+
from .functional import multinomial, masked_mean, mean_with_nan, shifted_softplus, multi_slice, multi_slice_mask, \
2+
as_mask, _size_to_index, _extend, variadic_log_softmax, variadic_softmax, variadic_sum, variadic_mean, \
3+
variadic_max, variadic_cross_entropy, variadic_sort, variadic_topk, variadic_arange, variadic_randperm, \
4+
variadic_sample, one_hot, clipped_policy_gradient_objective, policy_gradient_objective
55
from .embedding import transe_score, distmult_score, complex_score, simple_score, rotate_score
66
from .spmm import generalized_spmm, generalized_rspmm
77

88
__all__ = [
99
"multinomial", "masked_mean", "mean_with_nan", "shifted_softplus", "multi_slice_mask", "as_mask",
1010
"variadic_log_softmax", "variadic_softmax", "variadic_sum", "variadic_mean", "variadic_max",
11-
"variadic_cross_entropy", "variadic_sort", "variadic_topk", "variadic_arange", "variadic_randperm", "one_hot",
12-
"clipped_policy_gradient_objective", "policy_gradient_objective",
11+
"variadic_cross_entropy", "variadic_sort", "variadic_topk", "variadic_arange", "variadic_randperm",
12+
"variadic_sample",
13+
"one_hot", "clipped_policy_gradient_objective", "policy_gradient_objective",
1314
"transe_score", "distmult_score", "complex_score", "simple_score", "rotate_score",
1415
"generalized_spmm", "generalized_rspmm",
1516
]

torchdrug/layers/functional/extension/rspmm.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul)
245245
DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul)
246246
DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul)
247247

248+
DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd)
249+
DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd)
250+
248251
DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd)
249252
DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd)
250253

torchdrug/layers/functional/extension/rspmm.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,9 @@ DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul)
362362
DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul)
363363
DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul)
364364

365+
DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd)
366+
DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd)
367+
365368
DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd)
366369
DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd)
367370

torchdrug/layers/functional/extension/rspmm.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ Tensor rspmm_max_mul_forward_cpu(const SparseTensor &sparse, const Tensor &relat
3535
std::tuple<SparseTensor, Tensor, Tensor> rspmm_max_mul_backward_cpu(const SparseTensor &sparse,
3636
const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad);
3737

38+
Tensor rspmm_add_add_forward_cpu(const SparseTensor &sparse, const Tensor &relation, const Tensor &input);
39+
40+
std::tuple<SparseTensor, Tensor, Tensor> rspmm_add_add_backward_cpu(const SparseTensor &sparse,
41+
const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad);
42+
3843
Tensor rspmm_min_add_forward_cpu(const SparseTensor &sparse, const Tensor &relation, const Tensor &input);
3944

4045
std::tuple<SparseTensor, Tensor, Tensor> rspmm_min_add_backward_cpu(const SparseTensor &sparse,
@@ -61,6 +66,11 @@ Tensor rspmm_max_mul_forward_cuda(const SparseTensor &sparse, const Tensor &rela
6166
std::tuple<SparseTensor, Tensor, Tensor> rspmm_max_mul_backward_cuda(const SparseTensor &sparse,
6267
const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad);
6368

69+
Tensor rspmm_add_add_forward_cuda(const SparseTensor &sparse, const Tensor &relation, const Tensor &input);
70+
71+
std::tuple<SparseTensor, Tensor, Tensor> rspmm_add_add_backward_cuda(const SparseTensor &sparse,
72+
const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad);
73+
6474
Tensor rspmm_min_add_forward_cuda(const SparseTensor &sparse, const Tensor &relation, const Tensor &input);
6575

6676
std::tuple<SparseTensor, Tensor, Tensor> rspmm_min_add_backward_cuda(const SparseTensor &sparse,

torchdrug/layers/functional/extension/spmm.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul)
212212
DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul)
213213
DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul)
214214

215+
DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd)
216+
DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd)
217+
215218
DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd)
216219
DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd)
217220

@@ -227,6 +230,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
227230
m.def("spmm_min_mul_backward_cpu", &at::spmm_min_mul_backward_cpu);
228231
m.def("spmm_max_mul_forward_cpu", &at::spmm_max_mul_forward_cpu);
229232
m.def("spmm_max_mul_backward_cpu", &at::spmm_max_mul_backward_cpu);
233+
m.def("spmm_add_add_forward_cpu", &at::spmm_add_add_forward_cpu);
234+
m.def("spmm_add_add_backward_cpu", &at::spmm_add_add_backward_cpu);
230235
m.def("spmm_min_add_forward_cpu", &at::spmm_min_add_forward_cpu);
231236
m.def("spmm_min_add_backward_cpu", &at::spmm_min_add_backward_cpu);
232237
m.def("spmm_max_add_forward_cpu", &at::spmm_max_add_forward_cpu);
@@ -237,6 +242,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
237242
m.def("rspmm_min_mul_backward_cpu", &at::rspmm_min_mul_backward_cpu);
238243
m.def("rspmm_max_mul_forward_cpu", &at::rspmm_max_mul_forward_cpu);
239244
m.def("rspmm_max_mul_backward_cpu", &at::rspmm_max_mul_backward_cpu);
245+
m.def("rspmm_add_add_forward_cpu", &at::rspmm_add_add_forward_cpu);
246+
m.def("rspmm_add_add_backward_cpu", &at::rspmm_add_add_backward_cpu);
240247
m.def("rspmm_min_add_forward_cpu", &at::rspmm_min_add_forward_cpu);
241248
m.def("rspmm_min_add_backward_cpu", &at::rspmm_min_add_backward_cpu);
242249
m.def("rspmm_max_add_forward_cpu", &at::rspmm_max_add_forward_cpu);
@@ -248,6 +255,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
248255
m.def("spmm_min_mul_backward_cuda", &at::spmm_min_mul_backward_cuda);
249256
m.def("spmm_max_mul_forward_cuda", &at::spmm_max_mul_forward_cuda);
250257
m.def("spmm_max_mul_backward_cuda", &at::spmm_max_mul_backward_cuda);
258+
m.def("spmm_add_add_forward_cuda", &at::spmm_add_add_forward_cuda);
259+
m.def("spmm_add_add_backward_cuda", &at::spmm_add_add_backward_cuda);
251260
m.def("spmm_min_add_forward_cuda", &at::spmm_min_add_forward_cuda);
252261
m.def("spmm_min_add_backward_cuda", &at::spmm_min_add_backward_cuda);
253262
m.def("spmm_max_add_forward_cuda", &at::spmm_max_add_forward_cuda);
@@ -258,6 +267,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
258267
m.def("rspmm_min_mul_backward_cuda", &at::rspmm_min_mul_backward_cuda);
259268
m.def("rspmm_max_mul_forward_cuda", &at::rspmm_max_mul_forward_cuda);
260269
m.def("rspmm_max_mul_backward_cuda", &at::rspmm_max_mul_backward_cuda);
270+
m.def("rspmm_add_add_forward_cuda", &at::rspmm_add_add_forward_cuda);
271+
m.def("rspmm_add_add_backward_cuda", &at::rspmm_add_add_backward_cuda);
261272
m.def("rspmm_min_add_forward_cuda", &at::rspmm_min_add_forward_cuda);
262273
m.def("rspmm_min_add_backward_cuda", &at::rspmm_min_add_backward_cuda);
263274
m.def("rspmm_max_add_forward_cuda", &at::rspmm_max_add_forward_cuda);

torchdrug/layers/functional/extension/spmm.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,9 @@ DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul)
321321
DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul)
322322
DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul)
323323

324+
DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd)
325+
DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd)
326+
324327
DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd)
325328
DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd)
326329

torchdrug/layers/functional/extension/spmm.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ Tensor spmm_max_mul_forward_cpu(const SparseTensor &sparse, const Tensor &input)
3535
std::tuple<SparseTensor, Tensor> spmm_max_mul_backward_cpu(
3636
const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad);
3737

38+
Tensor spmm_add_add_forward_cpu(const SparseTensor &sparse, const Tensor &input);
39+
40+
std::tuple<SparseTensor, Tensor> spmm_add_add_backward_cpu(
41+
const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad);
42+
3843
Tensor spmm_min_add_forward_cpu(const SparseTensor &sparse, const Tensor &input);
3944

4045
std::tuple<SparseTensor, Tensor> spmm_min_add_backward_cpu(
@@ -61,6 +66,11 @@ Tensor spmm_max_mul_forward_cuda(const SparseTensor &sparse, const Tensor &input
6166
std::tuple<SparseTensor, Tensor> spmm_max_mul_backward_cuda(
6267
const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad);
6368

69+
Tensor spmm_add_add_forward_cuda(const SparseTensor &sparse, const Tensor &input);
70+
71+
std::tuple<SparseTensor, Tensor> spmm_add_add_backward_cuda(
72+
const SparseTensor &sparse, const Tensor &input, const Tensor &output, const Tensor &output_grad);
73+
6474
Tensor spmm_min_add_forward_cuda(const SparseTensor &sparse, const Tensor &input);
6575

6676
std::tuple<SparseTensor, Tensor> spmm_min_add_backward_cuda(

torchdrug/layers/functional/functional.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,41 @@ def shifted_softplus(input):
6060
return F.softplus(input) - F.softplus(torch.zeros(1, device=input.device))
6161

6262

63+
def multi_slice(starts, ends):
64+
"""
65+
Compute the union of indexes in multiple slices.
66+
67+
Example::
68+
69+
>>> mask = multi_slice(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6)
70+
>>> assert (mask == torch.tensor([0, 1, 2, 4, 5]).all()
71+
72+
Parameters:
73+
starts (LongTensor): start indexes of slices
74+
ends (LongTensor): end indexes of slices
75+
"""
76+
values = torch.cat([torch.ones_like(starts), -torch.ones_like(ends)])
77+
slices = torch.cat([starts, ends])
78+
slices, order = slices.sort()
79+
values = values[order]
80+
depth = values.cumsum(0)
81+
valid = (values == 1 & depth == 0) | (values == -1 & depth == 1)
82+
slices = slices[valid]
83+
84+
starts, ends = slices.view(-1, 2).t()
85+
size = ends - starts
86+
indexes = variadic_arange(size)
87+
indexes = indexes + starts.repeat_interleave(size)
88+
return indexes
89+
90+
6391
def multi_slice_mask(starts, ends, length):
6492
"""
6593
Compute the union of multiple slices into a binary mask.
6694
6795
Example::
6896
69-
>>> mask = F.multi_slice_mask(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6)
97+
>>> mask = multi_slice_mask(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6)
7098
>>> assert (mask == torch.tensor([1, 1, 1, 0, 1, 1])).all()
7199
72100
Parameters:
@@ -75,10 +103,10 @@ def multi_slice_mask(starts, ends, length):
75103
length (int): length of mask
76104
"""
77105
values = torch.cat([torch.ones_like(starts), -torch.ones_like(ends)])
78-
indexes = torch.cat([starts, ends])
79-
if indexes.numel():
80-
assert indexes.min() >= 0 and indexes.max() <= length
81-
mask = scatter_add(values, indexes, dim_size=length + 1)[:-1]
106+
slices = torch.cat([starts, ends])
107+
if slices.numel():
108+
assert slices.min() >= 0 and slices.max() <= length
109+
mask = scatter_add(values, slices, dim_size=length + 1)[:-1]
82110
mask = mask.cumsum(0).bool()
83111
return mask
84112

@@ -108,12 +136,8 @@ def _size_to_index(size):
108136
Parameters:
109137
size (LongTensor): size of each sample
110138
"""
111-
cum_size = size.cumsum(0)
112-
# special case 1: size[-1] = 0
113-
index = cum_size[cum_size < cum_size[-1]]
114-
# special case 2: size[i] = size[i+1] = 0
115-
index2sample = scatter_add(torch.ones_like(index), index, dim_size=cum_size[-1])
116-
index2sample = index2sample.cumsum(0)
139+
range = torch.arange(len(size), device=size.device)
140+
index2sample = range.repeat_interleave(size)
117141
return index2sample
118142

119143

@@ -363,7 +387,7 @@ def variadic_sort(input, size, descending=False):
363387
return value, index
364388

365389

366-
def variadic_arange(size, device=None):
390+
def variadic_arange(size):
367391
"""
368392
Return a 1-D tensor that contains integer intervals of variadic sizes.
369393
This is a variadic variant of ``torch.arange(stop).expand(batch_size, -1)``.
@@ -372,17 +396,15 @@ def variadic_arange(size, device=None):
372396
373397
Parameters:
374398
size (LongTensor): size of intervals of shape :math:`(N,)`
375-
device (torch.device, optional): device of the tensor
376399
"""
377-
index2sample = _size_to_index(size)
378400
starts = size.cumsum(0) - size
379401

380-
range = torch.arange(size.sum(), device=device)
381-
range = range - starts[index2sample]
402+
range = torch.arange(size.sum(), device=size.device)
403+
range = range - starts.repeat_interleave(size)
382404
return range
383405

384406

385-
def variadic_randperm(size, device=None):
407+
def variadic_randperm(size):
386408
"""
387409
Return random permutations for sets with variadic sizes.
388410
The ``i``-th permutation contains integers from 0 to ``size[i] - 1``.
@@ -393,11 +415,19 @@ def variadic_randperm(size, device=None):
393415
size (LongTensor): size of sets of shape :math:`(N,)`
394416
device (torch.device, optional): device of the tensor
395417
"""
396-
rand = torch.rand(size.sum(), device=device)
418+
rand = torch.rand(size.sum(), device=size.device)
397419
perm = variadic_sort(rand, size)[1]
398420
return perm
399421

400422

423+
def variadic_sample(input, size, k):
424+
rand = torch.rand(len(size), k, device=size.device)
425+
index = (rand * size.unsqueeze(-1)).long()
426+
index = index + (size.cumsum(0) - size).unsqueeze(-1)
427+
sample = input[index]
428+
return sample
429+
430+
401431
def one_hot(index, size):
402432
"""
403433
Expand indexes into one-hot vectors.

0 commit comments

Comments
 (0)