-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathVectorTextIndex.cs
More file actions
155 lines (132 loc) · 6.02 KB
/
VectorTextIndex.cs
File metadata and controls
155 lines (132 loc) · 6.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
// Copyright (c) Microsoft. All rights reserved.
using Microsoft.TypeChat.Embeddings;
namespace Microsoft.TypeChat;
/// <summary>
/// VectorTextIndex is an in-memory vector text index that automatically vectorizes given items using a model
/// All embeddings are normalized automatically for performance.
/// Each item T has an associated text description. It is this description that is indexed using embeddings.
///
/// The VectorTextIndex is also a TextRequestRouter that uses embeddings to route text requests
/// </summary>
/// <typeparam name="T"></typeparam>
public class VectorTextIndex<T> : ITextRequestRouter<T>
{
private readonly TextEmbeddingModel _model;
private readonly VectorizedList<T> _list;
/// <summary>
/// Create a new VectorTextIndex
/// </summary>
/// <param name="model">embedding model</param>
public VectorTextIndex(TextEmbeddingModel model)
: this(model, new VectorizedList<T>())
{
}
/// <summary>
/// Create a new VectorTextIndex
/// </summary>
/// <param name="model">model to use</param>
/// <param name="list">vector list to use</param>
public VectorTextIndex(TextEmbeddingModel model, VectorizedList<T> list)
{
ArgumentVerify.ThrowIfNull(model, nameof(model));
ArgumentVerify.ThrowIfNull(list, nameof(list));
_model = model;
_list = list;
}
/// <summary>
/// Items in this index
/// </summary>
public VectorizedList<T> Items => _list;
/// <summary>
/// Route the given request to the semantically nearest T
/// Does so by comparing the embedding of request to that of all registered T
/// </summary>
/// <param name="request">tequest</param>
/// <param name="cancelToken">optional cancel token</param>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
public Task<T> RouteRequestAsync(string request, CancellationToken cancelToken = default)
{
return NearestAsync(request, cancelToken);
}
/// <summary>
/// Add an item to the collection. Its associated textKey will be vectorized into an embedding
/// </summary>
/// <param name="item">item to add to the index</param>
/// <param name="textRepresentation">The text representation of the item; its transformed into an embedding</param>
/// <param name="cancelToken">cancel token</param>
public async Task AddAsync(T item, string textRepresentation, CancellationToken cancelToken = default)
{
ArgumentVerify.ThrowIfNullOrEmpty(textRepresentation, nameof(textRepresentation));
var embedding = await GetNormalizedEmbeddingAsync(textRepresentation, cancelToken).ConfigureAwait(false);
_list.Add(item, embedding);
}
/// <summary>
/// A multiple items to the collection.
/// If the associated embedding model supports batching, this can be much faster
/// </summary>
/// <param name="items">items to add to the collection</param>
/// <param name="textRepresentations">the text representations of these items</param>
/// <param name="cancelToken">optional cancel token</param>
/// <exception cref="ArgumentException"></exception>
/// <exception cref="InvalidOperationException"></exception>
public async Task AddAsync(T[] items, string[] textRepresentations, CancellationToken cancelToken = default)
{
ArgumentVerify.ThrowIfNull(items, nameof(items));
ArgumentVerify.ThrowIfNull(textRepresentations, nameof(textRepresentations));
if (items.Length != textRepresentations.Length)
{
throw new ArgumentException("items and their representations must of the same length");
}
Embedding[] embeddings = await GetNormalizedEmbeddingAsync(textRepresentations, cancelToken).ConfigureAwait(false);
if (embeddings.Length != items.Length)
{
throw new InvalidOperationException($"Embedding length {embeddings.Length} does not match items length {items.Length}");
}
for (int i = 0; i < items.Length; ++i)
{
_list.Add(items[i], embeddings[i]);
}
}
/// <summary>
/// Find nearest match to the given text
/// </summary>
/// <param name="text"></param>
/// <param name="cancelToken">optional cancel token</param>
/// <returns>nearest item</returns>
public async Task<T> NearestAsync(string text, CancellationToken cancelToken = default)
{
var embedding = await GetNormalizedEmbeddingAsync(text, cancelToken).ConfigureAwait(false);
return _list.Nearest(embedding, EmbeddingDistance.Dot);
}
/// <summary>
/// Return topN text from the collection closest to the given text
/// </summary>
/// <param name="text">text to search for</param>
/// <param name="maxMatches">max matches</param>
/// <param name="cancelToken">optional cancel token</param>
/// <returns>list of matches</returns>
public async Task<List<T>> NearestAsync(string text, int maxMatches, CancellationToken cancelToken = default)
{
var embedding = await GetNormalizedEmbeddingAsync(text, cancelToken).ConfigureAwait(false);
return _list.Nearest(embedding, maxMatches, EmbeddingDistance.Dot).ToList();
}
private async Task<Embedding> GetNormalizedEmbeddingAsync(string text, CancellationToken cancelToken)
{
var embedding_float = await _model.GenerateEmbeddingAsync(text, cancelToken).ConfigureAwait(false);
var embedding = new Embedding(embedding_float);
embedding.NormalizeInPlace();
return embedding;
}
private async Task<Embedding[]> GetNormalizedEmbeddingAsync(string[] texts, CancellationToken cancelToken)
{
var embeddings_float = await _model.GenerateEmbeddingsAsync(texts, cancelToken).ConfigureAwait(false);
var embeddings = new Embedding[embeddings_float.Length];
for (int i = 0; i < embeddings_float.Length; ++i)
{
embeddings[i] = new Embedding(embeddings_float[i]);
embeddings[i].NormalizeInPlace();
}
return embeddings;
}
}