Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions console.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,20 @@ import (
"unicode/utf8"

"github.com/creack/pty"
"github.com/muesli/cancelreader"
)

// Console is an interface to automate input and output for interactive
// applications. Console can block until a specified output is received and send
// input back on it's tty. Console can also multiplex other sources of input
// and multiplex its output to other writers.
type Console struct {
opts ConsoleOpts
ptm *os.File
pts *os.File
passthroughPipe *PassthroughPipe
runeReader *bufio.Reader
closers []io.Closer
opts ConsoleOpts
ptm *os.File
pts *os.File
cancelReader cancelreader.CancelReader
runeReader *bufio.Reader
closers []io.Closer
}

// ConsoleOpt allows setting Console options.
Expand Down Expand Up @@ -151,19 +152,16 @@ func NewConsole(opts ...ConsoleOpt) (*Console, error) {
}
closers := append(options.Closers, pts, ptm)

passthroughPipe, err := NewPassthroughPipe(ptm)
if err != nil {
return nil, err
c := &Console{
opts: options,
ptm: ptm,
pts: pts,
closers: closers,
}
closers = append(closers, passthroughPipe)

c := &Console{
opts: options,
ptm: ptm,
pts: pts,
passthroughPipe: passthroughPipe,
runeReader: bufio.NewReaderSize(passthroughPipe, utf8.UTFMax),
closers: closers,
err = c.reset()
if err != nil {
return nil, err
}

for _, stdin := range options.Stdins {
Expand All @@ -178,6 +176,16 @@ func NewConsole(opts ...ConsoleOpt) (*Console, error) {
return c, nil
}

func (c *Console) reset() error {
var err error
c.cancelReader, err = cancelreader.NewReader(c.ptm)
if err != nil {
return err
}
c.runeReader = bufio.NewReaderSize(c.cancelReader, utf8.UTFMax)
return nil
}

// Tty returns Console's pts (slave part of a pty). A pseudoterminal, or pty is
// a pair of psuedo-devices, one of which, the slave, emulates a real text
// terminal device.
Expand Down
34 changes: 30 additions & 4 deletions expect.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@ package expect
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"time"
"unicode/utf8"

"github.com/muesli/cancelreader"
)

var (
// ErrTimeout is returned if expect read timeout is reached before
// the expect condition is reached.
ErrTimeout = errors.New("read timeout")
)

// Expectf reads from the Console's tty until the provided formatted string
Expand Down Expand Up @@ -78,16 +87,33 @@ func (c *Console) Expect(opts ...ExpectOpt) (string, error) {
}()

for {
var readCh chan struct{}
if readTimeout != nil {
err = c.passthroughPipe.SetReadDeadline(time.Now().Add(*readTimeout))
if err != nil {
return buf.String(), err
}
readCh = make(chan struct{})
go func() {
timer := time.NewTimer(*readTimeout)
select {
case <-readCh:
timer.Stop()
case <-timer.C:
c.cancelReader.Cancel()
}
}()
}

var r rune
r, _, err = c.runeReader.ReadRune()
if readCh != nil {
close(readCh)
}
if err != nil {
if errors.Is(err, cancelreader.ErrCanceled) {
rerr := c.reset()
if rerr != nil {
return buf.String(), rerr
}
err = ErrTimeout
}
matcher = options.Match(err)
if matcher != nil {
err = nil
Expand Down
80 changes: 39 additions & 41 deletions expect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"os/exec"
"runtime/debug"
"strings"
"sync"
"testing"
"time"
)
Expand Down Expand Up @@ -123,10 +122,9 @@ func TestExpectf(t *testing.T) {
}
defer testCloser(t, c)

var wg sync.WaitGroup
wg.Add(1)
done := make(chan struct{})
go func() {
defer wg.Done()
defer close(done)
c.Expectf("What is 1+%d?", 1)
c.SendLine("2")
c.Expectf("What is %s backwards?", "Netflix")
Expand All @@ -139,7 +137,7 @@ func TestExpectf(t *testing.T) {
t.Errorf("Expected no error but got '%s'", err)
}
testCloser(t, c.Tty())
wg.Wait()
waitTestEnd(t, done)
}

func TestExpect(t *testing.T) {
Expand All @@ -151,10 +149,9 @@ func TestExpect(t *testing.T) {
}
defer testCloser(t, c)

var wg sync.WaitGroup
wg.Add(1)
done := make(chan struct{})
go func() {
defer wg.Done()
defer close(done)
c.ExpectString("What is 1+1?")
c.SendLine("2")
c.ExpectString("What is Netflix backwards?")
Expand All @@ -168,7 +165,15 @@ func TestExpect(t *testing.T) {
}
// close the pts so we can expect EOF
testCloser(t, c.Tty())
wg.Wait()
waitTestEnd(t, done)
}

func waitTestEnd(t *testing.T, done <-chan struct{}) {
select {
case <-done:
case <-time.After(3 * time.Second):
t.Error("Expected test to end within 3s")
}
}

func TestExpectOutput(t *testing.T) {
Expand All @@ -180,10 +185,9 @@ func TestExpectOutput(t *testing.T) {
}
defer testCloser(t, c)

var wg sync.WaitGroup
wg.Add(1)
done := make(chan struct{})
go func() {
defer wg.Done()
defer close(done)
c.ExpectString("What is 1+1?")
c.SendLine("3")
c.ExpectEOF()
Expand All @@ -194,7 +198,7 @@ func TestExpectOutput(t *testing.T) {
t.Errorf("Expected error '%s' but got '%s' instead", ErrWrongAnswer, err)
}
testCloser(t, c.Tty())
wg.Wait()
waitTestEnd(t, done)
}

func TestExpectDefaultTimeout(t *testing.T) {
Expand All @@ -206,21 +210,20 @@ func TestExpectDefaultTimeout(t *testing.T) {
}
defer testCloser(t, c)

var wg sync.WaitGroup
wg.Add(1)
done := make(chan struct{})
go func() {
defer wg.Done()
defer close(done)
Prompt(c.Tty(), c.Tty())
}()

_, err = c.ExpectString("What is 1+2?")
if err == nil || !strings.Contains(err.Error(), "i/o timeout") {
t.Errorf("Expected error to contain 'i/o timeout' but got '%s' instead", err)
if err == nil || !errors.Is(err, ErrTimeout) {
t.Errorf("Expected error to be ErrTimeout but got '%s' instead", err)
}

// Close to unblock Prompt and wait for the goroutine to exit.
c.Tty().Close()
wg.Wait()
waitTestEnd(t, done)
}

func TestExpectTimeout(t *testing.T) {
Expand All @@ -232,21 +235,20 @@ func TestExpectTimeout(t *testing.T) {
}
defer testCloser(t, c)

var wg sync.WaitGroup
wg.Add(1)
done := make(chan struct{})
go func() {
defer wg.Done()
defer close(done)
Prompt(c.Tty(), c.Tty())
}()

_, err = c.Expect(String("What is 1+2?"), WithTimeout(0))
if err == nil || !strings.Contains(err.Error(), "i/o timeout") {
t.Errorf("Expected error to contain 'i/o timeout' but got '%s' instead", err)
if err == nil || !errors.Is(err, ErrTimeout) {
t.Errorf("Expected error to be ErrTimeout but got '%s' instead", err)
}

// Close to unblock Prompt and wait for the goroutine to exit.
c.Tty().Close()
wg.Wait()
waitTestEnd(t, done)
}

func TestExpectDefaultTimeoutOverride(t *testing.T) {
Expand All @@ -258,10 +260,9 @@ func TestExpectDefaultTimeoutOverride(t *testing.T) {
}
defer testCloser(t, c)

var wg sync.WaitGroup
wg.Add(1)
done := make(chan struct{})
go func() {
defer wg.Done()
defer close(done)
err = Prompt(c.Tty(), c.Tty())
if err != nil {
t.Errorf("Expected no error but got '%s'", err)
Expand All @@ -276,7 +277,7 @@ func TestExpectDefaultTimeoutOverride(t *testing.T) {
c.SendLine("xilfteN")
c.Expect(EOF, PTSClosed, WithTimeout(time.Second))

wg.Wait()
waitTestEnd(t, done)
}

func TestConsoleChain(t *testing.T) {
Expand All @@ -288,10 +289,9 @@ func TestConsoleChain(t *testing.T) {
}
defer testCloser(t, c1)

var wg1 sync.WaitGroup
wg1.Add(1)
done1 := make(chan struct{})
go func() {
defer wg1.Done()
defer close(done1)
c1.ExpectString("What is Netflix backwards?")
c1.SendLine("xilfteN")
c1.ExpectEOF()
Expand All @@ -303,10 +303,9 @@ func TestConsoleChain(t *testing.T) {
}
defer testCloser(t, c2)

var wg2 sync.WaitGroup
wg2.Add(1)
done2 := make(chan struct{})
go func() {
defer wg2.Done()
defer close(done2)
c2.ExpectString("What is 1+1?")
c2.SendLine("2")
c2.ExpectEOF()
Expand All @@ -318,10 +317,10 @@ func TestConsoleChain(t *testing.T) {
}

testCloser(t, c2.Tty())
wg2.Wait()
waitTestEnd(t, done2)

testCloser(t, c1.Tty())
wg1.Wait()
waitTestEnd(t, done1)
}

func TestEditor(t *testing.T) {
Expand All @@ -346,10 +345,9 @@ func TestEditor(t *testing.T) {
cmd.Stdout = c.Tty()
cmd.Stderr = c.Tty()

var wg sync.WaitGroup
wg.Add(1)
done := make(chan struct{})
go func() {
defer wg.Done()
defer close(done)
c.Send("iHello world\x1b")
c.SendLine(":wq!")
c.ExpectEOF()
Expand All @@ -361,7 +359,7 @@ func TestEditor(t *testing.T) {
}

testCloser(t, c.Tty())
wg.Wait()
waitTestEnd(t, done)

data, err := ioutil.ReadFile(file.Name())
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ go 1.13

require (
github.com/creack/pty v1.1.17
github.com/muesli/cancelreader v0.1.1-0.20220210172959-fc0e97880fe4
github.com/stretchr/testify v1.6.1
)
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI=
github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/muesli/cancelreader v0.1.1-0.20220210172959-fc0e97880fe4 h1:E6Wlth8ZhQOO+cyKvDvyuzHJSxFzoNUfbjtd51UHQrs=
github.com/muesli/cancelreader v0.1.1-0.20220210172959-fc0e97880fe4/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a h1:ppl5mZgokTT8uPkmYOyEUmPTr3ypaKkg5eFOGrAmxxE=
golang.org/x/sys v0.0.0-20220204135822-1c1b9b1eba6a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
Expand Down
Loading