Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions mama/utils/mama_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@
import sys

# Allow running as a standalone script, not just as a package module.
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
# Important: do NOT put `<...>/mama` on sys.path — `mama/types/` would then
# shadow Python's stdlib `types` module the moment anything (e.g. contextlib)
# does `from types import ...`. Add the package's PARENT instead, so that
# `mama.utils.ssh_multiplex` resolves as a normal qualified import.
if __package__ in (None, ''):
sys.path.insert(0, os.path.dirname(_THIS_DIR))
from utils import ssh_multiplex # type: ignore
try:
from mama.utils import ssh_multiplex
except ImportError:
_MAMA_PARENT = os.path.dirname(os.path.dirname(
os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, _MAMA_PARENT)
from mama.utils import ssh_multiplex
else:
from . import ssh_multiplex

Expand Down
23 changes: 21 additions & 2 deletions mama/utils/ssh_multiplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import atexit
import contextlib
import functools
import os
import re
import shlex
Expand All @@ -34,6 +35,9 @@
import threading
from urllib.parse import urlparse

from .sub_process import execute_piped
from .system import System


DEFAULT_MAX_CONCURRENT_FETCHES = 20

Expand Down Expand Up @@ -137,15 +141,30 @@ def is_multiplex_configured(probe: dict[str, str]) -> bool:
return cm not in ('no', 'false', '') and cp not in ('none', '', 'no')


@functools.cache
def multiplex_known_broken() -> bool:
"""True iff the active ssh is Microsoft's OpenSSH for Windows, whose
ControlMaster is flaky (master drops, stale socket blocks reattach).
Detected via the `OpenSSH_for_Windows_<ver>` banner; Cygwin/MSYS/Git-Bash
on Windows report the standard banner and work fine."""
if not System.windows:
return False
out = execute_piped(['ssh', '-V'], timeout=5, throw=False, merge_stderr=True)
if out is None:
return True # ssh missing or failed — be safe
return 'for_windows' in out.lower()


def options_to_add(probe: dict[str, str]) -> tuple[list[str], bool]:
"""
Return (-o args, we_own_master). `we_own_master` is True when we are the
one configuring multiplex (and therefore responsible for pre-warming and
cleaning it up). False if the user already has multiplex configured.
cleaning it up). False if the user already has multiplex configured, or
if multiplex is known-broken on this platform.
"""
opts: list[str] = []
we_own_master = False
if not is_multiplex_configured(probe):
if not multiplex_known_broken() and not is_multiplex_configured(probe):
we_own_master = True
os.makedirs(_OUR_CONTROL_DIR, mode=0o700, exist_ok=True)
opts += [
Expand Down
9 changes: 7 additions & 2 deletions mama/utils/sub_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,24 @@ def execute(command, echo=False, throw=True):


# TODO: use new SubProcess.run instead
def execute_piped(command, cwd=None, timeout=None, throw=True):
def execute_piped(command, cwd=None, timeout=None, throw=True, merge_stderr=False):
"""
Executes a command and returns the piped outout string
- command: command string
- cwd: working dir for the subprocess
- timeout: timeout in seconds
- throw: if True, throws exception on status_code != 0
- merge_stderr: if True, stderr is captured into the returned string too.
Useful for tools like `ssh -V` that emit to stderr.
- returns: output string or None if throw=False
"""
if not isinstance(command, list):
command = shlex.split(command)
try:
cp = subprocess.run(command, stdout=subprocess.PIPE, cwd=cwd, timeout=timeout)
cp = subprocess.run(command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT if merge_stderr else None,
cwd=cwd, timeout=timeout)
return cp.stdout.decode('utf-8').rstrip()
except Exception as e:
if throw:
Expand Down
131 changes: 131 additions & 0 deletions tests/test_ssh_multiplex/test_ssh_multiplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,97 @@ def test_user_has_multiplex_only(self):
assert not any(o.startswith('-oControlPath=') for o in opts)
assert any(o.startswith('-oServerAliveInterval=') for o in opts)

def test_windows_microsoft_ssh_skips_multiplex_keeps_keepalives(self, monkeypatch, tmp_path):
# Microsoft OpenSSH on Windows has unreliable ControlMaster — the
# master drops mid-session and leaves the socket file behind. We
# detect it via the "for_Windows" banner string and skip multiplex.
monkeypatch.setattr(sm, 'multiplex_known_broken', lambda: True)
monkeypatch.setattr(sm, '_OUR_CONTROL_DIR', str(tmp_path / 'cm'))
monkeypatch.setattr(sm, '_OUR_CONTROL_PATH', str(tmp_path / 'cm' / '%C'))
probe = {'controlmaster': 'no', 'controlpath': 'none',
'serveraliveinterval': '0'}
opts, we_own = sm.options_to_add(probe)
assert we_own is False
assert not any(o.startswith('-oControlMaster=') for o in opts)
assert not any(o.startswith('-oControlPath=') for o in opts)
assert not any(o.startswith('-oControlPersist=') for o in opts)
assert any(o.startswith('-oServerAliveInterval=') for o in opts)
assert any(o.startswith('-oServerAliveCountMax=') for o in opts)

def test_windows_cygwin_ssh_keeps_multiplex(self, monkeypatch, tmp_path):
# Cygwin/Git-Bash ssh on Windows reports the standard banner and has
# working ControlMaster — so we DO add multiplex even though we're
# on Windows. (Equivalent to "non-buggy ssh" in detection terms.)
monkeypatch.setattr(sm, 'multiplex_known_broken', lambda: False)
monkeypatch.setattr(sm, '_OUR_CONTROL_DIR', str(tmp_path / 'cm'))
monkeypatch.setattr(sm, '_OUR_CONTROL_PATH', str(tmp_path / 'cm' / '%C'))
probe = {'controlmaster': 'no', 'controlpath': 'none',
'serveraliveinterval': '0'}
opts, we_own = sm.options_to_add(probe)
assert we_own is True
assert any(o.startswith('-oControlMaster=') for o in opts)

def test_windows_user_configured_multiplex_respected(self, monkeypatch):
# Even when the active ssh is the buggy one, if the user has multiplex
# explicitly configured (e.g. via ~/.ssh/config pointing at Cygwin ssh)
# we must respect their config, not override it.
monkeypatch.setattr(sm, 'multiplex_known_broken', lambda: True)
probe = {
'controlmaster': 'auto', 'controlpath': '~/.ssh/sockets/%C',
'serveraliveinterval': '30', 'serveralivecountmax': '5',
}
opts, we_own = sm.options_to_add(probe)
assert we_own is False
assert opts == [], 'user has full config — we add nothing'


class TestMultiplexKnownBroken:
"""`ssh -V` banner parsing for known-buggy clients."""

@pytest.fixture(autouse=True)
def _clear_cache(self):
sm.multiplex_known_broken.cache_clear()
yield
sm.multiplex_known_broken.cache_clear()

def test_non_windows_never_broken(self, monkeypatch):
# On Linux/macOS we don't even probe — multiplex always works.
monkeypatch.setattr(sm.System, 'windows', False)
ep = mock.Mock()
monkeypatch.setattr(sm, 'execute_piped', ep)
assert sm.multiplex_known_broken() is False
ep.assert_not_called()

def test_microsoft_for_windows_banner_detected(self, monkeypatch):
monkeypatch.setattr(sm.System, 'windows', True)
monkeypatch.setattr(sm, 'execute_piped',
lambda *a, **k: 'OpenSSH_for_Windows_8.6p1, LibreSSL 3.4.3')
assert sm.multiplex_known_broken() is True

def test_cygwin_banner_not_broken(self, monkeypatch):
monkeypatch.setattr(sm.System, 'windows', True)
monkeypatch.setattr(sm, 'execute_piped',
lambda *a, **k: 'OpenSSH_9.6p1, OpenSSL 3.0.13 30 Jan 2024')
assert sm.multiplex_known_broken() is False

def test_result_is_cached(self, monkeypatch):
monkeypatch.setattr(sm.System, 'windows', True)
ep = mock.Mock(return_value='OpenSSH_for_Windows_8.6p1')
monkeypatch.setattr(sm, 'execute_piped', ep)
sm.multiplex_known_broken()
sm.multiplex_known_broken()
sm.multiplex_known_broken()
assert ep.call_count == 1

def test_ssh_missing_treated_as_broken_on_windows(self, monkeypatch):
# Conservative default: if we can't even invoke ssh, don't risk
# configuring multiplex on Windows.
monkeypatch.setattr(sm.System, 'windows', True)
# execute_piped(throw=False) returns None on failure; we treat that as
# the conservative "skip mux" default.
monkeypatch.setattr(sm, 'execute_piped', lambda *a, **k: None)
assert sm.multiplex_known_broken() is True


class TestProbeSshConfig:
def test_parses_keys(self):
Expand Down Expand Up @@ -270,6 +361,46 @@ def worker():
assert probe_count[0] == 1


class TestWrapperPathSafety:
"""Regression: running mama_ssh.py as a script must not shadow stdlib
modules. Earlier versions inserted `<...>/mama` onto sys.path, which made
`mama/types/` shadow Python's stdlib `types` module — breaking `contextlib`
on uv-installed Pythons that hadn't pre-imported it."""

def test_invocation_does_not_put_mama_dir_on_syspath(self, tmp_path):
import json
import subprocess
import textwrap
wrapper = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', '..', 'mama', 'utils', 'mama_ssh.py'))
mama_dir = os.path.dirname(os.path.dirname(wrapper))
# Subprocess so we get a fresh interpreter (no pre-cached `types` etc).
# Monkey-patch os.execvp to a no-op BEFORE running the wrapper, so it
# can't replace the process before we read sys.path back.
probe = tmp_path / 'probe.py'
probe.write_text(textwrap.dedent(f"""
import json, os, sys
os.execvp = lambda *a, **k: None
sys.argv = [{wrapper!r}, 'git@example.com:foo.git', 'git-upload-pack']
ns = {{'__name__': '__main__', '__package__': '', '__file__': {wrapper!r}}}
with open({wrapper!r}) as f:
code = f.read()
try:
exec(code, ns)
except SystemExit:
pass
print('PATH_PROBE:' + json.dumps(sys.path))
"""))
cp = subprocess.run([sys.executable, str(probe)],
capture_output=True, text=True, timeout=15)
marker = [l for l in cp.stdout.splitlines() if l.startswith('PATH_PROBE:')]
assert marker, f'probe did not produce output. stderr={cp.stderr!r}'
path = json.loads(marker[-1][len('PATH_PROBE:'):])
assert mama_dir not in path, (
f'{mama_dir!r} ended up on sys.path — `mama/types/` would shadow '
f'stdlib `types`. sys.path={path!r}')


class TestWrapperMain:
"""The wrapper passes options + destination unchanged to ssh -G, then
exec's ssh with whatever extra -o flags are needed."""
Expand Down