diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index cefa39a57c886..627227ca93b16 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -6,8 +6,17 @@ #include #include +#include #include #include +#include + +// verbosity flag set via the params.verbosity CLI flag. This is used for two +// things: +// 1. If > 0, tensors are printed with 8 digits of precision instead of 5 +// 2. If > 1, all tensor values are printed instead of the pretty-printed +// partial output +static int verbosity = 0; /** * This the arbitrary data which will be passed to each callback. @@ -61,6 +70,10 @@ static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * } static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { + std::stringstream ss; + const int float_digits = verbosity > 0 ? 8 : 4; + ss << "%12." << float_digits << "f"; + const auto float_fmt = ss.str(); GGML_ASSERT(n > 0); float sum = 0; for (int64_t i3 = 0; i3 < ne[3]; i3++) { @@ -93,7 +106,7 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne i0 = ne[0] - n; } const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); - LOG("%12.4f", v); + LOG(float_fmt.c_str(), v); if (i0 < ne[0] - 1) LOG(", "); } LOG("],\n"); @@ -153,8 +166,12 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { } if (!ggml_is_quantized(t->type)) { + // The `--verbose` flag will set verbosity to INT_MAX. We want that to + // be the equivalent of `-lv 1` since it will be the most common command + // used and full-width printing is extremely verbose. + const int print_width = (verbosity > 1 && verbosity < std::numeric_limits::max()) ? std::numeric_limits::max() : 3; uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data(); - ggml_print_tensor(data, t->type, t->ne, t->nb, 3); + ggml_print_tensor(data, t->type, t->ne, t->nb, print_width); } return true; @@ -192,6 +209,9 @@ int main(int argc, char ** argv) { common_init(); + // set verbosity for printing + verbosity = params.verbosity; + llama_backend_init(); llama_numa_init(params.numa);