1313# limitations under the License.
1414
1515
16+ import os
1617from absl import flags
18+ import jax
1719from jetstream_pt .environment import QuantizationConfig
1820
1921FLAGS = flags .FLAGS
154156 "page size per page" ,
155157)
156158flags .DEFINE_string (
157- "jax_compilation_cache_dir " ,
159+ "internal_jax_compilation_cache_dir " ,
158160 "~/jax_cache" ,
159161 "Jax compilation cache directory" ,
160162)
161163flags .DEFINE_integer (
162- "jax_persistent_cache_min_entry_size_bytes " ,
164+ "internal_jax_persistent_cache_min_entry_size_bytes " ,
163165 0 ,
164166 "Minimum size (in bytes) of an entry that will be cached in the persistent compilation cache" ,
165167)
166168flags .DEFINE_integer (
167- "jax_persistent_cache_min_compile_time_secs " ,
169+ "internal_jax_persistent_cache_min_compile_time_secs " ,
168170 1 ,
169171 "Minimum compilation time for a computation to be written to persistent cache" ,
170172)
@@ -190,3 +192,19 @@ def create_quantization_config_from_flags():
190192 else FLAGS .quantize_weights
191193 )
192194 return config
195+
196+
197+ def set_jax_compilation_cache_config ():
198+ """Sets the jax compilation cache configuration"""
199+ jax .config .update (
200+ "jax_compilation_cache_dir" ,
201+ os .path .expanduser (FLAGS .internal_jax_compilation_cache_dir ),
202+ )
203+ jax .config .update (
204+ "jax_persistent_cache_min_entry_size_bytes" ,
205+ FLAGS .internal_jax_persistent_cache_min_entry_size_bytes ,
206+ )
207+ jax .config .update (
208+ "jax_persistent_cache_min_compile_time_secs" ,
209+ FLAGS .internal_jax_persistent_cache_min_compile_time_secs ,
210+ )
0 commit comments