@@ -119,11 +119,16 @@ int main(int argc, char** argv) {
119119 parser, " num_iters" , " Number of averaging timing iterations used to select kernels" , {" num-avg-timing-iters" });
120120 args::ValueFlag<uint64_t > workspace_size (
121121 parser, " workspace_size" , " Maximum size of workspace given to TensorRT" , {" workspace-size" });
122- args::ValueFlag<double > threshold (
122+ args::ValueFlag<double > atol (
123123 parser,
124- " threshold" ,
125- " Maximum acceptable numerical deviation from standard torchscript output (default 2e-5)" ,
126- {' t' , " threshold" });
124+ " atol" ,
125+ " Absolute tolerance threshold for acceptable numerical deviation from standard torchscript output (default 1e-8)" ,
126+ {" atol" });
127+ args::ValueFlag<double > rtol (
128+ parser,
129+ " rtol" ,
130+ " Relative tolerance threshold for acceptable numerical deviation from standard torchscript output (default 1e-5)" ,
131+ {" rtol" });
127132
128133 args::Flag no_threshold_check (
129134 parser, " no-threshold-check" , " Skip checking threshold compliance" , {" no-threshold-check" , " no-threshold-check" });
@@ -392,9 +397,13 @@ int main(int argc, char** argv) {
392397 (compile_settings.enabled_precisions .size () == 1 &&
393398 compile_settings.enabled_precisions .find (torchtrt::DataType::kFloat ) !=
394399 compile_settings.enabled_precisions .end ())) {
395- double threshold_val = 2e-5 ;
396- if (threshold) {
397- threshold_val = args::get (threshold);
400+ double atol_val = 1e-8 ;
401+ double rtol_val = 1e-5 ;
402+ if (atol) {
403+ atol_val = args::get (atol);
404+ }
405+ if (rtol) {
406+ rtol_val = args::get (rtol);
398407 }
399408
400409 std::vector<torch::jit::IValue> jit_inputs_ivalues;
@@ -431,14 +440,18 @@ int main(int argc, char** argv) {
431440 }
432441
433442 for (size_t i = 0 ; i < trt_results.size (); i++) {
443+ std::ostringstream threshold_ss;
444+ threshold_ss << " atol: " << atol_val << " rtol: " << rtol_val;
434445 if (!torchtrtc::accuracy::almost_equal (
435- jit_results[i], trt_results[i].reshape_as (jit_results[i]), threshold_val)) {
436- std::ostringstream threshold_ss;
437- threshold_ss << threshold_val;
446+ jit_results[i], trt_results[i].reshape_as (jit_results[i]), atol_val, rtol_val)) {
438447 torchtrt::logging::log (
439448 torchtrt::logging::Level::kWARNING ,
440- std::string (" Maximum numerical deviation for output exceeds set threshold (" ) + threshold_ss.str () +
441- std::string (" )" ));
449+ std::string (" Maximum numerical deviation for output exceeds tolerance thresholds (" ) +
450+ threshold_ss.str () + std::string (" )" ));
451+ } else {
452+ torchtrt::logging::log (
453+ torchtrt::logging::Level::kDEBUG ,
454+ std::string (" Maximum numerical deviation within threshold limits " ) + threshold_ss.str ());
442455 }
443456 }
444457 } else {
0 commit comments