@@ -77,19 +77,23 @@ cd ..
7777rm -r ./TransformerEngine
7878
7979# cudnn frontend
80- pip install nvidia-cudnn-cu12==9.5.0.50
80+ pip install nvidia-cudnn-cu12==9.7.1.26
8181CMAKE_ARGS=" -DCMAKE_POLICY_VERSION_MINIMUM=3.5" pip install nvidia-cudnn-frontend
8282python -c " import torch; print('cuDNN version:', torch.backends.cudnn.version());"
8383python -c " from transformer_engine.pytorch.utils import get_cudnn_version; get_cudnn_version()"
8484
85- # Megatron-LM requires flash-attn >= 2.1.1, <= 2.7.3
86- cu=$( nvcc --version | grep " Cuda compilation tools" | awk ' {print $5}' | cut -d ' .' -f 1)
87- torch=$( pip show torch | grep Version | awk ' {print $2}' | cut -d ' +' -f 1 | cut -d ' .' -f 1,2)
88- cp=$( python3 --version | awk ' {print $2}' | awk -F. ' {print $1$2}' )
89- cxx=$( g++ --version | grep ' g++' | awk ' {print $3}' | cut -d ' .' -f 1)
90- wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu${cu} torch${torch} cxx${cxx} abiFALSE-cp${cp} -cp${cp} -linux_x86_64.whl
91- pip install flash_attn-2.7.3+cu${cu} torch${torch} cxx${cxx} abiFALSE-cp${cp} -cp${cp} -linux_x86_64.whl
92- rm flash_attn-2.7.3+cu${cu} torch${torch} cxx${cxx} abiFALSE-cp${cp} -cp${cp} -linux_x86_64.whl
85+ # # Megatron-LM requires flash-attn >= 2.1.1, <= 2.7.3
86+ # cu=$(nvcc --version | grep "Cuda compilation tools" | awk '{print $5}' | cut -d '.' -f 1)
87+ # torch=$(pip show torch | grep Version | awk '{print $2}' | cut -d '+' -f 1 | cut -d '.' -f 1,2)
88+ # cp=$(python3 --version | awk '{print $2}' | awk -F. '{print $1$2}')
89+ # cxx=$(g++ --version | grep 'g++' | awk '{print $3}' | cut -d '.' -f 1)
90+ # wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu${cu}torch${torch}cxx${cxx}abiFALSE-cp${cp}-cp${cp}-linux_x86_64.whl
91+ # pip install flash_attn-2.7.3+cu${cu}torch${torch}cxx${cxx}abiFALSE-cp${cp}-cp${cp}-linux_x86_64.whl
92+ # rm flash_attn-2.7.3+cu${cu}torch${torch}cxx${cxx}abiFALSE-cp${cp}-cp${cp}-linux_x86_64.whl
93+ wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl
94+ pip install flash_attn-2.8.0.post2+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl
95+ rm flash_attn-2.8.0.post2+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl
96+
9397
9498# From Megatron-LM log
9599pip install " git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
@@ -153,7 +157,7 @@ if [ "${env}" == "train" ]; then
153157 fi
154158
155159 # Replace the following code with torch version 2.6.0
156- if [[ $torch_version == * " 2.6.0" * ]]; then
160+ if [[ $torch_version == * " 2.6.0" * ]] || [[ $torch_version == * " 2.7.0 " * ]] ; then
157161 # Check and replace line 908
158162 LINE_908=$( sed -n ' 908p' " $FILE " )
159163 EXPECTED_908=' if num_nodes_waiting > 0:'
0 commit comments