File tree Expand file tree Collapse file tree 8 files changed +45
-31
lines changed Expand file tree Collapse file tree 8 files changed +45
-31
lines changed Original file line number Diff line number Diff line change 5555 pylint --indent-string=' ' jetstream_pt/ benchmarks/
5656 - name : Format check with pyink
5757 run : |
58- pyink --pyink-indentation 2 --line-length 80 --check --verbose .
58+ pyink --pyink-indentation 2 --line-length 80 --check --verbose --extend-exclude=deps .
5959
6060 cpu :
6161 name : " jetstream_pt unit tests"
7979 JAX_PLATFORMS=cpu coverage run -m unittest -v
8080 - name : Create test coverage report
8181 run : |
82- coverage report -m
82+ coverage report -m
83+
84+ interactive :
85+ name : " jetstream_pt run interactive"
86+ strategy :
87+ matrix :
88+ os : [ubuntu-20.04]
89+ python-version : ['3.10']
90+ runs-on : ${{ matrix.os }}
91+ steps :
92+ - name : Checkout
93+ uses : actions/checkout@v4
94+ - name : Setup Python
95+ uses : actions/setup-python@v4
96+ with :
97+ python-version : ${{ matrix.python-version }}
98+ - name : Install Dependencies
99+ run : |
100+ source install_everything.sh
101+ - name : Run interactive (bf16)
102+ run : |
103+ JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=0 --quantize_kv_cache=0
104+ - name : Run interactive (int8)
105+ run : |
106+ JAX_PLATFORMS=cpu python run_interactive.py --size=tiny --batch_size=1 --max_cache_length=2048 --tokenizer_path=jetstream_pt/third_party/llama/tokenizer.model --model_name=llama-2 --sharding_config=default_shardings/llama.yaml --quantize_weights=1 --quantize_kv_cache=1
Original file line number Diff line number Diff line change 1- # source dependencies
2- deps /
3-
41# Byte-compiled / optimized / DLL files
52__pycache__ /
63* .py [cod ]
Original file line number Diff line number Diff line change 1+ [submodule "deps/JetStream "]
2+ path = deps/JetStream
3+ url = https://github.com/google/JetStream.git
4+ [submodule "deps/xla "]
5+ path = deps/xla
6+ url = https://github.com/pytorch/xla.git
Original file line number Diff line number Diff line change 1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- TORCHXLA_TAG=f26c35c2fa5eb1d22d042a2a8a8dc34f11b99f60 # updated May 14, 2024
16- JETSTREAM_TAG=e4952fbb12e0ab3c33bc7c1eef3839b7c2ad0dd4 # updated May 16, 2024
17-
1815# Uninstall existing jax
1916pip show jax && pip uninstall -y jax
2017pip show jaxlib && pip uninstall -y jaxlib
@@ -26,17 +23,5 @@ pip install torch --index-url https://download.pytorch.org/whl/cpu
2623pip install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage
2724pip install safetensors colorama coverage ray[default] humanize
2825
29- mkdir -p deps
30- pushd deps
31- git clone https://github.com/google/JetStream.git
32- git clone https://github.com/pytorch/xla.git
33- pushd xla/experimental/torch_xla2
34- git checkout $TORCHXLA_TAG
35- pip install .
36- popd # now at the folder deps
37- pushd JetStream
38- git checkout $JETSTREAM_TAG
39- pip install .
40- popd # now at the folder deps
41- popd # now at the folder current file
26+ git submodule update --init --recursive
4227pip install -e .
Original file line number Diff line number Diff line change @@ -3,7 +3,7 @@ requires = ["hatchling"]
33build-backend = " hatchling.build"
44
55[project ]
6- version = " 0.2.0 "
6+ version = " 0.2.1 "
77name = " jetstream_pt"
88dependencies = [
99 " absl-py" ,
@@ -14,7 +14,12 @@ dependencies = [
1414 " google-jetstream" ,
1515 " google-cloud-storage" ,
1616 " safetensors" ,
17+ " torch_xla2 @ {root:uri}/deps/xla/experimental/torch_xla2" ,
18+ " google-jetstream @ {root:uri}/deps/JetStream" ,
1719]
1820
1921requires-python = " >=3.10"
2022license = {file = " LICENSE" }
23+
24+ [tool .hatch .metadata ]
25+ allow-direct-references = true
Original file line number Diff line number Diff line change @@ -158,20 +158,15 @@ def main(argv):
158158 decode_state , result_tokens = engine .generate (params , decode_state )
159159 result_tokens = result_tokens .convert_to_numpy ()
160160 res = result_tokens .get_result_at_slot (slot )
161- stop_tokens = set (tokenizer .tokenizer . stop_tokens )
161+ stop_tokens = set (tokenizer .stop_tokens )
162162 stop_tokens .add (tokenizer .pad_id )
163+ token_id = res .tokens [0 ][0 ].item ()
164+ sampled_tokens_list .append (token_id )
163165 if (
164- res . tokens [ 0 ][ 0 ] in stop_tokens
166+ token_id in stop_tokens
165167 or len (sampled_tokens_list ) > max_output_length
166168 ):
167169 break
168- token_id = res .tokens [0 ][0 ]
169- sampled_tokens_list .append (token_id )
170- # output_str = tokenizer.decode_str([token_id])
171- # print(Fore.GREEN + output_str, end="", flush=True)
172-
173- # print(Style.RESET_ALL + "\n")
174- # print("---- Streaming decode finished.")
175170
176171 print ("---- All output tokens." )
177172 print (sampled_tokens_list )
You can’t perform that action at this time.
0 commit comments