Skip to content

Commit 8b0f041

Browse files
authored
Update input, add estimators (#154)
1 parent 7ad1c38 commit 8b0f041

File tree

209 files changed

+10109
-5964
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

209 files changed

+10109
-5964
lines changed

cli/cmd/get.go

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"github.com/cortexlabs/cortex/pkg/lib/errors"
2828
"github.com/cortexlabs/cortex/pkg/lib/json"
2929
"github.com/cortexlabs/cortex/pkg/lib/msgpack"
30+
"github.com/cortexlabs/cortex/pkg/lib/sets/strset"
3031
s "github.com/cortexlabs/cortex/pkg/lib/strings"
3132
libtime "github.com/cortexlabs/cortex/pkg/lib/time"
3233
"github.com/cortexlabs/cortex/pkg/lib/urls"
@@ -381,7 +382,6 @@ func describeAPI(name string, resourcesRes *schema.GetResourcesResponse) (string
381382

382383
ctx := resourcesRes.Context
383384
api := ctx.APIs[name]
384-
model := ctx.Models[api.ModelName]
385385

386386
var staleReplicas int32
387387
var ctxAPIStatus *resource.APIStatus
@@ -411,18 +411,29 @@ func describeAPI(name string, resourcesRes *schema.GetResourcesResponse) (string
411411
}
412412

413413
out += titleStr("Endpoint")
414-
var samplePlaceholderFields []string
415-
for _, colName := range ctx.RawColumnInputNames(model) {
416-
column := ctx.GetColumn(colName)
417-
fieldStr := `"` + colName + `": ` + column.GetType().JSONPlaceholder()
418-
samplePlaceholderFields = append(samplePlaceholderFields, fieldStr)
419-
}
420-
samplesPlaceholderStr := `{ "samples": [ { ` + strings.Join(samplePlaceholderFields, ", ") + " } ] }"
421414
out += "URL: " + urls.Join(resourcesRes.APIsBaseURL, anyAPIStatus.Path) + "\n"
422415
out += "Method: POST\n"
423416
out += `Header: "Content-Type: application/json"` + "\n"
424-
out += "Payload: " + samplesPlaceholderStr + "\n"
425417

418+
if api.Model != nil {
419+
model := ctx.Models[api.ModelName]
420+
resIDs := strset.New()
421+
combinedInput := []interface{}{model.Input, model.TrainingInput}
422+
for _, res := range ctx.ExtractCortexResources(combinedInput, resource.ConstantType, resource.RawColumnType, resource.AggregateType, resource.TransformedColumnType) {
423+
resIDs.Add(res.GetID())
424+
resIDs.Merge(ctx.AllComputedResourceDependencies(res.GetID()))
425+
}
426+
var samplePlaceholderFields []string
427+
for rawColumnName, rawColumn := range ctx.RawColumns {
428+
if resIDs.Has(rawColumn.GetID()) {
429+
fieldStr := fmt.Sprintf("\"%s\": %s", rawColumnName, rawColumn.GetColumnType().JSONPlaceholder())
430+
samplePlaceholderFields = append(samplePlaceholderFields, fieldStr)
431+
}
432+
}
433+
sort.Strings(samplePlaceholderFields)
434+
samplesPlaceholderStr := `{ "samples": [ { ` + strings.Join(samplePlaceholderFields, ", ") + " } ] }"
435+
out += "Payload: " + samplesPlaceholderStr + "\n"
436+
}
426437
if api != nil {
427438
out += resourceStr(api.API)
428439
}

cli/cmd/init.go

Lines changed: 96 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ func appInitFiles(appName string) map[string]string {
9191
# csv_config:
9292
# header: true
9393
# schema:
94-
# - column1
95-
# - column2
96-
# - column3
97-
# - label
94+
# - @column1
95+
# - @column2
96+
# - @column3
97+
# - @label
9898
`,
9999

100100
"resources/raw_columns.yaml": `## Sample raw columns:
@@ -125,45 +125,37 @@ func appInitFiles(appName string) map[string]string {
125125
# - kind: aggregate
126126
# name: column1_bucket_boundaries
127127
# aggregator: cortex.bucket_boundaries
128-
# inputs:
129-
# columns:
130-
# col: column1
131-
# args:
132-
# num_buckets: 3
128+
# input:
129+
# col: @column1
130+
# num_buckets: 3
133131
`,
134132

135133
"resources/transformed_columns.yaml": `## Sample transformed columns:
136134
#
137135
# - kind: transformed_column
138136
# name: column1_bucketized
139137
# transformer: cortex.bucketize # Cortex provided transformer in pkg/transformers
140-
# inputs:
141-
# columns:
142-
# num: column1
143-
# args:
144-
# bucket_boundaries: column2_bucket_boundaries
138+
# input:
139+
# num: @column1
140+
# bucket_boundaries: @column2_bucket_boundaries
145141
#
146142
# - kind: transformed_column
147143
# name: column2_transformed
148-
# transformer: my_transformer # Your own custom transformer from the transformers folder
144+
# transformer: my_transformer # Your own custom transformer
149145
# inputs:
150-
# columns:
151-
# num: column2
152-
# args:
153-
# arg1: 10
154-
# arg2: 100
146+
# col: @column2
147+
# arg1: 10
148+
# arg2: 100
155149
`,
156150

157151
"resources/models.yaml": `## Sample model:
158152
#
159153
# - kind: model
160-
# name: my_model
161-
# type: classification
162-
# target_column: label
163-
# feature_columns:
164-
# - column1
165-
# - column2
166-
# - column3
154+
# name: dnn
155+
# estimator: cortex.dnn_classifier
156+
# target_column: @class
157+
# input:
158+
# numeric_columns: [@column1, @column2, @column3]
167159
# hparams:
168160
# hidden_units: [4, 2]
169161
# data_partition_ratio:
@@ -178,7 +170,7 @@ func appInitFiles(appName string) map[string]string {
178170
#
179171
# - kind: api
180172
# name: my-api
181-
# model_name: my_model
173+
# model: @my_model
182174
# compute:
183175
# replicas: 1
184176
`,
@@ -204,26 +196,25 @@ def create_estimator(run_config, model_config):
204196
run_config: An instance of tf.estimator.RunConfig to be used when creating
205197
the estimator.
206198
207-
model_config: The Cortex configuration for the model.
208-
Note: nested resources are expanded (e.g. model_config["target_column"])
209-
will be the configuration for the target column, rather than the
210-
name of the target column).
199+
model_config: The Cortex configuration for the model. Column references in all
200+
inputs (i.e. model_config["target_column"], model_config["input"], and
201+
model_config["training_input"]) are replaced by their names (e.g. "@column1"
202+
will be replaced with "column1"). All other resource references (e.g. constants
203+
and aggregates) are replaced by their runtime values.
211204
212205
Returns:
213206
An instance of tf.estimator.Estimator to train the model.
214207
"""
215208
216209
## Sample create_estimator implementation:
217210
#
218-
# feature_columns = [
219-
# tf.feature_column.numeric_column("column1"),
220-
# tf.feature_column.indicator_column(
221-
# tf.feature_column.categorical_column_with_identity("column2", num_buckets=3)
222-
# ),
223-
# ]
211+
# feature_columns = []
212+
# for col_name in model_config["input"]["numeric_columns"]:
213+
# feature_columns.append(tf.feature_column.numeric_column(col_name))
224214
#
225-
# return tf.estimator.DNNRegressor(
215+
# return tf.estimator.DNNClassifier(
226216
# feature_columns=feature_columns,
217+
# n_classes=model_config["input"]["num_classes"],
227218
# hidden_units=model_config["hparams"]["hidden_units"],
228219
# config=run_config,
229220
# )
@@ -235,7 +226,6 @@ def create_estimator(run_config, model_config):
235226
#
236227
# - kind: constant
237228
# name: my_constant
238-
# type: [INT]
239229
# value: [0, 50, 100]
240230
`,
241231

@@ -244,38 +234,34 @@ def create_estimator(run_config, model_config):
244234
# - kind: aggregator
245235
# name: my_aggregator
246236
# output_type: [FLOAT]
247-
# inputs:
248-
# columns:
249-
# column1: FLOAT_COLUMN|INT_COLUMN
250-
# args:
251-
# arg1: INT
237+
# input:
238+
# column1: FLOAT_COLUMN|INT_COLUMN
239+
# arg1: INT
252240
`,
253241

254-
"implementations/aggregators/my_aggregator.py": `def aggregate_spark(data, columns, args):
242+
"implementations/aggregators/my_aggregator.py": `def aggregate_spark(data, input):
255243
"""Aggregate a column in a PySpark context.
256244
257245
This function is required.
258246
259247
Args:
260248
data: A dataframe including all of the raw columns.
261249
262-
columns: A dict with the same structure as the aggregator's input
263-
columns specifying the names of the dataframe's columns that
264-
contain the input columns.
265-
266-
args: A dict with the same structure as the aggregator's input args
267-
containing the runtime values of the args.
250+
input: The aggregate's input object. Column references in the input are
251+
replaced by their names (e.g. "@column1" will be replaced with "column1"),
252+
and all other resource references (e.g. constants) are replaced by their
253+
runtime values.
268254
269255
Returns:
270-
Any json-serializable object that matches the data type of the aggregator.
256+
Any serializable object that matches the output type of the aggregator.
271257
"""
272258
273259
## Sample aggregate_spark implementation:
274260
#
275261
# from pyspark.ml.feature import QuantileDiscretizer
276262
#
277263
# discretizer = QuantileDiscretizer(
278-
# numBuckets=args["num_buckets"], inputCol=columns["col"], outputCol="_"
264+
# numBuckets=input["num_buckets"], inputCol=input["col"], outputCol="_"
279265
# ).fit(data)
280266
#
281267
# return discretizer.getSplits()
@@ -288,28 +274,24 @@ def create_estimator(run_config, model_config):
288274
# - kind: transformer
289275
# name: my_transformer
290276
# output_type: INT_COLUMN
291-
# inputs:
292-
# columns:
293-
# column1: INT_COLUMN|FLOAT_COLUMN
294-
# args:
295-
# arg1: FLOAT
296-
# arg2: FLOAT
277+
# input:
278+
# column1: INT_COLUMN|FLOAT_COLUMN
279+
# arg1: FLOAT
280+
# arg2: FLOAT
297281
`,
298282

299-
"implementations/transformers/my_transformer.py": `def transform_spark(data, columns, args, transformed_column_name):
283+
"implementations/transformers/my_transformer.py": `def transform_spark(data, input, transformed_column_name):
300284
"""Transform a column in a PySpark context.
301285
302286
This function is optional (recommended for large-scale data processing).
303287
304288
Args:
305289
data: A dataframe including all of the raw columns.
306290
307-
columns: A dict with the same structure as the transformer's input
308-
columns specifying the names of the dataframe's columns that
309-
contain the input columns.
310-
311-
args: A dict with the same structure as the transformer's input args
312-
containing the runtime values of the args.
291+
input: The transformed column's input object. Column references in the input are
292+
replaced by their names (e.g. "@column1" will be replaced with "column1"),
293+
and all other resource references (e.g. constants and aggregates) are replaced
294+
by their runtime values.
313295
314296
transformed_column_name: The name of the column containing the transformed
315297
data that is to be appended to the dataframe.
@@ -322,36 +304,35 @@ def create_estimator(run_config, model_config):
322304
## Sample transform_spark implementation:
323305
#
324306
# return data.withColumn(
325-
# transformed_column_name, ((data[columns["num"]] - args["mean"]) / args["stddev"])
307+
# transformed_column_name, ((data[input["col"]] - input["mean"]) / input["stddev"])
326308
# )
327309
328310
pass
329311
330312
331-
def transform_python(sample, args):
313+
def transform_python(input):
332314
"""Transform a single data sample outside of a PySpark context.
333315
334-
This function is required.
316+
This function is required for any columns that are used during inference.
335317
336318
Args:
337-
sample: A dict with the same structure as the transformer's input
338-
columns containing a data sample to transform.
339-
340-
args: A dict with the same structure as the transformer's input args
341-
containing the runtime values of the args.
319+
input: The transformed column's input object. Column references in the input are
320+
replaced by their values in the sample (e.g. "@column1" will be replaced with
321+
the value for column1), and all other resource references (e.g. constants and
322+
aggregates) are replaced by their runtime values.
342323
343324
Returns:
344325
The transformed value.
345326
"""
346327
347328
## Sample transform_python implementation:
348329
#
349-
# return (sample["num"] - args["mean"]) / args["stddev"]
330+
# return (input["col"] - input["mean"]) / input["stddev"]
350331
351332
pass
352333
353334
354-
def reverse_transform_python(transformed_value, args):
335+
def reverse_transform_python(transformed_value, input):
355336
"""Reverse transform a single data sample outside of a PySpark context.
356337
357338
This function is optional, and only relevant for certain one-to-one
@@ -360,16 +341,51 @@ def reverse_transform_python(transformed_value, args):
360341
Args:
361342
transformed_value: The transformed data value.
362343
363-
args: A dict with the same structure as the transformer's input args
364-
containing the runtime values of the args.
344+
input: The transformed column's input object. Column references in the input are
345+
replaced by their names (e.g. "@column1" will be replaced with "column1"),
346+
and all other resource references (e.g. constants and aggregates) are replaced
347+
by their runtime values.
365348
366349
Returns:
367350
The raw data value that corresponds to the transformed value.
368351
"""
369352
370353
## Sample reverse_transform_python implementation:
371354
#
372-
# return args["mean"] + (transformed_value * args["stddev"])
355+
# return input["mean"] + (transformed_value * input["stddev"])
356+
357+
pass
358+
`,
359+
360+
"implementations/estimators/my_estimator.py": `def create_estimator(run_config, model_config):
361+
"""Create an estimator to train the model.
362+
363+
Args:
364+
run_config: An instance of tf.estimator.RunConfig to be used when creating
365+
the estimator.
366+
367+
model_config: The Cortex configuration for the model. Column references in all
368+
inputs (i.e. model_config["target_column"], model_config["input"], and
369+
model_config["training_input"]) are replaced by their names (e.g. "@column1"
370+
will be replaced with "column1"). All other resource references (e.g. constants
371+
and aggregates) are replaced by their runtime values.
372+
373+
Returns:
374+
An instance of tf.estimator.Estimator to train the model.
375+
"""
376+
377+
## Sample create_estimator implementation:
378+
#
379+
# feature_columns = []
380+
# for col_name in model_config["input"]["numeric_columns"]:
381+
# feature_columns.append(tf.feature_column.numeric_column(col_name))
382+
#
383+
# return tf.estimator.DNNClassifier(
384+
# feature_columns=feature_columns,
385+
# n_classes=model_config["input"]["num_classes"],
386+
# hidden_units=model_config["hparams"]["hidden_units"],
387+
# config=run_config,
388+
# )
373389
374390
pass
375391
`,

0 commit comments

Comments
 (0)