Skip to content

Commit

Permalink
fix(github): use apiEndpoint for basic or no auth (#1454)
Browse files Browse the repository at this point in the history
  • Loading branch information
rgmz authored Jul 26, 2023
1 parent f48a635 commit 2290954
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 22 deletions.
65 changes: 45 additions & 20 deletions pkg/sources/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,11 @@ func (s *Source) enumerate(ctx context.Context, apiEndpoint string) (*github.Cli

switch cred := s.conn.GetCredential().(type) {
case *sourcespb.GitHub_BasicAuth:
if err = s.enumerateBasicAuth(ctx, cred.BasicAuth); err != nil {
if err = s.enumerateBasicAuth(ctx, apiEndpoint, cred.BasicAuth); err != nil {
return nil, err
}
case *sourcespb.GitHub_Unauthenticated:
s.enumerateUnauthenticated(ctx)
s.enumerateUnauthenticated(ctx, apiEndpoint)
case *sourcespb.GitHub_Token:
if err = s.enumerateWithToken(ctx, apiEndpoint, cred.Token); err != nil {
return nil, err
Expand All @@ -382,11 +382,16 @@ func (s *Source) enumerate(ctx context.Context, apiEndpoint string) (*github.Cli
return installationClient, nil
}

func (s *Source) enumerateBasicAuth(ctx context.Context, basicAuth *credentialspb.BasicAuth) error {
s.apiClient = github.NewClient(&http.Client{Transport: &github.BasicAuthTransport{
func (s *Source) enumerateBasicAuth(ctx context.Context, apiEndpoint string, basicAuth *credentialspb.BasicAuth) error {
s.httpClient.Transport = &github.BasicAuthTransport{
Username: basicAuth.Username,
Password: basicAuth.Password,
}})
}
ghClient, err := createGitHubClient(s.httpClient, apiEndpoint)
if err != nil {
s.log.Error(err, "error creating GitHub client")
}
s.apiClient = ghClient

for _, org := range s.orgsCache.Keys() {
if err := s.getReposByOrg(ctx, org); err != nil {
Expand All @@ -397,8 +402,12 @@ func (s *Source) enumerateBasicAuth(ctx context.Context, basicAuth *credentialsp
return nil
}

func (s *Source) enumerateUnauthenticated(ctx context.Context) {
s.apiClient = github.NewClient(s.httpClient)
func (s *Source) enumerateUnauthenticated(ctx context.Context, apiEndpoint string) {
ghClient, err := createGitHubClient(s.httpClient, apiEndpoint)
if err != nil {
s.log.Error(err, "error creating GitHub client")
}
s.apiClient = ghClient
if s.orgsCache.Count() > unauthGithubOrgRateLimt {
s.log.Info("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.")
}
Expand Down Expand Up @@ -432,19 +441,14 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri
Source: oauth2.ReuseTokenSource(nil, ts),
}

var err error
// If we're using public Github, make a regular client.
// If we're using public GitHub, make a regular client.
// Otherwise, make an enterprise client.
var isGHE bool
if apiEndpoint == "https://api.github.com" {
s.apiClient = github.NewClient(s.httpClient)
} else {
isGHE = true
s.apiClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, s.httpClient)
if err != nil {
return errors.New(err)
}
var isGHE bool = apiEndpoint != "https://api.github.com"
ghClient, err := createGitHubClient(s.httpClient, apiEndpoint)
if err != nil {
s.log.Error(err, "error creating GitHub client")
}
s.apiClient = ghClient

// TODO: this should support scanning users too

Expand Down Expand Up @@ -560,7 +564,9 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
return nil, errors.New(err)
}
itr.BaseURL = apiEndpoint
s.apiClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, &http.Client{Transport: itr})

s.httpClient.Transport = itr
s.apiClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, s.httpClient)
if err != nil {
return nil, errors.New(err)
}
Expand All @@ -575,7 +581,11 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
return nil, errors.New(err)
}
appItr.BaseURL = apiEndpoint
installationClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, &http.Client{Transport: appItr})

// Does this need to be separate from |s.httpClient|?
instHttpClient := common.RetryableHttpClientTimeout(60)
instHttpClient.Transport = appItr
installationClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, instHttpClient)
if err != nil {
return nil, errors.New(err)
}
Expand Down Expand Up @@ -608,6 +618,21 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
return installationClient, nil
}

func createGitHubClient(httpClient *http.Client, apiEndpoint string) (ghClient *github.Client, err error) {
// If we're using public GitHub, make a regular client.
// Otherwise, make an enterprise client.
if apiEndpoint == "https://api.github.com" {
ghClient = github.NewClient(httpClient)
} else {
ghClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, httpClient)
if err != nil {
return nil, errors.New(err)
}
}

return ghClient, err
}

func (s *Source) scan(ctx context.Context, installationClient *github.Client, chunksChan chan *sources.Chunk) error {
var scanned uint64

Expand Down
5 changes: 3 additions & 2 deletions pkg/sources/github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,15 +339,16 @@ func TestHandleRateLimit(t *testing.T) {
func TestEnumerateUnauthenticated(t *testing.T) {
defer gock.Off()

gock.New("https://api.github.com").
apiEndpoint := "https://api.github.com"
gock.New(apiEndpoint).
Get("/orgs/super-secret-org/repos").
Reply(200).
JSON([]map[string]string{{"clone_url": "https://github.com/super-secret-repo.git", "full_name": "super-secret-repo"}})

s := initTestSource(nil)
s.orgsCache = memory.New()
s.orgsCache.Set("super-secret-org", "super-secret-org")
s.enumerateUnauthenticated(context.Background())
s.enumerateUnauthenticated(context.Background(), apiEndpoint)
assert.Equal(t, 1, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("super-secret-repo")
assert.True(t, ok)
Expand Down

0 comments on commit 2290954

Please sign in to comment.