diff --git a/support/environment.go b/support/environment.go index fb4cd6a..abb2f1d 100644 --- a/support/environment.go +++ b/support/environment.go @@ -95,7 +95,7 @@ func GetRayTorchROCmImage() string { } func GetPyTorchImage() string { - return lookupEnvOrDefault(CodeFlareTestPyTorchImage, "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime") + return lookupEnvOrDefault(CodeFlareTestPyTorchImage, "pytorch/pytorch:2.4.1-cuda11.8-cudnn9-runtime") } func GetCudaTrainingImage() string {