Skip to content

feature(providers): Added Vertex support#978

Open
machour wants to merge 5 commits intoprism-php:mainfrom
machour:vertex-support
Open

feature(providers): Added Vertex support#978
machour wants to merge 5 commits intoprism-php:mainfrom
machour:vertex-support

Conversation

@machour
Copy link

@machour machour commented Mar 24, 2026

Description

Added support for Vertex

Breaking Changes

None

Disclaimer: This is all Claude Opus 4.5 thinking work, with a small adjustment by copilot to support the "global" region special host

@machour machour mentioned this pull request Mar 24, 2026
@akalongman
Copy link

akalongman commented Mar 25, 2026

@machour this is my implementation what I use with prism in my project, if it can help:

<?php

declare(strict_types=1);

namespace App\Libraries\Ai\Prism;

use Exception;
use Google\ApiCore\ApiException;
use Google\Cloud\AIPlatform\V1\Candidate;
use Google\Cloud\AIPlatform\V1\Candidate\FinishReason;
use Google\Cloud\AIPlatform\V1\Client\PredictionServiceClient;
use Google\Cloud\AIPlatform\V1\Content;
use Google\Cloud\AIPlatform\V1\FunctionCall;
use Google\Cloud\AIPlatform\V1\FunctionDeclaration;
use Google\Cloud\AIPlatform\V1\FunctionResponse;
use Google\Cloud\AIPlatform\V1\GenerateContentRequest;
use Google\Cloud\AIPlatform\V1\GenerateContentResponse;
use Google\Cloud\AIPlatform\V1\GenerateContentResponse\PromptFeedback\BlockedReason;
use Google\Cloud\AIPlatform\V1\GenerationConfig;
use Google\Cloud\AIPlatform\V1\HarmCategory;
use Google\Cloud\AIPlatform\V1\Part;
use Google\Cloud\AIPlatform\V1\Retrieval;
use Google\Cloud\AIPlatform\V1\SafetyRating\HarmProbability;
use Google\Cloud\AIPlatform\V1\Schema;
use Google\Cloud\AIPlatform\V1\Tool as VertexTool;
use Google\Cloud\AIPlatform\V1\Tool\GoogleSearch;
use Google\Cloud\AIPlatform\V1\Type;
use Google\Cloud\AIPlatform\V1\VertexAISearch;
use Google\Protobuf\Struct;
use Prism\Prism\Concerns\CallsTools;
use Prism\Prism\Contracts\Schema as PrismSchema;
use Prism\Prism\Enums\FinishReason as PrismFinishReason;
use Prism\Prism\Providers\Provider;
use Prism\Prism\Text\Request as TextRequest;
use Prism\Prism\Text\Response as TextResponse;
use Prism\Prism\Text\ResponseBuilder;
use Prism\Prism\Text\Step;
use Prism\Prism\Tool;
use Prism\Prism\ValueObjects\Messages\AssistantMessage;
use Prism\Prism\ValueObjects\Messages\ToolResultMessage;
use Prism\Prism\ValueObjects\Meta;
use Prism\Prism\ValueObjects\ProviderTool;
use Prism\Prism\ValueObjects\ToolCall;
use Prism\Prism\ValueObjects\Usage;
use Throwable;

use function array_map;
use function count;
use function file_exists;
use function implode;
use function is_array;
use function json_decode;
use function json_encode;
use function sprintf;

class VertexAiProvider extends Provider
{
    use CallsTools;

    private PredictionServiceClient $client;
    private ResponseBuilder $responseBuilder;

    public function __construct(
        private readonly string $credentialsFile,
        private readonly string $projectId,
        private readonly string $location,
    ) {
        $clientOptions = [];

        if (file_exists($this->credentialsFile)) {
            $clientOptions = ['credentials' => $this->credentialsFile];
        }

        try {
            $this->client = new PredictionServiceClient($clientOptions);
        } catch (Throwable $e) {
            throw new Exception('Failed to initialize Vertex AI client: ' . $e->getMessage(), 0, $e);
        }
    }

    public function text(TextRequest $request): TextResponse
    {
        $this->responseBuilder = new ResponseBuilder();

        return $this->handleRequest($request);
    }

