Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions evals/benchmark_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,11 @@ QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(
return true;
};
runtime_config_.batch_stream_token = batch_stream_token;
if (runtime_config_.verbosity >= 2) {
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
runtime_config_.max_generated_tokens, runtime_config_.temperature,
runtime_config_.prefill_tbatch_size,
runtime_config_.decode_qbatch_size);
}
MaybePrint(runtime_config_.verbosity,
"Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
runtime_config_.max_generated_tokens, runtime_config_.temperature,
runtime_config_.prefill_tbatch_size,
runtime_config_.decode_qbatch_size);

// Ensure we have at least one KVCache per query.
while (kv_caches_.size() < num_queries) {
Expand Down Expand Up @@ -223,8 +222,6 @@ static constexpr const char* CompiledConfig() {
return "tsan";
} else if constexpr (HWY_IS_HWASAN) {
return "hwasan";
} else if constexpr (HWY_IS_UBSAN) {
return "ubsan";
} else if constexpr (HWY_IS_DEBUG_BUILD) {
return "dbg";
} else {
Expand All @@ -245,6 +242,7 @@ void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
WeightsPtrs::ToString(weight_read_mode));

if (args.inference.verbosity >= 2) {
// (fprintf instead of MaybePrint due to local variables)
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
char cpu100[100] = "unknown";
Expand Down
28 changes: 4 additions & 24 deletions gemma/gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -601,10 +601,7 @@ static void GenerateT(const ModelConfig& config,
SetWeightStats(layer, activations, env.ctx);
}

if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: prefill ]\n");
}
MaybePrint(timing_info.verbosity, "[ BEGIN PHASE: prefill ]");
const size_t max_gen_steps = PrefillTBatchOrQBatch(
config, runtime_config, weights, activations, qbatch, env, timing_info);
// No-op if the profiler is disabled, but useful to separate prefill and
Expand All @@ -613,10 +610,6 @@ static void GenerateT(const ModelConfig& config,
fprintf(stderr, "\n");
}
env.ctx.profiler.PrintResults();
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: prefill ]\n");
}

hwy::BitSet4096<> non_eos; // indexed by qi

Expand All @@ -629,21 +622,15 @@ static void GenerateT(const ModelConfig& config,
const SampleFunc sample_token =
ChooseSampleFunc(runtime_config, engine, env.ctx);

if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: generate ]\n");
}
MaybePrint(timing_info.verbosity, "\n[ BEGIN PHASE: generate ]\n");

timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
Transformer(config, runtime_config, weights, activations, qbatch, env);
SampleAndStream(config, runtime_config, weights, sample_token, activations,
qbatch, env, non_eos, timing_info);
}
timing_info.NotifyGenerateDone();
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: generate ]\n");
}
}

// Same as GenerateT, but uses ContinuousQBatch.
Expand Down Expand Up @@ -749,10 +736,7 @@ void GenerateImageTokensT(const ModelConfig& config,
const ModelConfig vit_config = GetVitConfig(config);
const size_t num_tokens = vit_config.max_seq_len;

if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: image_token_gen ]\n");
}
MaybePrint(timing_info.verbosity, "\n[ BEGIN PHASE: image_token_gen ]\n");
timing_info.NotifyImageTokenStart();

{
Expand All @@ -775,10 +759,6 @@ void GenerateImageTokensT(const ModelConfig& config,
env.ctx.profiler.PrintResults();

timing_info.NotifyImageTokenDone(num_tokens);
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: image_token_gen ]\n");
}
}

// NOLINTNEXTLINE(google-readability-namespace-comments)
Expand Down
49 changes: 21 additions & 28 deletions gemma/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,31 +258,32 @@ void Run(const GemmaArgs& args) {
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);

if (inference.verbosity >= 1) {
std::string instructions =
"*Usage*\n"
" Enter an instruction and press enter (%C resets conversation, "
"%Q quits).\n";
const std::string multiturn =
inference.multiturn == 0
? std::string(
" Since multiturn is set to 0, conversation will "
"automatically reset every turn.\n\n")
: "\n";
const std::string examples =
"*Examples*\n"
" - Write an email to grandma thanking her for the cookies.\n"
" - What are some historical attractions to visit around "
"Massachusetts?\n"
" - Compute the nth fibonacci number in javascript.\n"
" - Write a standup comedy bit about GPU programming.\n";
instructions += multiturn;
instructions += examples;
ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);

// Skip the banner and instructions in non-interactive mode
if (inference.IsInteractive()) {
std::string instructions =
"*Usage*\n"
" Enter an instruction and press enter (%C resets conversation, "
"%Q quits).\n";
const std::string multiturn =
inference.multiturn == 0
? std::string(
" Since multiturn is set to 0, conversation will "
"automatically reset every turn.\n\n")
: "\n";
const std::string examples =
"*Examples*\n"
" - Write an email to grandma thanking her for the cookies.\n"
" - What are some historical attractions to visit around "
"Massachusetts?\n"
" - Compute the nth fibonacci number in javascript.\n"
" - Write a standup comedy bit about GPU programming.\n";
instructions += multiturn;
instructions += examples;

std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n";
ShowConfig(args, gemma.Config(), gemma.WeightReadMode(), ctx);
std::cout << "\n" << instructions << "\n";
}
}
Expand Down Expand Up @@ -317,14 +318,6 @@ int main(int argc, char** argv) {
verbosity = args.inference.verbosity;
gcpp::Run(args);
}
if (verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: final_stats ]\n");
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
if (verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: final_stats ]\n");
}
return 0;
}
12 changes: 12 additions & 0 deletions util/basics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,27 @@

#include "util/basics.h"

#include <stdarg.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>

#include "hwy/contrib/sort/vqsort.h"
#include "hwy/highway.h"
#include "hwy/timer.h"

namespace gcpp {

void MaybePrint(int verbosity, const char* format, ...) {
char buf[800];
va_list args;
va_start(args, format);
vsnprintf(buf, sizeof(buf), format, args);
va_end(args);

fprintf(stderr, "%s\n", buf); // \n ensures flush.
}

AesCtrEngine::AesCtrEngine(bool deterministic) {
// Pi-based nothing up my sleeve numbers from Randen.
key_[0] = 0x243F6A8885A308D3ull;
Expand Down
3 changes: 3 additions & 0 deletions util/basics.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ static inline intptr_t MaybeTestInitialized(const void* ptr, size_t size) {
#endif
}

// If verbosity >= 2, prints the formatted message to stderr.
void MaybePrint(int verbosity, const char* format, ...);

// Shared between gemma.h and ops-inl.h.
#pragma pack(push, 1)
struct TokenAndProb {
Expand Down
Loading