@@ -104,32 +104,27 @@ def train(
104104 save (model_dir , mlp_model )
105105
106106
107- def model_fn ( path_to_model_files ):
108- import neomxnet # noqa: F401
107+ def neo_preprocess ( payload , content_type ):
108+ logging . info ( "Invoking user-defined pre-processing function" )
109109
110- ctx = mx .cpu ()
111- sym , arg_params , aux_params = mx .model .load_checkpoint (
112- os .path .join (path_to_model_files , "compiled" ), 0
113- )
114- mod = mx .mod .Module (symbol = sym , context = ctx , label_names = None )
115- mod .bind (
116- for_training = False , data_shapes = [("data" , (1 , 1 , 28 , 28 ))], label_shapes = mod ._label_shapes
117- )
118- mod .set_params (arg_params , aux_params , allow_missing = True )
119- return mod
110+ if content_type != "application/vnd+python.numpy+binary" :
111+ raise RuntimeError ("Content type must be application/vnd+python.numpy+binary" )
120112
113+ return np .asarray (json .loads (payload .decode ("utf-8" )))
121114
122- def transform_fn (mod , payload , input_content_type , requested_output_content_type ):
123- import neomxnet # noqa: F401
124115
125- if input_content_type != "application/vnd+python.numpy+binary" :
126- raise RuntimeError ("Input content type must be application/vnd+python.numpy+binary" )
116+ # NOTE: this function cannot use MXNet
117+ def neo_postprocess (result ):
118+ logging .info ("Invoking user-defined post-processing function" )
127119
128- inference_payload = np .asarray (json .loads (payload .decode ("utf-8" )))
129- result = mod .predict (inference_payload )
120+ # Softmax (assumes batch size 1)
130121 result = np .squeeze (result )
131- response_body = json .dumps (result .asnumpy ().tolist ())
122+ result_exp = np .exp (result - np .max (result ))
123+ result = result_exp / np .sum (result_exp )
124+
125+ response_body = json .dumps (result .tolist ())
132126 content_type = "application/json"
127+
133128 return response_body , content_type
134129
135130
@@ -140,7 +135,7 @@ def transform_fn(mod, payload, input_content_type, requested_output_content_type
140135 parser = argparse .ArgumentParser ()
141136
142137 parser .add_argument ("--batch-size" , type = int , default = 100 )
143- parser .add_argument ("--epochs" , type = int , default = 1 )
138+ parser .add_argument ("--epochs" , type = int , default = 10 )
144139 parser .add_argument ("--learning-rate" , type = float , default = 0.1 )
145140
146141 parser .add_argument ("--model-dir" , type = str , default = os .environ ["SM_MODEL_DIR" ])
0 commit comments