From a25486c56a404b056d6ccc19d995d743342a36a0 Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Fri, 18 Apr 2025 14:02:31 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=98=87=E8=85=BEEle?= =?UTF-8?q?mentWise=E7=AE=97=E5=AD=90=E7=BB=84=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../elementwise/ascend/elementwise_ascend.h | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 src/infiniop/elementwise/ascend/elementwise_ascend.h diff --git a/src/infiniop/elementwise/ascend/elementwise_ascend.h b/src/infiniop/elementwise/ascend/elementwise_ascend.h new file mode 100644 index 000000000..3b3c9c962 --- /dev/null +++ b/src/infiniop/elementwise/ascend/elementwise_ascend.h @@ -0,0 +1,54 @@ +#ifndef __INFINIOP_ELEMENTWISE_ASCEND_H__ +#define __INFINIOP_ELEMENTWISE_ASCEND_H__ + +#include "../../devices/ascend/common_ascend.h" +#include "../elementwise.h" +#include +#include + +namespace op::elementwise::ascend { +// template +class DeviceImpl final { + struct Opaque; + std::shared_ptr _opaque; + + DeviceImpl(std::shared_ptr opaque) : _opaque(std::move(opaque)) {} + +public: + ~DeviceImpl() = default; + + template + static utils::Result create(Args &&...args); +} + +template && ...), int> = 0> +struct Opaque { + mutable aclOpExecutor *executor; + size_t workspaceSize; + TensorDescs outTensorDesc; + std::tuple inTensorDescs; + + explicit Opaque(aclOpExecutor *exec, size_t wsSize, TensorDescs outDesc, TensorDescs... descs) + : executor(exec), workspaceSize(wsSize), outTensorDesc(outDesc), inTensorDescs(std::forward(descs)...) {} + + ~Opaque() { + aclDestroyAclOpExecutor(executor); + delete outDesc; + // 遍历元组并释放每个 Tensor 描述符 + std::apply([](auto &&...args) { + (..., (delete args)); + }, + inTensorDescs); + } + + // 获取输出 Tensor 描述符 + template + auto getInTensor() -> decltype(auto) { + return std::get(inTensorDescs); + } +} + +} // namespace op::elementwise::ascend + +#endif // __INFINIOP_ELEMENTWISE_ASCEND_H__