diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 07502ee2d..d5cb6707b 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -226,6 +226,7 @@ async Task StartCoreAsync(CancellationToken ct) // Verify protocol version compatibility await VerifyProtocolVersionAsync(connection, ct); + await ConfigureSessionFsAsync(ct); _logger.LogInformation("Copilot client connected"); return connection; @@ -474,6 +475,7 @@ public async Task CreateSessionAsync(SessionConfig config, Cance { session.On(config.OnEvent); } + ConfigureSessionFsHandlers(session, config.CreateSessionFsHandler); _sessions[sessionId] = session; try @@ -594,6 +596,7 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes { session.On(config.OnEvent); } + ConfigureSessionFsHandlers(session, config.CreateSessionFsHandler); _sessions[sessionId] = session; try @@ -1078,6 +1081,37 @@ private Task EnsureConnectedAsync(CancellationToken cancellationToke return (Task)StartAsync(cancellationToken); } + private async Task ConfigureSessionFsAsync(CancellationToken cancellationToken) + { + if (_options.SessionFs is null) + { + return; + } + + await Rpc.SessionFs.SetProviderAsync( + _options.SessionFs.InitialCwd, + _options.SessionFs.SessionStatePath, + _options.SessionFs.Conventions, + cancellationToken); + } + + private void ConfigureSessionFsHandlers(CopilotSession session, Func? createSessionFsHandler) + { + if (_options.SessionFs is null) + { + return; + } + + if (createSessionFsHandler is null) + { + throw new InvalidOperationException( + "CreateSessionFsHandler is required in the session config when CopilotClientOptions.SessionFs is configured."); + } + + session.ClientSessionApis.SessionFs = createSessionFsHandler(session) + ?? throw new InvalidOperationException("CreateSessionFsHandler returned null."); + } + private async Task VerifyProtocolVersionAsync(Connection connection, CancellationToken cancellationToken) { var maxVersion = SdkProtocolVersion.GetVersion(); @@ -1319,6 +1353,11 @@ private async Task ConnectToServerAsync(Process? cliProcess, string? rpc.AddLocalRpcMethod("userInput.request", handler.OnUserInputRequest); rpc.AddLocalRpcMethod("hooks.invoke", handler.OnHooksInvoke); rpc.AddLocalRpcMethod("systemMessage.transform", handler.OnSystemMessageTransform); + ClientSessionApiRegistration.RegisterClientSessionApiHandlers(rpc, sessionId => + { + var session = GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + return session.ClientSessionApis; + }); rpc.StartListening(); // Transition state to Disconnected if the JSON-RPC connection drops diff --git a/dotnet/src/Generated/Rpc.cs b/dotnet/src/Generated/Rpc.cs index 9907641b5..86d3daf2e 100644 --- a/dotnet/src/Generated/Rpc.cs +++ b/dotnet/src/Generated/Rpc.cs @@ -1264,6 +1264,230 @@ internal class SessionShellKillRequest public SessionShellKillRequestSignal? Signal { get; set; } } +/// RPC data type for SessionFsReadFile operations. +public class SessionFsReadFileResult +{ + /// File content as UTF-8 string. + [JsonPropertyName("content")] + public string Content { get; set; } = string.Empty; +} + +/// RPC data type for SessionFsReadFile operations. +public class SessionFsReadFileParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; +} + +/// RPC data type for SessionFsWriteFile operations. +public class SessionFsWriteFileParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; + + /// Content to write. + [JsonPropertyName("content")] + public string Content { get; set; } = string.Empty; + + /// Optional POSIX-style mode for newly created files. + [JsonPropertyName("mode")] + public double? Mode { get; set; } +} + +/// RPC data type for SessionFsAppendFile operations. +public class SessionFsAppendFileParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; + + /// Content to append. + [JsonPropertyName("content")] + public string Content { get; set; } = string.Empty; + + /// Optional POSIX-style mode for newly created files. + [JsonPropertyName("mode")] + public double? Mode { get; set; } +} + +/// RPC data type for SessionFsExists operations. +public class SessionFsExistsResult +{ + /// Whether the path exists. + [JsonPropertyName("exists")] + public bool Exists { get; set; } +} + +/// RPC data type for SessionFsExists operations. +public class SessionFsExistsParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; +} + +/// RPC data type for SessionFsStat operations. +public class SessionFsStatResult +{ + /// Whether the path is a file. + [JsonPropertyName("isFile")] + public bool IsFile { get; set; } + + /// Whether the path is a directory. + [JsonPropertyName("isDirectory")] + public bool IsDirectory { get; set; } + + /// File size in bytes. + [JsonPropertyName("size")] + public double Size { get; set; } + + /// ISO 8601 timestamp of last modification. + [JsonPropertyName("mtime")] + public string Mtime { get; set; } = string.Empty; + + /// ISO 8601 timestamp of creation. + [JsonPropertyName("birthtime")] + public string Birthtime { get; set; } = string.Empty; +} + +/// RPC data type for SessionFsStat operations. +public class SessionFsStatParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; +} + +/// RPC data type for SessionFsMkdir operations. +public class SessionFsMkdirParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; + + /// Create parent directories as needed. + [JsonPropertyName("recursive")] + public bool? Recursive { get; set; } + + /// Optional POSIX-style mode for newly created directories. + [JsonPropertyName("mode")] + public double? Mode { get; set; } +} + +/// RPC data type for SessionFsReaddir operations. +public class SessionFsReaddirResult +{ + /// Entry names in the directory. + [JsonPropertyName("entries")] + public List Entries { get => field ??= []; set; } +} + +/// RPC data type for SessionFsReaddir operations. +public class SessionFsReaddirParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; +} + +/// RPC data type for Entry operations. +public class Entry +{ + /// Entry name. + [JsonPropertyName("name")] + public string Name { get; set; } = string.Empty; + + /// Entry type. + [JsonPropertyName("type")] + public EntryType Type { get; set; } +} + +/// RPC data type for SessionFsReaddirWithTypes operations. +public class SessionFsReaddirWithTypesResult +{ + /// Directory entries with type information. + [JsonPropertyName("entries")] + public List Entries { get => field ??= []; set; } +} + +/// RPC data type for SessionFsReaddirWithTypes operations. +public class SessionFsReaddirWithTypesParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; +} + +/// RPC data type for SessionFsRm operations. +public class SessionFsRmParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; + + /// Remove directories and their contents recursively. + [JsonPropertyName("recursive")] + public bool? Recursive { get; set; } + + /// Ignore errors if the path does not exist. + [JsonPropertyName("force")] + public bool? Force { get; set; } +} + +/// RPC data type for SessionFsRename operations. +public class SessionFsRenameParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + + /// Source path using SessionFs conventions. + [JsonPropertyName("src")] + public string Src { get; set; } = string.Empty; + + /// Destination path using SessionFs conventions. + [JsonPropertyName("dest")] + public string Dest { get; set; } = string.Empty; +} + /// Path conventions used by this filesystem. [JsonConverter(typeof(JsonStringEnumConverter))] public enum SessionFsSetProviderRequestConventions @@ -1398,6 +1622,19 @@ public enum SessionShellKillRequestSignal } +/// Entry type. +[JsonConverter(typeof(JsonStringEnumConverter))] +public enum EntryType +{ + /// The file variant. + [JsonStringEnumMemberName("file")] + File, + /// The directory variant. + [JsonStringEnumMemberName("directory")] + Directory, +} + + /// Provides server-scoped RPC methods (no session required). public class ServerRpc { @@ -2075,6 +2312,151 @@ public async Task KillAsync(string processId, SessionShe } } +/// Handles `sessionFs` client session API methods. +public interface ISessionFsHandler +{ + /// Handles "sessionFs.readFile". + Task ReadFileAsync(SessionFsReadFileParams request, CancellationToken cancellationToken = default); + /// Handles "sessionFs.writeFile". + Task WriteFileAsync(SessionFsWriteFileParams request, CancellationToken cancellationToken = default); + /// Handles "sessionFs.appendFile". + Task AppendFileAsync(SessionFsAppendFileParams request, CancellationToken cancellationToken = default); + /// Handles "sessionFs.exists". + Task ExistsAsync(SessionFsExistsParams request, CancellationToken cancellationToken = default); + /// Handles "sessionFs.stat". + Task StatAsync(SessionFsStatParams request, CancellationToken cancellationToken = default); + /// Handles "sessionFs.mkdir". + Task MkdirAsync(SessionFsMkdirParams request, CancellationToken cancellationToken = default); + /// Handles "sessionFs.readdir". + Task ReaddirAsync(SessionFsReaddirParams request, CancellationToken cancellationToken = default); + /// Handles "sessionFs.readdirWithTypes". + Task ReaddirWithTypesAsync(SessionFsReaddirWithTypesParams request, CancellationToken cancellationToken = default); + /// Handles "sessionFs.rm". + Task RmAsync(SessionFsRmParams request, CancellationToken cancellationToken = default); + /// Handles "sessionFs.rename". + Task RenameAsync(SessionFsRenameParams request, CancellationToken cancellationToken = default); +} + +/// Provides all client session API handler groups for a session. +public class ClientSessionApiHandlers +{ + /// Optional handler for SessionFs client session API methods. + public ISessionFsHandler? SessionFs { get; set; } +} + +/// Registers client session API handlers on a JSON-RPC connection. +public static class ClientSessionApiRegistration +{ + /// + /// Registers handlers for server-to-client session API calls. + /// Each incoming call includes a sessionId in its params object, + /// which is used to resolve the session's handler group. + /// + public static void RegisterClientSessionApiHandlers(JsonRpc rpc, Func getHandlers) + { + var registerSessionFsReadFileMethod = (Func>)(async (request, cancellationToken) => + { + var handler = getHandlers(request.SessionId).SessionFs; + if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); + return await handler.ReadFileAsync(request, cancellationToken); + }); + rpc.AddLocalRpcMethod(registerSessionFsReadFileMethod.Method, registerSessionFsReadFileMethod.Target!, new JsonRpcMethodAttribute("sessionFs.readFile") + { + UseSingleObjectParameterDeserialization = true + }); + var registerSessionFsWriteFileMethod = (Func)(async (request, cancellationToken) => + { + var handler = getHandlers(request.SessionId).SessionFs; + if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); + await handler.WriteFileAsync(request, cancellationToken); + }); + rpc.AddLocalRpcMethod(registerSessionFsWriteFileMethod.Method, registerSessionFsWriteFileMethod.Target!, new JsonRpcMethodAttribute("sessionFs.writeFile") + { + UseSingleObjectParameterDeserialization = true + }); + var registerSessionFsAppendFileMethod = (Func)(async (request, cancellationToken) => + { + var handler = getHandlers(request.SessionId).SessionFs; + if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); + await handler.AppendFileAsync(request, cancellationToken); + }); + rpc.AddLocalRpcMethod(registerSessionFsAppendFileMethod.Method, registerSessionFsAppendFileMethod.Target!, new JsonRpcMethodAttribute("sessionFs.appendFile") + { + UseSingleObjectParameterDeserialization = true + }); + var registerSessionFsExistsMethod = (Func>)(async (request, cancellationToken) => + { + var handler = getHandlers(request.SessionId).SessionFs; + if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); + return await handler.ExistsAsync(request, cancellationToken); + }); + rpc.AddLocalRpcMethod(registerSessionFsExistsMethod.Method, registerSessionFsExistsMethod.Target!, new JsonRpcMethodAttribute("sessionFs.exists") + { + UseSingleObjectParameterDeserialization = true + }); + var registerSessionFsStatMethod = (Func>)(async (request, cancellationToken) => + { + var handler = getHandlers(request.SessionId).SessionFs; + if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); + return await handler.StatAsync(request, cancellationToken); + }); + rpc.AddLocalRpcMethod(registerSessionFsStatMethod.Method, registerSessionFsStatMethod.Target!, new JsonRpcMethodAttribute("sessionFs.stat") + { + UseSingleObjectParameterDeserialization = true + }); + var registerSessionFsMkdirMethod = (Func)(async (request, cancellationToken) => + { + var handler = getHandlers(request.SessionId).SessionFs; + if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); + await handler.MkdirAsync(request, cancellationToken); + }); + rpc.AddLocalRpcMethod(registerSessionFsMkdirMethod.Method, registerSessionFsMkdirMethod.Target!, new JsonRpcMethodAttribute("sessionFs.mkdir") + { + UseSingleObjectParameterDeserialization = true + }); + var registerSessionFsReaddirMethod = (Func>)(async (request, cancellationToken) => + { + var handler = getHandlers(request.SessionId).SessionFs; + if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); + return await handler.ReaddirAsync(request, cancellationToken); + }); + rpc.AddLocalRpcMethod(registerSessionFsReaddirMethod.Method, registerSessionFsReaddirMethod.Target!, new JsonRpcMethodAttribute("sessionFs.readdir") + { + UseSingleObjectParameterDeserialization = true + }); + var registerSessionFsReaddirWithTypesMethod = (Func>)(async (request, cancellationToken) => + { + var handler = getHandlers(request.SessionId).SessionFs; + if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); + return await handler.ReaddirWithTypesAsync(request, cancellationToken); + }); + rpc.AddLocalRpcMethod(registerSessionFsReaddirWithTypesMethod.Method, registerSessionFsReaddirWithTypesMethod.Target!, new JsonRpcMethodAttribute("sessionFs.readdirWithTypes") + { + UseSingleObjectParameterDeserialization = true + }); + var registerSessionFsRmMethod = (Func)(async (request, cancellationToken) => + { + var handler = getHandlers(request.SessionId).SessionFs; + if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); + await handler.RmAsync(request, cancellationToken); + }); + rpc.AddLocalRpcMethod(registerSessionFsRmMethod.Method, registerSessionFsRmMethod.Target!, new JsonRpcMethodAttribute("sessionFs.rm") + { + UseSingleObjectParameterDeserialization = true + }); + var registerSessionFsRenameMethod = (Func)(async (request, cancellationToken) => + { + var handler = getHandlers(request.SessionId).SessionFs; + if (handler is null) throw new InvalidOperationException($"No sessionFs handler registered for session: {request.SessionId}"); + await handler.RenameAsync(request, cancellationToken); + }); + rpc.AddLocalRpcMethod(registerSessionFsRenameMethod.Method, registerSessionFsRenameMethod.Target!, new JsonRpcMethodAttribute("sessionFs.rename") + { + UseSingleObjectParameterDeserialization = true + }); + } +} + [JsonSourceGenerationOptions( JsonSerializerDefaults.Web, AllowOutOfOrderMetadataProperties = true, @@ -2082,6 +2464,7 @@ public async Task KillAsync(string processId, SessionShe [JsonSerializable(typeof(AccountGetQuotaResult))] [JsonSerializable(typeof(AccountGetQuotaResultQuotaSnapshotsValue))] [JsonSerializable(typeof(Agent))] +[JsonSerializable(typeof(Entry))] [JsonSerializable(typeof(Extension))] [JsonSerializable(typeof(Model))] [JsonSerializable(typeof(ModelBilling))] @@ -2125,8 +2508,23 @@ public async Task KillAsync(string processId, SessionShe [JsonSerializable(typeof(SessionExtensionsReloadResult))] [JsonSerializable(typeof(SessionFleetStartRequest))] [JsonSerializable(typeof(SessionFleetStartResult))] +[JsonSerializable(typeof(SessionFsAppendFileParams))] +[JsonSerializable(typeof(SessionFsExistsParams))] +[JsonSerializable(typeof(SessionFsExistsResult))] +[JsonSerializable(typeof(SessionFsMkdirParams))] +[JsonSerializable(typeof(SessionFsReadFileParams))] +[JsonSerializable(typeof(SessionFsReadFileResult))] +[JsonSerializable(typeof(SessionFsReaddirParams))] +[JsonSerializable(typeof(SessionFsReaddirResult))] +[JsonSerializable(typeof(SessionFsReaddirWithTypesParams))] +[JsonSerializable(typeof(SessionFsReaddirWithTypesResult))] +[JsonSerializable(typeof(SessionFsRenameParams))] +[JsonSerializable(typeof(SessionFsRmParams))] [JsonSerializable(typeof(SessionFsSetProviderRequest))] [JsonSerializable(typeof(SessionFsSetProviderResult))] +[JsonSerializable(typeof(SessionFsStatParams))] +[JsonSerializable(typeof(SessionFsStatResult))] +[JsonSerializable(typeof(SessionFsWriteFileParams))] [JsonSerializable(typeof(SessionLogRequest))] [JsonSerializable(typeof(SessionLogResult))] [JsonSerializable(typeof(SessionMcpDisableRequest))] diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 6d0a78d4c..4e5142cb8 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -124,6 +124,8 @@ public sealed partial class CopilotSession : IAsyncDisposable /// public ISessionUiApi Ui { get; } + internal ClientSessionApiHandlers ClientSessionApis { get; } = new(); + /// /// Initializes a new instance of the class. /// diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 265781bac..2f81f3b4c 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -68,6 +68,7 @@ protected CopilotClientOptions(CopilotClientOptions? other) UseLoggedInUser = other.UseLoggedInUser; UseStdio = other.UseStdio; OnListModels = other.OnListModels; + SessionFs = other.SessionFs; } /// @@ -150,6 +151,14 @@ public string? GithubToken /// public Func>>? OnListModels { get; set; } + /// + /// Custom session filesystem provider configuration. + /// When set, the client registers as the session filesystem provider on connect, + /// routing session-scoped file I/O through per-session handlers created via + /// or . + /// + public SessionFsConfig? SessionFs { get; set; } + /// /// OpenTelemetry configuration for the CLI server. /// When set to a non- instance, the CLI server is started with OpenTelemetry instrumentation enabled. @@ -217,6 +226,28 @@ public sealed class TelemetryConfig public bool? CaptureContent { get; set; } } +/// +/// Configuration for a custom session filesystem provider. +/// +public sealed class SessionFsConfig +{ + /// + /// Initial working directory for sessions (user's project directory). + /// + public required string InitialCwd { get; init; } + + /// + /// Path within each session's SessionFs where the runtime stores + /// session-scoped files (events, workspace, checkpoints, and temp files). + /// + public required string SessionStatePath { get; init; } + + /// + /// Path conventions used by this filesystem provider. + /// + public required SessionFsSetProviderRequestConventions Conventions { get; init; } +} + /// /// Represents a binary result returned by a tool invocation. /// @@ -1586,6 +1617,7 @@ protected SessionConfig(SessionConfig? other) OnUserInputRequest = other.OnUserInputRequest; Provider = other.Provider; ReasoningEffort = other.ReasoningEffort; + CreateSessionFsHandler = other.CreateSessionFsHandler; SessionId = other.SessionId; SkillDirectories = other.SkillDirectories is not null ? [.. other.SkillDirectories] : null; Streaming = other.Streaming; @@ -1737,6 +1769,12 @@ protected SessionConfig(SessionConfig? other) /// public SessionEventHandler? OnEvent { get; set; } + /// + /// Supplies a handler for session filesystem operations. + /// This is used only when is configured. + /// + public Func? CreateSessionFsHandler { get; set; } + /// /// Creates a shallow clone of this instance. /// @@ -1793,6 +1831,7 @@ protected ResumeSessionConfig(ResumeSessionConfig? other) OnUserInputRequest = other.OnUserInputRequest; Provider = other.Provider; ReasoningEffort = other.ReasoningEffort; + CreateSessionFsHandler = other.CreateSessionFsHandler; SkillDirectories = other.SkillDirectories is not null ? [.. other.SkillDirectories] : null; Streaming = other.Streaming; SystemMessage = other.SystemMessage; @@ -1941,6 +1980,12 @@ protected ResumeSessionConfig(ResumeSessionConfig? other) /// public SessionEventHandler? OnEvent { get; set; } + /// + /// Supplies a handler for session filesystem operations. + /// This is used only when is configured. + /// + public Func? CreateSessionFsHandler { get; set; } + /// /// Creates a shallow clone of this instance. /// diff --git a/dotnet/test/Harness/E2ETestContext.cs b/dotnet/test/Harness/E2ETestContext.cs index 0da0fdad5..47c8b2c4d 100644 --- a/dotnet/test/Harness/E2ETestContext.cs +++ b/dotnet/test/Harness/E2ETestContext.cs @@ -92,16 +92,27 @@ public IReadOnlyDictionary GetEnvironment() return env!; } - public CopilotClient CreateClient(bool useStdio = true) + public CopilotClient CreateClient(bool useStdio = true, CopilotClientOptions? options = null) { - return new(new CopilotClientOptions + options ??= new CopilotClientOptions(); + + options.Cwd ??= WorkDir; + options.Environment ??= GetEnvironment(); + options.UseStdio = useStdio; + + if (string.IsNullOrEmpty(options.CliUrl)) { - Cwd = WorkDir, - CliPath = GetCliPath(_repoRoot), - Environment = GetEnvironment(), - UseStdio = useStdio, - GitHubToken = !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("GITHUB_ACTIONS")) ? "fake-token-for-e2e-tests" : null, - }); + options.CliPath ??= GetCliPath(_repoRoot); + } + + if (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("GITHUB_ACTIONS")) + && string.IsNullOrEmpty(options.GitHubToken) + && string.IsNullOrEmpty(options.CliUrl)) + { + options.GitHubToken = "fake-token-for-e2e-tests"; + } + + return new(options); } public async ValueTask DisposeAsync() diff --git a/dotnet/test/SessionFsTests.cs b/dotnet/test/SessionFsTests.cs new file mode 100644 index 000000000..b985e15af --- /dev/null +++ b/dotnet/test/SessionFsTests.cs @@ -0,0 +1,526 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.SDK.Rpc; +using GitHub.Copilot.SDK.Test.Harness; +using Microsoft.Extensions.AI; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.SDK.Test; + +public class SessionFsTests(E2ETestFixture fixture, ITestOutputHelper output) + : E2ETestBase(fixture, "session_fs", output) +{ + private static readonly SessionFsConfig SessionFsConfig = new() + { + InitialCwd = "/", + SessionStatePath = "/session-state", + Conventions = SessionFsSetProviderRequestConventions.Posix, + }; + + [Fact] + public async Task Should_Route_File_Operations_Through_The_Session_Fs_Provider() + { + var providerRoot = CreateProviderRoot(); + try + { + await using var client = CreateSessionFsClient(providerRoot); + + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = s => new TestSessionFsHandler(s.SessionId, providerRoot), + }); + + var msg = await session.SendAndWaitAsync(new MessageOptions { Prompt = "What is 100 + 200?" }); + Assert.Contains("300", msg?.Data.Content ?? string.Empty); + await session.DisposeAsync(); + + var eventsPath = GetStoredPath(providerRoot, session.SessionId, "/session-state/events.jsonl"); + await WaitForConditionAsync(() => File.Exists(eventsPath)); + var content = await ReadAllTextSharedAsync(eventsPath); + Assert.Contains("300", content); + } + finally + { + await TryDeleteDirectoryAsync(providerRoot); + } + } + + [Fact] + public async Task Should_Load_Session_Data_From_Fs_Provider_On_Resume() + { + var providerRoot = CreateProviderRoot(); + try + { + await using var client = CreateSessionFsClient(providerRoot); + Func createSessionFsHandler = s => new TestSessionFsHandler(s.SessionId, providerRoot); + + var session1 = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = createSessionFsHandler, + }); + var sessionId = session1.SessionId; + + var msg = await session1.SendAndWaitAsync(new MessageOptions { Prompt = "What is 50 + 50?" }); + Assert.Contains("100", msg?.Data.Content ?? string.Empty); + await session1.DisposeAsync(); + + var eventsPath = GetStoredPath(providerRoot, sessionId, "/session-state/events.jsonl"); + await WaitForConditionAsync(() => File.Exists(eventsPath)); + + var session2 = await client.ResumeSessionAsync(sessionId, new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = createSessionFsHandler, + }); + + var msg2 = await session2.SendAndWaitAsync(new MessageOptions { Prompt = "What is that times 3?" }); + Assert.Contains("300", msg2?.Data.Content ?? string.Empty); + await session2.DisposeAsync(); + } + finally + { + await TryDeleteDirectoryAsync(providerRoot); + } + } + + [Fact] + public async Task Should_Reject_SetProvider_When_Sessions_Already_Exist() + { + var providerRoot = CreateProviderRoot(); + try + { + await using var client1 = CreateSessionFsClient(providerRoot, useStdio: false); + var createSessionFsHandler = (Func)(s => new TestSessionFsHandler(s.SessionId, providerRoot)); + + _ = await client1.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = createSessionFsHandler, + }); + + var port = client1.ActualPort + ?? throw new InvalidOperationException("Client1 is not using TCP mode; ActualPort is null"); + + var client2 = Ctx.CreateClient( + useStdio: false, + options: new CopilotClientOptions + { + CliUrl = $"localhost:{port}", + LogLevel = "error", + SessionFs = SessionFsConfig, + }); + + try + { + await Assert.ThrowsAnyAsync(() => client2.StartAsync()); + } + finally + { + try + { + await client2.ForceStopAsync(); + } + catch (IOException ex) + { + Console.Error.WriteLine($"Ignoring expected teardown IOException from ForceStopAsync: {ex.Message}"); + } + } + } + finally + { + await TryDeleteDirectoryAsync(providerRoot); + } + } + + [Fact] + public async Task Should_Map_Large_Output_Handling_Into_SessionFs() + { + var providerRoot = CreateProviderRoot(); + try + { + const int largeContentSize = 100_000; + var suppliedFileContent = new string('x', largeContentSize); + + await using var client = CreateSessionFsClient(providerRoot); + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = s => new TestSessionFsHandler(s.SessionId, providerRoot), + Tools = + [ + AIFunctionFactory.Create(() => suppliedFileContent, "get_big_string", "Returns a large string") + ], + }); + + await session.SendAndWaitAsync(new MessageOptions + { + Prompt = "Call the get_big_string tool and reply with the word DONE only.", + }); + + var messages = await session.GetMessagesAsync(); + var toolResult = FindToolCallResult(messages, "get_big_string"); + Assert.NotNull(toolResult); + Assert.Contains("/session-state/temp/", toolResult); + + var match = System.Text.RegularExpressions.Regex.Match( + toolResult!, + @"([/\\]session-state[/\\]temp[/\\][^\s]+)"); + Assert.True(match.Success); + + var fileContent = await ReadAllTextSharedAsync(GetStoredPath(providerRoot, session.SessionId, match.Groups[1].Value)); + Assert.Equal(suppliedFileContent, fileContent); + await session.DisposeAsync(); + } + finally + { + await TryDeleteDirectoryAsync(providerRoot); + } + } + + [Fact] + public async Task Should_Succeed_With_Compaction_While_Using_SessionFs() + { + var providerRoot = CreateProviderRoot(); + try + { + await using var client = CreateSessionFsClient(providerRoot); + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = s => new TestSessionFsHandler(s.SessionId, providerRoot), + }); + + SessionCompactionCompleteEvent? compactionEvent = null; + using var _ = session.On(evt => + { + if (evt is SessionCompactionCompleteEvent complete) + { + compactionEvent = complete; + } + }); + + await session.SendAndWaitAsync(new MessageOptions { Prompt = "What is 2+2?" }); + + var eventsPath = GetStoredPath(providerRoot, session.SessionId, "/session-state/events.jsonl"); + await WaitForConditionAsync(() => File.Exists(eventsPath), TimeSpan.FromSeconds(30)); + var contentBefore = await ReadAllTextSharedAsync(eventsPath); + Assert.DoesNotContain("checkpointNumber", contentBefore); + + await session.Rpc.Compaction.CompactAsync(); + await WaitForConditionAsync(() => compactionEvent is not null, TimeSpan.FromSeconds(30)); + Assert.True(compactionEvent!.Data.Success); + + await WaitForConditionAsync(async () => + { + var content = await ReadAllTextSharedAsync(eventsPath); + return content.Contains("checkpointNumber", StringComparison.Ordinal); + }, TimeSpan.FromSeconds(30)); + } + finally + { + await TryDeleteDirectoryAsync(providerRoot); + } + } + + private CopilotClient CreateSessionFsClient(string providerRoot, bool useStdio = true) + { + Directory.CreateDirectory(providerRoot); + return Ctx.CreateClient( + useStdio: useStdio, + options: new CopilotClientOptions + { + SessionFs = SessionFsConfig, + }); + } + + private static string? FindToolCallResult(IReadOnlyList messages, string toolName) + { + var callId = messages + .OfType() + .FirstOrDefault(m => string.Equals(m.Data.ToolName, toolName, StringComparison.Ordinal)) + ?.Data.ToolCallId; + + if (callId is null) + { + return null; + } + + return messages + .OfType() + .FirstOrDefault(m => string.Equals(m.Data.ToolCallId, callId, StringComparison.Ordinal)) + ?.Data.Result?.Content; + } + + private static string CreateProviderRoot() + => Path.Join(Path.GetTempPath(), $"copilot-sessionfs-{Guid.NewGuid():N}"); + + private static string GetStoredPath(string providerRoot, string sessionId, string sessionPath) + { + var safeSessionId = NormalizeRelativePathSegment(sessionId, nameof(sessionId)); + var relativeSegments = sessionPath + .TrimStart('/', '\\') + .Split(['/', '\\'], StringSplitOptions.RemoveEmptyEntries) + .Select(segment => NormalizeRelativePathSegment(segment, nameof(sessionPath))) + .ToArray(); + + return Path.Join([providerRoot, safeSessionId, .. relativeSegments]); + } + + private static async Task WaitForConditionAsync(Func condition, TimeSpan? timeout = null) + { + await WaitForConditionAsync(() => Task.FromResult(condition()), timeout); + } + + private static async Task WaitForConditionAsync(Func> condition, TimeSpan? timeout = null) + { + var deadline = DateTime.UtcNow + (timeout ?? TimeSpan.FromSeconds(30)); + Exception? lastException = null; + while (DateTime.UtcNow < deadline) + { + try + { + if (await condition()) + { + return; + } + } + catch (IOException ex) + { + lastException = ex; + } + catch (UnauthorizedAccessException ex) + { + lastException = ex; + } + + await Task.Delay(100); + } + + throw new TimeoutException("Timed out waiting for condition.", lastException); + } + + private static async Task ReadAllTextSharedAsync(string path, CancellationToken cancellationToken = default) + { + await using var stream = new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.ReadWrite | FileShare.Delete); + using var reader = new StreamReader(stream); + return await reader.ReadToEndAsync(cancellationToken); + } + + private static async Task TryDeleteDirectoryAsync(string path) + { + if (!Directory.Exists(path)) + { + return; + } + + var deadline = DateTime.UtcNow + TimeSpan.FromSeconds(5); + Exception? lastException = null; + + while (DateTime.UtcNow < deadline) + { + try + { + if (!Directory.Exists(path)) + { + return; + } + + Directory.Delete(path, recursive: true); + return; + } + catch (IOException ex) + { + lastException = ex; + } + catch (UnauthorizedAccessException ex) + { + lastException = ex; + } + + await Task.Delay(100); + } + + if (lastException is not null) + { + throw lastException; + } + } + + private static string NormalizeRelativePathSegment(string segment, string paramName) + { + if (string.IsNullOrWhiteSpace(segment)) + { + throw new InvalidOperationException($"{paramName} must not be empty."); + } + + var normalized = segment.TrimStart(Path.DirectorySeparatorChar, Path.AltDirectorySeparatorChar); + if (Path.IsPathRooted(normalized) || normalized.Contains(Path.VolumeSeparatorChar)) + { + throw new InvalidOperationException($"{paramName} must be a relative path segment: {segment}"); + } + + return normalized; + } + + private sealed class TestSessionFsHandler(string sessionId, string rootDir) : ISessionFsHandler + { + public async Task ReadFileAsync(SessionFsReadFileParams request, CancellationToken cancellationToken = default) + { + var content = await File.ReadAllTextAsync(ResolvePath(request.Path), cancellationToken); + return new SessionFsReadFileResult { Content = content }; + } + + public async Task WriteFileAsync(SessionFsWriteFileParams request, CancellationToken cancellationToken = default) + { + var fullPath = ResolvePath(request.Path); + Directory.CreateDirectory(Path.GetDirectoryName(fullPath)!); + await File.WriteAllTextAsync(fullPath, request.Content, cancellationToken); + } + + public async Task AppendFileAsync(SessionFsAppendFileParams request, CancellationToken cancellationToken = default) + { + var fullPath = ResolvePath(request.Path); + Directory.CreateDirectory(Path.GetDirectoryName(fullPath)!); + await File.AppendAllTextAsync(fullPath, request.Content, cancellationToken); + } + + public Task ExistsAsync(SessionFsExistsParams request, CancellationToken cancellationToken = default) + { + var fullPath = ResolvePath(request.Path); + return Task.FromResult(new SessionFsExistsResult + { + Exists = File.Exists(fullPath) || Directory.Exists(fullPath), + }); + } + + public Task StatAsync(SessionFsStatParams request, CancellationToken cancellationToken = default) + { + var fullPath = ResolvePath(request.Path); + if (File.Exists(fullPath)) + { + var info = new FileInfo(fullPath); + return Task.FromResult(new SessionFsStatResult + { + IsFile = true, + IsDirectory = false, + Size = info.Length, + Mtime = info.LastWriteTimeUtc.ToString("O"), + Birthtime = info.CreationTimeUtc.ToString("O"), + }); + } + + var dirInfo = new DirectoryInfo(fullPath); + if (!dirInfo.Exists) + { + throw new FileNotFoundException($"Path does not exist: {request.Path}"); + } + + return Task.FromResult(new SessionFsStatResult + { + IsFile = false, + IsDirectory = true, + Size = 0, + Mtime = dirInfo.LastWriteTimeUtc.ToString("O"), + Birthtime = dirInfo.CreationTimeUtc.ToString("O"), + }); + } + + public Task MkdirAsync(SessionFsMkdirParams request, CancellationToken cancellationToken = default) + { + Directory.CreateDirectory(ResolvePath(request.Path)); + return Task.CompletedTask; + } + + public Task ReaddirAsync(SessionFsReaddirParams request, CancellationToken cancellationToken = default) + { + var entries = Directory + .EnumerateFileSystemEntries(ResolvePath(request.Path)) + .Select(Path.GetFileName) + .Where(name => name is not null) + .Cast() + .ToList(); + + return Task.FromResult(new SessionFsReaddirResult { Entries = entries }); + } + + public Task ReaddirWithTypesAsync(SessionFsReaddirWithTypesParams request, CancellationToken cancellationToken = default) + { + var entries = Directory + .EnumerateFileSystemEntries(ResolvePath(request.Path)) + .Select(path => new Entry + { + Name = Path.GetFileName(path), + Type = Directory.Exists(path) ? EntryType.Directory : EntryType.File, + }) + .ToList(); + + return Task.FromResult(new SessionFsReaddirWithTypesResult { Entries = entries }); + } + + public Task RmAsync(SessionFsRmParams request, CancellationToken cancellationToken = default) + { + var fullPath = ResolvePath(request.Path); + + if (File.Exists(fullPath)) + { + File.Delete(fullPath); + return Task.CompletedTask; + } + + if (Directory.Exists(fullPath)) + { + Directory.Delete(fullPath, request.Recursive ?? false); + return Task.CompletedTask; + } + + if (request.Force == true) + { + return Task.CompletedTask; + } + + throw new FileNotFoundException($"Path does not exist: {request.Path}"); + } + + public Task RenameAsync(SessionFsRenameParams request, CancellationToken cancellationToken = default) + { + var src = ResolvePath(request.Src); + var dest = ResolvePath(request.Dest); + Directory.CreateDirectory(Path.GetDirectoryName(dest)!); + + if (Directory.Exists(src)) + { + Directory.Move(src, dest); + } + else + { + File.Move(src, dest, overwrite: true); + } + + return Task.CompletedTask; + } + + private string ResolvePath(string sessionPath) + { + var normalizedSessionId = NormalizeRelativePathSegment(sessionId, nameof(sessionId)); + var sessionRoot = Path.GetFullPath(Path.Join(rootDir, normalizedSessionId)); + var relativeSegments = sessionPath + .TrimStart('/', '\\') + .Split(['/', '\\'], StringSplitOptions.RemoveEmptyEntries) + .Select(segment => NormalizeRelativePathSegment(segment, nameof(sessionPath))) + .ToArray(); + + var fullPath = Path.GetFullPath(Path.Join([sessionRoot, .. relativeSegments])); + if (!fullPath.StartsWith(sessionRoot, StringComparison.Ordinal)) + { + throw new InvalidOperationException($"Path escapes session root: {sessionPath}"); + } + + return fullPath; + } + } +} diff --git a/go/client.go b/go/client.go index 731efbe24..188fae920 100644 --- a/go/client.go +++ b/go/client.go @@ -53,6 +53,22 @@ import ( const noResultPermissionV2Error = "permission handlers cannot return 'no-result' when connected to a protocol v2 server" +func validateSessionFsConfig(config *SessionFsConfig) error { + if config == nil { + return nil + } + if config.InitialCwd == "" { + return errors.New("SessionFs.InitialCwd is required") + } + if config.SessionStatePath == "" { + return errors.New("SessionFs.SessionStatePath is required") + } + if config.Conventions != rpc.ConventionsPosix && config.Conventions != rpc.ConventionsWindows { + return errors.New("SessionFs.Conventions must be either 'posix' or 'windows'") + } + return nil +} + // Client manages the connection to the Copilot CLI server and provides session management. // // The Client can either spawn a CLI server process or connect to an existing server. @@ -192,6 +208,13 @@ func NewClient(options *ClientOptions) *Client { if options.OnListModels != nil { client.onListModels = options.OnListModels } + if options.SessionFs != nil { + if err := validateSessionFsConfig(options.SessionFs); err != nil { + panic(err.Error()) + } + sessionFs := *options.SessionFs + opts.SessionFs = &sessionFs + } } // Default Env to current environment if not set @@ -305,6 +328,20 @@ func (c *Client) Start(ctx context.Context) error { return errors.Join(err, killErr) } + // If a session filesystem provider was configured, register it. + if c.options.SessionFs != nil { + _, err := c.RPC.SessionFs.SetProvider(ctx, &rpc.SessionFSSetProviderParams{ + InitialCwd: c.options.SessionFs.InitialCwd, + SessionStatePath: c.options.SessionFs.SessionStatePath, + Conventions: c.options.SessionFs.Conventions, + }) + if err != nil { + killErr := c.killProcess() + c.state = StateError + return errors.Join(err, killErr) + } + } + c.state = StateConnected return nil } @@ -623,6 +660,16 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses c.sessions[sessionID] = session c.sessionsMux.Unlock() + if c.options.SessionFs != nil { + if config.CreateSessionFsHandler == nil { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + return nil, fmt.Errorf("CreateSessionFsHandler is required in session config when SessionFs is enabled in client options") + } + session.clientSessionApis.SessionFs = config.CreateSessionFsHandler(session) + } + result, err := c.client.Request("session.create", req) if err != nil { c.sessionsMux.Lock() @@ -763,6 +810,16 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, c.sessions[sessionID] = session c.sessionsMux.Unlock() + if c.options.SessionFs != nil { + if config.CreateSessionFsHandler == nil { + c.sessionsMux.Lock() + delete(c.sessions, sessionID) + c.sessionsMux.Unlock() + return nil, fmt.Errorf("CreateSessionFsHandler is required in session config when SessionFs is enabled in client options") + } + session.clientSessionApis.SessionFs = config.CreateSessionFsHandler(session) + } + result, err := c.client.Request("session.resume", req) if err != nil { c.sessionsMux.Lock() @@ -1526,6 +1583,15 @@ func (c *Client) setupNotificationHandler() { c.client.SetRequestHandler("userInput.request", jsonrpc2.RequestHandlerFor(c.handleUserInputRequest)) c.client.SetRequestHandler("hooks.invoke", jsonrpc2.RequestHandlerFor(c.handleHooksInvoke)) c.client.SetRequestHandler("systemMessage.transform", jsonrpc2.RequestHandlerFor(c.handleSystemMessageTransform)) + rpc.RegisterClientSessionApiHandlers(c.client, func(sessionID string) *rpc.ClientSessionApiHandlers { + c.sessionsMux.Lock() + defer c.sessionsMux.Unlock() + session := c.sessions[sessionID] + if session == nil { + return nil + } + return session.clientSessionApis + }) } func (c *Client) handleSessionEvent(req sessionEventRequest) { diff --git a/go/client_test.go b/go/client_test.go index 8f302f338..1b88eda20 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -9,6 +9,8 @@ import ( "regexp" "sync" "testing" + + "github.com/github/copilot-sdk/go/rpc" ) // This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.test.go instead @@ -223,6 +225,48 @@ func TestClient_URLParsing(t *testing.T) { }) } +func TestClient_SessionFsConfig(t *testing.T) { + t.Run("should throw error when InitialCwd is missing", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for missing SessionFs.InitialCwd") + } else { + matched, _ := regexp.MatchString("SessionFs.InitialCwd is required", r.(string)) + if !matched { + t.Errorf("Expected panic message to contain 'SessionFs.InitialCwd is required', got: %v", r) + } + } + }() + + NewClient(&ClientOptions{ + SessionFs: &SessionFsConfig{ + SessionStatePath: "/session-state", + Conventions: rpc.ConventionsPosix, + }, + }) + }) + + t.Run("should throw error when SessionStatePath is missing", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for missing SessionFs.SessionStatePath") + } else { + matched, _ := regexp.MatchString("SessionFs.SessionStatePath is required", r.(string)) + if !matched { + t.Errorf("Expected panic message to contain 'SessionFs.SessionStatePath is required', got: %v", r) + } + } + }() + + NewClient(&ClientOptions{ + SessionFs: &SessionFsConfig{ + InitialCwd: "/", + Conventions: rpc.ConventionsPosix, + }, + }) + }) +} + func TestClient_AuthOptions(t *testing.T) { t.Run("should accept GitHubToken option", func(t *testing.T) { client := NewClient(&ClientOptions{ diff --git a/go/internal/e2e/session_fs_test.go b/go/internal/e2e/session_fs_test.go new file mode 100644 index 000000000..0f51791db --- /dev/null +++ b/go/internal/e2e/session_fs_test.go @@ -0,0 +1,443 @@ +package e2e + +import ( + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + "testing" + "time" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" + "github.com/github/copilot-sdk/go/rpc" +) + +func TestSessionFs(t *testing.T) { + ctx := testharness.NewTestContext(t) + providerRoot := t.TempDir() + createSessionFsHandler := func(session *copilot.Session) rpc.SessionFsHandler { + return &testSessionFsHandler{ + root: providerRoot, + sessionID: session.SessionID, + } + } + p := func(sessionID string, path string) string { + return providerPath(providerRoot, sessionID, path) + } + + client := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.SessionFs = sessionFsConfig + }) + t.Cleanup(func() { client.ForceStop() }) + + t.Run("should route file operations through the session fs provider", func(t *testing.T) { + ctx.ConfigureForTest(t) + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + CreateSessionFsHandler: createSessionFsHandler, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + msg, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 100 + 200?"}) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + content := "" + if msg != nil && msg.Data.Content != nil { + content = *msg.Data.Content + } + if !strings.Contains(content, "300") { + t.Fatalf("Expected response to contain 300, got %q", content) + } + if err := session.Disconnect(); err != nil { + t.Fatalf("Failed to disconnect session: %v", err) + } + + events, err := os.ReadFile(p(session.SessionID, "/session-state/events.jsonl")) + if err != nil { + t.Fatalf("Failed to read events file: %v", err) + } + if !strings.Contains(string(events), "300") { + t.Fatalf("Expected events file to contain 300") + } + }) + + t.Run("should load session data from fs provider on resume", func(t *testing.T) { + ctx.ConfigureForTest(t) + + session1, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + CreateSessionFsHandler: createSessionFsHandler, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + sessionID := session1.SessionID + + msg, err := session1.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 50 + 50?"}) + if err != nil { + t.Fatalf("Failed to send first message: %v", err) + } + content := "" + if msg != nil && msg.Data.Content != nil { + content = *msg.Data.Content + } + if !strings.Contains(content, "100") { + t.Fatalf("Expected response to contain 100, got %q", content) + } + if err := session1.Disconnect(); err != nil { + t.Fatalf("Failed to disconnect first session: %v", err) + } + + if _, err := os.Stat(p(sessionID, "/session-state/events.jsonl")); err != nil { + t.Fatalf("Expected events file to exist before resume: %v", err) + } + + session2, err := client.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + CreateSessionFsHandler: createSessionFsHandler, + }) + if err != nil { + t.Fatalf("Failed to resume session: %v", err) + } + + msg2, err := session2.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is that times 3?"}) + if err != nil { + t.Fatalf("Failed to send second message: %v", err) + } + content2 := "" + if msg2 != nil && msg2.Data.Content != nil { + content2 = *msg2.Data.Content + } + if !strings.Contains(content2, "300") { + t.Fatalf("Expected response to contain 300, got %q", content2) + } + if err := session2.Disconnect(); err != nil { + t.Fatalf("Failed to disconnect resumed session: %v", err) + } + }) + + t.Run("should reject setProvider when sessions already exist", func(t *testing.T) { + ctx.ConfigureForTest(t) + + client1 := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + }) + t.Cleanup(func() { client1.ForceStop() }) + + if _, err := client1.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }); err != nil { + t.Fatalf("Failed to create initial session: %v", err) + } + + actualPort := client1.ActualPort() + if actualPort == 0 { + t.Fatalf("Expected non-zero port from TCP mode client") + } + + client2 := copilot.NewClient(&copilot.ClientOptions{ + CLIUrl: fmt.Sprintf("localhost:%d", actualPort), + LogLevel: "error", + Env: ctx.Env(), + SessionFs: sessionFsConfig, + }) + t.Cleanup(func() { client2.ForceStop() }) + + if err := client2.Start(t.Context()); err == nil { + t.Fatal("Expected Start to fail when sessionFs provider is set after sessions already exist") + } + }) + + t.Run("should map large output handling into sessionFs", func(t *testing.T) { + ctx.ConfigureForTest(t) + + suppliedFileContent := strings.Repeat("x", 100_000) + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + CreateSessionFsHandler: createSessionFsHandler, + Tools: []copilot.Tool{ + copilot.DefineTool("get_big_string", "Returns a large string", + func(_ struct{}, inv copilot.ToolInvocation) (string, error) { + return suppliedFileContent, nil + }), + }, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + if _, err := session.SendAndWait(t.Context(), copilot.MessageOptions{ + Prompt: "Call the get_big_string tool and reply with the word DONE only.", + }); err != nil { + t.Fatalf("Failed to send message: %v", err) + } + + messages, err := session.GetMessages(t.Context()) + if err != nil { + t.Fatalf("Failed to get messages: %v", err) + } + toolResult := findToolCallResult(messages, "get_big_string") + if !strings.Contains(toolResult, "/session-state/temp/") { + t.Fatalf("Expected tool result to reference /session-state/temp/, got %q", toolResult) + } + match := regexp.MustCompile(`(/session-state/temp/[^\s]+)`).FindStringSubmatch(toolResult) + if len(match) < 2 { + t.Fatalf("Expected temp file path in tool result, got %q", toolResult) + } + + fileContent, err := os.ReadFile(p(session.SessionID, match[1])) + if err != nil { + t.Fatalf("Failed to read temp file: %v", err) + } + if string(fileContent) != suppliedFileContent { + t.Fatalf("Expected temp file content to match supplied content") + } + }) + + t.Run("should succeed with compaction while using sessionFs", func(t *testing.T) { + ctx.ConfigureForTest(t) + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + CreateSessionFsHandler: createSessionFsHandler, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + if _, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 2+2?"}); err != nil { + t.Fatalf("Failed to send message: %v", err) + } + + eventsPath := p(session.SessionID, "/session-state/events.jsonl") + if err := waitForFile(eventsPath, 5*time.Second); err != nil { + t.Fatalf("Timed out waiting for events file: %v", err) + } + contentBefore, err := os.ReadFile(eventsPath) + if err != nil { + t.Fatalf("Failed to read events file before compaction: %v", err) + } + if strings.Contains(string(contentBefore), "checkpointNumber") { + t.Fatalf("Expected events file to not contain checkpointNumber before compaction") + } + + compactionResult, err := session.RPC.Compaction.Compact(t.Context()) + if err != nil { + t.Fatalf("Failed to compact session: %v", err) + } + if compactionResult == nil || !compactionResult.Success { + t.Fatalf("Expected compaction to succeed, got %+v", compactionResult) + } + + if err := waitForFileContent(eventsPath, "checkpointNumber", 5*time.Second); err != nil { + t.Fatalf("Timed out waiting for checkpoint rewrite: %v", err) + } + }) +} + +var sessionFsConfig = &copilot.SessionFsConfig{ + InitialCwd: "/", + SessionStatePath: "/session-state", + Conventions: rpc.ConventionsPosix, +} + +type testSessionFsHandler struct { + root string + sessionID string +} + +func (h *testSessionFsHandler) ReadFile(request *rpc.SessionFSReadFileParams) (*rpc.SessionFSReadFileResult, error) { + content, err := os.ReadFile(providerPath(h.root, h.sessionID, request.Path)) + if err != nil { + return nil, err + } + return &rpc.SessionFSReadFileResult{Content: string(content)}, nil +} + +func (h *testSessionFsHandler) WriteFile(request *rpc.SessionFSWriteFileParams) error { + path := providerPath(h.root, h.sessionID, request.Path) + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + mode := os.FileMode(0o666) + if request.Mode != nil { + mode = os.FileMode(uint32(*request.Mode)) + } + return os.WriteFile(path, []byte(request.Content), mode) +} + +func (h *testSessionFsHandler) AppendFile(request *rpc.SessionFSAppendFileParams) error { + path := providerPath(h.root, h.sessionID, request.Path) + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + mode := os.FileMode(0o666) + if request.Mode != nil { + mode = os.FileMode(uint32(*request.Mode)) + } + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, mode) + if err != nil { + return err + } + defer f.Close() + _, err = f.WriteString(request.Content) + return err +} + +func (h *testSessionFsHandler) Exists(request *rpc.SessionFSExistsParams) (*rpc.SessionFSExistsResult, error) { + _, err := os.Stat(providerPath(h.root, h.sessionID, request.Path)) + if err == nil { + return &rpc.SessionFSExistsResult{Exists: true}, nil + } + if os.IsNotExist(err) { + return &rpc.SessionFSExistsResult{Exists: false}, nil + } + return nil, err +} + +func (h *testSessionFsHandler) Stat(request *rpc.SessionFSStatParams) (*rpc.SessionFSStatResult, error) { + info, err := os.Stat(providerPath(h.root, h.sessionID, request.Path)) + if err != nil { + return nil, err + } + ts := info.ModTime().UTC().Format(time.RFC3339) + return &rpc.SessionFSStatResult{ + IsFile: !info.IsDir(), + IsDirectory: info.IsDir(), + Size: float64(info.Size()), + Mtime: ts, + Birthtime: ts, + }, nil +} + +func (h *testSessionFsHandler) Mkdir(request *rpc.SessionFSMkdirParams) error { + path := providerPath(h.root, h.sessionID, request.Path) + mode := os.FileMode(0o777) + if request.Mode != nil { + mode = os.FileMode(uint32(*request.Mode)) + } + if request.Recursive != nil && *request.Recursive { + return os.MkdirAll(path, mode) + } + return os.Mkdir(path, mode) +} + +func (h *testSessionFsHandler) Readdir(request *rpc.SessionFSReaddirParams) (*rpc.SessionFSReaddirResult, error) { + entries, err := os.ReadDir(providerPath(h.root, h.sessionID, request.Path)) + if err != nil { + return nil, err + } + names := make([]string, 0, len(entries)) + for _, entry := range entries { + names = append(names, entry.Name()) + } + return &rpc.SessionFSReaddirResult{Entries: names}, nil +} + +func (h *testSessionFsHandler) ReaddirWithTypes(request *rpc.SessionFSReaddirWithTypesParams) (*rpc.SessionFSReaddirWithTypesResult, error) { + entries, err := os.ReadDir(providerPath(h.root, h.sessionID, request.Path)) + if err != nil { + return nil, err + } + result := make([]rpc.Entry, 0, len(entries)) + for _, entry := range entries { + entryType := rpc.EntryTypeFile + if entry.IsDir() { + entryType = rpc.EntryTypeDirectory + } + result = append(result, rpc.Entry{ + Name: entry.Name(), + Type: entryType, + }) + } + return &rpc.SessionFSReaddirWithTypesResult{Entries: result}, nil +} + +func (h *testSessionFsHandler) Rm(request *rpc.SessionFSRmParams) error { + path := providerPath(h.root, h.sessionID, request.Path) + if request.Recursive != nil && *request.Recursive { + err := os.RemoveAll(path) + if err != nil && request.Force != nil && *request.Force && os.IsNotExist(err) { + return nil + } + return err + } + err := os.Remove(path) + if err != nil && request.Force != nil && *request.Force && os.IsNotExist(err) { + return nil + } + return err +} + +func (h *testSessionFsHandler) Rename(request *rpc.SessionFSRenameParams) error { + dest := providerPath(h.root, h.sessionID, request.Dest) + if err := os.MkdirAll(filepath.Dir(dest), 0o755); err != nil { + return err + } + return os.Rename( + providerPath(h.root, h.sessionID, request.Src), + dest, + ) +} + +func providerPath(root string, sessionID string, path string) string { + trimmed := strings.TrimPrefix(path, "/") + if trimmed == "" { + return filepath.Join(root, sessionID) + } + return filepath.Join(root, sessionID, filepath.FromSlash(trimmed)) +} + +func findToolCallResult(messages []copilot.SessionEvent, toolName string) string { + for _, message := range messages { + if message.Type == "tool.execution_complete" && + message.Data.Result != nil && + message.Data.Result.Content != nil && + message.Data.ToolCallID != nil && + findToolName(messages, *message.Data.ToolCallID) == toolName { + return *message.Data.Result.Content + } + } + return "" +} + +func findToolName(messages []copilot.SessionEvent, toolCallID string) string { + for _, message := range messages { + if message.Type == "tool.execution_start" && + message.Data.ToolCallID != nil && + *message.Data.ToolCallID == toolCallID && + message.Data.ToolName != nil { + return *message.Data.ToolName + } + } + return "" +} + +func waitForFile(path string, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if _, err := os.Stat(path); err == nil { + return nil + } + time.Sleep(50 * time.Millisecond) + } + return fmt.Errorf("file did not appear: %s", path) +} + +func waitForFileContent(path string, needle string, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + content, err := os.ReadFile(path) + if err == nil && strings.Contains(string(content), needle) { + return nil + } + time.Sleep(50 * time.Millisecond) + } + return fmt.Errorf("file %s did not contain %q", path, needle) +} diff --git a/go/internal/e2e/testharness/context.go b/go/internal/e2e/testharness/context.go index 1ec68d77e..269b53789 100644 --- a/go/internal/e2e/testharness/context.go +++ b/go/internal/e2e/testharness/context.go @@ -166,15 +166,15 @@ func (c *TestContext) NewClient(opts ...func(*copilot.ClientOptions)) *copilot.C Env: c.Env(), } - // Use fake token in CI to allow cached responses without real auth - if os.Getenv("GITHUB_ACTIONS") == "true" { - options.GitHubToken = "fake-token-for-e2e-tests" - } - for _, opt := range opts { opt(options) } + // Use fake token in CI to allow cached responses without real auth for spawned subprocess clients. + if os.Getenv("GITHUB_ACTIONS") == "true" && options.GitHubToken == "" && options.CLIUrl == "" { + options.GitHubToken = "fake-token-for-e2e-tests" + } + return copilot.NewClient(options) } diff --git a/go/rpc/generated_rpc.go b/go/rpc/generated_rpc.go index 6eee90963..c32510083 100644 --- a/go/rpc/generated_rpc.go +++ b/go/rpc/generated_rpc.go @@ -6,7 +6,8 @@ package rpc import ( "context" "encoding/json" - + "errors" + "fmt" "github.com/github/copilot-sdk/go/internal/jsonrpc2" ) @@ -749,6 +750,134 @@ type SessionShellKillParams struct { Signal *Signal `json:"signal,omitempty"` } +type SessionFSReadFileResult struct { + // File content as UTF-8 string + Content string `json:"content"` +} + +type SessionFSReadFileParams struct { + // Path using SessionFs conventions + Path string `json:"path"` + // Target session identifier + SessionID string `json:"sessionId"` +} + +type SessionFSWriteFileParams struct { + // Content to write + Content string `json:"content"` + // Optional POSIX-style mode for newly created files + Mode *float64 `json:"mode,omitempty"` + // Path using SessionFs conventions + Path string `json:"path"` + // Target session identifier + SessionID string `json:"sessionId"` +} + +type SessionFSAppendFileParams struct { + // Content to append + Content string `json:"content"` + // Optional POSIX-style mode for newly created files + Mode *float64 `json:"mode,omitempty"` + // Path using SessionFs conventions + Path string `json:"path"` + // Target session identifier + SessionID string `json:"sessionId"` +} + +type SessionFSExistsResult struct { + // Whether the path exists + Exists bool `json:"exists"` +} + +type SessionFSExistsParams struct { + // Path using SessionFs conventions + Path string `json:"path"` + // Target session identifier + SessionID string `json:"sessionId"` +} + +type SessionFSStatResult struct { + // ISO 8601 timestamp of creation + Birthtime string `json:"birthtime"` + // Whether the path is a directory + IsDirectory bool `json:"isDirectory"` + // Whether the path is a file + IsFile bool `json:"isFile"` + // ISO 8601 timestamp of last modification + Mtime string `json:"mtime"` + // File size in bytes + Size float64 `json:"size"` +} + +type SessionFSStatParams struct { + // Path using SessionFs conventions + Path string `json:"path"` + // Target session identifier + SessionID string `json:"sessionId"` +} + +type SessionFSMkdirParams struct { + // Optional POSIX-style mode for newly created directories + Mode *float64 `json:"mode,omitempty"` + // Path using SessionFs conventions + Path string `json:"path"` + // Create parent directories as needed + Recursive *bool `json:"recursive,omitempty"` + // Target session identifier + SessionID string `json:"sessionId"` +} + +type SessionFSReaddirResult struct { + // Entry names in the directory + Entries []string `json:"entries"` +} + +type SessionFSReaddirParams struct { + // Path using SessionFs conventions + Path string `json:"path"` + // Target session identifier + SessionID string `json:"sessionId"` +} + +type SessionFSReaddirWithTypesResult struct { + // Directory entries with type information + Entries []Entry `json:"entries"` +} + +type Entry struct { + // Entry name + Name string `json:"name"` + // Entry type + Type EntryType `json:"type"` +} + +type SessionFSReaddirWithTypesParams struct { + // Path using SessionFs conventions + Path string `json:"path"` + // Target session identifier + SessionID string `json:"sessionId"` +} + +type SessionFSRmParams struct { + // Ignore errors if the path does not exist + Force *bool `json:"force,omitempty"` + // Path using SessionFs conventions + Path string `json:"path"` + // Remove directories and their contents recursively + Recursive *bool `json:"recursive,omitempty"` + // Target session identifier + SessionID string `json:"sessionId"` +} + +type SessionFSRenameParams struct { + // Destination path using SessionFs conventions + Dest string `json:"dest"` + // Target session identifier + SessionID string `json:"sessionId"` + // Source path using SessionFs conventions + Src string `json:"src"` +} + type FilterMappingEnum string const ( @@ -887,6 +1016,14 @@ const ( SignalSIGTERM Signal = "SIGTERM" ) +// Entry type +type EntryType string + +const ( + EntryTypeDirectory EntryType = "directory" + EntryTypeFile EntryType = "file" +) + type FilterMappingUnion struct { Enum *FilterMappingEnum EnumMap map[string]FilterMappingEnum @@ -1683,3 +1820,201 @@ func NewSessionRpc(client *jsonrpc2.Client, sessionID string) *SessionRpc { r.Shell = (*ShellApi)(&r.common) return r } + +type SessionFsHandler interface { + ReadFile(request *SessionFSReadFileParams) (*SessionFSReadFileResult, error) + WriteFile(request *SessionFSWriteFileParams) error + AppendFile(request *SessionFSAppendFileParams) error + Exists(request *SessionFSExistsParams) (*SessionFSExistsResult, error) + Stat(request *SessionFSStatParams) (*SessionFSStatResult, error) + Mkdir(request *SessionFSMkdirParams) error + Readdir(request *SessionFSReaddirParams) (*SessionFSReaddirResult, error) + ReaddirWithTypes(request *SessionFSReaddirWithTypesParams) (*SessionFSReaddirWithTypesResult, error) + Rm(request *SessionFSRmParams) error + Rename(request *SessionFSRenameParams) error +} + +// ClientSessionApiHandlers provides all client session API handler groups for a session. +type ClientSessionApiHandlers struct { + SessionFs SessionFsHandler +} + +func clientSessionHandlerError(err error) *jsonrpc2.Error { + if err == nil { + return nil + } + var rpcErr *jsonrpc2.Error + if errors.As(err, &rpcErr) { + return rpcErr + } + return &jsonrpc2.Error{Code: -32603, Message: err.Error()} +} + +// RegisterClientSessionApiHandlers registers handlers for server-to-client session API calls. +func RegisterClientSessionApiHandlers(client *jsonrpc2.Client, getHandlers func(sessionID string) *ClientSessionApiHandlers) { + client.SetRequestHandler("sessionFs.readFile", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request SessionFSReadFileParams + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + handlers := getHandlers(request.SessionID) + if handlers == nil || handlers.SessionFs == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("No sessionFs handler registered for session: %s", request.SessionID)} + } + result, err := handlers.SessionFs.ReadFile(&request) + if err != nil { + return nil, clientSessionHandlerError(err) + } + raw, err := json.Marshal(result) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("Failed to marshal response: %v", err)} + } + return raw, nil + }) + client.SetRequestHandler("sessionFs.writeFile", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request SessionFSWriteFileParams + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + handlers := getHandlers(request.SessionID) + if handlers == nil || handlers.SessionFs == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("No sessionFs handler registered for session: %s", request.SessionID)} + } + if err := handlers.SessionFs.WriteFile(&request); err != nil { + return nil, clientSessionHandlerError(err) + } + return json.RawMessage("null"), nil + }) + client.SetRequestHandler("sessionFs.appendFile", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request SessionFSAppendFileParams + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + handlers := getHandlers(request.SessionID) + if handlers == nil || handlers.SessionFs == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("No sessionFs handler registered for session: %s", request.SessionID)} + } + if err := handlers.SessionFs.AppendFile(&request); err != nil { + return nil, clientSessionHandlerError(err) + } + return json.RawMessage("null"), nil + }) + client.SetRequestHandler("sessionFs.exists", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request SessionFSExistsParams + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + handlers := getHandlers(request.SessionID) + if handlers == nil || handlers.SessionFs == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("No sessionFs handler registered for session: %s", request.SessionID)} + } + result, err := handlers.SessionFs.Exists(&request) + if err != nil { + return nil, clientSessionHandlerError(err) + } + raw, err := json.Marshal(result) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("Failed to marshal response: %v", err)} + } + return raw, nil + }) + client.SetRequestHandler("sessionFs.stat", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request SessionFSStatParams + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + handlers := getHandlers(request.SessionID) + if handlers == nil || handlers.SessionFs == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("No sessionFs handler registered for session: %s", request.SessionID)} + } + result, err := handlers.SessionFs.Stat(&request) + if err != nil { + return nil, clientSessionHandlerError(err) + } + raw, err := json.Marshal(result) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("Failed to marshal response: %v", err)} + } + return raw, nil + }) + client.SetRequestHandler("sessionFs.mkdir", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request SessionFSMkdirParams + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + handlers := getHandlers(request.SessionID) + if handlers == nil || handlers.SessionFs == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("No sessionFs handler registered for session: %s", request.SessionID)} + } + if err := handlers.SessionFs.Mkdir(&request); err != nil { + return nil, clientSessionHandlerError(err) + } + return json.RawMessage("null"), nil + }) + client.SetRequestHandler("sessionFs.readdir", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request SessionFSReaddirParams + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + handlers := getHandlers(request.SessionID) + if handlers == nil || handlers.SessionFs == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("No sessionFs handler registered for session: %s", request.SessionID)} + } + result, err := handlers.SessionFs.Readdir(&request) + if err != nil { + return nil, clientSessionHandlerError(err) + } + raw, err := json.Marshal(result) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("Failed to marshal response: %v", err)} + } + return raw, nil + }) + client.SetRequestHandler("sessionFs.readdirWithTypes", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request SessionFSReaddirWithTypesParams + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + handlers := getHandlers(request.SessionID) + if handlers == nil || handlers.SessionFs == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("No sessionFs handler registered for session: %s", request.SessionID)} + } + result, err := handlers.SessionFs.ReaddirWithTypes(&request) + if err != nil { + return nil, clientSessionHandlerError(err) + } + raw, err := json.Marshal(result) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("Failed to marshal response: %v", err)} + } + return raw, nil + }) + client.SetRequestHandler("sessionFs.rm", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request SessionFSRmParams + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + handlers := getHandlers(request.SessionID) + if handlers == nil || handlers.SessionFs == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("No sessionFs handler registered for session: %s", request.SessionID)} + } + if err := handlers.SessionFs.Rm(&request); err != nil { + return nil, clientSessionHandlerError(err) + } + return json.RawMessage("null"), nil + }) + client.SetRequestHandler("sessionFs.rename", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request SessionFSRenameParams + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + handlers := getHandlers(request.SessionID) + if handlers == nil || handlers.SessionFs == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("No sessionFs handler registered for session: %s", request.SessionID)} + } + if err := handlers.SessionFs.Rename(&request); err != nil { + return nil, clientSessionHandlerError(err) + } + return json.RawMessage("null"), nil + }) +} diff --git a/go/session.go b/go/session.go index 71facb03b..8108180cc 100644 --- a/go/session.go +++ b/go/session.go @@ -53,6 +53,7 @@ type Session struct { SessionID string workspacePath string client *jsonrpc2.Client + clientSessionApis *rpc.ClientSessionApiHandlers handlers []sessionHandler nextHandlerID uint64 handlerMutex sync.RWMutex @@ -92,14 +93,15 @@ func (s *Session) WorkspacePath() string { // newSession creates a new session wrapper with the given session ID and client. func newSession(sessionID string, client *jsonrpc2.Client, workspacePath string) *Session { s := &Session{ - SessionID: sessionID, - workspacePath: workspacePath, - client: client, - handlers: make([]sessionHandler, 0), - toolHandlers: make(map[string]ToolHandler), - commandHandlers: make(map[string]CommandHandler), - eventCh: make(chan SessionEvent, 128), - RPC: rpc.NewSessionRpc(client, sessionID), + SessionID: sessionID, + workspacePath: workspacePath, + client: client, + clientSessionApis: &rpc.ClientSessionApiHandlers{}, + handlers: make([]sessionHandler, 0), + toolHandlers: make(map[string]ToolHandler), + commandHandlers: make(map[string]CommandHandler), + eventCh: make(chan SessionEvent, 128), + RPC: rpc.NewSessionRpc(client, sessionID), } go s.processEvents() return s diff --git a/go/types.go b/go/types.go index ff9b4aed3..d80a80f54 100644 --- a/go/types.go +++ b/go/types.go @@ -63,6 +63,10 @@ type ClientOptions struct { // querying the CLI server. Useful in BYOK mode to return models // available from your custom provider. OnListModels func(ctx context.Context) ([]ModelInfo, error) + // SessionFs configures a custom session filesystem provider. + // When provided, the client registers as the session filesystem provider + // on connection, routing session-scoped file I/O through per-session handlers. + SessionFs *SessionFsConfig // Telemetry configures OpenTelemetry integration for the Copilot CLI process. // When non-nil, COPILOT_OTEL_ENABLED=true is set and any populated fields // are mapped to the corresponding environment variables. @@ -434,6 +438,17 @@ type InfiniteSessionConfig struct { BufferExhaustionThreshold *float64 `json:"bufferExhaustionThreshold,omitempty"` } +// SessionFsConfig configures a custom session filesystem provider. +type SessionFsConfig struct { + // InitialCwd is the initial working directory for sessions. + InitialCwd string + // SessionStatePath is the path within each session's filesystem where the runtime stores + // session-scoped files such as events, checkpoints, and temp files. + SessionStatePath string + // Conventions identifies the path conventions used by this filesystem provider. + Conventions rpc.Conventions +} + // SessionConfig configures a new session type SessionConfig struct { // SessionID is an optional custom session ID @@ -500,6 +515,9 @@ type SessionConfig struct { // handler. Equivalent to calling session.On(handler) immediately after creation, // but executes earlier in the lifecycle so no events are missed. OnEvent SessionEventHandler + // CreateSessionFsHandler supplies a handler for session filesystem operations. + // This takes effect only when ClientOptions.SessionFs is configured. + CreateSessionFsHandler func(session *Session) rpc.SessionFsHandler // Commands registers slash-commands for this session. Each command appears as // /name in the CLI TUI for the user to invoke. The Handler is called when the // command is executed. @@ -697,6 +715,9 @@ type ResumeSessionConfig struct { // OnEvent is an optional event handler registered before the session.resume RPC // is issued, ensuring early events are delivered. See SessionConfig.OnEvent. OnEvent SessionEventHandler + // CreateSessionFsHandler supplies a handler for session filesystem operations. + // This takes effect only when ClientOptions.SessionFs is configured. + CreateSessionFsHandler func(session *Session) rpc.SessionFsHandler // Commands registers slash-commands for this session. See SessionConfig.Commands. Commands []CommandDefinition // OnElicitationRequest is a handler for elicitation requests from the server. diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index e61afcacf..5fdbf0358 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -297,6 +297,10 @@ export class CopilotClient { ); } + if (options.sessionFs) { + this.validateSessionFsConfig(options.sessionFs); + } + // Parse cliUrl if provided if (options.cliUrl) { const { host, port } = this.parseCliUrl(options.cliUrl); @@ -367,6 +371,20 @@ export class CopilotClient { return { host, port }; } + private validateSessionFsConfig(config: SessionFsConfig): void { + if (!config.initialCwd) { + throw new Error("sessionFs.initialCwd is required"); + } + + if (!config.sessionStatePath) { + throw new Error("sessionFs.sessionStatePath is required"); + } + + if (config.conventions !== "windows" && config.conventions !== "posix") { + throw new Error("sessionFs.conventions must be either 'windows' or 'posix'"); + } + } + /** * Starts the CLI server and establishes a connection. * diff --git a/nodejs/test/client.test.ts b/nodejs/test/client.test.ts index cf9b63252..c3f0770cd 100644 --- a/nodejs/test/client.test.ts +++ b/nodejs/test/client.test.ts @@ -278,6 +278,34 @@ describe("CopilotClient", () => { }); }); + describe("SessionFs config", () => { + it("throws when initialCwd is missing", () => { + expect(() => { + new CopilotClient({ + sessionFs: { + initialCwd: "", + sessionStatePath: "/session-state", + conventions: "posix", + }, + logLevel: "error", + }); + }).toThrow(/sessionFs\.initialCwd is required/); + }); + + it("throws when sessionStatePath is missing", () => { + expect(() => { + new CopilotClient({ + sessionFs: { + initialCwd: "/", + sessionStatePath: "", + conventions: "posix", + }, + logLevel: "error", + }); + }).toThrow(/sessionFs\.sessionStatePath is required/); + }); + }); + describe("Auth options", () => { it("should accept githubToken option", () => { const client = new CopilotClient({ diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index db9f150c8..702d35035 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -17,12 +17,15 @@ CommandContext, CommandDefinition, CopilotSession, + CreateSessionFsHandler, ElicitationContext, ElicitationHandler, ElicitationParams, ElicitationResult, InputOptions, SessionCapabilities, + SessionFsConfig, + SessionFsHandler, SessionUiApi, SessionUiCapabilities, ) @@ -35,6 +38,7 @@ "CommandDefinition", "CopilotClient", "CopilotSession", + "CreateSessionFsHandler", "ElicitationHandler", "ElicitationParams", "ElicitationContext", @@ -46,6 +50,8 @@ "ModelSupportsOverride", "ModelVisionLimitsOverride", "SessionCapabilities", + "SessionFsConfig", + "SessionFsHandler", "SessionUiApi", "SessionUiCapabilities", "SubprocessConfig", diff --git a/python/copilot/client.py b/python/copilot/client.py index df6756cfe..8be8b8220 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -32,11 +32,16 @@ from ._jsonrpc import JsonRpcClient, ProcessExitedError from ._sdk_protocol_version import get_sdk_protocol_version from ._telemetry import get_trace_context, trace_context -from .generated.rpc import ServerRpc +from .generated.rpc import ( + ClientSessionApiHandlers, + ServerRpc, + register_client_session_api_handlers, +) from .generated.session_events import PermissionRequest, SessionEvent, session_event_from_dict from .session import ( CommandDefinition, CopilotSession, + CreateSessionFsHandler, CustomAgentConfig, ElicitationHandler, InfiniteSessionConfig, @@ -44,6 +49,7 @@ ProviderConfig, ReasoningEffort, SectionTransformFn, + SessionFsConfig, SessionHooks, SystemMessageConfig, UserInputHandler, @@ -60,6 +66,15 @@ LogLevel = Literal["none", "error", "warning", "info", "debug", "all"] +def _validate_session_fs_config(config: SessionFsConfig) -> None: + if not config.get("initial_cwd"): + raise ValueError("session_fs.initial_cwd is required") + if not config.get("session_state_path"): + raise ValueError("session_fs.session_state_path is required") + if config.get("conventions") not in ("posix", "windows"): + raise ValueError("session_fs.conventions must be either 'posix' or 'windows'") + + class TelemetryConfig(TypedDict, total=False): """Configuration for OpenTelemetry integration with the Copilot CLI.""" @@ -126,6 +141,9 @@ class SubprocessConfig: telemetry: TelemetryConfig | None = None """OpenTelemetry configuration. Providing this enables telemetry — no separate flag needed.""" + session_fs: SessionFsConfig | None = None + """Connection-level session filesystem provider configuration.""" + @dataclass class ExternalServerConfig: @@ -139,6 +157,11 @@ class ExternalServerConfig: url: str """Server URL. Supports ``"host:port"``, ``"http://host:port"``, or just ``"port"``.""" + _: KW_ONLY + + session_fs: SessionFsConfig | None = None + """Connection-level session filesystem provider configuration.""" + # ============================================================================ # Response Types @@ -889,6 +912,9 @@ def __init__( self._lifecycle_handlers_lock = threading.Lock() self._rpc: ServerRpc | None = None self._negotiated_protocol_version: int | None = None + if config.session_fs is not None: + _validate_session_fs_config(config.session_fs) + self._session_fs_config = config.session_fs @property def rpc(self) -> ServerRpc: @@ -1018,6 +1044,9 @@ async def start(self) -> None: # Verify protocol version compatibility await self._verify_protocol_version() + if self._session_fs_config: + await self._set_session_fs_provider() + self._state = "connected" except ProcessExitedError as e: # Process exited with error - reraise as RuntimeError with stderr @@ -1179,6 +1208,7 @@ async def create_session( on_event: Callable[[SessionEvent], None] | None = None, commands: list[CommandDefinition] | None = None, on_elicitation_request: ElicitationHandler | None = None, + create_session_fs_handler: CreateSessionFsHandler | None = None, ) -> CopilotSession: """ Create a new conversation session with the Copilot CLI. @@ -1368,6 +1398,13 @@ async def create_session( # Create and register the session before issuing the RPC so that # events emitted by the CLI (e.g. session.start) are not dropped. session = CopilotSession(actual_session_id, self._client, workspace_path=None) + if self._session_fs_config: + if create_session_fs_handler is None: + raise ValueError( + "create_session_fs_handler is required in session config when " + "session_fs is enabled in client options." + ) + session._client_session_apis.session_fs = create_session_fs_handler(session) session._register_tools(tools) session._register_commands(commands) session._register_permission_handler(on_permission_request) @@ -1424,6 +1461,7 @@ async def resume_session( on_event: Callable[[SessionEvent], None] | None = None, commands: list[CommandDefinition] | None = None, on_elicitation_request: ElicitationHandler | None = None, + create_session_fs_handler: CreateSessionFsHandler | None = None, ) -> CopilotSession: """ Resume an existing conversation session by its ID. @@ -1592,6 +1630,13 @@ async def resume_session( # Create and register the session before issuing the RPC so that # events emitted by the CLI (e.g. session.start) are not dropped. session = CopilotSession(session_id, self._client, workspace_path=None) + if self._session_fs_config: + if create_session_fs_handler is None: + raise ValueError( + "create_session_fs_handler is required in session config when " + "session_fs is enabled in client options." + ) + session._client_session_apis.session_fs = create_session_fs_handler(session) session._register_tools(tools) session._register_commands(commands) session._register_permission_handler(on_permission_request) @@ -2283,6 +2328,7 @@ def handle_notification(method: str, params: dict): self._client.set_request_handler( "systemMessage.transform", self._handle_system_message_transform ) + register_client_session_api_handlers(self._client, self._get_client_session_handlers) # Start listening for messages loop = asyncio.get_running_loop() @@ -2387,11 +2433,32 @@ def handle_notification(method: str, params: dict): self._client.set_request_handler( "systemMessage.transform", self._handle_system_message_transform ) + register_client_session_api_handlers(self._client, self._get_client_session_handlers) # Start listening for messages loop = asyncio.get_running_loop() self._client.start(loop) + async def _set_session_fs_provider(self) -> None: + if not self._session_fs_config or not self._client: + return + + await self._client.request( + "sessionFs.setProvider", + { + "initialCwd": self._session_fs_config["initial_cwd"], + "sessionStatePath": self._session_fs_config["session_state_path"], + "conventions": self._session_fs_config["conventions"], + }, + ) + + def _get_client_session_handlers(self, session_id: str) -> ClientSessionApiHandlers: + with self._sessions_lock: + session = self._sessions.get(session_id) + if session is None: + raise ValueError(f"unknown session {session_id}") + return session._client_session_apis + async def _handle_user_input_request(self, params: dict) -> dict: """ Handle a user input request from the CLI server. diff --git a/python/copilot/generated/rpc.py b/python/copilot/generated/rpc.py index 93b80ee4f..52cc891a4 100644 --- a/python/copilot/generated/rpc.py +++ b/python/copilot/generated/rpc.py @@ -8,6 +8,10 @@ if TYPE_CHECKING: from .._jsonrpc import JsonRpcClient +from collections.abc import Callable +from dataclasses import dataclass +from typing import Protocol + from dataclasses import dataclass from typing import Any, TypeVar, Callable, cast @@ -2626,6 +2630,411 @@ def to_dict(self) -> dict: return result +@dataclass +class SessionFSReadFileResult: + content: str + """File content as UTF-8 string""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSReadFileResult': + assert isinstance(obj, dict) + content = from_str(obj.get("content")) + return SessionFSReadFileResult(content) + + def to_dict(self) -> dict: + result: dict = {} + result["content"] = from_str(self.content) + return result + + +@dataclass +class SessionFSReadFileParams: + path: str + """Path using SessionFs conventions""" + + session_id: str + """Target session identifier""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSReadFileParams': + assert isinstance(obj, dict) + path = from_str(obj.get("path")) + session_id = from_str(obj.get("sessionId")) + return SessionFSReadFileParams(path, session_id) + + def to_dict(self) -> dict: + result: dict = {} + result["path"] = from_str(self.path) + result["sessionId"] = from_str(self.session_id) + return result + + +@dataclass +class SessionFSWriteFileParams: + content: str + """Content to write""" + + path: str + """Path using SessionFs conventions""" + + session_id: str + """Target session identifier""" + + mode: float | None = None + """Optional POSIX-style mode for newly created files""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSWriteFileParams': + assert isinstance(obj, dict) + content = from_str(obj.get("content")) + path = from_str(obj.get("path")) + session_id = from_str(obj.get("sessionId")) + mode = from_union([from_float, from_none], obj.get("mode")) + return SessionFSWriteFileParams(content, path, session_id, mode) + + def to_dict(self) -> dict: + result: dict = {} + result["content"] = from_str(self.content) + result["path"] = from_str(self.path) + result["sessionId"] = from_str(self.session_id) + if self.mode is not None: + result["mode"] = from_union([to_float, from_none], self.mode) + return result + + +@dataclass +class SessionFSAppendFileParams: + content: str + """Content to append""" + + path: str + """Path using SessionFs conventions""" + + session_id: str + """Target session identifier""" + + mode: float | None = None + """Optional POSIX-style mode for newly created files""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSAppendFileParams': + assert isinstance(obj, dict) + content = from_str(obj.get("content")) + path = from_str(obj.get("path")) + session_id = from_str(obj.get("sessionId")) + mode = from_union([from_float, from_none], obj.get("mode")) + return SessionFSAppendFileParams(content, path, session_id, mode) + + def to_dict(self) -> dict: + result: dict = {} + result["content"] = from_str(self.content) + result["path"] = from_str(self.path) + result["sessionId"] = from_str(self.session_id) + if self.mode is not None: + result["mode"] = from_union([to_float, from_none], self.mode) + return result + + +@dataclass +class SessionFSExistsResult: + exists: bool + """Whether the path exists""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSExistsResult': + assert isinstance(obj, dict) + exists = from_bool(obj.get("exists")) + return SessionFSExistsResult(exists) + + def to_dict(self) -> dict: + result: dict = {} + result["exists"] = from_bool(self.exists) + return result + + +@dataclass +class SessionFSExistsParams: + path: str + """Path using SessionFs conventions""" + + session_id: str + """Target session identifier""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSExistsParams': + assert isinstance(obj, dict) + path = from_str(obj.get("path")) + session_id = from_str(obj.get("sessionId")) + return SessionFSExistsParams(path, session_id) + + def to_dict(self) -> dict: + result: dict = {} + result["path"] = from_str(self.path) + result["sessionId"] = from_str(self.session_id) + return result + + +@dataclass +class SessionFSStatResult: + birthtime: str + """ISO 8601 timestamp of creation""" + + is_directory: bool + """Whether the path is a directory""" + + is_file: bool + """Whether the path is a file""" + + mtime: str + """ISO 8601 timestamp of last modification""" + + size: float + """File size in bytes""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSStatResult': + assert isinstance(obj, dict) + birthtime = from_str(obj.get("birthtime")) + is_directory = from_bool(obj.get("isDirectory")) + is_file = from_bool(obj.get("isFile")) + mtime = from_str(obj.get("mtime")) + size = from_float(obj.get("size")) + return SessionFSStatResult(birthtime, is_directory, is_file, mtime, size) + + def to_dict(self) -> dict: + result: dict = {} + result["birthtime"] = from_str(self.birthtime) + result["isDirectory"] = from_bool(self.is_directory) + result["isFile"] = from_bool(self.is_file) + result["mtime"] = from_str(self.mtime) + result["size"] = to_float(self.size) + return result + + +@dataclass +class SessionFSStatParams: + path: str + """Path using SessionFs conventions""" + + session_id: str + """Target session identifier""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSStatParams': + assert isinstance(obj, dict) + path = from_str(obj.get("path")) + session_id = from_str(obj.get("sessionId")) + return SessionFSStatParams(path, session_id) + + def to_dict(self) -> dict: + result: dict = {} + result["path"] = from_str(self.path) + result["sessionId"] = from_str(self.session_id) + return result + + +@dataclass +class SessionFSMkdirParams: + path: str + """Path using SessionFs conventions""" + + session_id: str + """Target session identifier""" + + mode: float | None = None + """Optional POSIX-style mode for newly created directories""" + + recursive: bool | None = None + """Create parent directories as needed""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSMkdirParams': + assert isinstance(obj, dict) + path = from_str(obj.get("path")) + session_id = from_str(obj.get("sessionId")) + mode = from_union([from_float, from_none], obj.get("mode")) + recursive = from_union([from_bool, from_none], obj.get("recursive")) + return SessionFSMkdirParams(path, session_id, mode, recursive) + + def to_dict(self) -> dict: + result: dict = {} + result["path"] = from_str(self.path) + result["sessionId"] = from_str(self.session_id) + if self.mode is not None: + result["mode"] = from_union([to_float, from_none], self.mode) + if self.recursive is not None: + result["recursive"] = from_union([from_bool, from_none], self.recursive) + return result + + +@dataclass +class SessionFSReaddirResult: + entries: list[str] + """Entry names in the directory""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSReaddirResult': + assert isinstance(obj, dict) + entries = from_list(from_str, obj.get("entries")) + return SessionFSReaddirResult(entries) + + def to_dict(self) -> dict: + result: dict = {} + result["entries"] = from_list(from_str, self.entries) + return result + + +@dataclass +class SessionFSReaddirParams: + path: str + """Path using SessionFs conventions""" + + session_id: str + """Target session identifier""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSReaddirParams': + assert isinstance(obj, dict) + path = from_str(obj.get("path")) + session_id = from_str(obj.get("sessionId")) + return SessionFSReaddirParams(path, session_id) + + def to_dict(self) -> dict: + result: dict = {} + result["path"] = from_str(self.path) + result["sessionId"] = from_str(self.session_id) + return result + + +class EntryType(Enum): + """Entry type""" + + DIRECTORY = "directory" + FILE = "file" + + +@dataclass +class Entry: + name: str + """Entry name""" + + type: EntryType + """Entry type""" + + @staticmethod + def from_dict(obj: Any) -> 'Entry': + assert isinstance(obj, dict) + name = from_str(obj.get("name")) + type = EntryType(obj.get("type")) + return Entry(name, type) + + def to_dict(self) -> dict: + result: dict = {} + result["name"] = from_str(self.name) + result["type"] = to_enum(EntryType, self.type) + return result + + +@dataclass +class SessionFSReaddirWithTypesResult: + entries: list[Entry] + """Directory entries with type information""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSReaddirWithTypesResult': + assert isinstance(obj, dict) + entries = from_list(Entry.from_dict, obj.get("entries")) + return SessionFSReaddirWithTypesResult(entries) + + def to_dict(self) -> dict: + result: dict = {} + result["entries"] = from_list(lambda x: to_class(Entry, x), self.entries) + return result + + +@dataclass +class SessionFSReaddirWithTypesParams: + path: str + """Path using SessionFs conventions""" + + session_id: str + """Target session identifier""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSReaddirWithTypesParams': + assert isinstance(obj, dict) + path = from_str(obj.get("path")) + session_id = from_str(obj.get("sessionId")) + return SessionFSReaddirWithTypesParams(path, session_id) + + def to_dict(self) -> dict: + result: dict = {} + result["path"] = from_str(self.path) + result["sessionId"] = from_str(self.session_id) + return result + + +@dataclass +class SessionFSRmParams: + path: str + """Path using SessionFs conventions""" + + session_id: str + """Target session identifier""" + + force: bool | None = None + """Ignore errors if the path does not exist""" + + recursive: bool | None = None + """Remove directories and their contents recursively""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSRmParams': + assert isinstance(obj, dict) + path = from_str(obj.get("path")) + session_id = from_str(obj.get("sessionId")) + force = from_union([from_bool, from_none], obj.get("force")) + recursive = from_union([from_bool, from_none], obj.get("recursive")) + return SessionFSRmParams(path, session_id, force, recursive) + + def to_dict(self) -> dict: + result: dict = {} + result["path"] = from_str(self.path) + result["sessionId"] = from_str(self.session_id) + if self.force is not None: + result["force"] = from_union([from_bool, from_none], self.force) + if self.recursive is not None: + result["recursive"] = from_union([from_bool, from_none], self.recursive) + return result + + +@dataclass +class SessionFSRenameParams: + dest: str + """Destination path using SessionFs conventions""" + + session_id: str + """Target session identifier""" + + src: str + """Source path using SessionFs conventions""" + + @staticmethod + def from_dict(obj: Any) -> 'SessionFSRenameParams': + assert isinstance(obj, dict) + dest = from_str(obj.get("dest")) + session_id = from_str(obj.get("sessionId")) + src = from_str(obj.get("src")) + return SessionFSRenameParams(dest, session_id, src) + + def to_dict(self) -> dict: + result: dict = {} + result["dest"] = from_str(self.dest) + result["sessionId"] = from_str(self.session_id) + result["src"] = from_str(self.src) + return result + + def ping_result_from_dict(s: Any) -> PingResult: return PingResult.from_dict(s) @@ -3194,6 +3603,126 @@ def session_shell_kill_params_to_dict(x: SessionShellKillParams) -> Any: return to_class(SessionShellKillParams, x) +def session_fs_read_file_result_from_dict(s: Any) -> SessionFSReadFileResult: + return SessionFSReadFileResult.from_dict(s) + + +def session_fs_read_file_result_to_dict(x: SessionFSReadFileResult) -> Any: + return to_class(SessionFSReadFileResult, x) + + +def session_fs_read_file_params_from_dict(s: Any) -> SessionFSReadFileParams: + return SessionFSReadFileParams.from_dict(s) + + +def session_fs_read_file_params_to_dict(x: SessionFSReadFileParams) -> Any: + return to_class(SessionFSReadFileParams, x) + + +def session_fs_write_file_params_from_dict(s: Any) -> SessionFSWriteFileParams: + return SessionFSWriteFileParams.from_dict(s) + + +def session_fs_write_file_params_to_dict(x: SessionFSWriteFileParams) -> Any: + return to_class(SessionFSWriteFileParams, x) + + +def session_fs_append_file_params_from_dict(s: Any) -> SessionFSAppendFileParams: + return SessionFSAppendFileParams.from_dict(s) + + +def session_fs_append_file_params_to_dict(x: SessionFSAppendFileParams) -> Any: + return to_class(SessionFSAppendFileParams, x) + + +def session_fs_exists_result_from_dict(s: Any) -> SessionFSExistsResult: + return SessionFSExistsResult.from_dict(s) + + +def session_fs_exists_result_to_dict(x: SessionFSExistsResult) -> Any: + return to_class(SessionFSExistsResult, x) + + +def session_fs_exists_params_from_dict(s: Any) -> SessionFSExistsParams: + return SessionFSExistsParams.from_dict(s) + + +def session_fs_exists_params_to_dict(x: SessionFSExistsParams) -> Any: + return to_class(SessionFSExistsParams, x) + + +def session_fs_stat_result_from_dict(s: Any) -> SessionFSStatResult: + return SessionFSStatResult.from_dict(s) + + +def session_fs_stat_result_to_dict(x: SessionFSStatResult) -> Any: + return to_class(SessionFSStatResult, x) + + +def session_fs_stat_params_from_dict(s: Any) -> SessionFSStatParams: + return SessionFSStatParams.from_dict(s) + + +def session_fs_stat_params_to_dict(x: SessionFSStatParams) -> Any: + return to_class(SessionFSStatParams, x) + + +def session_fs_mkdir_params_from_dict(s: Any) -> SessionFSMkdirParams: + return SessionFSMkdirParams.from_dict(s) + + +def session_fs_mkdir_params_to_dict(x: SessionFSMkdirParams) -> Any: + return to_class(SessionFSMkdirParams, x) + + +def session_fs_readdir_result_from_dict(s: Any) -> SessionFSReaddirResult: + return SessionFSReaddirResult.from_dict(s) + + +def session_fs_readdir_result_to_dict(x: SessionFSReaddirResult) -> Any: + return to_class(SessionFSReaddirResult, x) + + +def session_fs_readdir_params_from_dict(s: Any) -> SessionFSReaddirParams: + return SessionFSReaddirParams.from_dict(s) + + +def session_fs_readdir_params_to_dict(x: SessionFSReaddirParams) -> Any: + return to_class(SessionFSReaddirParams, x) + + +def session_fs_readdir_with_types_result_from_dict(s: Any) -> SessionFSReaddirWithTypesResult: + return SessionFSReaddirWithTypesResult.from_dict(s) + + +def session_fs_readdir_with_types_result_to_dict(x: SessionFSReaddirWithTypesResult) -> Any: + return to_class(SessionFSReaddirWithTypesResult, x) + + +def session_fs_readdir_with_types_params_from_dict(s: Any) -> SessionFSReaddirWithTypesParams: + return SessionFSReaddirWithTypesParams.from_dict(s) + + +def session_fs_readdir_with_types_params_to_dict(x: SessionFSReaddirWithTypesParams) -> Any: + return to_class(SessionFSReaddirWithTypesParams, x) + + +def session_fs_rm_params_from_dict(s: Any) -> SessionFSRmParams: + return SessionFSRmParams.from_dict(s) + + +def session_fs_rm_params_to_dict(x: SessionFSRmParams) -> Any: + return to_class(SessionFSRmParams, x) + + +def session_fs_rename_params_from_dict(s: Any) -> SessionFSRenameParams: + return SessionFSRenameParams.from_dict(s) + + +def session_fs_rename_params_to_dict(x: SessionFSRenameParams) -> Any: + return to_class(SessionFSRenameParams, x) + + def _timeout_kwargs(timeout: float | None) -> dict: """Build keyword arguments for optional timeout forwarding.""" if timeout is not None: @@ -3536,3 +4065,105 @@ async def log(self, params: SessionLogParams, *, timeout: float | None = None) - params_dict["sessionId"] = self._session_id return SessionLogResult.from_dict(await self._client.request("session.log", params_dict, **_timeout_kwargs(timeout))) + +class SessionFsHandler(Protocol): + async def read_file(self, params: SessionFSReadFileParams) -> SessionFSReadFileResult: + pass + async def write_file(self, params: SessionFSWriteFileParams) -> None: + pass + async def append_file(self, params: SessionFSAppendFileParams) -> None: + pass + async def exists(self, params: SessionFSExistsParams) -> SessionFSExistsResult: + pass + async def stat(self, params: SessionFSStatParams) -> SessionFSStatResult: + pass + async def mkdir(self, params: SessionFSMkdirParams) -> None: + pass + async def readdir(self, params: SessionFSReaddirParams) -> SessionFSReaddirResult: + pass + async def readdir_with_types(self, params: SessionFSReaddirWithTypesParams) -> SessionFSReaddirWithTypesResult: + pass + async def rm(self, params: SessionFSRmParams) -> None: + pass + async def rename(self, params: SessionFSRenameParams) -> None: + pass + +@dataclass +class ClientSessionApiHandlers: + session_fs: SessionFsHandler | None = None + +def register_client_session_api_handlers( + client: "JsonRpcClient", + get_handlers: Callable[[str], ClientSessionApiHandlers], +) -> None: + """Register client-session request handlers on a JSON-RPC connection.""" + async def handle_session_fs_read_file(params: dict) -> dict | None: + request = SessionFSReadFileParams.from_dict(params) + handler = get_handlers(request.session_id).session_fs + if handler is None: raise RuntimeError(f"No session_fs handler registered for session: {request.session_id}") + result = await handler.read_file(request) + return result.to_dict() + client.set_request_handler("sessionFs.readFile", handle_session_fs_read_file) + async def handle_session_fs_write_file(params: dict) -> dict | None: + request = SessionFSWriteFileParams.from_dict(params) + handler = get_handlers(request.session_id).session_fs + if handler is None: raise RuntimeError(f"No session_fs handler registered for session: {request.session_id}") + await handler.write_file(request) + return None + client.set_request_handler("sessionFs.writeFile", handle_session_fs_write_file) + async def handle_session_fs_append_file(params: dict) -> dict | None: + request = SessionFSAppendFileParams.from_dict(params) + handler = get_handlers(request.session_id).session_fs + if handler is None: raise RuntimeError(f"No session_fs handler registered for session: {request.session_id}") + await handler.append_file(request) + return None + client.set_request_handler("sessionFs.appendFile", handle_session_fs_append_file) + async def handle_session_fs_exists(params: dict) -> dict | None: + request = SessionFSExistsParams.from_dict(params) + handler = get_handlers(request.session_id).session_fs + if handler is None: raise RuntimeError(f"No session_fs handler registered for session: {request.session_id}") + result = await handler.exists(request) + return result.to_dict() + client.set_request_handler("sessionFs.exists", handle_session_fs_exists) + async def handle_session_fs_stat(params: dict) -> dict | None: + request = SessionFSStatParams.from_dict(params) + handler = get_handlers(request.session_id).session_fs + if handler is None: raise RuntimeError(f"No session_fs handler registered for session: {request.session_id}") + result = await handler.stat(request) + return result.to_dict() + client.set_request_handler("sessionFs.stat", handle_session_fs_stat) + async def handle_session_fs_mkdir(params: dict) -> dict | None: + request = SessionFSMkdirParams.from_dict(params) + handler = get_handlers(request.session_id).session_fs + if handler is None: raise RuntimeError(f"No session_fs handler registered for session: {request.session_id}") + await handler.mkdir(request) + return None + client.set_request_handler("sessionFs.mkdir", handle_session_fs_mkdir) + async def handle_session_fs_readdir(params: dict) -> dict | None: + request = SessionFSReaddirParams.from_dict(params) + handler = get_handlers(request.session_id).session_fs + if handler is None: raise RuntimeError(f"No session_fs handler registered for session: {request.session_id}") + result = await handler.readdir(request) + return result.to_dict() + client.set_request_handler("sessionFs.readdir", handle_session_fs_readdir) + async def handle_session_fs_readdir_with_types(params: dict) -> dict | None: + request = SessionFSReaddirWithTypesParams.from_dict(params) + handler = get_handlers(request.session_id).session_fs + if handler is None: raise RuntimeError(f"No session_fs handler registered for session: {request.session_id}") + result = await handler.readdir_with_types(request) + return result.to_dict() + client.set_request_handler("sessionFs.readdirWithTypes", handle_session_fs_readdir_with_types) + async def handle_session_fs_rm(params: dict) -> dict | None: + request = SessionFSRmParams.from_dict(params) + handler = get_handlers(request.session_id).session_fs + if handler is None: raise RuntimeError(f"No session_fs handler registered for session: {request.session_id}") + await handler.rm(request) + return None + client.set_request_handler("sessionFs.rm", handle_session_fs_rm) + async def handle_session_fs_rename(params: dict) -> dict | None: + request = SessionFSRenameParams.from_dict(params) + handler = get_handlers(request.session_id).session_fs + if handler is None: raise RuntimeError(f"No session_fs handler registered for session: {request.session_id}") + await handler.rename(request) + return None + client.set_request_handler("sessionFs.rename", handle_session_fs_rename) diff --git a/python/copilot/session.py b/python/copilot/session.py index 59ec8532b..b3f62789d 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -23,6 +23,7 @@ from ._telemetry import get_trace_context, trace_context from .generated.rpc import ( Action, + ClientSessionApiHandlers, Kind, Level, Property, @@ -31,6 +32,7 @@ RequestedSchemaType, ResultResult, SessionCommandsHandlePendingCommandParams, + SessionFsHandler, SessionLogParams, SessionModelSwitchToParams, SessionPermissionsHandlePendingPermissionRequestParams, @@ -63,6 +65,14 @@ # ============================================================================ ReasoningEffort = Literal["low", "medium", "high", "xhigh"] +SessionFsConventions = Literal["posix", "windows"] + + +class SessionFsConfig(TypedDict): + initial_cwd: str + session_state_path: str + conventions: SessionFsConventions + # ============================================================================ # Attachment Types @@ -395,6 +405,8 @@ class ElicitationContext(TypedDict, total=False): ] """Handler invoked when the server dispatches an elicitation request to this client.""" +CreateSessionFsHandler = Callable[["CopilotSession"], SessionFsHandler] + # ============================================================================ # Session UI API @@ -862,6 +874,8 @@ class SessionConfig(TypedDict, total=False): # Handler for elicitation requests from the server. # When provided, the server calls back to this client for form-based UI dialogs. on_elicitation_request: ElicitationHandler + # Handler factory for session-scoped sessionFs operations. + create_session_fs_handler: CreateSessionFsHandler class ResumeSessionConfig(TypedDict, total=False): @@ -915,6 +929,8 @@ class ResumeSessionConfig(TypedDict, total=False): commands: list[CommandDefinition] # Handler for elicitation requests from the server. on_elicitation_request: ElicitationHandler + # Handler factory for session-scoped sessionFs operations. + create_session_fs_handler: CreateSessionFsHandler SessionEventHandler = Callable[[SessionEvent], None] @@ -984,6 +1000,7 @@ def __init__( self._elicitation_handler: ElicitationHandler | None = None self._elicitation_handler_lock = threading.Lock() self._capabilities: SessionCapabilities = {} + self._client_session_apis = ClientSessionApiHandlers() self._rpc: SessionRpc | None = None self._destroyed = False diff --git a/python/e2e/test_session_fs.py b/python/e2e/test_session_fs.py new file mode 100644 index 000000000..a656ce0f8 --- /dev/null +++ b/python/e2e/test_session_fs.py @@ -0,0 +1,349 @@ +"""E2E SessionFs tests mirroring nodejs/test/e2e/session_fs.test.ts.""" + +from __future__ import annotations + +import asyncio +import datetime as dt +import os +import re +from pathlib import Path + +import pytest +import pytest_asyncio + +from copilot import CopilotClient, SessionFsConfig, define_tool +from copilot.client import ExternalServerConfig, SubprocessConfig +from copilot.generated.rpc import ( + SessionFSExistsResult, + SessionFSReaddirResult, + SessionFSReaddirWithTypesResult, + SessionFSReadFileResult, + SessionFSStatResult, +) +from copilot.generated.session_events import SessionEvent +from copilot.session import PermissionHandler + +from .testharness import E2ETestContext + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +SESSION_FS_CONFIG: SessionFsConfig = { + "initial_cwd": "/", + "session_state_path": "/session-state", + "conventions": "posix", +} + + +@pytest_asyncio.fixture(scope="module", loop_scope="module") +async def session_fs_client(ctx: E2ETestContext): + github_token = ( + "fake-token-for-e2e-tests" if os.environ.get("GITHUB_ACTIONS") == "true" else None + ) + client = CopilotClient( + SubprocessConfig( + cli_path=ctx.cli_path, + cwd=ctx.work_dir, + env=ctx.get_env(), + github_token=github_token, + session_fs=SESSION_FS_CONFIG, + ) + ) + yield client + try: + await client.stop() + except Exception: + await client.force_stop() + + +class TestSessionFs: + async def test_should_route_file_operations_through_the_session_fs_provider( + self, ctx: E2ETestContext, session_fs_client: CopilotClient + ): + provider_root = Path(ctx.work_dir) / "provider" + session = await session_fs_client.create_session( + on_permission_request=PermissionHandler.approve_all, + create_session_fs_handler=create_test_session_fs_handler(provider_root), + ) + + msg = await session.send_and_wait("What is 100 + 200?") + assert msg is not None + assert msg.data.content is not None + assert "300" in msg.data.content + await session.disconnect() + + events_path = provider_path( + provider_root, session.session_id, "/session-state/events.jsonl" + ) + assert "300" in events_path.read_text(encoding="utf-8") + + async def test_should_load_session_data_from_fs_provider_on_resume( + self, ctx: E2ETestContext, session_fs_client: CopilotClient + ): + provider_root = Path(ctx.work_dir) / "provider" + create_session_fs_handler = create_test_session_fs_handler(provider_root) + + session1 = await session_fs_client.create_session( + on_permission_request=PermissionHandler.approve_all, + create_session_fs_handler=create_session_fs_handler, + ) + session_id = session1.session_id + + msg = await session1.send_and_wait("What is 50 + 50?") + assert msg is not None + assert msg.data.content is not None + assert "100" in msg.data.content + await session1.disconnect() + + assert provider_path(provider_root, session_id, "/session-state/events.jsonl").exists() + + session2 = await session_fs_client.resume_session( + session_id, + on_permission_request=PermissionHandler.approve_all, + create_session_fs_handler=create_session_fs_handler, + ) + + msg2 = await session2.send_and_wait("What is that times 3?") + assert msg2 is not None + assert msg2.data.content is not None + assert "300" in msg2.data.content + await session2.disconnect() + + async def test_should_reject_setprovider_when_sessions_already_exist(self, ctx: E2ETestContext): + github_token = ( + "fake-token-for-e2e-tests" if os.environ.get("GITHUB_ACTIONS") == "true" else None + ) + client1 = CopilotClient( + SubprocessConfig( + cli_path=ctx.cli_path, + cwd=ctx.work_dir, + env=ctx.get_env(), + use_stdio=False, + github_token=github_token, + ) + ) + session = None + client2 = None + + try: + session = await client1.create_session( + on_permission_request=PermissionHandler.approve_all, + ) + actual_port = client1.actual_port + assert actual_port is not None + + client2 = CopilotClient( + ExternalServerConfig( + url=f"localhost:{actual_port}", + session_fs=SESSION_FS_CONFIG, + ) + ) + + with pytest.raises(Exception): + await client2.start() + finally: + if session is not None: + await session.disconnect() + if client2 is not None: + await client2.force_stop() + await client1.force_stop() + + async def test_should_map_large_output_handling_into_sessionfs( + self, ctx: E2ETestContext, session_fs_client: CopilotClient + ): + provider_root = Path(ctx.work_dir) / "provider" + supplied_file_content = "x" * 100_000 + + @define_tool("get_big_string", description="Returns a large string") + def get_big_string() -> str: + return supplied_file_content + + session = await session_fs_client.create_session( + on_permission_request=PermissionHandler.approve_all, + create_session_fs_handler=create_test_session_fs_handler(provider_root), + tools=[get_big_string], + ) + + await session.send_and_wait( + "Call the get_big_string tool and reply with the word DONE only." + ) + + messages = await session.get_messages() + tool_result = find_tool_call_result(messages, "get_big_string") + assert tool_result is not None + assert "/session-state/temp/" in tool_result + match = re.search(r"(/session-state/temp/[^\s]+)", tool_result) + assert match is not None + + temp_file = provider_path(provider_root, session.session_id, match.group(1)) + assert temp_file.read_text(encoding="utf-8") == supplied_file_content + + async def test_should_succeed_with_compaction_while_using_sessionfs( + self, ctx: E2ETestContext, session_fs_client: CopilotClient + ): + provider_root = Path(ctx.work_dir) / "provider" + session = await session_fs_client.create_session( + on_permission_request=PermissionHandler.approve_all, + create_session_fs_handler=create_test_session_fs_handler(provider_root), + ) + + compaction_event = asyncio.Event() + compaction_success: bool | None = None + + def on_event(event: SessionEvent): + nonlocal compaction_success + if event.type.value == "session.compaction_complete": + compaction_success = event.data.success + compaction_event.set() + + session.on(on_event) + + await session.send_and_wait("What is 2+2?") + + events_path = provider_path( + provider_root, session.session_id, "/session-state/events.jsonl" + ) + await wait_for_path(events_path) + assert "checkpointNumber" not in events_path.read_text(encoding="utf-8") + + result = await session.rpc.compaction.compact() + await asyncio.wait_for(compaction_event.wait(), timeout=5.0) + assert result.success is True + assert compaction_success is True + + await wait_for_content(events_path, "checkpointNumber") + + +class _SessionFsHandler: + def __init__(self, provider_root: Path, session_id: str): + self._provider_root = provider_root + self._session_id = session_id + + async def read_file(self, params) -> SessionFSReadFileResult: + content = provider_path(self._provider_root, self._session_id, params.path).read_text( + encoding="utf-8" + ) + return SessionFSReadFileResult.from_dict({"content": content}) + + async def write_file(self, params) -> None: + path = provider_path(self._provider_root, self._session_id, params.path) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(params.content, encoding="utf-8") + + async def append_file(self, params) -> None: + path = provider_path(self._provider_root, self._session_id, params.path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as handle: + handle.write(params.content) + + async def exists(self, params) -> SessionFSExistsResult: + path = provider_path(self._provider_root, self._session_id, params.path) + return SessionFSExistsResult.from_dict({"exists": path.exists()}) + + async def stat(self, params) -> SessionFSStatResult: + path = provider_path(self._provider_root, self._session_id, params.path) + info = path.stat() + timestamp = dt.datetime.fromtimestamp(info.st_mtime, tz=dt.UTC).isoformat() + if timestamp.endswith("+00:00"): + timestamp = f"{timestamp[:-6]}Z" + return SessionFSStatResult.from_dict( + { + "isFile": not path.is_dir(), + "isDirectory": path.is_dir(), + "size": info.st_size, + "mtime": timestamp, + "birthtime": timestamp, + } + ) + + async def mkdir(self, params) -> None: + path = provider_path(self._provider_root, self._session_id, params.path) + if params.recursive: + path.mkdir(parents=True, exist_ok=True) + else: + path.mkdir() + + async def readdir(self, params) -> SessionFSReaddirResult: + entries = sorted( + entry.name + for entry in provider_path(self._provider_root, self._session_id, params.path).iterdir() + ) + return SessionFSReaddirResult.from_dict({"entries": entries}) + + async def readdir_with_types(self, params) -> SessionFSReaddirWithTypesResult: + entries = [] + for entry in sorted( + provider_path(self._provider_root, self._session_id, params.path).iterdir(), + key=lambda item: item.name, + ): + entries.append( + { + "name": entry.name, + "type": "directory" if entry.is_dir() else "file", + } + ) + return SessionFSReaddirWithTypesResult.from_dict({"entries": entries}) + + async def rm(self, params) -> None: + provider_path(self._provider_root, self._session_id, params.path).unlink() + + async def rename(self, params) -> None: + src = provider_path(self._provider_root, self._session_id, params.src) + dest = provider_path(self._provider_root, self._session_id, params.dest) + dest.parent.mkdir(parents=True, exist_ok=True) + src.rename(dest) + + +def create_test_session_fs_handler(provider_root: Path): + def create_handler(session): + return _SessionFsHandler(provider_root, session.session_id) + + return create_handler + + +def provider_path(provider_root: Path, session_id: str, path: str) -> Path: + return provider_root / session_id / path.lstrip("/") + + +def find_tool_call_result(messages: list[SessionEvent], tool_name: str) -> str | None: + for message in messages: + if ( + message.type.value == "tool.execution_complete" + and message.data.tool_call_id is not None + ): + if find_tool_name(messages, message.data.tool_call_id) == tool_name: + return message.data.result.content if message.data.result is not None else None + return None + + +def find_tool_name(messages: list[SessionEvent], tool_call_id: str) -> str | None: + for message in messages: + if ( + message.type.value == "tool.execution_start" + and message.data.tool_call_id == tool_call_id + ): + return message.data.tool_name + return None + + +async def wait_for_path(path: Path, timeout: float = 5.0) -> None: + async def predicate(): + return path.exists() + + await wait_for_predicate(predicate, timeout=timeout) + + +async def wait_for_content(path: Path, expected: str, timeout: float = 5.0) -> None: + async def predicate(): + return path.exists() and expected in path.read_text(encoding="utf-8") + + await wait_for_predicate(predicate, timeout=timeout) + + +async def wait_for_predicate(predicate, timeout: float = 5.0) -> None: + deadline = asyncio.get_running_loop().time() + timeout + while asyncio.get_running_loop().time() < deadline: + if await predicate(): + return + await asyncio.sleep(0.1) + raise TimeoutError("timed out waiting for condition") diff --git a/python/test_client.py b/python/test_client.py index d655df4d4..5d0dc868e 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -122,6 +122,36 @@ def test_is_external_server_true(self): assert client._is_external_server +class TestSessionFsConfig: + def test_missing_initial_cwd(self): + with pytest.raises(ValueError, match="session_fs.initial_cwd is required"): + CopilotClient( + SubprocessConfig( + cli_path=CLI_PATH, + log_level="error", + session_fs={ + "initial_cwd": "", + "session_state_path": "/session-state", + "conventions": "posix", + }, + ) + ) + + def test_missing_session_state_path(self): + with pytest.raises(ValueError, match="session_fs.session_state_path is required"): + CopilotClient( + SubprocessConfig( + cli_path=CLI_PATH, + log_level="error", + session_fs={ + "initial_cwd": "/", + "session_state_path": "", + "conventions": "posix", + }, + ) + ) + + class TestAuthOptions: def test_accepts_github_token(self): client = CopilotClient( diff --git a/scripts/codegen/csharp.ts b/scripts/codegen/csharp.ts index d60cfbb96..9049cb38c 100644 --- a/scripts/codegen/csharp.ts +++ b/scripts/codegen/csharp.ts @@ -602,7 +602,18 @@ let rpcEnumOutput: string[] = []; function singularPascal(s: string): string { const p = toPascalCase(s); - return p.endsWith("s") ? p.slice(0, -1) : p; + if (p.endsWith("ies")) return `${p.slice(0, -3)}y`; + if (/(xes|zes|ches|shes|sses)$/i.test(p)) return p.slice(0, -2); + if (p.endsWith("s") && !/(ss|us|is)$/i.test(p)) return p.slice(0, -1); + return p; +} + +function resultTypeName(rpcMethod: string): string { + return `${typeToClassName(rpcMethod)}Result`; +} + +function paramsTypeName(rpcMethod: string): string { + return `${typeToClassName(rpcMethod)}Params`; } function resolveRpcType(schema: JSONSchema7, isRequired: boolean, parentClassName: string, propName: string, classes: string[]): string { @@ -653,7 +664,7 @@ function emitRpcClass(className: string, schema: JSONSchema7, visibility: "publi const requiredSet = new Set(schema.required || []); const lines: string[] = []; - lines.push(...xmlDocComment(schema.description || `RPC data type for ${className.replace(/Request$/, "").replace(/Result$/, "")} operations.`, "")); + lines.push(...xmlDocComment(schema.description || `RPC data type for ${className.replace(/(Request|Result|Params)$/, "")} operations.`, "")); if (experimentalRpcTypes.has(className)) { lines.push(`[Experimental(Diagnostics.Experimental)]`); } @@ -923,6 +934,131 @@ function emitSessionApiClass(className: string, node: Record, c return lines.join("\n"); } +function collectClientGroups(node: Record): Array<{ groupName: string; groupNode: Record; methods: RpcMethod[] }> { + const groups: Array<{ groupName: string; groupNode: Record; methods: RpcMethod[] }> = []; + for (const [groupName, groupNode] of Object.entries(node)) { + if (typeof groupNode === "object" && groupNode !== null) { + groups.push({ + groupName, + groupNode: groupNode as Record, + methods: collectRpcMethods(groupNode as Record), + }); + } + } + return groups; +} + +function clientHandlerInterfaceName(groupName: string): string { + return `I${toPascalCase(groupName)}Handler`; +} + +function clientHandlerMethodName(rpcMethod: string): string { + const parts = rpcMethod.split("."); + return `${toPascalCase(parts[parts.length - 1])}Async`; +} + +function emitClientSessionApiRegistration(clientSchema: Record, classes: string[]): string[] { + const lines: string[] = []; + const groups = collectClientGroups(clientSchema); + + for (const { methods } of groups) { + for (const method of methods) { + if (method.result) { + const resultClass = emitRpcClass(resultTypeName(method.rpcMethod), method.result, "public", classes); + if (resultClass) classes.push(resultClass); + } + + if (method.params?.properties && Object.keys(method.params.properties).length > 0) { + const paramsClass = emitRpcClass(paramsTypeName(method.rpcMethod), method.params, "public", classes); + if (paramsClass) classes.push(paramsClass); + } + } + } + + for (const { groupName, groupNode, methods } of groups) { + const interfaceName = clientHandlerInterfaceName(groupName); + const groupExperimental = isNodeFullyExperimental(groupNode); + lines.push(`/// Handles \`${groupName}\` client session API methods.`); + if (groupExperimental) { + lines.push(`[Experimental(Diagnostics.Experimental)]`); + } + lines.push(`public interface ${interfaceName}`); + lines.push(`{`); + for (const method of methods) { + const hasParams = method.params?.properties && Object.keys(method.params.properties).length > 0; + const taskType = method.result ? `Task<${resultTypeName(method.rpcMethod)}>` : "Task"; + lines.push(` /// Handles "${method.rpcMethod}".`); + if (method.stability === "experimental" && !groupExperimental) { + lines.push(` [Experimental(Diagnostics.Experimental)]`); + } + if (hasParams) { + lines.push(` ${taskType} ${clientHandlerMethodName(method.rpcMethod)}(${paramsTypeName(method.rpcMethod)} request, CancellationToken cancellationToken = default);`); + } else { + lines.push(` ${taskType} ${clientHandlerMethodName(method.rpcMethod)}(CancellationToken cancellationToken = default);`); + } + } + lines.push(`}`); + lines.push(""); + } + + lines.push(`/// Provides all client session API handler groups for a session.`); + lines.push(`public class ClientSessionApiHandlers`); + lines.push(`{`); + for (const { groupName } of groups) { + lines.push(` /// Optional handler for ${toPascalCase(groupName)} client session API methods.`); + lines.push(` public ${clientHandlerInterfaceName(groupName)}? ${toPascalCase(groupName)} { get; set; }`); + lines.push(""); + } + if (lines[lines.length - 1] === "") lines.pop(); + lines.push(`}`); + lines.push(""); + + lines.push(`/// Registers client session API handlers on a JSON-RPC connection.`); + lines.push(`public static class ClientSessionApiRegistration`); + lines.push(`{`); + lines.push(` /// `); + lines.push(` /// Registers handlers for server-to-client session API calls.`); + lines.push(` /// Each incoming call includes a sessionId in its params object,`); + lines.push(` /// which is used to resolve the session's handler group.`); + lines.push(` /// `); + lines.push(` public static void RegisterClientSessionApiHandlers(JsonRpc rpc, Func getHandlers)`); + lines.push(` {`); + for (const { groupName, methods } of groups) { + for (const method of methods) { + const handlerProperty = toPascalCase(groupName); + const handlerMethod = clientHandlerMethodName(method.rpcMethod); + const hasParams = method.params?.properties && Object.keys(method.params.properties).length > 0; + const paramsClass = paramsTypeName(method.rpcMethod); + const taskType = method.result ? `Task<${resultTypeName(method.rpcMethod)}>` : "Task"; + const registrationVar = `register${typeToClassName(method.rpcMethod)}Method`; + + if (hasParams) { + lines.push(` var ${registrationVar} = (Func<${paramsClass}, CancellationToken, ${taskType}>)(async (request, cancellationToken) =>`); + lines.push(` {`); + lines.push(` var handler = getHandlers(request.SessionId).${handlerProperty};`); + lines.push(` if (handler is null) throw new InvalidOperationException($"No ${groupName} handler registered for session: {request.SessionId}");`); + if (method.result) { + lines.push(` return await handler.${handlerMethod}(request, cancellationToken);`); + } else { + lines.push(` await handler.${handlerMethod}(request, cancellationToken);`); + } + lines.push(` });`); + lines.push(` rpc.AddLocalRpcMethod(${registrationVar}.Method, ${registrationVar}.Target!, new JsonRpcMethodAttribute("${method.rpcMethod}")`); + lines.push(` {`); + lines.push(` UseSingleObjectParameterDeserialization = true`); + lines.push(` });`); + } else { + lines.push(` rpc.AddLocalRpcMethod("${method.rpcMethod}", (Func)(_ =>`); + lines.push(` throw new InvalidOperationException("No params provided for ${method.rpcMethod}")));`); + } + } + } + lines.push(` }`); + lines.push(`}`); + + return lines; +} + function generateRpcCode(schema: ApiSchema): string { emittedRpcClasses.clear(); experimentalRpcTypes.clear(); @@ -937,6 +1073,9 @@ function generateRpcCode(schema: ApiSchema): string { let sessionRpcParts: string[] = []; if (schema.session) sessionRpcParts = emitSessionRpcClasses(schema.session, classes); + let clientSessionParts: string[] = []; + if (schema.clientSession) clientSessionParts = emitClientSessionApiRegistration(schema.clientSession, classes); + const lines: string[] = []; lines.push(`${COPYRIGHT} @@ -962,6 +1101,7 @@ internal static class Diagnostics for (const enumCode of rpcEnumOutput) lines.push(enumCode, ""); for (const part of serverRpcParts) lines.push(part, ""); for (const part of sessionRpcParts) lines.push(part, ""); + if (clientSessionParts.length > 0) lines.push(...clientSessionParts, ""); // Add JsonSerializerContext for AOT/trimming support const typeNames = [...emittedRpcClasses].sort(); diff --git a/scripts/codegen/go.ts b/scripts/codegen/go.ts index 5c6a71b23..5f061fbd4 100644 --- a/scripts/codegen/go.ts +++ b/scripts/codegen/go.ts @@ -178,7 +178,11 @@ async function generateRpc(schemaPath?: string): Promise { const resolvedPath = schemaPath ?? (await getApiSchemaPath()); const schema = JSON.parse(await fs.readFile(resolvedPath, "utf-8")) as ApiSchema; - const allMethods = [...collectRpcMethods(schema.server || {}), ...collectRpcMethods(schema.session || {})]; + const allMethods = [ + ...collectRpcMethods(schema.server || {}), + ...collectRpcMethods(schema.session || {}), + ...collectRpcMethods(schema.clientSession || {}), + ]; // Build a combined schema for quicktype - prefix types to avoid conflicts const combinedSchema: JSONSchema7 = { @@ -271,11 +275,16 @@ async function generateRpc(schemaPath?: string): Promise { lines.push(``); lines.push(`package rpc`); lines.push(``); + const imports = [`"context"`, `"encoding/json"`]; + if (schema.clientSession) { + imports.push(`"errors"`, `"fmt"`); + } + imports.push(`"github.com/github/copilot-sdk/go/internal/jsonrpc2"`); + lines.push(`import (`); - lines.push(`\t"context"`); - lines.push(`\t"encoding/json"`); - lines.push(``); - lines.push(`\t"github.com/github/copilot-sdk/go/internal/jsonrpc2"`); + for (const imp of imports) { + lines.push(`\t${imp}`); + } lines.push(`)`); lines.push(``); @@ -292,6 +301,10 @@ async function generateRpc(schemaPath?: string): Promise { emitRpcWrapper(lines, schema.session, true, resolveType, fieldNames); } + if (schema.clientSession) { + emitClientSessionApiRegistration(lines, schema.clientSession, resolveType); + } + const outPath = await writeGeneratedFile("go/rpc/generated_rpc.go", lines.join("\n")); console.log(` ✓ ${outPath}`); @@ -430,6 +443,118 @@ function emitMethod(lines: string[], receiver: string, name: string, method: Rpc lines.push(``); } +interface ClientGroup { + groupName: string; + groupNode: Record; + methods: RpcMethod[]; +} + +function collectClientGroups(node: Record): ClientGroup[] { + const groups: ClientGroup[] = []; + for (const [groupName, groupNode] of Object.entries(node)) { + if (typeof groupNode === "object" && groupNode !== null) { + groups.push({ + groupName, + groupNode: groupNode as Record, + methods: collectRpcMethods(groupNode as Record), + }); + } + } + return groups; +} + +function clientHandlerInterfaceName(groupName: string): string { + return `${toPascalCase(groupName)}Handler`; +} + +function clientHandlerMethodName(rpcMethod: string): string { + return toPascalCase(rpcMethod.split(".").at(-1)!); +} + +function emitClientSessionApiRegistration(lines: string[], clientSchema: Record, resolveType: (name: string) => string): void { + const groups = collectClientGroups(clientSchema); + + for (const { groupName, groupNode, methods } of groups) { + const interfaceName = clientHandlerInterfaceName(groupName); + const groupExperimental = isNodeFullyExperimental(groupNode); + if (groupExperimental) { + lines.push(`// Experimental: ${interfaceName} contains experimental APIs that may change or be removed.`); + } + lines.push(`type ${interfaceName} interface {`); + for (const method of methods) { + if (method.stability === "experimental" && !groupExperimental) { + lines.push(`\t// Experimental: ${clientHandlerMethodName(method.rpcMethod)} is an experimental API and may change or be removed in future versions.`); + } + const paramsType = resolveType(toPascalCase(method.rpcMethod) + "Params"); + if (method.result) { + const resultType = resolveType(toPascalCase(method.rpcMethod) + "Result"); + lines.push(`\t${clientHandlerMethodName(method.rpcMethod)}(request *${paramsType}) (*${resultType}, error)`); + } else { + lines.push(`\t${clientHandlerMethodName(method.rpcMethod)}(request *${paramsType}) error`); + } + } + lines.push(`}`); + lines.push(``); + } + + lines.push(`// ClientSessionApiHandlers provides all client session API handler groups for a session.`); + lines.push(`type ClientSessionApiHandlers struct {`); + for (const { groupName } of groups) { + lines.push(`\t${toPascalCase(groupName)} ${clientHandlerInterfaceName(groupName)}`); + } + lines.push(`}`); + lines.push(``); + + lines.push(`func clientSessionHandlerError(err error) *jsonrpc2.Error {`); + lines.push(`\tif err == nil {`); + lines.push(`\t\treturn nil`); + lines.push(`\t}`); + lines.push(`\tvar rpcErr *jsonrpc2.Error`); + lines.push(`\tif errors.As(err, &rpcErr) {`); + lines.push(`\t\treturn rpcErr`); + lines.push(`\t}`); + lines.push(`\treturn &jsonrpc2.Error{Code: -32603, Message: err.Error()}`); + lines.push(`}`); + lines.push(``); + + lines.push(`// RegisterClientSessionApiHandlers registers handlers for server-to-client session API calls.`); + lines.push(`func RegisterClientSessionApiHandlers(client *jsonrpc2.Client, getHandlers func(sessionID string) *ClientSessionApiHandlers) {`); + for (const { groupName, methods } of groups) { + const handlerField = toPascalCase(groupName); + for (const method of methods) { + const paramsType = resolveType(toPascalCase(method.rpcMethod) + "Params"); + lines.push(`\tclient.SetRequestHandler("${method.rpcMethod}", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) {`); + lines.push(`\t\tvar request ${paramsType}`); + lines.push(`\t\tif err := json.Unmarshal(params, &request); err != nil {`); + lines.push(`\t\t\treturn nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)}`); + lines.push(`\t\t}`); + lines.push(`\t\thandlers := getHandlers(request.SessionID)`); + lines.push(`\t\tif handlers == nil || handlers.${handlerField} == nil {`); + lines.push(`\t\t\treturn nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("No ${groupName} handler registered for session: %s", request.SessionID)}`); + lines.push(`\t\t}`); + if (method.result) { + lines.push(`\t\tresult, err := handlers.${handlerField}.${clientHandlerMethodName(method.rpcMethod)}(&request)`); + lines.push(`\t\tif err != nil {`); + lines.push(`\t\t\treturn nil, clientSessionHandlerError(err)`); + lines.push(`\t\t}`); + lines.push(`\t\traw, err := json.Marshal(result)`); + lines.push(`\t\tif err != nil {`); + lines.push(`\t\t\treturn nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("Failed to marshal response: %v", err)}`); + lines.push(`\t\t}`); + lines.push(`\t\treturn raw, nil`); + } else { + lines.push(`\t\tif err := handlers.${handlerField}.${clientHandlerMethodName(method.rpcMethod)}(&request); err != nil {`); + lines.push(`\t\t\treturn nil, clientSessionHandlerError(err)`); + lines.push(`\t\t}`); + lines.push(`\t\treturn json.RawMessage("null"), nil`); + } + lines.push(`\t})`); + } + } + lines.push(`}`); + lines.push(``); +} + // ── Main ──────────────────────────────────────────────────────────────────── async function generate(sessionSchemaPath?: string, apiSchemaPath?: string): Promise { diff --git a/scripts/codegen/python.ts b/scripts/codegen/python.ts index 71e44943f..2aa593c5d 100644 --- a/scripts/codegen/python.ts +++ b/scripts/codegen/python.ts @@ -208,7 +208,11 @@ async function generateRpc(schemaPath?: string): Promise { const resolvedPath = schemaPath ?? (await getApiSchemaPath()); const schema = JSON.parse(await fs.readFile(resolvedPath, "utf-8")) as ApiSchema; - const allMethods = [...collectRpcMethods(schema.server || {}), ...collectRpcMethods(schema.session || {})]; + const allMethods = [ + ...collectRpcMethods(schema.server || {}), + ...collectRpcMethods(schema.session || {}), + ...collectRpcMethods(schema.clientSession || {}), + ]; // Build a combined schema for quicktype const combinedSchema: JSONSchema7 = { @@ -302,6 +306,10 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from .._jsonrpc import JsonRpcClient +from collections.abc import Callable +from dataclasses import dataclass +from typing import Protocol + `); lines.push(typesCode); lines.push(` @@ -320,6 +328,9 @@ def _timeout_kwargs(timeout: float | None) -> dict: if (schema.session) { emitRpcWrapper(lines, schema.session, true, resolveType); } + if (schema.clientSession) { + emitClientSessionApiRegistration(lines, schema.clientSession, resolveType); + } const outPath = await writeGeneratedFile("python/copilot/generated/rpc.py", lines.join("\n")); console.log(` ✓ ${outPath}`); @@ -429,6 +440,107 @@ function emitMethod(lines: string[], name: string, method: RpcMethod, isSession: lines.push(``); } +function emitClientSessionApiRegistration( + lines: string[], + node: Record, + resolveType: (name: string) => string +): void { + const groups = Object.entries(node).filter(([, value]) => typeof value === "object" && value !== null && !isRpcMethod(value)); + + for (const [groupName, groupNode] of groups) { + const handlerName = `${toPascalCase(groupName)}Handler`; + const groupExperimental = isNodeFullyExperimental(groupNode as Record); + if (groupExperimental) { + lines.push(`# Experimental: this API group is experimental and may change or be removed.`); + } + lines.push(`class ${handlerName}(Protocol):`); + for (const [methodName, value] of Object.entries(groupNode as Record)) { + if (!isRpcMethod(value)) continue; + emitClientSessionHandlerMethod(lines, methodName, value, resolveType, groupExperimental); + } + lines.push(``); + } + + lines.push(`@dataclass`); + lines.push(`class ClientSessionApiHandlers:`); + if (groups.length === 0) { + lines.push(` pass`); + } else { + for (const [groupName] of groups) { + lines.push(` ${toSnakeCase(groupName)}: ${toPascalCase(groupName)}Handler | None = None`); + } + } + lines.push(``); + + lines.push(`def register_client_session_api_handlers(`); + lines.push(` client: "JsonRpcClient",`); + lines.push(` get_handlers: Callable[[str], ClientSessionApiHandlers],`); + lines.push(`) -> None:`); + lines.push(` """Register client-session request handlers on a JSON-RPC connection."""`); + if (groups.length === 0) { + lines.push(` return`); + } else { + for (const [groupName, groupNode] of groups) { + for (const [methodName, value] of Object.entries(groupNode as Record)) { + if (!isRpcMethod(value)) continue; + emitClientSessionRegistrationMethod( + lines, + groupName, + methodName, + value, + resolveType + ); + } + } + } + lines.push(``); +} + +function emitClientSessionHandlerMethod( + lines: string[], + name: string, + method: RpcMethod, + resolveType: (name: string) => string, + groupExperimental = false +): void { + const paramsType = resolveType(toPascalCase(method.rpcMethod) + "Params"); + const resultType = method.result ? resolveType(toPascalCase(method.rpcMethod) + "Result") : "None"; + lines.push(` async def ${toSnakeCase(name)}(self, params: ${paramsType}) -> ${resultType}:`); + if (method.stability === "experimental" && !groupExperimental) { + lines.push(` """.. warning:: This API is experimental and may change or be removed in future versions."""`); + } + lines.push(` pass`); +} + +function emitClientSessionRegistrationMethod( + lines: string[], + groupName: string, + methodName: string, + method: RpcMethod, + resolveType: (name: string) => string +): void { + const handlerVariableName = `handle_${toSnakeCase(groupName)}_${toSnakeCase(methodName)}`; + const paramsType = resolveType(toPascalCase(method.rpcMethod) + "Params"); + const resultType = method.result ? resolveType(toPascalCase(method.rpcMethod) + "Result") : null; + const handlerField = toSnakeCase(groupName); + const handlerMethod = toSnakeCase(methodName); + + lines.push(` async def ${handlerVariableName}(params: dict) -> dict | None:`); + lines.push(` request = ${paramsType}.from_dict(params)`); + lines.push(` handler = get_handlers(request.session_id).${handlerField}`); + lines.push( + ` if handler is None: raise RuntimeError(f"No ${handlerField} handler registered for session: {request.session_id}")` + ); + if (resultType) { + lines.push(` result = await handler.${handlerMethod}(request)`); + lines.push(` return result.to_dict()`); + } else { + lines.push(` await handler.${handlerMethod}(request)`); + lines.push(` return None`); + } + lines.push(` client.set_request_handler("${method.rpcMethod}", ${handlerVariableName})`); +} + // ── Main ──────────────────────────────────────────────────────────────────── async function generate(sessionSchemaPath?: string, apiSchemaPath?: string): Promise {