diff --git a/src/main/kotlin/com/github/codeplangui/BridgeHandler.kt b/src/main/kotlin/com/github/codeplangui/BridgeHandler.kt index 4603895..9499cc2 100644 --- a/src/main/kotlin/com/github/codeplangui/BridgeHandler.kt +++ b/src/main/kotlin/com/github/codeplangui/BridgeHandler.kt @@ -268,12 +268,13 @@ class BridgeHandler( put("description", JsonPrimitive(description)) }) - fun notifyApprovalRequest(requestId: String, command: String, description: String) = + fun notifyApprovalRequest(requestId: String, command: String, description: String, toolName: String = "") = flushAndPush(buildEventJS("approval_request") { put("requestId", JsonPrimitive(requestId)) put("command", JsonPrimitive(command)) put("toolInput", JsonPrimitive(command)) put("description", JsonPrimitive(description)) + put("toolName", JsonPrimitive(toolName)) }).also { logger.info( "[CodePlanGUI Bridge] ide->frontend approvalRequest " + diff --git a/src/main/kotlin/com/github/codeplangui/ChatService.kt b/src/main/kotlin/com/github/codeplangui/ChatService.kt index b84b9a2..718ba6c 100644 --- a/src/main/kotlin/com/github/codeplangui/ChatService.kt +++ b/src/main/kotlin/com/github/codeplangui/ChatService.kt @@ -1,5 +1,6 @@ package com.github.codeplangui +import com.github.codeplangui.api.FunctionDefinition import com.github.codeplangui.api.OkHttpSseClient import com.github.codeplangui.api.ToolCallAccumulator import com.github.codeplangui.api.ToolCallDelta @@ -7,15 +8,7 @@ import com.github.codeplangui.api.ToolDefinition import com.github.codeplangui.api.TruncationDecision import com.github.codeplangui.api.TruncationHandler import com.github.codeplangui.execution.CommandExecutionService -import com.github.codeplangui.execution.ExecutionResult -import com.github.codeplangui.execution.FileChangeReview -import com.github.codeplangui.execution.FileWriteLock -import com.github.codeplangui.execution.PendingToolCall import com.github.codeplangui.execution.ShellPlatform -import com.github.codeplangui.execution.ToolCallDispatcher -import com.github.codeplangui.execution.ToolRegistry -import com.github.codeplangui.execution.ToolSpecs -import com.github.codeplangui.execution.hooks.ToolExecutionLogger import com.github.codeplangui.model.ChatSession import com.github.codeplangui.model.Message import com.github.codeplangui.model.MessageRole @@ -23,8 +16,14 @@ import com.github.codeplangui.model.ToolCallRecord import com.github.codeplangui.settings.ApiKeyStore import com.github.codeplangui.settings.PluginSettings import com.github.codeplangui.settings.PluginSettingsConfigurable -import com.github.codeplangui.settings.SettingsState import com.github.codeplangui.storage.SessionStore +import com.github.codeplangui.tools.Progress +import com.github.codeplangui.tools.ToolExecutionContext +import com.github.codeplangui.tools.ToolPermissionContext +import com.github.codeplangui.tools.ToolResultBlock +import com.github.codeplangui.tools.ToolUpdate +import com.github.codeplangui.tools.ToolUseBlock +import com.github.codeplangui.tools.runToolUseBatch import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.application.ReadAction import com.intellij.openapi.Disposable @@ -35,14 +34,12 @@ import com.intellij.openapi.project.Project import com.intellij.openapi.options.ShowSettingsUtil import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel import kotlinx.coroutines.launch import kotlinx.coroutines.withContext -import kotlinx.serialization.json.contentOrNull -import kotlinx.serialization.json.jsonObject -import kotlinx.serialization.json.jsonPrimitive -import kotlinx.serialization.json.put +import kotlinx.serialization.json.buildJsonObject import okhttp3.sse.EventSource import java.util.UUID import java.util.concurrent.CompletableFuture @@ -77,17 +74,10 @@ class ChatService(private val project: Project) : Disposable { // Approval gate: suspended coroutines wait on these futures private val pendingApprovals = ConcurrentHashMap>() + // Maps requestId → command string so onApprovalResponse can update the whitelist private val pendingApprovalCommands = ConcurrentHashMap() - - // Unified tool system (new) - private val toolRegistry = ToolRegistry(this) - private val fileChangeReview = FileChangeReview() - private val fileWriteLock = FileWriteLock() - private val dispatcher = ToolCallDispatcher(toolRegistry, fileChangeReview, fileWriteLock, project).also { - it.addHook(ToolExecutionLogger()) - // Register built-in tools - toolRegistry.addTools(ToolSpecs().allSpecs()) - } + // Tracks msgId while tool batch is executing so cancelStream() works during tools + @Volatile private var runningToolMsgId: String? = null // Tracks which msgIds have had notifyStart sent to the frontend // When tools are enabled, notifyStart is deferred until the final response round @@ -155,7 +145,6 @@ class ChatService(private val project: Project) : Disposable { val settingsState = settings.getState() val commandExecutionEnabled = settingsState.commandExecutionEnabled - val unifiedTools = settingsState.unifiedToolsEnabled && commandExecutionEnabled val contextSnapshot = if (includeContext && settingsState.contextInjectionEnabled) { capturePromptContextSnapshot() @@ -193,9 +182,7 @@ class ChatService(private val project: Project) : Disposable { temperature = settingsState.chatTemperature, maxTokens = settingsState.chatMaxTokens, stream = true, - tools = if (unifiedTools) dispatcher.buildToolsParam() - else if (commandExecutionEnabled) listOf(runCommandToolDefinition()) - else null + tools = if (commandExecutionEnabled) buildToolDefinitions() else null ) // notifyStart is now sent unconditionally in startStreamingRound() (Phase 2). @@ -203,11 +190,12 @@ class ChatService(private val project: Project) : Disposable { } fun cancelStream() { - val wasStreaming = activeMessageId != null + val msgId = activeMessageId ?: runningToolMsgId + val wasStreaming = msgId != null activeStream?.cancel() activeStream = null - val msgId = activeMessageId activeMessageId = null + runningToolMsgId = null if (wasStreaming && msgId != null) { publishStatus() bridgeHandler?.notifyEnd(msgId) @@ -224,7 +212,7 @@ class ChatService(private val project: Project) : Disposable { pendingApprovals.values.forEach { it.complete(false) } pendingApprovals.clear() pendingApprovalCommands.clear() - dispatcher.resetSession() + runningToolMsgId = null sessionStore.clearSession() contextFileCallback?.invoke("") publishStatus() @@ -257,17 +245,6 @@ class ChatService(private val project: Project) : Disposable { "[CodePlanGUI Approval] received frontend decision " + "requestId=$requestId decision=$decision addToWhitelist=$addToWhitelist hasPending=${pendingApprovals.containsKey(requestId)}" ) - - // Try unified dispatcher first (coroutine-based) - if (PluginSettings.getInstance().getState().unifiedToolsEnabled) { - if (addToWhitelist && decision == "allow") { - // Note: unified dispatcher doesn't use command-level whitelist the same way. - // For now, also handle via legacy path for backwards compatibility. - } - dispatcher.onApprovalResponse(requestId, decision) - } - - // Legacy path (CompletableFuture-based) if (addToWhitelist && decision == "allow") { val command = pendingApprovalCommands[requestId] if (command != null) { @@ -275,10 +252,11 @@ class ChatService(private val project: Project) : Disposable { val whitelist = PluginSettings.getInstance().getState().commandWhitelist if (baseCommand !in whitelist) { whitelist.add(baseCommand) - logger.info("[CodePlanGUI Approval] added '$baseCommand' to whitelist (from command: ${command.summarizeForLog()})") + logger.info("[CodePlanGUI Approval] added '$baseCommand' to whitelist") } } } + pendingApprovalCommands.remove(requestId) pendingApprovals[requestId]?.complete(decision == "allow") } @@ -286,61 +264,130 @@ class ChatService(private val project: Project) : Disposable { publishStatus() } - private fun runCommandToolDefinition(): ToolDefinition = - ShellPlatform.current().toolDefinition() + private fun buildToolDefinitions(): List { + val pool = com.github.codeplangui.tools.ToolRegistry.assembleToolPool() + return pool.map { tool -> + ToolDefinition( + type = "function", + function = FunctionDefinition( + name = tool.name, + description = tool.description, + parameters = tool.inputSchema, + ) + ) + } + } + + private fun buildPermissionContext(): ToolPermissionContext { + val state = PluginSettings.getInstance().getState() + return ToolPermissionContext( + mode = ToolPermissionContext.Mode.DEFAULT, + alwaysAllow = state.commandWhitelist.toSet(), + alwaysDeny = emptySet(), + alwaysAsk = emptySet(), + additionalWorkingDirectories = emptySet(), + ) + } private fun handleToolCallChunk(delta: ToolCallDelta) { toolCallAccumulator.append(delta) } private suspend fun handleToolCallComplete(msgId: String, responseBuffer: StringBuilder) { - val settingsState = PluginSettings.getInstance().getState() - - if (settingsState.unifiedToolsEnabled) { - // New unified dispatcher path - handleToolCallCompleteUnified(msgId, responseBuffer) - } else { - // Legacy path - handleToolCallCompleteLegacy(msgId, responseBuffer) - } - } + runningToolMsgId = msgId - private suspend fun handleToolCallCompleteUnified(msgId: String, responseBuffer: StringBuilder) { val accumulatedToolCalls = toolCallAccumulator.snapshot() if (accumulatedToolCalls.isEmpty()) { + runningToolMsgId = null abortStream(msgId, "AI sent a tool_calls finish_reason but no tool call deltas were captured") return } - val pendingCalls = accumulatedToolCalls.mapNotNull { accumulated -> + val toolUses = accumulatedToolCalls.mapNotNull { accumulated -> val toolCallId = accumulated.id ?: run { + runningToolMsgId = null abortStream(msgId, "AI sent a tool_calls finish_reason but tool call index ${accumulated.index} had no id") return } - PendingToolCall( - id = toolCallId, + val inputJson = try { + kotlinx.serialization.json.Json.parseToJsonElement(accumulated.argumentsJson) + } catch (_: Exception) { + buildJsonObject {} + } + ToolUseBlock( + toolUseId = toolCallId, name = accumulated.functionName ?: ShellPlatform.current().toolName(), - arguments = accumulated.argumentsJson, - index = accumulated.index + input = inputJson, ) } - dispatcher.resetRound() - val results = dispatcher.dispatchAll(pendingCalls, msgId, bridgeHandler ?: return) + val pool = com.github.codeplangui.tools.ToolRegistry.assembleToolPool() + val settingsState = PluginSettings.getInstance().getState() + val ctx = ToolExecutionContext( + project = project, + toolUseId = msgId, + abortJob = scope.coroutineContext[Job]!!, + permissionContext = buildPermissionContext(), + commandTimeoutSeconds = settingsState.commandTimeoutSeconds, + onPermissionAsked = { event -> + val bridge = bridgeHandler ?: return@ToolExecutionContext false + val requestId = event.toolUseId + val previewSummary = event.preview?.summary ?: event.reason + // Strip "Run: " prefix to store raw command for whitelist persistence + pendingApprovalCommands[requestId] = previewSummary.removePrefix("Run: ") + bridge.notifyApprovalRequest(requestId, previewSummary, event.reason, event.toolName) + val future = CompletableFuture() + pendingApprovals[requestId] = future + try { + withContext(Dispatchers.IO) { future.get(60, TimeUnit.SECONDS) } + } catch (_: Exception) { + false + } finally { + pendingApprovals.remove(requestId) + } + } + ) + + val startTimes = mutableMapOf() + val results = mutableMapOf() - // Build tool results for API - val toolCallRecords = results.map { (call, result) -> + runToolUseBatch(toolUses, pool, ctx).collect { update -> + val bridge = bridgeHandler + when (update) { + is ToolUpdate.Started -> { + startTimes[update.toolUseId] = System.currentTimeMillis() + bridge?.notifyToolStepStart(msgId, update.toolUseId, update.toolName, update.toolName) + } + is ToolUpdate.ProgressEmitted -> { + val (line, type) = when (val p = update.progress) { + is Progress.Stdout -> p.line to "stdout" + is Progress.Stderr -> p.line to "stderr" + is Progress.Status -> p.message to "info" + } + bridge?.notifyLog(update.toolUseId, line, type) + } + is ToolUpdate.PermissionAsked -> Unit + is ToolUpdate.Completed -> { + val durationMs = System.currentTimeMillis() - (startTimes[update.toolUseId] ?: 0) + results[update.toolUseId] = update.block + bridge?.notifyToolStepEnd(msgId, update.toolUseId, !update.block.isError, update.block.content, durationMs) + } + is ToolUpdate.Failed -> { + val durationMs = System.currentTimeMillis() - (startTimes[update.toolUseId] ?: 0) + val errorBlock = ToolResultBlock(update.toolUseId, "[${update.stage}] ${update.message}", isError = true) + results[update.toolUseId] = errorBlock + bridge?.notifyToolStepEnd(msgId, update.toolUseId, false, update.message, durationMs) + } + } + } + + val toolCallRecords = toolUses.map { tu -> ToolCallRecord( - id = call.id, - functionName = call.name, - arguments = call.arguments + id = tu.toolUseId, + functionName = tu.name, + arguments = tu.input.toString() ) } - val toolResultContents = results.map { (_, result) -> - result.output - } - - // Add assistant message with tool_calls session.add(Message( role = MessageRole.ASSISTANT, content = responseBuffer.toString(), @@ -348,85 +395,25 @@ class ChatService(private val project: Project) : Disposable { seq = session.nextSeq(), toolCalls = toolCallRecords )) - - // Add tool result messages - results.forEach { (call, result) -> + toolUses.forEach { tu -> + val block = results[tu.toolUseId] ?: ToolResultBlock(tu.toolUseId, "(no result)", isError = true) session.add(Message( role = MessageRole.TOOL, - content = result.output, - toolCallId = call.id, + content = block.content, + toolCallId = tu.toolUseId, id = UUID.randomUUID().toString(), seq = session.nextSeq() )) } persistSession() - resetToolCallState() - responseBuffer.clear() - // Re-activate so startStreamingRound's onToken/onEnd callbacks work - activeMessageId = msgId - sendMessageInternal(msgId) - } - - private suspend fun handleToolCallCompleteLegacy(msgId: String, responseBuffer: StringBuilder) { - val preparedToolCalls = prepareToolCallsForExecution(msgId) ?: return - val state = PluginSettings.getInstance().getState() - val completedToolCalls = mutableListOf() - - logger.info( - "[CodePlanGUI Approval] executing tool-call batch " + - "msgId=$msgId toolCallCount=${preparedToolCalls.size}" - ) + // Check if cancelled while tools were running + if (runningToolMsgId != msgId) return + runningToolMsgId = null - for (toolCall in preparedToolCalls) { - completedToolCalls += executeToolCallWithApproval(msgId, toolCall, state) - } - - continueWithToolResults(msgId, responseBuffer, completedToolCalls) - } - - private fun continueWithToolResults( - msgId: String, - responseBuffer: StringBuilder, - completedToolCalls: List - ) { - logger.info( - "[CodePlanGUI Approval] continuing conversation with tool results " + - "msgId=$msgId toolCallCount=${completedToolCalls.size} " + - "results=${completedToolCalls.joinToString { "index=${it.toolCall.index}:${it.result.summarizeForLog()}" }}" - ) - // Assistant message must carry tool_calls for the OpenAI API to accept the follow-up tool result - session.add(Message( - role = MessageRole.ASSISTANT, - content = responseBuffer.toString(), - id = UUID.randomUUID().toString(), - seq = session.nextSeq(), - toolCalls = completedToolCalls.map { - ToolCallRecord( - id = it.toolCall.id, - functionName = it.toolCall.functionName, - arguments = it.toolCall.argumentsJson - ) - } - )) - completedToolCalls.forEach { - session.add(Message( - role = MessageRole.TOOL, - content = it.result.toToolResultContent(), - toolCallId = it.toolCall.id, - id = UUID.randomUUID().toString(), - seq = session.nextSeq() - )) - } - persistSession() - - // Reset state machine resetToolCallState() responseBuffer.clear() - - // The next round's startStreamingRound() will send notifyStart which - // the frontend's groupReducer handles by reusing the existing assistant - // group (still streaming). Intermediate tokens are discarded via round_end. + activeMessageId = msgId sendMessageInternal(msgId) } @@ -436,7 +423,6 @@ class ChatService(private val project: Project) : Disposable { val apiKey = ApiKeyStore.load(provider.id) ?: return val settingsState = pluginSettings.getState() val commandExecutionEnabled = settingsState.commandExecutionEnabled - val unifiedTools = settingsState.unifiedToolsEnabled && commandExecutionEnabled logger.info("[CodePlanGUI Approval] starting follow-up model round msgId=$msgId") val request = client.buildRequest( @@ -446,9 +432,7 @@ class ChatService(private val project: Project) : Disposable { temperature = settingsState.chatTemperature, maxTokens = settingsState.chatMaxTokens, stream = true, - tools = if (unifiedTools) dispatcher.buildToolsParam() - else if (commandExecutionEnabled) listOf(runCommandToolDefinition()) - else null + tools = if (commandExecutionEnabled) buildToolDefinitions() else null ) startStreamingRound(msgId, request, toolsEnabled = commandExecutionEnabled) @@ -557,25 +541,15 @@ $selection }, onFinishReason = { reason -> if (toolsEnabled && reason == "tool_calls" && activeMessageId == msgId) { - val isUnified = PluginSettings.getInstance().getState().unifiedToolsEnabled - if (isUnified) { - // Unified path: create the assistant bubble for tool steps only. - // Do NOT flush round-1 text — the formal response streams after tools complete. - if (msgId !in bridgeNotifiedStart) { - bridgeHandler?.notifyStart(msgId) - bridgeNotifiedStart.add(msgId) - } - // Clear activeMessageId to prevent onEnd from finalizing this message - // — the follow-up round will continue appending to it. - activeMessageId = null - } else { - // Legacy path: remove the assistant bubble so execution cards appear before - // the final assistant bubble. - if (msgId in bridgeNotifiedStart) { - bridgeHandler?.notifyRemoveMessage(msgId) - bridgeNotifiedStart.remove(msgId) - } + // Create the assistant bubble for tool steps only. + // Do NOT flush round-1 text — the formal response streams after tools complete. + if (msgId !in bridgeNotifiedStart) { + bridgeHandler?.notifyStart(msgId) + bridgeNotifiedStart.add(msgId) } + // Clear activeMessageId to prevent onEnd from finalizing this message + // — the follow-up round will continue appending to it. + activeMessageId = null val capturedBuffer = responseBuffer scope.launch { handleToolCallComplete(msgId, capturedBuffer) } } @@ -651,181 +625,18 @@ $selection ) } - private fun prepareToolCallsForExecution(msgId: String): List? { - val accumulatedToolCalls = toolCallAccumulator.snapshot() - if (accumulatedToolCalls.isEmpty()) { - abortStream(msgId, "AI sent a tool_calls finish_reason but no tool call deltas were captured") - return null - } - - return accumulatedToolCalls.map { accumulated -> - val toolCallId = accumulated.id ?: run { - abortStream( - msgId, - "AI sent a tool_calls finish_reason but tool call index ${accumulated.index} had no id" - ) - return null - } - val argsJson = accumulated.argumentsJson - val argsObj = try { - kotlinx.serialization.json.Json.parseToJsonElement(argsJson).jsonObject - } catch (_: Exception) { - abortStream(msgId, "AI returned malformed tool call arguments for index ${accumulated.index}: '$argsJson'") - return null - } - val command = argsObj["command"]?.jsonPrimitive?.contentOrNull ?: run { - abortStream(msgId, "AI tool call index ${accumulated.index} is missing required 'command' field") - return null - } - val description = argsObj["description"]?.jsonPrimitive?.contentOrNull ?: "" - - PreparedToolCall( - index = accumulated.index, - id = toolCallId, - functionName = accumulated.functionName ?: ShellPlatform.current().toolName(), - argumentsJson = argsJson, - command = command, - description = description - ) - } - } - - private suspend fun executeToolCallWithApproval( - msgId: String, - toolCall: PreparedToolCall, - state: SettingsState - ): CompletedToolCall { - val requestId = UUID.randomUUID().toString() - logger.info( - "[CodePlanGUI Approval] prepared approval request " + - "requestId=$requestId msgId=$msgId toolCallId=${toolCall.id} index=${toolCall.index} " + - "function=${toolCall.functionName} command=${toolCall.command.summarizeForLog()} " + - "description=${toolCall.description.summarizeForLog()}" - ) - - bridgeHandler?.notifyExecutionCard(requestId, toolCall.command, toolCall.description) - - val basePath = project.basePath ?: "" - if (CommandExecutionService.hasPathsOutsideWorkspace(toolCall.command, basePath)) { - logger.info( - "[CodePlanGUI Approval] blocked by workspace path check " + - "requestId=$requestId index=${toolCall.index} command=${toolCall.command.summarizeForLog()} " + - "basePath=${basePath.summarizeForLog()}" - ) - val result = ExecutionResult.Blocked(toolCall.command, "Command accesses paths outside the project") - bridgeHandler?.notifyExecutionStatus(requestId, "blocked", result.toToolResultContent()) - return CompletedToolCall(toolCall, result) - } - - logger.info( - "[CodePlanGUI Approval] whitelist check " + - "requestId=$requestId baseCommand=${CommandExecutionService.extractBaseCommand(toolCall.command)} " + - "whitelist=${state.commandWhitelist}" - ) - - val whitelisted = CommandExecutionService.isWhitelisted(toolCall.command, state.commandWhitelist) - if (!whitelisted) { - logger.info( - "[CodePlanGUI Approval] command not in whitelist, requesting approval " + - "requestId=$requestId index=${toolCall.index} command=${toolCall.command.summarizeForLog()}" - ) - bridgeHandler?.notifyApprovalRequest(requestId, toolCall.command, toolCall.description) - bridgeHandler?.notifyExecutionStatus(requestId, "waiting", "{}") - bridgeHandler?.notifyLog(requestId, "Waiting for approval...", "info") - - val future = CompletableFuture() - pendingApprovals[requestId] = future - pendingApprovalCommands[requestId] = toolCall.command - - val approved = try { - withContext(Dispatchers.IO) { future.get(60, TimeUnit.SECONDS) } - } catch (e: Exception) { - logger.info( - "[CodePlanGUI Approval] decision wait failed " + - "requestId=$requestId index=${toolCall.index} error=${e.javaClass.simpleName}:${e.message ?: ""}" - ) - false - } finally { - pendingApprovals.remove(requestId) - pendingApprovalCommands.remove(requestId) - } - logger.info( - "[CodePlanGUI Approval] resolved user decision " + - "requestId=$requestId index=${toolCall.index} approved=$approved" - ) - - if (!approved) { - val result = ExecutionResult.Denied(toolCall.command, "User rejected the command") - bridgeHandler?.notifyExecutionStatus(requestId, "denied", result.toToolResultContent()) - return CompletedToolCall(toolCall, result) - } - } else { - bridgeHandler?.notifyLog(requestId, "Command whitelisted, auto-approved", "info") - logger.info( - "[CodePlanGUI Approval] command is whitelisted, auto-approving " + - "requestId=$requestId index=${toolCall.index} command=${toolCall.command.summarizeForLog()}" - ) - } - - bridgeHandler?.notifyExecutionStatus(requestId, "running", "{}") - bridgeHandler?.notifyLog(requestId, "Executing: ${toolCall.command}", "info") - logger.info( - "[CodePlanGUI Approval] starting command execution " + - "requestId=$requestId index=${toolCall.index} timeoutSeconds=${state.commandTimeoutSeconds} " + - "command=${toolCall.command.summarizeForLog()}" - ) - val execService = CommandExecutionService.getInstance(project) - val result = execService.executeAsyncWithStream( - toolCall.command, - state.commandTimeoutSeconds - ) { line, isError -> - bridgeHandler?.notifyLog(requestId, line, if (isError) "stderr" else "stdout") - } - val bridgeStatus = if (result is ExecutionResult.TimedOut) "timeout" else "done" - bridgeHandler?.notifyExecutionStatus(requestId, bridgeStatus, result.toToolResultContent()) - val durationMs = when (result) { - is ExecutionResult.Success -> result.durationMs - is ExecutionResult.Failed -> result.durationMs - else -> 0L - } - val exitCode = when (result) { - is ExecutionResult.Success -> result.exitCode - is ExecutionResult.Failed -> result.exitCode - else -> -1 - } - bridgeHandler?.notifyLog(requestId, "Finished: exit $exitCode, ${durationMs}ms", "info") - logger.info( - "[CodePlanGUI Approval] command execution finished " + - "requestId=$requestId index=${toolCall.index} bridgeStatus=$bridgeStatus result=${result.summarizeForLog()}" - ) - return CompletedToolCall(toolCall, result) - } - private fun String.summarizeForLog(maxLength: Int = 160): String { val singleLine = replace('\n', ' ').replace('\r', ' ').trim() return if (singleLine.length <= maxLength) singleLine else singleLine.take(maxLength) + "..." } - private fun ExecutionResult.summarizeForLog(): String = when (this) { - is ExecutionResult.Success -> - "success exit=$exitCode durationMs=$durationMs stdoutLen=${stdout.length} stderrLen=${stderr.length} truncated=$truncated" - is ExecutionResult.Failed -> - "failed exit=$exitCode durationMs=$durationMs stdoutLen=${stdout.length} stderrLen=${stderr.length} truncated=$truncated" - is ExecutionResult.Blocked -> - "blocked reason=${reason.summarizeForLog()}" - is ExecutionResult.Denied -> - "denied reason=${reason.summarizeForLog()}" - is ExecutionResult.TimedOut -> - "timeout timeoutSeconds=$timeoutSeconds stdoutLen=${stdout.length}" - } - override fun dispose() { activeStream?.cancel() pendingApprovals.values.forEach { it.complete(false) } pendingApprovals.clear() pendingApprovalCommands.clear() + runningToolMsgId = null bridgeNotifiedStart.clear() - toolRegistry.dispose() scope.cancel() } @@ -886,20 +697,6 @@ $selection val contextLabel: String? = null ) - private data class PreparedToolCall( - val index: Int, - val id: String, - val functionName: String, - val argumentsJson: String, - val command: String, - val description: String - ) - - private data class CompletedToolCall( - val toolCall: PreparedToolCall, - val result: ExecutionResult - ) - @kotlinx.serialization.Serializable private data class RestoredMessagePayload( val id: String, diff --git a/src/main/kotlin/com/github/codeplangui/execution/ExecutionResult.kt b/src/main/kotlin/com/github/codeplangui/execution/ExecutionResult.kt index 28574c1..a26e537 100644 --- a/src/main/kotlin/com/github/codeplangui/execution/ExecutionResult.kt +++ b/src/main/kotlin/com/github/codeplangui/execution/ExecutionResult.kt @@ -36,18 +36,4 @@ sealed class ExecutionResult { kotlinx.serialization.serializer(), s ) - /** Convert to unified ToolResult for the new tool system. */ - fun toToolResult(): ToolResult = when (this) { - is Success -> ToolResult(ok = true, output = buildString { - if (stdout.isNotEmpty()) append(stdout) - if (stderr.isNotEmpty()) { - if (isNotEmpty()) append("\n") - append(stderr) - } - }.ifEmpty { "Command completed with exit code $exitCode" }) - is Failed -> ToolResult(ok = false, output = stderr.ifEmpty { "Command failed with exit code $exitCode" }) - is Blocked -> ToolResult(ok = false, output = reason) - is Denied -> ToolResult(ok = false, output = reason) - is TimedOut -> ToolResult(ok = false, output = "Command timed out after ${timeoutSeconds}s") - } } diff --git a/src/main/kotlin/com/github/codeplangui/execution/FileChangeReview.kt b/src/main/kotlin/com/github/codeplangui/execution/FileChangeReview.kt deleted file mode 100644 index 222a5bd..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/FileChangeReview.kt +++ /dev/null @@ -1,131 +0,0 @@ -package com.github.codeplangui.execution - -import com.github.codeplangui.settings.SettingsState -import com.intellij.openapi.application.ApplicationManager -import com.intellij.openapi.project.Project -import com.intellij.openapi.ui.Messages -import java.util.concurrent.CompletableFuture -import java.util.concurrent.TimeUnit - -/** - * Manages file change review via IDE-native dialogs. - * Supports session-level trust mode to reduce approval fatigue. - * - * @deprecated Replaced by ChangeReviewStrategy + DialogReview/EditorInlineReview. - * Kept for backward compatibility during migration — session trust state is synced. - */ -@Deprecated("Use ChangeReviewStrategy implementations instead", ReplaceWith("DialogReview")) -class FileChangeReview { - - @Volatile - var sessionFileWriteTrusted: Boolean = false - private set - - fun resetSessionTrust() { - sessionFileWriteTrusted = false - } - - fun setSessionTrusted() { - sessionFileWriteTrusted = true - } - - /** - * Review a file modification. Returns true if the change is approved. - * In trust mode, skips the dialog and returns true directly. - * - * First version: uses simple Yes/No confirmation dialog. - * Future: IntelliJ DiffDialog integration. - */ - fun reviewFileChange( - project: Project, - path: String, - oldContent: String, - newContent: String, - settings: SettingsState - ): Boolean { - if (sessionFileWriteTrusted) return true - - val future = CompletableFuture() - - ApplicationManager.getApplication().invokeAndWait { - // Compute simple diff stats - val oldLines = oldContent.lines().size - val newLines = newContent.lines().size - val added = (newLines - oldLines).coerceAtLeast(0) - val removed = (oldLines - newLines).coerceAtLeast(0) - - val message = buildString { - appendLine("Apply changes to $path?") - appendLine() - appendLine("Lines: +$added / -$removed (was $oldLines, now $newLines)") - appendLine() - // Show first few changed lines as preview - val oldSet = oldContent.lines().toSet() - val newLinesList = newContent.lines() - val changed = newLinesList.filter { it !in oldSet }.take(5) - if (changed.isNotEmpty()) { - appendLine("--- New/changed lines (preview) ---") - changed.forEach { appendLine(it) } - } - } - - val result = Messages.showYesNoDialog( - project, - message, - "File Change Review: $path", - Messages.getQuestionIcon() - ) - future.complete(result == Messages.YES) - } - - return future.get(60, TimeUnit.SECONDS) - } - - /** - * Review a new file creation. Returns true if creation is approved. - * In trust mode, skips the dialog and returns true directly. - */ - fun reviewNewFile( - project: Project, - path: String, - content: String, - settings: SettingsState - ): Boolean { - if (sessionFileWriteTrusted) return true - - val future = CompletableFuture() - - ApplicationManager.getApplication().invokeAndWait { - val lineCount = content.lines().size - val sizeBytes = content.toByteArray().size - - val message = buildString { - appendLine("Create new file?") - appendLine() - appendLine("Path: $path") - appendLine("Size: ${formatSize(sizeBytes)} / $lineCount lines") - appendLine() - appendLine("--- Preview (first 20 lines) ---") - content.lines().take(20).forEach { appendLine(it) } - if (lineCount > 20) appendLine("... ($lineCount lines total)") - } - - val result = Messages.showOkCancelDialog( - project, - message, - "Create New File", - "Create", "Cancel", - Messages.getQuestionIcon() - ) - future.complete(result == Messages.OK) - } - - return future.get(60, TimeUnit.SECONDS) - } - - private fun formatSize(bytes: Int): String = when { - bytes < 1024 -> "$bytes B" - bytes < 1024 * 1024 -> "${bytes / 1024} KB" - else -> "${bytes / (1024 * 1024)} MB" - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/FileWriteLock.kt b/src/main/kotlin/com/github/codeplangui/execution/FileWriteLock.kt deleted file mode 100644 index 37cb375..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/FileWriteLock.kt +++ /dev/null @@ -1,27 +0,0 @@ -package com.github.codeplangui.execution - -import java.util.concurrent.ConcurrentHashMap -import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock - -/** - * File-level write lock for serializing concurrent writes to the same file. - * Prevents data races when multiple tool calls target the same path. - */ -class FileWriteLock { - private val locks = ConcurrentHashMap() - - suspend fun withFileLock(path: String, block: suspend () -> T): T { - val mutex = locks.computeIfAbsent(path) { Mutex() } - return mutex.withLock { - block() - } - // Do NOT remove mutex from map after release — another coroutine - // may be waiting, and removal + computeIfAbsent creates a new Mutex, - // breaking serialization semantics. - } - - fun clear() { - locks.clear() - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/InlineChangeHighlighter.kt b/src/main/kotlin/com/github/codeplangui/execution/InlineChangeHighlighter.kt deleted file mode 100644 index 88fa2d9..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/InlineChangeHighlighter.kt +++ /dev/null @@ -1,34 +0,0 @@ -package com.github.codeplangui.execution - -import com.intellij.openapi.diagnostic.Logger -import com.intellij.openapi.editor.Document -import com.intellij.openapi.editor.Editor -import com.intellij.openapi.editor.event.DocumentEvent -import com.intellij.openapi.editor.event.DocumentListener -import com.intellij.openapi.fileEditor.FileDocumentManager -import com.intellij.openapi.fileEditor.FileEditorManager -import com.intellij.openapi.project.Project -import com.intellij.openapi.vfs.VirtualFile - -/** - * Tracks inline change highlights for trusted file modifications. - * In the first version, this relies on IntelliJ's built-in VCS change highlighting - * (line markers + gutter colors), which works automatically when files are modified. - * - * Future iterations can add custom highlighting for AI-specific changes. - */ -class InlineChangeHighlighter(private val project: Project) { - - private val logger = Logger.getInstance(InlineChangeHighlighter::class.java) - - /** - * Notifies the highlighter that a file was changed by a tool. - * For now, this is a no-op — IntelliJ's built-in VCS integration - * handles gutter change markers automatically. - */ - fun onFileChanged(virtualFile: VirtualFile) { - // IntelliJ's built-in line-level change tracking (changelist-based) - // already provides gutter markers for modified files. - // No custom highlighting needed in v1. - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/PermissionMode.kt b/src/main/kotlin/com/github/codeplangui/execution/PermissionMode.kt deleted file mode 100644 index 0327626..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/PermissionMode.kt +++ /dev/null @@ -1,17 +0,0 @@ -package com.github.codeplangui.execution - -/** - * Permission levels for tool execution. Ordered: READ_ONLY < WORKSPACE_WRITE < DANGER_FULL_ACCESS. - */ -enum class PermissionMode(val level: Int) { - READ_ONLY(0), - WORKSPACE_WRITE(1), - DANGER_FULL_ACCESS(2); - - fun gte(other: PermissionMode): Boolean = this.level >= other.level - - companion object { - fun fromString(value: String?): PermissionMode = - values().find { it.name.equals(value, ignoreCase = true) } ?: WORKSPACE_WRITE - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/PostEditPipeline.kt b/src/main/kotlin/com/github/codeplangui/execution/PostEditPipeline.kt deleted file mode 100644 index 800ea7d..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/PostEditPipeline.kt +++ /dev/null @@ -1,35 +0,0 @@ -package com.github.codeplangui.execution - -import com.intellij.openapi.diagnostic.Logger -import com.intellij.openapi.project.Project -import com.intellij.openapi.vfs.VirtualFile - -/** - * Post-edit quality pipeline: optimize imports → reformat → inspection. - * Runs after file write operations to maintain code quality. - * - * First version: no-op stub. IntelliJ's built-in real-time inspections - * handle code quality feedback automatically. Future iterations can add - * programmatic optimize-imports/reformat/inspection here. - */ -class PostEditPipeline(private val project: Project) { - - private val logger = Logger.getInstance(PostEditPipeline::class.java) - - data class InspectionResult( - val errors: List, - val warnings: List, - val info: List - ) - - data class Finding(val line: Int, val severity: String, val message: String) - - /** - * Best-effort post-write pipeline. Returns inspection feedback if available. - */ - fun runAfterWriteSync(virtualFile: VirtualFile): String? { - // First version: rely on IntelliJ's built-in real-time inspections. - // Future: add optimizeImports + reformat + programmatic inspection. - return null - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/ToolCallDispatcher.kt b/src/main/kotlin/com/github/codeplangui/execution/ToolCallDispatcher.kt deleted file mode 100644 index 31a635b..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/ToolCallDispatcher.kt +++ /dev/null @@ -1,606 +0,0 @@ -package com.github.codeplangui.execution - -import com.github.codeplangui.BridgeHandler -import com.github.codeplangui.execution.executors.BashExecutor -import com.github.codeplangui.execution.review.ChangeReviewStrategy -import com.github.codeplangui.execution.review.EditorInlineReview -import com.github.codeplangui.settings.PluginSettings -import com.intellij.openapi.application.ApplicationManager -import com.intellij.openapi.command.WriteCommandAction -import com.intellij.openapi.diagnostic.Logger -import com.intellij.openapi.vfs.LocalFileSystem -import kotlinx.coroutines.CancellableContinuation -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.async -import kotlinx.coroutines.awaitAll -import kotlinx.coroutines.cancel -import kotlinx.coroutines.coroutineScope -import kotlinx.coroutines.suspendCancellableCoroutine -import kotlinx.coroutines.withTimeout -import kotlinx.serialization.json.Json -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.contentOrNull -import kotlinx.serialization.json.jsonObject -import kotlinx.serialization.json.jsonPrimitive -import java.io.File -import java.util.UUID -import java.util.concurrent.ConcurrentHashMap -import kotlin.coroutines.resume -import kotlin.coroutines.resumeWithException -import kotlin.math.min - -/** - * Unified tool dispatcher. Owns the complete dispatch pipeline: - * lookup -> parse -> dynamic permission -> deny_rules -> allow_rules -> - * session mode -> approval -> execution -> review -> write -> output truncation. - */ -class ToolCallDispatcher( - private val registry: ToolRegistry, - private val fileChangeReview: FileChangeReview, - private val fileWriteLock: FileWriteLock, - private val project: com.intellij.openapi.project.Project -) { - private val logger = Logger.getInstance(ToolCallDispatcher::class.java) - private val scope = CoroutineScope(Dispatchers.IO + SupervisorJob()) - - // Review strategy — selected based on settings - private val reviewStrategy: ChangeReviewStrategy - - // Approval suspension - private val pendingApprovals = ConcurrentHashMap>() - - // Hooks - private val hooks = mutableListOf() - - // Rate limiting - private var roundToolCallCount = 0 - private val consecutiveCalls = mutableMapOf() - - init { - val settings = PluginSettings.getInstance().getState() - reviewStrategy = when (settings.reviewMode) { - "dialog" -> com.github.codeplangui.execution.review.DialogReview() - else -> EditorInlineReview(project) - } - // Sync existing session trust state - if (fileChangeReview.sessionFileWriteTrusted) { - reviewStrategy.sessionTrusted = true - } - } - - companion object { - const val MAX_TOOL_OUTPUT_BYTES = 50 * 1024 // 50KB - const val MAX_TOOL_CALLS_PER_ROUND = 20 - const val MAX_TOOL_CALLS_PER_RESPONSE = 10 - const val CONSECUTIVE_CALL_WARNING = 5 - const val APPROVAL_TIMEOUT_MS = 60_000L - } - - fun addHook(hook: ToolHook) { - hooks.add(hook) - } - - /** Build the tools parameter for API requests. */ - fun buildToolsParam(): List { - return registry.buildOpenAiTools() - } - - /** Reset per-session state (new chat). */ - fun resetSession() { - fileChangeReview.resetSessionTrust() - reviewStrategy.resetSessionTrust() - roundToolCallCount = 0 - consecutiveCalls.clear() - cancelAllPendingApprovals() - } - - /** Reset per-round state. */ - fun resetRound() { - roundToolCallCount = 0 - } - - /** Called by Bridge when user responds to an approval request. */ - fun onApprovalResponse(requestId: String, decision: String) { - val cont = pendingApprovals.remove(requestId) ?: return - val approved = decision == "allow" - if (approved) { - cont.resume(true) - } else { - cont.resume(false) - } - } - - /** - * Dispatch a single tool call through the full pipeline. - */ - suspend fun dispatch( - toolName: String, - argsJson: String, - msgId: String, - bridgeHandler: BridgeHandler? - ): ToolResult { - return try { - dispatchInternal(toolName, argsJson, msgId, bridgeHandler) - } catch (e: Exception) { - ToolResult(ok = false, output = "Tool execution error: ${e.message}") - } - } - - /** - * Dispatch multiple tool calls with concurrent scheduling. - * Results are returned in original order. - */ - suspend fun dispatchAll( - calls: List, - msgId: String, - bridgeHandler: BridgeHandler? - ): List> { - // Rate limit check - if (calls.size > MAX_TOOL_CALLS_PER_RESPONSE) { - return calls.map { call -> - call to ToolResult( - ok = false, - output = "Too many tool calls in a single response (${calls.size} > $MAX_TOOL_CALLS_PER_RESPONSE)" - ) - } - } - - // Partition into batches - val batches = partitionToolCalls(calls) - - // Execute batches sequentially - val results = arrayOfNulls>(calls.size) - - for (batch in batches) { - if (batch.isConcurrencySafe && batch.entries.size > 1) { - // Concurrent batch - val hasBashError = java.util.concurrent.atomic.AtomicBoolean(false) - val batchResults = coroutineScope { - batch.entries.map { (index, call) -> - async { - if (hasBashError.get() && isBashCommand(call.name)) { - index to (call to ToolResult( - ok = false, - output = "Skipped: previous bash command in batch failed" - )) - } else { - val result = dispatch(call.name, call.arguments, msgId, bridgeHandler) - if (!result.ok && isBashCommand(call.name)) { - hasBashError.set(true) - } - index to (call to result) - } - } - }.awaitAll() - } - for ((index, result) in batchResults) { - results[index] = result - } - } else { - // Serial batch - for ((index, call) in batch.entries) { - results[index] = call to dispatch(call.name, call.arguments, msgId, bridgeHandler) - } - } - } - - return results.map { it!! } - } - - private suspend fun dispatchInternal( - toolName: String, - argsJson: String, - msgId: String, - bridgeHandler: BridgeHandler? - ): ToolResult { - // Rate limit check - roundToolCallCount++ - if (roundToolCallCount > MAX_TOOL_CALLS_PER_ROUND) { - return ToolResult(ok = false, output = "Round tool call limit exceeded ($roundToolCallCount > $MAX_TOOL_CALLS_PER_ROUND)") - } - - // Consecutive call warning - val count = consecutiveCalls.getOrDefault(toolName, 0) + 1 - consecutiveCalls[toolName] = count - - // 1. Build context - val settings = PluginSettings.getInstance().getState() - - // 2. Parse arguments - val input: JsonObject = try { - Json.parseToJsonElement(argsJson).jsonObject - } catch (_: Exception) { - return ToolResult(ok = false, output = "Invalid arguments: not valid JSON") - } - - // 3. Find tool - val spec = registry.find(toolName) - ?: return ToolResult(ok = false, output = "Unknown tool: $toolName") - - // 4. Build summary and emit tool_step_start - val stepRequestId = UUID.randomUUID().toString() - val summary = buildTargetSummary(toolName, input) - bridgeHandler?.notifyToolStepStart(msgId, stepRequestId, toolName, summary) - - // 5. Dynamic permission resolution - val requiredPermission = resolvePermission(spec, input) - - // 6. Authorization - val authDecision = authorize(toolName, input, requiredPermission, settings) - when (authDecision) { - is AuthDecision.Deny -> { - bridgeHandler?.notifyToolStepEnd(msgId, stepRequestId, false, authDecision.reason, 0L) - return ToolResult(ok = false, output = authDecision.reason) - } - is AuthDecision.Ask -> { - val requestId = UUID.randomUUID().toString() - val toolInput = if (toolName == "run_command" || toolName == "run_powershell") { - input["command"]?.jsonPrimitive?.contentOrNull ?: argsJson - } else { - argsJson.take(200) - } - val description = input["description"]?.jsonPrimitive?.contentOrNull ?: "" - - bridgeHandler?.notifyApprovalRequest( - requestId = requestId, - command = toolInput, - description = description - ) - - val approved = awaitApproval(requestId) - if (!approved) { - bridgeHandler?.notifyToolStepEnd(msgId, stepRequestId, false, "User denied permission for $toolName", 0L) - return ToolResult(ok = false, output = "User denied permission for $toolName") - } - } - is AuthDecision.Allow -> { /* proceed */ } - } - - // 7. Pre-Hooks - var intercepted: ToolResult? = null - for (hook in hooks) { - try { - val result = hook.beforeExecute(toolName, input) - if (result != null) { - intercepted = result - break // Short-circuit - } - } catch (e: Exception) { - logger.warn("Pre-Hook threw exception for $toolName", e) - } - } - - // 8. Execute (if not intercepted) with timing - val startTime = System.currentTimeMillis() - val rawResult = intercepted ?: runWithFileLock(spec, input, toolName) - val durationMs = System.currentTimeMillis() - startTime - - // 8.5 File change review (pendingReview handling) - val finalResult = if (rawResult.pendingReview != null) { - handlePendingReview(rawResult, requestId = UUID.randomUUID().toString(), msgId, stepRequestId, bridgeHandler, durationMs) - } else { - rawResult - } - - - // 9. Output truncation - val truncatedResult = truncateOutput(finalResult, msgId) - - // 10. Emit tool_step_end - val diffStats = extractDiffStats(toolName, input, truncatedResult) - bridgeHandler?.notifyToolStepEnd( - msgId, stepRequestId, truncatedResult.ok, truncatedResult.output, durationMs, diffStats - ) - - // 11. Post-Hooks - for (hook in hooks) { - try { - hook.afterExecute(toolName, input, truncatedResult) - } catch (e: Exception) { - logger.warn("Post-Hook threw exception for $toolName", e) - } - } - - return truncatedResult - } - - /** - * Handle pending file change review: get user approval, then write file + run post-edit. - */ - private suspend fun handlePendingReview( - result: ToolResult, - requestId: String, - msgId: String, - stepRequestId: String, - bridgeHandler: BridgeHandler?, - durationMs: Long - ): ToolResult { - val review = result.pendingReview!! - - val approved = if (review.isNewFile) { - reviewStrategy.reviewNewFile( - project, requestId, review.path, - review.newContentForCreate ?: "" - ) - } else { - reviewStrategy.reviewFileChange( - project, requestId, review.path, - review.originalContent, review.newContent - ) - } - - // Sync trust state back to FileChangeReview - if (reviewStrategy.sessionTrusted && !fileChangeReview.sessionFileWriteTrusted) { - fileChangeReview.setSessionTrusted() - } - - if (!approved) { - bridgeHandler?.notifyToolStepEnd(msgId, stepRequestId, false, "User rejected changes", durationMs) - return ToolResult(ok = false, output = "User rejected changes") - } - - // Approved: write the file - val writeOk = writeFileAfterApproval(review) - if (!writeOk) { - return ToolResult(ok = false, output = "Failed to write file: ${review.path}") - } - - // Run post-edit pipeline - val postEditResult = runPostEditPipeline(review.path) - - val lineCount = review.newContent.lines().size - val oldLines = review.originalContent.lines().size - val diffLines = Math.abs(lineCount - oldLines) - val changeType = if (lineCount > oldLines) "+$diffLines" else "-$diffLines" - - val output = buildString { - if (review.isNewFile) { - append("File created: ${review.path} ($lineCount lines)") - } else { - append("File edited successfully: ${review.path} ($changeType lines)") - } - if (postEditResult != null) { - append("\n\n") - append(postEditResult) - } - } - return ToolResult(ok = true, output = output) - } - - private fun writeFileAfterApproval(review: FileChangeReviewData): Boolean { - if (review.isNewFile) { - val file = File(review.path) - file.parentFile?.mkdirs() - } - return try { - ApplicationManager.getApplication().invokeAndWait { - WriteCommandAction.runWriteCommandAction(project) { - val file = File(review.path) - file.writeText(review.newContent) - val vf = LocalFileSystem.getInstance().refreshAndFindFileByIoFile(file) - vf?.refresh(false, false) - } - } - true - } catch (e: Exception) { - logger.warn("Failed to write file after approval: ${review.path}", e) - false - } - } - - private fun runPostEditPipeline(path: String): String? { - return try { - val vf = LocalFileSystem.getInstance().findFileByIoFile(File(path)) ?: return null - PostEditPipeline(project).runAfterWriteSync(vf) - } catch (_: Exception) { - null - } - } - - private fun buildTargetSummary(toolName: String, input: JsonObject): String { - return when (toolName) { - "read_file" -> input["path"]?.jsonPrimitive?.contentOrNull ?: toolName - "list_files" -> input["path"]?.jsonPrimitive?.contentOrNull ?: "." - "grep_files" -> "\"${input["pattern"]?.jsonPrimitive?.contentOrNull ?: ""}\"" - "edit_file" -> input["path"]?.jsonPrimitive?.contentOrNull ?: toolName - "write_file" -> input["path"]?.jsonPrimitive?.contentOrNull ?: toolName - "run_command", "run_powershell" -> { - val cmd = input["command"]?.jsonPrimitive?.contentOrNull ?: toolName - if (cmd.length > 60) cmd.take(57) + "..." else cmd - } - else -> toolName - } - } - - private fun extractDiffStats(toolName: String, input: JsonObject, result: ToolResult): String? { - if (toolName != "edit_file" && toolName != "write_file") return null - return null - } - - private fun resolvePermission(spec: ToolSpec, input: JsonObject): PermissionMode { - // Bash commands use dynamic classification - if (spec.name == "run_command" || spec.name == "run_powershell") { - val command = input["command"]?.jsonPrimitive?.contentOrNull ?: return PermissionMode.DANGER_FULL_ACCESS - return BashExecutor().classifyPermission(command) - } - return spec.requiredPermission - } - - private fun authorize( - toolName: String, - input: JsonObject, - requiredPermission: PermissionMode, - settings: com.github.codeplangui.settings.SettingsState - ): AuthDecision { - val sessionMode = PermissionMode.fromString(settings.permissionMode) - - // deny_rules check for bash commands - if (toolName == "run_command" || toolName == "run_powershell") { - val command = input["command"]?.jsonPrimitive?.contentOrNull ?: return AuthDecision.Deny("Missing command") - val bashExecutor = BashExecutor() - - // Check deny rules via executor (already done in execute, but check here for pre-execution denial) - val denied = checkDenyRulesEarly(command) - if (denied != null) return AuthDecision.Deny(denied) - - // Whitelist check - if (CommandExecutionService.isWhitelisted(command, settings.commandWhitelist)) { - if (sessionMode >= requiredPermission) return AuthDecision.Allow - } - } - - // Path traversal check for file tools - val path = input["path"]?.jsonPrimitive?.contentOrNull - if (path != null && (path.contains("../") || path.contains("..\\"))) { - return AuthDecision.Deny("Path traversal detected") - } - - // Session mode check - if (sessionMode >= requiredPermission) { - return AuthDecision.Allow - } - - // Trusted file write - if ((toolName == "edit_file" || toolName == "write_file") && fileChangeReview.sessionFileWriteTrusted) { - return AuthDecision.Allow - } - - // Fallback: ask - return AuthDecision.Ask - } - - private fun checkDenyRulesEarly(command: String): String? { - val cmd = command.lowercase() - // Path traversal (case-insensitive, handle URL encoding) - if (cmd.contains("../") || cmd.contains("..\\") || - cmd.contains("..%2f") || cmd.contains("..%5c")) return "Path traversal detected" - // Dangerous delete - if (Regex("""rm\s+(-\w*\s*)*(-r|--recursive).*\s+(/|~)""", RegexOption.IGNORE_CASE).containsMatchIn(cmd)) - return "Dangerous delete command detected" - // Network exfiltration - if (Regex("""(\|\s*(curl|wget)\s)|(>\s*/dev/tcp/)""", RegexOption.IGNORE_CASE).containsMatchIn(cmd)) - return "Potential network exfiltration detected" - // Fork bomb - if (Regex(""":\(\)\{.*:\|:&\}|fork\s*bomb""", RegexOption.IGNORE_CASE).containsMatchIn(cmd)) - return "Fork bomb pattern detected" - // Privilege escalation - if (Regex("""sudo\s+|chmod\s+[0-7]*77|chown\s+""", RegexOption.IGNORE_CASE).containsMatchIn(cmd)) - return "Privilege escalation detected" - return null - } - - private suspend fun runWithFileLock(spec: ToolSpec, input: JsonObject, toolName: String): ToolResult { - val settings = PluginSettings.getInstance().getState() - val project = resolveProject() - val cwd = project?.basePath ?: return ToolResult(ok = false, output = "Project path unavailable") - val context = ToolContext(project = project, cwd = cwd, settings = settings) - - val needsLock = !spec.isConcurrencySafe(input) - return if (needsLock) { - val path = input["path"]?.jsonPrimitive?.contentOrNull ?: toolName - fileWriteLock.withFileLock(path) { - spec.executor.execute(input, context) - } - } else { - spec.executor.execute(input, context) - } - } - - private suspend fun awaitApproval(requestId: String): Boolean { - return try { - withTimeout(APPROVAL_TIMEOUT_MS) { - suspendCancellableCoroutine { cont -> - pendingApprovals[requestId] = cont - cont.invokeOnCancellation { - pendingApprovals.remove(requestId) - } - } - } - } catch (_: kotlinx.coroutines.TimeoutCancellationException) { - pendingApprovals.remove(requestId) - false - } - } - - private fun cancelAllPendingApprovals() { - pendingApprovals.forEach { (_, cont) -> - try { cont.cancel() } catch (_: Exception) {} - } - pendingApprovals.clear() - } - - private fun truncateOutput(result: ToolResult, msgId: String): ToolResult { - if (result.output.toByteArray().size <= MAX_TOOL_OUTPUT_BYTES) return result - - val totalBytes = result.output.toByteArray().size - val truncatedOutput = String( - result.output.toByteArray(), - 0, - min(MAX_TOOL_OUTPUT_BYTES, result.output.toByteArray().size) - ) - - // Write full output to temp file - val tmpDir = File(System.getProperty("java.io.tmpdir"), "codeplan-tool-output") - tmpDir.mkdirs() - val tmpFile = File(tmpDir, "tool-output-$msgId-${System.currentTimeMillis()}.log") - tmpFile.writeText(result.output) - - return result.copy( - output = truncatedOutput + - "\n\n... [OUTPUT TRUNCATED: $totalBytes bytes total, showing first 50KB]", - truncated = true, - totalBytes = totalBytes, - outputPath = tmpFile.absolutePath - ) - } - - private fun partitionToolCalls(calls: List): List { - val result = mutableListOf() - for ((index, call) in calls.withIndex()) { - val spec = registry.find(call.name) - val input = try { - Json.parseToJsonElement(call.arguments).jsonObject - } catch (_: Exception) { - null - } - val safe = spec?.let { s -> input?.let { i -> s.isConcurrencySafe(i) } } ?: false - - if (safe && result.isNotEmpty() && result.last().isConcurrencySafe) { - val last = result.last() - result[result.lastIndex] = last.copy( - entries = last.entries + IndexedValue(index, call) - ) - } else { - result.add(Batch(safe, listOf(IndexedValue(index, call)))) - } - } - return result - } - - private fun isBashCommand(name: String): Boolean = - name == "run_command" || name == "run_powershell" - - private fun resolveProject(): com.intellij.openapi.project.Project = project -} - -// Helper types - -data class PendingToolCall( - val id: String, - val name: String, - val arguments: String, - val index: Int -) - -data class Batch( - val isConcurrencySafe: Boolean, - val entries: List> -) - -sealed class AuthDecision { - data object Allow : AuthDecision() - data class Deny(val reason: String) : AuthDecision() - data object Ask : AuthDecision() -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/ToolContext.kt b/src/main/kotlin/com/github/codeplangui/execution/ToolContext.kt deleted file mode 100644 index 2ae534f..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/ToolContext.kt +++ /dev/null @@ -1,13 +0,0 @@ -package com.github.codeplangui.execution - -import com.github.codeplangui.settings.SettingsState -import com.intellij.openapi.project.Project - -/** - * Execution context passed to every tool executor. - */ -data class ToolContext( - val project: Project, - val cwd: String, - val settings: SettingsState -) diff --git a/src/main/kotlin/com/github/codeplangui/execution/ToolExecutor.kt b/src/main/kotlin/com/github/codeplangui/execution/ToolExecutor.kt deleted file mode 100644 index 694d7d0..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/ToolExecutor.kt +++ /dev/null @@ -1,12 +0,0 @@ -package com.github.codeplangui.execution - -import kotlinx.serialization.json.JsonObject - -/** - * Interface that every tool executor implements. - * Implementations must NOT throw — return ToolResult(ok=false) on error. - * All IO operations must run on Dispatchers.IO. - */ -interface ToolExecutor { - suspend fun execute(input: JsonObject, context: ToolContext): ToolResult -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/ToolHook.kt b/src/main/kotlin/com/github/codeplangui/execution/ToolHook.kt deleted file mode 100644 index 4d6c4b6..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/ToolHook.kt +++ /dev/null @@ -1,25 +0,0 @@ -package com.github.codeplangui.execution - -import kotlinx.serialization.json.JsonObject - -/** - * Hook extension point for cross-cutting concerns around tool execution. - * Registered in ToolCallDispatcher via addHook(). - * - * Pre-Hooks use short-circuit semantics: the first non-null return stops the chain. - * Post-Hooks always all execute, even on failure or interception. - */ -interface ToolHook { - /** - * Called before tool execution. - * Return null → continue execution. - * Return ToolResult → intercept, skip executor, return this result. - */ - suspend fun beforeExecute(toolName: String, input: JsonObject): ToolResult? = null - - /** - * Called after tool execution (success, failure, or interception). - * For logging, metrics, result enrichment. - */ - suspend fun afterExecute(toolName: String, input: JsonObject, result: ToolResult) {} -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/ToolRegistry.kt b/src/main/kotlin/com/github/codeplangui/execution/ToolRegistry.kt deleted file mode 100644 index 2b2f7c1..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/ToolRegistry.kt +++ /dev/null @@ -1,79 +0,0 @@ -package com.github.codeplangui.execution - -import com.github.codeplangui.api.FunctionDefinition -import com.github.codeplangui.api.ToolDefinition -import com.intellij.openapi.Disposable -import com.intellij.openapi.diagnostic.Logger -import com.intellij.openapi.util.Disposer -import kotlinx.serialization.json.JsonObject - -/** - * Central registry for all tools (built-in + MCP). - * Bound to IntelliJ Project lifecycle via Disposable. - */ -class ToolRegistry(private val parentDisposable: com.intellij.openapi.Disposable) : Disposable { - - private val logger = Logger.getInstance(ToolRegistry::class.java) - - private val tools = mutableMapOf() - private val disposers = mutableListOf<() -> Unit>() - - init { - Disposer.register(parentDisposable, this) - } - - /** List all registered tools. */ - fun list(): List = tools.values.toList() - - /** Find a tool by name. */ - fun find(name: String): ToolSpec? = tools[name] - - /** Register tools. Skips duplicates (same name) silently. */ - fun addTools(specs: List) { - for (spec in specs) { - if (tools.containsKey(spec.name)) { - logger.info("Tool '${spec.name}' already registered, skipping") - continue - } - tools[spec.name] = spec - logger.info("Registered tool: ${spec.name}") - } - } - - /** Remove a tool by name (for MCP server disconnect). */ - fun removeTool(name: String) { - tools.remove(name) - logger.info("Removed tool: $name") - } - - /** Register a cleanup function (called in reverse order on dispose). */ - fun addDisposer(fn: () -> Unit) { - disposers.add(fn) - } - - /** Build OpenAI API tools parameter from all registered tools. */ - fun buildOpenAiTools(): List { - return tools.values.map { spec -> - ToolDefinition( - type = "function", - function = FunctionDefinition( - name = spec.name, - description = spec.description, - parameters = spec.inputSchema - ) - ) - } - } - - override fun dispose() { - disposers.reversed().forEach { fn -> - try { - fn() - } catch (e: Exception) { - logger.warn("Disposer threw exception", e) - } - } - disposers.clear() - tools.clear() - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/ToolResult.kt b/src/main/kotlin/com/github/codeplangui/execution/ToolResult.kt deleted file mode 100644 index bdf4515..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/ToolResult.kt +++ /dev/null @@ -1,42 +0,0 @@ -package com.github.codeplangui.execution - -import kotlinx.serialization.Serializable - -/** - * Unified result type for all tools (built-in and MCP). - * Every tool returns this, never throws. - */ -data class ToolResult( - val ok: Boolean, - val output: String, - val awaitUser: Boolean = false, - val backgroundTask: BackgroundTask? = null, - val truncated: Boolean = false, - val totalBytes: Int? = null, - val outputPath: String? = null, - val pendingReview: FileChangeReviewData? = null -) - -/** - * Carries pending file change data from executor to dispatcher. - * Dispatcher delegates to ChangeReviewStrategy for approval, then writes. - */ -data class FileChangeReviewData( - val path: String, - val originalContent: String, - val newContent: String, - val isNewFile: Boolean = false, - val newContentForCreate: String? = null -) - -@Serializable -data class BackgroundTask( - val id: String, - val command: String, - val status: BackgroundTaskStatus -) - -@Serializable -enum class BackgroundTaskStatus { - PENDING, RUNNING, COMPLETED, FAILED, CANCELLED -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/ToolSpec.kt b/src/main/kotlin/com/github/codeplangui/execution/ToolSpec.kt deleted file mode 100644 index 004f49b..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/ToolSpec.kt +++ /dev/null @@ -1,21 +0,0 @@ -package com.github.codeplangui.execution - -import kotlinx.serialization.json.JsonObject - -/** - * Tool registration info. Each registered tool has one ToolSpec. - * - * Dynamic capabilities (isConcurrencySafe, isReadOnly, isDestructive) accept - * the parsed input and return a boolean — e.g. run_command decides based on - * the concrete command, while read_file always returns true. - */ -data class ToolSpec( - val name: String, - val description: String, - val inputSchema: JsonObject, - val requiredPermission: PermissionMode, - val executor: ToolExecutor, - val isConcurrencySafe: (input: JsonObject) -> Boolean = { false }, - val isReadOnly: (input: JsonObject) -> Boolean = { false }, - val isDestructive: (input: JsonObject) -> Boolean = { false } -) diff --git a/src/main/kotlin/com/github/codeplangui/execution/ToolSpecs.kt b/src/main/kotlin/com/github/codeplangui/execution/ToolSpecs.kt deleted file mode 100644 index 4262994..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/ToolSpecs.kt +++ /dev/null @@ -1,225 +0,0 @@ -package com.github.codeplangui.execution - -import com.github.codeplangui.execution.executors.* -import kotlinx.serialization.json.buildJsonArray -import kotlinx.serialization.json.buildJsonObject -import kotlinx.serialization.json.jsonPrimitive -import kotlinx.serialization.json.put -import kotlinx.serialization.json.JsonObject - -/** - * ToolSpec definitions for all 6 built-in tools. - * The bashExecutor is shared for dynamic permission classification. - */ -class ToolSpecs { - - private val bashExecutor = BashExecutor() - - val readFileExecutor = ReadFileExecutor() - val listFilesExecutor = ListFilesExecutor() - val grepFilesExecutor = GrepFilesExecutor() - val editFileExecutor = EditFileExecutor() - val writeFileExecutor = WriteFileExecutor() - - fun allSpecs(): List = listOf( - runCommandSpec(), - readFileSpec(), - listFilesSpec(), - grepFilesSpec(), - editFileSpec(), - writeFileSpec() - ) - - fun runCommandSpec(): ToolSpec { - val toolName = com.github.codeplangui.execution.ShellPlatform.current().toolName() - return ToolSpec( - name = toolName, - description = "Execute a shell command in the project root directory. " + - "Only use when the user asks you to run something or when you need to " + - "inspect state to answer accurately.", - inputSchema = buildJsonObject { - put("type", "object") - put("properties", buildJsonObject { - put("command", buildJsonObject { - put("type", "string") - put("description", "The shell command to execute") - }) - put("description", buildJsonObject { - put("type", "string") - put("description", "One-line explanation of why you are running this command") - }) - }) - put("required", buildJsonArray { - add(kotlinx.serialization.json.JsonPrimitive("command")) - add(kotlinx.serialization.json.JsonPrimitive("description")) - }) - }, - requiredPermission = PermissionMode.READ_ONLY, // Dynamic — overridden in dispatch - executor = bashExecutor, - isConcurrencySafe = { input -> - input["command"]?.let { cmd -> - bashExecutor.isConcurrencySafe(cmd.jsonPrimitive.content) - } ?: false - }, - isReadOnly = { input -> - input["command"]?.let { cmd -> - bashExecutor.isReadOnly(cmd.jsonPrimitive.content) - } ?: false - }, - isDestructive = { input -> - input["command"]?.let { cmd -> - bashExecutor.isDestructive(cmd.jsonPrimitive.content) - } ?: false - } - ) - } - - fun readFileSpec() = ToolSpec( - name = "read_file", - description = "Read file contents. Supports line-based pagination. " + - "Returns content with line numbers. Use line_number and limit for pagination.", - inputSchema = buildJsonObject { - put("type", "object") - put("properties", buildJsonObject { - put("path", buildJsonObject { - put("type", "string") - put("description", "Path relative to project root") - }) - put("line_number", buildJsonObject { - put("type", "integer") - put("description", "Starting line number (1-based). Default: 1") - }) - put("limit", buildJsonObject { - put("type", "integer") - put("description", "Number of lines to read. Max 1000. Default: 500") - }) - }) - put("required", buildJsonArray { - add(kotlinx.serialization.json.JsonPrimitive("path")) - }) - }, - requiredPermission = PermissionMode.READ_ONLY, - executor = readFileExecutor, - isConcurrencySafe = { true }, - isReadOnly = { true }, - isDestructive = { false } - ) - - fun listFilesSpec() = ToolSpec( - name = "list_files", - description = "List directory contents. Returns files and subdirectories. " + - "Use this to explore the project structure.", - inputSchema = buildJsonObject { - put("type", "object") - put("properties", buildJsonObject { - put("path", buildJsonObject { - put("type", "string") - put("description", "Directory path relative to project root. Default: '.'") - }) - }) - }, - requiredPermission = PermissionMode.READ_ONLY, - executor = listFilesExecutor, - isConcurrencySafe = { true }, - isReadOnly = { true }, - isDestructive = { false } - ) - - fun grepFilesSpec() = ToolSpec( - name = "grep_files", - description = "Search for text patterns in project files. " + - "Returns matching lines with file paths and line numbers.", - inputSchema = buildJsonObject { - put("type", "object") - put("properties", buildJsonObject { - put("pattern", buildJsonObject { - put("type", "string") - put("description", "Search pattern") - }) - put("path", buildJsonObject { - put("type", "string") - put("description", "Directory to search in. Default: '.'") - }) - }) - put("required", buildJsonArray { - add(kotlinx.serialization.json.JsonPrimitive("pattern")) - }) - }, - requiredPermission = PermissionMode.READ_ONLY, - executor = grepFilesExecutor, - isConcurrencySafe = { true }, - isReadOnly = { true }, - isDestructive = { false } - ) - - fun editFileSpec() = ToolSpec( - name = "edit_file", - description = "Replace text in a file. Use for precise, targeted edits. " + - "If multiple matches exist, provide line_number to disambiguate. " + - "The change will be reviewed by the user before applying.", - inputSchema = buildJsonObject { - put("type", "object") - put("properties", buildJsonObject { - put("path", buildJsonObject { - put("type", "string") - put("description", "File path relative to project root") - }) - put("search", buildJsonObject { - put("type", "string") - put("description", "Text to search for") - }) - put("replace", buildJsonObject { - put("type", "string") - put("description", "Text to replace with") - }) - put("replaceAll", buildJsonObject { - put("type", "boolean") - put("description", "Replace all occurrences. Default: false") - }) - put("line_number", buildJsonObject { - put("type", "integer") - put("description", "Target line number (1-based) to disambiguate multiple matches") - }) - }) - put("required", buildJsonArray { - add(kotlinx.serialization.json.JsonPrimitive("path")) - add(kotlinx.serialization.json.JsonPrimitive("search")) - add(kotlinx.serialization.json.JsonPrimitive("replace")) - }) - }, - requiredPermission = PermissionMode.WORKSPACE_WRITE, - executor = editFileExecutor, - isConcurrencySafe = { false }, - isReadOnly = { false }, - isDestructive = { false } - ) - - fun writeFileSpec() = ToolSpec( - name = "write_file", - description = "Create or overwrite a file with complete content. " + - "Use for creating new files or when changes are too large for edit_file. " + - "The change will be reviewed by the user before applying.", - inputSchema = buildJsonObject { - put("type", "object") - put("properties", buildJsonObject { - put("path", buildJsonObject { - put("type", "string") - put("description", "File path relative to project root") - }) - put("content", buildJsonObject { - put("type", "string") - put("description", "Complete file content") - }) - }) - put("required", buildJsonArray { - add(kotlinx.serialization.json.JsonPrimitive("path")) - add(kotlinx.serialization.json.JsonPrimitive("content")) - }) - }, - requiredPermission = PermissionMode.WORKSPACE_WRITE, - executor = writeFileExecutor, - isConcurrencySafe = { false }, - isReadOnly = { false }, - isDestructive = { false } - ) -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/executors/BashExecutor.kt b/src/main/kotlin/com/github/codeplangui/execution/executors/BashExecutor.kt deleted file mode 100644 index 7bb7baf..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/executors/BashExecutor.kt +++ /dev/null @@ -1,122 +0,0 @@ -package com.github.codeplangui.execution.executors - -import com.github.codeplangui.execution.* -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.contentOrNull -import kotlinx.serialization.json.jsonPrimitive - -/** - * Wraps the existing CommandExecutionService for run_command / run_powershell. - * Adds dynamic permission classification and deny_rules checking. - */ -class BashExecutor : ToolExecutor { - - override suspend fun execute(input: JsonObject, context: ToolContext): ToolResult { - val command = input["command"]?.jsonPrimitive?.contentOrNull - ?: return ToolResult(ok = false, output = "Missing required parameter: command") - - val description = input["description"]?.jsonPrimitive?.contentOrNull ?: "" - - // deny_rules check - val denied = checkDenyRules(command) - if (denied != null) return ToolResult(ok = false, output = denied) - - // Workspace path check - val basePath = context.cwd - if (CommandExecutionService.hasPathsOutsideWorkspace(command, basePath)) { - return ToolResult(ok = false, output = "Command accesses paths outside the project") - } - - return withContext(Dispatchers.IO) { - val service = CommandExecutionService.getInstance(context.project) - val result = service.executeAsync(command, context.settings.commandTimeoutSeconds) - result.toToolResult() - } - } - - /** Dynamic permission classification based on base command name. */ - fun classifyPermission(command: String): PermissionMode { - val base = CommandExecutionService.extractBaseCommand(command).lowercase() - return when { - base in READ_ONLY_COMMANDS -> PermissionMode.READ_ONLY - base in DEVELOPMENT_COMMANDS -> PermissionMode.WORKSPACE_WRITE - else -> PermissionMode.DANGER_FULL_ACCESS - } - } - - /** Whether this command is safe to run concurrently with other tools. */ - fun isConcurrencySafe(command: String): Boolean { - return classifyPermission(command) == PermissionMode.READ_ONLY - } - - fun isReadOnly(command: String): Boolean = - classifyPermission(command) == PermissionMode.READ_ONLY - - fun isDestructive(command: String): Boolean { - val cmd = command.lowercase() - return DESTRUCTIVE_PATTERNS.any { it.containsMatchIn(cmd) } - } - - private fun checkDenyRules(command: String): String? { - // Path traversal - if (command.contains("../") || command.contains("..\\")) { - return "Path traversal detected in command" - } - // Dangerous delete - if (DANGEROUS_DELETE_PATTERN.containsMatchIn(command)) { - return "Dangerous delete command detected" - } - // Network exfiltration (basic pattern matching) - if (NETWORK_EXFIL_PATTERN.containsMatchIn(command)) { - return "Potential network exfiltration detected" - } - // Fork bomb - if (FORK_BOMB_PATTERN.containsMatchIn(command)) { - return "Fork bomb pattern detected" - } - // Privilege escalation - if (PRIVILEGE_ESCALATION_PATTERN.containsMatchIn(command)) { - return "Privilege escalation detected" - } - return null - } - - companion object { - private val READ_ONLY_COMMANDS = setOf( - "pwd", "ls", "find", "rg", "grep", "cat", "head", "tail", - "wc", "echo", "df", "du", "uname", "whoami", "type", "which", - "get-childitem", "get-content", "select-string", "get-location" - ) - - private val DEVELOPMENT_COMMANDS = setOf( - "git", "npm", "node", "python3", "python", "pytest", "bash", "sh", - "bun", "cargo", "gradle", "mvn", "yarn", "pnpm", "go", "rustc", - "javac", "java", "dotnet", "make", "cmake" - ) - - private val DESTRUCTIVE_PATTERNS = listOf( - Regex("""rm\s+(-\w*\s*)*(-r|--recursive).*\s+/""", RegexOption.IGNORE_CASE), - Regex("""rm\s+(-\w*\s*)*(-r|--recursive).*\s+~""", RegexOption.IGNORE_CASE) - ) - - private val DANGEROUS_DELETE_PATTERN = - Regex("""rm\s+(-\w*\s*)*(-r|--recursive).*\s+(/|~)""", RegexOption.IGNORE_CASE) - - private val NETWORK_EXFIL_PATTERN = Regex( - """(\|\s*(curl|wget)\s)|(>\s*/dev/tcp/)""", - RegexOption.IGNORE_CASE - ) - - private val FORK_BOMB_PATTERN = Regex( - """:\(\)\{.*:\|:&\}|fork\s*bomb""", - RegexOption.IGNORE_CASE - ) - - private val PRIVILEGE_ESCALATION_PATTERN = Regex( - """sudo\s+|chmod\s+[0-7]*77|chown\s+""", - RegexOption.IGNORE_CASE - ) - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/executors/EditFileExecutor.kt b/src/main/kotlin/com/github/codeplangui/execution/executors/EditFileExecutor.kt deleted file mode 100644 index 5b3ed90..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/executors/EditFileExecutor.kt +++ /dev/null @@ -1,102 +0,0 @@ -package com.github.codeplangui.execution.executors - -import com.github.codeplangui.execution.* -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.booleanOrNull -import kotlinx.serialization.json.contentOrNull -import kotlinx.serialization.json.intOrNull -import kotlinx.serialization.json.jsonPrimitive -import java.io.File - -/** - * Precise text replacement in files. - * Returns pendingReview for dispatcher-level approval, then writes. - */ -class EditFileExecutor : ToolExecutor { - - override suspend fun execute(input: JsonObject, context: ToolContext): ToolResult { - val path = input["path"]?.jsonPrimitive?.contentOrNull - ?: return ToolResult(ok = false, output = "Missing required parameter: path") - val search = input["search"]?.jsonPrimitive?.contentOrNull - ?: return ToolResult(ok = false, output = "Missing required parameter: search") - val replace = input["replace"]?.jsonPrimitive?.contentOrNull - ?: return ToolResult(ok = false, output = "Missing required parameter: replace") - val replaceAll = input["replaceAll"]?.jsonPrimitive?.booleanOrNull ?: false - val lineNumber = input["line_number"]?.jsonPrimitive?.intOrNull - - val resolvedPath = ReadFileExecutor.resolveToolPath(path, context.cwd) - ?: return ToolResult(ok = false, output = "Path resolves outside workspace: $path") - - return withContext(Dispatchers.IO) { - val file = File(resolvedPath) - if (!file.exists()) { - return@withContext ToolResult(ok = false, output = "File not found: $path") - } - - val originalContent = file.readText() - if (!originalContent.contains(search)) { - return@withContext ToolResult(ok = false, output = "Search text not found in $path") - } - - // Count matches and find line numbers - val matchLines = originalContent.lines().mapIndexedNotNull { idx, line -> - if (line.contains(search)) idx + 1 else null - } - val matchCount = matchLines.size - - if (!replaceAll && matchCount > 1) { - if (lineNumber == null) { - return@withContext ToolResult( - ok = false, - output = "Found $matchCount matches for the search text in $path. " + - "Matching lines: ${matchLines.joinToString(", ")}. " + - "Provide 'line_number' parameter to specify which match to replace." - ) - } - val targetLine = lineNumber - if (targetLine !in matchLines) { - return@withContext ToolResult( - ok = false, - output = "No match found at line $targetLine. Matching lines: ${matchLines.joinToString(", ")}" - ) - } - } - - // Generate new content - val newContent = if (replaceAll) { - originalContent.split(search).joinToString(replace) - } else if (lineNumber != null && matchCount > 1) { - replaceAtLine(originalContent, search, replace, lineNumber) - } else { - originalContent.replaceFirst(search, replace) - } - - if (newContent == originalContent) { - return@withContext ToolResult(ok = false, output = "No changes made (replacement text same as search text)") - } - - // Return pendingReview — dispatcher handles approval + write - ToolResult( - ok = true, - output = "Pending review", - pendingReview = FileChangeReviewData( - path = resolvedPath, - originalContent = originalContent, - newContent = newContent - ) - ) - } - } - - private fun replaceAtLine(content: String, search: String, replace: String, targetLine: Int): String { - val lines = content.lines().toMutableList() - if (targetLine < 1 || targetLine > lines.size) return content - val idx = targetLine - 1 - if (lines[idx].contains(search)) { - lines[idx] = lines[idx].replaceFirst(search, replace) - } - return lines.joinToString("\n") - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/executors/GrepFilesExecutor.kt b/src/main/kotlin/com/github/codeplangui/execution/executors/GrepFilesExecutor.kt deleted file mode 100644 index 686ff90..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/executors/GrepFilesExecutor.kt +++ /dev/null @@ -1,92 +0,0 @@ -package com.github.codeplangui.execution.executors - -import com.github.codeplangui.execution.CommandExecutionService -import com.github.codeplangui.execution.ExecutionResult -import com.github.codeplangui.execution.ToolContext -import com.github.codeplangui.execution.ToolExecutor -import com.github.codeplangui.execution.ToolResult -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.contentOrNull -import kotlinx.serialization.json.jsonPrimitive -import java.io.File - -/** - * Text search using IntelliJ FindInProjectUtil (preferred) or external rg/grep (fallback). - * Always READ_ONLY, always concurrency-safe. - * - * First version: uses external rg/grep directly. IntelliJ FindInProjectUtil - * integration requires complex setup and will be added in a future iteration. - */ -class GrepFilesExecutor : ToolExecutor { - - override suspend fun execute(input: JsonObject, context: ToolContext): ToolResult { - val pattern = input["pattern"]?.jsonPrimitive?.contentOrNull - ?: return ToolResult(ok = false, output = "Missing required parameter: pattern") - - val path = input["path"]?.jsonPrimitive?.contentOrNull ?: "." - val resolvedPath = ReadFileExecutor.resolveToolPath(path, context.cwd) - ?: return ToolResult(ok = false, output = "Path resolves outside workspace: $path") - - return searchWithExternalTool(pattern, resolvedPath, context) - } - - private suspend fun searchWithExternalTool( - pattern: String, - directoryPath: String, - context: ToolContext - ): ToolResult { - return withContext(Dispatchers.IO) { - val escapedPattern = pattern.replace("'", "'\\''") - val excludeDirs = ".git,.idea,build,.gradle,node_modules,.intellijPlatform,.claude" - // Try rg first, then grep - val command = if (isRgAvailable()) { - "rg -n --no-heading --max-count 50 --glob '!{$excludeDirs}' -- '$escapedPattern' '$directoryPath'" - } else { - "grep -rn --max-count=50 --exclude-dir={$excludeDirs} -- '$escapedPattern' '$directoryPath'" - } - - val service = CommandExecutionService.getInstance(context.project) - val result = service.executeAsync(command, context.settings.commandTimeoutSeconds) - - when (result) { - is ExecutionResult.Success -> { - val output = result.stdout.trim() - if (output.isEmpty()) { - ToolResult(ok = true, output = "(no matches)") - } else { - val lines = output.lines().take(50) - ToolResult(ok = true, output = lines.joinToString("\n")) - } - } - is ExecutionResult.Failed -> { - // grep exits with 1 when no matches - if (result.exitCode == 1) { - ToolResult(ok = true, output = "(no matches)") - } else { - ToolResult(ok = false, output = result.stderr.ifEmpty { "Search failed" }) - } - } - else -> result.toToolResult() - } - } - } - - companion object { - private var rgChecked: Boolean? = null - private var rgAvailable: Boolean = false - - private fun isRgAvailable(): Boolean { - if (rgChecked != null) return rgAvailable - rgAvailable = try { - val process = ProcessBuilder("rg", "--version").start() - process.waitFor() == 0 - } catch (_: Exception) { - false - } - rgChecked = true - return rgAvailable - } - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/executors/ListFilesExecutor.kt b/src/main/kotlin/com/github/codeplangui/execution/executors/ListFilesExecutor.kt deleted file mode 100644 index 036d57c..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/executors/ListFilesExecutor.kt +++ /dev/null @@ -1,54 +0,0 @@ -package com.github.codeplangui.execution.executors - -import com.github.codeplangui.execution.ToolContext -import com.github.codeplangui.execution.ToolExecutor -import com.github.codeplangui.execution.ToolResult -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.contentOrNull -import kotlinx.serialization.json.jsonPrimitive -import java.io.File - -/** - * Lists directory contents. Max 200 entries. - * Always READ_ONLY, always concurrency-safe. - */ -class ListFilesExecutor : ToolExecutor { - - override suspend fun execute(input: JsonObject, context: ToolContext): ToolResult { - val path = input["path"]?.jsonPrimitive?.contentOrNull ?: "." - val resolvedPath = ReadFileExecutor.resolveToolPath(path, context.cwd) - ?: return ToolResult(ok = false, output = "Path resolves outside workspace: $path") - - return withContext(Dispatchers.IO) { - val dir = File(resolvedPath) - if (!dir.exists()) { - return@withContext ToolResult(ok = false, output = "Directory not found: $path") - } - if (!dir.isDirectory) { - return@withContext ToolResult(ok = false, output = "Not a directory: $path") - } - - val entries = dir.listFiles() - ?.sortedWith(compareBy({ !it.isDirectory }, { it.name })) - ?.take(200) - ?: emptyList() - - if (entries.isEmpty()) { - return@withContext ToolResult(ok = true, output = "(empty directory)") - } - - val output = entries.joinToString("\n") { entry -> - val kind = if (entry.isDirectory) "dir" else "file" - "$kind ${entry.name}" - } - - val suffix = if ((dir.listFiles()?.size ?: 0) > 200) { - "\n\n(showing first 200 entries)" - } else "" - - ToolResult(ok = true, output = output + suffix) - } - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/executors/ReadFileExecutor.kt b/src/main/kotlin/com/github/codeplangui/execution/executors/ReadFileExecutor.kt deleted file mode 100644 index f1b33b4..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/executors/ReadFileExecutor.kt +++ /dev/null @@ -1,86 +0,0 @@ -package com.github.codeplangui.execution.executors - -import com.github.codeplangui.execution.ToolContext -import com.github.codeplangui.execution.ToolExecutor -import com.github.codeplangui.execution.ToolResult -import com.intellij.openapi.application.ReadAction -import com.intellij.openapi.vfs.VirtualFileManager -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.contentOrNull -import kotlinx.serialization.json.intOrNull -import kotlinx.serialization.json.jsonPrimitive -import java.io.File - -/** - * Reads file content by line range using IntelliJ VFS. - * Always returns READ_ONLY permission; always concurrency-safe. - */ -class ReadFileExecutor : ToolExecutor { - - override suspend fun execute(input: JsonObject, context: ToolContext): ToolResult { - val path = input["path"]?.jsonPrimitive?.contentOrNull - ?: return ToolResult(ok = false, output = "Missing required parameter: path") - - val lineNumber = input["line_number"]?.jsonPrimitive?.intOrNull ?: 1 - val limit = input["limit"]?.jsonPrimitive?.intOrNull ?: 500 - - // Validate range - if (lineNumber < 1) return ToolResult(ok = false, output = "line_number must be >= 1") - if (limit < 1 || limit > 1000) return ToolResult(ok = false, output = "limit must be between 1 and 1000") - - val resolvedPath = resolveToolPath(path, context.cwd) - ?: return ToolResult(ok = false, output = "Path resolves outside workspace: $path") - - return withContext(Dispatchers.IO) { - val file = File(resolvedPath) - if (!file.exists()) { - return@withContext ToolResult(ok = false, output = "File not found: $path") - } - if (!file.isFile) { - return@withContext ToolResult(ok = false, output = "Not a file: $path") - } - - // Binary check - val headBytes = file.inputStream().buffered().use { it.readNBytes(8192) } - if (headBytes.any { it == 0.toByte() }) { - return@withContext ToolResult( - ok = false, - output = "Binary file, cannot display: $path (${file.length()} bytes)" - ) - } - - val allLines = file.readLines() - val totalLines = allLines.size - - val startIdx = (lineNumber - 1).coerceIn(0, totalLines) - val endIdx = (startIdx + limit).coerceAtMost(totalLines) - val selectedLines = allLines.subList(startIdx, endIdx) - val truncated = endIdx < totalLines - - val sb = StringBuilder() - sb.appendLine("FILE: $path") - sb.appendLine("LINES: ${startIdx + 1}-${endIdx}") - sb.appendLine("TOTAL_LINES: $totalLines") - sb.appendLine("TRUNCATED: ${if (truncated) "yes" else "no"}") - sb.appendLine() - - val maxLineNumWidth = (endIdx).toString().length - for ((i, line) in selectedLines.withIndex()) { - val lineNum = (startIdx + i + 1).toString().padStart(maxLineNumWidth) - sb.appendLine("$lineNum→$line") - } - - ToolResult(ok = true, output = sb.toString()) - } - } - - companion object { - fun resolveToolPath(path: String, cwd: String): String? { - val resolved = File(cwd, path).canonicalPath - val canonicalCwd = File(cwd).canonicalPath - return if (resolved.startsWith(canonicalCwd)) resolved else null - } - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/executors/WriteFileExecutor.kt b/src/main/kotlin/com/github/codeplangui/execution/executors/WriteFileExecutor.kt deleted file mode 100644 index 3d25a0a..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/executors/WriteFileExecutor.kt +++ /dev/null @@ -1,62 +0,0 @@ -package com.github.codeplangui.execution.executors - -import com.github.codeplangui.execution.* -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.contentOrNull -import kotlinx.serialization.json.jsonPrimitive -import java.io.File - -/** - * Whole-file write (create or overwrite). - * Returns pendingReview for dispatcher-level approval, then writes. - */ -class WriteFileExecutor : ToolExecutor { - - override suspend fun execute(input: JsonObject, context: ToolContext): ToolResult { - val path = input["path"]?.jsonPrimitive?.contentOrNull - ?: return ToolResult(ok = false, output = "Missing required parameter: path") - val content = input["content"]?.jsonPrimitive?.contentOrNull - ?: return ToolResult(ok = false, output = "Missing required parameter: content") - - val resolvedPath = ReadFileExecutor.resolveToolPath(path, context.cwd) - ?: return ToolResult(ok = false, output = "Path resolves outside workspace: $path") - - return withContext(Dispatchers.IO) { - val file = File(resolvedPath) - val isNewFile = !file.exists() - - if (isNewFile) { - // New file: return pendingReview with isNewFile = true - ToolResult( - ok = true, - output = "Pending review", - pendingReview = FileChangeReviewData( - path = resolvedPath, - originalContent = "", - newContent = content, - isNewFile = true, - newContentForCreate = content - ) - ) - } else { - // Existing file: compute diff, return pendingReview - val originalContent = file.readText() - if (originalContent == content) { - return@withContext ToolResult(ok = true, output = "File unchanged: $path") - } - - ToolResult( - ok = true, - output = "Pending review", - pendingReview = FileChangeReviewData( - path = resolvedPath, - originalContent = originalContent, - newContent = content - ) - ) - } - } - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/hooks/ToolExecutionLogger.kt b/src/main/kotlin/com/github/codeplangui/execution/hooks/ToolExecutionLogger.kt deleted file mode 100644 index bfcdbce..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/hooks/ToolExecutionLogger.kt +++ /dev/null @@ -1,25 +0,0 @@ -package com.github.codeplangui.execution.hooks - -import com.github.codeplangui.execution.ToolHook -import com.github.codeplangui.execution.ToolResult -import com.intellij.openapi.diagnostic.Logger -import kotlinx.serialization.json.JsonObject - -/** - * Default Pre/Post Hook that logs tool call execution to IDE logs. - */ -class ToolExecutionLogger : ToolHook { - - private val logger = Logger.getInstance(ToolExecutionLogger::class.java) - - override suspend fun beforeExecute(toolName: String, input: JsonObject): ToolResult? { - logger.info("[ToolCall] Executing: $toolName | input size: ${input.toString().length}") - return null // Continue execution - } - - override suspend fun afterExecute(toolName: String, input: JsonObject, result: ToolResult) { - val status = if (result.ok) "OK" else "FAILED" - val outputPreview = result.output.take(200).replace("\n", " ") - logger.info("[ToolCall] Completed: $toolName | $status | output: $outputPreview") - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/review/ChangeReviewStrategy.kt b/src/main/kotlin/com/github/codeplangui/execution/review/ChangeReviewStrategy.kt deleted file mode 100644 index c7faea8..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/review/ChangeReviewStrategy.kt +++ /dev/null @@ -1,42 +0,0 @@ -package com.github.codeplangui.execution.review - -import com.intellij.openapi.project.Project - -/** - * File change review strategy. - * DialogReview: current implementation (Messages.showYesNoDialog), kept as fallback. - * EditorInlineReview: inline diff display inside IntelliJ editor (target implementation). - */ -interface ChangeReviewStrategy { - - /** - * Review a file modification. - * @return true = user accepted, false = user rejected or timeout - */ - suspend fun reviewFileChange( - project: Project, - requestId: String, - path: String, - originalContent: String, - newContent: String - ): Boolean - - /** - * Review a new file creation. - * @return true = user confirmed creation, false = user rejected or timeout - */ - suspend fun reviewNewFile( - project: Project, - requestId: String, - path: String, - content: String - ): Boolean - - /** Session trust state (in EditorInline mode, synced with settings). */ - var sessionTrusted: Boolean - - /** Reset session trust. */ - fun resetSessionTrust() { - sessionTrusted = false - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/review/DialogReview.kt b/src/main/kotlin/com/github/codeplangui/execution/review/DialogReview.kt deleted file mode 100644 index 979f32f..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/review/DialogReview.kt +++ /dev/null @@ -1,91 +0,0 @@ -package com.github.codeplangui.execution.review - -import com.intellij.openapi.application.ApplicationManager -import com.intellij.openapi.project.Project -import com.intellij.openapi.ui.Messages -import java.util.concurrent.CompletableFuture -import java.util.concurrent.TimeUnit - -/** - * Fallback strategy: uses Messages.showYesNoDialog with a simple diff summary. - * Extracted from the original FileChangeReview, behavior unchanged. - */ -class DialogReview : ChangeReviewStrategy { - - override var sessionTrusted: Boolean = false - - override suspend fun reviewFileChange( - project: Project, requestId: String, path: String, - originalContent: String, newContent: String - ): Boolean { - if (sessionTrusted) return true - - val future = CompletableFuture() - - ApplicationManager.getApplication().invokeAndWait { - val oldLines = originalContent.lines().size - val newLines = newContent.lines().size - val added = (newLines - oldLines).coerceAtLeast(0) - val removed = (oldLines - newLines).coerceAtLeast(0) - - val message = buildString { - appendLine("Apply changes to $path?") - appendLine() - appendLine("Lines: +$added / -$removed (was $oldLines, now $newLines)") - appendLine() - val oldSet = originalContent.lines().toSet() - val changed = newContent.lines().filter { it !in oldSet }.take(5) - if (changed.isNotEmpty()) { - appendLine("--- New/changed lines (preview) ---") - changed.forEach { appendLine(it) } - } - } - - val result = Messages.showYesNoDialog( - project, message, "File Change Review: $path", - Messages.getQuestionIcon() - ) - future.complete(result == Messages.YES) - } - - return future.get(60, TimeUnit.SECONDS) - } - - override suspend fun reviewNewFile( - project: Project, requestId: String, path: String, content: String - ): Boolean { - if (sessionTrusted) return true - - val future = CompletableFuture() - - ApplicationManager.getApplication().invokeAndWait { - val lineCount = content.lines().size - val sizeBytes = content.toByteArray().size - - val message = buildString { - appendLine("Create new file?") - appendLine() - appendLine("Path: $path") - appendLine("Size: ${formatSize(sizeBytes)} / $lineCount lines") - appendLine() - appendLine("--- Preview (first 20 lines) ---") - content.lines().take(20).forEach { appendLine(it) } - if (lineCount > 20) appendLine("... ($lineCount lines total)") - } - - val result = Messages.showOkCancelDialog( - project, message, "Create New File", - "Create", "Cancel", Messages.getQuestionIcon() - ) - future.complete(result == Messages.OK) - } - - return future.get(60, TimeUnit.SECONDS) - } - - private fun formatSize(bytes: Int): String = when { - bytes < 1024 -> "$bytes B" - bytes < 1024 * 1024 -> "${bytes / 1024} KB" - else -> "${bytes / (1024 * 1024)} MB" - } -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/review/DiffHunk.kt b/src/main/kotlin/com/github/codeplangui/execution/review/DiffHunk.kt deleted file mode 100644 index 3552812..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/review/DiffHunk.kt +++ /dev/null @@ -1,146 +0,0 @@ -package com.github.codeplangui.execution.review - -/** - * A single contiguous change region in a diff. - */ -data class DiffHunk( - val startLine: Int, - val deletedLines: List, - val insertedLines: List -) - -/** - * Computes diff hunks from two text strings using LCS-based line diff. - */ -object DiffCalculator { - - fun computeHunks(original: String, new: String): List { - val oldLines = original.lines() - val newLines = new.lines() - val lcs = computeLcsTable(oldLines, newLines) - val diffOps = backtrackDiff(lcs, oldLines, newLines) - return groupIntoHunks(diffOps) - } - - private fun computeLcsTable(old: List, new: List): Array { - val m = old.size - val n = new.size - val dp = Array(m + 1) { IntArray(n + 1) } - - for (i in 1..m) { - for (j in 1..n) { - if (old[i - 1] == new[j - 1]) { - dp[i][j] = dp[i - 1][j - 1] + 1 - } else { - dp[i][j] = maxOf(dp[i - 1][j], dp[i][j - 1]) - } - } - } - return dp - } - - private fun backtrackDiff( - dp: Array, - old: List, - new: List - ): List { - val ops = mutableListOf() - var i = old.size - var j = new.size - - while (i > 0 || j > 0) { - when { - i > 0 && j > 0 && old[i - 1] == new[j - 1] -> { - ops.add(DiffOp.Equal(i - 1, j - 1)) - i-- - j-- - } - j > 0 && (i == 0 || dp[i][j - 1] >= dp[i - 1][j]) -> { - ops.add(DiffOp.Insert(j - 1)) - j-- - } - else -> { - ops.add(DiffOp.Delete(i - 1)) - i-- - } - } - } - return ops.reversed() - } - - private fun groupIntoHunks(ops: List): List { - val hunks = mutableListOf() - val currentDeletes = mutableListOf() - val currentInserts = mutableListOf() - var hunkStartLine = -1 - - fun flushHunk() { - if (currentDeletes.isNotEmpty() || currentInserts.isNotEmpty()) { - hunks.add(DiffHunk(hunkStartLine, currentDeletes.toList(), currentInserts.toList())) - currentDeletes.clear() - currentInserts.clear() - hunkStartLine = -1 - } - } - - for (op in ops) { - when (op) { - is DiffOp.Equal -> flushHunk() - is DiffOp.Delete -> { - if (hunkStartLine == -1) hunkStartLine = op.oldIndex - currentDeletes.add(/* placeholder — filled below */"") - } - is DiffOp.Insert -> { - if (hunkStartLine == -1) hunkStartLine = op.newIndex - currentInserts.add(/* placeholder — filled below */"") - } - } - } - flushHunk() - return hunks - } - - // Recompute with actual line content - fun computeHunksWithContent(original: String, new: String): List { - val oldLines = original.lines() - val newLines = new.lines() - val dp = computeLcsTable(oldLines, newLines) - val ops = backtrackDiff(dp, oldLines, newLines) - - val hunks = mutableListOf() - val currentDeletes = mutableListOf() - val currentInserts = mutableListOf() - var hunkStartLine = -1 - - fun flushHunk() { - if (currentDeletes.isNotEmpty() || currentInserts.isNotEmpty()) { - hunks.add(DiffHunk(hunkStartLine, currentDeletes.toList(), currentInserts.toList())) - currentDeletes.clear() - currentInserts.clear() - hunkStartLine = -1 - } - } - - for (op in ops) { - when (op) { - is DiffOp.Equal -> flushHunk() - is DiffOp.Delete -> { - if (hunkStartLine == -1) hunkStartLine = op.oldIndex - currentDeletes.add(oldLines[op.oldIndex]) - } - is DiffOp.Insert -> { - if (hunkStartLine == -1) hunkStartLine = op.newIndex - currentInserts.add(newLines[op.newIndex]) - } - } - } - flushHunk() - return hunks - } -} - -private sealed class DiffOp { - data class Equal(val oldIndex: Int, val newIndex: Int) : DiffOp() - data class Delete(val oldIndex: Int) : DiffOp() - data class Insert(val newIndex: Int) : DiffOp() -} diff --git a/src/main/kotlin/com/github/codeplangui/execution/review/EditorInlineReview.kt b/src/main/kotlin/com/github/codeplangui/execution/review/EditorInlineReview.kt deleted file mode 100644 index 17743f9..0000000 --- a/src/main/kotlin/com/github/codeplangui/execution/review/EditorInlineReview.kt +++ /dev/null @@ -1,167 +0,0 @@ -package com.github.codeplangui.execution.review - -import com.intellij.diff.DiffContentFactory -import com.intellij.diff.DiffManager -import com.intellij.diff.requests.SimpleDiffRequest -import com.intellij.openapi.application.ApplicationManager -import com.intellij.openapi.fileTypes.FileTypeManager -import com.intellij.openapi.project.Project -import com.intellij.openapi.ui.DialogWrapper -import com.intellij.openapi.ui.Messages -import com.intellij.openapi.vfs.LocalFileSystem -import java.awt.BorderLayout -import java.awt.event.ActionEvent -import java.util.concurrent.CompletableFuture -import java.util.concurrent.TimeUnit -import javax.swing.AbstractAction -import javax.swing.Action -import javax.swing.JComponent -import javax.swing.JPanel - -/** - * IntelliJ native diff dialog review strategy. - * Shows a diff viewer using DiffManager with Accept/Reject buttons. - */ -class EditorInlineReview( - private val project: Project -) : ChangeReviewStrategy { - - override var sessionTrusted: Boolean = false - - override suspend fun reviewFileChange( - project: Project, requestId: String, path: String, - originalContent: String, newContent: String - ): Boolean { - if (sessionTrusted) return true - - val future = CompletableFuture() - - ApplicationManager.getApplication().invokeAndWait { - val contentFactory = DiffContentFactory.getInstance() - - val ext = path.substringAfterLast('.', "") - val fileType = FileTypeManager.getInstance().getFileTypeByExtension(ext) - val virtualFile = LocalFileSystem.getInstance().findFileByPath(path) - - val content1 = if (virtualFile != null) { - contentFactory.create(project, originalContent, virtualFile) - } else { - contentFactory.create(project, originalContent, fileType) - } - val content2 = if (virtualFile != null) { - contentFactory.create(project, newContent, virtualFile) - } else { - contentFactory.create(project, newContent, fileType) - } - - val fileName = path.substringAfterLast('/') - val request = SimpleDiffRequest( - "Review Changes: $fileName", - content1, content2, - "Before", "After" - ) - - val dialog = DiffReviewDialog(project, request, future) - dialog.show() - } - - return try { - future.get(120, TimeUnit.SECONDS) - } catch (_: Exception) { - false - } - } - - override suspend fun reviewNewFile( - project: Project, requestId: String, path: String, content: String - ): Boolean { - if (sessionTrusted) return true - return showCreateFileConfirmation(project, path, content) - } - - private fun showCreateFileConfirmation( - project: Project, path: String, content: String - ): Boolean { - val future = CompletableFuture() - - ApplicationManager.getApplication().invokeAndWait { - val lineCount = content.lines().size - val sizeBytes = content.toByteArray().size - - val message = buildString { - appendLine("Create new file?") - appendLine() - appendLine("Path: $path") - appendLine("Size: ${formatSize(sizeBytes)} / $lineCount lines") - appendLine() - appendLine("--- Preview (first 20 lines) ---") - content.lines().take(20).forEach { appendLine(it) } - if (lineCount > 20) appendLine("... ($lineCount lines total)") - } - - val result = Messages.showOkCancelDialog( - project, message, "Create New File", - "Create", "Cancel", Messages.getQuestionIcon() - ) - future.complete(result == Messages.OK) - } - - return future.get(60, TimeUnit.SECONDS) - } - - private fun formatSize(bytes: Int): String = when { - bytes < 1024 -> "$bytes B" - bytes < 1024 * 1024 -> "${bytes / 1024} KB" - else -> "${bytes / (1024 * 1024)} MB" - } - - /** - * Custom dialog wrapping DiffManager's diff viewer with Accept/Reject buttons. - */ - private class DiffReviewDialog( - project: Project, - request: SimpleDiffRequest, - private val resultFuture: CompletableFuture - ) : DialogWrapper(project, true) { - - private val diffRequest = request - private val dialogProject = project - - init { - title = request.title - setModal(true) - init() - } - - override fun createCenterPanel(): JComponent { - val panel = JPanel(BorderLayout()) - val diffPanel = DiffManager.getInstance().createRequestPanel(dialogProject, this.disposable, null) - diffPanel.setRequest(diffRequest) - panel.add(diffPanel.component, BorderLayout.CENTER) - panel.preferredSize = java.awt.Dimension(900, 600) - return panel - } - - override fun createActions(): Array { - return arrayOf( - object : AbstractAction("Accept") { - override fun actionPerformed(e: ActionEvent?) { - resultFuture.complete(true) - close(OK_EXIT_CODE) - } - }, - object : AbstractAction("Reject") { - override fun actionPerformed(e: ActionEvent?) { - resultFuture.complete(false) - close(CANCEL_EXIT_CODE) - } - } - ) - } - - override fun doCancelAction() { - resultFuture.complete(false) - super.doCancelAction() - } - } -} diff --git a/src/main/kotlin/com/github/codeplangui/tools/ToolExecutionContext.kt b/src/main/kotlin/com/github/codeplangui/tools/ToolExecutionContext.kt index fa4f5fc..fb7846a 100644 --- a/src/main/kotlin/com/github/codeplangui/tools/ToolExecutionContext.kt +++ b/src/main/kotlin/com/github/codeplangui/tools/ToolExecutionContext.kt @@ -13,6 +13,15 @@ data class ToolExecutionContext( val toolUseId: String, val abortJob: Job, val permissionContext: ToolPermissionContext = ToolPermissionContext.default(), + /** Default timeout for Bash commands when the model omits `timeoutSeconds`. */ + val commandTimeoutSeconds: Int = 120, + /** + * M3: real permission decision callback. Called when `checkPermissions` returns `Ask`. + * Returns true to allow, false to deny. Default auto-approves so existing callers + * that don't supply a callback keep the pre-M3 behavior. + * M5 wires this to the Bridge approval dialog. + */ + val onPermissionAsked: suspend (ToolUpdate.PermissionAsked) -> Boolean = { true }, ) /** diff --git a/src/main/kotlin/com/github/codeplangui/tools/ToolExecutor.kt b/src/main/kotlin/com/github/codeplangui/tools/ToolExecutor.kt index 7f40c44..d8a5978 100644 --- a/src/main/kotlin/com/github/codeplangui/tools/ToolExecutor.kt +++ b/src/main/kotlin/com/github/codeplangui/tools/ToolExecutor.kt @@ -1,9 +1,13 @@ package com.github.codeplangui.tools import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.joinAll import kotlinx.coroutines.launch +import kotlinx.coroutines.sync.Semaphore +import kotlinx.coroutines.sync.withPermit /** * Single-tool execution flow. Mirrors Claude Code's `runToolUse()` @@ -57,8 +61,13 @@ fun runToolUse( return@channelFlow } is PermissionResult.Ask -> { - send(ToolUpdate.PermissionAsked(id, tool.name, p.reason, p.preview)) - // MVP: auto-approve. M3 replaces this with a real user-decision callback. + val event = ToolUpdate.PermissionAsked(id, tool.name, p.reason, p.preview) + send(event) + val allowed = context.onPermissionAsked(event) + if (!allowed) { + send(ToolUpdate.Failed(id, ToolUpdate.Failed.Stage.PERMISSION, "Permission denied by user")) + return@channelFlow + } } } @@ -90,3 +99,116 @@ fun runToolUse( send(ToolUpdate.Completed(id, block)) } + +/** Max number of concurrency-safe tools executed in parallel within one batch. */ +const val MAX_BATCH_CONCURRENCY: Int = 10 + +/** Resolved or failed entry for a single tool_use in a batch. */ +private sealed class BatchEntry { + abstract val toolUse: ToolUseBlock + data class Resolved(override val toolUse: ToolUseBlock, val tool: Tool<*, *>, val isSafe: Boolean) : BatchEntry() + data class Failed(override val toolUse: ToolUseBlock, val reason: String) : BatchEntry() +} + +/** + * Multi-tool execution flow. Mirrors Claude Code's `partitionToolCalls()` + + * `runToolsConcurrently()` (toolOrchestration.ts:91-177). + * + * The LLM may emit multiple `tool_use` blocks in one turn. This function: + * 1. Resolves each block to a tool via the pool. Blocks whose tool is missing + * emit `Started` + `Failed(LOOKUP)` immediately. + * 2. Partitions resolved blocks by `isConcurrencySafe(parsedInput)`: + * - safe → run in parallel, bounded by [maxConcurrency] + * - unsafe → run serially after the safe group + * 3. Merges all `ToolUpdate` events into a single [Flow]. Ordering across + * different tool_use_ids is not deterministic for the concurrent group; + * updates for the same id stay in order. + * + * Partition-time parsing is best-effort: if `parseInput` throws, the block is + * classified as unsafe (so it's isolated) and the real PARSE failure is emitted + * by the underlying [runToolUse] call. On the happy path, input is parsed twice + * (once for classification, once for execution) — a deliberate trade-off to keep + * `runToolUse` self-contained. + * + * MVP limitations: + * - `ToolResult.contextModifier` is not threaded through the serial unsafe + * group; tools that mutate context see the original. Add in M3 if needed. + */ +fun runToolUseBatch( + toolUses: List, + pool: List>, + context: ToolExecutionContext, + maxConcurrency: Int = MAX_BATCH_CONCURRENCY, +): Flow = channelFlow { + require(maxConcurrency > 0) { "maxConcurrency must be > 0, was $maxConcurrency" } + + val poolMap = buildMap> { + for (tool in pool) { + put(tool.name, tool) + for (alias in tool.aliases) { + put(alias, tool) + } + } + } + + val entries = toolUses.map { tu -> + val tool = poolMap[tu.name] + if (tool == null) { + BatchEntry.Failed(tu, "Tool not found: ${tu.name}") + } else { + BatchEntry.Resolved(tu, tool, isSafe = classifyConcurrencySafe(tool, tu)) + } + } + + for (e in entries) { + if (e is BatchEntry.Failed) { + send(ToolUpdate.Started(e.toolUse.toolUseId, e.toolUse.name)) + send(ToolUpdate.Failed(e.toolUse.toolUseId, ToolUpdate.Failed.Stage.LOOKUP, e.reason)) + } + } + + val resolved = entries.filterIsInstance() + val (safeGroup, unsafeGroup) = resolved.partition { it.isSafe } + + suspend fun executeEntry(e: BatchEntry.Resolved) { + runToolUseErased(e.tool, e.toolUse, context).collect { send(it) } + } + + if (safeGroup.isNotEmpty()) { + // Semaphore limits concurrent entries into executeEntry, not thread count. + // limitedParallelism(n) would only limit threads — a suspended tool releases + // its thread, allowing the dispatcher to start another, bypassing the bound. + val semaphore = Semaphore(maxConcurrency) + coroutineScope { + safeGroup.map { e -> launch { semaphore.withPermit { executeEntry(e) } } }.joinAll() + } + } + + for (e in unsafeGroup) { + executeEntry(e) + } +} + +/** + * Best-effort concurrency-safety probe. Parses input with the tool and asks it. + * Parse failures classify as unsafe — they'll surface as PARSE errors when the + * real runToolUse runs. + */ +private fun classifyConcurrencySafe(tool: Tool<*, *>, toolUse: ToolUseBlock): Boolean { + return try { + val t = tool.asErased() + val input = t.parseInput(toolUse.input) + t.isConcurrencySafe(input) + } catch (_: Exception) { + false + } +} + +@Suppress("UNCHECKED_CAST") +private fun Tool<*, *>.asErased(): Tool = this as Tool + +private fun runToolUseErased( + tool: Tool<*, *>, + toolUse: ToolUseBlock, + context: ToolExecutionContext, +): Flow = runToolUse(tool.asErased(), toolUse, context) diff --git a/src/main/kotlin/com/github/codeplangui/tools/ToolRegistry.kt b/src/main/kotlin/com/github/codeplangui/tools/ToolRegistry.kt new file mode 100644 index 0000000..873dc33 --- /dev/null +++ b/src/main/kotlin/com/github/codeplangui/tools/ToolRegistry.kt @@ -0,0 +1,80 @@ +package com.github.codeplangui.tools + +import com.github.codeplangui.tools.bash.BashTool +import com.github.codeplangui.tools.file.FileEditTool +import com.github.codeplangui.tools.file.FileListTool +import com.github.codeplangui.tools.file.FileReadTool +import com.github.codeplangui.tools.file.FileSearchTool +import com.github.codeplangui.tools.file.WriteFileTool + +/** + * Central registry of built-in tools and the plumbing that assembles the pool + * handed to the LLM on each turn. + * + * Mirrors Claude Code's `tools.ts` (L193-389): + * - `getAllBaseTools()` — hard-coded list of built-in tools + * - `filterToolsByDenyRules()` — prunes tools blocked by permission config so + * they never enter the prompt (save tokens + prevent the model from + * attempting them) + * - `assembleToolPool()` — merges built-in + MCP tools, dedupes by name with + * built-in winning, sorts for prompt-cache stability. + * + * MVP: + * - No feature flags. Every built-in tool is always registered. + * - Deny rule format is simple prefix: rule == tool.name matches exactly, + * rule == "ToolName(" matches any parametrized variant (future M3). + * - MCP tools default to empty list. + */ +object ToolRegistry { + + /** + * All built-in tools, unfiltered. Order does not matter — `assembleToolPool` + * sorts by name. + */ + fun getAllBaseTools(): List> = listOf( + BashTool, + FileEditTool, + FileListTool, + FileReadTool, + FileSearchTool, + WriteFileTool, + ) + + /** + * Pre-filter tools by the permission context's deny rules. Denied tools are + * removed from the pool entirely so the LLM never sees them. + */ + fun filterToolsByDenyRules( + tools: List>, + context: ToolExecutionContext, + ): List> { + val deny = context.permissionContext.alwaysDeny + if (deny.isEmpty()) return tools + return tools.filterNot { tool -> isDenied(tool, deny) } + } + + /** + * Build the final tool pool for a request: base ∪ mcp, dedupe by name + * (built-in wins), sorted for prompt-cache stability. + */ + fun assembleToolPool( + baseTools: List> = getAllBaseTools(), + mcpTools: List> = emptyList(), + ): List> { + val seen = mutableSetOf() + return (baseTools + mcpTools) + .filter { seen.add(it.name) } + .sortedBy { it.name } + } + + /** + * Look up a tool by name or alias within a pool. Returns null if not found. + */ + fun findByName(pool: List>, name: String): Tool<*, *>? = + pool.firstOrNull { it.name == name || name in it.aliases } + + private fun isDenied(tool: Tool<*, *>, denyRules: Set): Boolean = + denyRules.any { rule -> + rule == tool.name || rule.startsWith("${tool.name}(") + } +} diff --git a/src/main/kotlin/com/github/codeplangui/tools/bash/BashTool.kt b/src/main/kotlin/com/github/codeplangui/tools/bash/BashTool.kt index 043c205..c98bb2a 100644 --- a/src/main/kotlin/com/github/codeplangui/tools/bash/BashTool.kt +++ b/src/main/kotlin/com/github/codeplangui/tools/bash/BashTool.kt @@ -20,10 +20,9 @@ import kotlinx.serialization.json.put data class BashInput( val command: String, val description: String? = null, - val timeoutSeconds: Int = DEFAULT_TIMEOUT_SECONDS, + val timeoutSeconds: Int? = null, ) { companion object { - const val DEFAULT_TIMEOUT_SECONDS = 120 const val MAX_TIMEOUT_SECONDS = 600 } } @@ -60,6 +59,8 @@ private fun isDestructiveCommand(command: String): Boolean = val BashTool: Tool = tool { name = "Bash" + // Alias the legacy execution/ tool names so the LLM can call either naming scheme. + aliases = listOf("run_command", "run_powershell") description = """ Execute a shell command in the project's working directory. Supports piping, redirection, and subshells. Output over 20k chars is truncated. Use for git, @@ -88,10 +89,11 @@ val BashTool: Tool = tool { parse { raw: JsonElement -> json.decodeFromJsonElement(BashInput.serializer(), raw) } validate { input, _ -> + val t = input.timeoutSeconds when { input.command.isBlank() -> ValidationResult.Failed("command must not be blank", errorCode = 1) - input.timeoutSeconds !in 1..BashInput.MAX_TIMEOUT_SECONDS -> + t != null && t !in 1..BashInput.MAX_TIMEOUT_SECONDS -> ValidationResult.Failed( "timeoutSeconds must be between 1 and ${BashInput.MAX_TIMEOUT_SECONDS}", errorCode = 2, @@ -128,9 +130,11 @@ val BashTool: Tool = tool { call { input, ctx, onProgress -> val service = CommandExecutionService.getInstance(ctx.project) + val effectiveTimeout = (input.timeoutSeconds ?: ctx.commandTimeoutSeconds) + .coerceIn(1, BashInput.MAX_TIMEOUT_SECONDS) val result = service.executeAsyncWithStream( command = input.command, - timeoutSeconds = input.timeoutSeconds, + timeoutSeconds = effectiveTimeout, onOutput = { line, isError -> onProgress(if (isError) Progress.Stderr(line) else Progress.Stdout(line)) }, @@ -167,7 +171,7 @@ private fun previewFor(input: BashInput, basePath: String): PreviewResult = Prev summary = "Run: ${input.command}", details = buildString { append("Working dir: ").appendLine(basePath) - append("Timeout: ").append(input.timeoutSeconds).appendLine("s") + append("Timeout: ").append(input.timeoutSeconds ?: "default").appendLine("s") input.description?.let { append("Intent: ").appendLine(it) } }, risk = if (isDestructiveCommand(input.command)) PreviewResult.Risk.HIGH else PreviewResult.Risk.MEDIUM, diff --git a/src/main/kotlin/com/github/codeplangui/tools/file/FileEditTool.kt b/src/main/kotlin/com/github/codeplangui/tools/file/FileEditTool.kt new file mode 100644 index 0000000..f005b3d --- /dev/null +++ b/src/main/kotlin/com/github/codeplangui/tools/file/FileEditTool.kt @@ -0,0 +1,241 @@ +package com.github.codeplangui.tools.file + +import com.github.codeplangui.tools.PermissionResult +import com.github.codeplangui.tools.PreviewResult +import com.github.codeplangui.tools.Tool +import com.github.codeplangui.tools.ToolPermissionContext +import com.github.codeplangui.tools.ToolResult +import com.github.codeplangui.tools.ToolResultBlock +import com.github.codeplangui.tools.ValidationResult +import com.github.codeplangui.tools.tool +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import java.io.File + +@Serializable +data class FileEditInput( + val path: String, + val oldString: String, + val newString: String, + val replaceAll: Boolean = false, +) + +data class FileEditOutput( + val path: String, + val replacements: Int, + val diff: String, +) + +private val json = Json { ignoreUnknownKeys = true } +private const val DIFF_CONTEXT_LINES = 3 + +val FileEditTool: Tool = tool { + name = "FileEdit" + description = """ + Edit an existing file by replacing exact text. Provide the `path`, the exact + `oldString` to match (must exist verbatim), and the `newString` to substitute. + Set `replaceAll` to true to replace all occurrences (default: first only). + Returns a unified diff of the change. Use FileWrite to create new files. + """.trimIndent() + + inputSchema = buildJsonObject { + put("type", "object") + put("required", JsonArray(listOf(JsonPrimitive("path"), JsonPrimitive("oldString"), JsonPrimitive("newString")))) + put("properties", buildJsonObject { + put("path", buildJsonObject { + put("type", "string") + put("description", "Absolute path, or path relative to project root.") + }) + put("oldString", buildJsonObject { + put("type", "string") + put("description", "Exact text to search for. Must exist verbatim in the file.") + }) + put("newString", buildJsonObject { + put("type", "string") + put("description", "Replacement text.") + }) + put("replaceAll", buildJsonObject { + put("type", "boolean") + put("description", "Replace all occurrences. Defaults to false (replace first only).") + }) + }) + } + + parse { raw: JsonElement -> json.decodeFromJsonElement(FileEditInput.serializer(), raw) } + + validate { input, _ -> + when { + input.path.isBlank() -> + ValidationResult.Failed("path must not be blank", errorCode = 1) + input.oldString.isEmpty() -> + ValidationResult.Failed("oldString must not be empty", errorCode = 2) + input.oldString == input.newString -> + ValidationResult.Failed("oldString and newString are identical — no change would occur", errorCode = 3) + else -> ValidationResult.Ok + } + } + + checkPermissions { input, ctx -> + val basePath = ctx.project.basePath + ?: return@checkPermissions PermissionResult.Deny("Project path unavailable") + val resolved = resolveInsideWorkspace(input.path, basePath, ctx.permissionContext.additionalWorkingDirectories) + ?: return@checkPermissions PermissionResult.Deny("Path resolves outside workspace: ${input.path}") + + val content = withContext(Dispatchers.IO) { + val file = File(resolved) + if (!file.exists() || !file.isFile) null else file.readText(Charsets.UTF_8) + } ?: return@checkPermissions PermissionResult.Deny("File not found: ${input.path}") + + val count = content.split(input.oldString).size - 1 + if (count == 0) { + return@checkPermissions PermissionResult.Deny("oldString not found in ${input.path}") + } + + // ACCEPT_EDITS / BYPASS modes skip the confirmation dialog + if (ctx.permissionContext.mode == ToolPermissionContext.Mode.ACCEPT_EDITS || + ctx.permissionContext.mode == ToolPermissionContext.Mode.BYPASS) { + return@checkPermissions PermissionResult.Allow( + Json.encodeToJsonElement(FileEditInput.serializer(), input) + ) + } + + val effectiveCount = if (input.replaceAll) count else 1 + val diff = buildEditPreviewDiff(content, input.oldString, input.newString, input.replaceAll, input.path) + PermissionResult.Ask( + reason = "Edit ${input.path}: replace $effectiveCount occurrence(s)", + preview = PreviewResult( + summary = "Replace $effectiveCount occurrence(s) in ${input.path}", + details = diff, + risk = PreviewResult.Risk.MEDIUM, + ), + ) + } + + preview { input, ctx -> + val basePath = ctx.project.basePath ?: return@preview null + val resolved = resolveInsideWorkspace( + input.path, basePath, ctx.permissionContext.additionalWorkingDirectories + ) ?: return@preview null + val content = withContext(Dispatchers.IO) { + val f = File(resolved) + if (!f.exists() || !f.isFile) null else f.readText(Charsets.UTF_8) + } ?: return@preview null + val count = content.split(input.oldString).size - 1 + if (count == 0) return@preview null + val effectiveCount = if (input.replaceAll) count else 1 + val diff = buildEditPreviewDiff(content, input.oldString, input.newString, input.replaceAll, input.path) + PreviewResult( + summary = "Replace $effectiveCount occurrence(s) in ${input.path}", + details = diff, + risk = PreviewResult.Risk.MEDIUM, + ) + } + + call { input, ctx, _ -> + val basePath = ctx.project.basePath!! + val resolved = resolveInsideWorkspace( + input.path, basePath, ctx.permissionContext.additionalWorkingDirectories + )!! + + withContext(Dispatchers.IO) { + val file = File(resolved) + require(file.exists() && file.isFile) { "File not found: ${input.path}" } + + val original = file.readText(Charsets.UTF_8) + val count = original.split(input.oldString).size - 1 + require(count > 0) { "oldString not found in ${input.path}" } + + val effectiveCount = if (input.replaceAll) count else 1 + val modified = if (input.replaceAll) { + original.replace(input.oldString, input.newString) + } else { + original.replaceFirst(input.oldString, input.newString) + } + + val diff = buildUnifiedDiff(original, modified, input.path) + file.writeText(modified, Charsets.UTF_8) + + ToolResult(FileEditOutput(resolved, effectiveCount, diff)) + } + } + + mapResult { output, toolUseId -> + val content = buildString { + appendLine("path: ${output.path}") + appendLine("replacements: ${output.replacements}") + appendLine() + append(output.diff) + } + ToolResultBlock(toolUseId = toolUseId, content = content) + } + + isConcurrencySafe { false } + isReadOnly { false } + isDestructive { true } + + activityDescription { input -> input?.let { "Editing ${it.path}" } } +} + +/** Preview diff for the Ask permission prompt — shows first hunk plus a "N more" note. */ +internal fun buildEditPreviewDiff( + content: String, + oldString: String, + newString: String, + replaceAll: Boolean, + path: String, +): String { + val firstHunk = buildUnifiedDiff(content, content.replaceFirst(oldString, newString), path) + if (!replaceAll) return firstHunk + val extra = content.split(oldString).size - 2 // total occurrences minus the one already shown + return if (extra > 0) "$firstHunk\n\n... and $extra more replacement(s)" else firstHunk +} + +/** + * Compute a unified diff between two texts. Shows the changed line block with + * DIFF_CONTEXT_LINES of surrounding context. Returns "(no line changes)" when + * the line sequences are identical (e.g. only whitespace within a line changed + * and the line set is the same — shouldn't occur in practice but handled cleanly). + */ +internal fun buildUnifiedDiff(original: String, modified: String, path: String): String { + val origLines = original.lines() + val modLines = modified.lines() + + if (origLines == modLines) return "(no line changes)" + + // Number of leading lines that match (prefix) + val prefix = origLines.zip(modLines).takeWhile { (a, b) -> a == b }.size + + // Number of trailing lines that match (suffix), capped to avoid overlapping the prefix + var suffix = 0 + val origTail = origLines.size - prefix + val modTail = modLines.size - prefix + while (suffix < origTail && suffix < modTail && + origLines[origLines.size - 1 - suffix] == modLines[modLines.size - 1 - suffix] + ) { + suffix++ + } + + val lastOrig = origLines.size - 1 - suffix // last changed line in original (inclusive) + val lastMod = modLines.size - 1 - suffix // last changed line in modified (inclusive) + + val ctxStart = maxOf(0, prefix - DIFF_CONTEXT_LINES) + val ctxEndOrig = minOf(origLines.size, lastOrig + 1 + DIFF_CONTEXT_LINES) + val ctxEndMod = minOf(modLines.size, lastMod + 1 + DIFF_CONTEXT_LINES) + + return buildString { + appendLine("--- a/$path") + appendLine("+++ b/$path") + appendLine("@@ -${ctxStart + 1},${ctxEndOrig - ctxStart} +${ctxStart + 1},${ctxEndMod - ctxStart} @@") + for (i in ctxStart until prefix) appendLine(" ${origLines[i]}") + for (i in prefix..lastOrig) appendLine("-${origLines[i]}") + for (i in prefix..lastMod) appendLine("+${modLines[i]}") + for (i in (lastOrig + 1) until ctxEndOrig) appendLine(" ${origLines[i]}") + }.trimEnd() +} diff --git a/src/main/kotlin/com/github/codeplangui/tools/file/FileListTool.kt b/src/main/kotlin/com/github/codeplangui/tools/file/FileListTool.kt new file mode 100644 index 0000000..ee97c97 --- /dev/null +++ b/src/main/kotlin/com/github/codeplangui/tools/file/FileListTool.kt @@ -0,0 +1,129 @@ +package com.github.codeplangui.tools.file + +import com.github.codeplangui.tools.PermissionResult +import com.github.codeplangui.tools.Tool +import com.github.codeplangui.tools.ToolResult +import com.github.codeplangui.tools.ToolResultBlock +import com.github.codeplangui.tools.ValidationResult +import com.github.codeplangui.tools.tool +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import java.io.File + +@Serializable +data class FileListInput( + val path: String = ".", + val recursive: Boolean = false, + val includeHidden: Boolean = false, +) + +data class FileListOutput( + val path: String, + val entries: List, + val truncated: Boolean, +) + +private val json = Json { ignoreUnknownKeys = true } +private const val FILE_LIST_MAX_ENTRIES = 500 + +val FileListTool: Tool = tool { + name = "FileList" + aliases = listOf("list_files") + description = """ + List files and directories in a given path. Defaults to the project root. + Set recursive=true to walk subdirectories (capped at $FILE_LIST_MAX_ENTRIES entries). + """.trimIndent() + + inputSchema = buildJsonObject { + put("type", "object") + put("properties", buildJsonObject { + put("path", buildJsonObject { + put("type", "string") + put("description", "Directory to list. Defaults to '.' (project root).") + }) + put("recursive", buildJsonObject { + put("type", "boolean") + put("description", "Walk subdirectories. Defaults to false.") + }) + put("includeHidden", buildJsonObject { + put("type", "boolean") + put("description", "Include hidden files/dirs (starting with '.'). Defaults to false.") + }) + }) + } + + parse { raw: JsonElement -> json.decodeFromJsonElement(FileListInput.serializer(), raw) } + + validate { input, _ -> + if (input.path.isBlank()) ValidationResult.Failed("path must not be blank", errorCode = 1) + else ValidationResult.Ok + } + + checkPermissions { input, ctx -> + val basePath = ctx.project.basePath + ?: return@checkPermissions PermissionResult.Deny("Project path unavailable") + resolveInsideWorkspace(input.path, basePath, ctx.permissionContext.additionalWorkingDirectories) + ?: return@checkPermissions PermissionResult.Deny("Path resolves outside workspace: ${input.path}") + PermissionResult.Allow(Json.encodeToJsonElement(FileListInput.serializer(), input)) + } + + call { input, ctx, _ -> + val basePath = ctx.project.basePath!! + val resolved = resolveInsideWorkspace( + input.path, basePath, ctx.permissionContext.additionalWorkingDirectories + )!! + + withContext(Dispatchers.IO) { + val dir = File(resolved) + require(dir.exists()) { "Directory not found: ${input.path}" } + require(dir.isDirectory) { "Not a directory: ${input.path}" } + + val entries = mutableListOf() + var truncated = false + + fun collect(f: File, prefix: String) { + if (entries.size >= FILE_LIST_MAX_ENTRIES) { truncated = true; return } + if (!input.includeHidden && f.name.startsWith('.')) return + val line = buildString { + append(prefix) + append(f.name) + if (f.isDirectory) append('/') + } + entries.add(line) + if (input.recursive && f.isDirectory) { + f.listFiles()?.sortedWith(compareBy({ !it.isDirectory }, { it.name })) + ?.forEach { collect(it, "$prefix ") } + } + } + + dir.listFiles() + ?.sortedWith(compareBy({ !it.isDirectory }, { it.name })) + ?.forEach { collect(it, "") } + + ToolResult(FileListOutput(resolved, entries, truncated)) + } + } + + mapResult { output, toolUseId -> + val content = buildString { + appendLine("path: ${output.path}") + appendLine() + output.entries.forEach { appendLine(it) } + if (output.truncated) appendLine("... (truncated at $FILE_LIST_MAX_ENTRIES entries)") + }.trimEnd() + ToolResultBlock(toolUseId, content) + } + + isConcurrencySafe { true } + isReadOnly { true } + isDestructive { false } + + activityDescription { input -> input?.let { "Listing ${it.path}" } } +} diff --git a/src/main/kotlin/com/github/codeplangui/tools/file/FileReadTool.kt b/src/main/kotlin/com/github/codeplangui/tools/file/FileReadTool.kt new file mode 100644 index 0000000..49a1f6e --- /dev/null +++ b/src/main/kotlin/com/github/codeplangui/tools/file/FileReadTool.kt @@ -0,0 +1,198 @@ +package com.github.codeplangui.tools.file + +import com.github.codeplangui.tools.PermissionResult +import com.github.codeplangui.tools.Progress +import com.github.codeplangui.tools.Tool +import com.github.codeplangui.tools.ToolExecutionContext +import com.github.codeplangui.tools.ToolResult +import com.github.codeplangui.tools.ToolResultBlock +import com.github.codeplangui.tools.ValidationResult +import com.github.codeplangui.tools.tool +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import java.io.File + +@Serializable +data class FileReadInput( + val path: String, + val offset: Int? = null, + val limit: Int? = null, +) + +data class FileReadOutput( + val content: String, + val path: String, + val totalLines: Int, + val returnedLines: Int, + val truncated: Boolean, +) + +private val json = Json { ignoreUnknownKeys = true } + +// Hard caps — mirrored to Claude Code's FileReadTool limits. A single tool call +// must not blow out the model context window. +internal const val FILE_READ_MAX_LINES = 10_000 +internal const val FILE_READ_MAX_BYTES = 2L * 1024 * 1024 // 2 MiB +internal const val FILE_READ_DEFAULT_LIMIT = 2_000 + +val FileReadTool: Tool = tool { + name = "FileRead" + description = """ + Read text file contents with optional offset/limit. Returns lines prefixed by + line numbers (format: " 1→content"). Hard caps: 10000 lines or 2MB per call; + use offset/limit to page through larger files. + """.trimIndent() + + inputSchema = buildJsonObject { + put("type", "object") + put("required", JsonArray(listOf(JsonPrimitive("path")))) + put("properties", buildJsonObject { + put("path", buildJsonObject { + put("type", "string") + put("description", "Absolute path, or path relative to project root.") + }) + put("offset", buildJsonObject { + put("type", "integer") + put("description", "1-indexed starting line. Defaults to 1.") + }) + put("limit", buildJsonObject { + put("type", "integer") + put("description", "Number of lines to return. Defaults to 2000, max 10000.") + }) + }) + } + + parse { raw: JsonElement -> json.decodeFromJsonElement(FileReadInput.serializer(), raw) } + + validate { input, _ -> + when { + input.path.isBlank() -> + ValidationResult.Failed("path must not be blank", errorCode = 1) + input.offset != null && input.offset < 1 -> + ValidationResult.Failed("offset must be >= 1", errorCode = 2) + input.limit != null && (input.limit < 1 || input.limit > FILE_READ_MAX_LINES) -> + ValidationResult.Failed( + "limit must be between 1 and $FILE_READ_MAX_LINES", + errorCode = 3, + ) + else -> ValidationResult.Ok + } + } + + checkPermissions { input, ctx -> + val basePath = ctx.project.basePath + ?: return@checkPermissions PermissionResult.Deny("Project path unavailable") + val resolved = resolveInsideWorkspace(input.path, basePath, ctx.permissionContext.additionalWorkingDirectories) + if (resolved == null) { + PermissionResult.Deny("Path resolves outside workspace: ${input.path}") + } else { + PermissionResult.Allow(Json.encodeToJsonElement(FileReadInput.serializer(), input)) + } + } + + // preview() returns null — read operations have no side effects. + + call { input, ctx, _: (Progress) -> Unit -> + val basePath = ctx.project.basePath + ?: error("Project path unavailable; permission layer should have rejected this") + val resolvedPath = resolveInsideWorkspace(input.path, basePath, ctx.permissionContext.additionalWorkingDirectories) + ?: error("Path outside workspace; permission layer should have rejected this") + + withContext(Dispatchers.IO) { + val file = File(resolvedPath) + require(file.exists()) { "File not found: ${input.path}" } + require(file.isFile) { "Not a regular file: ${input.path}" } + + val allLines = if (file.length() > FILE_READ_MAX_BYTES) { + // Stream up to the byte cap; dropping trailing partial line to avoid + // handing the LLM a half-word. + val bytes = file.inputStream().buffered().use { it.readNBytes(FILE_READ_MAX_BYTES.toInt()) } + String(bytes, Charsets.UTF_8).lines().dropLast(1) + } else { + file.readText(Charsets.UTF_8).lines() + } + // `readText` + `lines()` produces a trailing empty element for files ending + // in newline. Drop it so totalLines matches intuition. + val totalLines = if (allLines.isNotEmpty() && allLines.last().isEmpty()) allLines.size - 1 else allLines.size + val effectiveLines = if (allLines.isNotEmpty() && allLines.last().isEmpty()) allLines.dropLast(1) else allLines + + val startIdx = (input.offset ?: 1) - 1 // 0-indexed + val limit = input.limit ?: FILE_READ_DEFAULT_LIMIT + val endIdx = minOf(startIdx + limit, effectiveLines.size) + val returned = if (startIdx >= effectiveLines.size) emptyList() else effectiveLines.subList(startIdx, endIdx) + + val truncatedByByteCap = file.length() > FILE_READ_MAX_BYTES + val truncatedByLimit = endIdx < effectiveLines.size + val truncated = truncatedByByteCap || truncatedByLimit + + val width = maxOf(1, endIdx.toString().length) + val content = buildString { + returned.forEachIndexed { i, line -> + val ln = (startIdx + i + 1).toString().padStart(width) + append(ln).append('→').append(line).append('\n') + } + }.trimEnd('\n') + + ToolResult( + FileReadOutput( + content = content, + path = resolvedPath, + totalLines = totalLines, + returnedLines = returned.size, + truncated = truncated, + ) + ) + } + } + + mapResult { output, toolUseId -> + val header = buildString { + appendLine("path: ${output.path}") + appendLine("total_lines: ${output.totalLines}") + appendLine("returned_lines: ${output.returnedLines}") + if (output.truncated) appendLine("truncated: true — use offset/limit to continue") + appendLine() + } + ToolResultBlock( + toolUseId = toolUseId, + content = header + output.content, + isError = false, + ) + } + + isConcurrencySafe { true } + isReadOnly { true } + isDestructive { false } + + activityDescription { input -> input?.let { "Reading ${it.path}" } } +} + +/** + * Resolve a tool-supplied path to an absolute canonical path that is inside the + * workspace (or one of the explicitly allowed additional directories). Returns + * null if the path escapes via `..`, symlink, or an unrelated absolute path. + */ +internal fun resolveInsideWorkspace( + rawPath: String, + basePath: String, + additionalWorkingDirs: Set = emptySet(), +): String? { + val resolved = File(basePath, rawPath).canonicalPath + val canonicalBase = File(basePath).canonicalPath + val allowedRoots = buildList { + add(canonicalBase) + additionalWorkingDirs.forEach { add(File(it).canonicalPath) } + } + return if (allowedRoots.any { root -> resolved == root || resolved.startsWith("$root${File.separator}") }) { + resolved + } else { + null + } +} diff --git a/src/main/kotlin/com/github/codeplangui/tools/file/FileSearchTool.kt b/src/main/kotlin/com/github/codeplangui/tools/file/FileSearchTool.kt new file mode 100644 index 0000000..d6ef0b6 --- /dev/null +++ b/src/main/kotlin/com/github/codeplangui/tools/file/FileSearchTool.kt @@ -0,0 +1,173 @@ +package com.github.codeplangui.tools.file + +import com.github.codeplangui.tools.PermissionResult +import com.github.codeplangui.tools.Tool +import com.github.codeplangui.tools.ToolResult +import com.github.codeplangui.tools.ToolResultBlock +import com.github.codeplangui.tools.ValidationResult +import com.github.codeplangui.tools.tool +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import java.io.File + +@Serializable +data class FileSearchInput( + val pattern: String, + val path: String = ".", + val glob: String = "*", + val ignoreCase: Boolean = false, + val maxResults: Int = 200, +) + +data class FileSearchMatch( + val file: String, + val line: Int, + val text: String, +) + +data class FileSearchOutput( + val matches: List, + val truncated: Boolean, +) + +private val json = Json { ignoreUnknownKeys = true } +private const val FILE_SEARCH_MAX_RESULTS = 500 +private const val FILE_SEARCH_MAX_LINE_LEN = 400 + +val FileSearchTool: Tool = tool { + name = "FileSearch" + aliases = listOf("grep_files", "search_files") + description = """ + Search file contents for a regex or literal pattern. Returns matching lines with + file path and line number. Defaults to searching the project root recursively. + Use `glob` to restrict to specific file types (e.g. "*.kt", "*.ts"). + """.trimIndent() + + inputSchema = buildJsonObject { + put("type", "object") + put("required", JsonArray(listOf(JsonPrimitive("pattern")))) + put("properties", buildJsonObject { + put("pattern", buildJsonObject { + put("type", "string") + put("description", "Regex or literal string to search for.") + }) + put("path", buildJsonObject { + put("type", "string") + put("description", "Root directory to search. Defaults to '.' (project root).") + }) + put("glob", buildJsonObject { + put("type", "string") + put("description", "Glob pattern to filter files, e.g. '*.kt'. Defaults to '*'.") + }) + put("ignoreCase", buildJsonObject { + put("type", "boolean") + put("description", "Case-insensitive matching. Defaults to false.") + }) + put("maxResults", buildJsonObject { + put("type", "integer") + put("description", "Max matches to return. Defaults to 200, max $FILE_SEARCH_MAX_RESULTS.") + }) + }) + } + + parse { raw: JsonElement -> json.decodeFromJsonElement(FileSearchInput.serializer(), raw) } + + validate { input, _ -> + when { + input.pattern.isBlank() -> ValidationResult.Failed("pattern must not be blank", errorCode = 1) + input.maxResults !in 1..FILE_SEARCH_MAX_RESULTS -> + ValidationResult.Failed("maxResults must be between 1 and $FILE_SEARCH_MAX_RESULTS", errorCode = 2) + runCatching { if (input.ignoreCase) Regex(input.pattern, RegexOption.IGNORE_CASE) else Regex(input.pattern) }.isFailure -> + ValidationResult.Failed("pattern is not a valid regex", errorCode = 3) + else -> ValidationResult.Ok + } + } + + checkPermissions { input, ctx -> + val basePath = ctx.project.basePath + ?: return@checkPermissions PermissionResult.Deny("Project path unavailable") + resolveInsideWorkspace(input.path, basePath, ctx.permissionContext.additionalWorkingDirectories) + ?: return@checkPermissions PermissionResult.Deny("Path resolves outside workspace: ${input.path}") + PermissionResult.Allow(Json.encodeToJsonElement(FileSearchInput.serializer(), input)) + } + + call { input, ctx, _ -> + val basePath = ctx.project.basePath!! + val resolved = resolveInsideWorkspace( + input.path, basePath, ctx.permissionContext.additionalWorkingDirectories + )!! + + withContext(Dispatchers.IO) { + val regex = if (input.ignoreCase) Regex(input.pattern, RegexOption.IGNORE_CASE) + else Regex(input.pattern) + val globRegex = globToRegex(input.glob) + + val matches = mutableListOf() + var truncated = false + + File(resolved).walkTopDown() + .filter { it.isFile && globRegex.matches(it.name) } + .sortedBy { it.path } + .forEach { file -> + if (truncated) return@forEach + try { + file.bufferedReader(Charsets.UTF_8).useLines { lines -> + lines.forEachIndexed { idx, raw -> + if (truncated) return@forEachIndexed + val line = if (raw.length > FILE_SEARCH_MAX_LINE_LEN) + raw.take(FILE_SEARCH_MAX_LINE_LEN) + "…" + else raw + if (regex.containsMatchIn(line)) { + matches.add(FileSearchMatch(file.path, idx + 1, line)) + if (matches.size >= input.maxResults) truncated = true + } + } + } + } catch (_: Exception) { + // Skip unreadable files (binary, permission denied, etc.) + } + } + + ToolResult(FileSearchOutput(matches, truncated)) + } + } + + mapResult { output, toolUseId -> + val content = buildString { + output.matches.forEach { m -> + appendLine("${m.file}:${m.line}:${m.text}") + } + if (output.matches.isEmpty()) appendLine("(no matches)") + if (output.truncated) appendLine("... (truncated)") + }.trimEnd() + ToolResultBlock(toolUseId, content) + } + + isConcurrencySafe { true } + isReadOnly { true } + isDestructive { false } + + activityDescription { input -> input?.let { "Searching for '${it.pattern}' in ${it.path}" } } +} + +/** Convert a simple glob pattern (*, ?) to a Regex that matches filenames. */ +private fun globToRegex(glob: String): Regex { + val sb = StringBuilder("^") + for (c in glob) { + when (c) { + '*' -> sb.append(".*") + '?' -> sb.append(".") + '.' -> sb.append("\\.") + else -> sb.append(Regex.escape(c.toString())) + } + } + sb.append('$') + return Regex(sb.toString()) +} diff --git a/src/main/kotlin/com/github/codeplangui/tools/file/WriteFileTool.kt b/src/main/kotlin/com/github/codeplangui/tools/file/WriteFileTool.kt new file mode 100644 index 0000000..0ed14bc --- /dev/null +++ b/src/main/kotlin/com/github/codeplangui/tools/file/WriteFileTool.kt @@ -0,0 +1,121 @@ +package com.github.codeplangui.tools.file + +import com.github.codeplangui.tools.PermissionResult +import com.github.codeplangui.tools.PreviewResult +import com.github.codeplangui.tools.Tool +import com.github.codeplangui.tools.ToolPermissionContext +import com.github.codeplangui.tools.ToolResult +import com.github.codeplangui.tools.ToolResultBlock +import com.github.codeplangui.tools.ValidationResult +import com.github.codeplangui.tools.tool +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import java.io.File + +@Serializable +data class WriteFileInput( + val path: String, + val content: String, +) + +data class WriteFileOutput( + val path: String, + val bytesWritten: Long, + val isNewFile: Boolean, +) + +private val json = Json { ignoreUnknownKeys = true } + +val WriteFileTool: Tool = tool { + name = "WriteFile" + description = """ + Write content to a file, creating it if it doesn't exist or overwriting it if it does. + Prefer FileEdit for targeted in-place edits. Use WriteFile when creating new files or + when replacing the entire file content. Returns the number of bytes written. + """.trimIndent() + + inputSchema = buildJsonObject { + put("type", "object") + put("required", JsonArray(listOf(JsonPrimitive("path"), JsonPrimitive("content")))) + put("properties", buildJsonObject { + put("path", buildJsonObject { + put("type", "string") + put("description", "Absolute path, or path relative to project root.") + }) + put("content", buildJsonObject { + put("type", "string") + put("description", "Full content to write to the file.") + }) + }) + } + + parse { raw: JsonElement -> json.decodeFromJsonElement(WriteFileInput.serializer(), raw) } + + validate { input, _ -> + if (input.path.isBlank()) ValidationResult.Failed("path must not be blank", errorCode = 1) + else ValidationResult.Ok + } + + checkPermissions { input, ctx -> + val basePath = ctx.project.basePath + ?: return@checkPermissions PermissionResult.Deny("Project path unavailable") + val resolved = resolveInsideWorkspace(input.path, basePath, ctx.permissionContext.additionalWorkingDirectories) + ?: return@checkPermissions PermissionResult.Deny("Path resolves outside workspace: ${input.path}") + + if (ctx.permissionContext.mode == ToolPermissionContext.Mode.ACCEPT_EDITS || + ctx.permissionContext.mode == ToolPermissionContext.Mode.BYPASS) { + return@checkPermissions PermissionResult.Allow( + Json.encodeToJsonElement(WriteFileInput.serializer(), input) + ) + } + + val isNew = !File(resolved).exists() + val action = if (isNew) "Create" else "Overwrite" + PermissionResult.Ask( + reason = "$action ${input.path}", + preview = PreviewResult( + summary = "$action file: ${input.path}", + details = "Content length: ${input.content.length} chars\n" + + if (isNew) "(new file)" else "(file will be overwritten)", + risk = if (isNew) PreviewResult.Risk.LOW else PreviewResult.Risk.HIGH, + ), + ) + } + + call { input, ctx, _ -> + val basePath = ctx.project.basePath!! + val resolved = resolveInsideWorkspace( + input.path, basePath, ctx.permissionContext.additionalWorkingDirectories + )!! + + withContext(Dispatchers.IO) { + val file = File(resolved) + val isNew = !file.exists() + file.parentFile?.mkdirs() + val bytes = input.content.toByteArray(Charsets.UTF_8) + file.writeBytes(bytes) + ToolResult(WriteFileOutput(resolved, bytes.size.toLong(), isNew)) + } + } + + mapResult { output, toolUseId -> + val action = if (output.isNewFile) "Created" else "Wrote" + ToolResultBlock( + toolUseId = toolUseId, + content = "$action ${output.path} (${output.bytesWritten} bytes)", + ) + } + + isConcurrencySafe { false } + isReadOnly { false } + isDestructive { true } + + activityDescription { input -> input?.let { "Writing ${it.path}" } } +} diff --git a/src/main/kotlin/com/github/codeplangui/tools/mcp/McpClient.kt b/src/main/kotlin/com/github/codeplangui/tools/mcp/McpClient.kt new file mode 100644 index 0000000..51f9199 --- /dev/null +++ b/src/main/kotlin/com/github/codeplangui/tools/mcp/McpClient.kt @@ -0,0 +1,165 @@ +package com.github.codeplangui.tools.mcp + +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import java.io.Closeable +import java.io.InputStream +import java.io.OutputStream +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger + +/** + * Stdio JSON-RPC 2.0 client for an MCP server process. + * + * Lifecycle: + * 1. [connect] — sends `initialize` + `notifications/initialized`, then calls `tools/list`. + * 2. [call] — sends `tools/call` and awaits the response. + * 3. [close] — cancels the reader and destroys the process. + * + * Inject [input]/[output] directly for unit tests (see [fromStreams]). + * Production callers use [fromConfig] which launches the child process. + */ +class McpClient internal constructor( + val serverName: String, + private val process: Process?, + private val input: InputStream, + private val output: OutputStream, + private val scope: CoroutineScope, +) : Closeable { + + private val json = Json { ignoreUnknownKeys = true; encodeDefaults = false } + private val pending = ConcurrentHashMap>() + private val idSeq = AtomicInteger(0) + private val writer = output.bufferedWriter(Charsets.UTF_8) + private var readerJob: Job? = null + + /** Start the reader loop and perform the MCP handshake. Returns the server's tool list. */ + suspend fun connect(): List { + startReader() + + sendRequest( + "initialize", + buildJsonObject { + put("protocolVersion", MCP_PROTOCOL_VERSION) + put("capabilities", buildJsonObject {}) + put("clientInfo", buildJsonObject { + put("name", "CodePlanGUI") + put("version", "1.0") + }) + }, + ) + sendNotification("notifications/initialized") + + val result = sendRequest("tools/list") + return json.decodeFromJsonElement(ToolListResult.serializer(), result).tools + } + + /** Execute a tool on the MCP server and return the call result. */ + suspend fun call(toolName: String, arguments: JsonElement): McpCallResult { + val result = sendRequest( + "tools/call", + buildJsonObject { + put("name", toolName) + put("arguments", arguments) + }, + ) + return json.decodeFromJsonElement(McpCallResult.serializer(), result) + } + + override fun close() { + readerJob?.cancel() + runCatching { writer.close() } + process?.destroyForcibly() + pending.values.forEach { it.cancel() } + pending.clear() + } + + // ─── Internal ──────────────────────────────────────────────────────────── + + private fun startReader() { + readerJob = scope.launch(Dispatchers.IO) { + input.bufferedReader(Charsets.UTF_8).use { reader -> + reader.forEachLine { line -> + if (line.isBlank()) return@forEachLine + try { + val resp = json.decodeFromString(JsonRpcResponse.serializer(), line) + val id = resp.id ?: return@forEachLine + val deferred = pending.remove(id) ?: return@forEachLine + when { + resp.error != null -> + deferred.completeExceptionally(McpException(resp.error.message, resp.error.code)) + else -> + deferred.complete(resp.result ?: JsonNull) + } + } catch (_: Exception) { + // Ignore malformed lines (stderr leakage, debug output, etc.) + } + } + } + // Process closed — fail any requests that never got a response. + val err = McpException("MCP server '$serverName' process ended unexpectedly", -32000) + pending.values.forEach { it.completeExceptionally(err) } + pending.clear() + } + } + + private suspend fun sendRequest(method: String, params: JsonElement? = null): JsonElement { + val id = idSeq.incrementAndGet() + val deferred = CompletableDeferred() + pending[id] = deferred + + val req = JsonRpcRequest(id = id, method = method, params = params) + withContext(Dispatchers.IO) { + writer.write(json.encodeToString(JsonRpcRequest.serializer(), req)) + writer.newLine() + writer.flush() + } + + return withTimeout(MCP_CALL_TIMEOUT_MS) { deferred.await() } + } + + private suspend fun sendNotification(method: String, params: JsonElement? = null) { + withContext(Dispatchers.IO) { + val obj = buildJsonObject { + put("jsonrpc", "2.0") + put("method", method) + params?.let { put("params", it) } + } + writer.write(json.encodeToString(kotlinx.serialization.json.JsonObject.serializer(), obj)) + writer.newLine() + writer.flush() + } + } + + companion object { + /** + * Spawn a child process from [config] and wire its stdio to a new client. + * The caller is responsible for eventually calling [close]. + */ + fun fromConfig(config: McpServerConfig, scope: CoroutineScope): McpClient { + val proc = ProcessBuilder(buildList { add(config.command); addAll(config.args) }) + .apply { config.env.forEach { (k, v) -> environment()[k] = v } } + .redirectErrorStream(false) + .start() + return McpClient(config.name, proc, proc.inputStream, proc.outputStream, scope) + } + + /** For unit tests: inject arbitrary streams instead of a real process. */ + fun fromStreams( + serverName: String, + input: InputStream, + output: OutputStream, + scope: CoroutineScope, + ): McpClient = McpClient(serverName, process = null, input, output, scope) + } +} diff --git a/src/main/kotlin/com/github/codeplangui/tools/mcp/McpConnectionManager.kt b/src/main/kotlin/com/github/codeplangui/tools/mcp/McpConnectionManager.kt new file mode 100644 index 0000000..9a089c6 --- /dev/null +++ b/src/main/kotlin/com/github/codeplangui/tools/mcp/McpConnectionManager.kt @@ -0,0 +1,73 @@ +package com.github.codeplangui.tools.mcp + +import com.github.codeplangui.tools.Tool +import com.intellij.openapi.Disposable +import com.intellij.openapi.project.Project +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.serialization.json.JsonElement +import java.util.concurrent.ConcurrentHashMap + +/** + * Manages the lifecycle of external MCP server connections for a single project. + * + * Responsibilities: + * - Spawn child processes from [McpServerConfig] + * - Run the MCP handshake and collect the server's tool list + * - Expose all active MCP tools as [Tool] instances (for use by [ToolRegistry.assembleToolPool]) + * - Tear down connections when the project closes (via [Disposable]) + * + * M5 will extend this with: + * - Reconnection with exponential backoff on process death + * - Heartbeat / ping monitoring + * - Config reload without restarting the plugin + */ +class McpConnectionManager( + @Suppress("unused") private val project: Project, +) : Disposable { + + private val scope = CoroutineScope(Dispatchers.IO + SupervisorJob()) + private val clients = ConcurrentHashMap() + private val _tools = ConcurrentHashMap>>() + + /** All currently active MCP proxy tools, flat across all connected servers. */ + val tools: List> + get() = _tools.values.flatten() + + /** + * Connect to an MCP server: spawns the process, performs the handshake, and + * registers the remote tools. Safe to call from a coroutine. + * Replaces any existing connection for [config.name] without error. + */ + suspend fun addServer(config: McpServerConfig) { + removeServer(config.name) + + val client = McpClient.fromConfig(config, scope) + val specs = try { + client.connect() + } catch (t: Throwable) { + client.close() + throw McpException("Failed to connect to MCP server '${config.name}': ${t.message}") + } + + clients[config.name] = client + _tools[config.name] = specs.map { spec -> mcpProxyTool(config.name, spec, client) } + } + + /** + * Disconnect and remove all tools from [serverName]. + * No-op if the server was not registered. + */ + fun removeServer(serverName: String) { + clients.remove(serverName)?.close() + _tools.remove(serverName) + } + + /** Tear down all connections — called by the IntelliJ plugin lifecycle on project close. */ + override fun dispose() { + clients.keys.toList().forEach { removeServer(it) } + scope.cancel() + } +} diff --git a/src/main/kotlin/com/github/codeplangui/tools/mcp/McpProxyTool.kt b/src/main/kotlin/com/github/codeplangui/tools/mcp/McpProxyTool.kt new file mode 100644 index 0000000..2236446 --- /dev/null +++ b/src/main/kotlin/com/github/codeplangui/tools/mcp/McpProxyTool.kt @@ -0,0 +1,96 @@ +package com.github.codeplangui.tools.mcp + +import com.github.codeplangui.tools.PermissionResult +import com.github.codeplangui.tools.PreviewResult +import com.github.codeplangui.tools.Tool +import com.github.codeplangui.tools.ToolResult +import com.github.codeplangui.tools.ToolResultBlock +import com.github.codeplangui.tools.ValidationResult +import com.github.codeplangui.tools.tool +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.jsonObject + +/** + * Wrap a remote MCP tool as a local [Tool] instance. + * + * Naming convention: `mcp__{serverName}__{toolName}` — matches Claude Code's + * `__` separator so the LLM can distinguish MCP tools from built-ins. + * + * Design decisions: + * - Input/Output are both [JsonElement]: the remote schema is forwarded as-is so + * the LLM sees the exact schema the MCP server advertises. + * - `parseInput` is the identity — input arrives from the LLM as JSON and is + * forwarded without re-serialization. + * - `checkPermissions` always returns Ask (MCP tools are opaque; we don't know + * whether they mutate state). M5 may add per-server permission rules. + * - `isConcurrencySafe` is false (conservative default for remote tools). + * - `isDestructive` is true (conservative default). + */ +fun mcpProxyTool( + serverName: String, + spec: McpToolSpec, + client: McpClient, +): Tool { + val toolId = "mcp__${serverName}__${spec.name}" + val json = Json { ignoreUnknownKeys = true } + + return tool { + name = toolId + description = spec.description.ifBlank { "MCP tool '$toolId'" } + inputSchema = spec.inputSchema + + parse { raw: JsonElement -> raw } + + validate { input, _ -> + val obj = runCatching { input.jsonObject }.getOrNull() + ?: return@validate ValidationResult.Failed("MCP tool input must be a JSON object", errorCode = 1) + val totalBytes = json.encodeToString(JsonElement.serializer(), input).length + if (totalBytes > MAX_INPUT_BYTES) { + ValidationResult.Failed("Input exceeds ${MAX_INPUT_BYTES / 1024}KB limit", errorCode = 2) + } else { + ValidationResult.Ok + } + } + + checkPermissions { input, ctx -> + PermissionResult.Ask( + reason = "Call MCP tool $toolId", + preview = PreviewResult( + summary = "Invoke $toolId on server '$serverName'", + details = "Arguments:\n${json.encodeToString(JsonElement.serializer(), input)}", + risk = PreviewResult.Risk.MEDIUM, + ), + ) + } + + // No preview() — MCP servers don't expose a dry-run mechanism in the MVP. + + call { input, _, onProgress -> + onProgress(com.github.codeplangui.tools.Progress.Status("Calling $toolId…")) + val result = client.call(spec.name, input) + ToolResult(result) + } + + mapResult { result, toolUseId -> + val content = buildString { + if (result.isError) appendLine("error: true") + append(result.textContent()) + } + ToolResultBlock( + toolUseId = toolUseId, + content = content, + isError = result.isError, + ) + } + + isConcurrencySafe { false } + isReadOnly { false } + isDestructive { true } + + activityDescription { _ -> "Calling $toolId" } + } +} + +private const val MAX_INPUT_BYTES = 1024 * 1024 // 1MB diff --git a/src/main/kotlin/com/github/codeplangui/tools/mcp/McpTypes.kt b/src/main/kotlin/com/github/codeplangui/tools/mcp/McpTypes.kt new file mode 100644 index 0000000..59006f4 --- /dev/null +++ b/src/main/kotlin/com/github/codeplangui/tools/mcp/McpTypes.kt @@ -0,0 +1,78 @@ +package com.github.codeplangui.tools.mcp + +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject + +/** Configuration for an external MCP server launched via stdio. */ +data class McpServerConfig( + val name: String, + val command: String, + val args: List = emptyList(), + val env: Map = emptyMap(), +) + +// ─── JSON-RPC 2.0 wire types ──────────────────────────────────────────────── + +@Serializable +data class JsonRpcRequest( + val jsonrpc: String = "2.0", + val id: Int, + val method: String, + val params: JsonElement? = null, +) + +@Serializable +data class JsonRpcResponse( + val jsonrpc: String = "2.0", + val id: Int? = null, + val result: JsonElement? = null, + val error: JsonRpcError? = null, +) + +@Serializable +data class JsonRpcError( + val code: Int, + val message: String, + val data: JsonElement? = null, +) + +// ─── MCP protocol types ────────────────────────────────────────────────────── + +/** A single tool advertised by an MCP server. */ +@Serializable +data class McpToolSpec( + val name: String, + val description: String = "", + val inputSchema: JsonObject = JsonObject(emptyMap()), +) + +@Serializable +internal data class ToolListResult(val tools: List = emptyList()) + +/** A single content item in a tools/call response. */ +@Serializable +data class McpContentItem( + val type: String, + val text: String? = null, +) + +/** The result block returned by tools/call. */ +@Serializable +data class McpCallResult( + val content: List = emptyList(), + val isError: Boolean = false, +) { + /** Flatten all text content items into a single string. */ + fun textContent(): String = content + .filter { it.type == "text" } + .joinToString("\n") { it.text ?: "" } + .trimEnd() +} + +/** Thrown when the MCP server returns a JSON-RPC error or the process dies. */ +class McpException(message: String, val code: Int = -1) : RuntimeException(message) + +internal const val MCP_CALL_TIMEOUT_MS: Long = 30_000 +internal const val MCP_PROTOCOL_VERSION: String = "2024-11-05" diff --git a/src/test/kotlin/com/github/codeplangui/tools/ToolExecutorBatchTest.kt b/src/test/kotlin/com/github/codeplangui/tools/ToolExecutorBatchTest.kt new file mode 100644 index 0000000..2504309 --- /dev/null +++ b/src/test/kotlin/com/github/codeplangui/tools/ToolExecutorBatchTest.kt @@ -0,0 +1,245 @@ +package com.github.codeplangui.tools + +import com.intellij.openapi.project.Project +import io.mockk.every +import io.mockk.mockk +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Job +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.jsonPrimitive +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import java.util.concurrent.atomic.AtomicInteger + +class ToolExecutorBatchTest { + + private data class In(val tag: String) + private data class Out(val echoed: String) + + private fun makeTool( + n: String, + safe: Boolean, + onCall: suspend () -> Unit = {}, + ): Tool = tool { + name = n + description = "test tool $n" + parse { raw -> + val obj = raw as JsonObject + In(tag = obj["tag"]?.jsonPrimitive?.content ?: "") + } + isConcurrencySafe { safe } + call { input, _, _ -> + onCall() + ToolResult(Out(echoed = input.tag)) + } + mapResult { out, id -> ToolResultBlock(id, out.echoed) } + } + + private fun ctx(): ToolExecutionContext = ToolExecutionContext( + project = mockk().also { every { it.basePath } returns "/tmp" }, + toolUseId = "batch", + abortJob = Job(), + ) + + private fun use(name: String, tag: String, id: String) = ToolUseBlock( + toolUseId = id, + name = name, + input = buildJsonObject { put("tag", JsonPrimitive(tag)) }, + ) + + @Test + fun `empty batch emits no updates`() = runBlocking { + val updates = runToolUseBatch( + toolUses = emptyList(), + pool = ToolRegistry.getAllBaseTools(), + context = ctx(), + ).toList() + assertTrue(updates.isEmpty()) + } + + @Test + fun `unknown tool emits Started then Failed(LOOKUP)`() = runBlocking { + val updates = runToolUseBatch( + toolUses = listOf(use("NopeTool", "x", "u1")), + pool = ToolRegistry.getAllBaseTools(), + context = ctx(), + ).toList() + + assertEquals(2, updates.size) + assertTrue(updates[0] is ToolUpdate.Started) + val failed = updates[1] as ToolUpdate.Failed + assertEquals(ToolUpdate.Failed.Stage.LOOKUP, failed.stage) + assertTrue(failed.message.contains("NopeTool")) + } + + @Test + fun `concurrency-safe tools execute in parallel`() = runBlocking { + // Each tool blocks on a barrier, then completes. If they ran serially, the + // second would observe peakInFlight == 1. If parallel, peak must reach 2. + val inFlight = AtomicInteger(0) + val peak = AtomicInteger(0) + val barrier = CompletableDeferred() + + val tool: Tool = makeTool(n = "P", safe = true) { + val now = inFlight.incrementAndGet() + peak.updateAndGet { maxOf(it, now) } + // First task triggers barrier release on its own schedule; both must + // be in flight for peak to hit 2. Keep simple: release after a short + // delay so both can enter. + if (now == 2) barrier.complete(Unit) + barrier.await() + inFlight.decrementAndGet() + } + + val updates = runToolUseBatch( + toolUses = listOf(use("P", "a", "1"), use("P", "b", "2")), + pool = listOf(tool), + context = ctx(), + ).toList() + + assertEquals(2, peak.get(), "safe tools must run in parallel; peak in-flight should be 2") + assertEquals( + 2, + updates.filterIsInstance().size, + "both tool_uses should complete", + ) + } + + @Test + fun `concurrency-unsafe tools execute serially`() = runBlocking { + val inFlight = AtomicInteger(0) + val peak = AtomicInteger(0) + + val tool: Tool = makeTool(n = "U", safe = false) { + val now = inFlight.incrementAndGet() + peak.updateAndGet { maxOf(it, now) } + delay(20) + inFlight.decrementAndGet() + } + + val updates = runToolUseBatch( + toolUses = listOf(use("U", "a", "1"), use("U", "b", "2"), use("U", "c", "3")), + pool = listOf(tool), + context = ctx(), + ).toList() + + assertEquals(1, peak.get(), "unsafe tools must serialize; peak in-flight should be 1") + assertEquals(3, updates.filterIsInstance().size) + } + + @Test + fun `mixed batch runs safe group concurrently then unsafe serially`() = runBlocking { + val safeCounter = AtomicInteger(0) + val unsafeCounter = AtomicInteger(0) + val safeTool: Tool = makeTool("S", safe = true) { safeCounter.incrementAndGet() } + val unsafeTool: Tool = makeTool("U", safe = false) { unsafeCounter.incrementAndGet() } + + val updates = runToolUseBatch( + toolUses = listOf( + use("S", "a", "1"), + use("U", "b", "2"), + use("S", "c", "3"), + use("U", "d", "4"), + ), + pool = listOf(safeTool, unsafeTool), + context = ctx(), + ).toList() + + assertEquals(2, safeCounter.get()) + assertEquals(2, unsafeCounter.get()) + + val completed = updates.filterIsInstance() + assertEquals(4, completed.size) + + // By design, safe tools complete before unsafe ones start. Unsafe completions + // should appear after all safe completions. + val safeIds = setOf("1", "3") + val firstUnsafeIndex = completed.indexOfFirst { it.toolUseId !in safeIds } + val lastSafeIndex = completed.indexOfLast { it.toolUseId in safeIds } + assertTrue( + lastSafeIndex < firstUnsafeIndex, + "safe tools should finish before unsafe group starts", + ) + } + + @Test + fun `maxConcurrency bound is respected`() = runBlocking { + val inFlight = AtomicInteger(0) + val peak = AtomicInteger(0) + val tool: Tool = makeTool("S", safe = true) { + val now = inFlight.incrementAndGet() + peak.updateAndGet { maxOf(it, now) } + delay(15) + inFlight.decrementAndGet() + } + + runToolUseBatch( + toolUses = (1..5).map { use("S", "t$it", "u$it") }, + pool = listOf(tool), + context = ctx(), + maxConcurrency = 2, + ).toList() + + assertTrue(peak.get() <= 2, "peak in-flight ${peak.get()} must respect maxConcurrency=2") + assertTrue(peak.get() >= 2, "should actually reach the cap with 5 tasks") + } + + @Test + fun `parse failure in one tool_use does not block others`() = runBlocking { + val tool: Tool = tool { + name = "Brittle" + description = "parse fails for special tag" + parse { raw -> + val obj = raw as JsonObject + val tag = obj["tag"]?.jsonPrimitive?.content ?: "" + if (tag == "BOOM") throw IllegalArgumentException("bad tag") + In(tag = tag) + } + isConcurrencySafe { true } + call { input, _, _ -> ToolResult(Out(echoed = input.tag)) } + mapResult { out, id -> ToolResultBlock(id, out.echoed) } + } + + val updates = runToolUseBatch( + toolUses = listOf(use("Brittle", "ok", "1"), use("Brittle", "BOOM", "2"), use("Brittle", "ok2", "3")), + pool = listOf(tool), + context = ctx(), + ).toList() + + val completed = updates.filterIsInstance().map { it.toolUseId }.toSet() + assertEquals(setOf("1", "3"), completed) + + val failed = updates.filterIsInstance() + assertEquals(1, failed.size) + assertEquals(ToolUpdate.Failed.Stage.PARSE, failed[0].stage) + assertEquals("2", failed[0].toolUseId) + } + + @Test + fun `per-tool_use updates remain ordered for the same id`() = runBlocking { + val tool: Tool = makeTool("S", safe = true) + val updates = runToolUseBatch( + toolUses = listOf(use("S", "a", "1"), use("S", "b", "2")), + pool = listOf(tool), + context = ctx(), + ).toList() + + // Started must come before Completed for each id, even though id order across + // the whole flow is non-deterministic in the concurrent group. + for (id in listOf("1", "2")) { + val startedIdx = updates.indexOfFirst { it is ToolUpdate.Started && it.toolUseId == id } + val completedIdx = updates.indexOfFirst { it is ToolUpdate.Completed && it.toolUseId == id } + assertTrue(startedIdx >= 0, "Started missing for id=$id") + assertTrue(completedIdx >= 0, "Completed missing for id=$id") + assertTrue(startedIdx < completedIdx, "Started must precede Completed for id=$id") + } + assertFalse(updates.any { it is ToolUpdate.Failed }) + } +} diff --git a/src/test/kotlin/com/github/codeplangui/tools/ToolRegistryTest.kt b/src/test/kotlin/com/github/codeplangui/tools/ToolRegistryTest.kt new file mode 100644 index 0000000..eb91d65 --- /dev/null +++ b/src/test/kotlin/com/github/codeplangui/tools/ToolRegistryTest.kt @@ -0,0 +1,103 @@ +package com.github.codeplangui.tools + +import com.intellij.openapi.project.Project +import io.mockk.every +import io.mockk.mockk +import kotlinx.coroutines.Job +import kotlinx.serialization.json.JsonObject +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test + +class ToolRegistryTest { + + private data class Dummy(val x: String) + + private fun dummyTool(n: String, vararg extraAliases: String): Tool = tool { + name = n + description = "dummy" + aliases = extraAliases.toList() + parse { _ -> Dummy("") } + call { d, _, _ -> ToolResult(d) } + mapResult { _, id -> ToolResultBlock(id, "") } + } + + private fun ctx(deny: Set = emptySet()): ToolExecutionContext = ToolExecutionContext( + project = mockk().also { every { it.basePath } returns "/tmp" }, + toolUseId = "t", + abortJob = Job(), + permissionContext = ToolPermissionContext.default().copy(alwaysDeny = deny), + ) + + @Test + fun `getAllBaseTools returns non-empty registry including Bash FileEdit and FileRead`() { + val tools = ToolRegistry.getAllBaseTools() + val names = tools.map { it.name }.toSet() + assertTrue("Bash" in names, "Bash must be registered") + assertTrue("FileEdit" in names, "FileEdit must be registered") + assertTrue("FileRead" in names, "FileRead must be registered") + } + + @Test + fun `assembleToolPool sorts by name and dedupes built-in over mcp on name collision`() { + val builtin = listOf(dummyTool("Bash"), dummyTool("FileRead")) + val mcp = listOf(dummyTool("FileRead"), dummyTool("MCPExtra")) // duplicate name + val pool = ToolRegistry.assembleToolPool(baseTools = builtin, mcpTools = mcp) + + assertEquals(listOf("Bash", "FileRead", "MCPExtra"), pool.map { it.name }) + // The FileRead in the pool must be the built-in, not the MCP one. + // Identity check is fine here — dedupe should keep the first occurrence. + assertTrue(pool[1] === builtin[1]) + } + + @Test + fun `filterToolsByDenyRules removes exact-name matches`() { + val tools = listOf(dummyTool("Bash"), dummyTool("FileRead")) + val filtered = ToolRegistry.filterToolsByDenyRules(tools, ctx(deny = setOf("Bash"))) + assertEquals(listOf("FileRead"), filtered.map { it.name }) + } + + @Test + fun `filterToolsByDenyRules removes parametrized rule matches like Bash(git push)`() { + val tools = listOf(dummyTool("Bash"), dummyTool("FileRead")) + val filtered = ToolRegistry.filterToolsByDenyRules(tools, ctx(deny = setOf("Bash(git push *)"))) + assertEquals(listOf("FileRead"), filtered.map { it.name }) + } + + @Test + fun `filterToolsByDenyRules noop when rules empty`() { + val tools = ToolRegistry.getAllBaseTools() + val filtered = ToolRegistry.filterToolsByDenyRules(tools, ctx()) + assertEquals(tools.size, filtered.size) + } + + @Test + fun `findByName returns tool by primary name`() { + val pool: List> = listOf(dummyTool("Foo"), dummyTool("Bar")) + assertNotNull(ToolRegistry.findByName(pool, "Foo")) + assertNotNull(ToolRegistry.findByName(pool, "Bar")) + assertNull(ToolRegistry.findByName(pool, "Missing")) + } + + @Test + fun `findByName returns tool by alias`() { + val pool: List> = listOf(dummyTool("Canonical", "LegacyName", "OldName")) + assertEquals("Canonical", ToolRegistry.findByName(pool, "LegacyName")?.name) + assertEquals("Canonical", ToolRegistry.findByName(pool, "OldName")?.name) + assertNull(ToolRegistry.findByName(pool, "Nope")) + } + + @Test + fun `assembleToolPool with default args includes real built-in tools`() { + val pool = ToolRegistry.assembleToolPool() + val names = pool.map { it.name } + assertEquals(names.sorted(), names, "pool must be sorted for prompt-cache stability") + assertTrue("Bash" in names && "FileRead" in names) + } + + // Ensure Dummy's schema compiles — dummyTool uses JsonObject default + @Suppress("unused") + private val schemaSanity: JsonObject = JsonObject(emptyMap()) +} diff --git a/src/test/kotlin/com/github/codeplangui/tools/file/FileEditToolTest.kt b/src/test/kotlin/com/github/codeplangui/tools/file/FileEditToolTest.kt new file mode 100644 index 0000000..0fef897 --- /dev/null +++ b/src/test/kotlin/com/github/codeplangui/tools/file/FileEditToolTest.kt @@ -0,0 +1,274 @@ +package com.github.codeplangui.tools.file + +import com.github.codeplangui.tools.PermissionResult +import com.github.codeplangui.tools.ToolExecutionContext +import com.github.codeplangui.tools.ToolPermissionContext +import com.github.codeplangui.tools.ToolUpdate +import com.github.codeplangui.tools.ValidationResult +import com.github.codeplangui.tools.runToolUse +import com.intellij.openapi.project.Project +import io.mockk.every +import io.mockk.mockk +import kotlinx.coroutines.Job +import kotlinx.coroutines.flow.filterIsInstance +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.io.TempDir +import java.io.File +import java.nio.file.Path + +class FileEditToolTest { + + @TempDir + lateinit var tmp: Path + + private fun ctx( + mode: ToolPermissionContext.Mode = ToolPermissionContext.Mode.DEFAULT, + onPermission: suspend (ToolUpdate.PermissionAsked) -> Boolean = { true }, + ): ToolExecutionContext { + val project = mockk() + every { project.basePath } returns tmp.toString() + return ToolExecutionContext( + project = project, + toolUseId = "test-id", + abortJob = Job(), + permissionContext = ToolPermissionContext.default().copy(mode = mode), + onPermissionAsked = onPermission, + ) + } + + private fun writeFile(name: String, content: String): File = + tmp.resolve(name).toFile().also { it.writeText(content, Charsets.UTF_8) } + + // ─── validate ──────────────────────────────────────────────────── + + @Test + fun `validate rejects blank path`() = runBlocking { + val result = FileEditTool.validateInput(FileEditInput("", "old", "new"), ctx()) + assertTrue(result is ValidationResult.Failed) + } + + @Test + fun `validate rejects empty oldString`() = runBlocking { + val result = FileEditTool.validateInput(FileEditInput("file.txt", "", "new"), ctx()) + assertTrue(result is ValidationResult.Failed) + } + + @Test + fun `validate rejects identical oldString and newString`() = runBlocking { + val result = FileEditTool.validateInput(FileEditInput("file.txt", "same", "same"), ctx()) + assertTrue(result is ValidationResult.Failed) + } + + @Test + fun `validate passes for valid input`() = runBlocking { + val result = FileEditTool.validateInput(FileEditInput("file.txt", "old", "new"), ctx()) + assertEquals(ValidationResult.Ok, result) + } + + // ─── checkPermissions ──────────────────────────────────────────── + + @Test + fun `checkPermissions denies when file not found`() = runBlocking { + val result = FileEditTool.checkPermissions( + FileEditInput("nonexistent.txt", "old", "new"), ctx() + ) + assertTrue(result is PermissionResult.Deny) + } + + @Test + fun `checkPermissions denies when path escapes workspace`() = runBlocking { + val result = FileEditTool.checkPermissions( + FileEditInput("../../etc/passwd", "old", "new"), ctx() + ) + assertTrue(result is PermissionResult.Deny) + } + + @Test + fun `checkPermissions denies when oldString not found in file`() = runBlocking { + writeFile("target.txt", "hello world") + val result = FileEditTool.checkPermissions( + FileEditInput("target.txt", "MISSING", "new"), ctx() + ) + val denied = result as PermissionResult.Deny + assertTrue(denied.reason.contains("not found")) + } + + @Test + fun `checkPermissions returns Ask in DEFAULT mode`() = runBlocking { + writeFile("target.txt", "foo bar baz") + val result = FileEditTool.checkPermissions( + FileEditInput("target.txt", "bar", "qux"), ctx() + ) + assertTrue(result is PermissionResult.Ask) + val ask = result as PermissionResult.Ask + assertFalse(ask.preview?.details.isNullOrBlank()) + } + + @Test + fun `checkPermissions returns Allow in ACCEPT_EDITS mode`() = runBlocking { + writeFile("target.txt", "foo bar baz") + val result = FileEditTool.checkPermissions( + FileEditInput("target.txt", "bar", "qux"), + ctx(mode = ToolPermissionContext.Mode.ACCEPT_EDITS), + ) + assertTrue(result is PermissionResult.Allow) + } + + @Test + fun `checkPermissions returns Allow in BYPASS mode`() = runBlocking { + writeFile("target.txt", "foo bar baz") + val result = FileEditTool.checkPermissions( + FileEditInput("target.txt", "bar", "qux"), + ctx(mode = ToolPermissionContext.Mode.BYPASS), + ) + assertTrue(result is PermissionResult.Allow) + } + + // ─── preview ───────────────────────────────────────────────────── + + @Test + fun `preview returns null when oldString absent`() = runBlocking { + writeFile("target.txt", "no match here") + val result = FileEditTool.preview(FileEditInput("target.txt", "MISSING", "x"), ctx()) + assertNull(result) + } + + @Test + fun `preview returns diff details when oldString found`() = runBlocking { + writeFile("target.txt", "line1\nfoo\nline3") + val result = FileEditTool.preview(FileEditInput("target.txt", "foo", "bar"), ctx()) + assertTrue(result != null) + assertTrue(result!!.details!!.contains("-foo")) + assertTrue(result.details!!.contains("+bar")) + } + + // ─── call ──────────────────────────────────────────────────────── + + @Test + fun `call replaces first occurrence only by default`() = runBlocking { + val file = writeFile("edit.txt", "aaa bbb aaa") + val result = FileEditTool.call( + FileEditInput("edit.txt", "aaa", "ZZZ"), + ctx(mode = ToolPermissionContext.Mode.BYPASS), + ) + assertEquals("ZZZ bbb aaa", file.readText()) + assertEquals(1, result.data.replacements) + } + + @Test + fun `call replaces all occurrences when replaceAll is true`() = runBlocking { + val file = writeFile("edit.txt", "aaa bbb aaa") + val result = FileEditTool.call( + FileEditInput("edit.txt", "aaa", "ZZZ", replaceAll = true), + ctx(mode = ToolPermissionContext.Mode.BYPASS), + ) + assertEquals("ZZZ bbb ZZZ", file.readText()) + assertEquals(2, result.data.replacements) + } + + @Test + fun `call produces a unified diff in output`() = runBlocking { + writeFile("edit.txt", "line1\nold\nline3") + val result = FileEditTool.call( + FileEditInput("edit.txt", "old", "new"), + ctx(mode = ToolPermissionContext.Mode.BYPASS), + ) + assertTrue(result.data.diff.contains("-old")) + assertTrue(result.data.diff.contains("+new")) + } + + // ─── metadata ──────────────────────────────────────────────────── + + @Test + fun `FileEditTool metadata flags are correct`() = runBlocking { + val input = FileEditInput("f.txt", "a", "b") + assertFalse(FileEditTool.isConcurrencySafe(input)) + assertFalse(FileEditTool.isReadOnly(input)) + assertTrue(FileEditTool.isDestructive(input)) + } + + // ─── permission denied via callback ────────────────────────────── + + @Test + fun `runToolUse emits Failed(PERMISSION) when callback denies`() = runBlocking { + writeFile("edit.txt", "hello world") + val toolUse = com.github.codeplangui.tools.ToolUseBlock( + toolUseId = "u1", + name = "FileEdit", + input = buildJsonObject { + put("path", "edit.txt") + put("oldString", "hello") + put("newString", "bye") + }, + ) + val updates = runToolUse(FileEditTool, toolUse, ctx(onPermission = { false })).toList() + val failed = updates.filterIsInstance().firstOrNull() + assertTrue(failed != null, "Expected a Failed update") + assertEquals(ToolUpdate.Failed.Stage.PERMISSION, failed!!.stage) + } + + @Test + fun `runToolUse emits PermissionAsked before invoking callback`() = runBlocking { + writeFile("edit.txt", "hello world") + val toolUse = com.github.codeplangui.tools.ToolUseBlock( + toolUseId = "u2", + name = "FileEdit", + input = buildJsonObject { + put("path", "edit.txt") + put("oldString", "hello") + put("newString", "bye") + }, + ) + val asked = runToolUse(FileEditTool, toolUse, ctx()) + .filterIsInstance() + .first() + assertEquals("FileEdit", asked.toolName) + } + + // ─── buildUnifiedDiff unit tests ────────────────────────────────── + + @Test + fun `buildUnifiedDiff shows changed lines with context`() { + val original = "line1\nline2\nline3\nline4\nline5" + val modified = "line1\nline2\nchanged3\nline4\nline5" + val diff = buildUnifiedDiff(original, modified, "test.txt") + assertTrue(diff.contains("-line3")) + assertTrue(diff.contains("+changed3")) + assertTrue(diff.contains(" line2")) + assertTrue(diff.contains(" line4")) + } + + @Test + fun `buildUnifiedDiff returns no-change message for identical content`() { + val text = "same\ncontent" + val diff = buildUnifiedDiff(text, text, "test.txt") + assertEquals("(no line changes)", diff) + } + + @Test + fun `buildUnifiedDiff handles insertion`() { + val original = "a\nb" + val modified = "a\ninserted\nb" + val diff = buildUnifiedDiff(original, modified, "test.txt") + assertTrue(diff.contains("+inserted")) + assertFalse(diff.contains("-inserted")) + } + + @Test + fun `buildUnifiedDiff handles deletion`() { + val original = "a\ndelete_me\nb" + val modified = "a\nb" + val diff = buildUnifiedDiff(original, modified, "test.txt") + assertTrue(diff.contains("-delete_me")) + assertFalse(diff.contains("+delete_me")) + } +} diff --git a/src/test/kotlin/com/github/codeplangui/tools/file/FileListToolTest.kt b/src/test/kotlin/com/github/codeplangui/tools/file/FileListToolTest.kt new file mode 100644 index 0000000..e6f6335 --- /dev/null +++ b/src/test/kotlin/com/github/codeplangui/tools/file/FileListToolTest.kt @@ -0,0 +1,140 @@ +package com.github.codeplangui.tools.file + +import com.github.codeplangui.tools.PermissionResult +import com.github.codeplangui.tools.ToolExecutionContext +import com.github.codeplangui.tools.ToolPermissionContext +import com.github.codeplangui.tools.ToolUpdate +import com.github.codeplangui.tools.ValidationResult +import com.github.codeplangui.tools.runToolUse +import com.intellij.openapi.project.Project +import io.mockk.every +import io.mockk.mockk +import kotlinx.coroutines.Job +import kotlinx.coroutines.flow.filterIsInstance +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import com.github.codeplangui.tools.ToolUseBlock +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.io.TempDir +import java.nio.file.Path + +private fun toolUseBlock(id: String, input: JsonElement) = + ToolUseBlock(toolUseId = id, name = FileListTool.name, input = input) + +class FileListToolTest { + + @TempDir + lateinit var tmp: Path + + private fun ctx(): ToolExecutionContext { + val project = mockk() + every { project.basePath } returns tmp.toString() + return ToolExecutionContext( + project = project, + toolUseId = "test-id", + abortJob = Job(), + permissionContext = ToolPermissionContext.default(), + ) + } + + private fun mkFile(rel: String): java.io.File = + tmp.resolve(rel).toFile().also { it.parentFile?.mkdirs(); it.createNewFile() } + + // ─── validate ──────────────────────────────────────────────────── + + @Test + fun `validate rejects blank path`() = runBlocking { + val result = FileListTool.validateInput(FileListInput(""), ctx()) + assertTrue(result is ValidationResult.Failed) + } + + @Test + fun `validate accepts dot path`() = runBlocking { + val result = FileListTool.validateInput(FileListInput("."), ctx()) + assertTrue(result is ValidationResult.Ok) + } + + // ─── permissions ───────────────────────────────────────────────── + + @Test + fun `permissions deny when project path unavailable`() = runBlocking { + val project = mockk() + every { project.basePath } returns null + val ctx = ToolExecutionContext(project = project, toolUseId = "t", abortJob = Job()) + val result = FileListTool.checkPermissions(FileListInput("."), ctx) + assertTrue(result is PermissionResult.Deny) + } + + @Test + fun `permissions deny for path outside workspace`() = runBlocking { + val result = FileListTool.checkPermissions(FileListInput("../../etc"), ctx()) + assertTrue(result is PermissionResult.Deny) + } + + @Test + fun `permissions allow for workspace path`() = runBlocking { + val result = FileListTool.checkPermissions(FileListInput("."), ctx()) + assertTrue(result is PermissionResult.Allow) + } + + // ─── call ──────────────────────────────────────────────────────── + + @Test + fun `lists files in project root`() = runBlocking { + mkFile("alpha.txt") + mkFile("beta.txt") + val input = buildJsonObject { put("path", ".") } + val update = runToolUse(FileListTool, toolUseBlock("l1", input), ctx()) + .filterIsInstance().first() + assertTrue(update.block.content.contains("alpha.txt")) + assertTrue(update.block.content.contains("beta.txt")) + } + + @Test + fun `lists recursively when recursive=true`() = runBlocking { + mkFile("sub/deep.kt") + val input = buildJsonObject { put("path", "."); put("recursive", true) } + val update = runToolUse(FileListTool, toolUseBlock("l2", input), ctx()) + .filterIsInstance().first() + assertTrue(update.block.content.contains("deep.kt")) + } + + @Test + fun `excludes hidden files by default`() = runBlocking { + mkFile(".hidden") + mkFile("visible.txt") + val input = buildJsonObject { put("path", ".") } + val update = runToolUse(FileListTool, toolUseBlock("l3", input), ctx()) + .filterIsInstance().first() + assertTrue(!update.block.content.contains(".hidden")) + assertTrue(update.block.content.contains("visible.txt")) + } + + @Test + fun `includes hidden files when includeHidden=true`() = runBlocking { + mkFile(".hidden") + val input = buildJsonObject { put("path", "."); put("includeHidden", true) } + val update = runToolUse(FileListTool, toolUseBlock("l4", input), ctx()) + .filterIsInstance().first() + assertTrue(update.block.content.contains(".hidden")) + } + + @Test + fun `fails for non-existent path`() = runBlocking { + val input = buildJsonObject { put("path", "no_such_dir") } + val updates = runToolUse(FileListTool, toolUseBlock("l5", input), ctx()).toList() + assertTrue(updates.any { it is ToolUpdate.Failed }) + } + + @Test + fun `fails for path outside workspace`() = runBlocking { + val input = buildJsonObject { put("path", "../../") } + val updates = runToolUse(FileListTool, toolUseBlock("l6", input), ctx()).toList() + assertTrue(updates.any { it is ToolUpdate.Failed }) + } +} diff --git a/src/test/kotlin/com/github/codeplangui/tools/file/FileReadToolTest.kt b/src/test/kotlin/com/github/codeplangui/tools/file/FileReadToolTest.kt new file mode 100644 index 0000000..e4e9297 --- /dev/null +++ b/src/test/kotlin/com/github/codeplangui/tools/file/FileReadToolTest.kt @@ -0,0 +1,168 @@ +package com.github.codeplangui.tools.file + +import com.github.codeplangui.tools.PermissionResult +import com.github.codeplangui.tools.Progress +import com.github.codeplangui.tools.ToolExecutionContext +import com.github.codeplangui.tools.ToolPermissionContext +import com.github.codeplangui.tools.ValidationResult +import com.intellij.openapi.project.Project +import io.mockk.every +import io.mockk.mockk +import kotlinx.coroutines.Job +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.io.TempDir +import java.nio.file.Path +import kotlin.io.path.writeText + +class FileReadToolTest { + + private fun contextFor( + @TempDir base: Path, + additionalDirs: Set = emptySet(), + ): ToolExecutionContext { + val permissionCtx = ToolPermissionContext.default().copy(additionalWorkingDirectories = additionalDirs) + return ToolExecutionContext( + project = mockk().also { every { it.basePath } returns base.toString() }, + toolUseId = "t1", + abortJob = Job(), + permissionContext = permissionCtx, + ) + } + + private fun input(path: String, offset: Int? = null, limit: Int? = null) = buildJsonObject { + put("path", JsonPrimitive(path)) + offset?.let { put("offset", JsonPrimitive(it)) } + limit?.let { put("limit", JsonPrimitive(it)) } + } + + @Test + fun `happy path returns line-numbered content with correct totals`(@TempDir base: Path) = runBlocking { + val file = base.resolve("hello.txt") + file.writeText("alpha\nbeta\ngamma\n") + val ctx = contextFor(base) + + val parsed = FileReadTool.parseInput(input("hello.txt")) + val result = FileReadTool.call(parsed, ctx) { _: Progress -> } + val out = result.data + + assertEquals(3, out.totalLines) + assertEquals(3, out.returnedLines) + assertFalse(out.truncated) + // Line format: right-padded width, arrow, content + assertTrue(out.content.contains("1→alpha")) + assertTrue(out.content.contains("3→gamma")) + } + + @Test + fun `offset and limit page through the file correctly`(@TempDir base: Path) = runBlocking { + val file = base.resolve("big.txt") + file.writeText((1..20).joinToString("\n") { "line$it" }) + val ctx = contextFor(base) + + val parsed = FileReadTool.parseInput(input("big.txt", offset = 5, limit = 3)) + val out = FileReadTool.call(parsed, ctx) { _: Progress -> }.data + + assertEquals(3, out.returnedLines) + assertEquals(20, out.totalLines) + assertTrue(out.truncated) // endIdx = 8 < 20 + assertTrue(out.content.contains("5→line5")) + assertTrue(out.content.contains("7→line7")) + assertFalse(out.content.contains("line4")) + assertFalse(out.content.contains("line8")) + } + + @Test + fun `validation rejects blank path and out-of-range offset-limit`(@TempDir base: Path) = runBlocking { + val ctx = contextFor(base) + + val blank = FileReadTool.validateInput(FileReadInput(path = ""), ctx) + assertTrue(blank is ValidationResult.Failed) + + val badOffset = FileReadTool.validateInput(FileReadInput(path = "x", offset = 0), ctx) + assertTrue(badOffset is ValidationResult.Failed) + + val badLimit = FileReadTool.validateInput(FileReadInput(path = "x", limit = FILE_READ_MAX_LINES + 1), ctx) + assertTrue(badLimit is ValidationResult.Failed) + + val good = FileReadTool.validateInput(FileReadInput(path = "x", offset = 1, limit = 500), ctx) + assertEquals(ValidationResult.Ok, good) + } + + @Test + fun `permission denies path outside workspace`(@TempDir base: Path) = runBlocking { + val ctx = contextFor(base) + val outsideInput = FileReadInput(path = "../escape.txt") + val perm = FileReadTool.checkPermissions(outsideInput, ctx) + assertTrue(perm is PermissionResult.Deny, "expected Deny, got $perm") + } + + @Test + fun `permission allows path inside workspace`(@TempDir base: Path) = runBlocking { + val ctx = contextFor(base) + val perm = FileReadTool.checkPermissions(FileReadInput(path = "sub/file.txt"), ctx) + assertTrue(perm is PermissionResult.Allow, "expected Allow, got $perm") + } + + @Test + fun `permission allows path inside additional working directory`(@TempDir base: Path, @TempDir extra: Path) = runBlocking { + val ctx = contextFor(base, additionalDirs = setOf(extra.toString())) + val externalInput = FileReadInput(path = extra.resolve("x.txt").toString()) + val perm = FileReadTool.checkPermissions(externalInput, ctx) + assertTrue(perm is PermissionResult.Allow, "expected Allow for additional dir, got $perm") + } + + @Test + fun `file larger than byte cap returns truncated output`(@TempDir base: Path) = runBlocking { + val file = base.resolve("huge.txt") + // Write > 2MiB of text. Each line ~50 bytes → 50k lines ≈ 2.5MiB + val sb = StringBuilder() + repeat(50_000) { sb.append("padding-line-").append(it).append('\n') } + file.writeText(sb.toString()) + val ctx = contextFor(base) + + val parsed = FileReadTool.parseInput(input("huge.txt")) + val out = FileReadTool.call(parsed, ctx) { _: Progress -> }.data + + assertTrue(out.truncated, "huge file must report truncated") + assertTrue(out.returnedLines <= FILE_READ_DEFAULT_LIMIT) + } + + @Test + fun `metadata predicates mark tool as read-only and concurrency-safe`() { + val input = FileReadInput(path = "x") + assertTrue(FileReadTool.isReadOnly(input)) + assertTrue(FileReadTool.isConcurrencySafe(input)) + assertFalse(FileReadTool.isDestructive(input)) + } + + @Test + fun `preview returns null — read operations have no side effects`(@TempDir base: Path) = runBlocking { + val ctx = contextFor(base) + val preview = FileReadTool.preview(FileReadInput(path = "x"), ctx) + assertEquals(null, preview, "FileRead should not expose a preview") + } + + @Test + fun `mapResult includes path and total_lines header`(@TempDir base: Path) = runBlocking { + val file = base.resolve("z.txt") + file.writeText("one\ntwo\n") + val ctx = contextFor(base) + + val parsed = FileReadTool.parseInput(input("z.txt")) + val result = FileReadTool.call(parsed, ctx) { _: Progress -> } + val block = FileReadTool.mapResultToApiBlock(result.data, "tid") + + assertEquals("tid", block.toolUseId) + assertFalse(block.isError) + assertTrue(block.content.contains("path: ")) + assertTrue(block.content.contains("total_lines: 2")) + assertNotNull(block.content.contains("1→one")) + } +} diff --git a/src/test/kotlin/com/github/codeplangui/tools/file/FileSearchToolTest.kt b/src/test/kotlin/com/github/codeplangui/tools/file/FileSearchToolTest.kt new file mode 100644 index 0000000..a77d0d9 --- /dev/null +++ b/src/test/kotlin/com/github/codeplangui/tools/file/FileSearchToolTest.kt @@ -0,0 +1,164 @@ +package com.github.codeplangui.tools.file + +import com.github.codeplangui.tools.PermissionResult +import com.github.codeplangui.tools.ToolExecutionContext +import com.github.codeplangui.tools.ToolPermissionContext +import com.github.codeplangui.tools.ToolUpdate +import com.github.codeplangui.tools.ValidationResult +import com.github.codeplangui.tools.runToolUse +import com.intellij.openapi.project.Project +import io.mockk.every +import io.mockk.mockk +import kotlinx.coroutines.Job +import kotlinx.coroutines.flow.filterIsInstance +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import com.github.codeplangui.tools.ToolUseBlock +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.io.TempDir +import java.nio.file.Path + +private fun toolUseBlock(id: String, input: JsonElement) = + ToolUseBlock(toolUseId = id, name = FileSearchTool.name, input = input) + +class FileSearchToolTest { + + @TempDir + lateinit var tmp: Path + + private fun ctx(): ToolExecutionContext { + val project = mockk() + every { project.basePath } returns tmp.toString() + return ToolExecutionContext( + project = project, + toolUseId = "test-id", + abortJob = Job(), + permissionContext = ToolPermissionContext.default(), + ) + } + + private fun mkFile(rel: String, content: String): java.io.File = + tmp.resolve(rel).toFile().also { it.parentFile?.mkdirs(); it.writeText(content) } + + // ─── validate ──────────────────────────────────────────────────── + + @Test + fun `validate rejects blank pattern`() = runBlocking { + val result = FileSearchTool.validateInput(FileSearchInput(""), ctx()) + assertTrue(result is ValidationResult.Failed) + } + + @Test + fun `validate rejects invalid regex`() = runBlocking { + val result = FileSearchTool.validateInput(FileSearchInput("[invalid"), ctx()) + assertTrue(result is ValidationResult.Failed) + } + + @Test + fun `validate rejects maxResults out of range`() = runBlocking { + val result = FileSearchTool.validateInput(FileSearchInput("ok", maxResults = 0), ctx()) + assertTrue(result is ValidationResult.Failed) + } + + @Test + fun `validate accepts valid pattern`() = runBlocking { + val result = FileSearchTool.validateInput(FileSearchInput("hello"), ctx()) + assertTrue(result is ValidationResult.Ok) + } + + // ─── permissions ───────────────────────────────────────────────── + + @Test + fun `permissions deny when project path unavailable`() = runBlocking { + val project = mockk() + every { project.basePath } returns null + val ctx = ToolExecutionContext(project = project, toolUseId = "t", abortJob = Job()) + val result = FileSearchTool.checkPermissions(FileSearchInput("x"), ctx) + assertTrue(result is PermissionResult.Deny) + } + + @Test + fun `permissions deny for path outside workspace`() = runBlocking { + val result = FileSearchTool.checkPermissions(FileSearchInput("x", path = "../../etc"), ctx()) + assertTrue(result is PermissionResult.Deny) + } + + @Test + fun `permissions allow for workspace path`() = runBlocking { + val result = FileSearchTool.checkPermissions(FileSearchInput("x"), ctx()) + assertTrue(result is PermissionResult.Allow) + } + + // ─── call ──────────────────────────────────────────────────────── + + @Test + fun `finds matching lines`() = runBlocking { + mkFile("a.txt", "hello world\nfoo bar") + val input = buildJsonObject { put("pattern", "hello") } + val update = runToolUse(FileSearchTool, toolUseBlock("s1", input), ctx()) + .filterIsInstance().first() + assertTrue(update.block.content.contains("hello world")) + } + + @Test + fun `returns no matches message when nothing found`() = runBlocking { + mkFile("a.txt", "nothing here") + val input = buildJsonObject { put("pattern", "xyzzy") } + val update = runToolUse(FileSearchTool, toolUseBlock("s2", input), ctx()) + .filterIsInstance().first() + assertTrue(update.block.content.contains("no matches")) + } + + @Test + fun `glob restricts to matching file types`() = runBlocking { + mkFile("code.kt", "fun main() {}") + mkFile("readme.md", "fun main() {}") + val input = buildJsonObject { put("pattern", "fun main"); put("glob", "*.kt") } + val update = runToolUse(FileSearchTool, toolUseBlock("s3", input), ctx()) + .filterIsInstance().first() + assertTrue(update.block.content.contains("code.kt")) + assertFalse(update.block.content.contains("readme.md")) + } + + @Test + fun `ignoreCase matches regardless of case`() = runBlocking { + mkFile("b.txt", "Hello World") + val input = buildJsonObject { put("pattern", "hello"); put("ignoreCase", true) } + val update = runToolUse(FileSearchTool, toolUseBlock("s4", input), ctx()) + .filterIsInstance().first() + assertTrue(update.block.content.contains("Hello World")) + } + + @Test + fun `case sensitive search does not match wrong case`() = runBlocking { + mkFile("c.txt", "Hello World") + val input = buildJsonObject { put("pattern", "hello"); put("ignoreCase", false) } + val update = runToolUse(FileSearchTool, toolUseBlock("s5", input), ctx()) + .filterIsInstance().first() + assertTrue(update.block.content.contains("no matches")) + } + + @Test + fun `maxResults caps the result count`() = runBlocking { + mkFile("big.txt", (1..20).joinToString("\n") { "line $it match" }) + val input = buildJsonObject { put("pattern", "match"); put("maxResults", 5) } + val update = runToolUse(FileSearchTool, toolUseBlock("s6", input), ctx()) + .filterIsInstance().first() + val lineCount = update.block.content.lines().count { it.contains("match") } + assertTrue(lineCount <= 5) + assertTrue(update.block.content.contains("truncated")) + } + + @Test + fun `fails for path outside workspace`() = runBlocking { + val input = buildJsonObject { put("pattern", "x"); put("path", "../../") } + val updates = runToolUse(FileSearchTool, toolUseBlock("s7", input), ctx()).toList() + assertTrue(updates.any { it is ToolUpdate.Failed }) + } +} diff --git a/src/test/kotlin/com/github/codeplangui/tools/file/WriteFileToolTest.kt b/src/test/kotlin/com/github/codeplangui/tools/file/WriteFileToolTest.kt new file mode 100644 index 0000000..718d60d --- /dev/null +++ b/src/test/kotlin/com/github/codeplangui/tools/file/WriteFileToolTest.kt @@ -0,0 +1,145 @@ +package com.github.codeplangui.tools.file + +import com.github.codeplangui.tools.PermissionResult +import com.github.codeplangui.tools.ToolExecutionContext +import com.github.codeplangui.tools.ToolPermissionContext +import com.github.codeplangui.tools.ToolUpdate +import com.github.codeplangui.tools.ValidationResult +import com.github.codeplangui.tools.runToolUse +import com.intellij.openapi.project.Project +import io.mockk.every +import io.mockk.mockk +import kotlinx.coroutines.Job +import kotlinx.coroutines.flow.filterIsInstance +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import com.github.codeplangui.tools.ToolUseBlock +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.io.TempDir +import java.io.File +import java.nio.file.Path + +private fun toolUseBlock(id: String, input: JsonElement) = + ToolUseBlock(toolUseId = id, name = WriteFileTool.name, input = input) + +class WriteFileToolTest { + + @TempDir + lateinit var tmp: Path + + private fun ctx( + mode: ToolPermissionContext.Mode = ToolPermissionContext.Mode.ACCEPT_EDITS, + onPermission: suspend (ToolUpdate.PermissionAsked) -> Boolean = { true }, + ): ToolExecutionContext { + val project = mockk() + every { project.basePath } returns tmp.toString() + return ToolExecutionContext( + project = project, + toolUseId = "test-id", + abortJob = Job(), + permissionContext = ToolPermissionContext.default().copy(mode = mode), + onPermissionAsked = onPermission, + ) + } + + // ─── validate ──────────────────────────────────────────────────── + + @Test + fun `validate rejects blank path`() = runBlocking { + val result = WriteFileTool.validateInput(WriteFileInput("", "content"), ctx()) + assertTrue(result is ValidationResult.Failed) + } + + @Test + fun `validate accepts valid input`() = runBlocking { + val result = WriteFileTool.validateInput(WriteFileInput("file.txt", "hello"), ctx()) + assertEquals(ValidationResult.Ok, result) + } + + // ─── permissions ───────────────────────────────────────────────── + + @Test + fun `permissions deny when project path is unavailable`() = runBlocking { + val project = mockk() + every { project.basePath } returns null + val ctx = ToolExecutionContext(project = project, toolUseId = "t", abortJob = Job()) + val result = WriteFileTool.checkPermissions(WriteFileInput("a.txt", "x"), ctx) + assertTrue(result is PermissionResult.Deny) + } + + @Test + fun `permissions deny for path outside workspace`() = runBlocking { + val result = WriteFileTool.checkPermissions(WriteFileInput("../../etc/passwd", "bad"), ctx()) + assertTrue(result is PermissionResult.Deny) + } + + @Test + fun `permissions allow in ACCEPT_EDITS mode`() = runBlocking { + val result = WriteFileTool.checkPermissions(WriteFileInput("new.txt", "hi"), ctx(ToolPermissionContext.Mode.ACCEPT_EDITS)) + assertTrue(result is PermissionResult.Allow) + } + + @Test + fun `permissions allow in BYPASS mode`() = runBlocking { + val result = WriteFileTool.checkPermissions(WriteFileInput("new.txt", "hi"), ctx(ToolPermissionContext.Mode.BYPASS)) + assertTrue(result is PermissionResult.Allow) + } + + @Test + fun `permissions ask in DEFAULT mode`() = runBlocking { + val result = WriteFileTool.checkPermissions(WriteFileInput("new.txt", "hi"), ctx(ToolPermissionContext.Mode.DEFAULT)) + assertTrue(result is PermissionResult.Ask) + } + + // ─── call ──────────────────────────────────────────────────────── + + @Test + fun `creates new file with correct content`() = runBlocking { + val input = buildJsonObject { put("path", "hello.txt"); put("content", "world") } + val update = runToolUse(WriteFileTool, toolUseBlock("w1", input), ctx()) + .filterIsInstance().first() + val file = tmp.resolve("hello.txt").toFile() + assertTrue(file.exists()) + assertEquals("world", file.readText()) + assertTrue(update.block.content.contains("Created")) + } + + @Test + fun `overwrites existing file`() = runBlocking { + val file = tmp.resolve("existing.txt").toFile().also { it.writeText("old") } + val input = buildJsonObject { put("path", "existing.txt"); put("content", "new") } + val update = runToolUse(WriteFileTool, toolUseBlock("w2", input), ctx()) + .filterIsInstance().first() + assertEquals("new", file.readText()) + assertTrue(update.block.content.contains("Wrote")) + } + + @Test + fun `creates parent directories`() = runBlocking { + val input = buildJsonObject { put("path", "a/b/c/deep.txt"); put("content", "deep") } + runToolUse(WriteFileTool, toolUseBlock("w3", input), ctx()) + .filterIsInstance().first() + assertTrue(tmp.resolve("a/b/c/deep.txt").toFile().exists()) + } + + @Test + fun `reports bytes written`() = runBlocking { + val input = buildJsonObject { put("path", "bytes.txt"); put("content", "abc") } + val update = runToolUse(WriteFileTool, toolUseBlock("w4", input), ctx()) + .filterIsInstance().first() + assertTrue(update.block.content.contains("3 bytes")) + } + + @Test + fun `fails for path outside workspace`() = runBlocking { + val input = buildJsonObject { put("path", "../../evil.txt"); put("content", "x") } + val updates = runToolUse(WriteFileTool, toolUseBlock("w5", input), ctx()).toList() + assertTrue(updates.any { it is ToolUpdate.Failed }) + } +} diff --git a/src/test/kotlin/com/github/codeplangui/tools/mcp/McpClientTest.kt b/src/test/kotlin/com/github/codeplangui/tools/mcp/McpClientTest.kt new file mode 100644 index 0000000..8085411 --- /dev/null +++ b/src/test/kotlin/com/github/codeplangui/tools/mcp/McpClientTest.kt @@ -0,0 +1,207 @@ +package com.github.codeplangui.tools.mcp + +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonArray +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import java.io.PipedInputStream +import java.io.PipedOutputStream + +/** + * Unit tests for [McpClient] using piped streams as a fake MCP server. + * + * Architecture: each test has two sides: + * - client side: the [McpClient] under test (writes to serverIn, reads from serverOut) + * - server side: the test writes responses to serverOut and reads requests from serverIn + */ +class McpClientTest { + + private val json = Json { ignoreUnknownKeys = true; encodeDefaults = false } + + // Pipes: client reads from clientIn (server writes there), client writes to clientOut (server reads there) + private val serverToClientOut = PipedOutputStream() + private val clientIn = PipedInputStream(serverToClientOut) + private val clientToServerOut = PipedOutputStream() + private val serverIn = PipedInputStream(clientToServerOut) + + private lateinit var scope: CoroutineScope + private lateinit var client: McpClient + private val serverWriter = serverToClientOut.bufferedWriter(Charsets.UTF_8) + private val serverReader = serverIn.bufferedReader(Charsets.UTF_8) + + @BeforeEach + fun setup() { + scope = CoroutineScope(Dispatchers.IO + SupervisorJob()) + client = McpClient.fromStreams("test-server", clientIn, clientToServerOut, scope) + } + + @AfterEach + fun tearDown() { + client.close() + scope.cancel() + } + + // ─── Helpers ───────────────────────────────────────────────────────────── + + private fun respond(id: Int, result: kotlinx.serialization.json.JsonElement) { + val resp = buildJsonObject { + put("jsonrpc", "2.0") + put("id", id) + put("result", result) + } + serverWriter.write(json.encodeToString(kotlinx.serialization.json.JsonObject.serializer(), resp)) + serverWriter.newLine() + serverWriter.flush() + } + + private fun respondError(id: Int, code: Int, message: String) { + val resp = buildJsonObject { + put("jsonrpc", "2.0") + put("id", id) + put("error", buildJsonObject { + put("code", code) + put("message", message) + }) + } + serverWriter.write(json.encodeToString(kotlinx.serialization.json.JsonObject.serializer(), resp)) + serverWriter.newLine() + serverWriter.flush() + } + + private fun readRequest(): kotlinx.serialization.json.JsonObject { + val line = serverReader.readLine() ?: error("Server stream closed unexpectedly") + return json.decodeFromString(kotlinx.serialization.json.JsonObject.serializer(), line) + } + + /** Drain initialize + notifications/initialized, then respond to them. */ + private fun handshake() { + // 1. initialize request + val init = readRequest() + assertEquals("initialize", init["method"]?.let { (it as JsonPrimitive).content }) + val initId = (init["id"] as JsonPrimitive).content.toInt() + respond(initId, buildJsonObject { + put("protocolVersion", MCP_PROTOCOL_VERSION) + put("capabilities", buildJsonObject {}) + put("serverInfo", buildJsonObject { put("name", "test-server") }) + }) + + // 2. notifications/initialized (no response needed — it's a notification) + val notif = readRequest() + assertEquals("notifications/initialized", notif["method"]?.let { (it as JsonPrimitive).content }) + } + + // ─── Tests ─────────────────────────────────────────────────────────────── + + @Test + fun `connect performs handshake and returns tool list`() = runBlocking { + // Drive the server in background + Thread { + handshake() + // tools/list request + val listReq = readRequest() + assertEquals("tools/list", (listReq["method"] as JsonPrimitive).content) + val listId = (listReq["id"] as JsonPrimitive).content.toInt() + respond(listId, buildJsonObject { + put("tools", buildJsonArray { + add(buildJsonObject { + put("name", "echo") + put("description", "Echoes input") + put("inputSchema", buildJsonObject { put("type", "object") }) + }) + }) + }) + }.also { it.isDaemon = true }.start() + + val tools = client.connect() + + assertEquals(1, tools.size) + assertEquals("echo", tools[0].name) + assertEquals("Echoes input", tools[0].description) + } + + @Test + fun `call sends tools-call request and returns parsed result`() = runBlocking { + Thread { + handshake() + // tools/list + val listReq = readRequest() + respond((listReq["id"] as JsonPrimitive).content.toInt(), buildJsonObject { + put("tools", buildJsonArray {}) + }) + // tools/call + val callReq = readRequest() + assertEquals("tools/call", (callReq["method"] as JsonPrimitive).content) + val callId = (callReq["id"] as JsonPrimitive).content.toInt() + respond(callId, buildJsonObject { + put("content", buildJsonArray { + add(buildJsonObject { put("type", "text"); put("text", "hello") }) + }) + put("isError", false) + }) + }.also { it.isDaemon = true }.start() + + client.connect() + val result = client.call("echo", buildJsonObject { put("msg", "hello") }) + + assertFalse(result.isError) + assertEquals("hello", result.textContent()) + } + + @Test + fun `call propagates JSON-RPC error as McpException`() = runBlocking { + Thread { + handshake() + val listReq = readRequest() + respond((listReq["id"] as JsonPrimitive).content.toInt(), buildJsonObject { put("tools", buildJsonArray {}) }) + val callReq = readRequest() + val callId = (callReq["id"] as JsonPrimitive).content.toInt() + respondError(callId, -32600, "Invalid request") + }.also { it.isDaemon = true }.start() + + client.connect() + val ex = runCatching { client.call("bad", JsonNull) }.exceptionOrNull() + assertTrue(ex is McpException, "Expected McpException, got $ex") + assertTrue(ex!!.message!!.contains("Invalid request")) + } + + @Test + fun `malformed lines in server output are silently ignored`() = runBlocking { + Thread { + handshake() + val listReq = readRequest() + val listId = (listReq["id"] as JsonPrimitive).content.toInt() + // Inject garbage before the real response + serverWriter.write("NOT JSON AT ALL\n") + serverWriter.flush() + respond(listId, buildJsonObject { put("tools", buildJsonArray {}) }) + }.also { it.isDaemon = true }.start() + + val tools = client.connect() + assertEquals(0, tools.size) // garbage was ignored, real response arrived + } + + @Test + fun `textContent joins multiple text items`() { + val result = McpCallResult( + content = listOf( + McpContentItem("text", "line1"), + McpContentItem("image", null), // non-text, skipped + McpContentItem("text", "line2"), + ) + ) + assertEquals("line1\nline2", result.textContent()) + } +} diff --git a/src/test/kotlin/com/github/codeplangui/tools/mcp/McpProxyToolTest.kt b/src/test/kotlin/com/github/codeplangui/tools/mcp/McpProxyToolTest.kt new file mode 100644 index 0000000..3cc4c99 --- /dev/null +++ b/src/test/kotlin/com/github/codeplangui/tools/mcp/McpProxyToolTest.kt @@ -0,0 +1,149 @@ +package com.github.codeplangui.tools.mcp + +import com.github.codeplangui.tools.PermissionResult +import com.github.codeplangui.tools.ToolExecutionContext +import com.github.codeplangui.tools.ToolPermissionContext +import com.github.codeplangui.tools.ValidationResult +import com.intellij.openapi.project.Project +import io.mockk.coEvery +import io.mockk.every +import io.mockk.mockk +import kotlinx.coroutines.Job +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test + +class McpProxyToolTest { + + private lateinit var mockClient: McpClient + private lateinit var ctx: ToolExecutionContext + + @BeforeEach + fun setup() { + mockClient = mockk() + ctx = ToolExecutionContext( + project = mockk().also { every { it.basePath } returns "/tmp" }, + toolUseId = "test-id", + abortJob = Job(), + permissionContext = ToolPermissionContext.default(), + ) + } + + private fun spec(name: String = "echo", description: String = "Echoes text") = McpToolSpec( + name = name, + description = description, + inputSchema = buildJsonObject { put("type", "object") }, + ) + + // ─── Name and schema ───────────────────────────────────────────────────── + + @Test + fun `tool name follows mcp__server__tool convention`() { + val tool = mcpProxyTool("my_server", spec("do_thing"), mockClient) + assertEquals("mcp__my_server__do_thing", tool.name) + } + + @Test + fun `tool description inherits from spec`() { + val tool = mcpProxyTool("srv", spec(description = "Does stuff"), mockClient) + assertTrue(tool.description.contains("Does stuff")) + } + + @Test + fun `tool uses blank fallback description when spec description is empty`() { + val tool = mcpProxyTool("srv", spec(description = ""), mockClient) + assertTrue(tool.description.isNotBlank()) + } + + @Test + fun `inputSchema is forwarded from spec`() { + val schema = buildJsonObject { put("type", "object"); put("required", buildJsonObject {}) } + val tool = mcpProxyTool("srv", McpToolSpec("t", "d", schema as JsonObject), mockClient) + assertEquals(schema, tool.inputSchema) + } + + // ─── Metadata ──────────────────────────────────────────────────────────── + + @Test + fun `metadata flags are conservative`() = runBlocking { + val tool = mcpProxyTool("srv", spec(), mockClient) + val input = JsonNull + assertFalse(tool.isConcurrencySafe(input)) + assertFalse(tool.isReadOnly(input)) + assertTrue(tool.isDestructive(input)) + } + + // ─── validate ──────────────────────────────────────────────────────────── + + @Test + fun `validate rejects non-object input`() = runBlocking { + val tool = mcpProxyTool("srv", spec(), mockClient) + val result = tool.validateInput(JsonPrimitive("not an object"), ctx) + assertTrue(result is ValidationResult.Failed) + } + + @Test + fun `validate accepts object input`() = runBlocking { + val tool = mcpProxyTool("srv", spec(), mockClient) + val result = tool.validateInput(buildJsonObject { put("text", "hi") }, ctx) + assertEquals(ValidationResult.Ok, result) + } + + // ─── checkPermissions ──────────────────────────────────────────────────── + + @Test + fun `checkPermissions always returns Ask`() = runBlocking { + val tool = mcpProxyTool("srv", spec(), mockClient) + val perm = tool.checkPermissions(JsonNull, ctx) + assertTrue(perm is PermissionResult.Ask) + } + + @Test + fun `Ask preview contains tool name and server name`() = runBlocking { + val tool = mcpProxyTool("my_srv", spec("my_tool"), mockClient) + val ask = tool.checkPermissions(buildJsonObject {}, ctx) as PermissionResult.Ask + assertTrue(ask.preview?.summary?.contains("my_srv") == true) + assertTrue(ask.reason.contains("mcp__my_srv__my_tool")) + } + + // ─── call ──────────────────────────────────────────────────────────────── + + @Test + fun `call delegates to client and returns result`() = runBlocking { + val tool = mcpProxyTool("srv", spec("echo"), mockClient) + val args = buildJsonObject { put("text", "hello") } + val mcpResult = McpCallResult(content = listOf(McpContentItem("text", "hello")), isError = false) + + coEvery { mockClient.call("echo", args) } returns mcpResult + + val result = tool.call(args, ctx) + assertEquals(mcpResult, result.data) + } + + // ─── mapResult ─────────────────────────────────────────────────────────── + + @Test + fun `mapResult sets isError from call result`() = runBlocking { + val tool = mcpProxyTool("srv", spec(), mockClient) + val errResult = McpCallResult(content = listOf(McpContentItem("text", "boom")), isError = true) + val block = tool.mapResultToApiBlock(errResult, "u1") + assertTrue(block.isError) + assertTrue(block.content.contains("boom")) + } + + @Test + fun `mapResult includes error flag in content for error results`() = runBlocking { + val tool = mcpProxyTool("srv", spec(), mockClient) + val errResult = McpCallResult(content = listOf(McpContentItem("text", "fail")), isError = true) + val block = tool.mapResultToApiBlock(errResult, "u1") + assertTrue(block.content.contains("error: true")) + } +}