Skip to content

Commit

Permalink
[chore] - small updates (#2288)
Browse files Browse the repository at this point in the history
* small updates

* fix logic

* simplify fxn

* remove errors

* use strings.EqualFold
  • Loading branch information
ahrav authored Jan 11, 2024
1 parent aa40654 commit 9408425
Showing 1 changed file with 28 additions and 30 deletions.
58 changes: 28 additions & 30 deletions pkg/sources/github/github.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package github

import (
"errors"
"fmt"
"net/http"
"net/url"
Expand All @@ -15,7 +16,6 @@ import (
"time"

"github.com/bradleyfalzon/ghinstallation/v2"
"github.com/go-errors/errors"
gogit "github.com/go-git/go-git/v5"
"github.com/go-logr/logr"
"github.com/gobwas/glob"
Expand Down Expand Up @@ -217,7 +217,7 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so
var conn sourcespb.GitHub
err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{})
if err != nil {
return errors.WrapPrefix(err, "error unmarshalling connection", 0)
return fmt.Errorf("error unmarshalling connection: %w", err)
}
s.conn = &conn

Expand Down Expand Up @@ -323,7 +323,7 @@ func (s *Source) Validate(ctx context.Context) []error {
errs = append(errs, fmt.Errorf("error creating GitHub client: %+v", err))
}
default:
errs = append(errs, errors.Errorf("Invalid configuration given for source. Name: %s, Type: %s", s.name, s.Type()))
errs = append(errs, fmt.Errorf("Invalid configuration given for source. Name: %s, Type: %s", s.name, s.Type()))
}

// Run a simple query to check if the client is actually valid
Expand Down Expand Up @@ -419,11 +419,13 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility s
return
}

const cloudEndpoint = "https://api.github.com"

// Chunks emits chunks of bytes over a channel.
func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, targets ...sources.ChunkingTarget) error {
apiEndpoint := s.conn.Endpoint
if len(apiEndpoint) == 0 || endsWithGithub.MatchString(apiEndpoint) {
apiEndpoint = "https://api.github.com"
apiEndpoint = cloudEndpoint
}

// If targets are provided, we're only scanning the data in those targets.
Expand Down Expand Up @@ -469,7 +471,7 @@ func (s *Source) enumerate(ctx context.Context, apiEndpoint string) (*github.Cli
}
default:
// TODO: move this error to Init
return nil, errors.Errorf("Invalid configuration given for source. Name: %s, Type: %s", s.name, s.Type())
return nil, fmt.Errorf("Invalid configuration given for source. Name: %s, Type: %s", s.name, s.Type())
}

s.repos = make([]string, 0, s.filteredRepoCache.Count())
Expand Down Expand Up @@ -550,7 +552,6 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri

// If we're using public GitHub, make a regular client.
// Otherwise, make an enterprise client.
var isGHE bool = apiEndpoint != "https://api.github.com"
ghClient, err := createGitHubClient(s.httpClient, apiEndpoint)
if err != nil {
s.log.Error(err, "error creating GitHub client")
Expand All @@ -577,7 +578,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri
continue
}
if err != nil {
return errors.New(err)
return fmt.Errorf("error getting user: %w", err)
}
break
}
Expand Down Expand Up @@ -606,6 +607,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri
s.log.Error(err, "error fetching repos by user")
}

isGHE := !strings.EqualFold(apiEndpoint, cloudEndpoint)
if isGHE {
s.addAllVisibleOrgs(ctx)
} else {
Expand Down Expand Up @@ -653,12 +655,12 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri
func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *credentialspb.GitHubApp) (installationClient *github.Client, err error) {
installationID, err := strconv.ParseInt(app.InstallationId, 10, 64)
if err != nil {
return nil, errors.New(err)
return nil, err
}

appID, err := strconv.ParseInt(app.AppId, 10, 64)
if err != nil {
return nil, errors.New(err)
return nil, err
}

// This client is required to create installation tokens for cloning.
Expand All @@ -671,16 +673,16 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
appID,
[]byte(app.PrivateKey))
if err != nil {
return nil, errors.New(err)
return nil, err
}
appItr.BaseURL = apiEndpoint

// Does this need to be separate from |s.httpClient|?
instHttpClient := common.RetryableHttpClientTimeout(60)
instHttpClient.Transport = appItr
installationClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, instHttpClient)
instHTTPClient := common.RetryableHttpClientTimeout(60)
instHTTPClient.Transport = appItr
installationClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, instHTTPClient)
if err != nil {
return nil, errors.New(err)
return nil, err
}

// This client is used for most APIs.
Expand All @@ -690,14 +692,14 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
installationID,
[]byte(app.PrivateKey))
if err != nil {
return nil, errors.New(err)
return nil, err
}
itr.BaseURL = apiEndpoint

s.httpClient.Transport = itr
s.apiClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, s.httpClient)
if err != nil {
return nil, errors.New(err)
return nil, err
}

// If no repos were provided, enumerate them.
Expand Down Expand Up @@ -728,19 +730,14 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app *
return installationClient, nil
}

func createGitHubClient(httpClient *http.Client, apiEndpoint string) (ghClient *github.Client, err error) {
func createGitHubClient(httpClient *http.Client, apiEndpoint string) (*github.Client, error) {
// If we're using public GitHub, make a regular client.
// Otherwise, make an enterprise client.
if apiEndpoint == "https://api.github.com" {
ghClient = github.NewClient(httpClient)
} else {
ghClient, err = github.NewEnterpriseClient(apiEndpoint, apiEndpoint, httpClient)
if err != nil {
return nil, errors.New(err)
}
if strings.EqualFold(apiEndpoint, cloudEndpoint) {
return github.NewClient(httpClient), nil
}

return ghClient, err
return github.NewEnterpriseClient(apiEndpoint, apiEndpoint, httpClient)
}

func (s *Source) scan(ctx context.Context, installationClient *github.Client, chunksChan chan *sources.Chunk) error {
Expand Down Expand Up @@ -969,7 +966,7 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {
continue
}
if err != nil {
s.log.Error(err, "Could not list all organizations")
s.log.Error(err, "could not list all organizations")
return
}
if len(orgs) == 0 {
Expand All @@ -981,11 +978,12 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) {

for _, org := range orgs {
var name string
if org.Name != nil {
switch {
case org.Name != nil:
name = *org.Name
} else if org.Login != nil {
case org.Login != nil:
name = *org.Login
} else {
default:
continue
}
s.orgsCache.Set(name, name)
Expand Down Expand Up @@ -1046,7 +1044,7 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error {
continue
}
if err != nil || len(members) == 0 {
return errors.New("Could not list organization members: account may not have access to list organization members")
return fmt.Errorf("could not list organization members: account may not have access to list organization members %w", err)
}
if res == nil {
break
Expand Down

0 comments on commit 9408425

Please sign in to comment.