Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions pkg/cmd/agent-task/capi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@ package capi
import (
"context"
"net/http"
"net/url"
)

//go:generate moq -rm -out client_mock.go . CapiClient

const baseCAPIURL = "https://api.githubcopilot.com"
const capiHost = "api.githubcopilot.com"

// CapiClient defines the methods used by the caller. Implementations
// may be replaced with test doubles in unit tests.
type CapiClient interface {
Expand All @@ -24,33 +22,42 @@ type CapiClient interface {

// CAPIClient is a client for interacting with the Copilot API
type CAPIClient struct {
httpClient *http.Client
host string
httpClient *http.Client
host string
capiBaseURL string
}

// NewCAPIClient creates a new CAPI client. Provide a token, host, and an HTTP client which
// will be used as the base transport for CAPI requests.
// NewCAPIClient creates a new CAPI client. Provide a token, the user's GitHub
// host, the resolved Copilot API URL, and an HTTP client which will be used as
// the base transport for CAPI requests.
//
// The provided HTTP client will be mutated for use with CAPI, so it should not
// be reused elsewhere.
func NewCAPIClient(httpClient *http.Client, token string, host string) *CAPIClient {
httpClient.Transport = newCAPITransport(token, httpClient.Transport)
func NewCAPIClient(httpClient *http.Client, token string, host string, capiBaseURL string) *CAPIClient {
httpClient.Transport = newCAPITransport(token, capiBaseURL, httpClient.Transport)
return &CAPIClient{
httpClient: httpClient,
host: host,
httpClient: httpClient,
host: host,
capiBaseURL: capiBaseURL,
}
}

// capiTransport adds the Copilot auth headers
type capiTransport struct {
rp http.RoundTripper
token string
rp http.RoundTripper
token string
capiHost string
}

func newCAPITransport(token string, rp http.RoundTripper) *capiTransport {
func newCAPITransport(token string, capiBaseURL string, rp http.RoundTripper) *capiTransport {
capiHost := ""
if u, err := url.Parse(capiBaseURL); err == nil {
capiHost = u.Host
}
return &capiTransport{
rp: rp,
token: token,
rp: rp,
token: token,
capiHost: capiHost,
}
}

Expand All @@ -60,10 +67,10 @@ func (ct *capiTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Since this RoundTrip is reused for both Copilot API and
// GitHub API requests, we conditionally add the integration
// ID only when performing requests to the Copilot API.
if req.URL.Host == capiHost {
if req.URL.Host == ct.capiHost {
req.Header.Add("Copilot-Integration-Id", "copilot-4-cli")

// This is quick fix to ensure that we are not using GitHub API versions while targeting CAPI.
// Ensure we are not using GitHub API versions while targeting CAPI.
req.Header.Set("X-GitHub-Api-Version", "2026-01-09")
}
return ct.rp.RoundTrip(req)
Expand Down
8 changes: 5 additions & 3 deletions pkg/cmd/agent-task/capi/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ type JobError struct {
Service string `json:"service"`
}

const jobsBasePathV1 = baseCAPIURL + "/agents/swe/v1/jobs"
func (c *CAPIClient) jobsBasePathV1() string {
return c.capiBaseURL + "/agents/swe/v1/jobs"
}

// CreateJob queues a new job using the v1 Jobs API. It may or may not
// return Pull Request information. If Pull Request information is required
Expand All @@ -64,7 +66,7 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen
return nil, errors.New("problem statement is required")
}

url := fmt.Sprintf("%s/%s/%s", jobsBasePathV1, url.PathEscape(owner), url.PathEscape(repo))
url := fmt.Sprintf("%s/%s/%s", c.jobsBasePathV1(), url.PathEscape(owner), url.PathEscape(repo))

prOpts := JobPullRequest{}
if baseBranch != "" {
Expand Down Expand Up @@ -130,7 +132,7 @@ func (c *CAPIClient) GetJob(ctx context.Context, owner, repo, jobID string) (*Jo
if owner == "" || repo == "" || jobID == "" {
return nil, errors.New("owner, repo, and jobID are required")
}
url := fmt.Sprintf("%s/%s/%s/%s", jobsBasePathV1, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(jobID))
url := fmt.Sprintf("%s/%s/%s/%s", c.jobsBasePathV1(), url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(jobID))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions pkg/cmd/agent-task/capi/job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func TestGetJob(t *testing.T) {

httpClient := &http.Client{Transport: reg}

capiClient := NewCAPIClient(httpClient, "", "github.com")
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")

job, err := capiClient.GetJob(context.Background(), "OWNER", "REPO", "job123")

Expand Down Expand Up @@ -410,7 +410,7 @@ func TestCreateJob(t *testing.T) {

httpClient := &http.Client{Transport: reg}

capiClient := NewCAPIClient(httpClient, "", "github.com")
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")

job, err := capiClient.CreateJob(context.Background(), "OWNER", "REPO", "Do the thing", tt.baseBranch, tt.customAgent)

Expand Down
13 changes: 8 additions & 5 deletions pkg/cmd/agent-task/capi/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,16 @@ func (c *CAPIClient) ListLatestSessionsForViewer(ctx context.Context, limit int)
return nil, nil
}

url := baseCAPIURL + "/agents/sessions"
sessionsURL, err := url.JoinPath(c.capiBaseURL, "agents", "sessions")
if err != nil {
return nil, fmt.Errorf("failed to build sessions URL: %w", err)
}
pageSize := defaultSessionsPerPage

seenResources := make(map[int64]struct{})
latestSessions := make([]session, 0, limit)
for page := 1; ; page++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sessionsURL, http.NoBody)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -296,7 +299,7 @@ func (c *CAPIClient) GetSession(ctx context.Context, id string) (*Session, error
return nil, fmt.Errorf("missing session ID")
}

url := fmt.Sprintf("%s/agents/sessions/%s", baseCAPIURL, url.PathEscape(id))
url := fmt.Sprintf("%s/agents/sessions/%s", c.capiBaseURL, url.PathEscape(id))

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
Expand Down Expand Up @@ -335,7 +338,7 @@ func (c *CAPIClient) GetSessionLogs(ctx context.Context, id string) ([]byte, err
return nil, fmt.Errorf("missing session ID")
}

url := fmt.Sprintf("%s/agents/sessions/%s/logs", baseCAPIURL, url.PathEscape(id))
url := fmt.Sprintf("%s/agents/sessions/%s/logs", c.capiBaseURL, url.PathEscape(id))

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
Expand Down Expand Up @@ -368,7 +371,7 @@ func (c *CAPIClient) ListSessionsByResourceID(ctx context.Context, resourceType
return nil, nil
}

url := fmt.Sprintf("%s/agents/resource/%s/%d", baseCAPIURL, url.PathEscape(resourceType), resourceID)
url := fmt.Sprintf("%s/agents/resource/%s/%d", c.capiBaseURL, url.PathEscape(resourceType), resourceID)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions pkg/cmd/agent-task/capi/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1161,7 +1161,7 @@ func TestListLatestSessionsForViewer(t *testing.T) {

httpClient := &http.Client{Transport: reg}

capiClient := NewCAPIClient(httpClient, "", "github.com")
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")

if tt.perPage != 0 {
last := defaultSessionsPerPage
Expand Down Expand Up @@ -1540,7 +1540,7 @@ func TestListSessionsByResourceID(t *testing.T) {

httpClient := &http.Client{Transport: reg}

capiClient := NewCAPIClient(httpClient, "", "github.com")
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")

if tt.perPage != 0 {
last := defaultSessionsPerPage
Expand Down Expand Up @@ -1819,7 +1819,7 @@ func TestGetSession(t *testing.T) {

httpClient := &http.Client{Transport: reg}

capiClient := NewCAPIClient(httpClient, "", "github.com")
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")

session, err := capiClient.GetSession(context.Background(), "some-uuid")

Expand Down Expand Up @@ -1895,7 +1895,7 @@ func TestGetPullRequestDatabaseID(t *testing.T) {

httpClient := &http.Client{Transport: reg}

capiClient := NewCAPIClient(httpClient, "", "github.com")
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")

databaseID, url, err := capiClient.GetPullRequestDatabaseID(context.Background(), "github.com", "OWNER", "REPO", 42)

Expand Down
35 changes: 34 additions & 1 deletion pkg/cmd/agent-task/shared/capi.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package shared
import (
"errors"
"fmt"
"net/http"
"regexp"
"time"

"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/pkg/cmd/agent-task/capi"
prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
Expand All @@ -30,8 +33,38 @@ func CapiClientFunc(f *cmdutil.Factory) func() (capi.CapiClient, error) {
authCfg := cfg.Authentication()
host, _ := authCfg.DefaultHost()
token, _ := authCfg.ActiveToken(host)
return capi.NewCAPIClient(httpClient, token, host), nil

cachedClient := api.NewCachedHTTPClient(httpClient, time.Minute*10)
capiBaseURL, err := resolveCapiURL(cachedClient, host)
if err != nil {
return nil, fmt.Errorf("failed to resolve Copilot API URL: %w", err)
}

return capi.NewCAPIClient(httpClient, token, host, capiBaseURL), nil
}
}

// resolveCapiURL queries the GitHub API for the Copilot API endpoint URL.
func resolveCapiURL(httpClient *http.Client, host string) (string, error) {
apiClient := api.NewClientFromHTTP(httpClient)

var resp struct {
Viewer struct {
CopilotEndpoints struct {
Api string `graphql:"api"`
} `graphql:"copilotEndpoints"`
} `graphql:"viewer"`
}

if err := apiClient.Query(host, "CopilotEndpoints", &resp, nil); err != nil {
return "", err
}

if resp.Viewer.CopilotEndpoints.Api == "" {
return "", errors.New("empty Copilot API URL returned")
}

return resp.Viewer.CopilotEndpoints.Api, nil
}

func IsSessionID(s string) bool {
Expand Down
88 changes: 88 additions & 0 deletions pkg/cmd/agent-task/shared/capi_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,100 @@
package shared

import (
"net/http"
"testing"

"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/internal/gh"
ghmock "github.com/cli/cli/v2/internal/gh/mock"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestResolveCapiURL(t *testing.T) {
tests := []struct {
name string
resp string
wantURL string
wantErr bool
}{
{
name: "returns resolved URL",
resp: `{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.example.com"}}}}`,
wantURL: "https://test-copilot-api.example.com",
},
{
name: "ghe.com tenant URL",
resp: `{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.tenant.example.com"}}}}`,
wantURL: "https://test-copilot-api.tenant.example.com",
},
{
name: "empty URL returns error",
resp: `{"data":{"viewer":{"copilotEndpoints":{"api":""}}}}`,
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reg := &httpmock.Registry{}
defer reg.Verify(t)

reg.Register(
httpmock.GraphQL(`query CopilotEndpoints\b`),
httpmock.StringResponse(tt.resp),
)

httpClient := &http.Client{Transport: reg}
url, err := resolveCapiURL(httpClient, "github.com")

if tt.wantErr {
require.Error(t, err)
return
}

require.NoError(t, err)
assert.Equal(t, tt.wantURL, url)
})
}
}

func TestCapiClientFuncResolvesURL(t *testing.T) {
reg := &httpmock.Registry{}
defer reg.Verify(t)

reg.Register(
httpmock.GraphQL(`query CopilotEndpoints\b`),
httpmock.StringResponse(`{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.example.com"}}}}`),
)

f := &cmdutil.Factory{
Config: func() (gh.Config, error) {
return &ghmock.ConfigMock{
AuthenticationFunc: func() gh.AuthConfig {
c := &config.AuthConfig{}
c.SetDefaultHost("github.com", "hosts")
c.SetActiveToken("gho_TOKEN", "oauth_token")
return c
},
}, nil
},
HttpClient: func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
},
}

clientFunc := CapiClientFunc(f)
client, err := clientFunc()
require.NoError(t, err)
require.NotNil(t, client)

// Verify the GraphQL resolution was called
require.Len(t, reg.Requests, 1)
}

func TestIsSession(t *testing.T) {
assert.True(t, IsSessionID("00000000-0000-0000-0000-000000000000"))
assert.True(t, IsSessionID("e2fa49d2-f164-4a56-ab99-498090b8fcdf"))
Expand Down
Loading