From c29e13274f46fd34a425880af923a356a1539f3a Mon Sep 17 00:00:00 2001 From: brozorec <9572072+brozorec@users.noreply.github.com> Date: Sun, 19 Apr 2026 18:24:24 +0200 Subject: [PATCH 1/3] feat: add invocation tree utilities for context rule ID resolution Add pure utility functions for inspecting Soroban invocation trees and resolving context_rule_ids for multi-contract operations (e.g. deposit that internally calls transfer). Includes kit.hintContextRuleIds() and kit.resolveContextRuleIds() convenience methods, comprehensive tests, and public exports. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/index.ts | 13 ++ src/kit.ts | 90 +++++++++ src/kit/invocation-utils.test.ts | 311 ++++++++++++++++++++++++++++ src/kit/invocation-utils.ts | 336 +++++++++++++++++++++++++++++++ 4 files changed, 750 insertions(+) create mode 100644 src/kit/invocation-utils.test.ts create mode 100644 src/kit/invocation-utils.ts diff --git a/src/index.ts b/src/index.ts index 0382513..b1cc3c0 100644 --- a/src/index.ts +++ b/src/index.ts @@ -185,5 +185,18 @@ export type { export { StellarWalletsKitAdapter } from "./wallet-adapter"; export type { StellarWalletsKitAdapterConfig } from "./wallet-adapter"; +// Invocation tree utilities — for inspecting auth_context structure and +// determining contextRuleIds before signing transactions with sub-invocations. +export { + countAuthContexts, + walkInvocationTree, + validateContextRuleIds, + hintContextRuleIds, + resolveContextRuleIds, + type InvocationNode, + type ContextRuleMatch, + type InvocationContextHint, +} from "./kit/invocation-utils"; + // Re-export stellar-sdk types for convenience export type { AssembledTransaction } from "@stellar/stellar-sdk/contract"; diff --git a/src/kit.ts b/src/kit.ts index fedb3fa..2e23d75 100644 --- a/src/kit.ts +++ b/src/kit.ts @@ -122,8 +122,13 @@ import { import { convertPolicyParams, buildPoliciesScVal } from "./kit/policies-ops"; import { findWebAuthnSignerForCredential, + listContextRules, resolveContextRuleIdsForEntry, } from "./kit/context-rules"; +import { + hintContextRuleIds as hintContextRuleIdsFromTree, + type InvocationContextHint, +} from "./kit/invocation-utils"; import { validateAddress, validateAmount, xlmToStroops } from "./utils"; @@ -1213,6 +1218,91 @@ export class SmartAccountKit { return this.signAndSubmit(transaction, options); } + // ========================================================================== + // Invocation Tree Utilities + // ========================================================================== + + /** + * Get per-node rule suggestions for a multi-contract invocation tree. + * + * Walks the invocation tree depth-first, fetches all on-chain context rules, + * and matches each node against the rules by specificity: + * 1. `CallContract(address)` — most specific + * 2. `CreateContract(wasmHash)` — intermediate + * 3. `Default` — catch-all + * + * Use this to inspect which rules apply before signing, or to let users + * choose when multiple rules match a single node. + * + * @param authEntry - The authorization entry to inspect + * @param options - Optional settings + * @returns Array of hints, one per auth_context node in the tree + * + * @example + * ```typescript + * const hints = await kit.hintContextRuleIds(authEntry); + * // hints[0].suggestedRuleId = 1 (CallContract rule for deposit contract) + * // hints[1].suggestedRuleId = 0 (Default rule for transfer) + * + * const ruleIds = hints.map(h => h.suggestedRuleId); + * await kit.signAndSubmit(tx, { resolveContextRuleIds: () => ruleIds }); + * ``` + */ + async hintContextRuleIds( + authEntry: xdr.SorobanAuthorizationEntry, + options?: { + /** Rule ID used when no rule matches a node (default 0) */ + defaultRuleId?: number; + } + ): Promise { + const { wallet, contractId } = this.requireWallet(); + const discoveryDeps = { + getContractDetailsFromIndexer: () => this.getContractDetailsFromIndexer(contractId), + probeRuleIds: this.probeRuleIds, + rpc: this.rpc, + contractId, + networkPassphrase: this.networkPassphrase, + timeoutInSeconds: this.timeoutInSeconds, + }; + const rules = await listContextRules(wallet, discoveryDeps); + + return hintContextRuleIdsFromTree( + authEntry.rootInvocation(), + rules, + options?.defaultRuleId + ); + } + + /** + * Auto-resolve context rule IDs for every auth_context in an invocation tree. + * + * This is the non-interactive version of {@link hintContextRuleIds} — it + * returns only the suggested IDs. Use `hintContextRuleIds` when you need to + * inspect or override individual suggestions. + * + * @param authEntry - The authorization entry to resolve + * @param options - Optional settings + * @returns Array of rule IDs, one per auth_context, ready to use as `contextRuleIds` + * + * @example + * ```typescript + * const ruleIds = await kit.resolveContextRuleIds(authEntry); + * // -> [1, 0] (deposit contract: rule 1, transfer: default rule 0) + * + * await kit.signAndSubmit(tx, { resolveContextRuleIds: () => ruleIds }); + * ``` + */ + async resolveContextRuleIds( + authEntry: xdr.SorobanAuthorizationEntry, + options?: { + /** Rule ID used when no rule matches a node (default 0) */ + defaultRuleId?: number; + } + ): Promise { + const hints = await this.hintContextRuleIds(authEntry, options); + return hints.map((h) => h.suggestedRuleId); + } + // ========================================================================== // Private Helpers // ========================================================================== diff --git a/src/kit/invocation-utils.test.ts b/src/kit/invocation-utils.test.ts new file mode 100644 index 0000000..f60cc31 --- /dev/null +++ b/src/kit/invocation-utils.test.ts @@ -0,0 +1,311 @@ +import { Address, Keypair, xdr } from "@stellar/stellar-sdk"; +import { describe, expect, it } from "vitest"; +import type { ContextRule, ContextRuleType, Signer } from "smart-account-kit-bindings"; +import { + countAuthContexts, + walkInvocationTree, + validateContextRuleIds, + hintContextRuleIds, + resolveContextRuleIds, +} from "./invocation-utils"; + +function makeRule( + id: number, + contextType: ContextRuleType, + name?: string +): ContextRule { + return { + id, + context_type: contextType, + name: name ?? `rule-${id}`, + signers: [], + signer_ids: [], + policies: [], + policy_ids: [], + valid_until: undefined, + }; +} + +function makeAccount(seedByte: number): string { + return Keypair.fromRawEd25519Seed(Buffer.alloc(32, seedByte)).publicKey(); +} + +function makeContractInvocation( + contractId: string, + functionName: string, + subInvocations: xdr.SorobanAuthorizedInvocation[] = [] +): xdr.SorobanAuthorizedInvocation { + return new xdr.SorobanAuthorizedInvocation({ + function: + xdr.SorobanAuthorizedFunction.sorobanAuthorizedFunctionTypeContractFn( + new xdr.InvokeContractArgs({ + contractAddress: Address.fromString(contractId).toScAddress(), + functionName, + args: [], + }) + ), + subInvocations, + }); +} + +function makeAuthEntry( + rootInvocation: xdr.SorobanAuthorizedInvocation +): xdr.SorobanAuthorizationEntry { + return new xdr.SorobanAuthorizationEntry({ + credentials: xdr.SorobanCredentials.sorobanCredentialsAddress( + new xdr.SorobanAddressCredentials({ + address: Address.fromString(makeAccount(99)).toScAddress(), + nonce: xdr.Int64.fromString("1"), + signatureExpirationLedger: 1, + signature: xdr.ScVal.scvVoid(), + }) + ), + rootInvocation, + }); +} + +// Stable contract addresses for tests +const CONTRACT_A = "CDANWYENKH6PTTY6GDTMDAMYRHMU4SBRPX5NUDYDMTYVOIF32ASZFU4Y"; +const CONTRACT_B = "CBSHV66WG7UV6FQVUTB67P3DZUEJ2KJ5X6JKQH5MFRAAFNFJUAJVXJYV"; +const CONTRACT_C = "CDLZFC3SYJYDZT7K67VZ75HPJVIEUVNIXF47ZG2FB2RMQQVU2HHGCYSC"; + +describe("invocation-utils", () => { + // ========================================================================== + // countAuthContexts + // ========================================================================== + + it("counts a single invocation as 1", () => { + const invocation = makeContractInvocation(CONTRACT_A, "transfer"); + + expect(countAuthContexts(invocation)).toBe(1); + }); + + it("counts root + 2 sub-invocations as 3", () => { + const invocation = makeContractInvocation(CONTRACT_A, "deposit", [ + makeContractInvocation(CONTRACT_B, "transfer"), + makeContractInvocation(CONTRACT_C, "approve"), + ]); + + expect(countAuthContexts(invocation)).toBe(3); + }); + + it("counts a deeply nested chain (3 levels) as 3", () => { + const invocation = makeContractInvocation(CONTRACT_A, "swap", [ + makeContractInvocation(CONTRACT_B, "deposit", [ + makeContractInvocation(CONTRACT_C, "transfer"), + ]), + ]); + + expect(countAuthContexts(invocation)).toBe(3); + }); + + // ========================================================================== + // walkInvocationTree + // ========================================================================== + + it("walks a single contract call and returns one node", () => { + const invocation = makeContractInvocation(CONTRACT_A, "transfer"); + const nodes = walkInvocationTree(invocation); + + expect(nodes).toEqual([ + { index: 0, contractAddress: CONTRACT_A, functionName: "transfer" }, + ]); + }); + + it("walks nested calls in depth-first order with sequential indices", () => { + const invocation = makeContractInvocation(CONTRACT_A, "deposit", [ + makeContractInvocation(CONTRACT_B, "transfer"), + makeContractInvocation(CONTRACT_C, "approve"), + ]); + + const nodes = walkInvocationTree(invocation); + + expect(nodes).toEqual([ + { index: 0, contractAddress: CONTRACT_A, functionName: "deposit" }, + { index: 1, contractAddress: CONTRACT_B, functionName: "transfer" }, + { index: 2, contractAddress: CONTRACT_C, functionName: "approve" }, + ]); + }); + + it("walks a deeply nested chain in depth-first order", () => { + const invocation = makeContractInvocation(CONTRACT_A, "swap", [ + makeContractInvocation(CONTRACT_B, "deposit", [ + makeContractInvocation(CONTRACT_C, "transfer"), + ]), + ]); + + const nodes = walkInvocationTree(invocation); + + expect(nodes).toEqual([ + { index: 0, contractAddress: CONTRACT_A, functionName: "swap" }, + { index: 1, contractAddress: CONTRACT_B, functionName: "deposit" }, + { index: 2, contractAddress: CONTRACT_C, functionName: "transfer" }, + ]); + }); + + // ========================================================================== + // validateContextRuleIds + // ========================================================================== + + it("passes silently when contextRuleIds length matches", () => { + const invocation = makeContractInvocation(CONTRACT_A, "deposit", [ + makeContractInvocation(CONTRACT_B, "transfer"), + ]); + + expect(() => validateContextRuleIds([0, 1], invocation)).not.toThrow(); + }); + + it("throws a descriptive error when length mismatches", () => { + const invocation = makeContractInvocation(CONTRACT_A, "deposit", [ + makeContractInvocation(CONTRACT_B, "transfer"), + ]); + + expect(() => validateContextRuleIds([0], invocation)).toThrow( + /contextRuleIds length \(1\) does not match auth_contexts count \(2\)/ + ); + }); + + it("includes the tree dump and tip in the error message", () => { + const invocation = makeContractInvocation(CONTRACT_A, "deposit", [ + makeContractInvocation(CONTRACT_B, "transfer"), + ]); + + try { + validateContextRuleIds([], invocation); + expect.unreachable("should have thrown"); + } catch (error) { + const message = (error as Error).message; + expect(message).toContain(CONTRACT_A); + expect(message).toContain("deposit"); + expect(message).toContain(CONTRACT_B); + expect(message).toContain("transfer"); + expect(message).toContain("contextRuleIds: [0, 0]"); + expect(message).toContain("kit.hintContextRuleIds"); + } + }); + + // ========================================================================== + // hintContextRuleIds + // ========================================================================== + + it("prefers a CallContract rule over a Default rule", () => { + const invocation = makeContractInvocation(CONTRACT_A, "transfer"); + const rules = [ + makeRule(0, { tag: "Default", values: undefined }, "default"), + makeRule(1, { tag: "CallContract", values: [CONTRACT_A] }, "token-transfer"), + ]; + + const hints = hintContextRuleIds(invocation, rules); + + expect(hints).toHaveLength(1); + expect(hints[0].suggestedRuleId).toBe(1); + expect(hints[0].matchingRules[0].contextType).toBe("CallContract"); + expect(hints[0].matchingRules[1].contextType).toBe("Default"); + }); + + it("falls back to Default when no specific rule matches", () => { + const invocation = makeContractInvocation(CONTRACT_A, "transfer"); + const rules = [ + makeRule(0, { tag: "Default", values: undefined }, "default"), + makeRule(1, { tag: "CallContract", values: [CONTRACT_B] }, "other-contract"), + ]; + + const hints = hintContextRuleIds(invocation, rules); + + expect(hints).toHaveLength(1); + expect(hints[0].suggestedRuleId).toBe(0); + expect(hints[0].matchingRules).toHaveLength(1); + expect(hints[0].matchingRules[0].contextType).toBe("Default"); + }); + + it("sorts matchingRules by specificity (CallContract > Default)", () => { + const invocation = makeContractInvocation(CONTRACT_A, "transfer"); + const rules = [ + makeRule(0, { tag: "Default", values: undefined }), + makeRule(1, { tag: "CallContract", values: [CONTRACT_A] }), + makeRule(2, { tag: "Default", values: undefined }), + ]; + + const hints = hintContextRuleIds(invocation, rules); + const types = hints[0].matchingRules.map((m) => m.contextType); + + expect(types).toEqual(["CallContract", "Default", "Default"]); + }); + + it("falls back to defaultRuleId when no rules match at all", () => { + const invocation = makeContractInvocation(CONTRACT_A, "transfer"); + + const hints = hintContextRuleIds(invocation, [], 42); + + expect(hints).toHaveLength(1); + expect(hints[0].suggestedRuleId).toBe(42); + expect(hints[0].matchingRules).toHaveLength(0); + }); + + it("resolves a nested tree with mixed specific and default rules", () => { + const invocation = makeContractInvocation(CONTRACT_A, "deposit", [ + makeContractInvocation(CONTRACT_B, "transfer"), + ]); + const rules = [ + makeRule(0, { tag: "Default", values: undefined }, "default"), + makeRule(1, { tag: "CallContract", values: [CONTRACT_A] }, "deposit-rule"), + ]; + + const hints = hintContextRuleIds(invocation, rules); + + expect(hints).toHaveLength(2); + // Deposit contract has a specific CallContract rule + expect(hints[0].suggestedRuleId).toBe(1); + expect(hints[0].contractAddress).toBe(CONTRACT_A); + expect(hints[0].functionName).toBe("deposit"); + // Transfer contract only matches the Default rule + expect(hints[1].suggestedRuleId).toBe(0); + expect(hints[1].contractAddress).toBe(CONTRACT_B); + expect(hints[1].functionName).toBe("transfer"); + }); + + it("includes node metadata in each hint", () => { + const invocation = makeContractInvocation(CONTRACT_A, "transfer"); + const hints = hintContextRuleIds(invocation, [], 0); + + expect(hints[0].index).toBe(0); + expect(hints[0].contractAddress).toBe(CONTRACT_A); + expect(hints[0].functionName).toBe("transfer"); + }); + + // ========================================================================== + // resolveContextRuleIds + // ========================================================================== + + it("returns an array of suggested rule IDs", () => { + const invocation = makeContractInvocation(CONTRACT_A, "deposit", [ + makeContractInvocation(CONTRACT_B, "transfer"), + ]); + const rules = [ + makeRule(0, { tag: "Default", values: undefined }), + makeRule(1, { tag: "CallContract", values: [CONTRACT_A] }), + ]; + + const ruleIds = resolveContextRuleIds(invocation, rules); + + expect(ruleIds).toEqual([1, 0]); + }); + + it("uses defaultRuleId for all nodes when no rules are provided", () => { + const invocation = makeContractInvocation(CONTRACT_A, "deposit", [ + makeContractInvocation(CONTRACT_B, "transfer"), + ]); + + const ruleIds = resolveContextRuleIds(invocation, [], 5); + + expect(ruleIds).toEqual([5, 5]); + }); + + it("uses 0 as the default fallback rule ID", () => { + const invocation = makeContractInvocation(CONTRACT_A, "transfer"); + + const ruleIds = resolveContextRuleIds(invocation, []); + + expect(ruleIds).toEqual([0]); + }); +}); diff --git a/src/kit/invocation-utils.ts b/src/kit/invocation-utils.ts new file mode 100644 index 0000000..81b6bdf --- /dev/null +++ b/src/kit/invocation-utils.ts @@ -0,0 +1,336 @@ +/** + * Utilities for inspecting Soroban invocation trees and resolving context rule IDs. + * + * When the smart account's `__check_auth` is called, it receives one `auth_context` + * per node in the `SorobanAuthorizedInvocation` tree, traversed depth-first. The + * `AuthPayload.context_rule_ids` array must have exactly one entry per `auth_context`, + * aligned by index. + * + * For a simple transfer there is one node -> `contextRuleIds: [0]`. + * For a deposit that internally calls transfer there are two nodes -> + * `contextRuleIds: [ruleForDeposit, ruleForTransfer]`. + * + * @example + * ```typescript + * // 1. Simulate the transaction to get auth entries + * const simResult = await rpc.simulateTransaction(tx); + * const authEntry = simResult.result!.auth[0]; + * + * // 2. Inspect the tree to see how many auth_contexts there are + * const nodes = walkInvocationTree(authEntry.rootInvocation()); + * // nodes[0] = { index: 0, contractAddress: "CB7Z...", functionName: "deposit" } + * // nodes[1] = { index: 1, contractAddress: "CDSL...", functionName: "transfer" } + * + * // 3a. Let the kit auto-suggest rule IDs from on-chain rules + * const hints = await kit.hintContextRuleIds(authEntry); + * // hints[0].suggestedRuleId = 1 (CallContract rule for the deposit contract) + * // hints[1].suggestedRuleId = 0 (Default rule) + * + * // 3b. Or get just the IDs if you trust the suggestions + * const ruleIds = await kit.resolveContextRuleIds(authEntry); + * // -> [1, 0] + * + * // 4. Pass to the signing operation + * await kit.signAndSubmit(assembledTx, { + * resolveContextRuleIds: () => ruleIds, + * }); + * ``` + * + * @packageDocumentation + */ + +import { Address, xdr } from "@stellar/stellar-sdk"; +import type { ContextRule } from "smart-account-kit-bindings"; + +// ============================================================================ +// Types +// ============================================================================ + +/** + * Metadata about a single node in a `SorobanAuthorizedInvocation` tree. + * Each node corresponds to one `auth_context` passed to `__check_auth`. + */ +export interface InvocationNode { + /** Zero-based position in the depth-first traversal — equals the index into `context_rule_ids`. */ + index: number; + /** Contract address for contract-function invocations, undefined otherwise. */ + contractAddress?: string; + /** Function name for contract-function invocations, undefined otherwise. */ + functionName?: string; +} + +/** + * A single on-chain context rule that matches an invocation node. + */ +export interface ContextRuleMatch { + /** On-chain rule ID. */ + ruleId: number; + /** Human-readable rule name. */ + ruleName: string; + /** Why this rule matched: specific contract, wasm hash, or catch-all default. */ + contextType: "Default" | "CallContract" | "CreateContract"; + /** Human-readable description of why this rule was selected. */ + reason: string; +} + +/** + * Hint for a single auth_context node: which rules match and which is suggested. + */ +export interface InvocationContextHint { + /** Zero-based index — use as the position in `contextRuleIds`. */ + index: number; + /** Contract address (undefined for non-contract-function invocations). */ + contractAddress?: string; + /** Function name (undefined for non-contract-function invocations). */ + functionName?: string; + /** + * Recommended rule ID for this context. + * The most specific matching rule is preferred (CallContract > Default). + * Falls back to `defaultRuleId` when no rule explicitly matches. + */ + suggestedRuleId: number; + /** + * All rules that match this context, ordered by specificity. + * Inspect this to detect ambiguity or choose a different rule. + */ + matchingRules: ContextRuleMatch[]; +} + +// ============================================================================ +// Tree Traversal +// ============================================================================ + +/** + * Count the total number of auth_contexts that will be produced from an + * invocation tree (depth-first, one per node). + * + * @param invocation - Root of the invocation tree + * @returns Total number of auth_context entries + * + * @example + * ```typescript + * const count = countAuthContexts(authEntry.rootInvocation()); + * // For deposit -> transfer: count === 2 + * ``` + */ +export function countAuthContexts( + invocation: xdr.SorobanAuthorizedInvocation +): number { + let count = 1; + for (const sub of invocation.subInvocations()) { + count += countAuthContexts(sub); + } + return count; +} + +/** + * Walk an invocation tree depth-first and return a flat list of node metadata. + * The list index matches the `context_rule_ids` position. + * + * @param invocation - Root of the invocation tree + * @returns Flat array of node metadata in depth-first order + * + * @example + * ```typescript + * const nodes = walkInvocationTree(authEntry.rootInvocation()); + * // nodes[0] -> root invocation + * // nodes[1] -> first sub-invocation + * ``` + */ +export function walkInvocationTree( + invocation: xdr.SorobanAuthorizedInvocation +): InvocationNode[] { + const nodes: InvocationNode[] = []; + walkRecursive(invocation, nodes); + return nodes; +} + +function walkRecursive( + invocation: xdr.SorobanAuthorizedInvocation, + nodes: InvocationNode[] +): void { + const fn = invocation.function(); + const node: InvocationNode = { index: nodes.length }; + + if (fn.switch().name === "sorobanAuthorizedFunctionTypeContractFn") { + const contractFn = fn.contractFn(); + node.contractAddress = Address.fromScAddress( + contractFn.contractAddress() + ).toString(); + node.functionName = contractFn.functionName().toString(); + } + + nodes.push(node); + + for (const sub of invocation.subInvocations()) { + walkRecursive(sub, nodes); + } +} + +// ============================================================================ +// Validation +// ============================================================================ + +/** + * Validate that `contextRuleIds` length exactly matches the auth_contexts count + * of the invocation tree. Throws a descriptive error on mismatch, including + * the full tree so the caller can build the correct array. + * + * @param contextRuleIds - Array of context rule IDs to validate + * @param invocation - Root of the invocation tree + * @throws Error if the array length does not match the number of auth_contexts + * + * @example + * ```typescript + * // Throws if contextRuleIds.length !== auth_contexts count + * validateContextRuleIds([0], authEntry.rootInvocation()); + * ``` + */ +export function validateContextRuleIds( + contextRuleIds: number[], + invocation: xdr.SorobanAuthorizedInvocation +): void { + const contextCount = countAuthContexts(invocation); + if (contextRuleIds.length === contextCount) return; + + const nodes = walkInvocationTree(invocation); + const treeLines = nodes.map( + (n) => + ` [${n.index}] ${ + n.contractAddress + ? `${n.contractAddress}::${n.functionName ?? "?"}` + : "" + }` + ); + + throw new Error( + `contextRuleIds length (${contextRuleIds.length}) does not match ` + + `auth_contexts count (${contextCount}).\n` + + `Invocation tree — ${contextCount} auth_context${contextCount === 1 ? "" : "s"} (depth-first):\n` + + treeLines.join("\n") + + "\n" + + `Pass exactly ${contextCount} rule ID${contextCount === 1 ? "" : "s"}, e.g. ` + + `contextRuleIds: [${nodes.map(() => "0").join(", ")}]\n` + + `Tip: use kit.hintContextRuleIds(authEntry) to get per-node suggestions.` + ); +} + +// ============================================================================ +// Hint / Resolution +// ============================================================================ + +/** + * For each auth_context node in the invocation tree, find all on-chain context + * rules that match it and return a prioritised hint. + * + * Matching priority (most-specific first): + * 1. `CallContract(address)` — matches when the node's contract address equals the rule value + * 2. `CreateContract(wasmHash)` — matches when the node is a non-contract invocation + * 3. `Default` — catch-all, matches every node + * + * When multiple rules of the same priority match, all are listed in `matchingRules` + * so the caller can detect and resolve ambiguity manually. + * + * @param invocation - Root of the invocation tree (from `authEntry.rootInvocation()`) + * @param rules - On-chain context rules to match against + * @param defaultRuleId - Rule ID used when no rule explicitly matches (default 0) + * @returns Array of hints, one per auth_context node + * + * @example + * ```typescript + * const hints = hintContextRuleIds(authEntry.rootInvocation(), rules); + * // hints[0].suggestedRuleId = 1 (CallContract rule for deposit contract) + * // hints[0].matchingRules = [{ ruleId: 1, contextType: "CallContract", ... }, ...] + * ``` + */ +export function hintContextRuleIds( + invocation: xdr.SorobanAuthorizedInvocation, + rules: ContextRule[], + defaultRuleId: number = 0 +): InvocationContextHint[] { + const nodes = walkInvocationTree(invocation); + + return nodes.map((node) => { + const matchingRules: ContextRuleMatch[] = []; + + for (const rule of rules) { + const ct = rule.context_type; + + if (ct.tag === "CallContract") { + if (node.contractAddress && ct.values[0] === node.contractAddress) { + matchingRules.push({ + ruleId: rule.id, + ruleName: rule.name, + contextType: "CallContract", + reason: `CallContract rule for ${node.contractAddress}`, + }); + } + } else if (ct.tag === "CreateContract") { + if (!node.contractAddress) { + matchingRules.push({ + ruleId: rule.id, + ruleName: rule.name, + contextType: "CreateContract", + reason: "CreateContract rule", + }); + } + } else { + matchingRules.push({ + ruleId: rule.id, + ruleName: rule.name, + contextType: "Default", + reason: "Default rule (matches any context)", + }); + } + } + + const specificity: Record = { + CallContract: 0, + CreateContract: 1, + Default: 2, + }; + matchingRules.sort( + (a, b) => specificity[a.contextType] - specificity[b.contextType] + ); + + const suggestedRuleId = + matchingRules.length > 0 ? matchingRules[0].ruleId : defaultRuleId; + + return { + index: node.index, + contractAddress: node.contractAddress, + functionName: node.functionName, + suggestedRuleId, + matchingRules, + }; + }); +} + +/** + * Resolve context rule IDs for every auth_context in an invocation tree by + * matching each node against on-chain rules. + * + * This is the non-interactive version of {@link hintContextRuleIds} — it returns + * only the suggested IDs. Use `hintContextRuleIds` when you need to inspect + * or override individual suggestions. + * + * @param invocation - Root of the invocation tree + * @param rules - On-chain context rules + * @param defaultRuleId - Fallback rule ID when no rule matches (default 0) + * @returns Array of rule IDs, one per auth_context, ready to pass as `contextRuleIds` + * + * @example + * ```typescript + * const ruleIds = resolveContextRuleIds(authEntry.rootInvocation(), rules); + * // -> [1, 0] (one ID per auth_context) + * ``` + */ +export function resolveContextRuleIds( + invocation: xdr.SorobanAuthorizedInvocation, + rules: ContextRule[], + defaultRuleId: number = 0 +): number[] { + return hintContextRuleIds(invocation, rules, defaultRuleId).map( + (h) => h.suggestedRuleId + ); +} From f6ff361c7d0bc3821c11455de342a3ff9b795550 Mon Sep 17 00:00:00 2001 From: brozorec <9572072+brozorec@users.noreply.github.com> Date: Tue, 21 Apr 2026 09:15:20 -0600 Subject: [PATCH 2/3] fix: optimize utilities --- src/kit/invocation-utils.ts | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/src/kit/invocation-utils.ts b/src/kit/invocation-utils.ts index 81b6bdf..73573dc 100644 --- a/src/kit/invocation-utils.ts +++ b/src/kit/invocation-utils.ts @@ -42,6 +42,12 @@ import { Address, xdr } from "@stellar/stellar-sdk"; import type { ContextRule } from "smart-account-kit-bindings"; +const CONTEXT_TYPE_SPECIFICITY: Record = { + CallContract: 0, + CreateContract: 1, + Default: 2, +}; + // ============================================================================ // Types // ============================================================================ @@ -116,11 +122,7 @@ export interface InvocationContextHint { export function countAuthContexts( invocation: xdr.SorobanAuthorizedInvocation ): number { - let count = 1; - for (const sub of invocation.subInvocations()) { - count += countAuthContexts(sub); - } - return count; + return walkInvocationTree(invocation).length; } /** @@ -190,10 +192,9 @@ export function validateContextRuleIds( contextRuleIds: number[], invocation: xdr.SorobanAuthorizedInvocation ): void { - const contextCount = countAuthContexts(invocation); - if (contextRuleIds.length === contextCount) return; - const nodes = walkInvocationTree(invocation); + if (contextRuleIds.length === nodes.length) return; + const treeLines = nodes.map( (n) => ` [${n.index}] ${ @@ -205,11 +206,11 @@ export function validateContextRuleIds( throw new Error( `contextRuleIds length (${contextRuleIds.length}) does not match ` + - `auth_contexts count (${contextCount}).\n` + - `Invocation tree — ${contextCount} auth_context${contextCount === 1 ? "" : "s"} (depth-first):\n` + + `auth_contexts count (${nodes.length}).\n` + + `Invocation tree — ${nodes.length} auth_context${nodes.length === 1 ? "" : "s"} (depth-first):\n` + treeLines.join("\n") + "\n" + - `Pass exactly ${contextCount} rule ID${contextCount === 1 ? "" : "s"}, e.g. ` + + `Pass exactly ${nodes.length} rule ID${nodes.length === 1 ? "" : "s"}, e.g. ` + `contextRuleIds: [${nodes.map(() => "0").join(", ")}]\n` + `Tip: use kit.hintContextRuleIds(authEntry) to get per-node suggestions.` ); @@ -284,13 +285,10 @@ export function hintContextRuleIds( } } - const specificity: Record = { - CallContract: 0, - CreateContract: 1, - Default: 2, - }; matchingRules.sort( - (a, b) => specificity[a.contextType] - specificity[b.contextType] + (a, b) => + CONTEXT_TYPE_SPECIFICITY[a.contextType] - + CONTEXT_TYPE_SPECIFICITY[b.contextType] ); const suggestedRuleId = From 791b16e9bb1f4610e5e0199b809471d556ad604d Mon Sep 17 00:00:00 2001 From: brozorec <9572072+brozorec@users.noreply.github.com> Date: Wed, 22 Apr 2026 05:13:58 -0600 Subject: [PATCH 3/3] fix: unify traversal --- src/kit/context-rules.ts | 92 ++++---------------------------- src/kit/invocation-utils.test.ts | 83 +++++++++++++++++++++++++++- src/kit/invocation-utils.ts | 66 ++++++++++++++++++++++- 3 files changed, 157 insertions(+), 84 deletions(-) diff --git a/src/kit/context-rules.ts b/src/kit/context-rules.ts index 665e35e..bbf9ea9 100644 --- a/src/kit/context-rules.ts +++ b/src/kit/context-rules.ts @@ -22,6 +22,7 @@ import { } from "../signer-utils"; import type { ContractDetailsResponse } from "../indexer"; import { BASE_FEE } from "../constants"; +import { walkInvocationTree } from "./invocation-utils"; type ContextRuleQueryClient = { get_context_rule: (args: { context_rule_id: number }) => Promise>; @@ -78,37 +79,17 @@ export function contextRuleTypeMatches( export function buildInvocationContextTypes( entry: xdr.SorobanAuthorizationEntry ): ContextRuleType[] { - const contexts: ContextRuleType[] = []; - - const walk = (invocation: xdr.SorobanAuthorizedInvocation) => { - const fn = invocation.function(); - const switchName = fn.switch().name; - - if (switchName === "sorobanAuthorizedFunctionTypeContractFn") { - const args = fn.contractFn(); - contexts.push({ - tag: "CallContract", - values: [Address.fromScAddress(args.contractAddress()).toString()], - }); - } else if (switchName.startsWith("sorobanAuthorizedFunctionTypeCreateContract")) { - const wasmHash = extractCreateContractWasmHash(fn); - if (!wasmHash) { - throw new Error("Unable to extract WASM hash from create-contract authorization entry"); - } - - contexts.push({ - tag: "CreateContract", - values: [wasmHash], - }); + return walkInvocationTree(entry.rootInvocation()).map((node) => { + if (node.contractAddress) { + return { tag: "CallContract", values: [node.contractAddress] } as ContextRuleType; } - - for (const sub of invocation.subInvocations()) { - walk(sub); + if (node.wasmHash) { + return { tag: "CreateContract", values: [node.wasmHash] } as ContextRuleType; } - }; - - walk(entry.rootInvocation()); - return contexts; + throw new Error( + "Unable to determine context type for invocation node" + ); + }); } function hasRpcReadConfig( @@ -558,59 +539,6 @@ export async function resolveContextRuleIdsForEntry( }); } -function extractCreateContractWasmHash( - fn: xdr.SorobanAuthorizedFunction -): Buffer | null { - const candidates: Array = []; - const fnAny = fn as unknown as { - createContractHostFn?: () => unknown; - createContractWithCtorHostFn?: () => unknown; - createContractWithConstructorHostFn?: () => unknown; - }; - - if (typeof fnAny.createContractHostFn === "function") { - candidates.push(fnAny.createContractHostFn()); - } - if (typeof fnAny.createContractWithCtorHostFn === "function") { - candidates.push(fnAny.createContractWithCtorHostFn()); - } - if (typeof fnAny.createContractWithConstructorHostFn === "function") { - candidates.push(fnAny.createContractWithConstructorHostFn()); - } - - for (const candidate of candidates) { - if (!candidate || typeof candidate !== "object") { - continue; - } - - const ctx = candidate as { executable?: unknown }; - const executable = typeof ctx.executable === "function" - ? (ctx.executable as () => unknown)() - : ctx.executable; - - if (!executable || typeof executable !== "object") { - continue; - } - - const execAny = executable as { - switch?: () => { name: string }; - wasm?: (() => Buffer) | Buffer; - }; - const execSwitch = execAny.switch?.(); - - if (execSwitch?.name !== "contractExecutableWasm") { - continue; - } - - const wasm = typeof execAny.wasm === "function" ? execAny.wasm() : execAny.wasm; - if (wasm) { - return Buffer.from(wasm); - } - } - - return null; -} - function isMissingContextRuleError(error: unknown): boolean { if (!(error instanceof Error)) { return false; diff --git a/src/kit/invocation-utils.test.ts b/src/kit/invocation-utils.test.ts index f60cc31..7489434 100644 --- a/src/kit/invocation-utils.test.ts +++ b/src/kit/invocation-utils.test.ts @@ -1,4 +1,4 @@ -import { Address, Keypair, xdr } from "@stellar/stellar-sdk"; +import { Address, Keypair, hash, xdr } from "@stellar/stellar-sdk"; import { describe, expect, it } from "vitest"; import type { ContextRule, ContextRuleType, Signer } from "smart-account-kit-bindings"; import { @@ -64,6 +64,31 @@ function makeAuthEntry( }); } +function makeCreateContractInvocation( + wasmHash: Buffer, + subInvocations: xdr.SorobanAuthorizedInvocation[] = [] +): xdr.SorobanAuthorizedInvocation { + return new xdr.SorobanAuthorizedInvocation({ + function: + xdr.SorobanAuthorizedFunction.sorobanAuthorizedFunctionTypeCreateContractHostFn( + new xdr.CreateContractArgs({ + contractIdPreimage: + xdr.ContractIdPreimage.contractIdPreimageFromAddress( + new xdr.ContractIdPreimageFromAddress({ + address: xdr.ScAddress.scAddressTypeAccount( + xdr.PublicKey.publicKeyTypeEd25519(Buffer.alloc(32)) + ), + salt: Buffer.alloc(32), + }) + ), + executable: + xdr.ContractExecutable.contractExecutableWasm(wasmHash), + }) + ), + subInvocations, + }); +} + // Stable contract addresses for tests const CONTRACT_A = "CDANWYENKH6PTTY6GDTMDAMYRHMU4SBRPX5NUDYDMTYVOIF32ASZFU4Y"; const CONTRACT_B = "CBSHV66WG7UV6FQVUTB67P3DZUEJ2KJ5X6JKQH5MFRAAFNFJUAJVXJYV"; @@ -143,6 +168,31 @@ describe("invocation-utils", () => { ]); }); + it("walks a create-contract invocation and populates wasmHash", () => { + const wasmHash = hash(Buffer.from("test-wasm")); + const invocation = makeCreateContractInvocation(wasmHash); + const nodes = walkInvocationTree(invocation); + + expect(nodes).toHaveLength(1); + expect(nodes[0].contractAddress).toBeUndefined(); + expect(nodes[0].functionName).toBeUndefined(); + expect(nodes[0].wasmHash).toEqual(wasmHash); + }); + + it("walks a mixed tree with contract calls and create-contract", () => { + const wasmHash = hash(Buffer.from("deploy")); + const invocation = makeContractInvocation(CONTRACT_A, "deploy", [ + makeCreateContractInvocation(wasmHash), + ]); + const nodes = walkInvocationTree(invocation); + + expect(nodes).toHaveLength(2); + expect(nodes[0].contractAddress).toBe(CONTRACT_A); + expect(nodes[0].wasmHash).toBeUndefined(); + expect(nodes[1].contractAddress).toBeUndefined(); + expect(nodes[1].wasmHash).toEqual(wasmHash); + }); + // ========================================================================== // validateContextRuleIds // ========================================================================== @@ -264,6 +314,37 @@ describe("invocation-utils", () => { expect(hints[1].functionName).toBe("transfer"); }); + it("matches a CreateContract rule to a create-contract invocation", () => { + const wasmHash = hash(Buffer.from("test-wasm")); + const invocation = makeCreateContractInvocation(wasmHash); + const rules = [ + makeRule(0, { tag: "Default", values: undefined }, "default"), + makeRule(1, { tag: "CreateContract", values: [wasmHash] }, "deployer"), + ]; + + const hints = hintContextRuleIds(invocation, rules); + + expect(hints).toHaveLength(1); + expect(hints[0].suggestedRuleId).toBe(1); + expect(hints[0].matchingRules[0].contextType).toBe("CreateContract"); + expect(hints[0].matchingRules[1].contextType).toBe("Default"); + }); + + it("does not match a CreateContract rule to a contract-call invocation", () => { + const wasmHash = hash(Buffer.from("test-wasm")); + const invocation = makeContractInvocation(CONTRACT_A, "transfer"); + const rules = [ + makeRule(0, { tag: "Default", values: undefined }), + makeRule(1, { tag: "CreateContract", values: [wasmHash] }), + ]; + + const hints = hintContextRuleIds(invocation, rules); + + expect(hints[0].suggestedRuleId).toBe(0); + expect(hints[0].matchingRules).toHaveLength(1); + expect(hints[0].matchingRules[0].contextType).toBe("Default"); + }); + it("includes node metadata in each hint", () => { const invocation = makeContractInvocation(CONTRACT_A, "transfer"); const hints = hintContextRuleIds(invocation, [], 0); diff --git a/src/kit/invocation-utils.ts b/src/kit/invocation-utils.ts index 73573dc..1e01e84 100644 --- a/src/kit/invocation-utils.ts +++ b/src/kit/invocation-utils.ts @@ -63,6 +63,8 @@ export interface InvocationNode { contractAddress?: string; /** Function name for contract-function invocations, undefined otherwise. */ functionName?: string; + /** WASM hash for create-contract invocations, undefined otherwise. */ + wasmHash?: Buffer; } /** @@ -153,13 +155,19 @@ function walkRecursive( ): void { const fn = invocation.function(); const node: InvocationNode = { index: nodes.length }; + const switchName = fn.switch().name; - if (fn.switch().name === "sorobanAuthorizedFunctionTypeContractFn") { + if (switchName === "sorobanAuthorizedFunctionTypeContractFn") { const contractFn = fn.contractFn(); node.contractAddress = Address.fromScAddress( contractFn.contractAddress() ).toString(); node.functionName = contractFn.functionName().toString(); + } else if (switchName.startsWith("sorobanAuthorizedFunctionTypeCreateContract")) { + const wasmHash = extractCreateContractWasmHash(fn); + if (wasmHash) { + node.wasmHash = wasmHash; + } } nodes.push(node); @@ -169,6 +177,62 @@ function walkRecursive( } } +function extractCreateContractWasmHash( + fn: xdr.SorobanAuthorizedFunction +): Buffer | null { + const candidates: Array = []; + const fnAny = fn as unknown as { + createContractHostFn?: () => unknown; + createContractWithCtorHostFn?: () => unknown; + createContractWithConstructorHostFn?: () => unknown; + }; + + if (typeof fnAny.createContractHostFn === "function") { + candidates.push(fnAny.createContractHostFn()); + } + if (typeof fnAny.createContractWithCtorHostFn === "function") { + candidates.push(fnAny.createContractWithCtorHostFn()); + } + if (typeof fnAny.createContractWithConstructorHostFn === "function") { + candidates.push(fnAny.createContractWithConstructorHostFn()); + } + + for (const candidate of candidates) { + if (!candidate || typeof candidate !== "object") { + continue; + } + + const ctx = candidate as { executable?: unknown }; + const executable = typeof ctx.executable === "function" + ? (ctx.executable as () => unknown)() + : ctx.executable; + + if (!executable || typeof executable !== "object") { + continue; + } + + const execAny = executable as { + switch?: () => { name: string }; + wasm?: (() => Buffer) | Buffer; + wasmHash?: (() => Buffer) | Buffer; + }; + const execSwitch = execAny.switch?.(); + + if (execSwitch?.name !== "contractExecutableWasm") { + continue; + } + + const hash = + (typeof execAny.wasmHash === "function" ? execAny.wasmHash() : execAny.wasmHash) ?? + (typeof execAny.wasm === "function" ? execAny.wasm() : execAny.wasm); + if (hash) { + return Buffer.from(hash); + } + } + + return null; +} + // ============================================================================ // Validation // ============================================================================