From 739411843076a9ccf0196c568ae4ceaf43a37daa Mon Sep 17 00:00:00 2001 From: wucm667 Date: Sat, 25 Apr 2026 11:22:27 +0800 Subject: [PATCH] feat(knowledgebase): add local backend --- .../local_knowledge_backend.go | 388 ++++++++++++++++++ .../local_knowledge_backend_test.go | 198 +++++++++ knowledgebase/knowledgebase.go | 11 +- knowledgebase/knowledgebase_test.go | 28 +- 4 files changed, 623 insertions(+), 2 deletions(-) create mode 100644 knowledgebase/backend/local_knowledge_backend/local_knowledge_backend.go create mode 100644 knowledgebase/backend/local_knowledge_backend/local_knowledge_backend_test.go diff --git a/knowledgebase/backend/local_knowledge_backend/local_knowledge_backend.go b/knowledgebase/backend/local_knowledge_backend/local_knowledge_backend.go new file mode 100644 index 0000000..b3bd9ca --- /dev/null +++ b/knowledgebase/backend/local_knowledge_backend/local_knowledge_backend.go @@ -0,0 +1,388 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local_knowledge_backend + +import ( + "context" + "errors" + "fmt" + "math" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "unicode" + + _interface "github.com/volcengine/veadk-go/knowledgebase/interface" + "github.com/volcengine/veadk-go/knowledgebase/ktypes" +) + +const ( + DefaultIndex = "local_knowledge_base" + DefaultTopK = 5 +) + +var ( + ErrLocalKnowledgeBackend = errors.New("local knowledge backend error") + ErrInvalidEmbedding = errors.New("invalid embedding response") +) + +type Config struct { + Index string + TopK int + Embedder Embedder +} + +type Embedder interface { + EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) +} + +type LocalKnowledgeBackend struct { + index string + topK int + embedder Embedder + + mu sync.RWMutex + nextID int + entries []entry +} + +type entry struct { + id int + content string + metadata []map[string]any + vector []float32 +} + +type scoredEntry struct { + entry entry + score float64 +} + +func NewLocalKnowledgeBackend(cfg *Config) (_interface.KnowledgeBackend, error) { + if cfg == nil { + cfg = &Config{} + } + index := strings.TrimSpace(cfg.Index) + if index == "" { + index = DefaultIndex + } + topK := cfg.TopK + if topK <= 0 { + topK = DefaultTopK + } + return &LocalKnowledgeBackend{ + index: index, + topK: topK, + embedder: cfg.Embedder, + }, nil +} + +func (l *LocalKnowledgeBackend) Index() string { + return l.index +} + +func (l *LocalKnowledgeBackend) AddFromText(text []string, opts ...map[string]any) error { + contents := make([]string, 0, len(text)) + metadatas := make([][]map[string]any, 0, len(text)) + for _, t := range text { + if strings.TrimSpace(t) == "" { + continue + } + contents = append(contents, t) + metadatas = append(metadatas, metadataWithSource("text", "", opts...)) + } + return l.addEntries(contents, metadatas) +} + +func (l *LocalKnowledgeBackend) AddFromFiles(files []string, opts ...map[string]any) error { + contents := make([]string, 0, len(files)) + metadatas := make([][]map[string]any, 0, len(files)) + for _, file := range files { + data, err := os.ReadFile(file) + if err != nil { + return fmt.Errorf("%w: read file %q: %w", ErrLocalKnowledgeBackend, file, err) + } + if strings.TrimSpace(string(data)) == "" { + continue + } + contents = append(contents, string(data)) + metadatas = append(metadatas, metadataWithSource("file", file, opts...)) + } + return l.addEntries(contents, metadatas) +} + +func (l *LocalKnowledgeBackend) AddFromDirectory(directory string, opts ...map[string]any) error { + info, err := os.Stat(directory) + if err != nil { + return fmt.Errorf("%w: stat directory %q: %w", ErrLocalKnowledgeBackend, directory, err) + } + if !info.IsDir() { + return fmt.Errorf("%w: %q is not a directory", ErrLocalKnowledgeBackend, directory) + } + + files := make([]string, 0) + err = filepath.WalkDir(directory, func(path string, d os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if d.Type().IsRegular() { + files = append(files, path) + } + return nil + }) + if err != nil { + return fmt.Errorf("%w: walk directory %q: %w", ErrLocalKnowledgeBackend, directory, err) + } + sort.Strings(files) + return l.AddFromFiles(files, opts...) +} + +func (l *LocalKnowledgeBackend) Search(query string, opts ...map[string]any) ([]ktypes.KnowledgeEntry, error) { + if strings.TrimSpace(query) == "" { + return []ktypes.KnowledgeEntry{}, nil + } + + topK := extractIntOpt("topK", l.topK, opts...) + topK = extractIntOpt("top_k", topK, opts...) + if topK <= 0 { + topK = l.topK + } + + l.mu.RLock() + entries := make([]entry, len(l.entries)) + copy(entries, l.entries) + embedder := l.embedder + l.mu.RUnlock() + + if len(entries) == 0 { + return []ktypes.KnowledgeEntry{}, nil + } + + var scored []scoredEntry + if embedder != nil && hasVectors(entries) { + queryVector, err := embedQuery(context.Background(), embedder, query) + if err != nil { + return nil, err + } + scored = scoreByVector(entries, queryVector) + } else { + scored = scoreByText(entries, query) + } + + sort.SliceStable(scored, func(i, j int) bool { + if scored[i].score == scored[j].score { + return scored[i].entry.id < scored[j].entry.id + } + return scored[i].score > scored[j].score + }) + + if topK > len(scored) { + topK = len(scored) + } + + results := make([]ktypes.KnowledgeEntry, 0, topK) + for _, item := range scored[:topK] { + results = append(results, ktypes.KnowledgeEntry{ + Content: item.entry.content, + Metadata: cloneMetadata(item.entry.metadata), + }) + } + return results, nil +} + +func (l *LocalKnowledgeBackend) addEntries(contents []string, metadatas [][]map[string]any) error { + if len(contents) == 0 { + return nil + } + + vectors := make([][]float32, len(contents)) + if l.embedder != nil { + embedded, err := l.embedder.EmbedTexts(context.Background(), contents) + if err != nil { + return fmt.Errorf("%w: embed documents: %w", ErrLocalKnowledgeBackend, err) + } + if len(embedded) != len(contents) { + return fmt.Errorf("%w: got %d embeddings for %d documents", ErrInvalidEmbedding, len(embedded), len(contents)) + } + vectors = embedded + } + + l.mu.Lock() + defer l.mu.Unlock() + for i, content := range contents { + l.entries = append(l.entries, entry{ + id: l.nextID, + content: content, + metadata: cloneMetadata(metadatas[i]), + vector: append([]float32(nil), vectors[i]...), + }) + l.nextID++ + } + return nil +} + +func embedQuery(ctx context.Context, embedder Embedder, query string) ([]float32, error) { + vectors, err := embedder.EmbedTexts(ctx, []string{query}) + if err != nil { + return nil, fmt.Errorf("%w: embed query: %w", ErrLocalKnowledgeBackend, err) + } + if len(vectors) != 1 { + return nil, fmt.Errorf("%w: got invalid query embedding response", ErrInvalidEmbedding) + } + return vectors[0], nil +} + +func hasVectors(entries []entry) bool { + for _, item := range entries { + if len(item.vector) == 0 { + return false + } + } + return true +} + +func scoreByVector(entries []entry, queryVector []float32) []scoredEntry { + scored := make([]scoredEntry, 0, len(entries)) + for _, item := range entries { + scored = append(scored, scoredEntry{ + entry: item, + score: cosineSimilarity(queryVector, item.vector), + }) + } + return scored +} + +func scoreByText(entries []entry, query string) []scoredEntry { + queryTerms := tokenize(query) + scored := make([]scoredEntry, 0, len(entries)) + for _, item := range entries { + score := lexicalScore(query, queryTerms, item.content) + if score > 0 { + scored = append(scored, scoredEntry{ + entry: item, + score: score, + }) + } + } + return scored +} + +func lexicalScore(query string, queryTerms []string, content string) float64 { + lowerContent := strings.ToLower(content) + score := 0.0 + for _, term := range queryTerms { + score += float64(strings.Count(lowerContent, term)) + } + if strings.Contains(lowerContent, strings.ToLower(strings.TrimSpace(query))) { + score += 0.5 + } + return score +} + +func tokenize(text string) []string { + seen := make(map[string]struct{}) + terms := make([]string, 0) + for _, token := range strings.FieldsFunc(strings.ToLower(text), func(r rune) bool { + return !unicode.IsLetter(r) && !unicode.IsDigit(r) + }) { + if token == "" { + continue + } + if _, ok := seen[token]; ok { + continue + } + seen[token] = struct{}{} + terms = append(terms, token) + } + return terms +} + +func cosineSimilarity(a, b []float32) float64 { + if len(a) == 0 || len(a) != len(b) { + return math.Inf(-1) + } + var dot, normA, normB float64 + for i := range a { + av := float64(a[i]) + bv := float64(b[i]) + dot += av * bv + normA += av * av + normB += bv * bv + } + if normA == 0 || normB == 0 { + return math.Inf(-1) + } + return dot / (math.Sqrt(normA) * math.Sqrt(normB)) +} + +func metadataWithSource(source, filePath string, opts ...map[string]any) []map[string]any { + metadata := extractMetadata(opts...) + sourceMetadata := map[string]any{"source": source} + if filePath != "" { + sourceMetadata["file_path"] = filePath + } + return append(metadata, sourceMetadata) +} + +func extractMetadata(opts ...map[string]any) []map[string]any { + for _, opt := range opts { + val, ok := opt["metadata"] + if !ok { + continue + } + switch metadata := val.(type) { + case map[string]any: + return []map[string]any{cloneMap(metadata)} + case []map[string]any: + return cloneMetadata(metadata) + } + } + return nil +} + +func extractIntOpt(key string, defaultVal int, opts ...map[string]any) int { + for _, opt := range opts { + if val, ok := opt[key]; ok { + if intVal, ok := val.(int); ok { + return intVal + } + } + } + return defaultVal +} + +func cloneMetadata(metadata []map[string]any) []map[string]any { + if len(metadata) == 0 { + return nil + } + out := make([]map[string]any, 0, len(metadata)) + for _, item := range metadata { + out = append(out, cloneMap(item)) + } + return out +} + +func cloneMap(in map[string]any) map[string]any { + if in == nil { + return nil + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} diff --git a/knowledgebase/backend/local_knowledge_backend/local_knowledge_backend_test.go b/knowledgebase/backend/local_knowledge_backend/local_knowledge_backend_test.go new file mode 100644 index 0000000..fb71a4e --- /dev/null +++ b/knowledgebase/backend/local_knowledge_backend/local_knowledge_backend_test.go @@ -0,0 +1,198 @@ +// Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local_knowledge_backend + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewLocalKnowledgeBackendDefaults(t *testing.T) { + backend, err := NewLocalKnowledgeBackend(nil) + assert.Nil(t, err) + assert.NotNil(t, backend) + assert.Equal(t, DefaultIndex, backend.Index()) +} + +func TestLocalKnowledgeBackendAddFromTextAndSearch(t *testing.T) { + backend, err := NewLocalKnowledgeBackend(&Config{Index: "test-index", TopK: 1}) + assert.Nil(t, err) + + err = backend.AddFromText([]string{ + "Banana bread recipe with cinnamon.", + "The Volcengine Agent Development Kit builds agents.", + }, map[string]any{"metadata": map[string]any{"tenant": "test"}}) + assert.Nil(t, err) + + results, err := backend.Search("agent kit") + assert.Nil(t, err) + assert.Len(t, results, 1) + assert.Contains(t, results[0].Content, "Agent Development Kit") + assert.Equal(t, "test", results[0].Metadata[0]["tenant"]) + assert.Equal(t, "text", results[0].Metadata[1]["source"]) +} + +func TestLocalKnowledgeBackendSearchTopK(t *testing.T) { + backend, err := NewLocalKnowledgeBackend(&Config{TopK: 3}) + assert.Nil(t, err) + + err = backend.AddFromText([]string{ + "agent alpha", + "agent beta", + "agent gamma", + }) + assert.Nil(t, err) + + results, err := backend.Search("agent", map[string]any{"top_k": 2}) + assert.Nil(t, err) + assert.Len(t, results, 2) + assert.Equal(t, "agent alpha", results[0].Content) + assert.Equal(t, "agent beta", results[1].Content) +} + +func TestLocalKnowledgeBackendAddFromFiles(t *testing.T) { + dir := t.TempDir() + file := filepath.Join(dir, "agent.txt") + err := os.WriteFile(file, []byte("local knowledge file for agent tools"), 0600) + assert.Nil(t, err) + + backend, err := NewLocalKnowledgeBackend(nil) + assert.Nil(t, err) + + err = backend.AddFromFiles([]string{file}) + assert.Nil(t, err) + + results, err := backend.Search("agent tools") + assert.Nil(t, err) + assert.Len(t, results, 1) + assert.Equal(t, "local knowledge file for agent tools", results[0].Content) + assert.Equal(t, "file", results[0].Metadata[0]["source"]) + assert.Equal(t, file, results[0].Metadata[0]["file_path"]) +} + +func TestLocalKnowledgeBackendAddFromDirectory(t *testing.T) { + dir := t.TempDir() + err := os.WriteFile(filepath.Join(dir, "alpha.txt"), []byte("alpha knowledge"), 0600) + assert.Nil(t, err) + nested := filepath.Join(dir, "nested") + err = os.Mkdir(nested, 0700) + assert.Nil(t, err) + err = os.WriteFile(filepath.Join(nested, "beta.txt"), []byte("beta knowledge"), 0600) + assert.Nil(t, err) + + backend, err := NewLocalKnowledgeBackend(nil) + assert.Nil(t, err) + + err = backend.AddFromDirectory(dir) + assert.Nil(t, err) + + results, err := backend.Search("beta") + assert.Nil(t, err) + assert.Len(t, results, 1) + assert.Equal(t, "beta knowledge", results[0].Content) +} + +func TestLocalKnowledgeBackendAddErrors(t *testing.T) { + backend, err := NewLocalKnowledgeBackend(nil) + assert.Nil(t, err) + + err = backend.AddFromFiles([]string{filepath.Join(t.TempDir(), "missing.txt")}) + assert.NotNil(t, err) + assert.True(t, errors.Is(err, ErrLocalKnowledgeBackend)) + + file := filepath.Join(t.TempDir(), "not-dir.txt") + err = os.WriteFile(file, []byte("content"), 0600) + assert.Nil(t, err) + err = backend.AddFromDirectory(file) + assert.NotNil(t, err) + assert.True(t, errors.Is(err, ErrLocalKnowledgeBackend)) +} + +func TestLocalKnowledgeBackendEmptySearch(t *testing.T) { + backend, err := NewLocalKnowledgeBackend(nil) + assert.Nil(t, err) + + results, err := backend.Search(" ") + assert.Nil(t, err) + assert.Empty(t, results) +} + +func TestLocalKnowledgeBackendSearchWithEmbedder(t *testing.T) { + embedder := &mockEmbedder{ + vectors: map[string][]float32{ + "cat document": {1, 0}, + "dog document": {0, 1}, + "bark": {0, 1}, + }, + } + backend, err := NewLocalKnowledgeBackend(&Config{Embedder: embedder}) + assert.Nil(t, err) + + err = backend.AddFromText([]string{"cat document", "dog document"}) + assert.Nil(t, err) + + results, err := backend.Search("bark", map[string]any{"topK": 1}) + assert.Nil(t, err) + assert.Len(t, results, 1) + assert.Equal(t, "dog document", results[0].Content) +} + +func TestLocalKnowledgeBackendEmbedderError(t *testing.T) { + backend, err := NewLocalKnowledgeBackend(&Config{ + Embedder: &mockEmbedder{err: errors.New("embed failed")}, + }) + assert.Nil(t, err) + + err = backend.AddFromText([]string{"agent document"}) + assert.NotNil(t, err) + assert.True(t, errors.Is(err, ErrLocalKnowledgeBackend)) +} + +func TestLocalKnowledgeBackendInvalidEmbeddingResponse(t *testing.T) { + backend, err := NewLocalKnowledgeBackend(&Config{ + Embedder: &mockEmbedder{vectors: map[string][]float32{}}, + }) + assert.Nil(t, err) + + err = backend.AddFromText([]string{"agent document"}) + assert.NotNil(t, err) + assert.True(t, errors.Is(err, ErrInvalidEmbedding)) +} + +type mockEmbedder struct { + vectors map[string][]float32 + err error +} + +func (m *mockEmbedder) EmbedTexts(ctx context.Context, texts []string) ([][]float32, error) { + _ = ctx + if m.err != nil { + return nil, m.err + } + embeddings := make([][]float32, 0, len(texts)) + for _, text := range texts { + vector, ok := m.vectors[text] + if !ok { + continue + } + embeddings = append(embeddings, vector) + } + return embeddings, nil +} diff --git a/knowledgebase/knowledgebase.go b/knowledgebase/knowledgebase.go index 13bf866..d99f984 100644 --- a/knowledgebase/knowledgebase.go +++ b/knowledgebase/knowledgebase.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" + "github.com/volcengine/veadk-go/knowledgebase/backend/local_knowledge_backend" "github.com/volcengine/veadk-go/knowledgebase/backend/viking_knowledge_backend" _interface "github.com/volcengine/veadk-go/knowledgebase/interface" "github.com/volcengine/veadk-go/knowledgebase/ktypes" @@ -42,12 +43,20 @@ type KnowledgeBase struct { func getKnowledgeBackend(backend string, backendConfig any) (_interface.KnowledgeBackend, error) { switch backend { + case ktypes.LocalBackend: + if backendConfig == nil { + return local_knowledge_backend.NewLocalKnowledgeBackend(nil) + } + if config, ok := backendConfig.(*local_knowledge_backend.Config); ok { + return local_knowledge_backend.NewLocalKnowledgeBackend(config) + } + return nil, ErrInvalidKnowledgeBackendConfig case ktypes.VikingBackend: if config, ok := backendConfig.(*viking_knowledge_backend.Config); ok { return viking_knowledge_backend.NewVikingKnowledgeBackend(config) } return nil, ErrInvalidKnowledgeBackendConfig - case ktypes.RedisBackend, ktypes.LocalBackend, ktypes.OpensearchBackend: + case ktypes.RedisBackend, ktypes.OpensearchBackend: return nil, fmt.Errorf("%w: %s", ErrInvalidKnowledgeBackend, backend) default: return nil, fmt.Errorf("%w: %s", ErrInvalidKnowledgeBackend, backend) diff --git a/knowledgebase/knowledgebase_test.go b/knowledgebase/knowledgebase_test.go index eb499f6..7019c38 100644 --- a/knowledgebase/knowledgebase_test.go +++ b/knowledgebase/knowledgebase_test.go @@ -6,6 +6,7 @@ import ( "github.com/bytedance/mockey" "github.com/stretchr/testify/assert" + "github.com/volcengine/veadk-go/knowledgebase/backend/local_knowledge_backend" "github.com/volcengine/veadk-go/knowledgebase/backend/viking_knowledge_backend" _interface "github.com/volcengine/veadk-go/knowledgebase/interface" "github.com/volcengine/veadk-go/knowledgebase/ktypes" @@ -59,6 +60,23 @@ func TestNewKnowledgeBase_WithStringBackendAndValidConfig(t *testing.T) { }) } +func TestNewKnowledgeBase_WithLocalBackend(t *testing.T) { + kb, err := NewKnowledgeBase( + ktypes.LocalBackend, + WithBackendConfig(&local_knowledge_backend.Config{Index: "idx"}), + ) + assert.Nil(t, err) + assert.NotNil(t, kb) + assert.Equal(t, "idx", kb.Backend.Index()) +} + +func TestNewKnowledgeBase_WithLocalBackendDefaultConfig(t *testing.T) { + kb, err := NewKnowledgeBase(ktypes.LocalBackend) + assert.Nil(t, err) + assert.NotNil(t, kb) + assert.Equal(t, local_knowledge_backend.DefaultIndex, kb.Backend.Index()) +} + func TestNewKnowledgeBase_VikingConstructorError(t *testing.T) { mockey.PatchConvey("viking backend constructor returns error", t, func() { mockey.Mock(viking_knowledge_backend.NewVikingKnowledgeBackend).Return(nil, errors.New("ctor error")).Build() @@ -81,11 +99,19 @@ func TestNewKnowledgeBase_InvalidConfigType(t *testing.T) { assert.Nil(t, kb) assert.True(t, errors.Is(err, ErrInvalidKnowledgeBackendConfig)) }) + mockey.PatchConvey("local backend with invalid config type", t, func() { + kb, err := NewKnowledgeBase( + ktypes.LocalBackend, + WithBackendConfig(struct{}{}), + ) + assert.Nil(t, kb) + assert.True(t, errors.Is(err, ErrInvalidKnowledgeBackendConfig)) + }) } func TestGetKnowledgeBackend_Unsupported(t *testing.T) { mockey.PatchConvey("unsupported backend types return wrapped error", t, func() { - for _, b := range []string{ktypes.RedisBackend, ktypes.LocalBackend, ktypes.OpensearchBackend, "unknown"} { + for _, b := range []string{ktypes.RedisBackend, ktypes.OpensearchBackend, "unknown"} { kb, err := getKnowledgeBackend(b, nil) assert.Nil(t, kb) assert.True(t, errors.Is(err, ErrInvalidKnowledgeBackend))