From 78b958f9ae9f5828f62252e4b1bfcc006ec0f158 Mon Sep 17 00:00:00 2001 From: Kynan Ware <47394200+BagToad@users.noreply.github.com> Date: Wed, 18 Mar 2026 12:14:02 -0600 Subject: [PATCH] fix(agent-task): resolve Copilot API URL dynamically (#12956) * fix(agent-task): resolve Copilot API URL dynamically Query viewer.copilotEndpoints.api to get the correct Copilot API URL for the user's host instead of hardcoding api.githubcopilot.com. This fixes 401 errors for ghe.com tenancy users whose Copilot API lives at a different endpoint. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- pkg/cmd/agent-task/capi/client.go | 43 +++++++----- pkg/cmd/agent-task/capi/job.go | 8 ++- pkg/cmd/agent-task/capi/job_test.go | 4 +- pkg/cmd/agent-task/capi/sessions.go | 13 ++-- pkg/cmd/agent-task/capi/sessions_test.go | 8 +-- pkg/cmd/agent-task/shared/capi.go | 35 +++++++++- pkg/cmd/agent-task/shared/capi_test.go | 88 ++++++++++++++++++++++++ 7 files changed, 166 insertions(+), 33 deletions(-) diff --git a/pkg/cmd/agent-task/capi/client.go b/pkg/cmd/agent-task/capi/client.go index af618f4cee1..2f6c649a12f 100644 --- a/pkg/cmd/agent-task/capi/client.go +++ b/pkg/cmd/agent-task/capi/client.go @@ -3,13 +3,11 @@ package capi import ( "context" "net/http" + "net/url" ) //go:generate moq -rm -out client_mock.go . CapiClient -const baseCAPIURL = "https://api.githubcopilot.com" -const capiHost = "api.githubcopilot.com" - // CapiClient defines the methods used by the caller. Implementations // may be replaced with test doubles in unit tests. type CapiClient interface { @@ -24,33 +22,42 @@ type CapiClient interface { // CAPIClient is a client for interacting with the Copilot API type CAPIClient struct { - httpClient *http.Client - host string + httpClient *http.Client + host string + capiBaseURL string } -// NewCAPIClient creates a new CAPI client. Provide a token, host, and an HTTP client which -// will be used as the base transport for CAPI requests. +// NewCAPIClient creates a new CAPI client. Provide a token, the user's GitHub +// host, the resolved Copilot API URL, and an HTTP client which will be used as +// the base transport for CAPI requests. // // The provided HTTP client will be mutated for use with CAPI, so it should not // be reused elsewhere. -func NewCAPIClient(httpClient *http.Client, token string, host string) *CAPIClient { - httpClient.Transport = newCAPITransport(token, httpClient.Transport) +func NewCAPIClient(httpClient *http.Client, token string, host string, capiBaseURL string) *CAPIClient { + httpClient.Transport = newCAPITransport(token, capiBaseURL, httpClient.Transport) return &CAPIClient{ - httpClient: httpClient, - host: host, + httpClient: httpClient, + host: host, + capiBaseURL: capiBaseURL, } } // capiTransport adds the Copilot auth headers type capiTransport struct { - rp http.RoundTripper - token string + rp http.RoundTripper + token string + capiHost string } -func newCAPITransport(token string, rp http.RoundTripper) *capiTransport { +func newCAPITransport(token string, capiBaseURL string, rp http.RoundTripper) *capiTransport { + capiHost := "" + if u, err := url.Parse(capiBaseURL); err == nil { + capiHost = u.Host + } return &capiTransport{ - rp: rp, - token: token, + rp: rp, + token: token, + capiHost: capiHost, } } @@ -60,10 +67,10 @@ func (ct *capiTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Since this RoundTrip is reused for both Copilot API and // GitHub API requests, we conditionally add the integration // ID only when performing requests to the Copilot API. - if req.URL.Host == capiHost { + if req.URL.Host == ct.capiHost { req.Header.Add("Copilot-Integration-Id", "copilot-4-cli") - // This is quick fix to ensure that we are not using GitHub API versions while targeting CAPI. + // Ensure we are not using GitHub API versions while targeting CAPI. req.Header.Set("X-GitHub-Api-Version", "2026-01-09") } return ct.rp.RoundTrip(req) diff --git a/pkg/cmd/agent-task/capi/job.go b/pkg/cmd/agent-task/capi/job.go index 2d5c2d26467..2e37d4f5ee2 100644 --- a/pkg/cmd/agent-task/capi/job.go +++ b/pkg/cmd/agent-task/capi/job.go @@ -51,7 +51,9 @@ type JobError struct { Service string `json:"service"` } -const jobsBasePathV1 = baseCAPIURL + "/agents/swe/v1/jobs" +func (c *CAPIClient) jobsBasePathV1() string { + return c.capiBaseURL + "/agents/swe/v1/jobs" +} // CreateJob queues a new job using the v1 Jobs API. It may or may not // return Pull Request information. If Pull Request information is required @@ -64,7 +66,7 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen return nil, errors.New("problem statement is required") } - url := fmt.Sprintf("%s/%s/%s", jobsBasePathV1, url.PathEscape(owner), url.PathEscape(repo)) + url := fmt.Sprintf("%s/%s/%s", c.jobsBasePathV1(), url.PathEscape(owner), url.PathEscape(repo)) prOpts := JobPullRequest{} if baseBranch != "" { @@ -130,7 +132,7 @@ func (c *CAPIClient) GetJob(ctx context.Context, owner, repo, jobID string) (*Jo if owner == "" || repo == "" || jobID == "" { return nil, errors.New("owner, repo, and jobID are required") } - url := fmt.Sprintf("%s/%s/%s/%s", jobsBasePathV1, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(jobID)) + url := fmt.Sprintf("%s/%s/%s/%s", c.jobsBasePathV1(), url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(jobID)) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { return nil, err diff --git a/pkg/cmd/agent-task/capi/job_test.go b/pkg/cmd/agent-task/capi/job_test.go index 58b9805b713..b80e8e6a902 100644 --- a/pkg/cmd/agent-task/capi/job_test.go +++ b/pkg/cmd/agent-task/capi/job_test.go @@ -167,7 +167,7 @@ func TestGetJob(t *testing.T) { httpClient := &http.Client{Transport: reg} - capiClient := NewCAPIClient(httpClient, "", "github.com") + capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com") job, err := capiClient.GetJob(context.Background(), "OWNER", "REPO", "job123") @@ -410,7 +410,7 @@ func TestCreateJob(t *testing.T) { httpClient := &http.Client{Transport: reg} - capiClient := NewCAPIClient(httpClient, "", "github.com") + capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com") job, err := capiClient.CreateJob(context.Background(), "OWNER", "REPO", "Do the thing", tt.baseBranch, tt.customAgent) diff --git a/pkg/cmd/agent-task/capi/sessions.go b/pkg/cmd/agent-task/capi/sessions.go index 4b457d799bb..69ea80820c5 100644 --- a/pkg/cmd/agent-task/capi/sessions.go +++ b/pkg/cmd/agent-task/capi/sessions.go @@ -217,13 +217,16 @@ func (c *CAPIClient) ListLatestSessionsForViewer(ctx context.Context, limit int) return nil, nil } - url := baseCAPIURL + "/agents/sessions" + sessionsURL, err := url.JoinPath(c.capiBaseURL, "agents", "sessions") + if err != nil { + return nil, fmt.Errorf("failed to build sessions URL: %w", err) + } pageSize := defaultSessionsPerPage seenResources := make(map[int64]struct{}) latestSessions := make([]session, 0, limit) for page := 1; ; page++ { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, sessionsURL, http.NoBody) if err != nil { return nil, err } @@ -296,7 +299,7 @@ func (c *CAPIClient) GetSession(ctx context.Context, id string) (*Session, error return nil, fmt.Errorf("missing session ID") } - url := fmt.Sprintf("%s/agents/sessions/%s", baseCAPIURL, url.PathEscape(id)) + url := fmt.Sprintf("%s/agents/sessions/%s", c.capiBaseURL, url.PathEscape(id)) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { @@ -335,7 +338,7 @@ func (c *CAPIClient) GetSessionLogs(ctx context.Context, id string) ([]byte, err return nil, fmt.Errorf("missing session ID") } - url := fmt.Sprintf("%s/agents/sessions/%s/logs", baseCAPIURL, url.PathEscape(id)) + url := fmt.Sprintf("%s/agents/sessions/%s/logs", c.capiBaseURL, url.PathEscape(id)) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { @@ -368,7 +371,7 @@ func (c *CAPIClient) ListSessionsByResourceID(ctx context.Context, resourceType return nil, nil } - url := fmt.Sprintf("%s/agents/resource/%s/%d", baseCAPIURL, url.PathEscape(resourceType), resourceID) + url := fmt.Sprintf("%s/agents/resource/%s/%d", c.capiBaseURL, url.PathEscape(resourceType), resourceID) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { diff --git a/pkg/cmd/agent-task/capi/sessions_test.go b/pkg/cmd/agent-task/capi/sessions_test.go index 223f04320f6..fd7614a38c6 100644 --- a/pkg/cmd/agent-task/capi/sessions_test.go +++ b/pkg/cmd/agent-task/capi/sessions_test.go @@ -1161,7 +1161,7 @@ func TestListLatestSessionsForViewer(t *testing.T) { httpClient := &http.Client{Transport: reg} - capiClient := NewCAPIClient(httpClient, "", "github.com") + capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com") if tt.perPage != 0 { last := defaultSessionsPerPage @@ -1540,7 +1540,7 @@ func TestListSessionsByResourceID(t *testing.T) { httpClient := &http.Client{Transport: reg} - capiClient := NewCAPIClient(httpClient, "", "github.com") + capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com") if tt.perPage != 0 { last := defaultSessionsPerPage @@ -1819,7 +1819,7 @@ func TestGetSession(t *testing.T) { httpClient := &http.Client{Transport: reg} - capiClient := NewCAPIClient(httpClient, "", "github.com") + capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com") session, err := capiClient.GetSession(context.Background(), "some-uuid") @@ -1895,7 +1895,7 @@ func TestGetPullRequestDatabaseID(t *testing.T) { httpClient := &http.Client{Transport: reg} - capiClient := NewCAPIClient(httpClient, "", "github.com") + capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com") databaseID, url, err := capiClient.GetPullRequestDatabaseID(context.Background(), "github.com", "OWNER", "REPO", 42) diff --git a/pkg/cmd/agent-task/shared/capi.go b/pkg/cmd/agent-task/shared/capi.go index 65794581777..9d43fd3cce6 100644 --- a/pkg/cmd/agent-task/shared/capi.go +++ b/pkg/cmd/agent-task/shared/capi.go @@ -3,8 +3,11 @@ package shared import ( "errors" "fmt" + "net/http" "regexp" + "time" + "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared" "github.com/cli/cli/v2/pkg/cmdutil" @@ -30,8 +33,38 @@ func CapiClientFunc(f *cmdutil.Factory) func() (capi.CapiClient, error) { authCfg := cfg.Authentication() host, _ := authCfg.DefaultHost() token, _ := authCfg.ActiveToken(host) - return capi.NewCAPIClient(httpClient, token, host), nil + + cachedClient := api.NewCachedHTTPClient(httpClient, time.Minute*10) + capiBaseURL, err := resolveCapiURL(cachedClient, host) + if err != nil { + return nil, fmt.Errorf("failed to resolve Copilot API URL: %w", err) + } + + return capi.NewCAPIClient(httpClient, token, host, capiBaseURL), nil + } +} + +// resolveCapiURL queries the GitHub API for the Copilot API endpoint URL. +func resolveCapiURL(httpClient *http.Client, host string) (string, error) { + apiClient := api.NewClientFromHTTP(httpClient) + + var resp struct { + Viewer struct { + CopilotEndpoints struct { + Api string `graphql:"api"` + } `graphql:"copilotEndpoints"` + } `graphql:"viewer"` + } + + if err := apiClient.Query(host, "CopilotEndpoints", &resp, nil); err != nil { + return "", err } + + if resp.Viewer.CopilotEndpoints.Api == "" { + return "", errors.New("empty Copilot API URL returned") + } + + return resp.Viewer.CopilotEndpoints.Api, nil } func IsSessionID(s string) bool { diff --git a/pkg/cmd/agent-task/shared/capi_test.go b/pkg/cmd/agent-task/shared/capi_test.go index 205d881c8d0..3699d25c752 100644 --- a/pkg/cmd/agent-task/shared/capi_test.go +++ b/pkg/cmd/agent-task/shared/capi_test.go @@ -1,12 +1,100 @@ package shared import ( + "net/http" "testing" + "github.com/cli/cli/v2/internal/config" + "github.com/cli/cli/v2/internal/gh" + ghmock "github.com/cli/cli/v2/internal/gh/mock" + "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/httpmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestResolveCapiURL(t *testing.T) { + tests := []struct { + name string + resp string + wantURL string + wantErr bool + }{ + { + name: "returns resolved URL", + resp: `{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.example.com"}}}}`, + wantURL: "https://test-copilot-api.example.com", + }, + { + name: "ghe.com tenant URL", + resp: `{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.tenant.example.com"}}}}`, + wantURL: "https://test-copilot-api.tenant.example.com", + }, + { + name: "empty URL returns error", + resp: `{"data":{"viewer":{"copilotEndpoints":{"api":""}}}}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reg := &httpmock.Registry{} + defer reg.Verify(t) + + reg.Register( + httpmock.GraphQL(`query CopilotEndpoints\b`), + httpmock.StringResponse(tt.resp), + ) + + httpClient := &http.Client{Transport: reg} + url, err := resolveCapiURL(httpClient, "github.com") + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantURL, url) + }) + } +} + +func TestCapiClientFuncResolvesURL(t *testing.T) { + reg := &httpmock.Registry{} + defer reg.Verify(t) + + reg.Register( + httpmock.GraphQL(`query CopilotEndpoints\b`), + httpmock.StringResponse(`{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.example.com"}}}}`), + ) + + f := &cmdutil.Factory{ + Config: func() (gh.Config, error) { + return &ghmock.ConfigMock{ + AuthenticationFunc: func() gh.AuthConfig { + c := &config.AuthConfig{} + c.SetDefaultHost("github.com", "hosts") + c.SetActiveToken("gho_TOKEN", "oauth_token") + return c + }, + }, nil + }, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + } + + clientFunc := CapiClientFunc(f) + client, err := clientFunc() + require.NoError(t, err) + require.NotNil(t, client) + + // Verify the GraphQL resolution was called + require.Len(t, reg.Requests, 1) +} + func TestIsSession(t *testing.T) { assert.True(t, IsSessionID("00000000-0000-0000-0000-000000000000")) assert.True(t, IsSessionID("e2fa49d2-f164-4a56-ab99-498090b8fcdf"))