Skip to content
This repository was archived by the owner on Aug 28, 2025. It is now read-only.

Commit 3aa8b17

Browse files
committed
inference
1 parent 252f35b commit 3aa8b17

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

templates/titanic/tutorial.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@
8989
# %%
9090
df_test = pd.read_csv(csv_test)
9191

92-
predictions = model.predict(csv_test)
92+
dm = TabularClassificationData.from_data_frame(
93+
predict_data_frame=df_test,
94+
parameters=datamodule.parameters,
95+
)
96+
predictions = trainer.predict(model, datamodule=dm, output="classes")
9397
print(predictions[0])
9498

9599
# %%

0 commit comments

Comments
 (0)