Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions gemma/gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,10 @@ static void GenerateT(const ModelConfig& config,
SetWeightStats(layer, activations, env.ctx);
}

if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: prefill ]\n");
}
const size_t max_gen_steps = PrefillTBatchOrQBatch(
config, runtime_config, weights, activations, qbatch, env, timing_info);
// No-op if the profiler is disabled, but useful to separate prefill and
Expand All @@ -609,6 +613,10 @@ static void GenerateT(const ModelConfig& config,
fprintf(stderr, "\n");
}
env.ctx.profiler.PrintResults();
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: prefill ]\n");
}

hwy::BitSet4096<> non_eos; // indexed by qi

Expand All @@ -621,13 +629,21 @@ static void GenerateT(const ModelConfig& config,
const SampleFunc sample_token =
ChooseSampleFunc(runtime_config, engine, env.ctx);

if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: generate ]\n");
}
timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
Transformer(config, runtime_config, weights, activations, qbatch, env);
SampleAndStream(config, runtime_config, weights, sample_token, activations,
qbatch, env, non_eos, timing_info);
}
timing_info.NotifyGenerateDone();
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: generate ]\n");
}
}

// Same as GenerateT, but uses ContinuousQBatch.
Expand Down Expand Up @@ -733,6 +749,10 @@ void GenerateImageTokensT(const ModelConfig& config,
const ModelConfig vit_config = GetVitConfig(config);
const size_t num_tokens = vit_config.max_seq_len;

if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: image_token_gen ]\n");
}
timing_info.NotifyImageTokenStart();

{
Expand All @@ -755,6 +775,10 @@ void GenerateImageTokensT(const ModelConfig& config,
env.ctx.profiler.PrintResults();

timing_info.NotifyImageTokenDone(num_tokens);
if (timing_info.verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: image_token_gen ]\n");
}
}

// NOLINTNEXTLINE(google-readability-namespace-comments)
Expand Down
10 changes: 10 additions & 0 deletions gemma/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ void Run(const GemmaArgs& args) {

int main(int argc, char** argv) {
gcpp::InternalInit();
int verbosity = 0;
{
// Negligible CPU time.
gcpp::ConsumedArgs consumed(argc, argv);
Expand All @@ -313,8 +314,17 @@ int main(int argc, char** argv) {
// After `HasHelp` so that we print --help even if unconsumed args remain.
consumed.AbortIfUnconsumed();

verbosity = args.inference.verbosity;
gcpp::Run(args);
}
if (verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ BEGIN PHASE: final_stats ]\n");
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
if (verbosity >= 2) {
fflush(stdout);
fprintf(stderr, "\n[ END PHASE: final_stats ]\n");
}
return 0;
}
Loading