Skip to content
4 changes: 4 additions & 0 deletions torchhd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
from torchhd.tensors.fhrr import FHRRTensor
from torchhd.tensors.bsbc import BSBCTensor
from torchhd.tensors.vtb import VTBTensor
from torchhd.tensors.basemcr import BaseMCRTensor
from torchhd.tensors.mcr import MCRTensor
from torchhd.tensors.cgr import CGRTensor

from torchhd.functional import (
ensure_vsa_tensor,
Expand Down Expand Up @@ -91,7 +93,9 @@
"FHRRTensor",
"BSBCTensor",
"VTBTensor",
"BaseMCRTensor",
"MCRTensor",
"CGRTensor",
"functional",
"embeddings",
"structures",
Expand Down
7 changes: 5 additions & 2 deletions torchhd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from torchhd.tensors.bsbc import BSBCTensor
from torchhd.tensors.vtb import VTBTensor
from torchhd.tensors.mcr import MCRTensor
from torchhd.tensors.cgr import CGRTensor
from torchhd.types import VSAOptions


Expand Down Expand Up @@ -93,6 +94,8 @@ def get_vsa_tensor_class(vsa: VSAOptions) -> Type[VSATensor]:
return VTBTensor
elif vsa == "MCR":
return MCRTensor
elif vsa == "CGR":
return CGRTensor

raise ValueError(f"Provided VSA model is not supported, specified: {vsa}")

Expand Down Expand Up @@ -361,7 +364,7 @@ def level(
device=span_hv.device,
).as_subclass(vsa_tensor)

if vsa == "BSBC" or vsa == "MCR":
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
hv.block_size = span_hv.block_size

for i in range(num_vectors):
Expand Down Expand Up @@ -588,7 +591,7 @@ def circular(
device=span_hv.device,
).as_subclass(vsa_tensor)

if vsa == "BSBC" or vsa == "MCR":
if vsa == "BSBC" or vsa == "MCR" or vsa == "CGR":
hv.block_size = span_hv.block_size

mutation_history = deque()
Expand Down
Loading