Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,4 +671,21 @@ type Config struct {
// When possible, setting Strict to true is recommended for applications
// running on modern Linux kernels.
Strict bool

// MessageBufferSize specifies a fixed buffer size for receiving netlink
// messages. When set, the connection will reuse a single pre-allocated
// buffer of this size instead of peeking at each message to determine
// the exact size needed.
//
// This is useful for high-throughput applications where the overhead of
// peeking at each message is undesirable and the maximum message size
// is known in advance.
//
// If set to 0 (the default), the connection will peek at the upcoming
// message before allocating a buffer for it.
//
// Note: this is not the same as the kernel socket receive buffer which
// can be configured using SetReadBuffer. MessageBufferSize only controls
// the userspace buffer passed to recvmsg.
MessageBufferSize int
}
39 changes: 31 additions & 8 deletions conn_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"iter"
"os"
"sync"
"syscall"
"time"
"unsafe"
Expand All @@ -20,7 +21,8 @@ var _ Socket = &conn{}

// A conn is the Linux implementation of a netlink sockets connection.
type conn struct {
s *socket.Conn
s *socket.Conn
pool *sync.Pool
}

// dial is the entry point for Dial. dial opens a netlink socket using
Expand Down Expand Up @@ -72,6 +74,16 @@ func newConn(s *socket.Conn, config *Config) (*conn, uint32, error) {
}

c := &conn{s: s}

if config.MessageBufferSize > 0 {
c.pool = &sync.Pool{
New: func() any {
b := make([]byte, config.MessageBufferSize)
return &b
},
}
}

if config.Strict {
// The caller has requested the strict option set. Historically we have
// recommended checking for ENOPROTOOPT if the kernel does not support
Expand Down Expand Up @@ -133,22 +145,33 @@ func (c *conn) Receive() ([]Message, error) {
return msgs, nil
}

// getMsgBufferSize peeks at the upcoming message to determine the size of the
// buffer needed to read it.
func (c *conn) getMsgBufferSize() (int, error) {
// getBuffer returns the buffer to use for receiving messages and a function to
// release it back to the pool if applicable. If the pool is not configured, a
// new buffer is allocated by peeking the size of the next message to be
// received.
func (c *conn) getBuffer() ([]byte, func(), error) {
if c.pool != nil {
bp := c.pool.Get().(*[]byte)
return *bp, func() { c.pool.Put(bp) }, nil
}

n, _, _, _, err := c.s.Recvmsg(context.Background(), nil, nil, unix.MSG_PEEK|unix.MSG_TRUNC)
return n, err
if err != nil {
return nil, nil, err
}

return make([]byte, n), func() {}, nil
}

// ReceiveIter returns an iterator over Messages received from netlink.
func (c *conn) ReceiveIter() iter.Seq2[Message, error] {
return func(yield func(Message, error) bool) {
n, err := c.getMsgBufferSize()
b, release, err := c.getBuffer()
if err != nil {
yield(Message{}, err)
return
}
b := make([]byte, n)
defer release()

// Read out all available messages
// TODO(mdlayher): deal with OOB message data if available, such as
Expand All @@ -163,7 +186,7 @@ func (c *conn) ReceiveIter() iter.Seq2[Message, error] {
// Our buffer was too small to read the entire message,
// this should not happen since we peeked above, but if it does,
// return an error.
yield(Message{}, unix.ENOSPC)
yield(Message{}, errMessageTruncated)
return
}

Expand Down
42 changes: 42 additions & 0 deletions conn_linux_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,48 @@ func TestIntegrationConnStrict(t *testing.T) {
}
}

func TestIntegrationConnMessageBufferSize(t *testing.T) {
tests := []struct {
name string
cfg *netlink.Config
wantErr bool
}{
{
name: "valid message buffer size",
cfg: &netlink.Config{MessageBufferSize: 8192},
},
{
name: "invalid message buffer size",
cfg: &netlink.Config{MessageBufferSize: 1},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := netlink.Dial(unix.NETLINK_GENERIC, tt.cfg)
if err != nil {
t.Fatalf("failed to dial netlink: %v", err)
}
defer c.Close()

req := netlink.Message{
Header: netlink.Header{
Flags: netlink.Request | netlink.Acknowledge,
},
}

_, err = c.Execute(req)
if tt.wantErr && err == nil {
t.Fatal("expected an error, but none occurred")
}
if !tt.wantErr && err != nil {
t.Fatalf("failed to execute request: %v", err)
}
})
}
}

func mustBeTimeoutNetError(t *testing.T, err error) {
t.Helper()

Expand Down
4 changes: 4 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ var (

var errNotSupported = errors.New("operation not supported")

// errMessageTruncated is returned when a received message was truncated
// because the receive buffer was too small.
var errMessageTruncated = errors.New("message truncated")

// notSupported provides a concise constructor for "not supported" errors.
func notSupported(op string) error {
return newOpError(op, errNotSupported)
Expand Down
Loading