    private function handleRequest(TextRequest $request): TextResponse
    {
        $response = $this->sendRequest($request);

        $isToolCall = $this->hasToolCalls($response);
        $finishReason = $this->resolveFinishReason($response, $isToolCall);

        if ($finishReason === PrismFinishReason::ToolCalls) {
            return $this->handleToolUse($response, $request);
        }

        return $this->handleStop($response, $request, $finishReason);
    }

    private function sendRequest(TextRequest $request): GenerateContentResponse
    {
        $modelEndpoint = PredictionServiceClient::projectLocationPublisherModelName(
            $this->projectId,
            $this->location,
            'google',
            $request->model(),
        );

        $contents = $this->buildContents($request);

        $generateRequest = new GenerateContentRequest();
        $generateRequest->setModel($modelEndpoint);
        $generateRequest->setContents($contents);

        $systemInstruction = $this->buildSystemInstruction($request);
        if ($systemInstruction !== null) {
            $generateRequest->setSystemInstruction($systemInstruction);
        }

        $generationConfig = $this->buildGenerationConfig($request);
        if ($generationConfig !== null) {
            $generateRequest->setGenerationConfig($generationConfig);
        }

        $tools = $this->buildTools($request);
        if (! empty($tools)) {
            $generateRequest->setTools($tools);
        }

        $response = null;

        try {
            $response = $this->client->generateContent($generateRequest, [
                'timeoutMillis' => 120_000,
            ]);

            return $response;
        } catch (ApiException $e) {
            throw new VertexAiException(
                message: sprintf(
                    'Vertex AI request failed: %s (Status: %s, Code: %s)',
                    $e->getMessage(),
                    $e->getStatus(),
                    $e->getCode(),
                ),
                code: $e->getCode(),
                previous: $e,
            );
        } catch (Throwable $e) {
            if ($response !== null) {
                throw VertexAiException::fromResponse(
                    message: $e->getMessage(),
                    response: $response,
                    previous: $e,
                );
            }

            throw new VertexAiException(
                message: $e->getMessage(),
                previous: $e,
            );
        }
    }

    private function handleStop(GenerateContentResponse $response, TextRequest $request, PrismFinishReason $finishReason): TextResponse
    {
        $this->addStep($response, $request, $finishReason);

        return $this->responseBuilder->toResponse();
    }

    private function handleToolUse(GenerateContentResponse $response, TextRequest $request): TextResponse
    {
        $toolCalls = $this->extractToolCalls($response);
        $toolResults = $this->callTools($request->tools(), $toolCalls);
        $text = $this->extractText($response);

        $this->addStep($response, $request, PrismFinishReason::ToolCalls, $toolCalls, $toolResults);

        $request->addMessage(new AssistantMessage($text, $toolCalls));
        $request->addMessage(new ToolResultMessage($toolResults));
        $request->resetToolChoice();

        if ($this->responseBuilder->steps->count() < $request->maxSteps()) {
            return $this->handleRequest($request);
        }

        return $this->responseBuilder->toResponse();
    }

    /**
     * @param array<int, \Prism\Prism\ValueObjects\ToolCall> $toolCalls
     * @param array<int, \Prism\Prism\ValueObjects\ToolResult> $toolResults
     */
    private function addStep(
        GenerateContentResponse $response,
        TextRequest $request,
        PrismFinishReason $finishReason,
        array $toolCalls = [],
        array $toolResults = [],
    ): void {
        $text = $this->extractText($response);
        $additionalContent = $this->extractAdditionalContent($response);

        $usageMetadata = $response->getUsageMetadata();

        $this->responseBuilder->addStep(
            new Step(
                text: $text,
                finishReason: $finishReason,
                toolCalls: $toolCalls,
                toolResults: $toolResults,
                providerToolCalls: [],
                usage: new Usage(
                    promptTokens: $usageMetadata?->getPromptTokenCount() ?: 0,
                    completionTokens: $usageMetadata?->getCandidatesTokenCount() ?: 0,
                    cacheReadInputTokens: $usageMetadata?->getCachedContentTokenCount() ?: null,
                ),
                meta: new Meta(
                    id: '',
                    model: $request->model(),
                ),
                messages: $request->messages(),
                systemPrompts: $request->systemPrompts(),
                additionalContent: $additionalContent,
            ),
        );
    }

