diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index ad6bd08d..6e5c0cbf 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -122,12 +122,11 @@ QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics( return true; }; runtime_config_.batch_stream_token = batch_stream_token; - if (runtime_config_.verbosity >= 2) { - fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n", - runtime_config_.max_generated_tokens, runtime_config_.temperature, - runtime_config_.prefill_tbatch_size, - runtime_config_.decode_qbatch_size); - } + MaybePrint(runtime_config_.verbosity, + "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n", + runtime_config_.max_generated_tokens, runtime_config_.temperature, + runtime_config_.prefill_tbatch_size, + runtime_config_.decode_qbatch_size); // Ensure we have at least one KVCache per query. while (kv_caches_.size() < num_queries) { @@ -223,8 +222,6 @@ static constexpr const char* CompiledConfig() { return "tsan"; } else if constexpr (HWY_IS_HWASAN) { return "hwasan"; - } else if constexpr (HWY_IS_UBSAN) { - return "ubsan"; } else if constexpr (HWY_IS_DEBUG_BUILD) { return "dbg"; } else { @@ -245,6 +242,7 @@ void ShowConfig(const GemmaArgs& args, const ModelConfig& config, WeightsPtrs::ToString(weight_read_mode)); if (args.inference.verbosity >= 2) { + // (fprintf instead of MaybePrint due to local variables) time_t now = time(nullptr); char* dt = ctime(&now); // NOLINT char cpu100[100] = "unknown"; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index db41fd03..39b352b9 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -601,10 +601,7 @@ static void GenerateT(const ModelConfig& config, SetWeightStats(layer, activations, env.ctx); } - if (timing_info.verbosity >= 2) { - fflush(stdout); - fprintf(stderr, "\n[ BEGIN PHASE: prefill ]\n"); - } + MaybePrint(timing_info.verbosity, "[ BEGIN PHASE: prefill ]"); const size_t max_gen_steps = PrefillTBatchOrQBatch( config, runtime_config, weights, activations, qbatch, env, timing_info); // No-op if the profiler is disabled, but useful to separate prefill and @@ -613,10 +610,6 @@ static void GenerateT(const ModelConfig& config, fprintf(stderr, "\n"); } env.ctx.profiler.PrintResults(); - if (timing_info.verbosity >= 2) { - fflush(stdout); - fprintf(stderr, "\n[ END PHASE: prefill ]\n"); - } hwy::BitSet4096<> non_eos; // indexed by qi @@ -629,10 +622,8 @@ static void GenerateT(const ModelConfig& config, const SampleFunc sample_token = ChooseSampleFunc(runtime_config, engine, env.ctx); - if (timing_info.verbosity >= 2) { - fflush(stdout); - fprintf(stderr, "\n[ BEGIN PHASE: generate ]\n"); - } + MaybePrint(timing_info.verbosity, "\n[ BEGIN PHASE: generate ]\n"); + timing_info.generate_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { Transformer(config, runtime_config, weights, activations, qbatch, env); @@ -640,10 +631,6 @@ static void GenerateT(const ModelConfig& config, qbatch, env, non_eos, timing_info); } timing_info.NotifyGenerateDone(); - if (timing_info.verbosity >= 2) { - fflush(stdout); - fprintf(stderr, "\n[ END PHASE: generate ]\n"); - } } // Same as GenerateT, but uses ContinuousQBatch. @@ -749,10 +736,7 @@ void GenerateImageTokensT(const ModelConfig& config, const ModelConfig vit_config = GetVitConfig(config); const size_t num_tokens = vit_config.max_seq_len; - if (timing_info.verbosity >= 2) { - fflush(stdout); - fprintf(stderr, "\n[ BEGIN PHASE: image_token_gen ]\n"); - } + MaybePrint(timing_info.verbosity, "\n[ BEGIN PHASE: image_token_gen ]\n"); timing_info.NotifyImageTokenStart(); { @@ -775,10 +759,6 @@ void GenerateImageTokensT(const ModelConfig& config, env.ctx.profiler.PrintResults(); timing_info.NotifyImageTokenDone(num_tokens); - if (timing_info.verbosity >= 2) { - fflush(stdout); - fprintf(stderr, "\n[ END PHASE: image_token_gen ]\n"); - } } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/gemma/run.cc b/gemma/run.cc index 8269ce30..7705fd59 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -258,31 +258,32 @@ void Run(const GemmaArgs& args) { KVCache kv_cache(gemma.Config(), inference, ctx.allocator); if (inference.verbosity >= 1) { - std::string instructions = - "*Usage*\n" - " Enter an instruction and press enter (%C resets conversation, " - "%Q quits).\n"; - const std::string multiturn = - inference.multiturn == 0 - ? std::string( - " Since multiturn is set to 0, conversation will " - "automatically reset every turn.\n\n") - : "\n"; - const std::string examples = - "*Examples*\n" - " - Write an email to grandma thanking her for the cookies.\n" - " - What are some historical attractions to visit around " - "Massachusetts?\n" - " - Compute the nth fibonacci number in javascript.\n" - " - Write a standup comedy bit about GPU programming.\n"; - instructions += multiturn; - instructions += examples; + ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx); // Skip the banner and instructions in non-interactive mode if (inference.IsInteractive()) { + std::string instructions = + "*Usage*\n" + " Enter an instruction and press enter (%C resets conversation, " + "%Q quits).\n"; + const std::string multiturn = + inference.multiturn == 0 + ? std::string( + " Since multiturn is set to 0, conversation will " + "automatically reset every turn.\n\n") + : "\n"; + const std::string examples = + "*Examples*\n" + " - Write an email to grandma thanking her for the cookies.\n" + " - What are some historical attractions to visit around " + "Massachusetts?\n" + " - Compute the nth fibonacci number in javascript.\n" + " - Write a standup comedy bit about GPU programming.\n"; + instructions += multiturn; + instructions += examples; + std::cout << "\033[2J\033[1;1H" // clear screen << kAsciiArtBanner << "\n\n"; - ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx); std::cout << "\n" << instructions << "\n"; } } @@ -317,14 +318,6 @@ int main(int argc, char** argv) { verbosity = args.inference.verbosity; gcpp::Run(args); } - if (verbosity >= 2) { - fflush(stdout); - fprintf(stderr, "\n[ BEGIN PHASE: final_stats ]\n"); - } PROFILER_PRINT_RESULTS(); // Must call outside the zone above. - if (verbosity >= 2) { - fflush(stdout); - fprintf(stderr, "\n[ END PHASE: final_stats ]\n"); - } return 0; } diff --git a/util/basics.cc b/util/basics.cc index d9fbc27d..76c6092f 100644 --- a/util/basics.cc +++ b/util/basics.cc @@ -15,8 +15,10 @@ #include "util/basics.h" +#include #include #include +#include #include "hwy/contrib/sort/vqsort.h" #include "hwy/highway.h" @@ -24,6 +26,16 @@ namespace gcpp { +void MaybePrint(int verbosity, const char* format, ...) { + char buf[800]; + va_list args; + va_start(args, format); + vsnprintf(buf, sizeof(buf), format, args); + va_end(args); + + fprintf(stderr, "%s\n", buf); // \n ensures flush. +} + AesCtrEngine::AesCtrEngine(bool deterministic) { // Pi-based nothing up my sleeve numbers from Randen. key_[0] = 0x243F6A8885A308D3ull; diff --git a/util/basics.h b/util/basics.h index 49996ba4..2fc050eb 100644 --- a/util/basics.h +++ b/util/basics.h @@ -81,6 +81,9 @@ static inline intptr_t MaybeTestInitialized(const void* ptr, size_t size) { #endif } +// If verbosity >= 2, prints the formatted message to stderr. +void MaybePrint(int verbosity, const char* format, ...); + // Shared between gemma.h and ops-inl.h. #pragma pack(push, 1) struct TokenAndProb {