Skip to content

Commit 84644ce

Browse files
authored
Fix and add regression tests for issue #2724 related to Vector Index corruption (#2737)
1 parent 68b7e25 commit 84644ce

3 files changed

Lines changed: 163 additions & 5 deletions

File tree

LiteDB.Tests/Engine/DropCollection_Tests.cs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.IO;
44
using System.Linq;
55
using System.Reflection;
6+
using System.Threading.Tasks;
67
using FluentAssertions;
78
using LiteDB;
89
using LiteDB.Engine;
@@ -212,6 +213,56 @@ public void DropCollection_WithVectorIndex_ReclaimsTrackedPages()
212213
}
213214
}
214215

216+
/// <summary>
217+
/// Regression test for LiteDB issue #2724.
218+
/// </summary>
219+
/// <remarks>
220+
/// https://github.com/litedb-org/LiteDB/issues/2724
221+
/// </remarks>
222+
[Fact]
223+
public void DropCollection_WithVectorIndex_AfterParallelUpserts_DoesNotThrow()
224+
{
225+
using var file = new TempFile();
226+
227+
const int documentCount = 64;
228+
const int updateCount = 256;
229+
const ushort dimensions = 32;
230+
231+
using (var db = DatabaseFactory.Create(
232+
TestDatabaseType.Disk,
233+
$"Filename={file.Filename};Connection=Shared"))
234+
{
235+
var collection = db.GetCollection<VectorDocument>("docs");
236+
var options = new VectorIndexOptions(dimensions, VectorDistanceMetric.Cosine);
237+
238+
collection.EnsureIndex(VectorIndexName, x => x.Embedding, options);
239+
240+
var initial = Enumerable.Range(1, documentCount)
241+
.Select(i => new VectorDocument
242+
{
243+
Id = i,
244+
Embedding = CreateLargeVector(i, dimensions)
245+
})
246+
.ToList();
247+
248+
collection.Insert(initial);
249+
250+
Parallel.For(0, updateCount, i =>
251+
{
252+
var id = (i % documentCount) + 1;
253+
collection.Upsert(new VectorDocument
254+
{
255+
Id = id,
256+
Embedding = CreateLargeVector(i + 5000, dimensions)
257+
});
258+
});
259+
260+
Action drop = () => db.DropCollection("docs");
261+
262+
drop.Should().NotThrow();
263+
}
264+
}
265+
215266
private static Dictionary<PageType, int> CountPagesByType(string filename)
216267
{
217268
var counts = new Dictionary<PageType, int>();

LiteDB.Tests/Query/VectorIndex_Tests.cs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,80 @@ private static float[] CreateVector(Random random, int dimensions)
111111
return vector;
112112
}
113113