    /**
     * @return array<int, \Google\Cloud\AIPlatform\V1\Content>
     */
    private function buildContents(TextRequest $request): array
    {
        $contents = [];

        foreach ($request->messages() as $message) {
            if ($message instanceof AssistantMessage) {
                $contents[] = $this->mapAssistantMessage($message);
            } elseif ($message instanceof ToolResultMessage) {
                $contents[] = $this->mapToolResultMessage($message);
            } else {
                $content = new Content();
                $content->setRole('user');
                $content->setParts([new Part(['text' => $message->content])]);
                $contents[] = $content;
            }
        }

        if ($request->prompt()) {
            $content = new Content();
            $content->setRole('user');
            $content->setParts([new Part(['text' => $request->prompt()])]);
            $contents[] = $content;
        }

        return $contents;
    }

    private function buildSystemInstruction(TextRequest $request): ?Content
    {
        $systemPrompts = $request->systemPrompts();
        if (empty($systemPrompts)) {
            return null;
        }

        $parts = [];
        foreach ($systemPrompts as $systemPrompt) {
            $parts[] = new Part(['text' => $systemPrompt->content]);
        }

        $content = new Content();
        $content->setParts($parts);

        return $content;
    }

    private function mapAssistantMessage(AssistantMessage $message): Content
    {
        $parts = [];

        if ($message->content !== '' && $message->content !== '0') {
            $parts[] = new Part(['text' => $message->content]);
        }

        foreach ($message->toolCalls as $toolCall) {
            $functionCall = new FunctionCall();
            $functionCall->setName($toolCall->name);

            $args = $toolCall->arguments();
            if (count($args) > 0) {
                $struct = new Struct();
                $struct->mergeFromJsonString(json_encode($args));
                $functionCall->setArgs($struct);
            }

            $part = new Part();
            $part->setFunctionCall($functionCall);
            $parts[] = $part;
        }

        $content = new Content();
        $content->setRole('model');
        $content->setParts($parts);

        return $content;
    }

    private function mapToolResultMessage(ToolResultMessage $message): Content
    {
        $parts = [];

        foreach ($message->toolResults as $toolResult) {
            $functionResponse = new FunctionResponse();
            $functionResponse->setName($toolResult->toolName);

            $responseStruct = new Struct();
            $responseData = is_array($toolResult->result)
                ? $toolResult->result
                : ['output' => $toolResult->result];
            $responseStruct->mergeFromJsonString(json_encode($responseData));
            $functionResponse->setResponse($responseStruct);

            $part = new Part();
            $part->setFunctionResponse($functionResponse);
            $parts[] = $part;
        }

        $content = new Content();
        $content->setRole('user');
        $content->setParts($parts);

        return $content;
    }

    private function buildGenerationConfig(TextRequest $request): ?GenerationConfig
    {
        $config = new GenerationConfig();
        $hasValues = false;

        if ($request->temperature() !== null) {
            $config->setTemperature((float) $request->temperature());
            $hasValues = true;
        }

        if ($request->maxTokens() !== null) {
            $config->setMaxOutputTokens($request->maxTokens());
            $hasValues = true;
        }

        if ($request->topP() !== null) {
            $config->setTopP((float) $request->topP());
            $hasValues = true;
        }

        $providerOptions = $request->providerOptions() ?? [];

        if (isset($providerOptions['top_k'])) {
            $config->setTopK((float) $providerOptions['top_k']);
            $hasValues = true;
        }

        if (isset($providerOptions['seed'])) {
            $config->setSeed((int) $providerOptions['seed']);
            $hasValues = true;
        }

        if (isset($providerOptions['response_mime_type'])) {
            $config->setResponseMimeType($providerOptions['response_mime_type']);
            $hasValues = true;
        }

        if (! $hasValues) {
            return null;
        }

        return $config;
    }

    /**
     * @return array<int, \Google\Cloud\AIPlatform\V1\Tool>
     */
    private function buildTools(TextRequest $request): array
    {
        $tools = [];

        foreach ($request->providerTools() as $providerTool) {
            $tool = $this->mapProviderTool($providerTool);

            if ($tool !== null) {
                $tools[] = $tool;
            }
        }

        $prismTools = $request->tools();
        if (! empty($prismTools)) {
            $tool = new VertexTool();
            $tool->setFunctionDeclarations(
                array_map(static fn(Tool $t) => self::mapToolToFunctionDeclaration($t), $prismTools),
            );
            $tools[] = $tool;
        }

        return $tools;
    }

