diff --git a/pkg/controller/aof.go b/pkg/controller/aof.go index 25b08d93..a54ff443 100644 --- a/pkg/controller/aof.go +++ b/pkg/controller/aof.go @@ -133,26 +133,9 @@ func commandErrIsFatal(err error) bool { } func (c *Controller) writeAOF(value resp.Value, d *commandDetailsT) error { - if d != nil { - if !d.updated { - return nil // just ignore writes if the command did not update - } - if c.config.followHost() == "" { - // process hooks, for leader only - if d.parent { - // process children only - for _, d := range d.children { - if err := c.queueHooks(d); err != nil { - return err - } - } - } else { - // process parent - if err := c.queueHooks(d); err != nil { - return err - } - } - } + if d != nil && !d.updated { + // just ignore writes if the command did not update + return nil } if c.shrinking { var values []string @@ -178,14 +161,34 @@ func (c *Controller) writeAOF(value resp.Value, d *commandDetailsT) error { c.fcond.Broadcast() c.fcond.L.Unlock() + // process geofences if d != nil { - // write to live connection streams + // webhook geofences + if c.config.followHost() == "" { + // for leader only + if d.parent { + // queue children + for _, d := range d.children { + if err := c.queueHooks(d); err != nil { + return err + } + } + } else { + // queue parent + if err := c.queueHooks(d); err != nil { + return err + } + } + } + // live geofences c.lcond.L.Lock() if d.parent { + // queue children for _, d := range d.children { c.lstack = append(c.lstack, d) } } else { + // queue parent c.lstack = append(c.lstack, d) } c.lcond.Broadcast() @@ -196,7 +199,7 @@ func (c *Controller) writeAOF(value resp.Value, d *commandDetailsT) error { func (c *Controller) queueHooks(d *commandDetailsT) error { // big list of all of the messages - var hmsgs [][]byte + var hmsgs []string var hooks []*Hook // find the hook by the key if hm, ok := c.hookcols[d.key]; ok { @@ -204,9 +207,13 @@ func (c *Controller) queueHooks(d *commandDetailsT) error { // match the fence msgs := FenceMatch(hook.Name, hook.ScanWriter, hook.Fence, hook.Metas, d) if len(msgs) > 0 { - // append each msg to the big list - hmsgs = append(hmsgs, msgs...) - hooks = append(hooks, hook) + if hook.channel { + c.Publish(hook.Name, msgs...) + } else { + // append each msg to the big list + hmsgs = append(hmsgs, msgs...) + hooks = append(hooks, hook) + } } } } @@ -259,7 +266,7 @@ type liveAOFSwitches struct { } func (s liveAOFSwitches) Error() string { - return "going live" + return goingLive } func (c *Controller) cmdAOFMD5(msg *server.Message) (res resp.Value, err error) { diff --git a/pkg/controller/aofshrink.go b/pkg/controller/aofshrink.go index 2c96ef0c..09e1f13e 100644 --- a/pkg/controller/aofshrink.go +++ b/pkg/controller/aofshrink.go @@ -190,17 +190,23 @@ func (c *Controller) aofshrink() { if hook == nil { return } - hook.mu.Lock() - defer hook.mu.Unlock() + hook.cond.L.Lock() + defer hook.cond.L.Unlock() var values []string - values = append(values, "sethook") - values = append(values, name) - values = append(values, strings.Join(hook.Endpoints, ",")) + if hook.channel { + values = append(values, "setchan", name) + } else { + values = append(values, "sethook", name, + strings.Join(hook.Endpoints, ",")) + values = append(values) + } + for _, meta := range hook.Metas { + values = append(values, "meta", meta.Name, meta.Value) + } for _, value := range hook.Message.Values { values = append(values, value.String()) } - // append the values to the aof buffer aofbuf = append(aofbuf, '*') aofbuf = append(aofbuf, strconv.FormatInt(int64(len(values)), 10)...) diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index 181133f5..aebf4370 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -29,6 +29,8 @@ import ( var errOOM = errors.New("OOM command not allowed when used memory > 'maxmemory'") +const goingLive = "going live" + const hookLogPrefix = "hook:log:" type collectionT struct { @@ -109,6 +111,8 @@ type Controller struct { aofconnM map[net.Conn]bool luascripts *lScriptMap luapool *lStatePool + + pubsub *pubsub } // ListenAndServe starts a new tile38 server @@ -136,10 +140,10 @@ func ListenAndServeEx(host string, port int, dir string, ln *net.Listener, http expires: make(map[string]map[string]time.Time), started: time.Now(), conns: make(map[*server.Conn]*clientConn), - epc: endpoint.NewManager(), http: http, + pubsub: newPubsub(), } - + c.epc = endpoint.NewManager(c) c.luascripts = c.NewScriptMap() c.luapool = c.NewPool() defer c.luapool.Shutdown() @@ -217,7 +221,7 @@ func ListenAndServeEx(host string, port int, dir string, ln *net.Listener, http c.statsTotalCommands.add(1) err := c.handleInputCommand(conn, msg, w) if err != nil { - if err.Error() == "going live" { + if err.Error() == goingLive { return c.goLive(err, conn, rd, msg, websocket) } return err @@ -490,7 +494,9 @@ func (c *Controller) handleInputCommand(conn *server.Conn, msg *server.Message, default: c.mu.RLock() defer c.mu.RUnlock() - case "set", "del", "drop", "fset", "flushdb", "sethook", "pdelhook", "delhook", + case "set", "del", "drop", "fset", "flushdb", + "setchan", "pdelchan", "delchan", + "sethook", "pdelhook", "delhook", "expire", "persist", "jset", "pdel": // write operations write = true @@ -512,8 +518,9 @@ func (c *Controller) handleInputCommand(conn *server.Conn, msg *server.Message, if c.config.readOnly() { return writeErr("read only") } - case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "search", - "ttl", "bounds", "server", "info", "type", "jget", "evalro", "evalrosha": + case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", + "chans", "search", "ttl", "bounds", "server", "info", "type", "jget", + "evalro", "evalrosha": // read operations c.mu.RLock() defer c.mu.RUnlock() @@ -548,6 +555,8 @@ func (c *Controller) handleInputCommand(conn *server.Conn, msg *server.Message, defer c.mu.Unlock() case "evalna", "evalnasha": // No locking for scripts, otherwise writes cannot happen within scripts + case "subscribe", "psubscribe", "publish": + // No locking for pubsub } res, d, err := c.command(msg, w, conn) @@ -556,7 +565,7 @@ func (c *Controller) handleInputCommand(conn *server.Conn, msg *server.Message, return writeErr(res.String()) } if err != nil { - if err.Error() == "going live" { + if err.Error() == goingLive { return err } return writeErr(err.Error()) @@ -630,20 +639,31 @@ func (c *Controller) command( res, d, err = c.cmdDrop(msg) case "flushdb": res, d, err = c.cmdFlushDB(msg) + case "sethook": - res, d, err = c.cmdSetHook(msg) + res, d, err = c.cmdSetHook(msg, false) case "delhook": - res, d, err = c.cmdDelHook(msg) + res, d, err = c.cmdDelHook(msg, false) case "pdelhook": - res, d, err = c.cmdPDelHook(msg) + res, d, err = c.cmdPDelHook(msg, false) + case "hooks": + res, err = c.cmdHooks(msg, false) + + case "setchan": + res, d, err = c.cmdSetHook(msg, true) + case "delchan": + res, d, err = c.cmdDelHook(msg, true) + case "pdelchan": + res, d, err = c.cmdPDelHook(msg, true) + case "chans": + res, err = c.cmdHooks(msg, true) + case "expire": res, d, err = c.cmdExpire(msg) case "persist": res, d, err = c.cmdPersist(msg) case "ttl": res, err = c.cmdTTL(msg) - case "hooks": - res, err = c.cmdHooks(msg) case "shutdown": if !core.DevMode { err = fmt.Errorf("unknown command '%s'", msg.Values[0]) @@ -737,6 +757,12 @@ func (c *Controller) command( res, err = c.cmdScriptExists(msg) case "script flush": res, err = c.cmdScriptFlush(msg) + case "subscribe": + res, err = c.cmdSubscribe(msg) + case "psubscribe": + res, err = c.cmdPsubscribe(msg) + case "publish": + res, err = c.cmdPublish(msg) } return } diff --git a/pkg/controller/fence.go b/pkg/controller/fence.go index 311fb45d..31ffa247 100644 --- a/pkg/controller/fence.go +++ b/pkg/controller/fence.go @@ -12,14 +12,14 @@ import ( ) // FenceMatch executes a fence match returns back json messages for fence detection. -func FenceMatch(hookName string, sw *scanWriter, fence *liveFenceSwitches, metas []FenceMeta, details *commandDetailsT) [][]byte { +func FenceMatch(hookName string, sw *scanWriter, fence *liveFenceSwitches, metas []FenceMeta, details *commandDetailsT) []string { msgs := fenceMatch(hookName, sw, fence, metas, details) if len(fence.accept) == 0 { return msgs } - nmsgs := make([][]byte, 0, len(msgs)) + nmsgs := make([]string, 0, len(msgs)) for _, msg := range msgs { - if fence.accept[gjson.GetBytes(msg, "command").String()] { + if fence.accept[gjson.Get(msg, "command").String()] { nmsgs = append(nmsgs, msg) } } @@ -47,9 +47,12 @@ func appendHookDetails(b []byte, hookName string, metas []FenceMeta) []byte { func hookJSONString(hookName string, metas []FenceMeta) string { return string(appendHookDetails(nil, hookName, metas)) } -func fenceMatch(hookName string, sw *scanWriter, fence *liveFenceSwitches, metas []FenceMeta, details *commandDetailsT) [][]byte { +func fenceMatch(hookName string, sw *scanWriter, fence *liveFenceSwitches, metas []FenceMeta, details *commandDetailsT) []string { if details.command == "drop" { - return [][]byte{[]byte(`{"command":"drop"` + hookJSONString(hookName, metas) + `,"time":` + jsonTimeFormat(details.timestamp) + `}`)} + return []string{ + `{"command":"drop"` + hookJSONString(hookName, metas) + + `,"time":` + jsonTimeFormat(details.timestamp) + `}`, + } } if len(fence.glob) > 0 && !(len(fence.glob) == 1 && fence.glob[0] == '*') { match, _ := glob.Match(fence.glob, details.id) @@ -69,7 +72,10 @@ func fenceMatch(hookName string, sw *scanWriter, fence *liveFenceSwitches, metas } } if details.command == "del" { - return [][]byte{[]byte(`{"command":"del"` + hookJSONString(hookName, metas) + `,"id":` + jsonString(details.id) + `,"time":` + jsonTimeFormat(details.timestamp) + `}`)} + return []string{ + `{"command":"del"` + hookJSONString(hookName, metas) + `,"id":` + jsonString(details.id) + + `,"time":` + jsonTimeFormat(details.timestamp) + `}`, + } } var roamkeys, roamids []string var roammeters []float64 @@ -164,14 +170,13 @@ func fenceMatch(hookName string, sw *scanWriter, fence *liveFenceSwitches, metas return nil } - res := make([]byte, sw.wr.Len()) - copy(res, sw.wr.Bytes()) + res := sw.wr.String() sw.wr.Reset() if len(res) > 0 && res[0] == ',' { res = res[1:] } if sw.output == outputIDs { - res = []byte(`{"id":` + string(res) + `}`) + res = `{"id":` + string(res) + `}` } sw.mu.Unlock() @@ -195,12 +200,13 @@ func fenceMatch(hookName string, sw *scanWriter, fence *liveFenceSwitches, metas } } - var msgs [][]byte + var msgs []string if fence.detect == nil || fence.detect[detect] { if len(res) > 0 && res[0] == '{' { - msgs = append(msgs, makemsg(details.command, group, detect, hookName, metas, details.key, details.timestamp, res[1:])) + msgs = append(msgs, makemsg(details.command, group, detect, + hookName, metas, details.key, details.timestamp, res[1:])) } else { - msgs = append(msgs, res) + msgs = append(msgs, string(res)) } } switch detect { @@ -214,11 +220,12 @@ func fenceMatch(hookName string, sw *scanWriter, fence *liveFenceSwitches, metas } case "roam": if len(msgs) > 0 { - var nmsgs [][]byte + var nmsgs []string msg := msgs[0][:len(msgs[0])-1] for i, id := range roamids { - nmsg := append([]byte(nil), msg...) + var nmsg []byte + nmsg = append(nmsg, msg...) nmsg = append(nmsg, `,"nearby":{"key":`...) nmsg = appendJSONString(nmsg, roamkeys[i]) nmsg = append(nmsg, `,"id":`...) @@ -261,7 +268,7 @@ func fenceMatch(hookName string, sw *scanWriter, fence *liveFenceSwitches, metas nmsg = append(nmsg, '}') nmsg = append(nmsg, '}') - nmsgs = append(nmsgs, nmsg) + nmsgs = append(nmsgs, string(nmsg)) } msgs = nmsgs } @@ -269,7 +276,10 @@ func fenceMatch(hookName string, sw *scanWriter, fence *liveFenceSwitches, metas return msgs } -func makemsg(command, group, detect, hookName string, metas []FenceMeta, key string, t time.Time, tail []byte) []byte { +func makemsg( + command, group, detect, hookName string, + metas []FenceMeta, key string, t time.Time, tail string, +) string { var buf []byte buf = append(append(buf, `{"command":"`...), command...) buf = append(append(buf, `","group":"`...), group...) @@ -279,7 +289,7 @@ func makemsg(command, group, detect, hookName string, metas []FenceMeta, key str buf = appendJSONString(append(buf, `,"key":`...), key) buf = appendJSONTimeFormat(append(buf, `,"time":`...), t) buf = append(append(buf, ','), tail...) - return buf + return string(buf) } func fenceMatchObject(fence *liveFenceSwitches, obj geojson.Object) bool { diff --git a/pkg/controller/hooks.go b/pkg/controller/hooks.go index 073baabe..8ab4dced 100644 --- a/pkg/controller/hooks.go +++ b/pkg/controller/hooks.go @@ -2,7 +2,6 @@ package controller import ( "bytes" - "encoding/json" "errors" "sort" "strings" @@ -36,31 +35,38 @@ func (a hooksByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (c *Controller) cmdSetHook(msg *server.Message) (res resp.Value, d commandDetailsT, err error) { +func (c *Controller) cmdSetHook(msg *server.Message, chanCmd bool) ( + res resp.Value, d commandDetailsT, err error, +) { start := time.Now() - vs := msg.Values[1:] var name, urls, cmd string var ok bool if vs, name, ok = tokenval(vs); !ok || name == "" { return server.NOMessage, d, errInvalidNumberOfArguments } - if vs, urls, ok = tokenval(vs); !ok || urls == "" { - return server.NOMessage, d, errInvalidNumberOfArguments - } var endpoints []string - for _, url := range strings.Split(urls, ",") { - url = strings.TrimSpace(url) - err := c.epc.Validate(url) - if err != nil { - log.Errorf("sethook: %v", err) - return resp.SimpleStringValue(""), d, errInvalidArgument(url) + if chanCmd { + endpoints = []string{"local://" + name} + } else { + if vs, urls, ok = tokenval(vs); !ok || urls == "" { + return server.NOMessage, d, errInvalidNumberOfArguments + } + for _, url := range strings.Split(urls, ",") { + url = strings.TrimSpace(url) + err := c.epc.Validate(url) + if err != nil { + log.Errorf("sethook: %v", err) + return resp.SimpleStringValue(""), d, errInvalidArgument(url) + } + endpoints = append(endpoints, url) } - endpoints = append(endpoints, url) } var commandvs []resp.Value var cmdlc string var types []string + // var expires float64 + // var expiresSet bool metaMap := make(map[string]string) for { commandvs = vs @@ -82,6 +88,18 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res resp.Value, d commandD } metaMap[metakey] = metaval continue + // case "ex": + // var s string + // if vs, s, ok = tokenval(vs); !ok || s == "" { + // return server.NOMessage, d, errInvalidNumberOfArguments + // } + // v, err := strconv.ParseFloat(s, 64) + + // if err != nil { + // return server.NOMessage, d, errInvalidArgument(s) + // } + // expires = v + // expiresSet = true case "nearby": types = nearbyTypes case "within", "intersects": @@ -89,7 +107,7 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res resp.Value, d commandD } break } - s, err := c.cmdSearchArgs(cmdlc, vs, types) + s, err := c.cmdSearchArgs(true, cmdlc, vs, types) defer s.Close() if err != nil { return server.NOMessage, d, err @@ -119,12 +137,18 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res resp.Value, d commandD Endpoints: endpoints, Fence: &s, Message: cmsg, - db: c.qdb, epm: c.epc, Metas: metas, + channel: chanCmd, + cond: sync.NewCond(&sync.Mutex{}), + } + // if expiresSet { + // hook.expires = + // time.Now().Add(time.Duration(expires * float64(time.Second))) + // } + if !chanCmd { + hook.db = c.qdb } - hook.cond = sync.NewCond(&hook.mu) - var wr bytes.Buffer hook.ScanWriter, err = c.newScanWriter( &wr, cmsg, s.key, s.output, s.precision, s.glob, false, @@ -134,6 +158,10 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res resp.Value, d commandD } if h, ok := c.hooks[name]; ok { + if h.channel != chanCmd { + return server.NOMessage, d, + errors.New("hooks and channels cannot share the same name") + } if h.Equals(hook) { // it was a match so we do nothing. But let's signal just // for good measure. @@ -171,7 +199,9 @@ func (c *Controller) cmdSetHook(msg *server.Message) (res resp.Value, d commandD return server.NOMessage, d, nil } -func (c *Controller) cmdDelHook(msg *server.Message) (res resp.Value, d commandDetailsT, err error) { +func (c *Controller) cmdDelHook(msg *server.Message, chanCmd bool) ( + res resp.Value, d commandDetailsT, err error, +) { start := time.Now() vs := msg.Values[1:] @@ -183,7 +213,7 @@ func (c *Controller) cmdDelHook(msg *server.Message) (res resp.Value, d commandD if len(vs) != 0 { return server.NOMessage, d, errInvalidNumberOfArguments } - if h, ok := c.hooks[name]; ok { + if h, ok := c.hooks[name]; ok && h.channel == chanCmd { h.Close() if hm, ok := c.hookcols[h.Key]; ok { delete(hm, h.Name) @@ -205,7 +235,9 @@ func (c *Controller) cmdDelHook(msg *server.Message) (res resp.Value, d commandD return } -func (c *Controller) cmdPDelHook(msg *server.Message) (res resp.Value, d commandDetailsT, err error) { +func (c *Controller) cmdPDelHook(msg *server.Message, channel bool) ( + res resp.Value, d commandDetailsT, err error, +) { start := time.Now() vs := msg.Values[1:] @@ -219,19 +251,21 @@ func (c *Controller) cmdPDelHook(msg *server.Message) (res resp.Value, d command } count := 0 - for name := range c.hooks { - match, _ := glob.Match(pattern, name) - if match { - if h, ok := c.hooks[name]; ok { - h.Close() - if hm, ok := c.hookcols[h.Key]; ok { - delete(hm, h.Name) - } - delete(c.hooks, h.Name) - d.updated = true - count++ - } + for name, h := range c.hooks { + if h.channel != channel { + continue } + match, _ := glob.Match(pattern, name) + if !match { + continue + } + h.Close() + if hm, ok := c.hookcols[h.Key]; ok { + delete(hm, h.Name) + } + delete(c.hooks, h.Name) + d.updated = true + count++ } d.timestamp = time.Now() @@ -244,7 +278,9 @@ func (c *Controller) cmdPDelHook(msg *server.Message) (res resp.Value, d command return } -func (c *Controller) cmdHooks(msg *server.Message) (res resp.Value, err error) { +func (c *Controller) cmdHooks(msg *server.Message, channel bool) ( + res resp.Value, err error, +) { start := time.Now() vs := msg.Values[1:] @@ -260,6 +296,9 @@ func (c *Controller) cmdHooks(msg *server.Message) (res resp.Value, err error) { var hooks []*Hook for name, hook := range c.hooks { + if hook.channel != channel { + continue + } match, _ := glob.Match(pattern, name) if match { hooks = append(hooks, hook) @@ -270,7 +309,12 @@ func (c *Controller) cmdHooks(msg *server.Message) (res resp.Value, err error) { switch msg.OutputType { case server.JSON: buf := &bytes.Buffer{} - buf.WriteString(`{"ok":true,"hooks":[`) + buf.WriteString(`{"ok":true,`) + if channel { + buf.WriteString(`"chans":[`) + } else { + buf.WriteString(`"hooks":[`) + } for i, hook := range hooks { if i > 0 { buf.WriteByte(',') @@ -278,12 +322,14 @@ func (c *Controller) cmdHooks(msg *server.Message) (res resp.Value, err error) { buf.WriteString(`{`) buf.WriteString(`"name":` + jsonString(hook.Name)) buf.WriteString(`,"key":` + jsonString(hook.Key)) - buf.WriteString(`,"endpoints":[`) - for i, endpoint := range hook.Endpoints { - if i > 0 { - buf.WriteByte(',') + if !channel { + buf.WriteString(`,"endpoints":[`) + for i, endpoint := range hook.Endpoints { + if i > 0 { + buf.WriteByte(',') + } + buf.WriteString(jsonString(endpoint)) } - buf.WriteString(jsonString(endpoint)) } buf.WriteString(`],"command":[`) for i, v := range hook.Message.Values { @@ -303,7 +349,8 @@ func (c *Controller) cmdHooks(msg *server.Message) (res resp.Value, err error) { } buf.WriteString(`}}`) } - buf.WriteString(`],"elapsed":"` + time.Now().Sub(start).String() + "\"}") + buf.WriteString(`],"elapsed":"` + + time.Now().Sub(start).String() + "\"}") return resp.StringValue(buf.String()), nil case server.RESP: var vals []resp.Value @@ -332,7 +379,6 @@ func (c *Controller) cmdHooks(msg *server.Message) (res resp.Value, err error) { // Hook represents a hook. type Hook struct { - mu sync.Mutex cond *sync.Cond Key string Name string @@ -342,12 +388,15 @@ type Hook struct { ScanWriter *scanWriter Metas []FenceMeta db *buntdb.DB + channel bool closed bool opened bool query string epm *endpoint.Manager + expires time.Time } +// Equals returns true if two hooks are equal func (h *Hook) Equals(hook *Hook) bool { if h.Key != hook.Key || h.Name != hook.Name || @@ -370,6 +419,7 @@ func (h *Hook) Equals(hook *Hook) bool { resp.ArrayValue(hook.Message.Values)) } +// FenceMeta is a meta key/value pair for fences type FenceMeta struct { Name, Value string } @@ -391,21 +441,28 @@ func (arr hookMetaByName) Swap(a, b int) { // Open is called when a hook is first created. It calls the manager // function in a goroutine func (h *Hook) Open() { - h.mu.Lock() - defer h.mu.Unlock() + if h.channel { + // nothing to open for channels + return + } + h.cond.L.Lock() + defer h.cond.L.Unlock() if h.opened { return } h.opened = true - b, _ := json.Marshal(h.Name) - h.query = `{"hook":` + string(b) + `}` + h.query = `{"hook":` + jsonString(h.Name) + `}` go h.manager() } // Close closed the hook and stop the manager function func (h *Hook) Close() { - h.mu.Lock() - defer h.mu.Unlock() + if h.channel { + // nothing to close for channels + return + } + h.cond.L.Lock() + defer h.cond.L.Unlock() if h.closed { return } @@ -416,30 +473,35 @@ func (h *Hook) Close() { // Signal can be called at any point to wake up the hook and // notify the manager that there may be something new in the queue. func (h *Hook) Signal() { - h.mu.Lock() + if h.channel { + // nothing to signal for channels + return + } + h.cond.L.Lock() h.cond.Broadcast() - h.mu.Unlock() + h.cond.L.Unlock() } // the manager is a forever loop that calls proc whenever there's a signal. // it ends when the "closed" flag is set. func (h *Hook) manager() { for { - h.mu.Lock() + h.cond.L.Lock() for { if h.closed { - h.mu.Unlock() + h.cond.L.Unlock() return } if h.proc() { break } - h.mu.Unlock() - time.Sleep(time.Second / 4) - h.mu.Lock() + h.cond.L.Unlock() + // proc failed. wait half a second and try again + time.Sleep(time.Second / 2) + h.cond.L.Lock() } h.cond.Wait() - h.mu.Unlock() + h.cond.L.Unlock() } } @@ -452,13 +514,15 @@ func (h *Hook) proc() (ok bool) { start := time.Now() err := h.db.Update(func(tx *buntdb.Tx) error { // get keys and vals - err := tx.AscendGreaterOrEqual("hooks", h.query, func(key, val string) bool { - if strings.HasPrefix(key, hookLogPrefix) { - keys = append(keys, key) - vals = append(vals, val) - } - return true - }) + err := tx.AscendGreaterOrEqual("hooks", + h.query, func(key, val string) bool { + if strings.HasPrefix(key, hookLogPrefix) { + keys = append(keys, key) + vals = append(vals, val) + } + return true + }, + ) if err != nil { return err } @@ -494,7 +558,8 @@ func (h *Hook) proc() (ok bool) { for _, endpoint := range h.Endpoints { err := h.epm.Send(endpoint, val) if err != nil { - log.Debugf("Endpoint connect/send error: %v: %v: %v", idx, endpoint, err) + log.Debugf("Endpoint connect/send error: %v: %v: %v", + idx, endpoint, err) continue } log.Debugf("Endpoint send ok: %v: %v: %v", idx, endpoint, err) @@ -502,7 +567,8 @@ func (h *Hook) proc() (ok bool) { break } if !sent { - // failed to send. try to reinsert the remaining. if this fails we lose log entries. + // failed to send. try to reinsert the remaining. + // if this fails we lose log entries. keys = keys[i:] vals = vals[i:] ttls = ttls[i:] @@ -528,41 +594,3 @@ func (h *Hook) proc() (ok bool) { } return true } - -/* -// Do performs a hook. -func (hook *Hook) Do(details *commandDetailsT) error { - var lerrs []error - msgs := FenceMatch(hook.Name, hook.ScanWriter, hook.Fence, details) -nextMessage: - for _, msg := range msgs { - nextEndpoint: - for _, endpoint := range hook.Endpoints { - switch endpoint.Protocol { - case HTTP: - if err := sendHTTPMessage(endpoint, []byte(msg)); err != nil { - lerrs = append(lerrs, err) - continue nextEndpoint - } - continue nextMessage // sent - case Disque: - if err := sendDisqueMessage(endpoint, []byte(msg)); err != nil { - lerrs = append(lerrs, err) - continue nextEndpoint - } - continue nextMessage // sent - } - } - } - if len(lerrs) == 0 { - // log.Notice("YAY") - return nil - } - var errmsgs []string - for _, err := range lerrs { - errmsgs = append(errmsgs, err.Error()) - } - err := errors.New("not sent: " + strings.Join(errmsgs, ",")) - log.Error(err) - return err -}*/ diff --git a/pkg/controller/live.go b/pkg/controller/live.go index 8869ad7d..80d08d68 100644 --- a/pkg/controller/live.go +++ b/pkg/controller/live.go @@ -43,7 +43,13 @@ func (c *Controller) processLives() { } } -func writeMessage(conn net.Conn, message []byte, wrapRESP bool, connType server.Type, websocket bool) error { +func writeLiveMessage( + conn net.Conn, + message []byte, + wrapRESP bool, + connType server.Type, + websocket bool, +) error { if len(message) == 0 { return nil } @@ -70,28 +76,34 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.PipelineReade defer func() { log.Info("not live " + addr) }() - if s, ok := inerr.(liveAOFSwitches); ok { + switch s := inerr.(type) { + default: + return errors.New("invalid live type switches") + case liveAOFSwitches: return c.liveAOF(s.pos, conn, rd, msg) + case liveSubscriptionSwitches: + return c.liveSubscription(conn, rd, msg, websocket) + case liveFenceSwitches: + // fallthrough } + + // everything below is for live geofences lb := &liveBuffer{ cond: sync.NewCond(&sync.Mutex{}), } var err error var sw *scanWriter var wr bytes.Buffer - switch s := inerr.(type) { - default: - return errors.New("invalid switch") - case liveFenceSwitches: - lb.glob = s.glob - lb.key = s.key - lb.fence = &s - c.mu.RLock() - sw, err = c.newScanWriter( - &wr, msg, s.key, s.output, s.precision, s.glob, false, - s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields) - c.mu.RUnlock() - } + s := inerr.(liveFenceSwitches) + lb.glob = s.glob + lb.key = s.key + lb.fence = &s + c.mu.RLock() + sw, err = c.newScanWriter( + &wr, msg, s.key, s.output, s.precision, s.glob, false, + s.cursor, s.limit, s.wheres, s.whereins, s.whereevals, s.nofields) + c.mu.RUnlock() + // everything below if for live SCAN, NEARBY, WITHIN, INTERSECTS if err != nil { return err @@ -149,7 +161,7 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.PipelineReade case server.RESP: livemsg = []byte("+OK\r\n") } - if err := writeMessage(conn, livemsg, false, connType, websocket); err != nil { + if err := writeLiveMessage(conn, livemsg, false, connType, websocket); err != nil { return nil // nil return is fine here } for { @@ -168,7 +180,7 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.PipelineReade lb.cond.L.Unlock() msgs := FenceMatch("", sw, fence, nil, details) for _, msg := range msgs { - if err := writeMessage(conn, []byte(msg), true, connType, websocket); err != nil { + if err := writeLiveMessage(conn, []byte(msg), true, connType, websocket); err != nil { return nil // nil return is fine here } } diff --git a/pkg/controller/pubsub.go b/pkg/controller/pubsub.go new file mode 100644 index 00000000..2199b650 --- /dev/null +++ b/pkg/controller/pubsub.go @@ -0,0 +1,363 @@ +package controller + +import ( + "io" + "net" + "strconv" + "sync" + "time" + + "github.com/tidwall/gjson" + "github.com/tidwall/match" + "github.com/tidwall/redcon" + "github.com/tidwall/resp" + "github.com/tidwall/tile38/pkg/log" + "github.com/tidwall/tile38/pkg/server" +) + +const ( + pubsubChannel = iota + pubsubPattern +) + +type pubsub struct { + mu sync.RWMutex + hubs [2]map[string]*subhub +} + +func newPubsub() *pubsub { + return &pubsub{ + hubs: [2]map[string]*subhub{ + make(map[string]*subhub), + make(map[string]*subhub), + }, + } +} + +// Publish a message to subscribers +func (c *Controller) Publish(channel string, message ...string) int { + var msgs []submsg + c.pubsub.mu.RLock() + if hub := c.pubsub.hubs[pubsubChannel][channel]; hub != nil { + for target := range hub.targets { + for _, message := range message { + msgs = append(msgs, submsg{ + kind: pubsubChannel, + target: target, + channel: channel, + message: message, + }) + } + } + } + for pattern, hub := range c.pubsub.hubs[pubsubPattern] { + if match.Match(channel, pattern) { + for target := range hub.targets { + for _, message := range message { + msgs = append(msgs, submsg{ + kind: pubsubPattern, + target: target, + channel: channel, + pattern: pattern, + message: message, + }) + } + } + } + } + c.pubsub.mu.RUnlock() + + for _, msg := range msgs { + msg.target.cond.L.Lock() + msg.target.msgs = append(msg.target.msgs, msg) + msg.target.cond.Broadcast() + msg.target.cond.L.Unlock() + } + + return len(msgs) +} + +func (ps *pubsub) register(kind int, channel string, target *subtarget) { + ps.mu.Lock() + hub, ok := ps.hubs[kind][channel] + if !ok { + hub = newSubhub() + ps.hubs[kind][channel] = hub + } + hub.targets[target] = true + ps.mu.Unlock() +} + +func (ps *pubsub) unregister(kind int, channel string, target *subtarget) { + ps.mu.Lock() + hub, ok := ps.hubs[kind][channel] + if ok { + delete(hub.targets, target) + if len(hub.targets) == 0 { + delete(ps.hubs[kind], channel) + } + } + ps.mu.Unlock() +} + +type submsg struct { + kind byte + target *subtarget + pattern string + channel string + message string +} + +type subtarget struct { + cond *sync.Cond + msgs []submsg + closed bool +} + +func newSubtarget() *subtarget { + target := new(subtarget) + target.cond = sync.NewCond(&sync.Mutex{}) + return target +} + +type subhub struct { + targets map[*subtarget]bool +} + +func newSubhub() *subhub { + hub := new(subhub) + hub.targets = make(map[*subtarget]bool) + return hub +} + +type liveSubscriptionSwitches struct { + // no fields. everything is managed through the server.Message +} + +func (sub liveSubscriptionSwitches) Error() string { + return goingLive +} + +func (c *Controller) cmdSubscribe(msg *server.Message) (resp.Value, error) { + if len(msg.Values) < 2 { + return resp.Value{}, errInvalidNumberOfArguments + } + return server.NOMessage, liveSubscriptionSwitches{} +} + +func (c *Controller) cmdPsubscribe(msg *server.Message) (resp.Value, error) { + if len(msg.Values) < 2 { + return resp.Value{}, errInvalidNumberOfArguments + } + return server.NOMessage, liveSubscriptionSwitches{} +} + +func (c *Controller) cmdPublish(msg *server.Message) (resp.Value, error) { + start := time.Now() + if len(msg.Values) != 3 { + return resp.Value{}, errInvalidNumberOfArguments + } + + channel := msg.Values[1].String() + message := msg.Values[2].String() + //geofence := gjson.Valid(message) && gjson.Get(message, "fence").Bool() + n := c.Publish(channel, message) //, geofence) + var res resp.Value + switch msg.OutputType { + case server.JSON: + res = resp.StringValue(`{"ok":true` + + `,"published":` + strconv.FormatInt(int64(n), 10) + + `,"elapsed":"` + time.Now().Sub(start).String() + `"}`) + case server.RESP: + res = resp.IntegerValue(n) + } + return res, nil +} + +func (c *Controller) liveSubscription( + conn net.Conn, + rd *server.PipelineReader, + msg *server.Message, + websocket bool, +) error { + defer conn.Close() // close connection when we are done + + outputType := msg.OutputType + connType := msg.ConnType + if websocket { + outputType = server.JSON + } + + var start time.Time + + // write helpers + var writeLock sync.Mutex + write := func(data []byte) { + writeLock.Lock() + defer writeLock.Unlock() + writeLiveMessage(conn, data, false, connType, websocket) + } + writeOK := func() { + switch outputType { + case server.JSON: + write([]byte(`{"ok":true` + + `,"elapsed":"` + time.Now().Sub(start).String() + `"}`)) + case server.RESP: + write([]byte(`+OK\r\n`)) + } + } + writeWrongNumberOfArgsErr := func(command string) { + switch outputType { + case server.JSON: + write([]byte(`{"ok":false,"err":"invalid number of arguments"` + + `,"elapsed":"` + time.Now().Sub(start).String() + `"}`)) + case server.RESP: + write([]byte(`-ERR wrong number of arguments ` + + `for '` + command + `' command\r\n`)) + } + } + writeOnlyPubsubErr := func() { + switch outputType { + case server.JSON: + write([]byte(`{"ok":false` + + `,"err":"only (P)SUBSCRIBE / (P)UNSUBSCRIBE / ` + + `PING / QUIT allowed in this context"` + + `,"elapsed":"` + time.Now().Sub(start).String() + `"}`)) + case server.RESP: + write([]byte("-ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / " + + "PING / QUIT allowed in this context\r\n")) + } + } + writeSubscribe := func(command, channel string, num int) { + switch outputType { + case server.JSON: + write([]byte(`{"ok":true` + + `,"command":` + jsonString(command) + + `,"channel":` + jsonString(channel) + + `,"num":` + strconv.FormatInt(int64(num), 10) + + `,"elapsed":"` + time.Now().Sub(start).String() + `"}`)) + case server.RESP: + b := redcon.AppendArray(nil, 3) + b = redcon.AppendBulkString(b, command) + b = redcon.AppendBulkString(b, channel) + b = redcon.AppendInt(b, int64(num)) + write(b) + } + } + writeMessage := func(msg submsg) { + if msg.kind == pubsubChannel { + switch outputType { + case server.JSON: + var data []byte + if !gjson.Valid(msg.message) { + data = appendJSONString(nil, msg.message) + } else { + data = []byte(msg.message) + } + write(data) + case server.RESP: + b := redcon.AppendArray(nil, 3) + b = redcon.AppendBulkString(b, "message") + b = redcon.AppendBulkString(b, msg.channel) + b = redcon.AppendBulkString(b, msg.message) + write(b) + } + } else { + switch outputType { + case server.JSON: + var data []byte + if !gjson.Valid(msg.message) { + data = appendJSONString(nil, msg.message) + } else { + data = []byte(msg.message) + } + write(data) + case server.RESP: + b := redcon.AppendArray(nil, 4) + b = redcon.AppendBulkString(b, "pmessage") + b = redcon.AppendBulkString(b, msg.pattern) + b = redcon.AppendBulkString(b, msg.channel) + b = redcon.AppendBulkString(b, msg.message) + write(b) + } + } + } + + m := [2]map[string]bool{ + make(map[string]bool), + make(map[string]bool), + } + + target := newSubtarget() + + defer func() { + for i := 0; i < 2; i++ { + for channel := range m[i] { + c.pubsub.unregister(i, channel, target) + } + } + target.cond.L.Lock() + target.closed = true + target.cond.Broadcast() + target.cond.L.Unlock() + }() + go func() { + log.Debugf("pubsub open") + defer log.Debugf("pubsub closed") + for { + var msgs []submsg + target.cond.L.Lock() + if len(target.msgs) > 0 { + msgs = target.msgs + target.msgs = nil + } + target.cond.L.Unlock() + for _, msg := range msgs { + writeMessage(msg) + } + target.cond.L.Lock() + if target.closed { + target.cond.L.Unlock() + return + } + target.cond.Wait() + target.cond.L.Unlock() + } + }() + + msgs := []*server.Message{msg} + for { + for _, msg := range msgs { + start = time.Now() + var kind int + switch msg.Command { + case "quit": + writeOK() + return nil + case "psubscribe": + kind = pubsubPattern + case "subscribe": + kind = pubsubChannel + default: + writeOnlyPubsubErr() + } + if len(msg.Values) < 2 { + writeWrongNumberOfArgsErr(msg.Command) + } + for i := 1; i < len(msg.Values); i++ { + channel := msg.Values[i].String() + m[kind][channel] = true + c.pubsub.register(kind, channel, target) + writeSubscribe(msg.Command, channel, len(m[0])+len(m[1])) + } + } + var err error + msgs, err = rd.ReadMessages() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + } +} diff --git a/pkg/controller/scan.go b/pkg/controller/scan.go index 08107795..6293e03e 100644 --- a/pkg/controller/scan.go +++ b/pkg/controller/scan.go @@ -11,10 +11,15 @@ import ( "github.com/tidwall/tile38/pkg/server" ) -func (c *Controller) cmdScanArgs(vs []resp.Value) (s liveFenceSwitches, err error) { - if vs, s.searchScanBaseTokens, err = c.parseSearchScanBaseTokens("scan", vs); err != nil { +func (c *Controller) cmdScanArgs(vs []resp.Value) ( + s liveFenceSwitches, err error, +) { + var t searchScanBaseTokens + vs, t, err = c.parseSearchScanBaseTokens("scan", t, vs) + if err != nil { return } + s.searchScanBaseTokens = t if len(vs) != 0 { err = errInvalidNumberOfArguments return diff --git a/pkg/controller/search.go b/pkg/controller/search.go index 390dbd37..6afce7b9 100644 --- a/pkg/controller/search.go +++ b/pkg/controller/search.go @@ -38,7 +38,7 @@ type roamSwitches struct { } func (s liveFenceSwitches) Error() string { - return "going live" + return goingLive } func (s liveFenceSwitches) Close() { @@ -51,10 +51,18 @@ func (s liveFenceSwitches) usingLua() bool { return len(s.whereevals) > 0 } -func (c *Controller) cmdSearchArgs(cmd string, vs []resp.Value, types []string) (s liveFenceSwitches, err error) { - if vs, s.searchScanBaseTokens, err = c.parseSearchScanBaseTokens(cmd, vs); err != nil { +func (c *Controller) cmdSearchArgs( + fromFenceCmd bool, cmd string, vs []resp.Value, types []string, +) (s liveFenceSwitches, err error) { + var t searchScanBaseTokens + if fromFenceCmd { + t.fence = true + } + vs, t, err = c.parseSearchScanBaseTokens(cmd, t, vs) + if err != nil { return } + s.searchScanBaseTokens = t var typ string var ok bool if vs, typ, ok = tokenval(vs); !ok || typ == "" { @@ -350,7 +358,7 @@ func (c *Controller) cmdNearby(msg *server.Message) (res resp.Value, err error) start := time.Now() vs := msg.Values[1:] wr := &bytes.Buffer{} - s, err := c.cmdSearchArgs("nearby", vs, nearbyTypes) + s, err := c.cmdSearchArgs(false, "nearby", vs, nearbyTypes) if s.usingLua() { defer s.Close() defer func() { @@ -473,7 +481,7 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res vs := msg.Values[1:] wr := &bytes.Buffer{} - s, err := c.cmdSearchArgs(cmd, vs, withinOrIntersectsTypes) + s, err := c.cmdSearchArgs(false, cmd, vs, withinOrIntersectsTypes) if s.usingLua() { defer s.Close() defer func() { @@ -533,11 +541,11 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res return true } return sw.writeObject(ScanWriterParams{ - id: id, - o: o, - fields: fields, - noLock: true, - clip: s.clip, + id: id, + o: o, + fields: fields, + noLock: true, + clip: s.clip, clipbox: clipbox, }) }, @@ -552,10 +560,15 @@ func (c *Controller) cmdWithinOrIntersects(cmd string, msg *server.Message) (res return sw.respOut, nil } -func (c *Controller) cmdSeachValuesArgs(vs []resp.Value) (s liveFenceSwitches, err error) { - if vs, s.searchScanBaseTokens, err = c.parseSearchScanBaseTokens("search", vs); err != nil { +func (c *Controller) cmdSeachValuesArgs(vs []resp.Value) ( + s liveFenceSwitches, err error, +) { + var t searchScanBaseTokens + vs, t, err = c.parseSearchScanBaseTokens("search", t, vs) + if err != nil { return } + s.searchScanBaseTokens = t if len(vs) != 0 { err = errInvalidNumberOfArguments return diff --git a/pkg/controller/token.go b/pkg/controller/token.go index e3269606..0549cfb3 100644 --- a/pkg/controller/token.go +++ b/pkg/controller/token.go @@ -168,15 +168,15 @@ func (wherein whereinT) match(value float64) bool { } type whereevalT struct { - c *Controller - luaState *lua.LState - fn *lua.LFunction + c *Controller + luaState *lua.LState + fn *lua.LFunction } func (whereeval whereevalT) Close() { luaSetRawGlobals( whereeval.luaState, map[string]lua.LValue{ - "ARGV": lua.LNil, + "ARGV": lua.LNil, }) whereeval.c.luapool.Put(whereeval.luaState) } @@ -189,11 +189,11 @@ func (whereeval whereevalT) match(fieldsWithNames map[string]float64) bool { luaSetRawGlobals( whereeval.luaState, map[string]lua.LValue{ - "FIELDS": fieldsTbl, + "FIELDS": fieldsTbl, }) defer luaSetRawGlobals( whereeval.luaState, map[string]lua.LValue{ - "FIELDS": lua.LNil, + "FIELDS": lua.LNil, }) whereeval.luaState.Push(whereeval.fn) @@ -219,41 +219,48 @@ func (whereeval whereevalT) match(fieldsWithNames map[string]float64) bool { return true } var match bool - tbl.ForEach(func(lk lua.LValue, lv lua.LValue) {match = true}) + tbl.ForEach(func(lk lua.LValue, lv lua.LValue) { match = true }) return match } panic(fmt.Sprintf("Script returned value of type %s", ret.Type())) } type searchScanBaseTokens struct { - key string - cursor uint64 - output outputT - precision uint64 - lineout string - fence bool - distance bool - detect map[string]bool - accept map[string]bool - glob string - wheres []whereT - whereins []whereinT - whereevals []whereevalT - nofields bool - ulimit bool - limit uint64 - usparse bool - sparse uint8 - desc bool - clip bool + key string + cursor uint64 + output outputT + precision uint64 + lineout string + fence bool + distance bool + detect map[string]bool + accept map[string]bool + glob string + wheres []whereT + whereins []whereinT + whereevals []whereevalT + nofields bool + ulimit bool + limit uint64 + usparse bool + sparse uint8 + desc bool + clip bool } -func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vsout []resp.Value, t searchScanBaseTokens, err error) { +func (c *Controller) parseSearchScanBaseTokens( + cmd string, t searchScanBaseTokens, vs []resp.Value, +) ( + vsout []resp.Value, tout searchScanBaseTokens, err error, +) { var ok bool if vs, t.key, ok = tokenval(vs); !ok || t.key == "" { err = errInvalidNumberOfArguments return } + + fromFence := t.fence + var slimit string var ssparse string var scursor string @@ -261,7 +268,8 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso for { nvs, wtok, ok := tokenval(vs) if ok && len(wtok) > 0 { - if (wtok[0] == 'C' || wtok[0] == 'c') && strings.ToLower(wtok) == "cursor" { + switch strings.ToLower(wtok) { + case "cursor": vs = nvs if scursor != "" { err = errDuplicateArgument(strings.ToUpper(wtok)) @@ -272,7 +280,7 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso return } continue - } else if (wtok[0] == 'W' || wtok[0] == 'w') && strings.ToLower(wtok) == "where" { + case "where": vs = nvs var field, smin, smax string if vs, field, ok = tokenval(vs); !ok || field == "" { @@ -317,7 +325,7 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso } t.wheres = append(t.wheres, whereT{field, minx, min, maxx, max}) continue - } else if (wtok[0] == 'W' || wtok[0] == 'w') && strings.ToLower(wtok) == "wherein" { + case "wherein": vs = nvs var field, nvalsStr, valStr string if vs, field, ok = tokenval(vs); !ok || field == "" { @@ -349,7 +357,9 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso } t.whereins = append(t.whereins, whereinT{field, valMap}) continue - } else if (wtok[0] == 'W' || wtok[0] == 'w') && strings.Contains(strings.ToLower(wtok), "whereeval") { + case "whereevalsha": + fallthrough + case "whereeval": scriptIsSha := strings.ToLower(wtok) == "whereevalsha" vs = nvs var script, nargsStr, arg string @@ -392,7 +402,7 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso luaSetRawGlobals( luaState, map[string]lua.LValue{ - "ARGV": argsTbl, + "ARGV": argsTbl, }) compiled, ok := c.luascripts.Get(shaSum) @@ -417,9 +427,9 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso } c.luascripts.Put(shaSum, fn.Proto) } - t.whereevals = append(t.whereevals, whereevalT{c,luaState, fn}) + t.whereevals = append(t.whereevals, whereevalT{c, luaState, fn}) continue - } else if (wtok[0] == 'N' || wtok[0] == 'n') && strings.ToLower(wtok) == "nofields" { + case "nofields": vs = nvs if t.nofields { err = errDuplicateArgument(strings.ToUpper(wtok)) @@ -427,7 +437,7 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso } t.nofields = true continue - } else if (wtok[0] == 'L' || wtok[0] == 'l') && strings.ToLower(wtok) == "limit" { + case "limit": vs = nvs if slimit != "" { err = errDuplicateArgument(strings.ToUpper(wtok)) @@ -438,7 +448,7 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso return } continue - } else if (wtok[0] == 'S' || wtok[0] == 's') && strings.ToLower(wtok) == "sparse" { + case "sparse": vs = nvs if ssparse != "" { err = errDuplicateArgument(strings.ToUpper(wtok)) @@ -449,15 +459,15 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso return } continue - } else if (wtok[0] == 'F' || wtok[0] == 'f') && strings.ToLower(wtok) == "fence" { + case "fence": vs = nvs - if t.fence { + if t.fence && !fromFence { err = errDuplicateArgument(strings.ToUpper(wtok)) return } t.fence = true continue - } else if (wtok[0] == 'C' || wtok[0] == 'c') && strings.ToLower(wtok) == "commands" { + case "commands": vs = nvs if t.accept != nil { err = errDuplicateArgument(strings.ToUpper(wtok)) @@ -481,7 +491,7 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso t.accept = nil } continue - } else if (wtok[0] == 'D' || wtok[0] == 'd') && strings.ToLower(wtok) == "distance" { + case "distance": vs = nvs if t.distance { err = errDuplicateArgument(strings.ToUpper(wtok)) @@ -489,7 +499,7 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso } t.distance = true continue - } else if (wtok[0] == 'D' || wtok[0] == 'd') && strings.ToLower(wtok) == "detect" { + case "detect": vs = nvs if t.detect != nil { err = errDuplicateArgument(strings.ToUpper(wtok)) @@ -525,7 +535,7 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso } } continue - } else if (wtok[0] == 'D' || wtok[0] == 'd') && strings.ToLower(wtok) == "desc" { + case "desc": vs = nvs if t.desc || asc { err = errDuplicateArgument(strings.ToUpper(wtok)) @@ -533,7 +543,7 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso } t.desc = true continue - } else if (wtok[0] == 'A' || wtok[0] == 'a') && strings.ToLower(wtok) == "asc" { + case "asc": vs = nvs if t.desc || asc { err = errDuplicateArgument(strings.ToUpper(wtok)) @@ -541,7 +551,7 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso } asc = true continue - } else if (wtok[0] == 'M' || wtok[0] == 'm') && strings.ToLower(wtok) == "match" { + case "match": vs = nvs if t.glob != "" { err = errDuplicateArgument(strings.ToUpper(wtok)) @@ -552,7 +562,7 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso return } continue - } else if (wtok[0] == 'C' || wtok[0] == 'c') && strings.ToLower(wtok) == "clip" { + case "clip": vs = nvs if t.clip { err = errDuplicateArgument(strings.ToUpper(wtok)) @@ -666,5 +676,6 @@ func (c *Controller) parseSearchScanBaseTokens(cmd string, vs []resp.Value) (vso t.limit = math.MaxUint64 } vsout = vs + tout = t return } diff --git a/pkg/endpoint/disque.go b/pkg/endpoint/disque.go index f7ce2ae8..131254b3 100644 --- a/pkg/endpoint/disque.go +++ b/pkg/endpoint/disque.go @@ -1,14 +1,12 @@ package endpoint import ( - "bufio" - "errors" "fmt" - "net" - "strconv" - "strings" "sync" "time" + + "github.com/garyburd/redigo/redis" + "github.com/tidwall/tile38/pkg/log" ) const ( @@ -21,8 +19,7 @@ type DisqueConn struct { ep Endpoint ex bool t time.Time - conn net.Conn - rd *bufio.Reader + conn redis.Conn } func newDisqueConn(ep Endpoint) *DisqueConn { @@ -52,7 +49,6 @@ func (conn *DisqueConn) close() { conn.conn.Close() conn.conn = nil } - conn.rd = nil } // Send sends a message @@ -66,60 +62,23 @@ func (conn *DisqueConn) Send(msg string) error { if conn.conn == nil { addr := fmt.Sprintf("%s:%d", conn.ep.Disque.Host, conn.ep.Disque.Port) var err error - conn.conn, err = net.Dial("tcp", addr) + conn.conn, err = redis.Dial("tcp", addr) if err != nil { return err } - conn.rd = bufio.NewReader(conn.conn) } - var args []string - args = append(args, "ADDJOB", conn.ep.Disque.QueueName, msg, "0") + + var args []interface{} + args = append(args, conn.ep.Disque.QueueName, msg, 0) if conn.ep.Disque.Options.Replicate > 0 { - args = append(args, "REPLICATE", strconv.FormatInt(int64(conn.ep.Disque.Options.Replicate), 10)) + args = append(args, "REPLICATE", conn.ep.Disque.Options.Replicate) } - cmd := buildRedisCommand(args) - if _, err := conn.conn.Write(cmd); err != nil { - conn.close() - return err - } - c, err := conn.rd.ReadByte() + + reply, err := redis.String(conn.conn.Do("ADDJOB", args...)) if err != nil { conn.close() return err } - if c != '-' && c != '+' { - conn.close() - return errors.New("invalid disque reply") - } - ln, err := conn.rd.ReadBytes('\n') - if err != nil { - conn.close() - return err - } - if len(ln) < 2 || ln[len(ln)-2] != '\r' { - conn.close() - return errors.New("invalid disque reply") - } - id := string(ln[:len(ln)-2]) - p := strings.Split(id, "-") - if len(p) != 4 { - conn.close() - return errors.New("invalid disque reply") - } + log.Debugf("Disque: ADDJOB '%s'", reply) return nil } - -func buildRedisCommand(args []string) []byte { - var cmd []byte - cmd = append(cmd, '*') - cmd = strconv.AppendInt(cmd, int64(len(args)), 10) - cmd = append(cmd, '\r', '\n') - for _, arg := range args { - cmd = append(cmd, '$') - cmd = strconv.AppendInt(cmd, int64(len(arg)), 10) - cmd = append(cmd, '\r', '\n') - cmd = append(cmd, arg...) - cmd = append(cmd, '\r', '\n') - } - return cmd -} diff --git a/pkg/endpoint/endpoint.go b/pkg/endpoint/endpoint.go index fb0c186c..cacf32e0 100644 --- a/pkg/endpoint/endpoint.go +++ b/pkg/endpoint/endpoint.go @@ -17,6 +17,8 @@ var errExpired = errors.New("expired") type Protocol string const ( + // Local protocol + Local = Protocol("local") // HTTP protocol HTTP = Protocol("http") // Disque protocol @@ -87,7 +89,6 @@ type Endpoint struct { CertFile string KeyFile string } - SQS struct { QueueID string Region string @@ -95,7 +96,6 @@ type Endpoint struct { CredProfile string QueueName string } - NATS struct { Host string Port int @@ -103,6 +103,9 @@ type Endpoint struct { Pass string Topic string } + Local struct { + Channel string + } } // Conn is an endpoint connection @@ -113,14 +116,16 @@ type Conn interface { // Manager manages all endpoints type Manager struct { - mu sync.RWMutex - conns map[string]Conn + mu sync.RWMutex + conns map[string]Conn + publisher LocalPublisher } // NewManager returns a new manager -func NewManager() *Manager { +func NewManager(publisher LocalPublisher) *Manager { epc := &Manager{ - conns: make(map[string]Conn), + conns: make(map[string]Conn), + publisher: publisher, } go epc.Run() return epc @@ -180,6 +185,8 @@ func (epc *Manager) Send(endpoint, msg string) error { conn = newSQSConn(ep) case NATS: conn = newNATSConn(ep) + case Local: + conn = newLocalConn(ep, epc.publisher) } epc.conns[endpoint] = conn } @@ -204,6 +211,8 @@ func parseEndpoint(s string) (Endpoint, error) { switch { default: return endpoint, errors.New("unknown scheme") + case strings.HasPrefix(s, "local:"): + endpoint.Protocol = Local case strings.HasPrefix(s, "http:"): endpoint.Protocol = HTTP case strings.HasPrefix(s, "https:"): @@ -237,9 +246,17 @@ func parseEndpoint(s string) (Endpoint, error) { sp := strings.Split(sqp[0], "/") s = sp[0] if s == "" { + if endpoint.Protocol == Local { + return endpoint, errors.New("missing channel") + } return endpoint, errors.New("missing host") } + // Local PubSub channel + // local:// + if endpoint.Protocol == Local { + endpoint.Local.Channel = s + } if endpoint.Protocol == GRPC { dp := strings.Split(s, ":") switch len(dp) { diff --git a/pkg/endpoint/local.go b/pkg/endpoint/local.go new file mode 100644 index 00000000..3f6c9077 --- /dev/null +++ b/pkg/endpoint/local.go @@ -0,0 +1,38 @@ +package endpoint + +import ( + "time" +) + +const ( + localExpiresAfter = time.Second * 30 +) + +// LocalPublisher is used to publish local notifcations +type LocalPublisher interface { + Publish(channel string, message ...string) int +} + +// LocalConn is an endpoint connection +type LocalConn struct { + ep Endpoint + publisher LocalPublisher +} + +func newLocalConn(ep Endpoint, publisher LocalPublisher) *LocalConn { + return &LocalConn{ + ep: ep, + publisher: publisher, + } +} + +// Expired returns true if the connection has expired +func (conn *LocalConn) Expired() bool { + return false +} + +// Send sends a message +func (conn *LocalConn) Send(msg string) error { + conn.publisher.Publish(conn.ep.Local.Channel, msg) + return nil +} diff --git a/pkg/endpoint/redis.go b/pkg/endpoint/redis.go index 748c9b53..fb4f83d7 100644 --- a/pkg/endpoint/redis.go +++ b/pkg/endpoint/redis.go @@ -1,12 +1,11 @@ package endpoint import ( - "bufio" - "errors" "fmt" - "net" "sync" "time" + + "github.com/garyburd/redigo/redis" ) const ( @@ -19,8 +18,7 @@ type RedisConn struct { ep Endpoint ex bool t time.Time - conn net.Conn - rd *bufio.Reader + conn redis.Conn } func newRedisConn(ep Endpoint) *RedisConn { @@ -50,7 +48,6 @@ func (conn *RedisConn) close() { conn.conn.Close() conn.conn = nil } - conn.rd = nil } // Send sends a message @@ -61,48 +58,20 @@ func (conn *RedisConn) Send(msg string) error { if conn.ex { return errExpired } - conn.t = time.Now() if conn.conn == nil { addr := fmt.Sprintf("%s:%d", conn.ep.Redis.Host, conn.ep.Redis.Port) var err error - conn.conn, err = net.Dial("tcp", addr) + conn.conn, err = redis.Dial("tcp", addr) if err != nil { + conn.close() return err } - conn.rd = bufio.NewReader(conn.conn) } - - var args []string - args = append(args, "PUBLISH", conn.ep.Redis.Channel, msg) - cmd := buildRedisCommand(args) - - if _, err := conn.conn.Write(cmd); err != nil { - conn.close() - return err - } - - c, err := conn.rd.ReadByte() + _, err := redis.Int(conn.conn.Do("PUBLISH", conn.ep.Redis.Channel, msg)) if err != nil { conn.close() return err } - - if c != ':' { - conn.close() - return errors.New("invalid redis reply") - } - - ln, err := conn.rd.ReadBytes('\n') - if err != nil { - conn.close() - return err - } - - if string(ln[0:1]) != "1" { - conn.close() - return errors.New("invalid redis reply") - } - return nil } diff --git a/pkg/server/reader.go b/pkg/server/reader.go index f110450e..c9ba2a3a 100644 --- a/pkg/server/reader.go +++ b/pkg/server/reader.go @@ -145,7 +145,6 @@ func readNextHTTPCommand(packet []byte, argsIn [][]byte, msg *Message, wr io.Wri accept := base64.StdEncoding.EncodeToString(sum[:]) wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n" if _, err = wr.Write([]byte(wshead)); err != nil { - println(4) return false, err } } else if contentLength > 0 {