-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrag_query.py
More file actions
228 lines (196 loc) · 10.5 KB
/
rag_query.py
File metadata and controls
228 lines (196 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import os
from langchain.llms import HuggingFaceHub
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.document_loaders import DataFrameLoader
import pandas as pd
from dotenv import load_dotenv
import vector_store
import streamlit as st
# Load environment variables
load_dotenv()
# Get API keys from environment variables
huggingface_api_key = os.getenv("HUGGINGFACE_API_KEY", "")
openai_api_key = os.getenv("OPENAI_API_KEY", "")
# Initialize LLM based on available API keys and configuration
def get_llm(provider="OpenAI"):
try:
# Always try to use OpenAI first if the key is available
if openai_api_key:
# the newest OpenAI model is "gpt-4o" which was released May 13, 2024.
# do not change this unless explicitly requested by the user
llm = ChatOpenAI(
model_name="gpt-4o", # Using the latest GPT-4o model
temperature=0.7,
openai_api_key=openai_api_key
)
return llm
# Fall back to Hugging Face models if OpenAI key is not available
elif huggingface_api_key:
if provider == "DeepSeek-V3":
# Try a different, more reliable model on Hugging Face
llm = HuggingFaceHub(
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
model_kwargs={"temperature": 0.7, "max_length": 512},
huggingfacehub_api_token=huggingface_api_key
)
else:
# Default to a simpler, more reliable model
llm = HuggingFaceHub(
repo_id="google/flan-t5-xl",
model_kwargs={"temperature": 0.7, "max_length": 512},
huggingfacehub_api_token=huggingface_api_key
)
return llm
else:
raise ValueError("No valid API key found. Please provide either OpenAI or Hugging Face API key.")
except Exception as e:
st.error(f"Error initializing language model: {str(e)}")
return None
# Initialize embeddings model
def get_embeddings():
return HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={"device": "cpu"}
)
# Process dataset for RAG
def process_dataset_for_rag(df):
# Create text representation of the dataset
if df is None:
return None
# Create a summary of the dataset
summary = [
f"Dataset with {df.shape[0]} rows and {df.shape[1]} columns.",
f"Columns: {', '.join(df.columns.tolist())}",
f"Data types: {', '.join([f'{col}: {dtype}' for col, dtype in zip(df.columns, df.dtypes.astype(str))])}"
]
# Add some sample data
sample_data = []
for i, row in df.head(5).iterrows():
sample_data.append(f"Row {i}: {', '.join([f'{col}: {val}' for col, val in row.items()])}")
# Add statistical summary
stats = []
for col in df.select_dtypes(include=['number']).columns:
stats.append(f"Column {col} - min: {df[col].min()}, max: {df[col].max()}, mean: {df[col].mean()}, median: {df[col].median()}")
# Combine all text
text = "\n".join(summary + ["Sample data:"] + sample_data + ["Statistical summary:"] + stats)
# Create a document for the vector store
from langchain.docstore.document import Document
doc = Document(page_content=text, metadata={"source": "uploaded_dataset"})
return [doc]
# Create or retrieve vector store
def get_vector_store(df=None):
# If df is provided, create a new vector store
if df is not None:
documents = process_dataset_for_rag(df)
if documents:
embeddings = get_embeddings()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
splits = text_splitter.split_documents(documents)
vector_store = Chroma.from_documents(documents=splits, embedding=embeddings)
return vector_store
# Otherwise, try to load from persistent storage
try:
embeddings = get_embeddings()
vector_store = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
return vector_store
except:
# If no persistent store exists and no df provided, return None
return None
# Process a query using direct OpenAI response with fallback options
def process_query(query, df=None):
import time
# Add a timeout to avoid infinite waiting
try:
# Check if OpenAI is available
if openai_api_key:
# Directly use OpenAI for the best experience
try:
from openai import OpenAI
# Initialize the OpenAI client
client = OpenAI(api_key=openai_api_key)
# System message to guide the AI
system_message = """
You are CeCe (Climate Copilot), an AI assistant specializing in climate and weather data analysis.
You help users with climate data visualization, scientific calculations, and understanding weather patterns.
Your responses should be friendly, helpful, and focused on climate science.
"""
# Send the request to OpenAI
response = client.chat.completions.create(
model="gpt-4o", # Latest model for best results
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": query}
],
temperature=0.7,
max_tokens=500
)
# Return the response
return response.choices[0].message.content
except Exception as e:
st.warning(f"OpenAI query failed: {str(e)}. Trying LangChain approach...")
# Continue to the LangChain approach if OpenAI direct call fails
# Try the LangChain approach
llm = get_llm()
# If LLM is not available, use the fallback responses
if llm is None:
return get_fallback_response(query)
# Try to get or create vector store
vectorstore = None
try:
vectorstore = get_vector_store(df)
except Exception as e:
st.warning(f"Could not initialize vector store: {str(e)}")
# If we have both the LLM and vector store, try to use RAG
if vectorstore is not None:
try:
# Create memory for conversation history
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# Create the retrieval chain
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectorstore.as_retriever(),
memory=memory,
verbose=True
)
# Add system prompt to guide the model
system_prompt = (
"You are CeCe (Climate Copilot), an AI assistant specializing in climate and weather data analysis. "
"You help users with climate data visualization, scientific calculations, and understanding weather patterns. "
"Base your responses on the provided context and your knowledge of climate science."
)
# Run the query
response = qa({"question": query, "system_prompt": system_prompt})
return response["answer"]
except Exception as e:
st.warning(f"RAG query failed: {str(e)}. Falling back to direct response.")
return get_fallback_response(query)
else:
# If no vector store is available, use direct response
return get_fallback_response(query)
except Exception as e:
# Fallback response in case of errors
return f"I apologize, but I encountered an error processing your request: {str(e)}. Please try again or rephrase your question."
# Get a predefined response for common climate questions
def get_fallback_response(query):
# A dictionary of predefined responses for common queries
climate_responses = {
"temperature": "Temperature is a key climate variable. I can help you analyze temperature trends, calculate anomalies, and visualize temperature data. You can use the preset buttons above to explore temperature-related features.",
"precipitation": "Precipitation includes rain, snow, and other forms of water falling from the sky. I can help you analyze precipitation patterns and create visualization maps. Try the 'Generate a precipitation map' button above!",
"climate change": "Climate change refers to significant changes in global temperature, precipitation, wind patterns, and other measures of climate that occur over several decades or longer. I can help you analyze climate data to understand these changes.",
"weather": "Weather refers to day-to-day conditions, while climate refers to the average weather patterns in an area over a longer period. I can help you analyze both weather data and climate trends.",
"forecast": "While I don't provide real-time weather forecasts, I can help you analyze historical climate data and identify patterns that might inform future conditions.",
"hello": "Hello! I'm CeCe, your Climate Copilot. I'm here to help you analyze and visualize climate data. How can I assist you today?",
"help": "I can help you with climate data analysis, visualization, and scientific calculations. Try one of the preset buttons above to get started, or ask me a specific question about climate data."
}
# Check if the query contains any of our predefined topics
query_lower = query.lower()
for topic, response in climate_responses.items():
if topic in query_lower:
return response
# Default response if no specific topic is matched
return "I'm CeCe, your Climate Copilot. I can help you analyze climate data, create visualizations, and perform scientific calculations. Try one of the preset buttons above, or ask me a specific question about climate or weather data!"