diff --git a/code-rs/core/src/codex.rs b/code-rs/core/src/codex.rs index ac4358bca5b..6af6590ae3b 100644 --- a/code-rs/core/src/codex.rs +++ b/code-rs/core/src/codex.rs @@ -935,6 +935,7 @@ mod tests { content: "manual body".to_string(), policy: Some(SkillPolicy { allow_implicit_invocation: Some(false), + command_policies: Vec::new(), }), }]; let input = vec![InputItem::Text { diff --git a/code-rs/core/src/codex/session.rs b/code-rs/core/src/codex/session.rs index ccef1b03c34..5ae12691310 100644 --- a/code-rs/core/src/codex/session.rs +++ b/code-rs/core/src/codex/session.rs @@ -469,6 +469,7 @@ pub(crate) struct Session { pub(super) tools_config: ToolsConfig, pub(super) dynamic_tools: Vec, pub(super) skills: Vec, + pub(super) skill_command_policies: crate::skills::command_policy::SkillCommandPolicyRuntime, /// Manager for external MCP servers/tools. pub(super) mcp_connection_manager: McpConnectionManager, diff --git a/code-rs/core/src/codex/streaming.rs b/code-rs/core/src/codex/streaming.rs index 6a6885204aa..cc6e7ad4493 100644 --- a/code-rs/core/src/codex/streaming.rs +++ b/code-rs/core/src/codex/streaming.rs @@ -802,16 +802,22 @@ pub(super) async fn submission_loop( remote.refresh_remote_models().await; }); } + let session_skills = skills_outcome + .as_ref() + .map(|outcome| outcome.skills.clone()) + .unwrap_or_default(); + let skill_command_policies = + crate::skills::command_policy::SkillCommandPolicyRuntime::from_skills( + &session_skills, + ); let mut new_session = Arc::new(Session { id: session_id, client, remote_models_manager, tools_config, dynamic_tools, - skills: skills_outcome - .as_ref() - .map(|outcome| outcome.skills.clone()) - .unwrap_or_default(), + skills: session_skills, + skill_command_policies, tx_event: tx_event.clone(), user_instructions: effective_user_instructions.clone(), base_instructions, @@ -9818,6 +9824,29 @@ async fn handle_list_agents( ).await } +async fn command_guard_output( + sess: &Session, + sub_id: &str, + call_id: String, + attempt_req: u64, + output_index: Option, + guidance: String, +) -> ResponseInputItem { + let order = sess.next_background_order(sub_id, attempt_req, output_index); + sess + .notify_background_event_with_order( + sub_id, + order, + format!("Command guard: {}", guidance.clone()), + ) + .await; + + ResponseInputItem::FunctionCallOutput { + call_id, + output: FunctionCallOutputPayload::from_text(guidance), + } +} + async fn handle_container_exec_with_params( params: ExecParams, sess: &Session, @@ -10092,6 +10121,22 @@ async fn handle_container_exec_with_params( .iter() .any(|p| trimmed.starts_with(p)); + if let Some(policy_match) = sess + .skill_command_policies + .check(¶ms.command, has_confirm_prefix) + { + let guidance = policy_match.guidance("original_script", &script); + return command_guard_output( + sess, + &sub_id, + call_id, + attempt_req, + output_index, + guidance, + ) + .await; + } + // If no confirm prefix and it looks like a sensitive git command, reject with guidance. if !has_confirm_prefix { if let Some(pattern) = if sess.confirm_guard.is_empty() { @@ -10256,6 +10301,21 @@ async fn handle_container_exec_with_params( // If no shell script is present, perform a lightweight argv inspection for sensitive git commands. if extract_shell_script(¶ms.command).is_none() { let joined = params.command.join(" "); + if let Some(policy_match) = sess.skill_command_policies.check(¶ms.command, false) { + let guidance = policy_match.guidance( + "original_argv", + &format!("{:?}", params.command), + ); + return command_guard_output( + sess, + &sub_id, + call_id, + attempt_req, + output_index, + guidance, + ) + .await; + } if !sess.confirm_guard.is_empty() { if let Some(pattern) = sess.confirm_guard.matched_pattern(&joined) { let suggested = serde_json::to_string(&vec![ diff --git a/code-rs/core/src/skills/command_policy.rs b/code-rs/core/src/skills/command_policy.rs new file mode 100644 index 00000000000..72851ff33e2 --- /dev/null +++ b/code-rs/core/src/skills/command_policy.rs @@ -0,0 +1,454 @@ +use crate::skills::model::SkillCommandMatcher; +use crate::skills::model::SkillCommandPolicy; +use crate::skills::model::SkillCommandPolicyAction; +use crate::skills::model::SkillCommandPolicyPreferred; +use crate::skills::model::SkillCommandPolicyPreferredKind; +use crate::skills::model::SkillMetadata; +use regex_lite::Regex; +use std::path::PathBuf; + +#[derive(Debug, Clone)] +pub(crate) struct SkillCommandPolicyRuntime { + policies: Vec, +} + +#[derive(Debug, Clone)] +pub(crate) struct CompiledSkillCommandPolicy { + pub(crate) skill_name: String, + pub(crate) skill_path: PathBuf, + pub(crate) id: String, + pub(crate) matcher: SkillCommandMatcher, + pub(crate) shell_regex: Option, + pub(crate) action: SkillCommandPolicyAction, + pub(crate) message: Option, + pub(crate) preferred: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum MatchSpecificity { + Exact, + Prefix, + Regex, +} + +#[derive(Debug, Clone)] +pub(crate) struct SkillCommandPolicyMatch<'a> { + pub(crate) policy: &'a CompiledSkillCommandPolicy, + matched_command: String, + specificity: MatchSpecificity, +} + +impl SkillCommandPolicyRuntime { + pub(crate) fn from_skills(skills: &[SkillMetadata]) -> Self { + let mut policies = Vec::new(); + for skill in skills { + let Some(policy) = skill.policy.as_ref() else { + continue; + }; + for command_policy in &policy.command_policies { + if let Some(compiled) = compile_policy(skill, command_policy) { + policies.push(compiled); + } + } + } + Self { policies } + } + + pub(crate) fn check<'a>( + &'a self, + command: &[String], + has_confirm_prefix: bool, + ) -> Option> { + if self.policies.is_empty() { + return None; + } + let normalized = NormalizedCommand::from_command(command); + let mut best: Option> = None; + for policy in &self.policies { + if has_confirm_prefix && policy.action == SkillCommandPolicyAction::RequireConfirm { + continue; + } + let Some(candidate) = policy_matches(policy, &normalized) else { + continue; + }; + if best + .as_ref() + .is_none_or(|current| candidate.is_more_specific_than(current)) + { + best = Some(candidate); + } + } + best + } +} + +impl<'a> SkillCommandPolicyMatch<'a> { + fn is_more_specific_than(&self, other: &Self) -> bool { + specificity_rank(self.specificity) < specificity_rank(other.specificity) + } + + pub(crate) fn guidance(&self, original_label: &str, original_value: &str) -> String { + let policy = self.policy; + let message = policy.message.as_deref().unwrap_or(match policy.action { + SkillCommandPolicyAction::RequirePreferred => { + "A loaded skill owns this workflow; use one of the preferred actions instead." + } + SkillCommandPolicyAction::RequireConfirm => { + "A loaded skill requires explicit confirmation before this raw command runs." + } + SkillCommandPolicyAction::Reject => { + "A loaded skill rejects this raw command shape." + } + }); + let mut lines = vec![ + format!( + "Command policy matched skill `{}` ({}) at {}.", + policy.skill_name, + policy.id, + policy.skill_path.display() + ), + String::new(), + message.to_string(), + String::new(), + ]; + + if !policy.preferred.is_empty() { + lines.push("Preferred actions:".to_string()); + for preferred in &policy.preferred { + lines.push(format_preferred(preferred)); + } + lines.push(String::new()); + } + + if policy.action == SkillCommandPolicyAction::RequireConfirm { + lines.push("Resend with `confirm:` only if the user explicitly asked for this raw command.".to_string()); + lines.push(String::new()); + } + + lines.push(format!("matched_command: {}", self.matched_command)); + lines.push(format!("{original_label}: {original_value}")); + lines.join("\n") + } +} + +fn compile_policy( + skill: &SkillMetadata, + policy: &SkillCommandPolicy, +) -> Option { + let shell_regex = policy + .matcher + .shell_regex + .as_ref() + .and_then(|pattern| Regex::new(pattern).ok()); + Some(CompiledSkillCommandPolicy { + skill_name: skill.name.clone(), + skill_path: skill.path.clone(), + id: policy.id.clone(), + matcher: policy.matcher.clone(), + shell_regex, + action: policy.action, + message: policy.message.clone(), + preferred: policy.preferred.clone(), + }) +} + +fn policy_matches<'a>( + policy: &'a CompiledSkillCommandPolicy, + normalized: &NormalizedCommand, +) -> Option> { + if let Some(argv_exact) = policy.matcher.argv_exact.as_ref() { + for argv in normalized.argv_candidates() { + if argv == argv_exact.as_slice() { + return Some(SkillCommandPolicyMatch { + policy, + matched_command: argv.join(" "), + specificity: MatchSpecificity::Exact, + }); + } + } + } + + if let Some(argv_prefix) = policy.matcher.argv_prefix.as_ref() { + for argv in normalized.argv_candidates() { + if argv.starts_with(argv_prefix) { + return Some(SkillCommandPolicyMatch { + policy, + matched_command: argv.join(" "), + specificity: MatchSpecificity::Prefix, + }); + } + } + } + + if let Some(regex) = policy.shell_regex.as_ref() { + for text in normalized.text_candidates() { + if regex.is_match(&text) { + return Some(SkillCommandPolicyMatch { + policy, + matched_command: text.clone(), + specificity: MatchSpecificity::Regex, + }); + } + } + } + + None +} + +fn specificity_rank(specificity: MatchSpecificity) -> u8 { + match specificity { + MatchSpecificity::Exact => 0, + MatchSpecificity::Prefix => 1, + MatchSpecificity::Regex => 2, + } +} + +#[derive(Debug)] +struct NormalizedCommand { + original_argv: Vec, + shell_script: Option, + shell_argv: Option>, +} + +impl NormalizedCommand { + fn from_command(command: &[String]) -> Self { + let shell_script = extract_shell_script(command).map(|(_, script)| script.to_string()); + let shell_argv = shell_script + .as_ref() + .and_then(|script| shlex::split(script)); + Self { + original_argv: command.to_vec(), + shell_script, + shell_argv, + } + } + + fn argv_candidates(&self) -> Vec<&[String]> { + let mut candidates = Vec::new(); + if let Some(shell_argv) = self.shell_argv.as_ref() { + candidates.push(shell_argv.as_slice()); + } + candidates.push(self.original_argv.as_slice()); + candidates + } + + fn text_candidates(&self) -> Vec { + let mut candidates = Vec::new(); + if let Some(shell_script) = self.shell_script.as_ref() { + candidates.push(shell_script.clone()); + } + candidates.push(self.original_argv.join(" ")); + candidates + } +} + +fn extract_shell_script(command: &[String]) -> Option<(usize, &str)> { + if command.len() < 3 { + return None; + } + let shell = command.first()?.rsplit('/').next()?; + if !matches!(shell, "bash" | "sh" | "zsh" | "fish") { + return None; + } + for (index, arg) in command.iter().enumerate().skip(1) { + if matches!(arg.as_str(), "-c" | "-lc" | "-ic") { + return command.get(index + 1).map(|script| (index + 1, script.as_str())); + } + if !arg.starts_with('-') { + break; + } + } + None +} + +fn format_preferred(preferred: &SkillCommandPolicyPreferred) -> String { + let kind = match preferred.kind { + SkillCommandPolicyPreferredKind::Script => "script", + SkillCommandPolicyPreferredKind::Skill => "skill", + SkillCommandPolicyPreferredKind::Command => "command", + }; + let mut parts = vec![format!("- {kind}")]; + if let Some(name) = preferred.name.as_deref() { + parts.push(format!("name: {name}")); + } + if let Some(path) = preferred.path.as_deref() { + parts.push(format!("path: {}", path.display())); + } + if !preferred.example_argv.is_empty() { + parts.push(format!("example: {}", preferred.example_argv.join(" "))); + } + if let Some(purpose) = preferred.purpose.as_deref() { + parts.push(format!("purpose: {purpose}")); + } + parts.join("; ") +} + +pub(crate) fn render_command_policy_summary(skill: &SkillMetadata) -> Vec { + let mut lines = Vec::new(); + let Some(policy) = skill.policy.as_ref() else { + return lines; + }; + for command_policy in &policy.command_policies { + let matcher = describe_matcher(&command_policy.matcher); + let preferred = command_policy + .preferred + .first() + .map(format_preferred) + .unwrap_or_else(|| "preferred action declared in skill body".to_string()); + lines.push(format!( + " - `{matcher}`: use {preferred} instead." + )); + } + lines +} + +fn describe_matcher(matcher: &SkillCommandMatcher) -> String { + if let Some(argv_exact) = matcher.argv_exact.as_ref() { + return argv_exact.join(" "); + } + if let Some(argv_prefix) = matcher.argv_prefix.as_ref() { + return format!("{} ...", argv_prefix.join(" ")); + } + if let Some(shell_regex) = matcher.shell_regex.as_ref() { + return format!("/{shell_regex}/"); + } + "".to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::skills::model::SkillCommandMatcher; + use crate::skills::model::SkillCommandPolicy; + use crate::skills::model::SkillCommandPolicyAction; + use crate::skills::model::SkillCommandPolicyPreferred; + use crate::skills::model::SkillCommandPolicyPreferredKind; + use crate::skills::model::SkillPolicy; + use crate::skills::model::SkillScope; + + fn runtime(policy: SkillCommandPolicy) -> SkillCommandPolicyRuntime { + SkillCommandPolicyRuntime::from_skills(&[SkillMetadata { + name: "github".to_string(), + description: "GitHub workflows".to_string(), + short_description: None, + path: PathBuf::from("/tmp/github/SKILL.md"), + scope: SkillScope::User, + content: String::new(), + policy: Some(SkillPolicy { + allow_implicit_invocation: Some(false), + command_policies: vec![policy], + }), + }]) + } + + fn policy(matcher: SkillCommandMatcher) -> SkillCommandPolicy { + SkillCommandPolicy { + id: "prefer-helper".to_string(), + matcher, + action: SkillCommandPolicyAction::RequirePreferred, + message: Some("use the helper".to_string()), + preferred: vec![SkillCommandPolicyPreferred { + kind: SkillCommandPolicyPreferredKind::Script, + path: Some(PathBuf::from("/tmp/github/scripts/gh-pr.py")), + name: None, + example_argv: vec!["scripts/gh-pr.py".to_string(), "merge".to_string()], + purpose: Some("merge safely".to_string()), + }], + } + } + + #[test] + fn argv_prefix_matches_direct_command() { + let runtime = runtime(policy(SkillCommandMatcher { + argv_prefix: Some(vec!["gh".to_string(), "pr".to_string(), "merge".to_string()]), + ..Default::default() + })); + + let matched = runtime + .check( + &["gh".to_string(), "pr".to_string(), "merge".to_string(), "234".to_string()], + false, + ) + .expect("match"); + + assert_eq!(matched.policy.id, "prefer-helper"); + } + + #[test] + fn argv_prefix_matches_shell_script() { + let runtime = runtime(policy(SkillCommandMatcher { + argv_prefix: Some(vec!["gh".to_string(), "pr".to_string(), "merge".to_string()]), + ..Default::default() + })); + + let matched = runtime + .check( + &[ + "bash".to_string(), + "-lc".to_string(), + "gh pr merge 234 --delete-branch".to_string(), + ], + false, + ) + .expect("match"); + + assert_eq!(matched.matched_command, "gh pr merge 234 --delete-branch"); + } + + #[test] + fn shell_regex_matches_script_text() { + let runtime = runtime(policy(SkillCommandMatcher { + shell_regex: Some("^gh\\s+pr\\s+merge\\b".to_string()), + ..Default::default() + })); + + assert!( + runtime + .check( + &[ + "bash".to_string(), + "-lc".to_string(), + "gh pr merge 234".to_string(), + ], + false, + ) + .is_some() + ); + } + + #[test] + fn require_confirm_is_bypassed_by_confirm_prefix() { + let mut command_policy = policy(SkillCommandMatcher { + argv_prefix: Some(vec!["gh".to_string(), "pr".to_string(), "merge".to_string()]), + ..Default::default() + }); + command_policy.action = SkillCommandPolicyAction::RequireConfirm; + let runtime = runtime(command_policy); + + assert!( + runtime + .check( + &["gh".to_string(), "pr".to_string(), "merge".to_string()], + true, + ) + .is_none() + ); + } + + #[test] + fn guidance_includes_preferred_action() { + let runtime = runtime(policy(SkillCommandMatcher { + argv_prefix: Some(vec!["gh".to_string(), "pr".to_string(), "merge".to_string()]), + ..Default::default() + })); + + let guidance = runtime + .check(&["gh".to_string(), "pr".to_string(), "merge".to_string()], false) + .expect("match") + .guidance("original_argv", "[\"gh\", \"pr\", \"merge\"]"); + + assert!(guidance.contains("Command policy matched skill `github`")); + assert!(guidance.contains("scripts/gh-pr.py")); + assert!(guidance.contains("merge safely")); + } +} diff --git a/code-rs/core/src/skills/loader.rs b/code-rs/core/src/skills/loader.rs index b8ed00a4c5c..0f537e0351b 100644 --- a/code-rs/core/src/skills/loader.rs +++ b/code-rs/core/src/skills/loader.rs @@ -3,6 +3,11 @@ use crate::config::SkillConfigRuleSelector; use crate::config::resolve_code_path_for_read; use crate::git_info::resolve_root_git_project_for_trust; use crate::skills::model::SkillError; +use crate::skills::model::SkillCommandMatcher; +use crate::skills::model::SkillCommandPolicy; +use crate::skills::model::SkillCommandPolicyAction; +use crate::skills::model::SkillCommandPolicyPreferred; +use crate::skills::model::SkillCommandPolicyPreferredKind; use crate::skills::model::SkillLoadOutcome; use crate::skills::model::SkillMetadata; use crate::skills::model::SkillPolicy; @@ -40,6 +45,59 @@ struct SkillFrontmatterMetadata { struct SkillFrontmatterPolicy { #[serde(default)] allow_implicit_invocation: Option, + #[serde(default)] + command_policies: Vec, +} + +#[derive(Debug, Deserialize)] +struct SkillFrontmatterCommandPolicy { + id: String, + #[serde(rename = "match")] + matcher: SkillFrontmatterCommandMatcher, + action: SkillFrontmatterCommandPolicyAction, + #[serde(default)] + message: Option, + #[serde(default)] + preferred: Vec, +} + +#[derive(Debug, Deserialize)] +struct SkillFrontmatterCommandMatcher { + #[serde(default)] + argv_exact: Option>, + #[serde(default)] + argv_prefix: Option>, + #[serde(default)] + shell_regex: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "snake_case")] +enum SkillFrontmatterCommandPolicyAction { + RequirePreferred, + RequireConfirm, + Reject, +} + +#[derive(Debug, Deserialize)] +struct SkillFrontmatterCommandPolicyPreferred { + kind: SkillFrontmatterCommandPolicyPreferredKind, + #[serde(default)] + path: Option, + #[serde(default)] + name: Option, + #[serde(default)] + example_argv: Vec, + #[serde(default)] + purpose: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "snake_case")] +enum SkillFrontmatterCommandPolicyPreferredKind { + Script, + Skill, + Command, } const SKILLS_FILENAME: &str = "SKILL.md"; @@ -50,6 +108,14 @@ const ADMIN_SKILLS_ROOT: &str = "/etc/codex/skills"; const MAX_NAME_LEN: usize = 64; const MAX_DESCRIPTION_LEN: usize = 1024; const MAX_SHORT_DESCRIPTION_LEN: usize = 160; +const MAX_COMMAND_POLICY_ID_LEN: usize = 96; +const MAX_COMMAND_POLICY_MESSAGE_LEN: usize = 512; +const MAX_COMMAND_POLICY_TOKEN_LEN: usize = 160; +const MAX_COMMAND_POLICY_REGEX_LEN: usize = 512; +const MAX_COMMAND_POLICY_PURPOSE_LEN: usize = 256; +const MAX_COMMAND_POLICIES_PER_SKILL: usize = 64; +const MAX_COMMAND_POLICY_PREFERRED: usize = 8; +const MAX_COMMAND_POLICY_ARGV_TOKENS: usize = 32; #[derive(Debug)] enum SkillParseError { @@ -368,6 +434,11 @@ fn parse_skill_file(path: &Path, scope: SkillScope) -> Result Result Result { + if policy.command_policies.len() > MAX_COMMAND_POLICIES_PER_SKILL { + return Err(SkillParseError::InvalidField { + field: "policy.command_policies", + reason: format!( + "must contain at most {MAX_COMMAND_POLICIES_PER_SKILL} entries" + ), + }); + } + + let mut command_policies = Vec::with_capacity(policy.command_policies.len()); + for (index, command_policy) in policy.command_policies.into_iter().enumerate() { + command_policies.push(parse_command_policy(command_policy, skill_dir, index)?); + } + + Ok(SkillPolicy { + allow_implicit_invocation: policy.allow_implicit_invocation, + command_policies, + }) +} + +fn parse_command_policy( + policy: SkillFrontmatterCommandPolicy, + skill_dir: &Path, + index: usize, +) -> Result { + let id = sanitize_single_line(&policy.id); + validate_field(&id, MAX_COMMAND_POLICY_ID_LEN, "policy.command_policies.id")?; + let matcher = parse_command_matcher(policy.matcher, index)?; + let message = policy + .message + .map(|message| sanitize_single_line(&message)) + .filter(|message| !message.is_empty()) + .map(|message| { + validate_field( + &message, + MAX_COMMAND_POLICY_MESSAGE_LEN, + "policy.command_policies.message", + )?; + Ok::(message) + }) + .transpose()?; + + if policy.preferred.len() > MAX_COMMAND_POLICY_PREFERRED { + return Err(SkillParseError::InvalidField { + field: "policy.command_policies.preferred", + reason: format!("must contain at most {MAX_COMMAND_POLICY_PREFERRED} entries"), + }); + } + + let mut preferred = Vec::with_capacity(policy.preferred.len()); + for preferred_entry in policy.preferred { + preferred.push(parse_command_policy_preferred(preferred_entry, skill_dir)?); + } + + Ok(SkillCommandPolicy { + id, + matcher, + action: match policy.action { + SkillFrontmatterCommandPolicyAction::RequirePreferred => { + SkillCommandPolicyAction::RequirePreferred + } + SkillFrontmatterCommandPolicyAction::RequireConfirm => { + SkillCommandPolicyAction::RequireConfirm + } + SkillFrontmatterCommandPolicyAction::Reject => SkillCommandPolicyAction::Reject, + }, + message, + preferred, + }) +} + +fn parse_command_matcher( + matcher: SkillFrontmatterCommandMatcher, + index: usize, +) -> Result { + let argv_exact = parse_match_argv( + matcher.argv_exact, + "policy.command_policies.match.argv_exact", + )?; + let argv_prefix = parse_match_argv( + matcher.argv_prefix, + "policy.command_policies.match.argv_prefix", + )?; + let shell_regex = matcher + .shell_regex + .map(|regex| sanitize_single_line(®ex)) + .filter(|regex| !regex.is_empty()) + .map(|regex| { + validate_field( + ®ex, + MAX_COMMAND_POLICY_REGEX_LEN, + "policy.command_policies.match.shell_regex", + )?; + regex_lite::Regex::new(®ex).map_err(|err| SkillParseError::InvalidField { + field: "policy.command_policies.match.shell_regex", + reason: format!("invalid regex in entry {index}: {err}"), + })?; + Ok::(regex) + }) + .transpose()?; + + let matcher_count = usize::from(argv_exact.is_some()) + + usize::from(argv_prefix.is_some()) + + usize::from(shell_regex.is_some()); + if matcher_count == 0 { + return Err(SkillParseError::InvalidField { + field: "policy.command_policies.match", + reason: format!( + "entry {index} must set one of argv_exact, argv_prefix, or shell_regex" + ), + }); + } + if matcher_count > 1 { + return Err(SkillParseError::InvalidField { + field: "policy.command_policies.match", + reason: format!( + "entry {index} must set only one of argv_exact, argv_prefix, or shell_regex" + ), + }); + } + + Ok(SkillCommandMatcher { + argv_exact, + argv_prefix, + shell_regex, + }) +} + +fn parse_match_argv( + argv: Option>, + field: &'static str, +) -> Result>, SkillParseError> { + let Some(argv) = argv else { + return Ok(None); + }; + validate_argv_tokens(argv, field).map(Some) +} + +fn validate_argv_tokens( + argv: Vec, + field: &'static str, +) -> Result, SkillParseError> { + if argv.is_empty() || argv.len() > MAX_COMMAND_POLICY_ARGV_TOKENS { + return Err(SkillParseError::InvalidField { + field, + reason: format!( + "must contain between 1 and {MAX_COMMAND_POLICY_ARGV_TOKENS} tokens" + ), + }); + } + + let mut tokens = Vec::with_capacity(argv.len()); + for token in argv { + let token = sanitize_single_line(&token); + validate_field(&token, MAX_COMMAND_POLICY_TOKEN_LEN, field)?; + tokens.push(token); + } + Ok(tokens) +} + +fn parse_command_policy_preferred( + preferred: SkillFrontmatterCommandPolicyPreferred, + skill_dir: &Path, +) -> Result { + let purpose = preferred + .purpose + .map(|purpose| sanitize_single_line(&purpose)) + .filter(|purpose| !purpose.is_empty()) + .map(|purpose| { + validate_field( + &purpose, + MAX_COMMAND_POLICY_PURPOSE_LEN, + "policy.command_policies.preferred.purpose", + )?; + Ok::(purpose) + }) + .transpose()?; + let name = preferred + .name + .map(|name| sanitize_single_line(&name)) + .filter(|name| !name.is_empty()) + .map(|name| { + validate_field( + &name, + MAX_NAME_LEN, + "policy.command_policies.preferred.name", + )?; + Ok::(name) + }) + .transpose()?; + let path = preferred + .path + .map(|path| resolve_skill_relative_path(skill_dir, path)); + let example_argv = if preferred.example_argv.is_empty() { + Vec::new() + } else { + validate_argv_tokens( + preferred.example_argv, + "policy.command_policies.preferred.example_argv", + )? + }; + + Ok(SkillCommandPolicyPreferred { + kind: match preferred.kind { + SkillFrontmatterCommandPolicyPreferredKind::Script => { + SkillCommandPolicyPreferredKind::Script + } + SkillFrontmatterCommandPolicyPreferredKind::Skill => { + SkillCommandPolicyPreferredKind::Skill + } + SkillFrontmatterCommandPolicyPreferredKind::Command => { + SkillCommandPolicyPreferredKind::Command + } + }, + path, + name, + example_argv, + purpose, }) } +fn resolve_skill_relative_path(skill_dir: &Path, path: PathBuf) -> PathBuf { + let path = if path.is_absolute() { + path + } else { + skill_dir.join(path) + }; + normalize_path(&path).unwrap_or(path) +} + fn sanitize_single_line(raw: &str) -> String { raw.split_whitespace().collect::>().join(" ") } @@ -542,6 +845,20 @@ mod tests { skill_path } + fn write_command_policy_skill_at(skills_root: &Path) -> PathBuf { + let skill_dir = skills_root.join("github"); + fs::create_dir_all(skill_dir.join("scripts")).expect("create skill dir"); + let script_path = skill_dir.join("scripts").join("gh-pr.py"); + fs::write(&script_path, "#!/usr/bin/env python3\n").expect("write script"); + let skill_path = skill_dir.join(SKILLS_FILENAME); + fs::write( + &skill_path, + "---\nname: github\ndescription: GitHub workflows\npolicy:\n command_policies:\n - id: prefer-pr-merge\n match:\n argv_prefix: [\"gh\", \"pr\", \"merge\"]\n action: require_preferred\n message: Raw gh pr merge bypasses helper flow.\n preferred:\n - kind: script\n path: scripts/gh-pr.py\n example_argv: [\"scripts/gh-pr.py\", \"merge\", \"\"]\n purpose: Use helper script.\n - kind: skill\n name: github\n---\n\n# github\n", + ) + .expect("write skill file"); + skill_path + } + fn normalized(path: &Path) -> PathBuf { normalize_path(path).unwrap_or_else(|_| path.to_path_buf()) } @@ -573,11 +890,98 @@ mod tests { skill.policy, Some(SkillPolicy { allow_implicit_invocation: Some(false), + command_policies: Vec::new(), }) ); assert!(!skill.allow_implicit_invocation()); } + #[test] + fn loads_command_policy_from_frontmatter() { + let skills_root = tempfile::tempdir().expect("tempdir"); + let skill_path = write_command_policy_skill_at(skills_root.path()); + + let outcome = load_skills_from_roots(vec![SkillRoot { + path: skills_root.path().to_path_buf(), + scope: SkillScope::User, + }]); + + assert!( + outcome.errors.is_empty(), + "unexpected errors: {:?}", + outcome.errors + ); + let skill = outcome.skills.first().expect("skill"); + let policy = skill.policy.as_ref().expect("policy"); + assert!(skill.allow_implicit_invocation()); + assert_eq!(policy.command_policies.len(), 1); + let command_policy = &policy.command_policies[0]; + assert_eq!(command_policy.id, "prefer-pr-merge"); + assert_eq!( + command_policy.matcher.argv_prefix.as_deref(), + Some(&["gh".to_string(), "pr".to_string(), "merge".to_string()][..]) + ); + assert_eq!( + command_policy.action, + SkillCommandPolicyAction::RequirePreferred + ); + assert_eq!(command_policy.preferred.len(), 2); + assert_eq!( + command_policy.preferred[0].path.as_deref(), + Some(normalized(&skill_path).parent().unwrap().join("scripts/gh-pr.py").as_path()) + ); + } + + #[test] + fn invalid_command_policy_reports_skill_error() { + let skills_root = tempfile::tempdir().expect("tempdir"); + let skill_dir = skills_root.path().join("bad"); + fs::create_dir_all(&skill_dir).expect("create skill dir"); + fs::write( + skill_dir.join(SKILLS_FILENAME), + "---\nname: bad\ndescription: Bad policy\npolicy:\n command_policies:\n - id: bad-regex\n match:\n shell_regex: \"[\"\n action: require_preferred\n---\n\n# bad\n", + ) + .expect("write skill file"); + + let outcome = load_skills_from_roots(vec![SkillRoot { + path: skills_root.path().to_path_buf(), + scope: SkillScope::User, + }]); + + assert!(outcome.skills.is_empty()); + assert_eq!(outcome.errors.len(), 1); + assert!( + outcome.errors[0].message.contains("invalid regex"), + "unexpected error: {}", + outcome.errors[0].message + ); + } + + #[test] + fn command_policy_rejects_multiple_matchers() { + let skills_root = tempfile::tempdir().expect("tempdir"); + let skill_dir = skills_root.path().join("bad"); + fs::create_dir_all(&skill_dir).expect("create skill dir"); + fs::write( + skill_dir.join(SKILLS_FILENAME), + "---\nname: bad\ndescription: Bad policy\npolicy:\n command_policies:\n - id: too-many\n match:\n argv_exact: [\"gh\", \"pr\", \"merge\"]\n argv_prefix: [\"gh\", \"pr\"]\n action: require_preferred\n---\n\n# bad\n", + ) + .expect("write skill file"); + + let outcome = load_skills_from_roots(vec![SkillRoot { + path: skills_root.path().to_path_buf(), + scope: SkillScope::User, + }]); + + assert!(outcome.skills.is_empty()); + assert_eq!(outcome.errors.len(), 1); + assert!( + outcome.errors[0].message.contains("must set only one"), + "unexpected error: {}", + outcome.errors[0].message + ); + } + #[test] fn config_name_selector_disables_implicit_invocation() { let skills_root = tempfile::tempdir().expect("tempdir"); diff --git a/code-rs/core/src/skills/mod.rs b/code-rs/core/src/skills/mod.rs index 712e5ba91e3..bba502eeb77 100644 --- a/code-rs/core/src/skills/mod.rs +++ b/code-rs/core/src/skills/mod.rs @@ -1,3 +1,4 @@ +pub mod command_policy; pub mod loader; pub mod model; pub mod render; diff --git a/code-rs/core/src/skills/model.rs b/code-rs/core/src/skills/model.rs index 169993d2b28..6a6d72bfa08 100644 --- a/code-rs/core/src/skills/model.rs +++ b/code-rs/core/src/skills/model.rs @@ -31,6 +31,46 @@ impl SkillMetadata { #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct SkillPolicy { pub allow_implicit_invocation: Option, + pub command_policies: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SkillCommandPolicy { + pub id: String, + pub matcher: SkillCommandMatcher, + pub action: SkillCommandPolicyAction, + pub message: Option, + pub preferred: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct SkillCommandMatcher { + pub argv_exact: Option>, + pub argv_prefix: Option>, + pub shell_regex: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SkillCommandPolicyAction { + RequirePreferred, + RequireConfirm, + Reject, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SkillCommandPolicyPreferred { + pub kind: SkillCommandPolicyPreferredKind, + pub path: Option, + pub name: Option, + pub example_argv: Vec, + pub purpose: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SkillCommandPolicyPreferredKind { + Script, + Skill, + Command, } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/code-rs/core/src/skills/render.rs b/code-rs/core/src/skills/render.rs index 5d3ec2436de..d1a9b139b58 100644 --- a/code-rs/core/src/skills/render.rs +++ b/code-rs/core/src/skills/render.rs @@ -1,4 +1,5 @@ use crate::skills::model::SkillMetadata; +use crate::skills::command_policy::render_command_policy_summary; pub fn render_skills_section(skills: &[SkillMetadata]) -> Option { let implicit_skills: Vec<&SkillMetadata> = skills @@ -26,6 +27,7 @@ pub fn render_skills_section(skills: &[SkillMetadata]) -> Option { let name = skill.name.as_str(); let description = skill.description.as_str(); lines.push(format!("- {name}: {description} (file: {path_str})")); + lines.extend(render_command_policy_summary(skill)); } } @@ -66,6 +68,11 @@ pub fn render_skills_section(skills: &[SkillMetadata]) -> Option { #[cfg(test)] mod tests { use super::*; + use crate::skills::model::SkillCommandMatcher; + use crate::skills::model::SkillCommandPolicy; + use crate::skills::model::SkillCommandPolicyAction; + use crate::skills::model::SkillCommandPolicyPreferred; + use crate::skills::model::SkillCommandPolicyPreferredKind; use crate::skills::model::SkillPolicy; use crate::skills::model::SkillScope; use std::path::PathBuf; @@ -80,6 +87,7 @@ mod tests { content: String::new(), policy: allow_implicit_invocation.map(|allow_implicit_invocation| SkillPolicy { allow_implicit_invocation: Some(allow_implicit_invocation), + command_policies: Vec::new(), }), } } @@ -135,6 +143,39 @@ mod tests { assert!(!rendered.contains("compact UI summary")); } + #[test] + fn render_skills_section_includes_command_policy_guidance() { + let mut skill = skill("github", None); + skill.policy = Some(SkillPolicy { + allow_implicit_invocation: None, + command_policies: vec![SkillCommandPolicy { + id: "prefer-pr-merge".to_string(), + matcher: SkillCommandMatcher { + argv_prefix: Some(vec![ + "gh".to_string(), + "pr".to_string(), + "merge".to_string(), + ]), + ..Default::default() + }, + action: SkillCommandPolicyAction::RequirePreferred, + message: Some("use helper".to_string()), + preferred: vec![SkillCommandPolicyPreferred { + kind: SkillCommandPolicyPreferredKind::Script, + path: Some(PathBuf::from("/tmp/github/scripts/gh-pr.py")), + name: None, + example_argv: vec!["scripts/gh-pr.py".to_string(), "merge".to_string()], + purpose: Some("merge through helper".to_string()), + }], + }], + }); + + let rendered = render_skills_section(&[skill]).expect("skill should render"); + + assert!(rendered.contains("`gh pr merge ...`: use - script")); + assert!(rendered.contains("scripts/gh-pr.py")); + } + #[test] fn render_skills_section_resolves_relative_paths_from_skill_dir() { let rendered = render_skills_section(&[skill("helper", None)]) diff --git a/code-rs/core/tests/skill_command_policy.rs b/code-rs/core/tests/skill_command_policy.rs new file mode 100644 index 00000000000..381985a41b3 --- /dev/null +++ b/code-rs/core/tests/skill_command_policy.rs @@ -0,0 +1,203 @@ +#![allow(clippy::unwrap_used)] + +mod common; + +use common::{load_default_config_for_test, wait_for_event}; + +use code_core::model_family::find_family_for_model; +use code_core::protocol::{AskForApproval, EventMsg, InputItem, Op, SandboxPolicy}; +use code_core::{built_in_model_providers, CodexAuth, ConversationManager, ModelProviderInfo}; +use serde_json::Value; +use serde_json::json; +use tempfile::TempDir; +use wiremock::matchers::{method, path_regex}; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +fn sse_response(body: String) -> ResponseTemplate { + ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_string(body) +} + +fn tool_call_sse() -> String { + let function_call_args = json!({ + "command": ["bash", "-lc", "gh pr merge 236"], + "workdir": null, + "timeout_ms": null, + "sandbox_permissions": null, + "justification": null, + }); + let function_call_item = json!({ + "type": "response.output_item.done", + "item": { + "type": "function_call", + "id": "call-policy", + "call_id": "call-policy", + "name": "shell", + "arguments": function_call_args.to_string(), + } + }); + let completed = completed_event("resp-policy-1"); + format!( + "event: response.output_item.done\ndata: {function_call_item}\n\n\ +event: response.completed\ndata: {completed}\n\n" + ) +} + +fn completed_message_sse() -> String { + let message_item = json!({ + "type": "response.output_item.done", + "item": { + "type": "message", + "id": "msg-policy", + "role": "assistant", + "content": [{"type": "output_text", "text": "done"}], + } + }); + let completed = completed_event("resp-policy-2"); + format!( + "event: response.output_item.done\ndata: {message_item}\n\n\ +event: response.completed\ndata: {completed}\n\n" + ) +} + +fn completed_event(id: &str) -> Value { + json!({ + "type": "response.completed", + "response": { + "id": id, + "usage": { + "input_tokens": 0, + "input_tokens_details": null, + "output_tokens": 0, + "output_tokens_details": null, + "total_tokens": 0 + } + } + }) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn skill_command_policy_blocks_before_exec_begin() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path_regex(".*/responses$")) + .respond_with(sse_response(tool_call_sse())) + .up_to_n_times(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path_regex(".*/responses$")) + .respond_with(sse_response(completed_message_sse())) + .up_to_n_times(1) + .mount(&server) + .await; + + let cwd = TempDir::new().unwrap(); + let skill_dir = cwd.path().join(".codex").join("skills").join("github"); + std::fs::create_dir_all(&skill_dir).unwrap(); + std::fs::write( + skill_dir.join("SKILL.md"), + r#"--- +name: github +description: GitHub workflow helper +policy: + command_policies: + - id: prefer-pr-merge-helper + match: + argv_prefix: ["gh", "pr", "merge"] + action: require_preferred + message: Raw gh pr merge bypasses the helper flow. + preferred: + - kind: script + path: scripts/gh-pr.py + example_argv: ["scripts/gh-pr.py", "merge", ""] + purpose: Merge through the helper. +--- +Use the GitHub helper. +"#, + ) + .unwrap(); + + let code_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&code_home); + config.cwd = cwd.path().to_path_buf(); + config.model_provider = ModelProviderInfo { + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers(None)["openai"].clone() + }; + config.model = "gpt-5.1-codex".to_string(); + config.model_family = find_family_for_model(&config.model).unwrap(); + config.approval_policy = AskForApproval::Never; + config.sandbox_policy = SandboxPolicy::DangerFullAccess; + config.include_apply_patch_tool = false; + config.include_view_image_tool = false; + config.tools_web_search_request = false; + config.include_plan_tool = false; + + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); + let codex = conversation_manager + .new_conversation(config) + .await + .expect("create new conversation") + .conversation; + + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "merge the pr".to_string(), + }], + final_output_json_schema: None, + }) + .await + .unwrap(); + + let mut saw_policy_background = false; + let mut saw_task_complete = false; + let mut saw_exec_begin = false; + for _ in 0..20 { + let event = wait_for_event(&codex, |_| true).await; + match event { + EventMsg::BackgroundEvent(ev) => { + if ev.message.contains("Command guard: Command policy matched skill `github`") { + saw_policy_background = true; + } + } + EventMsg::ExecCommandBegin(_) => saw_exec_begin = true, + EventMsg::TaskComplete(_) => { + saw_task_complete = true; + break; + } + _ => {} + } + } + + assert!(saw_policy_background, "policy guard background event was not emitted"); + assert!(saw_task_complete, "task did not complete"); + assert!(!saw_exec_begin, "matched policy should block before ExecCommandBegin"); + + let requests = server.received_requests().await.unwrap(); + assert_eq!(requests.len(), 2, "expected tool call and follow-up request"); + let follow_up_body: Value = requests[1].body_json().unwrap(); + let output = find_function_call_output_text(&follow_up_body) + .expect("follow-up request should contain policy tool output"); + assert!(output.contains("Command policy matched skill `github`")); + assert!(output.contains("Raw gh pr merge bypasses the helper flow.")); + assert!(output.contains("scripts/gh-pr.py")); +} + +fn find_function_call_output_text(value: &Value) -> Option<&str> { + match value { + Value::Object(map) => { + if map.get("type").and_then(Value::as_str) == Some("function_call_output") { + if let Some(text) = map.get("output").and_then(Value::as_str) { + return Some(text); + } + } + map.values().find_map(find_function_call_output_text) + } + Value::Array(items) => items.iter().find_map(find_function_call_output_text), + _ => None, + } +}