Skip to content
This repository was archived by the owner on Aug 28, 2025. It is now read-only.

Commit 8f1f0e6

Browse files
committed
chain
1 parent 06cff9d commit 8f1f0e6

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

.actions/assistant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def _bash_download_data(folder: str) -> List[str]:
289289
if ext not in AssistantCLI._EXT_ARCHIVE:
290290
continue
291291
if ext in AssistantCLI._EXT_ARCHIVE_ZIP:
292-
cmd += [f"mkdir -p {name}", f"unzip -o {fn} -d {name} {UNZIP_PROGRESS_BAR}"]
292+
cmd += [f"unzip -o {fn} -d {name} {UNZIP_PROGRESS_BAR}"]
293293
else:
294294
cmd += [f"tar -zxvf {fn} --overwrite"]
295295
cmd += [f"rm {fn}"]

templates/titanic/tutorial.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,14 @@
9494
parameters=datamodule.parameters,
9595
batch_size=datamodule.batch_size,
9696
)
97-
predictions = trainer.predict(model, datamodule=dm, output="classes")
98-
print(predictions[0])
97+
preds = trainer.predict(model, datamodule=dm, output="classes")
98+
print(preds[0][:10])
9999

100100
# %%
101+
import itertools # noqa: E402]
101102
import numpy as np # noqa: E402]
102103

104+
predictions = list(itertools.chain(*preds))
103105
assert len(df_test) == len(predictions)
104106

105107
df_test["Survived"] = np.argmax(predictions, axis=-1)

0 commit comments

Comments
 (0)