diff --git a/.gitignore b/.gitignore index fd47912f..f4c452e9 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,8 @@ src/dataset/agent-v1-c4/datasets.json .DS_Store src/.DS_Store .env -tests/* \ No newline at end of file +tests/* +uv.lock +callgraph.json +lancedbsrc1/lancedb_src* +FiniteMonkey*/* diff --git a/.python-version b/.python-version new file mode 100644 index 00000000..f3fe474a --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12.9 diff --git a/.vscode/code-explorer.json b/.vscode/code-explorer.json new file mode 100644 index 00000000..3503e221 --- /dev/null +++ b/.vscode/code-explorer.json @@ -0,0 +1,22 @@ +{ + "#": "NOT recommend to edit manually. Write carefully! Generated by tianjianchn.code-explorer vscode extension.", + "#markerCount": 1, + "stacks": [ + { + "title": "class root(BaseM 2025-02-14", + "isActive": true, + "createdAt": "2025-02-15T00:17:34.074Z", + "id": "98d72628b1bb5d4b6be2bb35a6f73501", + "markers": [ + { + "code": "class root(BaseModel):", + "file": "src/root.py", + "line": 3, + "column": 0, + "createdAt": "2025-02-15T00:17:53.463Z", + "id": "156f801ac1de21e620e787f199161418" + } + ] + } + ] +} \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..23182ecb --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,21 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File with Arguments", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + }, + { + "name": "Python Debugger: Python File", + "type": "debugpy", + "request": "launch", + "program": "${file}" + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..7a73a41b --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,2 @@ +{ +} \ No newline at end of file diff --git a/app/core.py b/app/core.py new file mode 100644 index 00000000..52e5de64 --- /dev/null +++ b/app/core.py @@ -0,0 +1,174 @@ +from textual import work +from textual.app import App +from textual.widgets import Footer, Header +from models.schemas import LogCtxData +from ui.widgets import ThinSpinner, BilingualTable, TranslationToggle +import asyncio +from textual.containers import VerticalScroll, Horizontal +from textual.widgets import Input +import asyncio +import sys +import io +from textual import app, work +from textual.widgets import Input, RichLog +from textual.widgets import ProgressBar +from textual.reactive import reactive +#from rich.layout import VerticalScroll +from textual.app import ComposeResult +from textual.containers import Vertical, VerticalScroll, Horizontal + + +from textual.containers import Container, Horizontal, Vertical +from textual.widgets import DataTable, Switch, Static +from textual import work, on +from textual.reactive import var +from app.translation import TranslationService + + +class CoreApp(App): + CSS = """ + #main-content { + layout: grid; + grid-size: 2; + grid-columns: 1fr 1fr; + height: 1fr; + } + """ + translation_enabled = var(True) # Enabled by default + + async def on_mount(self): + self.main_loop = asyncio.get_running_loop() + self.start_ipython() + + @work(thread=True, exclusive=True) + def start_ipython(self): + """Start IPython in background thread""" + import sys + from IPython import start_ipython + + class IPythonWrapper: + def __init__(self, app): + self.app = app + + def write(self, data): + asyncio.run_coroutine_threadsafe( + self.app.post_output(data), + loop=self.app.main_loop + ) + def flush(self): + pass + + def isatty(self): + return False # Crucial fix for prompt_toolkit integration + + def fileno(self): + return -1 # Indicate no real file descriptor + + + sys.stdout = IPythonWrapper(self) + start_ipython(argv=[], user_ns={ + "app": self, + "toggle_translation": self.toggle_translation + }) + + async def post_output(self, data: str): + """Handle output from IPython thread""" + if self.translation_enabled: + # Add translation logic here + self.query_one(BilingualTable).add_row(data) + else: + self.query_one(BilingualTable).add_row(data) + self.refresh() + + def action_quit(self): + self.exit() + + def toggle_translation(self, enabled: bool): + """Toggle translation mode from IPython""" + self.translation_enabled = enabled + self.notify(f"Translation {'enabled' if enabled else 'disabled'}") + + CSS = """ + #main-content { + layout: grid; + grid-size: 2; + grid-columns: 1fr 1fr; + height: 1fr; + } + """ + translation_enabled = var(True) # Enabled by default + + def compose(self): + # Top controls + yield Horizontal( + ThinSpinner(), + TranslationToggle(value=True), # Toggle enabled by default + id="spinner-container" + ) + + # Main content - side by side tables + yield Horizontal( + BilingualTable("English", lang="EN"), + BilingualTable("中文", lang="CN"), + id="main-content" + ) + + # Input at bottom + yield Horizontal( + Input(placeholder="Input text/commands >>> "), + id="input-container" + ) + + async def process_output(self, data: str): + """Process all output with bidirectional translation""" + try: + # Detect input language + is_english = self._is_english(data) + source_lang = "EN" if is_english else "CN" + target_lang = "CN" if is_english else "EN" + + # Create translation context + ctx = LogCtxData( + txtENG=data if is_english else "", + txtCN=data if not is_english else "", + detail="Auto-translation" + ) + + # Always show original + translation when enabled + if self.translation_enabled: + translated = await self.translation.translate(ctx, target_lang) + self._update_tables( + source_text=data, + translated_text=translated, + source_lang=source_lang + ) + else: + self._update_single_table(data, source_lang) + + except Exception as e: + self.log_error(f"Translation error: {str(e)}") + + def _update_tables(self, source_text: str, translated_text: str, source_lang: str): + """Update both tables with original and translation""" + en_table = self.query_one("#english-table") + zh_table = self.query_one("#chinese-table") + + if source_lang == "EN": + en_table.add_row(source_text) + zh_table.add_row(translated_text) + else: + zh_table.add_row(source_text) + en_table.add_row(translated_text) + + def _update_single_table(self, text: str, source_lang: str): + """Update only the relevant table when translation off""" + table = self.query_one("#english-table" if source_lang == "EN" else "#chinese-table") + table.add_row(text) + + @on(Input.Submitted) + async def handle_input(self, event: Input.Submitted): + """Handle user input with immediate translation""" + input_text = event.value + if input_text: + await self.process_output(input_text) + event.input.clear() \ No newline at end of file diff --git a/app/ipython.py b/app/ipython.py new file mode 100644 index 00000000..7140e25c --- /dev/null +++ b/app/ipython.py @@ -0,0 +1,200 @@ +# app/ipython.py +import asyncio +import sys +from typing import Callable, Optional +from textual.widgets import Input +from textual.app import App + # Rest of your mount logic +class IPythonIO(io.TextIOBase): + """Thread-safe I/O redirection for IPython""" + def __init__(self, queue: asyncio.Queue, main_loop: asyncio.AbstractEventLoop): + self.queue = queue + self.main_loop = main_loop + + def write(self, data: str) -> int: + asyncio.run_coroutine_threadsafe( + self.queue.put(data), + loop=self.main_loop + ) + return len(data) + + def flush(self) -> None: + pass +class IPythonConsole(App): +# class IPythonConsole(BaseWindow): + """IPython-integrated console inheriting from base window""" + def on_mount(self) -> None: + # Initialize IPython components + self.main_loop = asyncio.get_running_loop() + self.start_ipython() + self.set_interval(0.05, self.process_output) + + # Keep previous IPython integration methods + @work(thread=True) + def start_ipython(self) -> None: + # Same thread setup as before + ... + + def process_output(self) -> None: + # Same output handling + ... + + # Rest of your IPython methods +# class IPythonConsole(app.App): + CSS = """ + Screen { + layout: vertical; + } + + Vertical { + height: auto; + } + + ThinSpinner { + height: 1; + margin: 0; + } + + /* Rest of your CSS */ + """ + + async def on_mount(self) -> None: + # Mount spinner first in the layout + await self.mount( + ThinSpinner(show_eta=False, total=100, show_bar=True ), # Explicit init + CustomRichUI(), + VerticalScroll( + RichLog(id="output-view"), + Input(placeholder=">>> ") + ) + ) + +# class IPythonConsole(app.App): + """Textual app hosting an IPython REPL with async integration""" + CSS = """ + RichLog { + height: 1fr; + overflow-y: auto; + } + Input { + dock: bottom; + } + """ + + def __init__(self): + super().__init__() + self.input_queue = asyncio.Queue() + self.output_queue = asyncio.Queue() + self.main_loop = None + + async def on_mount(self) -> None: + """Initialize application components""" + self.main_loop = asyncio.get_running_loop() + self.rich_log = RichLog() + self.input_widget = Input(placeholder=">>> ") + await self.mount(self.rich_log, self.input_widget) + self.start_ipython() + self.set_interval(0.05, self.process_output) + + @work(thread=True) + def start_ipython(self) -> None: + """Launch IPython in a background thread""" + ipy_loop = asyncio.new_event_loop() + asyncio.set_event_loop(ipy_loop) + + sys.stdin = self + sys.stdout = IPythonIO(self.output_queue, self.main_loop) + sys.stderr = IPythonIO(self.output_queue, self.main_loop) + + from IPython import start_ipython + try: + start_ipython( + argv=[], + user_ns=self.get_ipython_namespace(), + display_banner=False + ) + finally: + sys.stdin = sys.__stdin__ + sys.stdout = sys.__stdout__ + + def get_ipython_namespace(self) -> dict: + """Provide objects accessible in IPython REPL""" + return { + "app": self, + "run_async": self.run_in_main_loop, + "fetch_data": self.sample_async_method + } + + def run_in_main_loop(self, coro) -> any: + """Execute async code in main thread's event loop""" + return asyncio.run_coroutine_threadsafe( + coro, + loop=self.main_loop + ).result() + + async def sample_async_method(self) -> str: + """Example async method callable from IPython""" + await asyncio.sleep(1) + return "Data fetched successfully!" + + async def process_output(self) -> None: + """Update UI with output from IPython""" + while not self.output_queue.empty(): + data = await self.output_queue.get() + self.rich_log.write(data) + self.rich_log.scroll_end(animate=False) + + def readline(self, size: int = -1) -> str: + """Get input from queue (blocking in IPython thread)""" + return asyncio.run_coroutine_threadsafe( + self.input_queue.get(), + loop=self.main_loop + ).result() + + async def on_input_submitted(self, event: Input.Submitted) -> None: + """Handle user input submissions""" + await self.input_queue.put(event.value + "\n") + self.input_widget.clear() + + +# class IPythonIO: +# """Thread-safe I/O redirection with proper event loop handling""" +# def __init__(self, callback: Callable, main_loop: asyncio.AbstractEventLoop): +# self.callback = callback +# self.main_loop = main_loop + +# def write(self, data: str): +# if data.strip(): +# # Use the main thread's event loop +# asyncio.run_coroutine_threadsafe( +# self.callback(data), +# loop=self.main_loop +# ) + +# def flush(self): +# pass + +# class IPythonHost: +# """Managed IPython integration with lifecycle control""" +# def __init__(self, translation_callback: Callable): +# self.translation_callback = translation_callback +# self.main_loop: Optional[asyncio.AbstractEventLoop] = None + +# def start(self, main_loop: asyncio.AbstractEventLoop): +# """Start IPython with proper loop reference""" +# self.main_loop = main_loop +# sys.stdout = IPythonIO(self.handle_output, main_loop) + +# from IPython import start_ipython +# start_ipython(argv=[], user_ns={ +# "app": self, +# "translate": self.translation_callback +# }) + +# async def handle_output(self, data: str): +# """Process output through translation pipeline""" +# try: +# if self.main_loop and self.main_loop.is_running(): +# await self.translation_callback(data) +# except Exception as e: +# print(f"Output handling error: {e}") diff --git a/app/translation.py b/app/translation.py new file mode 100644 index 00000000..2364ce64 --- /dev/null +++ b/app/translation.py @@ -0,0 +1,43 @@ +from textwrap import dedent + + +from models.schemas import LogCtxData + +class TranslationService: + async def translate(self, ctx: LogCtxData, target_lang: str) -> str: + """Bidirectional translation with quality checks""" + if target_lang == "CN": + return await self._translate_english_to_chinese(ctx) + return await self._translate_chinese_to_english(ctx) + + async def _translate_english_to_chinese(self, ctx: LogCtxData) -> str: + prompt = dedent(f""" + Translate this technical text to colloquial Chinese: + {ctx.txtENG} + + Requirements: + - Maintain technical accuracy + - Use natural spoken Chinese + - Preserve numbers and proper nouns + """) + response = await self.ingress.chat.completions.create( + model="towerinstruct", + messages=[{"role": "user", "content": prompt}] + ) + return response.choices[0].message.content + + async def _translate_chinese_to_english(self, ctx: LogCtxData) -> str: + prompt = dedent(f""" + Translate this Chinese text to clear English: + {ctx.txtCN} + + Requirements: + - Keep technical terms in English + - Use simple, direct phrasing + - Maintain original formatting + """) + response = await self.egress.chat.completions.create( + model="towerinstruct", + messages=[{"role": "user", "content": prompt}] + ) + return response.choices[0].message.content \ No newline at end of file diff --git a/docs/flows.md b/docs/flows.md new file mode 100644 index 00000000..895b47fd --- /dev/null +++ b/docs/flows.md @@ -0,0 +1,32 @@ + +#internals +```mermaid +sequenceDiagram + participant M as Main __main__ + participant AM as async_main + participant TT as Thread Tasks (to_thread) + participant TG as TaskGroup + participant MP as message_pump + participant AG as Agent Stream + participant LS as Live Display + + M->>AM: asyncio.run async_main + AM->>AM: Load config, create engine, init Context + AM->>TT: await to_threadproject_audit.parse... + TT-->>AM: Parsing complete + AM->>TT: await to_threadplanning.do_planning + TT-->>AM: Planning complete + AM->>TT: await to_thread ai_engine.do_scan + TT-->>AM: Scan complete + AM->>TT: await to_thread ai_engine.check_function_vul + TT-->>AM: Vulnerability check complete + AM->>TG: Enter TaskGroup + TG->>MP: Create task: message_pump + MP->>AG: agent.run_streamprompt + loop For each message + AG-->>MP: yield message + MP->>LS: live.updateMarkdownmessage + end + AG-->>MP: Stream complete + TG-->>AM: TaskGroup exits all tasks done +``` \ No newline at end of file diff --git a/liquidity_pool.move b/liquidity_pool.move deleted file mode 100644 index 79b1118d..00000000 --- a/liquidity_pool.move +++ /dev/null @@ -1,2072 +0,0 @@ -module yuzuswap::liquidity_pool { - - use std::bcs; - use std::signer; - use std::string; - use std::vector; - - use aptos_std::aptos_hash; - use aptos_std::math128; - use aptos_std::math64; - use aptos_std::string_utils; - use aptos_std::table::{Self, Table}; - use aptos_framework::account::{Self, SignerCapability}; - use aptos_framework::event; - use aptos_framework::fungible_asset::{Self, FungibleStore, Metadata, FungibleAsset}; - use aptos_framework::object::{Self, Object, ExtendRef}; - use aptos_framework::primary_fungible_store; - - use yuzuswap::fixed_point; - use yuzuswap::config; - use yuzuswap::emergency; - use yuzuswap::fa_helper; - use yuzuswap::fee_tier; - use yuzuswap::i128::{Self, I128}; - use yuzuswap::sqrt_price_math; - use yuzuswap::swap_math; - use yuzuswap::tick_bitmap; - use yuzuswap::tick_math; - use yuzuswap::tick; - - friend yuzuswap::position_nft_manager; - friend yuzuswap::router; - friend yuzuswap::reward_manager; - - // Errors codes. - - /// Wrong tokens ordering. - const E_UNREACHABLE_CODE: u64 = 100; - /// Wrong tokens ordering. - const E_WRONG_TOKENS_ORDERING: u64 = 101; - /// The pool is locked. - const E_LOCKED_POOL: u64 = 102; - /// The tick lower must be less than the tick upper. - const E_TICK_LOWER_MUST_LESS_THAN_TICK_UPPER: u64 = 103; - /// The tick is not spaced. - const E_TICK_NOT_SPACED: u64 = 104; - /// The tick exceeds the maximum tick. - const E_EXCEED_MAX_TICK: u64 = 105; - /// The liquidity exceeds the maximum liquidity per tick. - const E_EXCEED_MAX_LIQUIDITY_PER_TICK: u64 = 106; - /// Mismatch token. - const E_TOKEN_MISMATCH: u64 = 107; - /// Not enough token to add liquidity. - const E_NOT_ENOUGH_TOKEN_TO_ADD_LIQUIDITY: u64 = 108; - /// The swap amount must be greater than zero. - const E_SWAP_AMOUNT_MUST_GREATER_THAN_ZERO: u64 = 109; - /// Invalid limit sqrt price. - const E_INVALID_LIMIT_SQRT_PRICE: u64 = 110; - /// Invalid token to pay swap. - const E_INVALID_PAY_SWAP_TOKEN: u64 = 111; - /// Invalid pay swap amount. - const E_INVALID_PAY_SWAP_AMOUNT: u64 = 112; - /// The position does not exist. - const E_POSITION_NOT_EXIST: u64 = 113; - /// The position still has liquidity. - const E_NOT_EMPTY_LIQUIDITY_POSITION: u64 = 114; - /// The position still has owed fees. - const E_NOT_EMPTY_FEE_POSITION: u64 = 115; - /// The position still has rewards. - const E_NOT_EMPTY_REWARD_POSITION: u64 = 116; - /// Exceed max rewards per pool. - const E_EXCEED_MAX_REWARDS_PER_POOL: u64 = 117; - /// The user is not reward manager. - const E_NOT_REWARD_MANAGER: u64 = 118; - /// Invalid reward token. - const E_INVALID_REWARD_TOKEN: u64 = 119; - /// Not enough reward. - const E_NOT_ENOUGH_REWARD: u64 = 120; - - // Constants. - - const MAX_U128: u256 = 0xffffffffffffffffffffffffffffffff; - - const MAX_REWARDS_PER_POOL: u64 = 3; - - // Structs. - - /// Stores resource account signer capability under Yuzuswap account. - struct PoolAccountCap has key { - signer_cap: SignerCapability, - } - - struct LiquidityPools has key { - all_pools: vector>, - } - - #[resource_group_member(group = aptos_framework::object::ObjectGroup)] - struct LiquidityPool has key { - token_0_reserve: Object, - token_1_reserve: Object, - current_tick: u32, - current_sqrt_price: u128, - liquidity: u128, - tick_bitmap: Table, - ticks: Table, - positions: Table, Position>, - next_position_id: u64, - fee_growth_global_0_x64: u128, - fee_growth_global_1_x64: u128, - protocol_fee_amount_0: u64, - protocol_fee_amount_1: u64, - reward_infos: vector, - reward_last_updated_at_seconds: u64, - fee_rate: u64, - tick_spacing: u32, - max_liquidity_per_tick: u128, - unlocked: bool, - extend_ref: ExtendRef, - } - - struct Position has store, drop { - id: u64, - tick_lower: u32, - tick_upper: u32, - liquidity: u128, - fee_growth_inside_0_last_x64: u128, - fee_growth_inside_1_last_x64: u128, - tokens_owed_0: u64, - tokens_owed_1: u64, - reward_infos: vector, - } - - struct TickInfo has store, drop { - liquditiy_gross: u128, - liquidity_net: I128, - fee_growth_outside_0_x64: u128, - fee_growth_outside_1_x64: u128, - reward_growths_outside: vector, - initialized: bool, - } - - struct PoolRewardInfo has copy, drop, store { - token_metadata: Object, - remaining_reward: u64, - emissions_per_second: u64, - growth_global: u128, - manager: address, - } - - struct PositionRewardInfo has copy, drop, store { - reward_growth_inside_last: u128, - amount_owed: u64, - } - - struct SwapReciept { - pool: Object, - token_metadata: Object, - amount_in: u64, - } - - // Events. - - #[event] - struct CreatePoolEvent has drop, store { - creator: address, - pool: address, - token_x: Object, - token_y: Object, - fee: u64, - tick_spacing: u32, - } - - #[event] - struct AddLiquidityEvent has drop, store { - user: address, - pool: address, - position_id: u64, - liquidity: u128, - amount_0: u64, - amount_1: u64, - } - - #[event] - struct RemoveLiquidityEvent has drop, store { - user: address, - pool: address, - position_id: u64, - liquidity: u128, - amount_0: u64, - amount_1: u64, - } - - #[event] - struct SwapEvent has drop, store { - pool: address, - zero_for_one: bool, - is_exact_in: bool, - amount_in: u64, - amount_out: u64, - fee_amount: u64, - sqrt_price_after: u128, - liquidity_after: u128, - tick_after: u32, - } - - #[event] - struct CollectFeeEvent has drop, store { - user: address, - pool: address, - position_id: u64, - amount_0: u64, - amount_1: u64, - } - - #[event] - struct CollectProtocolFee has drop, store { - admin: address, - pool: address, - amount_0: u64, - amount_1: u64, - } - - #[event] - struct InitRewardEvent has drop, store { - pool: address, - reward_index: u64, - manager: address, - } - - #[event] - struct UpdateRewardManagerEvent has drop, store { - pool: address, - reward_index: u64, - manager: address, - } - - #[event] - struct UpdateRewardEmissionsEvent has drop, store { - pool: address, - reward_index: u64, - manager: address, - emissions_per_second: u64, - } - - #[event] - struct AddRewardEvent has drop, store { - pool: address, - reward_index: u64, - manager: address, - amount: u64, - } - - #[event] - struct RemoveRewardEvent has drop, store { - pool: address, - reward_index: u64, - manager: address, - amount: u64, - } - - #[event] - struct CollectRewardEvent has drop, store { - user: address, - pool: address, - position_id: u64, - reward_index: u64, - amount: u64, - } - - // Module initialization. - - fun init_module(owner: &signer) { - let (_, signer_cap) = account::create_resource_account(owner, b"pool_account"); - move_to(owner, PoolAccountCap { signer_cap }); - move_to(owner, LiquidityPools { all_pools: vector[] }); - } - - // Public functions. - - public(friend) fun create_pool( - sender: &signer, - token_0: Object, - token_1: Object, - fee_rate: u64, - sqrt_price: u128, - ): Object - acquires PoolAccountCap, LiquidityPools { - emergency::assert_no_emergency(); - - assert!(fa_helper::is_sorted(token_0, token_1), E_WRONG_TOKENS_ORDERING); - - let tick_spacing = fee_tier::get_tick_spacing(fee_rate); - let current_tick = tick_math::get_tick_at_sqrt_price(sqrt_price); - - let pool_account_cap = borrow_global(@yuzuswap); - let pool_account = account::create_signer_with_capability(&pool_account_cap.signer_cap); - - let pool_seed = get_pool_seed(token_0, token_1, fee_rate); - let pool_constructor_ref = &object::create_named_object(&pool_account, pool_seed); - let pool_signer = &object::generate_signer(pool_constructor_ref); - - move_to(pool_signer, LiquidityPool { - token_0_reserve: create_token_store(pool_signer, token_0), - token_1_reserve: create_token_store(pool_signer, token_1), - current_tick, - current_sqrt_price: sqrt_price, - liquidity: 0, - tick_bitmap: table::new(), - ticks: table::new(), - positions: table::new(), - next_position_id: 1, - fee_growth_global_0_x64: 0, - fee_growth_global_1_x64: 0, - protocol_fee_amount_0: 0, - protocol_fee_amount_1: 0, - reward_infos: vector[], - reward_last_updated_at_seconds: 0, - fee_rate, - tick_spacing, - max_liquidity_per_tick: tick::tick_spacing_to_max_liquidity_per_tick(tick_spacing), - unlocked: true, - extend_ref: object::generate_extend_ref(pool_constructor_ref), - }); - - vector::push_back( - &mut borrow_global_mut(@yuzuswap).all_pools, - object::object_from_constructor_ref(pool_constructor_ref), - ); - - event::emit( - CreatePoolEvent { - creator: signer::address_of(sender), - pool: object::address_from_constructor_ref(pool_constructor_ref), - token_x: token_0, - token_y: token_1, - fee: fee_rate, - tick_spacing, - }, - ); - - object::object_from_constructor_ref(pool_constructor_ref) - } - - public fun open_position( - user: &signer, - pool: Object, - tick_lower: u32, - tick_upper: u32, - ): u64 - acquires LiquidityPool { - emergency::assert_no_emergency(); - - let pool_data = pool_data_mut(&pool); - assert!(pool_data.unlocked, E_LOCKED_POOL); - - assert_ticks(tick_lower, tick_upper, pool_data.tick_spacing); - - let position_id = pool_data.next_position_id; - let position_key = get_position_key(&signer::address_of(user), position_id); - let position = Position { - id: position_id, - tick_lower, - tick_upper, - liquidity: 0, - fee_growth_inside_0_last_x64: 0, - fee_growth_inside_1_last_x64: 0, - tokens_owed_0: 0, - tokens_owed_1: 0, - reward_infos: vector[], - }; - - table::add(&mut pool_data.positions, position_key, position); - pool_data.next_position_id = pool_data.next_position_id + 1; - - position_id - } - - public fun close_position( - user: &signer, - pool: Object, - position_id: u64, - ) acquires LiquidityPool { - emergency::assert_no_emergency(); - - let pool_data = pool_data_mut(&pool); - assert!(pool_data.unlocked, E_LOCKED_POOL); - - let user_address = signer::address_of(user); - let position_key = get_position_key(&user_address, position_id); - assert!(table::contains(&pool_data.positions, position_key), E_POSITION_NOT_EXIST); - - let position = table::borrow_mut(&mut pool_data.positions, position_key); - - assert!(position.liquidity == 0, E_NOT_EMPTY_LIQUIDITY_POSITION); - assert!(position.tokens_owed_0 == 0, E_NOT_EMPTY_FEE_POSITION); - assert!(position.tokens_owed_1 == 0, E_NOT_EMPTY_FEE_POSITION); - for (i in 0..vector::length(&position.reward_infos)) { - let reward_info = vector::borrow(&position.reward_infos, i); - assert!(reward_info.amount_owed == 0, E_NOT_EMPTY_REWARD_POSITION); - }; - - table::remove(&mut pool_data.positions, position_key); - } - - public fun add_liquidity( - user: &signer, - pool: Object, - position_id: u64, - liquidity_delta: u128, - token_0: &mut FungibleAsset, - token_1: &mut FungibleAsset, - ) acquires LiquidityPool { - emergency::assert_no_emergency(); - - let pool_data = pool_data_mut(&pool); - assert!(pool_data.unlocked, E_LOCKED_POOL); - - assert!( - fungible_asset::metadata_from_asset(token_0) == fungible_asset::store_metadata(pool_data.token_0_reserve) - && fungible_asset::metadata_from_asset(token_1) == fungible_asset::store_metadata(pool_data.token_1_reserve), - E_TOKEN_MISMATCH, - ); - - let user_address = signer::address_of(user); - - let (amount_0, amount_1) = modify_position( - pool_data, - user_address, - position_id, - i128::new(liquidity_delta, false), - ); - - assert!(fungible_asset::amount(token_0) >= amount_0, E_NOT_ENOUGH_TOKEN_TO_ADD_LIQUIDITY); - assert!(fungible_asset::amount(token_1) >= amount_1, E_NOT_ENOUGH_TOKEN_TO_ADD_LIQUIDITY); - - let token_0_in = fungible_asset::extract(token_0, amount_0); - let token_1_in = fungible_asset::extract(token_1, amount_1); - fungible_asset::deposit(pool_data.token_0_reserve, token_0_in); - fungible_asset::deposit(pool_data.token_1_reserve, token_1_in); - - event::emit( - AddLiquidityEvent { - user: user_address, - pool: object::object_address(&pool), - position_id, - liquidity: liquidity_delta, - amount_0, - amount_1, - }, - ) - } - - public fun remove_liquidity( - user: &signer, - pool: Object, - position_id: u64, - liquidity_delta: u128, - ): (FungibleAsset, FungibleAsset) - acquires LiquidityPool, PoolAccountCap { - emergency::assert_no_emergency(); - - let pool_data = pool_data_mut(&pool); - - assert!(pool_data.unlocked, E_LOCKED_POOL); - - let user_address = signer::address_of(user); - - let (amount_0_out, amount_1_out) = modify_position( - pool_data, - user_address, - position_id, - i128::new(liquidity_delta, true), - ); - - event::emit( - RemoveLiquidityEvent { - user: user_address, - pool: object::object_address(&pool), - position_id, - liquidity: liquidity_delta, - amount_0: amount_0_out, - amount_1: amount_1_out, - }, - ); - - let pool_account_signer = get_pool_account_signer(); - ( - fungible_asset::withdraw(&pool_account_signer, pool_data.token_0_reserve, amount_0_out), - fungible_asset::withdraw(&pool_account_signer, pool_data.token_1_reserve, amount_1_out), - ) - } - - public fun swap( - trader: &signer, - pool: Object, - zero_for_one: bool, - is_exact_in: bool, - specified_amount: u64, - sqrt_price_limit: u128, - ): (FungibleAsset, SwapReciept) - acquires LiquidityPool, PoolAccountCap { - emergency::assert_no_emergency(); - - assert!(specified_amount > 0, E_SWAP_AMOUNT_MUST_GREATER_THAN_ZERO); - - let pool_data = pool_data_mut(&pool); - - assert!(pool_data.unlocked, E_LOCKED_POOL); - // lock here and unlock in pay_swap function to guarantee that amount in reserves is correct, avoid error when - // the user swaps and then immediately modifies the pool (swap in opposite direction, remove liquidity, etc.) - // before paying the swap -> the token in reserves could not be enough to do those actions. - pool_data.unlocked = false; - - if (zero_for_one) { - assert!( - sqrt_price_limit < pool_data.current_sqrt_price - && sqrt_price_limit >= tick_math::min_sqrt_price(), - E_INVALID_LIMIT_SQRT_PRICE, - ); - } else { - assert!( - sqrt_price_limit > pool_data.current_sqrt_price - && sqrt_price_limit <= tick_math::max_sqrt_price(), - E_INVALID_LIMIT_SQRT_PRICE, - ) - }; - - let reward_growths_global = update_pool_reward_infos(pool_data); - - let tick_spacing = pool_data.tick_spacing; - let current_sqrt_price = pool_data.current_sqrt_price; - let current_tick = pool_data.current_tick; - let liquidity = pool_data.liquidity; - let remaining_amount = specified_amount; - let calculated_amount = 0; - - let protocol_fee_rate = config::protocol_fee_rate(); - let fee_scale = config::fee_scale(); - let total_protocol_fee_amount = 0; - let total_fee_amount = 0; - let fee_rate = get_fee_rate(signer::address_of(trader), pool_data); - let fee_growth_global_x64 = if (zero_for_one) { - pool_data.fee_growth_global_0_x64 - } else { - pool_data.fee_growth_global_1_x64 - }; - - while (remaining_amount > 0 && current_sqrt_price != sqrt_price_limit) { - let price_sqrt_start = current_sqrt_price; - - let (tick_next, is_initialized_tick) = tick_bitmap::get_next_initialized_tick_within_one_word( - &pool_data.tick_bitmap, - current_tick, - tick_spacing, - zero_for_one, - ); - - // ensure that we do not overshoot the min/max tick, as the tick bitmap is not aware of these bounds - if (tick_next < tick::min_tick()) { - tick_next = tick::min_tick(); - } else if (tick_next > tick::max_tick()) { - tick_next = tick::max_tick(); - }; - - let sqrt_price_next = tick_math::get_sqrt_price_at_tick(tick_next); - - let target_sqrt_price = if (zero_for_one) { - math128::max(sqrt_price_limit, sqrt_price_next) - } else { - math128::min(sqrt_price_limit, sqrt_price_next) - }; - let (sqrt_price, amount_in, amount_out, fee_amount) = swap_math::compute_swap_step( - current_sqrt_price, - target_sqrt_price, - liquidity, - remaining_amount, - is_exact_in, - fee_rate, - ); - current_sqrt_price = sqrt_price; - - if (is_exact_in) { - remaining_amount = remaining_amount - (amount_in + fee_amount); - calculated_amount = calculated_amount + amount_out; - } else { - remaining_amount = remaining_amount - amount_out; - calculated_amount = calculated_amount + (amount_in + fee_amount); - }; - - if (protocol_fee_rate > 0) { - let protocol_fee_amount = math64::mul_div(fee_amount, protocol_fee_rate, fee_scale); - total_protocol_fee_amount = total_protocol_fee_amount + protocol_fee_amount; - - fee_amount = fee_amount - protocol_fee_amount; - }; - total_fee_amount = total_fee_amount + fee_amount; - - if (liquidity != 0) { - fee_growth_global_x64 = fee_growth_global_x64 + ((fee_amount as u128) << 64) / liquidity; - }; - - // shift tick if we reached the next price - if (sqrt_price == sqrt_price_next) { - // if the tick is initialized, run the tick transition - if (is_initialized_tick) { - // TODO: calculate oracle - - let next_liquidity_net = cross_tick( - &mut pool_data.ticks, - tick_next, - if (zero_for_one) fee_growth_global_x64 else pool_data.fee_growth_global_0_x64, - if (zero_for_one) pool_data.fee_growth_global_1_x64 else fee_growth_global_x64, - &reward_growths_global, - ); - // if we're moving leftward, we interpret liquidityNet as the opposite sign - // safe because liquidityNet cannot be type(int128).min - if (zero_for_one) { - // next_liquidity_net = -next_liquidity_net - next_liquidity_net = i128::new( - i128::abs(&next_liquidity_net), - !i128::is_negative(&next_liquidity_net), - ); - }; - - liquidity = add_delta_liquidity(liquidity, &next_liquidity_net); - }; - - current_tick = if (zero_for_one) tick_next - 1 else tick_next; - } else if (current_sqrt_price != price_sqrt_start) { - current_tick = tick_math::get_tick_at_sqrt_price(sqrt_price); - } - }; - - if (current_tick != pool_data.current_tick) { - pool_data.current_tick = current_tick; - pool_data.current_sqrt_price = current_sqrt_price; - } else { - pool_data.current_sqrt_price = current_sqrt_price; - }; - - pool_data.liquidity = liquidity; - - if (zero_for_one) { - pool_data.fee_growth_global_0_x64 = fee_growth_global_x64; - pool_data.protocol_fee_amount_0 = pool_data.protocol_fee_amount_0 + total_protocol_fee_amount; - } else { - pool_data.fee_growth_global_1_x64 = fee_growth_global_x64; - pool_data.protocol_fee_amount_1 = pool_data.protocol_fee_amount_1 + total_protocol_fee_amount; - }; - - let (amount_in, amount_out) = if (is_exact_in) { - (specified_amount - remaining_amount, calculated_amount) - } else { - (calculated_amount, specified_amount - remaining_amount) - }; - - event::emit( - SwapEvent { - pool: object::object_address(&pool), - zero_for_one, - is_exact_in, - amount_in, - amount_out, - fee_amount: total_fee_amount, - sqrt_price_after: pool_data.current_sqrt_price, - liquidity_after: pool_data.liquidity, - tick_after: pool_data.current_tick, - } - ); - - // withdraw expected amount from reserves. - let pool_account_signer = get_pool_account_signer(); - if (zero_for_one) { - ( - fungible_asset::withdraw(&pool_account_signer, pool_data.token_1_reserve, amount_out), - SwapReciept { - pool, - token_metadata: fungible_asset::store_metadata(pool_data.token_0_reserve), - amount_in, - }, - ) - } else { - ( - fungible_asset::withdraw(&pool_account_signer, pool_data.token_0_reserve, amount_out), - SwapReciept { - pool, - token_metadata: fungible_asset::store_metadata(pool_data.token_1_reserve), - amount_in, - }, - ) - } - } - - public fun get_swap_receipt_amount(swap_receipt: &SwapReciept): u64 { - swap_receipt.amount_in - } - - public fun get_swap_receipt_token_metadata(swap_receipt: &SwapReciept): Object { - swap_receipt.token_metadata - } - - public fun pay_swap( - token_in: FungibleAsset, - reciept: SwapReciept, - ) acquires LiquidityPool { - let SwapReciept { - pool, - token_metadata, - amount_in, - } = reciept; - - assert!(token_metadata == fungible_asset::metadata_from_asset(&token_in), E_INVALID_PAY_SWAP_TOKEN); - assert!(fungible_asset::amount(&token_in) == amount_in, E_INVALID_PAY_SWAP_AMOUNT); - - let pool_data = pool_data_mut(&pool); - - if (token_metadata == fungible_asset::store_metadata(pool_data.token_0_reserve)) { - fungible_asset::deposit(pool_data.token_0_reserve, token_in); - } else { - fungible_asset::deposit(pool_data.token_1_reserve, token_in); - }; - - pool_data.unlocked = true; - } - - public fun collect_fee( - user: &signer, - pool: Object, - position_id: u64, - amount_0_requested: u64, - amount_1_requested: u64, - ): (FungibleAsset, FungibleAsset) - acquires LiquidityPool, PoolAccountCap { - emergency::assert_no_emergency(); - - let pool_data = pool_data_mut(&pool); - let user_address = signer::address_of(user); - let position = get_position_mut(&mut pool_data.positions, user_address, position_id); - - // only update fee if the position has liquidity to avoid unnecessary computation - if (position.liquidity > 0) { - let (fee_growth_inside_0_x64, fee_growth_inside_1_x64) = get_fee_growth_inside_tick( - &pool_data.ticks, - position.tick_lower, - position.tick_upper, - pool_data.current_tick, - pool_data.fee_growth_global_0_x64, - pool_data.fee_growth_global_1_x64, - ); - update_position_fee(position, fee_growth_inside_0_x64, fee_growth_inside_1_x64); - }; - - let amount_0 = math64::min(amount_0_requested, position.tokens_owed_0); - position.tokens_owed_0 = position.tokens_owed_0 - amount_0; - - let amount_1 = math64::min(amount_1_requested, position.tokens_owed_1); - position.tokens_owed_1 = position.tokens_owed_1 - amount_1; - - event::emit( - CollectFeeEvent { - user: user_address, - pool: object::object_address(&pool), - position_id, - amount_0, - amount_1, - }, - ); - - let pool_signer = get_pool_account_signer(); - ( - fungible_asset::withdraw(&pool_signer, pool_data.token_0_reserve, amount_0), - fungible_asset::withdraw(&pool_signer, pool_data.token_1_reserve, amount_1), - ) - } - - public fun collect_protocol_fee( - admin: &signer, - pool: Object, - amount_0_requested: u64, - amount_1_requested: u64, - ): (FungibleAsset, FungibleAsset) - acquires LiquidityPool, PoolAccountCap { - config::assert_pool_admin(admin); - - let pool_data = pool_data_mut(&pool); - - let amount_0 = math64::min(amount_0_requested, pool_data.protocol_fee_amount_0); - pool_data.protocol_fee_amount_0 = pool_data.protocol_fee_amount_0 - amount_0; - - let amount_1 = math64::min(amount_1_requested, pool_data.protocol_fee_amount_1); - pool_data.protocol_fee_amount_1 = pool_data.protocol_fee_amount_1 - amount_1; - - event::emit( - CollectProtocolFee { - admin: signer::address_of(admin), - pool: object::object_address(&pool), - amount_0, - amount_1, - }, - ); - - let pool_signer = get_pool_account_signer(); - ( - fungible_asset::withdraw(&pool_signer, pool_data.token_0_reserve, amount_0), - fungible_asset::withdraw(&pool_signer, pool_data.token_1_reserve, amount_1), - ) - } - - public fun update_reward_manager( - user: &signer, - pool: Object, - reward_index: u64, - new_manager: address, - ) acquires LiquidityPool { - let pool_data = pool_data_mut(&pool); - let reward_info = vector::borrow(&pool_data.reward_infos, reward_index); - assert!(reward_info.manager == signer::address_of(user), E_NOT_REWARD_MANAGER); - - let reward_info = vector::borrow_mut(&mut pool_data.reward_infos, reward_index); - reward_info.manager = new_manager; - - event::emit( - UpdateRewardManagerEvent { - pool: object::object_address(&pool), - reward_index, - manager: new_manager, - }, - ); - } - - public fun update_reward_emissions( - user: &signer, - pool: Object, - reward_index: u64, - emissions_per_second: u64, - ) acquires LiquidityPool { - let pool_data = pool_data_mut(&pool); - let reward_info = vector::borrow(&pool_data.reward_infos, reward_index); - assert!(reward_info.manager == signer::address_of(user), E_NOT_REWARD_MANAGER); - - update_pool_reward_infos(pool_data); - - let reward_info = vector::borrow_mut(&mut pool_data.reward_infos, reward_index); - reward_info.emissions_per_second = emissions_per_second; - - event::emit( - UpdateRewardEmissionsEvent { - pool: object::object_address(&pool), - reward_index, - manager: reward_info.manager, - emissions_per_second, - }, - ); - } - - public fun add_reward( - user: &signer, - pool: Object, - reward_index: u64, - token: FungibleAsset, - ) acquires LiquidityPool, PoolAccountCap { - let pool_data = pool_data_mut(&pool); - let reward_info = vector::borrow(&pool_data.reward_infos, reward_index); - assert!(reward_info.manager == signer::address_of(user), E_NOT_REWARD_MANAGER); - assert!(reward_info.token_metadata == fungible_asset::metadata_from_asset(&token), E_INVALID_REWARD_TOKEN); - - update_pool_reward_infos(pool_data); - - let reward_info = vector::borrow_mut(&mut pool_data.reward_infos, reward_index); - let added_amount = fungible_asset::amount(&token); - reward_info.remaining_reward = reward_info.remaining_reward + added_amount; - - let pool_acccount_address = get_pool_account_address(); - primary_fungible_store::deposit(pool_acccount_address, token); - - event::emit( - AddRewardEvent { - pool: object::object_address(&pool), - reward_index, - manager: reward_info.manager, - amount: added_amount, - }, - ); - } - - public fun remove_reward( - user: &signer, - pool: Object, - reward_index: u64, - amount: u64, - ): FungibleAsset - acquires LiquidityPool, PoolAccountCap { - let pool_data = pool_data_mut(&pool); - let reward_info = vector::borrow(&pool_data.reward_infos, reward_index); - assert!(reward_info.manager == signer::address_of(user), E_NOT_REWARD_MANAGER); - - update_pool_reward_infos(pool_data); - - let reward_info = vector::borrow_mut(&mut pool_data.reward_infos, reward_index); - let real_amount = math64::min(reward_info.remaining_reward, amount); - - let pool_signer = get_pool_account_signer(); - let removed_reward = primary_fungible_store::withdraw(&pool_signer, reward_info.token_metadata, real_amount); - - reward_info.remaining_reward = reward_info.remaining_reward - real_amount; - - event::emit( - RemoveRewardEvent { - pool: object::object_address(&pool), - reward_index, - manager: reward_info.manager, - amount: real_amount, - }, - ); - - removed_reward - } - - public fun collect_reward( - user: &signer, - pool: Object, - position_id: u64, - reward_index: u64, - amount_requested: u64, - ): FungibleAsset - acquires LiquidityPool, PoolAccountCap { - let pool_data = pool_data_mut(&pool); - let reward_growths_global = update_pool_reward_infos(pool_data); - - let user_address = signer::address_of(user); - let position = get_position_mut(&mut pool_data.positions, user_address, position_id); - - // only update position rewards if the position has liquidity to avoid unnecessary computation - if (position.liquidity > 0) { - let reward_growths_inside = get_reward_growths_inside( - &pool_data.ticks, - position.tick_lower, - position.tick_upper, - pool_data.current_tick, - &reward_growths_global, - ); - update_position_rewards(position, &reward_growths_inside); - }; - - let position_reward = vector::borrow_mut(&mut position.reward_infos, reward_index); - - let amount = math64::min(amount_requested, position_reward.amount_owed); - position_reward.amount_owed = position_reward.amount_owed - amount; - - event::emit( - CollectRewardEvent { - user: user_address, - pool: object::object_address(&pool), - position_id, - reward_index, - amount, - }, - ); - - let pool_signer = get_pool_account_signer(); - let pool_reward = vector::borrow(&pool_data.reward_infos, reward_index); - primary_fungible_store::withdraw(&pool_signer, pool_reward.token_metadata, amount) - } - - fun modify_position( - pool: &mut LiquidityPool, - owner: address, - position_id: u64, - liquidity_delta: I128, - ): (u64, u64) { - if (i128::is_zero(&liquidity_delta)) { - return (0, 0) - }; - - let position = update_position( - pool, - owner, - position_id, - liquidity_delta, - ); - - let (tick_lower, tick_upper) = (position.tick_lower, position.tick_upper); - - let amount0: u64 = 0; - let amount1: u64 = 0; - - if (pool.current_tick < tick_lower) { - amount0 = sqrt_price_math::get_amount_0_delta( - tick_math::get_sqrt_price_at_tick(tick_lower), - tick_math::get_sqrt_price_at_tick(tick_upper), - i128::abs(&liquidity_delta), - i128::is_positive(&liquidity_delta), - ); - } else if (pool.current_tick < tick_upper) { - // TODO: write oracle - - amount0 = sqrt_price_math::get_amount_0_delta( - pool.current_sqrt_price, - tick_math::get_sqrt_price_at_tick(tick_upper), - i128::abs(&liquidity_delta), - i128::is_positive(&liquidity_delta), - ); - amount1 = sqrt_price_math::get_amount_1_delta( - tick_math::get_sqrt_price_at_tick(tick_lower), - pool.current_sqrt_price, - i128::abs(&liquidity_delta), - i128::is_positive(&liquidity_delta), - ); - - pool.liquidity = add_delta_liquidity(pool.liquidity, &liquidity_delta); - } else { - amount1 = sqrt_price_math::get_amount_1_delta( - tick_math::get_sqrt_price_at_tick(tick_lower), - tick_math::get_sqrt_price_at_tick(tick_upper), - i128::abs(&liquidity_delta), - i128::is_positive(&liquidity_delta), - ); - }; - - (amount0, amount1) - } - - fun update_position( - pool: &mut LiquidityPool, - owner: address, - position_id: u64, - liquidity_delta: I128, - ): &Position { - let reward_growths_global = update_pool_reward_infos(pool); - - let position = get_position_mut(&mut pool.positions, owner, position_id); - - let flipped_tick_lower = false; - let flipped_tick_upper = false; - if (!i128::is_zero(&liquidity_delta)) { - flipped_tick_lower = update_tick( - &mut pool.ticks, - position.tick_lower, - pool.current_tick, - liquidity_delta, - pool.fee_growth_global_0_x64, - pool.fee_growth_global_1_x64, - false, - pool.max_liquidity_per_tick, - ); - flipped_tick_upper = update_tick( - &mut pool.ticks, - position.tick_upper, - pool.current_tick, - liquidity_delta, - pool.fee_growth_global_0_x64, - pool.fee_growth_global_1_x64, - true, - pool.max_liquidity_per_tick, - ); - - if (flipped_tick_lower) { - tick_bitmap::flip_tick(&mut pool.tick_bitmap, position.tick_lower, pool.tick_spacing); - }; - if (flipped_tick_upper) { - tick_bitmap::flip_tick(&mut pool.tick_bitmap, position.tick_upper, pool.tick_spacing); - }; - }; - - let (fee_growth_inside_0_x64, fee_growth_inside_1_x64) = - get_fee_growth_inside_tick( - &pool.ticks, - position.tick_lower, - position.tick_upper, - pool.current_tick, - pool.fee_growth_global_0_x64, - pool.fee_growth_global_1_x64, - ); - let reward_growths_inside = get_reward_growths_inside( - &pool.ticks, - position.tick_lower, - position.tick_upper, - pool.current_tick, - &reward_growths_global, - ); - update_position_fee(position, fee_growth_inside_0_x64, fee_growth_inside_1_x64); - update_position_rewards(position, &reward_growths_inside); - - position.liquidity = i128::as_u128(&i128::add(&liquidity_delta, &i128::new(position.liquidity, false))); - - // clear any tick data that is no longer needed - if (i128::is_negative(&liquidity_delta)) { - if (flipped_tick_lower) { - clear_tick(&mut pool.ticks, position.tick_lower); - }; - if (flipped_tick_upper) { - clear_tick(&mut pool.ticks, position.tick_upper); - }; - }; - - position - } - - fun update_position_fee( - position: &mut Position, - fee_growth_inside_0_x64: u128, - fee_growth_inside_1_x64: u128, - ) { - let tokens_owed_0 = math128::mul_div( - fee_growth_inside_0_x64 - position.fee_growth_inside_0_last_x64, - position.liquidity, - fixed_point::q64(), - ); - let tokens_owed_1 = math128::mul_div( - fee_growth_inside_1_x64 - position.fee_growth_inside_1_last_x64, - position.liquidity, - fixed_point::q64(), - ); - - position.fee_growth_inside_0_last_x64 = fee_growth_inside_0_x64; - position.fee_growth_inside_1_last_x64 = fee_growth_inside_1_x64; - - // overflow is acceptable, have to withdraw before you hit "maximum of uint64" fees - position.tokens_owed_0 = position.tokens_owed_0 + (tokens_owed_0 as u64); - position.tokens_owed_1 = position.tokens_owed_1 + (tokens_owed_1 as u64); - } - - fun update_tick( - ticks: &mut Table, - tick: u32, - current_tick: u32, - liquidity_delta: I128, - fee_growth_global_0_x64: u128, - fee_growth_global_1_x64: u128, - is_tick_upper: bool, - max_liquidity_per_tick: u128, - ): bool { - let tick_info = table::borrow_mut_with_default(ticks, tick, TickInfo { - liquditiy_gross: 0, - liquidity_net: i128::zero(), - fee_growth_outside_0_x64: 0, - fee_growth_outside_1_x64: 0, - reward_growths_outside: vector[], - initialized: false, - }); - - let liquidity_gross_before = tick_info.liquditiy_gross; - let liquidity_gross_after = i128::as_u128( - &i128::add(&liquidity_delta, &i128::new(liquidity_gross_before, false)) - ); - - assert!(liquidity_gross_after <= max_liquidity_per_tick, E_EXCEED_MAX_LIQUIDITY_PER_TICK); - - let flipped = (liquidity_gross_before == 0) != (liquidity_gross_after == 0); - if (liquidity_gross_before == 0) { - // by convention, we assume that all growth before a tick was initialized happened _below_ the tick - if (tick <= current_tick) { - tick_info.fee_growth_outside_0_x64 = fee_growth_global_0_x64; - tick_info.fee_growth_outside_1_x64 = fee_growth_global_1_x64; - - // TODO: calculate oracle - }; - - tick_info.initialized == true; - }; - - tick_info.liquditiy_gross = liquidity_gross_after; - - // when the lower (upper) tick is crossed left to right (right to left), liquidity must be added (removed) - if (is_tick_upper) { - tick_info.liquidity_net = i128::sub(&tick_info.liquidity_net, &liquidity_delta); - } else { - tick_info.liquidity_net = i128::add(&liquidity_delta, &tick_info.liquidity_net); - }; - - flipped - } - - fun get_fee_growth_inside_tick( - ticks: &Table, - tick_lower: u32, - tick_upper: u32, - current_tick: u32, - fee_growth_global_0_x64: u128, - fee_growth_global_1_x64: u128, - ): (u128, u128) { - let tick_lower_info = borrow_tick_or_empty(ticks, tick_lower); - let tick_upper_info = borrow_tick_or_empty(ticks, tick_upper); - - // calculate fee growth below - let fee_growth_below_0_x64: u128; - let fee_growth_below_1_x64: u128; - if (current_tick >= tick_lower) { - fee_growth_below_0_x64 = tick_lower_info.fee_growth_outside_0_x64; - fee_growth_below_1_x64 = tick_lower_info.fee_growth_outside_1_x64; - } else { - fee_growth_below_0_x64 = fee_growth_global_0_x64 - tick_lower_info.fee_growth_outside_0_x64; - fee_growth_below_1_x64 = fee_growth_global_1_x64 - tick_lower_info.fee_growth_outside_1_x64; - }; - - // calculate fee growth above - let fee_growth_above_0_x64: u128; - let fee_growth_above_1_x64: u128; - if (current_tick < tick_upper) { - fee_growth_above_0_x64 = tick_upper_info.fee_growth_outside_0_x64; - fee_growth_above_1_x64 = tick_upper_info.fee_growth_outside_1_x64; - } else { - fee_growth_above_0_x64 = fee_growth_global_0_x64 - tick_upper_info.fee_growth_outside_0_x64; - fee_growth_above_1_x64 = fee_growth_global_1_x64 - tick_upper_info.fee_growth_outside_1_x64; - }; - - let fee_growth_inside_0_x64 = fee_growth_global_0_x64 - fee_growth_below_0_x64 - fee_growth_above_0_x64; - let fee_growth_inside_1_x64 = fee_growth_global_1_x64 - fee_growth_below_1_x64 - fee_growth_above_1_x64; - - (fee_growth_inside_0_x64, fee_growth_inside_1_x64) - } - - fun cross_tick( - ticks: &mut Table, - tick: u32, - fee_growth_global_0_x64: u128, - fee_growth_global_1_x64: u128, - reward_growths_global: &vector, - ): I128 { - let tick_info = table::borrow_mut(ticks, tick); - tick_info.fee_growth_outside_0_x64 = fee_growth_global_0_x64 - tick_info.fee_growth_outside_0_x64; - tick_info.fee_growth_outside_1_x64 = fee_growth_global_1_x64 - tick_info.fee_growth_outside_1_x64; - - update_reward_growths(&mut tick_info.reward_growths_outside, reward_growths_global); - - tick_info.liquidity_net - } - - fun update_reward_growths( - reward_growths_outside: &mut vector, - reward_growths_global: &vector, - ) { - let reward_growths_outside_length = vector::length(reward_growths_outside); - for (i in 0..vector::length(reward_growths_global)) { - if (i >= reward_growths_outside_length) { - vector::push_back( - reward_growths_outside, - *vector::borrow(reward_growths_global, i), - ); - } else { - let reward_growth_outside = vector::borrow_mut(reward_growths_outside, i); - *reward_growth_outside = *vector::borrow(reward_growths_global, i) - *reward_growth_outside; - }; - }; - } - - fun sub_reward_growths( - reward_growths_global: &vector, - reward_growths_outside: &vector, - ): vector { - let result = vector[]; - - let reward_growths_outside_length = vector::length(reward_growths_outside); - for (i in 0..vector::length(reward_growths_global)) { - let reward_growth_outside = if (i >= reward_growths_outside_length) { - 0 - } else { - *vector::borrow(reward_growths_outside, i) - }; - - vector::push_back( - &mut result, - *vector::borrow(reward_growths_global, i) - reward_growth_outside, - ); - }; - - result - } - - public(friend) fun initialize_reward( - user: &signer, - pool: Object, - token_metadata: Object, - manager: address, - ) acquires LiquidityPool { - config::assert_reward_admin(user); - - let pool_data = pool_data_mut(&pool); - - assert!(vector::length(&pool_data.reward_infos) < MAX_REWARDS_PER_POOL, E_EXCEED_MAX_REWARDS_PER_POOL); - - let poolRewardInfo = PoolRewardInfo { - token_metadata, - remaining_reward: 0, - emissions_per_second: 0, - growth_global: 0, - manager, - }; - vector::push_back(&mut pool_data.reward_infos, poolRewardInfo); - - event::emit( - InitRewardEvent { - pool: object::object_address(&pool), - reward_index: vector::length(&pool_data.reward_infos) - 1, - manager, - }, - ); - } - - fun update_pool_reward_infos(pool: &mut LiquidityPool): vector { - let current_time = 0x1::timestamp::now_seconds(); - // This should never happen. - assert!(current_time >= pool.reward_last_updated_at_seconds, E_UNREACHABLE_CODE); - - let reward_infos = &mut pool.reward_infos; - - let reward_growths_global = 0x1::vector::empty(); - let elapsed_seconds = current_time - pool.reward_last_updated_at_seconds; - - for (i in 0..vector::length(reward_infos)) { - let reward_info = vector::borrow_mut(reward_infos, i); - if (pool.liquidity != 0 && elapsed_seconds != 0 - && reward_info.emissions_per_second != 0 && reward_info.remaining_reward != 0 - ) { - let emitted_reward = elapsed_seconds * reward_info.emissions_per_second; - emitted_reward = math64::min(emitted_reward, reward_info.remaining_reward); - - reward_info.remaining_reward = reward_info.remaining_reward - emitted_reward; - - let growth_reward = fixed_point::u64_to_x64_u128(emitted_reward) / pool.liquidity; - reward_info.growth_global = reward_info.growth_global + growth_reward; - }; - vector::push_back(&mut reward_growths_global, reward_info.growth_global); - }; - - pool.reward_last_updated_at_seconds = current_time; - - reward_growths_global - } - - fun get_pool_reward_infos(pool: &LiquidityPool): vector { - let current_time = 0x1::timestamp::now_seconds(); - // This should never happen. - assert!(current_time >= pool.reward_last_updated_at_seconds, E_UNREACHABLE_CODE); - - let reward_infos = &pool.reward_infos; - - let reward_growths_global = 0x1::vector::empty(); - let elapsed_seconds = current_time - pool.reward_last_updated_at_seconds; - - for (i in 0..vector::length(reward_infos)) { - let reward_info = vector::borrow(reward_infos, i); - let reward_growth_global = reward_info.growth_global; - if (pool.liquidity != 0 && elapsed_seconds != 0 - && reward_info.emissions_per_second != 0 && reward_info.remaining_reward != 0 - ) { - let emitted_reward = elapsed_seconds * reward_info.emissions_per_second; - emitted_reward = math64::min(emitted_reward, reward_info.remaining_reward); - - let growth_reward = fixed_point::u64_to_x64_u128(emitted_reward) / pool.liquidity; - reward_growth_global = reward_info.growth_global + growth_reward; - }; - vector::push_back(&mut reward_growths_global, reward_growth_global); - }; - - reward_growths_global - } - - fun get_reward_growths_inside( - ticks: &Table, - tick_lower: u32, - tick_upper: u32, - current_tick: u32, - reward_growths_global: &vector - ): (vector) { - let tick_lower_info = table::borrow(ticks, tick_lower); - let tick_upper_info = table::borrow(ticks, tick_upper); - - // calculate reward growth below - let reward_growths_below = if (current_tick >= tick_lower) { - tick_lower_info.reward_growths_outside - } else { - sub_reward_growths(reward_growths_global, &tick_lower_info.reward_growths_outside) - }; - - // calculate fee growth above - let reward_growths_above = if (current_tick < tick_upper) { - tick_upper_info.reward_growths_outside - } else { - sub_reward_growths(reward_growths_global, &tick_upper_info.reward_growths_outside) - }; - - sub_reward_growths( - &sub_reward_growths(reward_growths_global, &reward_growths_below), - &reward_growths_above, - ) - } - - fun update_position_rewards( - position: &mut Position, - reward_growths_inside: &vector, - ) { - for (i in 0..vector::length(reward_growths_inside)) { - let liquidity = position.liquidity; - let reward_growth_inside = *vector::borrow(reward_growths_inside, i); - - let position_reward = try_borrow_mut_reward_info(position, i); - let amount_owed = math128::mul_div( - reward_growth_inside - position_reward.reward_growth_inside_last, - liquidity, - fixed_point::q64() - ); - - position_reward.reward_growth_inside_last = reward_growth_inside; - // overflow is acceptable, have to withdraw before you hit "maximum of uint64" fees - position_reward.amount_owed = position_reward.amount_owed + (amount_owed as u64); - }; - } - - fun try_borrow_mut_reward_info(position: &mut Position, i: u64): &mut PositionRewardInfo { - let len = vector::length(&position.reward_infos); - if (i >= len) { - vector::push_back(&mut position.reward_infos, PositionRewardInfo { - reward_growth_inside_last: 0, - amount_owed: 0 - }); - }; - - vector::borrow_mut(&mut position.reward_infos, i) - } - - /// Returns the word and bit position of the tick within the bitmap. - fun tick_bitmap_position(tick: u32): (u16, u8) { - let word_position = ((tick >> 8) as u16); - let bit_position = ((tick % 256) as u8); - - (word_position, bit_position) - } - - fun clear_tick(ticks: &mut Table, tick: u32) { - table::remove(ticks, tick); - } - - fun assert_ticks(tick_lower: u32, tick_upper: u32, tick_spacing: u32) { - assert!(tick_lower < tick_upper, E_TICK_LOWER_MUST_LESS_THAN_TICK_UPPER); - assert!( - tick::is_spaced_tick(tick_lower, tick_spacing) && tick::is_spaced_tick(tick_upper, tick_spacing), - E_TICK_NOT_SPACED - ); - assert!(tick_upper <= tick::max_tick(), E_EXCEED_MAX_TICK); - } - - fun add_delta_liquidity(liquidity: u128, delta_liquidity: &I128): u128 { - let liqudity_after = if (i128::is_positive(delta_liquidity)) { - (liquidity as u256) + (i128::abs(delta_liquidity) as u256) - } else { - assert!(liquidity >= i128::abs(delta_liquidity), E_EXCEED_MAX_LIQUIDITY_PER_TICK); - (liquidity as u256) - (i128::abs(delta_liquidity) as u256) - }; - assert!(liqudity_after <= MAX_U128, 1); - - (liqudity_after as u128) - } - - fun get_position_mut( - positions: &mut Table, Position>, - owner: address, - position_id: u64, - ): &mut Position { - let position_key = get_position_key(&owner, position_id); - assert!(table::contains(positions, position_key), E_POSITION_NOT_EXIST); - - table::borrow_mut(positions, position_key) - } - - fun get_position_key( - owner: &address, - position_id: u64, - ): vector { - let position_key_raw_data = string::bytes( - &string_utils::format2(&b"{}-{}", *owner, position_id) - ); - aptos_hash::keccak256(*position_key_raw_data) - } - - fun get_fee_rate(trader: address, pool: &LiquidityPool): u64 { - config::get_trader_fee_rate(trader, pool.fee_rate) - } - - fun get_pool_account_address(): address acquires PoolAccountCap { - account::get_signer_capability_address(&borrow_global(@yuzuswap).signer_cap) - } - - fun get_pool_account_signer(): signer acquires PoolAccountCap { - account::create_signer_with_capability(&borrow_global(@yuzuswap).signer_cap) - } - - // Inline functions. - - inline fun pool_data_mut(pool: &Object): &mut LiquidityPool { - borrow_global_mut(object::object_address(pool)) - } - - inline fun pool_data(pool: &Object): &LiquidityPool { - borrow_global(object::object_address(pool)) - } - - inline fun get_pool_seed(token_0: Object, token_1: Object, fee: u64): vector { - let seed = vector[]; - vector::append(&mut seed, b"pool"); - vector::append(&mut seed, bcs::to_bytes(&object::object_address(&token_0))); - vector::append(&mut seed, bcs::to_bytes(&object::object_address(&token_1))); - vector::append(&mut seed, bcs::to_bytes(&fee)); - seed - } - - inline fun create_token_store(pool_signer: &signer, token: Object): Object { - let constructor_ref = &object::create_object_from_object(pool_signer); - fungible_asset::create_store(constructor_ref, token) - } - - inline fun borrow_tick_or_empty(ticks: &Table, tick: u32): &TickInfo { - if (table::contains(ticks, tick)) { - table::borrow(ticks, tick) - } else { - &TickInfo { - liquditiy_gross: 0, - liquidity_net: i128::zero(), - fee_growth_outside_0_x64: 0, - fee_growth_outside_1_x64: 0, - reward_growths_outside: vector[], - initialized: false, - } - } - } - - public(friend) fun get_pool_signer(pool: Object): signer acquires LiquidityPool { - object::generate_signer_for_extending(&pool_data(&pool).extend_ref) - } - - // View functions. - - #[view] - public fun get_pool( - token_x: Object, - token_y: Object, - fee: u64, - ): Object acquires PoolAccountCap { - object::address_to_object(get_pool_address(token_x, token_y, fee)) - } - - #[view] - public fun get_all_pools(): vector> acquires LiquidityPools { - borrow_global(@yuzuswap).all_pools - } - - #[view] - public fun count_pool(): u64 acquires LiquidityPools { - vector::length(&borrow_global(@yuzuswap).all_pools) - } - - #[view] - public fun get_pool_info( - pool: Object, - ): ( - Object, - Object, - u128, - u32, - u128, - u64, - u32, - ) - acquires LiquidityPool { - let pool_data = pool_data(&pool); - - ( - fungible_asset::store_metadata(pool_data.token_0_reserve), - fungible_asset::store_metadata(pool_data.token_1_reserve), - pool_data.current_sqrt_price, - pool_data.current_tick, - pool_data.liquidity, - pool_data.fee_rate, - pool_data.tick_spacing, - ) - } - - #[view] - public fun rewards_count( - pool: Object, - ): u64 - acquires LiquidityPool { - vector::length(&pool_data(&pool).reward_infos) - } - - struct TickView has copy, drop, store { - tick: u32, - liquidity_gross: u128, - liquidity_net: I128, - fee_growth_outside_0_x64: u128, - fee_growth_outside_1_x64: u128, - reward_growths_outside: vector, - } - - #[view] - public fun get_ticks( - pool: Object, - start_tick: u32, - limit: u32, - ): (vector) - acquires LiquidityPool { - let pool_data = pool_data(&pool); - - let tick_count = 0; - let ticks = vector[]; - let max_count = (((tick::max_tick() / pool_data.tick_spacing) >> 8) as u16); - - let tick_spacing = pool_data.tick_spacing; - let tick_adjustment = tick::tick_adjustment(tick_spacing); - let adjusted_start_tick = if (start_tick >= tick_adjustment) start_tick - tick_adjustment else 0; - let compessed_start_tick = (math64::ceil_div((adjusted_start_tick as u64), (tick_spacing as u64)) as u32); - - let i = ((compessed_start_tick >> 8) as u16); - let first_position_in_word = compessed_start_tick % 256; - while (i <= max_count && tick_count < limit) { - let word = *table::borrow_with_default(&pool_data.tick_bitmap, i, &0); - if (word == 0) { - i = i + 1; - continue - }; - - let mask: u256; - let j = first_position_in_word; - while (j < 256) { - mask = 1 << (j as u8); - - if (mask & word != 0) { - let tick: u32 = (((i as u32) << 8) + j) * pool_data.tick_spacing + tick_adjustment; - let tick_info = table::borrow(&pool_data.ticks, tick); - - vector::push_back(&mut ticks, TickView { - tick, - liquidity_gross: tick_info.liquditiy_gross, - liquidity_net: tick_info.liquidity_net, - fee_growth_outside_0_x64: tick_info.fee_growth_outside_0_x64, - fee_growth_outside_1_x64: tick_info.fee_growth_outside_1_x64, - reward_growths_outside: tick_info.reward_growths_outside, - }); - tick_count = tick_count + 1; - - if (tick_count >= limit) { - break - }; - }; - - j = j + 1; - }; - first_position_in_word = 0; - - i = i + 1; - }; - - return ticks - } - - #[view] - public fun get_pool_address( - token_x: Object, - token_y: Object, - fee: u64, - ): address acquires PoolAccountCap { - object::create_object_address(&get_pool_account_address(), get_pool_seed(token_x, token_y, fee)) - } - - #[view] - public fun get_tokens( - pool: Object, - ): (Object, Object) - acquires LiquidityPool { - ( - fungible_asset::store_metadata(pool_data(&pool).token_0_reserve), - fungible_asset::store_metadata(pool_data(&pool).token_1_reserve), - ) - } - - #[view] - public fun get_reserves_size(pool: Object): (u64, u64) acquires LiquidityPool { - let pool_data = pool_data_mut(&pool); - - let amount_0 = fungible_asset::balance(pool_data.token_0_reserve); - let amount_1 = fungible_asset::balance(pool_data.token_1_reserve); - - (amount_0, amount_1) - } - - // Struct for view purpose. - struct LiquidityPoolView has drop { - pool_addr: address, - token_0: address, - token_1: address, - token_0_decimals: u8, - token_1_decimals: u8, - token_0_reserve: u64, - token_1_reserve: u64, - current_tick: u32, - current_sqrt_price: u128, - liquidity: u128, - fee_growth_global_0_x64: u128, - fee_growth_global_1_x64: u128, - reward_infos: vector, - fee_rate: u64, - tick_spacing: u32, - } - - #[view] - public fun get_pool_view(pool: Object): LiquidityPoolView acquires LiquidityPool { - map_pool_view(&pool) - } - - #[view] - public fun get_pool_views( - offset: u64, - limit: u64, - ): vector - acquires LiquidityPool, LiquidityPools { - let pools = &borrow_global(@yuzuswap).all_pools; - - let pool_views = vector[]; - for (i in offset..math64::min(vector::length(pools), offset + limit)) { - let pool = vector::borrow(pools, i); - - vector::push_back(&mut pool_views, map_pool_view(pool)); - }; - - return pool_views - } - - fun map_pool_view(pool: &Object): LiquidityPoolView acquires LiquidityPool { - let pool_data = pool_data(pool); - - let token_0_metadata = fungible_asset::store_metadata(pool_data.token_0_reserve); - let token_1_metadata = fungible_asset::store_metadata(pool_data.token_1_reserve); - - LiquidityPoolView { - pool_addr: object::object_address(pool), - token_0: object::object_address( &token_0_metadata), - token_1: object::object_address( &token_1_metadata), - token_0_decimals: fungible_asset::decimals(token_0_metadata), - token_1_decimals: fungible_asset::decimals(token_1_metadata), - token_0_reserve: fungible_asset::balance(pool_data.token_0_reserve), - token_1_reserve: fungible_asset::balance(pool_data.token_1_reserve), - current_tick: pool_data.current_tick, - current_sqrt_price: pool_data.current_sqrt_price, - liquidity: pool_data.liquidity, - fee_growth_global_0_x64: pool_data.fee_growth_global_0_x64, - fee_growth_global_1_x64: pool_data.fee_growth_global_1_x64, - reward_infos: pool_data.reward_infos, - fee_rate: pool_data.fee_rate, - tick_spacing: pool_data.tick_spacing, - } - } - - #[view] - public fun get_position( - owner: address, - pool: Object, - position_id: u64, - ): Position - acquires LiquidityPool { - let position = table::borrow(&pool_data(&pool).positions, get_position_key(&owner, position_id)); - - Position { - id: position.id, - tick_lower: position.tick_lower, - tick_upper: position.tick_upper, - liquidity: position.liquidity, - fee_growth_inside_0_last_x64: position.fee_growth_inside_0_last_x64, - fee_growth_inside_1_last_x64: position.fee_growth_inside_1_last_x64, - tokens_owed_0: position.tokens_owed_0, - tokens_owed_1: position.tokens_owed_1, - reward_infos: position.reward_infos, - } - } - - #[view] - public fun get_position_with_pending_fees_and_rewards( - owner: address, - pool: Object, - position_id: u64, - ): Position acquires LiquidityPool { - let position = get_position(owner, pool, position_id); - let pool_data = pool_data(&pool); - - let position_view = Position { - id: position.id, - liquidity: position.liquidity, - tick_lower: position.tick_lower, - tick_upper: position.tick_upper, - fee_growth_inside_0_last_x64: position.fee_growth_inside_0_last_x64, - fee_growth_inside_1_last_x64: position.fee_growth_inside_1_last_x64, - tokens_owed_0: position.tokens_owed_0, - tokens_owed_1: position.tokens_owed_1, - reward_infos: position.reward_infos, - }; - - if (position.liquidity == 0) { - return position_view - }; - - let (fee_growth_inside_0_x64, fee_growth_inside_1_x64) = get_fee_growth_inside_tick( - &pool_data.ticks, - position.tick_lower, - position.tick_upper, - pool_data.current_tick, - pool_data.fee_growth_global_0_x64, - pool_data.fee_growth_global_1_x64, - ); - update_position_fee(&mut position_view, fee_growth_inside_0_x64, fee_growth_inside_1_x64); - - let reward_growths_global = get_pool_reward_infos(pool_data); - let reward_growths_inside = get_reward_growths_inside( - &pool_data.ticks, - position.tick_lower, - position.tick_upper, - pool_data.current_tick, - &reward_growths_global, - ); - update_position_rewards(&mut position_view, &reward_growths_inside); - - position_view - } - - #[view] - public fun get_position_info( - owner: address, - pool: Object, - position_id: u64, - ): (u128, u32, u32, u64, u64) acquires LiquidityPool { - let pool_data = pool_data_mut(&pool); - let position = get_position_mut(&mut pool_data.positions, owner, position_id); - - ( - position.liquidity, - position.tick_lower, - position.tick_upper, - position.tokens_owed_0, - position.tokens_owed_1, - ) - } - - #[view] - public fun quote_swap( - trader: address, - pool: Object, - zero_for_one: bool, - is_exact_in: bool, - specified_amount: u64, - sqrt_price_limit: u128, - ): (u64, u64, u64) acquires LiquidityPool { - assert!(specified_amount > 0, E_SWAP_AMOUNT_MUST_GREATER_THAN_ZERO); - - let pool_data = pool_data(&pool); - - if (zero_for_one) { - assert!( - sqrt_price_limit < pool_data.current_sqrt_price - && sqrt_price_limit >= tick_math::min_sqrt_price(), - E_INVALID_LIMIT_SQRT_PRICE, - ); - } else { - assert!( - sqrt_price_limit > pool_data.current_sqrt_price - && sqrt_price_limit <= tick_math::max_sqrt_price(), - E_INVALID_LIMIT_SQRT_PRICE, - ) - }; - - let tick_spacing = pool_data.tick_spacing; - let current_sqrt_price = pool_data.current_sqrt_price; - let current_tick = pool_data.current_tick; - let liquidity = pool_data.liquidity; - let remaining_amount = specified_amount; - let calculated_amount = 0; - let total_fee_amount = 0; - let fee_rate = get_fee_rate(trader, pool_data); - - while (remaining_amount > 0 && current_sqrt_price != sqrt_price_limit) { - let price_sqrt_start = current_sqrt_price; - - let (tick_next, is_initialized_tick) = tick_bitmap::get_next_initialized_tick_within_one_word( - &pool_data.tick_bitmap, - current_tick, - tick_spacing, - zero_for_one, - ); - - // ensure that we do not overshoot the min/max tick, as the tick bitmap is not aware of these bounds - if (tick_next < tick::min_tick()) { - tick_next = tick::min_tick(); - } else if (tick_next > tick::max_tick()) { - tick_next = tick::max_tick(); - }; - - let sqrt_price_next = tick_math::get_sqrt_price_at_tick(tick_next); - - let target_sqrt_price = if (zero_for_one) { - math128::max(sqrt_price_limit, sqrt_price_next) - } else { - math128::min(sqrt_price_limit, sqrt_price_next) - }; - let (sqrt_price, amount_in, amount_out, fee_amount) = swap_math::compute_swap_step( - current_sqrt_price, - target_sqrt_price, - liquidity, - remaining_amount, - is_exact_in, - fee_rate, - ); - current_sqrt_price = sqrt_price; - - if (is_exact_in) { - remaining_amount = remaining_amount - (amount_in + fee_amount); - calculated_amount = calculated_amount + amount_out; - } else { - remaining_amount = remaining_amount - amount_out; - calculated_amount = calculated_amount + (amount_in + fee_amount); - }; - - total_fee_amount = total_fee_amount + fee_amount; - - // shift tick if we reached the next price - if (sqrt_price == sqrt_price_next) { - if (is_initialized_tick) { - let next_liquidity_net = table::borrow(&pool_data.ticks, tick_next).liquidity_net; - - if (zero_for_one) { - // next_liquidity_net = -next_liquidity_net - next_liquidity_net = i128::new( - i128::abs(&next_liquidity_net), - !i128::is_negative(&next_liquidity_net), - ); - }; - - liquidity = add_delta_liquidity(liquidity, &next_liquidity_net); - }; - - current_tick = if (zero_for_one) tick_next - 1 else tick_next; - } else if (current_sqrt_price != price_sqrt_start) { - current_tick = tick_math::get_tick_at_sqrt_price(sqrt_price); - } - }; - - let (amount_in, amount_out) = if (is_exact_in) { - (specified_amount - remaining_amount, calculated_amount) - } else { - (calculated_amount, specified_amount - remaining_amount) - }; - - (amount_in, amount_out, total_fee_amount) - } - - // Tests. - - #[test_only] - friend yuzuswap::test_pool; - #[test_only] - friend yuzuswap::liquidity_pool_tests; - #[test_only] - friend yuzuswap::liquidity_pool_liquidity_tests; - #[test_only] - friend yuzuswap::liqudity_pool_swap_tests; - #[test_only] - friend yuzuswap::liquidity_pool_reward_tests; - #[test_only] - friend yuzuswap::position_nft_manager_tests; - #[test_only] - friend yuzuswap::router_tests; - #[test_only] - friend yuzuswap::router_swap_tests; - - #[test_only] - public fun initialize_for_test(owner: &signer) { - init_module(owner); - } - - #[test_only] - public fun get_fee(pool: Object): u64 acquires LiquidityPool { - pool_data(&pool).fee_rate - } - - #[test_only] - public fun get_tick_spacing(pool: Object): u32 acquires LiquidityPool { - pool_data(&pool).tick_spacing - } - - #[test_only] - public fun get_current_tick(pool: Object): u32 acquires LiquidityPool { - pool_data(&pool).current_tick - } - - #[test_only] - public fun get_current_sqrt_price(pool: Object): u128 acquires LiquidityPool { - pool_data(&pool).current_sqrt_price - } - - #[test_only] - public fun get_liquidity(pool: Object): u128 acquires LiquidityPool { - pool_data(&pool).liquidity - } - - #[test_only] - public fun get_fee_growth_global(pool: Object): (u128, u128) acquires LiquidityPool { - let pool_data = pool_data(&pool); - - ( - pool_data.fee_growth_global_0_x64, - pool_data.fee_growth_global_1_x64, - ) - } - - #[test_only] - public fun is_position_exists( - owner: address, - pool: Object, - position_id: u64, - ): bool - acquires LiquidityPool { - let pool_data = pool_data(&pool); - let position_key = get_position_key(&owner, position_id); - - table::contains(&pool_data.positions, position_key) - } - - #[test_only] - public fun extract_position( - position: &Position, - ): ( - u64, u32, u32, u128, u128, u128, u64, u64, - vector, - ) { - ( - position.id, - position.tick_lower, - position.tick_upper, - position.liquidity, - position.fee_growth_inside_0_last_x64, - position.fee_growth_inside_1_last_x64, - position.tokens_owed_0, - position.tokens_owed_1, - position.reward_infos, - ) - } - - #[test_only] - public fun get_reward_info( - pool: Object, - reward_index: u64, - ): (Object, u64, u64, u128, address) acquires LiquidityPool { - let pool_data = pool_data(&pool); - let reward_info = vector::borrow(&pool_data.reward_infos, reward_index); - - ( - reward_info.token_metadata, - reward_info.remaining_reward, - reward_info.emissions_per_second, - reward_info.growth_global, - reward_info.manager, - ) - } - - #[test_only] - public fun extract_tick_view(tick_view: &TickView): (u32, u128, I128, u128, u128, vector) { - ( - tick_view.tick, - tick_view.liquidity_gross, - tick_view.liquidity_net, - tick_view.fee_growth_outside_0_x64, - tick_view.fee_growth_outside_1_x64, - tick_view.reward_growths_outside, - ) - } -} diff --git a/models/schemas.py b/models/schemas.py new file mode 100644 index 00000000..f9d38e2e --- /dev/null +++ b/models/schemas.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + +@dataclass +class LogCtxData: + txtENG: str = "" + txtCN: str = "" + detail: str = "" + + @property + def source(self) -> str: + return self.txtENG or self.txtCN + + @property + def source_lang(self) -> str: + return "EN" if self.txtENG else "CN" \ No newline at end of file diff --git a/path/to/app/core.py b/path/to/app/core.py new file mode 100644 index 00000000..69b04db1 --- /dev/null +++ b/path/to/app/core.py @@ -0,0 +1,21 @@ +# app/core.py +from textual import work + +from services.translation import TranslationProcessor + + +class CoreApp: + def __init__(self): + self.translation_enabled = False + self.main_loop = None + self.translation = TranslationProcessor() # Ensure this attribute exists + + @work(thread=True, exclusive=True) + async def start_ipython(self): + # Existing code remains the same + pass + + async def _update_tables(self, ctx, target_lang): + if self.translation_enabled: + translated = await self.translation.translate(ctx, target_lang) + self._update_tables() # Ensure this method exists and is correctly named \ No newline at end of file diff --git a/path/to/src/library/parsing/SolidityParser.py b/path/to/src/library/parsing/SolidityParser.py new file mode 100644 index 00000000..1805e17d --- /dev/null +++ b/path/to/src/library/parsing/SolidityParser.py @@ -0,0 +1,23 @@ +from typing import Any +from antlr4 import CommonTokenStream +from src.library.parsing.SolidityParser import getSubcontract, parseToken + +SUBCONTRACT_ID = "subcontract" + +class SolidityParser: + def __init__(self): + pass + + def parse_stream(self, stream: CommonTokenStream) -> list: + result = [] + max_length = len(stream.tokens) + i = 0 + + while i < max_length: + id, content, loc = parseToken(str(stream.tokens[i])) + if id is not None: + if id == SUBCONTRACT_ID: + i, subcontract, subcontract_entry = getSubcontract(i+1, stream.tokens, max_length, id, loc) + result.append(subcontract_entry) + + return result \ No newline at end of file diff --git a/path/to/src/runui.py b/path/to/src/runui.py new file mode 100644 index 00000000..659c98fe --- /dev/null +++ b/path/to/src/runui.py @@ -0,0 +1,51 @@ +from textual import work +from app import translation +from app.core import App +import asyncio +from rich.table import Table +zh_table = Table("Translation") +en_table = Table("Source") + +class IPythonConsole(App): + CSS = """ + Screen { + layout: vertical; + } + """ + + @work(thread=True, exclusive=True) + def start_ipython(self): + """Start IPython in background thread""" + import sys + from IPython import start_ipython + + class IPythonWrapper: + def __init__(self, app): + self.app = app + + def write(self, data): + asyncio.run_coroutine_threadsafe( + self.app.post_output(data), + loop=self.app.main_loop + ) + + def flush(self): + pass + + def isatty(self): + return False # Crucial fix for prompt_toolkit integration + + def fileno(self): + return -1 # Indicate no real file descriptor + + sys.stdout = IPythonWrapper(self) + start_ipython(argv=[], user_ns={ + "app": self, + "toggle_translation": self.toggle_translation + }) + +def some_function(): + source = "Default Source" + style_class = "default" # Define a default style class + en_table.add_row(f"[{style_class}]{source}") + zh_table.add_row(f"[{style_class}]{translation}") \ No newline at end of file diff --git a/path/to/ui/widgets.py b/path/to/ui/widgets.py new file mode 100644 index 00000000..ba0d6cbe --- /dev/null +++ b/path/to/ui/widgets.py @@ -0,0 +1,13 @@ +# ui/widgets.py +from textual.widgets import Switch +class TranslationToggle(Switch): + def __init__(self, title: str = "🌐 TRANSLATE OFF"): + super().__init__(title) + self.border_title = title + self.add_column("Content") # Fixed method name + + def add_row(self, text: str): + super().addRow(text) # Ensure the method name matches if necessary + + def on_change(self, value: bool): + self.label = "🌐 TRANSLATE ON" if value else "🌐 TRANSLATE OFF" \ No newline at end of file diff --git a/philosophy_of_monkey.md b/philosophy_of_monkey.md new file mode 100644 index 00000000..a1cd38f2 --- /dev/null +++ b/philosophy_of_monkey.md @@ -0,0 +1,121 @@ +这篇文章写于2025年2月18日深夜,是关于FiniteMonkey这个引擎的思考 + +# 1. 起源 + +这个引擎的起源只有一句话"这个代码里面有一个漏洞,请你把它找出来"。当时还是gpt3.5和gpt4的时候,很有趣,这个prompt要比"这个代码里有漏洞吗?"要有效得多。 + +具体而言,这个有效性体现在,相比与问句而言,这个prompt似乎更能够触发大模型的推理能力或者称之为逻辑查找能力,而不仅仅是基于训练数据集的记忆。大模型也更会倾向于去找漏洞,而不是基于既有的最基本的漏洞知识去胡乱猜测。 + +当时有这样一个工作流:我提问=>模型回答=>我验证=>如果不是的话,用相同的prompt再问一遍,直到找到漏洞为止。 + +基于这句话我曾经在2年内拿到了接近7万美元的bug bounty。 + +# 2. 建设 + +在2024年2月,也就是文章的整整一年前,我决定开始建设这个引擎。当时ai agent的概念已经出现,不过当时并没有考虑用agent的思路来进行建设,而是只把它当做一个我人工工作流的复现工具,让整个流程更加自动化。 + +最初始的工作流如下: +![alt text](./pics/未命名文件(48).png) + +简单来说,就是将项目直接粗暴的拆成函数粒度进行提问,每个提问10次,然后针对每一个输出进行validation,让gpt判断它是否是误报。 + +当然结果很差,甚至不如我人工直接来,我意识到,他需要迭代。 + +# 3. 迭代 + +## 3.1 迭代的第一步:validation与检测的粒度 + +首先迭代的是validation过程和检测的粒度,在整个工作流中我发现,它与我个人的工作流相差最大的就是我在提问的时候是从来不进行上下文考虑的 + +换句话说,当时的上下文(128k)足够我粗暴的把整个合约扔进去直接提问,但是实际上在工具中,我却用了函数,这种上下文的差别导致了一点,就是它的检出性非常的低 + +我曾经考虑过使用一些静态的方法进行一个合理的上下文提取,比如说针对某个函数,使用slither提取相应的上下文,但是并没有采用(后面会解释),而是使用了antlrv4 + +使用antlrv4,我将每一个文件进行了一个行为,称之为业务流的构建 + +业务流的构建的想法,曾经跟yuqiang讨论过,当时他提到,使用slither进行变量读写的提取或者函数调用的提取其实并不好,相应的,我便放弃了slither(这是其中一个原因),转而采用直接gpt提问的方式来抽取业务流,抽取业务流时候,使用的prompt为: + +``` +Based on the code above, analyze the business flows that start with the {function_name} function, consisting of multiple function calls. The analysis should adhere to the following requirements: + 1. only output the one sub-business flows, and must start from {function_name}. + 2. The output business flows should only involve the list of functions of the contract itself (ignoring calls to other contracts or interfaces, as well as events). + 3. After step-by-step analysis, output one result in JSON format, with the structure: {{"{function_name}":[function1,function2,function3....]}} + 4. The business flows must include all involved functions without any omissions +``` +通过这种方式提取业务流然后进行提问,基于此,检测的粒度变得相对reasonable,但是仍然存在问题,现在仍未解决 + +接下来,validation也是同样的道理,不过validation最开始的处理方案是使用多种prompt(比如说提问是否有漏洞/提问是否有patch),经过大批量测试后发现效果并不好,后来发现实际上关键问题还是在于上下文的完整性,毕竟你不能让llm去验证一些它完全不知道的事情 + +## 3.2 迭代第二步:上下文 + +对于上下文的问题,我曾经想过粗暴的来扩展prompt输入解决,但是实际证明效果并不好,因为当项目过大的时候,你并不能直接把整个项目扔进去,而需要进行一个合理的上下文提取。 + +在这个时候,cursor出现了,我算是cursor的第一批用户,它的codebase QA功能让我眼前一亮,我意识到,我需要一个合理的上下文提取,而cursor的codebase QA功能恰好可以满足这一点 + +因此,我详细调查了一下cursor的codebase QA功能,发现它实际上是一个基于问题的RAG抽取,这好办了,RAG我熟,之前在llm4vuln中进行dev的时候用了非常多的RAG + +因此自己实现一个codebaseQA的中间组件,便成了很重要的一部,现在它是这样的: + +![alt text](./pics/image.png) + +简单来说,这个组件的功能就是通过对codebase,也就是项目代码的预处理,然后根据某一个漏洞输入,提取出跟这个漏洞最相关的一个上下文 + +在这之中,我自己实现了一个相对复杂的call tree,以及覆盖整个项目code的RAG,最终他会产出一个上下文text,这个text我称之为context funnel(上下文漏斗) + +基于这个context funnel,我便可以进行一个合理的上下文提取,然后进行漏洞的validation + +## 3.3 迭代第三步:模型的选择 + +在整个项目的开发之中,包括3.1和3.2,我尝试过几乎所有模型,包含了gpt全家桶,claude全家桶,以及最新的o1,o3,和r1,在实际的工程实现中,我们并不能说哪个漏洞检测模型的使用是绝对的,它有很多因素要考虑,其中最重要的两个就是: + +时间和成本 + +因为目标是一个可以直接产品化的工具,因此并不能说越强大的模型就越要用,并不能说reasoning模型(就是o1,o3,r1这些会有cot的模型)就一定要比unreasoning模型要好,因为时间和成本问题,你不可能要求一个task花费几分钟甚至接近10分钟才能得到答案,因此当前的模型选择是: + +检测用claude,validation用deepseek o1 + +## 3.4 迭代第四步:回到检测 + +感谢ret2basic的启发,我意识到,最初始的prompt并非完美,一个完全依赖于大模型能力的prompt,高度受限于大模型本身的训练数据,即使你能够触发它最强大的推理和逻辑能力,但有些漏洞它是不会注意到的 + +因此,回到最初的prompt,我需要一个checklist,那么它从哪里来 + +感谢solodit,我爬取了24000个审计漏洞,感谢dacien整理的那么多checklist,我对这些漏洞和checklist进行了处理,把他们加入到了prompt里,并形成了一个简单的组件,称之为S.P.A.R.T(Smart Prompting for Automated Risk Tailoring,自动化风险prompt组合) + +这样,检测的prompt似乎就变得完整了,当然还有很多需要调整的 + +# 4. 疑惑的解答:为什么使用antlrv4而不是slither + +这个引擎有一些预设的条件,其中有一条比较关键的就是:"最好的编译器是大模型" + +我并非否认slither的强大,我也曾经花费非常多的时间在slither上并开发了很多规则和优化,但是slither面临一个问题就是它基于了solc,并且变得越来越重,对项目的完整性要求的越来越高,这不但不利于我进行测试,更不利于我扩展到更多语言 + +而antlrv4则没有这个问题,或者换句话说,一个简单的函数拆解,然后基于拆解的各种组合,不但可以满足各类的业务流抽取需求,更可以将相同的代码和架构扩展到更多的语言上,毕竟你只需要让代码进行函数粒度拆解就可以了 + +因此,这个引擎实际当前支持的语言非常多:不仅仅包括solidity,甚至包括solidity反编译的伪代码,rust,move,go,python,甚至可以扩展到任何语言 + +# 5. 设计哲学 + +这一点才是最重要的一个部分,不同于市面上大多数检测工具,不同于我之前和yuqiang,daoyuan,liuye的工作,它并非是基于传统的确定性模式识别 + +传统的确定性模式识别需要对漏洞有非常细粒度的拆解和定义,然后通过大量的数据集作为打底,然后进行模式识别,不论是gptscan也好,propertygpt也好,还是llm4vuln也好,都是基于这个思路 + +这种思路都是基于专家经验,将漏洞解读为cot,解读为functionality和keyconcept,进行匹配,引导大模型进行一步步思考,但这带来了一个问题: + +1. 它需要大量的专家经验和数据集,而专家经验是稀缺的,因此很难覆盖到所有漏洞 +2. 它是规则驱动的,意味着它需要大量的规则,而规则的开发和维护成本非常高,就像gptscan一样,天知道yuqiang为了写那11个detector花费了多少精力:P +3. 它是规则驱动的,而这就意味着规则越严格,就越难覆盖到所有漏洞,人力会回到无限的规则回测中去 +4. 在整个过程中需要经历极多次的llm提问,在这里感谢alan@secure3的文章,它让我意识到,这种复杂流程中,会产生一个致命问题:错误累积定律,简单来说就是,错误会随着流程的复杂度指数级增长,当llm action为30时,即使每一个action都有99%的准确率,最终的准确率会降低到73%,而实际上当前的llm准确率是惨不忍睹的,简单粗暴的猜测一下,每一个action大概只有60-70%的准确率,这也就意味着只需要经过10个action,准确率就会来到2%,即使是5个action,也只有16%的准确率,这种情况下,完全无法接受,导致长链检测任务必然存在准确率天花板 + +因此,为了解决这个问题,至少在ai audit领域内,要跳出agent action link的限制,通过以下方法论 + +从"寻找正确答案"转向"管理可能性空间" +A[代码输入] --> B{可能性空间构建} +B --> C[漏洞假说云] +C --> D[验证收敛] +D --> E[确定性结论] +(这里可能有更多的理论基础等待探索,后面会逐步扩展) + +这种方式好处在于,我们不需要花费大力气去构建一个非常复杂的规则,而是通过构建一个可能性空间,然后通过验证收敛,最终得到一个确定性结论 + +另外,这种方式可以非常有效的解决错误累积定律的问题,我通过降低复杂度,从而降低错误累积定律的影响 \ No newline at end of file diff --git a/pics/image.png b/pics/image.png new file mode 100644 index 00000000..c9644321 Binary files /dev/null and b/pics/image.png differ diff --git "a/pics/\346\234\252\345\221\275\345\220\215\346\226\207\344\273\266(48).png" "b/pics/\346\234\252\345\221\275\345\220\215\346\226\207\344\273\266(48).png" new file mode 100644 index 00000000..2e1bf8db Binary files /dev/null and "b/pics/\346\234\252\345\221\275\345\220\215\346\226\207\344\273\266(48).png" differ diff --git a/prompt.txt b/prompt.txt deleted file mode 100644 index b581d70f..00000000 --- a/prompt.txt +++ /dev/null @@ -1,309 +0,0 @@ -:function _updateCreditDelegations( - Data storage self, - uint128[] memory connectedMarketsIdsCache, - bool shouldRehydrateCache - ) - private - returns (uint128[] memory rehydratedConnectedMarketsIdsCache, SD59x18 vaultCreditCapacityUsdX18) - { - rehydratedConnectedMarketsIdsCache = new uint128[](connectedMarketsIdsCache.length); - // cache the vault id - uint128 vaultId = self.id; - - // cache the connected markets length - uint256 connectedMarketsConfigLength = self.connectedMarkets.length; - - // loads the connected markets storage pointer by taking the last configured market ids uint set - EnumerableSet.UintSet storage connectedMarkets = self.connectedMarkets[connectedMarketsConfigLength - 1]; - - // loop over each connected market id that has been cached once again in order to update this vault's - // credit delegations - for (uint256 i; i < connectedMarketsIdsCache.length; i++) { - // rehydrate the markets ids cache if needed - if (shouldRehydrateCache) { - rehydratedConnectedMarketsIdsCache[i] = connectedMarkets.at(i).toUint128(); - } else { - rehydratedConnectedMarketsIdsCache[i] = connectedMarketsIdsCache[i]; - } - - // loads the memory cached market id - uint128 connectedMarketId = rehydratedConnectedMarketsIdsCache[i]; - - // load the credit delegation to the given market id - CreditDelegation.Data storage creditDelegation = CreditDelegation.load(vaultId, connectedMarketId); - - // cache the previous credit delegation value - UD60x18 previousCreditDelegationUsdX18 = ud60x18(creditDelegation.valueUsd); - - // cache the latest credit delegation share of the vault's credit capacity - uint128 totalCreditDelegationWeightCache = self.totalCreditDelegationWeight; - - if (totalCreditDelegationWeightCache != 0) { - // get the latest credit delegation share of the vault's credit capacity - UD60x18 creditDelegationShareX18 = - ud60x18(creditDelegation.weight).div(ud60x18(totalCreditDelegationWeightCache)); - - // stores the vault's total credit capacity to be returned - vaultCreditCapacityUsdX18 = getTotalCreditCapacityUsd(self); - - // if the vault's credit capacity went to zero or below, we set its credit delegation to that market - // to zero - UD60x18 newCreditDelegationUsdX18 = vaultCreditCapacityUsdX18.gt(SD59x18_ZERO) - ? vaultCreditCapacityUsdX18.intoUD60x18().mul(creditDelegationShareX18) - : UD60x18_ZERO; - - // calculate the delta applied to the market's total delegated credit - UD60x18 creditDeltaUsdX18 = newCreditDelegationUsdX18.sub(previousCreditDelegationUsdX18); - - // loads the market's storage pointer and update total delegated credit - Market.Data storage market = Market.load(connectedMarketId); - market.updateTotalDelegatedCredit(creditDeltaUsdX18); - - // if new credit delegation is zero, we clear the credit delegation storage - if (newCreditDelegationUsdX18.isZero()) { - creditDelegation.clear(); - } else { - // update the credit delegation stored usd value - creditDelegation.valueUsd = newCreditDelegationUsdX18.intoUint128(); - } - } - } - } -function updateVaultAndCreditDelegationWeight( - Data storage self, - uint128[] memory connectedMarketsIdsCache - ) - internal - { - // cache the connected markets length - uint256 connectedMarketsConfigLength = self.connectedMarkets.length; - - // loads the connected markets storage pointer by taking the last configured market ids uint set - EnumerableSet.UintSet storage connectedMarkets = self.connectedMarkets[connectedMarketsConfigLength - 1]; - - // get the total of shares - uint128 newWeight = uint128(IERC4626(self.indexToken).totalAssets()); - - for (uint256 i; i < connectedMarketsIdsCache.length; i++) { - // load the credit delegation to the given market id - CreditDelegation.Data storage creditDelegation = - CreditDelegation.load(self.id, connectedMarkets.at(i).toUint128()); - - // update the credit delegation weight - creditDelegation.weight = newWeight; - } - - // update the vault weight - self.totalCreditDelegationWeight = newWeight; - } -function _recalculateConnectedMarketsState( - Data storage self, - uint128[] memory connectedMarketsIdsCache, - bool shouldRehydrateCache - ) - private - returns ( - uint128[] memory rehydratedConnectedMarketsIdsCache, - SD59x18 vaultTotalRealizedDebtChangeUsdX18, - SD59x18 vaultTotalUnrealizedDebtChangeUsdX18, - UD60x18 vaultTotalUsdcCreditChangeX18, - UD60x18 vaultTotalWethRewardChangeX18 - ) - { - RecalculateConnectedMarketsState_Context memory ctx; - rehydratedConnectedMarketsIdsCache = new uint128[](connectedMarketsIdsCache.length); - - // cache the vault id - ctx.vaultId = self.id; - - // cache the connected markets length - uint256 connectedMarketsConfigLength = self.connectedMarkets.length; - - // loads the connected markets storage pointer by taking the last configured market ids uint set - EnumerableSet.UintSet storage connectedMarkets = self.connectedMarkets[connectedMarketsConfigLength - 1]; - - for (uint256 i; i < connectedMarketsIdsCache.length; i++) { - if (shouldRehydrateCache) { - rehydratedConnectedMarketsIdsCache[i] = connectedMarkets.at(i).toUint128(); - } else { - rehydratedConnectedMarketsIdsCache[i] = connectedMarketsIdsCache[i]; - } - - // loads the market storage pointer - Market.Data storage market = Market.load(rehydratedConnectedMarketsIdsCache[i]); - - // first we cache the market's unrealized and realized debt - ctx.marketUnrealizedDebtUsdX18 = market.getUnrealizedDebtUsd(); - ctx.marketRealizedDebtUsdX18 = market.getRealizedDebtUsd(); - - // if market has debt distribute it - if (!ctx.marketUnrealizedDebtUsdX18.isZero() || !ctx.marketRealizedDebtUsdX18.isZero()) { - // distribute the market's debt to its connected vaults - market.distributeDebtToVaults(ctx.marketUnrealizedDebtUsdX18, ctx.marketRealizedDebtUsdX18); - } - - // load the credit delegation to the given market id - CreditDelegation.Data storage creditDelegation = - CreditDelegation.load(ctx.vaultId, rehydratedConnectedMarketsIdsCache[i]); - - // prevent division by zero - if (!market.getTotalDelegatedCreditUsd().isZero()) { - // get the vault's accumulated debt, credit and reward changes from the market to update its stored - // values - ( - ctx.realizedDebtChangeUsdX18, - ctx.unrealizedDebtChangeUsdX18, - ctx.usdcCreditChangeX18, - ctx.wethRewardChangeX18 - ) = market.getVaultAccumulatedValues( - ud60x18(creditDelegation.valueUsd), - sd59x18(creditDelegation.lastVaultDistributedRealizedDebtUsdPerShare), - sd59x18(creditDelegation.lastVaultDistributedUnrealizedDebtUsdPerShare), - ud60x18(creditDelegation.lastVaultDistributedUsdcCreditPerShare), - ud60x18(creditDelegation.lastVaultDistributedWethRewardPerShare) - ); - } - - // if there's been no change in any of the returned values, we can iterate to the next - // market id - if ( - ctx.realizedDebtChangeUsdX18.isZero() && ctx.unrealizedDebtChangeUsdX18.isZero() - && ctx.usdcCreditChangeX18.isZero() && ctx.wethRewardChangeX18.isZero() - ) { - continue; - } - - // update the vault's state by adding its share of the market's latest state variables - vaultTotalRealizedDebtChangeUsdX18 = vaultTotalRealizedDebtChangeUsdX18.add(ctx.realizedDebtChangeUsdX18); - vaultTotalUnrealizedDebtChangeUsdX18 = - vaultTotalUnrealizedDebtChangeUsdX18.add(ctx.unrealizedDebtChangeUsdX18); - vaultTotalUsdcCreditChangeX18 = vaultTotalUsdcCreditChangeX18.add(ctx.usdcCreditChangeX18); - vaultTotalWethRewardChangeX18 = vaultTotalWethRewardChangeX18.add(ctx.wethRewardChangeX18); - - // update the last distributed debt, credit and reward values to the vault's credit delegation to the - // given market id, in order to keep next calculations consistent - creditDelegation.updateVaultLastDistributedValues( - sd59x18(market.realizedDebtUsdPerVaultShare), - sd59x18(market.unrealizedDebtUsdPerVaultShare), - ud60x18(market.usdcCreditPerVaultShare), - ud60x18(market.wethRewardPerVaultShare) - ); - } - } -function load(uint128 vaultId) internal pure returns (Data storage vault) { - bytes32 slot = keccak256(abi.encode(VAULT_LOCATION, vaultId)); - assembly { - vault.slot := slot - } - } -function recalculateVaultsCreditCapacity(uint256[] memory vaultsIds) internal { - for (uint256 i; i < vaultsIds.length; i++) { - // uint256 -> uint128 - uint128 vaultId = vaultsIds[i].toUint128(); - - // load the vault storage pointer - Data storage self = load(vaultId); - - // make sure there are markets connected to the vault - uint256 connectedMarketsConfigLength = self.connectedMarkets.length; - if (connectedMarketsConfigLength == 0) continue; - - // loads the connected markets storage pointer by taking the last configured market ids uint set - EnumerableSet.UintSet storage connectedMarkets = self.connectedMarkets[connectedMarketsConfigLength - 1]; - - // cache the connected markets ids to avoid multiple storage reads, as we're going to loop over them twice - // at `_recalculateConnectedMarketsState` and `_updateCreditDelegations` - uint128[] memory connectedMarketsIdsCache = new uint128[](connectedMarkets.length()); - - // update vault and credit delegation weight - updateVaultAndCreditDelegationWeight(self, connectedMarketsIdsCache); - - // iterate over each connected market id and distribute its debt so we can have the latest credit - // delegation of the vault id being iterated to the provided `marketId` - ( - uint128[] memory updatedConnectedMarketsIdsCache, - SD59x18 vaultTotalRealizedDebtChangeUsdX18, - SD59x18 vaultTotalUnrealizedDebtChangeUsdX18, - UD60x18 vaultTotalUsdcCreditChangeX18, - UD60x18 vaultTotalWethRewardChangeX18 - ) = _recalculateConnectedMarketsState(self, connectedMarketsIdsCache, true); - - // gas optimization: only write to storage if values have changed - // - // updates the vault's stored unsettled realized debt distributed from markets - if (!vaultTotalRealizedDebtChangeUsdX18.isZero()) { - self.marketsRealizedDebtUsd = sd59x18(self.marketsRealizedDebtUsd).add( - vaultTotalRealizedDebtChangeUsdX18 - ).intoInt256().toInt128(); - } - - // updates the vault's stored unrealized debt distributed from markets - if (!vaultTotalUnrealizedDebtChangeUsdX18.isZero()) { - self.marketsUnrealizedDebtUsd = sd59x18(self.marketsUnrealizedDebtUsd).add( - vaultTotalUnrealizedDebtChangeUsdX18 - ).intoInt256().toInt128(); - } - - // adds the vault's total USDC credit change, earned from its connected markets, to the - // `depositedUsdc` variable - if (!vaultTotalUsdcCreditChangeX18.isZero()) { - self.depositedUsdc = ud60x18(self.depositedUsdc).add(vaultTotalUsdcCreditChangeX18).intoUint128(); - } - - // distributes the vault's total WETH reward change, earned from its connected markets - if (!vaultTotalWethRewardChangeX18.isZero() && self.wethRewardDistribution.totalShares != 0) { - SD59x18 vaultTotalWethRewardChangeSD59X18 = - sd59x18(int256(vaultTotalWethRewardChangeX18.intoUint256())); - self.wethRewardDistribution.distributeValue(vaultTotalWethRewardChangeSD59X18); - } - - // update the vault's credit delegations - (, SD59x18 vaultNewCreditCapacityUsdX18) = - _updateCreditDelegations(self, updatedConnectedMarketsIdsCache, false); - - emit LogUpdateVaultCreditCapacity( - vaultId, - vaultTotalRealizedDebtChangeUsdX18.intoInt256(), - vaultTotalUnrealizedDebtChangeUsdX18.intoInt256(), - vaultTotalUsdcCreditChangeX18.intoUint256(), - vaultTotalWethRewardChangeX18.intoUint256(), - vaultNewCreditCapacityUsdX18.intoInt256() - ); - } - }You are the best solidity auditor in the world,our task is to pinpoint and correct any logical or code-error or financial related vulnerabilities present in the code.We have already confirmed that the code contains only one exploitable, \ - code-error based vulnerability due to error logic in the code, \ - and your job is to identify it. - and the vulnerability is include but [not limited] to the following vulnerabilities,1. **Liquidation Before Default:** Liquidation should only occur after a genuine default (e.g., overdue repayment or insufficient collateral), yet in cases like Sherlock’s TellerV2—where the function returns the loan’s accepted timestamp instead of the last repayment timestamp—and Hats Finance Tempus Raft—where an unchecked collateralToken parameter permits price miscalculation—the conditions enable premature liquidation before the due repayment date. - -2. **Borrower Can't Be Liquidated:** In certain implementations such as Sherlock TellerV2, neglecting to check the return value of OpenZeppelin’s EnumerableSetUpgradeable.AddressSet.add() allows the borrower to overwrite existing collateral records (even with a zero amount), thereby preventing proper liquidation on default. - -3. **Debt Closed Without Repayment:** Some systems, as seen in a DebtDAO audit, allow borrowers to call the close() function with a non-existent credit ID that returns a default Credit structure (with principal 0), bypassing repayment validations and erroneously marking the loan as repaid while decrementing an internal counter. - -4. **Repayments Paused While Liquidations Enabled:** In platforms like Sherlock’s Blueberry example, repay() enforces an isRepayAllowed() check while liquidate() does not, which permits liquidation operations even when repayments are deliberately paused, placing borrowers at an unfair disadvantage. - -5. **Token Disallow Stops Existing Repayment & Liquidation:** When governance changes disallow a previously permitted token (as seen in BlueBerry updates), loans using that token for repayment or as collateral might become incapable of proper repayment or liquidation, creating inconsistencies that jeopardize both borrowers and lenders. - -6. **Borrower Immediately Liquidated After Repayments Resume:** If market conditions deteriorate during a pause in repayments, then—as soon as repayments are re-enabled without a grace period—the unchanged liquidation thresholds can trigger immediate liquidation, leaving borrowers with little to no opportunity to recover. - -7. **Liquidator Takes Collateral With Insufficient Repayment:** Partial liquidation calculations that rely solely on the ratio from a specific debt position—for instance, using share/oldShare in Blueberry—can let liquidators pay a minimal portion of the debt while unjustifiably seizing a disproportionately large amount of collateral, ignoring the borrower’s entire debt profile. - -8. **Infinite Loan Rollover:** Allowing borrowers to extend (roll over) their loans indefinitely without imposing strict limits exposes lenders to prolonged credit risk and potential non-repayment, underscoring the need for capping the number or duration of rollovers. - -9. **Repayment Sent to Zero Address:** In examples like Cooler’s Sherlock audit, deleting loan records before executing the repayment transfer can reset critical fields (such as loan.lender) to the zero address, resulting in repayment funds being sent to (0) and permanently lost. - -10. **Borrower Permanently Unable To Repay Loan:** System logic errors or token disallowances that prevent the successful execution of a repay() call can leave borrowers incapable of repaying—forcing them into liquidation while also preventing lenders from recovering their funds. - -11. **Borrower Repayment Only Partially Credited:** When a borrower makes a lump-sum repayment covering multiple loans, if the system credits only the current loan without applying any overpayment to subsequent loans, it leads to partial repayments, excessive interest accrual, or misinterpreted default statuses. - -12. **No Incentive To Liquidate Small Positions:** With rising gas fees, liquidation fees for small underwater positions may be economically unattractive; consequently, liquidators might avoid these positions, allowing them to accumulate risk and threaten the platform’s overall solvency. - -13. **Liquidation Leaves Traders Unhealthier:** Certain liquidation algorithms may inadvertently worsen a borrower’s health by prioritizing the removal of lower-risk collateral, thereby leaving behind riskier positions and potentially setting the stage for subsequent, compounding liquidations. - Follow the guidelines below for your response: - 1. Describe this practical, exploitable code vulnerability in detail. It should be logical and an error or logic missing in the code, not based on technical errors or just security advice or best practices. - 2. Show step-by-step how to exploit this vulnerability. The exploit should be beneficial for an auditor and could invalidate the code. - 3. Keep your description clear and concise. Avoid vague terms. - 4. Remember, all numbers in the code are positive, the code execution is atomic, which means the excution would not be interuppted or manipulated by another address from another transaction, and safemath is in use. - 5. dont response in "attaker use some way" the exploit method must be clear and useable - 6. Dont consider any corner case or extreme scenario, the vulnerability must be practical and exploitable. - 7. Assume that the attack can not have the role of the owner of the contract - Ensure your response is as detailed as possible, and strictly adheres to all the above requirements \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..e7d3905c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,185 @@ +[project] +name = "FiniteMonkey" +version = "0.0.1" +description = "Intelligent vulnerability mining engine." +authors = [{ name = "Xue Yue" }] +readme = "README.md" +license = { text = "Apache 2.0" } +requires-python = ">=3.10" + +dependencies = [ + "antlr4-python3-runtime>=4.13.2", + "asyncpg>=0.30.0", + "certifi>=2024.8.30", + "charset-normalizer>=3.4.0", + "devtools>=0.9.0", + "fastapi>=0.115.8", + "future>=1.0.0", + "jiter>=0.8.2", + "joblib>=1.4.2", + "lancedb>=0.19.0", + "logging>=0.4.9.6", + "loguru>=0.7.3", + "networkx>=3.4.2", + "numpy>=1.24.4", + "ollama>=0.4.7", + "openai>=1.63.0", + "pandas>=1.24.4", + "psycopg2-binary>=2.9.9", + "pydantic>=2.11.0a2", + "pydantic-ai>=0.0.24", + "pydantic-graph>=0.0.24", + "pydantic-settings>=2.7.1", + "pydot>=3.0.4", + "pylance>=0.23.0", + "python-box>=7.3.2", + "python-dateutil==2.8.2", + "pytz>=2024.2", + "requests>=2.32.3", + "rich>=13.9.4", + "scipy>=1.10.1", + "six>=1.16.0", + "sqlalchemy[async,asyncio,asyncnpg,asyncpg,asynnpg,asynpg]>=2.0.38", + "threadpoolctl>=3.5.0", + "tree-sitter>=0.24.0", + "tree-sitter-solidity>=1.2.11", + "typing_extensions>=4.12.2", + "urllib3>=2.2.3", + "httpx[http2]>=0.28.1", + "nest-asyncio>=1.6.0", + "prompt-toolkit>=3.0.50", + "pygments>=2.19.1", + "asyncio[subprocess]>=3.4.3", + "textual-dev>=1.7.0", + "textual[demo,dev]>=2.1.0", + "textual-web>=0.4.2", + "ipython>=8.32.0", + "openpyxl>=3.1.5", + "typing[io]>=3.10.0.0", +] + +[project.urls] +"Homepage" = "https://github.com/BradMoonUESTC/finite-monkey-engine" +"Bug Reports" = "https://github.com/BradMoonUESTC/finite-monkey-engine/issues" +"Funding" = "https://github.com/BradMoonUESTC/finite-monkey-engine" +"Say Thanks!" = "https://github.com/BradMoonUESTC/finite-monkey-engine" +"Source" = "https://github.com/BradMoonUESTC/finite-monkey-engine" + +[project.scripts] +finite_monkey_engine = "finite_monkey_engine:main" +finite_monkey_root = "finite_monkey_engine.root_service:main" + +[build-system] +requires = ["setuptools>=73.0.0", "wheel"] +build-backend = "setuptools.build_meta" +[tool.uv.sources] +"finite_monkey_engine" = { workspace = true } + +[tool.finite-monkey-engine] +AZURE_API_VERSION = "test" +COMMON_PROJECT = "pyproject" +GEMINI_API_KEY = "gemini_key" +AZURE_OR_OPENAI = "OPENAI" +id="12123" +output="/tmp/121" +base_dir="/export/md0/contracts-eb12d93c17cf93b27cba7b3a49ebdc9536d7d894/" + +[tool.hatch.build.targets.wheel] +packages = ["finite_monkey_engine"] + + +[tool.uv.workspace] +"members" = ["/FiniteMonkey.egg-info/entry_points.txt"] + +[tool.uv] +package = true +compile-bytecode = true +prerelease = "allow" + +[tool.uv.pip] +compile-bytecode = true +emit-build-options = true +emit-index-annotation = true +emit-marker-expression = true +python = "3.12" +python-version = "3.12" +universal = true + + +[tool.setuptools.package-dir] +"finite_monkey_engine" = "." + +[tool.black] +target-version = ['py312'] +#line-length = 120 +skip-string-normalization = true +skip-magic-trailing-comma = true +force-exclude = ''' +/( + | docs + | setup.py +)/ +''' + +[tool.isort] +py_version = 312 +sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] +default_section = "THIRDPARTY" +known_third_party = [] +known_first_party = [] +known_local_folder = [] +# style: black +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +ensure_newline_before_comments = true + +split_on_trailing_comma = true +lines_after_imports = 2 +force_single_line = true +skip_glob = ["docs/*", "setup.py"] +filter_files = true + + +[tool.mypy] +# Platform configuration +python_version = "3.12" +# imports related +plugins = ['pydantic.mypy'] +ignore_missing_imports = true +follow_imports = "silent" +# None and Optional handling +no_implicit_optional = false +strict_optional = false +# Configuring warnings +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +warn_unreachable = true +warn_return_any = false +# Untyped definitions and calls +check_untyped_defs = false +disallow_untyped_calls = false +disallow_untyped_defs = false +disallow_incomplete_defs = false +disallow_untyped_decorators = false +# Disallow dynamic typing +disallow_subclassing_any = false +disallow_any_unimported = false +disallow_any_expr = false +disallow_any_decorated = false +disallow_any_explicit = false +disallow_any_generics = false +# Miscellaneous strictness flags +allow_untyped_globals = true +allow_redefinition = true +local_partial_types = false +implicit_reexport = true +strict_equality = true +# Configuring error messages +show_error_context = false +show_column_numbers = false +show_error_codes = true +exclude = ["docs", "test", "tests"] diff --git a/requirements-from-uv.txt b/requirements-from-uv.txt new file mode 100644 index 00000000..063b77ff --- /dev/null +++ b/requirements-from-uv.txt @@ -0,0 +1,119 @@ +annotated-types==0.7.0 +anthropic==0.45.2 +antlr4-python3-runtime==4.13.2 +anyio==4.8.0 +asttokens==2.4.1 +asyncio==3.4.3 +asyncpg==0.30.0 +cachetools==5.5.1 +certifi==2025.1.31 +charset-normalizer==3.4.1 +cohere==5.13.12 +colorama==0.4.6 +decorator==5.1.1 +dependency-injector==4.45.0 +deprecation==2.1.0 +devtools==0.12.2 +distro==1.9.0 +eval-type-backport==0.2.2 +executing==2.2.0 +fastapi==0.115.8 +fastavro==1.10.0 +filelock==3.17.0 +-e file:///home/files/git/finite-monkey-engine +fsspec==2025.2.0 +future==1.0.0 +google-api-core==2.24.1 +google-auth==2.38.0 +google-cloud-core==2.4.1 +google-cloud-translate==3.20.0 +googleapis-common-protos==1.67.0 +greenlet==3.1.1 +griffe==1.5.7 +groq==0.18.0 +grpc-google-iam-v1==0.14.0 +grpcio==1.70.0 +grpcio-status==1.70.0 +h11==0.14.0 +h2==4.2.0 +hpack==4.1.0 +httpcore==1.0.7 +httpx==0.28.1 +httpx-sse==0.4.0 +huggingface-hub==0.28.1 +hyperframe==6.1.0 +idna==3.10 +ipython==8.32.0 +jedi==0.19.2 +jiter==0.8.2 +joblib==1.4.2 +jsonpath-python==1.0.6 +lancedb==0.19.0 +linkify-it-py==2.0.3 +logfire-api==3.5.3 +logging==0.4.9.6 +loguru==0.7.3 +markdown-it-py==3.0.0 +matplotlib-inline==0.1.7 +mdit-py-plugins==0.4.2 +mdurl==0.1.2 +mistralai==1.5.0 +mypy-extensions==1.0.0 +nest-asyncio==1.6.0 +networkx==3.4.2 +numpy==2.2.3 +ollama==0.4.7 +openai==1.63.2 +overrides==7.7.0 +packaging==24.2 +pandas==2.2.3 +parso==0.8.4 +pexpect==4.9.0 +platformdirs==4.3.6 +prompt-toolkit==3.0.50 +proto-plus==1.26.0 +protobuf==5.29.3 +psycopg2-binary==2.9.10 +ptyprocess==0.7.0 +pure-eval==0.2.3 +pyarrow==19.0.0 +pyasn1==0.6.1 +pyasn1-modules==0.4.1 +pydantic==2.11.0a2 +pydantic-ai==0.0.24 +pydantic-ai-slim==0.0.24 +pydantic-core==2.29.0 +pydantic-graph==0.0.24 +pydantic-settings==2.7.1 +pydot==3.0.4 +pygments==2.19.1 +pylance==0.23.0 +pyparsing==3.2.1 +python-box==7.3.2 +python-dateutil==2.8.2 +python-dotenv==1.0.1 +pytz==2025.1 +pyyaml==6.0.2 +requests==2.32.3 +rich==13.9.4 +rsa==4.9 +scipy==1.15.2 +six==1.17.0 +sniffio==1.3.1 +sqlalchemy==2.0.38 +stack-data==0.6.3 +starlette==0.45.3 +textual==2.0.4 +threadpoolctl==3.5.0 +tokenizers==0.21.0 +tqdm==4.67.1 +traitlets==5.14.3 +tree-sitter==0.24.0 +tree-sitter-solidity==1.2.11 +types-requests==2.32.0.20241016 +typing-extensions==4.12.2 +typing-inspect==0.9.0 +tzdata==2025.1 +uc-micro-py==1.0.3 +urllib3==2.3.0 +wcwidth==0.2.13 diff --git a/run.sh b/run.sh new file mode 100755 index 00000000..ed9b1331 --- /dev/null +++ b/run.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +# Navigate to the project root directory +cd "$(dirname "$0")" + +uvicorn run:run --host 0.0.0.0 --port 8000 diff --git a/rx.py b/rx.py new file mode 100644 index 00000000..e2367e22 --- /dev/null +++ b/rx.py @@ -0,0 +1,4 @@ +from app.core import CoreApp + +app = CoreApp() +app.run() diff --git a/services/cache.py b/services/cache.py new file mode 100644 index 00000000..0fc0f87a --- /dev/null +++ b/services/cache.py @@ -0,0 +1,21 @@ +import shelve +from datetime import datetime + +class TranslationCache: + def __init__(self): + self.store = shelve.open("translations.db") + + def get(self, key: str): + return self.store.get(key) + + def set(self, key: str, value: str): + self.store[key] = { + 'value': value, + 'timestamp': datetime.now().isoformat() + } + + def clear(self): + self.store.clear() + + def __del__(self): + self.store.close() diff --git a/services/llm.py b/services/llm.py new file mode 100644 index 00000000..2d9e3b8b --- /dev/null +++ b/services/llm.py @@ -0,0 +1,54 @@ +from httpx import AsyncClient +from openai import AsyncOpenAI +from models.schemas import LogCtxData + +from openai import AsyncOpenAI +from httpx import AsyncClient +from pydantic_ai.models.openai import OpenAIModel +from pydantic_ai.agent import Agent +from pydantic_ai.result import RunResult +from typing import Optional +from rich.console import Console +class LLMClientv1: + def __init__(self): + self.http_client = AsyncClient() + self.client = AsyncOpenAI( + base_url="http://127.0.0.1:11434/v1", + api_key="k", + http_client=self.http_client + ) + + async def translate(self, text: str, direction: str) -> str: + prompt = f"Translate this text to {direction}: {text}" + response = await self.client.chat.completions.create( + model="towerinstruct", + messages=[{"role": "user", "content": prompt}] + ) + return response.choices[0].message.content + + +class OpenAIModel: + """Wrapper for OpenAI client with custom config""" + def __init__(self, model: str, api_key: str, base_url: str, http_client: AsyncClient): + self.client = AsyncOpenAI( + base_url=base_url, + api_key=api_key, + http_client=http_client + ) + self.model = model + +class LLMClient: + def __init__(self): + self.http = AsyncClient() + self.ingress = OpenAIModel( + "hf.co/mradermacher/TowerInstruct-WMT24-Chat-7B-i1-GGUF:Q4_K_M", + api_key="k", + base_url="http://127.0.0.1:11434/v1", + http_client=self.http + ) + self.egress = OpenAIModel( + "hf.co/mradermacher/TowerInstruct-WMT24-Chat-7B-i1-GGUF:Q4_K_M", + api_key="k", + base_url="http://127.0.0.1:11434/v1", + http_client=self.http + ) \ No newline at end of file diff --git a/services/translation.py b/services/translation.py new file mode 100644 index 00000000..094a483a --- /dev/null +++ b/services/translation.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import Protocol, Any + +class TranslationClient(Protocol): + async def translate(self, text: str, direction: str) -> str: ... + +@dataclass +class TranslationProcessor: + client: TranslationClient + cache: Any + + async def process(self, text: str) -> tuple[str, str]: + # Implementation using injected dependencies + if cached := self.cache.get(text): + return cached + + direction = await self.detect_language(text) + result = await self.client.translate(text, direction) + self.cache.set(text, result) + return (text, result) if direction == "EN->CN" else (result, text) + + async def detect_language(self, text: str) -> str: + return "EN->CN" if any(ord(c) > 127 for c in text) else "CN->EN" diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agents/__init__.py b/src/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agents/formatted_console.py b/src/agents/formatted_console.py new file mode 100644 index 00000000..4ce731ae --- /dev/null +++ b/src/agents/formatted_console.py @@ -0,0 +1,79 @@ +from typing import Annotated + +from pydantic import Field, ValidationError +from rich.console import Console +from rich.live import Live +from rich.table import Table +from typing_extensions import NotRequired, TypedDict + +from pydantic_ai import Agent + +class Whale(TypedDict): + name: str + length: Annotated[ + float, Field(description='Average length of an adult whale in meters.') + ] + weight: NotRequired[ + Annotated[ + float, + Field(description='Average weight of an adult whale in kilograms.', ge=50), + ] + ] + ocean: NotRequired[str] + description: NotRequired[Annotated[str, Field(description='Short Description')]] + + +agent = Agent('openai:gpt-4', result_type=list[Whale]) + + +async def main(): + console = Console() + with Live('\n' * 36, console=console) as live: + console.print('Requesting data...', style='cyan') + async with agent.run_stream( + 'Generate me details of 5 species of Whale.' + ) as result: + console.print('Response:', style='green') + + async for message, last in result.stream_structured(debounce_by=0.01): + try: + whales = await result.validate_structured_result( + message, allow_partial=not last + ) + except ValidationError as exc: + if all( + e['type'] == 'missing' and e['loc'] == ('response',) + for e in exc.errors() + ): + continue + else: + raise + + table = Table( + title='Species of Whale', + caption='Streaming Structured responses from GPT-4', + width=120, + ) + table.add_column('ID', justify='right') + table.add_column('Name') + table.add_column('Avg. Length (m)', justify='right') + table.add_column('Avg. Weight (kg)', justify='right') + table.add_column('Ocean') + table.add_column('Description', justify='right') + + for wid, whale in enumerate(whales, start=1): + table.add_row( + str(wid), + whale['name'], + f'{whale["length"]:0.0f}', + f'{w:0.0f}' if (w := whale.get('weight')) else '…', + whale.get('ocean') or '…', + whale.get('description') or '…', + ) + live.update(table) + + +if __name__ == '__main__': + import asyncio + + asyncio.run(main()) diff --git a/src/agents/md_output.py b/src/agents/md_output.py new file mode 100644 index 00000000..80336764 --- /dev/null +++ b/src/agents/md_output.py @@ -0,0 +1,210 @@ +import asyncio +import threading +from typing import Any, TextIO + +from httpx import AsyncClient +from rich.console import Console, ConsoleOptions, RenderResult,Group +from rich.live import Live +from rich.markdown import CodeBlock, Markdown +from rich.syntax import Syntax +from rich.text import Text +from rich.columns import Columns + +from pydantic_ai.models import KnownModelName +from rich.panel import Panel +from pydantic_ai.agent import Agent, RunContext +from pydantic.dataclasses import dataclass +from pydantic_ai.models.openai import OpenAIModel +from rich.logging import RichHandler + + +console = Console() + + +from loguru import logger as logging +def logdata(message): + logging.info(message) + +class LogToMarkdown: + def __init__(self, filename): + self.filename = filename + self.file = open(filename, 'a+') + + def write(self, message): + # Translate and log the message to Markdown + translated_message = self.translate_to_markdown(message) + self.file.write(translated_message) + + def flush(self): + self.file.flush() + + @staticmethod + def translate_to_markdown(text): + # Simple translation logic (you can customize this part as needed) + return text.replace('\n', ' \n') + '\n' + + def close(self): + self.file.close() + + +@dataclass +class LogCtxData: + txtENG: str = "" + txtCN: str = "" + detail: str = "" + + +h = AsyncClient() +oai = OpenAIModel( + "hf.co/mradermacher/TowerInstruct-WMT24-Chat-7B-i1-GGUF:Q4_K_M", + api_key="k", + base_url="http://127.0.0.1:11434/v1", + http_client=h, +) +sys_prompt = """You are a very concise and capable translation helper. +You are very good at making clear and accurate translations for technical phrases and slang into a familure and understandable phrase of the language indicated by the user. +You adapt technical jargon as best as possible and maintain clear translations. +After commencing through a translation, you do not add any exantemperanious information, only the text being evaluated is to be reflected in you're output as that is needed.""" +egress_txt = Agent( + oai, + deps_type=LogCtxData, + system_prompt=sys_prompt +) + +ingress_txt = Agent( + oai, + deps_type=LogCtxData, + system_prompt=sys_prompt +) + + +class LogView: + def __init__(self): + self.scroll_offset = 0 + self.content1 = "" + self.content2 = "" + self.minimized_column = None + + def setup_logging(self): + ... + # logging.basicConfig( + # level=logging.DEBUG, + # format="%(message)s", + # datefmt="[%X]", + # handlers=[RichHandler(rich_tracebacks=True)]) + + async def logdata(self, info: str, details: str = ""): + lang = await self.detect_language(info) + eng = zh = "" + + ctx = LogCtxData(txtENG=info, txtCN=info, detail=details) + try: + trans = await ingress_txt.run( + deps=ctx, + user_prompt=f"{lang}" + ) + if lang.startswith("CN->EN"): + eng = info + zh = trans.data + else: + eng = trans.data + zh = info + except Exception as e: + console.print(f"Error during translation: {e}") + return + + self.content1 += f"{eng}\n" + self.content2 += f"{zh}\n" + + async def detect_language(self, text: str) -> str: + # Return EN for chinese and ZH for english as we want to know what language to translate into + return "CN->EN: Translate the text following this sentence from Chinese (zh) to English (en) ### " if any("\u4e00" <= ch <= "\u9fff" for ch in text) else "EN->CN: Translate the text following this sentence from English (en) to Chinese (zh) ### " + + + async def render(self): + if self.minimized_column == 1: + left_panel = Panel(Markdown(self.content1[self.scroll_offset:]), title="Chinese") + right_panel = "" + elif self.minimized_column == 2: + left_panel = "" + right_panel = Panel(Markdown(self.content2[self.scroll_offset:]), title="English") + else: + left_panel = Panel(Markdown(self.content1[self.scroll_offset:]), title="Chinese") + right_panel = Panel(Markdown(self.content2[self.scroll_offset:]), title="English") + spinner_frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"] + # Reconfigure logging to write debug logs to "debug.log" instead of the UI console. + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + file_handler = logging.FileHandler("debug.log") + file_handler.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + file_handler.setFormatter(formatter) + logging.getLogger().addHandler(file_handler) + spinner = spinner_frames[int(asyncio.get_event_loop().time() * 10) % len(spinner_frames)] + status_bar = Panel(Markdown(f"{spinner} UI Active"), style="bold green", height=1) + renderable = Group(Columns([left_panel, right_panel]), status_bar) + console.clear() # Clear previous output + console.print(renderable) + # return + + # columns_txt = Columns([left_panel, right_panel]) + + +logView = LogView() + +async def setup_live_console(): + prettier_code_blocks() + console = Console() + with Live('', console=console, vertical_overflow='visible') as live: + async for message in egress_txt.stream(): + live.update(Markdown(message)) +def prettier_code_blocks(): + class SimpleCodeBlock(CodeBlock): + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + code = str(self.text).rstrip() + yield Text(self.lexer_name, style='dim') + yield Syntax( + code, + self.lexer_name, + theme=self.theme, + background_color='default', + word_wrap=True, + ) + yield Text(f'/{self.lexer_name}', style='dim') + + Markdown.elements['fence'] = SimpleCodeBlock + +def start_event_loop(loop: asyncio.AbstractEventLoop): + asyncio.set_event_loop(loop) + loop.run_forever() + +# Create a new event loop for the group of coroutines. +new_loop = asyncio.new_event_loop() + +# Start the new event loop in a dedicated thread. +t = threading.Thread(target=start_event_loop, args=(new_loop,)) +t.start() + +# Now, schedule your coroutines on the new loop: +async def run_console_output(name): + await setup_live_console() + print(f"Task {name} done in isolated loop") + + +if __name__ == '__main__': + asyncio.run(run_console_output("logging")) + + + +# # Create a new event loop for the group of coroutines. +# new_loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() +# # Start the new event loop in a dedicated thread. +# t = threading.Thread(target=start_event_loop, args=(new_loop,)) +# t.start() + +# # Now, schedule your coroutines on the new loop: +# async def run_console_output(name): +# await setup_live_console(logView) +# print(f"Task {name} done in isolated loop") diff --git a/src/ai_async_eng.py b/src/ai_async_eng.py new file mode 100644 index 00000000..7c0646fa --- /dev/null +++ b/src/ai_async_eng.py @@ -0,0 +1,497 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +import os +from typing import Any, Coroutine, List +from venv import logger +from tqdm import tqdm +# import lancedb +from dao.atask_mgr import AProject_Task, AProjectTaskMgr +from planning.aplanning_v2 import APlanningV2 +from ai_engine import PromptAssembler, as_completed +import re +from json import loads,JSONDecodeError +import time +# ... other imports ... + +class AiEngine: + def __init__(self, planning, taskmgr, lance, lance_table_name, project_audit): + # Step 1: 获取results + self.planning:APlanningV2 = planning + self.project_taskmgr:AProjectTaskMgr = taskmgr + self.lancedb:Coroutine[Any, Any, Any] = lance + self.lance_table_name = lance_table_name + self.project_audit = project_audit + self.executor = ThreadPoolExecutor(max_workers=int(os.getenv("MAX_THREADS_OF_SCAN", 5))) + + async def do_scan(self, is_gpt4=False, filter_func=None): + tasks: Coroutine[Any, Any, Any] = self.project_taskmgr.get_task_list() + if not tasks: + logger.warn("AI ASYNC ENG no Tasks in the project, is the DB ok?") + return + + async def process_task(task): + # Assemble prompt based on environment; assume PromptAssembler methods are synchronous. + scan_mode = os.getenv("SCAN_MODE", "COMMON_VUL") + if scan_mode == "OPTIMIZE": + prompt = PromptAssembler.assemble_optimize_prompt(task.content) + elif scan_mode == "COMMON_PROJECT": + prompt = PromptAssembler.assemble_prompt_common(task.content) + elif scan_mode == "PURE_SCAN": + prompt = PromptAssembler.assemble_prompt_pure(task.content) + elif scan_mode == "SPECIFIC_PROJECT": + business_types = task.recommendation.split(',') + prompt = PromptAssembler.assemble_prompt_for_specific_project(task.content, business_types) + else: + prompt = PromptAssembler.assemble_prompt_common(task.content) + + # Call async LLM API via our async wrapper + response_vul = await async_ask_claude(prompt) + response_vul = response_vul or "no" + # Update task result; if update_result is blocking, offload it: + await asyncio.to_thread(self.project_taskmgr.update_result, task.id, response_vul, "", "") + + await asyncio.gather(*(process_task(task) for task in tasks)) + return tasks + + async def check_function_vul(self): + tasks = self.project_taskmgr.get_task_list() + if not tasks: + return + + async def process_check(task): + prompt = PromptAssembler.assemble_vul_check_prompt(task.content, task.get_result(False)) + # Assume we create an async wrapper for confirmation API calls; + # if not, wrap the blocking call via asyncio.to_thread: + initial_response = await asyncio.to_thread(common_ask_confirmation, prompt) + if not initial_response: + print(f"Empty response for task {task.id}") + return + # Additional processing, JSON parsing, voting logic etc. + # Finally, update results (offload DB call): + await asyncio.to_thread(self.project_taskmgr.update_result, task.id, task.get_result(False), "final_status", "final_response") + + await asyncio.gather(*(process_check(task) for task in tasks)) + return tasks + def do_planning(self): + self.planning.do_planning() + def extract_title_from_text(self,input_text): + try: + # Regular expression pattern to capture the value of the title field + pattern = r'"title"\s*:\s*"([^"]+)"' + + # Searching for the pattern in the input text + match = re.search(pattern, input_text) + + # Extracting the value if the pattern is found + if match: + return match.group(1) + else: + return "Logic Error" + except Exception as e: + # Handling any exception that occurs and returning a message + return f"Logic Error {str(e)}" + + def process_task_do_scan(self,task, filter_func = None, is_gpt4 = False): + + response_final = "" + response_vul = "" + + # print("query vul %s - %s" % (task.name, task.rule)) + + result = task.get_result(is_gpt4) + business_flow_code = task.business_flow_code + if_business_flow_scan = task.if_business_flow_scan + function_code=task.content + + # 要进行检测的代码粒度 + code_to_be_tested=business_flow_code if if_business_flow_scan=="1" else function_code + if result is not None and len(result) > 0 and str(result).strip() != "NOT A VUL IN RES no": + print("\t skipped (scanned)") + else: + to_scan = filter_func is None or filter_func(task) + if not to_scan: + print("\t skipped (filtered)") + else: + print("\t to scan") + if os.getenv("SCAN_MODE","COMMON_VUL")=="OPTIMIZE": + prompt=PromptAssembler.assemble_optimize_prompt(code_to_be_tested) + elif os.getenv("SCAN_MODE","COMMON_VUL")=="COMMON_PROJECT": + prompt=PromptAssembler.assemble_prompt_common(code_to_be_tested) + elif os.getenv("SCAN_MODE","COMMON_VUL")=="PURE_SCAN": + prompt=PromptAssembler.assemble_prompt_pure(code_to_be_tested) + elif os.getenv("SCAN_MODE","COMMON_VUL")=="SPECIFIC_PROJECT": + # 构建提示来判断业务类型 + business_type=task.recommendation + print(f"[DEBUG] business_type: {business_type}") + # 数据库中保存的形式是xxxx,xxxxx,xxxx... 转成assemble_prompt_for_specific_project可以接收的数组形式 + business_type_list=business_type.split(',') + print(f"[DEBUG] business_type_list: {business_type_list}") + prompt = PromptAssembler.assemble_prompt_for_specific_project(code_to_be_tested, business_type_list) + print(f"[DEBUG] Generated prompt: {prompt}") + response_vul=ask_claude(prompt) + print(f"[DEBUG] Claude response: {response_vul}") + response_vul = response_vul if response_vul is not None else "no" + self.project_taskmgr.update_result(task.id, response_vul, "","") + def do_scan(self, is_gpt4=False, filter_func=None): + # self.llm.init_conversation() + + tasks = self.project_taskmgr.get_task_list() + if len(tasks) == 0: + return + + # 定义线程池中的线程数量 + max_threads = int(os.getenv("MAX_THREADS_OF_SCAN", 5)) + + with ThreadPoolExecutor(max_workers=max_threads) as executor: + futures = [executor.submit(self.process_task_do_scan, task, filter_func, is_gpt4) for task in tasks] + + with tqdm(total=len(tasks), desc="Processing tasks") as pbar: + for future in as_completed(futures): + future.result() # 等待每个任务完成 + pbar.update(1) # 更新进度条 + + return tasks + def process_task_check_vul(self, task:AProject_Task): + print("\n" + "="*80) + print(f"Processing Task ID: {task.id}") + print("="*80) + starttime = time.time() + result = task.get_result(False) + result_CN = task.get_result_CN() + category_mark = task.get_category() + + if result_CN is not None and len(result_CN) > 0 and result_CN != "None" and category_mark is not None and len(category_mark)>0: + print("\n🔄 Task already processed, skipping...") + return + + print("\n🔍 Starting vulnerability confirmation process...") + function_code = task.content + if_business_flow_scan = task.if_business_flow_scan + business_flow_code = task.business_flow_code + business_flow_context = task.business_flow_context + + code_to_be_tested = business_flow_code+"\n"+business_flow_context if if_business_flow_scan=="1" else function_code + + # 第一轮分析 + print("\n=== First Round Analysis ===") + print("📝 Analyzing potential vulnerability...") + prompt = PromptAssembler.assemble_vul_check_prompt(code_to_be_tested, result) + # 把prompot保存到临时文件 + with open("prompt.txt", "w") as file: + file.write(prompt) + + initial_response = common_ask_confirmation(prompt) + if not initial_response or initial_response == "": + print(f"❌ Error: Empty response received for task {task.id}") + return + + print("\n📊 Initial Analysis Result:") + print("-" * 80) + print(initial_response) + print("-" * 80) + + # 提取需要的额外信息 + required_info = self.extract_required_info(initial_response) + + combined_code = code_to_be_tested + if required_info: + print("\n=== Additional Information Required ===") + print("🔎 Required context/information:") + for i, info in enumerate(required_info, 1): + print(f"{i}. {info}") + + print("\n📥 Retrieving additional context...") + additional_context = self.get_additional_context(required_info) + + if additional_context: + print("\n📦 Retrieved additional context (length: {len(additional_context)} chars)") + if len(additional_context) < 500: + print("\nAdditional context details:") + print("-" * 80) + print(additional_context) + print("-" * 80) + + combined_code = f"""Original Code: + {code_to_be_tested} + + First Round Analysis: + {initial_response} + + Additional Context: + {additional_context}""" + + # 进行三轮确认 + confirmation_results = [] + response_final = None # 初始化 response_final + final_response = None # 初始化 final_response + + for i in range(3): + if response_final == "no": # 如果已经确认为 no,直接跳过后续循环 + break + + print(f"\n📊 Round {i+1}/3 Analysis:") + prompt = PromptAssembler.assemble_vul_check_prompt_final(combined_code, result) + round_response = common_ask_confirmation(prompt) + + print("-" * 80) + print(round_response) + print("-" * 80) + + prompt_translate_to_json = PromptAssembler.brief_of_response() + print("\n🔍 Brief Response Prompt:") + print(prompt_translate_to_json) + + round_json_response = str(common_ask_for_json(round_response+"\n"+prompt_translate_to_json)) + print("\n📋 JSON Response:") + print(round_json_response) + + try: + response_data = loads(round_json_response) + result_status = response_data.get("result", "").lower() + print("\n🎯 Extracted Result Status:") + print(result_status) + + confirmation_results.append(result_status) + + # 如果发现一个明确的 "no",立即确认为不存在漏洞 + if "no" in result_status: + print("\n🛑 Clear 'no vulnerability' detected - stopping further analysis") + response_final = "no" + final_response = f"Analysis stopped after round {i+1} due to clear 'no vulnerability' result" + continue # 使用 continue 让循环能够在下一轮开始时通过上面的 break 检查退出 + + except JSONDecodeError: + print("\n⚠️ JSON Decode Error - marking as 'not sure'") + confirmation_results.append("not sure") + + # 只有在没有提前退出(找到明确的 no)的情况下才进行多数投票 + if response_final != "no": # 修改判断条件 + # 统计结果 + yes_count = sum(1 for r in confirmation_results if "yes" in r or "confirmed" in r) + no_count = sum(1 for r in confirmation_results if "no" in r and "vulnerability" in r) + + if yes_count >= 2: + response_final = "yes" + print("\n⚠️ Final Result: Vulnerability Confirmed (2+ positive confirmations)") + elif no_count >= 2: + response_final = "no" + print("\n✅ Final Result: No Vulnerability (2+ negative confirmations)") + else: + response_final = "not sure" + print("\n❓ Final Result: Not Sure (inconclusive results)") + + final_response = "\n".join([f"Round {i+1} Analysis:\n{resp}" for i, resp in enumerate(confirmation_results)]) + + self.project_taskmgr.update_result(task.id, result, response_final, final_response) + + endtime = time.time() + time_cost = endtime - starttime + + print("\n=== Task Summary ===") + print(f"⏱️ Time cost: {time_cost:.2f} seconds") + print(f"📝 Analyses performed: {len(confirmation_results)}") + print(f"🏁 Final status: {response_final}") + print("=" * 80 + "\n") + def get_related_functions(self,query,k=3): + query_embedding = common_get_embedding(query) + table = self.lancedb.open_table(self.lance_table_name) + return table.search(query_embedding).limit(k).to_list() + + def extract_related_functions_by_level(self, function_names: List[str], level: int) -> str: + """ + 从call_trees中提取指定函数相关的上下游函数信息并扁平化处理 + + Args: + function_names: 要分析的函数名列表 + level: 要分析的层级深度 + + Returns: + str: 所有相关函数内容的拼接文本 + """ + def get_functions_from_tree(tree, current_level=0, max_level=level, collected_funcs=None, level_stats=None): + """递归获取树中指定层级内的所有函数信息""" + if collected_funcs is None: + collected_funcs = [] + if level_stats is None: + level_stats = {} + + if not tree or current_level > max_level: + return collected_funcs, level_stats + + # 添加当前节点的函数信息 + if tree['function_data']: + collected_funcs.append(tree['function_data']) + # 更新层级统计 + level_stats[current_level] = level_stats.get(current_level, 0) + 1 + + # 递归处理子节点 + if current_level < max_level: + for child in tree['children']: + get_functions_from_tree(child, current_level + 1, max_level, collected_funcs, level_stats) + + return collected_funcs, level_stats + + all_related_functions = [] + statistics: dict[str, dict[int, int]] = { + 'upstream_stats': {}, + 'downstream_stats': {} + } + + # 使用集合进行更严格的去重 + seen_functions = set() # 存储函数的唯一标识符 + unique_functions = [] # 存储去重后的函数 + + # 遍历每个指定的函数名 + for func_name in function_names: + # 在call_trees中查找对应的树 + for tree_data in self.project_audit.call_trees: + if tree_data['function'] == func_name: + # 处理上游调用树 + if tree_data['upstream_tree']: + upstream_funcs, upstream_stats = get_functions_from_tree(tree_data['upstream_tree']) + all_related_functions.extend(upstream_funcs) + # 合并上游统计信息 + for level, count in upstream_stats.items(): + if isinstance(statistics['upstream_stats'], dict): + statistics['upstream_stats'][level] = statistics['upstream_stats'].get(level, 0) + count + else: # the following line is unreachable + statistics['upstream_stats'] = {level: count} + + # 处理下游调用树 + if tree_data['downstream_tree']: + downstream_funcs, downstream_stats = get_functions_from_tree(tree_data['downstream_tree']) + all_related_functions.extend(downstream_funcs) + # 合并下游统计信息 + for level, count in downstream_stats.items(): + # TODO #?? Double check this + statistics['downstream_stats'][level] = statistics['downstream_stats'].get(level, 0) + count + + # 添加原始函数本身 + for func in self.project_audit.functions_to_check: + if func['name'].split('.')[-1] == func_name: + all_related_functions.append(func) + break + + break + + # 增强的去重处理 + for func in all_related_functions: + # 创建一个更精确的唯一标识符,包含函数名和内容的hash + func_identifier = f"{func['name']}_{hash(func['content'])}" + if func_identifier not in seen_functions: + seen_functions.add(func_identifier) + unique_functions.append(func) + + # 拼接所有函数内容,包括状态变量 + combined_text_parts = [] + for func in unique_functions: + # 查找对应的状态变量 + state_vars = None + for tree_data in self.project_audit.call_trees: + if tree_data['function'] == func['name'].split('.')[-1]: + state_vars = tree_data.get('state_variables', '') + break + + # 构建函数文本,包含状态变量 + function_text = [] + if state_vars: + function_text.append("// Contract State Variables:") + function_text.append(state_vars) + function_text.append("\n// Function Implementation:") + function_text.append(func['content']) + + combined_text_parts.append('\n'.join(function_text)) + + combined_text = '\n\n'.join(combined_text_parts) + + # 打印统计信息 + print("\nFunction Call Tree Statistics:") + print(f"Total Layers Analyzed: {level}") + print("\nUpstream Statistics:") + for layer, count in statistics['upstream_stats'].items(): + print(f"Layer {layer}: {count} functions") + print("\nDownstream Statistics:") + for layer, count in statistics['downstream_stats'].items(): + print(f"Layer {layer}: {count} functions") + print(f"\nTotal Unique Functions: {len(unique_functions)}") + + return combined_text + + + def check_function_vul(self): + # self.llm.init_conversation() + tasks = self.project_taskmgr.get_task_list() + # 用codebaseQA的形式进行,首先通过rag和task中的vul获取相应的核心三个最相关的函数 + for task in tqdm(tasks,desc="Processing tasks for update business_flow_context"): + if task.score=="1": + continue + if task.if_business_flow_scan=="1": + # 获取business_flow_context + code_to_be_tested=task.business_flow_code + else: + code_to_be_tested=task.content + related_functions=self.get_related_functions(code_to_be_tested,5) + related_functions_names=[func['name'].split('.')[-1] for func in related_functions] + combined_text=self.extract_related_functions_by_level(related_functions_names,6) + # 更新task对应的business_flow_context + self.project_taskmgr.update_business_flow_context(task.id,combined_text) + self.project_taskmgr.update_score(task.id,"1") + + + if len(tasks) == 0: + return + + # 定义线程池中的线程数量, 从env获取 + max_threads = int(os.getenv("MAX_THREADS_OF_CONFIRMATION", 5)) + + with ThreadPoolExecutor(max_workers=max_threads) as executor: + futures = [executor.submit(self.process_task_check_vul, task) for task in tasks] + + with tqdm(total=len(tasks), desc="Checking vulnerabilities") as pbar: + for future in as_completed(futures): + future.result() # 等待每个任务完成 + pbar.update(1) # 更新进度条 + + return tasks + + def extract_required_info(self, claude_response): + """Extract information that needs further investigation from Claude's response""" + prompt = """ + Please extract all information points that need further understanding or confirmation from the following analysis response. + If the analysis explicitly states "no additional information needed" or similar, return empty. + If the analysis mentions needing more information, extract these information points. + + Analysis response: + {response} + """ + + extraction_result = ask_claude(prompt.format(response=claude_response)) + if not extraction_result or extraction_result.isspace(): + return [] + + # If response contains negative phrases, return empty list + if any(phrase in extraction_result.lower() for phrase in ["no need", "not needed", "no additional", "no more"]): + return [] + + return [extraction_result] + + def get_additional_context(self, query_contents): + """获取额外的上下文信息""" + if not query_contents: + return "" + + # 使用所有查询内容获取相关信息 + related_functions = [] + for query in query_contents: + results = self.get_related_functions(query, k=10) # 获取最相关的3个匹配 + if results: + related_functions.extend(results) + + # 提取这些函数的上下文 + if related_functions: + function_names = [func['name'].split('.')[-1] for func in related_functions] + return self.extract_related_functions_by_level(function_names, 2) + return "" + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/src/ai_engine.py b/src/ai_engine.py index c5a0f8af..89438ca8 100644 --- a/src/ai_engine.py +++ b/src/ai_engine.py @@ -6,7 +6,7 @@ from typing import List import requests import tqdm -from sklearn.metrics.pairwise import cosine_similarity +#from sklearn.metrics.pairwise import cosine_similarity from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm import warnings @@ -16,8 +16,8 @@ from prompt_factory.prompt_assembler import PromptAssembler from prompt_factory.core_prompt import CorePrompt from openai_api.openai import * -class AiEngine(object): +class AiEngine(object): def __init__(self, planning, taskmgr,lancedb,lance_table_name,project_audit): # Step 1: 获取results self.planning = planning @@ -70,6 +70,8 @@ def process_task_do_scan(self,task, filter_func = None, is_gpt4 = False): prompt=PromptAssembler.assemble_optimize_prompt(code_to_be_tested) elif os.getenv("SCAN_MODE","COMMON_VUL")=="COMMON_PROJECT": prompt=PromptAssembler.assemble_prompt_common(code_to_be_tested) + elif os.getenv("SCAN_MODE","COMMON_VUL")=="PURE_SCAN": + prompt=PromptAssembler.assemble_prompt_pure(code_to_be_tested) elif os.getenv("SCAN_MODE","COMMON_VUL")=="SPECIFIC_PROJECT": # 构建提示来判断业务类型 business_type=task.recommendation diff --git a/src/codebaseQA/arag_processor.py b/src/codebaseQA/arag_processor.py new file mode 100644 index 00000000..a8706d03 --- /dev/null +++ b/src/codebaseQA/arag_processor.py @@ -0,0 +1,94 @@ +import lancedb +import os +import numpy as np +import requests +import pyarrow as pa +from typing import Coroutine, List, Dict, Any +from datetime import datetime +from tqdm.asyncio import tqdm + +from openai_api.openai import common_get_embedding +from project.aproject_audit import AProjectAudit + +class ARAGProcessor: + + async def acheck_data_count(self, expected_count: int) -> bool: + """检查表中的数据数量是否匹配""" + try: + table = await self.db.open_table(self.table_name) + actual_count = len(await table.to_lance()) + return actual_count == expected_count + except Exception: + return False + + def __init__(self, id: str = None, audit: AProjectAudit = None): + self.db_path: str = os.path.join(os.getcwd(), f"Alancedb{id}") + self.audit:AProjectAudit = audit + os.makedirs(name=self.db_path, exist_ok=True) + functions_to_check: List[Dict[str, Any]] = audit.functions_to_check + self.db: Coroutine[Any, Any, lancedb.AsyncConnection] = lancedb.connect_async(self.db_path) + self.table_name = f"Alancedb_{id}" + + # 创建schema + self.schema = pa.schema([ + pa.field("id", pa.string()), + pa.field("name", pa.string()), + pa.field("content", pa.string()), + pa.field("start_line", pa.int32()), + pa.field("end_line", pa.int32()), + pa.field("file_path", pa.string()), + pa.field("embedding", pa.list_(pa.float32(), 3072)), + pa.field("modifiers", pa.list_(pa.string())), + pa.field("visibility", pa.string()), + pa.field("state_mutability", pa.string()) + ]) + + async def table_exists(self) -> bool: + """检查表是否存在""" + try: + await self.db.open_table(self.table_name) + return True + except Exception: + return False + + def process_function(self, func: Dict[str, Any]) -> Dict[str, Any]: + return { + "id": f"{func['name']}_{func['start_line']}", + "name": func['name'], + "content": func['content'], + "start_line": func['start_line'], + "end_line": func['end_line'], + "file_path": func['relative_file_path'], + "embedding": common_get_embedding(func['content']), + "modifiers": func.get('modifiers', []), + "visibility": func.get('visibility', ''), + "state_mutability": func.get('stateMutability', '') + } + + async def _create_database(self, functions_to_check: List[Dict[str, Any]]) -> None: + print(f"Processing {len(functions_to_check)} functions...") + + # 创建表 + table = await self.db.create_table(self.table_name, schema=self.schema, mode="overwrite") + + # 逐条处理并添加数据 + for func in tqdm(functions_to_check, desc="Processing functions", unit="function"): + try: + processed_func = self.process_function(func) + # 将单条数据添加到表中 + await table.add([processed_func]) + except Exception as e: + print(f"Error processing function {func.get('name', 'unknown')}: {str(e)}") + continue + + print("Database creation completed!") + + async def search_similar_functions(self, query: str, k: int = 5) -> List[Dict[str, Any]]: + query_embedding = common_get_embedding(query) + table = await self.db.open_table(self.table_name) + return (await table.search(query_embedding).limit(k)).to_list() + + async def get_function_context(self, function_name: str) -> Dict[str, Any]: + table = await self.db.open_table(self.table_name) + results = (await table.filter(f"name = '{function_name}'")).to_list() + return results[0] if results else None \ No newline at end of file diff --git a/src/codebaseQA/rag_processor.py b/src/codebaseQA/rag_processor.py index ed77ae6a..15852c7b 100644 --- a/src/codebaseQA/rag_processor.py +++ b/src/codebaseQA/rag_processor.py @@ -8,13 +8,16 @@ from tqdm import tqdm from openai_api.openai import common_get_embedding +from project.project_audit import ProjectAudit class RAGProcessor: - def __init__(self, functions_to_check: List[Dict[str, Any]], db_path: str = "./lancedb", project_id:str=None): - os.makedirs(db_path, exist_ok=True) - - self.db = lancedb.connect(db_path) - self.table_name = f"lancedb_{project_id}" + def __init__(self, id:str=None, audit:ProjectAudit=None): + self.db_path: str = os.path.join(os.getcwd(),f"lancedb{id}") + self.audit:ProjectAudit = audit + os.makedirs(name=self.db_path, exist_ok=True) + functions_to_check: List[Dict[str, Any]] = audit.functions_to_check + self.db = lancedb.connect(self.db_path) + self.table_name = f"lancedb_{id}" # 创建schema self.schema = pa.schema([ diff --git a/src/dao/__init__.py b/src/dao/__init__.py index 5c427850..e69de29b 100644 --- a/src/dao/__init__.py +++ b/src/dao/__init__.py @@ -1,3 +0,0 @@ - -from .cache_manager import CacheManager -from .task_mgr import ProjectTaskMgr \ No newline at end of file diff --git a/src/dao/aentity.py b/src/dao/aentity.py new file mode 100644 index 00000000..b578ccb2 --- /dev/null +++ b/src/dao/aentity.py @@ -0,0 +1,176 @@ +import random +from sqlalchemy import Column, Integer, String, select +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import sessionmaker, declarative_base +from library.utils import str_hash +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import DeclarativeBase +from nodes_config import Settings + +class Base (AsyncAttrs,DeclarativeBase): + pass + +class ACacheEntry(Base): + __tablename__ = 'prompt_cache2' + index = Column(String, primary_key=True) + key = Column(String) + value = Column(String) + +class AProject_Task(Base): + __tablename__ = 'project_tasks_amazing_prompt' + id = Column(Integer, autoincrement=True, primary_key=True) + key = Column(String, index=True) + project_id = Column(String, index=True) + name = Column(String) + content = Column(String) + keyword = Column(String) + business_type = Column(String) + sub_business_type = Column(String) + function_type = Column(String) + rule = Column(String) + result = Column(String) + result_gpt4 = Column(String) + score=Column(String) + category=Column(String) + contract_code=Column(String) + risklevel=Column(String) + similarity_with_rule=Column(String) + description = Column(String) + start_line=Column(String) + end_line=Column(String) + relative_file_path=Column(String) + absolute_file_path=Column(String) + recommendation=Column(String) + title=Column(String) + business_flow_code=Column(String) + business_flow_lines=Column(String) + business_flow_context=Column(String) + if_business_flow_scan=Column(String) + + fieldNames = ['name', 'content', 'keyword', 'business_type', 'sub_business_type', 'function_type', 'rule', 'result', 'result_gpt4','score','category','contract_code','risklevel','similarity_with_rule','description','start_line','end_line','relative_file_path','absolute_file_path','recommendation','title','business_flow_code','business_flow_lines','business_flow_context','if_business_flow_scan'] + + def __init__(self, project_id, name, content, keyword, business_type, sub_business_type, function_type, rule, result='', result_gpt4='',score='0.00',category='',contract_code='',risklevel='',similarity_with_rule='0.00',description='',start_line='',end_line='',relative_file_path='',absolute_file_path='',recommendation='',title='',business_flow_code='',business_flow_lines='',business_flow_context='',if_business_flow_scan='0'): + self.project_id = project_id + self.name = name + self.content = content + self.keyword = keyword + self.business_type = business_type + self.sub_business_type = sub_business_type + self.function_type = function_type + self.rule = rule + self.result = result + self.result_gpt4 = result_gpt4 + self.key = self.get_key() + self.score=score + self.category=category + self.contract_code=contract_code + self.risklevel=risklevel + self.similarity_with_rule=similarity_with_rule + self.description = description + self.start_line=start_line + self.end_line=end_line + self.relative_file_path=relative_file_path + self.absolute_file_path=absolute_file_path + self.recommendation=recommendation + self.title=title + self.business_flow_code=business_flow_code + self.business_flow_lines=business_flow_lines + self.business_flow_context=business_flow_context + self.if_business_flow_scan=if_business_flow_scan + + def as_dict(self): + return { + 'name': self.name, + 'content': self.content, + 'keyword': self.keyword, + 'business_type': self.business_type, + 'sub_business_type': self.sub_business_type, + 'function_type': self.function_type, + 'rule': self.rule, + 'result': self.result, + 'result_gpt4': self.result_gpt4, + 'score':self.score, + 'category':self.category, + 'contract_code':self.contract_code, + 'risklevel':self.risklevel, + 'similarity_with_rule':self.similarity_with_rule, + 'description': self.description, + 'start_line':self.start_line, + 'end_line':self.end_line, + 'relative_file_path':self.relative_file_path, + 'absolute_file_path':self.absolute_file_path, + 'recommendation':self.recommendation, + 'title':self.title, + 'business_flow_code':self.business_flow_code, + 'business_flow_lines':self.business_flow_lines, + 'business_flow_context':self.business_flow_context, + 'if_business_flow_scan':self.if_business_flow_scan + } + + def set_result(self, result, is_gpt4 = False): + if is_gpt4: + self.result_gpt4 = result + else: + self.result = result + + def get_result(self, is_gpt4 = False): + result = self.result + return None if result == '' else result + + def get_result_CN(self): + result = self.result_gpt4 + return None if result == '' else result + + def get_category(self): + result = self.category + return None if result == '' else result + + def get_key(self): + key = "/".join([self.name, self.content,self.keyword]) + # key = str(random.random()) + return str_hash(key) + + def get_similarity_with_rule(self): + result = self.similarity_with_rule + return None if result == '' else result + +# Create an async engine and sessionmaker +DATABASE_URL = "postgresql+asyncpg://postgres:1234@127.0.0.1:5432/postgres" +engine = create_async_engine(DATABASE_URL, echo=True) +AsyncSessionLocal = sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, +) + +async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + +# Example of using the asynchronous session +async def get_aproject_task_by_id(task_id: int): + async with AsyncSessionLocal() as session: + result = await session.execute(select(AProject_Task).where(AProject_Task.id == task_id)) + return result.scalars().first() + +# Example of inserting a new project task +async def create_aproject_task(project_id, name, content, keyword, business_type, sub_business_type, function_type, rule) -> AProject_Task: + async with AsyncSessionLocal() as session: + new_task = AProject_Task( + project_id=project_id, + name=name, + content=content, + keyword=keyword, + business_type=business_type, + sub_business_type=sub_business_type, + function_type=function_type, + rule=rule + ) + session.add(new_task) + await session.commit() + return new_task + +# Initialize the database on startup +import asyncio +asyncio.run(init_db()) \ No newline at end of file diff --git a/src/dao/atask_mgr.py b/src/dao/atask_mgr.py new file mode 100644 index 00000000..656b5a4f --- /dev/null +++ b/src/dao/atask_mgr.py @@ -0,0 +1,225 @@ +import asyncio +import csv +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.ext.asyncio.engine import AsyncEngine +from sqlalchemy.orm import declarative_base +from sqlalchemy.future import select +from sqlalchemy.exc import IntegrityError +from tqdm.asyncio import tqdm as tqdm_asyncio +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.ext.asyncio import async_sessionmaker +from dao.aentity import AProject_Task +# Base = declarative_base() +class Base(AsyncAttrs, DeclarativeBase): + pass +class AProjectTaskMgr: + def __init__(self, project_id, engine_url): + self.project_id = project_id + # Create async engine + self.engine: AsyncEngine = create_async_engine("postgresql+asyncpg://postgres:1234@127.0.0.1:5432/postgres") + # Ensure table is created + #Base.metadata.create_all(self.engine) + async def init_models(): + async with self.engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + + init_models() + # Configure sessionmaker to use AsyncSession + self.Session = async_sessionmaker(bind=self.engine, class_=AsyncSession) + + async def _operate_in_session(self, func, *args, **kwargs): + """Generic function to handle operations within an async session.""" + async with self.Session() as session: + return await func(session, *args, **kwargs) + + async def add_tasks(self, tasks): + for task in tasks: + await self._operate_in_session(self._add_task, task) + + async def _add_task(self, session, task, commit=True): + try: + key = task.get_key() + # Uncomment if you need to check uniqueness + # ts = (await session.execute(select(Project_Task).filter_by(project_id=self.project_id, key=key))).scalars().all() + # if not ts: # Assuming get_key returns a unique identifier + session.add(task) + if commit: + await session.commit() + except IntegrityError as e: + await session.rollback() + + async def query_task_by_project_id(self, id): + return await self._operate_in_session(self._query_task_by_project_id, id) + + async def _query_task_by_project_id(self, session, id): + result = (await session.execute(select(AProject_Task).filter_by(project_id=id))).scalars().all() + return list(result) + + async def update_score(self, id, score): + await self._operate_in_session(self._update_score, id, score) + + async def _update_score(self, session, id, score): + await session.execute( + select(AProject_Task).filter_by(id=id).values(score=score) + ) + await session.commit() + + async def update_business_flow_context(self, id, context): + await self._operate_in_session(self._update_business_flow_context, id, context) + + async def _update_business_flow_context(self, session, id, context): + await session.execute( + select(AProject_Task).filter_by(id=id).values(business_flow_context=context) + ) + await session.commit() + + async def add_task( + self, + name, + content, + keyword, + business_type, + sub_business_type, + function_type, + rule, + result='', + result_gpt4='', + score='0.00', + category='', + contract_code='', + risklevel='', + similarity_with_rule='', + description='', + start_line='', + end_line='', + relative_file_path='', + absolute_file_path='', + recommendation='', + title='', + business_flow_code='', + business_flow_lines='', + business_flow_context='', + if_business_flow_scan='', + **kwargs + ): + task = AProject_Task( + self.project_id, name, content, keyword, business_type, sub_business_type, + function_type, rule, result, result_gpt4, score, category, contract_code, + risklevel, similarity_with_rule, description, start_line, end_line, + relative_file_path, absolute_file_path, recommendation, title, + business_flow_code, business_flow_lines, business_flow_context, + if_business_flow_scan + ) + await self._operate_in_session(self._add_task, task) + + async def get_task_list(self): + return await self._operate_in_session(self._get_task_list) + + async def _get_task_list(self, session): + result = (await session.execute(select(AProject_Task).filter_by(project_id=self.project_id))).scalars().all() + return list(result) + + async def get_task_list_by_id(self, id): + return await self._operate_in_session(self._get_task_list_by_id, id) + + async def _get_task_list_by_id(self, session, id): + result = (await session.execute(select(AProject_Task).filter_by(project_id=id))).scalars().all() + return list(result) + + async def update_result(self, id, result, result_gpt4, result_assumation): + await self._operate_in_session(self._update_result, id, result, result_gpt4, result_assumation) + + async def _update_result(self, session, id, result, result_gpt4, result_assumation): + await session.execute( + select(AProject_Task).filter_by(id=id).values( + result=result, + result_gpt4=result_gpt4, + category=result_assumation + ) + ) + await session.commit() + + async def update_similarity_generated_referenced_score(self, id, similarity_with_rule): + await self._operate_in_session(self._update_similarity_generated_referenced_score, id, similarity_with_rule) + + async def _update_similarity_generated_referenced_score(self, session, id, similarity_with_rule): + await session.execute( + select(AProject_Task).filter_by(id=id).values(similarity_with_rule=similarity_with_rule) + ) + await session.commit() + + async def update_description(self, id, description): + await self._operate_in_session(self._update_description, id, description) + + async def _update_description(self, session, id, description): + await session.execute( + select(AProject_Task).filter_by(id=id).values(description=description) + ) + await session.commit() + + async def update_recommendation(self, id, recommendation): + await self._operate_in_session(self._update_recommendation, id, recommendation) + + async def _update_recommendation(self, session, id, recommendation): + await session.execute( + select(AProject_Task).filter_by(id=id).values(recommendation=recommendation) + ) + await session.commit() + + async def update_title(self, id, title): + await self._operate_in_session(self._update_title, id, title) + + async def _update_title(self, session, id, title): + await session.execute( + select(AProject_Task).filter_by(id=id).values(title=title) + ) + await session.commit() + + async def import_file(self, filename): + reader = csv.DictReader(open(filename, 'r', encoding='utf-8')) + + processed = 0 + for row in tqdm_asyncio(list(reader), "import tasks"): + await self.add_task(**row) + processed += 1 + if processed % 10 == 0: + await self._operate_in_session(lambda s: s.commit()) + await self._operate_in_session(lambda s: s.commit()) + + def dump_file(self, filename): + writer = self.get_writer(filename) + + async def write_rows(): + ts = await self._operate_in_session(self._get_task_list) + for row in ts: + writer.writerow(row.as_dict()) + + # Run the asynchronous task within an event loop + import asyncio + asyncio.run(write_rows()) + + del writer + + def get_writer(self, filename): + file = open(filename, 'w', newline='', encoding='utf-8') + writer = csv.DictWriter(file, fieldnames=AProject_Task.fieldNames) + writer.writeheader() # write header + return writer + + def merge_results(self, function_rules): + rule_map = {} + for rule in function_rules: + keys = [ + rule['name'], + rule['content'], + rule['BusinessType'], + rule['Sub-BusinessType'], + rule['FunctionType'], + rule['KeySentence'] + ] + key = "/".join(keys) + rule_map[key] = rule + + return rule_map.values() \ No newline at end of file diff --git a/src/dataset/agent-v1-c4/datasets.json b/src/dataset/agent-v1-c4/datasets.json index a937f41d..2eaa64ea 100644 --- a/src/dataset/agent-v1-c4/datasets.json +++ b/src/dataset/agent-v1-c4/datasets.json @@ -213,8 +213,8 @@ "files":[], "functions":[] }, - "zaros":{ - "path":"zaros", + "gamma2":{ + "path":"gamma", "files":[], "functions":[] } diff --git a/src/dynamic_governor.py b/src/dynamic_governor.py new file mode 100644 index 00000000..8ff3885f --- /dev/null +++ b/src/dynamic_governor.py @@ -0,0 +1,48 @@ +import asyncio +import time +from collections import deque + +class DynamicGovernor: + def __init__(self, initial_limit=5, window_size=10, fast_threshold=1.0, slow_threshold=3.0): + self.limit = initial_limit + self.semaphore = asyncio.Semaphore(self.limit) + self.window_size = window_size + # Store the last window_size durations + self.task_durations = deque(maxlen=window_size) + self.fast_threshold = fast_threshold + self.slow_threshold = slow_threshold + + async def acquire(self): + await self.semaphore.acquire() + + def release(self, duration: float): + self.task_durations.append(duration) + self.semaphore.release() + self.adjust_limit() + + def adjust_limit(self): + if not self.task_durations: + return + # Compute the moving average of task durations + avg_duration = sum(self.task_durations) / len(self.task_durations) + new_limit = self.limit + if avg_duration < self.fast_threshold: + new_limit = self.limit * 2 + elif avg_duration > self.slow_threshold: + new_limit = max(1, int(self.limit * 2 / 3)) + # Only adjust if there is a change + if new_limit != self.limit: + print(f"Adjusting concurrency limit from {self.limit} to {new_limit} (avg_duration={avg_duration:.2f}s)") + self.limit = new_limit + # Reinitialize the semaphore to the new limit. + self.semaphore = asyncio.Semaphore(self.limit) + +async def governed_task(task_fn, governor: DynamicGovernor, *args, **kwargs): + await governor.acquire() + start = time.monotonic() + try: + result = await task_fn(*args, **kwargs) + return result + finally: + duration = time.monotonic() - start + governor.release(duration) diff --git a/src/knowledges/__init__.py b/src/knowledges/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/library/__init__.py b/src/library/__init__.py index 8ceee087..e69de29b 100644 --- a/src/library/__init__.py +++ b/src/library/__init__.py @@ -1,4 +0,0 @@ -import sys, os -sys.path.append(os.path.abspath(os.path.dirname(__file__))) - -from .parsing import * diff --git a/src/library/app.css b/src/library/app.css new file mode 100644 index 00000000..9e5c0188 --- /dev/null +++ b/src/library/app.css @@ -0,0 +1,45 @@ +/* app.css */ + +/* Style for the input widget */ +#text_input { + height: 3; + padding: 1; + border: tall #00FF00; + background: #1E1E1E; + color: white; +} + +/* Style for the dual language view container */ +#dual_view { + margin-top: 1; + margin-bottom: 1; +} + +/* Style for the detail view widget */ +#detail_view { + margin-top: 1; + margin-bottom: 1; + background: #222222; + border: round #888888; + padding: 1; + color: #CCCCCC; +} + +/* Header and Footer styling */ +Header { + background: #333333; + color: #FFFFFF; + padding: 1; +} + +Footer { + background: #333333; + color: #FFFFFF; + padding: 1; +} + +/* Optional: Style for panels within the dual view */ +Panel { + border: round #555555; + padding: 1; +} diff --git a/src/library/parsing/SolidityLexer.py b/src/library/parsing/SolidityLexer.py index cfba6b5b..ce7116db 100644 --- a/src/library/parsing/SolidityLexer.py +++ b/src/library/parsing/SolidityLexer.py @@ -1,11 +1,9 @@ # Generated from .\SolidityLexer.g4 by ANTLR 4.12.0 -from antlr4 import * +# Removed wildcard import; explicit imports follow from io import StringIO import sys -if sys.version_info[1] > 5: - from typing import TextIO -else: - from typing.io import TextIO +from antlr4 import Lexer, LexerATNSimulator, ATNDeserializer, DFA, PredictionContextCache +from typing import TextIO def serializedATN(): diff --git a/src/library/parsing/__init__.py b/src/library/parsing/__init__.py index 3c52448f..71f8e89c 100644 --- a/src/library/parsing/__init__.py +++ b/src/library/parsing/__init__.py @@ -1,14 +1 @@ -from antlr4 import CommonTokenStream, InputStream -from .SolidityLexer import SolidityLexer -from .SolidityParser import parseString -from .callgraph import CallGraph - -def parse(source:str): - lexer = SolidityLexer(InputStream(source)) - stream = CommonTokenStream(lexer) - parser = SolidityParser(stream) - return parser.sourceUnit() - -def get_tokens(source:str): - lexer = SolidityLexer(InputStream(source)) - return lexer.getAllTokens() +from .SolidityParser import parseString \ No newline at end of file diff --git a/src/library/parsing/callgraph.py b/src/library/parsing/callgraph.py index 68c4b241..12fbcc14 100644 --- a/src/library/parsing/callgraph.py +++ b/src/library/parsing/callgraph.py @@ -90,10 +90,8 @@ def __init__(self, root:str) -> None: self.root = root self.files = {} self.call_data = {} - self.__parse_all_files() self.__run_jar() - self.__clean() def get_rel_path(self, path:str)->str: @@ -137,6 +135,7 @@ def __parse_all_files(self): continue self.files[os.path.abspath(os.path.join(root, file))] = parseString(open(os.path.join(root, file), "r", encoding="utf-8", errors="ignore").read()) + ### TODO!! asyncio.create_subprocess_exec def __run_jar(self): dir_name = os.path.abspath(os.path.dirname(__file__)) jar_file = os.path.join(dir_name, "jars/SolidityCallgraph-1.0-SNAPSHOT-standalone.jar") diff --git a/src/library/sgp/parser/SolidityLexer.py b/src/library/sgp/parser/SolidityLexer.py index f7dfdf27..e8bd5833 100644 --- a/src/library/sgp/parser/SolidityLexer.py +++ b/src/library/sgp/parser/SolidityLexer.py @@ -1,11 +1,9 @@ # Generated from Solidity.g4 by ANTLR 4.13.1 -from antlr4 import * +from antlr4 import DFA, ATNDeserializer, LexerATNSimulator, PredictionContextCache +from antlr4.Lexer import Lexer from io import StringIO import sys -if sys.version_info[1] > 5: - from typing import TextIO -else: - from typing.io import TextIO +from typing import TextIO def serializedATN(): diff --git a/src/library/sgp/parser/SolidityParser.py b/src/library/sgp/parser/SolidityParser.py index 8db8235d..a11a3e0f 100644 --- a/src/library/sgp/parser/SolidityParser.py +++ b/src/library/sgp/parser/SolidityParser.py @@ -1,12 +1,13 @@ # Generated from Solidity.g4 by ANTLR 4.13.1 # encoding: utf-8 -from antlr4 import * +from antlr4 import DFA, Parser, RuleContext, Token, TokenStream, ParserRuleContext, PredictionContextCache, ParseTreeVisitor, ParseTreeListener, ParserATNSimulator from io import StringIO import sys -if sys.version_info[1] > 5: - from typing import TextIO -else: - from typing.io import TextIO +from antlr4.atn.ATNDeserializer import ATNDeserializer +from antlr4 import RecognitionException +from antlr4.atn.ATN import ATN +from antlr4.error.Errors import NoViableAltException +from typing import TextIO def serializedATN(): return [ @@ -9011,7 +9012,7 @@ def stringLiteral(self): def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): - if self._predicates == None: + if not hasattr(self, '_predicates') or self._predicates is None: self._predicates = dict() self._predicates[38] = self.typeName_sempred self._predicates[70] = self.expression_sempred diff --git a/src/library/sgp/sgp_parser.py b/src/library/sgp/sgp_parser.py index 771755f8..d2ba9ae3 100644 --- a/src/library/sgp/sgp_parser.py +++ b/src/library/sgp/sgp_parser.py @@ -4,15 +4,13 @@ import re from antlr4.CommonTokenStream import CommonTokenStream from antlr4.InputStream import InputStream as ANTLRInputStream - -from .parser.SolidityLexer import SolidityLexer -from .parser.SolidityParser import SolidityParser - -from .sgp_visitor import SGPVisitorOptions, SGPVisitor,SolidityInfoVisitor -from .sgp_error_listener import SGPErrorListener -from .ast_node_types import SourceUnit -from .tokens import build_token_list -from .utils import string_from_snake_to_camel_case +from library.sgp.parser.SolidityLexer import SolidityLexer +from library.sgp.parser.SolidityParser import SolidityParser +from library.sgp.sgp_visitor import SGPVisitorOptions, SGPVisitor,SolidityInfoVisitor +from library.sgp.sgp_error_listener import SGPErrorListener +from library.sgp.ast_node_types import SourceUnit +from library.sgp.tokens import build_token_list +from library.sgp.utils import string_from_snake_to_camel_case class ParserError(Exception): diff --git a/src/library/sgp/sgp_visitor.py b/src/library/sgp/sgp_visitor.py index da78b411..bce441ea 100644 --- a/src/library/sgp/sgp_visitor.py +++ b/src/library/sgp/sgp_visitor.py @@ -1,20 +1,103 @@ -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Union, cast from typing_extensions import override +from library.sgp.ast_node_types import ( + UNARY_OP_VALUES, + InheritanceSpecifier, + NameValueExpression, + BINARY_OP_VALUES, + Conditional, + NameValueList, + ArrayTypeName, + AssemblyBlock, + AssemblyCase, + AssemblyCall, + AssemblyFunctionDefinition, + AssemblyIf, + AssemblyFor, + AssemblyAssignment, + AssemblyItem, + AssemblyLocalDefinition, + AssemblyMemberAccess, + AssemblyStackAssignment, + AssemblySwitch, + BaseASTNode, + Block, + BinaryOperation, + BooleanLiteral, + Break, + BreakStatement, + CatchClause, + Continue, + ContinueStatement, + ContractDefinition, + CustomErrorDefinition, + DecimalNumber, + DoWhileStatement, + ElementaryTypeName, + EmitStatement, + EnumDefinition, + EnumValue, + EventDefinition, + Expression, + ExpressionStatement, + FileLevelConstant, + ForStatement, + FunctionCall, + FunctionDefinition, + FunctionTypeName, + HexLiteral, + HexNumber, + Identifier, + IfStatement, + ImportDirective, + IndexAccess, + IndexRangeAccess, + LabelDefinition, + Location, + Mapping, + MemberAccess, + ModifierDefinition, + ModifierInvocation, + NewExpression, + NumberLiteral, + PrimaryExpression, + PragmaDirective, + Range, + ReturnStatement, + RevertStatement, + SimpleStatement, + SourceUnit, + StateVariableDeclaration, + StateVariableDeclarationVariable, + Statement, + StringLiteral, + StructDefinition, + ThrowStatement, + TryStatement, + TupleExpression, + TypeDefinition, + TypeName, + UnaryOperation, + UncheckedStatement, + UserDefinedTypeName, + UsingForDeclaration, + VariableDeclaration, + VariableDeclarationStatement, + WhileStatement, + InlineAssemblyStatement, # added inline assembly type + AssemblyExpression, # added assembly expression type +) from antlr4.tree.Tree import ErrorNode from antlr4 import ParserRuleContext from antlr4.tree.Tree import ParseTree,ParseTreeVisitor from antlr4.error.Errors import RecognitionException -from sgp.utilities.contract_extractor import extract_contract_with_name +from library.sgp.utilities.contract_extractor import extract_contract_with_name -from .parser.SolidityParser import SolidityParser as SP -from .parser.SolidityVisitor import SolidityVisitor - - -from .ast_node_types import * - +from library.sgp.parser.SolidityParser import SolidityParser as SP +from library.sgp.parser.SolidityVisitor import SolidityVisitor class SGPVisitorOptions: def __init__( @@ -43,7 +126,7 @@ def __init__( class SGPVisitor(SolidityVisitor): def __init__(self, options: SGPVisitorOptions): super().__init__() - self._current_contract = None + self._current_contract: Optional[str] = None self._options = options @override @@ -169,7 +252,7 @@ def visitVariableDeclaration( identifierCtx = ctx.identifier() node = VariableDeclaration( - type_name=self.visitTypeName(ctx.typeName()), + type_name=cast(TypeName, self.visitTypeName(ctx.typeName())), name=self._to_text(identifierCtx), identifier=self.visitIdentifier(identifierCtx), storage_location=storageLocation, @@ -199,11 +282,11 @@ def visitVariableDeclarationStatement( ctxExpression = ctx.expression() if ctxExpression: initialValue = self.visitExpression(ctxExpression) - - node = VariableDeclarationStatement( - variables=variables, + node = VariableDeclarationStatement( + variables=[cast(Optional[BaseASTNode], v) for v in variables], initial_value=initialValue, - ) + ) + return self._add_meta(node, ctx) @@ -213,7 +296,7 @@ def visitStatement(self, ctx: SP.StatementContext) -> Statement: def visitSimpleStatement( self, ctx: SP.SimpleStatementContext ) -> SimpleStatement: - if ctx.children == None: + if ctx.children is None: return self.visitErrorNode(ctx.start) return self.visit(ctx.getChild(0)) # Assuming the child type is SimpleStatement @@ -223,7 +306,7 @@ def visitEventDefinition( parameters = [ self._add_meta( VariableDeclaration( - type_name=self.visitTypeName(paramCtx.typeName()), + type_name=cast(TypeName, self.visitTypeName(paramCtx.typeName())), name=self._to_text(paramCtx.identifier()) if paramCtx.identifier() else None, @@ -472,7 +555,7 @@ def visitUsingForDeclaration( # using Lib for ... node = UsingForDeclaration( is_global=isGlobal, - type_name=typeName, + type_name=cast(TypeName, typeName), library_name=self._to_text(userDefinedTypeNameCtx), functions=[], operators=[], @@ -492,7 +575,7 @@ def visitUsingForDeclaration( node = UsingForDeclaration( is_global=isGlobal, - type_name=typeName, + type_name=cast(TypeName, typeName), library_name=None, functions=functions, operators=operators, @@ -608,7 +691,7 @@ def visitFunctionTypeParameter( return self._add_meta(node, ctx) def visitThrowStatement(self, ctx: SP.ThrowStatementContext) -> ThrowStatement: - node = ThrowStatement(type="ThrowStatement") + node = ThrowStatement() return self._add_meta(node, ctx) @@ -916,9 +999,9 @@ def visitExpression(self, ctx: SP.ExpressionContext) -> Expression: ): node = TupleExpression( components=[ - self.visitExpression( + cast(BaseASTNode, self.visitExpression( ctx.getTypedRuleContext(SP.ExpressionContext, 0) - ) + )) ], isArray=False, ) @@ -1144,7 +1227,7 @@ def visitPrimaryExpression(self, ctx: SP.PrimaryExpressionContext) -> Union[Prim fragments_info.append({"value": value, "is_unicode": is_unicode}) parts = [x["value"] for x in fragments_info] - node = StringLiteral(value="".join(parts), parts=parts, is_unicode=[x["is_unicode"] for x in fragments_info]) + node = StringLiteral(value="".join(str(p) for p in parts), parts=parts, is_unicode=[x["is_unicode"] for x in fragments_info]) return self._add_meta(node, ctx) if ctx.numberLiteral(): @@ -1157,7 +1240,7 @@ def visitPrimaryExpression(self, ctx: SP.PrimaryExpressionContext) -> Union[Prim if ctx.typeName(): return self.visitTypeName(ctx.typeName()) - if ctx.children == None: + if ctx.children is None: return self.visitErrorNode(ctx.start) return self.visit(ctx.getChild(0)) @@ -1594,7 +1677,7 @@ def visitBreakStatement(self, ctx: SP.BreakStatementContext) -> BreakStatement: return self._add_meta(node, ctx) - def _to_text(self, ctx: ParserRuleContext or ParseTree) -> str: + def _to_text(self, ctx: Union[ParserRuleContext | ParseTree]) -> str: text = ctx.getText() if text is None: raise ValueError("Assertion error: text should never be undefined") @@ -1626,21 +1709,20 @@ def _loc(self, ctx) -> Location: ) return source_location + def _range(self, ctx) -> Range: + start_offset = ctx.start.start + end_offset = ctx.stop.stop if ctx.stop else ctx.start.stop + return Range(start=start_offset, end=end_offset) - def _range(self, ctx) -> Tuple[int, int]: - return Range(ctx.start.start, ctx.stop.stop if ctx.stop else ctx.start.stop) - - def _add_meta( - self, node: Union[BaseASTNode, NameValueList], ctx - ) -> Union[BaseASTNode, NameValueList]: + from typing import TypeVar + T = TypeVar("T", bound=BaseASTNode) + + def _add_meta(self, node: T, ctx) -> T: # node_with_meta = {"type": node.type} - if self._options.loc: node.add_loc(self._loc(ctx)) - if self._options.range: node.add_range(self._range(ctx)) - return node def _map_commas_to_nulls( @@ -1649,7 +1731,7 @@ def _map_commas_to_nulls( if len(children) == 0: return [] - values = [] + values: list[Optional[ParseTree]] = [] comma = True for el in children: diff --git a/src/library/sgp/tool/PDG_parser.py b/src/library/sgp/tool/PDG_parser.py index 12dfcf76..0b0449c5 100644 --- a/src/library/sgp/tool/PDG_parser.py +++ b/src/library/sgp/tool/PDG_parser.py @@ -2,9 +2,9 @@ import antlr4 from collections import defaultdict -from parser.SolidityLexer import SolidityLexer -from parser.SolidityParser import SolidityParser -from parser.SolidityListener import SolidityListener +from library.parsing.SolidityLexer import SolidityLexer +from library.parsing.SolidityParser import SolidityParser +from library.sgp.parser.SolidityListener import SolidityListener class FunctionListener(SolidityListener): def __init__(self): @@ -154,12 +154,12 @@ def enterFunctionCall(self, ctx:SolidityParser.FunctionCallContext): if self.current_function: # We are inside a function and we found a function call a = ctx.getText() - called_function = re.sub('\(.*\)$', '', ctx.getText()) # get the name of the called function + called_function = re.sub('(.*)$', '', ctx.getText()) # get the name of the called function self.functions[self.current_function]['calls'].add(called_function) self.call_graph[self.current_function].add(called_function) self.callers_graph[called_function].add(self.current_function) - def enterIdentifier(self, ctx:SolidityParser.Identifier): + def enterIdentifier(self, ctx:SolidityParser): if self.current_function: identifier = ctx.getText() if identifier and identifier in self.state_variables: diff --git a/src/library/sgp/tool/__init__.py b/src/library/sgp/tool/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/library/sgp/utilities/__init__.py b/src/library/sgp/utilities/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/library/sgp/utilities/contract_extractor.py b/src/library/sgp/utilities/contract_extractor.py index 1f15874d..70031cc3 100644 --- a/src/library/sgp/utilities/contract_extractor.py +++ b/src/library/sgp/utilities/contract_extractor.py @@ -1,10 +1,10 @@ import re from antlr4 import * -from sgp.parser.SolidityLexer import SolidityLexer -from sgp.parser.SolidityParser import SolidityParser -from sgp.parser.SolidityListener import SolidityListener +from library.sgp.parser.SolidityLexer import SolidityLexer +from library.sgp.parser.SolidityParser import SolidityParser +from library.sgp.parser.SolidityListener import SolidityListener from colorama import Fore, init - +from os import path def extract_solc_version(filename): with open(filename, 'r') as file: @@ -332,7 +332,7 @@ def extract_function_from_solidity(function_name, solidity_file_path): match = function_pattern.search(contract_body) if match is None: - raise ValueError(f"No function found with name: {function_name} in contract: {contract_name}") + raise ValueError(f"No function found with name: {function_name} in contract: {path.basename(solidity_file_path)}") start = match.start() open_braces = 0 diff --git a/src/main.py b/src/main.py index a5cd6727..624824a4 100644 --- a/src/main.py +++ b/src/main.py @@ -1,5 +1,4 @@ import argparse -import ast import os import time import audit_config @@ -7,18 +6,14 @@ from project import ProjectAudit from library.dataset_utils import load_dataset, Project from planning import PlanningV2 -from prompts import prompts from sqlalchemy import create_engine from dao import CacheManager, ProjectTaskMgr -import os import pandas as pd from openpyxl import Workbook,load_workbook from openpyxl.utils.dataframe import dataframe_to_rows from codebaseQA.rag_processor import RAGProcessor from res_processor.res_processor import ResProcessor -import dotenv -dotenv.load_dotenv() def scan_project(project, db_engine): # 1. parsing projects @@ -136,19 +131,19 @@ def generate_excel(output_path, project_id): if switch_production_or_test == 'test': start_time=time.time() db_url_from = os.environ.get("DATABASE_URL") - engine = create_engine(db_url_from) + engine = create_engine(db_url_from, echo=True) dataset_base = "./src/dataset/agent-v1-c4" projects = load_dataset(dataset_base) - project_id = 'zaros' + project_id = 'gamma2' project_path = '' project = Project(project_id, projects[project_id]) cmd = 'detect_vul' if cmd == 'detect_vul': lancedb,lance_table_name,project_audit=scan_project(project, engine) # scan - if os.getenv("SCAN_MODE","SPECIFIC_PROJECT")=="SPECIFIC_PROJECT" or os.getenv("SCAN_MODE","SPECIFIC_PROJECT")=="COMMON_PROJECT": + if os.getenv("SCAN_MODE","SPECIFIC_PROJECT") in ["SPECIFIC_PROJECT","COMMON_PROJECT","PURE_SCAN"]: check_function_vul(engine,lancedb,lance_table_name,project_audit) # confirm @@ -186,7 +181,7 @@ def generate_excel(output_path, project_id): # Database setup db_url_from = os.environ.get("DATABASE_URL") - engine = create_engine(db_url_from) + engine = create_engine(db_url_from, echo=True) # Load projects projects = load_dataset(dataset_base, args.id, folder_name) @@ -199,7 +194,7 @@ def generate_excel(output_path, project_id): # check_function_vul(engine) # confirm # elif args.cmd == 'all': lancedb=scan_project(project, engine) # scan - check_function_vul(engine,lancedb) # confirm + check_function_vul(engine,lancedb, lance_table_name, project_audit) # confirm end_time = time.time() print("Total time:", end_time -start_time) diff --git a/src/main_pipeline.py b/src/main_pipeline.py new file mode 100644 index 00000000..1ad6a42f --- /dev/null +++ b/src/main_pipeline.py @@ -0,0 +1,84 @@ +import asyncio +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIModel + +# --- Phase 1: Planning Phase --- +async def planning_phase(code_snippet: str) -> dict: + planning_model = OpenAIModel( + model_name="openai-planning-model", + base_url="http://127.0.0.1:11434/v1" + ) + planning_agent = Agent(planning_model, retries=3, result_type=dict) + prompt = ( + f"Analyze the following code and extract its business flow as JSON:\n\n{code_snippet}" + ) + planning_result = await planning_agent.run(prompt) + print("[Planning] Result:", planning_result) + return planning_result + +# --- Phase 2: Scanning Phase with Dynamic Governor --- +# Import the dynamic governor from above +from dynamic_governor import DynamicGovernor, governed_task + +async def scanning_phase(business_flow: dict) -> dict: + scanning_model = OpenAIModel( + model_name="openai-scanning-model", + base_url="http://127.0.0.1:11434/v1" + ) + scanning_agent = Agent(scanning_model, retries=3, result_type=dict) + prompt = ( + f"Given the business flow JSON:\n{business_flow}\n" + f"Identify potential vulnerabilities and output a JSON summary." + ) + # Initialize dynamic governor for this phase + governor = DynamicGovernor(initial_limit=5, window_size=10) + # Wrap the scanning API call in a governed task: + async def scan_call(p): + return await scanning_agent.run(p) + scanning_result = await governed_task(scan_call, governor, prompt) + print("[Scanning] Result:", scanning_result) + return scanning_result + +# --- Phase 3: Confirmation Phase --- +async def confirmation_phase(scan_result: dict) -> dict: + confirmation_model = OpenAIModel( + model_name="openai-confirmation-model", + base_url="http://127.0.0.1:11434/v1" + ) + confirmation_agent = Agent(confirmation_model, retries=3, result_type=dict) + prompt = ( + f"Review the following scan result:\n{scan_result}\n" + f"Confirm or refute the findings and output your conclusion in JSON." + ) + confirmation_result = await confirmation_agent.run(prompt) + print("[Confirmation] Result:", confirmation_result) + return confirmation_result + +# Phase 4: Final Aggregation / Message Pump +async def final_phase(confirmation_data: dict, run_context): + final_model = OpenAIModel( + model_name="openai-final-model", + base_url="http://127.0.0.1:11434/v1" + ) + final_agent = Agent(final_model, retries=3, result_type=str) + prompt = ( + f"Aggregate the following confirmation data into a final report in Markdown:\n{confirmation_data}\n " + f"Output the final report." + ) + async with final_agent.run_stream(prompt) as stream: + async for message in stream.stream(): + # Inject output into the local RunContext for rendering. + run_context.update_markdown(message) + +# Main Pipeline: Chaining all phases together. +async def main_pipeline(run_context): + # In practice, the code snippet might come from your project audit/planning module. + code_snippet = """ + function transfer(address recipient, uint256 amount) public returns (bool) { + // business logic here... + } + """ + plan_result = await planning_phase(code_snippet) + scan_result = await scanning_phase(plan_result) + confirmation_result = await confirmation_phase(scan_result) + await final_phase(confirmation_result, run_context) \ No newline at end of file diff --git a/src/nodes_config.py b/src/nodes_config.py new file mode 100644 index 00000000..1355eba9 --- /dev/null +++ b/src/nodes_config.py @@ -0,0 +1,104 @@ +from typing import Tuple, Type +from griffe import DocstringStyle +from pydantic_settings import BaseSettings, CliPositionalArg, PydanticBaseSettingsSource, PyprojectTomlConfigSettingsSource, SettingsConfigDict +from typing import Tuple, Type +from os import environ +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + PyprojectTomlConfigSettingsSource, + SettingsConfigDict, +) +class Settings(BaseSettings): + model_config:SettingsConfigDict = SettingsConfigDict( + cli_parse_args=True, + cli_prog_name='finite-monkey-engine', + pyproject_toml_depth=0, + pyproject_toml_table_header=('tool', 'finite-monkey-engine'), + toml_file='pyproject.toml', + extra='ignore', + env_file ='.env', + strict=False + ) + #database_dsn:PostgresDsn = None + id:str="" + base_dir:str="" + src_dir:str="" + output:str="." + AZURE_OR_OPENAI:str="" + AZURE_API_BASE:str="" + AZURE_API_KEY:str="" + AZURE_API_VERSION:str="" + AZURE_DEPLOYMENT_NAME:str="" + BUSINESS_FLOW_COUNT:str="10" + CLAUDE_MODEL:str="" + COMMON_PROJECT:str="" + COMMON_VUL:str="all" + CONFIRMATION_MODEL:str="" + DATABASE_SQLITE:str="" + DATABASE_SETTINGS_URL:str="" + DATABASE_URL:str="postgresql://postgres:1234@127.0.0.1:5432/postgres" + ASYNC_DB_URL:str="postgresql+asyncpg://postgres:1234@127.0.0.1:5432/postgres" + IGNORE_FOLDERS:str="test" + MAX_THREADS_OF_CONFIRMATION:str="8" + MAX_THREADS_OF_SCAN:str="8" + OPENAI_API_BASE:str="" + OPENAI_API_KEY:str="" + OPENAI_MODEL:str="" + OPTIMIZE:str="" + PRE_TRAIN_MODEL:str="" + SCAN_MODE:str="all" + SPECIFIC_PROJECT:str="" + SWITCH_BUSINESS_CODE:str="True" + SWITCH_FUNCTION_CODE:str="False" + GEMINI_API_KEY:str="k" + PYTHONASYNCIODEBUG:str="1" + FORCE_COLOR:str="1" + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + env_settings: PydanticBaseSettingsSource, + init_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + + ) -> Tuple[PydanticBaseSettingsSource, ...]: + return ( + init_settings, + env_settings, + dotenv_settings, + PyprojectTomlConfigSettingsSource(settings_cls), + ) + +class nodes_config(Settings): + def nodes_config(self): + model_config:SettingsConfigDict = SettingsConfigDict( + cli_parse_args=True, + cli_prog_name='finite-monkey-engine', + pyproject_toml_depth=0, + pyproject_toml_table_header=('tool', 'finite-monkey-engine'), + toml_file='pyproject.toml', + extra='ignore', + env_file ='.env', + env_ignore_empty=True, + strict=False, + ) + self.settings = Settings() + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[Settings], + env_settings: PydanticBaseSettingsSource, + init_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + + ) -> Tuple[PydanticBaseSettingsSource, ...]: + return ( + init_settings, + PyprojectTomlConfigSettingsSource(settings_cls), + env_settings, + dotenv_settings, + ) + diff --git a/src/openai_api/__init__.py b/src/openai_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/openai_api/openai.py b/src/openai_api/openai.py index 13bab5b9..eccde0d5 100644 --- a/src/openai_api/openai.py +++ b/src/openai_api/openai.py @@ -3,6 +3,8 @@ import numpy as np import requests from openai import OpenAI +from pydantic_ai import Agent +from pydantic_ai.models import * def azure_openai(prompt): # Azure OpenAI配置 @@ -112,7 +114,7 @@ def ask_openai_for_json(prompt): "Authorization": f"Bearer {api_key}" } data = { - "model": os.environ.get('OPENAI_MODEL'), + "model": os.environ.get('OPENAI_MODEL', 'hf.co/unsloth/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q8_0'), "response_format": { "type": "json_object" }, "messages": [ { @@ -125,7 +127,7 @@ def ask_openai_for_json(prompt): } ] } - response = requests.post(f'https://{api_base}/v1/chat/completions', headers=headers, json=data) + response = requests.post(f'{api_base}/v1/chat/completions', headers=headers, json=data) if response.status_code != 200: print(response.text) response_josn = response.json() @@ -262,7 +264,7 @@ def common_get_embedding(text: str): raise ValueError("OPENAI_API_KEY environment variable is not set") api_base = os.getenv('OPENAI_API_BASE', 'api.openai.com') - model = os.getenv("PRE_TRAIN_MODEL", "text-embedding-3-large") + model = os.getenv("EMBEDDING_MODEL", "text-embedding-3-large") headers = { "Authorization": f"Bearer {api_key}", diff --git a/src/planning/__init__.py b/src/planning/__init__.py index 6e753106..e69de29b 100644 --- a/src/planning/__init__.py +++ b/src/planning/__init__.py @@ -1,4 +0,0 @@ -import os -import pathlib -from .planning_v2 import PlanningV2 - diff --git a/src/planning/aplanning_v2.py b/src/planning/aplanning_v2.py new file mode 100644 index 00000000..45a65efb --- /dev/null +++ b/src/planning/aplanning_v2.py @@ -0,0 +1,329 @@ +import asyncio +import json +import os +from typing import AsyncGenerator, Callable, Any + +from dao.atask_mgr import AProjectTaskMgr + +# Assume AProjectTaskMgr and other dependencies are imported + +# class APlanningV2: +# def __init__(self, project, mgr:AProjectTaskMgr ): +# self.project = project +# self.taskmgr = AProjectTaskMgr +# self.scan_list_for_larget_context: list[Any] = [] + +# async def do_planning(self) -> AsyncGenerator[Callable[[], Any], None]: +# """ +# Asynchronously plan the business flows for each function. +# Instead of returning a final result, yield a lambda that, when invoked, +# returns a future (a coroutine) for each iteration of planning. +# This allows an external agent to evaluate each stage as a future. +# """ +# # For example, fetch tasks from the DB. +# tasks = await self.taskmgr.get_task_list_by_id(self.project.project_id) +# if tasks: +# # If tasks already exist, we consider planning already done. +# return + +# # Filter functions (example logic). +# functions_to_check = [f for f in self.project.functions_to_check if "test" not in f['name']] +# self.project.functions_to_check = functions_to_check + +# # If business code switching is enabled, get all business flows. +# switch_business_code = eval(os.environ.get('SWITCH_BUSINESS_CODE', 'True')) +# if switch_business_code: +# all_business_flow, all_business_flow_line, all_business_flow_context = await self.get_all_business_flow(self.project.functions_to_check) +# else: +# all_business_flow = all_business_flow_line = all_business_flow_context = {} + +# # Iterate over each function to plan its business flow. +# for function in self.project.functions_to_check: +# name = function['name'] +# contract_name = function['contract_name'] +# # Log processing +# print(f"Processing function: {name}") + +# # Use your existing search_business_flow to get the flow code. +# business_flow_code, line_info_list, other_contract_context = await self.search_business_flow( +# all_business_flow, all_business_flow_line, all_business_flow_context, +# name.split(".")[1], contract_name +# ) +# # Build a prompt for type checking. +# type_check_prompt = ( +# "分析以下智能合约代码,判断它属于哪些业务类型。可能的类型包括:\n" +# "chainlink, dao, inline assembly, lending, liquidation, liquidity manager, signature, " +# "slippage, univ3, other\n" +# "请以JSON格式返回结果,格式为: {\"business_types\": [\"type1\", \"type2\"]}\n\n" +# "代码:\n{0}" +# ) +# formatted_prompt = type_check_prompt.format(business_flow_code + "\n" + other_contract_context + "\n" + function['content']) + +# # Instead of directly calling a REST method to get the answer, yield a lambda +# # that returns a coroutine to be evaluated later by the caller. +# yield lambda: self.common_ask_for_json(formatted_prompt) + +# # Optionally, you could also yield the current intermediate planning context. +# # For example: +# # yield lambda: asyncio.sleep(0, result={"function": name, "prompt": formatted_prompt}) + +# async def common_ask_for_json(self, prompt: str) -> dict: +# """ +# Simulate an asynchronous call to an LLM API that returns a JSON result. +# In your production code, this would be replaced with a pydantic-ai Agent call. +# """ +# # Simulated delay +# await asyncio.sleep(0.5) +# # For demonstration, return a dummy JSON response. +# # In reality, your agent would process the prompt. +# dummy_response = {"business_types": ["example_type1", "example_type2"]} +# print(f"Processed prompt: {prompt[:60]}... -> {dummy_response}") +# return dummy_response + +# # Placeholder methods for get_all_business_flow and search_business_flow +# async def get_all_business_flow(self, functions_to_check): +# # Simulated async processing. +# await asyncio.sleep(0.5) +# # Return dummy dictionaries. +# return {}, {}, {} + +# async def search_business_flow(self, all_flow, all_flow_line, all_flow_context, function_name, contract_name): +# # Simulated async search. +# await asyncio.sleep(0.5) +# # Return dummy business flow code and context. +# return f"business flow code for {function_name}", [], f"context for {contract_name}" +import asyncio +import json +import os +import re +from typing import AsyncGenerator, Callable, Any, List, Tuple + +# Simulated pydantic model for LLM responses (for illustration) +from pydantic import BaseModel + +class TranslationResult(BaseModel): + english: str + chinese: str + +# Dummy implementations for demonstration. +# In production, replace these with real pydantic‑ai Agent calls. + +class APlanningV2: + def __init__(self, project, tm): + self.project = project + self.taskmgr = tm # Assume this is defined elsewhere + self.scan_list_for_larget_context = [] + + async def ask_openai_for_business_flow(self, function_name: str, contract_code_without_comment: str) -> List[str]: + """ + Asynchronously call the LLM to get business flows starting with function_name. + Returns a list of function names (strings). + """ + prompt = f""" + Based on the code below, analyze the business flows starting with the function {function_name}. + Only output a JSON list of function names that are part of the business flow. + Code: + {contract_code_without_comment} + """ + # Simulate API latency. + await asyncio.sleep(0.5) + # For demonstration, return dummy function names. + return [function_name, function_name + "_helper", function_name + "_final"] + + async def extract_filtered_functions(self, business_flow_list: List[str]) -> List[str]: + """ + Process the list of function names from the LLM output. + For demonstration, simply remove duplicates. + """ + await asyncio.sleep(0.1) + return list(set(business_flow_list)) + + async def extract_and_concatenate_functions_content(self, function_lists: List[str], contract_info: dict) -> str: + """ + Given a list of function names and contract info, concatenate their contents. + For demonstration, we simulate by returning dummy content for each function. + """ + await asyncio.sleep(0.1) + contents = [f"Content for {fn}" for fn in function_lists] + return "\n".join(contents) + + async def extract_results(self, text: str) -> List[dict]: + """ + Extract results from a text output (dummy implementation). + """ + await asyncio.sleep(0.2) + # For demonstration, return a dummy list. + return [{"result": "yes"}] + + async def merge_and_sort_rulesets(self, high: List[dict], medium: List[dict]) -> List[dict]: + """ + Merge two rulesets and sort by sim_score. + """ + combined = high + medium + combined.sort(key=lambda x: x.get("sim_score", 0), reverse=True) + return combined + + async def decode_business_flow_list_from_response(self, response: str) -> List[str]: + """ + Decode a JSON-formatted business flow list from a response. + """ + pattern = r'({\s*\"[a-zA-Z0-9_]+\"\s*:\s*\[[^\]]*\]\s*})' + matches = re.findall(pattern, response) + functions = set() + for match in matches: + try: + data = json.loads(match) + for key, value in data.items(): + functions.add(key) + functions.update(value) + except Exception: + continue + return list(functions) + + async def identify_contexts(self, functions_to_check: List[dict]) -> dict: + """ + For each function in functions_to_check, identify its sub_calls and parent_calls. + Returns a dictionary mapping function names to context information. + """ + await asyncio.sleep(0.2) + contexts = {} + for function in functions_to_check: + name = function.get("name", "unknown") + contexts[name] = { + "sub_calls": [{"name": name + "_sub", "content": "dummy sub-call content"}], + "parent_calls": [{"name": name + "_parent", "content": "dummy parent-call content"}] + } + return contexts + + async def search_business_flow(self, + all_business_flow: dict, + all_business_flow_line: dict, + all_business_flow_context: dict, + function_name: str, + contract_name: str) -> Tuple[str, List[Tuple[int, int]], str]: + """ + Search for business flow information given a function name and contract name. + Returns a tuple of: + - business_flow_code (str) + - line_info_list (list of tuples) + - context_info (str) + """ + await asyncio.sleep(0.3) + business_flow_code = f"Business flow code for {function_name} in contract {contract_name}" + line_info_list = [(1, 10)] + context_info = f"Extended context for {contract_name}" + return business_flow_code, line_info_list, context_info + + async def common_ask_for_json(self, prompt: str) -> dict: + """ + Simulate an asynchronous call to an LLM API that returns a JSON result. + """ + await asyncio.sleep(0.5) + dummy_response = {"business_types": ["type1", "type2"]} + print(f"[LLM] Processed prompt (first 60 chars): {prompt[:60]}... -> {dummy_response}") + return dummy_response + + async def do_planning(self) -> AsyncGenerator[Callable[[], Any], None]: + """ + Asynchronously process planning for each function. + Yield a lambda that returns a future (coroutine) for each planning iteration. + """ + # Simulate fetching tasks – if tasks already exist, planning may be skipped. + tasks = await self.taskmgr.get_task_list_by_id(self.project.project_id) + if tasks: + return # Planning already done. + + # Filter functions to check. + functions_to_check = [f for f in self.project.functions_to_check if "test" not in f.get('name', "")] + self.project.functions_to_check = functions_to_check + + # Optionally, if business code switching is enabled, get all business flows. + switch_business_code = eval(os.environ.get('SWITCH_BUSINESS_CODE', 'True')) + if switch_business_code: + all_business_flow, all_business_flow_line, all_business_flow_context = await self.get_all_business_flow(self.project.functions_to_check) + else: + all_business_flow = all_business_flow_line = all_business_flow_context = {} + + # Process each function. + for function in self.project.functions_to_check: + name = function.get('name', "unknown") + contract_name = function.get('contract_name', "default") + print(f"Processing function: {name}") + + # Retrieve business flow code and context. + business_flow_code, line_info_list, other_contract_context = await self.search_business_flow( + all_business_flow, all_business_flow_line, all_business_flow_context, + name.split(".")[1] if "." in name else name, contract_name + ) + # Build a prompt. + type_check_prompt = ( + "分析以下智能合约代码,判断它属于哪些业务类型。可能的类型包括:\n" + "chainlink, dao, inline assembly, lending, liquidation, liquidity manager, signature, " + "slippage, univ3, other\n" + "请以JSON格式返回结果,格式为: {\"business_types\": [\"type1\", \"type2\"]}\n\n" + "代码:\n{0}" + ) + formatted_prompt = type_check_prompt.format(business_flow_code + "\n" + other_contract_context + "\n" + function.get('content', "")) + + # Yield a lambda that, when called, returns the coroutine to get the JSON response. + yield lambda: self.common_ask_for_json(formatted_prompt) + + # Optionally, you might also yield additional intermediate context. + # For example: + # yield lambda: asyncio.sleep(0, result={"function": name, "prompt": formatted_prompt}) + + # Dummy implementation of get_all_business_flow for completeness. + async def get_all_business_flow(self, functions_to_check: List[dict]) -> Tuple[dict, dict, dict]: + await asyncio.sleep(0.5) + # Return dummy dictionaries. + return {}, {}, {} + + # Dummy implementations for task manager and project for testing purposes. + class AProjectTaskMgr: + async def get_task_list_by_id(self, project_id: int) -> List[Any]: + await asyncio.sleep(0.1) + return [] # Simulate no tasks yet. + + class DummyProject: + project_id = 1 + functions_to_check = [ + {"name": "ContractA.transfer", "content": "function transfer(...) {}", "contract_name": "ContractA"}, + {"name": "ContractA.approve", "content": "function approve(...) {}", "contract_name": "ContractA"} + ] + + # --------------------------- + # Example Orchestrator: Consume the Async Generator + # --------------------------- + # async def process_planning(): + # project = DummyProject() + # planning = APlanningV2(project) + # # Iterate over each planning stage. + # async for stage_callable in planning.do_planning(): + # # Call the lambda to get the coroutine and await its result. + # result = await stage_callable() + # print("Orchestrator received planning stage result:", result) + + # if __name__ == "__main__": + # asyncio.run(process_planning()) + +# --------------------------- +# Example Usage: An Orchestrator Evaluating the Generator +# --------------------------- +async def process_planning(): + # Assume we have a dummy project object with required attributes. + class DummyProject: + project_id = 1 + functions_to_check = [ + {"name": "ContractA.transfer", "content": "function transfer...", "contract_name": "ContractA"}, + {"name": "ContractA.approve", "content": "function approve...", "contract_name": "ContractA"} + ] + project = DummyProject() + planning = APlanningV2(project) + + async for stage_callable in planning.do_planning(): + # stage_callable is a lambda that returns a coroutine (future). + result = await stage_callable() + print("Agent processed planning stage with result:", result) + +if __name__ == "__main__": + asyncio.run(process_planning()) diff --git a/src/planning/business_flow_extractor.py b/src/planning/business_flow_extractor.py new file mode 100644 index 00000000..4d255b11 --- /dev/null +++ b/src/planning/business_flow_extractor.py @@ -0,0 +1,397 @@ +import json +import re +from typing import Any, Dict, Tuple, List, Optional +from collections import defaultdict +import logging + +# Local module imports +from dao.entity import Project_Task # Ensure usage or remove if unnecessary +from library.parsing import CallGraph # Replace with actual import path +from library.sgp.utilities.contract_extractor import ( + group_functions_by_contract, + check_function_if_public_or_external, + check_function_if_view_or_pure +) + +# Configure logger (if not already configured elsewhere) +logger = logging.getLogger(__name__) +if not logger.hasHandlers(): + logger.setLevel(logging.INFO) # Adjust as needed + handler = logging.StreamHandler() + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + +class BusinessFlowExtractor: + # Centralized language patterns to avoid redundancy + LANGUAGE_PATTERNS = { + '.rust': lambda f: True, # No visibility filter for Rust + '.python': lambda f: True, # No visibility filter for Python + '.move': lambda f: f.get('visibility') == 'public', + '.fr': lambda f: f.get('visibility') == 'public', + '.java': lambda f: f.get('visibility') in ['public', 'protected'], + '.cairo': lambda f: f.get('visibility') == 'public', + '.tact': lambda f: f.get('visibility') == 'public', + '.func': lambda f: f.get('visibility') == 'public' + } + + def __init__(self, call_graph: CallGraph) -> None: + """ + Initialize the BusinessFlowExtractor with an instance of CallGraph. + + :param call_graph: An instance of the CallGraph class. + """ + self.call_graph = call_graph + + def get_all_business_flow( + self, + functions_to_check: List[str] + ) -> Tuple[ + Dict[str, Dict[str, Any]], + Dict[str, Dict[str, List[Tuple[int, int]]]], + Dict[str, Dict[str, str]] + ]: + """ + Extracts all business flows for a list of functions. + + :param functions_to_check: A list of function names to extract business flows for. + :return: + - all_business_flow: Dict[contract_name, Dict[function_name, business_flow_code]] + - all_business_flow_line: Dict[contract_name, Dict[function_name, List[Tuple[start_line, end_line]]]] + - all_business_flow_context: Dict[contract_name, Dict[function_name, extended_flow_code]] + """ + if not functions_to_check: + logger.warning("No functions provided for business flow extraction.") + return {}, {}, {} + + # Group functions by their respective contracts + grouped_functions = group_functions_by_contract(functions_to_check) + # Identify contexts for the functions + contexts = self.call_graph.identify_contexts(functions_to_check) + + # Initialize dictionaries to store business flows + all_business_flow: Dict[str, Dict[str, Any]] = defaultdict(dict) + all_business_flow_line: Dict[str, Dict[str, List[Tuple[int, int]]]] = defaultdict(dict) + all_business_flow_context: Dict[str, Dict[str, str]] = defaultdict(dict) + + logger.info(f"Grouped contract count: {len(grouped_functions)}") + + for contract_info in grouped_functions: + contract_name = contract_info.get('contract_name') + functions = contract_info.get('functions', []) + contract_code_without_comments = contract_info.get('contract_code_without_comment', '') + file_path = contract_info.get('file_path') # Assuming 'file_path' key exists + + if not contract_name: + logger.warning("Contract info missing 'contract_name'. Skipping.") + continue + + logger.info(f"Processing contract: {contract_name}") + + # Determine file extension and corresponding visibility filter + file_ext = self._get_file_extension(functions) + visibility_filter = self._get_visibility_filter(file_ext) + + # Filter public/external functions based on visibility + all_public_external_function_names = [ + self._extract_function_name(function.get('name', '')) + for function in functions + if visibility_filter(function) + ] + + logger.info(f"Public/External functions count in {contract_name}: {len(all_public_external_function_names)}") + + for function_name in all_public_external_function_names: + if not function_name: + logger.warning(f"Encountered empty function name in contract '{contract_name}'. Skipping.") + continue + + logger.debug(f"Processing function: {function_name}") + + # Special handling for Python contracts with a single public/external function + if "_python" in contract_name.lower() and len(all_public_external_function_names) == 1: + # Assuming downstream methods expect a dictionary, not a JSON string + business_flow_list = { + function_name: all_public_external_function_names + } + else: + try: + business_flow_list = self.ask_openai_for_business_flow(function_name, contract_code_without_comments) + except Exception as e: + logger.error(f"Error fetching business flow for {function_name}: {e}") + business_flow_list = {} + + if not business_flow_list: + logger.warning(f"No business flow data for function: {function_name}") + continue + + # Extract function lists from business_flow_list + try: + function_lists = business_flow_list.get("BusinessFlow", {}).get("flow", []) + function_lists = [fn for fn in function_lists if fn != function_name] + except AttributeError as e: + logger.error(f"Error processing business_flow_list for {function_name}: {e}") + continue + + logger.debug(f"Business flow list for {function_name}: {function_lists}") + + # Retrieve line information for each function in the flow + line_info_list = [] + for fn in function_lists: + if isinstance(fn, str) and fn != "-1": + func_struct = self.call_graph.get_function_detail(file=file_path, contract=contract_name, function=fn) + if func_struct: + line_info = (func_struct.get('start_line'), func_struct.get('end_line')) + line_info_list.append(line_info) + + # Extract and concatenate function contents + business_flow_code = self.extract_and_concatenate_functions_content(function_lists, contract_info) + + # Build extended flow code from contexts + extended_flow_code = self._build_extended_flow_code(contract_name, function_lists, contexts) + + # Assign to respective dictionaries + all_business_flow[contract_name][function_name] = business_flow_code + all_business_flow_line[contract_name][function_name] = line_info_list + all_business_flow_context[contract_name][function_name] = extended_flow_code.strip() + + return all_business_flow, all_business_flow_line, all_business_flow_context + + def _get_file_extension(self, functions: List[Dict[str, Any]]) -> Optional[str]: + """ + Determine the file extension based on the functions' relative file paths. + + :param functions: List of function dictionaries. + :return: File extension if found, else None. + """ + for func in functions: + file_path = func.get('relative_file_path', '') + for ext, filter_func in self.LANGUAGE_PATTERNS.items(): + if file_path.endswith(ext) and filter_func(func): + return ext + return None + + def _get_visibility_filter(self, file_ext: Optional[str]): + """ + Retrieve the visibility filter lambda based on the file extension. + + :param file_ext: File extension. + :return: A lambda function for visibility filtering. + """ + return self.LANGUAGE_PATTERNS.get(file_ext, lambda f: True) + + def _extract_function_name(self, full_name: str) -> str: + """ + Extract the function name from its full name. + + :param full_name: Full function name (e.g., "Contract.Function"). + :return: Function name (e.g., "Function"). + """ + if "." in full_name: + return full_name.split(".")[-1] + else: + logger.warning(f"Function name '{full_name}' does not contain a period. Returning as is.") + return full_name + + def _build_extended_flow_code( + self, + contract_name: str, + function_lists: List[str], + contexts: Dict[str, Dict[str, Any]] + ) -> str: + """ + Build the extended flow code by aggregating context content. + + :param contract_name: Name of the contract. + :param function_lists: List of function names involved in the business flow. + :param contexts: Contexts dictionary containing sub_calls and parent_calls. + :return: Concatenated extended flow code. + """ + extended_flow_parts = [] + for func in function_lists: + key = f"{contract_name}.{func}" + context = contexts.get(key, {}) + sub_calls = context.get("sub_calls", []) + parent_calls = context.get("parent_calls", []) + + combined_calls = sub_calls + parent_calls + if not combined_calls: + logger.debug(f"No sub_calls or parent_calls found for key '{key}'.") + continue + + context_content = "\n".join(call.get("content", "") for call in combined_calls if call.get("content")) + if context_content: + extended_flow_parts.append(context_content) + else: + logger.debug(f"No content found in sub_calls or parent_calls for key '{key}'.") + + extended_flow_code = "\n".join(extended_flow_parts) + return extended_flow_code.strip() + + def ask_openai_for_business_flow(self, function_name: str, contract_code: str) -> Dict[str, Any]: + """ + Interface with OpenAI to retrieve business flow for a given function. + + :param function_name: Name of the function. + :param contract_code: Source code of the contract. + :return: Parsed JSON response from OpenAI. + """ + prompt = f""" + Analyze the business flow for the function '{function_name}' in the following contract code. Identify all functions that are called by '{function_name}' and the sequence of these calls. Provide the output in the following JSON format: + + {{ + "BusinessFlow": {{ + "flow": ["{function_name}", "FunctionA", "FunctionB", "..."] + }} + }} + """ + + try: + logger.info(f"Asking OpenAI for business flow of function '{function_name}'.") + response = openai.Completion.create( + engine="text-davinci-003", # Replace with the desired engine + prompt=prompt, + max_tokens=500, + n=1, + stop=None, + temperature=0.5 + ) + response_text = response.choices[0].text.strip() + business_flow = json.loads(response_text) + logger.debug(f"Received business flow from OpenAI for function '{function_name}': {business_flow}") + return business_flow + except openai.error.OpenAIError as e: + logger.error(f"OpenAI API error while fetching business flow for '{function_name}': {e}") + return {} + except json.JSONDecodeError as e: + logger.error(f"JSON decoding error for function '{function_name}': {e}") + return {} + except Exception as e: + logger.error(f"Unexpected error while fetching business flow for '{function_name}': {e}") + return {} + + def extract_and_concatenate_functions_content( + self, + function_names: List[str], + contract_info: Dict[str, Any] + ) -> str: + """ + Extracts the content of functions based on a given function list and contract info, + and concatenates them into a single string. + + :param function_names: List of function names to extract. + :param contract_info: Information about the contract containing the functions. + :return: Concatenated source code of the specified functions. + """ + file_path = contract_info.get('file_path') + contract_name = contract_info.get('contract_name') + if not file_path or not contract_name: + logger.error("Contract information missing 'file_path' or 'contract_name'.") + return "" + + functions = contract_info.get('functions', []) + concatenated_code_parts = [] + + for func_name in function_names: + if not func_name: + logger.warning(f"Encountered empty function name in contract '{contract_name}'. Skipping.") + continue + + func_detail = self.call_graph.get_function_detail(file=file_path, contract=contract_name, function=func_name) + if func_detail: + func_src = self.call_graph.get_function_src(file=file_path, func=func_detail) + if func_src: + concatenated_code_parts.append(func_src) + else: + logger.warning(f"Source code for function '{func_name}' in contract '{contract_name}' is empty.") + else: + logger.warning(f"Function '{func_name}' not found in contract '{contract_name}'.") + + concatenated_code = "\n".join(concatenated_code_parts) + return concatenated_code.strip() + + def merge_and_sort_rulesets( + self, + high: List[Dict[str, Any]], + medium: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + Merge two rulesets based on sim_score and sort them in descending order. + + :param high: List of high-priority rules. + :param medium: List of medium-priority rules. + :return: Sorted combined ruleset. + """ + combined_ruleset = high + medium + sorted_ruleset = sorted(combined_ruleset, key=lambda x: x.get('sim_score', 0), reverse=True) + logger.debug(f"Merged and sorted ruleset with {len(sorted_ruleset)} rules.") + return sorted_ruleset + + def decode_business_flow_list_from_response(self, response: str) -> List[str]: + """ + Extracts unique function names from a JSON response. + + :param response: JSON string containing business flow information. + :return: A list of unique function names. + """ + unique_functions = set() + try: + json_obj = json.loads(response) + business_flow = json_obj.get("BusinessFlow", {}).get("flow", []) + for func in business_flow: + if isinstance(func, str): + func_name = func.split(".")[-1] if "." in func else func + unique_functions.add(func_name) + else: + logger.warning(f"Unexpected function format in business flow: {func}") + except json.JSONDecodeError as e: + logger.error(f"JSON decoding error: {e}") + except Exception as e: + logger.error(f"Unexpected error during business flow decoding: {e}") + return sorted(unique_functions) + + def search_business_flow( + self, + all_business_flow: Dict[str, Dict[str, Any]], + all_business_flow_line: Dict[str, Dict[str, Any]], + all_business_flow_context: Dict[str, Dict[str, Any]], + function_name: str, + contract_name: str + ) -> 'BusinessFlowResult': + """ + Search for the business flow code based on a function name and contract name. + + :param all_business_flow: The dictionary containing all business flows. + :param all_business_flow_line: The dictionary containing business flow lines. + :param all_business_flow_context: The dictionary containing business flow contexts. + :param function_name: The name of the function to search for. + :param contract_name: The name of the contract where the function is located. + :return: BusinessFlowResult containing (business_flow_code, business_flow_line, business_flow_context) + if found, otherwise (None, [], None). + """ + contract_flows = all_business_flow.get(contract_name) + contract_flows_line = all_business_flow_line.get(contract_name, {}) + contract_flows_context = all_business_flow_context.get(contract_name, {}) + + if not contract_flows: + logger.warning(f"Contract '{contract_name}' not found in all_business_flow.") + return BusinessFlowResult(None, [], None) + + business_flow_code = contract_flows.get(function_name) + if business_flow_code is None: + logger.warning(f"Function '{function_name}' not found in contract '{contract_name}'.") + return BusinessFlowResult(None, [], None) + + business_flow_line = contract_flows_line.get(function_name, []) + business_flow_context = contract_flows_context.get(function_name, "") + + return BusinessFlowResult(business_flow_code, business_flow_line, business_flow_context) + + +from typing import NamedTuple + +class BusinessFlowResult(NamedTuple): + business_flow_code: Optional[str] + business_flow_line: List[Tuple[int, int]] + business_flow_context: Optional[str] diff --git a/src/planning/planning_v2.py b/src/planning/planning_v2.py index 48070e23..b65a9663 100644 --- a/src/planning/planning_v2.py +++ b/src/planning/planning_v2.py @@ -8,16 +8,19 @@ import pickle from openai_api.openai import * import re +from project.project_audit import ProjectAudit +from dao.task_mgr import ProjectTaskMgr +from library.parsing.callgraph import CallGraph ''' 根据每个function 的 functionality embbeding 匹配结果 ''' class PlanningV2(object): - def __init__(self, project,taskmgr) -> None: - self.project = project - self.taskmgr=taskmgr - self.scan_list_for_larget_context=[] - + def __init__(self, project:ProjectAudit,taskmgr:ProjectTaskMgr) -> None: + self.project:ProjectAudit=project + self.taskmgr:ProjectTaskMgr=taskmgr + self.callgraph:CallGraph = project.cg + self.scan_list_for_larget_context:list=[] def ask_openai_for_business_flow(self,function_name,contract_code_without_comment): prompt=f""" diff --git a/src/project/__init__.py b/src/project/__init__.py index 6a2a50a3..e69de29b 100644 --- a/src/project/__init__.py +++ b/src/project/__init__.py @@ -1 +0,0 @@ -from .project_audit import ProjectAudit diff --git a/src/project/aproject_audit.py b/src/project/aproject_audit.py new file mode 100644 index 00000000..87db136c --- /dev/null +++ b/src/project/aproject_audit.py @@ -0,0 +1,154 @@ +import csv +from typing import Any, Dict, Optional, Tuple, Union, List, Set +from os import path +import asyncio +from nodes_config import nodes_config +from library.parsing.callgraph import CallGraph +from .aproject_parser import parse_project_async,ABaseProjectFilter +import re +from library.sgp.utilities.contract_extractor import extract_state_variables_from_code, extract_state_variables_from_code_move + +__all__ = ('ProjectAudit') + +class AProjectAudit(object): + def __init__(self, config: nodes_config) -> None: + self.config: nodes_config = config + self.project_id: str = config.id + self.project_path: str = config.base_dir + self.cg = CallGraph(root=path.join(config.base_dir, config.src_dir)) + + self.functions_to_check: list = [] + self.functions: list = [] + self.tasks: list = [] + self.taskkeys: set = set() + + async def analyze_function_relationships(self, functions_to_check: List[Dict]) -> Tuple[Dict[str, Dict[str, Set]], Dict[str, Dict]]: + # Construct a mapping and calling relationship dictionary from function name to function information + func_map = {} + relationships = {'upstream': {}, 'downstream': {}} + for idx, func in enumerate(functions_to_check): + func_name = func['name'].split('.')[-1] + func['func_name'] = func_name + func_map[func_name] = { + 'index': idx, + 'data': func + } + + # Analyze the calling relationship of each function + for idx, func in enumerate(functions_to_check): + func_name = func['name'].split('.')[-1] + content = func['content'].lower() + + if func_name not in relationships['upstream']: + relationships['upstream'][func_name] = set() + if func_name not in relationships['downstream']: + relationships['downstream'][func_name] = set() + + # Check whether other functions call the current function + for other_func in functions_to_check: + if other_func == func: + continue + other_name = other_func['name'].split('.')[-1] + other_content = other_func['content'].lower() + + # If other functions call the current function + if re.search(r'\b' + re.escape(func_name.lower()) + r'\b', other_content): + relationships['upstream'][func_name].add(other_name) + if other_name not in relationships['downstream']: + relationships['downstream'][other_name] = set() + relationships['downstream'][other_name].add(func_name) + + # If the current function calls other functions + if re.search(r'\b' + re.escape(other_name.lower()) + r'\b', content): + relationships['downstream'][func_name].add(other_name) + if other_name not in relationships['upstream']: + relationships['upstream'][other_name] = set() + relationships['upstream'][other_name].add(func_name) + + return relationships, func_map + + async def build_call_tree(self, func_name: str, relationships: Dict[str, Dict[str, Set]], direction: str, func_map: Dict[str, Dict], visited: Optional[Set[str]] = None) -> Optional[Dict[str, Any]]: + if visited is None: + visited = set() + + if func_name in visited: + return None + + visited.add(func_name) + + # 获取函数完整信息 + func_info = func_map.get(func_name, {'index': -1, 'data': None}) + node = { + 'name': func_name, + 'index': func_info['index'], + 'function_data': func_info['data'], # 包含完整的函数信息 + 'children': [] + } + + # 获取该方向上的所有直接调用 + related_funcs = relationships[direction].get(func_name, set()) + + # 递归构建每个相关函数的调用树 + for related_func in related_funcs: + child_tree: Optional[Dict[str, Any]] = await self.build_call_tree(related_func, relationships, direction, func_map, visited.copy()) + if child_tree: + node['children'].append(child_tree) + + return node + + def print_call_tree(self, node: Dict[str, Any], level: int = 0, prefix: str = ''): + if not node: + return + + # 打印当前节点的基本信息 + func_data = node['function_data'] + if func_data: + print(f"{prefix}{'└─' if level > 0 else ''}{node['name']} (index: {node['index']}, " + f"lines: {func_data['start_line']}-{func_data['end_line']})") + else: + print(f"{prefix}{'└─' if level > 0 else ''}{node['name']} (index: {node['index']})") + + # 打印子节点 + for i, child in enumerate(node['children']): + is_last = i == len(node['children']) - 1 + new_prefix = prefix + (' ' if level == 0 else '│ ' if not is_last else ' ') + self.print_call_tree(child, level + 1, new_prefix + ('└─' if is_last else '├─')) + + async def parse(self, white_files: List[str], white_functions: List[str]) -> None: + parser_filter = ABaseProjectFilter(white_files, white_functions) + functions, functions_to_check = await parse_project_async(self.project_path, parser_filter) + self.functions = functions + self.functions_to_check = functions_to_check + + relationships: Dict[str, Dict] + func_map: Dict + + # 分析函数关系 + relationships, func_map = await self.analyze_function_relationships(functions_to_check) + + # 为每个函数构建并打印调用树 + call_trees: List[Dict] = [] + for func in functions_to_check: + func_name = func['name'].split('.')[-1] + + upstream_tree = await self.build_call_tree(func_name, relationships, 'upstream', func_map) + downstream_tree = await self.build_call_tree(func_name, relationships, 'downstream', func_map) + + state_variables: List[str] = [] + if func['relative_file_path'].endswith('.move'): + state_variables = extract_state_variables_from_code_move(func['contract_code'], func['relative_file_path']) + if func['relative_file_path'].endswith('.sol') or func['relative_file_path'].endswith('.fr'): + state_variables = extract_state_variables_from_code(func['contract_code']) + + state_variables_text = '\n'.join(state_variables) if state_variables else '' + call_trees.append({ + 'function': func_name, + 'upstream_tree': upstream_tree, + 'downstream_tree': downstream_tree, + 'state_variables': state_variables_text + }) + + self.call_trees: List[Dict] = call_trees + + def get_function_names(self) -> Set[str]: + return set([function['name'] for function in self.functions]) \ No newline at end of file diff --git a/src/project/aproject_parser.py b/src/project/aproject_parser.py new file mode 100644 index 00000000..327c84f9 --- /dev/null +++ b/src/project/aproject_parser.py @@ -0,0 +1,131 @@ +import os +from typing import Any, Dict, Tuple, Union, List, Set +from os import path +import asyncio + +async def parse_project_async(project_path: str, project_filter=None): + if project_filter is None: + project_filter = ABaseProjectFilter([], []) + + ignore_folders = set() + if os.environ.get('IGNORE_FOLDERS'): + ignore_folders = set(os.environ.get('IGNORE_FOLDERS').split(',')) + ignore_folders.add('.git') + all_results = [] + + async def walk_and_parse(dirpath: str, dirs: List[str], files: List[str]): + dirs[:] = [d for d in dirs if d not in ignore_folders] + for file in files: + to_scan = not await project_filter.filter_file_async(dirpath, file) + sol_file = os.path.join(dirpath, file) # relative path + absolute_path = os.path.abspath(sol_file) # absolute path + print("parsing file: ", sol_file, " " if to_scan else "[skipped]") + + if to_scan: + results = await get_antlr_parsing_async(sol_file) + for result in results: + result['relative_file_path'] = sol_file + result['absolute_file_path'] = absolute_path + all_results.extend(results) + + tasks = [] + for dirpath, dirs, files in os.walk(project_path): + tasks.append(walk_and_parse(dirpath, dirs, files)) + + await asyncio.gather(*tasks) + + functions = [result for result in all_results if result['type'] == 'FunctionDefinition'] + # fix func name + fs = [] + for func in functions: + if func['name'][8:] != "tor": + name = func['name'][8:] # remove SPECIAL_ Prefix, I forgot the specific reason, it seems to be to consider a specific function name + else: + name = "constructor" + func['name'] = "%s.%s" % (func['contract_name'], name) + fs.append(func) + + fs_filtered = fs[:] + # 2. filter contract + fs_filtered = [func for func in fs_filtered if not await project_filter.filter_contract_async(func)] + + # 3. filter functions + fs_filtered = [func for func in fs_filtered if not await project_filter.filter_functions_async(func)] + + return fs, fs_filtered + +class ABaseProjectFilter(object): + def __init__(self, white_files=[], white_functions=[]): + self.white_files = white_files + self.white_functions = white_functions + + async def filter_file_async(self, path: str, filename: str) -> bool: + # 检查文件后缀 + valid_extensions = ('.sol', '.rs', '.py', '.move', '.cairo', '.tact', '.fc', '.fr', '.java') + if not any(filename.endswith(ext) for ext in valid_extensions) or filename.endswith('.t.sol'): + return True + + # 如果白名单不为空,检查文件是否在白名单中 + if len(self.white_files) > 0: + return not any(os.path.basename(filename) in white_file for white_file in self.white_files) + + return False + + async def filter_contract_async(self, function: Dict[str, Any]) -> bool: + # rust情况下,不进行筛选 + if '_rust' in function["name"]: + return False + if '_python' in function["name"]: + return False + if '_move' in function["name"]: + return False + if '_cairo' in function["name"]: + return False + if '_tact' in function["name"]: + return False + if '_func' in function["name"]: + return False + if '_fa' in function["name"]: + return False + + # solidity情况下,进行筛选 + if str(function["contract_name"]).startswith("I") and function["contract_name"][1].isupper(): + print("function ", function['name'], " skipped for interface contract") + return True + if "test" in str(function["name"]).lower(): + print("function ", function['name'], " skipped for test function") + return True + + if "function init" in str(function["content"]).lower() or "function initialize" in str(function["content"]).lower() or "constructor(" in str(function["content"]).lower() or "receive()" in str(function["content"]).lower() or "fallback()" in str(function["content"]).lower(): + print("function ", function['name'], " skipped for constructor") + return True + + return False + + async def filter_functions_async(self, function: Dict[str, Any]) -> bool: + # Step 3: function 筛选 ( 白名单检查 ) + if len(self.white_functions) == 0: + return False + return function['name'] not in self.white_functions + +async def get_antlr_parsing_async(sol_file: str) -> List[Dict[str, Any]]: + # Simulate asynchronous parsing process + await asyncio.sleep(1) # Placeholder for actual async operation + return [{"type": "FunctionDefinition", "contract_name": "example", "name": "exampleFunc"}] # TODO!! + +async def main(): + from library.dataset_utils import load_dataset_async + dataset_base = "../../dataset/agent-v1-c4" + projects = await load_dataset_async(dataset_base) + project = projects['whalefall'] + + project_path = os.path.join(project['base_path'], project['path']) + white_files, white_functions = project.get('files', []), project.get('functions', []) + + parser_filter = ABaseProjectFilter(white_files, white_functions) + functions, functions_to_check = await parse_project_async(project_path, parser_filter) + + print(functions_to_check) + +if __name__ == '__main__': + asyncio.run(main()) \ No newline at end of file diff --git a/src/project/dataclasses.py b/src/project/dataclasses.py new file mode 100644 index 00000000..e69de29b diff --git a/src/project/project_audit.py b/src/project/project_audit.py index 4e91f8e9..4e9e7dfb 100644 --- a/src/project/project_audit.py +++ b/src/project/project_audit.py @@ -1,55 +1,65 @@ import csv -from .project_parser import parse_project, BaseProjectFilter +from typing import Any, Dict, Tuple, Union, List, Set +from os import path +from nodes_config import nodes_config +from library.parsing.callgraph import CallGraph +from .project_parser import parse_project, BaseProjectFilter import re from library.sgp.utilities.contract_extractor import extract_state_variables_from_code, extract_state_variables_from_code_move - +__all__ = ('ProjectAudit') class ProjectAudit(object): - def analyze_function_relationships(self, functions_to_check): + def analyze_function_relationships(self, functions_to_check: List[Dict]) -> Tuple[Dict[str, Dict[str, Set]], Dict[str, Dict]]: + # Construct a mapping and calling relationship dictionary from function name to function information # 构建函数名到函数信息的映射和调用关系字典 func_map = {} relationships = {'upstream': {}, 'downstream': {}} - + for idx, func in enumerate(functions_to_check): func_name = func['name'].split('.')[-1] + func['func_name'] = func_name func_map[func_name] = { 'index': idx, 'data': func } - + + # Analyze the calling relationship of each function # 分析每个函数的调用关系 - for func in functions_to_check: + for idx,func in enumerate(functions_to_check): func_name = func['name'].split('.')[-1] content = func['content'].lower() - + if func_name not in relationships['upstream']: relationships['upstream'][func_name] = set() if func_name not in relationships['downstream']: relationships['downstream'][func_name] = set() - + + # Check whether other functions call the current function # 检查其他函数是否调用了当前函数 for other_func in functions_to_check: if other_func == func: continue - + other_name = other_func['name'].split('.')[-1] other_content = other_func['content'].lower() - - # 如果其他函数调用了当前函数 + + # If other functions call the current function if re.search(r'\b' + re.escape(func_name.lower()) + r'\b', other_content): relationships['upstream'][func_name].add(other_name) + if other_name not in relationships['downstream']: relationships['downstream'][other_name] = set() relationships['downstream'][other_name].add(func_name) - - # 如果当前函数调用了其他函数 + + # If the current function calls other functions if re.search(r'\b' + re.escape(other_name.lower()) + r'\b', content): relationships['downstream'][func_name].add(other_name) + if other_name not in relationships['upstream']: relationships['upstream'][other_name] = set() relationships['upstream'][other_name].add(func_name) - - return relationships, func_map + return relationships, func_map + def build_call_tree(self, func_name, relationships, direction, func_map, visited=None): if visited is None: visited = set() @@ -74,7 +84,7 @@ def build_call_tree(self, func_name, relationships, direction, func_map, visited # 递归构建每个相关函数的调用树 for related_func in related_funcs: - child_tree = self.build_call_tree(related_func, relationships, direction, func_map, visited.copy()) + child_tree: None | dict[str, Any] = self.build_call_tree(related_func, relationships, direction, func_map, visited.copy()) if child_tree: node['children'].append(child_tree) @@ -98,26 +108,32 @@ def print_call_tree(self, node, level=0, prefix=''): new_prefix = prefix + (' ' if level == 0 else '│ ' if not is_last else ' ') self.print_call_tree(child, level + 1, new_prefix + ('└─' if is_last else '├─')) - def __init__(self, project_id, project_path, db_engine): - self.project_id = project_id - self.project_path = project_path - self.functions = [] - self.functions_to_check = [] - self.tasks = [] + #project_id, project_path, db_engine + def __init__(self, config:nodes_config) -> None: + self.config:nodes_config = config + self.project_id:str = config.id + self.project_path:str = config.base_dir + self.cg = CallGraph(root=path.join(config.base_dir, config.src_dir)) + + self.functions_to_check:list = [] + self.functions:list = [] + self.tasks:list = [] self.taskkeys = set() + - def parse(self, white_files, white_functions): + def parse(self, white_files, white_functions) -> None: parser_filter = BaseProjectFilter(white_files, white_functions) functions, functions_to_check = parse_project(self.project_path, parser_filter) self.functions = functions self.functions_to_check = functions_to_check - + relationships:Dict[str,Dict] + func_map:Dict # 分析函数关系 - relationships, func_map = self.analyze_function_relationships(functions_to_check) + relationships,func_map = self.analyze_function_relationships(functions_to_check) # 为每个函数构建并打印调用树 - call_trees = [] + call_trees:list = [] for func in functions_to_check: func_name = func['name'].split('.')[-1] # print(f"\nAnalyzing function: {func_name}") @@ -140,7 +156,7 @@ def parse(self, white_files, white_functions): 'state_variables': state_variables_text }) - self.call_trees = call_trees + self.call_trees:list = call_trees def get_function_names(self): return set([function['name'] for function in self.functions]) \ No newline at end of file diff --git a/src/project/project_parser.py b/src/project/project_parser.py index 2225efdc..31c27a3f 100644 --- a/src/project/project_parser.py +++ b/src/project/project_parser.py @@ -1,12 +1,13 @@ -from library.sgp.sgp_parser import get_antlr_parsing from library.parsing.callgraph import CallGraph import os import re - from library.sgp.utilities.contract_extractor import extract_state_variables_from_code from .project_settings import FILE_PARTIAL_WHITE_LIST, PATH_PARTIAL_WHITE_LIST, PATH_WHITE_LIST, OPENZEPPELIN_CONTRACTS,OPENZEPPELIN_FUNCTIONS +from library.sgp.sgp_parser import get_antlr_parsing +from dataclasses import * +from box import Box -class Function(dict): +class Function(Box): def __init__(self, file, contract, func): self.file = file self.contract = contract @@ -156,7 +157,6 @@ def filter_functions(self, function): def parse_project(project_path, project_filter = None): - if project_filter is None: project_filter = BaseProjectFilter([], []) @@ -184,9 +184,19 @@ def parse_project(project_path, project_filter = None): # fix func name fs = [] for func in functions: - name = func['name'][8:] # remove special_前缀,具体为啥我也忘了,似乎是为了考虑特定的function name + if func['name'][8:] != "tor": + name = func['name'][8:] # remove SPECIAL_ Prefix,I forgot the specific reason, it seems to be to consider a specificfunction name + else: + name = "constructor" func['name'] = "%s.%s" % (func['contract_name'], name) fs.append(func) + + + + # for func in functions: + # name = func['name'][8:] # remove special_前缀,具体为啥我也忘了,似乎是为了考虑特定的function name + # func['name'] = "%s.%s" % (func['contract_name'], name) + # fs.append(func) fs_filtered = fs[:] # 2. filter contract diff --git a/src/prompt_factory/__init__.py b/src/prompt_factory/__init__.py new file mode 100644 index 00000000..d365068f --- /dev/null +++ b/src/prompt_factory/__init__.py @@ -0,0 +1 @@ +print(f"running for {__file__}") \ No newline at end of file diff --git a/src/prompt_factory/prompt_assembler.py b/src/prompt_factory/prompt_assembler.py index bcb6d410..4f308e19 100644 --- a/src/prompt_factory/prompt_assembler.py +++ b/src/prompt_factory/prompt_assembler.py @@ -13,6 +13,16 @@ def assemble_prompt_common(code): +PeripheryPrompt.jailbreak_prompt() + return ret_prompt + def assemble_prompt_pure(code): + ret_prompt=code+"\n"\ + +PeripheryPrompt.role_set_solidity_common()+"\n"\ + +PeripheryPrompt.task_set_blockchain_common()+"\n"\ + +CorePrompt.core_prompt_pure()+"\n"\ + +PeripheryPrompt.guidelines()+"\n"\ + +PeripheryPrompt.jailbreak_prompt() + + return ret_prompt def assemble_prompt_for_specific_project(code, business_type): vul_prompts = [] diff --git a/src/res_processor/ares_processor.py b/src/res_processor/ares_processor.py new file mode 100644 index 00000000..d11ac97d --- /dev/null +++ b/src/res_processor/ares_processor.py @@ -0,0 +1,148 @@ +import pandas as pd +from tqdm import tqdm +import json +from openai_api.openai import ask_claude, common_ask_for_json +import asyncio + +class AResProcessor: + def __init__(self, df): + self.df = df + + async def process(self): + self.df['flow_code_len'] = self.df['业务流程代码'].str.len() + grouped = list(self.df.groupby('业务流程代码')) + tasks = [] + + for flow_code, group in grouped: + task = asyncio.create_task(self._process_group(flow_code, group)) + tasks.append(task) + + processed_results = await asyncio.gather(*tasks) + new_df = pd.DataFrame(processed_results) + + if 'flow_code_len' in new_df.columns: + new_df = new_df.drop('flow_code_len', axis=1) + + original_columns = [col for col in self.df.columns if col != 'flow_code_len'] + new_df = new_df[original_columns] + + return new_df + + async def _process_group(self, flow_code, group): + if len(group) <= 1: + return await self._process_single_vulnerability(group.iloc[0]) + return await self._merge_vulnerabilities(group) + + async def _process_single_vulnerability(self, row): + translate_prompt = f"""请对以下漏洞描述翻译,用中文输出 +原漏洞描述: +{row['漏洞结果']} +""" + translated_description = await ask_claude(translate_prompt) + return { + '漏洞结果': translated_description, + 'ID': row['ID'], + '项目名称': row['项目名称'], + '合同编号': row['合同编号'], + 'UUID': row['UUID'], + '函数名称': row['函数名称'], + '函数代码': row['函数代码'], + '开始行': row['开始行'], + '结束行': row['结束行'], + '相对路径': row['相对路径'], + '绝对路径': row['绝对路径'], + '业务流程代码': row['业务流程代码'], + '业务流程行': row['业务流程行'], + '业务流程上下文': row['业务流程上下文'], + '确认结果': row['确认结果'], + '确认细节': row['确认细节'] + } + + async def _merge_vulnerabilities(self, group): + base_info = group.iloc[0].copy() + merge_prompt = """ +合并一下这几个漏洞,将相同或本质相似漏洞合并到一起,用中文输出,如果存在多个不同漏洞,则分开描述, +但要保证如下:1. 合并的结果相比于原来的漏洞结果不能有任何信息或漏洞缺失,且合并的相同或本质相似的漏洞描述要全面不能有遗漏; +2. 每个合并后的漏洞描述必须非常详细,不能少于600个字用原来的原文翻译来描述漏洞,不能有任何表述遗漏,否则你会受到惩罚, +输出格式如下,必须严格遵循格式输出 +{ +"merged_vulnerabilities": [ +{"description": "合并后的漏洞1描述"}, +{"description": "合并后的漏洞2描述"}, +... +] +} +""" + for _, row in group.iterrows(): + merge_prompt += f"漏洞结果:{row['漏洞结果']}\n" + merge_prompt += "---\n" + + merged_result = await common_ask_for_json(merge_prompt) + try: + merged_data = json.loads(merged_result) + vulnerabilities = merged_data.get('merged_vulnerabilities', []) + if not vulnerabilities: + return { + '漏洞结果': merged_result, + 'ID': base_info['ID'], + '项目名称': base_info['项目名称'], + '合同编号': base_info['合同编号'], + 'UUID': base_info['UUID'], + '函数名称': base_info['函数名称'], + '函数代码': row['函数代码'], + '开始行': min(group['开始行']), + '结束行': max(group['结束行']), + '相对路径': base_info['相对路径'], + '绝对路径': base_info['绝对路径'], + '业务流程代码': base_info['业务流程代码'], + '业务流程行': base_info['业务流程行'], + '业务流程上下文': base_info['业务流程上下文'], + '确认结果': base_info['确认结果'], + '确认细节': base_info['确认细节'] + } + results = [] + for vuln in vulnerabilities: + results.append({ + '漏洞结果': vuln['description'], + 'ID': base_info['ID'], + '项目名称': base_info['项目名称'], + '合同编号': base_info['合同编号'], + 'UUID': base_info['UUID'], + '函数名称': base_info['函数名称'], + '函数代码': row['函数代码'], + '开始行': min(group['开始行']), + '结束行': max(group['结束行']), + '相对路径': base_info['相对路径'], + '绝对路径': base_info['绝对路径'], + '业务流程代码': base_info['业务流程代码'], + '业务流程行': base_info['业务流程行'], + '业务流程上下文': base_info['业务流程上下文'], + '确认结果': base_info['确认结果'], + '确认细节': base_info['确认细节'] + }) + return results + except json.JSONDecodeError: + # Fallback in case of JSON parsing error + return { + '漏洞结果': merged_result, + 'ID': base_info['ID'], + '项目名称': base_info['项目名称'], + '合同编号': base_info['合同编号'], + 'UUID': base_info['UUID'], + '函数名称': base_info['函数名称'], + '函数代码': row['函数代码'], + '开始行': min(group['开始行']), + '结束行': max(group['结束行']), + '相对路径': base_info['相对路径'], + '绝对路径': base_info['绝对路径'], + '业务流程代码': base_info['业务流程代码'], + '业务流程行': base_info['业务流程行'], + '业务流程上下文': base_info['业务流程上下文'], + '确认结果': base_info['确认结果'], + '确认细节': base_info['确认细节'] + } + + def _clean_text(self, text): + if pd.isna(text): + return '' + return str(text).strip() \ No newline at end of file diff --git a/src/root_service.py b/src/root_service.py new file mode 100644 index 00000000..08ffe858 --- /dev/null +++ b/src/root_service.py @@ -0,0 +1,22 @@ +import os +import subprocess + + + +def main(): + # Set the PYTHONPATH environment variable + project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + python_path = os.environ.get('PYTHONPATH', '') + print(python_path) + if project_root not in python_path.split(os.pathsep): + new_python_path = f"{project_root}:{python_path}" + os.environ['PYTHONPATH'] = new_python_path + print(f"Updated PYTHONPATH: {new_python_path}") + else: + print("PYTHONPATH is already set correctly.") + print(new_python_path) + from run import run + subprocess.run("run.py") + +if __name__ == "__main__": + main() diff --git a/src/run.py b/src/run.py new file mode 100644 index 00000000..3db4c873 --- /dev/null +++ b/src/run.py @@ -0,0 +1,209 @@ +import asyncio +from functools import partial +from getpass import getpass +import os +from contextlib import redirect_stdout, redirect_stderr +from typing import Optional +from httpx import AsyncClient +from openai import AsyncOpenAI +from rich.console import Console +from rich.live import Live +from pydantic_ai import Agent, RunContext +from nodes_config import nodes_config, Settings +from main_pipeline import DynamicGovernor, governed_task, main_pipeline + +from tracing import logger as logging # Import the logger from the logView module +from agents.md_output import LogToMarkdown, logView +from dataclasses import dataclass +from sqlalchemy.ext.asyncio.engine import create_async_engine, AsyncEngine +from dao.atask_mgr import AProjectTaskMgr +from codebaseQA.arag_processor import ARAGProcessor +from res_processor.ares_processor import AResProcessor +from project.aproject_audit import AProjectAudit +from planning.aplanning_v2 import APlanningV2 +from ai_engine import AiEngine + +import asyncio + + +logger = logging.opt(colors=True) + +# Global configurations +cfg = nodes_config() + +console = Console() + + + +class Project: + def __init__(self, config: nodes_config) -> None: + self.id = config.id + self.output = config.output + self.path = config.base_dir + self.project = self.start_project(dataset_path=config.base_dir, code_path=config.src_dir) + self.functions: list[str] = [] + self.white_files = self.project["files"] + self.white_functions = self.project.get("functions", []) + self.files = self.project["files"] + + def __repr__(self) -> str: + return f"Project(id={self.id}, path={self.path})" + + def start_project(self, dataset_path:str, code_path:str): + self.AllFiles = False + self.AllMySource = True + srcs: str = f"{dataset_path}/{code_path}" + return { + "path": code_path, + "files": [ + os.path.relpath(os.path.join(root, file), dataset_path) + for root, _, files in os.walk(srcs) + for file in files + if file.endswith(".sol") + ], + "functions": [], + "base_path": dataset_path, + "AllFiles": False, + "AllMySource": True, + } + + +@dataclass +class Context: + def __init__(self, config: nodes_config, db_engine: AsyncEngine) -> None: + self.config: nodes_config = config + self.project: Project = Project(config) + self.id: str = config.id + self.path: str = self.project.path + self.db_engine: AsyncEngine = db_engine + self.all_files: bool = self.project.AllFiles + self.statefile: str = f"{config.id}-funcs.json" + self.output: str = config.output + self.tags: list[str] = [""] + logger.log(31, "Context: CallGraph") + # self.call_graph = CallGraph(root=self.path) + logger.log(31, "Context: ProjectAudit") + self.project_audit = AProjectAudit(config) + self.aproject_audit = AProjectAudit(config) + + logger.log(31, "Contex: RAGProcesssor") + self.rag_processor = ARAGProcessor(config.id, audit=self.project_audit) + + self.project_taskmgr = AProjectTaskMgr(self.id, self.db_engine) + self.planning = APlanningV2( + self.project_audit, + self.project_taskmgr + ) + self.ai_engine: AiEngine = AiEngine(self.planning, self.project_taskmgr, self.rag_processor.db, "lancedb_" + config.id, self.project_audit) + + # classes with complicated startups need to be in a secondary init routine since init is sync + # and these have nothing todo with the pipeline + async def startup(self): + logger.log(31, "Contex: RAGProcesssor") + self.rag_processor = ARAGProcessor(self.config.id, audit=self.project_audit) + if await self.rag_processor.table_exists() and await self.rag_processor.acheck_data_count(len(self.project_audit.functions_to_check)): + print(f"Table {self.rag_processor.table_name} already exists with correct data count. Skipping processing.") + else: + self.rag_processor._create_database(self.project_audit.functions_to_check) + + self.project_taskmgr = AProjectTaskMgr(self.id, self.db_engine) + self.planning = APlanningV2( + self.project_audit, + self.project_taskmgr + ) + self.ai_engine: AiEngine = AiEngine(self.planning, self.project_taskmgr, self.rag_processor.db, "lancedb_" + self.config.id, self.project_audit) + + +hx = AsyncClient(base_url="http://127.0.0.1:11434/api") +client = AsyncOpenAI(http_client=hx) + +async def integrated_pipeline(): + await logView.logdata("test startup logging") + + config = nodes_config() + engine: AsyncEngine = create_async_engine(config.settings.ASYNC_DB_URL, echo=True) + context = Context(config, engine) + + with Live("", console=console, vertical_overflow="visible") as live: + await asyncio.to_thread(context.project_audit.parse, context.project.white_files, context.project.white_functions) + async for stage_callable in context.planning.do_planning(): + result = await stage_callable() + print("Orchestrator received planning stage result:", result) + + while True: + await logView.logdata("The pipeline has completed execution.") + await asyncio.sleep(-1) + + +async def integrated_pipeline(): + await logView.logdata("test startup logging") + + config = nodes_config() + engine: AsyncEngine = create_async_engine(config.ASYNC_DB_URL, echo=True) + context = Context(config, engine) + + with Live("", console=console, vertical_overflow="visible") as live: + await asyncio.to_thread(context.project_audit.parse, context.project.white_files, context.project.white_functions) + async for stage_callable in context.planning.do_planning(): + result = await stage_callable() + print("Orchestrator received planning stage result:", result) + + while True: + await logView.logdata("What time is it???") + await asyncio.sleep(1) + +async def gather_tasks(): + governor = DynamicGovernor(initial_limit=5, window_size=10) + ... + # Simulate a list of tasks with varying delays + # tasks = [governed_task(simulated_task, governor, i, delay) + # for i, delay in enumerate([0.5, 0.7, 1.2, 0.9, 0.6, 0.8, 1.5, 0.4, 0.3, 1.0, 0.5, 0.9])] + + # results = await asyncio.gather(*tasks) + # print("Results:", results) + + +async def async_input(prompt: str, /, *, hide: Optional[bool] = None) -> str: + """Asynchronous equivalent of `input()` function. + Arguments: + prompt (`str`): A prompt message to be displayed. + Keyword Arguments: + hide (`bool`, optional): If `True`, the input text will be hidden. + Returns: + `str`: The input text. + """ + #with ThreadPoolExecutor(1) as executor: + #wrapped_input = partial(getpass if hide else input, prompt) + # with ThreadPoolExecutor(1) as executor: + wrapped_input = partial(getpass if hide else input, prompt) + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, wrapped_input) + +async def setup_live_console(logView): + await logView.logdata("Setting up live console...") + md_logger = LogToMarkdown('output.md') + with redirect_stdout(md_logger), redirect_stderr(md_logger): + await integrated_pipeline() + +async def main(): + await setup_live_console(logView=logView) + + + +# async def main(): +# # Set up the Markdown logger +# md_logger = LogToMarkdown('output.md') + +# with redirect_stdout(md_logger), redirect_stderr(md_logger): +# setup_logging() # Ensure logging is set up before running the pipeline +# await integrated_pipeline() + + +async def main(): + await setup_live_console(logView=logView) + # Set up the Markdown logger +# await integrated_pipeline() + +if __name__ == "__main__": + logView.setup_logging() + asyncio.run(main()) \ No newline at end of file diff --git a/src/run_context.py b/src/run_context.py new file mode 100644 index 00000000..8cb509d1 --- /dev/null +++ b/src/run_context.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from rich.console import Console +from rich.live import Live +from rich.markdown import Markdown +from trace import logging +@dataclass +class RunContext: + console: Console + live: Live + + def update_markdown(self, text: str): + # Update the live display with new Markdown content + self.live.update(Markdown(text)) + logging.log("[RunContext] Updated Markdown output.") diff --git a/src/runui.py b/src/runui.py new file mode 100644 index 00000000..70d4f969 --- /dev/null +++ b/src/runui.py @@ -0,0 +1,513 @@ +import asyncio +import sys +import io +from typing import Any +from textual import app, work +from textual.widgets import Input, RichLog +from textual.widgets import ProgressBar +from textual.reactive import reactive +from textual.app import ComposeResult +from textual.containers import Vertical, VerticalScroll, Horizontal, Container +from textual.widgets import DataTable, Switch, Static +from models.schemas import LogCtxData +from src.agents.md_output import LogView, ingress_txt +from pydantic_ai import Agent + + +from textual.widgets import Button, Sparkline +from datetime import datetime +import shelve + + +class TranslationFilter: + """Translation engine integration""" + def __init__(self): + self.enabled = False + self.log_ctx = LogView() + + async def process_line(self, text: str) -> str | tuple[str, float]: + """Apply translation if enabled""" + if not self.enabled: + return text + + lang = await self.log_ctx.detect_language(text) + ctx = LogCtxData(txtENG=text, txtCN=text) + + try: + trans = await ingress_txt.run( + f"{lang}", + deps=ctx.dict() + ) + return f"{text} → {trans.data}" + except Exception as e: + return f"{text} [Translation Error: {str(e)}]" + +class TranslationUI(Container): + """Bilingual output interface""" + def compose(self) -> ComposeResult: + # yield Horizontal( + en_table: DataTable = DataTable() + zh_table: DataTable = DataTable() + + yield Horizontal( + Vertical(en_table), + Vertical(zh_table) + ) + + # Add rows to the tables + style_class = "quality-poor" + source = "Hello, World!" + translation = "你好,世界!" + + en_table.add_row(f"[{style_class}]{source}") + zh_table.add_row(f"[{style_class}]{translation}") + DataTable(id="output-english", zebra_stripes=True), + DataTable(id="output-chinese", zebra_stripes=True), + classes="translation-columns" + # ) + + def add_line(self, en: str, zh: str): + self.query_one("#output-english", DataTable).add_row(en) + self.query_one("#output-chinese", DataTable).add_row(zh) + +class BaseWindow(app.App): + """Main application window with composed layout""" + CSS = """ + Screen { + layout: grid; + grid-size: 1; + grid-rows: 1fr 8fr 1fr; + } + + #spinner-container { + height: 1; + background: $surface; + } + + #main-content { + height: 100%; + overflow: hidden; + } + + #ipython-container { + height: 1fr; + } + """ + + def compose(self) -> ComposeResult: + """Declarative layout composition""" + yield Vertical( + Horizontal( + ThinSpinner(show_eta=False, id="spinner"), + id="spinner-container" + ), + Vertical( + CustomRichUI(), + id="custom-ui" + ), + VerticalScroll( + RichLog(id="output-view"), + Input(placeholder=">>> "), + id="ipython-container" + ) + ) + +class EnhancedIPythonConsole(BaseWindow): + """IPython console with translation capabilities""" + CSS = """ + .translation-columns { + width: 100%; + height: 1fr; + grid-size: 2; + } + + #translation-toggle { + dock: top; + height: 1; + background: $surface; + } + """ + + def __init__(self): + super().__init__() + self.translation_filter = TranslationFilter() + self.translation_ui = TranslationUI() + self.output_queue = asyncio.Queue() + + def compose(self) -> ComposeResult: + yield Vertical( + Horizontal( + Switch(id="translation-toggle"), + Static("实时翻译", classes="toggle-label"), + id="translation-control" + ), + self.translation_ui + ) + yield from super().compose() + + async def process_output(self) -> None: + """Handle output with translation support""" + while not self.output_queue.empty(): + data = await self.output_queue.get() + + if self.translation_filter.enabled: + translated = await self.translation_filter.process_line(data) + en, zh = self.parse_translation(translated) + self.translation_ui.add_line(en, zh) + else: + self.query_one("#output-view", RichLog).write(data) + self.query_one("#output-view", RichLog).scroll_end(animate=False) + + def parse_translation(self, text: str) -> tuple[str, str]: + """Split translated text into EN/CN components""" + if "→" in text: + parts = text.split("→", 1) + return parts[0].strip(), parts[1].strip() + return text, text + + async def on_switch_changed(self, event: Switch.Changed) -> None: + """Handle translation toggle""" + self.translation_filter.enabled = event.value + self.translation_ui.display = event.value + self.refresh_layout() + + def refresh_layout(self): + """Adjust layout based on translation state""" + if self.translation_filter.enabled: + self.query_one("RichLog").display = False + self.translation_ui.display = True + else: + self.query_one("RichLog").display = True + self.translation_ui.display = False + +class EnhancedTranslationFilter(TranslationFilter): + """Translation engine with caching and quality metrics""" + def __init__(self): + super().__init__() + self.cache = shelve.open("translation_cache") + self.quality_metrics = {} + + async def process_line(self, text: str) -> tuple[str, float]: + """Return translated text with quality score (0-1)""" + cache_key = f"{hash(text)}-{self.target_lang}" + + # Check cache first + if cache_key in self.cache: + cached = self.cache[cache_key] + self.quality_metrics[cache_key] = cached['quality'] + return cached['text'], cached['quality'] + + # New translation + start_time = datetime.now() + try: + ctx = LogCtxData(txtENG=text, txtCN=text) + user_prompt = await self.log_ctx.detect_language(text) + trans = await ingress_txt.run(user_prompt, deps=ctx.dict()) + translation = trans.data + quality = self._calculate_quality(text, translation, start_time) + + # Cache results + self.cache[cache_key] = { + 'text': translation, + 'quality': quality, + 'timestamp': datetime.now().isoformat() + } + return translation, quality + except Exception as e: + return f"[Error: {str(e)}]", 0.0 + + def _calculate_quality(self, source: str, translation: str, start: datetime) -> float: + """Calculate translation quality score (mock implementation)""" + time_diff = (datetime.now() - start).total_seconds() + length_factor = min(len(source) / 100, 1.0) + return max(0.0, min(1.0 - (time_diff * 0.1), 0.9)) * length_factor + +class QualityIndicator(Sparkline): + """Visual quality indicator sparkline""" + def __init__(self): + super().__init__([], summary_function="max") + self.border_title = "Quality" + + def update_quality(self, score: float): + new_values = list(self.data[-9:]) + [score * 100] + self.data = new_values[-10:] + +class TranslationToggle(Button): + """Custom toggle button with status indicators""" + def __init__(self): + super().__init__("🌐 TRANSLATE", id="translation-toggle") + self.quality_indicator = QualityIndicator() + + def compose(self) -> ComposeResult: + yield self.quality_indicator + yield from super().compose() + +class EnhancedTranslationUI(Container): + """Enhanced UI with quality visualization""" + CSS = """ + #quality-header { + height: 3; + border-bottom: heavy $accent; + } + .quality-good { color: green; } + .quality-medium { color: yellow; } + .quality-poor { color: red; } + """ + + def compose(self) -> ComposeResult: + yield Horizontal( + Vertical( + Static("[b]English", id="en-header"), + DataTable(id="output-english"), + classes="column" + ), + Vertical( + Static("[b]中文", id="zh-header"), + DataTable(id="output-chinese"), + classes="column" + ), + Vertical( + Static("Quality Metrics", id="quality-header"), + QualityIndicator(), + classes="quality-panel" + ), + classes="translation-grid" + ) + + def add_translation(self, source: str, translation: str, quality: float): + # Add to tables with quality coloring + en_table = self.query_one("#output-english", DataTable) + zh_table = self.query_one("#output-chinese", DataTable) + + style_class = ( + "quality-good" if quality > 0.7 else + "quality-medium" if quality > 0.4 else + "quality-poor" + ) + en_table.add_row(f"[{style_class}]{source}") + zh_table.add_row(f"[{style_class}]{translation}") + # en_table.add_row(f"[{style_class}]{source}") + # zh_table.add_row(f"[{style_class}]{translation}") + + # Update quality sparkline + self.query_one(QualityIndicator).update_quality(quality) + +# class EnhancedIPythonConsole(BaseWindow): +# """Final integrated console with all features""" +# def compose(self) -> ComposeResult: +# yield Horizontal( +# TranslationToggle(), +# Static("|"), +# Button("Clear Cache", id="clear-cache"), +# id="control-bar" +# ) +# yield EnhancedTranslationUI() +# yield super().compose() + +# async def on_button_pressed(self, event: Button.Pressed): +# if event.button.id == "translation-toggle": +# self.translation_filter.enabled = not self.translation_filter.enabled +# event.button.label = "🌐 TRANSLATING" if self.translation_filter.enabled else "🌐 TRANSLATE" +# self.refresh_layout() +# elif event.button.id == "clear-cache": +# self.translation_filter.cache.clear() +# self.notify("Translation cache cleared!", severity="information") + +# async def process_output(self) -> None: +# while not self.output_queue.empty(): +# data = await self.output_queue.get() + +# if self.translation_filter.enabled: +# translated, quality = await self.translation_filter.process_line(data) +# self.translation_ui.add_translation(data, translated, quality) +# else: +# self.rich_log.write(data) + +# self.refresh_quality_display() + +# def refresh_quality_display(self): +# """Update quality indicators based on recent metrics""" +# if self.translation_filter.quality_metrics: +# avg_quality = sum(self.translation_filter.quality_metrics.values()) / len(self.translation_filter.quality_metrics) +# self.query_one(QualityIndicator).update_quality(avg_quality) + +class CustomRichUI(RichLog): + """Placeholder for custom rich UI widget""" + pass + +class ThinSpinner(ProgressBar): + """Animated thin progress bar spinner""" + _animation_progress = reactive(0.0) + + CSS = """ + ThinSpinner { + height: 1; + width: 100%; + background: $surface; + color: red; + border: none; + margin: 0; + } + + ThinSpinner > .progress--bar { + background: red; + min-height: 1; + width: 100%; + background: $surface; + margin: 0; + color: red; + border: solid yellow; + } + """ + + def on_mount(self) -> None: + # Start animation with slight delay to ensure CSS loads + self.call_later(lambda: self.animate("_animation_progress", 1.0, duration=1.5, on_complete=self.on_mount)) + self.set_interval(0.05, self.update) + + def updte_animation_progress(self, progress: float) -> None: + self.progress = int(progress * 100) + + # Rest of your mount logic +class IPythonIO(io.TextIOBase): + """Thread-safe I/O redirection for IPython""" + def __init__(self, queue: asyncio.Queue, main_loop: asyncio.AbstractEventLoop): + self.queue = queue + self.main_loop = main_loop + + def write(self, data: str) -> int: + asyncio.run_coroutine_threadsafe( + self.queue.put(data), + loop=self.main_loop + ) + return len(data) + + def flush(self) -> None: + pass + +# class IPythonConsole(BaseWindow): +# """IPython-integrated console inheriting from base window""" +# def on_mount(self) -> None: +# # Initialize IPython components +# self.main_loop = asyncio.get_running_loop() +# self.start_ipython() +# self.set_interval(0.05, self.process_output) + +# # Keep previous IPython integration methods +# @work(thread=True) +# # def start_ipython(self) -> None: + # # Same thread setup as before + # ... + + # def process_output(self) -> None: + # # Same output handling + # ... + + # Rest of your IPython methods +class IPythonConsole(app.App): + CSS = """ + Screen { + layout: vertical; + } + + Vertical { + height: auto; + } + + ThinSpinner { + height: 1; + margin: 0; + } + + /* Rest of your CSS */ + """ + + async def on_mount(self) -> None: + # Mount spinner first in the layout + await self.mount( + ThinSpinner(show_eta=False, total=100, show_bar=True ), # Explicit init + CustomRichUI(), + VerticalScroll( + RichLog(id="output-view"), + Input(placeholder=">>> ") + ) + ) + + + def __init__(self): + super().__init__() + self.input_queue = asyncio.Queue() + self.output_queue = asyncio.Queue() + self.main_loop = None + + async def on_mount(self) -> None: + """Initialize application components""" + self.main_loop = asyncio.get_running_loop() + self.rich_log = RichLog() + self.input_widget = Input(placeholder=">>> ") + await self.mount(self.rich_log, self.input_widget) + self.start_ipython() + self.set_interval(0.05, self.process_output) + + @work(thread=True) + def start_ipython(self) -> None: + """Launch IPython in a background thread""" + ipy_loop = asyncio.new_event_loop() + asyncio.set_event_loop(ipy_loop) + + sys.stdin = self + sys.stdout = IPythonIO(self.output_queue, self.main_loop) + sys.stderr = IPythonIO(self.output_queue, self.main_loop) + + from IPython import start_ipython + try: + start_ipython( + argv=[], + user_ns=self.get_ipython_namespace(), + display_banner=False + ) + finally: + sys.stdin = sys.__stdin__ + sys.stdout = sys.__stdout__ + + def get_ipython_namespace(self) -> dict: + """Provide objects accessible in IPython REPL""" + return { + "app": self, + "run_async": self.run_in_main_loop, + "fetch_data": self.sample_async_method + } + + def run_in_main_loop(self, coro) -> Any: + """Execute async code in main thread's event loop""" + return asyncio.run_coroutine_threadsafe( + coro, + loop=self.main_loop + ).result() + + async def sample_async_method(self) -> str: + """Example async method callable from IPython""" + await asyncio.sleep(1) + return "Data fetched successfully!" + + async def process_output(self) -> None: + """Update UI with output from IPython""" + while not self.output_queue.empty(): + data = await self.output_queue.get() + self.rich_log.write(data) + self.rich_log.scroll_end(animate=False) + + def readline(self, size: int = -1) -> str: + """Get input from queue (blocking in IPython thread)""" + return asyncio.run_coroutine_threadsafe( + self.input_queue.get(), + loop=self.main_loop + ).result() + + async def on_input_submitted(self, event: Input.Submitted) -> None: + """Handle user input submissions""" + await self.input_queue.put(event.value + "\n") + self.input_widget.clear() + +if __name__ == "__main__": + IPythonConsole().run() \ No newline at end of file diff --git a/src/scripts/__init__.py b/src/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/tracing.py b/src/tracing.py new file mode 100644 index 00000000..521401a8 --- /dev/null +++ b/src/tracing.py @@ -0,0 +1,27 @@ +from loguru import logger +from sys import stdout +import functools +# Configure local logging to output both to a file and to the console. +logger.add("local_trace.log", format="{time} {level}: {message}", level="DEBUG") +logger.add(stdout, format="{time} {level}: {message}", level="DEBUG") + +def trace(func): + """ + A decorator for local code tracing that logs the entry, exit, and result of the function. + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + logger.debug(f"Entering {func.__name__} with args={args} kwargs={kwargs}") + logger = func(*args, **kwargs) + logger.debug(f"Exiting {func.__name__} with result={result}") + return result + return wrapper + +# Example usage: +@trace +def example_function(x, y): + return x + y + +if __name__ == "__main__": + result = example_function(2, 3) + print("Result:", result) diff --git a/ui/widgets.py b/ui/widgets.py new file mode 100644 index 00000000..d3b92fba --- /dev/null +++ b/ui/widgets.py @@ -0,0 +1,79 @@ +from typing import Any, Union +from rich.text import Text +from textual.widgets import ProgressBar, DataTable, Switch +from textual.reactive import reactive + +# class BilingualTable(DataTable): +# def __init__(self, title: str): +# super().__init__(zebra_stripes=True) +# self.border_title = title +# self.add_column("Content") +class ThinSpinner(ProgressBar): + """Animated thin progress bar spinner""" + _animation_progress = reactive(0.0) + + CSS = """ + ThinSpinner { + height: 1; + width: 100%; + background: $surface; + color: red; + border: none; + margin: 0; + } + + ThinSpinner > .progress--bar { + background: red; + min-height: 1; + width: 100%; + background: $surface; + margin: 0; + color: red; + # border: solid yellow; + } + """ + + def on_mount(self) -> None: + # Start animation with slight delay to ensure CSS loads + self.call_later(lambda: self.animate("_animation_progress", 1.0, duration=1.5, on_complete=self.on_mount)) + self.set_interval(0.05, self.update) + + def updte_animation_progress(self, progress: float) -> None: + self.progress = int(progress * 100) +# python +# python +class TranslationToggle(Switch): + def __init__(self, value: bool = False): + super().__init__(value) + self.label = "🌐 TRANSLATE ON" if value else "🌐 TRANSLATE OFF" + + def on_change(self, event: Switch.Changed): + self.label = "🌐 TRANSLATE ON" if event.value else "🌐 TRANSLATE OFF" + + + + +# class TranslationToggle(Switch): +# def __init__(self): +# super().__init__("🌐 TRANSLATE OFF") + +# def on_change(self, event: Switch.Changed): +# self.label = "🌐 TRANSLATE ON" if event.value else "🌐 TRANSLATE OFF" + +class BilingualTable(DataTable): + STYLES = { + "EN": ("bold #1F618D", "▷"), + "CN": ("bold #C0392B", "◁") + } + + def __init__(self, title: str, lang: str): + super().__init__(zebra_stripes=True, cursor_type="row") + self.border_title = f"{self.STYLES[lang][1]} {title}" + self.lang = lang + self.add_columns("Content") + def add_row(self, *cells: Any, height: int = 1, key: str = "", label: Union[str, Text] = "") -> Any: + #def add_row(self, *cells: Any, height: int = 1, key: str = "", label: str = "") -> Any: + return super().add_row(*cells, height=height, key=key, label=label) + #color, _ = self.STYLES[self.lang] + #styled = f"[{color}]{text}[/]" + #super().add_row(styled) \ No newline at end of file