114+
/// <summary>
115+
/// Regression test for LiteDB issue #2724.
116+
/// </summary>
117+
/// <remarks>
118+
/// https://github.com/litedb-org/LiteDB/issues/2724
119+
/// </remarks>
120+
[Fact]
121+
public void VectorIndex_ShouldMaintainBidirectionalEdges()
122+
{
123+
const int count = 256;
124+
const int dimensions = 8;
125+
126+
using var db = DatabaseFactory.Create();
127+
var collection = db.GetCollection<VectorDocument>("docs");
128+
129+
collection.EnsureIndex(
130+
"embedding_idx",
131+
BsonExpression.Create("$.Embedding"),
132+
new VectorIndexOptions(dimensions, VectorDistanceMetric.Cosine));
133+
134+
var random = new Random(123);
135+
136+
for (var i = 1; i <= count; i++)
137+
{
138+
collection.Insert(new VectorDocument
139+
{
140+
Id = i,
141+
Embedding = CreateVector(random, dimensions)
142+
});
143+
}
144+
145+
db.Checkpoint();
146+
147+
InspectVectorIndex(db, "docs", (snapshot, collation, metadata) =>
148+
{
149+
CountNodes(snapshot, metadata.Root).Should().Be(count);
150+
151+
var visited = new HashSet<PageAddress>();
152+
var queue = new Queue<PageAddress>();
153+
queue.Enqueue(metadata.Root);
154+
155+
while (queue.Count > 0)
156+
{
157+
var address = queue.Dequeue();
158+
if (!visited.Add(address))
159+
{
160+
continue;
161+
}
162+
163+
var node = snapshot.GetPage<VectorIndexPage>(address.PageID).GetNode(address.Index);
164+
165+
for (var level = 0; level < node.LevelCount; level++)
166+
{
167+
foreach (var neighbor in node.GetNeighbors(level))
168+
{
169+
if (neighbor.IsEmpty)
170+
{
171+
continue;
172+
}
173+
174+
queue.Enqueue(neighbor);
175+
176+
var neighborNode = snapshot.GetPage<VectorIndexPage>(neighbor.PageID).GetNode(neighbor.Index);
177+
178+
(level < neighborNode.LevelCount).Should().BeTrue();
179+
neighborNode.GetNeighbors(level).Should().Contain(address);
180+
}
181+
}
182+
}
183+
184+
return true;
185+
}).Should().BeTrue();
186+
}
187+
114188
private static float[] ReadExternalVector(DataService dataService, PageAddress start, int dimensions, out int blocksRead)
115189
{
116190
var totalBytes = dimensions * sizeof(float);

LiteDB/Engine/Services/VectorIndexService.cs

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,11 +293,16 @@ private void Insert(VectorIndexMetadata metadata, PageAddress dataBlock, float[]
293293
candidates.Where(x => x.Address != newAddress).ToList(),
294294
VectorIndexNode.MaxNeighborsPerLevel);
295295

296-
node.SetNeighbors(level, selected.Select(x => x.Address).ToList());
296+
var selectedAddresses = selected.Select(x => x.Address).ToList();
297297

298-
foreach (var neighbor in selected)
298+
node.SetNeighbors(level, selectedAddresses);
299+
300+
foreach (var neighbor in selectedAddresses)
299301
{
300-
this.EnsureBidirectional(metadata, neighbor.Address, newAddress, level, vectorCache);
302+
if (!this.EnsureBidirectional(metadata, neighbor, newAddress, level, vectorCache))
303+
{
304+
node.RemoveNeighbor(level, neighbor);
305+
}
301306
}
302307

303308
if (selected.Count > 0)
@@ -420,10 +425,11 @@ private List<NodeDistance> SearchLayer(
420425
return this.SelectNeighbors(results, Math.Max(1, maxResults));
421426
}
422427

423-
private void EnsureBidirectional(VectorIndexMetadata metadata, PageAddress source, PageAddress target, int level, Dictionary<PageAddress, float[]> vectorCache)
428+
private bool EnsureBidirectional(VectorIndexMetadata metadata, PageAddress source, PageAddress target, int level, Dictionary<PageAddress, float[]> vectorCache)
424429
{
425430
var node = this.GetNode(source);
426-
var neighbors = node.GetNeighbors(level).ToList();
431+
var neighbors = node.GetNeighbors(level).Where(x => !x.IsEmpty).ToList();
432+
var before = neighbors.ToList();
427433

428434
if (!neighbors.Contains(target))
429435
{
@@ -432,6 +438,33 @@ private void EnsureBidirectional(VectorIndexMetadata metadata, PageAddress sourc
432438

433439
var pruned = this.PruneNeighbors(metadata, source, neighbors, vectorCache);
434440
node.SetNeighbors(level, pruned);
441+
442+
foreach (var removed in before)
443+
{
444+
if (!pruned.Contains(removed))
445+
{
446+
this.RemoveBackLink(removed, source, level);
447+
}
448+
}
449+
450+
return pruned.Contains(target);
451+
}
452+
453+
private void RemoveBackLink(PageAddress source, PageAddress target, int level)
454+
{
455+
if (source.IsEmpty || target.IsEmpty)
456+
{
457+
return;
458+
}
459+
460+
var node = this.GetNode(source);
461+
462+
if (level < 0 || level >= node.LevelCount)
463+
{
464+
return;
465+
}
466+
467+
node.RemoveNeighbor(level, target);
435468
}
436469

437470
private IReadOnlyList<PageAddress> PruneNeighbors(VectorIndexMetadata metadata, PageAddress source, List<PageAddress> neighbors, Dictionary<PageAddress, float[]> vectorCache)

0 commit comments

Comments
 (0)