Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 145 additions & 20 deletions ska_shell/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
"""Utilities to run subprocesses"""

import datetime
import errno
import functools
import logging
import re
import os
import sys
import pty
import re
import select
import signal
import subprocess
import sys
import termios


class ShellError(Exception):
Expand Down Expand Up @@ -112,6 +116,69 @@ def communicate(process, logfile=None, logger=None, log_level=None):
return lines


def _communicate_pty(process, master_fd, logfile=None, logger=None, log_level=None):
"""
Real-time reading of subprocess output from a pty master file descriptor.

Used when ``run_shell`` is invoked with ``unbuffered=True`` to avoid the
libc block-buffering that happens when the child's stdout is a pipe.

Parameters
----------
:param process: process returned by subprocess.Popen
:param master_fd: pty master file descriptor to read from
:param logfile: append output to the supplied file object (flushed per line)
:param logger: log output to the supplied logging.Logger
:param log_level: log level for logger
"""
log_level = "INFO" if log_level is None else log_level
log_level = getattr(logging, log_level) if type(log_level) is str else log_level

lines = []
buf = b""

def _emit(raw_bytes):
line = raw_bytes.decode(errors="replace").rstrip("\r")
if logfile:
logfile.write(line + "\n")
logfile.flush()
if logger is not None:
logger.log(log_level, line)
lines.append(line)

def _read_chunk():
try:
return os.read(master_fd, 4096)
except OSError as e:
if e.errno == errno.EIO:
return b""
raise

while True:
r, _, _ = select.select([master_fd], [], [], 0.1)
if r:
chunk = _read_chunk()
if not chunk:
break
buf += chunk
while b"\n" in buf:
raw, buf = buf.split(b"\n", 1)
_emit(raw)
elif process.poll() is not None:
chunk = _read_chunk()
if not chunk:
break
buf += chunk

while b"\n" in buf:
raw, buf = buf.split(b"\n", 1)
_emit(raw)
if buf:
_emit(buf)

return lines


