KNN results for NEARBY command
This commit includes the ability to search for k nearest neighbors using a NEARBY command. When the LIMIT keyword is included and the 'meters' param is excluded, the knn algorithm will be used instead of the standard overlap+haversine algorithm. NEARBY fleet LIMIT 10 POINT 33.5 -115.8 This will find the 10 closest points to 33.5,-115.8. closes #136, #130, and #138. ping @tomquas, @joernroeder, and @m1ome
This commit is contained in:
parent
49e1fcce7a
commit
04290ec535
@ -536,3 +536,17 @@ func (c *Collection) Intersects(cursor uint64, sparse uint8, obj geojson.Object,
|
|||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Collection) NearestNeighbors(k int, lat, lon float64, iterator func(id string, obj geojson.Object, fields []float64) bool) {
|
||||||
|
c.index.NearestNeighbors(k, lat, lon, func(item index.Item) bool {
|
||||||
|
var iitm *itemT
|
||||||
|
iitm, ok := item.(*itemT)
|
||||||
|
if !ok {
|
||||||
|
return true // just ignore
|
||||||
|
}
|
||||||
|
if !iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -22,6 +22,7 @@ type liveFenceSwitches struct {
|
|||||||
maxLat, maxLon float64
|
maxLat, maxLon float64
|
||||||
cmd string
|
cmd string
|
||||||
roam roamSwitches
|
roam roamSwitches
|
||||||
|
knn bool
|
||||||
groups map[string]string
|
groups map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,10 +86,22 @@ func (c *Controller) cmdSearchArgs(cmd string, vs []resp.Value, types []string)
|
|||||||
err = errInvalidNumberOfArguments
|
err = errInvalidNumberOfArguments
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
umeters := true
|
||||||
if vs, smeters, ok = tokenval(vs); !ok || smeters == "" {
|
if vs, smeters, ok = tokenval(vs); !ok || smeters == "" {
|
||||||
err = errInvalidNumberOfArguments
|
umeters = false
|
||||||
return
|
if cmd == "nearby" {
|
||||||
|
// possible that this is KNN search
|
||||||
|
s.knn = s.searchScanBaseTokens.ulimit && // must be true
|
||||||
|
!s.searchScanBaseTokens.usparse && // must be false
|
||||||
|
s.searchScanBaseTokens.cursor == 0 // must be zero
|
||||||
|
}
|
||||||
|
if !s.knn {
|
||||||
|
err = errInvalidArgument(slat)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.lat, err = strconv.ParseFloat(slat, 64); err != nil {
|
if s.lat, err = strconv.ParseFloat(slat, 64); err != nil {
|
||||||
err = errInvalidArgument(slat)
|
err = errInvalidArgument(slat)
|
||||||
return
|
return
|
||||||
@ -97,9 +110,12 @@ func (c *Controller) cmdSearchArgs(cmd string, vs []resp.Value, types []string)
|
|||||||
err = errInvalidArgument(slon)
|
err = errInvalidArgument(slon)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.meters, err = strconv.ParseFloat(smeters, 64); err != nil {
|
|
||||||
err = errInvalidArgument(smeters)
|
if umeters {
|
||||||
return
|
if s.meters, err = strconv.ParseFloat(smeters, 64); err != nil {
|
||||||
|
err = errInvalidArgument(smeters)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case "object":
|
case "object":
|
||||||
var obj string
|
var obj string
|
||||||
@ -290,7 +306,7 @@ func (c *Controller) cmdNearby(msg *server.Message) (res string, err error) {
|
|||||||
}
|
}
|
||||||
sw.writeHead()
|
sw.writeHead()
|
||||||
if sw.col != nil {
|
if sw.col != nil {
|
||||||
s.cursor = sw.col.Nearby(s.cursor, s.sparse, s.lat, s.lon, s.meters, minZ, maxZ, func(id string, o geojson.Object, fields []float64) bool {
|
iter := func(id string, o geojson.Object, fields []float64) bool {
|
||||||
// Calculate distance if we need to
|
// Calculate distance if we need to
|
||||||
distance := 0.0
|
distance := 0.0
|
||||||
if s.distance {
|
if s.distance {
|
||||||
@ -303,7 +319,12 @@ func (c *Controller) cmdNearby(msg *server.Message) (res string, err error) {
|
|||||||
fields: fields,
|
fields: fields,
|
||||||
distance: distance,
|
distance: distance,
|
||||||
})
|
})
|
||||||
})
|
}
|
||||||
|
if s.knn {
|
||||||
|
sw.col.NearestNeighbors(int(s.limit), s.lat, s.lon, iter)
|
||||||
|
} else {
|
||||||
|
s.cursor = sw.col.Nearby(s.cursor, s.sparse, s.lat, s.lon, s.meters, minZ, maxZ, iter)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
sw.writeFoot(s.cursor)
|
sw.writeFoot(s.cursor)
|
||||||
if msg.OutputType == server.JSON {
|
if msg.OutputType == server.JSON {
|
||||||
|
@ -169,7 +169,9 @@ type searchScanBaseTokens struct {
|
|||||||
glob string
|
glob string
|
||||||
wheres []whereT
|
wheres []whereT
|
||||||
nofields bool
|
nofields bool
|
||||||
|
ulimit bool
|
||||||
limit uint64
|
limit uint64
|
||||||
|
usparse bool
|
||||||
sparse uint8
|
sparse uint8
|
||||||
desc bool
|
desc bool
|
||||||
}
|
}
|
||||||
@ -465,12 +467,14 @@ func parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if slimit != "" {
|
if slimit != "" {
|
||||||
|
t.ulimit = true
|
||||||
if t.limit, err = strconv.ParseUint(slimit, 10, 64); err != nil || t.limit == 0 {
|
if t.limit, err = strconv.ParseUint(slimit, 10, 64); err != nil || t.limit == 0 {
|
||||||
err = errInvalidArgument(slimit)
|
err = errInvalidArgument(slimit)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if ssparse != "" {
|
if ssparse != "" {
|
||||||
|
t.usparse = true
|
||||||
var sparse uint64
|
var sparse uint64
|
||||||
if sparse, err = strconv.ParseUint(ssparse, 10, 8); err != nil || sparse == 0 || sparse > 8 {
|
if sparse, err = strconv.ParseUint(ssparse, 10, 8); err != nil || sparse == 0 || sparse > 8 {
|
||||||
err = errInvalidArgument(ssparse)
|
err = errInvalidArgument(ssparse)
|
||||||
|
@ -123,6 +123,20 @@ func (ix *Index) getRTreeItem(item rtree.Item) Item {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ix *Index) NearestNeighbors(k int, lat, lon float64, iterator func(item Item) bool) {
|
||||||
|
x, y, _ := normPoint(lat, lon)
|
||||||
|
items := ix.r.NearestNeighbors(k, x, y, 0)
|
||||||
|
for _, item := range items {
|
||||||
|
iitm := ix.getRTreeItem(item)
|
||||||
|
if item == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !iterator(iitm) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Search returns all items that intersect the bounding box.
|
// Search returns all items that intersect the bounding box.
|
||||||
func (ix *Index) Search(cursor uint64, swLat, swLon, neLat, neLon, minZ, maxZ float64, iterator func(item Item) bool) (ncursor uint64) {
|
func (ix *Index) Search(cursor uint64, swLat, swLon, neLat, neLon, minZ, maxZ float64, iterator func(item Item) bool) (ncursor uint64) {
|
||||||
var idx uint64
|
var idx uint64
|
||||||
|
205
index/rtree/knn.go
Normal file
205
index/rtree/knn.go
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
// Much of the KNN code has been adapted from the
|
||||||
|
// github.com/dhconnelly/rtreego project.
|
||||||
|
//
|
||||||
|
// Copyright 2012 Daniel Connelly. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
package rtree
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NearestNeighbors gets the closest Spatials to the Point.
|
||||||
|
func (tr *RTree) NearestNeighbors(k int, x, y, z float64) []Item {
|
||||||
|
if tr.tr.root == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
dists := make([]float64, k)
|
||||||
|
objs := make([]Item, k)
|
||||||
|
for i := 0; i < k; i++ {
|
||||||
|
dists[i] = math.MaxFloat64
|
||||||
|
objs[i] = nil
|
||||||
|
}
|
||||||
|
objs, _ = tr.nearestNeighbors(k, x, y, z, tr.tr.root, dists, objs)
|
||||||
|
//for i := 0; i < len(objs); i++ {
|
||||||
|
// fmt.Printf("%v\n", objs[i])
|
||||||
|
//}
|
||||||
|
for i := 0; i < len(objs); i++ {
|
||||||
|
if objs[i] == nil {
|
||||||
|
return objs[:i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return objs
|
||||||
|
}
|
||||||
|
|
||||||
|
// minDist computes the square of the distance from a point to a rectangle.
|
||||||
|
// If the point is contained in the rectangle then the distance is zero.
|
||||||
|
//
|
||||||
|
// Implemented per Definition 2 of "Nearest Neighbor Queries" by
|
||||||
|
// N. Roussopoulos, S. Kelley and F. Vincent, ACM SIGMOD, pages 71-79, 1995.
|
||||||
|
func minDist(x, y, z float64, r d3rectT) float64 {
|
||||||
|
sum := 0.0
|
||||||
|
p := [3]float64{x, y, z}
|
||||||
|
rp := [3]float64{
|
||||||
|
float64(r.min[0]), float64(r.min[1]), float64(r.min[2]),
|
||||||
|
}
|
||||||
|
rq := [3]float64{
|
||||||
|
float64(r.max[0]), float64(r.max[1]), float64(r.max[2]),
|
||||||
|
}
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
if p[i] < float64(rp[i]) {
|
||||||
|
d := p[i] - float64(rp[i])
|
||||||
|
sum += d * d
|
||||||
|
} else if p[i] > float64(rq[i]) {
|
||||||
|
d := p[i] - float64(rq[i])
|
||||||
|
sum += d * d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr *RTree) nearestNeighbors(k int, x, y, z float64, n *d3nodeT, dists []float64, nearest []Item) ([]Item, []float64) {
|
||||||
|
if n.isLeaf() {
|
||||||
|
for i := 0; i < n.count; i++ {
|
||||||
|
e := n.branch[i]
|
||||||
|
dist := math.Sqrt(minDist(x, y, z, e.rect))
|
||||||
|
dists, nearest = insertNearest(k, dists, nearest, dist, e.data.(Item))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
branches, branchDists := sortEntries(x, y, z, n.branch[:n.count])
|
||||||
|
branches = pruneEntries(x, y, z, branches, branchDists)
|
||||||
|
for _, e := range branches {
|
||||||
|
nearest, dists = tr.nearestNeighbors(k, x, y, z, e.child, dists, nearest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nearest, dists
|
||||||
|
}
|
||||||
|
|
||||||
|
// insert obj into nearest and return the first k elements in increasing order.
|
||||||
|
func insertNearest(k int, dists []float64, nearest []Item, dist float64, obj Item) ([]float64, []Item) {
|
||||||
|
i := 0
|
||||||
|
for i < k && dist >= dists[i] {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i >= k {
|
||||||
|
return dists, nearest
|
||||||
|
}
|
||||||
|
|
||||||
|
left, right := dists[:i], dists[i:k-1]
|
||||||
|
updatedDists := make([]float64, k)
|
||||||
|
copy(updatedDists, left)
|
||||||
|
updatedDists[i] = dist
|
||||||
|
copy(updatedDists[i+1:], right)
|
||||||
|
|
||||||
|
leftObjs, rightObjs := nearest[:i], nearest[i:k-1]
|
||||||
|
updatedNearest := make([]Item, k)
|
||||||
|
copy(updatedNearest, leftObjs)
|
||||||
|
updatedNearest[i] = obj
|
||||||
|
copy(updatedNearest[i+1:], rightObjs)
|
||||||
|
|
||||||
|
return updatedDists, updatedNearest
|
||||||
|
}
|
||||||
|
|
||||||
|
type entrySlice struct {
|
||||||
|
entries []d3branchT
|
||||||
|
dists []float64
|
||||||
|
x, y, z float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s entrySlice) Len() int { return len(s.entries) }
|
||||||
|
|
||||||
|
func (s entrySlice) Swap(i, j int) {
|
||||||
|
s.entries[i], s.entries[j] = s.entries[j], s.entries[i]
|
||||||
|
s.dists[i], s.dists[j] = s.dists[j], s.dists[i]
|
||||||
|
}
|
||||||
|
func (s entrySlice) Less(i, j int) bool {
|
||||||
|
return s.dists[i] < s.dists[j]
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortEntries(x, y, z float64, entries []d3branchT) ([]d3branchT, []float64) {
|
||||||
|
sorted := make([]d3branchT, len(entries))
|
||||||
|
dists := make([]float64, len(entries))
|
||||||
|
for i := 0; i < len(entries); i++ {
|
||||||
|
sorted[i] = entries[i]
|
||||||
|
dists[i] = minDist(x, y, z, entries[i].rect)
|
||||||
|
}
|
||||||
|
sort.Sort(entrySlice{sorted, dists, x, y, z})
|
||||||
|
return sorted, dists
|
||||||
|
}
|
||||||
|
|
||||||
|
func pruneEntries(x, y, z float64, entries []d3branchT, minDists []float64) []d3branchT {
|
||||||
|
minMinMaxDist := math.MaxFloat64
|
||||||
|
for i := range entries {
|
||||||
|
minMaxDist := minMaxDist(x, y, z, entries[i].rect)
|
||||||
|
if minMaxDist < minMinMaxDist {
|
||||||
|
minMinMaxDist = minMaxDist
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// remove all entries with minDist > minMinMaxDist
|
||||||
|
pruned := []d3branchT{}
|
||||||
|
for i := range entries {
|
||||||
|
if minDists[i] <= minMinMaxDist {
|
||||||
|
pruned = append(pruned, entries[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pruned
|
||||||
|
}
|
||||||
|
|
||||||
|
// minMaxDist computes the minimum of the maximum distances from p to points
|
||||||
|
// on r. If r is the bounding box of some geometric objects, then there is
|
||||||
|
// at least one object contained in r within minMaxDist(p, r) of p.
|
||||||
|
//
|
||||||
|
// Implemented per Definition 4 of "Nearest Neighbor Queries" by
|
||||||
|
// N. Roussopoulos, S. Kelley and F. Vincent, ACM SIGMOD, pages 71-79, 1995.
|
||||||
|
func minMaxDist(x, y, z float64, r d3rectT) float64 {
|
||||||
|
|
||||||
|
p := [3]float64{x, y, z}
|
||||||
|
rp := [3]float64{
|
||||||
|
float64(r.min[0]), float64(r.min[1]), float64(r.min[2]),
|
||||||
|
}
|
||||||
|
rq := [3]float64{
|
||||||
|
float64(r.max[0]), float64(r.max[1]), float64(r.max[2]),
|
||||||
|
}
|
||||||
|
|
||||||
|
// by definition, MinMaxDist(p, r) =
|
||||||
|
// min{1<=k<=n}(|pk - rmk|^2 + sum{1<=i<=n, i != k}(|pi - rMi|^2))
|
||||||
|
// where rmk and rMk are defined as follows:
|
||||||
|
|
||||||
|
rm := func(k int) float64 {
|
||||||
|
if p[k] <= (rp[k]+rq[k])/2 {
|
||||||
|
return rp[k]
|
||||||
|
}
|
||||||
|
return rq[k]
|
||||||
|
}
|
||||||
|
|
||||||
|
rM := func(k int) float64 {
|
||||||
|
if p[k] >= (rp[k]+rq[k])/2 {
|
||||||
|
return rp[k]
|
||||||
|
}
|
||||||
|
return rq[k]
|
||||||
|
}
|
||||||
|
|
||||||
|
// This formula can be computed in linear time by precomputing
|
||||||
|
// S = sum{1<=i<=n}(|pi - rMi|^2).
|
||||||
|
|
||||||
|
S := 0.0
|
||||||
|
for i := range p {
|
||||||
|
d := p[i] - rM(i)
|
||||||
|
S += d * d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute MinMaxDist using the precomputed S.
|
||||||
|
min := math.MaxFloat64
|
||||||
|
for k := range p {
|
||||||
|
d1 := p[k] - rM(k)
|
||||||
|
d2 := p[k] - rm(k)
|
||||||
|
d := S - d1*d1 + d2*d2
|
||||||
|
if d < min {
|
||||||
|
min = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return min
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
package rtree
|
package rtree
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
@ -87,6 +88,25 @@ func TestBounds(t *testing.T) {
|
|||||||
t.Fatalf("expected 10,10,0 30,30,0, got %v,%v %v,%v\n", minX, minY, minZ, maxX, maxY, maxZ)
|
t.Fatalf("expected 10,10,0 30,30,0, got %v,%v %v,%v\n", minX, minY, minZ, maxX, maxY, maxZ)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func TestKNN(t *testing.T) {
|
||||||
|
x, y, z := 20., 20., 0.
|
||||||
|
tr := New()
|
||||||
|
tr.Insert(wpp(5, 5, 0))
|
||||||
|
tr.Insert(wpp(19, 19, 0))
|
||||||
|
tr.Insert(wpp(12, 19, 0))
|
||||||
|
tr.Insert(wpp(-5, 5, 0))
|
||||||
|
tr.Insert(wpp(33, 21, 0))
|
||||||
|
items := tr.NearestNeighbors(10, x, y, z)
|
||||||
|
var res string
|
||||||
|
for i, item := range items {
|
||||||
|
ix, iy, _, _, _, _ := item.Rect()
|
||||||
|
res += fmt.Sprintf("%d:%v,%v\n", i, ix, iy)
|
||||||
|
}
|
||||||
|
if res != "0:19,19\n1:12,19\n2:33,21\n3:5,5\n4:-5,5\n" {
|
||||||
|
t.Fatal("invalid response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkInsert(b *testing.B) {
|
func BenchmarkInsert(b *testing.B) {
|
||||||
rand.Seed(0)
|
rand.Seed(0)
|
||||||
tr := New()
|
tr := New()
|
||||||
|
25
tests/keys_search.go
Normal file
25
tests/keys_search.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package tests
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func subTestSearch(t *testing.T, mc *mockServer) {
|
||||||
|
runStep(t, mc, "KNN", keys_KNN_test)
|
||||||
|
}
|
||||||
|
|
||||||
|
func keys_KNN_test(mc *mockServer) error {
|
||||||
|
return mc.DoBatch([][]interface{}{
|
||||||
|
{"SET", "mykey", "1", "POINT", 5, 5}, {"OK"},
|
||||||
|
{"SET", "mykey", "2", "POINT", 19, 19}, {"OK"},
|
||||||
|
{"SET", "mykey", "3", "POINT", 12, 19}, {"OK"},
|
||||||
|
{"SET", "mykey", "4", "POINT", -5, 5}, {"OK"},
|
||||||
|
{"SET", "mykey", "5", "POINT", 33, 21}, {"OK"},
|
||||||
|
{"NEARBY", "mykey", "LIMIT", 10, "DISTANCE", "POINTS", "POINT", 20, 20}, {
|
||||||
|
"[0 [" +
|
||||||
|
"[2 [19 19] 152808.67164037024] " +
|
||||||
|
"[3 [12 19] 895945.1409106688] " +
|
||||||
|
"[5 [33 21] 1448929.5916252395] " +
|
||||||
|
"[1 [5 5] 2327116.1069888202] " +
|
||||||
|
"[4 [-5 5] 3227402.6159841116]" +
|
||||||
|
"]]"},
|
||||||
|
})
|
||||||
|
}
|
@ -41,6 +41,7 @@ func TestAll(t *testing.T) {
|
|||||||
defer mc.Close()
|
defer mc.Close()
|
||||||
runSubTest(t, "keys", mc, subTestKeys)
|
runSubTest(t, "keys", mc, subTestKeys)
|
||||||
runSubTest(t, "json", mc, subTestJSON)
|
runSubTest(t, "json", mc, subTestJSON)
|
||||||
|
runSubTest(t, "search", mc, subTestSearch)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runSubTest(t *testing.T, name string, mc *mockServer, test func(t *testing.T, mc *mockServer)) {
|
func runSubTest(t *testing.T, name string, mc *mockServer, test func(t *testing.T, mc *mockServer)) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user