Skip to content

Commit b430273

Browse files
committed
issue/174: Separate getworkspace for rmsnorm
1 parent 02922ce commit b430273

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

src/infiniop/ops/rms_norm/ascend/rms_norm_aclnn.cc

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@ struct Descriptor::Opaque {
1010
aclnnTensorDescriptor_t w;
1111
aclnnTensorDescriptor_t rstd;
1212
size_t workspaceSize;
13+
aclOpExecutor *executor;
1314

1415
~Opaque() {
1516
delete y;
1617
delete x;
1718
delete w;
1819
delete rstd;
20+
21+
aclDestroyAclOpExecutor(executor);
1922
}
2023
};
2124

@@ -62,17 +65,16 @@ infiniStatus_t Descriptor::create(
6265

6366
// Get WorkspaceSize and set executor
6467
CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(epsilon), ty, trstd, &workspace_size, &executor));
68+
aclSetAclOpExecutorRepeatable(executor);
6569

6670
auto handle_ascend = reinterpret_cast<device::ascend::Handle *>(handle);
6771
size_t all_workspace_size = workspace_size + rstd->numel() * aclDataTypeSize(rstd->dataType);
6872
*desc_ptr = new Descriptor(
69-
new Opaque{y, x, w, rstd, workspace_size},
73+
new Opaque{y, x, w, rstd, workspace_size, executor},
7074
std::move(info),
7175
all_workspace_size,
7276
handle_ascend->device, handle_ascend->device_id);
7377

74-
aclDestroyAclOpExecutor(executor);
75-
7678
return INFINI_STATUS_SUCCESS;
7779
}
7880

@@ -88,21 +90,16 @@ infiniStatus_t Descriptor::calculate(
8890
auto tx = _opaque->x->tensor;
8991
auto ty = _opaque->y->tensor;
9092
auto trstd = _opaque->rstd->tensor;
91-
size_t workspace_size_ = 0;
92-
aclOpExecutor *executor = nullptr;
93-
94-
CHECK_ACL(aclnnRmsNormGetWorkspaceSize(tx, tw, static_cast<double>(_info.epsilon), ty, trstd, &workspace_size_, &executor));
95-
CHECK_ACL(aclSetAclOpExecutorRepeatable(executor));
9693

9794
void *rstdPtr = (void *)((uint8_t *)workspace + _opaque->workspaceSize);
9895

9996
auto unit = infiniSizeOf(_info.atype);
100-
AclSetTensorAddr(executor, 1, tw, (void *)w);
101-
AclSetTensorAddr(executor, 3, trstd, rstdPtr);
97+
AclSetTensorAddr(_opaque->executor, 1, tw, (void *)w);
98+
AclSetTensorAddr(_opaque->executor, 3, trstd, rstdPtr);
10299
for (size_t i = 0; i < (_info.shape)[0]; ++i) {
103-
AclSetTensorAddr(executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit);
104-
AclSetTensorAddr(executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit);
105-
CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, executor, stream));
100+
AclSetTensorAddr(_opaque->executor, 0, tx, ((char *)x) + i * (_info.x_strides)[0] * unit);
101+
AclSetTensorAddr(_opaque->executor, 2, ty, ((char *)y) + i * (_info.y_strides)[0] * unit);
102+
CHECK_ACL(aclnnRmsNorm(workspace, _opaque->workspaceSize, _opaque->executor, stream));
106103
}
107104
return INFINI_STATUS_SUCCESS;
108105
}

0 commit comments

Comments
 (0)