From b5c69faa5740214abb3d43c22e925fd45fe8ab3d Mon Sep 17 00:00:00 2001 From: Antoine Pourchet Date: Thu, 26 Oct 2017 15:12:55 -0700 Subject: [PATCH] Added the ability to override non-zero values Previously, if some subtree of an object that was to be populated was not zero/nil, we would terminate the population of that subtree then and there. However this makes the package hard to use for very target dependencies that only occur at the bottom of the dependency tree (e.g: a network client or database layer). With this patch, the injection does move downwards and populate the fields that it can. Furthermore, we can now have constructors populate our structs with default values, that will then be overridden by the injection mechanism if the tag `inject:"override"` is present on the field. This means that we can now use sane defaults when creating our objects that we can inject mocks for during unit tests. ```go // A is our bogus wrapper of the http.Client type A struct { Client *http.Client `inject:"override"` } // NewA returns a fully functional A, with a non-nil http client. func NewA() *A { return &A{ Client: &http.Client{} } } var testClient *http.Client func TestMain(m *testing.M) testClient = NewTestClient() os.Exit(m.Run()) } func TestWithBogusClient(t *testing.T) { a := NewA() inject.Populate(a, testClient) // Here the injector should have overridden the default http client // inside A to use the bogus client. a.Do(...) } ``` The last feature that I was looking for was the traversal of non-nil interfaces. For instance: ```go // I is our bogus interface. It has no functions for the sake of // brevity. type I interface {} // A implements that bogus interface type A struct { Client *http.Client `inject:""` } // Nested contains an I which does not need to be injected. However, // the I (which will be of type A at runtime) needs to be traversed so that // we can inject the right *http.Client. type Nested struct { Iface I } func main() { n := &Nested{ Iface: &A{}, } specialclient := &http.Client{} inject.Populate(n, specialclient) } ``` This is useful in bigger projects when you have many subcomponents that use an *http.Client or sub-interface of that, and need to all have that client mocked during unit tests. --- inject.go | 57 +++++++++++++++++++++++++++++++------------------- inject_test.go | 43 +++++++++++++++++++++++++++++++++++-- 2 files changed, 76 insertions(+), 24 deletions(-) diff --git a/inject.go b/inject.go index 300b9a3..9fdd15d 100644 --- a/inject.go +++ b/inject.go @@ -254,11 +254,6 @@ StructLoop: ) } - // Don't overwrite existing values. - if !isNilOrZero(field, fieldType) { - continue - } - // Named injects must have been explicitly provided. if tag.Name != "" { existing := g.named[tag.Name] @@ -326,7 +321,19 @@ StructLoop: } // Interface injection is handled in a second pass. + if fieldType.Kind() == reflect.Interface && isNilOrZero(field, fieldType) { + continue + } + if fieldType.Kind() == reflect.Interface { + err := g.Provide(&Object{ + Value: field.Elem().Interface(), + private: true, + embedded: o.reflectType.Elem().Field(i).Anonymous, + }) + if err != nil { + return err + } continue } @@ -383,6 +390,10 @@ StructLoop: } } + if !tag.Override && !isNilOrZero(field, fieldType) { + continue + } + newValue := reflect.New(fieldType.Elem()) newObject := &Object{ Value: newValue.Interface(), @@ -453,11 +464,6 @@ func (g *Graph) populateUnnamedInterface(o *Object) error { ) } - // Don't overwrite existing values. - if !isNilOrZero(field, fieldType) { - continue - } - // Named injects must have already been handled in populateExplicit. if tag.Name != "" { panic(fmt.Sprintf("unhandled named instance with name %s", tag.Name)) @@ -482,6 +488,9 @@ func (g *Graph) populateUnnamedInterface(o *Object) error { existing.reflectValue, ) } + if !tag.Override && !isNilOrZero(field, fieldType) { + continue + } found = existing field.Set(reflect.ValueOf(existing.Value)) if g.Logger != nil { @@ -497,13 +506,14 @@ func (g *Graph) populateUnnamedInterface(o *Object) error { } // If we didn't find an assignable value, we're missing something. - if found == nil { + if found == nil && isNilOrZero(field, fieldType) { return fmt.Errorf( "found no assignable value for field %s in type %s", o.reflectType.Elem().Field(i).Name, o.reflectType, ) } + } return nil } @@ -531,15 +541,17 @@ func (g *Graph) Objects() []*Object { } var ( - injectOnly = &tag{} - injectPrivate = &tag{Private: true} - injectInline = &tag{Inline: true} + injectOnly = &tag{} + injectPrivate = &tag{Private: true} + injectInline = &tag{Inline: true} + injectOverride = &tag{Override: true} ) type tag struct { - Name string - Inline bool - Private bool + Name string + Inline bool + Private bool + Override bool } func parseTag(t string) (*tag, error) { @@ -550,14 +562,15 @@ func parseTag(t string) (*tag, error) { if !found { return nil, nil } - if value == "" { + switch value { + case "": return injectOnly, nil - } - if value == "inline" { + case "inline": return injectInline, nil - } - if value == "private" { + case "private": return injectPrivate, nil + case "override": + return injectOverride, nil } return &tag{Name: value}, nil } diff --git a/inject_test.go b/inject_test.go index 6433eed..148b131 100644 --- a/inject_test.go +++ b/inject_test.go @@ -3,6 +3,7 @@ package inject_test import ( "fmt" "math/rand" + "reflect" "strings" "testing" "time" @@ -114,6 +115,44 @@ func TestInjectSimple(t *testing.T) { } } +func TestInjectOverride(t *testing.T) { + var v struct { + A *TypeAnswerStruct `inject:""` + B *TypeAnswerStruct `inject:"override"` + } + olda, oldb := &TypeAnswerStruct{}, &TypeAnswerStruct{} + v.A, v.B = olda, oldb + if err := inject.Populate(&v); err != nil { + t.Fatal(err) + } + if v.A != olda { + t.Fatal("original A was lost") + } + if v.B == oldb { + t.Fatal("original B was not overridden") + } +} + +type TypeNestedInterfaceStruct struct { + A Answerable `inject:""` +} + +func TestNonEmptyInterfaceTraversal(t *testing.T) { + olda := &TypeNestedStruct{} + v := TypeNestedInterfaceStruct{ + A: olda, + } + if err := inject.Populate(&v); err != nil { + t.Fatal(err) + } + if v.A != olda { + t.Fatal("original A was lost") + } + if olda.A == nil { + t.Fatal("v.A.A is nil") + } +} + func TestDoesNotOverwrite(t *testing.T) { a := &TypeAnswerStruct{} var v struct { @@ -238,7 +277,7 @@ func TestProvideTwoOfTheSame(t *testing.T) { t.Fatal("expected error") } - const msg = "provided two unnamed instances of type *github.com/facebookgo/inject_test.TypeAnswerStruct" + msg := fmt.Sprintf("provided two unnamed instances of type *%s.TypeAnswerStruct", reflect.TypeOf(a).PkgPath()) if err.Error() != msg { t.Fatalf("expected:\n%s\nactual:\n%s", msg, err.Error()) } @@ -251,7 +290,7 @@ func TestProvideTwoOfTheSameWithPopulate(t *testing.T) { t.Fatal("expected error") } - const msg = "provided two unnamed instances of type *github.com/facebookgo/inject_test.TypeAnswerStruct" + msg := fmt.Sprintf("provided two unnamed instances of type *%s.TypeAnswerStruct", reflect.TypeOf(a).PkgPath()) if err.Error() != msg { t.Fatalf("expected:\n%s\nactual:\n%s", msg, err.Error()) }