@@ -10,12 +10,15 @@ struct Descriptor::Opaque {
1010 aclnnTensorDescriptor_t w;
1111 aclnnTensorDescriptor_t rstd;
1212 size_t workspaceSize;
13+ aclOpExecutor *executor;
1314
1415 ~Opaque () {
1516 delete y;
1617 delete x;
1718 delete w;
1819 delete rstd;
20+
21+ aclDestroyAclOpExecutor (executor);
1922 }
2023};
2124
@@ -62,17 +65,16 @@ infiniStatus_t Descriptor::create(
6265
6366 // Get WorkspaceSize and set executor
6467 CHECK_ACL (aclnnRmsNormGetWorkspaceSize (tx, tw, static_cast <double >(epsilon), ty, trstd, &workspace_size, &executor));
68+ aclSetAclOpExecutorRepeatable (executor);
6569
6670 auto handle_ascend = reinterpret_cast <device::ascend::Handle *>(handle);
6771 size_t all_workspace_size = workspace_size + rstd->numel () * aclDataTypeSize (rstd->dataType );
6872 *desc_ptr = new Descriptor (
69- new Opaque{y, x, w, rstd, workspace_size},
73+ new Opaque{y, x, w, rstd, workspace_size, executor },
7074 std::move (info),
7175 all_workspace_size,
7276 handle_ascend->device , handle_ascend->device_id );
7377
74- aclDestroyAclOpExecutor (executor);
75-
7678 return INFINI_STATUS_SUCCESS;
7779}
7880
@@ -88,21 +90,16 @@ infiniStatus_t Descriptor::calculate(
8890 auto tx = _opaque->x ->tensor ;
8991 auto ty = _opaque->y ->tensor ;
9092 auto trstd = _opaque->rstd ->tensor ;
91- size_t workspace_size_ = 0 ;
92- aclOpExecutor *executor = nullptr ;
93-
94- CHECK_ACL (aclnnRmsNormGetWorkspaceSize (tx, tw, static_cast <double >(_info.epsilon ), ty, trstd, &workspace_size_, &executor));
95- CHECK_ACL (aclSetAclOpExecutorRepeatable (executor));
9693
9794 void *rstdPtr = (void *)((uint8_t *)workspace + _opaque->workspaceSize );
9895
9996 auto unit = infiniSizeOf (_info.atype );
100- AclSetTensorAddr (executor, 1 , tw, (void *)w);
101- AclSetTensorAddr (executor, 3 , trstd, rstdPtr);
97+ AclSetTensorAddr (_opaque-> executor , 1 , tw, (void *)w);
98+ AclSetTensorAddr (_opaque-> executor , 3 , trstd, rstdPtr);
10299 for (size_t i = 0 ; i < (_info.shape )[0 ]; ++i) {
103- AclSetTensorAddr (executor, 0 , tx, ((char *)x) + i * (_info.x_strides )[0 ] * unit);
104- AclSetTensorAddr (executor, 2 , ty, ((char *)y) + i * (_info.y_strides )[0 ] * unit);
105- CHECK_ACL (aclnnRmsNorm (workspace, _opaque->workspaceSize , executor, stream));
100+ AclSetTensorAddr (_opaque-> executor , 0 , tx, ((char *)x) + i * (_info.x_strides )[0 ] * unit);
101+ AclSetTensorAddr (_opaque-> executor , 2 , ty, ((char *)y) + i * (_info.y_strides )[0 ] * unit);
102+ CHECK_ACL (aclnnRmsNorm (workspace, _opaque->workspaceSize , _opaque-> executor , stream));
106103 }
107104 return INFINI_STATUS_SUCCESS;
108105}
0 commit comments