diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index f14139b890..b6b9148a86 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -134,8 +134,13 @@ function __init__() ) @debug "XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[] maxlog = 1 - if XLA_REACTANT_GPU_MEM_FRACTION[] > 1 || XLA_REACTANT_GPU_MEM_FRACTION[] < 0 - error("XLA_REACTANT_GPU_MEM_FRACTION must be between 0 and 1") + if XLA_REACTANT_GPU_MEM_FRACTION[] < 0 + error("XLA_REACTANT_GPU_MEM_FRACTION must be not be negative") + elseif XLA_REACTANT_GPU_MEM_FRACTION[] > 1 + if get(ENV, "TF_FORCE_UNIFIED_MEMORY", "0") != "1" + error("XLA_REACTANT_GPU_MEM_FRACTION must be not greater than 1 without \ + TF_FORCE_UNIFIED_MEMORY=1") + end end end