Skip to content

Commit

Permalink
Update to reduce unnecessary API calls and sanitize input
Browse files Browse the repository at this point in the history
Update UpdateHTTPServers and UpdateStreamServers:
- No longer make extra GET requests for each PUT and DELETE request.
- Removes identical duplicate servers.
- Returns errors for duplicate servers with different parameters.
  • Loading branch information
dylan-way committed Jan 24, 2025
1 parent eb3d45d commit f625289
Show file tree
Hide file tree
Showing 3 changed files with 642 additions and 184 deletions.
206 changes: 156 additions & 50 deletions client/nginx.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ var (
ErrServerExists = errors.New("server already exists")
ErrNotSupported = errors.New("not supported")
ErrInvalidTimeout = errors.New("invalid timeout")
ErrParameterMismatch = errors.New("encountered duplicate server with different parameters")
ErrPlusVersionNotFound = errors.New("plus version not found in the input string")
)

Expand Down Expand Up @@ -775,9 +776,13 @@ func (client *NginxClient) AddHTTPServer(ctx context.Context, upstream string, s
if id != -1 {
return fmt.Errorf("failed to add %v server to %v upstream: %w", server.Server, upstream, ErrServerExists)
}
err = client.addHTTPServer(ctx, upstream, server)
return err
}

func (client *NginxClient) addHTTPServer(ctx context.Context, upstream string, server UpstreamServer) error {
path := fmt.Sprintf("http/upstreams/%v/servers/", upstream)
err = client.post(ctx, path, &server)
err := client.post(ctx, path, &server)
if err != nil {
return fmt.Errorf("failed to add %v server to %v upstream: %w", server.Server, upstream, err)
}
Expand All @@ -794,9 +799,13 @@ func (client *NginxClient) DeleteHTTPServer(ctx context.Context, upstream string
if id == -1 {
return fmt.Errorf("failed to remove %v server from %v upstream: %w", server, upstream, ErrServerNotFound)
}
err = client.deleteHTTPServer(ctx, upstream, server, id)
return err
}

path := fmt.Sprintf("http/upstreams/%v/servers/%v", upstream, id)
err = client.delete(ctx, path, http.StatusOK)
func (client *NginxClient) deleteHTTPServer(ctx context.Context, upstream, server string, serverID int) error {
path := fmt.Sprintf("http/upstreams/%v/servers/%v", upstream, serverID)
err := client.delete(ctx, path, http.StatusOK)
if err != nil {
return fmt.Errorf("failed to remove %v server from %v upstream: %w", server, upstream, err)
}
Expand All @@ -809,6 +818,8 @@ func (client *NginxClient) DeleteHTTPServer(ctx context.Context, upstream string
// Servers that aren't in the slice, but exist in NGINX, will be removed from NGINX.
// Servers that are in the slice and exist in NGINX, but have different parameters, will be updated.
// The client will attempt to update all servers, returning all the errors that occurred.
// If there are duplicate servers with equivalent parameters, the duplicates will be ignored.
// If there are duplicate servers with different parameters, those server entries will be ignored and an error returned.
func (client *NginxClient) UpdateHTTPServers(ctx context.Context, upstream string, servers []UpstreamServer) (added []UpstreamServer, deleted []UpstreamServer, updated []UpstreamServer, err error) {
serversInNginx, err := client.GetHTTPServers(ctx, upstream)
if err != nil {
Expand All @@ -822,10 +833,12 @@ func (client *NginxClient) UpdateHTTPServers(ctx context.Context, upstream strin
formattedServers = append(formattedServers, server)
}

formattedServers, err = deduplicateServers(upstream, formattedServers)

toAdd, toDelete, toUpdate := determineUpdates(formattedServers, serversInNginx)

for _, server := range toAdd {
addErr := client.AddHTTPServer(ctx, upstream, server)
addErr := client.addHTTPServer(ctx, upstream, server)
if addErr != nil {
err = errors.Join(err, addErr)
continue
Expand All @@ -834,7 +847,7 @@ func (client *NginxClient) UpdateHTTPServers(ctx context.Context, upstream strin
}

for _, server := range toDelete {
deleteErr := client.DeleteHTTPServer(ctx, upstream, server.Server)
deleteErr := client.deleteHTTPServer(ctx, upstream, server.Server, server.ID)
if deleteErr != nil {
err = errors.Join(err, deleteErr)
continue
Expand All @@ -858,46 +871,82 @@ func (client *NginxClient) UpdateHTTPServers(ctx context.Context, upstream strin
return added, deleted, updated, err
}

// haveSameParameters checks if a given server has the same parameters as a server already present in NGINX. Order matters.
func haveSameParameters(newServer UpstreamServer, serverNGX UpstreamServer) bool {
newServer.ID = serverNGX.ID
func deduplicateServers(upstream string, servers []UpstreamServer) ([]UpstreamServer, error) {
type serverCheck struct {
server UpstreamServer
valid bool
}

if serverNGX.MaxConns != nil && newServer.MaxConns == nil {
newServer.MaxConns = &defaultMaxConns
serverMap := make(map[string]*serverCheck, len(servers))
var err error
for _, server := range servers {
if prev, ok := serverMap[server.Server]; ok {
if !prev.valid {
continue
}
if !server.hasSameParametersAs(prev.server) {
prev.valid = false
err = errors.Join(err, fmt.Errorf(
"failed to update %s server to %s upstream: %w",
server.Server, upstream, ErrParameterMismatch))
}
continue
}
serverMap[server.Server] = &serverCheck{server, true}
}
retServers := make([]UpstreamServer, 0, len(serverMap))
for _, server := range servers {
if check, ok := serverMap[server.Server]; ok && check.valid {
retServers = append(retServers, server)
delete(serverMap, server.Server)
}
}
return retServers, err
}

if serverNGX.MaxFails != nil && newServer.MaxFails == nil {
newServer.MaxFails = &defaultMaxFails
// hasSameParametersAs checks if a given server has the same parameters.
func (s UpstreamServer) hasSameParametersAs(compareServer UpstreamServer) bool {
s.ID = compareServer.ID
s.applyDefaults()
compareServer.applyDefaults()
return reflect.DeepEqual(s, compareServer)
}

func (s *UpstreamServer) applyDefaults() {
if s.MaxConns == nil {
s.MaxConns = &defaultMaxConns
}

if serverNGX.FailTimeout != "" && newServer.FailTimeout == "" {
newServer.FailTimeout = defaultFailTimeout
if s.MaxFails == nil {
s.MaxFails = &defaultMaxFails
}

if serverNGX.SlowStart != "" && newServer.SlowStart == "" {
newServer.SlowStart = defaultSlowStart
if s.FailTimeout == "" {
s.FailTimeout = defaultFailTimeout
}

if serverNGX.Backup != nil && newServer.Backup == nil {
newServer.Backup = &defaultBackup
if s.SlowStart == "" {
s.SlowStart = defaultSlowStart
}

if serverNGX.Down != nil && newServer.Down == nil {
newServer.Down = &defaultDown
if s.Backup == nil {
s.Backup = &defaultBackup
}

if serverNGX.Weight != nil && newServer.Weight == nil {
newServer.Weight = &defaultWeight
if s.Down == nil {
s.Down = &defaultDown
}

return reflect.DeepEqual(newServer, serverNGX)
if s.Weight == nil {
s.Weight = &defaultWeight
}
}

func determineUpdates(updatedServers []UpstreamServer, nginxServers []UpstreamServer) (toAdd []UpstreamServer, toRemove []UpstreamServer, toUpdate []UpstreamServer) {
for _, server := range updatedServers {
updateFound := false
for _, serverNGX := range nginxServers {
if server.Server == serverNGX.Server && !haveSameParameters(server, serverNGX) {
if server.Server == serverNGX.Server && !server.hasSameParametersAs(serverNGX) {
server.ID = serverNGX.ID
updateFound = true
break
Expand Down Expand Up @@ -1089,9 +1138,13 @@ func (client *NginxClient) AddStreamServer(ctx context.Context, upstream string,
if id != -1 {
return fmt.Errorf("failed to add %v stream server to %v upstream: %w", server.Server, upstream, ErrServerExists)
}
err = client.addStreamServer(ctx, upstream, server)
return err
}

func (client *NginxClient) addStreamServer(ctx context.Context, upstream string, server StreamUpstreamServer) error {
path := fmt.Sprintf("stream/upstreams/%v/servers/", upstream)
err = client.post(ctx, path, &server)
err := client.post(ctx, path, &server)
if err != nil {
return fmt.Errorf("failed to add %v stream server to %v upstream: %w", server.Server, upstream, err)
}
Expand All @@ -1107,9 +1160,13 @@ func (client *NginxClient) DeleteStreamServer(ctx context.Context, upstream stri
if id == -1 {
return fmt.Errorf("failed to remove %v stream server from %v upstream: %w", server, upstream, ErrServerNotFound)
}
err = client.deleteStreamServer(ctx, upstream, server, id)
return err
}

path := fmt.Sprintf("stream/upstreams/%v/servers/%v", upstream, id)
err = client.delete(ctx, path, http.StatusOK)
func (client *NginxClient) deleteStreamServer(ctx context.Context, upstream, server string, serverID int) error {
path := fmt.Sprintf("stream/upstreams/%v/servers/%v", upstream, serverID)
err := client.delete(ctx, path, http.StatusOK)
if err != nil {
return fmt.Errorf("failed to remove %v stream server from %v upstream: %w", server, upstream, err)
}
Expand All @@ -1121,6 +1178,8 @@ func (client *NginxClient) DeleteStreamServer(ctx context.Context, upstream stri
// Servers that aren't in the slice, but exist in NGINX, will be removed from NGINX.
// Servers that are in the slice and exist in NGINX, but have different parameters, will be updated.
// The client will attempt to update all servers, returning all the errors that occurred.
// If there are duplicate servers with equivalent parameters, the duplicates will be ignored.
// If there are duplicate servers with different parameters, those server entries will be ignored and an error returned.
func (client *NginxClient) UpdateStreamServers(ctx context.Context, upstream string, servers []StreamUpstreamServer) (added []StreamUpstreamServer, deleted []StreamUpstreamServer, updated []StreamUpstreamServer, err error) {
serversInNginx, err := client.GetStreamServers(ctx, upstream)
if err != nil {
Expand All @@ -1133,10 +1192,12 @@ func (client *NginxClient) UpdateStreamServers(ctx context.Context, upstream str
formattedServers = append(formattedServers, server)
}

formattedServers, err = deduplicateStreamServers(upstream, formattedServers)

toAdd, toDelete, toUpdate := determineStreamUpdates(formattedServers, serversInNginx)

for _, server := range toAdd {
addErr := client.AddStreamServer(ctx, upstream, server)
addErr := client.addStreamServer(ctx, upstream, server)
if addErr != nil {
err = errors.Join(err, addErr)
continue
Expand All @@ -1145,7 +1206,7 @@ func (client *NginxClient) UpdateStreamServers(ctx context.Context, upstream str
}

for _, server := range toDelete {
deleteErr := client.DeleteStreamServer(ctx, upstream, server.Server)
deleteErr := client.deleteStreamServer(ctx, upstream, server.Server, server.ID)
if deleteErr != nil {
err = errors.Join(err, deleteErr)
continue
Expand Down Expand Up @@ -1184,45 +1245,82 @@ func (client *NginxClient) getIDOfStreamServer(ctx context.Context, upstream str
return -1, nil
}

// haveSameParametersForStream checks if a given server has the same parameters as a server already present in NGINX. Order matters.
func haveSameParametersForStream(newServer StreamUpstreamServer, serverNGX StreamUpstreamServer) bool {
newServer.ID = serverNGX.ID
if serverNGX.MaxConns != nil && newServer.MaxConns == nil {
newServer.MaxConns = &defaultMaxConns
func deduplicateStreamServers(upstream string, servers []StreamUpstreamServer) ([]StreamUpstreamServer, error) {
type serverCheck struct {
server StreamUpstreamServer
valid bool
}

serverMap := make(map[string]*serverCheck, len(servers))
var err error
for _, server := range servers {
if prev, ok := serverMap[server.Server]; ok {
if !prev.valid {
continue
}
if !server.hasSameParametersAs(prev.server) {
prev.valid = false
err = errors.Join(err, fmt.Errorf(
"failed to update stream %s server to %s upstream: %w",
server.Server, upstream, ErrParameterMismatch))
}
continue
}
serverMap[server.Server] = &serverCheck{server, true}
}
retServers := make([]StreamUpstreamServer, 0, len(serverMap))
for _, server := range servers {
if check, ok := serverMap[server.Server]; ok && check.valid {
retServers = append(retServers, server)
delete(serverMap, server.Server)
}
}
return retServers, err
}

// hasSameParametersAs checks if a given server has the same parameters.
func (s StreamUpstreamServer) hasSameParametersAs(compareServer StreamUpstreamServer) bool {
s.ID = compareServer.ID
s.applyDefaults()
compareServer.applyDefaults()
return reflect.DeepEqual(s, compareServer)
}

if serverNGX.MaxFails != nil && newServer.MaxFails == nil {
newServer.MaxFails = &defaultMaxFails
func (s *StreamUpstreamServer) applyDefaults() {
if s.MaxConns == nil {
s.MaxConns = &defaultMaxConns
}

if serverNGX.FailTimeout != "" && newServer.FailTimeout == "" {
newServer.FailTimeout = defaultFailTimeout
if s.MaxFails == nil {
s.MaxFails = &defaultMaxFails
}

if serverNGX.SlowStart != "" && newServer.SlowStart == "" {
newServer.SlowStart = defaultSlowStart
if s.FailTimeout == "" {
s.FailTimeout = defaultFailTimeout
}

if serverNGX.Backup != nil && newServer.Backup == nil {
newServer.Backup = &defaultBackup
if s.SlowStart == "" {
s.SlowStart = defaultSlowStart
}

if serverNGX.Down != nil && newServer.Down == nil {
newServer.Down = &defaultDown
if s.Backup == nil {
s.Backup = &defaultBackup
}

if serverNGX.Weight != nil && newServer.Weight == nil {
newServer.Weight = &defaultWeight
if s.Down == nil {
s.Down = &defaultDown
}

return reflect.DeepEqual(newServer, serverNGX)
if s.Weight == nil {
s.Weight = &defaultWeight
}
}

func determineStreamUpdates(updatedServers []StreamUpstreamServer, nginxServers []StreamUpstreamServer) (toAdd []StreamUpstreamServer, toRemove []StreamUpstreamServer, toUpdate []StreamUpstreamServer) {
for _, server := range updatedServers {
updateFound := false
for _, serverNGX := range nginxServers {
if server.Server == serverNGX.Server && !haveSameParametersForStream(server, serverNGX) {
if server.Server == serverNGX.Server && !server.hasSameParametersAs(serverNGX) {
server.ID = serverNGX.ID
updateFound = true
break
Expand Down Expand Up @@ -1950,9 +2048,13 @@ func (client *NginxClient) deleteKeyValPairs(ctx context.Context, zone string, s
return nil
}

// UpdateHTTPServer updates the server of the upstream.
// UpdateHTTPServer updates the server of the upstream with the matching server ID.
func (client *NginxClient) UpdateHTTPServer(ctx context.Context, upstream string, server UpstreamServer) error {
path := fmt.Sprintf("http/upstreams/%v/servers/%v", upstream, server.ID)
// The server ID is expected in the URI, but not expected in the body.
// The NGINX API will return
// {"error":{"status":400,"text":"unknown parameter \"id\"","code":"UpstreamConfFormatError"}
// if the ID field is present.
server.ID = 0
err := client.patch(ctx, path, &server, http.StatusOK)
if err != nil {
Expand All @@ -1962,9 +2064,13 @@ func (client *NginxClient) UpdateHTTPServer(ctx context.Context, upstream string
return nil
}

// UpdateStreamServer updates the stream server of the upstream.
// UpdateStreamServer updates the stream server of the upstream with the matching server ID.
func (client *NginxClient) UpdateStreamServer(ctx context.Context, upstream string, server StreamUpstreamServer) error {
path := fmt.Sprintf("stream/upstreams/%v/servers/%v", upstream, server.ID)
// The server ID is expected in the URI, but not expected in the body.
// The NGINX API will return
// {"error":{"status":400,"text":"unknown parameter \"id\"","code":"UpstreamConfFormatError"}
// if the ID field is present.
server.ID = 0
err := client.patch(ctx, path, &server, http.StatusOK)
if err != nil {
Expand Down
Loading

0 comments on commit f625289

Please sign in to comment.