@@ -129,37 +129,22 @@ void add_permute_node(
129129 std::vector<PushConstantDataInfo> push_constants;
130130 vkapi::SpecVarList spec_vars;
131131
132- if (graph.is_buffer_storage (out)) {
133- param_buffers.append (graph.sizes_ubo (in));
134- param_buffers.append (graph.strides_ubo (out));
135- param_buffers.append (graph.numel_ubo (out));
136-
137- // Buffer storage - use permute_buffer shader
138- push_constants = {
139- graph.strides_pc_of (in),
140- PushConstantDataInfo (&whcn_permute_dims, sizeof (whcn_permute_dims)),
141- };
142-
143- spec_vars = {graph.hashed_layout_of (out), graph.hashed_layout_of (in)};
144- } else {
145- // Texture storage - use permute_texture shader
146- const int32_t out_channels = dim_at<kChannel4D >(graph.sizes_of (out));
147- const int32_t in_channels = dim_at<kChannel4D >(graph.sizes_of (in));
148-
149- const int32_t packed_dim = graph.packed_dim_of (in);
150- ivec2 channel_info = {out_channels, in_channels};
151- if (packed_dim == WHCN::kChannelsDim ) {
152- channel_info[0 ] = utils::align_up_4 (channel_info[0 ]);
153- channel_info[1 ] = utils::align_up_4 (channel_info[1 ]);
154- }
132+ const int32_t out_channels = dim_at<kChannel4D >(graph.sizes_of (out));
133+ const int32_t in_channels = dim_at<kChannel4D >(graph.sizes_of (in));
134+
135+ const int32_t packed_dim = graph.packed_dim_of (in);
136+ ivec2 channel_info = {out_channels, in_channels};
137+ if (packed_dim == WHCN::kChannelsDim ) {
138+ channel_info[0 ] = utils::align_up_4 (channel_info[0 ]);
139+ channel_info[1 ] = utils::align_up_4 (channel_info[1 ]);
140+ }
155141
156- push_constants = {
157- graph.sizes_pc_of (out),
158- graph.sizes_pc_of (in),
159- PushConstantDataInfo (&whcn_permute_dims, sizeof (whcn_permute_dims))};
142+ push_constants = {
143+ graph.sizes_pc_of (out),
144+ graph.sizes_pc_of (in),
145+ PushConstantDataInfo (&whcn_permute_dims, sizeof (whcn_permute_dims))};
160146
161- spec_vars = {graph.hashed_layout_of (out), graph.hashed_layout_of (in)};
162- }
147+ spec_vars = {graph.hashed_layout_of (out), graph.hashed_layout_of (in)};
163148
164149 graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
165150 graph,
@@ -179,8 +164,83 @@ void add_permute_node(
179164 resize_permute_node));
180165}
181166
167+ struct WHCNPermuteDims {
168+ int32_t whcn_permute_dims[api::kTensorDimLimit ];
169+
170+ void initialize (const std::vector<int64_t >& permute_dims) {
171+ const int32_t permute_ndim = permute_dims.size ();
172+ for (int32_t whcn_i = 0 ; whcn_i < permute_ndim; whcn_i++) {
173+ const int32_t nchw_i = permute_ndim - 1 - whcn_i;
174+ int64_t index_val = permute_dims.at (nchw_i);
175+ if (index_val < 0 ) {
176+ index_val += permute_ndim;
177+ }
178+ const int32_t permute_dim_whcn = permute_ndim - 1 - index_val;
179+ whcn_permute_dims[whcn_i] = permute_dim_whcn;
180+ }
181+ for (int32_t whcn_i = permute_ndim; whcn_i < api::kTensorDimLimit ;
182+ whcn_i++) {
183+ whcn_permute_dims[whcn_i] = whcn_i;
184+ }
185+ }
186+ };
187+
188+ void add_permute_buffer_node (
189+ ComputeGraph& graph,
190+ const ValueRef in,
191+ const ValueRef permute_dims,
192+ const ValueRef out) {
193+ check_args (graph, in, permute_dims, out);
194+
195+ WHCNPermuteDims whcn_permute_dims;
196+ // Convert the permute dims to WHCN dimension order, which is the standard in
197+ // our compute shaders. The following transformations are applied.
198+ // 1. Change dimension index values from NCHW order valueto WHCN order value
199+ // 2. Extend the permute array to kTensorDimLimit
200+ {
201+ IntListPtr permute_dims_ptr = graph.get_int_list (permute_dims);
202+ whcn_permute_dims.initialize (*permute_dims_ptr);
203+ }
204+
205+ std::string kernel_name = " permute" ;
206+ kernel_name.reserve (kShaderNameReserve );
207+ add_storage_type_suffix (kernel_name, graph.storage_type_of (out));
208+ add_dtype_suffix (kernel_name, graph.dtype_of (out));
209+
210+ vkapi::ParamsBindList param_buffers = {
211+ graph.buffer_meta_ubo (out),
212+ graph.buffer_meta_ubo (in),
213+ graph.create_params_buffer (whcn_permute_dims),
214+ };
215+
216+ graph.execute_nodes ().emplace_back (new DynamicDispatchNode (
217+ graph,
218+ VK_KERNEL_FROM_STR (kernel_name),
219+ default_pick_global_wg_size,
220+ default_pick_local_wg_size,
221+ {{out, vkapi::kWrite }, {in, vkapi::kRead }},
222+ // Parameter buffers
223+ param_buffers,
224+ // Push Constants
225+ {},
226+ // Specialization Constants
227+ {},
228+ // Resize Args
229+ {permute_dims},
230+ // Resizing Logic
231+ resize_permute_node));
232+ }
233+
182234void permute (ComputeGraph& graph, const std::vector<ValueRef>& args) {
183- return add_permute_node (graph, args[0 ], args[1 ], args[2 ]);
235+ int idx = 0 ;
236+ const ValueRef in = args.at (idx++);
237+ const ValueRef permute_dims = args.at (idx++);
238+ const ValueRef out = args.at (idx++);
239+
240+ if (graph.is_buffer_storage (args[2 ])) {
241+ return add_permute_buffer_node (graph, in, permute_dims, out);
242+ }
243+ return add_permute_node (graph, in, permute_dims, out);
184244}
185245
186246REGISTER_OPERATORS {
0 commit comments