Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 12 additions & 21 deletions quadtree/maxheap.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,18 @@ import "github.com/paulmach/orb"
// the furthest point from the query point in the list, hence maxHeap.
// When we find a point closer than the furthest away, we remove
// furthest and add the new point to the heap.
type maxHeap []*heapItem
type maxHeap []heapItem

type heapItem struct {
point orb.Pointer
distance float64
}

func (h *maxHeap) Push(point orb.Pointer, distance float64) {
// Common usage is Push followed by a Pop if we have > k points.
// We're reusing the k+1 heapItem object to reduce memory allocations.
// First we manaully lengthen the slice,
// then we see if the last item has been allocated already.

prevLen := len(*h)
*h = (*h)[:prevLen+1]
if (*h)[prevLen] == nil {
(*h)[prevLen] = &heapItem{point: point, distance: distance}
} else {
(*h)[prevLen].point = point
(*h)[prevLen].distance = distance
}
(*h)[prevLen].point = point
(*h)[prevLen].distance = distance

i := len(*h) - 1
for i > 0 {
Expand All @@ -53,21 +44,20 @@ func (h *maxHeap) Push(point orb.Pointer, distance float64) {

// Pop returns the "greatest" item in the list.
// The returned item should not be saved across push/pop operations.
func (h *maxHeap) Pop() *heapItem {
removed := (*h)[0]
func (h *maxHeap) Pop() {
lastItem := (*h)[len(*h)-1]
(*h) = (*h)[:len(*h)-1]

mh := (*h)
if len(mh) == 0 {
return removed
return
}

// move the last item to the top and reset the heap
mh[0] = lastItem
mh[0].point = lastItem.point
mh[0].distance = lastItem.distance

i := 0
current := mh[i]
for {
right := (i + 1) << 1
left := right - 1
Expand All @@ -92,11 +82,12 @@ func (h *maxHeap) Pop() *heapItem {
}

// swap the nodes
mh[i] = child
mh[childIndex] = current
mh[i].point = child.point
mh[i].distance = child.distance

mh[childIndex].point = lastItem.point
mh[childIndex].distance = lastItem.distance

i = childIndex
}

return removed
}
6 changes: 4 additions & 2 deletions quadtree/maxheap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ func TestMaxHeap(t *testing.T) {
h.Push(nil, r.Float64())
}

current := h.Pop().distance
current := h[0].distance
h.Pop()
for len(h) > 0 {
next := h.Pop().distance
next := h[0].distance
h.Pop()
if next > current {
t.Errorf("incorrect")
}
Expand Down
4 changes: 3 additions & 1 deletion quadtree/quadtree.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ func (q *Quadtree) add(n *node, p orb.Pointer, point orb.Point, left, right, bot
// Remove will remove the pointer from the quadtree. By default it'll match
// using the points, but a FilterFunc can be provided for a more specific test
// if there are elements with the same point value in the tree. For example:
//
// func(pointer orb.Pointer) {
// return pointer.(*MyType).ID == lookingFor.ID
// }
Expand Down Expand Up @@ -273,7 +274,8 @@ func (q *Quadtree) KNearestMatching(buf []orb.Pointer, p orb.Point, k int, f Fil
}

for i := len(v.maxHeap) - 1; i >= 0; i-- {
buf[i] = v.maxHeap.Pop().point
buf[i] = v.maxHeap[0].point
v.maxHeap.Pop()
}

return buf
Expand Down
21 changes: 21 additions & 0 deletions quadtree/quadtree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,27 @@ func TestQuadtreeKNearest_sorted(t *testing.T) {
}
}

func TestQuadtreeKNearest_sorted2(t *testing.T) {
q := New(orb.Bound{Max: orb.Point{8, 8}})
q.Add(orb.Point{0, 0})
q.Add(orb.Point{1, 1})
q.Add(orb.Point{2, 2})
q.Add(orb.Point{3, 3})
q.Add(orb.Point{4, 4})
q.Add(orb.Point{5, 5})
q.Add(orb.Point{6, 6})
q.Add(orb.Point{7, 7})

nearest := q.KNearest(nil, orb.Point{5.25, 5.25}, 3)

expected := []orb.Point{{5, 5}, {6, 6}, {4, 4}}
for i, p := range expected {
if n := nearest[i].Point(); !n.Equal(p) {
t.Errorf("incorrect point %d: %v", i, n)
}
}
}

func TestQuadtreeKNearest_DistanceLimit(t *testing.T) {
type dataPointer struct {
orb.Pointer
Expand Down