Skip to content

Commit 79d177d

Browse files
jwfrommmeta-codesync[bot]
authored andcommitted
Fix workspace allocation for f8f8bf16_rowwise_batched (#5098)
Summary: Pull Request resolved: #5098 X-link: https://github.com/facebookresearch/FBGEMM/pull/2105 X-link: https://github.com/meta-pytorch/MSLK/pull/6 This diff updates the workspace allocation for f8f8bf16_rowwise_batched to make sure its on the proper device. Previously, it could default to using device 0 despite other inputs being on a different gpu. Reviewed By: q10 Differential Revision: D86439655 fbshipit-source-id: c5652c4791b5075103876c8ae76bd65213d6a9cb
1 parent 6d51557 commit 79d177d

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_common.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ at::Tensor f8f8bf16_rowwise_batched_impl(
274274
size_t workspace_size = Gemm::get_workspace_size(arguments);
275275

276276
// Allocate workspace memory
277-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
277+
at::Tensor workspace =
278+
at::empty(workspace_size, XQ.options().dtype(at::kByte));
278279

279280
// Check the problem size is supported or not
280281
cutlass::Status status = gemm.can_implement(arguments);
@@ -283,7 +284,7 @@ at::Tensor f8f8bf16_rowwise_batched_impl(
283284
}
284285

285286
// Initialize CUTLASS kernel with arguments and workspace pointer
286-
status = gemm.initialize(arguments, workspace.get());
287+
status = gemm.initialize(arguments, workspace.data_ptr());
287288
if (status != cutlass::Status::kSuccess) {
288289
throw std::runtime_error("cutlass cannot initialize");
289290
}

0 commit comments

Comments
 (0)