@@ -186,23 +186,23 @@ defmodule EXLA.MLIR.Value do
186186
187187 def reverse ( % Value { function: func } = operand , dims , typespec ) do
188188 result_types = typespecs_to_mlir_types ( [ typespec ] )
189- attributes = [ dimensions: attr_dense_i64_elements ( dims ) ]
189+ attributes = [ dimensions: attr_array_i64_elements ( dims ) ]
190190 op ( func , "stablehlo.reverse" , [ operand ] , result_types , attributes: attributes ) |> one! ( )
191191 end
192192
193193 def transpose ( % Value { function: func } = operand , axes , typespec ) do
194194 result_types = typespecs_to_mlir_types ( [ typespec ] )
195- attributes = [ permutation: attr_dense_i64_elements ( axes ) ]
195+ attributes = [ permutation: attr_array_i64_elements ( axes ) ]
196196 op ( func , "stablehlo.transpose" , [ operand ] , result_types , attributes: attributes ) |> one! ( )
197197 end
198198
199199 def slice ( % Value { function: func } = operand , starts , limits , strides , typespec ) do
200200 result_types = typespecs_to_mlir_types ( [ typespec ] )
201201
202202 attributes = [
203- start_indices: attr_dense_i64_elements ( starts ) ,
204- limit_indices: attr_dense_i64_elements ( limits ) ,
205- strides: attr_dense_i64_elements ( strides )
203+ start_indices: attr_array_i64_elements ( starts ) ,
204+ limit_indices: attr_array_i64_elements ( limits ) ,
205+ strides: attr_array_i64_elements ( strides )
206206 ]
207207
208208 op ( func , "stablehlo.slice" , [ operand ] , result_types , attributes: attributes ) |> one! ( )
@@ -211,7 +211,7 @@ defmodule EXLA.MLIR.Value do
211211 def dynamic_slice ( % Value { function: func } = operand , starts , lengths , typespec ) do
212212 result_types = typespecs_to_mlir_types ( [ typespec ] )
213213 operands = [ operand ] ++ starts
214- attributes = [ slice_sizes: attr_dense_i64_elements ( lengths ) ]
214+ attributes = [ slice_sizes: attr_array_i64_elements ( lengths ) ]
215215 op ( func , "stablehlo.dynamic_slice" , operands , result_types , attributes: attributes ) |> one! ( )
216216 end
217217
@@ -303,7 +303,7 @@ defmodule EXLA.MLIR.Value do
303303 result_types = typespecs_to_mlir_types ( [ typespec ] )
304304
305305 attributes = [
306- broadcast_dimensions: attr_dense_i64_elements ( axes )
306+ broadcast_dimensions: attr_array_i64_elements ( axes )
307307 ]
308308
309309 op ( func , "stablehlo.broadcast_in_dim" , [ operand ] , result_types , attributes: attributes )
@@ -347,9 +347,9 @@ defmodule EXLA.MLIR.Value do
347347 { padding_low , padding_high , padding_mid } = unzip_padding_config ( padding_config )
348348
349349 attributes = [
350- edge_padding_low: attr_dense_i64_elements ( padding_low ) ,
351- edge_padding_high: attr_dense_i64_elements ( padding_high ) ,
352- interior_padding: attr_dense_i64_elements ( padding_mid )
350+ edge_padding_low: attr_array_i64_elements ( padding_low ) ,
351+ edge_padding_high: attr_array_i64_elements ( padding_high ) ,
352+ interior_padding: attr_array_i64_elements ( padding_mid )
353353 ]
354354
355355 op ( func , "stablehlo.pad" , [ operand , pad ] , result_types , attributes: attributes ) |> one! ( )
@@ -375,7 +375,7 @@ defmodule EXLA.MLIR.Value do
375375
376376 attributes = [
377377 fft_type: fft_type ,
378- fft_length: attr_dense_i64_elements ( List . wrap ( fft_length ) )
378+ fft_length: attr_array_i64_elements ( List . wrap ( fft_length ) )
379379 ]
380380
381381 op ( func , "stablehlo.fft" , [ value ] , result_types , attributes: attributes ) |> one! ( )
@@ -451,8 +451,8 @@ defmodule EXLA.MLIR.Value do
451451 result_types = typespecs_to_mlir_types ( [ typespec ] )
452452
453453 attributes = [
454- window_dimensions: attr_dense_i64_elements ( window_dimensions ) ,
455- window_strides: attr_dense_i64_elements ( window_strides ) ,
454+ window_dimensions: attr_array_i64_elements ( window_dimensions ) ,
455+ window_strides: attr_array_i64_elements ( window_strides ) ,
456456 padding: attr_padding ( padding )
457457 ]
458458
@@ -501,7 +501,7 @@ defmodule EXLA.MLIR.Value do
501501
502502 attributes = [
503503 dimension_numbers: dimension_numbers ,
504- slice_sizes: attr_dense_i64_elements ( slice_sizes ) ,
504+ slice_sizes: attr_array_i64_elements ( slice_sizes ) ,
505505 indices_are_sorted: attr_boolean ( false )
506506 ]
507507
@@ -546,10 +546,10 @@ defmodule EXLA.MLIR.Value do
546546 attr_precision_config = attr_precision_config ( precision_config )
547547
548548 attributes = [
549- window_strides: attr_dense_i64_elements ( strides ) ,
549+ window_strides: attr_array_i64_elements ( strides ) ,
550550 padding: attr_padding ( padding ) ,
551- lhs_dilation: attr_dense_i64_elements ( input_dilation ) ,
552- rhs_dilation: attr_dense_i64_elements ( kernel_dilation ) ,
551+ lhs_dilation: attr_array_i64_elements ( input_dilation ) ,
552+ rhs_dilation: attr_array_i64_elements ( kernel_dilation ) ,
553553 dimension_numbers: attr_conv_dimension_numbers ( dimension_numbers ) ,
554554 feature_group_count: attr_i64 ( feature_group_count ) ,
555555 batch_group_count: attr_i64 ( batch_group_count ) ,
@@ -625,7 +625,7 @@ defmodule EXLA.MLIR.Value do
625625 ) do
626626 operands = inputs ++ init_values
627627 result_types = typespecs_to_mlir_types ( typespecs )
628- attributes = [ dimensions: attr_dense_i64_elements ( dimensions ) ]
628+ attributes = [ dimensions: attr_array_i64_elements ( dimensions ) ]
629629 regions = [ reducer ]
630630 op ( func , "stablehlo.reduce" , operands , result_types , attributes: attributes , regions: regions )
631631 end
@@ -645,10 +645,10 @@ defmodule EXLA.MLIR.Value do
645645 result_types = typespecs_to_mlir_types ( typespecs )
646646
647647 attributes = [
648- window_dimensions: attr_dense_i64_elements ( window_dimensions ) ,
649- window_strides: attr_dense_i64_elements ( window_strides ) ,
650- base_dilations: attr_dense_i64_elements ( input_dilations ) ,
651- window_dilations: attr_dense_i64_elements ( window_dilations ) ,
648+ window_dimensions: attr_array_i64_elements ( window_dimensions ) ,
649+ window_strides: attr_array_i64_elements ( window_strides ) ,
650+ base_dilations: attr_array_i64_elements ( input_dilations ) ,
651+ window_dilations: attr_array_i64_elements ( window_dilations ) ,
652652 padding: attr_padding ( padding )
653653 ]
654654
@@ -669,7 +669,7 @@ defmodule EXLA.MLIR.Value do
669669 result_types = typespecs_to_mlir_types ( [ typespec ] )
670670
671671 attributes = [
672- dimensions: attr_dense_i64_elements ( dimensions )
672+ dimensions: attr_array_i64_elements ( dimensions )
673673 ]
674674
675675 regions = [ mapper ]
@@ -904,8 +904,12 @@ defmodule EXLA.MLIR.Value do
904904 << value :: size ( size ) - big >>
905905 end
906906
907- defp attr_dense_i64_elements ( list ) do
908- attr_dense_elements ( list , { :s , 64 } , { length ( list ) } )
907+ defp attr_array_i64_elements ( [ ] ) do
908+ "array<i64>"
909+ end
910+
911+ defp attr_array_i64_elements ( list ) do
912+ "array<i64: #{ Enum . join ( list , ", " ) } >"
909913 end
910914
911915 defp attr_dense_elements ( [ ] , type , { 0 } = shape ) do
0 commit comments