From ecacbb5e083dc88d68502deb7238f1b6848c1530 Mon Sep 17 00:00:00 2001 From: GitGuru7 <128375421+GitGuru7@users.noreply.github.com> Date: Tue, 30 Apr 2024 18:37:12 +0530 Subject: [PATCH 1/4] feat: add TokenVault --- contracts/TokenVault/TokenVault.sol | 479 +++++++++++++++++++++ contracts/TokenVault/TokenVaultStorage.sol | 67 +++ 2 files changed, 546 insertions(+) create mode 100644 contracts/TokenVault/TokenVault.sol create mode 100644 contracts/TokenVault/TokenVaultStorage.sol diff --git a/contracts/TokenVault/TokenVault.sol b/contracts/TokenVault/TokenVault.sol new file mode 100644 index 000000000..faa9c0046 --- /dev/null +++ b/contracts/TokenVault/TokenVault.sol @@ -0,0 +1,479 @@ +// SPDX-License-Identifier: BSD-3-Clause + +pragma solidity 0.8.25; + +import { ReentrancyGuard } from "@openzeppelin/contracts/security/ReentrancyGuard.sol"; +import { Pausable } from "@openzeppelin/contracts/security/Pausable.sol"; +import { ensureNonzeroAddress } from "@venusprotocol/solidity-utilities/contracts/validators.sol"; +import { IERC20Upgradeable } from "@openzeppelin/contracts-upgradeable/token/ERC20/IERC20Upgradeable.sol"; +import { IAccessControlManagerV8 } from "@venusprotocol/governance-contracts/contracts/Governance/IAccessControlManagerV8.sol"; +import { Initializable } from "@openzeppelin/contracts-upgradeable/proxy/utils/Initializable.sol"; +import { SafeERC20Upgradeable } from "@openzeppelin/contracts-upgradeable/token/ERC20/utils/SafeERC20Upgradeable.sol"; +import { ECDSA } from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import { TimeManagerV8 } from "@venusprotocol/solidity-utilities/contracts/TimeManagerV8.sol"; +import { TokenVaultStorage } from "./TokenVaultStorage.sol"; + +/** + * @title Token Vault + * @author Venus + * @notice Token vault is a generic vault that can support multiple token. User can lock their supported token in the TokenVault to receive voting rights in Venus governance. + */ +contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, TokenVaultStorage { + /// @notice Event emitted when deposit + event Deposit(address indexed user, address indexed token, uint256 indexed amount); + + /// @notice Event emitted when execute withrawal + event ExecutedWithdrawal(address indexed user, address indexed token, uint256 indexed amount); + + /// @notice Event emitted when request withrawal + event RequestedWithdrawal(address indexed user, address indexed token, uint256 indexed amount); + + /// @notice An event thats emitted when an account changes its delegate + event DelegateChangedV2(address indexed delegator, address indexed fromDelegate, address indexed toDelegate); + + /// @notice An event thats emitted when a delegate account's vote balance changes + event DelegateVotesChangedV2(address indexed delegate, uint256 previousBalance, uint256 newBalance); + + /// @notice Event emitted when tokens are updated + event UpdateTokens(address token, bool isAdded); + + using SafeERC20Upgradeable for IERC20Upgradeable; + + constructor(address _token, bool _timeBased, uint256 _blocksPerYear) TimeManagerV8(_timeBased, _blocksPerYear) { + ensureNonzeroAddress(_token); + tokens[_token] = true; + } + + /** + * @notice Initialize the contract + * @param _accessControlManager Address of access control manager + */ + function initialize(address _accessControlManager) external initializer { + ensureNonzeroAddress(_accessControlManager); + accessControlManager = _accessControlManager; + } + + /** + * @notice Update tokens supported by the vault + * @param _token Address of token + * @param _isAdded Bool value, should be true to add token + * @custom:access Controlled by Access Control Manager + * @custom:event Emit UpdateTokens with address of token and its bool value + */ + function updateTokens(address _token, bool _isAdded) external { + _ensureAllowed("updateTokens(address,bool)"); + ensureNonzeroAddress(address(_token)); + tokens[_token] = _isAdded; + emit UpdateTokens(_token, _isAdded); + } + + /** + * @notice Deposit token to TokenVault + * @param _token Address of token to be deposited + * @param _amount Amount of token to be deposited + * @custom:event Emit Deposit with msg.sender, token and amount + */ + function deposit(address _token, uint256 _amount) external nonReentrant whenNotPaused { + require(tokens[_token], "TokenVault::deposit: token is not registered"); + require(_amount > 0, "TokenVault::deposit: invalid amount"); + UserInfo storage user = userInfos[msg.sender][_token]; + IERC20Upgradeable(_token).safeTransferFrom(msg.sender, address(this), _amount); + userInfos[msg.sender][_token].amount = user.amount + _amount; + _moveDelegates(address(0), delegates[msg.sender], _amount, _token); + emit Deposit(msg.sender, _token, _amount); + } + + /** + * @notice Execute withdrawal of given token + * @param _token Address of token to be withdrawal. It should be a registered token + * @custom:event Emit ExecutedWithdrawal with msg.sender, token and withdrawal amount + */ + function executeWithdrawal(address _token) external nonReentrant whenNotPaused { + require(tokens[_token], "TokenVault::executeWithdrawal: token is not registered"); + UserInfo storage user = userInfos[msg.sender][_token]; + WithdrawalRequest[] storage requests = withdrawalRequests[msg.sender][_token]; + + uint256 withdrawalAmount; + + withdrawalAmount = popEligibleWithdrawalRequests(user, requests); + require(withdrawalAmount > 0, "nothing to withdraw"); + + user.amount = user.amount - withdrawalAmount; + IERC20Upgradeable(_token).safeTransfer(address(msg.sender), withdrawalAmount); + totalPendingWithdrawals[_token] = totalPendingWithdrawals[_token] - withdrawalAmount; + emit ExecutedWithdrawal(msg.sender, _token, withdrawalAmount); + } + + /** + * @notice Request withdrawal to TokenVault for token allocation + * @param _token Address of token to be withdrawal + * @param _amount The amount to withdraw from the vault + * @custom:event Emit RequestedWithdrawal with msg.sender, token and withdrawal amount + */ + function requestWithdrawal(address _token, uint256 _amount) external nonReentrant whenNotPaused { + require(tokens[_token], "TokenVault::requestWithdrawal: token is not registered"); + require(_amount > 0, "TokenVault::requestWithdrawal: requested amount cannot be zero"); + UserInfo storage user = userInfos[_token][msg.sender]; + WithdrawalRequest[] storage requests = withdrawalRequests[_token][msg.sender]; + + require( + user.amount >= user.pendingWithdrawals + _amount, + "TokenVault::requestWithdrawal: requested amount is invalid" + ); + + uint256 lockedUntil = tokenLockPeriod[_token] + block.timestamp; + + pushWithdrawalRequest(user, requests, _amount, lockedUntil); + totalPendingWithdrawals[_token] = totalPendingWithdrawals[_token] + _amount; + + // Update Delegate Amount + _moveDelegates(delegates[msg.sender], address(0), _amount, _token); + + emit RequestedWithdrawal(msg.sender, _token, _amount); + } + + /** + * @notice Get unlocked withdrawal amount + * @param _token Address of token + * @param _user The User Address + * @return withdrawalAmount Amount that the user can withdraw + */ + function getEligibleWithdrawalAmount( + address _token, + address _user + ) external view returns (uint256 withdrawalAmount) { + require(tokens[_token], "TokenVault::getEligibleWithdrawalAmount: token is not registered"); + WithdrawalRequest[] storage requests = withdrawalRequests[_token][_user]; + // Since the requests are sorted by their unlock time, we can take + // the entries from the end of the array and stop at the first + // not-yet-eligible one + for (uint256 i = requests.length; i > 0 && isUnlocked(requests[i - 1]); --i) { + withdrawalAmount = withdrawalAmount + requests[i - 1].amount; + } + return withdrawalAmount; + } + + /** + * @notice Get requested amount + * @param _token Address of token + * @param _user The User Address + * @return Total amount of requested but not yet executed withdrawals (including both executable and locked ones) + */ + function getRequestedAmount(address _token, address _user) external view returns (uint256) { + require(tokens[_token], "TokenVault::getRequestedAmount: token is not reistered"); + UserInfo storage user = userInfos[_token][_user]; + return user.pendingWithdrawals; + } + + /** + * @notice Returns the array of withdrawal requests that have not been executed yet + * @param _token Address of token + * @param _user The User Address + * @return An array of withdrawal requests + */ + function getWithdrawalRequests(address _token, address _user) external view returns (WithdrawalRequest[] memory) { + require(tokens[_token], "TokenVault::getWithdrawalRequests: token is not reistered"); + return withdrawalRequests[_token][_user]; + } + + /** + * @notice Determine the token stake balance for an account + * @param _account The address of the account to check + * @param _blockNumberOrSecond The block number or second to get the vote balance at + * @param _token Address of token + * @return The balance that user staked + */ + function getPriorVotes( + address _account, + uint256 _blockNumberOrSecond, + address _token + ) external view returns (uint256) { + require(_blockNumberOrSecond < getBlockNumberOrTimestamp(), "TokenVault::getPriorVotes: not yet determined"); + + uint32 nCheckpoints = numCheckpoints[_token][_account]; + if (nCheckpoints == 0) { + return 0; + } + + // First check most recent balance + if (checkpoints[_token][_account][nCheckpoints - 1].fromBlockOrSecond <= _blockNumberOrSecond) { + return checkpoints[_token][_account][nCheckpoints - 1].votes; + } + + // Next check implicit zero balance + if (checkpoints[_token][_account][0].fromBlockOrSecond > _blockNumberOrSecond) { + return 0; + } + + uint32 lower = 0; + uint32 upper = nCheckpoints - 1; + while (upper > lower) { + uint32 center = upper - (upper - lower) / 2; // ceil, avoiding overflow + Checkpoint memory cp = checkpoints[_token][_account][center]; + if (cp.fromBlockOrSecond == _blockNumberOrSecond) { + return cp.votes; + } else if (cp.fromBlockOrSecond < _blockNumberOrSecond) { + lower = center; + } else { + upper = center - 1; + } + } + return checkpoints[_token][_account][lower].votes; + } + + /** + * @notice Get user info with reward token address and pid + * @param _token Reward token address + * @param _user User address + * @return amount Deposited amount + * @return pendingWithdrawals Requested but not yet executed withdrawals + */ + function getUserInfo( + address _token, + address _user + ) external view returns (uint256 amount, uint256 pendingWithdrawals) { + require(tokens[_token], "TokenVault::getUserInfo: token is not reistered"); + UserInfo storage user = userInfos[_token][_user]; + amount = user.amount; + pendingWithdrawals = user.pendingWithdrawals; + } + + /** + * @notice Delegate votes from `msg.sender` to `delegatee` + * @param _delegatee The address to delegate votes to + * @param _token Address of token + */ + function delegate(address _delegatee, address _token) external whenNotPaused { + return _delegate(msg.sender, _delegatee, _token); + } + + /** + * @notice Delegates votes from signatory to `delegatee` + * @param _delegatee The address to delegate votes to + * @param _nonce The contract state required to match the signature + * @param _expiry The time at which to expire the signature + * @param v The recovery byte of the signature + * @param r Half of the ECDSA signature pair + * @param s Half of the ECDSA signature pair + */ + function delegateBySig( + address _delegatee, + uint256 _nonce, + uint256 _expiry, + uint8 v, + bytes32 r, + bytes32 s, + address _token + ) external whenNotPaused { + bytes32 domainSeparator = keccak256( + abi.encode(DOMAIN_TYPEHASH, keccak256(bytes("XVSVault")), block.chainid, address(this)) + ); + bytes32 structHash = keccak256(abi.encode(DELEGATION_TYPEHASH, _delegatee, _nonce, _expiry)); + bytes32 digest = keccak256(abi.encodePacked("\x19\x01", domainSeparator, structHash)); + address signatory = ECDSA.recover(digest, v, r, s); + require(_nonce == nonces[signatory]++, "XVSVault::delegateBySig: invalid nonce"); + require(block.timestamp <= _expiry, "XVSVault::delegateBySig: signature expired"); + return _delegate(signatory, _delegatee, _token); + } + + /** + * @notice Set Access Control Manager + * @param _accessControlManager Address of Access Control Manager + */ + function setAccessControlManager(address _accessControlManager) external { + _ensureAllowed("setAccessControlManager(address)"); + ensureNonzeroAddress(_accessControlManager); + accessControlManager = _accessControlManager; + } + + /** + * @notice Gets the current votes balance for `account` + * @param _account The address to get votes balance + * @param _token Address of token + * @return The number of current votes for `account` + */ + function getCurrentVotes(address _account, address _token) external view returns (uint256) { + uint32 nCheckpoints = numCheckpoints[_token][_account]; + return nCheckpoints > 0 ? checkpoints[_token][_account][nCheckpoints - 1].votes : 0; + } + + /** + * @notice Pushes withdrawal request to the requests array and updates + * the pending withdrawals amount. The requests are always sorted + * by unlock time (descending) so that the earliest to execute requests + * are always at the end of the array + * @param _user The user struct storage pointer + * @param _requests The user's requests array storage pointer + * @param _amount The amount being requested + */ + function pushWithdrawalRequest( + UserInfo storage _user, + WithdrawalRequest[] storage _requests, + uint256 _amount, + uint256 _lockedUntil + ) internal { + uint256 i = _requests.length; + _requests.push(WithdrawalRequest(0, 0)); + // Keep it sorted so that the first to get unlocked request is always at the end + for (; i > 0 && _requests[i - 1].lockedUntil <= _lockedUntil; ) { + _requests[i] = _requests[i - 1]; + unchecked { + --i; + } + } + _requests[i] = WithdrawalRequest(_amount, uint128(_lockedUntil)); + _user.pendingWithdrawals = _user.pendingWithdrawals + _amount; + } + + /** + * @notice Pops the requests with unlock time < now from the requests + * array and deducts the computed amount from the user's pending + * withdrawals counter. Assumes that the requests array is sorted + * by unclock time (descending). + * @dev This function **removes** the eligible requests from the requests + * array. If this function is called, the withdrawal should actually + * happen (or the transaction should be reverted). + * @param _user The user struct storage pointer + * @param _requests The user's requests array storage pointer + * @return withdrawalAmount The amount eligible for withdrawal + */ + function popEligibleWithdrawalRequests( + UserInfo storage _user, + WithdrawalRequest[] storage _requests + ) internal returns (uint256 withdrawalAmount) { + // Since the requests are sorted by their unlock time, we can just + // pop them from the array and stop at the first not-yet-eligible one + for (uint256 i = _requests.length; i > 0 && isUnlocked(_requests[i - 1]); ) { + withdrawalAmount = withdrawalAmount + (_requests[i - 1].amount); + + _requests.pop(); + unchecked { + --i; + } + } + _user.pendingWithdrawals = _user.pendingWithdrawals - withdrawalAmount; + return withdrawalAmount; + } + + /** + * @dev Delegate user votes + * @param _delegator Address of delegator + * @param _delegatee Address of delegatee + * @param _token Address of token + * @custom:event Emit DelegateChangedV2 with current delegate, new delegatee and token + */ + function _delegate(address _delegator, address _delegatee, address _token) internal { + address currentDelegate = delegates[_delegator]; + uint256 delegatorBalance = getStakeAmount(_delegator, _token); + delegates[_delegator] = _delegatee; + + emit DelegateChangedV2(_delegator, currentDelegate, _delegatee); + + _moveDelegates(currentDelegate, _delegatee, delegatorBalance, _token); + } + + /** + * @dev Internal function to moves voting power from one representative to another based on the given parameters + * @param _srcRep The address of the current representative whose voting power is being transferred + * @param _dstRep The address of the new representative who will receive the transferred voting power + * @param _amount The amount of voting power to be transferred + * @param _token The address of the token associated with the voting power + */ + function _moveDelegates(address _srcRep, address _dstRep, uint256 _amount, address _token) internal { + if (_srcRep != _dstRep && _amount > 0) { + if (_srcRep != address(0)) { + uint32 srcRepNum = numCheckpoints[_token][_srcRep]; + uint256 srcRepOld = srcRepNum > 0 ? checkpoints[_token][_srcRep][srcRepNum - 1].votes : 0; + uint256 srcRepNew = srcRepOld - _amount; + _writeCheckpoint(_srcRep, srcRepNum, srcRepOld, srcRepNew, _token); + } + + if (_dstRep != address(0)) { + uint32 dstRepNum = numCheckpoints[_token][_dstRep]; + uint256 dstRepOld = dstRepNum > 0 ? checkpoints[_token][_dstRep][dstRepNum - 1].votes : 0; + uint256 dstRepNew = dstRepOld + _amount; + _writeCheckpoint(_dstRep, dstRepNum, dstRepOld, dstRepNew, _token); + } + } + } + + /** + * @dev Updates the voting checkpoint for a delegatee with the given parameters + * If there are existing checkpoints for the delegatee at the current block number or timestamp, + * the function updates the votes in the most recent checkpoint + * Otherwise, it creates a new checkpoint with the current block number or timestamp and the new votes + * @param delegatee The address of the delegatee whose voting checkpoint is being updated + * @param nCheckpoints The number of existing voting checkpoints for the delegatee + * @param oldVotes The previous number of votes held by the delegatee + * @param newVotes The new number of votes to be assigned to the delegatee + * @param _token The address of the token associated with the voting power + * @custom:event Emits a DelegateVotesChangedV2 event to signal the change in voting power for the delegatee + */ + function _writeCheckpoint( + address delegatee, + uint32 nCheckpoints, + uint256 oldVotes, + uint256 newVotes, + address _token + ) internal { + uint32 blockNumberOrSecond = uint32(getBlockNumberOrTimestamp()); + + if ( + nCheckpoints > 0 && + checkpoints[_token][delegatee][nCheckpoints - 1].fromBlockOrSecond == blockNumberOrSecond + ) { + checkpoints[_token][delegatee][nCheckpoints - 1].votes = newVotes; + } else { + checkpoints[_token][delegatee][nCheckpoints] = Checkpoint(blockNumberOrSecond, newVotes); + numCheckpoints[_token][delegatee] = nCheckpoints + 1; + } + + emit DelegateVotesChangedV2(delegatee, oldVotes, newVotes); + } + + /** + * @dev Returns before and after upgrade pending withdrawal amount + * @param _requests The user's requests array storage pointer + * @return withdrawalAmount The amount eligible for withdrawal + */ + function getRequestedWithdrawalAmount( + WithdrawalRequest[] storage _requests + ) internal view returns (uint256 withdrawalAmount) { + for (uint256 i = _requests.length; i > 0; --i) { + withdrawalAmount = withdrawalAmount + (_requests[i - 1].amount); + } + return withdrawalAmount; + } + + /** + * @notice Get the XVS stake balance of an account (excluding the pending withdrawals) + * @param _account The address of the account to check + * @param _token Address of token + * @return The balance that user staked + */ + function getStakeAmount(address _account, address _token) internal view returns (uint256) { + require(tokens[_token], "TokenVault::getStakeAmount: token is not reistered"); + UserInfo storage user = userInfos[_token][_account]; + return user.amount - (user.pendingWithdrawals); + } + + /** + * @dev Ensure that the caller has permission to execute a specific function + * @param functionSig_ Function signature to be checked for permission + */ + function _ensureAllowed(string memory functionSig_) internal view { + require( + IAccessControlManagerV8(accessControlManager).isAllowedToCall(msg.sender, functionSig_), + "access denied" + ); + } + + /** + * @dev Checks if the request is eligible for withdrawal. + * @param _request The request struct storage pointer + * @return True if the request is eligible for withdrawal, false otherwise + */ + function isUnlocked(WithdrawalRequest storage _request) private view returns (bool) { + return _request.lockedUntil <= block.timestamp; + } +} diff --git a/contracts/TokenVault/TokenVaultStorage.sol b/contracts/TokenVault/TokenVaultStorage.sol new file mode 100644 index 000000000..578909dc6 --- /dev/null +++ b/contracts/TokenVault/TokenVaultStorage.sol @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: BSD-3-Clause + +pragma solidity ^0.8.25; + +contract TokenVaultStorage { + /// @notice Info of each user. + struct UserInfo { + uint256 amount; + uint256 pendingWithdrawals; + } + // Infomation about a withdrawal request + struct WithdrawalRequest { + uint256 amount; + uint128 lockedUntil; + } + + /// @notice A checkpoint for marking number of votes from a given block or second + struct Checkpoint { + uint32 fromBlockOrSecond; + uint256 votes; + } + + // Access Control Manager + address public accessControlManager; + + /// @notice The EIP-712 typehash for the contract's domain + bytes32 public constant DOMAIN_TYPEHASH = + keccak256("EIP712Domain(string name,uint256 chainId,address verifyingContract)"); + + /// @notice The EIP-712 typehash for the delegation struct used by the contract + bytes32 public constant DELEGATION_TYPEHASH = + keccak256("Delegation(address delegatee,uint256 nonce,uint256 expiry)"); + + // @notice ERC20 tokens along with a bool value to indicate support of each token + mapping(address => bool) public tokens; + + /// @notice A record of each accounts delegate + mapping(address => address) public delegates; + + /// @notice A record of votes checkpoints for each account, by index for each token + mapping(address => mapping(address => mapping(uint32 => Checkpoint))) public checkpoints; + + /// @notice The number of checkpoints for each account for each token + mapping(address => mapping(address => uint32)) public numCheckpoints; + + /// @notice Tracks pending withdrawals for all users for a particular token + mapping(address => uint256) public totalPendingWithdrawals; + + /// @notice Indicate lock period of each token + mapping(address => uint128) public tokenLockPeriod; + + // Info of requested but not yet executed withdrawals for each token + mapping(address => mapping(address => WithdrawalRequest[])) internal withdrawalRequests; + + // Info of each user that stakes tokens, for each token + mapping(address => mapping(address => UserInfo)) internal userInfos; + + /// @notice A record of states for signing / validating signatures + mapping(address => uint) public nonces; + + /** + * @dev This empty reserved space is put in place to allow future versions to add new + * variables without shifting down storage in the inheritance chain. + * See https://docs.openzeppelin.com/contracts/4.x/upgradeable#storage_gaps + */ + uint256[47] private __gap; +} From 56e1a8498122752e852d18d4c901fc8e556f0927 Mon Sep 17 00:00:00 2001 From: GitGuru7 <128375421+GitGuru7@users.noreply.github.com> Date: Wed, 1 May 2024 16:23:43 +0530 Subject: [PATCH 2/4] refactor: add more functionality in TokenVault and tests --- contracts/TokenVault/TokenVault.sol | 63 +++++++-- tests/hardhat/TokenVault/tokenVaultTest.ts | 154 +++++++++++++++++++++ 2 files changed, 204 insertions(+), 13 deletions(-) create mode 100644 tests/hardhat/TokenVault/tokenVaultTest.ts diff --git a/contracts/TokenVault/TokenVault.sol b/contracts/TokenVault/TokenVault.sol index faa9c0046..b78607596 100644 --- a/contracts/TokenVault/TokenVault.sol +++ b/contracts/TokenVault/TokenVault.sol @@ -37,19 +37,23 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, /// @notice Event emitted when tokens are updated event UpdateTokens(address token, bool isAdded); + /// @notice Event Emitted when lock period of token is set + event SetLockPeriod(address token, uint128 lockPeriod); + using SafeERC20Upgradeable for IERC20Upgradeable; - constructor(address _token, bool _timeBased, uint256 _blocksPerYear) TimeManagerV8(_timeBased, _blocksPerYear) { - ensureNonzeroAddress(_token); - tokens[_token] = true; - } + /// @custom:oz-upgrades-unsafe-allow constructor + constructor(bool _timeBased, uint256 _blocksPerYear) TimeManagerV8(_timeBased, _blocksPerYear) {} /** * @notice Initialize the contract * @param _accessControlManager Address of access control manager + * @param _token Address of token */ - function initialize(address _accessControlManager) external initializer { + function initialize(address _accessControlManager, address _token) external initializer { ensureNonzeroAddress(_accessControlManager); + ensureNonzeroAddress(_token); + tokens[_token] = true; accessControlManager = _accessControlManager; } @@ -67,6 +71,20 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, emit UpdateTokens(_token, _isAdded); } + /** + * @notice Sets Lock period of particular token + * @param _token Address of token + * @param _lockPeriod Minimum time between withdrawal request and its execution + * @custom:event Emit SetLockPeriod with token and its lock period + * @custom:access Controlled by Access Control Manager + */ + function setLockPeriod(address _token, uint128 _lockPeriod) external { + _ensureAllowed("setLockPeriod(address,uint128)"); + require(tokens[_token], "TokenVault::setLockPeriod: token is not registered"); + tokenLockPeriod[_token] = _lockPeriod; + emit SetLockPeriod(_token, _lockPeriod); + } + /** * @notice Deposit token to TokenVault * @param _token Address of token to be deposited @@ -76,9 +94,9 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, function deposit(address _token, uint256 _amount) external nonReentrant whenNotPaused { require(tokens[_token], "TokenVault::deposit: token is not registered"); require(_amount > 0, "TokenVault::deposit: invalid amount"); - UserInfo storage user = userInfos[msg.sender][_token]; + UserInfo storage user = userInfos[_token][msg.sender]; IERC20Upgradeable(_token).safeTransferFrom(msg.sender, address(this), _amount); - userInfos[msg.sender][_token].amount = user.amount + _amount; + userInfos[_token][msg.sender].amount = user.amount + _amount; _moveDelegates(address(0), delegates[msg.sender], _amount, _token); emit Deposit(msg.sender, _token, _amount); } @@ -90,8 +108,8 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, */ function executeWithdrawal(address _token) external nonReentrant whenNotPaused { require(tokens[_token], "TokenVault::executeWithdrawal: token is not registered"); - UserInfo storage user = userInfos[msg.sender][_token]; - WithdrawalRequest[] storage requests = withdrawalRequests[msg.sender][_token]; + UserInfo storage user = userInfos[_token][msg.sender]; + WithdrawalRequest[] storage requests = withdrawalRequests[_token][msg.sender]; uint256 withdrawalAmount; @@ -104,6 +122,24 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, emit ExecutedWithdrawal(msg.sender, _token, withdrawalAmount); } + /** + * @notice Pause the vault + * @custom:access Controlled by Access Controlled Manager + */ + function pause() external { + _ensureAllowed("pause()"); + _pause(); + } + + /** + * @notice Unpause the vault + * @custom:access Controlled by Access Controlled Manager + */ + function unpause() external { + _ensureAllowed("unpause()"); + _unpause(); + } + /** * @notice Request withdrawal to TokenVault for token allocation * @param _token Address of token to be withdrawal @@ -160,7 +196,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @return Total amount of requested but not yet executed withdrawals (including both executable and locked ones) */ function getRequestedAmount(address _token, address _user) external view returns (uint256) { - require(tokens[_token], "TokenVault::getRequestedAmount: token is not reistered"); + require(tokens[_token], "TokenVault::getRequestedAmount: token is not registered"); UserInfo storage user = userInfos[_token][_user]; return user.pendingWithdrawals; } @@ -172,7 +208,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @return An array of withdrawal requests */ function getWithdrawalRequests(address _token, address _user) external view returns (WithdrawalRequest[] memory) { - require(tokens[_token], "TokenVault::getWithdrawalRequests: token is not reistered"); + require(tokens[_token], "TokenVault::getWithdrawalRequests: token is not registered"); return withdrawalRequests[_token][_user]; } @@ -232,7 +268,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, address _token, address _user ) external view returns (uint256 amount, uint256 pendingWithdrawals) { - require(tokens[_token], "TokenVault::getUserInfo: token is not reistered"); + require(tokens[_token], "TokenVault::getUserInfo: token is not registered"); UserInfo storage user = userInfos[_token][_user]; amount = user.amount; pendingWithdrawals = user.pendingWithdrawals; @@ -244,6 +280,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @param _token Address of token */ function delegate(address _delegatee, address _token) external whenNotPaused { + require(tokens[_token], "TokenVault::delegate: token is not registered"); return _delegate(msg.sender, _delegatee, _token); } @@ -452,7 +489,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @return The balance that user staked */ function getStakeAmount(address _account, address _token) internal view returns (uint256) { - require(tokens[_token], "TokenVault::getStakeAmount: token is not reistered"); + require(tokens[_token], "TokenVault::getStakeAmount: token is not registered"); UserInfo storage user = userInfos[_token][_account]; return user.amount - (user.pendingWithdrawals); } diff --git a/tests/hardhat/TokenVault/tokenVaultTest.ts b/tests/hardhat/TokenVault/tokenVaultTest.ts new file mode 100644 index 000000000..67ab4a45a --- /dev/null +++ b/tests/hardhat/TokenVault/tokenVaultTest.ts @@ -0,0 +1,154 @@ +import { SignerWithAddress } from "@nomicfoundation/hardhat-ethers/signers"; +import { loadFixture, mine } from "@nomicfoundation/hardhat-network-helpers"; +import { expect } from "chai"; +import { BigNumber } from "ethers"; +import { parseUnits } from "ethers/lib/utils"; +import { ethers, upgrades } from "hardhat"; + +import { MockToken, TokenVault } from "../../../typechain"; + +describe("TokenVault", async () => { + let deployer: SignerWithAddress; + let signer1: SignerWithAddress; + let tokenVault: TokenVault; + let accessControlManager; + let token: MockToken; + let amount: BigNumber; + + const tokenVaultFixture = async () => { + [deployer, signer1] = await ethers.getSigners(); + amount = parseUnits("10", 18); + const accessControlManagerFactory = await ethers.getContractFactory("AccessControlManager"); + accessControlManager = await accessControlManagerFactory.deploy(); + const tokenFactory = await ethers.getContractFactory("MockToken"); + token = await tokenFactory.deploy("HARD_Token", "HARD", 18); + const tokenVaultFactory = await ethers.getContractFactory("TokenVault"); + tokenVault = await upgrades.deployProxy(tokenVaultFactory, [accessControlManager.address, token.address], { + constructorArgs: [false, 10512000], + initializer: "initialize", + unsafeAllow: ["constructor"], + }); + + let tx = await accessControlManager.giveCallPermission( + tokenVault.address, + "updateTokens(address,bool)", + deployer.address, + ); + await tx.wait(); + tx = await accessControlManager.giveCallPermission( + tokenVault.address, + "setLockPeriod(address,uint128)", + deployer.address, + ); + await tx.wait(); + + tx = await accessControlManager.giveCallPermission(tokenVault.address, "pause()", deployer.address); + await tx.wait(); + + tx = await accessControlManager.giveCallPermission(tokenVault.address, "unpause()", deployer.address); + await tx.wait(); + + await tokenVault.setLockPeriod(token.address, 300); + await token.faucet(parseUnits("100", 18)); + }; + + beforeEach("Configure Vault", async () => { + await loadFixture(tokenVaultFixture); + }); + + describe("Deposit", async () => { + it("User can deposit registered token", async () => { + await token.approve(tokenVault.address, amount); + await expect(tokenVault.deposit(token.address, amount)).to.emit(tokenVault, "Deposit"); + expect(await token.balanceOf(tokenVault.address)).equals(amount); + }); + it("Reverts if token is not registered or zero amount is given ", async () => { + const tokenFactory = await ethers.getContractFactory("MockToken"); + const dummyToken = await tokenFactory.deploy("DUMMY_Token", "DUMMY", 18); + await expect(tokenVault.deposit(dummyToken.address, amount)).to.be.revertedWith( + "TokenVault::deposit: token is not registered", + ); + await expect(tokenVault.deposit(token.address, 0)).to.be.revertedWith("TokenVault::deposit: invalid amount"); + }); + it("Reverts if vault is paused", async () => { + await tokenVault.pause(); + await expect(tokenVault.deposit(token.address, amount)).to.be.revertedWith("Pausable: paused"); + }); + }); + + describe("Delegate", async () => { + it("Reverts if vault is paused", async () => { + await tokenVault.pause(); + await expect(tokenVault.delegate(signer1.address, token.address)).to.be.revertedWith("Pausable: paused"); + }); + it("Reverts if token is not registered", async () => { + const tokenFactory = await ethers.getContractFactory("MockToken"); + const dummyToken = await tokenFactory.deploy("DUMMY_Token", "DUMMY", 18); + await expect(tokenVault.delegate(signer1.address, dummyToken.address)).to.be.revertedWith( + "TokenVault::delegate: token is not registered", + ); + }); + it("Delegate successfully", async () => { + expect(await tokenVault.delegates(deployer.address)).to.equals(ethers.constants.AddressZero); + await expect(tokenVault.delegate(signer1.address, token.address)).to.emit(tokenVault, "DelegateChangedV2"); + const amount = parseUnits("10", 18); + await token.approve(tokenVault.address, amount); + await expect(tokenVault.deposit(token.address, amount)).to.emit(tokenVault, "Deposit"); + expect(await tokenVault.numCheckpoints(token.address, signer1.address)).equals(1); + expect((await tokenVault.checkpoints(token.address, signer1.address, 0))[1]).equals(amount); + let latestBlock = (await ethers.provider.getBlock("latest")).number; + await mine(); + expect(await tokenVault.getPriorVotes(signer1.address, latestBlock, token.address)).equals(amount); + expect(await tokenVault.getPriorVotes(deployer.address, latestBlock, token.address)).equals(0); + await token.approve(tokenVault.address, amount); + + // Deposit again + await expect(tokenVault.deposit(token.address, amount)).to.emit(tokenVault, "Deposit"); + expect(await tokenVault.numCheckpoints(token.address, signer1.address)).equals(2); + expect((await tokenVault.checkpoints(token.address, signer1.address, 1))[1]).equals(amount.mul(2)); + latestBlock = (await ethers.provider.getBlock("latest")).number; + await mine(); + expect(await tokenVault.getPriorVotes(signer1.address, latestBlock, token.address)).equals(amount.mul(2)); + expect(await tokenVault.getPriorVotes(deployer.address, latestBlock, token.address)).equals(0); + }); + }); + describe("Withdraw", async () => { + it("Withdraw tokens", async () => { + const amount = parseUnits("10", 18); + await token.approve(tokenVault.address, amount); + await expect(tokenVault.deposit(token.address, amount)).to.emit(tokenVault, "Deposit"); + expect(await token.balanceOf(deployer.address)).equals(parseUnits("90", 18)); + await expect(tokenVault.requestWithdrawal(token.address, amount)).to.emit(tokenVault, "RequestedWithdrawal"); + await expect(tokenVault.executeWithdrawal(token.address)).to.be.revertedWith("nothing to withdraw"); + await mine(300); + await expect(tokenVault.executeWithdrawal(token.address)).to.emit(tokenVault, "ExecutedWithdrawal"); + expect(await token.balanceOf(deployer.address)).equals(parseUnits("100", 18)); + expect(await token.balanceOf(tokenVault.address)).equals(0); + }); + it("Reverts if vault is paused", async () => { + await tokenVault.pause(); + await expect(tokenVault.requestWithdrawal(token.address, amount)).to.be.revertedWith("Pausable: paused"); + await expect(tokenVault.executeWithdrawal(token.address)).to.be.revertedWith("Pausable: paused"); + }); + it("Reverts if token is not registered", async () => { + const tokenFactory = await ethers.getContractFactory("MockToken"); + const dummyToken = await tokenFactory.deploy("DUMMY_Token", "DUMMY", 18); + await expect(tokenVault.requestWithdrawal(dummyToken.address, amount)).to.be.revertedWith( + "TokenVault::requestWithdrawal: token is not registered", + ); + await expect(tokenVault.executeWithdrawal(dummyToken.address)).to.be.revertedWith( + "TokenVault::executeWithdrawal: token is not registered", + ); + }); + it("Reverts if zero amount is passed for withdrawal", async () => { + await expect(tokenVault.requestWithdrawal(token.address, 0)).to.be.revertedWith( + "TokenVault::requestWithdrawal: requested amount cannot be zero", + ); + }); + it("User cannot withdrawal more than deposit", async () => { + await expect(tokenVault.requestWithdrawal(token.address, amount)).to.be.revertedWith( + "TokenVault::requestWithdrawal: requested amount is invalid", + ); + }); + }); +}); From b3deb988270ce41c0aa42269bd7d2f65c7143333 Mon Sep 17 00:00:00 2001 From: GitGuru7 <128375421+GitGuru7@users.noreply.github.com> Date: Thu, 2 May 2024 12:30:04 +0530 Subject: [PATCH 3/4] refactor: use custom error --- contracts/TokenVault/TokenVault.sol | 73 +++++++++++++--------- contracts/TokenVault/TokenVaultStorage.sol | 9 +++ tests/hardhat/TokenVault/tokenVaultTest.ts | 37 ++++++----- 3 files changed, 77 insertions(+), 42 deletions(-) diff --git a/contracts/TokenVault/TokenVault.sol b/contracts/TokenVault/TokenVault.sol index b78607596..e70b923a8 100644 --- a/contracts/TokenVault/TokenVault.sol +++ b/contracts/TokenVault/TokenVault.sol @@ -43,7 +43,9 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, using SafeERC20Upgradeable for IERC20Upgradeable; /// @custom:oz-upgrades-unsafe-allow constructor - constructor(bool _timeBased, uint256 _blocksPerYear) TimeManagerV8(_timeBased, _blocksPerYear) {} + constructor(bool _timeBased, uint256 _blocksPerYear) TimeManagerV8(_timeBased, _blocksPerYear) { + _disableInitializers(); + } /** * @notice Initialize the contract @@ -80,7 +82,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, */ function setLockPeriod(address _token, uint128 _lockPeriod) external { _ensureAllowed("setLockPeriod(address,uint128)"); - require(tokens[_token], "TokenVault::setLockPeriod: token is not registered"); + isTokenRegistered(_token); tokenLockPeriod[_token] = _lockPeriod; emit SetLockPeriod(_token, _lockPeriod); } @@ -90,10 +92,13 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @param _token Address of token to be deposited * @param _amount Amount of token to be deposited * @custom:event Emit Deposit with msg.sender, token and amount + * @custom:error ZeroAmountNotAllowed is thrown when zero amount is passed */ function deposit(address _token, uint256 _amount) external nonReentrant whenNotPaused { - require(tokens[_token], "TokenVault::deposit: token is not registered"); - require(_amount > 0, "TokenVault::deposit: invalid amount"); + isTokenRegistered(_token); + if (_amount == 0) { + revert ZeroAmountNotAllowed(); + } UserInfo storage user = userInfos[_token][msg.sender]; IERC20Upgradeable(_token).safeTransferFrom(msg.sender, address(this), _amount); userInfos[_token][msg.sender].amount = user.amount + _amount; @@ -107,13 +112,11 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @custom:event Emit ExecutedWithdrawal with msg.sender, token and withdrawal amount */ function executeWithdrawal(address _token) external nonReentrant whenNotPaused { - require(tokens[_token], "TokenVault::executeWithdrawal: token is not registered"); + isTokenRegistered(_token); UserInfo storage user = userInfos[_token][msg.sender]; WithdrawalRequest[] storage requests = withdrawalRequests[_token][msg.sender]; - uint256 withdrawalAmount; - - withdrawalAmount = popEligibleWithdrawalRequests(user, requests); + uint256 withdrawalAmount = popEligibleWithdrawalRequests(user, requests); require(withdrawalAmount > 0, "nothing to withdraw"); user.amount = user.amount - withdrawalAmount; @@ -145,18 +148,20 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @param _token Address of token to be withdrawal * @param _amount The amount to withdraw from the vault * @custom:event Emit RequestedWithdrawal with msg.sender, token and withdrawal amount + * @custom:error ZeroAmountNotAllowed is thrown when zero amount is passed + * @custom:error InvalidAmount is thrown when given amount and pending withdrawals are greater than deposited amount. */ function requestWithdrawal(address _token, uint256 _amount) external nonReentrant whenNotPaused { - require(tokens[_token], "TokenVault::requestWithdrawal: token is not registered"); - require(_amount > 0, "TokenVault::requestWithdrawal: requested amount cannot be zero"); + isTokenRegistered(_token); + if (_amount == 0) { + revert ZeroAmountNotAllowed(); + } UserInfo storage user = userInfos[_token][msg.sender]; WithdrawalRequest[] storage requests = withdrawalRequests[_token][msg.sender]; - require( - user.amount >= user.pendingWithdrawals + _amount, - "TokenVault::requestWithdrawal: requested amount is invalid" - ); - + if (user.amount < user.pendingWithdrawals + _amount) { + revert InvalidAmount(); + } uint256 lockedUntil = tokenLockPeriod[_token] + block.timestamp; pushWithdrawalRequest(user, requests, _amount, lockedUntil); @@ -178,7 +183,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, address _token, address _user ) external view returns (uint256 withdrawalAmount) { - require(tokens[_token], "TokenVault::getEligibleWithdrawalAmount: token is not registered"); + isTokenRegistered(_token); WithdrawalRequest[] storage requests = withdrawalRequests[_token][_user]; // Since the requests are sorted by their unlock time, we can take // the entries from the end of the array and stop at the first @@ -196,7 +201,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @return Total amount of requested but not yet executed withdrawals (including both executable and locked ones) */ function getRequestedAmount(address _token, address _user) external view returns (uint256) { - require(tokens[_token], "TokenVault::getRequestedAmount: token is not registered"); + isTokenRegistered(_token); UserInfo storage user = userInfos[_token][_user]; return user.pendingWithdrawals; } @@ -208,7 +213,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @return An array of withdrawal requests */ function getWithdrawalRequests(address _token, address _user) external view returns (WithdrawalRequest[] memory) { - require(tokens[_token], "TokenVault::getWithdrawalRequests: token is not registered"); + isTokenRegistered(_token); return withdrawalRequests[_token][_user]; } @@ -224,7 +229,8 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, uint256 _blockNumberOrSecond, address _token ) external view returns (uint256) { - require(_blockNumberOrSecond < getBlockNumberOrTimestamp(), "TokenVault::getPriorVotes: not yet determined"); + require(_blockNumberOrSecond < getBlockNumberOrTimestamp(), "Not yet determined"); + isTokenRegistered(_token); uint32 nCheckpoints = numCheckpoints[_token][_account]; if (nCheckpoints == 0) { @@ -258,8 +264,8 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, } /** - * @notice Get user info with reward token address and pid - * @param _token Reward token address + * @notice Get user info + * @param _token Address of token * @param _user User address * @return amount Deposited amount * @return pendingWithdrawals Requested but not yet executed withdrawals @@ -268,7 +274,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, address _token, address _user ) external view returns (uint256 amount, uint256 pendingWithdrawals) { - require(tokens[_token], "TokenVault::getUserInfo: token is not registered"); + isTokenRegistered(_token); UserInfo storage user = userInfos[_token][_user]; amount = user.amount; pendingWithdrawals = user.pendingWithdrawals; @@ -280,7 +286,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @param _token Address of token */ function delegate(address _delegatee, address _token) external whenNotPaused { - require(tokens[_token], "TokenVault::delegate: token is not registered"); + isTokenRegistered(_token); return _delegate(msg.sender, _delegatee, _token); } @@ -302,14 +308,15 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, bytes32 s, address _token ) external whenNotPaused { + isTokenRegistered(_token); bytes32 domainSeparator = keccak256( - abi.encode(DOMAIN_TYPEHASH, keccak256(bytes("XVSVault")), block.chainid, address(this)) + abi.encode(DOMAIN_TYPEHASH, keccak256(bytes("TokenVault")), block.chainid, address(this)) ); bytes32 structHash = keccak256(abi.encode(DELEGATION_TYPEHASH, _delegatee, _nonce, _expiry)); bytes32 digest = keccak256(abi.encodePacked("\x19\x01", domainSeparator, structHash)); address signatory = ECDSA.recover(digest, v, r, s); - require(_nonce == nonces[signatory]++, "XVSVault::delegateBySig: invalid nonce"); - require(block.timestamp <= _expiry, "XVSVault::delegateBySig: signature expired"); + require(_nonce == nonces[signatory]++, "Invalid nonce"); + require(block.timestamp <= _expiry, "Signature expired"); return _delegate(signatory, _delegatee, _token); } @@ -483,13 +490,12 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, } /** - * @notice Get the XVS stake balance of an account (excluding the pending withdrawals) + * @notice Get the token stake balance of an account (excluding the pending withdrawals) * @param _account The address of the account to check * @param _token Address of token * @return The balance that user staked */ function getStakeAmount(address _account, address _token) internal view returns (uint256) { - require(tokens[_token], "TokenVault::getStakeAmount: token is not registered"); UserInfo storage user = userInfos[_token][_account]; return user.amount - (user.pendingWithdrawals); } @@ -505,6 +511,17 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, ); } + /** + * @dev This function reverts if token is not registered + * @param _token Address of the token + * @custom:error UnregisteredToken is thrown when token is not registered in TokenVault + */ + function isTokenRegistered(address _token) private view { + if (!tokens[_token]) { + revert UnregisteredToken(_token); + } + } + /** * @dev Checks if the request is eligible for withdrawal. * @param _request The request struct storage pointer diff --git a/contracts/TokenVault/TokenVaultStorage.sol b/contracts/TokenVault/TokenVaultStorage.sol index 578909dc6..f9557f59a 100644 --- a/contracts/TokenVault/TokenVaultStorage.sol +++ b/contracts/TokenVault/TokenVaultStorage.sol @@ -58,6 +58,15 @@ contract TokenVaultStorage { /// @notice A record of states for signing / validating signatures mapping(address => uint) public nonces; + /// @notice Thrown when token is not registered + error UnregisteredToken(address token); + + /// @notice Thrown when zero amount is passed + error ZeroAmountNotAllowed(); + + /// @notice Thrown when given amount is invalid + error InvalidAmount(); + /** * @dev This empty reserved space is put in place to allow future versions to add new * variables without shifting down storage in the inheritance chain. diff --git a/tests/hardhat/TokenVault/tokenVaultTest.ts b/tests/hardhat/TokenVault/tokenVaultTest.ts index 67ab4a45a..6d3b3204a 100644 --- a/tests/hardhat/TokenVault/tokenVaultTest.ts +++ b/tests/hardhat/TokenVault/tokenVaultTest.ts @@ -21,7 +21,7 @@ describe("TokenVault", async () => { const accessControlManagerFactory = await ethers.getContractFactory("AccessControlManager"); accessControlManager = await accessControlManagerFactory.deploy(); const tokenFactory = await ethers.getContractFactory("MockToken"); - token = await tokenFactory.deploy("HARD_Token", "HARD", 18); + token = await tokenFactory.deploy("MockToken", "MT", 18); const tokenVaultFactory = await ethers.getContractFactory("TokenVault"); tokenVault = await upgrades.deployProxy(tokenVaultFactory, [accessControlManager.address, token.address], { constructorArgs: [false, 10512000], @@ -65,10 +65,14 @@ describe("TokenVault", async () => { it("Reverts if token is not registered or zero amount is given ", async () => { const tokenFactory = await ethers.getContractFactory("MockToken"); const dummyToken = await tokenFactory.deploy("DUMMY_Token", "DUMMY", 18); - await expect(tokenVault.deposit(dummyToken.address, amount)).to.be.revertedWith( - "TokenVault::deposit: token is not registered", + await expect(tokenVault.deposit(dummyToken.address, amount)).to.be.revertedWithCustomError( + tokenVault, + "UnregisteredToken", + ); + await expect(tokenVault.deposit(token.address, 0)).to.be.revertedWithCustomError( + tokenVault, + "ZeroAmountNotAllowed", ); - await expect(tokenVault.deposit(token.address, 0)).to.be.revertedWith("TokenVault::deposit: invalid amount"); }); it("Reverts if vault is paused", async () => { await tokenVault.pause(); @@ -84,8 +88,9 @@ describe("TokenVault", async () => { it("Reverts if token is not registered", async () => { const tokenFactory = await ethers.getContractFactory("MockToken"); const dummyToken = await tokenFactory.deploy("DUMMY_Token", "DUMMY", 18); - await expect(tokenVault.delegate(signer1.address, dummyToken.address)).to.be.revertedWith( - "TokenVault::delegate: token is not registered", + await expect(tokenVault.delegate(signer1.address, dummyToken.address)).to.be.revertedWithCustomError( + tokenVault, + "UnregisteredToken", ); }); it("Delegate successfully", async () => { @@ -133,21 +138,25 @@ describe("TokenVault", async () => { it("Reverts if token is not registered", async () => { const tokenFactory = await ethers.getContractFactory("MockToken"); const dummyToken = await tokenFactory.deploy("DUMMY_Token", "DUMMY", 18); - await expect(tokenVault.requestWithdrawal(dummyToken.address, amount)).to.be.revertedWith( - "TokenVault::requestWithdrawal: token is not registered", + await expect(tokenVault.requestWithdrawal(dummyToken.address, amount)).to.be.revertedWithCustomError( + tokenVault, + "UnregisteredToken", ); - await expect(tokenVault.executeWithdrawal(dummyToken.address)).to.be.revertedWith( - "TokenVault::executeWithdrawal: token is not registered", + await expect(tokenVault.executeWithdrawal(dummyToken.address)).to.be.revertedWithCustomError( + tokenVault, + "UnregisteredToken", ); }); it("Reverts if zero amount is passed for withdrawal", async () => { - await expect(tokenVault.requestWithdrawal(token.address, 0)).to.be.revertedWith( - "TokenVault::requestWithdrawal: requested amount cannot be zero", + await expect(tokenVault.requestWithdrawal(token.address, 0)).to.be.revertedWithCustomError( + tokenVault, + "ZeroAmountNotAllowed", ); }); it("User cannot withdrawal more than deposit", async () => { - await expect(tokenVault.requestWithdrawal(token.address, amount)).to.be.revertedWith( - "TokenVault::requestWithdrawal: requested amount is invalid", + await expect(tokenVault.requestWithdrawal(token.address, amount)).to.be.revertedWithCustomError( + tokenVault, + "InvalidAmount", ); }); }); From 22a3474cb3394d12497e33b369349664fe62acfb Mon Sep 17 00:00:00 2001 From: GitGuru7 <128375421+GitGuru7@users.noreply.github.com> Date: Thu, 2 May 2024 13:27:09 +0530 Subject: [PATCH 4/4] refactor: inherit AccessControlledV8 & ReentrancyGuardUpgradeable & PausableUpgradeable --- contracts/TokenVault/TokenVault.sol | 46 +++++++--------------- contracts/TokenVault/TokenVaultStorage.sol | 5 +-- tests/hardhat/TokenVault/tokenVaultTest.ts | 28 +++---------- 3 files changed, 21 insertions(+), 58 deletions(-) diff --git a/contracts/TokenVault/TokenVault.sol b/contracts/TokenVault/TokenVault.sol index e70b923a8..19704bd80 100644 --- a/contracts/TokenVault/TokenVault.sol +++ b/contracts/TokenVault/TokenVault.sol @@ -2,12 +2,11 @@ pragma solidity 0.8.25; -import { ReentrancyGuard } from "@openzeppelin/contracts/security/ReentrancyGuard.sol"; -import { Pausable } from "@openzeppelin/contracts/security/Pausable.sol"; +import { ReentrancyGuardUpgradeable } from "@openzeppelin/contracts-upgradeable/security/ReentrancyGuardUpgradeable.sol"; +import { PausableUpgradeable } from "@openzeppelin/contracts-upgradeable/security/PausableUpgradeable.sol"; import { ensureNonzeroAddress } from "@venusprotocol/solidity-utilities/contracts/validators.sol"; import { IERC20Upgradeable } from "@openzeppelin/contracts-upgradeable/token/ERC20/IERC20Upgradeable.sol"; -import { IAccessControlManagerV8 } from "@venusprotocol/governance-contracts/contracts/Governance/IAccessControlManagerV8.sol"; -import { Initializable } from "@openzeppelin/contracts-upgradeable/proxy/utils/Initializable.sol"; +import { AccessControlledV8 } from "@venusprotocol/governance-contracts/contracts/Governance/AccessControlledV8.sol"; import { SafeERC20Upgradeable } from "@openzeppelin/contracts-upgradeable/token/ERC20/utils/SafeERC20Upgradeable.sol"; import { ECDSA } from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; import { TimeManagerV8 } from "@venusprotocol/solidity-utilities/contracts/TimeManagerV8.sol"; @@ -18,7 +17,13 @@ import { TokenVaultStorage } from "./TokenVaultStorage.sol"; * @author Venus * @notice Token vault is a generic vault that can support multiple token. User can lock their supported token in the TokenVault to receive voting rights in Venus governance. */ -contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, TokenVaultStorage { +contract TokenVault is + PausableUpgradeable, + ReentrancyGuardUpgradeable, + TimeManagerV8, + AccessControlledV8, + TokenVaultStorage +{ /// @notice Event emitted when deposit event Deposit(address indexed user, address indexed token, uint256 indexed amount); @@ -56,7 +61,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, ensureNonzeroAddress(_accessControlManager); ensureNonzeroAddress(_token); tokens[_token] = true; - accessControlManager = _accessControlManager; + __AccessControlled_init(_accessControlManager); } /** @@ -67,7 +72,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @custom:event Emit UpdateTokens with address of token and its bool value */ function updateTokens(address _token, bool _isAdded) external { - _ensureAllowed("updateTokens(address,bool)"); + _checkAccessAllowed("updateTokens(address,bool)"); ensureNonzeroAddress(address(_token)); tokens[_token] = _isAdded; emit UpdateTokens(_token, _isAdded); @@ -81,7 +86,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @custom:access Controlled by Access Control Manager */ function setLockPeriod(address _token, uint128 _lockPeriod) external { - _ensureAllowed("setLockPeriod(address,uint128)"); + _checkAccessAllowed("setLockPeriod(address,uint128)"); isTokenRegistered(_token); tokenLockPeriod[_token] = _lockPeriod; emit SetLockPeriod(_token, _lockPeriod); @@ -130,7 +135,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @custom:access Controlled by Access Controlled Manager */ function pause() external { - _ensureAllowed("pause()"); + _checkAccessAllowed("pause()"); _pause(); } @@ -139,7 +144,7 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, * @custom:access Controlled by Access Controlled Manager */ function unpause() external { - _ensureAllowed("unpause()"); + _checkAccessAllowed("unpause()"); _unpause(); } @@ -320,16 +325,6 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, return _delegate(signatory, _delegatee, _token); } - /** - * @notice Set Access Control Manager - * @param _accessControlManager Address of Access Control Manager - */ - function setAccessControlManager(address _accessControlManager) external { - _ensureAllowed("setAccessControlManager(address)"); - ensureNonzeroAddress(_accessControlManager); - accessControlManager = _accessControlManager; - } - /** * @notice Gets the current votes balance for `account` * @param _account The address to get votes balance @@ -500,17 +495,6 @@ contract TokenVault is Pausable, ReentrancyGuard, Initializable, TimeManagerV8, return user.amount - (user.pendingWithdrawals); } - /** - * @dev Ensure that the caller has permission to execute a specific function - * @param functionSig_ Function signature to be checked for permission - */ - function _ensureAllowed(string memory functionSig_) internal view { - require( - IAccessControlManagerV8(accessControlManager).isAllowedToCall(msg.sender, functionSig_), - "access denied" - ); - } - /** * @dev This function reverts if token is not registered * @param _token Address of the token diff --git a/contracts/TokenVault/TokenVaultStorage.sol b/contracts/TokenVault/TokenVaultStorage.sol index f9557f59a..2df394c20 100644 --- a/contracts/TokenVault/TokenVaultStorage.sol +++ b/contracts/TokenVault/TokenVaultStorage.sol @@ -1,6 +1,6 @@ // SPDX-License-Identifier: BSD-3-Clause -pragma solidity ^0.8.25; +pragma solidity 0.8.25; contract TokenVaultStorage { /// @notice Info of each user. @@ -20,9 +20,6 @@ contract TokenVaultStorage { uint256 votes; } - // Access Control Manager - address public accessControlManager; - /// @notice The EIP-712 typehash for the contract's domain bytes32 public constant DOMAIN_TYPEHASH = keccak256("EIP712Domain(string name,uint256 chainId,address verifyingContract)"); diff --git a/tests/hardhat/TokenVault/tokenVaultTest.ts b/tests/hardhat/TokenVault/tokenVaultTest.ts index 6d3b3204a..6ffb2c8ad 100644 --- a/tests/hardhat/TokenVault/tokenVaultTest.ts +++ b/tests/hardhat/TokenVault/tokenVaultTest.ts @@ -1,3 +1,4 @@ +import { FakeContract, smock } from "@defi-wonderland/smock"; import { SignerWithAddress } from "@nomicfoundation/hardhat-ethers/signers"; import { loadFixture, mine } from "@nomicfoundation/hardhat-network-helpers"; import { expect } from "chai"; @@ -5,21 +6,21 @@ import { BigNumber } from "ethers"; import { parseUnits } from "ethers/lib/utils"; import { ethers, upgrades } from "hardhat"; -import { MockToken, TokenVault } from "../../../typechain"; +import { IAccessControlManagerV8, MockToken, TokenVault } from "../../../typechain"; describe("TokenVault", async () => { let deployer: SignerWithAddress; let signer1: SignerWithAddress; let tokenVault: TokenVault; - let accessControlManager; + let accessControlManager: FakeContract; let token: MockToken; let amount: BigNumber; const tokenVaultFixture = async () => { [deployer, signer1] = await ethers.getSigners(); amount = parseUnits("10", 18); - const accessControlManagerFactory = await ethers.getContractFactory("AccessControlManager"); - accessControlManager = await accessControlManagerFactory.deploy(); + accessControlManager = await smock.fake("AccessControlManager"); + accessControlManager.isAllowedToCall.returns(true); const tokenFactory = await ethers.getContractFactory("MockToken"); token = await tokenFactory.deploy("MockToken", "MT", 18); const tokenVaultFactory = await ethers.getContractFactory("TokenVault"); @@ -29,25 +30,6 @@ describe("TokenVault", async () => { unsafeAllow: ["constructor"], }); - let tx = await accessControlManager.giveCallPermission( - tokenVault.address, - "updateTokens(address,bool)", - deployer.address, - ); - await tx.wait(); - tx = await accessControlManager.giveCallPermission( - tokenVault.address, - "setLockPeriod(address,uint128)", - deployer.address, - ); - await tx.wait(); - - tx = await accessControlManager.giveCallPermission(tokenVault.address, "pause()", deployer.address); - await tx.wait(); - - tx = await accessControlManager.giveCallPermission(tokenVault.address, "unpause()", deployer.address); - await tx.wait(); - await tokenVault.setLockPeriod(token.address, 300); await token.faucet(parseUnits("100", 18)); };