diff --git a/PolyPilot.Tests/ConnectionRecoveryTests.cs b/PolyPilot.Tests/ConnectionRecoveryTests.cs index 4a905f0e0..29eb5a1ab 100644 --- a/PolyPilot.Tests/ConnectionRecoveryTests.cs +++ b/PolyPilot.Tests/ConnectionRecoveryTests.cs @@ -464,6 +464,66 @@ public void RestorePreviousSessions_FallbackCoversProcessErrors() "IsProcessError must be included in the RestorePreviousSessionsAsync fallback condition (not found after the 'Session not found' anchor)"); } + // ===== Behavior test: process error → CreateSessionAsync fallback ===== + // Proves that when RestorePreviousSessionsAsync encounters a stale CLI server process, + // the session is recreated via CreateSessionAsync rather than silently dropped. + // + // Architecture note: CopilotClient is a concrete SDK class (not mockable), and + // ResumeSessionAsync is not virtual, so we can't inject a throwing client through + // the full RestorePreviousSessionsAsync pipeline. Instead, this test verifies the + // behavioral contract at the seam: IsProcessError detects the exception, and + // CreateSessionAsync (the fallback) successfully creates the replacement session. + // The structural test above guarantees these are wired together in RestorePreviousSessionsAsync. + + [Fact] + public async Task ProcessError_DuringRestore_FallbackCreatesSession() + { + // GIVEN: a process error exception (CLI server died, stale handle) + var processError = new InvalidOperationException("No process is associated with this object."); + + // WHEN: IsProcessError evaluates it + Assert.True(CopilotService.IsProcessError(processError)); + // Also detected as a connection error (broader category) + Assert.True(CopilotService.IsConnectionError(processError)); + + // THEN: the CreateSessionAsync fallback path works — session is created and accessible + var svc = CreateService(); + await svc.ReconnectAsync(new PolyPilot.Models.ConnectionSettings + { + Mode = PolyPilot.Models.ConnectionMode.Demo + }); + + var fallbackSession = await svc.CreateSessionAsync("Recovered Session", "gpt-4"); + Assert.NotNull(fallbackSession); + Assert.Equal("Recovered Session", fallbackSession.Name); + + var allSessions = svc.GetAllSessions().Select(s => s.Name).ToList(); + Assert.Contains("Recovered Session", allSessions); + } + + [Fact] + public async Task ProcessError_WrappedInAggregate_FallbackCreatesSession() + { + // GIVEN: a process error wrapped in AggregateException (from TaskScheduler.UnobservedTaskException) + var inner = new InvalidOperationException("No process is associated with this object."); + var aggregate = new AggregateException("A Task's exception(s) were not observed", inner); + + // WHEN: IsProcessError evaluates the wrapped exception + Assert.True(CopilotService.IsProcessError(aggregate)); + Assert.True(CopilotService.IsConnectionError(aggregate)); + + // THEN: the fallback path works + var svc = CreateService(); + await svc.ReconnectAsync(new PolyPilot.Models.ConnectionSettings + { + Mode = PolyPilot.Models.ConnectionMode.Demo + }); + + var session = await svc.CreateSessionAsync("Recovered Aggregate", "gpt-4"); + Assert.NotNull(session); + Assert.Equal("Recovered Aggregate", session.Name); + } + // ===== SafeFireAndForget task observation ===== // Prevents UnobservedTaskException from fire-and-forget _chatDb calls. // See crash log: "A Task's exception(s) were not observed" wrapping ConnectionLostException. diff --git a/PolyPilot.Tests/DiagnosticsLogTests.cs b/PolyPilot.Tests/DiagnosticsLogTests.cs index a78b6c51a..48ff2bed8 100644 --- a/PolyPilot.Tests/DiagnosticsLogTests.cs +++ b/PolyPilot.Tests/DiagnosticsLogTests.cs @@ -9,6 +9,12 @@ namespace PolyPilot.Tests; /// (not just DEBUG). The #if DEBUG guard was removed so Release builds also /// get lifecycle diagnostics for post-mortem analysis. /// +/// +/// In the "BaseDir" collection because CopilotService.BaseDir is a shared static. +/// MultiAgentRegressionTests temporarily changes it via SetBaseDirForTesting(), +/// which would change the log file path mid-test if we ran in parallel with them. +/// +[Collection("BaseDir")] public class DiagnosticsLogTests { private readonly StubChatDatabase _chatDb = new(); diff --git a/PolyPilot.Tests/ExternalSessionScannerTests.cs b/PolyPilot.Tests/ExternalSessionScannerTests.cs index c69ed59b4..886c3456b 100644 --- a/PolyPilot.Tests/ExternalSessionScannerTests.cs +++ b/PolyPilot.Tests/ExternalSessionScannerTests.cs @@ -459,22 +459,55 @@ public void FindActiveLockPid_DetectsCurrentProcess() var dir = Path.Combine(_sessionStateDir, sessionId); Directory.CreateDirectory(dir); - // Start a real "dotnet" process so the name passes the process-name validation - using var child = System.Diagnostics.Process.Start(new System.Diagnostics.ProcessStartInfo("dotnet", "--info") + // Use a command guaranteed to run for much longer than the test: + // `dotnet repl` / `dotnet watch` aren't available everywhere, but + // reading stdin on a `dotnet` REPL-like loop works cross-platform. + // Simplest portable option: run `sleep` on Unix, `timeout` on Windows. + System.Diagnostics.Process child; + if (OperatingSystem.IsWindows()) { - RedirectStandardOutput = true, - UseShellExecute = false, - }); - Assert.NotNull(child); - - File.WriteAllText(Path.Combine(dir, $"inuse.{child.Id}.lock"), ""); - - var scanner = new ExternalSessionScanner(_sessionStateDir, () => new HashSet()); - var detectedPid = scanner.FindActiveLockPid(dir); + child = System.Diagnostics.Process.Start(new System.Diagnostics.ProcessStartInfo("cmd", "/c timeout /t 60 /nobreak") + { + RedirectStandardOutput = true, + RedirectStandardInput = true, + UseShellExecute = false, + })!; + } + else + { + child = System.Diagnostics.Process.Start(new System.Diagnostics.ProcessStartInfo("sleep", "60") + { + UseShellExecute = false, + })!; + } - Assert.Equal(child.Id, detectedPid); + Assert.NotNull(child); - if (!child.HasExited) child.Kill(); + try + { + File.WriteAllText(Path.Combine(dir, $"inuse.{child.Id}.lock"), ""); + + // `sleep 60` / `timeout /t 60` will not exit in the test window — no race guard needed. + Assert.False(child.HasExited, "Long-running child process should still be alive"); + + // FindActiveLockPid requires a dotnet/copilot/node/github process name. + // `sleep`/`cmd` won't pass that filter. We need to use the current test process instead. + // Verify the behaviour using the test process itself (definitely alive, name = "dotnet"). + var testSessionId = Guid.NewGuid().ToString(); + var testDir = Path.Combine(_sessionStateDir, testSessionId); + Directory.CreateDirectory(testDir); + var myPid = Environment.ProcessId; + File.WriteAllText(Path.Combine(testDir, $"inuse.{myPid}.lock"), ""); + + var scanner = new ExternalSessionScanner(_sessionStateDir, () => new HashSet()); + var detectedPid = scanner.FindActiveLockPid(testDir); + Assert.Equal(myPid, detectedPid); + } + finally + { + if (!child.HasExited) child.Kill(); + child.Dispose(); + } } [Fact] diff --git a/PolyPilot.Tests/FiestaPairingTests.cs b/PolyPilot.Tests/FiestaPairingTests.cs new file mode 100644 index 000000000..155735d0d --- /dev/null +++ b/PolyPilot.Tests/FiestaPairingTests.cs @@ -0,0 +1,635 @@ +using System.Net; +using System.Net.WebSockets; +using System.Reflection; +using System.Text; +using System.Text.Json; +using Microsoft.Extensions.DependencyInjection; +using PolyPilot.Models; +using PolyPilot.Services; + +namespace PolyPilot.Tests; + +/// +/// Tests for Fiesta pairing features: pairing string encode/decode, +/// ApprovePairRequestAsync TCS behavior on failure, and RequestPairAsync +/// with a malformed approval response (Approved=true but null BridgeUrl). +/// +public class FiestaPairingTests : IDisposable +{ + private readonly WsBridgeServer _bridgeServer; + private readonly CopilotService _copilot; + private readonly FiestaService _fiesta; + + public FiestaPairingTests() + { + _bridgeServer = new WsBridgeServer(); + // Pre-set the server password so EnsureServerPassword() never falls through to + // ConnectionSettings.Load()/Save(), which would touch the real ~/.polypilot/settings.json. + _bridgeServer.ServerPassword = "test-token-isolation"; + _copilot = new CopilotService( + new StubChatDatabase(), + new StubServerManager(), + new StubWsBridgeClient(), + new RepoManager(), + new ServiceCollection().BuildServiceProvider(), + new StubDemoService()); + _fiesta = new FiestaService(_copilot, _bridgeServer, new TailscaleService()); + } + + public void Dispose() + { + _fiesta.Dispose(); + _bridgeServer.Dispose(); + } + + // ---- Helpers ---- + + private static string BuildPairingString(string url, string token, string hostname) + { + var payload = new FiestaPairingPayload { Url = url, Token = token, Hostname = hostname }; + var json = JsonSerializer.Serialize(payload, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); + var b64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(json)) + .TrimEnd('=') + .Replace('+', '-') + .Replace('/', '_'); + return $"pp+{b64}"; + } + + private static int GetFreePort() + { + using var l = new System.Net.Sockets.TcpListener(IPAddress.Loopback, 0); + l.Start(); + var port = ((IPEndPoint)l.LocalEndpoint).Port; + l.Stop(); + return port; + } + + // ---- Test 1: Pairing string roundtrip ---- + + [Fact] + public void ParseAndLinkPairingString_Roundtrip_CorrectWorkerFields() + { + const string url = "http://192.168.1.50:4322"; + const string token = "test-token-abc123"; + const string hostname = "devbox-1"; + + var pairingString = BuildPairingString(url, token, hostname); + Assert.StartsWith("pp+", pairingString); + + var linked = _fiesta.ParseAndLinkPairingString(pairingString); + + Assert.Equal(url, linked.BridgeUrl); + Assert.Equal(token, linked.Token); + Assert.Equal(hostname, linked.Name); + Assert.Single(_fiesta.LinkedWorkers); + } + + [Fact] + public void ParseAndLinkPairingString_InvalidPrefix_ThrowsFormatException() + { + Assert.Throws(() => _fiesta.ParseAndLinkPairingString("notvalid")); + Assert.Throws(() => _fiesta.ParseAndLinkPairingString("pp+!!!notbase64!!!")); + } + + [Fact] + public void ParseAndLinkPairingString_MissingUrl_ThrowsFormatException() + { + // Build a pairing string that's valid base64 but has no URL field + var payload = new FiestaPairingPayload { Url = "", Token = "tok", Hostname = "host" }; + var json = JsonSerializer.Serialize(payload, new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); + var b64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(json)).TrimEnd('=').Replace('+', '-').Replace('/', '_'); + var s = $"pp+{b64}"; + + Assert.Throws(() => _fiesta.ParseAndLinkPairingString(s)); + } + + // ---- Test 2: ApprovePairRequestAsync return value + TCS behavior ---- + + [Fact] + public async Task ApprovePairRequestAsync_SendFails_ReturnsFalse() + { + const string requestId = "req-test-001"; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + // Inject a pending pair request with a WebSocket that reports Open state + // but throws on SendAsync, simulating a race-condition socket failure. + var faultySocket = new FaultyOpenWebSocket(); + var pending = new PendingPairRequest + { + RequestId = requestId, + HostName = "test-host", + HostInstanceId = "host-id", + RemoteIp = "127.0.0.1", + Socket = faultySocket, + CompletionSource = tcs, + ExpiresAt = DateTime.UtcNow.AddSeconds(60) + }; + + var dictField = typeof(FiestaService).GetField("_pendingPairRequests", BindingFlags.NonPublic | BindingFlags.Instance)!; + var dict = (Dictionary)dictField.GetValue(_fiesta)!; + lock (dict) dict[requestId] = pending; + + var result = await _fiesta.ApprovePairRequestAsync(requestId); + + // Method returns false because SendAsync threw (approval message not delivered) + Assert.False(result); + // TCS is claimed true (approve won ownership) before the send attempt + Assert.True(tcs.Task.IsCompleted); + Assert.True(await tcs.Task); + } + + [Fact] + public async Task ApprovePairRequestAsync_UnknownRequestId_ReturnsFalse() + { + var result = await _fiesta.ApprovePairRequestAsync("nonexistent-id"); + Assert.False(result); + } + + // ---- Test 3: RequestPairAsync with Approved=true but null BridgeUrl ---- + + [Fact] + public async Task RequestPairAsync_ApprovedWithNullBridgeUrl_ReturnsUnreachable() + { + var port = GetFreePort(); + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(15)); + + // Stand up a minimal WebSocket server that responds with Approved=true but no BridgeUrl + var serverTask = Task.Run(async () => + { + var listener = new HttpListener(); + listener.Prefixes.Add($"http://127.0.0.1:{port}/"); + listener.Start(); + try + { + var ctx = await listener.GetContextAsync().WaitAsync(cts.Token); + if (!ctx.Request.IsWebSocketRequest) { ctx.Response.StatusCode = 400; ctx.Response.Close(); return; } + + var wsCtx = await ctx.AcceptWebSocketAsync(subProtocol: null); + var ws = wsCtx.WebSocket; + + // Read (and discard) the pair request + var buf = new byte[4096]; + await ws.ReceiveAsync(new ArraySegment(buf), cts.Token); + + // Send back Approved=true with no BridgeUrl / Token + var response = BridgeMessage.Create(BridgeMessageTypes.FiestaPairResponse, + new FiestaPairResponsePayload + { + RequestId = "req-null-url", + Approved = true, + BridgeUrl = null, + Token = null, + WorkerName = "worker" + }); + var bytes = Encoding.UTF8.GetBytes(response.Serialize()); + await ws.SendAsync(new ArraySegment(bytes), WebSocketMessageType.Text, true, cts.Token); + + // Best-effort close; client may have already closed + try { await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "done", cts.Token); } catch { } + } + catch (OperationCanceledException) { /* test timed out */ } + catch (Exception) { /* ignore server-side cleanup errors */ } + finally + { + listener.Stop(); + } + }, cts.Token); + + // Give the server a moment to bind + await Task.Delay(50, cts.Token); + + var worker = new FiestaDiscoveredWorker + { + InstanceId = "remote-id", + Hostname = "remote-box", + BridgeUrl = $"http://127.0.0.1:{port}" + }; + + var countBefore = _fiesta.LinkedWorkers.Count; + var result = await _fiesta.RequestPairAsync(worker, cts.Token); + + // An approved response with no BridgeUrl should be treated as Unreachable + Assert.Equal(PairRequestResult.Unreachable, result); + + // No new worker should have been linked by this call + Assert.Equal(countBefore, _fiesta.LinkedWorkers.Count); + Assert.DoesNotContain(_fiesta.LinkedWorkers, w => + string.Equals(w.Hostname, "remote-box", StringComparison.OrdinalIgnoreCase) || + w.BridgeUrl.Contains($"127.0.0.1:{port}")); + + await serverTask; + } + + // ---- Test 4: Concurrent approve + deny race — only one send occurs ---- + + [Fact] + public async Task ApprovePairRequestAsync_ConcurrentWithDeny_OnlyOneWins() + { + const string requestId = "req-race-001"; + var countingSocket = new CountingSendWebSocket(onSendAsync: (_, _) => Task.CompletedTask); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var pending = new PendingPairRequest + { + RequestId = requestId, + HostName = "race-host", + HostInstanceId = "race-id", + RemoteIp = "127.0.0.1", + Socket = countingSocket, + CompletionSource = tcs, + ExpiresAt = DateTime.UtcNow.AddSeconds(60) + }; + + var dictField = typeof(FiestaService).GetField("_pendingPairRequests", BindingFlags.NonPublic | BindingFlags.Instance)!; + var dict = (Dictionary)dictField.GetValue(_fiesta)!; + lock (dict) dict[requestId] = pending; + + // Race approve and deny concurrently — exactly one TrySetResult wins + var approveTask = _fiesta.ApprovePairRequestAsync(requestId); + var denyTask = _fiesta.DenyPairRequestAsync(requestId); + await Task.WhenAll(approveTask, denyTask); + + // Exactly one send should have occurred (the winner sends, the loser skips) + Assert.Equal(1, countingSocket.SendCount); + // TCS should be resolved exactly once + Assert.True(tcs.Task.IsCompleted); + // The winner's result should match the TCS value + var approveWon = await approveTask; + Assert.Equal(approveWon, await tcs.Task); + } + + // ---- Test 5: DenyPairRequestAsync sends exactly once, TCS resolves false ---- + + [Fact] + public async Task DenyPairRequestAsync_SendsOnce_TcsIsFalse() + { + const string requestId = "req-deny-order-001"; + var countingSocket = new CountingSendWebSocket(onSendAsync: (_, _) => Task.CompletedTask); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var pending = new PendingPairRequest + { + RequestId = requestId, + HostName = "deny-host", + HostInstanceId = "deny-id", + RemoteIp = "127.0.0.1", + Socket = countingSocket, + CompletionSource = tcs, + ExpiresAt = DateTime.UtcNow.AddSeconds(60) + }; + + var dictField = typeof(FiestaService).GetField("_pendingPairRequests", BindingFlags.NonPublic | BindingFlags.Instance)!; + var dict = (Dictionary)dictField.GetValue(_fiesta)!; + lock (dict) dict[requestId] = pending; + + await _fiesta.DenyPairRequestAsync(requestId); + + // Deny claimed TCS first (approve never tried) + Assert.True(tcs.Task.IsCompleted); + Assert.False(await tcs.Task); + Assert.Equal(1, countingSocket.SendCount); + } + + // ---- Test 6: EnsureServerPassword auto-generates when not pre-set ---- + + [Fact] + public async Task ApprovePairRequestAsync_AutoGeneratesPassword_WhenNotPreSet() + { + // Create a fresh bridge server with NO pre-set password so the auto-generate + // path in EnsureServerPassword is exercised. + // ConnectionSettings is already redirected to the test dir by TestSetup.Initialize(), + // so Load()/Save() will NOT touch ~/.polypilot/settings.json. + var freshBridge = new WsBridgeServer(); + Assert.True(string.IsNullOrWhiteSpace(freshBridge.ServerPassword), + "Precondition: ServerPassword must be empty before the test"); + + var freshFiesta = new FiestaService(_copilot, freshBridge, new TailscaleService()); + try + { + const string requestId = "req-autopass-001"; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var socket = new CountingSendWebSocket(onSendAsync: (_, _) => Task.CompletedTask); + var pending = new PendingPairRequest + { + RequestId = requestId, + HostName = "auto-host", + HostInstanceId = "auto-id", + RemoteIp = "127.0.0.1", + Socket = socket, + CompletionSource = tcs, + ExpiresAt = DateTime.UtcNow.AddSeconds(60) + }; + + var dictField = typeof(FiestaService).GetField("_pendingPairRequests", + BindingFlags.NonPublic | BindingFlags.Instance)!; + var dict = (Dictionary)dictField.GetValue(freshFiesta)!; + lock (dict) dict[requestId] = pending; + + var result = await freshFiesta.ApprovePairRequestAsync(requestId); + + // Should succeed and have auto-generated a non-empty password + Assert.True(result); + Assert.False(string.IsNullOrWhiteSpace(freshBridge.ServerPassword), + "EnsureServerPassword should have set a non-empty password on the bridge server"); + // Password should be URL-safe (no '+' or '/') + Assert.DoesNotContain("+", freshBridge.ServerPassword); + Assert.DoesNotContain("/", freshBridge.ServerPassword); + } + finally + { + freshFiesta.Dispose(); + freshBridge.Dispose(); + } + } + + // ---- Test 7: MaxPendingPairRequests constant = 5 ---- + + [Fact] + public void MaxPendingPairRequests_ConstantIs5() + { + // Verify the limit was raised from 1 to 5 per the review recommendation. + var field = typeof(FiestaService).GetField("MaxPendingPairRequests", + BindingFlags.NonPublic | BindingFlags.Static)!; + var value = (int)field.GetValue(null)!; + Assert.Equal(5, value); + } + + // ---- Test 8: HandleIncomingPairHandshakeAsync rejects when at capacity ---- + + [Fact] + public async Task HandleIncomingPairHandshake_AtCapacity_SendsDenialAndSkipsDict() + { + // Fill 5 slots directly (MaxPendingPairRequests = 5) + var dictField = typeof(FiestaService).GetField("_pendingPairRequests", + BindingFlags.NonPublic | BindingFlags.Instance)!; + var dict = (Dictionary)dictField.GetValue(_fiesta)!; + + for (int i = 0; i < 5; i++) + { + var id = $"slot-full-{i}"; + lock (dict) dict[id] = new PendingPairRequest + { + RequestId = id, + HostName = $"host-{i}", + HostInstanceId = $"inst-{i}", + RemoteIp = "127.0.0.1", + Socket = new CountingSendWebSocket(onSendAsync: (_, _) => Task.CompletedTask), + CompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously), + ExpiresAt = DateTime.UtcNow.AddSeconds(60) + }; + } + + int countBefore; + lock (dict) countBefore = dict.Count; + Assert.Equal(5, countBefore); + + // Build a FiestaPairRequest message for the 6th connection + const string overflowId = "slot-overflow"; + var pairRequestPayload = new FiestaPairRequestPayload + { + RequestId = overflowId, + HostName = "overflow-host", + HostInstanceId = "overflow-inst" + }; + var pairMsg = BridgeMessage.Create(BridgeMessageTypes.FiestaPairRequest, pairRequestPayload); + var msgBytes = System.Text.Encoding.UTF8.GetBytes(pairMsg.Serialize()); + + // Create a WebSocket that returns the pair request message on first ReceiveAsync, + // then captures whatever is sent back (should be Approved=false). + byte[]? responseBytes = null; + var responseSent = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var readCount = 0; + var overflowSocket = new CountingSendWebSocket(onSendAsync: (buf, _) => + { + responseBytes = buf.Array![buf.Offset..(buf.Offset + buf.Count)]; + responseSent.TrySetResult(true); + return Task.CompletedTask; + }) + { + ReceiveAsyncOverride = (buffer, ct) => + { + if (Interlocked.Increment(ref readCount) == 1) + { + // Return the FiestaPairRequest message + msgBytes.CopyTo(buffer.Array!, buffer.Offset); + return Task.FromResult(new WebSocketReceiveResult(msgBytes.Length, WebSocketMessageType.Text, true)); + } + // Subsequent reads: signal close + return Task.FromResult(new WebSocketReceiveResult(0, WebSocketMessageType.Close, true)); + } + }; + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + await _fiesta.HandleIncomingPairHandshakeAsync(overflowSocket, "127.0.0.1", cts.Token); + + // The overflow slot must NOT be in the pending dict + int countAfter; + lock (dict) countAfter = dict.Count; + Assert.Equal(5, countAfter); + lock (dict) Assert.False(dict.ContainsKey(overflowId), "Overflow request must not be in pending dict"); + + // Must have sent a denial + await Task.WhenAny(responseSent.Task, Task.Delay(3000)); + Assert.True(responseSent.Task.IsCompleted, "Overflow request should receive a denial response"); + Assert.NotNull(responseBytes); + var json = System.Text.Encoding.UTF8.GetString(responseBytes!); + var msg = JsonSerializer.Deserialize(json, + new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); + Assert.Equal(BridgeMessageTypes.FiestaPairResponse, msg?.Type); + var resp = msg?.GetPayload(); + Assert.NotNull(resp); + Assert.False(resp!.Approved, "Overflow request must be denied"); + } + + // ---- Test 8b: MaxPendingPairRequestsPerIp constant = 2 ---- + + [Fact] + public void MaxPendingPairRequestsPerIp_ConstantIs2() + { + var field = typeof(FiestaService).GetField("MaxPendingPairRequestsPerIp", + BindingFlags.NonPublic | BindingFlags.Static)!; + Assert.NotNull(field); + var value = (int)field.GetValue(null)!; + Assert.Equal(2, value); + } + + // ---- Test 8c: Per-IP rate limit — third request from same IP is denied ---- + + [Fact] + public async Task HandleIncomingPairHandshake_PerIpLimit_ThirdRequestFromSameIpDenied() + { + const string repeatIp = "10.0.0.42"; + var dictField = typeof(FiestaService).GetField("_pendingPairRequests", + BindingFlags.NonPublic | BindingFlags.Instance)!; + var dict = (Dictionary)dictField.GetValue(_fiesta)!; + + // Fill 2 slots with the same IP (at per-IP limit, but total < MaxPendingPairRequests) + for (int i = 0; i < 2; i++) + { + var id = $"per-ip-slot-{i}"; + lock (dict) dict[id] = new PendingPairRequest + { + RequestId = id, + HostName = $"host-{i}", + HostInstanceId = $"inst-{i}", + RemoteIp = repeatIp, + Socket = new CountingSendWebSocket(onSendAsync: (_, _) => Task.CompletedTask), + CompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously), + ExpiresAt = DateTime.UtcNow.AddSeconds(60) + }; + } + + int countBefore; + lock (dict) countBefore = dict.Count; + Assert.Equal(2, countBefore); + + const string thirdId = "per-ip-overflow"; + var pairRequestPayload = new FiestaPairRequestPayload + { + RequestId = thirdId, + HostName = "repeat-host", + HostInstanceId = "repeat-inst" + }; + var pairMsg = BridgeMessage.Create(BridgeMessageTypes.FiestaPairRequest, pairRequestPayload); + var msgBytes = System.Text.Encoding.UTF8.GetBytes(pairMsg.Serialize()); + + byte[]? responseBytes = null; + var responseSent = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var readCount = 0; + var overflowSocket = new CountingSendWebSocket(onSendAsync: (buf, _) => + { + responseBytes = buf.Array![buf.Offset..(buf.Offset + buf.Count)]; + responseSent.TrySetResult(true); + return Task.CompletedTask; + }) + { + ReceiveAsyncOverride = (buffer, ct) => + { + if (Interlocked.Increment(ref readCount) == 1) + { + msgBytes.CopyTo(buffer.Array!, buffer.Offset); + return Task.FromResult(new WebSocketReceiveResult(msgBytes.Length, WebSocketMessageType.Text, true)); + } + return Task.FromResult(new WebSocketReceiveResult(0, WebSocketMessageType.Close, true)); + } + }; + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + await _fiesta.HandleIncomingPairHandshakeAsync(overflowSocket, repeatIp, cts.Token); + + // Must not have added the third request + int countAfter; + lock (dict) countAfter = dict.Count; + Assert.Equal(2, countAfter); + lock (dict) Assert.False(dict.ContainsKey(thirdId), "Third request from same IP must not be in pending dict"); + + // Must have sent a denial + await Task.WhenAny(responseSent.Task, Task.Delay(3000)); + Assert.True(responseSent.Task.IsCompleted, "Third request from same IP should receive a denial"); + Assert.NotNull(responseBytes); + var json = System.Text.Encoding.UTF8.GetString(responseBytes!); + var msg = JsonSerializer.Deserialize(json, + new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }); + Assert.Equal(BridgeMessageTypes.FiestaPairResponse, msg?.Type); + var resp = msg?.GetPayload(); + Assert.NotNull(resp); + Assert.False(resp!.Approved, "Third request from same IP must be denied"); + } + + // ---- Test 9: DenyPairRequest when TCS already resolved (timeout path) ---- + + [Fact] + public async Task DenyPairRequestAsync_TcsAlreadyResolved_SkipsSend() + { + const string requestId = "req-already-resolved"; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + // Pre-resolve TCS to true (approve already won) + tcs.TrySetResult(true); + + var socket = new CountingSendWebSocket(onSendAsync: (_, _) => Task.CompletedTask); + var dictField = typeof(FiestaService).GetField("_pendingPairRequests", + BindingFlags.NonPublic | BindingFlags.Instance)!; + var dict = (Dictionary)dictField.GetValue(_fiesta)!; + lock (dict) dict[requestId] = new PendingPairRequest + { + RequestId = requestId, + HostName = "resolved-host", + HostInstanceId = "resolved-id", + RemoteIp = "127.0.0.1", + Socket = socket, + CompletionSource = tcs, + ExpiresAt = DateTime.UtcNow.AddSeconds(60) + }; + + await _fiesta.DenyPairRequestAsync(requestId); + + // TCS was already resolved — deny's TrySetResult(false) lost, no send + Assert.Equal(0, socket.SendCount); + } + + + + /// + /// A WebSocket that counts calls to SendAsync and optionally delegates to a custom action. + /// ReceiveAsyncOverride can be set to control what ReadSingleMessageAsync receives. + /// + private sealed class CountingSendWebSocket : WebSocket + { + private readonly Func, CancellationToken, Task> _onSendAsync; + public int SendCount; + + /// When set, ReceiveAsync delegates to this instead of returning Close. + public Func, CancellationToken, Task>? ReceiveAsyncOverride; + + public CountingSendWebSocket(Func, CancellationToken, Task> onSendAsync) + => _onSendAsync = onSendAsync; + + public override WebSocketState State => WebSocketState.Open; + public override WebSocketCloseStatus? CloseStatus => null; + public override string? CloseStatusDescription => null; + public override string? SubProtocol => null; + + public override void Abort() { } + public override Task CloseAsync(WebSocketCloseStatus c, string? d, CancellationToken ct) => Task.CompletedTask; + public override Task CloseOutputAsync(WebSocketCloseStatus c, string? d, CancellationToken ct) => Task.CompletedTask; + public override Task ReceiveAsync(ArraySegment buffer, CancellationToken ct) + => ReceiveAsyncOverride?.Invoke(buffer, ct) + ?? Task.FromResult(new WebSocketReceiveResult(0, WebSocketMessageType.Close, true)); + + public override async Task SendAsync(ArraySegment buffer, WebSocketMessageType type, bool end, CancellationToken ct) + { + Interlocked.Increment(ref SendCount); + await _onSendAsync(buffer, ct); + } + + public override void Dispose() { } + } + + /// + /// A WebSocket that passes the State == Open guard but throws on SendAsync, + /// simulating a socket that closes between the state check and the write. + /// + private sealed class FaultyOpenWebSocket : WebSocket + { + public override WebSocketState State => WebSocketState.Open; + public override WebSocketCloseStatus? CloseStatus => null; + public override string? CloseStatusDescription => null; + public override string? SubProtocol => null; + + public override void Abort() { } + + public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken ct) + => Task.CompletedTask; + + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken ct) + => Task.CompletedTask; + + public override Task ReceiveAsync(ArraySegment buffer, CancellationToken ct) + => Task.FromResult(new WebSocketReceiveResult(0, WebSocketMessageType.Close, true)); + + public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken ct) + => throw new WebSocketException("Simulated send failure after state check"); + + public override void Dispose() { } + } +} diff --git a/PolyPilot.Tests/FontSizingEnforcementTests.cs b/PolyPilot.Tests/FontSizingEnforcementTests.cs index b6aecccbf..386867885 100644 --- a/PolyPilot.Tests/FontSizingEnforcementTests.cs +++ b/PolyPilot.Tests/FontSizingEnforcementTests.cs @@ -77,6 +77,7 @@ private static readonly (string File, string ValuePattern, string Reason)[] CssF // Decorative elements beyond the type-scale range ("Settings.razor.css", @"^2rem$", "Decorative mode-icon — beyond type-scale range"), + ("Settings.razor.css", @"^0\.85em$", "Inline code (.onboarding-list code) — scales with parent text"), // Worker child items scale relative to parent — em is correct here ("SessionListItem.razor.css", @"^0\.85em$", "Worker child items scale relative to parent text"), diff --git a/PolyPilot.Tests/PolyPilot.Tests.csproj b/PolyPilot.Tests/PolyPilot.Tests.csproj index 43c8f833f..052f81c42 100644 --- a/PolyPilot.Tests/PolyPilot.Tests.csproj +++ b/PolyPilot.Tests/PolyPilot.Tests.csproj @@ -71,6 +71,7 @@ + @@ -92,6 +93,7 @@ + diff --git a/PolyPilot.Tests/ProcessHelperTests.cs b/PolyPilot.Tests/ProcessHelperTests.cs new file mode 100644 index 000000000..093e2aade --- /dev/null +++ b/PolyPilot.Tests/ProcessHelperTests.cs @@ -0,0 +1,242 @@ +using System.Diagnostics; +using PolyPilot.Services; + +namespace PolyPilot.Tests; + +/// +/// Tests for ProcessHelper — safe wrappers for Process lifecycle operations +/// that prevent InvalidOperationException / UnobservedTaskException crashes +/// when a process is disposed while background tasks are still monitoring it. +/// +public class ProcessHelperTests +{ + // ===== SafeHasExited ===== + + [Fact] + public void SafeHasExited_NullProcess_ReturnsTrue() + { + Assert.True(ProcessHelper.SafeHasExited(null)); + } + + [Fact] + public void SafeHasExited_DisposedProcess_ReturnsTrue() + { + // Start a short-lived process and dispose it immediately + var psi = new ProcessStartInfo + { + FileName = OperatingSystem.IsWindows() ? "cmd.exe" : "/bin/sh", + Arguments = OperatingSystem.IsWindows() ? "/c echo test" : "-c \"echo test\"", + UseShellExecute = false, + RedirectStandardOutput = true, + CreateNoWindow = true + }; + var process = Process.Start(psi)!; + process.WaitForExit(5000); + process.Dispose(); + + // After disposal, HasExited would throw InvalidOperationException. + // SafeHasExited must return true instead. + Assert.True(ProcessHelper.SafeHasExited(process)); + } + + [Fact] + public void SafeHasExited_ExitedProcess_ReturnsTrue() + { + var psi = new ProcessStartInfo + { + FileName = OperatingSystem.IsWindows() ? "cmd.exe" : "/bin/sh", + Arguments = OperatingSystem.IsWindows() ? "/c echo done" : "-c \"echo done\"", + UseShellExecute = false, + RedirectStandardOutput = true, + CreateNoWindow = true + }; + var process = Process.Start(psi)!; + process.WaitForExit(5000); + + Assert.True(ProcessHelper.SafeHasExited(process)); + process.Dispose(); + } + + [Fact] + public void SafeHasExited_RunningProcess_ReturnsFalse() + { + // Start a long-running process + var psi = new ProcessStartInfo + { + FileName = OperatingSystem.IsWindows() ? "cmd.exe" : "/bin/sh", + Arguments = OperatingSystem.IsWindows() ? "/c ping -n 30 127.0.0.1 > nul" : "-c \"sleep 30\"", + UseShellExecute = false, + CreateNoWindow = true + }; + var process = Process.Start(psi)!; + try + { + Assert.False(ProcessHelper.SafeHasExited(process)); + } + finally + { + try { process.Kill(true); } catch { } + process.Dispose(); + } + } + + // ===== SafeKill ===== + + [Fact] + public void SafeKill_NullProcess_DoesNotThrow() + { + ProcessHelper.SafeKill(null); + } + + [Fact] + public void SafeKill_DisposedProcess_DoesNotThrow() + { + var psi = new ProcessStartInfo + { + FileName = OperatingSystem.IsWindows() ? "cmd.exe" : "/bin/sh", + Arguments = OperatingSystem.IsWindows() ? "/c echo test" : "-c \"echo test\"", + UseShellExecute = false, + CreateNoWindow = true + }; + var process = Process.Start(psi)!; + process.WaitForExit(5000); + process.Dispose(); + + // Must not throw + ProcessHelper.SafeKill(process); + } + + [Fact] + public void SafeKill_RunningProcess_KillsIt() + { + var psi = new ProcessStartInfo + { + FileName = OperatingSystem.IsWindows() ? "cmd.exe" : "/bin/sh", + Arguments = OperatingSystem.IsWindows() ? "/c ping -n 30 127.0.0.1 > nul" : "-c \"sleep 30\"", + UseShellExecute = false, + CreateNoWindow = true + }; + var process = Process.Start(psi)!; + + ProcessHelper.SafeKill(process); + process.WaitForExit(5000); + Assert.True(process.HasExited); + process.Dispose(); + } + + // ===== SafeKillAndDispose ===== + + [Fact] + public void SafeKillAndDispose_NullProcess_DoesNotThrow() + { + ProcessHelper.SafeKillAndDispose(null); + } + + [Fact] + public void SafeKillAndDispose_AlreadyDisposed_DoesNotThrow() + { + var psi = new ProcessStartInfo + { + FileName = OperatingSystem.IsWindows() ? "cmd.exe" : "/bin/sh", + Arguments = OperatingSystem.IsWindows() ? "/c echo test" : "-c \"echo test\"", + UseShellExecute = false, + CreateNoWindow = true + }; + var process = Process.Start(psi)!; + process.WaitForExit(5000); + process.Dispose(); + + // Calling SafeKillAndDispose on already-disposed process must not throw + ProcessHelper.SafeKillAndDispose(process); + } + + [Fact] + public void SafeKillAndDispose_RunningProcess_KillsAndDisposes() + { + var psi = new ProcessStartInfo + { + FileName = OperatingSystem.IsWindows() ? "cmd.exe" : "/bin/sh", + Arguments = OperatingSystem.IsWindows() ? "/c ping -n 30 127.0.0.1 > nul" : "-c \"sleep 30\"", + UseShellExecute = false, + CreateNoWindow = true + }; + var process = Process.Start(psi)!; + var pid = process.Id; + + ProcessHelper.SafeKillAndDispose(process); + + // Verify the process is no longer running + try + { + var check = Process.GetProcessById(pid); + // Process might still be there for a moment — give it time + check.WaitForExit(2000); + } + catch (ArgumentException) + { + // Process already gone — expected + } + } + + // ===== Race condition regression test ===== + + [Fact] + public void SafeHasExited_ConcurrentDispose_NoUnobservedTaskException() + { + // Regression test: simulates the race condition where a background task + // checks HasExited while another thread disposes the process. + using var unobservedSignal = new ManualResetEventSlim(false); + Exception? unobservedException = null; + EventHandler handler = (sender, args) => + { + if (args.Exception?.InnerException is InvalidOperationException) + { + unobservedException = args.Exception; + unobservedSignal.Set(); + } + }; + + TaskScheduler.UnobservedTaskException += handler; + try + { + for (int i = 0; i < 5; i++) + { + var psi = new ProcessStartInfo + { + FileName = OperatingSystem.IsWindows() ? "cmd.exe" : "/bin/sh", + Arguments = OperatingSystem.IsWindows() ? "/c ping -n 10 127.0.0.1 > nul" : "-c \"sleep 10\"", + UseShellExecute = false, + CreateNoWindow = true + }; + var process = Process.Start(psi)!; + + // Background task monitoring HasExited (like DevTunnel's fire-and-forget tasks) + _ = Task.Run(() => + { + for (int j = 0; j < 50; j++) + { + if (ProcessHelper.SafeHasExited(process)) + break; + Thread.Sleep(10); + } + }); + + // Simulate concurrent disposal (like Stop() being called) + Thread.Sleep(50); + ProcessHelper.SafeKillAndDispose(process); + } + + // Force GC to surface any unobserved task exceptions + GC.Collect(); + GC.WaitForPendingFinalizers(); + GC.Collect(); + + unobservedSignal.Wait(TimeSpan.FromMilliseconds(500)); + Assert.Null(unobservedException); + } + finally + { + TaskScheduler.UnobservedTaskException -= handler; + } + } +} diff --git a/PolyPilot.Tests/ServerManagerTests.cs b/PolyPilot.Tests/ServerManagerTests.cs index 3c1d62e89..cc78fa109 100644 --- a/PolyPilot.Tests/ServerManagerTests.cs +++ b/PolyPilot.Tests/ServerManagerTests.cs @@ -8,6 +8,7 @@ namespace PolyPilot.Tests; /// Tests for ServerManager.CheckServerRunning to verify socket exceptions /// are properly observed and don't cause UnobservedTaskException crashes. /// +[Collection("SocketIsolated")] public class ServerManagerTests { [Fact] diff --git a/PolyPilot.Tests/SocketIsolatedCollection.cs b/PolyPilot.Tests/SocketIsolatedCollection.cs new file mode 100644 index 000000000..deb773fe7 --- /dev/null +++ b/PolyPilot.Tests/SocketIsolatedCollection.cs @@ -0,0 +1,12 @@ +using Xunit; + +namespace PolyPilot.Tests; + +/// +/// xUnit collection with DisableParallelization=true for tests that hook +/// TaskScheduler.UnobservedTaskException globally and call GC.Collect(). +/// Without this, SocketExceptions from unrelated parallel tests get surfaced +/// during GC and falsely trigger the handler. +/// +[CollectionDefinition("SocketIsolated", DisableParallelization = true)] +public class SocketIsolatedCollection { } diff --git a/PolyPilot.Tests/TestIsolationGuardTests.cs b/PolyPilot.Tests/TestIsolationGuardTests.cs index 2daa29eb1..9b3852f0b 100644 --- a/PolyPilot.Tests/TestIsolationGuardTests.cs +++ b/PolyPilot.Tests/TestIsolationGuardTests.cs @@ -88,10 +88,8 @@ public async Task CreateGroup_DoesNotTouchRealOrgFile() Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), ".polypilot", "organization.json"); - // Snapshot the real file's last-write time (if it exists) - var beforeTime = File.Exists(realOrgFile) - ? File.GetLastWriteTimeUtc(realOrgFile) - : (DateTime?)null; + // Use a unique sentinel group name to detect if the test write leaks to the real file + var sentinelGroupName = $"IsolationTest-{Guid.NewGuid():N}"; // Create a service and do something that triggers a write var services = new ServiceCollection(); @@ -99,21 +97,29 @@ public async Task CreateGroup_DoesNotTouchRealOrgFile() new StubChatDatabase(), new StubServerManager(), new StubWsBridgeClient(), new RepoManager(), services.BuildServiceProvider(), new StubDemoService()); - svc.CreateGroup("IsolationTest"); + svc.CreateGroup(sentinelGroupName); // Wait for the 2s debounce timer to fire await Task.Delay(3000); - // Verify the real file was NOT modified - if (beforeTime.HasValue) - { - var afterTime = File.GetLastWriteTimeUtc(realOrgFile); - Assert.Equal(beforeTime.Value, afterTime); - } - // Verify the write went to the test directory instead var testOrgFile = Path.Combine(TestSetup.TestBaseDir, "organization.json"); Assert.True(File.Exists(testOrgFile), $"Organization file should have been written to test dir: {testOrgFile}"); + var testOrgContent = await File.ReadAllTextAsync(testOrgFile); + // Note: we don't assert sentinelGroupName in testOrgContent because parallel tests + // also write to the shared TestSetup.TestBaseDir and may have overwritten it. + // The key isolation invariant is: (a) the file exists in test dir, and (b) the + // sentinel does NOT appear in the real file. + _ = testOrgContent; // read to confirm the file is readable + + // Verify the sentinel group did NOT leak into the real file. + // We check content rather than timestamps because the running app also + // writes to the real file during normal operation. + if (File.Exists(realOrgFile)) + { + var realContent = await File.ReadAllTextAsync(realOrgFile); + Assert.DoesNotContain(sentinelGroupName, realContent); + } } } diff --git a/PolyPilot.Tests/TestSetup.cs b/PolyPilot.Tests/TestSetup.cs index 2a65dfca1..c47e28f41 100644 --- a/PolyPilot.Tests/TestSetup.cs +++ b/PolyPilot.Tests/TestSetup.cs @@ -1,4 +1,5 @@ using System.Runtime.CompilerServices; +using PolyPilot.Models; using PolyPilot.Services; namespace PolyPilot.Tests; @@ -12,8 +13,11 @@ namespace PolyPilot.Tests; /// This has caused production data loss (squad groups destroyed) multiple times. /// /// This runs automatically via [ModuleInitializer] before any test executes. -/// If you add new file paths to CopilotService, you MUST also clear them -/// in SetBaseDirForTesting() or they will leak to the real filesystem. +/// If you add new file paths to CopilotService or any service that persists state, +/// you MUST also redirect them in Initialize() or they will leak to the real filesystem. +/// +/// Currently isolated: CopilotService BaseDir/CaptureDir, RepoManager, AuditLogService, +/// PromptLibraryService, FiestaService state file, ConnectionSettings settings file. /// internal static class TestSetup { @@ -29,5 +33,7 @@ internal static void Initialize() RepoManager.SetBaseDirForTesting(TestBaseDir); AuditLogService.SetLogDirForTesting(Path.Combine(TestBaseDir, "audit_logs")); PromptLibraryService.SetUserPromptsDirForTesting(Path.Combine(TestBaseDir, "prompts")); + FiestaService.SetStateFilePathForTesting(Path.Combine(TestBaseDir, "fiesta.json")); + ConnectionSettings.SetSettingsFilePathForTesting(Path.Combine(TestBaseDir, "settings.json")); } } diff --git a/PolyPilot.Tests/TurnEndFallbackTests.cs b/PolyPilot.Tests/TurnEndFallbackTests.cs index ff341a700..471209dee 100644 --- a/PolyPilot.Tests/TurnEndFallbackTests.cs +++ b/PolyPilot.Tests/TurnEndFallbackTests.cs @@ -67,7 +67,9 @@ public async Task FallbackTimer_NotCancelled_FiresAfterDelay() { // Verify the Task.Run+Task.Delay pattern fires its completion action // when the CTS is never cancelled. Uses 50ms to keep the test fast. - var fired = false; + // Uses TaskCompletionSource instead of a bool field to avoid memory-ordering + // issues and to provide a reliable, load-tolerant signal mechanism. + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); using var cts = new CancellationTokenSource(); var token = cts.Token; @@ -76,14 +78,18 @@ public async Task FallbackTimer_NotCancelled_FiresAfterDelay() try { await Task.Delay(50, token); - if (token.IsCancellationRequested) return; - fired = true; + if (!token.IsCancellationRequested) + tcs.TrySetResult(true); + else + tcs.TrySetResult(false); } - catch (OperationCanceledException) { } + catch (OperationCanceledException) { tcs.TrySetResult(false); } }); - await Task.Delay(500); - Assert.True(fired, "Fallback timer should fire when CTS is not cancelled"); + // Wait up to 5s — robust against thread-pool starvation under heavy parallel load + var completedTask = await Task.WhenAny(tcs.Task, Task.Delay(5000)); + Assert.True(completedTask == tcs.Task, "Fallback timer task should complete within 5s"); + Assert.True(tcs.Task.Result, "Fallback timer should fire when CTS is not cancelled"); } [Fact] diff --git a/PolyPilot/Components/Pages/Dashboard.razor b/PolyPilot/Components/Pages/Dashboard.razor index 9ecfc0909..32aefb5c7 100644 --- a/PolyPilot/Components/Pages/Dashboard.razor +++ b/PolyPilot/Components/Pages/Dashboard.razor @@ -808,7 +808,7 @@ try { // Try JSON format first: { "url": "...", "token": "...", "lanUrl": "...", "lanToken": "..." } - var doc = System.Text.Json.JsonDocument.Parse(result); + using var doc = System.Text.Json.JsonDocument.Parse(result); if (doc.RootElement.TryGetProperty("url", out var urlProp)) mobileRemoteUrl = urlProp.GetString() ?? ""; if (doc.RootElement.TryGetProperty("token", out var tokenProp)) diff --git a/PolyPilot/Components/Pages/Settings.razor b/PolyPilot/Components/Pages/Settings.razor index b4354eee3..444a31ef2 100644 --- a/PolyPilot/Components/Pages/Settings.razor +++ b/PolyPilot/Components/Pages/Settings.razor @@ -230,6 +230,15 @@

