Skip to content

Commit

Permalink
cell: Add DecorateAll for decorating everywhere
Browse files Browse the repository at this point in the history
Add the ability to decorate a type in all scopes. This is sometimes
useful when one needs to swap out a value or implementation coming
from a cell that one doesn't control.

Signed-off-by: Jussi Maki <[email protected]>
  • Loading branch information
joamaki committed Nov 6, 2024
1 parent a606391 commit 316b277
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 11 deletions.
4 changes: 3 additions & 1 deletion cell/cell.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type Cell interface {
Info(container) Info

// Apply the cell to the dependency graph container.
Apply(container) error
Apply(container, rootContainer) error
}

// In when embedded into a struct used as constructor parameter makes the exported
Expand Down Expand Up @@ -49,3 +49,5 @@ type container interface {
Decorate(fn any, opts ...dig.DecorateOption) error
Scope(name string, opts ...dig.ScopeOption) *dig.Scope
}

type rootContainer = container
2 changes: 1 addition & 1 deletion cell/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func decoderConfig(target any, extraHooks DecodeHooks) *mapstructure.DecoderConf
}
}

func (c *config[Cfg]) Apply(cont container) error {
func (c *config[Cfg]) Apply(cont container, _ rootContainer) error {
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
c.defaultConfig.Flags(flags)

Expand Down
38 changes: 36 additions & 2 deletions cell/decorator.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ type decorator struct {
cells []Cell
}

func (d *decorator) Apply(c container) error {
func (d *decorator) Apply(c container, rc rootContainer) error {
scope := c.Scope(fmt.Sprintf("(decorate %s)", internal.PrettyType(d.decorator)))
if err := scope.Decorate(d.decorator); err != nil {
return err
}

for _, cell := range d.cells {
if err := cell.Apply(scope); err != nil {
if err := cell.Apply(scope, rc); err != nil {
return err
}
}
Expand All @@ -60,3 +60,37 @@ func (d *decorator) Info(c container) Info {
}
return n
}

// DecorateAll takes a decorator function and applies the decoration globally.
//
// Example:
//
// cell.Module(
// "my-app",
// "My application",
// foo.Cell, // provides foo.Foo
// bar.Cell,
//
// // Wrap 'foo.Foo' everywhere, including inside foo.Cell.
// cell.DecorateAll(
// func(f foo.Foo) foo.Foo {
// return myFooWrapper{f}
// },
// ),
// )
func DecorateAll(dtor any) Cell {
return &allDecorator{dtor}
}

type allDecorator struct {
decorator any
}

func (d *allDecorator) Apply(_ container, rc rootContainer) error {
return rc.Decorate(d.decorator)
}

func (d *allDecorator) Info(_ container) Info {
n := NewInfoNode(fmt.Sprintf("🔀* %s: %s", internal.FuncNameAndLocation(d.decorator), internal.PrettyType(d.decorator)))
return n
}
4 changes: 2 additions & 2 deletions cell/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ func Group(cells ...Cell) Cell {
return group(cells)
}

func (g group) Apply(c container) error {
func (g group) Apply(c container, rc rootContainer) error {
for _, cell := range g {
if err := cell.Apply(c); err != nil {
if err := cell.Apply(c, rc); err != nil {
return err
}
}
Expand Down
2 changes: 1 addition & 1 deletion cell/invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (inv *invoker) invoke(log *slog.Logger, cont container, logThreshold time.D
return nil
}

func (inv *invoker) Apply(c container) error {
func (inv *invoker) Apply(c container, _ rootContainer) error {
// Remember the scope in which we need to invoke.
invoker := func(log *slog.Logger, logThreshold time.Duration) error { return inv.invoke(log, c, logThreshold) }

Expand Down
4 changes: 2 additions & 2 deletions cell/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func (m *module) modulePrivateProviders(scope *dig.Scope) error {
return scope.Invoke(provide)
}

func (m *module) Apply(c container) error {
func (m *module) Apply(c container, rc rootContainer) error {
scope := c.Scope(m.id)

// Provide ModuleID and FullModuleID in the module's scope.
Expand Down Expand Up @@ -190,7 +190,7 @@ func (m *module) Apply(c container) error {
}

for _, cell := range m.cells {
if err := cell.Apply(scope); err != nil {
if err := cell.Apply(scope, rc); err != nil {
return err
}
}
Expand Down
2 changes: 1 addition & 1 deletion cell/provide.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type provider struct {
export bool
}

func (p *provider) Apply(c container) error {
func (p *provider) Apply(c container, rc rootContainer) error {
// Since the same Provide cell may be used multiple times
// in different hives we use a mutex to protect it and we
// fill the provide info only the first time.
Expand Down
2 changes: 1 addition & 1 deletion hive.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func NewWithOptions(opts Options, cells ...cell.Cell) *Hive {
// and adds all config flags. Invokes are delayed until Start() is
// called.
for _, cell := range cells {
if err := cell.Apply(h.container); err != nil {
if err := cell.Apply(h.container, h.container); err != nil {
panic(fmt.Sprintf("Failed to apply cell: %s", err))
}
}
Expand Down
26 changes: 26 additions & 0 deletions hive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,32 @@ func TestDecorate(t *testing.T) {
assert.True(t, invoked, "expected decorated invoke function to be called")
}

func TestDecorateAll(t *testing.T) {
rootX, testX := 0, 0
hive := hive.New(
cell.Module("test", "test",
cell.Provide(func() *SomeObject { return &SomeObject{1} }),
cell.Invoke(func(o *SomeObject) { testX = o.X }),
),

cell.Invoke(func(o *SomeObject) {
rootX = o.X
}),

cell.DecorateAll(
func(o *SomeObject) *SomeObject {
return &SomeObject{X: o.X + 1}
},
),

shutdownOnStartCell,
)

assert.NoError(t, hive.Run(hivetest.Logger(t)), "expected Run() to succeed")
assert.Equal(t, 2, rootX, "expected object at root scope to have X=2")
assert.Equal(t, 2, testX, "expected object in test module scope to have X=2")
}

func TestShutdown(t *testing.T) {
//
// Happy paths without a shutdown error:
Expand Down

0 comments on commit 316b277

Please sign in to comment.