diff --git a/bugout.go b/bugout.go index 041ae34..a28bad9 100644 --- a/bugout.go +++ b/bugout.go @@ -3,17 +3,51 @@ package main // Much of this code is copied from waggle: https://github.com/bugout-dev/waggle/blob/main/main.go import ( + "bytes" "encoding/csv" "encoding/json" "fmt" "io" + "net/http" "os" + "strconv" "strings" + "time" bugout "github.com/bugout-dev/bugout-go/pkg" spire "github.com/bugout-dev/bugout-go/pkg/spire" ) +type BugoutAPIClient struct { + BroodBaseURL string + SpireBaseURL string + HTTPClient *http.Client +} + +func InitBugoutAPIClient() (*BugoutAPIClient, error) { + if BROOD_API_URL == "" { + BROOD_API_URL = "https://auth.bugout.dev" + } + if SPIRE_API_URL == "" { + SPIRE_API_URL = "https://spire.bugout.dev" + } + if BUGOUT_API_TIMEOUT_SECONDS == "" { + BUGOUT_API_TIMEOUT_SECONDS = "10" + } + timeoutSeconds, conversionErr := strconv.Atoi(BUGOUT_API_TIMEOUT_SECONDS) + if conversionErr != nil { + return nil, conversionErr + } + timeout := time.Duration(timeoutSeconds) * time.Second + httpClient := http.Client{Timeout: timeout} + + return &BugoutAPIClient{ + BroodBaseURL: BROOD_API_URL, + SpireBaseURL: SPIRE_API_URL, + HTTPClient: &httpClient, + }, nil +} + func CleanTimestamp(rawTimestamp string) string { return strings.ReplaceAll(rawTimestamp, " ", "T") } @@ -137,3 +171,106 @@ func ProcessDropperClaims(client *bugout.BugoutClient, bugoutToken, journalID, c return processedErr } + +type User struct { + Id string `json:"user_id"` + Username string `json:"username"` + ApplicationId string `json:"application_id"` +} + +func (c *BugoutAPIClient) GetUser(accessToken string) (User, error) { + var user User + var requestBodyBytes []byte + request, requestErr := http.NewRequest("GET", fmt.Sprintf("%s/user", c.BroodBaseURL), bytes.NewBuffer(requestBodyBytes)) + if requestErr != nil { + return user, requestErr + } + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + request.Header.Add("Accept", "application/json") + request.Header.Add("Content-Type", "application/json") + + response, responseErr := c.HTTPClient.Do(request) + if responseErr != nil { + return user, responseErr + } + defer response.Body.Close() + + responseBody, responseBodyErr := io.ReadAll(response.Body) + + if response.StatusCode < 200 || response.StatusCode >= 300 { + if responseBodyErr != nil { + return user, fmt.Errorf("unexpected status code: %d -- could not read response body: %s", response.StatusCode, responseBodyErr.Error()) + } + } + + if responseBodyErr != nil { + return user, fmt.Errorf("could not read response body: %s", responseBodyErr.Error()) + } + + unmarshalErr := json.Unmarshal(responseBody, &user) + if unmarshalErr != nil { + return user, fmt.Errorf("could not parse response body: %s", unmarshalErr.Error()) + } + + return user, nil +} + +type AccessWaggleResourceData struct { + Type string `json:"type"` + Customer string `json:"customer"` + AccessLevel string `json:"access_level"` + UserId string `json:"user_id"` +} + +type AccessWaggleResource struct { + Id string `json:"id"` + ApplicationId string `json:"application_id"` + ResourceData AccessWaggleResourceData `json:"resource_data"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type AccessWaggleResources struct { + Resources []AccessWaggleResource `json:"resources"` +} + +func (c *BugoutAPIClient) GetAccessLevelFromResources() (AccessWaggleResources, error) { + var accessWaggleResources AccessWaggleResources + var requestBodyBytes []byte + request, requestErr := http.NewRequest("GET", fmt.Sprintf("%s/resources", c.BroodBaseURL), bytes.NewBuffer(requestBodyBytes)) + if requestErr != nil { + return accessWaggleResources, requestErr + } + queryParameters := request.URL.Query() + queryParameters.Add("application_id", MOONSTREAM_APPLICATION_ID) + queryParameters.Add("type", BUGOUT_RESOURCE_TYPE_WAGGLE_ACCESS) + + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", MOONSTREAM_WAGGLE_ADMIN_ACCESS_TOKEN)) + request.Header.Add("Accept", "application/json") + request.Header.Add("Content-Type", "application/json") + + response, responseErr := c.HTTPClient.Do(request) + if responseErr != nil { + return accessWaggleResources, responseErr + } + defer response.Body.Close() + + responseBody, responseBodyErr := io.ReadAll(response.Body) + + if response.StatusCode < 200 || response.StatusCode >= 300 { + if responseBodyErr != nil { + return accessWaggleResources, fmt.Errorf("unexpected status code: %d -- could not read response body: %s", response.StatusCode, responseBodyErr.Error()) + } + } + + if responseBodyErr != nil { + return accessWaggleResources, fmt.Errorf("could not read response body: %s", responseBodyErr.Error()) + } + + unmarshalErr := json.Unmarshal(responseBody, &accessWaggleResources) + if unmarshalErr != nil { + return accessWaggleResources, fmt.Errorf("could not parse response body: %s", unmarshalErr.Error()) + } + + return accessWaggleResources, nil +} diff --git a/cmd.go b/cmd.go index e15ac3e..c02a29a 100644 --- a/cmd.go +++ b/cmd.go @@ -555,7 +555,10 @@ func CreateServerCommand() *cobra.Command { for _, o := range strings.Split(WAGGLE_CORS_ALLOWED_ORIGINS, ",") { corsWhitelist[o] = true } - + bugoutClient, bugoutClientErr := InitBugoutAPIClient() + if bugoutClientErr != nil { + return bugoutClientErr + } moonstreamClient, clientErr := InitMoonstreamEngineAPIClient() if clientErr != nil { return clientErr @@ -566,6 +569,7 @@ func CreateServerCommand() *cobra.Command { Port: port, AvailableSigners: availableSigners, CORSWhitelist: corsWhitelist, + BugoutAPIClient: bugoutClient, MoonstreamEngineAPIClient: moonstreamClient, } diff --git a/sample.env b/sample.env index 0af2552..4889105 100644 --- a/sample.env +++ b/sample.env @@ -5,3 +5,5 @@ export BUGOUT_ACCESS_TOKEN="" # Server related environment variables export WAGGLE_CORS_ALLOWED_ORIGINS="http://localhost:3000,https://moonstream.to,https://portal.moonstream.to,https://www.moonstream.to" +export MOONSTREAM_APPLICATION_ID="" +export MOONSTREAM_WAGGLE_ADMIN_ACCESS_TOKEN="" diff --git a/server.go b/server.go index f9eee01..699b801 100644 --- a/server.go +++ b/server.go @@ -29,6 +29,7 @@ type Server struct { AvailableSigners map[string]AvailableSigner LogLevel int CORSWhitelist map[string]bool + BugoutAPIClient *BugoutAPIClient MoonstreamEngineAPIClient *MoonstreamEngineAPIClient ServerMu sync.Mutex @@ -50,6 +51,16 @@ type SignDropperRequest struct { MetatxRegistered bool `json:"metatx_registered"` } +type AccessLevel struct { + Admin bool + RequestSignatures bool +} + +type AuthorizationContext struct { + AuthorizationToken string + AccessLevel AccessLevel +} + // Check access id was provided correctly and save user access configuration to request context func (server *Server) accessMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -60,21 +71,66 @@ func (server *Server) accessMiddleware(next http.Handler) http.Handler { authorizationTokenRaw = h } var authorizationToken string - if authorizationTokenRaw != "" { - authorizationTokenSlice := strings.Split(authorizationTokenRaw, " ") - if len(authorizationTokenSlice) != 2 || authorizationTokenSlice[0] != "Bearer" || authorizationTokenSlice[1] == "" { - http.Error(w, "Wrong authorization token provided", http.StatusForbidden) - return - } - authorizationToken = authorizationTokenSlice[1] - _, uuidParseErr := uuid.Parse(authorizationToken) - if uuidParseErr != nil { - http.Error(w, "Wrong authorization token provided", http.StatusForbidden) - return + if authorizationTokenRaw == "" { + http.Error(w, "No authorization header passed with request", http.StatusForbidden) + return + } + + authorizationTokenSlice := strings.Split(authorizationTokenRaw, " ") + if len(authorizationTokenSlice) != 2 || authorizationTokenSlice[0] != "Bearer" || authorizationTokenSlice[1] == "" { + http.Error(w, "Wrong authorization token provided", http.StatusForbidden) + return + } + authorizationToken = authorizationTokenSlice[1] + _, uuidParseErr := uuid.Parse(authorizationToken) + if uuidParseErr != nil { + http.Error(w, "Wrong authorization token provided", http.StatusForbidden) + return + } + + user, getUserErr := server.BugoutAPIClient.GetUser(authorizationToken) + if getUserErr != nil { + log.Println(getUserErr) + http.Error(w, "Access token not found", http.StatusNotFound) + return + } + if user.ApplicationId != MOONSTREAM_APPLICATION_ID { + http.Error(w, "Wrong bugout application", http.StatusForbidden) + return + } + + accessWaggleResources, getAccessErr := server.BugoutAPIClient.GetAccessLevelFromResources() + if getAccessErr != nil { + log.Println(getAccessErr) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + var accessLevel AccessLevel + accessGranted := false + for _, resource := range accessWaggleResources.Resources { + if resource.ResourceData.UserId == user.Id { + if resource.ResourceData.AccessLevel == "admin" { + accessLevel.Admin = true + accessGranted = true + } + if resource.ResourceData.AccessLevel == "request_signatures" { + accessLevel.RequestSignatures = true + accessGranted = true + } } } + if !accessGranted { + http.Error(w, "Access restricted", http.StatusForbidden) + return + } + + authorizationContext := AuthorizationContext{ + AuthorizationToken: authorizationToken, + AccessLevel: accessLevel, + } - ctxUser := context.WithValue(r.Context(), "authorizationToken", authorizationToken) + ctxUser := context.WithValue(r.Context(), "authorizationContext", authorizationContext) next.ServeHTTP(w, r.WithContext(ctxUser)) }) @@ -171,7 +227,8 @@ func (server *Server) pingRoute(w http.ResponseWriter, r *http.Request) { // signDropperRoute sign dropper call requests func (server *Server) signDropperRoute(w http.ResponseWriter, r *http.Request) { - authorizationToken := r.Context().Value("authorizationToken").(string) + authorizationContext := r.Context().Value("authorizationContext").(AuthorizationContext) + authorizationToken := authorizationContext.AuthorizationToken queries := r.URL.Query() isMetatxDrop := queries.Has("is_metatx_drop") @@ -252,12 +309,11 @@ func (server *Server) signDropperRoute(w http.ResponseWriter, r *http.Request) { // Serve handles server run func (server *Server) Serve() error { serveMux := http.NewServeMux() + serveMux.Handle("/sign/dropper", server.accessMiddleware(http.HandlerFunc(server.signDropperRoute))) serveMux.HandleFunc("/ping", server.pingRoute) - serveMux.HandleFunc("/sign/dropper", server.signDropperRoute) // Set list of common middleware, from bottom to top - commonHandler := server.accessMiddleware(serveMux) - commonHandler = server.corsMiddleware(commonHandler) + commonHandler := server.corsMiddleware(serveMux) commonHandler = server.logMiddleware(commonHandler) commonHandler = server.panicMiddleware(commonHandler) diff --git a/settings.go b/settings.go index 3484058..2e11451 100644 --- a/settings.go +++ b/settings.go @@ -22,9 +22,17 @@ var ( MOONSTREAM_API_URL = os.Getenv("MOONSTREAM_API_URL") MOONSTREAM_API_TIMEOUT_SECONDS = os.Getenv("MOONSTREAM_API_TIMEOUT_SECONDS") + BROOD_API_URL = os.Getenv("BUGOUT_AUTH_URL") + SPIRE_API_URL = os.Getenv("BUGOUT_SPIRE_URL") + BUGOUT_API_TIMEOUT_SECONDS = os.Getenv("BUGOUT_API_TIMEOUT_SECONDS") + BUGOUT_ACCESS_TOKEN = os.Getenv("BUGOUT_ACCESS_TOKEN") - WAGGLE_CORS_ALLOWED_ORIGINS = os.Getenv("WAGGLE_CORS_ALLOWED_ORIGINS") + WAGGLE_CORS_ALLOWED_ORIGINS = os.Getenv("WAGGLE_CORS_ALLOWED_ORIGINS") + MOONSTREAM_APPLICATION_ID = os.Getenv("MOONSTREAM_APPLICATION_ID") + MOONSTREAM_WAGGLE_ADMIN_ACCESS_TOKEN = os.Getenv("MOONSTREAM_WAGGLE_ADMIN_ACCESS_TOKEN") + + BUGOUT_RESOURCE_TYPE_WAGGLE_ACCESS = "waggle-access" CASER = cases.Title(language.English) )