@functools.cache
def _shell_ok(shell):
p = subprocess.run(["which", shell], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
Expand All @@ -128,6 +195,7 @@ def run_shell(
logger=None,
log_level=None,
check=None,
unbuffered=False,
):
"""
Run the command string ``cmdstr`` in a ``shell`` ('bash' or 'tcsh'). It can have
Expand All @@ -142,6 +210,8 @@ def run_shell(
:param getenv: get the environent changes after running ``cmdstr``
:param env: set environment using ``env`` dict prior to running commands
:param check: raise an exception if any command fails
:param unbuffered: stream output in real time via a pty so ``logfile`` / ``logger``
see each line as it is produced instead of at process exit

:rtype: (outlines, deltaenv)
"""
Expand Down Expand Up @@ -169,21 +239,50 @@ def run_shell(
elif shell in ["bash", "zsh"] and check:
actual_cmdstr = f"set -e; {actual_cmdstr}"

proc = subprocess.Popen(
[actual_cmdstr],
executable=actual_shell,
shell=True,
env=environ,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
if logfile:
time = datetime.datetime.now().isoformat()[:22]
logfile.write(f"{shell.capitalize()}-{time}> {cmdstr}\n")
stdout = communicate(proc, logfile=logfile, logger=logger, log_level=log_level)
if logfile:
time = datetime.datetime.now().isoformat()[:22]
logfile.write(f"{shell.capitalize()}-{time}>\n")
if unbuffered:
master_fd, slave_fd = pty.openpty()
try:
attrs = termios.tcgetattr(slave_fd)
attrs[1] &= ~termios.OPOST
termios.tcsetattr(slave_fd, termios.TCSANOW, attrs)
proc = subprocess.Popen(
[actual_cmdstr],
executable=actual_shell,
shell=True,
env=environ,
stdout=slave_fd,
stderr=slave_fd,
)
finally:
os.close(slave_fd)
else:
master_fd = None
proc = subprocess.Popen(
[actual_cmdstr],
executable=actual_shell,
shell=True,
env=environ,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)

try:
if logfile:
time = datetime.datetime.now().isoformat()[:22]
logfile.write(f"{shell.capitalize()}-{time}> {cmdstr}\n")
if master_fd is None:
stdout = communicate(proc, logfile=logfile, logger=logger, log_level=log_level)
else:
stdout = _communicate_pty(
proc, master_fd, logfile=logfile, logger=logger, log_level=log_level
)
if logfile:
time = datetime.datetime.now().isoformat()[:22]
logfile.write(f"{shell.capitalize()}-{time}>\n")
finally:
if master_fd is not None:
os.close(master_fd)
proc.wait()
if check and proc.returncode:
msg = " ".join(stdout[-1:]) # stdout could be empty
exc = NonZeroReturnCode(
Expand Down Expand Up @@ -224,7 +323,8 @@ def bash_shell(
env=None,
logger=None,
log_level=None,
check=None
check=None,
unbuffered=False,
):
"""
Run the command string ``cmdstr`` in a bash shell. It can have
Expand All @@ -239,6 +339,7 @@ def bash_shell(
:param env: set environment using ``env`` dict prior to running commands
:param logger: log output to the supplied logging.Logger
:param log_level: log level for logger
:param unbuffered: stream output in real time via a pty

:rtype: (outlines, deltaenv)
"""
Expand All @@ -252,11 +353,21 @@ def bash_shell(
logger=logger,
log_level=log_level,
check=check,
unbuffered=unbuffered,
)
return outlines, newenv


def bash(cmdstr, logfile=None, importenv=False, env=None, logger=None, log_level=None, check=None):
def bash(
cmdstr,
logfile=None,
importenv=False,
env=None,
logger=None,
log_level=None,
check=None,
unbuffered=False,
):
"""Run the ``cmdstr`` string in a bash shell. See ``run_shell`` for options.

:returns: bash output
Expand All @@ -270,10 +381,20 @@ def bash(cmdstr, logfile=None, importenv=False, env=None, logger=None, log_level
logger=logger,
log_level=log_level,
check=check,
unbuffered=unbuffered,
)[0]


def tcsh(cmdstr, logfile=None, importenv=False, env=None, logger=None, log_level=None, check=None):
def tcsh(
cmdstr,
logfile=None,
importenv=False,
env=None,
logger=None,
log_level=None,
check=None,
unbuffered=False,
):
"""Run the ``cmdstr`` string in a tcsh shell. See ``run_shell`` for options.

:returns: tcsh output
Expand All @@ -287,6 +408,7 @@ def tcsh(cmdstr, logfile=None, importenv=False, env=None, logger=None, log_level
logger=logger,
log_level=log_level,
check=check,
unbuffered=unbuffered,
)[0]


Expand All @@ -299,6 +421,7 @@ def tcsh_shell(
logger=None,
log_level=None,
check=None,
unbuffered=False,
):
"""
Run the command string ``cmdstr`` in a tcsh shell. It can have
Expand All @@ -313,6 +436,7 @@ def tcsh_shell(
:param env: set environment using ``env`` dict prior to running commands
:param logger: log output to the supplied logging.Logger
:param log_level: log level for logger
:param unbuffered: stream output in real time via a pty

:rtype: (outlines, deltaenv)
"""
Expand All @@ -326,6 +450,7 @@ def tcsh_shell(
logger=logger,
log_level=log_level,
check=check,
unbuffered=unbuffered,
)
return outlines, newenv

Expand Down
45 changes: 45 additions & 0 deletions ska_shell/tests/test_shell.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import os
import threading
import time

import pytest
from six.moves import cStringIO as StringIO
Expand Down Expand Up @@ -229,6 +231,49 @@ def test_check(self):
out = bash("lsd; echo DONE", check=True)


@pytest.mark.parametrize("shell_fn", [bash_shell, tcsh_shell])
def test_unbuffered_streaming(shell_fn, tmpdir):
"""With unbuffered=True, output must appear in the logfile as lines are
produced, not all at once when the process exits."""
logpath = tmpdir.join("stream.log")
snapshot = {}

def peek():
time.sleep(0.4)
snapshot["content"] = logpath.read() if logpath.exists() else ""

# Use sentinels that don't appear in the echoed command header.
cmd = "echo __alpha__\nsleep 1\necho __beta__"
t = threading.Thread(target=peek)
with open(str(logpath), "w", buffering=1) as fh:
t.start()
shell_fn(cmd, unbuffered=True, logfile=fh)
t.join()

mid_lines = snapshot["content"].splitlines()
assert "__alpha__" in mid_lines, f"'__alpha__' line should be present at 0.4s, got: {mid_lines!r}"
assert "__beta__" not in mid_lines, f"'__beta__' line should not be present at 0.4s, got: {mid_lines!r}"

full_lines = logpath.read().splitlines()
assert "__alpha__" in full_lines
assert "__beta__" in full_lines


@pytest.mark.parametrize("shell", ["bash", "tcsh"])
def test_unbuffered_getenv(shell):
"""unbuffered=True must not break __PRINTENV__ parsing for getenv/importenv."""
setter = 'export UNBUF_VAR="value"' if shell == "bash" else 'setenv UNBUF_VAR "value"'
_, env = run_shell(setter, shell=shell, unbuffered=True, getenv=True)
assert env.get("UNBUF_VAR") == "value"


@pytest.mark.parametrize("shell", ["bash", "tcsh"])
def test_unbuffered_lines_no_cr(shell):
"""pty OPOST is disabled, so lines must not carry stray '\\r' chars."""
outlines, _ = run_shell("echo one\necho two", shell=shell, unbuffered=True)
assert outlines == ["one", "two"]


@pytest.mark.parametrize("shell", ["bash", "tcsh"])
def test_err(shell):
cmds = """
Expand Down