Skip to content

Commit

Permalink
Showing 4 changed files with 74 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -293,16 +293,16 @@ func (a *Azure) configFromCD() ([]byte, error) {
//
//nolint:gocyclo
func (a *Azure) NetworkConfiguration(ctx context.Context, _ state.State, ch chan<- *runtime.PlatformNetworkConfig) error {
log.Printf("fetching azure instance config from: %q", AzureMetadataEndpoint)

metadata, err := a.getMetadata(ctx)
metadata, apiVersion, err := a.getMetadata(ctx)
if err != nil {
return err
}

log.Printf("fetching network config from %q", AzureInterfacesEndpoint)
interfacesEndpoint := fmt.Sprintf(AzureInterfacesEndpoint, apiVersion)

log.Printf("fetching network config from %q", interfacesEndpoint)

metadataNetworkConfig, err := download.Download(ctx, AzureInterfacesEndpoint,
metadataNetworkConfig, err := download.Download(ctx, interfacesEndpoint,
download.WithHeaders(map[string]string{"Metadata": "true"}))
if err != nil {
return fmt.Errorf("failed to fetch network config from metadata service: %w", err)
@@ -319,11 +319,13 @@ func (a *Azure) NetworkConfiguration(ctx context.Context, _ state.State, ch chan
return fmt.Errorf("failed to parse network metadata: %w", err)
}

log.Printf("fetching load balancer metadata from: %q", AzureLoadbalancerEndpoint)
loadbalancerEndpoint := fmt.Sprintf(AzureLoadbalancerEndpoint, apiVersion)

log.Printf("fetching load balancer metadata from: %q", loadbalancerEndpoint)

var loadBalancerAddresses LoadBalancerMetadata

lbConfig, err := download.Download(ctx, AzureLoadbalancerEndpoint,
lbConfig, err := download.Download(ctx, loadbalancerEndpoint,
download.WithHeaders(map[string]string{"Metadata": "true"}),
download.WithErrorOnNotFound(errors.ErrNoConfigSource),
download.WithErrorOnEmptyResponse(errors.ErrNoConfigSource))
Original file line number Diff line number Diff line change
@@ -9,8 +9,8 @@ import (
"encoding/json"
stderrors "errors"
"fmt"
"log"

"github.com/siderolabs/talos/internal/app/machined/pkg/runtime/v1alpha1/platform/errors"
"github.com/siderolabs/talos/pkg/download"
)

@@ -19,15 +19,21 @@ const (
// ref: https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service
// ref: https://github.com/Azure/azure-rest-api-specs/blob/main/specification/imds/data-plane/Microsoft.InstanceMetadataService/stable/2023-07-01/examples/GetInstanceMetadata.json

// AzureVersion is the version of the Azure metadata service.
AzureVersion = "2021-12-13"

// AzureVersionFallback is the fallback version of the Azure metadata service (e.g. Azure Stack Hub).
AzureVersionFallback = "2019-06-01"

// AzureInternalEndpoint is the Azure Internal Channel IP
// https://blogs.msdn.microsoft.com/mast/2015/05/18/what-is-the-ip-address-168-63-129-16/
AzureInternalEndpoint = "http://168.63.129.16"
// AzureMetadataEndpoint is the local endpoint for the metadata.
AzureMetadataEndpoint = "http://169.254.169.254/metadata/instance/compute?api-version=2021-12-13&format=json"
AzureMetadataEndpoint = "http://169.254.169.254/metadata/instance/compute?api-version=%s&format=json"
// AzureInterfacesEndpoint is the local endpoint to get external IPs.
AzureInterfacesEndpoint = "http://169.254.169.254/metadata/instance/network/interface?api-version=2021-12-13&format=json"
AzureInterfacesEndpoint = "http://169.254.169.254/metadata/instance/network/interface?api-version=%s&format=json"
// AzureLoadbalancerEndpoint is the local endpoint for load balancer config.
AzureLoadbalancerEndpoint = "http://169.254.169.254/metadata/loadbalancer?api-version=2021-05-01&format=json"
AzureLoadbalancerEndpoint = "http://169.254.169.254/metadata/loadbalancer?api-version=%s&format=json"

mnt = "/mnt"
)
@@ -54,18 +60,38 @@ type ComputeMetadata struct {
EvictionPolicy string `json:"evictionPolicy,omitempty"`
}

func (a *Azure) getMetadata(ctx context.Context) (*ComputeMetadata, error) {
metadataDl, err := download.Download(ctx, AzureMetadataEndpoint,
download.WithHeaders(map[string]string{"Metadata": "true"}))
if err != nil && !stderrors.Is(err, errors.ErrNoHostname) {
return nil, fmt.Errorf("error fetching metadata: %w", err)
func (a *Azure) getMetadata(ctx context.Context) (*ComputeMetadata, string, error) {
apiVersion := AzureVersion
errBadRequest := stderrors.New("bad request")

metadataEndpoint := fmt.Sprintf(AzureMetadataEndpoint, apiVersion)

log.Printf("fetching azure instance config from: %q", metadataEndpoint)

metadataDl, err := download.Download(ctx, metadataEndpoint,
download.WithHeaders(map[string]string{"Metadata": "true"}),
download.WithErrorOnBadRequest(errBadRequest),
)
if err != nil && stderrors.Is(err, errBadRequest) {
apiVersion = AzureVersionFallback
metadataEndpoint = fmt.Sprintf(AzureMetadataEndpoint, apiVersion)

log.Printf("fetching azure instance config from: %q", metadataEndpoint)

metadataDl, err = download.Download(ctx, metadataEndpoint,
download.WithHeaders(map[string]string{"Metadata": "true"}),
)
}

if err != nil {
return nil, "", fmt.Errorf("error fetching metadata: %w", err)
}

var metadata ComputeMetadata

if err = json.Unmarshal(metadataDl, &metadata); err != nil {
return nil, fmt.Errorf("failed to parse compute metadata: %w", err)
return nil, "", fmt.Errorf("failed to parse compute metadata: %w", err)
}

return &metadata, nil
return &metadata, apiVersion, nil
}
13 changes: 13 additions & 0 deletions pkg/download/download.go
Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@ type downloadOptions struct {
EndpointFunc func(context.Context) (string, error)

ErrorOnNotFound error
ErrorOnBadRequest error
ErrorOnEmptyResponse error

Timeout time.Duration
@@ -108,6 +109,13 @@ func WithErrorOnEmptyResponse(e error) Option {
}
}

// WithErrorOnBadRequest provides specific error to return when response has HTTP 400 error.
func WithErrorOnBadRequest(e error) Option {
return func(d *downloadOptions) {
d.ErrorOnBadRequest = e
}
}

// WithEndpointFunc provides a function that sets the endpoint of the download options.
func WithEndpointFunc(endpointFunc func(context.Context) (string, error)) Option {
return func(d *downloadOptions) {
@@ -212,6 +220,7 @@ func Download(ctx context.Context, endpoint string, opts ...Option) (b []byte, e
return b, nil
}

//nolint:gocyclo
func download(req *http.Request, options *downloadOptions) (data []byte, err error) {
transport := httpdefaults.PatchTransport(cleanhttp.DefaultTransport())
transport.RegisterProtocol("tftp", NewTFTPTransport())
@@ -249,6 +258,10 @@ func download(req *http.Request, options *downloadOptions) (data []byte, err err
return data, options.ErrorOnNotFound
}

if resp.StatusCode == http.StatusBadRequest && options.ErrorOnBadRequest != nil {
return data, options.ErrorOnBadRequest
}

if resp.StatusCode != http.StatusOK {
// try to read first 32 bytes of the response body
// to provide more context in case of error
15 changes: 15 additions & 0 deletions pkg/download/download_test.go
Original file line number Diff line number Diff line change
@@ -53,6 +53,9 @@ func TestDownload(t *testing.T) {
case "/base64":
w.WriteHeader(http.StatusOK)
w.Write([]byte("ZGF0YQ==")) //nolint:errcheck
case "/400":
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintln(w, "bad request")
case "/404":
w.WriteHeader(http.StatusNotFound)
fmt.Fprintln(w, "not found")
@@ -107,12 +110,24 @@ func TestDownload(t *testing.T) {
opts: []download.Option{download.WithErrorOnNotFound(errors.New("gone forever"))},
expectedError: "gone forever",
},
{
name: "bad request error",
path: "/400",
opts: []download.Option{download.WithErrorOnBadRequest(errors.New("bad req"))},
expectedError: "bad req",
},
{
name: "failure 404",
path: "/404",
opts: []download.Option{download.WithTimeout(2 * time.Second)},
expectedError: "failed to download config, status code 404, body \"not found\\n\"",
},
{
name: "failure 400",
path: "/400",
opts: []download.Option{download.WithTimeout(2 * time.Second)},
expectedError: "failed to download config, status code 400, body \"bad request\\n\"",
},
{
name: "retry endpoint change",
opts: []download.Option{

0 comments on commit 9a23d84

Please sign in to comment.