Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 190 additions & 41 deletions src/stores/chat.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import { defineStore } from 'pinia';
import { computed, ref, type ComputedRef, type Ref } from 'vue';
import type { Document, ChatMessage } from '@/types';
import type { Document, ErrorDetails, ChatMessage, ReformulateResponse } from '@/types';
import { basePostAxios } from '@/utils/fetch';
import { getQueryParamValue } from '@/utils/urlsUtils';
import { RELEVANCE_FACTOR } from '@/utils/constants';
import { getFromStorage, saveToStorage, clearFromStorage } from '@/utils/storage';
import i18n from '@/localisation/i18n';
import type { AxiosResponse } from 'axios';
import { useFiltersStore } from '@/stores/filters';

// CHAT STATUSES
export const CHAT_STATUS = {
ERROR: 'ERROR',
EMPTY: 'EMPTY',
REFORMULATED: 'REFORMULATED',
REFORMULATING: 'REFORMULATING',
Expand All @@ -28,13 +28,14 @@ export const useChatStore = defineStore('chat', () => {
const chatInput: Ref<string> = ref('');
const chatMessagesList: Ref<ChatMessage[]> = ref(getFromStorage('chat') || []);
const questionQueues: Ref<string[] | null> = ref(getFromStorage('questionQueues'));
const sourcesList: Ref<Document[]> = ref(getFromStorage('chatSources') || []);
const sourcesList: Ref<Document[] | null> = ref(getFromStorage('chatSources') || []);
const reformulatedQuery: Ref<string | null> = ref(getFromStorage('reformulatedQuery'));
const queryToSearch: Ref<string | null> = ref(null);
const storedConversationId: Ref<string | null> = ref(localStorage.getItem('chatConversationId'));
const storedMessageId: Ref<string | null> = ref(localStorage.getItem('chatMessageId'));

const isChatEmpty: ComputedRef<Boolean> = computed(() => chatMessagesList.value.length === 0);
const chatStatus: Ref<(typeof CHAT_STATUS)[CHAT_STATUSES_TYPE]> = ref(
const chatStatus: Ref<CHAT_STATUSES_TYPE> = ref(
isChatEmpty.value ? CHAT_STATUS.EMPTY : CHAT_STATUS.DONE
);

Expand All @@ -43,6 +44,8 @@ export const useChatStore = defineStore('chat', () => {

const storedSubject: Ref<string | undefined> = ref(getFromStorage('chatSubject') || undefined);

const corpora: Ref<ReducedCorpus[]> = ref([]);

const clearSubject = (): void => {
storedSubject.value = undefined;
clearFromStorage('chatSubject');
Expand Down Expand Up @@ -80,6 +83,10 @@ export const useChatStore = defineStore('chat', () => {
saveToStorage('reformulatedQuery', query);
};

const setQueryToSearch = (queries: string): void => {
queryToSearch.value = queries;
};

function setQuestionQueues(messages: string[]): void {
questionQueues.value = messages;
saveToStorage('questionQueues', messages);
Expand All @@ -89,6 +96,120 @@ export const useChatStore = defineStore('chat', () => {
chatInput.value = '';
}

async function reformulateQuestion(query: string) {
try {
chatStatus.value = CHAT_STATUS.REFORMULATING;
const bodyContent = {
history: getMessageHistory.value,
query: query
};

const reformulate: AxiosResponse<ReformulateResponse> = await basePostAxios(
'/qna/reformulate/query',
bodyContent
);

if (!reformulate || !reformulate.data) {
chatStatus.value = CHAT_STATUS.ERROR;
return;
}

const { data } = reformulate;

if (data.QUERY_STATUS === 'REF_TO_PAST') {
shouldFetchNewDocuments.value = false;
chatStatus.value = CHAT_STATUS.SEARCHED;
return;
}

if (data.QUERY_STATUS === 'INVALID') {
chatStatus.value = CHAT_STATUS.ERROR;
return;
}

shouldFetchNewDocuments.value = true;

const reformulatedQuery = reformulate.data.STANDALONE_QUESTION;

if (!reformulatedQuery) {
// add error message
return;
}

setReformulatedQuery(reformulatedQuery);
setQueryToSearch(reformulatedQuery);
chatStatus.value = CHAT_STATUS.REFORMULATED;
} catch (error: unknown) {
chatStatus.value = CHAT_STATUS.ERROR;
const { message, code } = (error as ErrorDetails).response.data.detail;

if (code === 'LANG_NOT_SUPPORTED') {
chatMessagesList.value.push({
role: 'assistant',
content: `${i18n.global.t('error.LANG_NOT_SUPPORTED.title')} ${i18n.global.t(
'error.LANG_NOT_SUPPORTED.description'
)}`
});
}

if (code === 'INVALID_QUESTION') {
chatMessagesList.value.push({
role: 'assistant',
content: message
});
}
}
}

async function fetchSources(): Promise<void> {
if (
(storedSubject.value && !subjectHasChanged.value) ||
!reformulatedQuery.value ||
!shouldFetchNewDocuments.value
) {
return;
}

const { sdgFilters, sourcesFilters: selectedCorpus } = useFiltersStore();

try {
chatStatus.value = CHAT_STATUS.SEARCHING;
const sourcesResp: AxiosResponse<Document[]> = await basePostAxios(
`/search/by_slices?nb_results=10${
storedSubject.value ? `&subject=${storedSubject.value}` : ''
}`,
{
query: queryToSearch.value,
relevance_factor: RELEVANCE_FACTOR,
sdg_filter: sdgFilters,
corpora: selectedCorpus
}
);

subjectHasChanged.value = false;

const { data: sources, status: sourcesStatus } = sourcesResp;

if (sourcesStatus === 204) {
chatStatus.value = CHAT_STATUS.NO_RESULTS;
// no result
return;
}

if (sourcesStatus === 200 && sources === null) {
chatStatus.value = CHAT_STATUS.NO_RESULTS;
// no result
return;
}
chatStatus.value = CHAT_STATUS.SEARCHED;

sourcesList.value = sources;
saveToStorage('chatSources', sources);
} catch (error) {
chatStatus.value = CHAT_STATUS.ERROR;
}
}

function storeConversationId(conversationId: string) {
if (conversationId !== storedConversationId.value) {
localStorage.setItem('chatConversationId', conversationId);
Expand All @@ -104,6 +225,37 @@ export const useChatStore = defineStore('chat', () => {
}
}

async function getNoStreamAnswer(userMsg: string) {
chatStatus.value = CHAT_STATUS.FORMULATING_ANSWER;
const bodyContent = {
conversation_id: storedConversationId.value,
sources: sourcesList.value || [],
history: getMessageHistory.value,
query: userMsg,
...(storedSubject.value && { subject: storedSubject.value })
};

const respBody = await basePostAxios('/qna/chat/answer', bodyContent);

chatStatus.value = CHAT_STATUS.FORMULATED_ANSWER;

chatMessagesList.value.push({ role: 'assistant', content: respBody.data.answer });
saveToStorage('chat', chatMessagesList.value);
storeConversationId(respBody.data.conversation_id);
storeMessageId(respBody.data.message_id);

const newQuestions: AxiosResponse<{ NEW_QUESTIONS: string[] }> = await basePostAxios(
'/qna/reformulate/questions',
{
history: getMessageHistory.value,
query: reformulatedQuery.value
}
);
chatStatus.value = CHAT_STATUS.FORMULATED_ANSWER;

setQuestionQueues(newQuestions?.data['NEW_QUESTIONS']);
}

async function fetchRephrase() {
chatStatus.value = CHAT_STATUS.FORMULATING_ANSWER;
// get the content of the message which the role is assistant
Expand All @@ -127,39 +279,12 @@ export const useChatStore = defineStore('chat', () => {
chatStatus.value = CHAT_STATUS.DONE;
}

async function getAgentAnswer(userMsg: string) {
const { sdgFilters, sourcesFilters: selectedCorpus } = useFiltersStore();
const body = {
query: userMsg,
threadId: storedConversationId.value,
corpora: selectedCorpus,
sdg_filter: sdgFilters
};

const { data } = await basePostAxios('/qna/chat/agent', body);

chatMessagesList.value.push({ role: 'assistant', content: data.content });
if (data.docs) {
sourcesList.value = data.docs;
}

storeConversationId(data.conversation_id);
storeMessageId(data.message_id);

chatStatus.value = CHAT_STATUS.FORMULATED_ANSWER;
}

async function getNewQuestions(userMsg: string) {
const newQuestions: AxiosResponse<{ NEW_QUESTIONS: string[] }> = await basePostAxios(
'/qna/reformulate/questions',
{
history: getMessageHistory.value,
query: userMsg
}
);

setQuestionQueues(newQuestions?.data['NEW_QUESTIONS']);
}
const noResultsAnswer = () => {
chatMessagesList.value.push({
role: 'assistant',
content: i18n.global.t('chatNoResults')
});
};

async function onSendMessage(message: string): Promise<void> {
// checks if can be sent
Expand All @@ -170,18 +295,41 @@ export const useChatStore = defineStore('chat', () => {
return;
}

// adds message to history
addToMessageList({ role: 'user', content: message });

try {
chatStatus.value = CHAT_STATUS.FORMULATING_ANSWER;
await getAgentAnswer(message);
await getNewQuestions(message);
// rephrases question
await reformulateQuestion(message);
if (chatStatus.value === CHAT_STATUS.ERROR) {
chatMessagesList.value.push({
role: 'assistant',
content: i18n.global.t('chatProvideValidQuestion')
});
}

// gets documents
await fetchSources();
} catch (error) {
chatStatus.value = CHAT_STATUS.ERROR;
console.error(error);
chatStatus.value = CHAT_STATUS.ERROR;
}
if (chatStatus.value === CHAT_STATUS.NO_RESULTS) {
noResultsAnswer();
return;
}

if (sourcesList.value?.length) {
try {
// gets new questions & answer
// await fetchChatAnswer(message);
await getNoStreamAnswer(message);
} catch (error) {
chatStatus.value = CHAT_STATUS.ERROR;
console.error(error);
return;
}
}
chatStatus.value = CHAT_STATUS.DONE;
}

Expand Down Expand Up @@ -211,6 +359,7 @@ export const useChatStore = defineStore('chat', () => {
chatMessagesList,
questionQueues,
sourcesList,
corpora,
reformulatedQuery,
onSendMessage,
fetchRephrase,
Expand Down
Loading