Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module github.com/stvp/go-udp-testing

go 1.18.0
33 changes: 19 additions & 14 deletions udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net"
"runtime"
"strings"
"testing"
"time"
)

Expand All @@ -16,6 +15,12 @@ var (
Timeout = time.Millisecond
)

// TestingT interface for the methods used from *testing.T
type TestingT interface {
Fatal(args ...any)
Errorf(format string, args ...any)
}

type fn func()

// SetAddr sets the UDP port that will be listened on.
Expand All @@ -35,7 +40,7 @@ func WriteTo(b []byte, addr net.Addr) (n int, err error) {
return listener.WriteTo(b, addr)
}

func start(t *testing.T) {
func start(t TestingT) {
resAddr, err := net.ResolveUDPAddr("udp", *addr)
if err != nil {
t.Fatal(err)
Expand All @@ -46,13 +51,13 @@ func start(t *testing.T) {
}
}

func stop(t *testing.T) {
func stop(t TestingT) {
if err := listener.Close(); err != nil {
t.Fatal(err)
}
}

func getMessage(t *testing.T, body fn) string {
func getMessage(t TestingT, body fn) string {
start(t)
defer stop(t)

Expand All @@ -73,21 +78,21 @@ func getMessage(t *testing.T, body fn) string {
return string(message[0:bufLen])
}

func get(t *testing.T, match string, body fn) (got string, equals bool, contains bool) {
func get(t TestingT, match string, body fn) (got string, equals bool, contains bool) {
got = getMessage(t, body)
equals = got == match
contains = strings.Contains(got, match)
return got, equals, contains
}

func printLocation(t *testing.T) {
func printLocation(t TestingT) {
_, file, line, _ := runtime.Caller(2)
t.Errorf("At: %s:%d", file, line)
}

// ShouldReceiveOnly will fire a test error if the given function doesn't send
// exactly the given string over UDP.
func ShouldReceiveOnly(t *testing.T, expected string, body fn) {
func ShouldReceiveOnly(t TestingT, expected string, body fn) {
got, equals, _ := get(t, expected, body)
if !equals {
printLocation(t)
Expand All @@ -98,7 +103,7 @@ func ShouldReceiveOnly(t *testing.T, expected string, body fn) {

// ShouldNotReceiveOnly will fire a test error if the given function sends
// exactly the given string over UDP.
func ShouldNotReceiveOnly(t *testing.T, notExpected string, body fn) {
func ShouldNotReceiveOnly(t TestingT, notExpected string, body fn) {
_, equals, _ := get(t, notExpected, body)
if equals {
printLocation(t)
Expand All @@ -108,7 +113,7 @@ func ShouldNotReceiveOnly(t *testing.T, notExpected string, body fn) {

// ShouldReceive will fire a test error if the given function doesn't send the
// given string over UDP.
func ShouldReceive(t *testing.T, expected string, body fn) {
func ShouldReceive(t TestingT, expected string, body fn) {
got, _, contains := get(t, expected, body)
if !contains {
printLocation(t)
Expand All @@ -119,7 +124,7 @@ func ShouldReceive(t *testing.T, expected string, body fn) {

// ShouldNotReceive will fire a test error if the given function sends the
// given string over UDP.
func ShouldNotReceive(t *testing.T, expected string, body fn) {
func ShouldNotReceive(t TestingT, expected string, body fn) {
got, _, contains := get(t, expected, body)
if contains {
printLocation(t)
Expand All @@ -130,7 +135,7 @@ func ShouldNotReceive(t *testing.T, expected string, body fn) {

// ShouldReceiveAll will fire a test error unless all of the given strings are
// sent over UDP.
func ShouldReceiveAll(t *testing.T, expected []string, body fn) {
func ShouldReceiveAll(t TestingT, expected []string, body fn) {
got := getMessage(t, body)
failed := false

Expand All @@ -151,7 +156,7 @@ func ShouldReceiveAll(t *testing.T, expected []string, body fn) {

// ShouldNotReceiveAny will fire a test error if any of the given strings are
// sent over UDP.
func ShouldNotReceiveAny(t *testing.T, unexpected []string, body fn) {
func ShouldNotReceiveAny(t TestingT, unexpected []string, body fn) {
got := getMessage(t, body)
failed := false

Expand All @@ -170,7 +175,7 @@ func ShouldNotReceiveAny(t *testing.T, unexpected []string, body fn) {
}
}

func ShouldReceiveAllAndNotReceiveAny(t *testing.T, expected []string, unexpected []string, body fn) {
func ShouldReceiveAllAndNotReceiveAny(t TestingT, expected []string, unexpected []string, body fn) {
got := getMessage(t, body)
failed := false

Expand Down Expand Up @@ -198,6 +203,6 @@ func ShouldReceiveAllAndNotReceiveAny(t *testing.T, expected []string, unexpecte
}
}

func ReceiveString(t *testing.T, body fn) string {
func ReceiveString(t TestingT, body fn) string {
return getMessage(t, body)
}