Skip to content

Commit 8ee32b8

Browse files
authored
Fix Spark file and cache warnings (#137)
* don't re-upload impls * remove second cache calll * rename sample_df -> transformed_df
1 parent 4bb80ce commit 8ee32b8

File tree

3 files changed

+18
-10
lines changed

3 files changed

+18
-10
lines changed

pkg/workloads/lib/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(self, **kwargs):
113113
self._aggregator_impls = {}
114114
self._model_impls = {}
115115
self._metadatas = {}
116+
self.spark_uploaded_impls = {}
116117

117118
# This affects Tensorflow S3 access
118119
os.environ["AWS_REGION"] = self.cortex_config.get("region", "")

pkg/workloads/spark_job/spark_job.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ def validate_transformers(spark, ctx, cols_to_transform, raw_df):
236236
TEST_DF_SIZE = 100
237237

238238
logger.info("Sanity checking transformers against the first {} samples".format(TEST_DF_SIZE))
239-
sample_df = raw_df.limit(TEST_DF_SIZE).cache()
240-
test_df = raw_df.limit(TEST_DF_SIZE).cache()
239+
transformed_df = raw_df.limit(TEST_DF_SIZE).cache()
240+
test_df = raw_df.limit(TEST_DF_SIZE)
241241

242242
resource_list = sorted([ctx.tf_id_map[f] for f in cols_to_transform], key=lambda r: r["name"])
243243
for transformed_column in resource_list:
@@ -257,17 +257,17 @@ def validate_transformers(spark, ctx, cols_to_transform, raw_df):
257257
logger.info("Transforming {} to {}".format(", ".join(input_cols), tf_name))
258258

259259
spark_util.validate_transformer(tf_name, test_df, ctx, spark)
260-
sample_df = spark_util.transform_column(
261-
transformed_column["name"], sample_df, ctx, spark
260+
transformed_df = spark_util.transform_column(
261+
transformed_column["name"], transformed_df, ctx, spark
262262
)
263263

264-
sample_df.select(tf_name).collect() # run the transformer
265-
show_df(sample_df.select(*input_cols, tf_name), ctx, n=3, sort=False)
264+
transformed_df.select(tf_name).collect() # run the transformer
265+
show_df(transformed_df.select(*input_cols, tf_name), ctx, n=3, sort=False)
266266

267267
for alias in transformed_column["aliases"][1:]:
268268
logger.info("Transforming {} to {}".format(", ".join(input_cols), alias))
269269

270-
display_transform_df = sample_df.withColumn(alias, F.col(tf_name)).select(
270+
display_transform_df = transformed_df.withColumn(alias, F.col(tf_name)).select(
271271
*input_cols, alias
272272
)
273273
show_df(display_transform_df, ctx, n=3, sort=False)

pkg/workloads/spark_job/spark_util.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,11 @@ def extract_inputs(column_name, ctx):
501501

502502
def execute_transform_spark(column_name, df, ctx, spark):
503503
trans_impl, trans_impl_path = ctx.get_transformer_impl(column_name)
504-
spark.sparkContext.addPyFile(trans_impl_path) # Executor pods need this because of the UDF
504+
505+
if trans_impl_path not in ctx.spark_uploaded_impls:
506+
spark.sparkContext.addPyFile(trans_impl_path) # Executor pods need this because of the UDF
507+
ctx.spark_uploaded_impls[trans_impl_path] = True
508+
505509
columns_input_config, impl_args = extract_inputs(column_name, ctx)
506510
try:
507511
return trans_impl.transform_spark(df, columns_input_config, impl_args, column_name)
@@ -513,8 +517,11 @@ def execute_transform_python(column_name, df, ctx, spark, validate=False):
513517
trans_impl, trans_impl_path = ctx.get_transformer_impl(column_name)
514518
columns_input_config, impl_args = extract_inputs(column_name, ctx)
515519

516-
spark.sparkContext.addPyFile(trans_impl_path) # Executor pods need this because of the UDF
517-
# not a dictionary because it is possible that one column may map to multiple input names
520+
if trans_impl_path not in ctx.spark_uploaded_impls:
521+
spark.sparkContext.addPyFile(trans_impl_path) # Executor pods need this because of the UDF
522+
# not a dictionary because it is possible that one column may map to multiple input names
523+
ctx.spark_uploaded_impls[trans_impl_path] = True
524+
518525
required_columns_sorted, columns_input_config_indexed = column_names_to_index(
519526
columns_input_config
520527
)

0 commit comments

Comments
 (0)