diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 07502ee2d..a15644a39 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -73,6 +73,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable private readonly string? _optionsHost; private int? _actualPort; private int? _negotiatedProtocolVersion; + private readonly SessionFsConfig? _sessionFsConfig; private List? _modelsCache; private readonly SemaphoreSlim _modelsCacheLock = new(1, 1); private readonly Func>>? _onListModels; @@ -143,6 +144,7 @@ public CopilotClient(CopilotClientOptions? options = null) _logger = _options.Logger ?? NullLogger.Instance; _onListModels = _options.OnListModels; + _sessionFsConfig = _options.SessionFs; // Parse CliUrl if provided if (!string.IsNullOrEmpty(_options.CliUrl)) @@ -227,6 +229,12 @@ async Task StartCoreAsync(CancellationToken ct) // Verify protocol version compatibility await VerifyProtocolVersionAsync(connection, ct); + // Register sessionFs provider if configured + if (_sessionFsConfig is not null) + { + await RegisterSessionFsProviderAsync(connection, ct); + } + _logger.LogInformation("Copilot client connected"); return connection; } @@ -462,6 +470,18 @@ public async Task CreateSessionAsync(SessionConfig config, Cance { session.RegisterUserInputHandler(config.OnUserInputRequest); } + if (_sessionFsConfig is not null) + { + if (config.CreateSessionFsHandler is not null) + { + session.RegisterSessionFsHandler(config.CreateSessionFsHandler(session)); + } + else + { + throw new InvalidOperationException( + "CreateSessionFsHandler is required in session config when SessionFs is enabled in client options."); + } + } if (config.Hooks != null) { session.RegisterHooks(config.Hooks); @@ -582,6 +602,18 @@ public async Task ResumeSessionAsync(string sessionId, ResumeSes { session.RegisterUserInputHandler(config.OnUserInputRequest); } + if (_sessionFsConfig is not null) + { + if (config.CreateSessionFsHandler is not null) + { + session.RegisterSessionFsHandler(config.CreateSessionFsHandler(session)); + } + else + { + throw new InvalidOperationException( + "CreateSessionFsHandler is required in session config when SessionFs is enabled in client options."); + } + } if (config.Hooks != null) { session.RegisterHooks(config.Hooks); @@ -1104,6 +1136,20 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio _negotiatedProtocolVersion = serverVersion; } + private async Task RegisterSessionFsProviderAsync(Connection connection, CancellationToken cancellationToken) + { + var config = _sessionFsConfig!; + await _rpc!.SessionFs.SetProviderAsync( + config.InitialCwd, + config.SessionStatePath, + config.Conventions switch + { + SessionFsConventions.Windows => SessionFsSetProviderRequestConventions.Windows, + _ => SessionFsSetProviderRequestConventions.Posix, + }, + cancellationToken); + } + private static async Task<(Process Process, int? DetectedLocalhostTcpPort, StringBuilder StderrBuffer)> StartCliServerAsync(CopilotClientOptions options, ILogger logger, CancellationToken cancellationToken) { // Use explicit path, COPILOT_CLI_PATH env var (from options.Environment or process env), or bundled CLI - no PATH fallback @@ -1319,6 +1365,19 @@ private async Task ConnectToServerAsync(Process? cliProcess, string? rpc.AddLocalRpcMethod("userInput.request", handler.OnUserInputRequest); rpc.AddLocalRpcMethod("hooks.invoke", handler.OnHooksInvoke); rpc.AddLocalRpcMethod("systemMessage.transform", handler.OnSystemMessageTransform); + + // SessionFs client session API handlers + rpc.AddLocalRpcMethod("sessionFs.readFile", handler.OnSessionFsReadFile); + rpc.AddLocalRpcMethod("sessionFs.writeFile", handler.OnSessionFsWriteFile); + rpc.AddLocalRpcMethod("sessionFs.appendFile", handler.OnSessionFsAppendFile); + rpc.AddLocalRpcMethod("sessionFs.exists", handler.OnSessionFsExists); + rpc.AddLocalRpcMethod("sessionFs.stat", handler.OnSessionFsStat); + rpc.AddLocalRpcMethod("sessionFs.mkdir", handler.OnSessionFsMkdir); + rpc.AddLocalRpcMethod("sessionFs.readdir", handler.OnSessionFsReaddir); + rpc.AddLocalRpcMethod("sessionFs.readdirWithTypes", handler.OnSessionFsReaddirWithTypes); + rpc.AddLocalRpcMethod("sessionFs.rm", handler.OnSessionFsRm); + rpc.AddLocalRpcMethod("sessionFs.rename", handler.OnSessionFsRename); + rpc.StartListening(); // Transition state to Disconnected if the JSON-RPC connection drops @@ -1554,6 +1613,67 @@ public async Task OnPermissionRequestV2(string sess }); } } + + // SessionFs handler methods + public async Task OnSessionFsReadFile(string sessionId, string path) + { + var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + return await session.HandleSessionFsReadFileAsync(new SessionFsReadFileParams { SessionId = sessionId, Path = path }); + } + + public async Task OnSessionFsWriteFile(string sessionId, string path, string content, int? mode = null) + { + var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + await session.HandleSessionFsWriteFileAsync(new SessionFsWriteFileParams { SessionId = sessionId, Path = path, Content = content, Mode = mode }); + } + + public async Task OnSessionFsAppendFile(string sessionId, string path, string content, int? mode = null) + { + var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + await session.HandleSessionFsAppendFileAsync(new SessionFsAppendFileParams { SessionId = sessionId, Path = path, Content = content, Mode = mode }); + } + + public async Task OnSessionFsExists(string sessionId, string path) + { + var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + return await session.HandleSessionFsExistsAsync(new SessionFsExistsParams { SessionId = sessionId, Path = path }); + } + + public async Task OnSessionFsStat(string sessionId, string path) + { + var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + return await session.HandleSessionFsStatAsync(new SessionFsStatParams { SessionId = sessionId, Path = path }); + } + + public async Task OnSessionFsMkdir(string sessionId, string path, bool? recursive = null, int? mode = null) + { + var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + await session.HandleSessionFsMkdirAsync(new SessionFsMkdirParams { SessionId = sessionId, Path = path, Recursive = recursive, Mode = mode }); + } + + public async Task OnSessionFsReaddir(string sessionId, string path) + { + var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + return await session.HandleSessionFsReaddirAsync(new SessionFsReaddirParams { SessionId = sessionId, Path = path }); + } + + public async Task OnSessionFsReaddirWithTypes(string sessionId, string path) + { + var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + return await session.HandleSessionFsReaddirWithTypesAsync(new SessionFsReaddirWithTypesParams { SessionId = sessionId, Path = path }); + } + + public async Task OnSessionFsRm(string sessionId, string path, bool? recursive = null, bool? force = null) + { + var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + await session.HandleSessionFsRmAsync(new SessionFsRmParams { SessionId = sessionId, Path = path, Recursive = recursive, Force = force }); + } + + public async Task OnSessionFsRename(string sessionId, string src, string dest) + { + var session = client.GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); + await session.HandleSessionFsRenameAsync(new SessionFsRenameParams { SessionId = sessionId, Src = src, Dest = dest }); + } } private class Connection( diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 6d0a78d4c..2d26ba414 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -63,6 +63,7 @@ public sealed partial class CopilotSession : IAsyncDisposable private volatile PermissionRequestHandler? _permissionHandler; private volatile UserInputHandler? _userInputHandler; private volatile ElicitationHandler? _elicitationHandler; + private volatile ISessionFsHandler? _sessionFsHandler; private ImmutableArray _eventHandlers = ImmutableArray.Empty; private SessionHooks? _hooks; @@ -664,6 +665,29 @@ internal void RegisterElicitationHandler(ElicitationHandler? handler) _elicitationHandler = handler; } + /// + /// Registers a session filesystem handler for this session. + /// + /// The handler to invoke for filesystem operations. + internal void RegisterSessionFsHandler(ISessionFsHandler handler) + { + _sessionFsHandler = handler; + } + + internal ISessionFsHandler GetSessionFsHandler() => + _sessionFsHandler ?? throw new InvalidOperationException($"No sessionFs handler registered for session: {SessionId}"); + + internal Task HandleSessionFsReadFileAsync(SessionFsReadFileParams request) => GetSessionFsHandler().ReadFileAsync(request); + internal Task HandleSessionFsWriteFileAsync(SessionFsWriteFileParams request) => GetSessionFsHandler().WriteFileAsync(request); + internal Task HandleSessionFsAppendFileAsync(SessionFsAppendFileParams request) => GetSessionFsHandler().AppendFileAsync(request); + internal Task HandleSessionFsExistsAsync(SessionFsExistsParams request) => GetSessionFsHandler().ExistsAsync(request); + internal Task HandleSessionFsStatAsync(SessionFsStatParams request) => GetSessionFsHandler().StatAsync(request); + internal Task HandleSessionFsMkdirAsync(SessionFsMkdirParams request) => GetSessionFsHandler().MkdirAsync(request); + internal Task HandleSessionFsReaddirAsync(SessionFsReaddirParams request) => GetSessionFsHandler().ReaddirAsync(request); + internal Task HandleSessionFsReaddirWithTypesAsync(SessionFsReaddirWithTypesParams request) => GetSessionFsHandler().ReaddirWithTypesAsync(request); + internal Task HandleSessionFsRmAsync(SessionFsRmParams request) => GetSessionFsHandler().RmAsync(request); + internal Task HandleSessionFsRenameAsync(SessionFsRenameParams request) => GetSessionFsHandler().RenameAsync(request); + /// /// Sets the capabilities reported by the host for this session. /// diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 265781bac..021cb918a 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -64,6 +64,7 @@ protected CopilotClientOptions(CopilotClientOptions? other) Logger = other.Logger; LogLevel = other.LogLevel; Port = other.Port; + SessionFs = other.SessionFs; Telemetry = other.Telemetry; UseLoggedInUser = other.UseLoggedInUser; UseStdio = other.UseStdio; @@ -150,6 +151,14 @@ public string? GithubToken /// public Func>>? OnListModels { get; set; } + /// + /// Custom session filesystem provider. + /// When provided, the client registers as the session filesystem provider + /// on connection, routing all session scoped file I/O through callbacks + /// instead of the server's default local filesystem storage. + /// + public SessionFsConfig? SessionFs { get; set; } + /// /// OpenTelemetry configuration for the CLI server. /// When set to a non- instance, the CLI server is started with OpenTelemetry instrumentation enabled. @@ -1579,6 +1588,7 @@ protected SessionConfig(SessionConfig? other) ? new Dictionary(other.McpServers, other.McpServers.Comparer) : null; Model = other.Model; + CreateSessionFsHandler = other.CreateSessionFsHandler; ModelCapabilities = other.ModelCapabilities; OnElicitationRequest = other.OnElicitationRequest; OnEvent = other.OnEvent; @@ -1732,11 +1742,17 @@ protected SessionConfig(SessionConfig? other) /// /// Equivalent to calling immediately /// after creation, but executes earlier in the lifecycle so no events are missed. - /// Using this property rather than guarantees that early events emitted + /// Using this property rather than guarantees that early events emitted /// by the CLI during session creation (e.g. session.start) are delivered to the handler. /// public SessionEventHandler? OnEvent { get; set; } + /// + /// Factory that creates a session filesystem handler for this session. + /// Required when is configured. + /// + public Func? CreateSessionFsHandler { get; set; } + /// /// Creates a shallow clone of this instance. /// @@ -1778,6 +1794,7 @@ protected ResumeSessionConfig(ResumeSessionConfig? other) CustomAgents = other.CustomAgents is not null ? [.. other.CustomAgents] : null; Agent = other.Agent; DisabledSkills = other.DisabledSkills is not null ? [.. other.DisabledSkills] : null; + CreateSessionFsHandler = other.CreateSessionFsHandler; DisableResume = other.DisableResume; ExcludedTools = other.ExcludedTools is not null ? [.. other.ExcludedTools] : null; Hooks = other.Hooks; @@ -1941,6 +1958,12 @@ protected ResumeSessionConfig(ResumeSessionConfig? other) /// public SessionEventHandler? OnEvent { get; set; } + /// + /// Factory that creates a session filesystem handler for this session. + /// Required when is configured. + /// + public Func? CreateSessionFsHandler { get; set; } + /// /// Creates a shallow clone of this instance. /// @@ -2435,6 +2458,277 @@ public class SystemMessageTransformRpcResponse public Dictionary? Sections { get; set; } } +/// +/// Connection level configuration for the session filesystem provider. +/// When set on , the client registers +/// as the session filesystem provider on connection, routing all session scoped +/// file I/O through callbacks instead of the +/// server's default local filesystem storage. +/// +public class SessionFsConfig +{ + /// + /// Initial working directory for sessions (user's project directory). + /// + public string InitialCwd { get; set; } = string.Empty; + + /// + /// Path within each session's SessionFs where the runtime stores + /// session scoped files (events, workspace, checkpoints, etc.). + /// + public string SessionStatePath { get; set; } = string.Empty; + + /// + /// Path conventions used by this filesystem provider. + /// + public SessionFsConventions Conventions { get; set; } +} + +/// +/// Path conventions used by a session filesystem provider. +/// +public enum SessionFsConventions +{ + /// POSIX-style paths (forward slashes). + Posix, + /// Windows-style paths (backslashes). + Windows, +} + +/// +/// Handler interface for session filesystem operations. +/// Implement this interface to provide a custom virtual filesystem for session data. +/// +public interface ISessionFsHandler +{ + /// Reads the contents of a file. + Task ReadFileAsync(SessionFsReadFileParams request, CancellationToken cancellationToken = default); + /// Writes content to a file, creating it if it does not exist. + Task WriteFileAsync(SessionFsWriteFileParams request, CancellationToken cancellationToken = default); + /// Appends content to a file, creating it if it does not exist. + Task AppendFileAsync(SessionFsAppendFileParams request, CancellationToken cancellationToken = default); + /// Checks whether a path exists. + Task ExistsAsync(SessionFsExistsParams request, CancellationToken cancellationToken = default); + /// Returns metadata about a file or directory. + Task StatAsync(SessionFsStatParams request, CancellationToken cancellationToken = default); + /// Creates a directory, optionally creating parent directories. + Task MkdirAsync(SessionFsMkdirParams request, CancellationToken cancellationToken = default); + /// Lists entries in a directory. + Task ReaddirAsync(SessionFsReaddirParams request, CancellationToken cancellationToken = default); + /// Lists entries in a directory with type information. + Task ReaddirWithTypesAsync(SessionFsReaddirWithTypesParams request, CancellationToken cancellationToken = default); + /// Removes a file or directory. + Task RmAsync(SessionFsRmParams request, CancellationToken cancellationToken = default); + /// Renames (moves) a file or directory. + Task RenameAsync(SessionFsRenameParams request, CancellationToken cancellationToken = default); +} + +/// Parameters for a sessionFs.readFile request. +public class SessionFsReadFileParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; +} + +/// Result of a sessionFs.readFile request. +public class SessionFsReadFileResult +{ + /// File content as UTF-8 string. + [JsonPropertyName("content")] + public string Content { get; set; } = string.Empty; +} + +/// Parameters for a sessionFs.writeFile request. +public class SessionFsWriteFileParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; + /// Content to write. + [JsonPropertyName("content")] + public string Content { get; set; } = string.Empty; + /// Optional POSIX-style mode for newly created files. + [JsonPropertyName("mode")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? Mode { get; set; } +} + +/// Parameters for a sessionFs.appendFile request. +public class SessionFsAppendFileParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; + /// Content to append. + [JsonPropertyName("content")] + public string Content { get; set; } = string.Empty; + /// Optional POSIX-style mode for newly created files. + [JsonPropertyName("mode")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? Mode { get; set; } +} + +/// Parameters for a sessionFs.exists request. +public class SessionFsExistsParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; +} + +/// Result of a sessionFs.exists request. +public class SessionFsExistsResult +{ + /// Whether the path exists. + [JsonPropertyName("exists")] + public bool Exists { get; set; } +} + +/// Parameters for a sessionFs.stat request. +public class SessionFsStatParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; +} + +/// Result of a sessionFs.stat request. +public class SessionFsStatResult +{ + /// Whether the path is a file. + [JsonPropertyName("isFile")] + public bool IsFile { get; set; } + /// Whether the path is a directory. + [JsonPropertyName("isDirectory")] + public bool IsDirectory { get; set; } + /// File size in bytes. + [JsonPropertyName("size")] + public long Size { get; set; } + /// ISO 8601 timestamp of last modification. + [JsonPropertyName("mtime")] + public string Mtime { get; set; } = string.Empty; + /// ISO 8601 timestamp of creation. + [JsonPropertyName("birthtime")] + public string Birthtime { get; set; } = string.Empty; +} + +/// Parameters for a sessionFs.mkdir request. +public class SessionFsMkdirParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; + /// Whether to create parent directories. + [JsonPropertyName("recursive")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public bool? Recursive { get; set; } + /// Optional POSIX-style mode for the directory. + [JsonPropertyName("mode")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? Mode { get; set; } +} + +/// Parameters for a sessionFs.readdir request. +public class SessionFsReaddirParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; +} + +/// Result of a sessionFs.readdir request. +public class SessionFsReaddirResult +{ + /// Entry names in the directory. + [JsonPropertyName("entries")] + public List Entries { get; set; } = []; +} + +/// Parameters for a sessionFs.readdirWithTypes request. +public class SessionFsReaddirWithTypesParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; +} + +/// A directory entry with type information. +public class SessionFsDirEntry +{ + /// Entry name. + [JsonPropertyName("name")] + public string Name { get; set; } = string.Empty; + /// Entry type: "file" or "directory". + [JsonPropertyName("type")] + public string Type { get; set; } = string.Empty; +} + +/// Result of a sessionFs.readdirWithTypes request. +public class SessionFsReaddirWithTypesResult +{ + /// Entries with type information. + [JsonPropertyName("entries")] + public List Entries { get; set; } = []; +} + +/// Parameters for a sessionFs.rm request. +public class SessionFsRmParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + /// Path using SessionFs conventions. + [JsonPropertyName("path")] + public string Path { get; set; } = string.Empty; + /// Whether to remove directories recursively. + [JsonPropertyName("recursive")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public bool? Recursive { get; set; } + /// Whether to ignore errors if the path does not exist. + [JsonPropertyName("force")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public bool? Force { get; set; } +} + +/// Parameters for a sessionFs.rename request. +public class SessionFsRenameParams +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; + /// Source path. + [JsonPropertyName("src")] + public string Src { get; set; } = string.Empty; + /// Destination path. + [JsonPropertyName("dest")] + public string Dest { get; set; } = string.Empty; +} + [JsonSourceGenerationOptions( JsonSerializerDefaults.Web, AllowOutOfOrderMetadataProperties = true, @@ -2462,6 +2756,22 @@ public class SystemMessageTransformRpcResponse [JsonSerializable(typeof(PingResponse))] [JsonSerializable(typeof(ProviderConfig))] [JsonSerializable(typeof(SessionContext))] +[JsonSerializable(typeof(SessionFsAppendFileParams))] +[JsonSerializable(typeof(SessionFsDirEntry))] +[JsonSerializable(typeof(SessionFsExistsParams))] +[JsonSerializable(typeof(SessionFsExistsResult))] +[JsonSerializable(typeof(SessionFsMkdirParams))] +[JsonSerializable(typeof(SessionFsReadFileParams))] +[JsonSerializable(typeof(SessionFsReadFileResult))] +[JsonSerializable(typeof(SessionFsReaddirParams))] +[JsonSerializable(typeof(SessionFsReaddirResult))] +[JsonSerializable(typeof(SessionFsReaddirWithTypesParams))] +[JsonSerializable(typeof(SessionFsReaddirWithTypesResult))] +[JsonSerializable(typeof(SessionFsRenameParams))] +[JsonSerializable(typeof(SessionFsRmParams))] +[JsonSerializable(typeof(SessionFsStatParams))] +[JsonSerializable(typeof(SessionFsStatResult))] +[JsonSerializable(typeof(SessionFsWriteFileParams))] [JsonSerializable(typeof(SessionLifecycleEvent))] [JsonSerializable(typeof(SessionLifecycleEventMetadata))] [JsonSerializable(typeof(SessionListFilter))] diff --git a/dotnet/test/SessionFsE2ETests.cs b/dotnet/test/SessionFsE2ETests.cs new file mode 100644 index 000000000..dbd17815b --- /dev/null +++ b/dotnet/test/SessionFsE2ETests.cs @@ -0,0 +1,379 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.SDK.Test.Harness; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.SDK.Test; + +/// +/// Custom fixture that creates a CopilotClient with SessionFs enabled. +/// +public class SessionFsE2EFixture : IAsyncLifetime +{ + public E2ETestContext Ctx { get; private set; } = null!; + public CopilotClient Client { get; private set; } = null!; + + public async Task InitializeAsync() + { + Ctx = await E2ETestContext.CreateAsync(); + Client = new CopilotClient(new CopilotClientOptions + { + Cwd = Ctx.WorkDir, + CliPath = Environment.GetEnvironmentVariable("COPILOT_CLI_PATH") + ?? Path.Combine(FindRepoRoot(), "nodejs/node_modules/@github/copilot/index.js"), + Environment = Ctx.GetEnvironment(), + UseStdio = true, + GitHubToken = !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("GITHUB_ACTIONS")) + ? "fake-token-for-e2e-tests" + : null, + SessionFs = new SessionFsConfig + { + InitialCwd = "/", + SessionStatePath = "/session-state", + Conventions = SessionFsConventions.Posix, + }, + }); + } + + public async Task DisposeAsync() + { + if (Client is not null) await Client.ForceStopAsync(); + await Ctx.DisposeAsync(); + } + + private static string FindRepoRoot() + { + var dir = new DirectoryInfo(AppContext.BaseDirectory); + while (dir != null) + { + if (Directory.Exists(Path.Combine(dir.FullName, "nodejs"))) + return dir.FullName; + dir = dir.Parent; + } + throw new InvalidOperationException("Could not find repository root"); + } +} + +/// +/// In memory filesystem implementation for session filesystem E2E tests. +/// +internal class InMemorySessionFsHandler : ISessionFsHandler +{ + private readonly string _sessionId; + private readonly InMemoryFileSystem _fs; + + public InMemorySessionFsHandler(string sessionId, InMemoryFileSystem fs) + { + _sessionId = sessionId; + _fs = fs; + } + + private string Sp(string path) => $"/{_sessionId}{(path.StartsWith('/') ? path : "/" + path)}"; + + public Task ReadFileAsync(SessionFsReadFileParams request, CancellationToken ct = default) + => Task.FromResult(new SessionFsReadFileResult { Content = _fs.ReadFile(Sp(request.Path)) }); + + public Task WriteFileAsync(SessionFsWriteFileParams request, CancellationToken ct = default) + { + _fs.WriteFile(Sp(request.Path), request.Content); + return Task.CompletedTask; + } + + public Task AppendFileAsync(SessionFsAppendFileParams request, CancellationToken ct = default) + { + _fs.AppendFile(Sp(request.Path), request.Content); + return Task.CompletedTask; + } + + public Task ExistsAsync(SessionFsExistsParams request, CancellationToken ct = default) + => Task.FromResult(new SessionFsExistsResult { Exists = _fs.Exists(Sp(request.Path)) }); + + public Task StatAsync(SessionFsStatParams request, CancellationToken ct = default) + { + var (isFile, size, mtime) = _fs.Stat(Sp(request.Path)); + return Task.FromResult(new SessionFsStatResult + { + IsFile = isFile, + IsDirectory = !isFile, + Size = size, + Mtime = mtime.ToString("o"), + Birthtime = mtime.ToString("o"), + }); + } + + public Task MkdirAsync(SessionFsMkdirParams request, CancellationToken ct = default) + { + _fs.Mkdir(Sp(request.Path), request.Recursive ?? false); + return Task.CompletedTask; + } + + public Task ReaddirAsync(SessionFsReaddirParams request, CancellationToken ct = default) + => Task.FromResult(new SessionFsReaddirResult { Entries = _fs.Readdir(Sp(request.Path)) }); + + public Task ReaddirWithTypesAsync(SessionFsReaddirWithTypesParams request, CancellationToken ct = default) + { + var entries = _fs.ReaddirWithTypes(Sp(request.Path)); + return Task.FromResult(new SessionFsReaddirWithTypesResult + { + Entries = entries.Select(e => new SessionFsDirEntry { Name = e.Name, Type = e.IsDirectory ? "directory" : "file" }).ToList() + }); + } + + public Task RmAsync(SessionFsRmParams request, CancellationToken ct = default) + { + _fs.Remove(Sp(request.Path)); + return Task.CompletedTask; + } + + public Task RenameAsync(SessionFsRenameParams request, CancellationToken ct = default) + { + _fs.Rename(Sp(request.Src), Sp(request.Dest)); + return Task.CompletedTask; + } +} + +/// +/// Simple in memory filesystem for testing. Stores files as path to content entries. +/// Directories are inferred from file paths (mkdir is tracked separately). +/// +internal class InMemoryFileSystem +{ + private readonly Dictionary _files = new(); + private readonly HashSet _directories = new() { "/" }; + private readonly Dictionary _mtimes = new(); + private readonly object _lock = new(); + + public string ReadFile(string path) + { + lock (_lock) + { + if (!_files.TryGetValue(NormalizePath(path), out var content)) + throw new FileNotFoundException($"File not found: {path}"); + return content; + } + } + + public void WriteFile(string path, string content) + { + lock (_lock) + { + var p = NormalizePath(path); + EnsureParentDirs(p); + _files[p] = content; + _mtimes[p] = DateTime.UtcNow; + } + } + + public void AppendFile(string path, string content) + { + lock (_lock) + { + var p = NormalizePath(path); + EnsureParentDirs(p); + _files[p] = _files.TryGetValue(p, out var existing) ? existing + content : content; + _mtimes[p] = DateTime.UtcNow; + } + } + + public bool Exists(string path) + { + lock (_lock) + { + var p = NormalizePath(path); + return _files.ContainsKey(p) || _directories.Contains(p); + } + } + + public (bool IsFile, long Size, DateTime Mtime) Stat(string path) + { + lock (_lock) + { + var p = NormalizePath(path); + if (_files.TryGetValue(p, out var content)) + return (true, content.Length, _mtimes.GetValueOrDefault(p, DateTime.UtcNow)); + if (_directories.Contains(p)) + return (false, 0, DateTime.UtcNow); + throw new FileNotFoundException($"Path not found: {path}"); + } + } + + public void Mkdir(string path, bool recursive) + { + lock (_lock) + { + var p = NormalizePath(path); + if (recursive) + EnsureParentDirs(p + "/placeholder"); + _directories.Add(p); + } + } + + public List Readdir(string path) + { + lock (_lock) + { + var p = NormalizePath(path); + if (!p.EndsWith('/')) p += "/"; + var entries = new HashSet(); + foreach (var key in _files.Keys) + { + if (key.StartsWith(p) && key.Length > p.Length) + { + var rest = key[p.Length..]; + var slash = rest.IndexOf('/'); + entries.Add(slash >= 0 ? rest[..slash] : rest); + } + } + foreach (var dir in _directories) + { + if (dir.StartsWith(p) && dir.Length > p.Length) + { + var rest = dir[p.Length..]; + var slash = rest.IndexOf('/'); + entries.Add(slash >= 0 ? rest[..slash] : rest); + } + } + return entries.Order().ToList(); + } + } + + public List<(string Name, bool IsDirectory)> ReaddirWithTypes(string path) + { + lock (_lock) + { + var names = Readdir(path); + var p = NormalizePath(path); + if (!p.EndsWith('/')) p += "/"; + return names.Select(n => + { + var full = p + n; + var isDir = _directories.Contains(full) || _files.Keys.Any(k => k.StartsWith(full + "/")); + return (n, isDir); + }).ToList(); + } + } + + public void Remove(string path) + { + lock (_lock) + { + var p = NormalizePath(path); + _files.Remove(p); + _directories.Remove(p); + _mtimes.Remove(p); + } + } + + public void Rename(string src, string dest) + { + lock (_lock) + { + var s = NormalizePath(src); + var d = NormalizePath(dest); + if (_files.TryGetValue(s, out var content)) + { + _files.Remove(s); + EnsureParentDirs(d); + _files[d] = content; + _mtimes[d] = _mtimes.GetValueOrDefault(s, DateTime.UtcNow); + _mtimes.Remove(s); + } + } + } + + private static string NormalizePath(string path) + { + if (string.IsNullOrEmpty(path)) + return "/"; + + var normalized = path.TrimEnd('/'); + return normalized.Length == 0 ? "/" : normalized; + } + + private void EnsureParentDirs(string path) + { + var parts = path.Split('/'); + for (int i = 1; i < parts.Length - 1; i++) + { + var dir = string.Join("/", parts[..( i + 1)]); + _directories.Add(dir); + } + } +} + +public class SessionFsE2ETests(SessionFsE2EFixture fixture, ITestOutputHelper output) : IClassFixture, IAsyncLifetime +{ + private readonly SessionFsE2EFixture _fixture = fixture; + private readonly string _testName = GetTestName(output); + + private E2ETestContext Ctx => _fixture.Ctx; + private CopilotClient Client => _fixture.Client; + + // Shared in memory filesystem across tests in this class + private static readonly InMemoryFileSystem SharedFs = new(); + + private static string GetTestName(ITestOutputHelper output) + { + var type = output.GetType(); + var testField = type.GetField("test", System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.NonPublic); + var test = (ITest?)testField?.GetValue(output); + return test?.TestCase.TestMethod.Method.Name ?? throw new InvalidOperationException("Couldn't find test name"); + } + + public async Task InitializeAsync() + { + await Ctx.ConfigureForTestAsync("session_fs", _testName); + } + + public Task DisposeAsync() => Task.CompletedTask; + + [Fact] + public async Task Should_Route_File_Operations_Through_The_Session_Fs_Provider() + { + var session = await Client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = s => new InMemorySessionFsHandler(s.SessionId, SharedFs), + }); + + var msg = await session.SendAndWaitAsync(new MessageOptions { Prompt = "What is 100 + 200?" }); + Assert.NotNull(msg); + Assert.Contains("300", msg!.Data.Content); + await session.DisposeAsync(); + + var content = SharedFs.ReadFile($"/{session.SessionId}/session-state/events.jsonl"); + Assert.Contains("300", content); + } + + [Fact] + public async Task Should_Load_Session_Data_From_Fs_Provider_On_Resume() + { + var session1 = await Client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = s => new InMemorySessionFsHandler(s.SessionId, SharedFs), + }); + var sessionId = session1.SessionId; + + var msg = await session1.SendAndWaitAsync(new MessageOptions { Prompt = "What is 50 + 50?" }); + Assert.NotNull(msg); + Assert.Contains("100", msg!.Data.Content); + await session1.DisposeAsync(); + + Assert.True(SharedFs.Exists($"/{sessionId}/session-state/events.jsonl")); + + var session2 = await Client.ResumeSessionAsync(sessionId, new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = s => new InMemorySessionFsHandler(s.SessionId, SharedFs), + }); + + var msg2 = await session2.SendAndWaitAsync(new MessageOptions { Prompt = "What is that times 3?" }); + await session2.DisposeAsync(); + Assert.NotNull(msg2); + Assert.Contains("300", msg2!.Data.Content); + } +} diff --git a/dotnet/test/SessionFsTests.cs b/dotnet/test/SessionFsTests.cs new file mode 100644 index 000000000..68f0bbf23 --- /dev/null +++ b/dotnet/test/SessionFsTests.cs @@ -0,0 +1,121 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using Xunit; + +namespace GitHub.Copilot.SDK.Test; + +public class SessionFsTests +{ + [Fact] + public void SessionFsConfig_CanBeSetOnClientOptions() + { + var options = new CopilotClientOptions + { + SessionFs = new SessionFsConfig + { + InitialCwd = "/home/user/project", + SessionStatePath = "/session-state", + Conventions = SessionFsConventions.Posix, + } + }; + + Assert.NotNull(options.SessionFs); + Assert.Equal("/home/user/project", options.SessionFs.InitialCwd); + Assert.Equal("/session-state", options.SessionFs.SessionStatePath); + Assert.Equal(SessionFsConventions.Posix, options.SessionFs.Conventions); + } + + [Fact] + public void SessionFsConfig_CopiedInClone() + { + var original = new CopilotClientOptions + { + SessionFs = new SessionFsConfig + { + InitialCwd = "/", + SessionStatePath = "/state", + Conventions = SessionFsConventions.Windows, + } + }; + + var clone = original.Clone(); + + Assert.NotNull(clone.SessionFs); + Assert.Same(original.SessionFs, clone.SessionFs); + } + + [Fact] + public void SessionConfig_HasCreateSessionFsHandler() + { + var config = new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = _ => new TestSessionFsHandler(), + }; + + Assert.NotNull(config.CreateSessionFsHandler); + } + + [Fact] + public void ResumeSessionConfig_HasCreateSessionFsHandler() + { + var config = new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = _ => new TestSessionFsHandler(), + }; + + Assert.NotNull(config.CreateSessionFsHandler); + } + + [Fact] + public void CreateSessionFsHandler_CopiedInSessionConfigClone() + { + Func factory = _ => new TestSessionFsHandler(); + var original = new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = factory, + }; + + var clone = original.Clone(); + + Assert.Same(factory, clone.CreateSessionFsHandler); + } + + [Fact] + public void CreateSessionFsHandler_CopiedInResumeSessionConfigClone() + { + Func factory = _ => new TestSessionFsHandler(); + var original = new ResumeSessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + CreateSessionFsHandler = factory, + }; + + var clone = original.Clone(); + + Assert.Same(factory, clone.CreateSessionFsHandler); + } + + private class TestSessionFsHandler : ISessionFsHandler + { + public Task ReadFileAsync(SessionFsReadFileParams request, CancellationToken cancellationToken = default) + => Task.FromResult(new SessionFsReadFileResult { Content = "" }); + public Task WriteFileAsync(SessionFsWriteFileParams request, CancellationToken cancellationToken = default) => Task.CompletedTask; + public Task AppendFileAsync(SessionFsAppendFileParams request, CancellationToken cancellationToken = default) => Task.CompletedTask; + public Task ExistsAsync(SessionFsExistsParams request, CancellationToken cancellationToken = default) + => Task.FromResult(new SessionFsExistsResult { Exists = false }); + public Task StatAsync(SessionFsStatParams request, CancellationToken cancellationToken = default) + => Task.FromResult(new SessionFsStatResult()); + public Task MkdirAsync(SessionFsMkdirParams request, CancellationToken cancellationToken = default) => Task.CompletedTask; + public Task ReaddirAsync(SessionFsReaddirParams request, CancellationToken cancellationToken = default) + => Task.FromResult(new SessionFsReaddirResult()); + public Task ReaddirWithTypesAsync(SessionFsReaddirWithTypesParams request, CancellationToken cancellationToken = default) + => Task.FromResult(new SessionFsReaddirWithTypesResult()); + public Task RmAsync(SessionFsRmParams request, CancellationToken cancellationToken = default) => Task.CompletedTask; + public Task RenameAsync(SessionFsRenameParams request, CancellationToken cancellationToken = default) => Task.CompletedTask; + } +} diff --git a/go/client.go b/go/client.go index 731efbe24..749246b75 100644 --- a/go/client.go +++ b/go/client.go @@ -97,6 +97,7 @@ type Client struct { osProcess atomic.Pointer[os.Process] negotiatedProtocolVersion int onListModels func(ctx context.Context) ([]ModelInfo, error) + sessionFsConfig *SessionFsConfig // RPC provides typed server-scoped RPC methods. // This field is nil until the client is connected via Start(). @@ -192,6 +193,9 @@ func NewClient(options *ClientOptions) *Client { if options.OnListModels != nil { client.onListModels = options.OnListModels } + if options.SessionFs != nil { + client.sessionFsConfig = options.SessionFs + } } // Default Env to current environment if not set @@ -305,6 +309,15 @@ func (c *Client) Start(ctx context.Context) error { return errors.Join(err, killErr) } + // Register sessionFs provider if configured + if c.sessionFsConfig != nil { + if err := c.registerSessionFsProvider(ctx); err != nil { + killErr := c.killProcess() + c.state = StateError + return errors.Join(err, killErr) + } + } + c.state = StateConnected return nil } @@ -618,6 +631,13 @@ func (c *Client) CreateSession(ctx context.Context, config *SessionConfig) (*Ses if config.OnElicitationRequest != nil { session.registerElicitationHandler(config.OnElicitationRequest) } + if c.sessionFsConfig != nil { + if config.CreateSessionFsHandler != nil { + session.registerSessionFsHandler(config.CreateSessionFsHandler(session)) + } else { + return nil, fmt.Errorf("CreateSessionFsHandler is required in session config when SessionFs is enabled in client options") + } + } c.sessionsMux.Lock() c.sessions[sessionID] = session @@ -758,6 +778,13 @@ func (c *Client) ResumeSessionWithOptions(ctx context.Context, sessionID string, if config.OnElicitationRequest != nil { session.registerElicitationHandler(config.OnElicitationRequest) } + if c.sessionFsConfig != nil { + if config.CreateSessionFsHandler != nil { + session.registerSessionFsHandler(config.CreateSessionFsHandler(session)) + } else { + return nil, fmt.Errorf("CreateSessionFsHandler is required in session config when SessionFs is enabled in client options") + } + } c.sessionsMux.Lock() c.sessions[sessionID] = session @@ -1262,6 +1289,145 @@ func (c *Client) verifyProtocolVersion(ctx context.Context) error { return nil } +func (c *Client) registerSessionFsProvider(ctx context.Context) error { + cfg := c.sessionFsConfig + _, err := c.RPC.SessionFs.SetProvider(ctx, &rpc.SessionFSSetProviderParams{ + InitialCwd: cfg.InitialCwd, + SessionStatePath: cfg.SessionStatePath, + Conventions: rpc.Conventions(cfg.Conventions), + }) + return err +} + +func (c *Client) getSessionFsHandler(sessionID string) (SessionFsHandler, *jsonrpc2.Error) { + c.sessionsMux.Lock() + session, ok := c.sessions[sessionID] + c.sessionsMux.Unlock() + if !ok { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("unknown session %s", sessionID)} + } + h := session.getSessionFsHandler() + if h == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("no sessionFs handler registered for session: %s", sessionID)} + } + return h, nil +} + +func (c *Client) handleSessionFsReadFile(req SessionFsReadFileParams) (*SessionFsReadFileResult, *jsonrpc2.Error) { + h, rpcErr := c.getSessionFsHandler(req.SessionID) + if rpcErr != nil { + return nil, rpcErr + } + result, err := h.ReadFile(req) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} + } + return result, nil +} + +func (c *Client) handleSessionFsWriteFile(req SessionFsWriteFileParams) (*struct{}, *jsonrpc2.Error) { + h, rpcErr := c.getSessionFsHandler(req.SessionID) + if rpcErr != nil { + return nil, rpcErr + } + if err := h.WriteFile(req); err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} + } + return &struct{}{}, nil +} + +func (c *Client) handleSessionFsAppendFile(req SessionFsAppendFileParams) (*struct{}, *jsonrpc2.Error) { + h, rpcErr := c.getSessionFsHandler(req.SessionID) + if rpcErr != nil { + return nil, rpcErr + } + if err := h.AppendFile(req); err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} + } + return &struct{}{}, nil +} + +func (c *Client) handleSessionFsExists(req SessionFsExistsParams) (*SessionFsExistsResult, *jsonrpc2.Error) { + h, rpcErr := c.getSessionFsHandler(req.SessionID) + if rpcErr != nil { + return nil, rpcErr + } + result, err := h.Exists(req) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} + } + return result, nil +} + +func (c *Client) handleSessionFsStat(req SessionFsStatParams) (*SessionFsStatResult, *jsonrpc2.Error) { + h, rpcErr := c.getSessionFsHandler(req.SessionID) + if rpcErr != nil { + return nil, rpcErr + } + result, err := h.Stat(req) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} + } + return result, nil +} + +func (c *Client) handleSessionFsMkdir(req SessionFsMkdirParams) (*struct{}, *jsonrpc2.Error) { + h, rpcErr := c.getSessionFsHandler(req.SessionID) + if rpcErr != nil { + return nil, rpcErr + } + if err := h.Mkdir(req); err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} + } + return &struct{}{}, nil +} + +func (c *Client) handleSessionFsReaddir(req SessionFsReaddirParams) (*SessionFsReaddirResult, *jsonrpc2.Error) { + h, rpcErr := c.getSessionFsHandler(req.SessionID) + if rpcErr != nil { + return nil, rpcErr + } + result, err := h.Readdir(req) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} + } + return result, nil +} + +func (c *Client) handleSessionFsReaddirWithTypes(req SessionFsReaddirWithTypesParams) (*SessionFsReaddirWithTypesResult, *jsonrpc2.Error) { + h, rpcErr := c.getSessionFsHandler(req.SessionID) + if rpcErr != nil { + return nil, rpcErr + } + result, err := h.ReaddirWithTypes(req) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} + } + return result, nil +} + +func (c *Client) handleSessionFsRm(req SessionFsRmParams) (*struct{}, *jsonrpc2.Error) { + h, rpcErr := c.getSessionFsHandler(req.SessionID) + if rpcErr != nil { + return nil, rpcErr + } + if err := h.Rm(req); err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} + } + return &struct{}{}, nil +} + +func (c *Client) handleSessionFsRename(req SessionFsRenameParams) (*struct{}, *jsonrpc2.Error) { + h, rpcErr := c.getSessionFsHandler(req.SessionID) + if rpcErr != nil { + return nil, rpcErr + } + if err := h.Rename(req); err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: err.Error()} + } + return &struct{}{}, nil +} + // startCLIServer starts the CLI server process. // // This spawns the CLI server as a subprocess using the configured transport @@ -1526,6 +1692,18 @@ func (c *Client) setupNotificationHandler() { c.client.SetRequestHandler("userInput.request", jsonrpc2.RequestHandlerFor(c.handleUserInputRequest)) c.client.SetRequestHandler("hooks.invoke", jsonrpc2.RequestHandlerFor(c.handleHooksInvoke)) c.client.SetRequestHandler("systemMessage.transform", jsonrpc2.RequestHandlerFor(c.handleSystemMessageTransform)) + + // SessionFs client session API handlers + c.client.SetRequestHandler("sessionFs.readFile", jsonrpc2.RequestHandlerFor(c.handleSessionFsReadFile)) + c.client.SetRequestHandler("sessionFs.writeFile", jsonrpc2.RequestHandlerFor(c.handleSessionFsWriteFile)) + c.client.SetRequestHandler("sessionFs.appendFile", jsonrpc2.RequestHandlerFor(c.handleSessionFsAppendFile)) + c.client.SetRequestHandler("sessionFs.exists", jsonrpc2.RequestHandlerFor(c.handleSessionFsExists)) + c.client.SetRequestHandler("sessionFs.stat", jsonrpc2.RequestHandlerFor(c.handleSessionFsStat)) + c.client.SetRequestHandler("sessionFs.mkdir", jsonrpc2.RequestHandlerFor(c.handleSessionFsMkdir)) + c.client.SetRequestHandler("sessionFs.readdir", jsonrpc2.RequestHandlerFor(c.handleSessionFsReaddir)) + c.client.SetRequestHandler("sessionFs.readdirWithTypes", jsonrpc2.RequestHandlerFor(c.handleSessionFsReaddirWithTypes)) + c.client.SetRequestHandler("sessionFs.rm", jsonrpc2.RequestHandlerFor(c.handleSessionFsRm)) + c.client.SetRequestHandler("sessionFs.rename", jsonrpc2.RequestHandlerFor(c.handleSessionFsRename)) } func (c *Client) handleSessionEvent(req sessionEventRequest) { diff --git a/go/client_test.go b/go/client_test.go index 8f302f338..c88b0d84d 100644 --- a/go/client_test.go +++ b/go/client_test.go @@ -846,3 +846,47 @@ func TestCreateSessionResponse_Capabilities(t *testing.T) { } }) } + +func TestClient_SessionFsConfig(t *testing.T) { + t.Run("stores session fs config from options", func(t *testing.T) { + client := NewClient(&ClientOptions{ + CLIUrl: "localhost:9999", + SessionFs: &SessionFsConfig{ + InitialCwd: "/home/user", + SessionStatePath: "/session-state", + Conventions: "posix", + }, + }) + + if client.sessionFsConfig == nil { + t.Fatal("Expected sessionFsConfig to be set") + } + if client.sessionFsConfig.InitialCwd != "/home/user" { + t.Errorf("Expected InitialCwd '/home/user', got '%s'", client.sessionFsConfig.InitialCwd) + } + if client.sessionFsConfig.Conventions != "posix" { + t.Errorf("Expected Conventions 'posix', got '%s'", client.sessionFsConfig.Conventions) + } + }) + + t.Run("returns error when sessionFs enabled but no handler on CreateSession", func(t *testing.T) { + client := NewClient(&ClientOptions{ + CLIUrl: "localhost:9999", + AutoStart: Bool(false), + SessionFs: &SessionFsConfig{ + InitialCwd: "/", + SessionStatePath: "/state", + Conventions: "posix", + }, + }) + + _, err := client.CreateSession(context.Background(), &SessionConfig{ + OnPermissionRequest: PermissionHandler.ApproveAll, + // CreateSessionFsHandler intentionally omitted + }) + + if err == nil { + t.Fatal("Expected error when CreateSessionFsHandler is missing") + } + }) +} diff --git a/go/internal/e2e/session_fs_test.go b/go/internal/e2e/session_fs_test.go new file mode 100644 index 000000000..ce36c15b3 --- /dev/null +++ b/go/internal/e2e/session_fs_test.go @@ -0,0 +1,321 @@ +package e2e + +import ( + "fmt" + "sort" + "strings" + "sync" + "testing" + "time" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +// inMemoryFS is a simple in memory filesystem for testing. +type inMemoryFS struct { + mu sync.Mutex + files map[string]string + dirs map[string]bool + mtime map[string]time.Time +} + +func newInMemoryFS() *inMemoryFS { + return &inMemoryFS{ + files: make(map[string]string), + dirs: map[string]bool{"/": true}, + mtime: make(map[string]time.Time), + } +} + +func (fs *inMemoryFS) ensureParents(p string) { + parts := strings.Split(p, "/") + for i := 1; i < len(parts)-1; i++ { + dir := strings.Join(parts[:i+1], "/") + fs.dirs[dir] = true + } +} + +// sessionFsHandler adapts the in memory FS for a specific session. +type sessionFsHandler struct { + sessionID string + fs *inMemoryFS +} + +func (h *sessionFsHandler) sp(p string) string { + if strings.HasPrefix(p, "/") { + return "/" + h.sessionID + p + } + return "/" + h.sessionID + "/" + p +} + +func (h *sessionFsHandler) ReadFile(params copilot.SessionFsReadFileParams) (*copilot.SessionFsReadFileResult, error) { + h.fs.mu.Lock() + defer h.fs.mu.Unlock() + content, ok := h.fs.files[h.sp(params.Path)] + if !ok { + return nil, fmt.Errorf("file not found: %s", params.Path) + } + return &copilot.SessionFsReadFileResult{Content: content}, nil +} + +func (h *sessionFsHandler) WriteFile(params copilot.SessionFsWriteFileParams) error { + h.fs.mu.Lock() + defer h.fs.mu.Unlock() + p := h.sp(params.Path) + h.fs.ensureParents(p) + h.fs.files[p] = params.Content + h.fs.mtime[p] = time.Now() + return nil +} + +func (h *sessionFsHandler) AppendFile(params copilot.SessionFsAppendFileParams) error { + h.fs.mu.Lock() + defer h.fs.mu.Unlock() + p := h.sp(params.Path) + h.fs.ensureParents(p) + h.fs.files[p] += params.Content + h.fs.mtime[p] = time.Now() + return nil +} + +func (h *sessionFsHandler) Exists(params copilot.SessionFsExistsParams) (*copilot.SessionFsExistsResult, error) { + h.fs.mu.Lock() + defer h.fs.mu.Unlock() + p := h.sp(params.Path) + _, fileOk := h.fs.files[p] + _, dirOk := h.fs.dirs[p] + return &copilot.SessionFsExistsResult{Exists: fileOk || dirOk}, nil +} + +func (h *sessionFsHandler) Stat(params copilot.SessionFsStatParams) (*copilot.SessionFsStatResult, error) { + h.fs.mu.Lock() + defer h.fs.mu.Unlock() + p := h.sp(params.Path) + if content, ok := h.fs.files[p]; ok { + mt := h.fs.mtime[p] + return &copilot.SessionFsStatResult{ + IsFile: true, + IsDirectory: false, + Size: int64(len(content)), + Mtime: mt.Format(time.RFC3339Nano), + Birthtime: mt.Format(time.RFC3339Nano), + }, nil + } + if h.fs.dirs[p] { + return &copilot.SessionFsStatResult{ + IsFile: false, + IsDirectory: true, + Mtime: time.Now().Format(time.RFC3339Nano), + Birthtime: time.Now().Format(time.RFC3339Nano), + }, nil + } + return nil, fmt.Errorf("path not found: %s", params.Path) +} + +func (h *sessionFsHandler) Mkdir(params copilot.SessionFsMkdirParams) error { + h.fs.mu.Lock() + defer h.fs.mu.Unlock() + p := h.sp(params.Path) + if params.Recursive != nil && *params.Recursive { + h.fs.ensureParents(p + "/x") + } + h.fs.dirs[p] = true + return nil +} + +func (h *sessionFsHandler) Readdir(params copilot.SessionFsReaddirParams) (*copilot.SessionFsReaddirResult, error) { + h.fs.mu.Lock() + defer h.fs.mu.Unlock() + p := h.sp(params.Path) + if !strings.HasSuffix(p, "/") { + p += "/" + } + entries := map[string]bool{} + for k := range h.fs.files { + if strings.HasPrefix(k, p) && len(k) > len(p) { + rest := k[len(p):] + if idx := strings.Index(rest, "/"); idx >= 0 { + entries[rest[:idx]] = true + } else { + entries[rest] = true + } + } + } + for k := range h.fs.dirs { + if strings.HasPrefix(k, p) && len(k) > len(p) { + rest := k[len(p):] + if idx := strings.Index(rest, "/"); idx >= 0 { + entries[rest[:idx]] = true + } else { + entries[rest] = true + } + } + } + result := make([]string, 0, len(entries)) + for e := range entries { + result = append(result, e) + } + sort.Strings(result) + return &copilot.SessionFsReaddirResult{Entries: result}, nil +} + +func (h *sessionFsHandler) ReaddirWithTypes(params copilot.SessionFsReaddirWithTypesParams) (*copilot.SessionFsReaddirWithTypesResult, error) { + dirResult, err := h.Readdir(copilot.SessionFsReaddirParams{SessionID: params.SessionID, Path: params.Path}) + if err != nil { + return nil, err + } + p := h.sp(params.Path) + if !strings.HasSuffix(p, "/") { + p += "/" + } + h.fs.mu.Lock() + defer h.fs.mu.Unlock() + var entries []copilot.SessionFsDirEntry + for _, name := range dirResult.Entries { + full := p + name + entryType := "file" + if h.fs.dirs[full] { + entryType = "directory" + } else { + // Check if any file has this as prefix (implicit directory) + for k := range h.fs.files { + if strings.HasPrefix(k, full+"/") { + entryType = "directory" + break + } + } + } + entries = append(entries, copilot.SessionFsDirEntry{Name: name, Type: entryType}) + } + return &copilot.SessionFsReaddirWithTypesResult{Entries: entries}, nil +} + +func (h *sessionFsHandler) Rm(params copilot.SessionFsRmParams) error { + h.fs.mu.Lock() + defer h.fs.mu.Unlock() + p := h.sp(params.Path) + delete(h.fs.files, p) + delete(h.fs.dirs, p) + delete(h.fs.mtime, p) + return nil +} + +func (h *sessionFsHandler) Rename(params copilot.SessionFsRenameParams) error { + h.fs.mu.Lock() + defer h.fs.mu.Unlock() + src := h.sp(params.Src) + dest := h.sp(params.Dest) + if content, ok := h.fs.files[src]; ok { + h.fs.ensureParents(dest) + h.fs.files[dest] = content + h.fs.mtime[dest] = h.fs.mtime[src] + delete(h.fs.files, src) + delete(h.fs.mtime, src) + } + return nil +} + +func TestSessionFs(t *testing.T) { + ctx := testharness.NewTestContext(t) + + // Shared in memory filesystem across tests + memFS := newInMemoryFS() + + client := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.SessionFs = &copilot.SessionFsConfig{ + InitialCwd: "/", + SessionStatePath: "/session-state", + Conventions: "posix", + } + }) + t.Cleanup(func() { client.ForceStop() }) + + t.Run("should route file operations through the session fs provider", func(t *testing.T) { + ctx.ConfigureForTest(t) + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + CreateSessionFsHandler: func(s *copilot.Session) copilot.SessionFsHandler { + return &sessionFsHandler{sessionID: s.SessionID, fs: memFS} + }, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + msg, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 100 + 200?"}) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + if msg == nil || msg.Data.Content == nil || !strings.Contains(*msg.Data.Content, "300") { + t.Fatalf("Expected response containing '300', got: %v", msg) + } + session.Disconnect() + + // Verify the events file was written through our provider + eventsPath := "/" + session.SessionID + "/session-state/events.jsonl" + memFS.mu.Lock() + content, ok := memFS.files[eventsPath] + memFS.mu.Unlock() + if !ok { + t.Fatal("Expected events.jsonl to exist in in memory filesystem") + } + if !strings.Contains(content, "300") { + t.Errorf("Expected events.jsonl to contain '300', got: %s", content[:min(200, len(content))]) + } + }) + + t.Run("should load session data from fs provider on resume", func(t *testing.T) { + ctx.ConfigureForTest(t) + + session1, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + CreateSessionFsHandler: func(s *copilot.Session) copilot.SessionFsHandler { + return &sessionFsHandler{sessionID: s.SessionID, fs: memFS} + }, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + sessionID := session1.SessionID + + msg, err := session1.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is 50 + 50?"}) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + if msg == nil || msg.Data.Content == nil || !strings.Contains(*msg.Data.Content, "100") { + t.Fatalf("Expected response containing '100', got: %v", msg) + } + session1.Disconnect() + + // Verify events file exists + eventsPath := "/" + sessionID + "/session-state/events.jsonl" + memFS.mu.Lock() + _, exists := memFS.files[eventsPath] + memFS.mu.Unlock() + if !exists { + t.Fatal("Expected events.jsonl to exist before resume") + } + + // Resume the session + session2, err := client.ResumeSession(t.Context(), sessionID, &copilot.ResumeSessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + CreateSessionFsHandler: func(s *copilot.Session) copilot.SessionFsHandler { + return &sessionFsHandler{sessionID: s.SessionID, fs: memFS} + }, + }) + if err != nil { + t.Fatalf("Failed to resume session: %v", err) + } + + msg2, err := session2.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "What is that times 3?"}) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + session2.Disconnect() + if msg2 == nil || msg2.Data.Content == nil || !strings.Contains(*msg2.Data.Content, "300") { + t.Fatalf("Expected response containing '300', got: %v", msg2) + } + }) +} diff --git a/go/session.go b/go/session.go index 71facb03b..c72c19e1a 100644 --- a/go/session.go +++ b/go/session.go @@ -70,6 +70,8 @@ type Session struct { commandHandlersMu sync.RWMutex elicitationHandler ElicitationHandler elicitationMu sync.RWMutex + sessionFsHandler SessionFsHandler + sessionFsMu sync.RWMutex capabilities SessionCapabilities capabilitiesMu sync.RWMutex @@ -575,6 +577,20 @@ func (s *Session) getElicitationHandler() ElicitationHandler { return s.elicitationHandler } +// registerSessionFsHandler registers a session filesystem handler for this session. +func (s *Session) registerSessionFsHandler(handler SessionFsHandler) { + s.sessionFsMu.Lock() + defer s.sessionFsMu.Unlock() + s.sessionFsHandler = handler +} + +// getSessionFsHandler returns the currently registered session filesystem handler, or nil. +func (s *Session) getSessionFsHandler() SessionFsHandler { + s.sessionFsMu.RLock() + defer s.sessionFsMu.RUnlock() + return s.sessionFsHandler +} + // handleElicitationRequest dispatches an elicitation.requested event to the registered handler // and sends the result back via the RPC layer. Auto-cancels on error. func (s *Session) handleElicitationRequest(elicitCtx ElicitationContext, requestID string) { diff --git a/go/types.go b/go/types.go index ff9b4aed3..b6762ae3b 100644 --- a/go/types.go +++ b/go/types.go @@ -67,6 +67,11 @@ type ClientOptions struct { // When non-nil, COPILOT_OTEL_ENABLED=true is set and any populated fields // are mapped to the corresponding environment variables. Telemetry *TelemetryConfig + // SessionFs configures the session filesystem provider. + // When non-nil, the client registers as the session filesystem provider + // on connection, routing all session scoped file I/O through SessionFsHandler + // callbacks instead of the server's default local filesystem storage. + SessionFs *SessionFsConfig } // TelemetryConfig configures OpenTelemetry integration for the Copilot CLI process. @@ -508,6 +513,9 @@ type SessionConfig struct { // When provided, the server may call back to this client for form-based UI dialogs // (e.g. from MCP tools). Also enables the elicitation capability on the session. OnElicitationRequest ElicitationHandler + // CreateSessionFsHandler creates a session filesystem handler for this session. + // Required when ClientOptions.SessionFs is configured. + CreateSessionFsHandler func(session *Session) SessionFsHandler } type Tool struct { Name string `json:"name"` @@ -702,6 +710,9 @@ type ResumeSessionConfig struct { // OnElicitationRequest is a handler for elicitation requests from the server. // See SessionConfig.OnElicitationRequest. OnElicitationRequest ElicitationHandler + // CreateSessionFsHandler creates a session filesystem handler for this session. + // Required when ClientOptions.SessionFs is configured. + CreateSessionFsHandler func(session *Session) SessionFsHandler } type ProviderConfig struct { // Type is the provider type: "openai", "azure", or "anthropic". Defaults to "openai". @@ -1107,3 +1118,133 @@ type userInputResponse struct { Answer string `json:"answer"` WasFreeform bool `json:"wasFreeform"` } + +// SessionFsConfig configures the session filesystem provider at the connection level. +type SessionFsConfig struct { + // InitialCwd is the initial working directory for sessions. + InitialCwd string + // SessionStatePath is the path within each session's SessionFs where the + // runtime stores session scoped files. + SessionStatePath string + // Conventions is the path convention: "posix" or "windows". + Conventions string +} + +// SessionFsHandler handles session filesystem operations. +// Implement this interface to provide a custom virtual filesystem for session data. +type SessionFsHandler interface { + ReadFile(params SessionFsReadFileParams) (*SessionFsReadFileResult, error) + WriteFile(params SessionFsWriteFileParams) error + AppendFile(params SessionFsAppendFileParams) error + Exists(params SessionFsExistsParams) (*SessionFsExistsResult, error) + Stat(params SessionFsStatParams) (*SessionFsStatResult, error) + Mkdir(params SessionFsMkdirParams) error + Readdir(params SessionFsReaddirParams) (*SessionFsReaddirResult, error) + ReaddirWithTypes(params SessionFsReaddirWithTypesParams) (*SessionFsReaddirWithTypesResult, error) + Rm(params SessionFsRmParams) error + Rename(params SessionFsRenameParams) error +} + +// SessionFsReadFileParams are the params for a sessionFs.readFile request. +type SessionFsReadFileParams struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` +} + +// SessionFsReadFileResult is the result of a sessionFs.readFile request. +type SessionFsReadFileResult struct { + Content string `json:"content"` +} + +// SessionFsWriteFileParams are the params for a sessionFs.writeFile request. +type SessionFsWriteFileParams struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` + Content string `json:"content"` + Mode *int `json:"mode,omitempty"` +} + +// SessionFsAppendFileParams are the params for a sessionFs.appendFile request. +type SessionFsAppendFileParams struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` + Content string `json:"content"` + Mode *int `json:"mode,omitempty"` +} + +// SessionFsExistsParams are the params for a sessionFs.exists request. +type SessionFsExistsParams struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` +} + +// SessionFsExistsResult is the result of a sessionFs.exists request. +type SessionFsExistsResult struct { + Exists bool `json:"exists"` +} + +// SessionFsStatParams are the params for a sessionFs.stat request. +type SessionFsStatParams struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` +} + +// SessionFsStatResult is the result of a sessionFs.stat request. +type SessionFsStatResult struct { + IsFile bool `json:"isFile"` + IsDirectory bool `json:"isDirectory"` + Size int64 `json:"size"` + Mtime string `json:"mtime"` + Birthtime string `json:"birthtime"` +} + +// SessionFsMkdirParams are the params for a sessionFs.mkdir request. +type SessionFsMkdirParams struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` + Recursive *bool `json:"recursive,omitempty"` + Mode *int `json:"mode,omitempty"` +} + +// SessionFsReaddirParams are the params for a sessionFs.readdir request. +type SessionFsReaddirParams struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` +} + +// SessionFsReaddirResult is the result of a sessionFs.readdir request. +type SessionFsReaddirResult struct { + Entries []string `json:"entries"` +} + +// SessionFsReaddirWithTypesParams are the params for a sessionFs.readdirWithTypes request. +type SessionFsReaddirWithTypesParams struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` +} + +// SessionFsDirEntry is a directory entry with type information. +type SessionFsDirEntry struct { + Name string `json:"name"` + Type string `json:"type"` // "file" or "directory" +} + +// SessionFsReaddirWithTypesResult is the result of a sessionFs.readdirWithTypes request. +type SessionFsReaddirWithTypesResult struct { + Entries []SessionFsDirEntry `json:"entries"` +} + +// SessionFsRmParams are the params for a sessionFs.rm request. +type SessionFsRmParams struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` + Recursive *bool `json:"recursive,omitempty"` + Force *bool `json:"force,omitempty"` +} + +// SessionFsRenameParams are the params for a sessionFs.rename request. +type SessionFsRenameParams struct { + SessionID string `json:"sessionId"` + Src string `json:"src"` + Dest string `json:"dest"` +} diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index db9f150c8..5faf75942 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -11,6 +11,8 @@ ModelLimitsOverride, ModelSupportsOverride, ModelVisionLimitsOverride, + SessionFsConfig, + SessionFsHandler, SubprocessConfig, ) from .session import ( @@ -46,6 +48,8 @@ "ModelSupportsOverride", "ModelVisionLimitsOverride", "SessionCapabilities", + "SessionFsConfig", + "SessionFsHandler", "SessionUiApi", "SessionUiCapabilities", "SubprocessConfig", diff --git a/python/copilot/client.py b/python/copilot/client.py index df6756cfe..0d0e90020 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -140,6 +140,93 @@ class ExternalServerConfig: """Server URL. Supports ``"host:port"``, ``"http://host:port"``, or just ``"port"``.""" +@dataclass +class SessionFsConfig: + """Connection level configuration for the session filesystem provider. + + When provided to :class:`CopilotClient`, the client registers as the session + filesystem provider on connection, routing all session scoped file I/O through + :class:`SessionFsHandler` callbacks instead of the server's default local + filesystem storage. + """ + + initial_cwd: str + """Initial working directory for sessions (user's project directory).""" + + session_state_path: str + """Path within each session's SessionFs where the runtime stores + session scoped files (events, workspace, checkpoints, etc.).""" + + conventions: Literal["posix", "windows"] + """Path conventions used by this filesystem provider.""" + + +class SessionFsHandler: + """Handler interface for session filesystem operations. + + Implement this class to provide a custom virtual filesystem for session data. + All methods receive their parameters as keyword arguments. + """ + + async def read_file(self, *, session_id: str, path: str) -> dict[str, Any]: + """Read the contents of a file. Return ``{"content": str}``.""" + raise NotImplementedError + + async def write_file( + self, *, session_id: str, path: str, content: str, mode: int | None = None + ) -> None: + """Write content to a file.""" + raise NotImplementedError + + async def append_file( + self, *, session_id: str, path: str, content: str, mode: int | None = None + ) -> None: + """Append content to a file.""" + raise NotImplementedError + + async def exists(self, *, session_id: str, path: str) -> dict[str, Any]: + """Check whether a path exists. Return ``{"exists": bool}``.""" + raise NotImplementedError + + async def stat(self, *, session_id: str, path: str) -> dict[str, Any]: + """Return metadata about a file or directory.""" + raise NotImplementedError + + async def mkdir( + self, + *, + session_id: str, + path: str, + recursive: bool | None = None, + mode: int | None = None, + ) -> None: + """Create a directory.""" + raise NotImplementedError + + async def readdir(self, *, session_id: str, path: str) -> dict[str, Any]: + """List entries in a directory. Return ``{"entries": list[str]}``.""" + raise NotImplementedError + + async def readdir_with_types(self, *, session_id: str, path: str) -> dict[str, Any]: + """List entries with type info. Return ``{"entries": list[dict]}``.""" + raise NotImplementedError + + async def rm( + self, + *, + session_id: str, + path: str, + recursive: bool | None = None, + force: bool | None = None, + ) -> None: + """Remove a file or directory.""" + raise NotImplementedError + + async def rename(self, *, session_id: str, src: str, dest: str) -> None: + """Rename (move) a file or directory.""" + raise NotImplementedError + + # ============================================================================ # Response Types # ============================================================================ @@ -810,6 +897,7 @@ def __init__( *, auto_start: bool = True, on_list_models: Callable[[], list[ModelInfo] | Awaitable[list[ModelInfo]]] | None = None, + session_fs: SessionFsConfig | None = None, ): """ Initialize a new CopilotClient. @@ -844,6 +932,7 @@ def __init__( self._config: SubprocessConfig | ExternalServerConfig = config self._auto_start = auto_start self._on_list_models = on_list_models + self._session_fs_config = session_fs # Resolve connection-mode-specific state self._actual_host: str = "localhost" @@ -1018,6 +1107,10 @@ async def start(self) -> None: # Verify protocol version compatibility await self._verify_protocol_version() + # Register sessionFs provider if configured + if self._session_fs_config is not None: + await self._register_session_fs_provider() + self._state = "connected" except ProcessExitedError as e: # Process exited with error - reraise as RuntimeError with stderr @@ -1179,6 +1272,7 @@ async def create_session( on_event: Callable[[SessionEvent], None] | None = None, commands: list[CommandDefinition] | None = None, on_elicitation_request: ElicitationHandler | None = None, + create_session_fs_handler: Callable[[CopilotSession], SessionFsHandler] | None = None, ) -> CopilotSession: """ Create a new conversation session with the Copilot CLI. @@ -1375,6 +1469,14 @@ async def create_session( session._register_user_input_handler(on_user_input_request) if on_elicitation_request: session._register_elicitation_handler(on_elicitation_request) + if self._session_fs_config is not None: + if create_session_fs_handler is not None: + session._register_session_fs_handler(create_session_fs_handler(session)) + else: + raise ValueError( + "create_session_fs_handler is required in session config " + "when session_fs is enabled in client options." + ) if hooks: session._register_hooks(hooks) if transform_callbacks: @@ -1424,6 +1526,7 @@ async def resume_session( on_event: Callable[[SessionEvent], None] | None = None, commands: list[CommandDefinition] | None = None, on_elicitation_request: ElicitationHandler | None = None, + create_session_fs_handler: Callable[[CopilotSession], SessionFsHandler] | None = None, ) -> CopilotSession: """ Resume an existing conversation session by its ID. @@ -1599,6 +1702,14 @@ async def resume_session( session._register_user_input_handler(on_user_input_request) if on_elicitation_request: session._register_elicitation_handler(on_elicitation_request) + if self._session_fs_config is not None: + if create_session_fs_handler is not None: + session._register_session_fs_handler(create_session_fs_handler(session)) + else: + raise ValueError( + "create_session_fs_handler is required in session config " + "when session_fs is enabled in client options." + ) if hooks: session._register_hooks(hooks) if transform_callbacks: @@ -2284,6 +2395,22 @@ def handle_notification(method: str, params: dict): "systemMessage.transform", self._handle_system_message_transform ) + # SessionFs client session API handlers + self._client.set_request_handler("sessionFs.readFile", self._handle_session_fs_read_file) + self._client.set_request_handler("sessionFs.writeFile", self._handle_session_fs_write_file) + self._client.set_request_handler( + "sessionFs.appendFile", self._handle_session_fs_append_file + ) + self._client.set_request_handler("sessionFs.exists", self._handle_session_fs_exists) + self._client.set_request_handler("sessionFs.stat", self._handle_session_fs_stat) + self._client.set_request_handler("sessionFs.mkdir", self._handle_session_fs_mkdir) + self._client.set_request_handler("sessionFs.readdir", self._handle_session_fs_readdir) + self._client.set_request_handler( + "sessionFs.readdirWithTypes", self._handle_session_fs_readdir_with_types + ) + self._client.set_request_handler("sessionFs.rm", self._handle_session_fs_rm) + self._client.set_request_handler("sessionFs.rename", self._handle_session_fs_rename) + # Start listening for messages loop = asyncio.get_running_loop() self._client.start(loop) @@ -2388,6 +2515,22 @@ def handle_notification(method: str, params: dict): "systemMessage.transform", self._handle_system_message_transform ) + # SessionFs client session API handlers + self._client.set_request_handler("sessionFs.readFile", self._handle_session_fs_read_file) + self._client.set_request_handler("sessionFs.writeFile", self._handle_session_fs_write_file) + self._client.set_request_handler( + "sessionFs.appendFile", self._handle_session_fs_append_file + ) + self._client.set_request_handler("sessionFs.exists", self._handle_session_fs_exists) + self._client.set_request_handler("sessionFs.stat", self._handle_session_fs_stat) + self._client.set_request_handler("sessionFs.mkdir", self._handle_session_fs_mkdir) + self._client.set_request_handler("sessionFs.readdir", self._handle_session_fs_readdir) + self._client.set_request_handler( + "sessionFs.readdirWithTypes", self._handle_session_fs_readdir_with_types + ) + self._client.set_request_handler("sessionFs.rm", self._handle_session_fs_rm) + self._client.set_request_handler("sessionFs.rename", self._handle_session_fs_rename) + # Start listening for messages loop = asyncio.get_running_loop() self._client.start(loop) @@ -2574,3 +2717,92 @@ async def _handle_permission_request_v2(self, params: dict) -> dict: "kind": "denied-no-approval-rule-and-could-not-request-from-user", } } + + async def _register_session_fs_provider(self) -> None: + """Register this client as the session filesystem provider.""" + cfg = self._session_fs_config + assert cfg is not None + assert self._client is not None + await self._client.request( + "sessionFs.setProvider", + { + "initialCwd": cfg.initial_cwd, + "sessionStatePath": cfg.session_state_path, + "conventions": cfg.conventions, + }, + ) + + def _get_session_fs_handler(self, session_id: str) -> SessionFsHandler: + with self._sessions_lock: + session = self._sessions.get(session_id) + if session is None: + raise ValueError(f"Unknown session {session_id}") + handler = session._get_session_fs_handler() + if handler is None: + raise ValueError(f"No sessionFs handler registered for session: {session_id}") + return handler + + async def _handle_session_fs_read_file(self, params: dict) -> dict: + h = self._get_session_fs_handler(params["sessionId"]) + return await h.read_file(session_id=params["sessionId"], path=params["path"]) + + async def _handle_session_fs_write_file(self, params: dict) -> dict: + h = self._get_session_fs_handler(params["sessionId"]) + await h.write_file( + session_id=params["sessionId"], + path=params["path"], + content=params["content"], + mode=params.get("mode"), + ) + return {} + + async def _handle_session_fs_append_file(self, params: dict) -> dict: + h = self._get_session_fs_handler(params["sessionId"]) + await h.append_file( + session_id=params["sessionId"], + path=params["path"], + content=params["content"], + mode=params.get("mode"), + ) + return {} + + async def _handle_session_fs_exists(self, params: dict) -> dict: + h = self._get_session_fs_handler(params["sessionId"]) + return await h.exists(session_id=params["sessionId"], path=params["path"]) + + async def _handle_session_fs_stat(self, params: dict) -> dict: + h = self._get_session_fs_handler(params["sessionId"]) + return await h.stat(session_id=params["sessionId"], path=params["path"]) + + async def _handle_session_fs_mkdir(self, params: dict) -> dict: + h = self._get_session_fs_handler(params["sessionId"]) + await h.mkdir( + session_id=params["sessionId"], + path=params["path"], + recursive=params.get("recursive"), + mode=params.get("mode"), + ) + return {} + + async def _handle_session_fs_readdir(self, params: dict) -> dict: + h = self._get_session_fs_handler(params["sessionId"]) + return await h.readdir(session_id=params["sessionId"], path=params["path"]) + + async def _handle_session_fs_readdir_with_types(self, params: dict) -> dict: + h = self._get_session_fs_handler(params["sessionId"]) + return await h.readdir_with_types(session_id=params["sessionId"], path=params["path"]) + + async def _handle_session_fs_rm(self, params: dict) -> dict: + h = self._get_session_fs_handler(params["sessionId"]) + await h.rm( + session_id=params["sessionId"], + path=params["path"], + recursive=params.get("recursive"), + force=params.get("force"), + ) + return {} + + async def _handle_session_fs_rename(self, params: dict) -> dict: + h = self._get_session_fs_handler(params["sessionId"]) + await h.rename(session_id=params["sessionId"], src=params["src"], dest=params["dest"]) + return {} diff --git a/python/copilot/session.py b/python/copilot/session.py index 59ec8532b..d6032c60e 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -983,6 +983,8 @@ def __init__( self._command_handlers_lock = threading.Lock() self._elicitation_handler: ElicitationHandler | None = None self._elicitation_handler_lock = threading.Lock() + self._session_fs_handler: Any = None + self._session_fs_handler_lock = threading.Lock() self._capabilities: SessionCapabilities = {} self._rpc: SessionRpc | None = None self._destroyed = False @@ -1523,6 +1525,16 @@ def _register_elicitation_handler(self, handler: ElicitationHandler | None) -> N with self._elicitation_handler_lock: self._elicitation_handler = handler + def _register_session_fs_handler(self, handler: Any) -> None: + """Register a session filesystem handler for this session.""" + with self._session_fs_handler_lock: + self._session_fs_handler = handler + + def _get_session_fs_handler(self) -> Any: + """Return the currently registered session filesystem handler, or None.""" + with self._session_fs_handler_lock: + return self._session_fs_handler + def _set_capabilities(self, capabilities: SessionCapabilities | None) -> None: """Set the host capabilities for this session. diff --git a/python/e2e/test_session_fs.py b/python/e2e/test_session_fs.py new file mode 100644 index 000000000..8df0b7e08 --- /dev/null +++ b/python/e2e/test_session_fs.py @@ -0,0 +1,306 @@ +"""E2E tests for SessionFs virtual filesystem support.""" + +from __future__ import annotations + +import os +import re +import shutil +import tempfile +from pathlib import Path +from typing import Any + +import pytest +import pytest_asyncio + +from copilot import CopilotClient, SessionFsConfig, SessionFsHandler +from copilot.client import SubprocessConfig +from copilot.session import CopilotSession, PermissionHandler + +from .testharness import E2ETestContext + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +class InMemoryFS: + """Simple in memory filesystem for testing.""" + + def __init__(self): + self._files: dict[str, str] = {} + self._dirs: set[str] = {"/"} + + def _ensure_parents(self, path: str) -> None: + parts = path.split("/") + for i in range(1, len(parts) - 1): + self._dirs.add("/".join(parts[: i + 1])) + + def read_file(self, path: str) -> str: + if path not in self._files: + raise FileNotFoundError(f"File not found: {path}") + return self._files[path] + + def write_file(self, path: str, content: str) -> None: + self._ensure_parents(path) + self._files[path] = content + + def append_file(self, path: str, content: str) -> None: + self._ensure_parents(path) + self._files[path] = self._files.get(path, "") + content + + def exists(self, path: str) -> bool: + p = path.rstrip("/") or "/" + return p in self._files or p in self._dirs + + def mkdir(self, path: str, recursive: bool = False) -> None: + if recursive: + self._ensure_parents(path + "/x") + self._dirs.add(path.rstrip("/")) + + def readdir(self, path: str) -> list[str]: + prefix = path if path.endswith("/") else path + "/" + entries: set[str] = set() + for key in list(self._files.keys()) + list(self._dirs): + if key.startswith(prefix) and len(key) > len(prefix): + rest = key[len(prefix) :] + slash = rest.find("/") + entries.add(rest[:slash] if slash >= 0 else rest) + return sorted(entries) + + def remove(self, path: str) -> None: + p = path.rstrip("/") or "/" + self._files.pop(p, None) + self._dirs.discard(p) + + def rename(self, src: str, dest: str) -> None: + if src in self._files: + self._ensure_parents(dest) + self._files[dest] = self._files.pop(src) + + +class InMemorySessionFsHandler(SessionFsHandler): + """SessionFs handler backed by an in memory filesystem.""" + + def __init__(self, session_id: str, fs: InMemoryFS): + self._session_id = session_id + self._fs = fs + + def _sp(self, path: str) -> str: + if path.startswith("/"): + return f"/{self._session_id}{path}" + return f"/{self._session_id}/{path}" + + async def read_file(self, *, session_id: str, path: str) -> dict[str, Any]: + return {"content": self._fs.read_file(self._sp(path))} + + async def write_file( + self, *, session_id: str, path: str, content: str, mode: int | None = None + ) -> None: + self._fs.write_file(self._sp(path), content) + + async def append_file( + self, *, session_id: str, path: str, content: str, mode: int | None = None + ) -> None: + self._fs.append_file(self._sp(path), content) + + async def exists(self, *, session_id: str, path: str) -> dict[str, Any]: + return {"exists": self._fs.exists(self._sp(path))} + + async def stat(self, *, session_id: str, path: str) -> dict[str, Any]: + p = self._sp(path) + if p in self._fs._files: + content = self._fs._files[p] + return { + "isFile": True, + "isDirectory": False, + "size": len(content), + "mtime": "2026-01-01T00:00:00.000Z", + "birthtime": "2026-01-01T00:00:00.000Z", + } + if p.rstrip("/") in self._fs._dirs: + return { + "isFile": False, + "isDirectory": True, + "size": 0, + "mtime": "2026-01-01T00:00:00.000Z", + "birthtime": "2026-01-01T00:00:00.000Z", + } + raise FileNotFoundError(f"Path not found: {path}") + + async def mkdir( + self, + *, + session_id: str, + path: str, + recursive: bool | None = None, + mode: int | None = None, + ) -> None: + self._fs.mkdir(self._sp(path), recursive=bool(recursive)) + + async def readdir(self, *, session_id: str, path: str) -> dict[str, Any]: + return {"entries": self._fs.readdir(self._sp(path))} + + async def readdir_with_types(self, *, session_id: str, path: str) -> dict[str, Any]: + p = self._sp(path) + names = self._fs.readdir(p) + prefix = p if p.endswith("/") else p + "/" + entries = [] + for name in names: + full = prefix + name + is_dir = full in self._fs._dirs or any( + k.startswith(full + "/") for k in self._fs._files + ) + entries.append({"name": name, "type": "directory" if is_dir else "file"}) + return {"entries": entries} + + async def rm( + self, + *, + session_id: str, + path: str, + recursive: bool | None = None, + force: bool | None = None, + ) -> None: + self._fs.remove(self._sp(path)) + + async def rename(self, *, session_id: str, src: str, dest: str) -> None: + self._fs.rename(self._sp(src), self._sp(dest)) + + +# Shared in memory filesystem for all tests in this module +_shared_fs = InMemoryFS() + +SESSION_FS_CONFIG = SessionFsConfig( + initial_cwd="/", + session_state_path="/session-state", + conventions="posix", +) + + +def _make_handler(session: CopilotSession) -> SessionFsHandler: + return InMemorySessionFsHandler(session.session_id, _shared_fs) + + +@pytest_asyncio.fixture(scope="module", loop_scope="module") +async def ctx(request): + """Custom context that creates a CopilotClient with SessionFs enabled.""" + context = E2ETestContext() + # Override setup to inject session_fs config + context.cli_path = context.cli_path or str( + ( + Path(__file__).parents[2] + / "nodejs" + / "node_modules" + / "@github" + / "copilot" + / "index.js" + ).resolve() + ) + env_cli = os.environ.get("COPILOT_CLI_PATH") + if env_cli and Path(env_cli).exists(): + context.cli_path = str(Path(env_cli).resolve()) + else: + base = Path(__file__).parents[2] + cli = base / "nodejs" / "node_modules" / "@github" / "copilot" / "index.js" + if cli.exists(): + context.cli_path = str(cli.resolve()) + else: + pytest.skip("CLI not found") + + context.home_dir = tempfile.mkdtemp(prefix="copilot-test-config-") + context.work_dir = tempfile.mkdtemp(prefix="copilot-test-work-") + + from .testharness.proxy import CapiProxy + + context._proxy = CapiProxy() + context.proxy_url = await context._proxy.start() + + github_token = ( + "fake-token-for-e2e-tests" if os.environ.get("GITHUB_ACTIONS") == "true" else None + ) + env = os.environ.copy() + env.update( + { + "COPILOT_API_URL": context.proxy_url, + "XDG_CONFIG_HOME": context.home_dir, + "XDG_STATE_HOME": context.home_dir, + } + ) + context._client = CopilotClient( + SubprocessConfig( + cli_path=context.cli_path, + cwd=context.work_dir, + env=env, + github_token=github_token, + ), + session_fs=SESSION_FS_CONFIG, + ) + + yield context + any_failed = request.session.stash.get("any_test_failed", False) + await context.teardown(test_failed=any_failed) + + +@pytest_asyncio.fixture(autouse=True, loop_scope="module") +async def configure_test(request, ctx): + """Configure the proxy for each test using session_fs snapshot dir.""" + test_name = request.node.name + if test_name.startswith("test_"): + test_name = test_name[5:] + sanitized = re.sub(r"[^a-zA-Z0-9]", "_", test_name).lower() + + snapshots_dir = Path(__file__).parents[2] / "test" / "snapshots" + snapshot_path = snapshots_dir / "session_fs" / f"{sanitized}.yaml" + + await ctx._proxy.configure(str(snapshot_path.resolve()), ctx.work_dir) + + # Clean temp dirs between tests + for item in Path(ctx.home_dir).iterdir(): + if item.is_dir(): + shutil.rmtree(item, ignore_errors=True) + else: + item.unlink(missing_ok=True) + yield + + +class TestSessionFs: + async def test_should_route_file_operations_through_the_session_fs_provider( + self, ctx: E2ETestContext + ): + session = await ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all, + create_session_fs_handler=_make_handler, + ) + + msg = await session.send_and_wait("What is 100 + 200?") + assert msg is not None + assert "300" in msg.data.content + await session.disconnect() + + events_path = f"/{session.session_id}/session-state/events.jsonl" + content = _shared_fs.read_file(events_path) + assert "300" in content + + async def test_should_load_session_data_from_fs_provider_on_resume(self, ctx: E2ETestContext): + session1 = await ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all, + create_session_fs_handler=_make_handler, + ) + session_id = session1.session_id + + msg = await session1.send_and_wait("What is 50 + 50?") + assert msg is not None + assert "100" in msg.data.content + await session1.disconnect() + + events_path = f"/{session_id}/session-state/events.jsonl" + assert _shared_fs.exists(events_path) + + session2 = await ctx.client.resume_session( + session_id, + on_permission_request=PermissionHandler.approve_all, + create_session_fs_handler=_make_handler, + ) + + msg2 = await session2.send_and_wait("What is that times 3?") + await session2.disconnect() + assert msg2 is not None + assert "300" in msg2.data.content diff --git a/python/test_client.py b/python/test_client.py index d655df4d4..b5f3e82c9 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -528,3 +528,38 @@ async def test_aexit_calls_disconnect(self): with patch.object(session, "disconnect", new_callable=AsyncMock) as mock_disconnect: await session.__aexit__(None, None, None) mock_disconnect.assert_awaited_once() + + +class TestSessionFsConfig: + def test_session_fs_config_stored_on_client(self): + from copilot.client import SessionFsConfig + + config = SessionFsConfig( + initial_cwd="/home/user", + session_state_path="/session-state", + conventions="posix", + ) + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH), session_fs=config) + assert client._session_fs_config is not None + assert client._session_fs_config.initial_cwd == "/home/user" + assert client._session_fs_config.conventions == "posix" + + @pytest.mark.asyncio + async def test_create_session_raises_when_session_fs_enabled_but_no_handler(self): + from copilot.client import SessionFsConfig + + config = SessionFsConfig( + initial_cwd="/", + session_state_path="/session-state", + conventions="posix", + ) + client = CopilotClient(SubprocessConfig(cli_path=CLI_PATH), session_fs=config) + await client.start() + try: + with pytest.raises(ValueError, match="create_session_fs_handler is required"): + await client.create_session( + on_permission_request=PermissionHandler.approve_all, + # create_session_fs_handler intentionally omitted + ) + finally: + await client.stop()