diff --git a/conn.go b/conn.go index c986c13..606b051 100644 --- a/conn.go +++ b/conn.go @@ -185,37 +185,34 @@ func (cc *Conn) receiveSeq(conn *netlink.Conn) iter.Seq2[netlink.Message, error] break } - replies, err := conn.Receive() - if err != nil { - // Yield the error but continue iterating - if !yield(netlink.Message{}, err) { - return + for res, err := range conn.ReceiveIter() { + if err != nil { + // Yield the error but continue iterating + if !yield(netlink.Message{}, err) { + return + } } - continue - } - if len(replies) == 0 && cc.TestDial != nil { - // When using a test dial function, we don't always get a reply for each - // sent message. Additionally, there is no buffer to poll for more data, - // so we stop here. - return - } - - for _, msg := range replies { // Filter out non-nftables messages. // In practice, this would only be netlink.Error messages. // Those are handled by the netlink library itself and should be - // reported as errors by conn.Receive(). - subsystem := msg.Header.Type >> 8 + // reported as errors by conn.ReceiveIter(). + subsystem := res.Header.Type >> 8 if subsystem != unix.NFNL_SUBSYS_NFTABLES { continue } - // Stop iteration if yield returns false - if !yield(msg, nil) { + if !yield(res, nil) { return } } + + if cc.TestDial != nil { + // When using a test dial function, we don't always get a reply for each + // sent message. Additionally, there is no buffer to poll for more data, + // so we stop here. + return + } } } } diff --git a/go.mod b/go.mod index c862aa1..e8e0f84 100644 --- a/go.mod +++ b/go.mod @@ -15,3 +15,5 @@ require ( golang.org/x/net v0.43.0 // indirect golang.org/x/sync v0.6.0 // indirect ) + +replace github.com/mdlayher/netlink => ../netlink diff --git a/rule.go b/rule.go index cd4fd39..3e842d0 100644 --- a/rule.go +++ b/rule.go @@ -170,20 +170,25 @@ func (cc *Conn) getRules(t *Table, c *Chain, msgType nftMsgType, handle uint64) return nil, fmt.Errorf("SendMessages: %v", err) } - reply, err := cc.receive(conn) - if err != nil { - return nil, fmt.Errorf("receive: %w", err) - } var rules []*Rule - for _, msg := range reply { - r, err := ruleFromMsg(t.Family, msg) - if err != nil { - return nil, err + var firstErr error + + for msg, err := range cc.receiveSeq(conn) { + if err != nil && firstErr == nil { + firstErr = err + continue } - rules = append(rules, r) + + rule, err := ruleFromMsg(t.Family, msg) + if err != nil && firstErr == nil { + firstErr = err + continue + } + + rules = append(rules, rule) } - return rules, nil + return rules, firstErr } func (cc *Conn) newRule(r *Rule, op ruleOperation) *Rule { diff --git a/set.go b/set.go index 38997e1..4faabf8 100644 --- a/set.go +++ b/set.go @@ -1106,20 +1106,24 @@ func (cc *Conn) getSetElements(s *Set, e []SetElement, reset bool) ([]SetElement return nil, fmt.Errorf("SendMessages: %v", err) } - reply, err := cc.receive(conn) - if err != nil { - return nil, fmt.Errorf("receive: %w", err) - } var elems []SetElement - for _, msg := range reply { - s, err := elementsFromMsg(uint8(s.Table.Family), msg) - if err != nil { - return nil, err + var firstErr error + for msg, err := range cc.receiveSeq(conn) { + if err != nil && firstErr == nil { + firstErr = err + continue } - elems = append(elems, s...) + + e, err := elementsFromMsg(uint8(s.Table.Family), msg) + if err != nil && firstErr == nil { + firstErr = err + continue + } + + elems = append(elems, e...) } - return elems, nil + return elems, firstErr } // GetSetElements returns the elements in the specified set.