11#!/usr/bin/env python3
2- # Copyright (c) 2023 Triad National Security, LLC. All rights reserved.
2+ # Copyright (c) 2023 Triad National Security, LLC. All rights reserved.
3+ # Copyright (c) 2023 Research Organization for Information Science
4+ # and Technology (RIST). All rights reserved.
35# $COPYRIGHT$
46#
57# Additional copyrights may follow
2325import argparse
2426import re
2527import sys
26- import uuid
28+ import os
2729
2830# C type: const int
2931ERROR_CLASSES = [
@@ -1042,6 +1044,30 @@ def need_bigcount(self):
10421044 return any ('COUNT' in param .type_ for param in self .params )
10431045
10441046
1047+ class TemplateParseError (Exception ):
1048+ """Error raised during parsing."""
1049+ pass
1050+
1051+
1052+ def validate_body (body ):
1053+ """Validate the body of a template."""
1054+ # Just do a simple bracket balance test determine the bounds of the
1055+ # function body. All lines after the function body should be blank. There
1056+ # are cases where this will break, such as if someone puts code all on one
1057+ # line.
1058+ bracket_balance = 0
1059+ line_count = 0
1060+ for line in body :
1061+ line = line .strip ()
1062+ if bracket_balance == 0 and line_count > 0 and line :
1063+ raise TemplateParserError ('Extra code found in template; only one function body is allowed' )
1064+
1065+ update = line .count ('{' ) - line .count ('}' )
1066+ bracket_balance += update
1067+ if bracket_balance != 0 :
1068+ line_count += 1
1069+
1070+
10451071class SourceTemplate :
10461072 """Source template for a single API function."""
10471073
@@ -1051,8 +1077,10 @@ def __init__(self, prototype, header, body):
10511077 self .body = body
10521078
10531079 @staticmethod
1054- def load (fname ):
1080+ def load (fname , prefix = None ):
10551081 """Load a template file and return the SourceTemplate."""
1082+ if prefix is not None :
1083+ fname = os .path .join (prefix , fname )
10561084 with open (fname ) as fp :
10571085 header = []
10581086 prototype = []
@@ -1061,11 +1089,12 @@ def load(fname):
10611089 for line in fp :
10621090 line = line .rstrip ()
10631091 if prototype and line .startswith ('PROTOTYPE' ):
1064- raise RuntimeError ('more than one prototype found in template file' )
1092+ raise TemplateParseError ('more than one prototype found in template file' )
10651093 elif ((prototype and not any (')' in s for s in prototype ))
10661094 or line .startswith ('PROTOTYPE' )):
10671095 prototype .append (line )
10681096 elif prototype :
1097+ # Validate bracket balance
10691098 body .append (line )
10701099 else :
10711100 header .append (line )
@@ -1082,6 +1111,8 @@ def load(fname):
10821111 params = [param .strip () for param in prototype [i + 1 :j ].split (',' ) if param .strip ()]
10831112 params = [Parameter (param ) for param in params ]
10841113 prototype = Prototype (name , return_type , params )
1114+ # Ensure the body contains only one function
1115+ validate_body (body )
10851116 return SourceTemplate (prototype , header , body )
10861117
10871118 def print_header (self , file = sys .stdout ):
@@ -1148,7 +1179,7 @@ def standard_abi(base_name, template):
11481179 print (f'#include "{ ABI_INTERNAL_HEADER } "' )
11491180
11501181 # Static internal function (add a random component to avoid conflicts)
1151- internal_name = f'ompi_ { template .prototype .name } _ { uuid . uuid4 (). hex [: 10 ] } '
1182+ internal_name = f'ompi_abi_ { template .prototype .name } '
11521183 internal_sig = template .prototype .signature ('ompi' , internal_name ,
11531184 count_type = 'MPI_Count' )
11541185 print ('static inline' , internal_sig )
@@ -1190,7 +1221,7 @@ def generate_function(prototype, fn_name, internal_fn, count_type='int'):
11901221
11911222def gen_header (args ):
11921223 """Generate an ABI header and conversion code."""
1193- prototypes = [SourceTemplate .load (file_ ).prototype for file_ in args .file ]
1224+ prototypes = [SourceTemplate .load (file_ , args . srcdir ).prototype for file_ in args .file ]
11941225
11951226 builder = ABIHeaderBuilder (prototypes , external = args .external )
11961227 builder .dump_header ()
@@ -1219,6 +1250,7 @@ def main():
12191250 parser_header = subparsers .add_parser ('header' )
12201251 parser_header .add_argument ('file' , nargs = '+' , help = 'list of template source files' )
12211252 parser_header .add_argument ('--external' , action = 'store_true' , help = 'generate external mpi.h header file' )
1253+ parser_header .add_argument ('--srcdir' , help = 'source directory' )
12221254 parser_header .set_defaults (func = gen_header )
12231255
12241256 parser_gen = subparsers .add_parser ('source' )
0 commit comments