From 02eb176440e7c1fb7a0068400e9054761c3011c8 Mon Sep 17 00:00:00 2001 From: joel Date: Thu, 7 Mar 2024 16:22:17 +0800 Subject: [PATCH] fix: refactor out mixed hook invocation logic --- internal/api/hooks.go | 38 +++++++++++++++++++++++++++++++++----- internal/api/phone.go | 5 ----- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index bc1465cd6..937568a4d 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/http/httptrace" + "net/url" "strings" "time" @@ -22,8 +23,7 @@ import ( "github.com/supabase/auth/internal/storage" ) -func (a *API) runHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) { - +func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) { db := a.db.WithContext(ctx) request, err := json.Marshal(input) @@ -218,7 +218,32 @@ func isOverSizeLimit(payload []byte) bool { return len(payload) > maxSizeKB } +func validateHTTPHook(uri string) error { + u, err := url.Parse(uri) + if err != nil { + return err + } + if !(strings.ToLower(u.Scheme) == "http" || strings.ToLower(u.Scheme) == "https") { + return fmt.Errorf("invalid http hook") + } + return nil +} + +func validatePostgresHook(uri string) error { + u, err := url.Parse(uri) + if err != nil { + return err + } + if !(strings.ToLower(u.Scheme) == "pg-functions") { + return fmt.Errorf("invalid postgres hook") + } + return nil +} + func (a *API) invokeHTTPHook(input, output any, hookURI string) error { + if err := validateHTTPHook(hookURI); err != nil { + return err + } switch input.(type) { case *hooks.CustomSMSProviderInput: hookOutput, ok := output.(*hooks.CustomSMSProviderOutput) @@ -252,6 +277,9 @@ func (a *API) invokeHTTPHook(input, output any, hookURI string) error { // pass the current transaciton, as pool-exhaustion deadlocks are very easy to // trigger. func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, input, output any, hookURI string) error { + if err := validatePostgresHook(hookURI); err != nil { + return err + } config := a.config // Switch based on hook type switch input.(type) { @@ -261,7 +289,7 @@ func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, panic("output should be *hooks.MFAVerificationAttemptOutput") } - if _, err := a.runHook(ctx, conn, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil { + if _, err := a.runPostgresHook(ctx, conn, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil { return internalServerError("Error invoking MFA verification hook.").WithInternalError(err) } @@ -287,7 +315,7 @@ func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, panic("output should be *hooks.PasswordVerificationAttemptOutput") } - if _, err := a.runHook(ctx, conn, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil { + if _, err := a.runPostgresHook(ctx, conn, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil { return internalServerError("Error invoking password verification hook.").WithInternalError(err) } @@ -313,7 +341,7 @@ func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, panic("output should be *hooks.CustomAccessTokenOutput") } - if _, err := a.runHook(ctx, conn, config.Hook.CustomAccessToken.HookName, input, output); err != nil { + if _, err := a.runPostgresHook(ctx, conn, config.Hook.CustomAccessToken.HookName, input, output); err != nil { return internalServerError("Error invoking access token hook.").WithInternalError(err) } diff --git a/internal/api/phone.go b/internal/api/phone.go index b530705d8..71129cbed 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -99,15 +99,10 @@ func (a *API) sendPhoneConfirmation(tx *storage.Connection, user *models.User, p OTP: otp, } output := hooks.CustomSMSProviderOutput{} - // TODO: Fix this by either passing down the context ore removing the need for it err := a.invokeHTTPHook(&input, &output, config.Hook.CustomSMSProvider.URI) if err != nil { return "", err } - // if !output.Success { - // // TODO: adjust this - // return "", internalServerError("error sending sms using custom provider") - // } } else { messageID, err = smsProvider.SendMessage(phone, message, channel, otp)