diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 77308ce94d..95514c5d47 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -1528,13 +1528,15 @@ defmodule Torchx.Backend do |> then(unfold_flat) |> then(function) + {device, _} = from_nx(tensor) + indices_to_flatten = tensor |> Nx.axes() |> Enum.map(fn axis -> tensor |> Nx.shape() - |> Nx.iota(axis: axis, backend: Torchx.Backend) + |> Nx.iota(axis: axis, backend: {Torchx.Backend, device: device}) |> then(unfold_flat) |> Nx.take_along_axis(Nx.new_axis(arg_idx, -1), axis: -1) end) diff --git a/torchx/test/torchx/device_test.exs b/torchx/test/torchx/device_test.exs index 0f55a5366a..32e717ee22 100644 --- a/torchx/test/torchx/device_test.exs +++ b/torchx/test/torchx/device_test.exs @@ -45,4 +45,12 @@ defmodule Torchx.DeviceTest do # assert_raise ArgumentError, fn -> Nx.backend_transfer(t) end end end + + describe "indices_to_flatten" do + test "works" do + t = Nx.tensor([[1, 2], [3, 4]], backend: {TB, device: @device}) + t2 = Nx.tensor([[2, 6], [3, 1]], backend: {TB, device: @device}) + assert_equal Nx.window_scatter_max(t, t2, 0, {2, 3}), Nx.tensor([[0, 0, 0, 0, 6, 0], [0, 0, 2, 0, 0, 0], [0, 0, 3, 0, 0, 0], [0, 0, 0, 0, 0, 1]], backend: {TB, device: @device}) + end + end end