Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 57 additions & 24 deletions src/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import type { AIMessage } from '../types'
import { v4 as uuidv4 } from 'uuid'
import { summarizeMessages } from './llm'

const WINDOW_SIZE = 5

export type MessageWithMetadata = AIMessage & {
id: string
createdAt: string
Expand All @@ -13,6 +15,11 @@ type Data = {
summary: string
}

const defaultData: Data = {
messages: [],
summary: '',
}

export const addMetadata = (message: AIMessage) => {
return {
...message,
Expand All @@ -26,24 +33,39 @@ export const removeMetadata = (message: MessageWithMetadata) => {
return rest
}

const defaultData: Data = {
messages: [],
summary: '',
}

export const getDb = async () => {
const db = await JSONFilePreset<Data>('db.json', defaultData)
return db
return await JSONFilePreset<Data>('db.json', defaultData)
}

export const addMessages = async (messages: AIMessage[]) => {
export const addMessages = async (newMessages: AIMessage[]) => {
const db = await getDb()
db.data.messages.push(...messages.map(addMetadata))
db.data.messages.push(...newMessages.map(addMetadata))

const messages = db.data.messages
const len = messages.length

// We only have a "previous window of size N" to summarize once we have at least 2N messages.
if (len >= 2 * WINDOW_SIZE) {
// Tail = raw messages we will return as-is (N, or N+1 if tool-boundary adjustment kicks in)
const tailStart = computeTailStartIndex(messages, WINDOW_SIZE)

// Summary window is the N messages immediately before the raw tail.
let summaryStart = tailStart - WINDOW_SIZE
let summaryEndExclusive = tailStart

if (db.data.messages.length >= 10) {
const oldestMessages = db.data.messages.slice(0, 5).map(removeMetadata)
const summary = await summarizeMessages(oldestMessages)
db.data.summary = summary
// If the summary window starts with a tool response, shift start left by 1
// so we don't begin a summarized chunk with a tool message detached from its context.
if (summaryStart > 0 && messages[summaryStart]?.role === 'tool') {
summaryStart -= 1
}

summaryStart = Math.max(0, summaryStart)

const messagesToSummarize = messages
.slice(summaryStart, summaryEndExclusive)
.map(removeMetadata)

db.data.summary = await summarizeMessages(messagesToSummarize)
}

await db.write()
Expand All @@ -52,17 +74,9 @@ export const addMessages = async (messages: AIMessage[]) => {
export const getMessages = async () => {
const db = await getDb()
const messages = db.data.messages.map(removeMetadata)
const lastFive = messages.slice(-5)

// If first message is a tool response, get one more message before it
if (lastFive[0]?.role === 'tool') {
const sixthMessage = messages[messages.length - 6]
if (sixthMessage) {
return [...[sixthMessage], ...lastFive]
}
}

return lastFive
const tailStart = computeTailStartIndex(messages, WINDOW_SIZE)
return messages.slice(tailStart)
}

export const getSummary = async () => {
Expand All @@ -72,7 +86,7 @@ export const getSummary = async () => {

export const saveToolResponse = async (
toolCallId: string,
toolResponse: string
toolResponse: string,
) => {
return addMessages([
{
Expand All @@ -82,3 +96,22 @@ export const saveToolResponse = async (
},
])
}

function computeTailStartIndex(
messages: AIMessage[],
keepLastN: number,
): number {
const len = messages.length
if (len <= keepLastN) return 0

// Nominally keep the last N raw messages
let tailStart = len - keepLastN

// If the kept raw tail starts with a tool response, shift tailStart left by 1
// so the tool response is not the first raw item (we include the message before it).
if (messages[tailStart]?.role === 'tool') {
tailStart = Math.max(0, tailStart - 1)
}

return tailStart
}