diff --git a/controller/collection/collection.go b/controller/collection/collection.go index 7b1f9a46..9a1ec577 100644 --- a/controller/collection/collection.go +++ b/controller/collection/collection.go @@ -536,3 +536,17 @@ func (c *Collection) Intersects(cursor uint64, sparse uint8, obj geojson.Object, return true }) } + +func (c *Collection) NearestNeighbors(k int, lat, lon float64, iterator func(id string, obj geojson.Object, fields []float64) bool) { + c.index.NearestNeighbors(k, lat, lon, func(item index.Item) bool { + var iitm *itemT + iitm, ok := item.(*itemT) + if !ok { + return true // just ignore + } + if !iterator(iitm.id, iitm.object, c.getFieldValues(iitm.id)) { + return false + } + return true + }) +} diff --git a/controller/search.go b/controller/search.go index 336bc2a1..3430eeef 100644 --- a/controller/search.go +++ b/controller/search.go @@ -22,6 +22,7 @@ type liveFenceSwitches struct { maxLat, maxLon float64 cmd string roam roamSwitches + knn bool groups map[string]string } @@ -85,10 +86,22 @@ func (c *Controller) cmdSearchArgs(cmd string, vs []resp.Value, types []string) err = errInvalidNumberOfArguments return } + + umeters := true if vs, smeters, ok = tokenval(vs); !ok || smeters == "" { - err = errInvalidNumberOfArguments - return + umeters = false + if cmd == "nearby" { + // possible that this is KNN search + s.knn = s.searchScanBaseTokens.ulimit && // must be true + !s.searchScanBaseTokens.usparse && // must be false + s.searchScanBaseTokens.cursor == 0 // must be zero + } + if !s.knn { + err = errInvalidArgument(slat) + return + } } + if s.lat, err = strconv.ParseFloat(slat, 64); err != nil { err = errInvalidArgument(slat) return @@ -97,9 +110,12 @@ func (c *Controller) cmdSearchArgs(cmd string, vs []resp.Value, types []string) err = errInvalidArgument(slon) return } - if s.meters, err = strconv.ParseFloat(smeters, 64); err != nil { - err = errInvalidArgument(smeters) - return + + if umeters { + if s.meters, err = strconv.ParseFloat(smeters, 64); err != nil { + err = errInvalidArgument(smeters) + return + } } case "object": var obj string @@ -290,7 +306,7 @@ func (c *Controller) cmdNearby(msg *server.Message) (res string, err error) { } sw.writeHead() if sw.col != nil { - s.cursor = sw.col.Nearby(s.cursor, s.sparse, s.lat, s.lon, s.meters, minZ, maxZ, func(id string, o geojson.Object, fields []float64) bool { + iter := func(id string, o geojson.Object, fields []float64) bool { // Calculate distance if we need to distance := 0.0 if s.distance { @@ -303,7 +319,12 @@ func (c *Controller) cmdNearby(msg *server.Message) (res string, err error) { fields: fields, distance: distance, }) - }) + } + if s.knn { + sw.col.NearestNeighbors(int(s.limit), s.lat, s.lon, iter) + } else { + s.cursor = sw.col.Nearby(s.cursor, s.sparse, s.lat, s.lon, s.meters, minZ, maxZ, iter) + } } sw.writeFoot(s.cursor) if msg.OutputType == server.JSON { diff --git a/controller/token.go b/controller/token.go index 0f0c2fb5..aa11cfc9 100644 --- a/controller/token.go +++ b/controller/token.go @@ -169,7 +169,9 @@ type searchScanBaseTokens struct { glob string wheres []whereT nofields bool + ulimit bool limit uint64 + usparse bool sparse uint8 desc bool } @@ -465,12 +467,14 @@ func parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value, } } if slimit != "" { + t.ulimit = true if t.limit, err = strconv.ParseUint(slimit, 10, 64); err != nil || t.limit == 0 { err = errInvalidArgument(slimit) return } } if ssparse != "" { + t.usparse = true var sparse uint64 if sparse, err = strconv.ParseUint(ssparse, 10, 8); err != nil || sparse == 0 || sparse > 8 { err = errInvalidArgument(ssparse) diff --git a/index/index.go b/index/index.go index 53c5b4d8..80a49164 100644 --- a/index/index.go +++ b/index/index.go @@ -123,6 +123,20 @@ func (ix *Index) getRTreeItem(item rtree.Item) Item { return nil } +func (ix *Index) NearestNeighbors(k int, lat, lon float64, iterator func(item Item) bool) { + x, y, _ := normPoint(lat, lon) + items := ix.r.NearestNeighbors(k, x, y, 0) + for _, item := range items { + iitm := ix.getRTreeItem(item) + if item == nil { + continue + } + if !iterator(iitm) { + break + } + } +} + // Search returns all items that intersect the bounding box. func (ix *Index) Search(cursor uint64, swLat, swLon, neLat, neLon, minZ, maxZ float64, iterator func(item Item) bool) (ncursor uint64) { var idx uint64 diff --git a/index/rtree/knn.go b/index/rtree/knn.go new file mode 100644 index 00000000..30497b6e --- /dev/null +++ b/index/rtree/knn.go @@ -0,0 +1,205 @@ +// Much of the KNN code has been adapted from the +// github.com/dhconnelly/rtreego project. +// +// Copyright 2012 Daniel Connelly. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package rtree + +import ( + "math" + "sort" +) + +// NearestNeighbors gets the closest Spatials to the Point. +func (tr *RTree) NearestNeighbors(k int, x, y, z float64) []Item { + if tr.tr.root == nil { + return nil + } + dists := make([]float64, k) + objs := make([]Item, k) + for i := 0; i < k; i++ { + dists[i] = math.MaxFloat64 + objs[i] = nil + } + objs, _ = tr.nearestNeighbors(k, x, y, z, tr.tr.root, dists, objs) + //for i := 0; i < len(objs); i++ { + // fmt.Printf("%v\n", objs[i]) + //} + for i := 0; i < len(objs); i++ { + if objs[i] == nil { + return objs[:i] + } + } + return objs +} + +// minDist computes the square of the distance from a point to a rectangle. +// If the point is contained in the rectangle then the distance is zero. +// +// Implemented per Definition 2 of "Nearest Neighbor Queries" by +// N. Roussopoulos, S. Kelley and F. Vincent, ACM SIGMOD, pages 71-79, 1995. +func minDist(x, y, z float64, r d3rectT) float64 { + sum := 0.0 + p := [3]float64{x, y, z} + rp := [3]float64{ + float64(r.min[0]), float64(r.min[1]), float64(r.min[2]), + } + rq := [3]float64{ + float64(r.max[0]), float64(r.max[1]), float64(r.max[2]), + } + for i := 0; i < 3; i++ { + if p[i] < float64(rp[i]) { + d := p[i] - float64(rp[i]) + sum += d * d + } else if p[i] > float64(rq[i]) { + d := p[i] - float64(rq[i]) + sum += d * d + } + } + return sum +} + +func (tr *RTree) nearestNeighbors(k int, x, y, z float64, n *d3nodeT, dists []float64, nearest []Item) ([]Item, []float64) { + if n.isLeaf() { + for i := 0; i < n.count; i++ { + e := n.branch[i] + dist := math.Sqrt(minDist(x, y, z, e.rect)) + dists, nearest = insertNearest(k, dists, nearest, dist, e.data.(Item)) + } + } else { + branches, branchDists := sortEntries(x, y, z, n.branch[:n.count]) + branches = pruneEntries(x, y, z, branches, branchDists) + for _, e := range branches { + nearest, dists = tr.nearestNeighbors(k, x, y, z, e.child, dists, nearest) + } + } + return nearest, dists +} + +// insert obj into nearest and return the first k elements in increasing order. +func insertNearest(k int, dists []float64, nearest []Item, dist float64, obj Item) ([]float64, []Item) { + i := 0 + for i < k && dist >= dists[i] { + i++ + } + if i >= k { + return dists, nearest + } + + left, right := dists[:i], dists[i:k-1] + updatedDists := make([]float64, k) + copy(updatedDists, left) + updatedDists[i] = dist + copy(updatedDists[i+1:], right) + + leftObjs, rightObjs := nearest[:i], nearest[i:k-1] + updatedNearest := make([]Item, k) + copy(updatedNearest, leftObjs) + updatedNearest[i] = obj + copy(updatedNearest[i+1:], rightObjs) + + return updatedDists, updatedNearest +} + +type entrySlice struct { + entries []d3branchT + dists []float64 + x, y, z float64 +} + +func (s entrySlice) Len() int { return len(s.entries) } + +func (s entrySlice) Swap(i, j int) { + s.entries[i], s.entries[j] = s.entries[j], s.entries[i] + s.dists[i], s.dists[j] = s.dists[j], s.dists[i] +} +func (s entrySlice) Less(i, j int) bool { + return s.dists[i] < s.dists[j] +} + +func sortEntries(x, y, z float64, entries []d3branchT) ([]d3branchT, []float64) { + sorted := make([]d3branchT, len(entries)) + dists := make([]float64, len(entries)) + for i := 0; i < len(entries); i++ { + sorted[i] = entries[i] + dists[i] = minDist(x, y, z, entries[i].rect) + } + sort.Sort(entrySlice{sorted, dists, x, y, z}) + return sorted, dists +} + +func pruneEntries(x, y, z float64, entries []d3branchT, minDists []float64) []d3branchT { + minMinMaxDist := math.MaxFloat64 + for i := range entries { + minMaxDist := minMaxDist(x, y, z, entries[i].rect) + if minMaxDist < minMinMaxDist { + minMinMaxDist = minMaxDist + } + } + // remove all entries with minDist > minMinMaxDist + pruned := []d3branchT{} + for i := range entries { + if minDists[i] <= minMinMaxDist { + pruned = append(pruned, entries[i]) + } + } + return pruned +} + +// minMaxDist computes the minimum of the maximum distances from p to points +// on r. If r is the bounding box of some geometric objects, then there is +// at least one object contained in r within minMaxDist(p, r) of p. +// +// Implemented per Definition 4 of "Nearest Neighbor Queries" by +// N. Roussopoulos, S. Kelley and F. Vincent, ACM SIGMOD, pages 71-79, 1995. +func minMaxDist(x, y, z float64, r d3rectT) float64 { + + p := [3]float64{x, y, z} + rp := [3]float64{ + float64(r.min[0]), float64(r.min[1]), float64(r.min[2]), + } + rq := [3]float64{ + float64(r.max[0]), float64(r.max[1]), float64(r.max[2]), + } + + // by definition, MinMaxDist(p, r) = + // min{1<=k<=n}(|pk - rmk|^2 + sum{1<=i<=n, i != k}(|pi - rMi|^2)) + // where rmk and rMk are defined as follows: + + rm := func(k int) float64 { + if p[k] <= (rp[k]+rq[k])/2 { + return rp[k] + } + return rq[k] + } + + rM := func(k int) float64 { + if p[k] >= (rp[k]+rq[k])/2 { + return rp[k] + } + return rq[k] + } + + // This formula can be computed in linear time by precomputing + // S = sum{1<=i<=n}(|pi - rMi|^2). + + S := 0.0 + for i := range p { + d := p[i] - rM(i) + S += d * d + } + + // Compute MinMaxDist using the precomputed S. + min := math.MaxFloat64 + for k := range p { + d1 := p[k] - rM(k) + d2 := p[k] - rm(k) + d := S - d1*d1 + d2*d2 + if d < min { + min = d + } + } + + return min +} diff --git a/index/rtree/rtree_test.go b/index/rtree/rtree_test.go index c1d89c00..9f3cee21 100644 --- a/index/rtree/rtree_test.go +++ b/index/rtree/rtree_test.go @@ -1,6 +1,7 @@ package rtree import ( + "fmt" "math/rand" "runtime" "testing" @@ -87,6 +88,25 @@ func TestBounds(t *testing.T) { t.Fatalf("expected 10,10,0 30,30,0, got %v,%v %v,%v\n", minX, minY, minZ, maxX, maxY, maxZ) } } +func TestKNN(t *testing.T) { + x, y, z := 20., 20., 0. + tr := New() + tr.Insert(wpp(5, 5, 0)) + tr.Insert(wpp(19, 19, 0)) + tr.Insert(wpp(12, 19, 0)) + tr.Insert(wpp(-5, 5, 0)) + tr.Insert(wpp(33, 21, 0)) + items := tr.NearestNeighbors(10, x, y, z) + var res string + for i, item := range items { + ix, iy, _, _, _, _ := item.Rect() + res += fmt.Sprintf("%d:%v,%v\n", i, ix, iy) + } + if res != "0:19,19\n1:12,19\n2:33,21\n3:5,5\n4:-5,5\n" { + t.Fatal("invalid response") + } +} + func BenchmarkInsert(b *testing.B) { rand.Seed(0) tr := New() diff --git a/tests/keys_search.go b/tests/keys_search.go new file mode 100644 index 00000000..c7c64a82 --- /dev/null +++ b/tests/keys_search.go @@ -0,0 +1,25 @@ +package tests + +import "testing" + +func subTestSearch(t *testing.T, mc *mockServer) { + runStep(t, mc, "KNN", keys_KNN_test) +} + +func keys_KNN_test(mc *mockServer) error { + return mc.DoBatch([][]interface{}{ + {"SET", "mykey", "1", "POINT", 5, 5}, {"OK"}, + {"SET", "mykey", "2", "POINT", 19, 19}, {"OK"}, + {"SET", "mykey", "3", "POINT", 12, 19}, {"OK"}, + {"SET", "mykey", "4", "POINT", -5, 5}, {"OK"}, + {"SET", "mykey", "5", "POINT", 33, 21}, {"OK"}, + {"NEARBY", "mykey", "LIMIT", 10, "DISTANCE", "POINTS", "POINT", 20, 20}, { + "[0 [" + + "[2 [19 19] 152808.67164037024] " + + "[3 [12 19] 895945.1409106688] " + + "[5 [33 21] 1448929.5916252395] " + + "[1 [5 5] 2327116.1069888202] " + + "[4 [-5 5] 3227402.6159841116]" + + "]]"}, + }) +} diff --git a/tests/tests_test.go b/tests/tests_test.go index e38fa37e..0abf1334 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -41,6 +41,7 @@ func TestAll(t *testing.T) { defer mc.Close() runSubTest(t, "keys", mc, subTestKeys) runSubTest(t, "json", mc, subTestJSON) + runSubTest(t, "search", mc, subTestSearch) } func runSubTest(t *testing.T, name string, mc *mockServer, test func(t *testing.T, mc *mockServer)) {