From 3a5618b505564ec20ecce83f2fe6697c00c82617 Mon Sep 17 00:00:00 2001 From: Joe Atzberger Date: Tue, 1 Dec 2020 19:02:21 -0500 Subject: [PATCH] json Marshaller interface for Inet Note that this does allow the serialization/deserialization between empty string and a Null struct. It does NOT permit invalid addresses or masks. See #79 --- inet.go | 26 ++++++++++++++++++++++-- inet_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 2 deletions(-) diff --git a/inet.go b/inet.go index b449819..56d04f5 100644 --- a/inet.go +++ b/inet.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "net" + "strings" errors "golang.org/x/xerrors" ) @@ -122,7 +123,7 @@ func (src *Inet) AssignTo(dst interface{}) error { return errors.Errorf("cannot decode %#v into %T", src, dst) } -func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { +func (dst *Inet) DecodeText(_ *ConnInfo, src []byte) error { if src == nil { *dst = Inet{Status: Null} return nil @@ -150,7 +151,7 @@ func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error { return nil } -func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error { +func (dst *Inet) DecodeBinary(_ *ConnInfo, src []byte) error { if src == nil { *dst = Inet{Status: Null} return nil @@ -218,6 +219,27 @@ func (src Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) { return append(buf, src.IPNet.IP...), nil } +// MarshalJSON implements the json.Marshaler interface +func (src Inet) MarshalJSON() ([]byte, error) { + if src.Status != Present { + return []byte(`""`), nil + } + v, err := src.Value() + if err != nil || v == nil { + return []byte(`""`), err + } + return []byte(`"` + v.(string) + `"`), nil +} + +// UnmarshalJSON implements the json.Marshaler interface +func (dst *Inet) UnmarshalJSON(data []byte) error { + trimmed := strings.Trim(string(data), `"`) + if trimmed == "" { + return dst.DecodeText(nil, nil) + } + return dst.DecodeText(nil, []byte(trimmed)) +} + // Scan implements the database/sql Scanner interface. func (dst *Inet) Scan(src interface{}) error { if src == nil { diff --git a/inet_test.go b/inet_test.go index cb420a5..a09e270 100644 --- a/inet_test.go +++ b/inet_test.go @@ -114,3 +114,60 @@ func TestInetAssignTo(t *testing.T) { } } } + +func TestInetMarshalJSON(t *testing.T) { + successfulTests := []struct { + json string + source pgtype.Inet + }{ + {source: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, json: `"127.0.0.1/32"`}, + {source: pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, json: `"2607:f8b0:4009:80b::200e/128"`}, + {source: pgtype.Inet{Status: pgtype.Null}, json: `""`}, + {source: pgtype.Inet{}, json: `""`}, + } + + for i, tt := range successfulTests { + got, err := tt.source.MarshalJSON() + if err != nil { + t.Errorf("%d: %v", i, err) + } + if !reflect.DeepEqual(got, []byte(tt.json)) { + t.Errorf("%d: expected JSON `%s`, but it was %s", i, tt.json, string(got)) + } + } +} + +func TestInetUnmarshalJSON(t *testing.T) { + successfulTests := []struct { + json string + expected pgtype.Inet + }{ + {expected: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, json: `"127.0.0.1/32"`}, + {expected: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, json: `"127.0.0.1"`}, + {expected: pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, json: `"2607:f8b0:4009:80b::200e/128"`}, + {expected: pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, json: `"2607:f8b0:4009:80b::200e"`}, + {expected: pgtype.Inet{Status: pgtype.Null}, json: `""`}, // empty is OK, equivalent to our null struct + } + badJSON := []string{ + `"127.0.0.1/"`, // no network + `"444.555.666.777/32"`, // bad addr + `"nonsense"`, // bad everything + } + + for i, tt := range successfulTests { + got := pgtype.Inet{} + if err := got.UnmarshalJSON([]byte(tt.json)); err != nil { + t.Errorf("%d: %v", i, err) + } + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("%d: expected %v from JSON `%s`, but it was %v", i, tt.expected, tt.json, got) + } + } + + for i, example := range badJSON { + got := pgtype.Inet{} + if err := got.UnmarshalJSON([]byte(example)); err == nil { + t.Errorf("%d: Expected error for %s, but got none", i, example) + } + } +}