Skip to content

Conversation

@giordano
Copy link
Member

@giordano giordano commented Nov 12, 2025

Based on the template of #1626. This is very preliminary, but very basic stuff works if we ignore the result of a matmul is off by ~0.1%, which doesn't look terribly good (accuracy is a lot better on CPU):

julia> using Reactant

julia> Reactant.devices()
32-element Vector{Reactant.XLA.PJRT.Device}:
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001bd189f0), "TT:0 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d000), "TT:1 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d0b0), "TT:2 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d160), "TT:3 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d210), "TT:4 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d2c0), "TT:5 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d370), "TT:6 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d420), "TT:7 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d4d0), "TT:8 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d580), "TT:9 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d630), "TT:10 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d6e0), "TT:11 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d790), "TT:12 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d840), "TT:13 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d8f0), "TT:14 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16d9a0), "TT:15 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16da50), "TT:16 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16db00), "TT:17 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16dbb0), "TT:18 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16dc60), "TT:19 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16dd10), "TT:20 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16ddc0), "TT:21 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16de70), "TT:22 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16df20), "TT:23 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16dfd0), "TT:24 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16e080), "TT:25 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16e130), "TT:26 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16e1e0), "TT:27 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16e290), "TT:28 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16e340), "TT:29 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16e3f0), "TT:30 Wormhole_b0")
 Reactant.XLA.PJRT.Device(Ptr{Nothing}(0x000000001c16e4a0), "TT:31 Wormhole_b0")

julia> A = randn(Float32, 3, 3); B = randn(Float32, 3, 3); # random matrices on CPU

julia> rA = Reactant.to_rarray(A); rB = Reactant.to_rarray(B); # copy matrices to the device

julia> @jit rA * rB # matmul on the device
2025-11-12 18:50:25.481 | info     |          Fabric | TopologyMapper mapping start (mesh=0): n_log=32, n_phys=32, log_deg_hist={2:4, 3:16, 4:12}, phys_deg_hist={4:32} (topology_mapper.cpp:479)
2025-11-12 18:50:25.718 | warning  |           Metal | Opening subset of mmio devices slows down UMD read/write to remote chips. If opening more devices, consider using CreateDevices API. (device_pool.cpp:303)
2025-11-12 18:50:27.586 | info     |           Metal | Enabling program cache on MeshDevice 2 (mesh_device.cpp:642)
3×3 ConcretePJRTArray{Float32,2}:
  0.856026  -0.228477   0.102485
  1.01793    1.36568   -1.67887
 -0.344294  -0.512573   0.685867

julia> A * B # matmul on CPU, as a sanity check
3×3 Matrix{Float32}:
  0.856317  -0.22869    0.102752
  1.0184     1.36624   -1.68034
 -0.344221  -0.512788   0.686363

Ref: #1256, #1235 (CC @p-w-rs)
@fleclairTT you might be interested in this, too.

end
end

has_tt() = true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to figure out how to detect the devices automatically, but apart from that this should be good from my side as an experimental plugin. I'm not sure about the "TT" name everywhere though, bit too short and obscure, but that's how all the TensTorrent tools are called.

@avik-pal
Copy link
Collaborator

Let's add it to https://github.com/EnzymeAD/Reactant.jl/blob/mg/tt-plugin/docs/src/index.md#select-an-accelerator-backend as well (maybe with TensTorrent (Experimental))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants