diff --git a/src/db.sql b/src/db.sql index f72796b7..f85cbaed 100644 --- a/src/db.sql +++ b/src/db.sql @@ -1,37 +1,36 @@ /* - Navicat Premium Data Transfer - - Source Server : postgres - Source Server Type : PostgreSQL - Source Server Version : 160002 (160002) - Source Host : localhost:5432 - Source Catalog : postgres - Source Schema : public - - Target Server Type : PostgreSQL - Target Server Version : 160002 (160002) - File Encoding : 65001 - - Date: 07/04/2024 16:27:37 +Navicat Premium Data Transfer +Source Server : postgres +Source Server Type : PostgreSQL +Source Server Version : 160002 (160002) +Source Host : localhost:5432 +Source Catalog : postgres +Source Schema : public +Target Server Type : PostgreSQL +Target Server Version : 160002 (160002) +File Encoding : 65001 +Date: 07/04/2024 16:27:37 */ - -- ---------------------------- -- Sequence structure for project_tasks_id_seq -- ---------------------------- DROP SEQUENCE IF EXISTS "public"."project_tasks_id_seq"; + CREATE SEQUENCE "public"."project_tasks_id_seq" INCREMENT 1 MINVALUE 1 MAXVALUE 2147483647 START 1 CACHE 1; + ALTER SEQUENCE "public"."project_tasks_id_seq" OWNER TO "postgres"; -- ---------------------------- -- Table structure for project_tasks -- ---------------------------- DROP TABLE IF EXISTS "public"."project_tasks"; + CREATE TABLE "public"."project_tasks" ( "id" int4 NOT NULL DEFAULT nextval('project_tasks_id_seq'::regclass), "key" varchar COLLATE "pg_catalog"."default", @@ -59,12 +58,14 @@ CREATE TABLE "public"."project_tasks" ( "title" varchar COLLATE "pg_catalog"."default" ) ; + ALTER TABLE "public"."project_tasks" OWNER TO "postgres"; -- ---------------------------- -- Table structure for project_tasks_amazing_prompt -- ---------------------------- DROP TABLE IF EXISTS "public"."project_tasks_amazing_prompt"; + CREATE TABLE "public"."project_tasks_amazing_prompt" ( "id" int4 NOT NULL DEFAULT nextval('project_tasks_id_seq'::regclass), "key" varchar COLLATE "pg_catalog"."default", @@ -96,6 +97,7 @@ CREATE TABLE "public"."project_tasks_amazing_prompt" ( "if_business_flow_scan" varchar COLLATE "pg_catalog"."default" ) ; + ALTER TABLE "public"."project_tasks_amazing_prompt" OWNER TO "postgres"; -- ---------------------------- @@ -103,7 +105,8 @@ ALTER TABLE "public"."project_tasks_amazing_prompt" OWNER TO "postgres"; -- ---------------------------- ALTER SEQUENCE "public"."project_tasks_id_seq" OWNED BY "public"."project_tasks"."id"; -SELECT setval('"public"."project_tasks_id_seq"', 98390, true); + +SELECT setval ( '"public"."project_tasks_id_seq"', 98390, true ); -- ---------------------------- -- Indexes structure for table project_tasks @@ -111,6 +114,7 @@ SELECT setval('"public"."project_tasks_id_seq"', 98390, true); CREATE INDEX "ix_project_tasks_key" ON "public"."project_tasks" USING btree ( "key" COLLATE "pg_catalog"."default" "pg_catalog"."text_ops" ASC NULLS LAST ); + CREATE INDEX "ix_project_tasks_project_id" ON "public"."project_tasks" USING btree ( "project_id" COLLATE "pg_catalog"."default" "pg_catalog"."text_ops" ASC NULLS LAST ); @@ -126,6 +130,7 @@ ALTER TABLE "public"."project_tasks" ADD CONSTRAINT "project_tasks_pkey" PRIMARY CREATE INDEX "ix_project_tasks_key_copy1_copy1" ON "public"."project_tasks_amazing_prompt" USING btree ( "key" COLLATE "pg_catalog"."default" "pg_catalog"."text_ops" ASC NULLS LAST ); + CREATE INDEX "ix_project_tasks_project_id_copy1_copy1" ON "public"."project_tasks_amazing_prompt" USING btree ( "project_id" COLLATE "pg_catalog"."default" "pg_catalog"."text_ops" ASC NULLS LAST ); @@ -134,3 +139,16 @@ CREATE INDEX "ix_project_tasks_project_id_copy1_copy1" ON "public"."project_task -- Primary Key structure for table project_tasks_amazing_prompt -- ---------------------------- ALTER TABLE "public"."project_tasks_amazing_prompt" ADD CONSTRAINT "project_tasks_copy1_copy1_pkey" PRIMARY KEY ("id"); + +Drop table if exists cache; + +CREATE TABLE IF NOT EXISTS cache ( + id SERIAL PRIMARY KEY, + model_type VARCHAR(50) NOT NULL, + data_hash VARCHAR(64) NOT NULL, + response_data TEXT NOT NULL, + created_at TIMESTAMP + WITH + TIME ZONE DEFAULT CURRENT_TIMESTAMP, + UNIQUE (model_type, data_hash) +); \ No newline at end of file diff --git a/src/openai_api/openai.py b/src/openai_api/openai.py index 6ad43db7..f64e444a 100644 --- a/src/openai_api/openai.py +++ b/src/openai_api/openai.py @@ -2,7 +2,10 @@ import os import numpy as np import requests +from utils.cache_manager import CacheManager +# Initialize cache manager +cache_manager = CacheManager() def azure_openai(prompt): # Azure OpenAI配置 @@ -10,6 +13,20 @@ def azure_openai(prompt): api_base = os.environ.get('AZURE_API_BASE') api_version = os.environ.get('AZURE_API_VERSION') deployment_name = os.environ.get('AZURE_DEPLOYMENT_NAME') + + # Prepare request data + request_data = { + "messages": [ + {"role": "system", "content": "你是一个熟悉智能合约与区块链安全的安全专家。"}, + {"role": "user", "content": prompt} + ] + } + + # Check cache first + cached_response = cache_manager.get_cached_response('azure', request_data) + if cached_response: + return cached_response + # 构建URL url = f"{api_base}openai/deployments/{deployment_name}/chat/completions?api-version={api_version}" # 设置请求头 @@ -17,23 +34,20 @@ def azure_openai(prompt): "Content-Type": "application/json", "api-key": api_key } - # 设置请求体 - data = { - "messages": [ - {"role": "system", "content": "你是一个熟悉智能合约与区块链安全的安全专家。"}, - {"role": "user", "content": prompt} - ], - # "max_tokens": 150 - } + try: # 发送POST请求 - response = requests.post(url, headers=headers, json=data) + response = requests.post(url, headers=headers, json=request_data) # 检查响应状态 response.raise_for_status() # 解析JSON响应 result = response.json() - # 打印响应 - return result['choices'][0]['message']['content'] + response_content = result['choices'][0]['message']['content'] + + # Cache the response + cache_manager.cache_response('azure', request_data, response_content) + + return response_content except requests.exceptions.RequestException as e: print("Azure OpenAI测试失败。错误:", str(e)) return None @@ -45,15 +59,9 @@ def azure_openai_json(prompt): api_base = os.environ.get('AZURE_API_BASE') api_version = os.environ.get('AZURE_API_VERSION') deployment_name = os.environ.get('AZURE_DEPLOYMENT_NAME') - # 构建URL - url = f"{api_base}openai/deployments/{deployment_name}/chat/completions?api-version={api_version}" - # 设置请求头 - headers = { - "Content-Type": "application/json", - "api-key": api_key - } - # 设置请求体 - data = { + + # Prepare request data + request_data = { "response_format": { "type": "json_object" }, "messages": [ { @@ -66,52 +74,122 @@ def azure_openai_json(prompt): } ] } + + # Check cache first + cached_response = cache_manager.get_cached_response('azure', request_data) + if cached_response: + return cached_response + + # 构建URL + url = f"{api_base}openai/deployments/{deployment_name}/chat/completions?api-version={api_version}" + # 设置请求头 + headers = { + "Content-Type": "application/json", + "api-key": api_key + } + try: # 发送POST请求 - response = requests.post(url, headers=headers, json=data) + response = requests.post(url, headers=headers, json=request_data) # 检查响应状态 response.raise_for_status() # 解析JSON响应 result = response.json() - # 打印响应 - return result['choices'][0]['message']['content'] + response_content = result['choices'][0]['message']['content'] + + # Cache the response + cache_manager.cache_response('azure', request_data, response_content) + + return response_content except requests.exceptions.RequestException as e: print("Azure OpenAI测试失败。错误:", str(e)) return None def ask_openai_common(prompt): - api_base = os.environ.get('OPENAI_API_BASE', 'api.openai.com') # Replace with your actual OpenAI API base URL - api_key = os.environ.get('OPENAI_API_KEY') # Replace with your actual OpenAI API key - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}" - } - data = { - "model": os.environ.get('VUL_MODEL_ID'), # Replace with your actual OpenAI model - "messages": [ - { - "role": "user", - "content": prompt - } - ] - } - response = requests.post(f'https://{api_base}/v1/chat/completions', headers=headers, json=data) - try: - response_josn = response.json() - except Exception as e: - return '' - if 'choices' not in response_josn: - return '' - return response_josn['choices'][0]['message']['content'] -def ask_openai_for_json(prompt): - api_base = os.environ.get('OPENAI_API_BASE', 'api.openai.com') # Replace with your actual OpenAI API base URL - api_key = os.environ.get('OPENAI_API_KEY') # Replace with your actual OpenAI API key + api_base = os.environ.get('OPENAI_API_BASE', 'api.openai.com') + api_key = os.environ.get('OPENAI_API_KEY') + + # Prepare request data + request_data = { + "model": os.environ.get('VUL_MODEL_ID'), + "messages": [ + { + "role": "user", + "content": prompt + } + ] + } + + # Check cache first + cached_response = cache_manager.get_cached_response('openai', request_data) + if cached_response: + return cached_response + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + + response = requests.post(f'https://{api_base}/v1/chat/completions', headers=headers, json=request_data) + try: + response_json = response.json() + response_content = response_json['choices'][0]['message']['content'] + + # Cache the response + cache_manager.cache_response('openai', request_data, response_content) + + return response_content + except Exception as e: + print(f"Error in ask_openai_common: {str(e)}") + return None + +def ask_deepseek_common(prompt): + api_base = os.environ.get('OPENAI_API_BASE', 'api.openai.com') + api_key = os.environ.get('OPENAI_API_KEY') + + # Prepare request data + request_data = { + "model": os.environ.get('VUL_MODEL_ID'), + "messages": [ + { + "role": "user", + "content": prompt + } + ], + "stream": False, + "temperature":0, + } + + # Check cache first + cached_response = cache_manager.get_cached_response('deepseek', request_data) + if cached_response: + return cached_response + headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}" } - data = { + + response = requests.post(f'https://{api_base}/v1/chat/completions', headers=headers, json=request_data) + try: + response_json = response.json() + response_content = response_json['choices'][0]['message']['content'] + + # Cache the response + cache_manager.cache_response('deepseek', request_data, response_content) + + return response_content + except Exception as e: + print(f"Error in ask_deepseek_common: {str(e)}") + return None + +def ask_openai_for_json(prompt): + api_base = os.environ.get('OPENAI_API_BASE', 'api.openai.com') + api_key = os.environ.get('OPENAI_API_KEY') + + # Prepare request data + request_data = { "model": os.environ.get('VUL_MODEL_ID'), "response_format": { "type": "json_object" }, "messages": [ @@ -125,30 +203,43 @@ def ask_openai_for_json(prompt): } ] } - response = requests.post(f'https://{api_base}/v1/chat/completions', headers=headers, json=data) - if response.status_code != 200: - print(response.text) - response_josn = response.json() - if 'choices' not in response_josn: - return '' - return response_josn['choices'][0]['message']['content'] + + # Check cache first + cached_response = cache_manager.get_cached_response('openai', request_data) + if cached_response: + return cached_response + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + + response = requests.post(f'https://{api_base}/v1/chat/completions', headers=headers, json=request_data) + try: + response_json = response.json() + response_content = response_json['choices'][0]['message']['content'] + + # Cache the response + cache_manager.cache_response('openai', request_data, response_content) + + return response_content + except Exception as e: + print(f"Error in ask_openai_for_json: {str(e)}") + return None def common_ask_for_json(prompt): if os.environ.get('AZURE_OR_OPENAI') == 'AZURE': return azure_openai_json(prompt) else: return ask_openai_for_json(prompt) + def ask_claude(prompt): model = os.environ.get('CLAUDE_MODEL', 'claude-3-5-sonnet-20240620') api_key = os.environ.get('OPENAI_API_KEY') api_base = os.environ.get('OPENAI_API_BASE', 'https://apix.ai-gaochao.cn') - headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {api_key}' - } - - data = { + # Prepare request data + request_data = { 'model': model, 'messages': [ { @@ -157,17 +248,29 @@ def ask_claude(prompt): } ] } + + # Check cache first + cached_response = cache_manager.get_cached_response('claude', request_data) + if cached_response: + return cached_response + + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {api_key}' + } try: response = requests.post(f'https://{api_base}/v1/chat/completions', headers=headers, - json=data) + json=request_data) response.raise_for_status() response_data = response.json() - if 'choices' in response_data and len(response_data['choices']) > 0: - return response_data['choices'][0]['message']['content'] - else: - return "" + response_content = response_data['choices'][0]['message']['content'] + + # Cache the response + cache_manager.cache_response('claude', request_data, response_content) + + return response_content except requests.exceptions.RequestException as e: print(f"Claude API调用失败。错误: {str(e)}") return "" @@ -178,6 +281,8 @@ def common_ask(prompt): return azure_openai(prompt) elif model_type == 'CLAUDE': return ask_claude(prompt) + elif model_type == 'DEEPSEEK': + return ask_deepseek_common(prompt) else: return ask_openai_common(prompt) @@ -192,24 +297,72 @@ def common_get_embedding(text: str): api_base = os.getenv('OPENAI_API_BASE', 'api.openai.com') model = os.getenv("PRE_TRAIN_MODEL", "text-embedding-3-large") + # Prepare request data + request_data = { + "input": clean_text(text), + "model": model, + "encoding_format": "float" + } + + # Check cache first + cached_response = cache_manager.get_cached_response('openai_embedding', request_data) + if cached_response: + return json.loads(cached_response) + headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } - cleaned_text = clean_text(text) + try: + response = requests.post(f'https://{api_base}/v1/embeddings', json=request_data, headers=headers) + response.raise_for_status() + embedding_data = response.json() + response_content = embedding_data['data'][0]['embedding'] + + # Cache the response + cache_manager.cache_response('openai_embedding', request_data, json.dumps(response_content)) + + return response_content + except requests.exceptions.RequestException as e: + print(f"Error: {e}") + return list(np.zeros(3072)) # 返回长度为3072的全0数组 + +def common_get_embedding2(text: str): + api_key = os.getenv('OPENAI_EMBEDDING_API_KEY') + if not api_key: + raise ValueError("OPENAI_EMBEDDING_API_KEY environment variable is not set") + + api_base = os.getenv('OPENAI_EMBEDDING_BASE', 'https://ark.cn-beijing.volces.com/api/v3') + model = os.getenv("OPENAI_EMBEDDING_MODEL", "ep-20241218223410-kwbkm") - data = { - "input": cleaned_text, + # Prepare request data + request_data = { + "input": clean_text(text), "model": model, "encoding_format": "float" } + + # Check cache first + cached_response = cache_manager.get_cached_response('custom_embedding', request_data) + if cached_response: + return json.loads(cached_response) + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } try: - response = requests.post(f'https://{api_base}/v1/embeddings', json=data, headers=headers) + response = requests.post(f'https://{api_base}/embeddings', json=request_data, headers=headers) response.raise_for_status() embedding_data = response.json() - return embedding_data['data'][0]['embedding'] + response_content = embedding_data['data'][0]['embedding'] + + # Cache the response + cache_manager.cache_response('custom_embedding', request_data, json.dumps(response_content)) + + return response_content except requests.exceptions.RequestException as e: print(f"Error: {e}") - return list(np.zeros(3072)) # 返回长度为3072的全0数组 + return list(np.zeros(4096)) # 返回长度为4096的全0数组 \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/utils/cache_manager.py b/src/utils/cache_manager.py new file mode 100644 index 00000000..583c5716 --- /dev/null +++ b/src/utils/cache_manager.py @@ -0,0 +1,55 @@ +import hashlib +import json +import psycopg2 +from psycopg2.extras import DictCursor +import os +from datetime import datetime +from typing import Optional, Any +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +class CacheManager: + def __init__(self): + database_url = os.environ.get('DATABASE_URL', 'postgresql://postgres:postgres@localhost:5432/postgres') + self.conn = psycopg2.connect(database_url) + + def _generate_hash(self, data: dict) -> str: + """Generate a unique hash for the request data""" + # Sort the dictionary to ensure consistent hashing + serialized = json.dumps(data, sort_keys=True) + return hashlib.sha256(serialized.encode()).hexdigest() + + def get_cached_response(self, model_type: str, request_data: dict) -> Optional[str]: + """Retrieve cached response if it exists""" + data_hash = self._generate_hash(request_data) + + with self.conn.cursor(cursor_factory=DictCursor) as cur: + cur.execute( + "SELECT response_data FROM cache WHERE model_type = %s AND data_hash = %s", + (model_type, data_hash) + ) + result = cur.fetchone() + return result['response_data'] if result else None + + def cache_response(self, model_type: str, request_data: dict, response_data: str) -> None: + """Cache the response data""" + data_hash = self._generate_hash(request_data) + + with self.conn.cursor() as cur: + cur.execute( + """ + INSERT INTO cache (model_type, data_hash, response_data) + VALUES (%s, %s, %s) + ON CONFLICT (model_type, data_hash) + DO UPDATE SET response_data = EXCLUDED.response_data + """, + (model_type, data_hash, response_data) + ) + self.conn.commit() + + def __del__(self): + """Close the database connection when the object is destroyed""" + if hasattr(self, 'conn'): + self.conn.close() \ No newline at end of file