diff --git a/pkg/cmd/agent-task/capi/client.go b/pkg/cmd/agent-task/capi/client.go index af618f4cee1..2f6c649a12f 100644 --- a/pkg/cmd/agent-task/capi/client.go +++ b/pkg/cmd/agent-task/capi/client.go @@ -3,13 +3,11 @@ package capi import ( "context" "net/http" + "net/url" ) //go:generate moq -rm -out client_mock.go . CapiClient -const baseCAPIURL = "https://api.githubcopilot.com" -const capiHost = "api.githubcopilot.com" - // CapiClient defines the methods used by the caller. Implementations // may be replaced with test doubles in unit tests. type CapiClient interface { @@ -24,33 +22,42 @@ type CapiClient interface { // CAPIClient is a client for interacting with the Copilot API type CAPIClient struct { - httpClient *http.Client - host string + httpClient *http.Client + host string + capiBaseURL string } -// NewCAPIClient creates a new CAPI client. Provide a token, host, and an HTTP client which -// will be used as the base transport for CAPI requests. +// NewCAPIClient creates a new CAPI client. Provide a token, the user's GitHub +// host, the resolved Copilot API URL, and an HTTP client which will be used as +// the base transport for CAPI requests. // // The provided HTTP client will be mutated for use with CAPI, so it should not // be reused elsewhere. -func NewCAPIClient(httpClient *http.Client, token string, host string) *CAPIClient { - httpClient.Transport = newCAPITransport(token, httpClient.Transport) +func NewCAPIClient(httpClient *http.Client, token string, host string, capiBaseURL string) *CAPIClient { + httpClient.Transport = newCAPITransport(token, capiBaseURL, httpClient.Transport) return &CAPIClient{ - httpClient: httpClient, - host: host, + httpClient: httpClient, + host: host, + capiBaseURL: capiBaseURL, } } // capiTransport adds the Copilot auth headers type capiTransport struct { - rp http.RoundTripper - token string + rp http.RoundTripper + token string + capiHost string } -func newCAPITransport(token string, rp http.RoundTripper) *capiTransport { +func newCAPITransport(token string, capiBaseURL string, rp http.RoundTripper) *capiTransport { + capiHost := "" + if u, err := url.Parse(capiBaseURL); err == nil { + capiHost = u.Host + } return &capiTransport{ - rp: rp, - token: token, + rp: rp, + token: token, + capiHost: capiHost, } } @@ -60,10 +67,10 @@ func (ct *capiTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Since this RoundTrip is reused for both Copilot API and // GitHub API requests, we conditionally add the integration // ID only when performing requests to the Copilot API. - if req.URL.Host == capiHost { + if req.URL.Host == ct.capiHost { req.Header.Add("Copilot-Integration-Id", "copilot-4-cli") - // This is quick fix to ensure that we are not using GitHub API versions while targeting CAPI. + // Ensure we are not using GitHub API versions while targeting CAPI. req.Header.Set("X-GitHub-Api-Version", "2026-01-09") } return ct.rp.RoundTrip(req) diff --git a/pkg/cmd/agent-task/capi/job.go b/pkg/cmd/agent-task/capi/job.go index 2d5c2d26467..2e37d4f5ee2 100644 --- a/pkg/cmd/agent-task/capi/job.go +++ b/pkg/cmd/agent-task/capi/job.go @@ -51,7 +51,9 @@ type JobError struct { Service string `json:"service"` } -const jobsBasePathV1 = baseCAPIURL + "/agents/swe/v1/jobs" +func (c *CAPIClient) jobsBasePathV1() string { + return c.capiBaseURL + "/agents/swe/v1/jobs" +} // CreateJob queues a new job using the v1 Jobs API. It may or may not // return Pull Request information. If Pull Request information is required @@ -64,7 +66,7 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen return nil, errors.New("problem statement is required") } - url := fmt.Sprintf("%s/%s/%s", jobsBasePathV1, url.PathEscape(owner), url.PathEscape(repo)) + url := fmt.Sprintf("%s/%s/%s", c.jobsBasePathV1(), url.PathEscape(owner), url.PathEscape(repo)) prOpts := JobPullRequest{} if baseBranch != "" { @@ -130,7 +132,7 @@ func (c *CAPIClient) GetJob(ctx context.Context, owner, repo, jobID string) (*Jo if owner == "" || repo == "" || jobID == "" { return nil, errors.New("owner, repo, and jobID are required") } - url := fmt.Sprintf("%s/%s/%s/%s", jobsBasePathV1, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(jobID)) + url := fmt.Sprintf("%s/%s/%s/%s", c.jobsBasePathV1(), url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(jobID)) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { return nil, err diff --git a/pkg/cmd/agent-task/capi/job_test.go b/pkg/cmd/agent-task/capi/job_test.go index 58b9805b713..b80e8e6a902 100644 --- a/pkg/cmd/agent-task/capi/job_test.go +++ b/pkg/cmd/agent-task/capi/job_test.go @@ -167,7 +167,7 @@ func TestGetJob(t *testing.T) { httpClient := &http.Client{Transport: reg} - capiClient := NewCAPIClient(httpClient, "", "github.com") + capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com") job, err := capiClient.GetJob(context.Background(), "OWNER", "REPO", "job123") @@ -410,7 +410,7 @@ func TestCreateJob(t *testing.T) { httpClient := &http.Client{Transport: reg} - capiClient := NewCAPIClient(httpClient, "", "github.com") + capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com") job, err := capiClient.CreateJob(context.Background(), "OWNER", "REPO", "Do the thing", tt.baseBranch, tt.customAgent) diff --git a/pkg/cmd/agent-task/capi/sessions.go b/pkg/cmd/agent-task/capi/sessions.go index 4b457d799bb..69ea80820c5 100644 --- a/pkg/cmd/agent-task/capi/sessions.go +++ b/pkg/cmd/agent-task/capi/sessions.go @@ -217,13 +217,16 @@ func (c *CAPIClient) ListLatestSessionsForViewer(ctx context.Context, limit int) return nil, nil } - url := baseCAPIURL + "/agents/sessions" + sessionsURL, err := url.JoinPath(c.capiBaseURL, "agents", "sessions") + if err != nil { + return nil, fmt.Errorf("failed to build sessions URL: %w", err) + } pageSize := defaultSessionsPerPage seenResources := make(map[int64]struct{}) latestSessions := make([]session, 0, limit) for page := 1; ; page++ { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, sessionsURL, http.NoBody) if err != nil { return nil, err } @@ -296,7 +299,7 @@ func (c *CAPIClient) GetSession(ctx context.Context, id string) (*Session, error return nil, fmt.Errorf("missing session ID") } - url := fmt.Sprintf("%s/agents/sessions/%s", baseCAPIURL, url.PathEscape(id)) + url := fmt.Sprintf("%s/agents/sessions/%s", c.capiBaseURL, url.PathEscape(id)) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { @@ -335,7 +338,7 @@ func (c *CAPIClient) GetSessionLogs(ctx context.Context, id string) ([]byte, err return nil, fmt.Errorf("missing session ID") } - url := fmt.Sprintf("%s/agents/sessions/%s/logs", baseCAPIURL, url.PathEscape(id)) + url := fmt.Sprintf("%s/agents/sessions/%s/logs", c.capiBaseURL, url.PathEscape(id)) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { @@ -368,7 +371,7 @@ func (c *CAPIClient) ListSessionsByResourceID(ctx context.Context, resourceType return nil, nil } - url := fmt.Sprintf("%s/agents/resource/%s/%d", baseCAPIURL, url.PathEscape(resourceType), resourceID) + url := fmt.Sprintf("%s/agents/resource/%s/%d", c.capiBaseURL, url.PathEscape(resourceType), resourceID) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { diff --git a/pkg/cmd/agent-task/capi/sessions_test.go b/pkg/cmd/agent-task/capi/sessions_test.go index 223f04320f6..fd7614a38c6 100644 --- a/pkg/cmd/agent-task/capi/sessions_test.go +++ b/pkg/cmd/agent-task/capi/sessions_test.go @@ -1161,7 +1161,7 @@ func TestListLatestSessionsForViewer(t *testing.T) { httpClient := &http.Client{Transport: reg} - capiClient := NewCAPIClient(httpClient, "", "github.com") + capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com") if tt.perPage != 0 { last := defaultSessionsPerPage @@ -1540,7 +1540,7 @@ func TestListSessionsByResourceID(t *testing.T) { httpClient := &http.Client{Transport: reg} - capiClient := NewCAPIClient(httpClient, "", "github.com") + capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com") if tt.perPage != 0 { last := defaultSessionsPerPage @@ -1819,7 +1819,7 @@ func TestGetSession(t *testing.T) { httpClient := &http.Client{Transport: reg} - capiClient := NewCAPIClient(httpClient, "", "github.com") + capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com") session, err := capiClient.GetSession(context.Background(), "some-uuid") @@ -1895,7 +1895,7 @@ func TestGetPullRequestDatabaseID(t *testing.T) { httpClient := &http.Client{Transport: reg} - capiClient := NewCAPIClient(httpClient, "", "github.com") + capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com") databaseID, url, err := capiClient.GetPullRequestDatabaseID(context.Background(), "github.com", "OWNER", "REPO", 42) diff --git a/pkg/cmd/agent-task/shared/capi.go b/pkg/cmd/agent-task/shared/capi.go index 65794581777..9d43fd3cce6 100644 --- a/pkg/cmd/agent-task/shared/capi.go +++ b/pkg/cmd/agent-task/shared/capi.go @@ -3,8 +3,11 @@ package shared import ( "errors" "fmt" + "net/http" "regexp" + "time" + "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared" "github.com/cli/cli/v2/pkg/cmdutil" @@ -30,8 +33,38 @@ func CapiClientFunc(f *cmdutil.Factory) func() (capi.CapiClient, error) { authCfg := cfg.Authentication() host, _ := authCfg.DefaultHost() token, _ := authCfg.ActiveToken(host) - return capi.NewCAPIClient(httpClient, token, host), nil + + cachedClient := api.NewCachedHTTPClient(httpClient, time.Minute*10) + capiBaseURL, err := resolveCapiURL(cachedClient, host) + if err != nil { + return nil, fmt.Errorf("failed to resolve Copilot API URL: %w", err) + } + + return capi.NewCAPIClient(httpClient, token, host, capiBaseURL), nil + } +} + +// resolveCapiURL queries the GitHub API for the Copilot API endpoint URL. +func resolveCapiURL(httpClient *http.Client, host string) (string, error) { + apiClient := api.NewClientFromHTTP(httpClient) + + var resp struct { + Viewer struct { + CopilotEndpoints struct { + Api string `graphql:"api"` + } `graphql:"copilotEndpoints"` + } `graphql:"viewer"` + } + + if err := apiClient.Query(host, "CopilotEndpoints", &resp, nil); err != nil { + return "", err } + + if resp.Viewer.CopilotEndpoints.Api == "" { + return "", errors.New("empty Copilot API URL returned") + } + + return resp.Viewer.CopilotEndpoints.Api, nil } func IsSessionID(s string) bool { diff --git a/pkg/cmd/agent-task/shared/capi_test.go b/pkg/cmd/agent-task/shared/capi_test.go index 205d881c8d0..3699d25c752 100644 --- a/pkg/cmd/agent-task/shared/capi_test.go +++ b/pkg/cmd/agent-task/shared/capi_test.go @@ -1,12 +1,100 @@ package shared import ( + "net/http" "testing" + "github.com/cli/cli/v2/internal/config" + "github.com/cli/cli/v2/internal/gh" + ghmock "github.com/cli/cli/v2/internal/gh/mock" + "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/httpmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestResolveCapiURL(t *testing.T) { + tests := []struct { + name string + resp string + wantURL string + wantErr bool + }{ + { + name: "returns resolved URL", + resp: `{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.example.com"}}}}`, + wantURL: "https://test-copilot-api.example.com", + }, + { + name: "ghe.com tenant URL", + resp: `{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.tenant.example.com"}}}}`, + wantURL: "https://test-copilot-api.tenant.example.com", + }, + { + name: "empty URL returns error", + resp: `{"data":{"viewer":{"copilotEndpoints":{"api":""}}}}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reg := &httpmock.Registry{} + defer reg.Verify(t) + + reg.Register( + httpmock.GraphQL(`query CopilotEndpoints\b`), + httpmock.StringResponse(tt.resp), + ) + + httpClient := &http.Client{Transport: reg} + url, err := resolveCapiURL(httpClient, "github.com") + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantURL, url) + }) + } +} + +func TestCapiClientFuncResolvesURL(t *testing.T) { + reg := &httpmock.Registry{} + defer reg.Verify(t) + + reg.Register( + httpmock.GraphQL(`query CopilotEndpoints\b`), + httpmock.StringResponse(`{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.example.com"}}}}`), + ) + + f := &cmdutil.Factory{ + Config: func() (gh.Config, error) { + return &ghmock.ConfigMock{ + AuthenticationFunc: func() gh.AuthConfig { + c := &config.AuthConfig{} + c.SetDefaultHost("github.com", "hosts") + c.SetActiveToken("gho_TOKEN", "oauth_token") + return c + }, + }, nil + }, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + } + + clientFunc := CapiClientFunc(f) + client, err := clientFunc() + require.NoError(t, err) + require.NotNil(t, client) + + // Verify the GraphQL resolution was called + require.Len(t, reg.Requests, 1) +} + func TestIsSession(t *testing.T) { assert.True(t, IsSessionID("00000000-0000-0000-0000-000000000000")) assert.True(t, IsSessionID("e2fa49d2-f164-4a56-ab99-498090b8fcdf"))