diff --git a/src/control/fault/code/codes.go b/src/control/fault/code/codes.go index 3eba1b3e8cd..e9d78fd6588 100644 --- a/src/control/fault/code/codes.go +++ b/src/control/fault/code/codes.go @@ -154,6 +154,7 @@ const ( ServerNoCompatibilityInsecure ServerPoolHasContainers ServerHugepagesDisabled + ServerBadFaultDomainLabels ) // server config fault codes diff --git a/src/control/server/config/faults.go b/src/control/server/config/faults.go index a48c1fe1863..b5128ffcb20 100644 --- a/src/control/server/config/faults.go +++ b/src/control/server/config/faults.go @@ -61,11 +61,6 @@ var ( "no DAOS IO Engines specified in configuration", "specify at least one IO Engine configuration ('engines' list parameter) and restart the control server", ) - FaultConfigFaultDomainInvalid = serverConfigFault( - code.ServerConfigFaultDomainInvalid, - "invalid fault domain", - "specify a valid fault domain ('fault_path' parameter) or callback script ('fault_cb' parameter) and restart the control server", - ) FaultConfigFaultCallbackNotFound = serverConfigFault( code.ServerConfigFaultCallbackNotFound, "fault domain callback script not found", @@ -113,6 +108,14 @@ var ( ) ) +func FaultConfigFaultDomainInvalid(err error) *fault.Fault { + return serverConfigFault( + code.ServerConfigFaultDomainInvalid, + fmt.Sprintf("invalid fault domain: %s", err.Error()), + "specify a valid fault domain ('fault_path' parameter) or callback script ('fault_cb' parameter) and restart the control server", + ) +} + func FaultConfigDuplicateFabric(curIdx, seenIdx int) *fault.Fault { return serverConfigFault( code.ServerConfigDuplicateFabric, diff --git a/src/control/server/faultdomain.go b/src/control/server/faultdomain.go index d7463bb630d..373803795f5 100644 --- a/src/control/server/faultdomain.go +++ b/src/control/server/faultdomain.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2020-2023 Intel Corporation. +// (C) Copyright 2020-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -54,8 +54,11 @@ func getFaultDomain(cfg *config.Server) (*system.FaultDomain, error) { func newFaultDomainFromConfig(domainStr string) (*system.FaultDomain, error) { fd, err := system.NewFaultDomainFromString(domainStr) - if err != nil || fd.NumLevels() == 0 { - return nil, config.FaultConfigFaultDomainInvalid + if err != nil { + return nil, config.FaultConfigFaultDomainInvalid(err) + } + if fd.NumLevels() == 0 { + return nil, config.FaultConfigFaultDomainInvalid(errors.New("at least one domain level is required")) } // TODO DAOS-6353: remove when multiple layers supported if fd.NumLevels() > 2 { diff --git a/src/control/server/faultdomain_test.go b/src/control/server/faultdomain_test.go index 51e2fbe29a5..bf168007a8d 100644 --- a/src/control/server/faultdomain_test.go +++ b/src/control/server/faultdomain_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2020-2022 Intel Corporation. +// (C) Copyright 2020-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -52,7 +52,7 @@ func TestServer_getDefaultFaultDomain(t *testing.T) { getHostname: func() (string, error) { return "/////////", nil }, - expErr: config.FaultConfigFaultDomainInvalid, + expErr: config.FaultConfigFaultDomainInvalid(errors.New("domain name \"\": empty string is an invalid fault domain")), }, } { t.Run(name, func(t *testing.T) { @@ -107,13 +107,13 @@ func TestServer_getFaultDomain(t *testing.T) { cfg: &config.Server{ FaultPath: "junk", }, - expErr: config.FaultConfigFaultDomainInvalid, + expErr: config.FaultConfigFaultDomainInvalid(errors.New("fault path must start with root (/)")), }, "root-only path is not valid": { cfg: &config.Server{ FaultPath: "/", }, - expErr: config.FaultConfigFaultDomainInvalid, + expErr: config.FaultConfigFaultDomainInvalid(errors.New("at least one domain level is required")), }, "too many layers": { // TODO DAOS-6353: change when arbitrary layers supported cfg: &config.Server{ @@ -284,11 +284,11 @@ func TestServer_getFaultDomainFromCallback(t *testing.T) { }, "script returned invalid fault domain": { scriptPath: invalidScriptPath, - expErr: config.FaultConfigFaultDomainInvalid, + expErr: config.FaultConfigFaultDomainInvalid(errors.New("fault path must start with root (/)")), }, "script returned root fault domain": { scriptPath: rootScriptPath, - expErr: config.FaultConfigFaultDomainInvalid, + expErr: config.FaultConfigFaultDomainInvalid(errors.New("at least one domain level is required")), }, "script returned fault domain with too many layers": { // TODO DAOS-6353: change when multiple layers supported scriptPath: multiLayerScriptPath, diff --git a/src/control/server/faults.go b/src/control/server/faults.go index cf7d4e2ade2..5a5526e36fa 100644 --- a/src/control/server/faults.go +++ b/src/control/server/faults.go @@ -171,6 +171,15 @@ func FaultNoCompatibilityInsecure(self, other build.Version) *fault.Fault { ) } +func FaultBadFaultDomainLabels(faultPath, addr string, reqLabels, systemLabels []string) *fault.Fault { + return serverFault( + code.ServerBadFaultDomainLabels, + fmt.Sprintf("labels in join request [%s] don't match system labels [%s] for server %s (fault path: %s)", + strings.Join(reqLabels, ", "), strings.Join(systemLabels, ", "), addr, faultPath), + "update the 'fault_path' or executable specified in 'fault_cb' in the affected server's configuration file to match the system labels", + ) +} + func serverFault(code code.Code, desc, res string) *fault.Fault { return &fault.Fault{ Domain: "server", diff --git a/src/control/server/mgmt_system.go b/src/control/server/mgmt_system.go index a35753d95ba..536b3397a42 100644 --- a/src/control/server/mgmt_system.go +++ b/src/control/server/mgmt_system.go @@ -36,6 +36,7 @@ import ( "github.com/daos-stack/daos/src/control/lib/hostlist" "github.com/daos-stack/daos/src/control/lib/ranklist" "github.com/daos-stack/daos/src/control/logging" + "github.com/daos-stack/daos/src/control/server/config" "github.com/daos-stack/daos/src/control/system" "github.com/daos-stack/daos/src/control/system/checker" "github.com/daos-stack/daos/src/control/system/raft" @@ -43,6 +44,9 @@ import ( const fabricProviderProp = "fabric_providers" const groupUpdatePauseProp = "group_update_paused" +const domainLabelsProp = "domain_labels" + +const domainLabelsSep = "=" // invalid in a label name // GetAttachInfo handles a request to retrieve a map of ranks to fabric URIs, in addition // to client network autoconfiguration hints. @@ -182,9 +186,9 @@ func (svc *mgmtSvc) join(ctx context.Context, req *mgmtpb.JoinReq, peerAddr *net return nil, errors.Wrapf(err, "invalid uuid %q", req.Uuid) } - fd, err := system.NewFaultDomainFromString(req.SrvFaultDomain) + fd, err := svc.verifyFaultDomain(req) if err != nil { - return nil, errors.Wrapf(err, "invalid server fault domain %q", req.SrvFaultDomain) + return nil, err } if err := svc.checkReqFabricProvider(req, peerAddr, svc.events); err != nil { @@ -255,6 +259,67 @@ func (svc *mgmtSvc) join(ctx context.Context, req *mgmtpb.JoinReq, peerAddr *net return resp, nil } +func (svc *mgmtSvc) verifyFaultDomain(req *mgmtpb.JoinReq) (*system.FaultDomain, error) { + fd, err := system.NewFaultDomainFromString(req.SrvFaultDomain) + if err != nil { + return nil, config.FaultConfigFaultDomainInvalid(err) + } + + if fd.Empty() { + return nil, errors.New("no fault domain in join request") + } + + labels := fd.Labels + if !fd.HasLabels() { + // While saving the labels, an unlabeled fault domain sets the labels to empty + // strings. This allows us to distinguish between unset and unlabeled. + labels = make([]string, fd.NumLevels()) + } + + sysLabels, err := svc.getDomainLabels() + if system.IsErrSystemAttrNotFound(err) { + svc.log.Debugf("setting fault domain labels for the first time: %+v", labels) + if err := svc.setDomainLabels(labels); err != nil { + return nil, errors.Wrap(err, "failed to set fault domain labels") + } + return fd, nil + } + if err != nil { + return nil, errors.Wrap(err, "failed to get current fault domain labels") + } + + // If system labels are all empty strings, that indicates an unlabeled system. In errors + // and logging, clearer to present this as a completely empty array. + var printSysLabels []string + if sysLabels[0] != "" { + printSysLabels = sysLabels + } + + svc.log.Tracef("system labels: [%s], request labels: [%s]", strings.Join(printSysLabels, ", "), strings.Join(labels, ", ")) + if len(sysLabels) != len(labels) { + return nil, FaultBadFaultDomainLabels(req.SrvFaultDomain, req.Uri, fd.Labels, printSysLabels) + } + for i := range sysLabels { + if labels[i] != sysLabels[i] { + return nil, FaultBadFaultDomainLabels(req.SrvFaultDomain, req.Uri, fd.Labels, printSysLabels) + } + } + return fd, nil +} + +func (svc *mgmtSvc) getDomainLabels() ([]string, error) { + propStr, err := system.GetMgmtProperty(svc.sysdb, domainLabelsProp) + if err != nil { + return nil, err + } + return strings.Split(propStr, domainLabelsSep), nil +} + +func (svc *mgmtSvc) setDomainLabels(labels []string) error { + propStr := strings.Join(labels, domainLabelsSep) + return system.SetMgmtProperty(svc.sysdb, domainLabelsProp, propStr) +} + // allRanksJoined checks whether all ranks that the system knows about, and that are not admin // excluded, are joined. // diff --git a/src/control/server/mgmt_system_test.go b/src/control/server/mgmt_system_test.go index 0cffe7f63d8..0095cd3c87f 100644 --- a/src/control/server/mgmt_system_test.go +++ b/src/control/server/mgmt_system_test.go @@ -25,6 +25,7 @@ import ( "github.com/daos-stack/daos/src/control/build" "github.com/daos-stack/daos/src/control/common" + "github.com/daos-stack/daos/src/control/common/proto/mgmt" mgmtpb "github.com/daos-stack/daos/src/control/common/proto/mgmt" sharedpb "github.com/daos-stack/daos/src/control/common/proto/shared" "github.com/daos-stack/daos/src/control/common/test" @@ -1974,7 +1975,7 @@ func TestServer_MgmtSvc_Join(t *testing.T) { req: &mgmtpb.JoinReq{ SrvFaultDomain: "bad fault domain", }, - expErr: errors.New("bad fault domain"), + expErr: errors.New("invalid fault domain"), }, "dupe host same rank diff uuid": { req: &mgmtpb.JoinReq{ @@ -2366,6 +2367,154 @@ func TestServer_MgmtSvc_doGroupUpdate(t *testing.T) { } } +func TestMgmtSvc_verifyFaultDomain(t *testing.T) { + testURI := "tcp://localhost:10001" + for name, tc := range map[string]struct { + getSvc func(*testing.T, logging.Logger) *mgmtSvc + curLabels []string + req *mgmtpb.JoinReq + expFaultDomain *system.FaultDomain + expErr error + expLabels []string + }{ + "no fault domain": { + req: &mgmtpb.JoinReq{}, + expErr: errors.New("no fault domain"), + }, + "invalid fault domain": { + req: &mgmtpb.JoinReq{SrvFaultDomain: "junk"}, + expErr: errors.New("invalid fault domain"), + }, + "failed to get system domain labels": { + getSvc: func(t *testing.T, log logging.Logger) *mgmtSvc { + svc := newTestMgmtSvcMulti(t, log, maxEngines, false) + // not a replica + svc.sysdb = raft.MockDatabaseWithCfg(t, log, &raft.DatabaseConfig{ + SystemName: build.DefaultSystemName, + }) + + return svc + }, + req: &mgmt.JoinReq{SrvFaultDomain: "/rack=r1/node=n2"}, + expErr: &system.ErrNotReplica{}, + }, + "failed to set system domain labels": { + getSvc: func(t *testing.T, log logging.Logger) *mgmtSvc { + svc := newTestMgmtSvcMulti(t, log, maxEngines, true) + svc.sysdb = raft.MockDatabaseWithCfg(t, log, &raft.DatabaseConfig{ + SystemName: build.DefaultSystemName, + Replicas: []*net.TCPAddr{common.LocalhostCtrlAddr()}, + }) + if err := svc.sysdb.ResignLeadership(errors.New("test")); err != nil { + t.Fatal(err) + } + + return svc + }, + req: &mgmt.JoinReq{SrvFaultDomain: "/rack=r1/node=n2"}, + expErr: &system.ErrNotLeader{}, + }, + "first success with labels": { + req: &mgmt.JoinReq{SrvFaultDomain: "/rack=r1/node=n2"}, + expFaultDomain: system.MustCreateFaultDomainFromString("/rack=r1/node=n2"), + expLabels: []string{"rack", "node"}, + }, + "first success with no labels": { + req: &mgmt.JoinReq{SrvFaultDomain: "/r1/n2"}, + expFaultDomain: system.MustCreateFaultDomainFromString("/r1/n2"), + expLabels: []string{"", ""}, + }, + "success with labels": { + curLabels: []string{"rack", "node"}, + req: &mgmt.JoinReq{SrvFaultDomain: "/rack=r1/node=n2"}, + expFaultDomain: system.MustCreateFaultDomainFromString("/rack=r1/node=n2"), + expLabels: []string{"rack", "node"}, + }, + "success with no labels": { + curLabels: []string{"", ""}, + req: &mgmt.JoinReq{SrvFaultDomain: "/r1/n2"}, + expFaultDomain: system.MustCreateFaultDomainFromString("/r1/n2"), + expLabels: []string{"", ""}, + }, + "labeled request with unlabeled system": { + curLabels: []string{"", ""}, + req: &mgmt.JoinReq{ + SrvFaultDomain: "/rack=r1/node=n2", + Uri: testURI, + }, + expErr: FaultBadFaultDomainLabels("/rack=r1/node=n2", testURI, []string{"rack", "node"}, nil), + expLabels: []string{"", ""}, + }, + "unlabeled request with labeled system": { + curLabels: []string{"rack", "node"}, + req: &mgmt.JoinReq{ + SrvFaultDomain: "/r1/n2", + Uri: testURI, + }, + expErr: FaultBadFaultDomainLabels("/r1/n2", testURI, nil, []string{"rack", "node"}), + expLabels: []string{"rack", "node"}, + }, + "mismatched labels": { + curLabels: []string{"rack", "node"}, + req: &mgmt.JoinReq{ + SrvFaultDomain: "/rack=r1/host=n2", + Uri: testURI, + }, + expErr: FaultBadFaultDomainLabels("/rack=r1/host=n2", testURI, []string{"rack", "host"}, []string{"rack", "node"}), + expLabels: []string{"rack", "node"}, + }, + "mismatched length": { + curLabels: []string{"rack"}, + req: &mgmt.JoinReq{ + SrvFaultDomain: "/rack=r1/node=n2", + Uri: testURI, + }, + expErr: FaultBadFaultDomainLabels("/rack=r1/node=n2", testURI, []string{"rack", "node"}, []string{"rack"}), + expLabels: []string{"rack"}, + }, + } { + t.Run(name, func(t *testing.T) { + log, buf := logging.NewTestLogger(t.Name()) + defer test.ShowBufferOnFailure(t, buf) + + if tc.getSvc == nil { + tc.getSvc = func(t *testing.T, l logging.Logger) *mgmtSvc { + svc := mgmtSystemTestSetup(t, l, + system.Members{ + mockMember(t, 1, 1, "stopped"), + mockMember(t, 2, 2, "stopped"), + }, + []*control.HostResponse{}) + return svc + } + } + svc := tc.getSvc(t, log) + if tc.curLabels != nil { + if err := svc.setDomainLabels(tc.curLabels); err != nil { + t.Fatal(err) + } + } + + fd, err := svc.verifyFaultDomain(tc.req) + + test.CmpErr(t, tc.expErr, err) + test.AssertTrue(t, fd.Equals(tc.expFaultDomain), fmt.Sprintf("want %q, got %q", tc.expFaultDomain, fd)) + + if tc.expLabels == nil { + return + } + + newLabels, labelErr := svc.getDomainLabels() + if len(tc.expLabels) == 0 { + test.AssertTrue(t, system.IsErrSystemAttrNotFound(labelErr), "") + } else if labelErr != nil { + t.Fatal(labelErr) + } + test.CmpAny(t, "", tc.expLabels, newLabels) + }) + } +} + func TestMgmtSvc_updateFabricProviders(t *testing.T) { for name, tc := range map[string]struct { getSvc func(*testing.T, logging.Logger) *mgmtSvc diff --git a/src/control/system/faultdomain.go b/src/control/system/faultdomain.go index 43ee8644c6f..02de44ab5d3 100644 --- a/src/control/system/faultdomain.go +++ b/src/control/system/faultdomain.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2020-2021 Intel Corporation. +// (C) Copyright 2020-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -7,12 +7,13 @@ package system import ( - "errors" "fmt" "io" "math" "sort" "strings" + + "github.com/pkg/errors" ) const ( @@ -25,11 +26,15 @@ const ( // FaultDomainRootID is the ID of the root node FaultDomainRootID = 1 + + // FaultDomainLabelAssign is used to assign a label to a domain layer. + FaultDomainLabelAssign = "=" ) // FaultDomain represents a multi-layer fault domain. type FaultDomain struct { - Domains []string // Hierarchical sequence of fault domain levels + Domains []string `json:"domains"` // Hierarchical sequence of fault domain levels + Labels []string `json:"labels,omitempty"` // Labels for each layer of the domain (optional) } func (f *FaultDomain) String() string { @@ -39,7 +44,33 @@ func (f *FaultDomain) String() string { if f.Empty() { return FaultDomainSeparator } - return FaultDomainSeparator + strings.Join(f.Domains, FaultDomainSeparator) + return FaultDomainSeparator + strings.Join(f.DomainStrings(), FaultDomainSeparator) +} + +// DomainStrings returns the array of domain strings, including labels if applicable. +func (f *FaultDomain) DomainStrings() []string { + if f == nil || f.Empty() { + return []string{} + } + if f.HasLabels() { + levels := make([]string, 0, len(f.Domains)) + for i := range f.Domains { + levels = append(levels, fmt.Sprintf("%s%s%s", f.Labels[i], FaultDomainLabelAssign, f.Domains[i])) + } + return levels + } + return f.Domains +} + +// HasLabels checks whether the FaultDomain has associated labels for every level. +func (f *FaultDomain) HasLabels() bool { + if f == nil || f.Empty() { + return false + } + if len(f.Labels) == len(f.Domains) { + return true + } + return false } // Equals checks if the fault domains are equal. @@ -59,6 +90,14 @@ func (f *FaultDomain) Equals(other *FaultDomain) bool { if other.Domains[i] != dom { return false } + if other.HasLabels() || f.HasLabels() { + if other.HasLabels() != f.HasLabels() { + return false + } + if other.Labels[i] != f.Labels[i] { + return false + } + } } return true } @@ -82,7 +121,11 @@ func (f *FaultDomain) Level(level int) (string, error) { if level < 0 || level >= f.NumLevels() { return "", errors.New("out of range") } - return f.Domains[f.topLevelIdx()-level], nil + return f.Domains[f.levelIdx(level)], nil +} + +func (f *FaultDomain) levelIdx(level int) int { + return f.topLevelIdx() - level } func (f *FaultDomain) topLevelIdx() int { @@ -102,6 +145,17 @@ func (f *FaultDomain) TopLevel() string { return top } +// GetLabel returns the label for the bottom layer of the fault domain. +func (f *FaultDomain) GetLabel() string { + if f == nil { + return "(nil)" + } + if f.Empty() || !f.HasLabels() { + return "" + } + return f.Labels[f.levelIdx(0)] +} + // IsAncestorOf determines if this fault domain is an ancestor of the one passed in. // This includes a 0th-degree ancestor (i.e. it is identical). func (f *FaultDomain) IsAncestorOf(d *FaultDomain) bool { @@ -146,22 +200,74 @@ func (f *FaultDomain) MustCreateChild(childLevel string) *FaultDomain { return child } +func errFaultDomainSpecialChar(specialChar string) error { + return fmt.Errorf("invalid fault domain contains special character %q", specialChar) +} + // NewFaultDomain creates a FaultDomain from a sequence of strings representing // individual levels of the domain. // For each level of the domain, we assume case insensitivity and trim // leading/trailing whitespace. func NewFaultDomain(domains ...string) (*FaultDomain, error) { - for i := range domains { - domains[i] = strings.TrimSpace(domains[i]) - if domains[i] == "" || strings.Contains(domains[i], FaultDomainSeparator) { - return nil, errors.New("invalid fault domain") + fd := &FaultDomain{} + + normalize := func(s string) (string, error) { + // strip out all the spaces and quote marks without worrying about quote mark matching. + // e.g. " string1"" becomes string1 + for prev := s; ; prev = s { + s = strings.TrimSpace(s) + s = strings.Trim(s, "\"") + if s == prev { + break + } + } + if s == "" { + return "", fmt.Errorf("empty string is an invalid fault domain") + } + if strings.Contains(s, FaultDomainSeparator) { + return "", errFaultDomainSpecialChar(FaultDomainSeparator) + } + if strings.Contains(s, FaultDomainLabelAssign) { + return "", errFaultDomainSpecialChar(FaultDomainLabelAssign) + } + return strings.ToLower(s), nil + } + + var useLabels bool + for i, d := range domains { + var label string + var dom string + var err error + parts := strings.SplitN(d, "=", 2) + if len(parts) == 1 { + if i != 0 && useLabels { + return nil, fmt.Errorf("layer %d (%s) has no label, but other layers include labels", i, d) + } + + dom = parts[0] + } else { + if i == 0 { + useLabels = true + } else if !useLabels { + return nil, fmt.Errorf("layer %d (%s) has a label, but other layers don't include labels", i, d) + } + label = parts[0] + dom = parts[1] + + if label, err = normalize(label); err != nil { + return nil, errors.Wrapf(err, "domain label %q", label) + } + + fd.Labels = append(fd.Labels, label) } - domains[i] = strings.ToLower(domains[i]) + + if dom, err = normalize(dom); err != nil { + return nil, errors.Wrapf(err, "domain name %q", dom) + } + fd.Domains = append(fd.Domains, dom) } - return &FaultDomain{ - Domains: domains, - }, nil + return fd, nil } // MustCreateFaultDomain creates a FaultDomain from a sequence of strings @@ -189,7 +295,7 @@ func NewFaultDomainFromString(domainStr string) (*FaultDomain, error) { } if !strings.HasPrefix(domainStr, FaultDomainSeparator) { - return nil, errors.New("invalid fault domain") + return nil, errors.New("fault path must start with root (/)") } domains := strings.Split(domainStr, FaultDomainSeparator) @@ -214,9 +320,9 @@ type ( // This tree structure is not thread-safe and callers are expected to // add access synchronization if needed. FaultDomainTree struct { - Domain *FaultDomain - ID uint32 - Children []*FaultDomainTree + Domain *FaultDomain `json:"domain"` + ID uint32 `json:"id"` + Children []*FaultDomainTree `json:"children"` } ) @@ -253,6 +359,43 @@ func (t *FaultDomainTree) nextID() uint32 { return nextID } +// GetLabel gets the label for the top level of the FaultDomainTree. +func (t *FaultDomainTree) GetLabel() string { + if t == nil { + return "(nil)" + } + if t.IsRoot() { + return "" + } + return t.Domain.GetLabel() +} + +// Labels returns the sequence of non-root, labels from the top of this tree to the bottom. +// NB: This method assumes a balanced tree, and does not search for the longest branch when collecting the labels. +func (t *FaultDomainTree) Labels() ([]string, error) { + if t == nil { + return nil, errors.New("nil FaultDomainTree") + } + list := t.labels() + hasLabels := len(list) > 0 && list[0] != "" + if !hasLabels { + return []string{}, nil + } + return list, nil +} + +func (t *FaultDomainTree) labels() []string { + labelList := []string{} + if !t.IsRoot() { + labelList = append(labelList, t.GetLabel()) + } + if t.IsLeaf() { + return labelList + } + // all children must have the same label + return append(labelList, t.Children[0].labels()...) +} + // AddDomain adds a child fault domain, including intermediate nodes, to the // fault domain tree. func (t *FaultDomainTree) AddDomain(domain *FaultDomain) error { @@ -283,35 +426,47 @@ func (t *FaultDomainTree) Merge(t2 *FaultDomainTree) error { return nil // nothing to do } - // To merge, tree domains must match at the top. if !t.Domain.Equals(t2.Domain) { - return errors.New("trees cannot be merged") + return fmt.Errorf("FaultDomainTrees don't share a root, so cannot be merged") } nextID := t.nextID() - t.mergeTree(t2, &nextID) - return nil + return t.mergeTree(t2, &nextID, t.labels()) } -func (t *FaultDomainTree) mergeTree(toBeMerged *FaultDomainTree, nextID *uint32) { +func (t *FaultDomainTree) mergeTree(toBeMerged *FaultDomainTree, nextID *uint32, labels []string) error { for _, m := range toBeMerged.Children { foundBranch := false for _, p := range t.Children { if p.Domain.Equals(m.Domain) { foundBranch = true - p.mergeTree(m, nextID) + if err := p.mergeTree(m, nextID, labels[1:]); err != nil { + return err + } break } + if p.Domain.HasLabels() != m.Domain.HasLabels() { + if m.Domain.HasLabels() { + return errors.New("cannot merge a fault domain tree with labels into one with no labels") + } + return errors.New("cannot merge a fault domain tree with no labels into one with labels") + } + if len(labels) > 0 && m.GetLabel() != labels[0] { + return fmt.Errorf("cannot merge a fault domain tree with label %q at level with different existing label %q", m.GetLabel(), labels[0]) + } } if !foundBranch { if nextID != nil { m.updateAllIDs(nextID) } + if err := t.verifyChildLabels(m, labels); err != nil { + return err + } t.addChild(m) } } - return + return nil } func (t *FaultDomainTree) updateAllIDs(nextID *uint32) { @@ -322,6 +477,19 @@ func (t *FaultDomainTree) updateAllIDs(nextID *uint32) { } } +func (t *FaultDomainTree) verifyChildLabels(child *FaultDomainTree, labels []string) error { + if len(labels) == 0 { + return nil // nothing to check + } + if child.GetLabel() != labels[0] { + return fmt.Errorf("child tree with label %q at level with different existing label %q", child.GetLabel(), labels[0]) + } + if child.IsLeaf() { + return nil + } + return t.verifyChildLabels(child.Children[0], labels[1:]) +} + func (t *FaultDomainTree) addChild(child *FaultDomainTree) { t.Children = append(t.Children, child) sort.Slice(t.Children, func(i, j int) bool { @@ -530,7 +698,7 @@ func NewFaultDomainTree(domains ...*FaultDomain) *FaultDomainTree { nextID := tree.ID + 1 for _, d := range domains { subtree := faultDomainTreeFromDomain(d) - tree.mergeTree(subtree, &nextID) + tree.mergeTree(subtree, &nextID, tree.labels()) } return tree } @@ -540,8 +708,9 @@ func faultDomainTreeFromDomain(d *FaultDomain) *FaultDomainTree { nextID := tree.ID + 1 if !d.Empty() { node := tree + domainStrs := d.DomainStrings() for i := 0; i < d.NumLevels(); i++ { - childDomain := MustCreateFaultDomain(d.Domains[:i+1]...) + childDomain := MustCreateFaultDomain(domainStrs[:i+1]...) child := NewFaultDomainTree(). WithNodeDomain(childDomain). WithID(nextID) diff --git a/src/control/system/faultdomain_test.go b/src/control/system/faultdomain_test.go index 7011fa7b436..2f38c7bea3f 100644 --- a/src/control/system/faultdomain_test.go +++ b/src/control/system/faultdomain_test.go @@ -1,5 +1,5 @@ // -// (C) Copyright 2020-2022 Intel Corporation. +// (C) Copyright 2020-2024 Intel Corporation. // // SPDX-License-Identifier: BSD-2-Clause-Patent // @@ -27,19 +27,19 @@ func TestSystem_NewFaultDomain(t *testing.T) { }, "empty strings": { input: []string{"ok", ""}, - expErr: errors.New("invalid fault domain"), + expErr: errors.New("empty string"), }, "explicit root": { input: []string{"/"}, - expErr: errors.New("invalid fault domain"), + expErr: errors.New("special character"), }, "whitespace-only strings": { input: []string{"ok", "\t "}, - expErr: errors.New("invalid fault domain"), + expErr: errors.New("empty string"), }, "name contains separator": { input: []string{"ok", "alpha/beta"}, - expErr: errors.New("invalid fault domain"), + expErr: errors.New("special character"), }, "single-level": { input: []string{"ok"}, @@ -65,13 +65,47 @@ func TestSystem_NewFaultDomain(t *testing.T) { Domains: []string{"ok", "go"}, }, }, + "labels": { + input: []string{"First=ok", "second=go"}, + expResult: &FaultDomain{ + Domains: []string{"ok", "go"}, + Labels: []string{"first", "second"}, + }, + }, + "strip quote marks": { + input: []string{"first=\"ok\"", "SECOND=go"}, + expResult: &FaultDomain{ + Domains: []string{"ok", "go"}, + Labels: []string{"first", "second"}, + }, + }, + "empty label": { + input: []string{"=ok", "second=go"}, + expErr: errors.New("empty string"), + }, + "empty label with quote marks": { + input: []string{" \" \"\"=ok", "second=go"}, + expErr: errors.New("empty string"), + }, + "missing first label": { + input: []string{"ok", "second=go"}, + expErr: errors.New("has a label"), + }, + "missing second label": { + input: []string{"first=ok", "go"}, + expErr: errors.New("has no label"), + }, + "domain contains label assigner": { + input: []string{"first=ok", "second=go=back"}, + expErr: errors.New("special character"), + }, } { t.Run(name, func(t *testing.T) { result, err := NewFaultDomain(tc.input...) test.CmpErr(t, tc.expErr, err) - if diff := cmp.Diff(result, tc.expResult); diff != "" { + if diff := cmp.Diff(tc.expResult, result); diff != "" { t.Fatalf("(-want, +got): %s", diff) } }) @@ -130,7 +164,7 @@ func TestSystem_NewFaultDomainFromString(t *testing.T) { }, "fault domain doesn't start with separator": { input: "junk", - expErr: errors.New("invalid fault domain"), + expErr: errors.New("fault path must start with root"), }, "fault domain ends with separator": { input: "/junk/", @@ -144,6 +178,13 @@ func TestSystem_NewFaultDomainFromString(t *testing.T) { input: "/dc0/ /pdu2/host", expErr: errors.New("invalid fault domain"), }, + "labels": { + input: "/datacenter=dc0/rack=rack1/power=pdu2/node=host", + expResult: &FaultDomain{ + Domains: []string{"dc0", "rack1", "pdu2", "host"}, + Labels: []string{"datacenter", "rack", "power", "node"}, + }, + }, } { t.Run(name, func(t *testing.T) { result, err := NewFaultDomainFromString(tc.input) @@ -181,6 +222,13 @@ func TestSystem_FaultDomain_String(t *testing.T) { }, expStr: "/rack0/pdu1/host", }, + "labels": { + domain: &FaultDomain{ + Domains: []string{"rack0", "pdu1", "host"}, + Labels: []string{"rack", "power", "node"}, + }, + expStr: "/rack=rack0/power=pdu1/node=host", + }, } { t.Run(name, func(t *testing.T) { test.AssertEqual(t, tc.expStr, tc.domain.String(), "unexpected result") @@ -188,6 +236,84 @@ func TestSystem_FaultDomain_String(t *testing.T) { } } +func TestSystem_FaultDomain_DomainStrings(t *testing.T) { + for name, tc := range map[string]struct { + domain *FaultDomain + expStr []string + }{ + "nil": { + expStr: []string{}, + }, + "empty": { + domain: &FaultDomain{}, + expStr: []string{}, + }, + "single level": { + domain: &FaultDomain{ + Domains: []string{"host"}, + }, + expStr: []string{"host"}, + }, + "multi level": { + domain: &FaultDomain{ + Domains: []string{"rack0", "pdu1", "host"}, + }, + expStr: []string{"rack0", "pdu1", "host"}, + }, + "labels": { + domain: &FaultDomain{ + Domains: []string{"rack0", "pdu1", "host"}, + Labels: []string{"rack", "power", "node"}, + }, + expStr: []string{"rack=rack0", "power=pdu1", "node=host"}, + }, + } { + t.Run(name, func(t *testing.T) { + test.AssertEqual(t, tc.expStr, tc.domain.DomainStrings(), "unexpected result") + }) + } +} + +func TestSystem_FaultDomain_HasLabels(t *testing.T) { + for name, tc := range map[string]struct { + domain *FaultDomain + expResult bool + }{ + "nil": {}, + "empty": { + domain: &FaultDomain{}, + }, + "matching labels for domains": { + domain: &FaultDomain{ + Domains: []string{"a", "b", "c"}, + Labels: []string{"1", "2", "3"}, + }, + expResult: true, + }, + "no labels for domains": { + domain: &FaultDomain{ + Domains: []string{"a", "b", "c"}, + }, + }, + "not enough labels for domains": { + domain: &FaultDomain{ + Domains: []string{"a", "b", "c"}, + Labels: []string{"1", "2"}, + }, + }, + "too many labels for domains": { // length mismatch means we don't know how to match labels + domain: &FaultDomain{ + Domains: []string{"a", "b", "c"}, + Labels: []string{"1", "2", "3", "4"}, + }, + }, + } { + t.Run(name, func(t *testing.T) { + test.AssertEqual(t, tc.expResult, tc.domain.HasLabels(), "") + }) + } +} + func TestSystem_FaultDomain_Equals(t *testing.T) { for name, tc := range map[string]struct { domain1 *FaultDomain @@ -246,6 +372,38 @@ func TestSystem_FaultDomain_Equals(t *testing.T) { }, expResult: false, }, + "label matching": { + domain1: &FaultDomain{ + Domains: []string{"one", "two"}, + Labels: []string{"l1", "l2"}, + }, + domain2: &FaultDomain{ + Domains: []string{"one", "two"}, + Labels: []string{"l1", "l2"}, + }, + expResult: true, + }, + "label doesn't match": { + domain1: &FaultDomain{ + Domains: []string{"one", "two"}, + Labels: []string{"l1", "l2"}, + }, + domain2: &FaultDomain{ + Domains: []string{"one", "two"}, + Labels: []string{"l1", "l3"}, + }, + expResult: false, + }, + "labels vs no labels": { + domain1: &FaultDomain{ + Domains: []string{"one", "two"}, + Labels: []string{"l1", "l2"}, + }, + domain2: &FaultDomain{ + Domains: []string{"one", "two"}, + }, + expResult: false, + }, } { t.Run(name, func(t *testing.T) { test.AssertEqual(t, tc.domain1.Equals(tc.domain2), tc.expResult, "domain1.Equals failed") @@ -513,6 +671,35 @@ func TestSystem_FaultDomain_IsAncestorOf(t *testing.T) { } } +func TestSystem_FaultDomain_GetLabel(t *testing.T) { + for name, tc := range map[string]struct { + fd *FaultDomain + expResult string + }{ + "nil": { + expResult: "(nil)", + }, + "empty": { + fd: MustCreateFaultDomain(), + }, + "unlabeled": { + fd: MustCreateFaultDomain("one", "two", "three"), + }, + "single layer": { + fd: MustCreateFaultDomain("layer1=one"), + expResult: "layer1", + }, + "multi layer": { + fd: MustCreateFaultDomain("layer1=one", "layer2=two"), + expResult: "layer2", + }, + } { + t.Run(name, func(t *testing.T) { + test.AssertEqual(t, tc.expResult, tc.fd.GetLabel(), "") + }) + } +} + func TestSystem_FaultDomain_NewChild(t *testing.T) { for name, tc := range map[string]struct { orig *FaultDomain @@ -562,6 +749,30 @@ func TestSystem_FaultDomain_NewChild(t *testing.T) { childLevel: "/", expErr: errors.New("invalid fault domain"), }, + "with labels": { + orig: &FaultDomain{ + Domains: []string{"one=parent"}, + }, + childLevel: "two=child", + expResult: &FaultDomain{ + Domains: []string{"parent", "child"}, + Labels: []string{"one", "two"}, + }, + }, + "unlabeled child with labeled parent": { + orig: &FaultDomain{ + Domains: []string{"one=parent"}, + }, + childLevel: "child", + expErr: errors.New("labels"), + }, + "labeled child with unlabeled parent": { + orig: &FaultDomain{ + Domains: []string{"parent"}, + }, + childLevel: "two=child", + expErr: errors.New("labels"), + }, } { t.Run(name, func(t *testing.T) { result, err := tc.orig.NewChild(tc.childLevel) @@ -820,6 +1031,32 @@ func TestSystem_NewFaultDomainTree(t *testing.T) { }, }, }, + "multi-layer with labels": { + domains: []*FaultDomain{MustCreateFaultDomainFromString("/rack=r1/pdu=p2/node=n3")}, + expResult: &FaultDomainTree{ + Domain: MustCreateFaultDomain(), + ID: FaultDomainRootID, + Children: []*FaultDomainTree{ + { + Domain: MustCreateFaultDomainFromString("/rack=r1"), + ID: ExpFaultDomainID(1), + Children: []*FaultDomainTree{ + { + Domain: MustCreateFaultDomainFromString("/rack=r1/pdu=p2"), + ID: ExpFaultDomainID(2), + Children: []*FaultDomainTree{ + { + Domain: MustCreateFaultDomainFromString("/rack=r1/pdu=p2/node=n3"), + ID: ExpFaultDomainID(3), + Children: []*FaultDomainTree{}, + }, + }, + }, + }, + }, + }, + }, + }, } { t.Run(name, func(t *testing.T) { result := NewFaultDomainTree(tc.domains...) @@ -985,6 +1222,69 @@ func TestSystem_FaultDomainTree_nextID(t *testing.T) { } } +func TestSystem_FaultDomainTree_GetLabel(t *testing.T) { + for name, tc := range map[string]struct { + tree *FaultDomainTree + expResult string + }{ + "nil tree": { + expResult: "(nil)", + }, + "nil domain": { + tree: &FaultDomainTree{}, + }, + "root": { + tree: NewFaultDomainTree(), + }, + "labeled": { + tree: &FaultDomainTree{ + Domain: MustCreateFaultDomain("rack=r1"), + }, + expResult: "rack", + }, + "unlabeled": { + tree: &FaultDomainTree{ + Domain: MustCreateFaultDomain("r1"), + }, + }, + } { + t.Run(name, func(t *testing.T) { + test.AssertEqual(t, tc.expResult, tc.tree.GetLabel(), "") + }) + } +} + +func TestSystem_FaultDomainTree_Labels(t *testing.T) { + for name, tc := range map[string]struct { + tree *FaultDomainTree + expResult []string + expErr error + }{ + "nil": { + expErr: errors.New("nil"), + }, + "empty tree": { + tree: NewFaultDomainTree(), + expResult: []string{}, + }, + "labeled tree": { + tree: NewFaultDomainTree(MustCreateFaultDomainFromString("/room=lab100/rack=123/node=host456")), + expResult: []string{"room", "rack", "node"}, + }, + "unlabeled tree": { + tree: NewFaultDomainTree(MustCreateFaultDomainFromString("/lab100/123/host456")), + expResult: []string{}, + }, + } { + t.Run(name, func(t *testing.T) { + result, err := tc.tree.Labels() + + test.CmpErr(t, tc.expErr, err) + test.AssertEqual(t, tc.expResult, result, "") + }) + } +} + func TestSystem_FaultDomainTree_AddDomain(t *testing.T) { single := MustCreateFaultDomain("rack0") multi := single.MustCreateChild("node1") @@ -1191,7 +1491,7 @@ func TestSystem_FaultDomainTree_Merge(t *testing.T) { "different top level domains can't merge": { tree: NewFaultDomainTree(), toMerge: NewFaultDomainTree().WithNodeDomain(rack0), - expErr: errors.New("trees cannot be merged"), + expErr: errors.New("cannot be merged"), }, "merge single branch into empty tree": { tree: NewFaultDomainTree(), @@ -1264,6 +1564,31 @@ func TestSystem_FaultDomainTree_Merge(t *testing.T) { toMerge: fullTree(), expResult: fullTree(), }, + "merge labeled into empty tree": { + tree: NewFaultDomainTree(), + toMerge: NewFaultDomainTree(MustCreateFaultDomain("rack=rack1")), + expResult: NewFaultDomainTree(MustCreateFaultDomain("rack=rack1")), + }, + "merge labeled into unlabeled tree": { + tree: NewFaultDomainTree(rack0), + toMerge: NewFaultDomainTree(MustCreateFaultDomain("rack=rack1")), + expErr: errors.New("with labels into one with no labels"), + }, + "merge unlabeled into labeled tree": { + tree: NewFaultDomainTree(MustCreateFaultDomain("rack=rack1")), + toMerge: NewFaultDomainTree(rack0), + expErr: errors.New("with no labels into one with labels"), + }, + "merge different labels at same layer": { + tree: NewFaultDomainTree(MustCreateFaultDomain("room=lab123", "rack=rack1")), + toMerge: NewFaultDomainTree(MustCreateFaultDomain("room=lab123", "group=rack0")), + expErr: errors.New("different existing label"), + }, + "merge different labels on totally separate branches": { + tree: NewFaultDomainTree(MustCreateFaultDomain("room=lab123", "rack=rack1")), + toMerge: NewFaultDomainTree(MustCreateFaultDomain("room=lab456", "group=rack0")), + expErr: errors.New("different existing label"), + }, } { t.Run(name, func(t *testing.T) { if tc.expResult == nil && tc.tree != nil {