Skip to content

Commit 124312a

Browse files
committed
Use Tensor.dim() in CGRTensor
Shorter syntax and more readable.
1 parent 58b75d4 commit 124312a

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchhd/tensors/cgr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,17 @@ def bundle(self, other: "CGRTensor") -> "CGRTensor":
6565

6666
# Ensure hypervectors are in the same shape, i.e., [..., 1, DIM]
6767
t1 = self
68-
if len(t1.shape) == 1:
68+
if t1.dim() == 1:
6969
t1 = t1.unsqueeze(0)
7070
t2 = other
71-
if len(t2.shape) == 1:
71+
if t2.dim() == 1:
7272
t2 = t2.unsqueeze(0)
7373

7474
t = torch.stack((t1, t2), dim=-2)
7575
val = t.multibundle()
7676

7777
# Convert shape back to [DIM] if inputs are plain hypervectors
78-
need_squeeze = len(self.shape) == 1 and len(other.shape) == 1
78+
need_squeeze = self.dim() == 1 and other.dim() == 1
7979
if need_squeeze:
8080
return val.squeeze(0)
8181

0 commit comments

Comments
 (0)