Direct Connection

Share your server directly over LAN, Tailscale, or VPN — no DevTunnel needed.

+
+

🖥️ Setting up this machine as a Fiesta worker?

+
    +
  1. Set a password below and click Enable Direct Sharing.
  2. +
  3. Copy the pairing string (pp+…) that appears.
  4. +
  5. On the host machine, go to Settings → Fiesta Workers → paste the string → click Link.
  6. +
  7. The host can then dispatch tasks to this machine using @@worker-name mentions in chat.
  8. +
+
@@ -257,11 +266,24 @@ @if (TailscaleService.IsRunning) {
- + http://@(TailscaleService.MagicDnsName ?? TailscaleService.TailscaleIp):@DevTunnelService.BridgePort
} + else + { +
+

🌐 Want to use Fiesta across different networks?

+

Install Tailscale (free) to connect machines that aren't on the same LAN.

+
    +
  1. Download from tailscale.com/download
  2. +
  3. Install and sign in on both machines (same account)
  4. +
  5. Restart PolyPilot — it auto-detects Tailscale
  6. +
+

Tailscale creates a secure private network between your devices, so Fiesta pairing strings work across the internet.

+
+ } @foreach (var ip in localIps) {
@@ -270,6 +292,29 @@
} + @if (fiestaPairingString != null) + { +
+ + @fiestaPairingString + + +
+

Paste this string into the hub machine's Settings → Fiesta Workers → "Paste pairing string" field.

+ } + @if (!string.IsNullOrEmpty(directQrCodeDataUri)) {
@@ -288,11 +333,60 @@ @if (PlatformHelper.IsDesktop && (settings.Mode == ConnectionMode.Embedded || settings.Mode == ConnectionMode.Persistent)) { -
+

Fiesta Workers

+
+

🎉 How Fiesta works

+

Fiesta lets this machine act as a hub that dispatches work to linked worker machines on your LAN.

+
    +
  1. On each worker machine: open Settings → Direct Connection → set a password → click Enable Direct Sharing → copy the pp+… pairing string.
  2. +
  3. Here (hub): paste that string into the field below and click Link. The worker appears in "Linked workers".
  4. +
  5. In any chat: use @@worker-name to send a task to that machine. It runs autonomously and returns results.
  6. +
+

Discovered workers require manual linking before they can be used in Fiesta mode.

+ @* Incoming pair requests (worker side) *@ + @foreach (var req in FiestaService.PendingPairRequests) + { + var secondsLeft = Math.Max(0, (int)(req.ExpiresAt - DateTime.UtcNow).TotalSeconds); +
+

@req.HostName (@req.RemoteIp) wants to pair with this machine. Expires in @secondsLeft s

+
+ + +
+
+ } + + @* Worker side — show pairing string when bridge is running *@ + @if (WsBridgeServer.IsRunning && fiestaPairingString != null) + { +
+ +

Copy this on the worker machine and paste it on the hub to link instantly. Works via RDP clipboard, SSH, or any text channel.

+
+ @fiestaPairingString + + +
+
+ } + + @* Discovered LAN workers with Request Pair button *@ @if (FiestaService.DiscoveredWorkers.Any()) {
@@ -301,7 +395,15 @@ {
@worker.Hostname — @worker.BridgeUrl - + @if (pendingOutgoingPairs.ContainsKey(worker.InstanceId)) + { + Waiting for approval… + } + else + { + + + }
}
@@ -311,6 +413,21 @@

No workers discovered yet. Enable Direct Sharing on worker machines first.

} + @* Paste pairing string (hub side) *@ +
+ +
+ + +
+
+ @if (!string.IsNullOrEmpty(fiestaPasteError)) + { +

@fiestaPasteError

+ } + + @* Manual form *@
@@ -637,6 +754,12 @@ private string fiestaLinkUrl = ""; private string fiestaLinkToken = ""; private string? fiestaLinkError; + private string fiestaPasteString = ""; + private string? fiestaPasteError; + private string? fiestaPairingString; + private bool showFiestaPairingString; + private Dictionary pendingOutgoingPairs = new(); + private string? pairStatusMessage; private SettingsContext settingsCtx = null!; private List discoveredPlugins = new(); @@ -786,6 +909,8 @@ DevTunnelService.OnStateChanged += OnTunnelStateChanged; GitAutoUpdate.OnStateChanged += OnAutoUpdateStateChanged; FiestaService.OnStateChanged += OnFiestaStateChanged; + FiestaService.OnPairRequested += OnFiestaPairRequested; + FiestaService.OnPairApprovalSendFailed += OnFiestaPairApprovalSendFailed; var uiState = CopilotService.LoadUiState(); if (uiState?.FontSize > 0) fontSize = uiState.FontSize; @@ -797,7 +922,10 @@ GenerateQrCode(DevTunnelService.TunnelUrl, DevTunnelService.AccessToken); if (WsBridgeServer.IsRunning) + { GenerateDirectQrCode(); + TryGenerateFiestaPairingString(); + } } protected override void OnAfterRender(bool firstRender) @@ -882,6 +1010,8 @@ DevTunnelService.OnStateChanged -= OnTunnelStateChanged; GitAutoUpdate.OnStateChanged -= OnAutoUpdateStateChanged; FiestaService.OnStateChanged -= OnFiestaStateChanged; + FiestaService.OnPairRequested -= OnFiestaPairRequested; + FiestaService.OnPairApprovalSendFailed -= OnFiestaPairApprovalSendFailed; _ = JS.InvokeVoidAsync("eval", "document.querySelector('article.content')?.classList.remove('settings-content-active');"); _ = JS.InvokeVoidAsync("eval", "window.__settingsRef = null;"); _selfRef?.Dispose(); @@ -1062,7 +1192,7 @@ string? url = null; try { - var doc = System.Text.Json.JsonDocument.Parse(result); + using var doc = System.Text.Json.JsonDocument.Parse(result); if (doc.RootElement.TryGetProperty("url", out var urlProp)) url = urlProp.GetString(); if (doc.RootElement.TryGetProperty("token", out var tokenProp)) @@ -1091,6 +1221,7 @@ settings.RemoteUrl = url; ShowStatus("QR code scanned!", "success"); + StateHasChanged(); } private async Task TunnelLogin() @@ -1314,6 +1445,105 @@ ShowStatus("Fiesta worker removed", "success", 2000); } + private void TryGenerateFiestaPairingString() + { + try + { + var preferredHost = TailscaleService.MagicDnsName ?? TailscaleService.TailscaleIp; + fiestaPairingString = FiestaService.GeneratePairingString(preferredHost); + } + catch (Exception ex) + { + fiestaPairingString = null; + Console.WriteLine($"[Settings] Failed to generate pairing string: {ex.Message}"); + ShowStatus($"Could not generate pairing string: {ex.Message}", "error", 8000); + } + } + + private async Task CopyFiestaPairingString() + { + if (fiestaPairingString != null) + { + await Microsoft.Maui.ApplicationModel.DataTransfer.Clipboard.SetTextAsync(fiestaPairingString); + ShowStatus("Pairing string copied!", "success", 2000); + } + } + + private void ImportFiestaPairingString() + { + fiestaPasteError = null; + try + { + FiestaService.ParseAndLinkPairingString(fiestaPasteString.Trim()); + fiestaPasteString = ""; + ShowStatus("Worker linked via pairing string!", "success", 2500); + } + catch (Exception ex) + { + fiestaPasteError = ex.Message; + } + } + + private async Task RequestPairFromWorker(FiestaDiscoveredWorker worker) + { + pendingOutgoingPairs[worker.InstanceId] = null; + pairStatusMessage = null; + StateHasChanged(); + try + { + var result = await FiestaService.RequestPairAsync(worker); + pendingOutgoingPairs[worker.InstanceId] = result; + pairStatusMessage = result switch + { + PairRequestResult.Approved => $"{worker.Hostname} approved the pairing request!", + PairRequestResult.Denied => $"{worker.Hostname} denied the pairing request.", + PairRequestResult.Timeout => $"No response from {worker.Hostname} (timed out).", + _ => $"Could not reach {worker.Hostname}." + }; + var kind = result == PairRequestResult.Approved ? "success" : "error"; + ShowStatus(pairStatusMessage, kind, 4000); + } + catch (Exception ex) + { + pendingOutgoingPairs[worker.InstanceId] = PairRequestResult.Unreachable; + ShowStatus($"Pair request failed: {ex.Message}", "error", 4000); + } + finally + { + pendingOutgoingPairs.Remove(worker.InstanceId); + await InvokeAsync(StateHasChanged); + } + } + + private async Task ApproveFiestaPairRequest(string requestId) + { + var success = await FiestaService.ApprovePairRequestAsync(requestId); + if (success) + ShowStatus("Pair request approved — worker linked!", "success", 2500); + else + ShowStatus("Failed to send approval — worker may not have received credentials.", "error", 3000); + } + + private async Task DenyFiestaPairRequest(string requestId) + { + await FiestaService.DenyPairRequestAsync(requestId); + ShowStatus("Pair request denied.", "error", 2000); + } + + private void OnFiestaPairRequested(string requestId, string hostName, string remoteIp) + { + InvokeAsync(StateHasChanged); + } + + private void OnFiestaPairApprovalSendFailed(string requestId, string errorMessage) + { + InvokeAsync(() => + { + ShowStatus("Approval send failed — ask the host to re-initiate pairing.", "error", 5000); + StateHasChanged(); + }); + } + private static string? TryExtractHost(string url) { try @@ -1448,9 +1678,15 @@ WsBridgeServer.ServerPassword = settings.ServerPassword; WsBridgeServer.SetCopilotService(CopilotService); WsBridgeServer.Start(DevTunnelService.BridgePort, settings.Port); + if (!WsBridgeServer.IsRunning) + { + ShowStatus($"Failed to start bridge server on port {DevTunnelService.BridgePort} — the port may already be in use.", "error", 10000); + return; + } settings.DirectSharingEnabled = true; settings.Save(); GenerateDirectQrCode(); + TryGenerateFiestaPairingString(); StateHasChanged(); } @@ -1460,6 +1696,7 @@ settings.DirectSharingEnabled = false; settings.Save(); directQrCodeDataUri = null; + fiestaPairingString = null; StateHasChanged(); } diff --git a/PolyPilot/Components/Pages/Settings.razor.css b/PolyPilot/Components/Pages/Settings.razor.css index cb471c066..51a2b19c5 100644 --- a/PolyPilot/Components/Pages/Settings.razor.css +++ b/PolyPilot/Components/Pages/Settings.razor.css @@ -666,6 +666,65 @@ margin: 0; } +.pair-request-banner { + display: flex; + flex-direction: column; + gap: 0.5rem; + padding: 0.75rem 1rem; + background: rgba(var(--accent-rgb, 59,130,246), 0.12); + border: 1px solid rgba(var(--accent-rgb, 59,130,246), 0.35); + border-radius: 8px; +} + +.pair-request-banner p { + margin: 0; + font-size: var(--type-body); +} + +.pair-request-actions { + display: flex; + gap: 0.5rem; +} + +.pair-expiry { + font-size: var(--type-callout); + opacity: 0.65; + margin-left: 0.25rem; +} + +.onboarding-steps { + padding: 0.75rem 1rem; + background: rgba(var(--accent-rgb, 59,130,246), 0.07); + border: 1px solid rgba(var(--accent-rgb, 59,130,246), 0.2); + border-radius: 8px; + display: flex; + flex-direction: column; + gap: 0.5rem; +} + +.onboarding-heading { + margin: 0; + font-size: var(--type-body); +} + +.onboarding-list { + margin: 0; + padding-left: 1.25rem; + display: flex; + flex-direction: column; + gap: 0.35rem; + font-size: var(--type-body); + color: var(--text-dim); +} + +.onboarding-list li { + line-height: 1.5; +} + +.onboarding-list code { + font-size: 0.85em; +} + .tunnel-url-section { display: flex; flex-direction: column; diff --git a/PolyPilot/Models/BridgeMessages.cs b/PolyPilot/Models/BridgeMessages.cs index 058caf9b6..391e17941 100644 --- a/PolyPilot/Models/BridgeMessages.cs +++ b/PolyPilot/Models/BridgeMessages.cs @@ -123,6 +123,10 @@ public static class BridgeMessageTypes public const string FiestaTaskError = "fiesta_task_error"; public const string FiestaPing = "fiesta_ping"; public const string FiestaPong = "fiesta_pong"; + + // Fiesta push-to-pair (unauthenticated /pair WebSocket path) + public const string FiestaPairRequest = "fiesta_pair_request"; + public const string FiestaPairResponse = "fiesta_pair_response"; } // --- Server → Client payloads --- @@ -457,6 +461,22 @@ public class FiestaPongPayload public string Sender { get; set; } = ""; } +public class FiestaPairRequestPayload +{ + public string RequestId { get; set; } = ""; + public string HostInstanceId { get; set; } = ""; + public string HostName { get; set; } = ""; +} + +public class FiestaPairResponsePayload +{ + public string RequestId { get; set; } = ""; + public bool Approved { get; set; } + public string? BridgeUrl { get; set; } + public string? Token { get; set; } + public string? WorkerName { get; set; } +} + // --- Repo bridge payloads --- public class AddRepoPayload diff --git a/PolyPilot/Models/ConnectionSettings.cs b/PolyPilot/Models/ConnectionSettings.cs index 1fd653b60..1258bfd77 100644 --- a/PolyPilot/Models/ConnectionSettings.cs +++ b/PolyPilot/Models/ConnectionSettings.cs @@ -165,6 +165,9 @@ public string? ServerPassword private static string SettingsPath => _settingsPath ??= Path.Combine( GetPolyPilotDir(), "settings.json"); + /// For test isolation only — redirects Load()/Save() to a temp file. + internal static void SetSettingsFilePathForTesting(string? path) => _settingsPath = path; + private static string GetPolyPilotDir() { #if IOS || ANDROID diff --git a/PolyPilot/Models/FiestaModels.cs b/PolyPilot/Models/FiestaModels.cs index 95709199c..74600ca13 100644 --- a/PolyPilot/Models/FiestaModels.cs +++ b/PolyPilot/Models/FiestaModels.cs @@ -1,3 +1,4 @@ +using System.Net.WebSockets; using System.Text.Json.Serialization; namespace PolyPilot.Models; @@ -66,3 +67,39 @@ public class FiestaDispatchResult public int DispatchCount { get; set; } public List UnresolvedMentions { get; set; } = new(); } + +// --- Pairing string --- + +public class FiestaPairingPayload +{ + public string Url { get; set; } = ""; + public string Token { get; set; } = ""; + public string Hostname { get; set; } = ""; +} + +// --- Push-to-pair --- + +public enum PairRequestResult { Approved, Denied, Timeout, Unreachable } + +/// Read-only view of a pending pair request for UI consumption. +public class PendingPairRequestInfo +{ + public string RequestId { get; set; } = ""; + public string HostName { get; set; } = ""; + public string RemoteIp { get; set; } = ""; + public DateTime ExpiresAt { get; set; } +} + +internal class PendingPairRequest +{ + public string RequestId { get; set; } = ""; + public string HostName { get; set; } = ""; + public string HostInstanceId { get; set; } = ""; + public string RemoteIp { get; set; } = ""; + public WebSocket Socket { get; set; } = null!; + public TaskCompletionSource CompletionSource { get; set; } = new(TaskCreationOptions.RunContinuationsAsynchronously); + /// Resolved by the winner after its SendAsync completes, so HandleIncomingPairHandshakeAsync + /// can wait for the send to finish before returning (which lets the caller close the socket safely). + public TaskCompletionSource SendComplete { get; } = new(TaskCreationOptions.RunContinuationsAsynchronously); + public DateTime ExpiresAt { get; set; } +} diff --git a/PolyPilot/PolyPilot.csproj b/PolyPilot/PolyPilot.csproj index bbdd8126a..89cb657c4 100644 --- a/PolyPilot/PolyPilot.csproj +++ b/PolyPilot/PolyPilot.csproj @@ -81,8 +81,8 @@ - - + + diff --git a/PolyPilot/QrScannerPage.xaml.cs b/PolyPilot/QrScannerPage.xaml.cs index 86bba27d1..7deb58368 100644 --- a/PolyPilot/QrScannerPage.xaml.cs +++ b/PolyPilot/QrScannerPage.xaml.cs @@ -56,6 +56,11 @@ private void LayoutOverlays(double pageWidth, double pageHeight) overlayRight.WidthRequest = pageWidth - left - cutoutSize; overlayRight.HeightRequest = cutoutSize; overlayRight.Margin = new Thickness(0, top, 0, 0); + + overlayTop.IsVisible = true; + overlayBottom.IsVisible = true; + overlayLeft.IsVisible = true; + overlayRight.IsVisible = true; } protected override async void OnAppearing() diff --git a/PolyPilot/Services/CodespaceService.cs b/PolyPilot/Services/CodespaceService.cs index 2c9f30951..096cb4053 100644 --- a/PolyPilot/Services/CodespaceService.cs +++ b/PolyPilot/Services/CodespaceService.cs @@ -32,6 +32,7 @@ public sealed class TunnelHandle : IAsyncDisposable public int LocalPort { get; } public bool IsSshTunnel { get; } private readonly Process _process; + private volatile bool _disposed; internal TunnelHandle(int localPort, Process process, bool isSshTunnel = false) { @@ -40,18 +41,19 @@ internal TunnelHandle(int localPort, Process process, bool isSshTunnel = false) IsSshTunnel = isSshTunnel; } - public bool IsAlive => !_process.HasExited; + public bool IsAlive => !_disposed && !ProcessHelper.SafeHasExited(_process); public async ValueTask DisposeAsync() { + _disposed = true; try { - if (!_process.HasExited) + if (!ProcessHelper.SafeHasExited(_process)) _process.Kill(entireProcessTree: true); await _process.WaitForExitAsync(CancellationToken.None).WaitAsync(TimeSpan.FromSeconds(3)); } catch { } - _process.Dispose(); + try { _process.Dispose(); } catch { } } } /// @@ -137,6 +139,11 @@ public async ValueTask DisposeAsync() public async Task OpenSshTunnelAsync( string codespaceName, int remotePort = 4321, int connectTimeoutSeconds = 30) { + if (codespaceName.Length > 255 || !System.Text.RegularExpressions.Regex.IsMatch(codespaceName, @"^[a-zA-Z0-9\-]+$")) + throw new ArgumentException("Invalid codespace name.", nameof(codespaceName)); + if (remotePort < 1 || remotePort > 65535) + throw new ArgumentOutOfRangeException(nameof(remotePort), "Port must be between 1 and 65535."); + var localPort = FindFreePort(); var psi = new ProcessStartInfo @@ -236,7 +243,10 @@ public async ValueTask DisposeAsync() var authCmd = ""; if (!string.IsNullOrEmpty(localToken)) { - authCmd = $"gh auth login --with-token <<< '{localToken.Replace("'", "'\\''")}' 2>/dev/null; "; + // Base64-encode the token so it can be safely embedded in a shell command + // with no quoting or escaping needed (base64 output is [A-Za-z0-9+/=] only). + var b64Token = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(localToken)); + authCmd = $"echo {b64Token} | base64 -d | gh auth login --with-token 2>/dev/null; "; Console.WriteLine($"[CodespaceService] Injecting gh auth token into codespace SSH session"); } diff --git a/PolyPilot/Services/CopilotService.Utilities.cs b/PolyPilot/Services/CopilotService.Utilities.cs index 918d1c0ae..221c397be 100644 --- a/PolyPilot/Services/CopilotService.Utilities.cs +++ b/PolyPilot/Services/CopilotService.Utilities.cs @@ -494,6 +494,10 @@ internal static bool IsInitializationError(Exception ex) => /// internal static bool IsProcessError(Exception ex) { + // NOTE: "No process is associated" is an English BCL string from System.Diagnostics.Process. + // .NET Core / .NET 5+ does NOT localize exception messages, so this is safe for all + // supported runtimes. If .NET ever starts localizing, add a secondary check on the + // call stack (e.g., Process.HasExited) or catch the exception at a higher level. if (ex is InvalidOperationException && ex.Message.Contains("No process is associated", StringComparison.OrdinalIgnoreCase)) return true; if (ex is AggregateException agg) diff --git a/PolyPilot/Services/DevTunnelService.cs b/PolyPilot/Services/DevTunnelService.cs index 27b9e8566..c4800d7fc 100644 --- a/PolyPilot/Services/DevTunnelService.cs +++ b/PolyPilot/Services/DevTunnelService.cs @@ -295,11 +295,7 @@ public async Task HostAsync(int copilotPort) private async Task TryHostTunnelAsync(ConnectionSettings settings) { // Kill any existing host process from a previous attempt - if (_hostProcess != null && !_hostProcess.HasExited) - { - try { _hostProcess.Kill(entireProcessTree: true); } catch { } - } - _hostProcess?.Dispose(); + ProcessHelper.SafeKillAndDispose(_hostProcess); _hostProcess = null; var hostArgs = _tunnelId != null @@ -323,6 +319,9 @@ private async Task TryHostTunnelAsync(ConnectionSettings settings) return false; } + // Capture in local variable — fire-and-forget tasks must not access _hostProcess + // field, which can be nulled/disposed by Stop() or a subsequent TryHostTunnelAsync(). + var process = _hostProcess; var urlFound = new TaskCompletionSource(); var lastErrorLine = ""; @@ -330,9 +329,9 @@ private async Task TryHostTunnelAsync(ConnectionSettings settings) { try { - while (!_hostProcess.HasExited) + while (!ProcessHelper.SafeHasExited(process)) { - var line = await _hostProcess.StandardOutput.ReadLineAsync(); + var line = await process.StandardOutput.ReadLineAsync(); if (line == null) break; Console.WriteLine($"[DevTunnel] {line}"); if (!string.IsNullOrWhiteSpace(line)) @@ -347,9 +346,9 @@ private async Task TryHostTunnelAsync(ConnectionSettings settings) { try { - while (!_hostProcess.HasExited) + while (!ProcessHelper.SafeHasExited(process)) { - var line = await _hostProcess.StandardError.ReadLineAsync(); + var line = await process.StandardError.ReadLineAsync(); if (line == null) break; Console.WriteLine($"[DevTunnel ERR] {line}"); if (!string.IsNullOrWhiteSpace(line)) @@ -475,12 +474,9 @@ public void Stop(bool cleanClose = true) _ = _auditLog?.LogSessionClosed(null, 0, cleanClose, cleanClose ? "DevTunnel stopped" : "DevTunnel stopped after error"); try { - if (_hostProcess != null && !_hostProcess.HasExited) - { - _hostProcess.Kill(entireProcessTree: true); + if (!ProcessHelper.SafeHasExited(_hostProcess)) Console.WriteLine("[DevTunnel] Host process killed"); - } - _hostProcess?.Dispose(); + ProcessHelper.SafeKillAndDispose(_hostProcess); } catch (Exception ex) { diff --git a/PolyPilot/Services/FiestaService.cs b/PolyPilot/Services/FiestaService.cs index b3412c80e..666c3a2f2 100644 --- a/PolyPilot/Services/FiestaService.cs +++ b/PolyPilot/Services/FiestaService.cs @@ -13,12 +13,15 @@ namespace PolyPilot.Services; public class FiestaService : IDisposable { private const int DiscoveryPort = 43223; + private const int MaxPendingPairRequests = 5; + private const int MaxPendingPairRequestsPerIp = 2; private static readonly TimeSpan DiscoveryInterval = TimeSpan.FromSeconds(5); private static readonly TimeSpan DiscoveryStaleAfter = TimeSpan.FromSeconds(20); private static readonly Regex MentionRegex = new(@"(?[A-Za-z0-9._-]+)", RegexOptions.Compiled); private readonly CopilotService _copilot; private readonly WsBridgeServer _bridgeServer; + private readonly TailscaleService? _tailscale; private readonly ConcurrentDictionary _discoveredWorkers = new(StringComparer.OrdinalIgnoreCase); private readonly Dictionary _activeFiestas = new(StringComparer.Ordinal); private readonly object _stateLock = new(); @@ -34,40 +37,34 @@ public class FiestaService : IDisposable private Task? _broadcastTask; private Task? _listenTask; private static string? _stateFilePath; + private readonly Dictionary _pendingPairRequests = new(StringComparer.Ordinal); + + internal static void SetStateFilePathForTesting(string path) => _stateFilePath = path; public event Action? OnStateChanged; public event Action? OnHostTaskUpdate; - - public FiestaService(CopilotService copilot, WsBridgeServer bridgeServer) + /// Fires on the worker side when a remote host requests pairing. Args: requestId, hostName, remoteIp. + public event Action? OnPairRequested; + /// + /// Fires when ApprovePairRequestAsync succeeds in claiming the TCS but the send fails. + /// The pairing cannot be completed for this request — the host will time out and show "Unreachable". + /// UI should prompt the user to retry pairing from the host side. + /// Args: requestId, errorMessage. + /// + public event Action? OnPairApprovalSendFailed; + + public FiestaService(CopilotService copilot, WsBridgeServer bridgeServer, TailscaleService tailscale) { _copilot = copilot; _bridgeServer = bridgeServer; + _tailscale = tailscale; _bridgeServer.SetFiestaService(this); LoadState(); if (PlatformHelper.IsDesktop) StartDiscovery(); } - private static string StateFilePath => _stateFilePath ??= Path.Combine(GetPolyPilotBaseDir(), "fiesta.json"); - - private static string GetPolyPilotBaseDir() - { - try - { -#if IOS || ANDROID - return Path.Combine(FileSystem.AppDataDirectory, ".polypilot"); -#else - var home = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); - if (string.IsNullOrEmpty(home)) - home = Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData); - return Path.Combine(home, ".polypilot"); -#endif - } - catch - { - return Path.Combine(Path.GetTempPath(), ".polypilot"); - } - } + private static string StateFilePath => _stateFilePath ??= Path.Combine(CopilotService.BaseDir, "fiesta.json"); public IReadOnlyList DiscoveredWorkers => _discoveredWorkers.Values @@ -110,16 +107,20 @@ public bool IsFiestaActive(string sessionName) } public void LinkWorker(string name, string hostname, string bridgeUrl, string token) + => LinkWorkerAndReturn(name, hostname, bridgeUrl, token); + + private FiestaLinkedWorker? LinkWorkerAndReturn(string name, string hostname, string bridgeUrl, string token) { var normalizedUrl = NormalizeBridgeUrl(bridgeUrl); if (string.IsNullOrWhiteSpace(normalizedUrl) || string.IsNullOrWhiteSpace(token)) - return; + return null; var workerName = string.IsNullOrWhiteSpace(name) ? (!string.IsNullOrWhiteSpace(hostname) ? hostname.Trim() : normalizedUrl) : name.Trim(); var workerHostname = string.IsNullOrWhiteSpace(hostname) ? workerName : hostname.Trim(); + FiestaLinkedWorker result; lock (_stateLock) { var existing = _linkedWorkers.FirstOrDefault(w => @@ -133,23 +134,27 @@ public void LinkWorker(string name, string hostname, string bridgeUrl, string to existing.BridgeUrl = normalizedUrl; existing.Token = token.Trim(); existing.LinkedAt = DateTime.UtcNow; + result = existing; } else { - _linkedWorkers.Add(new FiestaLinkedWorker + var added = new FiestaLinkedWorker { Name = workerName, Hostname = workerHostname, BridgeUrl = normalizedUrl, Token = token.Trim(), LinkedAt = DateTime.UtcNow - }); + }; + _linkedWorkers.Add(added); + result = added; } } SaveState(); UpdateLinkedWorkerPresence(); OnStateChanged?.Invoke(); + return result; } public void RemoveLinkedWorker(string workerId) @@ -306,6 +311,402 @@ public async Task HandleBridgeMessageAsync(string clientId, WebSocket ws, return true; } + // ---- Pairing string (Feature B) ---- + + public IReadOnlyList PendingPairRequests + { + get + { + lock (_stateLock) + return _pendingPairRequests.Values + .Where(r => r.ExpiresAt > DateTime.UtcNow) + .Select(r => new PendingPairRequestInfo + { + RequestId = r.RequestId, + HostName = r.HostName, + RemoteIp = r.RemoteIp, + ExpiresAt = r.ExpiresAt + }) + .ToList(); + } + } + + public string GeneratePairingString(string? preferredHost = null) + { + if (!_bridgeServer.IsRunning) + throw new InvalidOperationException("Bridge server is not running. Enable Direct Sharing first."); + + var token = EnsureServerPassword(); + + // If no explicit host supplied, prefer Tailscale IP/MagicDNS when running — + // it works across different networks, not just the local LAN. + if (preferredHost == null && _tailscale?.IsRunning == true) + preferredHost = _tailscale.MagicDnsName ?? _tailscale.TailscaleIp; + + var localIp = preferredHost ?? GetPrimaryLocalIpAddress() ?? "localhost"; + var url = $"http://{localIp}:{_bridgeServer.BridgePort}"; + + var payload = new FiestaPairingPayload + { + Url = url, + Token = token, + Hostname = Environment.MachineName + }; + var json = JsonSerializer.Serialize(payload, _jsonOptions); + var b64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(json)) + .TrimEnd('=') + .Replace('+', '-') + .Replace('/', '_'); + return $"pp+{b64}"; + } + + public FiestaLinkedWorker ParseAndLinkPairingString(string pairingString) + { + if (string.IsNullOrWhiteSpace(pairingString) || !pairingString.StartsWith("pp+", StringComparison.Ordinal)) + throw new FormatException("Not a valid PolyPilot pairing string (must start with 'pp+')."); + if (pairingString.Length > 4096) + throw new FormatException("Pairing string is too large."); + + var b64 = pairingString[3..].Replace('-', '+').Replace('_', '/'); + // Restore standard base64 padding + int remainder = b64.Length % 4; + var padded = remainder == 2 ? b64 + "==" + : remainder == 3 ? b64 + "=" + : b64; + + byte[] bytes; + try { bytes = Convert.FromBase64String(padded); } + catch (FormatException) { throw new FormatException("Pairing string is corrupted (invalid base64)."); } + + var json = Encoding.UTF8.GetString(bytes); + var parsed = JsonSerializer.Deserialize(json, _jsonOptions) + ?? throw new FormatException("Pairing string payload is empty."); + + if (string.IsNullOrWhiteSpace(parsed.Url)) + throw new FormatException("Pairing string is missing a URL."); + if (string.IsNullOrWhiteSpace(parsed.Token)) + throw new FormatException("Pairing string is missing a token."); + + var name = !string.IsNullOrWhiteSpace(parsed.Hostname) ? parsed.Hostname : "Unknown"; + var linked = LinkWorkerAndReturn(name, name, parsed.Url, parsed.Token) + ?? throw new InvalidOperationException("Failed to link worker (invalid URL or token)."); + return CloneLinkedWorker(linked); + } + + // ---- Push-to-pair — Worker (incoming) side (Feature C) ---- + + public async Task HandleIncomingPairHandshakeAsync(WebSocket ws, string remoteIp, CancellationToken ct) + { + // Read the initial pair request with a short timeout + using var readCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + readCts.CancelAfter(TimeSpan.FromSeconds(10)); + + BridgeMessage? msg; + try { msg = await ReadSingleMessageAsync(ws, readCts.Token); } + catch (OperationCanceledException) { return; } + + if (msg?.Type != BridgeMessageTypes.FiestaPairRequest) return; + + var req = msg.GetPayload(); + if (req == null || string.IsNullOrWhiteSpace(req.RequestId)) return; + + var pending = new PendingPairRequest + { + RequestId = req.RequestId, + HostInstanceId = req.HostInstanceId, + HostName = req.HostName, + RemoteIp = remoteIp, + Socket = ws, + ExpiresAt = DateTime.UtcNow.AddSeconds(60) + }; + + // Capture the TCS before releasing the lock + TaskCompletionSource tcs; + bool isDuplicate; + lock (_stateLock) + { + var requestsFromIp = _pendingPairRequests.Values.Count(r => r.RemoteIp == remoteIp); + isDuplicate = _pendingPairRequests.Count >= MaxPendingPairRequests + || requestsFromIp >= MaxPendingPairRequestsPerIp; + if (!isDuplicate) + { + _pendingPairRequests[req.RequestId] = pending; + tcs = pending.CompletionSource; + } + else + { + tcs = null!; // won't be used + } + } + + if (isDuplicate) + { + // Already handling a pair request — deny inline so the send completes + // before this method returns and the caller closes the socket. + try + { + await SendAsync(ws, BridgeMessage.Create(BridgeMessageTypes.FiestaPairResponse, + new FiestaPairResponsePayload { RequestId = req.RequestId, Approved = false }), ct); + } + catch { } + return; + } + + OnPairRequested?.Invoke(req.RequestId, req.HostName, remoteIp); + OnStateChanged?.Invoke(); + + // Wait for user approval/denial (up to 60s) + using var expiryCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + expiryCts.CancelAfter(TimeSpan.FromSeconds(60)); + try + { + await tcs.Task.WaitAsync(expiryCts.Token); + // Winner's send is in-flight — wait for it to complete before returning so the + // caller's finally (socket close) doesn't race the outgoing message. + try { await pending.SendComplete.Task.WaitAsync(TimeSpan.FromSeconds(5)); } catch (TimeoutException) { } catch (OperationCanceledException) { } + } + catch (OperationCanceledException) + { + // Timed out — auto-deny. Claim via TrySetResult first so we don't race with + // ApprovePairRequestAsync (only the winner of TrySetResult sends). + if (tcs.TrySetResult(false)) + { + try + { + await SendAsync(ws, BridgeMessage.Create(BridgeMessageTypes.FiestaPairResponse, + new FiestaPairResponsePayload { RequestId = req.RequestId, Approved = false }), CancellationToken.None); + } + catch { } + finally + { + pending.SendComplete.TrySetResult(); + } + } + else + { + // Approve already won — wait for its send to finish before closing socket + try { await pending.SendComplete.Task.WaitAsync(TimeSpan.FromSeconds(5)); } catch (TimeoutException) { } catch (OperationCanceledException) { } + } + } + finally + { + lock (_stateLock) _pendingPairRequests.Remove(req.RequestId); + OnStateChanged?.Invoke(); + } + } + + public async Task ApprovePairRequestAsync(string requestId) + { + PendingPairRequest? pending; + TaskCompletionSource? tcs; + lock (_stateLock) + { + if (!_pendingPairRequests.TryGetValue(requestId, out pending)) return false; + tcs = pending.CompletionSource; + } + + var token = EnsureServerPassword(); + var localIp = (_tailscale?.IsRunning == true ? (_tailscale.MagicDnsName ?? _tailscale.TailscaleIp) : null) + ?? GetPrimaryLocalIpAddress() ?? "localhost"; + var bridgeUrl = $"http://{localIp}:{_bridgeServer.BridgePort}"; + + // Atomically claim ownership. If the timeout already fired (TrySetResult(false) won), + // skip sending — the WebSocket may already be closed. + if (!tcs.TrySetResult(true)) + return false; // timeout already won, don't attempt a concurrent send + + try + { + await SendAsync(pending.Socket, BridgeMessage.Create( + BridgeMessageTypes.FiestaPairResponse, + new FiestaPairResponsePayload + { + RequestId = requestId, + Approved = true, + BridgeUrl = bridgeUrl, + Token = token, + WorkerName = Environment.MachineName + }), CancellationToken.None); + return true; + } + catch (Exception ex) + { + // TCS already resolved to true so this request cannot be retried or denied. + // Log clearly and fire event so the UI can prompt the user to retry from the host side. + var msg = ex.Message; + Console.WriteLine($"[Fiesta] Approval send failed (request={requestId}, irrecoverable): {msg}"); + OnPairApprovalSendFailed?.Invoke(requestId, msg); + return false; + } + finally + { + // Signal that our send is complete so HandleIncomingPairHandshakeAsync + // can safely return (allowing the caller to close the socket). + pending.SendComplete.TrySetResult(); + } + } + + public async Task DenyPairRequestAsync(string requestId) + { + PendingPairRequest? pending; + TaskCompletionSource? tcs; + lock (_stateLock) + { + if (!_pendingPairRequests.TryGetValue(requestId, out pending)) return; + tcs = pending.CompletionSource; + } + + // Atomically claim ownership — if approve already won, skip sending. + if (!tcs.TrySetResult(false)) + return; // approve already won, don't race on the socket + + try + { + await SendAsync(pending.Socket, BridgeMessage.Create( + BridgeMessageTypes.FiestaPairResponse, + new FiestaPairResponsePayload { RequestId = requestId, Approved = false }), + CancellationToken.None); + } + catch { } + finally + { + // Signal send complete so HandleIncomingPairHandshakeAsync can safely return. + pending.SendComplete.TrySetResult(); + } + } + + // Keep a synchronous shim for callers that can't await (e.g., Blazor @onclick non-async) + public void DenyPairRequest(string requestId) => + _ = DenyPairRequestAsync(requestId); + + // ---- Push-to-pair — Host (outgoing) side (Feature C) ---- + + public async Task RequestPairAsync(FiestaDiscoveredWorker worker, CancellationToken ct = default) + { + var wsUri = ToWebSocketUri(worker.BridgeUrl); + // Append /pair path + wsUri = wsUri.TrimEnd('/') + "/pair"; + var requestId = Guid.NewGuid().ToString("N"); + + try + { + using var ws = new ClientWebSocket(); + // No auth header — /pair is intentionally unauthenticated + using var connectCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + connectCts.CancelAfter(TimeSpan.FromSeconds(10)); + + await ws.ConnectAsync(new Uri(wsUri), connectCts.Token); + + await SendAsync(ws, BridgeMessage.Create( + BridgeMessageTypes.FiestaPairRequest, + new FiestaPairRequestPayload + { + RequestId = requestId, + HostInstanceId = _instanceId, + HostName = Environment.MachineName + }), ct); + + // Wait up to 65s for the worker to approve or deny + using var responseCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + responseCts.CancelAfter(TimeSpan.FromSeconds(65)); + var msg = await ReadSingleMessageAsync(ws, responseCts.Token); + + if (msg?.Type != BridgeMessageTypes.FiestaPairResponse) + return PairRequestResult.Unreachable; + + var resp = msg.GetPayload(); + if (resp == null || !resp.Approved) + return PairRequestResult.Denied; + + // Guard: an approval without connection details is a malformed response + if (string.IsNullOrWhiteSpace(resp.BridgeUrl) || string.IsNullOrWhiteSpace(resp.Token)) + return PairRequestResult.Unreachable; + + var workerName = !string.IsNullOrWhiteSpace(resp.WorkerName) ? resp.WorkerName : worker.Hostname; + LinkWorker(workerName, worker.Hostname, resp.BridgeUrl, resp.Token); + return PairRequestResult.Approved; + } + catch (WebSocketException) { return PairRequestResult.Unreachable; } + catch (OperationCanceledException) { return PairRequestResult.Timeout; } + } + + // ---- Shared helper: read a single framed WebSocket message ---- + + private static async Task ReadSingleMessageAsync(WebSocket ws, CancellationToken ct) + { + var buffer = new byte[65536]; + var sb = new StringBuilder(); + while (ws.State == WebSocketState.Open) + { + var result = await ws.ReceiveAsync(buffer, ct); + if (result.MessageType == WebSocketMessageType.Close) return null; + sb.Append(Encoding.UTF8.GetString(buffer, 0, result.Count)); + if (sb.Length > 256 * 1024) return null; // guard against unbounded frames on unauthenticated /pair path + if (result.EndOfMessage) break; + } + return BridgeMessage.Deserialize(sb.ToString()); + } + + // ---- Settings integration ---- + + private string EnsureServerPassword() + { + // Fast path: check the runtime value without disk I/O. + lock (_stateLock) + { + if (!string.IsNullOrWhiteSpace(_bridgeServer.ServerPassword)) + return _bridgeServer.ServerPassword; + } + + // Slow path: load settings outside the lock so steady-state pairing operations + // (which hold _stateLock for _pendingPairRequests / _linkedWorkers reads) are + // not blocked by disk I/O. + var settings = ConnectionSettings.Load(); + string candidatePassword; + bool needsSave = false; + + if (!string.IsNullOrWhiteSpace(settings.ServerPassword)) + { + candidatePassword = settings.ServerPassword; + } + else + { + // Generate a candidate; the final winner is decided inside the lock below. + candidatePassword = Convert.ToBase64String(System.Security.Cryptography.RandomNumberGenerator.GetBytes(18)) + .Replace('+', '-').Replace('/', '_').TrimEnd('='); + needsSave = true; + } + + // Re-enter the lock to elect exactly one winner. + // If another thread already stored a password we use that. + // If we win, we also save — under the lock — so the disk write and the + // runtime state stay in sync even when two threads race here simultaneously. + string password; + lock (_stateLock) + { + if (!string.IsNullOrWhiteSpace(_bridgeServer.ServerPassword)) + { + // Another thread already set it — no save needed. + password = _bridgeServer.ServerPassword; + } + else + { + password = candidatePassword; + _bridgeServer.ServerPassword = password; + if (needsSave) + { + // Persist inside the lock so disk and runtime value are always the same. + // This I/O happens only once per process lifetime (when no password existed). + settings.ServerPassword = password; + settings.Save(); + Console.WriteLine("[Fiesta] Auto-generated server password for pairing."); + } + } + } + + return password; + } + private async Task HandleFiestaAssignAsync(string clientId, WebSocket ws, FiestaAssignPayload assign, CancellationToken ct) { var workerName = Environment.MachineName; @@ -468,6 +869,11 @@ private async Task ReadTaskUpdatesAsync(ClientWebSocket ws, string hostSessionNa break; messageBuffer.Append(Encoding.UTF8.GetString(buffer, 0, result.Count)); + if (messageBuffer.Length > 256 * 1024) + { + try { await ws.CloseAsync(WebSocketCloseStatus.MessageTooBig, "Message exceeds 256KB limit", CancellationToken.None); } catch { } + break; // guard against unbounded frames + } if (!result.EndOfMessage) continue; @@ -586,7 +992,30 @@ private static string NormalizeBridgeUrl(string url) public static string GetFiestaWorkspaceDirectory(string fiestaName) { var safeName = SanitizeFiestaName(fiestaName); - return Path.Combine(GetPolyPilotBaseDir(), "workspace", safeName); + var baseDir = Path.GetFullPath(Path.Combine(CopilotService.BaseDir, "workspace")); + var fullPath = Path.GetFullPath(Path.Combine(baseDir, safeName)); + + // Primary guard: reject paths that escape baseDir by path components (covers ".." attacks). + var relativePath = Path.GetRelativePath(baseDir, fullPath); + if (relativePath.StartsWith("..", StringComparison.Ordinal)) + throw new InvalidOperationException("Workspace path escapes the base directory."); + + // Secondary guard: if the directory already exists, resolve symlinks and re-validate. + // Path.GetFullPath does NOT resolve symlinks, so a symlink inside the workspace tree + // could redirect to an arbitrary location. ResolveLinkTarget(returnFinalTarget: true) + // follows the full chain. Only needed when the directory exists (pre-created symlinks). + if (Directory.Exists(fullPath)) + { + var resolved = Directory.ResolveLinkTarget(fullPath, returnFinalTarget: true)?.FullName; + if (resolved != null) + { + var resolvedRelative = Path.GetRelativePath(baseDir, resolved); + if (resolvedRelative.StartsWith("..", StringComparison.Ordinal)) + throw new InvalidOperationException("Workspace directory is a symlink that escapes the base directory."); + } + } + + return fullPath; } private static string SanitizeFiestaName(string fiestaName) @@ -668,8 +1097,14 @@ private void SaveState() private void StartDiscovery() { _discoveryCts = new CancellationTokenSource(); - _broadcastTask = Task.Run(() => BroadcastPresenceLoopAsync(_discoveryCts.Token)); - _listenTask = Task.Run(() => ListenForWorkersLoopAsync(_discoveryCts.Token)); + // Capture the token struct NOW, before Task.Run queues the work. + // If Dispose() runs before the thread-pool picks up the lambda, accessing + // _discoveryCts.Token on a disposed CTS throws ObjectDisposedException. + // A captured CancellationToken struct remains valid (IsCancellationRequested=true) + // even after the parent CTS is cancelled and disposed. + var token = _discoveryCts.Token; + _broadcastTask = Task.Run(() => BroadcastPresenceLoopAsync(token)); + _listenTask = Task.Run(() => ListenForWorkersLoopAsync(token)); } private async Task BroadcastPresenceLoopAsync(CancellationToken ct) @@ -681,14 +1116,18 @@ private async Task BroadcastPresenceLoopAsync(CancellationToken ct) { if (_bridgeServer.IsRunning && _bridgeServer.BridgePort > 0) { - var localIp = GetPrimaryLocalIpAddress(); - if (!string.IsNullOrEmpty(localIp)) + // Prefer Tailscale IP in the broadcast so peers that receive it can reach us + // via Tailscale (works across networks). Fall back to primary LAN IP. + string? advertiseIp = (_tailscale?.IsRunning == true) + ? (_tailscale.TailscaleIp ?? GetPrimaryLocalIpAddress()) + : GetPrimaryLocalIpAddress(); + if (!string.IsNullOrEmpty(advertiseIp)) { var announcement = new FiestaDiscoveryAnnouncement { InstanceId = _instanceId, Hostname = Environment.MachineName, - BridgeUrl = $"http://{localIp}:{_bridgeServer.BridgePort}", + BridgeUrl = $"http://{advertiseIp}:{_bridgeServer.BridgePort}", TimestampUtc = DateTime.UtcNow }; @@ -714,6 +1153,7 @@ private async Task ListenForWorkersLoopAsync(CancellationToken ct) try { var result = await listener.ReceiveAsync(ct); + if (result.Buffer.Length > 4096) continue; // reject oversized discovery packets var json = Encoding.UTF8.GetString(result.Buffer); var announcement = JsonSerializer.Deserialize(json, _jsonOptions); if (announcement == null || string.IsNullOrWhiteSpace(announcement.InstanceId)) @@ -780,16 +1220,32 @@ private void UpdateLinkedWorkerPresence() { try { + string? best = null; + int bestScore = -1; + foreach (var ni in NetworkInterface.GetAllNetworkInterfaces()) { if (ni.OperationalStatus != OperationalStatus.Up) continue; if (ni.NetworkInterfaceType == NetworkInterfaceType.Loopback) continue; + if (ni.NetworkInterfaceType == NetworkInterfaceType.Tunnel) continue; + if (IsVirtualAdapterName(ni.Name)) continue; - var ip = ni.GetIPProperties().UnicastAddresses + var unicast = ni.GetIPProperties().UnicastAddresses .FirstOrDefault(a => a.Address.AddressFamily == AddressFamily.InterNetwork); - if (ip != null) - return ip.Address.ToString(); + if (unicast == null) continue; + + var addr = unicast.Address.ToString(); + if (IsVirtualAdapterIp(addr)) continue; + + int score = ScoreNetworkInterface(ni.NetworkInterfaceType, addr); + if (score > bestScore) + { + bestScore = score; + best = addr; + } } + + return best; } catch { @@ -798,6 +1254,57 @@ private void UpdateLinkedWorkerPresence() return null; } + private static bool IsVirtualAdapterName(string name) => + name.StartsWith("vEthernet", StringComparison.OrdinalIgnoreCase) || // Hyper-V + name.StartsWith("br-", StringComparison.OrdinalIgnoreCase) || // Docker bridge + name.StartsWith("virbr", StringComparison.OrdinalIgnoreCase) || // libvirt + name.Contains("docker", StringComparison.OrdinalIgnoreCase) || + name.Contains("WSL", StringComparison.OrdinalIgnoreCase) || + name.Contains("VMware", StringComparison.OrdinalIgnoreCase) || + name.Contains("VirtualBox", StringComparison.OrdinalIgnoreCase) || + name.Contains("ZeroTier", StringComparison.OrdinalIgnoreCase); + + private static bool IsVirtualAdapterIp(string ip) + { + // Filter known virtual/container subnets that Docker and VM managers use by default. + // 172.17–172.24 covers Docker's default bridge (172.17), Docker custom networks + // (typically 172.18–172.24), and common VMware/VirtualBox host-only subnets. + // 172.25–172.31 are also in RFC-1918 /12 but are less commonly assigned by tooling; + // we leave them through so legitimate corporate LANs in that range still work. + // The name-based filter (IsVirtualAdapterName) is the primary defense for adapters + // with names like "br-*", "docker*", "vEthernet", etc. + if (ip.StartsWith("172.", StringComparison.Ordinal)) + { + var parts = ip.Split('.'); + if (parts.Length >= 2 && int.TryParse(parts[1], out var oct) && oct >= 17 && oct <= 24) + return true; + } + return false; + } + + private static bool IsRfc1918_172(string ip) + { + var parts = ip.Split('.'); + return parts.Length >= 2 && int.TryParse(parts[1], out var oct) && oct >= 16 && oct <= 31; + } + + private static int ScoreNetworkInterface(NetworkInterfaceType type, string ip) + { + // Prefer RFC-1918 private ranges (real LAN) vs others + bool isPrivateLan = ip.StartsWith("192.168.", StringComparison.Ordinal) + || ip.StartsWith("10.", StringComparison.Ordinal) + || (ip.StartsWith("172.", StringComparison.Ordinal) && IsRfc1918_172(ip)); + + return type switch + { + NetworkInterfaceType.Ethernet => isPrivateLan ? 100 : 60, + NetworkInterfaceType.Wireless80211 => isPrivateLan ? 90 : 50, + NetworkInterfaceType.GigabitEthernet => isPrivateLan ? 100 : 60, + NetworkInterfaceType.FastEthernetT => isPrivateLan ? 100 : 60, + _ => isPrivateLan ? 20 : 5, + }; + } + private static FiestaDiscoveredWorker CloneDiscoveredWorker(FiestaDiscoveredWorker worker) => new() { diff --git a/PolyPilot/Services/ProcessHelper.cs b/PolyPilot/Services/ProcessHelper.cs new file mode 100644 index 000000000..35ea37faa --- /dev/null +++ b/PolyPilot/Services/ProcessHelper.cs @@ -0,0 +1,85 @@ +using System.Diagnostics; + +namespace PolyPilot.Services; + +/// +/// Safe wrappers for operations that can throw +/// when the process handle is +/// disposed or was never associated. +/// +public static class ProcessHelper +{ + /// + /// Returns true if the process has exited or the handle is invalid/disposed. + /// Unlike , this never throws. + /// A disposed or invalid process is treated as exited. + /// + public static bool SafeHasExited(Process? process) + { + if (process == null) + return true; + try + { + return process.HasExited; + } + catch (InvalidOperationException) + { + // "No process is associated with this object" — handle was disposed + return true; + } + catch (SystemException) + { + // Win32Exception, NotSupportedException, etc. + return true; + } + } + + /// + /// Attempts to kill the process tree. Swallows all exceptions — safe to call + /// on disposed or already-exited processes. + /// + public static void SafeKill(Process? process, bool entireProcessTree = true) + { + if (process == null) + return; + try + { + if (!process.HasExited) + process.Kill(entireProcessTree); + } + catch (UnauthorizedAccessException ex) + { + // Access denied — process belongs to another user or is protected. + // Log so the caller knows the process may still be running and holding ports. + Console.WriteLine($"[ProcessHelper] SafeKill: access denied for PID {TryGetPid(process)} — {ex.Message}"); + } + catch + { + // Process already exited, disposed, or other transient error — nothing to do. + } + } + + private static int TryGetPid(Process process) + { + try { return process.Id; } + catch { return -1; } + } + + /// + /// Kills (if alive) and disposes the process. Safe to call multiple times. + /// + public static void SafeKillAndDispose(Process? process, bool entireProcessTree = true) + { + if (process == null) + return; + SafeKill(process, entireProcessTree); + try + { + process.Dispose(); + } + catch + { + // Already disposed — ignore + } + } +} diff --git a/PolyPilot/Services/QrScannerService.cs b/PolyPilot/Services/QrScannerService.cs index 8e1b213dd..c378bfaad 100644 --- a/PolyPilot/Services/QrScannerService.cs +++ b/PolyPilot/Services/QrScannerService.cs @@ -6,11 +6,20 @@ namespace PolyPilot.Services; /// public class QrScannerService { + private readonly object _lock = new(); private TaskCompletionSource? _tcs; public Task ScanAsync() { - _tcs = new TaskCompletionSource(); + TaskCompletionSource captured; + lock (_lock) + { + if (_tcs != null && !_tcs.Task.IsCompleted) + return _tcs.Task; + + _tcs = new TaskCompletionSource(); + captured = _tcs; // capture inside lock — safe from field-swap races + } MainThread.BeginInvokeOnMainThread(async () => { @@ -21,20 +30,22 @@ public class QrScannerService if (currentPage != null) await currentPage.Navigation.PushModalAsync(scannerPage); else - _tcs?.TrySetResult(null); + captured.TrySetResult(null); } catch (Exception ex) { Console.WriteLine($"[QrScanner] Error launching scanner: {ex}"); - _tcs?.TrySetResult(null); + captured.TrySetResult(null); } }); - return _tcs.Task; + return captured.Task; } internal void SetResult(string? value) { - _tcs?.TrySetResult(value); + TaskCompletionSource? current; + lock (_lock) current = _tcs; + current?.TrySetResult(value); } } diff --git a/PolyPilot/Services/ServerManager.cs b/PolyPilot/Services/ServerManager.cs index ad09cd636..70fd3c9c1 100644 --- a/PolyPilot/Services/ServerManager.cs +++ b/PolyPilot/Services/ServerManager.cs @@ -37,8 +37,22 @@ public bool CheckServerRunning(string host = "127.0.0.1", int? port = null) try { using var client = new TcpClient(); - using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(1)); - client.ConnectAsync(host, port.Value, cts.Token).AsTask().GetAwaiter().GetResult(); + // Use Task.WaitAny with a timeout task instead of CancellationTokenSource. + // CancellationTokenSource disposal while ConnectAsync is still running its + // internal cleanup can produce unobserved ObjectDisposedException tasks. + var connectTask = client.ConnectAsync(host, port.Value); + int index = Task.WaitAny(new[] { connectTask }, TimeSpan.FromSeconds(1)); + if (index == -1) + { + // Timed out — observe any future exception (Faulted or Cancelled) to prevent + // unobserved task exceptions. NotOnRanToCompletion covers both Faulted and + // Cancelled states; OnlyOnFaulted would miss Cancelled (which can occur if the + // TcpClient is disposed while the connect is still in-flight). + _ = connectTask.ContinueWith(t => { _ = t.Exception; }, + TaskContinuationOptions.NotOnRanToCompletion); + return false; + } + connectTask.GetAwaiter().GetResult(); return true; } catch @@ -166,8 +180,7 @@ public void StopServer() try { var process = Process.GetProcessById(pid.Value); - process.Kill(); - process.Dispose(); + ProcessHelper.SafeKillAndDispose(process, entireProcessTree: false); Console.WriteLine($"[ServerManager] Killed server PID {pid}"); } catch (Exception ex) diff --git a/PolyPilot/Services/WsBridgeServer.cs b/PolyPilot/Services/WsBridgeServer.cs index b26dc6c4d..d58875804 100644 --- a/PolyPilot/Services/WsBridgeServer.cs +++ b/PolyPilot/Services/WsBridgeServer.cs @@ -22,6 +22,7 @@ public class WsBridgeServer : IDisposable private RepoManager? _repoManager; private readonly ConcurrentDictionary _clients = new(); private readonly ConcurrentDictionary _clientSendLocks = new(); + private long _lastPairRequestAcceptedAtTicks = DateTime.MinValue.Ticks; // Debounce timers to prevent flooding mobile clients during streaming private Timer? _sessionsListDebounce; @@ -232,7 +233,25 @@ private async Task AcceptLoopAsync(CancellationToken ct) { var context = await _listener!.GetContextAsync(); - if (context.Request.IsWebSocketRequest) + if (context.Request.IsWebSocketRequest && + context.Request.Url?.AbsolutePath == "/pair") + { + // Unauthenticated pairing handshake path — rate-limited at HTTP level + // Use Interlocked.CompareExchange to atomically claim the slot, preventing TOCTOU races. + var nowTicks = DateTime.UtcNow.Ticks; + var lastTicks = Interlocked.Read(ref _lastPairRequestAcceptedAtTicks); + var elapsed = TimeSpan.FromTicks(nowTicks - lastTicks); + if (elapsed.TotalSeconds < 5 || + Interlocked.CompareExchange(ref _lastPairRequestAcceptedAtTicks, nowTicks, lastTicks) != lastTicks) + { + context.Response.StatusCode = 429; + context.Response.Close(); + Console.WriteLine("[WsBridge] Pair request rate-limited"); + continue; + } + _ = Task.Run(() => HandlePairHandshakeAsync(context, ct), ct); + } + else if (context.Request.IsWebSocketRequest) { if (!ValidateClientToken(context.Request)) { @@ -442,6 +461,12 @@ await SendToClientAsync(clientId, ws, messageBuffer.Append(Encoding.UTF8.GetString(buffer, 0, result.Count)); + if (messageBuffer.Length > 256 * 1024) + { + try { await ws.CloseAsync(WebSocketCloseStatus.MessageTooBig, "Message exceeds 256KB limit", CancellationToken.None); } catch { } + break; // guard against unbounded frames + } + if (result.EndOfMessage) { var json = messageBuffer.ToString(); @@ -1477,4 +1502,31 @@ private static string TruncateSummary(string text, int maxLength = 100) ".tiff" => "image/tiff", _ => "image/png" }; + + private async Task HandlePairHandshakeAsync(HttpListenerContext ctx, CancellationToken ct) + { + WebSocket? ws = null; + try + { + var wsCtx = await ctx.AcceptWebSocketAsync(null); + ws = wsCtx.WebSocket; + var remoteIp = ctx.Request.RemoteEndPoint?.Address.ToString() ?? "unknown"; + Console.WriteLine($"[WsBridge] Pair handshake from {remoteIp}"); + + if (_fiestaService != null) + await _fiestaService.HandleIncomingPairHandshakeAsync(ws, remoteIp, ct); + } + catch (Exception ex) + { + Console.WriteLine($"[WsBridge] Pair handshake error: {ex.Message}"); + } + finally + { + if (ws?.State == WebSocketState.Open) + { + try { await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "done", CancellationToken.None); } + catch { } + } + } + } }