    private static function mapToolToFunctionDeclaration(Tool $tool): FunctionDeclaration
    {
        $declaration = new FunctionDeclaration();
        $declaration->setName($tool->name());
        $declaration->setDescription($tool->description());

        if ($tool->hasParameters()) {
            $schema = new Schema();
            $schema->setType(Type::OBJECT);

            $properties = [];
            foreach ($tool->parameters() as $name => $param) {
                $properties[$name] = self::mapSchemaToProtobuf($param);
            }
            $schema->setProperties($properties);

            $required = $tool->requiredParameters();
            if (! empty($required)) {
                $schema->setRequired($required);
            }

            $declaration->setParameters($schema);
        }

        return $declaration;
    }

    private static function mapSchemaToProtobuf(PrismSchema $schema): Schema
    {
        $protoSchema = new Schema();

        $schemaArray = $schema->toArray();

        $typeMap = [
            'string'  => Type::STRING,
            'number'  => Type::NUMBER,
            'integer' => Type::INTEGER,
            'boolean' => Type::BOOLEAN,
            'array'   => Type::ARRAY,
            'object'  => Type::OBJECT,
        ];

        $type = $schemaArray['type'] ?? 'string';
        $protoSchema->setType($typeMap[$type] ?? Type::STRING);

        if (isset($schemaArray['description'])) {
            $protoSchema->setDescription($schemaArray['description']);
        }

        if (isset($schemaArray['enum'])) {
            $protoSchema->setEnum($schemaArray['enum']);
        }

        if (isset($schemaArray['items']) && $type === 'array') {
            $itemSchema = $schemaArray['items'];
            $itemProto = new Schema();
            $itemProto->setType($typeMap[$itemSchema['type'] ?? 'string'] ?? Type::STRING);
            if (isset($itemSchema['description'])) {
                $itemProto->setDescription($itemSchema['description']);
            }
            $protoSchema->setItems($itemProto);
        }

        if (isset($schemaArray['properties']) && $type === 'object') {
            $properties = [];
            foreach ($schemaArray['properties'] as $propName => $propSchema) {
                $propProto = new Schema();
                $propProto->setType($typeMap[$propSchema['type'] ?? 'string'] ?? Type::STRING);
                if (isset($propSchema['description'])) {
                    $propProto->setDescription($propSchema['description']);
                }
                $properties[$propName] = $propProto;
            }
            $protoSchema->setProperties($properties);
        }

        return $protoSchema;
    }

    private function mapProviderTool(ProviderTool $providerTool): ?VertexTool
    {
        return match ($providerTool->type) {
            'retrieval'     => $this->createRetrievalTool($providerTool->options),
            'google_search' => $this->createGoogleSearchTool(),
            default         => null,
        };
    }

    private function createRetrievalTool(array $options): VertexTool
    {
        $datastoreResourceName = sprintf(
            'projects/%s/locations/%s/collections/%s/dataStores/%s',
            $this->projectId,
            $options['location'] ?? 'global',
            $options['collection_id'] ?? 'default_collection',
            $options['datastore_id'],
        );

        $vertexAiSearch = new VertexAISearch();
        $vertexAiSearch->setDatastore($datastoreResourceName);

        $retrieval = new Retrieval();
        $retrieval->setVertexAiSearch($vertexAiSearch);

        $tool = new VertexTool();
        $tool->setRetrieval($retrieval);

        return $tool;
    }

    private function createGoogleSearchTool(): VertexTool
    {
        $tool = new VertexTool();
        $tool->setGoogleSearch(new GoogleSearch());

        return $tool;
    }

    private function hasToolCalls(GenerateContentResponse $response): bool
    {
        $candidates = $response->getCandidates();
        if (empty($candidates)) {
            return false;
        }

        $content = $candidates[0]->getContent();
        if (! $content) {
            return false;
        }

        foreach ($content->getParts() as $part) {
            if ($part->hasFunctionCall()) {
                return true;
            }
        }

        return false;
    }

    /**
     * @return array<int, \Prism\Prism\ValueObjects\ToolCall>
     */
    private function extractToolCalls(GenerateContentResponse $response): array
    {
        $toolCalls = [];
        $candidates = $response->getCandidates();
        if (empty($candidates)) {
            return $toolCalls;
        }

        $content = $candidates[0]->getContent();
        if (! $content) {
            return $toolCalls;
        }

        foreach ($content->getParts() as $part) {
            if (! $part->hasFunctionCall()) {
                continue;
            }

            $functionCall = $part->getFunctionCall();
            $name = $functionCall->getName();
            $args = [];

            $argsStruct = $functionCall->getArgs();
            if ($argsStruct !== null) {
                $argsJson = $argsStruct->serializeToJsonString();
                $args = json_decode($argsJson, true) ?: [];
            }

            $toolCalls[] = new ToolCall(
                id: $name,
                name: $name,
                arguments: $args,
            );
        }

        return $toolCalls;
    }

