From 2290954b028a330a821306092611546b749f450f Mon Sep 17 00:00:00 2001 From: Richard Gomez <32133502+rgmz@users.noreply.github.com> Date: Tue, 25 Jul 2023 23:03:08 -0400 Subject: [PATCH] fix(github): use apiEndpoint for basic or no auth (#1454) --- pkg/sources/github/github.go | 65 +++++++++++++++++++++---------- pkg/sources/github/github_test.go | 5 ++- 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index 1ab6283198f1..4749fb6557c0 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -356,11 +356,11 @@ func (s *Source) enumerate(ctx context.Context, apiEndpoint string) (*github.Cli switch cred := s.conn.GetCredential().(type) { case *sourcespb.GitHub_BasicAuth: - if err = s.enumerateBasicAuth(ctx, cred.BasicAuth); err != nil { + if err = s.enumerateBasicAuth(ctx, apiEndpoint, cred.BasicAuth); err != nil { return nil, err } case *sourcespb.GitHub_Unauthenticated: - s.enumerateUnauthenticated(ctx) + s.enumerateUnauthenticated(ctx, apiEndpoint) case *sourcespb.GitHub_Token: if err = s.enumerateWithToken(ctx, apiEndpoint, cred.Token); err != nil { return nil, err @@ -382,11 +382,16 @@ func (s *Source) enumerate(ctx context.Context, apiEndpoint string) (*github.Cli return installationClient, nil } -func (s *Source) enumerateBasicAuth(ctx context.Context, basicAuth *credentialspb.BasicAuth) error { - s.apiClient = github.NewClient(&http.Client{Transport: &github.BasicAuthTransport{ +func (s *Source) enumerateBasicAuth(ctx context.Context, apiEndpoint string, basicAuth *credentialspb.BasicAuth) error { + s.httpClient.Transport = &github.BasicAuthTransport{ Username: basicAuth.Username, Password: basicAuth.Password, - }}) + } + ghClient, err := createGitHubClient(s.httpClient, apiEndpoint) + if err != nil { + s.log.Error(err, "error creating GitHub client") + } + s.apiClient = ghClient for _, org := range s.orgsCache.Keys() { if err := s.getReposByOrg(ctx, org); err != nil { @@ -397,8 +402,12 @@ func (s *Source) enumerateBasicAuth(ctx context.Context, basicAuth *credentialsp return nil } -func (s *Source) enumerateUnauthenticated(ctx context.Context) { - s.apiClient = github.NewClient(s.httpClient) +func (s *Source) enumerateUnauthenticated(ctx context.Context, apiEndpoint string) { + ghClient, err := createGitHubClient(s.httpClient, apiEndpoint) + if err != nil { + s.log.Error(err, "error creating GitHub client") + } + s.apiClient = ghClient if s.orgsCache.Count() > unauthGithubOrgRateLimt { s.log.Info("You may experience rate limiting when using the unauthenticated GitHub api. Consider using an authenticated scan instead.") } @@ -432,19 +441,14 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri Source: oauth2.ReuseTokenSource(nil, ts), } - var err error - // If we're using public Github, make a regular client. + // If we're using public GitHub, make a regular client. // Otherwise, make an enterprise client. - var isGHE bool - if apiEndpoint == "https://api.github.com" { - s.apiClient = github.NewClient(s.httpClient) - } else { - isGHE = true - s.apiClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, s.httpClient) - if err != nil { - return errors.New(err) - } + var isGHE bool = apiEndpoint != "https://api.github.com" + ghClient, err := createGitHubClient(s.httpClient, apiEndpoint) + if err != nil { + s.log.Error(err, "error creating GitHub client") } + s.apiClient = ghClient // TODO: this should support scanning users too @@ -560,7 +564,9 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app * return nil, errors.New(err) } itr.BaseURL = apiEndpoint - s.apiClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, &http.Client{Transport: itr}) + + s.httpClient.Transport = itr + s.apiClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, s.httpClient) if err != nil { return nil, errors.New(err) } @@ -575,7 +581,11 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app * return nil, errors.New(err) } appItr.BaseURL = apiEndpoint - installationClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, &http.Client{Transport: appItr}) + + // Does this need to be separate from |s.httpClient|? + instHttpClient := common.RetryableHttpClientTimeout(60) + instHttpClient.Transport = appItr + installationClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, instHttpClient) if err != nil { return nil, errors.New(err) } @@ -608,6 +618,21 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app * return installationClient, nil } +func createGitHubClient(httpClient *http.Client, apiEndpoint string) (ghClient *github.Client, err error) { + // If we're using public GitHub, make a regular client. + // Otherwise, make an enterprise client. + if apiEndpoint == "https://api.github.com" { + ghClient = github.NewClient(httpClient) + } else { + ghClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, httpClient) + if err != nil { + return nil, errors.New(err) + } + } + + return ghClient, err +} + func (s *Source) scan(ctx context.Context, installationClient *github.Client, chunksChan chan *sources.Chunk) error { var scanned uint64 diff --git a/pkg/sources/github/github_test.go b/pkg/sources/github/github_test.go index e7946bbc1178..ee106638b80a 100644 --- a/pkg/sources/github/github_test.go +++ b/pkg/sources/github/github_test.go @@ -339,7 +339,8 @@ func TestHandleRateLimit(t *testing.T) { func TestEnumerateUnauthenticated(t *testing.T) { defer gock.Off() - gock.New("https://api.github.com"). + apiEndpoint := "https://api.github.com" + gock.New(apiEndpoint). Get("/orgs/super-secret-org/repos"). Reply(200). JSON([]map[string]string{{"clone_url": "https://github.com/super-secret-repo.git", "full_name": "super-secret-repo"}}) @@ -347,7 +348,7 @@ func TestEnumerateUnauthenticated(t *testing.T) { s := initTestSource(nil) s.orgsCache = memory.New() s.orgsCache.Set("super-secret-org", "super-secret-org") - s.enumerateUnauthenticated(context.Background()) + s.enumerateUnauthenticated(context.Background(), apiEndpoint) assert.Equal(t, 1, s.filteredRepoCache.Count()) ok := s.filteredRepoCache.Exists("super-secret-repo") assert.True(t, ok)