From 1df77928c7bfd02ee6241d2c5eb4f2a43e5053ac Mon Sep 17 00:00:00 2001 From: Edgar Lee Date: Wed, 9 Feb 2022 21:47:10 -0800 Subject: [PATCH] Use muesli/cancelreader for cancellable readers without data being consumed Signed-off-by: Edgar Lee --- console.go | 42 +++++++++++------- expect.go | 34 ++++++++++++-- expect_test.go | 80 +++++++++++++++++---------------- go.mod | 1 + go.sum | 4 ++ passthrough_pipe.go | 95 ---------------------------------------- passthrough_pipe_test.go | 53 ---------------------- reader_lease.go | 87 ------------------------------------ reader_lease_test.go | 64 --------------------------- 9 files changed, 99 insertions(+), 361 deletions(-) delete mode 100644 passthrough_pipe.go delete mode 100644 passthrough_pipe_test.go delete mode 100644 reader_lease.go delete mode 100644 reader_lease_test.go diff --git a/console.go b/console.go index 26deead..df7f776 100644 --- a/console.go +++ b/console.go @@ -25,6 +25,7 @@ import ( "unicode/utf8" "github.com/creack/pty" + "github.com/muesli/cancelreader" ) // Console is an interface to automate input and output for interactive @@ -32,12 +33,12 @@ import ( // input back on it's tty. Console can also multiplex other sources of input // and multiplex its output to other writers. type Console struct { - opts ConsoleOpts - ptm *os.File - pts *os.File - passthroughPipe *PassthroughPipe - runeReader *bufio.Reader - closers []io.Closer + opts ConsoleOpts + ptm *os.File + pts *os.File + cancelReader cancelreader.CancelReader + runeReader *bufio.Reader + closers []io.Closer } // ConsoleOpt allows setting Console options. @@ -151,19 +152,16 @@ func NewConsole(opts ...ConsoleOpt) (*Console, error) { } closers := append(options.Closers, pts, ptm) - passthroughPipe, err := NewPassthroughPipe(ptm) - if err != nil { - return nil, err + c := &Console{ + opts: options, + ptm: ptm, + pts: pts, + closers: closers, } - closers = append(closers, passthroughPipe) - c := &Console{ - opts: options, - ptm: ptm, - pts: pts, - passthroughPipe: passthroughPipe, - runeReader: bufio.NewReaderSize(passthroughPipe, utf8.UTFMax), - closers: closers, + err = c.reset() + if err != nil { + return nil, err } for _, stdin := range options.Stdins { @@ -178,6 +176,16 @@ func NewConsole(opts ...ConsoleOpt) (*Console, error) { return c, nil } +func (c *Console) reset() error { + var err error + c.cancelReader, err = cancelreader.NewReader(c.ptm) + if err != nil { + return err + } + c.runeReader = bufio.NewReaderSize(c.cancelReader, utf8.UTFMax) + return nil +} + // Tty returns Console's pts (slave part of a pty). A pseudoterminal, or pty is // a pair of psuedo-devices, one of which, the slave, emulates a real text // terminal device. diff --git a/expect.go b/expect.go index b99b326..1fbbc40 100644 --- a/expect.go +++ b/expect.go @@ -17,10 +17,19 @@ package expect import ( "bufio" "bytes" + "errors" "fmt" "io" "time" "unicode/utf8" + + "github.com/muesli/cancelreader" +) + +var ( + // ErrTimeout is returned if expect read timeout is reached before + // the expect condition is reached. + ErrTimeout = errors.New("read timeout") ) // Expectf reads from the Console's tty until the provided formatted string @@ -78,16 +87,33 @@ func (c *Console) Expect(opts ...ExpectOpt) (string, error) { }() for { + var readCh chan struct{} if readTimeout != nil { - err = c.passthroughPipe.SetReadDeadline(time.Now().Add(*readTimeout)) - if err != nil { - return buf.String(), err - } + readCh = make(chan struct{}) + go func() { + timer := time.NewTimer(*readTimeout) + select { + case <-readCh: + timer.Stop() + case <-timer.C: + c.cancelReader.Cancel() + } + }() } var r rune r, _, err = c.runeReader.ReadRune() + if readCh != nil { + close(readCh) + } if err != nil { + if errors.Is(err, cancelreader.ErrCanceled) { + rerr := c.reset() + if rerr != nil { + return buf.String(), rerr + } + err = ErrTimeout + } matcher = options.Match(err) if matcher != nil { err = nil diff --git a/expect_test.go b/expect_test.go index 0ecee34..fdade4d 100644 --- a/expect_test.go +++ b/expect_test.go @@ -25,7 +25,6 @@ import ( "os/exec" "runtime/debug" "strings" - "sync" "testing" "time" ) @@ -123,10 +122,9 @@ func TestExpectf(t *testing.T) { } defer testCloser(t, c) - var wg sync.WaitGroup - wg.Add(1) + done := make(chan struct{}) go func() { - defer wg.Done() + defer close(done) c.Expectf("What is 1+%d?", 1) c.SendLine("2") c.Expectf("What is %s backwards?", "Netflix") @@ -139,7 +137,7 @@ func TestExpectf(t *testing.T) { t.Errorf("Expected no error but got '%s'", err) } testCloser(t, c.Tty()) - wg.Wait() + waitTestEnd(t, done) } func TestExpect(t *testing.T) { @@ -151,10 +149,9 @@ func TestExpect(t *testing.T) { } defer testCloser(t, c) - var wg sync.WaitGroup - wg.Add(1) + done := make(chan struct{}) go func() { - defer wg.Done() + defer close(done) c.ExpectString("What is 1+1?") c.SendLine("2") c.ExpectString("What is Netflix backwards?") @@ -168,7 +165,15 @@ func TestExpect(t *testing.T) { } // close the pts so we can expect EOF testCloser(t, c.Tty()) - wg.Wait() + waitTestEnd(t, done) +} + +func waitTestEnd(t *testing.T, done <-chan struct{}) { + select { + case <-done: + case <-time.After(3 * time.Second): + t.Error("Expected test to end within 3s") + } } func TestExpectOutput(t *testing.T) { @@ -180,10 +185,9 @@ func TestExpectOutput(t *testing.T) { } defer testCloser(t, c) - var wg sync.WaitGroup - wg.Add(1) + done := make(chan struct{}) go func() { - defer wg.Done() + defer close(done) c.ExpectString("What is 1+1?") c.SendLine("3") c.ExpectEOF() @@ -194,7 +198,7 @@ func TestExpectOutput(t *testing.T) { t.Errorf("Expected error '%s' but got '%s' instead", ErrWrongAnswer, err) } testCloser(t, c.Tty()) - wg.Wait() + waitTestEnd(t, done) } func TestExpectDefaultTimeout(t *testing.T) { @@ -206,21 +210,20 @@ func TestExpectDefaultTimeout(t *testing.T) { } defer testCloser(t, c) - var wg sync.WaitGroup - wg.Add(1) + done := make(chan struct{}) go func() { - defer wg.Done() + defer close(done) Prompt(c.Tty(), c.Tty()) }() _, err = c.ExpectString("What is 1+2?") - if err == nil || !strings.Contains(err.Error(), "i/o timeout") { - t.Errorf("Expected error to contain 'i/o timeout' but got '%s' instead", err) + if err == nil || !errors.Is(err, ErrTimeout) { + t.Errorf("Expected error to be ErrTimeout but got '%s' instead", err) } // Close to unblock Prompt and wait for the goroutine to exit. c.Tty().Close() - wg.Wait() + waitTestEnd(t, done) } func TestExpectTimeout(t *testing.T) { @@ -232,21 +235,20 @@ func TestExpectTimeout(t *testing.T) { } defer testCloser(t, c) - var wg sync.WaitGroup - wg.Add(1) + done := make(chan struct{}) go func() { - defer wg.Done() + defer close(done) Prompt(c.Tty(), c.Tty()) }() _, err = c.Expect(String("What is 1+2?"), WithTimeout(0)) - if err == nil || !strings.Contains(err.Error(), "i/o timeout") { - t.Errorf("Expected error to contain 'i/o timeout' but got '%s' instead", err) + if err == nil || !errors.Is(err, ErrTimeout) { + t.Errorf("Expected error to be ErrTimeout but got '%s' instead", err) } // Close to unblock Prompt and wait for the goroutine to exit. c.Tty().Close() - wg.Wait() + waitTestEnd(t, done) } func TestExpectDefaultTimeoutOverride(t *testing.T) { @@ -258,10 +260,9 @@ func TestExpectDefaultTimeoutOverride(t *testing.T) { } defer testCloser(t, c) - var wg sync.WaitGroup - wg.Add(1) + done := make(chan struct{}) go func() { - defer wg.Done() + defer close(done) err = Prompt(c.Tty(), c.Tty()) if err != nil { t.Errorf("Expected no error but got '%s'", err) @@ -276,7 +277,7 @@ func TestExpectDefaultTimeoutOverride(t *testing.T) { c.SendLine("xilfteN") c.Expect(EOF, PTSClosed, WithTimeout(time.Second)) - wg.Wait() + waitTestEnd(t, done) } func TestConsoleChain(t *testing.T) { @@ -288,10 +289,9 @@ func TestConsoleChain(t *testing.T) { } defer testCloser(t, c1) - var wg1 sync.WaitGroup - wg1.Add(1) + done1 := make(chan struct{}) go func() { - defer wg1.Done() + defer close(done1) c1.ExpectString("What is Netflix backwards?") c1.SendLine("xilfteN") c1.ExpectEOF() @@ -303,10 +303,9 @@ func TestConsoleChain(t *testing.T) { } defer testCloser(t, c2) - var wg2 sync.WaitGroup - wg2.Add(1) + done2 := make(chan struct{}) go func() { - defer wg2.Done() + defer close(done2) c2.ExpectString("What is 1+1?") c2.SendLine("2") c2.ExpectEOF() @@ -318,10 +317,10 @@ func TestConsoleChain(t *testing.T) { } testCloser(t, c2.Tty()) - wg2.Wait() + waitTestEnd(t, done2) testCloser(t, c1.Tty()) - wg1.Wait() + waitTestEnd(t, done1) } func TestEditor(t *testing.T) { @@ -346,10 +345,9 @@ func TestEditor(t *testing.T) { cmd.Stdout = c.Tty() cmd.Stderr = c.Tty() - var wg sync.WaitGroup - wg.Add(1) + done := make(chan struct{}) go func() { - defer wg.Done() + defer close(done) c.Send("iHello world\x1b") c.SendLine(":wq!") c.ExpectEOF() @@ -361,7 +359,7 @@ func TestEditor(t *testing.T) { } testCloser(t, c.Tty()) - wg.Wait() + waitTestEnd(t, done) data, err := ioutil.ReadFile(file.Name()) if err != nil { diff --git a/go.mod b/go.mod index 4e0e2a1..2245ccb 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,6 @@ go 1.13 require ( github.com/creack/pty v1.1.17 + github.com/muesli/cancelreader v0.1.1-0.20220210172959-fc0e97880fe4 github.com/stretchr/testify v1.6.1 ) diff --git a/go.sum b/go.sum index 5eeba84..c1f307d 100644 --- a/go.sum +++ b/go.sum @@ -2,11 +2,15 @@ github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI= github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/muesli/cancelreader v0.1.1-0.20220210172959-fc0e97880fe4 h1:E6Wlth8ZhQOO+cyKvDvyuzHJSxFzoNUfbjtd51UHQrs= +github.com/muesli/cancelreader v0.1.1-0.20220210172959-fc0e97880fe4/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a h1:ppl5mZgokTT8uPkmYOyEUmPTr3ypaKkg5eFOGrAmxxE= +golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= diff --git a/passthrough_pipe.go b/passthrough_pipe.go deleted file mode 100644 index 0056075..0000000 --- a/passthrough_pipe.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package expect - -import ( - "io" - "os" - "time" -) - -// PassthroughPipe is pipes data from a io.Reader and allows setting a read -// deadline. If a timeout is reached the error is returned, otherwise the error -// from the provided io.Reader is returned is passed through instead. -type PassthroughPipe struct { - reader *os.File - errC chan error -} - -// NewPassthroughPipe returns a new pipe for a io.Reader that passes through -// non-timeout errors. -func NewPassthroughPipe(reader io.Reader) (*PassthroughPipe, error) { - pipeReader, pipeWriter, err := os.Pipe() - if err != nil { - return nil, err - } - - errC := make(chan error, 1) - go func() { - defer close(errC) - _, readerErr := io.Copy(pipeWriter, reader) - if readerErr == nil { - // io.Copy reads from reader until EOF, and a successful Copy returns - // err == nil. We set it back to io.EOF to surface the error to Expect. - readerErr = io.EOF - } - - // Closing the pipeWriter will unblock the pipeReader.Read. - err = pipeWriter.Close() - if err != nil { - // If we are unable to close the pipe, and the pipe isn't already closed, - // the caller will hang indefinitely. - panic(err) - return - } - - // When an error is read from reader, we need it to passthrough the err to - // callers of (*PassthroughPipe).Read. - errC <- readerErr - }() - - return &PassthroughPipe{ - reader: pipeReader, - errC: errC, - }, nil -} - -func (pp *PassthroughPipe) Read(p []byte) (n int, err error) { - n, err = pp.reader.Read(p) - if err != nil { - if os.IsTimeout(err) { - return n, err - } - - // If the pipe is closed, this is the second time calling Read on - // PassthroughPipe, so just return the error from the os.Pipe io.Reader. - perr, ok := <-pp.errC - if !ok { - return n, err - } - - return n, perr - } - - return n, nil -} - -func (pp *PassthroughPipe) Close() error { - return pp.reader.Close() -} - -func (pp *PassthroughPipe) SetReadDeadline(t time.Time) error { - return pp.reader.SetReadDeadline(t) -} diff --git a/passthrough_pipe_test.go b/passthrough_pipe_test.go deleted file mode 100644 index 8553a4d..0000000 --- a/passthrough_pipe_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package expect - -import ( - "errors" - "io" - "os" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -func TestPassthroughPipe(t *testing.T) { - r, w := io.Pipe() - - passthroughPipe, err := NewPassthroughPipe(r) - require.NoError(t, err) - - err = passthroughPipe.SetReadDeadline(time.Now().Add(time.Hour)) - require.NoError(t, err) - - pipeError := errors.New("pipe error") - err = w.CloseWithError(pipeError) - require.NoError(t, err) - - p := make([]byte, 1) - _, err = passthroughPipe.Read(p) - require.Equal(t, err, pipeError) -} - -func TestPassthroughPipeTimeout(t *testing.T) { - r, w := io.Pipe() - - passthroughPipe, err := NewPassthroughPipe(r) - require.NoError(t, err) - - err = passthroughPipe.SetReadDeadline(time.Now()) - require.NoError(t, err) - - _, err = w.Write([]byte("a")) - require.NoError(t, err) - - p := make([]byte, 1) - _, err = passthroughPipe.Read(p) - require.True(t, os.IsTimeout(err)) - - err = passthroughPipe.SetReadDeadline(time.Time{}) - require.NoError(t, err) - - n, err := passthroughPipe.Read(p) - require.Equal(t, 1, n) - require.NoError(t, err) -} diff --git a/reader_lease.go b/reader_lease.go deleted file mode 100644 index 50180de..0000000 --- a/reader_lease.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2018 Netflix, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package expect - -import ( - "context" - "fmt" - "io" -) - -// ReaderLease provides cancellable io.Readers from an underlying io.Reader. -type ReaderLease struct { - reader io.Reader - bytec chan byte -} - -// NewReaderLease returns a new ReaderLease that begins reading the given -// io.Reader. -func NewReaderLease(reader io.Reader) *ReaderLease { - rm := &ReaderLease{ - reader: reader, - bytec: make(chan byte), - } - - go func() { - for { - p := make([]byte, 1) - n, err := rm.reader.Read(p) - if err != nil { - return - } - if n == 0 { - panic("non eof read 0 bytes") - } - rm.bytec <- p[0] - } - }() - - return rm -} - -// NewReader returns a cancellable io.Reader for the underlying io.Reader. -// Readers can be cancelled without interrupting other Readers, and once -// a reader is a cancelled it will not read anymore bytes from ReaderLease's -// underlying io.Reader. -func (rm *ReaderLease) NewReader(ctx context.Context) io.Reader { - return NewChanReader(ctx, rm.bytec) -} - -type chanReader struct { - ctx context.Context - bytec <-chan byte -} - -// NewChanReader returns a io.Reader over a byte chan. If context is cancelled, -// future Reads will return io.EOF. -func NewChanReader(ctx context.Context, bytec <-chan byte) io.Reader { - return &chanReader{ - ctx: ctx, - bytec: bytec, - } -} - -func (cr *chanReader) Read(p []byte) (n int, err error) { - select { - case <-cr.ctx.Done(): - return 0, io.EOF - case b := <-cr.bytec: - if len(p) < 1 { - return 0, fmt.Errorf("cannot read into 0 len byte slice") - } - p[0] = b - return 1, nil - } -} diff --git a/reader_lease_test.go b/reader_lease_test.go deleted file mode 100644 index 401bd8d..0000000 --- a/reader_lease_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package expect - -import ( - "context" - "io" - "sync" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestReaderLease(t *testing.T) { - in, out := io.Pipe() - defer out.Close() - defer in.Close() - - rm := NewReaderLease(in) - - tests := []struct { - title string - expected string - }{ - { - "Read cancels with deadline", - "apple", - }, - { - "Second read has no bytes stolen", - "banana", - }, - } - - for _, test := range tests { - t.Run(test.title, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - tin, tout := io.Pipe() - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - io.Copy(tout, rm.NewReader(ctx)) - }() - - wg.Add(1) - go func() { - defer wg.Done() - _, err := out.Write([]byte(test.expected)) - require.Nil(t, err) - }() - - for i := 0; i < len(test.expected); i++ { - p := make([]byte, 1) - n, err := tin.Read(p) - require.Nil(t, err) - require.Equal(t, 1, n) - require.Equal(t, test.expected[i], p[0]) - } - - cancel() - wg.Wait() - }) - } -}