    private function extractText(GenerateContentResponse $response): string
    {
        $candidates = $response->getCandidates();
        if (empty($candidates)) {
            return '';
        }

        $content = $candidates[0]->getContent();
        if (! $content) {
            return '';
        }

        $textParts = [];
        foreach ($content->getParts() as $part) {
            if ($part->hasText()) {
                $textParts[] = $part->getText();
            }
        }

        return implode('', $textParts);
    }

    private function resolveFinishReason(GenerateContentResponse $response, bool $isToolCall): PrismFinishReason
    {
        if ($isToolCall) {
            return PrismFinishReason::ToolCalls;
        }

        $candidates = $response->getCandidates();
        if (empty($candidates)) {
            return PrismFinishReason::Unknown;
        }

        return $this->mapFinishReason($candidates[0]->getFinishReason());
    }

    /**
     * @return array<string, mixed>
     */
    private function extractAdditionalContent(GenerateContentResponse $response): array
    {
        $blockedReason = null;
        $promptFeedback = $response->getPromptFeedback();
        if ($promptFeedback && $promptFeedback->getBlockReason() !== BlockedReason::BLOCKED_REASON_UNSPECIFIED) {
            $blockedReason = BlockedReason::name($promptFeedback->getBlockReason());
        }

        $safetyRatings = [];
        $groundingMetadata = [];

        $candidates = $response->getCandidates();
        if (! empty($candidates)) {
            $candidate = $candidates[0];

            $safetyRatings = $this->extractSafetyRatings($candidate);
            $groundingMetadata = $this->extractGroundingMetadata($candidate);
        }

        return [
            'safety_ratings'     => $safetyRatings,
            'grounding_metadata' => $groundingMetadata,
            'blocked_reason'     => $blockedReason,
        ];
    }

    /**
     * @return array<int, array<string, mixed>>
     */
    private function extractSafetyRatings(Candidate $candidate): array
    {
        $ratings = [];

        foreach ($candidate->getSafetyRatings() as $rating) {
            $ratings[] = [
                'category'    => HarmCategory::name($rating->getCategory()),
                'probability' => HarmProbability::name($rating->getProbability()),
                'blocked'     => $rating->getBlocked(),
            ];
        }

        return $ratings;
    }

    /**
     * @return array<string, mixed>
     */
    private function extractGroundingMetadata(Candidate $candidate): array
    {
        $grounding = $candidate->getGroundingMetadata();
        if (! $grounding) {
            return [];
        }

        $queries = [];
        foreach ($grounding->getWebSearchQueries() as $query) {
            $queries[] = $query;
        }

        $chunks = [];
        foreach ($grounding->getGroundingChunks() as $chunk) {
            $web = $chunk->getWeb();
            if ($web) {
                $chunks[] = [
                    'type'  => 'web',
                    'uri'   => $web->getUri(),
                    'title' => $web->getTitle(),
                ];
            }

            $retrievedContext = $chunk->getRetrievedContext();
            if ($retrievedContext) {
                $chunks[] = [
                    'type'  => 'retrieved_context',
                    'uri'   => $retrievedContext->getUri(),
                    'title' => $retrievedContext->getTitle(),
                ];
            }
        }

        return [
            'web_search_queries' => $queries,
            'grounding_chunks'   => $chunks,
        ];
    }

    private function mapFinishReason(int $reason): PrismFinishReason
    {
        return match ($reason) {
            FinishReason::STOP                    => PrismFinishReason::Stop,
            FinishReason::MAX_TOKENS              => PrismFinishReason::Length,
            FinishReason::SAFETY                  => PrismFinishReason::ContentFilter,
            FinishReason::MALFORMED_FUNCTION_CALL => PrismFinishReason::Error,
            default                               => PrismFinishReason::Unknown,
        };
    }
}

@machour
Copy link
Author

machour commented Mar 25, 2026

@akalongman this PR content have been used in production for a few weeks on my side, no issues so far

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants