-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodule_plot_SVG.py
More file actions
112 lines (103 loc) · 5.91 KB
/
module_plot_SVG.py
File metadata and controls
112 lines (103 loc) · 5.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import re
import pandas as pd
from pandasai import SmartDataframe
from pandasai import Agent
from langchain_groq import ChatGroq
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
def get_plot_model():
return ChatGroq(model="llama3-groq-70b-8192-tool-use-preview", temperature=0)
# It should receive the queries only, not the actual data
def get_plot_from_all(llm_plot, db, question, PRINT_SETTINGS):
all_data = pd.read_sql_query("SELECT * from expensesok;", db)
get_plot(all_data, question, llm_plot, PRINT_SETTINGS, is_there_time=True)
def get_plot_from_RAG(llm_plot, search_output, question, PRINT_SETTINGS):
print(search_output)
data = []
pattern = r'(\d+\.\d+)\s+(.+)\s+(\w+)$'
if (type(search_output) == list):
for i in search_output:
print(f"i: {i}")
lines = i.split('\n')
for line in lines:
match = re.match(pattern, line)
if match:
price = float(match.group(1))
description = match.group(2)
category = match.group(3)
data.append({
'price': price,
'description': description,
'category': category
})
print(data)
else:
lines = search_output.split('\n')
for line in lines:
print(line)
match = re.match(pattern, line)
if match:
price = float(match.group(1))
description = match.group(2)
category = match.group(3)
data.append({
'price': price,
'description': description,
'category': category
})
print(data)
df = pd.DataFrame(data)
get_plot(df, question, llm_plot, PRINT_SETTINGS, is_there_time=False)
def get_plot(dataframe, question, llm, PRINT_SETTINGS, is_there_time):
agent = Agent(dataframe, config={"llm": llm}, description="""You are a data analysis agent.
Your main goal is to help non-technical users to analyze data on their financial expenses.
The charts must have numbers to ease the reading.
Give an SVG file and not a PNG file.
Make sure everything is visible.
Make it very good looking.
If you need to display categories, display everything.
Don't print anything to terminal and don't open newly generated SVG files.
Don't use floats for time.
Only one plot per question.
Always get a complete legend.
""")
sdf = SmartDataframe(dataframe, config={"llm": llm})
if (is_there_time):
sdf = SmartDataframe(sdf, name="Financial Expenses",
description="""The database contains the columns price, description, category, timestamp.
Each row is an expense of the user.
You will be called to plot the correct data, given the question.""",
config={"llm": llm})
else:
sdf = SmartDataframe(sdf, name="Financial Expenses",
description="""The database contains the columns price, description, category.
Each row is an expense of the user.
You will be called to plot the correct data, given the question.""",
config={"llm": llm})
agent.chat(question + """Create an appropriate very good looking plotly chart using only the relevant data from the database.
Give an SVG file and not a PNG file.
Make sure everything is visible.
Organize the spaces in a way that everything is readable. You don't have limits.
If you need to display categories, display everything.
Don't use floats for time.
Make the chart very good looking.
Only one plot per question.
Always get a complete legend.""")
if (PRINT_SETTINGS["print_explaination_plot"]):
res = agent.explain()
parser = StrOutputParser()
system_template = """You will receive a description of the procedure used to extract a plot from the data.
You will transform this description as if you were talking to a non-technical user, who only cares about which data got used
and about the kind of chart created.
Make the chart very good looking.
The only thing you will talk about is what data you chose and what kind of chart you created.
Don't talk about databases, SQL, SVG, titles, labels or any technical stuff.
Don't talk about what the title and the labels are.
Be very short and concise.
The description is: {description}"""
prompt_template = ChatPromptTemplate.from_messages(
[("system", system_template), ("user", "{description}")]
)
chain = prompt_template | llm | parser
res = chain.invoke({"description": res})
print(f"Explaination: {res}")