Skip to content
This repository was archived by the owner on Mar 22, 2019. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions inject.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,6 @@ StructLoop:
)
}

// Don't overwrite existing values.
if !isNilOrZero(field, fieldType) {
continue
}

// Named injects must have been explicitly provided.
if tag.Name != "" {
existing := g.named[tag.Name]
Expand Down Expand Up @@ -326,7 +321,19 @@ StructLoop:
}

// Interface injection is handled in a second pass.
if fieldType.Kind() == reflect.Interface && isNilOrZero(field, fieldType) {
continue
}

if fieldType.Kind() == reflect.Interface {
err := g.Provide(&Object{
Value: field.Elem().Interface(),
private: true,
embedded: o.reflectType.Elem().Field(i).Anonymous,
})
if err != nil {
return err
}
continue
}

Expand Down Expand Up @@ -383,6 +390,10 @@ StructLoop:
}
}

if !tag.Override && !isNilOrZero(field, fieldType) {
continue
}

newValue := reflect.New(fieldType.Elem())
newObject := &Object{
Value: newValue.Interface(),
Expand Down Expand Up @@ -453,11 +464,6 @@ func (g *Graph) populateUnnamedInterface(o *Object) error {
)
}

// Don't overwrite existing values.
if !isNilOrZero(field, fieldType) {
continue
}

// Named injects must have already been handled in populateExplicit.
if tag.Name != "" {
panic(fmt.Sprintf("unhandled named instance with name %s", tag.Name))
Expand All @@ -482,6 +488,9 @@ func (g *Graph) populateUnnamedInterface(o *Object) error {
existing.reflectValue,
)
}
if !tag.Override && !isNilOrZero(field, fieldType) {
continue
}
found = existing
field.Set(reflect.ValueOf(existing.Value))
if g.Logger != nil {
Expand All @@ -497,13 +506,14 @@ func (g *Graph) populateUnnamedInterface(o *Object) error {
}

// If we didn't find an assignable value, we're missing something.
if found == nil {
if found == nil && isNilOrZero(field, fieldType) {
return fmt.Errorf(
"found no assignable value for field %s in type %s",
o.reflectType.Elem().Field(i).Name,
o.reflectType,
)
}

}
return nil
}
Expand Down Expand Up @@ -531,15 +541,17 @@ func (g *Graph) Objects() []*Object {
}

var (
injectOnly = &tag{}
injectPrivate = &tag{Private: true}
injectInline = &tag{Inline: true}
injectOnly = &tag{}
injectPrivate = &tag{Private: true}
injectInline = &tag{Inline: true}
injectOverride = &tag{Override: true}
)

type tag struct {
Name string
Inline bool
Private bool
Name string
Inline bool
Private bool
Override bool
}

func parseTag(t string) (*tag, error) {
Expand All @@ -550,14 +562,15 @@ func parseTag(t string) (*tag, error) {
if !found {
return nil, nil
}
if value == "" {
switch value {
case "":
return injectOnly, nil
}
if value == "inline" {
case "inline":
return injectInline, nil
}
if value == "private" {
case "private":
return injectPrivate, nil
case "override":
return injectOverride, nil
}
return &tag{Name: value}, nil
}
Expand Down
43 changes: 41 additions & 2 deletions inject_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package inject_test
import (
"fmt"
"math/rand"
"reflect"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -114,6 +115,44 @@ func TestInjectSimple(t *testing.T) {
}
}

func TestInjectOverride(t *testing.T) {
var v struct {
A *TypeAnswerStruct `inject:""`
B *TypeAnswerStruct `inject:"override"`
}
olda, oldb := &TypeAnswerStruct{}, &TypeAnswerStruct{}
v.A, v.B = olda, oldb
if err := inject.Populate(&v); err != nil {
t.Fatal(err)
}
if v.A != olda {
t.Fatal("original A was lost")
}
if v.B == oldb {
t.Fatal("original B was not overridden")
}
}

type TypeNestedInterfaceStruct struct {
A Answerable `inject:""`
}

func TestNonEmptyInterfaceTraversal(t *testing.T) {
olda := &TypeNestedStruct{}
v := TypeNestedInterfaceStruct{
A: olda,
}
if err := inject.Populate(&v); err != nil {
t.Fatal(err)
}
if v.A != olda {
t.Fatal("original A was lost")
}
if olda.A == nil {
t.Fatal("v.A.A is nil")
}
}

func TestDoesNotOverwrite(t *testing.T) {
a := &TypeAnswerStruct{}
var v struct {
Expand Down Expand Up @@ -238,7 +277,7 @@ func TestProvideTwoOfTheSame(t *testing.T) {
t.Fatal("expected error")
}

const msg = "provided two unnamed instances of type *github.com/facebookgo/inject_test.TypeAnswerStruct"
msg := fmt.Sprintf("provided two unnamed instances of type *%s.TypeAnswerStruct", reflect.TypeOf(a).PkgPath())
if err.Error() != msg {
t.Fatalf("expected:\n%s\nactual:\n%s", msg, err.Error())
}
Expand All @@ -251,7 +290,7 @@ func TestProvideTwoOfTheSameWithPopulate(t *testing.T) {
t.Fatal("expected error")
}

const msg = "provided two unnamed instances of type *github.com/facebookgo/inject_test.TypeAnswerStruct"
msg := fmt.Sprintf("provided two unnamed instances of type *%s.TypeAnswerStruct", reflect.TypeOf(a).PkgPath())
if err.Error() != msg {
t.Fatalf("expected:\n%s\nactual:\n%s", msg, err.Error())
}
Expand Down