-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmemory_server.py
More file actions
1573 lines (1296 loc) · 59.8 KB
/
memory_server.py
File metadata and controls
1573 lines (1296 loc) · 59.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Minimal Viable Memory System (v0)
Cloud-based vector memory with FastAPI + SentenceTransformers + Chroma
"""
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from typing import List, Optional
import json
import os
import hashlib
from datetime import datetime, timedelta, timezone
import numpy as np
from numpy.linalg import norm
import psycopg2
import psycopg2.extras
import secrets
from contextlib import asynccontextmanager
import logging
# Set up logging
log_dir = '/app/logs' if os.path.exists('/app') else './logs'
os.makedirs(log_dir, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(f'{log_dir}/memory_server.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# Global variables
security = HTTPBearer(auto_error=False)
def _env_bool(name: str, default: bool) -> bool:
"""Parse boolean env flags consistently."""
raw = os.getenv(name)
if raw is None:
return default
return raw.strip().lower() in {"1", "true", "yes", "on"}
def _required_env(name: str) -> str:
"""Fail fast for required DB settings when DATABASE_URL is not provided."""
value = os.getenv(name)
if value:
return value
raise RuntimeError(f"Missing required environment variable: {name}")
def _csv_env(name: str, default_values: List[str]) -> List[str]:
raw = os.getenv(name)
if not raw:
return default_values
return [item.strip() for item in raw.split(",") if item.strip()]
DATABASE_URL = os.getenv("DATABASE_URL")
def get_db_connection():
"""
Get a PostgreSQL connection from DATABASE_URL or explicit DB_* variables.
Keeping this in one place prevents accidental credential hardcoding.
"""
if DATABASE_URL:
return psycopg2.connect(DATABASE_URL)
return psycopg2.connect(
host=_required_env("DB_HOST"),
database=_required_env("DB_NAME"),
user=_required_env("DB_USER"),
password=_required_env("DB_PASSWORD"),
port=int(os.getenv("DB_PORT", "5432")),
sslmode=os.getenv("DB_SSLMODE", "require"),
)
# Database initialization
def init_auth_db():
"""Initialize authentication database"""
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS auth_tokens (
id SERIAL PRIMARY KEY,
token TEXT UNIQUE NOT NULL,
user_id TEXT NOT NULL,
user_name TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_used TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_active BOOLEAN DEFAULT true,
is_admin BOOLEAN DEFAULT false
)
""")
# Create indexes for better performance
cursor.execute("CREATE INDEX IF NOT EXISTS idx_auth_tokens_token ON auth_tokens(token)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_auth_tokens_user_id ON auth_tokens(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_auth_tokens_active ON auth_tokens(is_active)")
# Add is_admin column if it doesn't exist (for existing databases)
cursor.execute("""
ALTER TABLE auth_tokens
ADD COLUMN IF NOT EXISTS is_admin BOOLEAN DEFAULT false
""")
# Keep schema in sync for token generation endpoint.
cursor.execute("""
ALTER TABLE auth_tokens
ADD COLUMN IF NOT EXISTS email TEXT
""")
# Create waitlist table
cursor.execute("""
CREATE TABLE IF NOT EXISTS waitlist_emails (
id SERIAL PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
token_used TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (token_used) REFERENCES auth_tokens(token)
)
""")
# Core memory storage tables used by all API endpoints.
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
cursor.execute("""
CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
text TEXT NOT NULL,
tag TEXT NOT NULL,
embedding vector(384) NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_memories_user_id ON memories(user_id)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_memories_tag ON memories(tag)")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_memories_created_at ON memories(created_at)")
cursor.execute("""
CREATE TABLE IF NOT EXISTS short_term_memories (
id SERIAL PRIMARY KEY,
user_id TEXT UNIQUE NOT NULL,
title TEXT,
content TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Create default token if none exists
cursor.execute("SELECT COUNT(*) FROM auth_tokens WHERE is_active = true")
active_token_count = cursor.fetchone()[0]
should_bootstrap_token = _env_bool("BOOTSTRAP_DEFAULT_TOKEN", False)
if active_token_count == 0 and should_bootstrap_token:
default_token = secrets.token_urlsafe(32)
cursor.execute("""
INSERT INTO auth_tokens (token, user_id, user_name)
VALUES (%s, %s, %s)
""", (default_token, "default_user", "Default User"))
logger.info("Bootstrapped default token for local setup (written to /tmp/auth_token.txt)")
with open("/tmp/auth_token.txt", "w") as f:
f.write(default_token)
elif active_token_count == 0:
logger.info("No active tokens exist yet. Use /api/generate-token to create one.")
conn.commit()
cursor.close()
conn.close()
# Initialize auth database
init_auth_db()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan events"""
# Startup
logger.info("🚀 Memory Server starting up...")
yield
# Shutdown
logger.info("🛑 Memory Server shutting down...")
app = FastAPI(title="Memory System v0", lifespan=lifespan)
# Configure CORS
origins = _csv_env("CORS_ALLOW_ORIGINS", [
"http://localhost:5173", # Vite dev server
"http://localhost:5174", # Alternative dev port
"http://localhost:8081", # Frontend dev server
"https://usemindmirror.com", # Production domain
"https://www.usemindmirror.com", # Production with www
"https://memory.usemindmirror.com", # Memory subdomain
])
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize embedding model (pre-downloaded during build)
logger.info("🔄 Loading pre-downloaded embedding model...")
model = SentenceTransformer('all-MiniLM-L6-v2')
# Create memory hash cache for deduplication
memory_hashes = set()
# Load existing memory hashes on startup
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("SELECT metadata->>'hash' as hash FROM memories WHERE metadata->>'hash' IS NOT NULL")
hashes = cursor.fetchall()
for (hash_val,) in hashes:
memory_hashes.add(hash_val)
cursor.close()
conn.close()
logger.info(f"Loaded {len(memory_hashes)} memory hashes for deduplication")
except Exception as e:
logger.error(f"Error loading memory hashes: {e}")
# Fixed tag set
VALID_TAGS = [
"goal", "routine", "preference", "constraint",
"habit", "project", "tool", "identity", "value"
]
# Constants for pruning logic
ARCHIVE_AGE_DAYS = 90 # Archive memories older than 90 days
ACCESS_THRESHOLD_DAYS = 30 # Archive if not accessed in last 30 days
CORE_TAGS = ["identity", "value"] # Never prune these core memories
SIMILARITY_THRESHOLD = 0.65 # Threshold for conflict detection
# Authentication functions
def get_user_from_token(token: str) -> Optional[str]:
"""Get user_id from token"""
logger.info("Validating token")
try:
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute("""
SELECT user_id FROM auth_tokens
WHERE token = %s AND is_active = true
""", (token,))
result = cursor.fetchone()
if result:
# Update last_used timestamp
cursor.execute("""
UPDATE auth_tokens
SET last_used = CURRENT_TIMESTAMP
WHERE token = %s
""", (token,))
conn.commit()
logger.info(f"Token validated successfully for user: {result[0]}")
else:
logger.warning("Token validation failed")
cursor.close()
conn.close()
return result[0] if result else None
except Exception as e:
logger.error(f"Database error during token validation: {e}")
return None
def get_current_user(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
"""Get current user from token (URL param or Authorization header)"""
# Validate host to ensure memory limits are enforced
host = request.headers.get("host", "")
allowed_hosts = set(_csv_env("ALLOWED_API_HOSTS", [
"memory.usemindmirror.com",
"localhost:8001",
"localhost:8000",
"127.0.0.1:8001",
"127.0.0.1:8000"
]))
enforce_host_check = _env_bool("ENFORCE_HOST_CHECK", True)
if enforce_host_check and host not in allowed_hosts:
logger.warning(f"Memory access denied from unauthorized host: {host}")
raise HTTPException(
status_code=403,
detail=f"Memory access restricted. Please use https://memory.usemindmirror.com"
)
token = None
# Check URL parameter first (Zapier-style)
if "token" in request.query_params:
token = request.query_params["token"]
logger.info("Token provided via URL parameter")
# Check Authorization header as fallback
elif credentials:
token = credentials.credentials
logger.info("Token provided via Authorization header")
if not token:
logger.warning("No authentication token provided")
raise HTTPException(status_code=401, detail="Authentication token required")
user_id = get_user_from_token(token)
if not user_id:
logger.warning("Authentication failed for provided token")
raise HTTPException(status_code=401, detail="Invalid or expired token")
logger.info(f"Request authenticated for user: {user_id}")
return user_id
class MemoryItem(BaseModel):
text: str
tag: str
timestamp: Optional[str] = None
last_accessed: Optional[str] = None
class TokenGenerationRequest(BaseModel):
"""Request model for generating a new auth token"""
email: str
user_name: Optional[str] = None
class TokenGenerationResponse(BaseModel):
"""Response model for token generation"""
token: str
user_id: str
url: str
memory_limit: int = 25
memories_used: int = 0
class WaitlistRequest(BaseModel):
"""Request model for joining premium waitlist"""
email: str
class WaitlistResponse(BaseModel):
"""Response model for waitlist signup"""
message: str
email: str
class MemoryResponse(BaseModel):
id: str
text: str
tag: str
timestamp: str
similarity: Optional[float] = None
last_accessed: Optional[str] = None
class SearchRequest(BaseModel):
query: str
limit: int = 10
tag_filter: Optional[str] = None
class CheckpointRequest(BaseModel):
content: str
title: Optional[str] = None
class CheckpointResponse(BaseModel):
status: str
overwrote: bool
previous_checkpoint_time: Optional[str] = None
id: int
class ResumeResponse(BaseModel):
exists: bool
content: Optional[str] = None
title: Optional[str] = None
created_at: Optional[str] = None
id: Optional[int] = None
@app.post("/memories")
async def store_memory(memory: MemoryItem, user_id: str = Depends(get_current_user)):
"""Store a new memory item"""
logger.info(f"Storing memory for user {user_id}: '{memory.text[:50]}...' (tag: {memory.tag})")
# Validate tag
if memory.tag not in VALID_TAGS:
logger.error(f"Invalid tag '{memory.tag}' provided by user {user_id}")
raise HTTPException(
status_code=400,
detail=f"Invalid tag. Must be one of: {VALID_TAGS}"
)
# Generate timestamp if not provided
if not memory.timestamp:
memory.timestamp = datetime.now(timezone.utc).isoformat() + "Z"
# Check for duplicates using hash of text+tag
memory_text_normalized = memory.text.strip().lower()
memory_hash = hashlib.md5(f"{memory_text_normalized}:{memory.tag}".encode()).hexdigest()
if memory_hash in memory_hashes:
# Memory already exists, return without storing
logger.info(f"Duplicate memory detected for user {user_id}: hash {memory_hash[:10]}...")
return {
"status": "skipped",
"reason": "duplicate",
"text": memory.text,
"tag": memory.tag
}
# Add to hash set
memory_hashes.add(memory_hash)
# Check memory limit (except for admin users)
MEMORY_LIMIT = 25
conn = get_db_connection()
cursor = conn.cursor()
# Check if user is admin
cursor.execute("""
SELECT is_admin FROM auth_tokens
WHERE user_id = %s AND is_active = true
LIMIT 1
""", (user_id,))
admin_result = cursor.fetchone()
is_admin = admin_result[0] if admin_result and admin_result[0] else False
if not is_admin:
# Check current memory count
cursor.execute("""
SELECT COUNT(*) FROM memories WHERE user_id = %s
""", (user_id,))
memory_count = cursor.fetchone()[0]
if memory_count >= MEMORY_LIMIT:
cursor.close()
conn.close()
logger.info(f"Memory limit reached for user {user_id}: {memory_count}/{MEMORY_LIMIT}")
return {
"error": "Memory limit reached. Upgrade to premium to store more.",
"premium_link": "https://usemindmirror.com/premium",
"memories_used": memory_count,
"memory_limit": MEMORY_LIMIT
}
# Generate embedding
embedding = model.encode(memory.text).tolist()
# Check for semantic duplicates (similarity > 0.95) before storing
DUPLICATE_THRESHOLD = 0.95
# Check for similar memories with same tag
cursor.execute("""
SELECT id, text, COALESCE(1 - (embedding <=> %s::vector), 0.0) as similarity
FROM memories
WHERE user_id = %s AND tag = %s
ORDER BY embedding <=> %s::vector
LIMIT 3
""", (embedding, user_id, memory.tag, embedding))
similar_memories = cursor.fetchall()
for mem_id, mem_text, similarity in similar_memories:
if similarity is not None and similarity > DUPLICATE_THRESHOLD:
cursor.close()
conn.close()
logger.info(f"Semantic duplicate detected for user {user_id}: '{memory.text}' too similar to '{mem_text}' (similarity: {similarity:.4f})")
return {"status": "duplicate", "message": f"Memory too similar to existing memory {mem_id}", "similarity": similarity}
# Create unique ID
memory_id = f"mem_{int(datetime.now(timezone.utc).timestamp() * 1000)}"
# Check for conflicts with existing memories (same tag, high similarity)
conflicts = []
SIMILARITY_THRESHOLD = 0.65 # Lowered threshold for potential conflicts - was 0.8 initially
logger.info(f"Checking for conflicts for user {user_id} with similarity threshold {SIMILARITY_THRESHOLD}...")
# Search for memories with same tag for conflict detection
cursor.execute("""
SELECT id, text, COALESCE(1 - (embedding <=> %s::vector), 0.0) as similarity
FROM memories
WHERE user_id = %s AND tag = %s
ORDER BY embedding <=> %s::vector
LIMIT 5
""", (embedding, user_id, memory.tag, embedding))
conflict_candidates = cursor.fetchall()
for conflict_id, conflict_text, similarity in conflict_candidates:
similarity_str = f"{similarity:.4f}" if similarity is not None else "None"
logger.info(f"Similarity check for user {user_id} with '{conflict_text}': {similarity_str}")
if similarity is not None and similarity >= SIMILARITY_THRESHOLD:
logger.info(f"Conflict detected for user {user_id}: {conflict_id} (similarity: {similarity:.4f})")
conflicts.append(conflict_id)
# Set last_accessed timestamp
current_time = datetime.now(timezone.utc).isoformat() + "Z"
last_accessed = memory.last_accessed or current_time
# Prepare metadata with conflict info if any
metadata = {
"tag": memory.tag,
"timestamp": memory.timestamp,
"hash": memory_hash, # Store hash for future reference
"last_accessed": last_accessed
}
# Add conflict flags and IDs if conflicts exist
if conflicts:
metadata["has_conflicts"] = True
metadata["conflict_ids"] = json.dumps(conflicts)
# Update the conflicting memories to point back to this one
for conflict_id in conflicts:
# Get existing metadata for the conflict
cursor.execute("SELECT metadata FROM memories WHERE id = %s", (conflict_id,))
result = cursor.fetchone()
if result:
conflict_metadata = result[0]
# Update conflict info
conflict_metadata["has_conflicts"] = True
# Add this memory ID to the conflict's list of conflicts
if "conflict_ids" in conflict_metadata:
existing_conflicts = json.loads(conflict_metadata["conflict_ids"])
if memory_id not in existing_conflicts:
existing_conflicts.append(memory_id)
conflict_metadata["conflict_ids"] = json.dumps(existing_conflicts)
else:
conflict_metadata["conflict_ids"] = json.dumps([memory_id])
# Update the conflict's metadata
cursor.execute("UPDATE memories SET metadata = %s WHERE id = %s", (json.dumps(conflict_metadata), conflict_id))
# Store in PostgreSQL
cursor.execute("""
INSERT INTO memories (id, user_id, text, tag, embedding, metadata)
VALUES (%s, %s, %s, %s, %s, %s)
""", (memory_id, user_id, memory.text, memory.tag, embedding, json.dumps(metadata)))
conn.commit()
cursor.close()
conn.close()
logger.info(f"Memory stored successfully for user {user_id}: ID {memory_id}, {len(conflicts)} conflicts detected")
return {
"id": memory_id,
"text": memory.text,
"tag": memory.tag,
"timestamp": memory.timestamp,
"status": "stored"
}
def keyword_search(query: str, user_id: str, limit: int, tag_filter: str = None, exclude_ids: set = None):
"""
Fallback keyword search using PostgreSQL ILIKE for text matching
Returns list of MemoryResponse objects with artificial similarity scores
"""
# Extract keywords from query (split on spaces, remove common words)
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
keywords = [word.strip().lower() for word in query.split() if word.strip().lower() not in stop_words and len(word.strip()) > 2]
if not keywords:
return []
# Build ILIKE conditions for each keyword
like_conditions = []
params = [user_id]
for keyword in keywords:
like_conditions.append("text ILIKE %s")
params.append(f"%{keyword}%")
# Build tag filter
tag_filter_sql = ""
if tag_filter:
tag_filter_sql = "AND tag = %s"
params.append(tag_filter)
# Build exclude IDs filter
exclude_filter_sql = ""
if exclude_ids:
exclude_placeholders = ','.join(['%s'] * len(exclude_ids))
exclude_filter_sql = f"AND id NOT IN ({exclude_placeholders})"
params.extend(exclude_ids)
params.append(limit)
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
# Execute keyword search
where_clause = f"WHERE user_id = %s AND ({' OR '.join(like_conditions)}) {tag_filter_sql} {exclude_filter_sql}"
cursor.execute(f"""
SELECT id, text, tag, metadata, created_at
FROM memories
{where_clause}
ORDER BY created_at DESC
LIMIT %s
""", params)
results = cursor.fetchall()
cursor.close()
conn.close()
# Convert to MemoryResponse objects with artificial similarity scores
memories = []
for i, row in enumerate(results):
metadata = row['metadata']
# Assign decreasing similarity scores (0.7 to 0.6) to compete with weaker semantic matches
artificial_similarity = 0.7 - (i * 0.03) # 0.7, 0.67, 0.64, 0.61, etc.
memory = MemoryResponse(
id=row['id'],
text=row['text'],
tag=row['tag'],
timestamp=metadata['timestamp'],
last_accessed=metadata.get('last_accessed', metadata['timestamp']),
similarity=artificial_similarity
)
memories.append(memory)
return memories
@app.post("/memories/search")
async def search_memories(request: SearchRequest, user_id: str = Depends(get_current_user)):
"""Hybrid search: semantic search with keyword fallback"""
logger.info(f"Search request from user {user_id}: query='{request.query}', limit={request.limit}, tag_filter={request.tag_filter}")
# Generate query embedding
query_embedding = model.encode(request.query).tolist()
# Build where clause for tag filtering
tag_filter_sql = ""
params = [query_embedding, user_id, query_embedding, request.limit]
if request.tag_filter:
if request.tag_filter not in VALID_TAGS:
raise HTTPException(
status_code=400,
detail=f"Invalid tag filter. Must be one of: {VALID_TAGS}"
)
tag_filter_sql = "AND tag = %s"
params.insert(2, request.tag_filter)
# Log search parameters
logger.info(f"Search SQL params - user_id: {user_id}, limit: {request.limit}, tag_filter: {request.tag_filter}")
logger.info(f"SQL will use LIMIT: {params[-1]}")
# Search in PostgreSQL
conn = get_db_connection()
cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
cursor.execute(f"""
SELECT id, text, tag, metadata, created_at,
COALESCE(1 - (embedding <=> %s::vector), 0.0) as similarity
FROM memories
WHERE user_id = %s {tag_filter_sql}
ORDER BY embedding <=> %s::vector
LIMIT %s
""", params)
results = cursor.fetchall()
# Log raw search results
total_results = len(results)
logger.info(f"Semantic search results for user {user_id}: found {total_results} memories from database")
logger.info(f"Query was: '{request.query}' with limit {request.limit}")
if total_results > 0:
similarities = [row['similarity'] for row in results[:3]]
logger.info(f"Top 3 similarities: {similarities}")
# Log first result for debugging
logger.info(f"First result: {results[0]['text'][:50]}... (similarity: {results[0]['similarity']})")
# Hybrid search: add keyword fallback if semantic search returned fewer than requested results
if total_results < request.limit:
remaining_slots = request.limit - total_results
semantic_ids = {row['id'] for row in results}
logger.info(f"Semantic search returned {total_results} < {request.limit} requested. Running keyword fallback for {remaining_slots} more results.")
# Run keyword search for remaining slots, excluding already found IDs
keyword_results = keyword_search(
query=request.query,
user_id=user_id,
limit=remaining_slots,
tag_filter=request.tag_filter,
exclude_ids=semantic_ids
)
if keyword_results:
logger.info(f"Keyword fallback found {len(keyword_results)} additional memories")
# Convert MemoryResponse objects back to dict format to match semantic results
for memory in keyword_results:
result_dict = {
'id': memory.id,
'text': memory.text,
'tag': memory.tag,
'similarity': memory.similarity,
'metadata': {
'timestamp': memory.timestamp,
'last_accessed': memory.last_accessed,
'tag': memory.tag
},
'created_at': memory.timestamp # Use timestamp as created_at
}
results.append(result_dict)
else:
logger.info("Keyword fallback found no additional memories")
# Update total count after hybrid search
total_hybrid_results = len(results)
logger.info(f"Total hybrid search results: {total_hybrid_results} memories (semantic: {total_results}, keyword: {total_hybrid_results - total_results})")
# Sort all results by similarity and timestamp (composite sort for recency tiebreaking)
if results:
results.sort(key=lambda x: (x.get('similarity', 0.0), x.get('created_at', '')), reverse=True)
top_created = results[0].get('created_at', 'unknown')
if hasattr(top_created, 'strftime'):
top_created_str = top_created.strftime('%Y-%m-%d')
elif isinstance(top_created, str):
top_created_str = top_created[:10]
else:
top_created_str = str(top_created)[:10]
logger.info(f"Sorted hybrid results by similarity+timestamp - top result: similarity={results[0].get('similarity', 0.0):.3f}, created={top_created_str}")
# Format response
memories = []
conflict_sets = {}
if results:
# First pass - create all memories
for row in results:
metadata = row['metadata']
memory_id = row['id']
# Update last_accessed timestamp
current_time = datetime.now(timezone.utc).isoformat() + "Z"
metadata["last_accessed"] = current_time
cursor.execute("UPDATE memories SET metadata = %s WHERE id = %s", (json.dumps(metadata), memory_id))
memory = MemoryResponse(
id=memory_id,
text=row['text'],
tag=row['tag'],
timestamp=metadata['timestamp'],
last_accessed=current_time,
similarity=row['similarity']
)
memories.append(memory)
# Check for conflicts and build conflict sets
if "has_conflicts" in metadata and metadata["has_conflicts"]:
# This memory has conflicts, retrieve the conflict IDs
conflict_ids = json.loads(metadata.get("conflict_ids", "[]"))
# Create or update conflict set for this memory
if memory_id not in conflict_sets:
conflict_sets[memory_id] = [dict(memory)]
# Fetch and add conflicting memories if not already in results
for conflict_id in conflict_ids:
if conflict_id not in [m['id'] if isinstance(m, dict) else m.id for m in memories]:
cursor.execute("SELECT id, text, tag, metadata FROM memories WHERE id = %s", (conflict_id,))
conflict_result = cursor.fetchone()
if conflict_result:
conflict_metadata = conflict_result['metadata']
conflict_memory = MemoryResponse(
id=conflict_id,
text=conflict_result['text'],
tag=conflict_result['tag'],
timestamp=conflict_metadata['timestamp'],
last_accessed=conflict_metadata.get('last_accessed', conflict_metadata['timestamp'])
)
# Add to conflict set
conflict_sets[memory_id].append(dict(conflict_memory))
# Apply semantic deduplication to conflict sets
for memory_id, conflicts in conflict_sets.items():
if len(conflicts) > 1:
# Extract embeddings for similarity calculation
unique_conflicts = []
for conflict in conflicts:
# Check if this conflict is semantically similar to any existing unique conflict
is_duplicate = False
conflict_embedding = model.encode([conflict['text']])
for unique_conflict in unique_conflicts:
unique_embedding = model.encode([unique_conflict['text']])
# Calculate cosine similarity between embeddings
# Flatten embeddings for calculation
emb1 = conflict_embedding[0]
emb2 = unique_embedding[0]
# Calculate cosine similarity
try:
similarity = np.dot(emb1, emb2) / (norm(emb1) * norm(emb2))
except (ZeroDivisionError, ValueError):
similarity = 0.0
# If similarity > 0.95, it's a duplicate
if similarity is not None and similarity > 0.95:
is_duplicate = True
# Keep the more recent one
if conflict['timestamp'] > unique_conflict['timestamp']:
unique_conflicts.remove(unique_conflict)
unique_conflicts.append(conflict)
break
# If not a duplicate, add to unique conflicts
if not is_duplicate:
unique_conflicts.append(conflict)
# Sort unique conflicts by timestamp (most recent first)
unique_conflicts.sort(key=lambda x: x['timestamp'], reverse=True)
# Update conflict set with deduplicated conflicts
conflict_sets[memory_id] = unique_conflicts
# Close database connection
conn.commit()
cursor.close()
conn.close()
# Build initial response
response = {
"query": request.query,
"results": memories,
"count": len(memories)
}
# Server-side conflict message formatting will be added after conflict groups are built
# Group overlapping conflict sets using Union-Find for transitive merging
class UnionFind:
def __init__(self):
self.parent = {}
self.rank = {}
def find(self, x):
if x not in self.parent:
self.parent[x] = x
self.rank[x] = 0
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # Path compression
return self.parent[x]
def union(self, x, y):
px, py = self.find(x), self.find(y)
if px == py:
return
# Union by rank for efficiency
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
def get_groups(self):
groups = {}
for x in self.parent:
root = self.find(x)
if root not in groups:
groups[root] = []
groups[root].append(x)
return list(groups.values())
# Build conflict groups using Union-Find
if conflict_sets:
uf = UnionFind()
# Union memories that appear in the same conflict set
for memory_id, conflicts in conflict_sets.items():
memory_ids_in_set = {memory_id}
for conflict in conflicts:
memory_ids_in_set.add(conflict['id'])
# Union all memories in this conflict set
memory_list = list(memory_ids_in_set)
for i in range(len(memory_list)):
for j in range(i + 1, len(memory_list)):
uf.union(memory_list[i], memory_list[j])
# Get unified conflict groups
memory_groups = uf.get_groups()
# Filter out singleton groups (groups with only 1 memory = no real conflicts)
meaningful_groups = [group for group in memory_groups if len(group) >= 2]
# Build final conflict groups with memory objects
conflict_groups = []
for group in meaningful_groups:
group_memories = []
for memory_id in group:
# Find memory object from original results or conflict sets
memory_obj = None
# First check if it's in the main results
for memory in memories:
if (isinstance(memory, dict) and memory['id'] == memory_id) or \
(hasattr(memory, 'id') and memory.id == memory_id):
memory_obj = dict(memory) if not isinstance(memory, dict) else memory
break
# If not found, look in conflict sets
if memory_obj is None:
for conflicts in conflict_sets.values():
for conflict in conflicts:
if conflict['id'] == memory_id:
memory_obj = conflict
break
if memory_obj:
break
if memory_obj:
group_memories.append(memory_obj)
# Sort group by timestamp (most recent first) and add to conflict groups
group_memories.sort(key=lambda x: x['timestamp'], reverse=True)
conflict_groups.append(group_memories)
# Replace individual conflict_sets with unified conflict_groups
if conflict_groups:
response["conflict_groups"] = conflict_groups
logger.info(f"Unified conflict groups for user {user_id}: {len(conflict_groups)} groups created from {len(conflict_sets)} individual sets")
for i, group in enumerate(conflict_groups):
logger.info(f"- Group {i+1}: {len(group)} memories ({', '.join([m['text'][:30] + '...' for m in group])})")
else:
# Keep original conflict_sets if no meaningful groups found
response["conflict_sets"] = conflict_sets
else:
# No conflicts detected, keep original structure
response["conflict_sets"] = conflict_sets
# Debug: log conflict detection information
logger.info(f"Debug for user {user_id}: ChromaDB returned {total_results} memories, built {len(memories)} responses")
# Log memory content details
if memories:
logger.info(f"Memory details for user {user_id}:")
for memory in memories:
memory_dict = memory if isinstance(memory, dict) else dict(memory)
text_snippet = memory_dict['text'][:50] + "..." if len(memory_dict['text']) > 50 else memory_dict['text']
similarity = memory_dict.get('similarity', 'N/A')
# Extract short dates from timestamps
created_date = "unknown"
accessed_date = "unknown"
if 'timestamp' in memory_dict and memory_dict['timestamp']:
try:
# Convert "2025-07-01T14:40:19.175601Z" to "07-01"
created_date = memory_dict['timestamp'].split('T')[0][5:] # Take YYYY-MM-DD, extract MM-DD
except:
created_date = "unknown"
if 'last_accessed' in memory_dict and memory_dict['last_accessed']:
try:
# Convert "2025-07-03T14:37:18.063049Z" to "07-03"
accessed_date = memory_dict['last_accessed'].split('T')[0][5:] # Take YYYY-MM-DD, extract MM-DD
except:
accessed_date = "unknown"
logger.info(f"- {memory_dict['id']}: \"{text_snippet}\" ({memory_dict['tag']}, sim: {similarity:.3f}, ts: {created_date}, accessed: {accessed_date})")
else:
logger.info(f"No memories returned for user {user_id}")
logger.info(f"Debug for user {user_id}: Identified {len(conflict_sets)} conflict sets")
# Add conflict sets if any exist
if conflict_sets:
# Clean conflict summary logging
conflict_summaries = []
for memory_id, conflicts in conflict_sets.items():
conflict_count = len(conflicts) - 1 # Subtract 1 because it includes the original memory
if conflicts:
main_memory = conflicts[0]
text_snippet = main_memory['text'][:50] + "..." if len(main_memory['text']) > 50 else main_memory['text']
conflict_summaries.append(f"{memory_id}: '{text_snippet}' ({conflict_count} conflicts)")
logger.info(f"Conflict sets for user {user_id}: {len(conflict_sets)} sets detected")
for summary in conflict_summaries:
logger.info(f"- {summary}")
else:
# Manually check for conflicts among the returned results
logger.info(f"Debug for user {user_id}: No conflict sets detected, performing manual conflict check")
manual_conflict_sets = {}
# Check each memory for conflicts
conn2 = get_db_connection()
cursor2 = conn2.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
for memory in memories:
memory_id = memory['id'] if isinstance(memory, dict) else memory.id