Skip to content

Commit f6d11bc

Browse files
committed
refactor(policy): align mcp deny rules
Signed-off-by: ddurst <267424412+ddurst-nvidia@users.noreply.github.com>
1 parent 6a15197 commit f6d11bc

6 files changed

Lines changed: 76 additions & 131 deletions

File tree

crates/openshell-policy/src/lib.rs

Lines changed: 53 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ struct NetworkEndpointDef {
108108
enforcement: String,
109109
#[serde(default, skip_serializing_if = "String::is_empty")]
110110
access: String,
111-
#[serde(default, skip_serializing_if = "RulesDef::is_empty")]
112-
rules: RulesDef,
111+
#[serde(default, skip_serializing_if = "Vec::is_empty")]
112+
rules: Vec<L7RuleDef>,
113113
#[serde(default, skip_serializing_if = "Vec::is_empty")]
114114
allowed_ips: Vec<String>,
115115
#[serde(default, skip_serializing_if = "Vec::is_empty")]
@@ -195,57 +195,6 @@ struct L7RuleDef {
195195
allow: L7AllowDef,
196196
}
197197

198-
// Preserve the original `rules: [{ allow: ... }]` shape while accepting the
199-
// newer grouped shape (`rules.allow` / `rules.deny`) used by MCP examples.
200-
#[derive(Debug, Serialize, Deserialize)]
201-
#[serde(untagged)]
202-
enum RulesDef {
203-
Legacy(Vec<L7RuleDef>),
204-
Grouped(L7RuleGroupsDef),
205-
}
206-
207-
impl Default for RulesDef {
208-
fn default() -> Self {
209-
Self::Legacy(Vec::new())
210-
}
211-
}
212-
213-
impl RulesDef {
214-
// Serde needs this for `skip_serializing_if`; keeping it on the enum keeps
215-
// call sites from having to know which rules shape was parsed.
216-
fn is_empty(&self) -> bool {
217-
match self {
218-
Self::Legacy(rules) => rules.is_empty(),
219-
Self::Grouped(groups) => groups.allow.is_empty() && groups.deny.is_empty(),
220-
}
221-
}
222-
223-
// The proto model still carries allow rules and deny rules separately. This
224-
// folds both YAML spellings back into that stable internal representation.
225-
fn into_parts(self) -> (Vec<L7RuleDef>, Vec<L7DenyRuleDef>) {
226-
match self {
227-
Self::Legacy(rules) => (rules, Vec::new()),
228-
Self::Grouped(groups) => (
229-
groups
230-
.allow
231-
.into_iter()
232-
.map(|allow| L7RuleDef { allow })
233-
.collect(),
234-
groups.deny,
235-
),
236-
}
237-
}
238-
}
239-
240-
#[derive(Debug, Serialize, Deserialize)]
241-
#[serde(deny_unknown_fields)]
242-
struct L7RuleGroupsDef {
243-
#[serde(default, skip_serializing_if = "Vec::is_empty")]
244-
allow: Vec<L7AllowDef>,
245-
#[serde(default, skip_serializing_if = "Vec::is_empty")]
246-
deny: Vec<L7DenyRuleDef>,
247-
}
248-
249198
#[derive(Debug, Serialize, Deserialize)]
250199
#[serde(deny_unknown_fields)]
251200
struct L7AllowDef {
@@ -459,15 +408,14 @@ fn params_with_tool(
459408
}
460409

461410
fn allow_def_to_proto(protocol: &str, allow: L7AllowDef) -> L7Allow {
462-
let (method, rpc_method) = if is_mcp_protocol(protocol) {
463-
let rpc_method = if allow.method.is_empty() {
411+
let method = if is_jsonrpc_family_protocol(protocol) {
412+
if allow.method.is_empty() {
464413
allow.rpc_method
465414
} else {
466415
allow.method
467-
};
468-
(String::new(), rpc_method)
416+
}
469417
} else {
470-
(allow.method, allow.rpc_method)
418+
allow.method
471419
};
472420

473421
L7Allow {
@@ -477,7 +425,6 @@ fn allow_def_to_proto(protocol: &str, allow: L7AllowDef) -> L7Allow {
477425
operation_type: allow.operation_type,
478426
operation_name: allow.operation_name,
479427
fields: allow.fields,
480-
rpc_method,
481428
query: allow
482429
.query
483430
.into_iter()
@@ -491,15 +438,14 @@ fn allow_def_to_proto(protocol: &str, allow: L7AllowDef) -> L7Allow {
491438
}
492439

493440
fn deny_def_to_proto(protocol: &str, deny: L7DenyRuleDef) -> L7DenyRule {
494-
let (method, rpc_method) = if is_mcp_protocol(protocol) {
495-
let rpc_method = if deny.method.is_empty() {
441+
let method = if is_jsonrpc_family_protocol(protocol) {
442+
if deny.method.is_empty() {
496443
deny.rpc_method
497444
} else {
498445
deny.method
499-
};
500-
(String::new(), rpc_method)
446+
}
501447
} else {
502-
(deny.method, deny.rpc_method)
448+
deny.method
503449
};
504450

505451
L7DenyRule {
@@ -509,7 +455,6 @@ fn deny_def_to_proto(protocol: &str, deny: L7DenyRuleDef) -> L7DenyRule {
509455
operation_type: deny.operation_type,
510456
operation_name: deny.operation_name,
511457
fields: deny.fields,
512-
rpc_method,
513458
query: deny
514459
.query
515460
.into_iter()
@@ -535,6 +480,14 @@ fn is_mcp_protocol(protocol: &str) -> bool {
535480
protocol.eq_ignore_ascii_case("mcp")
536481
}
537482

483+
fn is_jsonrpc_protocol(protocol: &str) -> bool {
484+
protocol.eq_ignore_ascii_case("json-rpc")
485+
}
486+
487+
fn is_jsonrpc_family_protocol(protocol: &str) -> bool {
488+
is_mcp_protocol(protocol) || is_jsonrpc_protocol(protocol)
489+
}
490+
538491
fn split_tool_param(
539492
protocol: &str,
540493
params: BTreeMap<String, QueryMatcherDef>,
@@ -566,9 +519,11 @@ fn allow_proto_to_def(protocol: &str, allow: L7Allow) -> L7AllowDef {
566519
let (tool, params) = split_tool_param(protocol, params);
567520
let params = flat_params_to_def(protocol, params);
568521
let (method, rpc_method) = if is_mcp_protocol(protocol) {
569-
(allow.rpc_method, String::new())
522+
(allow.method, String::new())
523+
} else if is_jsonrpc_protocol(protocol) {
524+
(String::new(), allow.method)
570525
} else {
571-
(allow.method, allow.rpc_method)
526+
(allow.method, String::new())
572527
};
573528
L7AllowDef {
574529
method,
@@ -597,9 +552,11 @@ fn deny_proto_to_def(protocol: &str, deny: &L7DenyRule) -> L7DenyRuleDef {
597552
let (tool, params) = split_tool_param(protocol, params);
598553
let params = flat_params_to_def(protocol, params);
599554
let (method, rpc_method) = if is_mcp_protocol(protocol) {
600-
(deny.rpc_method.clone(), String::new())
555+
(deny.method.clone(), String::new())
556+
} else if is_jsonrpc_protocol(protocol) {
557+
(String::new(), deny.method.clone())
601558
} else {
602-
(deny.method.clone(), deny.rpc_method.clone())
559+
(deny.method.clone(), String::new())
603560
};
604561
L7DenyRuleDef {
605562
method,
@@ -635,9 +592,8 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy {
635592
.into_iter()
636593
.map(|e| {
637594
let protocol = e.protocol;
638-
let (allow_rules, grouped_deny_rules) = e.rules.into_parts();
639-
let mut deny_rules = grouped_deny_rules;
640-
deny_rules.extend(e.deny_rules);
595+
let allow_rules = e.rules;
596+
let deny_rules = e.deny_rules;
641597
// Normalize port/ports: ports takes precedence, else
642598
// single port is promoted to ports array.
643599
let normalized_ports: Vec<u32> = if !e.ports.is_empty() {
@@ -770,39 +726,21 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile {
770726
(clamp(e.ports.first().copied().unwrap_or(e.port)), vec![])
771727
};
772728
let protocol = e.protocol.clone();
773-
let allow_defs: Vec<L7AllowDef> = e
729+
let rules = e
774730
.rules
775731
.iter()
776-
.map(|r| {
777-
allow_proto_to_def(&protocol, r.allow.clone().unwrap_or_default())
732+
.map(|r| L7RuleDef {
733+
allow: allow_proto_to_def(
734+
&protocol,
735+
r.allow.clone().unwrap_or_default(),
736+
),
778737
})
779738
.collect();
780-
let deny_defs: Vec<L7DenyRuleDef> = e
739+
let deny_rules: Vec<L7DenyRuleDef> = e
781740
.deny_rules
782741
.iter()
783742
.map(|d| deny_proto_to_def(&protocol, d))
784743
.collect();
785-
let (rules, deny_rules) = if is_mcp_protocol(&protocol)
786-
&& (!allow_defs.is_empty() || !deny_defs.is_empty())
787-
{
788-
(
789-
RulesDef::Grouped(L7RuleGroupsDef {
790-
allow: allow_defs,
791-
deny: deny_defs,
792-
}),
793-
Vec::new(),
794-
)
795-
} else {
796-
(
797-
RulesDef::Legacy(
798-
allow_defs
799-
.into_iter()
800-
.map(|allow| L7RuleDef { allow })
801-
.collect(),
802-
),
803-
deny_defs,
804-
)
805-
};
806744
let (json_rpc, mcp) = if is_mcp_protocol(&protocol) {
807745
(None, mcp_config_from_proto(e.json_rpc_max_body_bytes))
808746
} else {
@@ -2058,7 +1996,7 @@ network_policies:
20581996
}
20591997

20601998
#[test]
2061-
fn parse_grouped_mcp_rules_to_runtime_fields() {
1999+
fn parse_mcp_rules_to_runtime_fields() {
20622000
let yaml = r"
20632001
version: 1
20642002
network_policies:
@@ -2073,17 +2011,19 @@ network_policies:
20732011
mcp:
20742012
max_body_bytes: 131072
20752013
rules:
2076-
deny:
2077-
- method: tools/call
2078-
tool: send_email
2079-
allow:
2080-
- method: initialize
2081-
- method: tools/list
2082-
- method: tools/call
2014+
- allow:
2015+
method: initialize
2016+
- allow:
2017+
method: tools/list
2018+
- allow:
2019+
method: tools/call
20832020
tool: search_web
20842021
params:
20852022
arguments:
20862023
repository: NVIDIA/OpenShell
2024+
deny_rules:
2025+
- method: tools/call
2026+
tool: send_email
20872027
binaries:
20882028
- path: /usr/bin/curl
20892029
";
@@ -2093,7 +2033,7 @@ network_policies:
20932033
assert_eq!(ep.protocol, "mcp");
20942034
assert_eq!(ep.json_rpc_max_body_bytes, 131_072);
20952035
assert_eq!(ep.rules.len(), 3);
2096-
assert_eq!(ep.rules[2].allow.as_ref().unwrap().rpc_method, "tools/call");
2036+
assert_eq!(ep.rules[2].allow.as_ref().unwrap().method, "tools/call");
20972037
assert_eq!(
20982038
ep.rules[2].allow.as_ref().unwrap().params["name"].glob,
20992039
"search_web"
@@ -2103,7 +2043,7 @@ network_policies:
21032043
"NVIDIA/OpenShell"
21042044
);
21052045
assert_eq!(ep.deny_rules.len(), 1);
2106-
assert_eq!(ep.deny_rules[0].rpc_method, "tools/call");
2046+
assert_eq!(ep.deny_rules[0].method, "tools/call");
21072047
assert_eq!(ep.deny_rules[0].params["name"].glob, "send_email");
21082048
}
21092049

@@ -2121,15 +2061,15 @@ network_policies:
21212061
mcp:
21222062
max_body_bytes: 131072
21232063
rules:
2124-
allow:
2125-
- method: tools/call
2064+
- allow:
2065+
method: tools/call
21262066
tool: search_web
21272067
params:
21282068
arguments:
21292069
repository: NVIDIA/OpenShell
2130-
deny:
2131-
- method: tools/call
2132-
tool: send_email
2070+
deny_rules:
2071+
- method: tools/call
2072+
tool: send_email
21332073
binaries:
21342074
- path: /usr/bin/curl
21352075
";
@@ -2141,6 +2081,7 @@ network_policies:
21412081
assert!(yaml_out.contains("method: tools/call"));
21422082
assert!(yaml_out.contains("tool: search_web"));
21432083
assert!(yaml_out.contains("tool: send_email"));
2084+
assert!(yaml_out.contains("deny_rules:"));
21442085
assert!(yaml_out.contains("arguments:"));
21452086
assert!(yaml_out.contains("repository: NVIDIA/OpenShell"));
21462087
assert!(!yaml_out.contains("arguments.repository"));

crates/openshell-policy/src/merge.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,14 +736,13 @@ fn expand_access_preset(protocol: &str, access: &str) -> Option<Vec<L7Rule>> {
736736
.into_iter()
737737
.map(|rpc_method| L7Rule {
738738
allow: Some(L7Allow {
739-
method: String::new(),
739+
method: rpc_method.to_string(),
740740
path: String::new(),
741741
command: String::new(),
742742
query: HashMap::default(),
743743
operation_type: String::new(),
744744
operation_name: String::new(),
745745
fields: Vec::new(),
746-
rpc_method: rpc_method.to_string(),
747746
params: HashMap::default(),
748747
}),
749748
})

crates/openshell-supervisor-network/src/l7/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,6 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec<String>, Vec<
10301030
"{deny_loc}: GraphQL rule fields are ignored unless protocol is graphql or websocket"
10311031
));
10321032
}
1033-
10341033
}
10351034
}
10361035
}

crates/openshell-supervisor-network/src/l7/relay.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ fn engine_type_for_protocol(protocol: L7Protocol) -> &'static str {
140140
match protocol {
141141
L7Protocol::Graphql => "l7-graphql",
142142
L7Protocol::JsonRpc => "l7-jsonrpc",
143+
L7Protocol::Mcp => "l7-mcp",
143144
L7Protocol::Websocket => "l7-websocket",
144145
L7Protocol::Rest | L7Protocol::Sql => "l7",
145146
}
@@ -567,7 +568,7 @@ fn l7_protocol_log_summary(
567568
if let Some(info) = jsonrpc_info {
568569
return format!(
569570
" rpc_methods={} params_sha256={}",
570-
jsonrpc_methods_for_log(info),
571+
rule_method_names_for_log(info),
571572
info.params_sha256()
572573
.unwrap_or_else(|| "<empty>".to_string())
573574
);
@@ -2457,14 +2458,16 @@ network_policies:
24572458
mcp:
24582459
max_body_bytes: 131072
24592460
rules:
2460-
deny:
2461-
- method: tools/call
2462-
tool: delete_resource
2463-
allow:
2464-
- method: initialize
2465-
- method: tools/list
2466-
- method: tools/call
2461+
- allow:
2462+
method: initialize
2463+
- allow:
2464+
method: tools/list
2465+
- allow:
2466+
method: tools/call
24672467
tool: read_status
2468+
deny_rules:
2469+
- method: tools/call
2470+
tool: delete_resource
24682471
binaries:
24692472
- { path: /usr/bin/node }
24702473
"#;

0 commit comments

Comments
 (0)