Skip to content

Commit af7e12b

Browse files
committed
Set all TensorFlow version directories to "1" (#560)
(cherry picked from commit 0147e52)
1 parent 43847b5 commit af7e12b

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

pkg/operator/workloads/api_workload.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,11 @@ func (aw *APIWorkload) IsFailed(ctx *context.Context) (bool, error) {
234234
}
235235

236236
type downloadContainerArg struct {
237-
From string `json:"from"`
238-
To string `json:"to"`
239-
Unzip bool `json:"unzip"`
240-
ItemName string `json:"item_name"` // name of the item being downloaded, just for logging (if "" nothing will be logged)
237+
From string `json:"from"`
238+
To string `json:"to"`
239+
Unzip bool `json:"unzip"`
240+
ItemName string `json:"item_name"` // name of the item being downloaded, just for logging (if "" nothing will be logged)
241+
TFModelVersionRename string `json:"tf_model_version_rename"` // e.g. passing in /mnt/model/1 will rename /mnt/model/* to /mnt/model/1 only if there is one item in /mnt/model/
241242
}
242243

243244
func tfAPISpec(
@@ -269,10 +270,11 @@ func tfAPISpec(
269270

270271
downloadArgs := []downloadContainerArg{
271272
{
272-
From: ctx.APIs[api.Name].TensorFlow.Model,
273-
To: path.Join(consts.EmptyDirMountPath, "model"),
274-
Unzip: strings.HasSuffix(ctx.APIs[api.Name].TensorFlow.Model, ".zip"),
275-
ItemName: "model",
273+
From: ctx.APIs[api.Name].TensorFlow.Model,
274+
To: path.Join(consts.EmptyDirMountPath, "model"),
275+
Unzip: strings.HasSuffix(ctx.APIs[api.Name].TensorFlow.Model, ".zip"),
276+
ItemName: "model",
277+
TFModelVersionRename: path.Join(consts.EmptyDirMountPath, "model", "1"),
276278
},
277279
}
278280

pkg/workloads/cortex/downloader/download.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,21 @@ def start(args):
3737
logger.info("downloading {} from {}".format(item_name, from_path))
3838
s3_client.download(prefix, to_path)
3939

40-
if download_arg["unzip"]:
40+
if download_arg.get("unzip", False):
4141
if item_name != "":
4242
logger.info("unzipping {}".format(item_name))
4343
util.extract_zip(
4444
os.path.join(to_path, os.path.basename(from_path)), delete_zip_file=True
4545
)
4646

47+
if download_arg.get("tf_model_version_rename", "") != "":
48+
dest = util.trim_suffix(download_arg["tf_model_version_rename"], "/")
49+
dir_path = os.path.dirname(dest)
50+
entries = os.listdir(dir_path)
51+
if len(entries) == 1:
52+
src = os.path.join(dir_path, entries[0])
53+
os.rename(src, dest)
54+
4755

4856
def main():
4957
parser = argparse.ArgumentParser()

0 commit comments

Comments
 (0)