Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

67 changes: 60 additions & 7 deletions contracts/invoice_nft/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ pub enum DataKey {
Admin,
/// Instance key: access control contract address for pause checks
AccessControl,
/// Instance key: tracks migration version
MigrationVersion,
}

// ── Contract ─────────────────────────────────────────────────────────────────
Expand Down Expand Up @@ -157,6 +159,14 @@ impl InvoiceNftContract {
}

/// Transition invoice to Listed status. Called by Marketplace contract.
///
/// **Parameters:**
/// - `caller` — The marketplace contract address.
/// - `invoice_id` — The ID of the invoice to list.
///
/// **Returns:** `Ok(())` on success, or an appropriate `KoraError`.
///
/// **Security:** Requires auth from the caller. Validates that the invoice is in `Created` status.
pub fn set_listed(env: Env, caller: Address, invoice_id: u64) -> Result<(), KoraError> {
caller.require_auth();
Self::require_not_paused(&env)?;
Expand All @@ -173,6 +183,14 @@ impl InvoiceNftContract {
}

/// Transition invoice to Funded. Called by Financing Pool contract.
///
/// **Parameters:**
/// - `caller` — The investor or financing pool contract address.
/// - `invoice_id` — The ID of the invoice to fund.
///
/// **Returns:** `Ok(())` on success, or an appropriate `KoraError`.
///
/// **Security:** Requires auth from the caller. Validates that the invoice is in `Listed` status.
pub fn set_funded(env: Env, caller: Address, invoice_id: u64) -> Result<(), KoraError> {
caller.require_auth();
Self::require_not_paused(&env)?;
Expand All @@ -185,11 +203,19 @@ impl InvoiceNftContract {
env.storage()
.persistent()
.set(&DataKey::Invoice(invoice_id), &invoice);
events::invoice_funded(&env, invoice_id, &invoice.sme, invoice.amount);
events::invoice_funded(&env, invoice_id, &caller, invoice.amount);
Ok(())
}

/// Mark invoice as Repaid. Called by Financing Pool on full repayment.
///
/// **Parameters:**
/// - `caller` — The financing pool contract address.
/// - `invoice_id` — The ID of the invoice to repay.
///
/// **Returns:** `Ok(())` on success, or an appropriate `KoraError`.
///
/// **Security:** Requires auth from the caller. Validates that the invoice is in `Funded` status.
pub fn set_repaid(env: Env, caller: Address, invoice_id: u64) -> Result<(), KoraError> {
caller.require_auth();
let mut invoice = Self::load_invoice(&env, invoice_id)?;
Expand All @@ -206,6 +232,14 @@ impl InvoiceNftContract {
}

/// Mark invoice as Defaulted. Called by admin after due date passes.
///
/// **Parameters:**
/// - `caller` — The admin address.
/// - `invoice_id` — The ID of the invoice to mark as defaulted.
///
/// **Returns:** `Ok(())` on success, or an appropriate `KoraError`.
///
/// **Security:** Requires admin auth. Validates that the invoice is `Funded` and the due date has passed.
pub fn set_defaulted(env: Env, caller: Address, invoice_id: u64) -> Result<(), KoraError> {
caller.require_auth();
Self::require_admin(&env, &caller)?;
Expand Down Expand Up @@ -250,7 +284,11 @@ impl InvoiceNftContract {

/// Returns the number of invoices minted (next_id - 1).
pub fn invoice_count(env: Env) -> u64 {
env.storage().instance().get::<_, u64>(&DataKey::NextId).unwrap_or(1).saturating_sub(1)
env.storage()
.instance()
.get::<_, u64>(&DataKey::NextId)
.unwrap_or(1)
.saturating_sub(1)
}

// ── Helpers ──────────────────────────────────────────────────────────────
Expand Down Expand Up @@ -376,7 +414,10 @@ mod tests {
let (env, _admin, client) = setup();
let sme = Address::generate(&env);
let debtor_hash = Bytes::from_slice(&env, &[1u8; 32]);
let ipfs_cid = String::from_str(&env, "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi");
let ipfs_cid = String::from_str(
&env,
"bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
);
let due_date = env.ledger().timestamp() + 86_400;

let result = client.try_mint_invoice(
Expand All @@ -396,7 +437,10 @@ mod tests {
let (env, _admin, client) = setup();
let sme = Address::generate(&env);
let debtor_hash = Bytes::from_slice(&env, &[1u8; 32]);
let ipfs_cid = String::from_str(&env, "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi");
let ipfs_cid = String::from_str(
&env,
"bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
);
let due_date = env.ledger().timestamp() + 86_400;

let result = client.try_mint_invoice(
Expand All @@ -416,7 +460,10 @@ mod tests {
let (env, _admin, client) = setup();
let sme = Address::generate(&env);
let debtor_hash = Bytes::from_slice(&env, &[1u8; 32]);
let ipfs_cid = String::from_str(&env, "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi");
let ipfs_cid = String::from_str(
&env,
"bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
);
let due_date = env.ledger().timestamp() - 1;

let result = client.try_mint_invoice(
Expand All @@ -436,7 +483,10 @@ mod tests {
let (env, _admin, client) = setup();
let sme = Address::generate(&env);
let debtor_hash = Bytes::from_slice(&env, &[1u8; 32]);
let ipfs_cid = String::from_str(&env, "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi");
let ipfs_cid = String::from_str(
&env,
"bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
);
let due_date = env.ledger().timestamp() + 86_400;

let result = client.try_mint_invoice(
Expand All @@ -456,7 +506,10 @@ mod tests {
let (env, _admin, client) = setup();
let sme = Address::generate(&env);
let debtor_hash = Bytes::from_slice(&env, &[]);
let ipfs_cid = String::from_str(&env, "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi");
let ipfs_cid = String::from_str(
&env,
"bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
);
let due_date = env.ledger().timestamp() + 86_400;

let result = client.try_mint_invoice(
Expand Down
38 changes: 3 additions & 35 deletions contracts/shared/src/reentrancy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,38 +67,6 @@ pub fn is_locked(env: &Env) -> bool {
env.storage().instance().has(&GuardKey::Lock)
}

// ── RAII guard ────────────────────────────────────────────────────────────────

/// RAII reentrancy guard. Acquires the lock on construction and releases it
/// automatically when dropped, ensuring the lock is always released even on
/// early returns or panics.
///
/// # Usage
/// ```ignore
/// pub fn my_fn(env: Env) -> Result<(), KoraError> {
/// let _guard = ReentrancyGuard::new(&env)?;
/// // ... protected logic ...
/// Ok(())
/// } // lock released here automatically
/// ```
pub struct ReentrancyGuard<'a> {
env: &'a Env,
}

impl<'a> ReentrancyGuard<'a> {
/// Acquire the lock. Returns `KoraError::Reentrancy` if already locked.
pub fn new(env: &'a Env) -> Result<Self, KoraError> {
acquire_guard(env)?;
Ok(Self { env })
}
}

impl<'a> Drop for ReentrancyGuard<'a> {
fn drop(&mut self) {
release_guard(self.env);
}
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
Expand All @@ -118,7 +86,7 @@ mod tests {
let env = Env::default();
acquire_guard(&env).unwrap();
let result = acquire_guard(&env);
assert_eq!(result.unwrap_err(), KoraError::Reentrancy);
assert_eq!(result.err().unwrap(), KoraError::Reentrancy);
release_guard(&env);
}

Expand All @@ -145,7 +113,7 @@ mod tests {
fn test_double_acquire_returns_reentrancy_error() {
let env = Env::default();
acquire_guard(&env).unwrap();
let err = acquire_guard(&env).unwrap_err();
let err = acquire_guard(&env).err().unwrap();
assert_eq!(err, KoraError::Reentrancy);
release_guard(&env);
}
Expand Down Expand Up @@ -220,7 +188,7 @@ mod tests {
let _guard = ReentrancyGuard::new(&env).unwrap();
// Second guard must fail while first is held
let result = ReentrancyGuard::new(&env);
assert_eq!(result.unwrap_err(), KoraError::Reentrancy);
assert_eq!(result.err().unwrap(), KoraError::Reentrancy);
// First guard drops here, lock released
}
}
64 changes: 32 additions & 32 deletions contracts/treasury/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
use kora_shared::{
errors::KoraError,
events,
validation::require_valid_fee_bps,
reentrancy::ReentrancyGuard,
validation::{require_non_zero_amount, require_valid_fee_bps},
};
use soroban_sdk::{contract, contractimpl, contracttype, token, Address, Env};

Expand All @@ -23,6 +24,8 @@ pub enum DataKey {
Collected(Address),
/// Reentrancy guard for withdrawal functions.
WithdrawalLock,
/// Whitelisted tokens.
WhitelistedToken(Address),
}

// ── Contract ──────────────────────────────────────────────────────────────────
Expand Down Expand Up @@ -132,25 +135,19 @@ impl TreasuryContract {
return Err(KoraError::InvalidAmount);
}

Self::acquire_lock(&env)?;
let _guard = ReentrancyGuard::new(&env)?;

let token_client = token::Client::new(&env, &token);
let balance = token_client.balance(&env.current_contract_address());

if balance < amount {
// Release lock before returning error — must not leave lock stuck
Self::release_lock(&env);
return Err(KoraError::InsufficientPoolBalance);
}

// ── Effects ───────────────────────────────────────────────────────────
// Deduct from informational accounting if tracked
let collected_key = DataKey::Collected(token.clone());
if let Some(collected) = env
.storage()
.persistent()
.get::<_, i128>(&collected_key)
{
if let Some(collected) = env.storage().persistent().get::<_, i128>(&collected_key) {
// Saturating sub: accounting is informational, don't revert on mismatch
let new_collected = collected.saturating_sub(amount);
env.storage()
Expand Down Expand Up @@ -186,20 +183,11 @@ impl TreasuryContract {
let balance = token_client.balance(&env.current_contract_address());

if balance > 0 {
// ── Interactions ──────────────────────────────────────────────────────
token_client.transfer(&env.current_contract_address(), &recipient, &balance);
}

// Always release lock regardless of whether a transfer occurred
Self::release_lock(&env);

if balance > 0 {
events::emergency_withdrawn(&env, &admin, &token, balance);
}

// ── Interactions ──────────────────────────────────────────────────────
token_client.transfer(&env.current_contract_address(), &recipient, &balance);

events::emergency_withdrawn(&env, &admin, &token, balance);
Ok(())
}

Expand Down Expand Up @@ -249,8 +237,12 @@ impl TreasuryContract {
Ok(())
}

fn release_lock(env: &Env) {
env.storage().instance().set(&DataKey::WithdrawalLock, &false);
fn bump_persistent(env: &Env, key: &DataKey) {
env.storage().persistent().extend_ttl(
key,
PERSISTENT_LIFETIME_THRESHOLD,
PERSISTENT_BUMP_AMOUNT,
);
}
}

Expand All @@ -267,7 +259,7 @@ mod tests {
let contract_id = env.register_contract(None, TreasuryContract);
let client = TreasuryContractClient::new(&env, &contract_id);
let admin = Address::generate(&env);
client.initialize(&admin, &50u32).unwrap();
client.initialize(&admin, &50u32);
(env, admin, client)
}

Expand Down Expand Up @@ -310,7 +302,7 @@ mod tests {
#[test]
fn test_set_fee_bps_success() {
let (_env, admin, client) = setup();
client.set_fee_bps(&admin, &100u32).unwrap();
client.set_fee_bps(&admin, &100u32);
assert_eq!(client.get_fee_bps(), 100);
}

Expand All @@ -330,14 +322,14 @@ mod tests {
#[test]
fn test_set_fee_bps_zero_allowed() {
let (_env, admin, client) = setup();
client.set_fee_bps(&admin, &0u32).unwrap();
client.set_fee_bps(&admin, &0u32);
assert_eq!(client.get_fee_bps(), 0);
}

#[test]
fn test_set_fee_bps_max_allowed() {
let (_env, admin, client) = setup();
client.set_fee_bps(&admin, &10_000u32).unwrap();
client.set_fee_bps(&admin, &10_000u32);
assert_eq!(client.get_fee_bps(), 10_000);
}

Expand All @@ -350,11 +342,11 @@ mod tests {
#[test]
fn test_set_fee_bps_multiple_updates() {
let (_env, admin, client) = setup();
client.set_fee_bps(&admin, &100u32).unwrap();
client.set_fee_bps(&admin, &100u32);
assert_eq!(client.get_fee_bps(), 100);
client.set_fee_bps(&admin, &200u32).unwrap();
client.set_fee_bps(&admin, &200u32);
assert_eq!(client.get_fee_bps(), 200);
client.set_fee_bps(&admin, &50u32).unwrap();
client.set_fee_bps(&admin, &50u32);
assert_eq!(client.get_fee_bps(), 50);
}

Expand All @@ -366,23 +358,29 @@ mod tests {
let non_admin = Address::generate(&env);
let token = Address::generate(&env);
let recipient = Address::generate(&env);
assert!(client.try_withdraw(&non_admin, &token, &recipient, &1_000_000i128).is_err());
assert!(client
.try_withdraw(&non_admin, &token, &recipient, &1_000_000i128)
.is_err());
}

#[test]
fn test_withdraw_zero_amount_fails() {
let (env, admin, client) = setup();
let token = Address::generate(&env);
let recipient = Address::generate(&env);
assert!(client.try_withdraw(&admin, &token, &recipient, &0i128).is_err());
assert!(client
.try_withdraw(&admin, &token, &recipient, &0i128)
.is_err());
}

#[test]
fn test_withdraw_with_negative_amount_rejected() {
let (env, admin, client) = setup();
let token = Address::generate(&env);
let recipient = Address::generate(&env);
assert!(client.try_withdraw(&admin, &token, &recipient, &-1_000i128).is_err());
assert!(client
.try_withdraw(&admin, &token, &recipient, &-1_000i128)
.is_err());
}

#[test]
Expand All @@ -391,7 +389,9 @@ mod tests {
let non_admin = Address::generate(&env);
let token = Address::generate(&env);
let recipient = Address::generate(&env);
assert!(client.try_emergency_withdraw(&non_admin, &token, &recipient).is_err());
assert!(client
.try_emergency_withdraw(&non_admin, &token, &recipient)
.is_err());
}

#[test]
Expand Down