From ee647ba499cf5439d6a5466ac6afa618112477fd Mon Sep 17 00:00:00 2001 From: fcy540 <57482812+fcy540@users.noreply.github.com> Date: Mon, 23 Dec 2024 01:47:35 +0800 Subject: [PATCH] Update model.py --- svi_percept/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/svi_percept/model.py b/svi_percept/model.py index 3abe3d3..7f53ab3 100644 --- a/svi_percept/model.py +++ b/svi_percept/model.py @@ -76,5 +76,5 @@ def forward(self, features, **kwargs): kweights = torch.pow(10.0, raw_weights) # [batch_size, self.k] kweights = torch.nn.functional.softmax(kweights, dim=1) # [batch_size, self.k] kscores = scores[indices] # [batch_size, self.k] - results[:, cat_i] = torch.sum(kscores * kweights, dim=1) + results[:, cat_i] = torch.sum(kscores * kweights, dim=1).cpu() return {"results": results}