@@ -198,7 +198,7 @@ const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
198198 ( "tessellation_evaluation" , TessellationEvaluation ) ,
199199 ( "geometry" , Geometry ) ,
200200 ( "fragment" , Fragment ) ,
201- ( "gl_compute " , GLCompute ) ,
201+ ( "compute " , GLCompute ) ,
202202 ( "kernel" , Kernel ) ,
203203 ( "task_nv" , TaskNV ) ,
204204 ( "mesh_nv" , MeshNV ) ,
@@ -218,6 +218,7 @@ enum ExecutionModeExtraDim {
218218 X ,
219219 Y ,
220220 Z ,
221+ Tuple ,
221222}
222223
223224const EXECUTION_MODES : & [ ( & str , ExecutionMode , ExecutionModeExtraDim ) ] = {
@@ -240,9 +241,7 @@ const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
240241 ( "depth_greater" , DepthGreater , None ) ,
241242 ( "depth_less" , DepthLess , None ) ,
242243 ( "depth_unchanged" , DepthUnchanged , None ) ,
243- ( "local_size_x" , LocalSize , X ) ,
244- ( "local_size_y" , LocalSize , Y ) ,
245- ( "local_size_z" , LocalSize , Z ) ,
244+ ( "threads" , LocalSize , Tuple ) ,
246245 ( "local_size_hint_x" , LocalSizeHint , X ) ,
247246 ( "local_size_hint_y" , LocalSizeHint , Y ) ,
248247 ( "local_size_hint_z" , LocalSizeHint , Z ) ,
@@ -690,6 +689,40 @@ fn parse_attr_int_value(arg: &NestedMetaItem) -> Result<u32, ParseAttrError> {
690689 }
691690}
692691
692+ fn parse_local_size_attr ( arg : & NestedMetaItem ) -> Result < [ u32 ; 3 ] , ParseAttrError > {
693+ let arg = match arg. meta_item ( ) {
694+ Some ( arg) => arg,
695+ None => return Err ( ( arg. span ( ) , "attribute must have value" . to_string ( ) ) ) ,
696+ } ;
697+ match arg. meta_item_list ( ) {
698+ Some ( tuple) if !tuple. is_empty ( ) && tuple. len ( ) < 4 => {
699+ let mut local_size = [ 1 ; 3 ] ;
700+ for ( idx, lit) in tuple. iter ( ) . enumerate ( ) {
701+ match lit. literal ( ) {
702+ Some ( & Lit {
703+ kind : LitKind :: Int ( x, LitIntType :: Unsuffixed ) ,
704+ ..
705+ } ) if x <= u32:: MAX as u128 => local_size[ idx] = x as u32 ,
706+ _ => return Err ( ( lit. span ( ) , "must be a u32 literal" . to_string ( ) ) ) ,
707+ }
708+ }
709+ Ok ( local_size)
710+ }
711+ Some ( tuple) if tuple. is_empty ( ) => Err ( (
712+ arg. span ,
713+ "#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided" . to_string ( ) ,
714+ ) ) ,
715+ Some ( tuple) if tuple. len ( ) > 3 => Err ( (
716+ arg. span ,
717+ "#[spirv(compute(threads(x, y, z)))] is three dimensional" . to_string ( ) ,
718+ ) ) ,
719+ _ => Err ( (
720+ arg. span ,
721+ "#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided" . to_string ( ) ,
722+ ) ) ,
723+ }
724+ }
725+
693726// for a given entry, gather up the additional attributes
694727// in this case ExecutionMode's, some have extra arguments
695728// others are specified with x, y, or z components
@@ -715,30 +748,23 @@ fn parse_entry_attrs(
715748 {
716749 use ExecutionModeExtraDim :: * ;
717750 let val = match extra_dim {
718- None => Option :: None ,
751+ None | Tuple => Option :: None ,
719752 _ => Some ( parse_attr_int_value ( attr) ?) ,
720753 } ;
721754 match execution_mode {
722755 OriginUpperLeft | OriginLowerLeft => {
723756 origin_mode. replace ( * execution_mode) ;
724757 }
725758 LocalSize => {
726- let val = val. unwrap ( ) ;
727759 if local_size. is_none ( ) {
728- local_size. replace ( [ 1 , 1 , 1 ] ) ;
729- }
730- let local_size = local_size. as_mut ( ) . unwrap ( ) ;
731- match extra_dim {
732- X => {
733- local_size[ 0 ] = val;
734- }
735- Y => {
736- local_size[ 1 ] = val;
737- }
738- Z => {
739- local_size[ 2 ] = val;
740- }
741- _ => unreachable ! ( ) ,
760+ local_size. replace ( parse_local_size_attr ( attr) ?) ;
761+ } else {
762+ return Err ( (
763+ attr_name. span ,
764+ String :: from (
765+ "`#[spirv(compute(threads))]` may only be specified once" ,
766+ ) ,
767+ ) ) ;
742768 }
743769 }
744770 LocalSizeHint => {
@@ -838,10 +864,18 @@ fn parse_entry_attrs(
838864 . push ( ( origin_mode, ExecutionModeExtra :: new ( [ ] ) ) ) ;
839865 }
840866 GLCompute => {
841- let local_size = local_size. unwrap_or ( [ 1 , 1 , 1 ] ) ;
842- entry
843- . execution_modes
844- . push ( ( LocalSize , ExecutionModeExtra :: new ( local_size) ) ) ;
867+ if let Some ( local_size) = local_size {
868+ entry
869+ . execution_modes
870+ . push ( ( LocalSize , ExecutionModeExtra :: new ( local_size) ) ) ;
871+ } else {
872+ return Err ( (
873+ arg. span ( ) ,
874+ String :: from (
875+ "The `threads` argument must be specified when using `#[spirv(compute)]`" ,
876+ ) ,
877+ ) ) ;
878+ }
845879 }
846880 Kernel => {
847881 if let Some ( local_size) = local_size {
0 commit comments