88import base64
99from collections import OrderedDict
1010import importlib
11+ import pickle
1112import os
1213import sys
1314
14- try :
15- import dill
16- from dill import load , loads , dumps , dump
17- except ImportError :
18- dill = None
19- from pickle import load , loads , dumps , dump
15+
2016import six
2117
22- from .ds2 import DS2Method , DS2Thread , DS2Variable , DS2Package
18+ from .ds2 import DS2Thread , DS2Variable , DS2Package
2319from .python import ds2_variables
2420
2521
26- def build_wrapper_function (func , variables , array_input , return_msg = True ):
22+ def build_wrapper_function (func , variables , array_input , setup = None ,
23+ return_msg = True ):
2724 """Wraps a function to ensure compatibility when called by PyMAS.
2825
29- PyMAS has strict expectations regarding the format of any function called directly by PyMAS.
30- Isolating the desired function inside a wrapping function provides a simple way to ensure that functions
31- called by PyMAS are compliant.
26+ PyMAS has strict expectations regarding the format of any function called
27+ directly by PyMAS. Isolating the desired function inside a wrapping
28+ function provides a simple way to ensure that functions called by PyMAS
29+ are compliant.
3230
3331 Parameters
3432 ----------
3533 func : function or str
3634 Function name or an instance of Function which will be wrapped
3735 variables : list of DS2Variable
3836 array_input : bool
37+ Whether `variables` should be combined into a single array before passing to `func`
38+ setup : iterable
39+ Python source code lines to be executed during package setup
3940 return_msg : bool
4041
4142 Returns
@@ -45,36 +46,54 @@ def build_wrapper_function(func, variables, array_input, return_msg=True):
4546
4647 Notes
4748 -----
48- The format for the `# Output: ` is very strict. It must be exactly "# Output: <var>, <var>". Any changes to
49- spelling, capitalization, punctuation, or spacing will result in an error when the DS2 code is executed.
49+ The format for the `# Output: ` is very strict. It must be exactly
50+ "# Output: <var>, <var>". Any changes to spelling, capitalization,
51+ punctuation, or spacing will result in an error when the DS2 code is
52+ executed.
5053
5154 """
5255
5356 input_names = [v .name for v in variables if not v .out ]
5457 output_names = [v .name for v in variables if v .out ]
55-
5658 args = input_names
57-
5859 func = func .__name__ if callable (func ) else func
5960
6061 # Statement to execute the function w/ provided parameters
6162 if array_input :
62- func_call = '{}(np.asarray({} ).reshape((1,-1)))' .format (func , ',' .join (args ))
63+ func_call = '{}(np.array([{}] ).reshape((1, -1)))' .format (func , ',' .join (args ))
6364 else :
6465 func_call = '{}({})' .format (func , ',' .join (args ))
6566
6667 # TODO: Verify that # of values returned by wrapped func matches length of output_names
6768 # TODO: cast all return types before returning (DS2 errors out if not exact match)
6869
69- # NOTE: 'Output:' section is required. All return variables must be listed separated by ', '
70- definition = ('def wrapper({}):' .format (', ' .join (args )),
70+ # NOTE: 'Output:' section is required. All return variables must be listed
71+ # separated by ', '
72+
73+ if setup :
74+ header = ('try:' , ) + \
75+ tuple (' ' + line for line in setup ) + \
76+ (' _compile_error = None' ,
77+ 'except Exception as e:' ,
78+ ' _compile_error = e' ,
79+ '' )
80+ else :
81+ header = ('' , )
82+
83+ definition = header + \
84+ ('def wrapper({}):' .format (', ' .join (args )),
7185 ' "Output: {}"' .format (', ' .join (output_names + ['msg' ]) if return_msg
7286 else ', ' .join (output_names )),
7387 ' result = None' ,
7488 ' try:' ,
89+ ' global _compile_error' ,
90+ ' if _compile_error is not None:' ,
91+ ' raise _compile_error' ,
7592 ' msg = ""' if return_msg else '' ,
7693 ' import numpy as np' ,
77- ' result = float({})' .format (func_call ),
94+ ' result = {}' .format (func_call ),
95+ ' if result.size == 1:' ,
96+ ' result = np.asscalar(result)' ,
7897 ' except Exception as e:' ,
7998 ' msg = str(e)' if return_msg else '' ,
8099 ' if result is None:' ,
@@ -110,12 +129,14 @@ def from_inline(func, input_types=None, array_input=False, return_code=True, ret
110129
111130 """
112131
113- obj = dumps (func )
132+ obj = pickle . dumps (func )
114133 return from_pickle (obj , None , input_types , array_input , return_code , return_message )
115134
116135
117- def from_python_file (file , func_name = None , input_types = None , array_input = False , return_code = True , return_message = True ):
118- """ Creates a PyMAS wrapper to execute a function defined in an external .py file.
136+ def from_python_file (file , func_name = None , input_types = None , array_input = False ,
137+ return_code = True , return_message = True ):
138+ """Creates a PyMAS wrapper to execute a function defined in an
139+ external .py file.
119140
120141 Parameters
121142 ----------
@@ -127,7 +148,8 @@ def from_python_file(file, func_name=None, input_types=None, array_input=False,
127148 The expected type for each input value of the target function.
128149 Can be ommitted if target function includes type hints.
129150 array_input : bool
130- Whether the function inputs should be treated as an array instead of individual parameters
151+ Whether the function inputs should be treated as an array instead of
152+ individual parameters
131153 return_code : bool
132154 Whether the DS2-generated return code should be included
133155 return_message : bool
@@ -156,29 +178,34 @@ def from_python_file(file, func_name=None, input_types=None, array_input=False,
156178 target_func = getattr (module , func_name )
157179
158180 if not callable (target_func ):
159- raise RuntimeError ("Could not find a valid function named {}" .format (func_name ))
181+ raise RuntimeError ("Could not find a valid function named %s"
182+ % func_name )
160183
161184 with open (file , 'r' ) as f :
162185 code = [line .strip ('\n ' ) for line in f .readlines ()]
163186
164- return _build_pymas (target_func , None , input_types , array_input , return_code , return_message , code )
187+ return _build_pymas (target_func , None , input_types , array_input ,
188+ return_code , return_message , code )
165189
166190
167- def from_pickle (file , func_name = None , input_types = None , array_input = False , return_code = True , return_message = True ):
191+ def from_pickle (file , func_name = None , input_types = None , array_input = False ,
192+ return_code = True , return_message = True ):
168193 """Create a deployable DS2 package from a Python pickle file.
169194
170195 Parameters
171196 ----------
172197 file : str or bytes or file_like
173- Pickled object to use. String is assumed to be a path to a picked file, file_like is assumed to be an open
174- file handle to a pickle object, and bytes is assumed to be the raw pickled bytes.
198+ Pickled object to use. String is assumed to be a path to a picked
199+ file, file_like is assumed to be an open file handle to a pickle
200+ object, and bytes is assumed to be the raw pickled bytes.
175201 func_name : str
176202 Name of the target function to call
177203 input_types : list of type, optional
178204 The expected type for each input value of the target function.
179205 Can be ommitted if target function includes type hints.
180206 array_input : bool
181- Whether the function inputs should be treated as an array instead of individual parameters
207+ Whether the function inputs should be treated as an array instead of
208+ individual parameters
182209 return_code : bool
183210 Whether the DS2-generated return code should be included
184211 return_message : bool
@@ -190,39 +217,41 @@ def from_pickle(file, func_name=None, input_types=None, array_input=False, retur
190217 Generated DS2 code which can be executed in a SAS scoring environment
191218
192219 """
193-
194220 try :
195- # In Python2 str could either be a path or the binary pickle data, so check if its a valid filepath too.
221+ # In Python2 str could either be a path or the binary pickle data,
222+ # so check if its a valid filepath too.
196223 is_file_path = isinstance (file , six .string_types ) and os .path .isfile (file )
197224 except TypeError :
198225 is_file_path = False
199226
200227 # Path to a pickle file
201228 if is_file_path :
202229 with open (file , 'rb' ) as f :
203- obj = load (f )
230+ obj = pickle . load (f )
204231
205232 # The actual pickled bytes
206233 elif isinstance (file , bytes ):
207- obj = loads (file )
234+ obj = pickle . loads (file )
208235 else :
209- obj = load (file )
236+ obj = pickle . load (file )
210237
211238 # Encode the pickled data so we can inline it in the DS2 package
212- pkl = base64 .b64encode (dumps (obj ))
239+ pkl = base64 .b64encode (pickle . dumps (obj ))
213240
214- package = 'dill' if dill else 'pickle'
241+ code = ('import pickle, base64' ,
242+ # Replace b' with " before embedding in DS2.
243+ 'bytes = {}' .format (pkl ).replace ("'" , '"' ),
244+ 'obj = pickle.loads(base64.b64decode(bytes))' )
215245
216- code = ('import %s, base64' % package ,
217- 'bytes = {}' .format (pkl ).replace ("'" , '"' ), # Replace b' with " before embedding in DS2.
218- 'obj = %s.loads(base64.b64decode(bytes))' % package )
246+ return _build_pymas (obj , func_name , input_types , array_input , return_code ,
247+ return_message , code )
219248
220- return _build_pymas (obj , func_name , input_types , array_input , return_code , return_message , code )
221249
250+ def _build_pymas (obj , func_name = None , input_types = None , array_input = False ,
251+ return_code = True , return_message = True , code = []):
222252
223- def _build_pymas (obj , func_name = None , input_types = None , array_input = False , return_code = True , return_message = True , code = []):
224-
225- # If the object passed was a function, no need to search for target function
253+ # If the object passed was a function, no need to search for
254+ # target function
226255 if six .callable (obj ) and (func_name is None or obj .__name__ == func_name ):
227256 target_func = obj
228257 elif func_name is None :
@@ -231,19 +260,23 @@ def _build_pymas(obj, func_name=None, input_types=None, array_input=False, retur
231260 target_func = getattr (obj , func_name )
232261
233262 if not callable (target_func ):
234- raise RuntimeError ("Could not find a valid function named {}" .format (func_name ))
263+ raise RuntimeError ("Could not find a valid function named %s"
264+ % func_name )
235265
236266 # Need to create DS2Variable instances to pass to PyMAS
237267 if hasattr (input_types , 'columns' ):
238- # Assuming input is a DataFrame representing model inputs. Use to get input variables
268+ # Assuming input is a DataFrame representing model inputs. Use to
269+ # get input variables
239270 vars = ds2_variables (input_types )
240271
241- # Run one observation through the model and use the result to determine output variables
272+ # Run one observation through the model and use the result to
273+ # determine output variables
242274 output = target_func (input_types .iloc [0 , :].values .reshape ((1 , - 1 )))
243275 output_vars = ds2_variables (output , output_vars = True )
244276 vars .extend (output_vars )
245277 elif isinstance (input_types , type ):
246- params = OrderedDict ([(k , input_types ) for k in target_func .__code__ .co_varnames ])
278+ params = OrderedDict ([(k , input_types )
279+ for k in target_func .__code__ .co_varnames ])
247280 vars = ds2_variables (params )
248281 elif isinstance (input_types , dict ):
249282 vars = ds2_variables (input_types )
@@ -253,22 +286,16 @@ def _build_pymas(obj, func_name=None, input_types=None, array_input=False, retur
253286
254287 target_func = 'obj.' + target_func .__name__
255288
256- # If all inputs should be passed as an array
257- if array_input :
258- first_input = vars [0 ]
259- array_type = first_input .type or 'double'
260- out_vars = [x for x in vars if x .out ]
261- num_inputs = len (vars ) - len (out_vars )
262- vars = [DS2Variable (first_input .name , array_type + '[{}]' .format (num_inputs ), out = False )] + out_vars
263-
264289 if not any ([v for v in vars if v .out ]):
265290 vars .append (DS2Variable (name = 'result' , type = 'float' , out = True ))
266291
267- return PyMAS (target_func , vars , code , return_code , return_message )
292+ return PyMAS (target_func , vars , code , return_code , return_message ,
293+ array_input = array_input )
268294
269295
270296class PyMAS :
271- def __init__ (self , target_function , variables , python_source , return_code = True , return_msg = True ):
297+ def __init__ (self , target_function , variables , python_source ,
298+ return_code = True , return_msg = True , ** kwargs ):
272299 """
273300
274301 Parameters
@@ -282,27 +309,30 @@ def __init__(self, target_function, variables, python_source, return_code=True,
282309 Whether the DS2-generated return code should be included
283310 return_msg : bool
284311 Whether the DS2-generated return message should be included
312+ kwargs : any
313+ Passed to :func:`build_wrapper_function`
285314
286315 """
287316
288317 self .target = target_function
289318
290319 # Any input variable that should be treated as an array
291- array_input = any (v for v in variables if v .is_array )
320+ # array_input = any(v for v in variables if v.is_array)
292321
293322 # Python wrapper function will serve as entrypoint from DS2
294323 self .wrapper = build_wrapper_function (target_function , variables ,
295- array_input , return_msg = return_msg ).split ('\n ' )
324+ setup = python_source ,
325+ return_msg = return_msg ,
326+ ** kwargs ).split ('\n ' )
296327
297328 # Lines of Python code to be embedded in DS2
298- python_source = list (python_source ) + list ( self .wrapper )
329+ python_source = list (self .wrapper )
299330
300331 self .variables = variables
301332 self .return_code = return_code
302333 self .return_message = return_msg
303334
304- self .package = DS2Package ()
305- self .package .methods .append (DS2Method (variables , python_source ))
335+ self .package = DS2Package (variables , python_source , return_code , return_msg )
306336
307337 def score_code (self , input_table = None , output_table = None , columns = None , dest = 'MAS' ):
308338 """Generate DS2 score code
@@ -333,7 +363,7 @@ def score_code(self, input_table=None, output_table=None, columns=None, dest='MA
333363 raise ValueError ('Output table name `{}` is a reserved term.' .format (output_table ))
334364
335365 # Get package code
336- code = ( str ( self .package ), )
366+ code = tuple ( self .package . code (). split ( ' \n ' ) )
337367
338368 if dest == 'ESP' :
339369 code = ('data sasep.out;' , ) + code + (' method run();' ,
0 commit comments