Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/adapters/chrome/background/PredictionManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ export interface PredictorDebugSnapshot {
hasWebGPU: boolean;
initAttemptCount: number;
isGenerating: boolean;
cacheSize: number;
lastFailureAt: number | null;
lastInitStartedAt: number | null;
lastInitDurationMs: number | null;
Expand Down Expand Up @@ -205,7 +204,6 @@ export class PredictionManager {
logger.info("Clearing predictor debug traces");
this.debugTraces = [];
this.debugTraceById.clear();
this.getWebLLMPredictor().clearCache();
}

getPredictorDebugSnapshot(): PredictorDebugSnapshot {
Expand Down Expand Up @@ -244,7 +242,6 @@ export class PredictionManager {
hasWebGPU: webllmDebugState.hasWebGPU,
initAttemptCount: webllmDebugState.initAttemptCount,
isGenerating: webllmDebugState.isGenerating,
cacheSize: webllmDebugState.cacheSize,
lastFailureAt: webllmDebugState.lastFailureAt,
lastInitStartedAt: webllmDebugState.lastInitStartedAt,
lastInitDurationMs: webllmDebugState.lastInitDurationMs,
Expand Down
35 changes: 3 additions & 32 deletions src/adapters/chrome/background/PresageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ const SUGGESTION_COUNT = 5;
const MIN_WORD_LENGTH_TO_PREDICT = 1;
const logger = createLogger("PresageHandler");

interface LastPrediction {
pastStream: string;
templates: string[];
}

export interface PresageConfig {
numSuggestions: number;
engineNumSuggestions?: number;
Expand Down Expand Up @@ -50,7 +45,7 @@ export interface PresagePredictionContext {

export class PresageHandler {
private presageEngines: Record<string, PresageEngine>;
private lastPrediction: Record<string, LastPrediction>;
private lastPredictionInputByLang: Record<string, string> = {};
private numSuggestions: number;
private minWordLengthToPredict: number;
private predictNextWordAfterSeparatorChar: boolean;
Expand All @@ -74,7 +69,6 @@ export class PresageHandler {
prefixOnlyMode: false,
};
this.presageEngines = {};
this.lastPrediction = {};
this.numSuggestions = SUGGESTION_COUNT;
this.engineNumSuggestions = MAX_NUM_SUGGESTIONS;
this.minWordLengthToPredict = MIN_WORD_LENGTH_TO_PREDICT;
Expand All @@ -93,7 +87,6 @@ export class PresageHandler {
continue;
}
try {
this.lastPrediction[lang] = { pastStream: "", templates: [] };
this.presageEngines[lang] = new PresageEngine(Module, engineConfig, lang);
} catch (error) {
logger.warn("Failed to create Presage engine instance", {
Expand Down Expand Up @@ -128,8 +121,6 @@ export class PresageHandler {
this.dateFormat = config.dateFormat;
this.userDictionaryList = config.userDictionaryList || [];

this.resetLastPredictionState();

if (shouldRefreshEngines) {
this.refreshPresageEngines();
this.textExpansionsSignature = textExpansionsSignature;
Expand Down Expand Up @@ -197,19 +188,8 @@ export class PresageHandler {
this.dateFormat ?? "",
tabId,
);
const cachedPrediction = this.lastPrediction[lang];
if (cachedPrediction?.pastStream === predictionInput) {
return Promise.all(
cachedPrediction.templates.map((text) =>
TemplateExpander.parseStringTemplateAsync(text, resolver),
),
);
}
const predictions = this.presageEngines[lang].predict(predictionInput);
this.lastPrediction[lang] = {
pastStream: predictionInput,
templates: predictions.slice(),
};
this.lastPredictionInputByLang[lang] = predictionInput;
return Promise.all(
predictions.map((text) => TemplateExpander.parseStringTemplateAsync(text, resolver)),
);
Expand Down Expand Up @@ -346,21 +326,12 @@ export class PresageHandler {
}

getLastPredictionInput(lang: string): string {
if (lang in this.lastPrediction) {
return this.lastPrediction[lang].pastStream;
}
return "";
return this.lastPredictionInputByLang[lang] ?? "";
}

private refreshPresageEngines(): void {
for (const presageEngine of Object.values(this.presageEngines)) {
presageEngine.reinitialize();
}
}

private resetLastPredictionState(): void {
for (const lang of Object.keys(this.presageEngines)) {
this.lastPrediction[lang] = { pastStream: "", templates: [] };
}
}
}
18 changes: 0 additions & 18 deletions src/adapters/chrome/background/WebLLMPredictor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import type {
import { CandidateRanker } from "./webllm/CandidateRanker";
import { EngineLifecycleService } from "./webllm/EngineLifecycleService";
import { GenerationCoordinator } from "./webllm/GenerationCoordinator";
import { PredictionCache } from "./webllm/PredictionCache";
import { PromptBuilder } from "./webllm/PromptBuilder";
import { ResponseParser } from "./webllm/ResponseParser";
import { maybePredictFromRuntimeTestOverride } from "@adapters/chrome/background/testing/RuntimeTestHooks";
Expand All @@ -21,7 +20,6 @@ import type {
PredictionResponsePayload,
} from "./webllm/types";

const CACHE_TTL_MS = 5000;
const MAX_GENERATION_CHOICES = 5;
const logger = createLogger("WebLLMPredictor");

Expand All @@ -36,7 +34,6 @@ export interface WebLLMPredictorDebugState {
hasWebGPU: boolean;
initAttemptCount: number;
isGenerating: boolean;
cacheSize: number;
lastFailureAt: number | null;
lastInitStartedAt: number | null;
lastInitDurationMs: number | null;
Expand All @@ -60,7 +57,6 @@ export class WebLLMPredictor implements SecondaryPredictor {
private readonly promptBuilder = new PromptBuilder(MAX_GENERATION_CHOICES);
private readonly responseParser = new ResponseParser();
private readonly candidateRanker = new CandidateRanker();
private readonly predictionCache = new PredictionCache(CACHE_TTL_MS);

private enabled = DEFAULT_AI_PREDICTOR_ENABLED;
private modelId = DEFAULT_AI_MODEL_ID;
Expand Down Expand Up @@ -99,7 +95,6 @@ export class WebLLMPredictor implements SecondaryPredictor {
hasWebGPU: lifecycleState.hasWebGPU,
initAttemptCount: lifecycleState.initAttemptCount,
isGenerating: this.generationCoordinator.getIsGenerating(),
cacheSize: this.predictionCache.size(),
lastFailureAt: lifecycleState.lastFailureAt > 0 ? lifecycleState.lastFailureAt : null,
lastInitStartedAt:
lifecycleState.lastInitStartedAt > 0 ? lifecycleState.lastInitStartedAt : null,
Expand All @@ -122,10 +117,6 @@ export class WebLLMPredictor implements SecondaryPredictor {
};
}

clearCache(): void {
this.predictionCache.clear();
}

preload(): void {
void this.ensureReady();
}
Expand Down Expand Up @@ -171,11 +162,6 @@ export class WebLLMPredictor implements SecondaryPredictor {
}
return testOverridePredictions;
}
const cacheKey = this.predictionCache.getCacheKey(this.modelId, request);
const cachedPredictions = this.predictionCache.get(cacheKey);
if (cachedPredictions) {
return cachedPredictions;
}
const ready = await this.ensureReady();
if (!ready || !this.engineLifecycleService.getEngine() || this.isRequestStale(requestSeq)) {
return [];
Expand Down Expand Up @@ -252,9 +238,6 @@ export class WebLLMPredictor implements SecondaryPredictor {
request.numSuggestions,
);

if (predictions.length > 0) {
this.predictionCache.set(cacheKey, predictions);
}
this.lastPredictDurationMs = Date.now() - predictStartedAt;
this.lastPredictSource = source;
this.lastRawOutputPreview = rawOutput.slice(0, 400);
Expand Down Expand Up @@ -314,7 +297,6 @@ export class WebLLMPredictor implements SecondaryPredictor {
}

private resetEngine(): void {
this.predictionCache.clear();
this.generationCoordinator.advanceGenerationSeq();
this.interruptActiveGeneration("reset");
this.generationCoordinator.clearGenerationTracking();
Expand Down
43 changes: 0 additions & 43 deletions src/adapters/chrome/background/webllm/PredictionCache.ts

This file was deleted.

5 changes: 0 additions & 5 deletions src/ui/options/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1398,11 +1398,6 @@ function renderPredictorDebugSnapshot(root: HTMLElement, snapshot: PredictorSnap
"WebLLM generating",
runtimeWebllm?.isGenerating ? "yes" : "no",
);
appendPredictorInfoItem(
runtimeCard,
"WebLLM cache entries",
formatMetricNumber(runtimeWebllm?.cacheSize),
);
const lastFailureAt = runtimeWebllm?.lastFailureAt;
appendPredictorInfoItem(
runtimeCard,
Expand Down
1 change: 0 additions & 1 deletion tests/ObservabilityService.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ function createPredictorSnapshot() {
hasWebGPU: true,
initAttemptCount: 1,
isGenerating: false,
cacheSize: 0,
lastFailureAt: null,
lastInitStartedAt: null,
lastInitDurationMs: null,
Expand Down
4 changes: 2 additions & 2 deletions tests/PresageHandler.parallel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ describe("PredictionOrchestrator parallel merge", () => {
expect(result.predictions.length).toBeGreaterThan(0);
});

test("re-expands cached templates so random variables stay fresh", async () => {
test("re-expands random variables on every call", async () => {
const predictionsRef = { current: ["${random:alpha|beta}"] };
const { module, predictWithProbability } = createFakeModuleWithSpy(predictionsRef);
const presageHandler = new PresageHandler(module);
Expand All @@ -273,7 +273,7 @@ describe("PredictionOrchestrator parallel merge", () => {

expect(firstResult.predictions).toEqual(["alpha"]);
expect(secondResult.predictions).toEqual(["beta"]);
expect(predictWithProbability).toHaveBeenCalledTimes(1);
expect(predictWithProbability).toHaveBeenCalledTimes(2);
expect(randomSpy).toHaveBeenCalledTimes(2);
});
});
20 changes: 0 additions & 20 deletions tests/WebLLMPredictor.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,26 +134,6 @@ describe("WebLLMPredictor", () => {
expect(engine.completions.create).toHaveBeenCalledTimes(1);
});

test("uses cache for identical requests", async () => {
const engine = createMockEngine();
createMLCEngineMock.mockResolvedValue(engine);
const { WebLLMPredictor } = await import("../src/adapters/chrome/background/WebLLMPredictor");
const predictor = new WebLLMPredictor();

const request = {
lang: "en_US",
predictionInput: "hello ",
numSuggestions: 3,
};

const first = await predictor.predict(request);
const second = await predictor.predict(request);

expect(first).toEqual(second);
expect(engine.chat.completions.create).toHaveBeenCalledTimes(1);
expect(engine.completions.create).toHaveBeenCalledTimes(0);
});

test("parses streamed chat chunks from async iterable responses", async () => {
const engine = createMockEngine({
chatCompletionImpl: async () =>
Expand Down
Loading