diff --git a/.gitignore b/.gitignore index 23249243..cf4f3388 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ /vendor /.bench +.idea/ +.vscode/ *.mem *.cpu *.test diff --git a/dig.go b/dig.go index 2607e7db..9a0bc724 100644 --- a/dig.go +++ b/dig.go @@ -229,6 +229,18 @@ type Container struct { // Defer acyclic check on provide until Invoke. deferAcyclicVerification bool + + // Name of the container. + name string + + // Sub graphs of the container. + children []*Container + + // Parent is the container that spawned this. + parent *Container + + // Decorator functions of already provided dependencies + decorators map[key][]*node } // containerWriter provides write access to the Container's underlying data @@ -257,7 +269,7 @@ type containerStore interface { // Retrieves all values for the provided group and type. // // The order in which the values are returned is undefined. - getValueGroup(name string, t reflect.Type) []reflect.Value + getValueGroup(name string, t reflect.Type) ([]reflect.Value, bool) // Returns the providers that can produce a value with the given name and // type. @@ -267,6 +279,9 @@ type containerStore interface { // type. getGroupProviders(name string, t reflect.Type) []provider + // Returns the decorator list of a particular node + getDecorators(k key) []*node + createGraph() *dot.Graph } @@ -297,10 +312,11 @@ type provider interface { // New constructs a Container. func New(opts ...Option) *Container { c := &Container{ - providers: make(map[key][]*node), - values: make(map[key]reflect.Value), - groups: make(map[key][]reflect.Value), - rand: rand.New(rand.NewSource(time.Now().UnixNano())), + providers: make(map[key][]*node), + values: make(map[key]reflect.Value), + groups: make(map[key][]reflect.Value), + decorators: make(map[key][]*node), + rand: rand.New(rand.NewSource(time.Now().UnixNano())), } for _, opt := range opts { @@ -331,6 +347,8 @@ func setRand(r *rand.Rand) Option { }) } +// knownTypes returns the types known to this container, including types known +// by its descendants. func (c *Container) knownTypes() []reflect.Type { typeSet := make(map[reflect.Type]struct{}, len(c.providers)) for k := range c.providers { @@ -341,6 +359,11 @@ func (c *Container) knownTypes() []reflect.Type { for t := range typeSet { types = append(types, t) } + + for _, c := range append(c.children) { + types = append(types, c.knownTypes()...) + } + sort.Sort(byTypeName(types)) return types } @@ -351,13 +374,17 @@ func (c *Container) getValue(name string, t reflect.Type) (v reflect.Value, ok b } func (c *Container) setValue(name string, t reflect.Type, v reflect.Value) { - c.values[key{name: name, t: t}] = v + k := key{t: t, name: name} + c.values[k] = v } -func (c *Container) getValueGroup(name string, t reflect.Type) []reflect.Value { - items := c.groups[key{group: name, t: t}] +func (c *Container) getValueGroup(name string, t reflect.Type) ([]reflect.Value, bool) { + items, ok := c.groups[key{group: name, t: t}] + if !ok { + return []reflect.Value{}, ok + } // shuffle the list so users don't rely on the ordering of grouped values - return shuffledCopy(c.rand, items) + return shuffledCopy(c.rand, items), true } func (c *Container) submitGroupedValue(name string, t reflect.Type, v reflect.Value) { @@ -366,11 +393,23 @@ func (c *Container) submitGroupedValue(name string, t reflect.Type, v reflect.Va } func (c *Container) getValueProviders(name string, t reflect.Type) []provider { - return c.getProviders(key{name: name, t: t}) + providers := c.getProviders(key{name: name, t: t}) + + for _, c := range c.children { + providers = append(providers, c.getValueProviders(name, t)...) + } + + return providers } func (c *Container) getGroupProviders(name string, t reflect.Type) []provider { - return c.getProviders(key{group: name, t: t}) + providers := c.getProviders(key{group: name, t: t}) + + for _, c := range c.children { + providers = append(providers, c.getGroupProviders(name, t)...) + } + + return providers } func (c *Container) getProviders(k key) []provider { @@ -382,6 +421,41 @@ func (c *Container) getProviders(k key) []provider { return providers } +func (c *Container) getDecorators(k key) []*node { + p := c + if _, ok := c.providers[k]; !ok { + cont := c.children + for len(cont) > 0 { + v := cont[0] + cont = cont[1:] + if _, ok := v.providers[k]; !ok { + cont = append(cont, v.children...) + } else { + p = v + break + } + } + } else { + p = c + } + decorators := make([]*node, 0) + for p != nil { + if _, ok := p.decorators[k]; ok { + decorators = append(decorators, p.decorators[k]...) + } + p = p.parent + } + return decorators +} + +func (c *Container) getRoot() *Container { + if c.parent == nil { + return c + } + + return c.parent.getRoot() +} + // Provide teaches the container how to build values of one or more types and // expresses their dependencies. // @@ -433,6 +507,7 @@ func (c *Container) Provide(constructor interface{}, opts ...ProvideOption) erro // The function may return an error to indicate failure. The error will be // returned to the caller as-is. func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error { + cp := c.getRoot() // run invoke on root to get access to all the graphs ftype := reflect.TypeOf(function) if ftype == nil { return errors.New("can't invoke an untyped nil") @@ -446,20 +521,20 @@ func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error { return err } - if err := shallowCheckDependencies(c, pl); err != nil { + if err := shallowCheckDependencies(cp, pl); err != nil { return errMissingDependencies{ Func: digreflect.InspectFunc(function), Reason: err, } } - if !c.isVerifiedAcyclic { - if err := c.verifyAcyclic(); err != nil { + if !cp.isVerifiedAcyclic { + if err := cp.verifyAcyclic(); err != nil { return err } } - args, err := pl.BuildList(c) + args, err := pl.BuildList(cp) if err != nil { return errArgumentsFailed{ Func: digreflect.InspectFunc(function), @@ -479,6 +554,54 @@ func (c *Container) Invoke(function interface{}, opts ...InvokeOption) error { return nil } +func (c *Container) Decorate(decorator interface{}, opts ...ProvideOption) error { + dtype := reflect.TypeOf(decorator) + if dtype == nil { + return errors.New("can't decorate with an untyped nil") + } + if dtype.Kind() != reflect.Func { + return fmt.Errorf("can't call non-function %v (type %v)", decorator, dtype) + } + + var options provideOptions + for _, o := range opts { + o.applyProvideOption(&options) + } + if err := options.Validate(); err != nil { + return err + } + + if err := c.decorate(decorator, options); err != nil { + return errConstructorFailed{ + Func: digreflect.InspectFunc(decorator), + Reason: err, + } + } + return nil +} + +// Child returns a named child of this container. The child container has +// full access to the parent's types, and any types provided to the child +// will be made available to the parent. +// +// The name of the child is for observability purposes only. As such, it +// does not have to be unique across different children of the container. +func (c *Container) Child(name string) *Container { + child := &Container{ + providers: make(map[key][]*node), + values: make(map[key]reflect.Value), + groups: make(map[key][]reflect.Value), + decorators: make(map[key][]*node), + rand: c.rand, + name: name, + parent: c, + } + + c.children = append(c.children, child) + + return child +} + func (c *Container) verifyAcyclic() error { visited := make(map[key]struct{}) for _, n := range c.nodes { @@ -522,7 +645,7 @@ func (c *Container) provide(ctor interface{}, opts provideOptions) error { if c.deferAcyclicVerification { continue } - if err := verifyAcyclic(c, n, k); err != nil { + if err := verifyAcyclic(c.getRoot(), n, k); err != nil { c.providers[k] = oldProviders return err } @@ -539,7 +662,7 @@ func (c *Container) findAndValidateResults(n *node) (map[key]struct{}, error) { var err error keyPaths := make(map[key]string) walkResult(n.ResultList(), connectionVisitor{ - c: c, + c: c.getRoot(), n: n, err: &err, keyPaths: keyPaths, @@ -556,6 +679,142 @@ func (c *Container) findAndValidateResults(n *node) (map[key]struct{}, error) { return keys, nil } +func (c *Container) decorate(dtor interface{}, opts provideOptions) error { + n, err := newNode( + dtor, + nodeOptions{ + ResultName: opts.Name, + ResultGroup: opts.Group, + ResultAs: opts.As, + }, + ) + if err != nil { + return err + } + + dtype := reflect.TypeOf(dtor) + + // Check if all the result types exist among the input types + inTypes := make(map[key]struct{}) + for i := 0; i < dtype.NumIn(); i++ { + in := dtype.In(i) + if IsIn(in) { + for j := 0; j < in.NumField(); j++ { + t := in.Field(j).Type + //Exclude embedded In type + if IsIn(t) { + continue + } + name := in.Field(j).Tag.Get(_nameTag) + group := in.Field(j).Tag.Get(_groupTag) + if name != "" && group != "" { + return errors.New("cannot use name tags and group tags together") + } + if group != "" { + if _, ok := inTypes[key{t.Elem(), name, group}]; ok { + return fmt.Errorf("cannot provide %v multple times in decorator", t) + } + inTypes[key{t.Elem(), name, group}] = struct{}{} + } else { + if _, ok := inTypes[key{t, name, group}]; ok { + return fmt.Errorf("cannot provide %v multple times in decorator", t) + } + inTypes[key{t, name, group}] = struct{}{} + } + } + } else { + inTypes[key{t: in}] = struct{}{} + } + } + outTypes := make(map[key]struct{}) + for i := 0; i < dtype.NumOut(); i++ { + out := dtype.Out(i) + if IsOut(out) { + for j := 0; j < out.NumField(); j++ { + t := out.Field(j).Type + //Exclude embedded Out type + if IsOut(t) { + continue + } + name := out.Field(j).Tag.Get(_nameTag) + group := out.Field(j).Tag.Get(_groupTag) + if name != "" && group != "" { + return errors.New("cannot use name tags and group tags together") + } + if _, ok := outTypes[key{t, name, group}]; ok { + return fmt.Errorf("cannot provide %v multple times in decorator", t) + } + outTypes[key{t, name, group}] = struct{}{} + } + } else { + outTypes[key{t: out}] = struct{}{} + } + } + + for k := range outTypes { + if _, ok := inTypes[k]; !ok { + return errors.New("the result types, with the exception of error, must be present among the input parameters") + } + delete(inTypes, k) + } + + params := []param{} + for k := range inTypes { + if k.group != "" { + params = append(params, paramGroupedSlice{k.group, reflect.SliceOf(k.t)}) + } else { + params = append(params, paramSingle{ + Name: k.name, + Type: k.t, + }) + } + } + + for k := range outTypes { + found := false + // Checking for the decorator output's existence in the sub graph with the + // current container as root. + if _, ok := c.providers[k]; !ok { + var cont []*Container + cont = append(cont, c.children...) + for !found && !(len(cont) == 0) { + v := cont[0] + cont = cont[1:] + if _, ok := v.providers[k]; !ok { + cont = append(cont, v.children...) + } else { + found = true + } + } + } else { + found = true + } + if !found { + return errors.New("decorator must be declared in the scope of the node's container or its ancestors')") + } + + if len(params) > 0 { + c.isVerifiedAcyclic = false + oldParams := n.paramList.Params + oldProviders := c.providers[k] + for _, p := range c.providers[k] { + params = append(params, p.paramList.Params...) + } + n.paramList.Params = params + c.providers[k] = append([]*node{n}, c.providers[k]...) + if err := verifyAcyclic(c.getRoot(), n, k); err != nil { + c.providers[k] = oldProviders + return err + } + c.providers[k] = oldProviders + n.paramList.Params = oldParams + c.isVerifiedAcyclic = true + } + c.decorators[k] = append(c.decorators[k], n) + } + return nil +} + // Visits the results of a node and compiles a collection of all the keys // produced by that node. type connectionVisitor struct { @@ -616,6 +875,7 @@ func (cv connectionVisitor) Visit(res result) resultVisitor { *cv.err = err return nil } + cv.keyPaths[k] = path for _, asType := range r.As { k := key{name: r.Name, t: asType} @@ -643,7 +903,7 @@ func (cv connectionVisitor) checkKey(k key, path string) error { "cannot provide %v from %v: already provided by %v", k, path, conflict) } - if ps := cv.c.providers[k]; len(ps) > 0 { + if ps := cv.c.getValueProviders(k.name, k.t); len(ps) > 0 { cons := make([]string, len(ps)) for i, p := range ps { cons[i] = fmt.Sprint(p.Location()) @@ -653,6 +913,7 @@ func (cv connectionVisitor) checkKey(k key, path string) error { "cannot provide %v from %v: already provided by %v", k, path, strings.Join(cons, "; ")) } + return nil } @@ -733,14 +994,12 @@ func (n *node) Call(c containerStore) error { if n.called { return nil } - if err := shallowCheckDependencies(c, n.paramList); err != nil { return errMissingDependencies{ Func: n.location, Reason: err, } } - args, err := n.paramList.BuildList(c) if err != nil { return errArgumentsFailed{ @@ -748,7 +1007,9 @@ func (n *node) Call(c containerStore) error { Reason: err, } } - + if n.called { + return nil + } receiver := newStagingContainerWriter() results := reflect.ValueOf(n.ctor).Call(args) if err := n.resultList.ExtractList(receiver, results); err != nil { @@ -804,8 +1065,9 @@ func shallowCheckDependencies(c containerStore, p param) error { // stagingContainerWriter is a containerWriter that records the changes that // would be made to a containerWriter and defers them until Commit is called. type stagingContainerWriter struct { - values map[key]reflect.Value - groups map[key][]reflect.Value + values map[key]reflect.Value + groups map[key][]reflect.Value + isDecorated map[key]bool } var _ containerWriter = (*stagingContainerWriter)(nil) diff --git a/dig_test.go b/dig_test.go index 82b519bd..ee748a7f 100644 --- a/dig_test.go +++ b/dig_test.go @@ -29,6 +29,7 @@ import ( "math/rand" "os" "reflect" + "strconv" "testing" "time" @@ -36,11 +37,25 @@ import ( "github.com/stretchr/testify/require" ) +// containerView is a view of one or more containers. +// +// The provide and invoke methods may point to different Container instances. +type containerView struct { + Provide func(interface{}, ...ProvideOption) error + Invoke func(interface{}, ...InvokeOption) error +} + +type newContainerFunc func(...Option) containerView + func TestEndToEndSuccess(t *testing.T) { + testSubGraphs(t, testEndToEndSuccess) +} + +func testEndToEndSuccess(t *testing.T, newContainer newContainerFunc) { t.Parallel() t.Run("pointer constructor", func(t *testing.T) { - c := New() + c := newContainer() var b *bytes.Buffer require.NoError(t, c.Provide(func() *bytes.Buffer { b = &bytes.Buffer{} @@ -56,7 +71,7 @@ func TestEndToEndSuccess(t *testing.T) { // Dig shouldn't forbid this - it's perfectly reasonable to explicitly // provide a typed nil, since that's often a convenient way to supply a // default no-op implementation. - c := New() + c := newContainer() require.NoError(t, c.Provide(func() *bytes.Buffer { return nil }), "provide failed") require.NoError(t, c.Invoke(func(b *bytes.Buffer) { require.Nil(t, b, "expected to get nil buffer") @@ -64,7 +79,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("struct constructor", func(t *testing.T) { - c := New() + c := newContainer() require.NoError(t, c.Provide(func() bytes.Buffer { var buf bytes.Buffer buf.WriteString("foo") @@ -77,7 +92,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("slice constructor", func(t *testing.T) { - c := New() + c := newContainer() b1 := &bytes.Buffer{} b2 := &bytes.Buffer{} require.NoError(t, c.Provide(func() []*bytes.Buffer { @@ -91,7 +106,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("array constructor", func(t *testing.T) { - c := New() + c := newContainer() bufs := [1]*bytes.Buffer{{}} require.NoError(t, c.Provide(func() [1]*bytes.Buffer { return bufs }), "provide failed") require.NoError(t, c.Invoke(func(bs [1]*bytes.Buffer) { @@ -100,7 +115,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("map constructor", func(t *testing.T) { - c := New() + c := newContainer() require.NoError(t, c.Provide(func() map[string]string { return map[string]string{} }), "provide failed") @@ -110,7 +125,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("channel constructor", func(t *testing.T) { - c := New() + c := newContainer() require.NoError(t, c.Provide(func() chan int { return make(chan int) }), "provide failed") @@ -120,7 +135,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("func constructor", func(t *testing.T) { - c := New() + c := newContainer() require.NoError(t, c.Provide(func() func(int) { return func(int) {} }), "provide failed") @@ -130,7 +145,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("interface constructor", func(t *testing.T) { - c := New() + c := newContainer() require.NoError(t, c.Provide(func() io.Writer { return &bytes.Buffer{} }), "provide failed") @@ -140,7 +155,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("param", func(t *testing.T) { - c := New() + c := newContainer() type contents string type Args struct { In @@ -166,7 +181,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("invoke param", func(t *testing.T) { - c := New() + c := newContainer() require.NoError(t, c.Provide(func() *bytes.Buffer { return new(bytes.Buffer) }), "provide failed") @@ -188,7 +203,7 @@ func TestEndToEndSuccess(t *testing.T) { called bool ) - c := New() + c := newContainer() require.NoError(t, c.Provide(func() *bytes.Buffer { require.False(t, called, "constructor must be called exactly once") called = true @@ -230,7 +245,7 @@ func TestEndToEndSuccess(t *testing.T) { called bool ) - c := New() + c := newContainer() require.NoError(t, c.Provide(func() *bytes.Buffer { require.False(t, called, "constructor must be called exactly once") called = true @@ -250,7 +265,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("multiple-type constructor", func(t *testing.T) { - c := New() + c := newContainer() constructor := func() (*bytes.Buffer, []int, error) { return &bytes.Buffer{}, []int{42}, nil } @@ -263,7 +278,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("multiple-type constructor is called once", func(t *testing.T) { - c := New() + c := newContainer() type A struct{} type B struct{} count := 0 @@ -285,7 +300,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("method invocation inside Invoke", func(t *testing.T) { - c := New() + c := newContainer() type A struct{} type B struct{} cA := func() (*A, error) { @@ -307,7 +322,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("collections and instances of same type", func(t *testing.T) { - c := New() + c := newContainer() require.NoError(t, c.Provide(func() []*bytes.Buffer { return []*bytes.Buffer{{}} }), "providing collection failed") @@ -326,7 +341,7 @@ func TestEndToEndSuccess(t *testing.T) { return &type1{}, &type3{}, &type4{} } - c := New() + c := newContainer() type param struct { In @@ -357,7 +372,7 @@ func TestEndToEndSuccess(t *testing.T) { myA := A{"string A"} myB := &B{"string B"} - c := New() + c := newContainer() require.NoError(t, c.Provide(func() Ret { return Ret{A: myA, B: myB} }), "provide for the Ret struct should succeed") @@ -379,7 +394,7 @@ func TestEndToEndSuccess(t *testing.T) { T1 *type1 `optional:"true"` } - c := New() + c := newContainer() var gave *type2 require.NoError(t, c.Provide(func(p param) *type2 { @@ -394,7 +409,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("nested dependencies", func(t *testing.T) { - c := New() + c := newContainer() type A struct{ name string } type B struct{ name string } @@ -411,7 +426,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("primitives", func(t *testing.T) { - c := New() + c := newContainer() require.NoError(t, c.Provide(func() string { return "piper" }), "string provide failed") require.NoError(t, c.Provide(func() int { return 42 }), "int provide failed") require.NoError(t, c.Provide(func() int64 { return 24 }), "int provide failed") @@ -441,7 +456,7 @@ func TestEndToEndSuccess(t *testing.T) { *B C } - c := New() + c := newContainer() require.NoError(t, c.Provide(func() Ret2 { return Ret2{ @@ -458,7 +473,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("named instances can be created with tags", func(t *testing.T) { - c := New() + c := newContainer() type A struct{ idx int } // returns three named instances of A @@ -487,7 +502,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("named instances can be created with Name option", func(t *testing.T) { - c := New() + c := newContainer() type A struct{ idx int } @@ -513,7 +528,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("named and unnamed instances coexist", func(t *testing.T) { - c := New() + c := newContainer() type A struct{ idx int } type out struct { @@ -538,7 +553,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("named instances recurse", func(t *testing.T) { - c := New() + c := newContainer() type A struct{ idx int } type Ret1 struct { @@ -572,7 +587,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("named instances do not cause cycles", func(t *testing.T) { - c := New() + c := newContainer() type A struct{ idx int } type param struct { In @@ -606,7 +621,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("struct constructor with as interface option", func(t *testing.T) { - c := New() + c := newContainer() provider := c.Provide( func() *bytes.Buffer { @@ -630,7 +645,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("As with Name", func(t *testing.T) { - c := New() + c := newContainer() require.NoError(t, c.Provide( func() *bytes.Buffer { @@ -660,21 +675,21 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("As same interface", func(t *testing.T) { - c := New() + c := newContainer() require.NoError(t, c.Provide(func() io.Reader { panic("this function should not be called") }, As(new(io.Reader))), "failed to provide") }) t.Run("As different interface", func(t *testing.T) { - c := New() + c := newContainer() require.NoError(t, c.Provide(func() io.ReadCloser { panic("this function should not be called") }, As(new(io.Reader), new(io.Closer))), "failed to provide") }) t.Run("invoke on a type that depends on named parameters", func(t *testing.T) { - c := New() + c := newContainer() type A struct{ idx int } type B struct{ sum int } type param struct { @@ -718,7 +733,7 @@ func TestEndToEndSuccess(t *testing.T) { } t.Run("optional", func(t *testing.T) { - c := New() + c := newContainer() called1 := false require.NoError(t, c.Invoke(func(p param1) { @@ -737,7 +752,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("named", func(t *testing.T) { - c := New() + c := newContainer() require.NoError(t, c.Provide(func() *struct{} { return &struct{}{} @@ -764,7 +779,7 @@ func TestEndToEndSuccess(t *testing.T) { t.Run("dynamically generated dig.In", func(t *testing.T) { // This test verifies that a dig.In generated using reflect.StructOf // works with our dig.In detection logic. - c := New() + c := newContainer() type type1 struct{} type type2 struct{} @@ -826,7 +841,7 @@ func TestEndToEndSuccess(t *testing.T) { // This test verifies that a dig.Out generated using reflect.StructOf // works with our dig.Out detection logic. - c := New() + c := newContainer() type A struct{ Value int } @@ -871,7 +886,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("variadic arguments invoke", func(t *testing.T) { - c := New() + c := newContainer() type A struct{} @@ -893,7 +908,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("variadic arguments dependency", func(t *testing.T) { - c := New() + c := newContainer() type A struct{} type B struct{} @@ -924,7 +939,7 @@ func TestEndToEndSuccess(t *testing.T) { }) t.Run("non-error return arguments from invoke are ignored", func(t *testing.T) { - c := New() + c := newContainer() type A struct{} type B struct{} @@ -934,16 +949,96 @@ func TestEndToEndSuccess(t *testing.T) { err := c.Invoke(func(B) {}) require.Error(t, err, "invoking with B param should error out") assertErrorMatches(t, err, - `missing dependencies for function "go.uber.org/dig".TestEndToEndSuccess.func\S+ \(\S+/src/go.uber.org/dig/dig_test.go:\d+\):`, + `missing dependencies for function "go.uber.org/dig".testEndToEndSuccess.func\S+ \(\S+/src/go.uber.org/dig/dig_test.go:\d+\):`, "type dig.B is not in the container,", "did you mean to Provide it?", ) }) } +func TestChildren(t *testing.T) { + t.Parallel() + + t.Run("parent providers available to deeply nested children", func(t *testing.T) { + c := New() + + var b *bytes.Buffer + require.NoError(t, c.Provide(func() *bytes.Buffer { + b = &bytes.Buffer{} + return b + }), "provide failed") + ch := c.Child("1").Child("2").Child("3") + require.NoError(t, ch.Invoke(func(got *bytes.Buffer) { + require.NotNil(t, got, "invoke got nil buffer") + require.True(t, got == b, "invoke got wrong buffer") + }), "invoke failed") + }) + + t.Run("multiple sub graphs", func(t *testing.T) { + c := New() + + cc := make([]*Container, 0, 5) + for i := 0; i < 5; i++ { + cc = append(cc, c.Child(strconv.Itoa(i))) + } + + for i := 0; i < 5; i++ { + cc = append(cc, cc[1].Child(strconv.Itoa(i))) + } + + var b *bytes.Buffer + require.NoError(t, cc[2].Provide(func() *bytes.Buffer { + b = &bytes.Buffer{} + return b + }), "provide failed") + require.NoError(t, c.Invoke(func(got *bytes.Buffer) { + require.NotNil(t, got, "invoke got nil buffer") + require.True(t, got == b, "invoke got wrong buffer") + }), "invoke failed") + }) +} + func TestGroups(t *testing.T) { + testSubGraphs(t, testGroups) +} + +func testSubGraphs(t *testing.T, tf func(*testing.T, newContainerFunc)) { + t.Run("root", func(t *testing.T) { + tf(t, func(oo ...Option) containerView { + c := New(oo...) + return containerView{ + Provide: c.Provide, + Invoke: c.Invoke, + } + }) + }) + + t.Run("provide in parent, invoke in child", func(t *testing.T) { + tf(t, func(oo ...Option) containerView { + parent := New(oo...) + child := parent.Child("child") + return containerView{ + Provide: parent.Provide, + Invoke: child.Invoke, + } + }) + }) + + t.Run("provide in child, invoke in parent", func(t *testing.T) { + tf(t, func(oo ...Option) containerView { + parent := New(oo...) + child := parent.Child("child") + return containerView{ + Provide: child.Provide, + Invoke: parent.Invoke, + } + }) + }) +} + +func testGroups(t *testing.T, newContainer newContainerFunc) { t.Run("empty slice received without provides", func(t *testing.T) { - c := New() + c := newContainer() type in struct { In @@ -957,7 +1052,7 @@ func TestGroups(t *testing.T) { }) t.Run("values are provided", func(t *testing.T) { - c := New(setRand(rand.New(rand.NewSource(0)))) + c := newContainer(setRand(rand.New(rand.NewSource(0)))) type out struct { Out @@ -987,7 +1082,7 @@ func TestGroups(t *testing.T) { }) t.Run("groups are provided via option", func(t *testing.T) { - c := New(setRand(rand.New(rand.NewSource(0)))) + c := newContainer(setRand(rand.New(rand.NewSource(0)))) provide := func(i int) { require.NoError(t, c.Provide(func() int { @@ -1011,7 +1106,7 @@ func TestGroups(t *testing.T) { }) t.Run("different types may be grouped", func(t *testing.T) { - c := New(setRand(rand.New(rand.NewSource(0)))) + c := newContainer(setRand(rand.New(rand.NewSource(0)))) provide := func(i int, s string) { require.NoError(t, c.Provide(func() (int, string) { @@ -1037,7 +1132,7 @@ func TestGroups(t *testing.T) { }) t.Run("group options may not be provided for result structs", func(t *testing.T) { - c := New(setRand(rand.New(rand.NewSource(0)))) + c := newContainer(setRand(rand.New(rand.NewSource(0)))) type out struct { Out @@ -1053,7 +1148,7 @@ func TestGroups(t *testing.T) { }) t.Run("constructor is called at most once", func(t *testing.T) { - c := New(setRand(rand.New(rand.NewSource(0)))) + c := newContainer(setRand(rand.New(rand.NewSource(0)))) type out struct { Out @@ -1100,7 +1195,7 @@ func TestGroups(t *testing.T) { }) t.Run("consume groups in constructor", func(t *testing.T) { - c := New(setRand(rand.New(rand.NewSource(0)))) + c := newContainer(setRand(rand.New(rand.NewSource(0)))) type out struct { Out @@ -1143,7 +1238,7 @@ func TestGroups(t *testing.T) { }) t.Run("provide multiple values", func(t *testing.T) { - c := New(setRand(rand.New(rand.NewSource(0)))) + c := newContainer(setRand(rand.New(rand.NewSource(0)))) type outInt struct { Out @@ -1206,7 +1301,7 @@ func TestGroups(t *testing.T) { }) t.Run("duplicate values are supported", func(t *testing.T) { - c := New(setRand(rand.New(rand.NewSource(0)))) + c := newContainer(setRand(rand.New(rand.NewSource(0)))) type out struct { Out @@ -1245,7 +1340,7 @@ func TestGroups(t *testing.T) { }) t.Run("failure to build a grouped value fails everything", func(t *testing.T) { - c := New(setRand(rand.New(rand.NewSource(0)))) + c := newContainer(setRand(rand.New(rand.NewSource(0)))) type out struct { Out @@ -1278,9 +1373,9 @@ func TestGroups(t *testing.T) { }) require.Error(t, err, "expected failure") assertErrorMatches(t, err, - `could not build arguments for function "go.uber.org/dig".TestGroups`, + `could not build arguments for function "go.uber.org/dig".testGroups`, `could not build value group string\[group="x"\]:`, - `function "go.uber.org/dig".TestGroups\S+ \(\S+:\d+\) returned a non-nil error:`, + `function "go.uber.org/dig".testGroups\S+ \(\S+:\d+\) returned a non-nil error:`, "great sadness", ) assert.Equal(t, gaveErr, RootCause(err)) @@ -1677,6 +1772,18 @@ func TestProvideKnownTypesFails(t *testing.T) { assert.NoError(t, c.Provide(func() *bytes.Buffer { return nil })) assert.Error(t, c.Provide(func() *bytes.Buffer { return nil })) }) + t.Run("provide constructor twice first in parent and then in child", func(t *testing.T) { + parent := New() + child := parent.Child("child") + assert.NoError(t, parent.Provide(func() *bytes.Buffer { return nil })) + assert.Error(t, child.Provide(func() *bytes.Buffer { return nil })) + }) + t.Run("provide constructor twice first in parent and then in child", func(t *testing.T) { + parent := New() + child := parent.Child("child") + assert.NoError(t, child.Provide(func() *bytes.Buffer { return nil })) + assert.Error(t, parent.Provide(func() *bytes.Buffer { return nil })) + }) } func TestProvideCycleFails(t *testing.T) { @@ -1914,7 +2021,11 @@ func TestTypeCheckingEquality(t *testing.T) { func TestInvokesUseCachedObjects(t *testing.T) { t.Parallel() - c := New() + testSubGraphs(t, testInvokesUseCachedObjects) +} + +func testInvokesUseCachedObjects(t *testing.T, newContainer newContainerFunc) { + c := newContainer() constructorCalls := 0 buf := &bytes.Buffer{} @@ -2108,6 +2219,34 @@ func TestProvideFailures(t *testing.T) { assert.Contains(t, err.Error(), "cannot provide *bytes.Buffer") assert.Contains(t, err.Error(), "already provided") }) + + t.Run("provide multiple instances with the same name in different children", func(t *testing.T) { + c := New() + + ca := c.Child("1") + cb := ca.Child("2") + type A struct{} + type ret1 struct { + Out + *A `name:"foo"` + } + type ret2 struct { + Out + *A `name:"foo"` + } + require.NoError(t, ca.Provide(func() ret1 { + return ret1{A: &A{}} + })) + err := cb.Provide(func() ret2 { + return ret2{A: &A{}} + }) + require.Error(t, err, "expected error on the second provide") + assertErrorMatches(t, err, + `function "go.uber.org/dig".TestProvideFailures\S+ \(\S+:\d+\) cannot be provided:`, + `cannot provide \*dig.A\[name="foo"\] from \[0\].A:`, + `already provided by "go.uber.org/dig".TestProvideFailures\S+`, + ) + }) } func TestInvokeFailures(t *testing.T) { diff --git a/param.go b/param.go index 4daa4135..922b0bc4 100644 --- a/param.go +++ b/param.go @@ -48,6 +48,7 @@ type param interface { // This MAY panic if the param does not produce a single value. Build(containerStore) (reflect.Value, error) + Decorate(containerStore) (reflect.Value, error) // DotParam returns a slice of dot.Param(s). DotParam() []*dot.Param } @@ -206,6 +207,10 @@ func (pl paramList) BuildList(c containerStore) ([]reflect.Value, error) { return args, nil } +func (pl paramList) Decorate(c containerStore) (reflect.Value, error) { + panic("not supposed to happen") +} + // paramSingle is an explicitly requested type, optionally with a name. // // This object must be present in the graph as-is unless it's specified as @@ -259,6 +264,27 @@ func (ps paramSingle) Build(c containerStore) (reflect.Value, error) { Reason: err, } } + if v, err := ps.Decorate(c); err != nil { + return _noValue, err + } else { + return v, nil + } +} + +func (ps paramSingle) Decorate(c containerStore) (reflect.Value, error) { + + decorators := c.getDecorators(key{name: ps.Name, t: ps.Type}) + for _, n := range decorators { + err := n.Call(c) + if err == nil { + continue + } + return _noValue, errParamSingleFailed{ + CtorID: n.ID(), + Key: key{t: ps.Type, name: ps.Name}, + Reason: err, + } + } // If we get here, it's impossible for the value to be absent from the // container. @@ -317,6 +343,10 @@ func (po paramObject) Build(c containerStore) (reflect.Value, error) { return dest, nil } +func (po paramObject) Decorate(c containerStore) (reflect.Value, error) { + panic("not supposed to happen") +} + // paramObjectField is a single field of a dig.In struct. type paramObjectField struct { // Name of the field in the struct. @@ -388,6 +418,10 @@ func (pof paramObjectField) Build(c containerStore) (reflect.Value, error) { return v, nil } +func (pof paramObjectField) Decorate(c containerStore) (reflect.Value, error) { + panic("not supposed not happen") +} + // paramGroupedSlice is a param which produces a slice of values with the same // group name. type paramGroupedSlice struct { @@ -434,6 +468,13 @@ func newParamGroupedSlice(f reflect.StructField) (paramGroupedSlice, error) { } func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { + if items, ok := c.getValueGroup(pt.Group, pt.Type.Elem()); ok { + result := reflect.MakeSlice(pt.Type, len(items), len(items)) + for i, v := range items { + result.Index(i).Set(v) + } + return result, nil + } for _, n := range c.getGroupProviders(pt.Group, pt.Type.Elem()) { if err := n.Call(c); err != nil { return _noValue, errParamGroupFailed{ @@ -443,12 +484,25 @@ func (pt paramGroupedSlice) Build(c containerStore) (reflect.Value, error) { } } } + val, err := pt.Decorate(c) + if err != nil { + return _noValue, err + } + return val, nil +} - items := c.getValueGroup(pt.Group, pt.Type.Elem()) +func (pt paramGroupedSlice) Decorate(c containerStore) (reflect.Value, error) { + decs := c.getDecorators(key{t: pt.Type.Elem(), group: pt.Group}) + for _, n := range decs { + if err := n.Call(c); err != nil { + return _noValue, err + } + } + items, _ := c.getValueGroup(pt.Group, pt.Type.Elem()) result := reflect.MakeSlice(pt.Type, len(items), len(items)) for i, v := range items { result.Index(i).Set(v) } return result, nil -} +} \ No newline at end of file diff --git a/stringer.go b/stringer.go index d10fa0fb..8411f3b3 100644 --- a/stringer.go +++ b/stringer.go @@ -29,6 +29,9 @@ import ( // String representation of the entire Container func (c *Container) String() string { b := &bytes.Buffer{} + if c.parent != nil { + fmt.Fprintf(b, "parent: %p\n", c.parent) + } fmt.Fprintln(b, "nodes: {") for k, vs := range c.providers { for _, v := range vs { @@ -48,6 +51,14 @@ func (c *Container) String() string { } fmt.Fprintln(b, "}") + fmt.Fprintln(b, "children: [") + for _, v := range c.children { + fmt.Fprintln(b, "\t{") + fmt.Fprintln(b, "\t\t", v.name, "->", v) + fmt.Fprintln(b, "\t}") + } + fmt.Fprintln(b, "]") + return b.String() } diff --git a/stringer_test.go b/stringer_test.go index e096bb6e..c1af8121 100644 --- a/stringer_test.go +++ b/stringer_test.go @@ -21,6 +21,7 @@ package dig import ( + "fmt" "math/rand" "testing" @@ -105,4 +106,14 @@ func TestStringer(t *testing.T) { assert.Contains(t, s, `string[group="baz"] => foo`) assert.Contains(t, s, `string[group="baz"] => bar`) assert.Contains(t, s, `string[group="baz"] => baz`) + + s = c.Child("child").String() + + // Parent + assert.Contains(t, s, fmt.Sprintf("parent: %p", c)) + + s = c.String() + // Children + assert.Contains(t, s, "children: [") + assert.Contains(t, s, "child -> ") }