diff --git a/controller/aof.go b/controller/aof.go index 46eb2308..084a35d6 100644 --- a/controller/aof.go +++ b/controller/aof.go @@ -359,7 +359,7 @@ func (c *Controller) cmdAOF(msg *server.Message) (res string, err error) { return "", s } -func (c *Controller) liveAOF(pos int64, conn net.Conn, rd *server.AnyReaderWriter, msg *server.Message) error { +func (c *Controller) liveAOF(pos int64, conn net.Conn, rd *server.PipelineReader, msg *server.Message) error { c.mu.Lock() c.aofconnM[conn] = true c.mu.Unlock() @@ -394,19 +394,21 @@ func (c *Controller) liveAOF(pos int64, conn net.Conn, rd *server.AnyReaderWrite cond.L.Unlock() }() for { - v, err := rd.ReadMessage() + vs, err := rd.ReadMessages() if err != nil { if err != io.EOF { log.Error(err) } return } - switch v.Command { - default: - log.Error("received a live command that was not QUIT") - return - case "quit", "": - return + for _, v := range vs { + switch v.Command { + default: + log.Error("received a live command that was not QUIT") + return + case "quit", "": + return + } } } }() diff --git a/controller/controller.go b/controller/controller.go index b9c5fc94..2913feb9 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -194,7 +194,7 @@ func ListenAndServeEx(host string, port int, dir string, ln *net.Listener, http c.stopWatchingMemory.set(true) c.stopWatchingAutoGC.set(true) }() - handler := func(conn *server.Conn, msg *server.Message, rd *server.AnyReaderWriter, w io.Writer, websocket bool) error { + handler := func(conn *server.Conn, msg *server.Message, rd *server.PipelineReader, w io.Writer, websocket bool) error { c.connsmu.RLock() if cc, ok := c.conns[conn]; ok { cc.last.set(time.Now()) diff --git a/controller/live.go b/controller/live.go index 3c2c088d..38c157e9 100644 --- a/controller/live.go +++ b/controller/live.go @@ -64,7 +64,7 @@ func writeMessage(conn net.Conn, message []byte, wrapRESP bool, connType server. return err } -func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.AnyReaderWriter, msg *server.Message, websocket bool) error { +func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.PipelineReader, msg *server.Message, websocket bool) error { addr := conn.RemoteAddr().String() log.Info("live " + addr) defer func() { @@ -114,22 +114,24 @@ func (c *Controller) goLive(inerr error, conn net.Conn, rd *server.AnyReaderWrit conn.Close() }() for { - v, err := rd.ReadMessage() + vs, err := rd.ReadMessages() if err != nil { if err != io.EOF && !(websocket && err == io.ErrUnexpectedEOF) { log.Error(err) } return } - if v == nil { - continue - } - switch v.Command { - default: - log.Error("received a live command that was not QUIT") - return - case "quit", "": - return + for _, v := range vs { + if v == nil { + continue + } + switch v.Command { + default: + log.Error("received a live command that was not QUIT") + return + case "quit", "": + return + } } } }() diff --git a/controller/server/anyreader.go b/controller/server/anyreader.go deleted file mode 100644 index d73625e2..00000000 --- a/controller/server/anyreader.go +++ /dev/null @@ -1,334 +0,0 @@ -package server - -import ( - "bufio" - "crypto/sha1" - "encoding/base64" - "errors" - "io" - "net/url" - "strconv" - "strings" - - "github.com/tidwall/resp" -) - -const telnetIsJSON = false - -// Type is resp type -type Type int - -const ( - Null Type = iota - RESP - Telnet - Native - HTTP - WebSocket - JSON -) - -// String return a string for type. -func (t Type) String() string { - switch t { - default: - return "Unknown" - case Null: - return "Null" - case RESP: - return "RESP" - case Telnet: - return "Telnet" - case Native: - return "Native" - case HTTP: - return "HTTP" - case WebSocket: - return "WebSocket" - case JSON: - return "JSON" - } -} - -type errRESPProtocolError struct { - msg string -} - -func (err errRESPProtocolError) Error() string { - return "Protocol error: " + err.msg -} - -// Message is a resp message -type Message struct { - Command string - Values []resp.Value - ConnType Type - OutputType Type - Auth string -} - -// AnyReaderWriter is resp or native reader writer. -type AnyReaderWriter struct { - rd *bufio.Reader - wr io.Writer - ws bool -} - -// NewAnyReaderWriter returns an AnyReaderWriter object. -func NewAnyReaderWriter(rd io.Reader) *AnyReaderWriter { - ar := &AnyReaderWriter{} - if rd2, ok := rd.(*bufio.Reader); ok { - ar.rd = rd2 - } else { - ar.rd = bufio.NewReader(rd) - } - if wr, ok := rd.(io.Writer); ok { - ar.wr = wr - } - return ar -} - -func (ar *AnyReaderWriter) peekcrlfline() (string, error) { - // this is slow operation. - for i := 0; ; i++ { - bb, err := ar.rd.Peek(i) - if err != nil { - return "", err - } - if len(bb) > 2 && bb[len(bb)-2] == '\r' && bb[len(bb)-1] == '\n' { - return string(bb[:len(bb)-2]), nil - } - } -} - -func (ar *AnyReaderWriter) readcrlfline() (string, error) { - var line []byte - for { - bb, err := ar.rd.ReadBytes('\r') - if err != nil { - return "", err - } - if line == nil { - line = bb - } else { - line = append(line, bb...) - } - b, err := ar.rd.ReadByte() - if err != nil { - return "", err - } - if b == '\n' { - return string(line[:len(line)-1]), nil - } - line = append(line, b) - } -} - -// ReadMessage reads the next resp message. -func (ar *AnyReaderWriter) ReadMessage() (*Message, error) { - b, err := ar.rd.ReadByte() - if err != nil { - return nil, err - } - if err := ar.rd.UnreadByte(); err != nil { - return nil, err - } - switch b { - case 'G', 'P': - line, err := ar.peekcrlfline() - if err != nil { - return nil, err - } - if len(line) > 9 && line[len(line)-9:len(line)-3] == " HTTP/" { - return ar.readHTTPMessage() - } - case '$': - return ar.readNativeMessage() - } - // MultiBulk also reads telnet - return ar.readMultiBulkMessage() -} - -func readNativeMessageLine(line []byte) (*Message, error) { - values := make([]resp.Value, 0, 16) -reading: - for len(line) != 0 { - if line[0] == '{' { - // The native protocol cannot understand json boundaries so it assumes that - // a json element must be at the end of the line. - values = append(values, resp.StringValue(string(line))) - break - } - if line[0] == '"' && line[len(line)-1] == '"' { - if len(values) > 0 && - strings.ToLower(values[0].String()) == "set" && - strings.ToLower(values[len(values)-1].String()) == "string" { - // Setting a string value that is contained inside double quotes. - // This is only because of the boundary issues of the native protocol. - values = append(values, resp.StringValue(string(line[1:len(line)-1]))) - break - } - } - i := 0 - for ; i < len(line); i++ { - if line[i] == ' ' { - value := string(line[:i]) - if value != "" { - values = append(values, resp.StringValue(value)) - } - line = line[i+1:] - continue reading - } - } - values = append(values, resp.StringValue(string(line))) - break - } - return &Message{Command: commandValues(values), Values: values, ConnType: Native, OutputType: JSON}, nil -} - -func (ar *AnyReaderWriter) readNativeMessage() (*Message, error) { - b, err := ar.rd.ReadBytes(' ') - if err != nil { - return nil, err - } - if len(b) > 0 && b[0] != '$' { - return nil, errors.New("invalid message") - } - n, err := strconv.ParseUint(string(b[1:len(b)-1]), 10, 32) - if err != nil { - return nil, errors.New("invalid size") - } - if n > 0x1FFFFFFF { // 536,870,911 bytes - return nil, errors.New("message too big") - } - b = make([]byte, int(n)+2) - if _, err := io.ReadFull(ar.rd, b); err != nil { - return nil, err - } - if b[len(b)-2] != '\r' || b[len(b)-1] != '\n' { - return nil, errors.New("expecting crlf") - } - - return readNativeMessageLine(b[:len(b)-2]) -} - -func commandValues(values []resp.Value) string { - if len(values) == 0 { - return "" - } - return strings.ToLower(values[0].String()) -} - -func (ar *AnyReaderWriter) readMultiBulkMessage() (*Message, error) { - rd := resp.NewReader(ar.rd) - v, telnet, _, err := rd.ReadMultiBulk() - if err != nil { - return nil, err - } - values := v.Array() - if len(values) == 0 { - return nil, nil - } - if telnet && telnetIsJSON { - return &Message{Command: commandValues(values), Values: values, ConnType: Telnet, OutputType: JSON}, nil - } - return &Message{Command: commandValues(values), Values: values, ConnType: RESP, OutputType: RESP}, nil - -} - -func (ar *AnyReaderWriter) readHTTPMessage() (*Message, error) { - msg := &Message{ConnType: HTTP, OutputType: JSON} - line, err := ar.readcrlfline() - if err != nil { - return nil, err - } - parts := strings.Split(line, " ") - if len(parts) != 3 { - return nil, errors.New("invalid HTTP request") - } - method := parts[0] - path := parts[1] - if len(path) == 0 || path[0] != '/' { - return nil, errors.New("invalid HTTP request") - } - path, err = url.QueryUnescape(path[1:]) - if err != nil { - return nil, errors.New("invalid HTTP request") - } - if method != "GET" && method != "POST" { - return nil, errors.New("invalid HTTP method") - } - contentLength := 0 - websocket := false - websocketVersion := 0 - websocketKey := "" - for { - header, err := ar.readcrlfline() - if err != nil { - return nil, err - } - if header == "" { - break // end of headers - } - if header[0] == 'a' || header[0] == 'A' { - if strings.HasPrefix(strings.ToLower(header), "authorization:") { - msg.Auth = strings.TrimSpace(header[len("authorization:"):]) - } - } else if header[0] == 'u' || header[0] == 'U' { - if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" { - websocket = true - } - } else if header[0] == 's' || header[0] == 'S' { - if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") { - var n uint64 - n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64) - if err != nil { - return nil, err - } - websocketVersion = int(n) - } else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") { - websocketKey = strings.TrimSpace(header[len("sec-websocket-key:"):]) - } - } else if header[0] == 'c' || header[0] == 'C' { - if strings.HasPrefix(strings.ToLower(header), "content-length:") { - var n uint64 - n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64) - if err != nil { - return nil, err - } - contentLength = int(n) - } - } - } - if websocket && websocketVersion >= 13 && websocketKey != "" { - msg.ConnType = WebSocket - if ar.wr == nil { - return nil, errors.New("connection is nil") - } - sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) - accept := base64.StdEncoding.EncodeToString(sum[:]) - wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n" - if _, err = ar.wr.Write([]byte(wshead)); err != nil { - return nil, err - } - ar.ws = true - } else if contentLength > 0 { - msg.ConnType = HTTP - buf := make([]byte, contentLength) - if _, err = io.ReadFull(ar.rd, buf); err != nil { - return nil, err - } - path += string(buf) - } - if path == "" { - return msg, nil - } - nmsg, err := readNativeMessageLine([]byte(path)) - if err != nil { - return nil, err - } - msg.OutputType = JSON - msg.Values = nmsg.Values - msg.Command = commandValues(nmsg.Values) - return msg, nil -} diff --git a/controller/server/reader.go b/controller/server/reader.go new file mode 100644 index 00000000..aac3f998 --- /dev/null +++ b/controller/server/reader.go @@ -0,0 +1,305 @@ +package server + +import ( + "crypto/sha1" + "encoding/base64" + "errors" + "io" + "net/url" + "strconv" + "strings" + + "github.com/tidwall/redcon" + "github.com/tidwall/resp" +) + +var errInvalidHTTP = errors.New("invalid HTTP request") + +// Type is resp type +type Type int + +const ( + Null Type = iota + RESP + Telnet + Native + HTTP + WebSocket + JSON +) + +// Message is a resp message +type Message struct { + Command string + Values []resp.Value + ConnType Type + OutputType Type + Auth string +} + +// PipelineReader ... +type PipelineReader struct { + rd io.Reader + wr io.Writer + pbuf [0xFFFF]byte + rbuf []byte +} + +const kindHTTP redcon.Kind = 9999 + +// NewPipelineReader ... +func NewPipelineReader(rd io.ReadWriter) *PipelineReader { + return &PipelineReader{rd: rd, wr: rd} +} + +func readcrlfline(packet []byte) (line string, leftover []byte, ok bool) { + for i := 1; i < len(packet); i++ { + if packet[i] == '\n' && packet[i-1] == '\r' { + return string(packet[:i-1]), packet[i+1:], true + } + } + return "", packet, false +} + +func readNextHTTPCommand(packet []byte, argsIn [][]byte, msg *Message, wr io.Writer) ( + complete bool, args [][]byte, kind redcon.Kind, leftover []byte, err error, +) { + args = argsIn[:0] + msg.ConnType = HTTP + msg.OutputType = JSON + opacket := packet + + ready, err := func() (bool, error) { + var line string + var ok bool + + // read header + var headers []string + for { + line, packet, ok = readcrlfline(packet) + if !ok { + return false, nil + } + if line == "" { + break + } + headers = append(headers, line) + } + parts := strings.Split(headers[0], " ") + if len(parts) != 3 { + return false, errInvalidHTTP + } + method := parts[0] + path := parts[1] + if len(path) == 0 || path[0] != '/' { + return false, errInvalidHTTP + } + path, err = url.QueryUnescape(path[1:]) + if err != nil { + return false, errInvalidHTTP + } + if method != "GET" && method != "POST" { + return false, errInvalidHTTP + } + contentLength := 0 + websocket := false + websocketVersion := 0 + websocketKey := "" + for _, header := range headers[1:] { + if header[0] == 'a' || header[0] == 'A' { + if strings.HasPrefix(strings.ToLower(header), "authorization:") { + msg.Auth = strings.TrimSpace(header[len("authorization:"):]) + } + } else if header[0] == 'u' || header[0] == 'U' { + if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" { + websocket = true + } + } else if header[0] == 's' || header[0] == 'S' { + if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") { + var n uint64 + n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64) + if err != nil { + return false, err + } + websocketVersion = int(n) + } else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") { + websocketKey = strings.TrimSpace(header[len("sec-websocket-key:"):]) + } + } else if header[0] == 'c' || header[0] == 'C' { + if strings.HasPrefix(strings.ToLower(header), "content-length:") { + var n uint64 + n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64) + if err != nil { + return false, err + } + contentLength = int(n) + } + } + } + if websocket && websocketVersion >= 13 && websocketKey != "" { + msg.ConnType = WebSocket + if wr == nil { + return false, errors.New("connection is nil") + } + sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + accept := base64.StdEncoding.EncodeToString(sum[:]) + wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n" + if _, err = wr.Write([]byte(wshead)); err != nil { + println(4) + return false, err + } + } else if contentLength > 0 { + msg.ConnType = HTTP + if len(packet) < contentLength { + return false, nil + } + path += string(packet[:contentLength]) + packet = packet[contentLength:] + } + if path == "" { + return true, nil + } + nmsg, err := readNativeMessageLine([]byte(path)) + if err != nil { + return false, err + } + + msg.OutputType = JSON + msg.Values = nmsg.Values + msg.Command = commandValues(nmsg.Values) + return true, nil + }() + if err != nil || !ready { + return false, args[:0], kindHTTP, opacket, err + } + return true, args[:0], kindHTTP, packet, nil +} +func readNextCommand(packet []byte, argsIn [][]byte, msg *Message, wr io.Writer) ( + complete bool, args [][]byte, kind redcon.Kind, leftover []byte, err error, +) { + if packet[0] == 'G' || packet[0] == 'P' { + // could be an HTTP request + var line []byte + for i := 1; i < len(packet); i++ { + if packet[i] == '\n' { + if packet[i-1] == '\r' { + line = packet[:i+1] + break + } + } + } + if len(line) == 0 { + return false, argsIn[:0], redcon.Redis, packet, nil + } + if len(line) > 11 && string(line[len(line)-11:len(line)-5]) == " HTTP/" { + return readNextHTTPCommand(packet, argsIn, msg, wr) + } + } + return redcon.ReadNextCommand(packet, args) +} + +// ReadMessages ... +func (rd *PipelineReader) ReadMessages() ([]*Message, error) { + var msgs []*Message +moreData: + n, err := rd.rd.Read(rd.pbuf[:]) + if err != nil { + return nil, err + } + if n == 0 { + // need more data + goto moreData + } + var packet []byte + if len(rd.rbuf) == 0 { + packet = rd.pbuf[:n] + } else { + rd.rbuf = append(rd.rbuf, rd.pbuf[:n]...) + packet = rd.rbuf + } + + for len(packet) > 0 { + msg := &Message{} + complete, args, kind, leftover, err := readNextCommand(packet, nil, msg, rd.wr) + if err != nil { + break + } + if !complete { + break + } + if kind != kindHTTP { + msg.Command = strings.ToLower(string(args[0])) + for i := 0; i < len(args); i++ { + msg.Values = append(msg.Values, resp.BytesValue(args[i])) + } + switch kind { + case redcon.Redis: + msg.ConnType = RESP + msg.OutputType = RESP + case redcon.Tile38: + msg.ConnType = Native + msg.OutputType = JSON + case redcon.Telnet: + msg.ConnType = RESP + msg.OutputType = RESP + } + } else if len(msg.Values) == 0 { + return nil, errInvalidHTTP + } + msgs = append(msgs, msg) + packet = leftover + } + if len(packet) > 0 { + rd.rbuf = append(rd.rbuf[:0], packet...) + } else if rd.rbuf != nil { + rd.rbuf = rd.rbuf[:0] + } + if err != nil && len(msgs) == 0 { + return nil, err + } + return msgs, nil +} + +func readNativeMessageLine(line []byte) (*Message, error) { + values := make([]resp.Value, 0, 16) +reading: + for len(line) != 0 { + if line[0] == '{' { + // The native protocol cannot understand json boundaries so it assumes that + // a json element must be at the end of the line. + values = append(values, resp.StringValue(string(line))) + break + } + if line[0] == '"' && line[len(line)-1] == '"' { + if len(values) > 0 && + strings.ToLower(values[0].String()) == "set" && + strings.ToLower(values[len(values)-1].String()) == "string" { + // Setting a string value that is contained inside double quotes. + // This is only because of the boundary issues of the native protocol. + values = append(values, resp.StringValue(string(line[1:len(line)-1]))) + break + } + } + i := 0 + for ; i < len(line); i++ { + if line[i] == ' ' { + value := string(line[:i]) + if value != "" { + values = append(values, resp.StringValue(value)) + } + line = line[i+1:] + continue reading + } + } + values = append(values, resp.StringValue(string(line))) + break + } + return &Message{Command: commandValues(values), Values: values, ConnType: Native, OutputType: JSON}, nil +} + +func commandValues(values []resp.Value) string { + if len(values) == 0 { + return "" + } + return strings.ToLower(values[0].String()) +} diff --git a/controller/server/server.go b/controller/server/server.go index fff1bf6b..4500b45e 100644 --- a/controller/server/server.go +++ b/controller/server/server.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -45,6 +46,7 @@ type Conn struct { Authenticated bool } +// SetKeepAlive sets the connection keepalive func (conn Conn) SetKeepAlive(period time.Duration) error { if tcp, ok := conn.Conn.(*net.TCPConn); ok { if err := tcp.SetKeepAlive(true); err != nil { @@ -61,7 +63,7 @@ var errCloseHTTP = errors.New("close http") func ListenAndServe( host string, port int, protected func() bool, - handler func(conn *Conn, msg *Message, rd *AnyReaderWriter, w io.Writer, websocket bool) error, + handler func(conn *Conn, msg *Message, rd *PipelineReader, w io.Writer, websocket bool) error, opened func(conn *Conn), closed func(conn *Conn), lnp *net.Listener, @@ -85,87 +87,90 @@ func ListenAndServe( } } -// func writeCommandErr(proto client.Proto, conn *Conn, err error) error { -// if proto == client.HTTP || proto == client.WebSocket { -// conn.Write([]byte(`HTTP/1.1 500 ` + err.Error() + "\r\nConnection: close\r\n\r\n")) -// } -// return err -// } - func handleConn( conn *Conn, protected func() bool, - handler func(conn *Conn, msg *Message, rd *AnyReaderWriter, w io.Writer, websocket bool) error, + handler func(conn *Conn, msg *Message, rd *PipelineReader, w io.Writer, websocket bool) error, opened func(conn *Conn), closed func(conn *Conn), http bool, ) { - opened(conn) - defer closed(conn) addr := conn.RemoteAddr().String() + opened(conn) if core.ShowDebugMessages { log.Debugf("opened connection: %s", addr) - defer func() { - log.Debugf("closed connection: %s", addr) - }() } + defer func() { + conn.Close() + closed(conn) + if core.ShowDebugMessages { + log.Debugf("closed connection: %s", addr) + } + }() if !strings.HasPrefix(addr, "127.0.0.1:") && !strings.HasPrefix(addr, "[::1]:") { if protected() { // This is a protected server. Only loopback is allowed. conn.Write(deniedMessage) - conn.Close() return } } - defer conn.Close() + + wr := &bytes.Buffer{} outputType := Null - rd := NewAnyReaderWriter(conn) + rd := NewPipelineReader(conn) for { - msg, err := rd.ReadMessage() - - // Just closing connection if we have deprecated HTTP or WS connection, - // And --http-transport = false - if !http && (msg.ConnType == WebSocket || msg.ConnType == HTTP) { - conn.Close() - return - } - - if err != nil { - if err == io.EOF { - return - } - if err == errCloseHTTP || - strings.Contains(err.Error(), "use of closed network connection") { - return - } - log.Error(err) - return - } - if msg != nil && msg.Command != "" { - if outputType != Null { - msg.OutputType = outputType - } - if msg.Command == "quit" { - if msg.OutputType == RESP { - io.WriteString(conn, "+OK\r\n") - } - return - } - err := handler(conn, msg, rd, conn, msg.ConnType == WebSocket) + wr.Reset() + ok := func() bool { + msgs, err := rd.ReadMessages() if err != nil { + if err == io.EOF { + return false + } + if err == errCloseHTTP || + strings.Contains(err.Error(), "use of closed network connection") { + return false + } log.Error(err) - return + return false } - outputType = msg.OutputType - } else { - conn.Write([]byte("HTTP/1.1 500 Bad Request\r\nConnection: close\r\n\r\n")) - return + for _, msg := range msgs { + // Just closing connection if we have deprecated HTTP or WS connection, + // And --http-transport = false + if !http && (msg.ConnType == WebSocket || msg.ConnType == HTTP) { + return false + } + if msg != nil && msg.Command != "" { + if outputType != Null { + msg.OutputType = outputType + } + if msg.Command == "quit" { + if msg.OutputType == RESP { + io.WriteString(wr, "+OK\r\n") + } + return false + } + err := handler(conn, msg, rd, wr, msg.ConnType == WebSocket) + if err != nil { + log.Error(err) + return false + } + outputType = msg.OutputType + } else { + wr.Write([]byte("HTTP/1.1 500 Bad Request\r\nConnection: close\r\n\r\n")) + return false + } + if msg.ConnType == HTTP || msg.ConnType == WebSocket { + return false + } + } + return true + }() + conn.Write(wr.Bytes()) + if !ok { + break } - if msg.ConnType == HTTP || msg.ConnType == WebSocket { - return - } - } + // all done } // WriteWebSocketMessage write a websocket message to an io.Writer. diff --git a/vendor/github.com/tidwall/redcon/LICENSE b/vendor/github.com/tidwall/redcon/LICENSE new file mode 100644 index 00000000..58f5819a --- /dev/null +++ b/vendor/github.com/tidwall/redcon/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2016 Josh Baker + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/tidwall/redcon/README.md b/vendor/github.com/tidwall/redcon/README.md new file mode 100644 index 00000000..fc5d8b3e --- /dev/null +++ b/vendor/github.com/tidwall/redcon/README.md @@ -0,0 +1,182 @@ +

+REDCON +
+Build Status +GoDoc +

+ +

Fast Redis compatible server framework for Go

+ +Redcon is a custom Redis server framework for Go that is fast and simple to use. The reason for this library it to give an efficient server front-end for the [BuntDB](https://github.com/tidwall/buntdb) and [Tile38](https://github.com/tidwall/tile38) projects. + +Features +-------- +- Create a [Fast](#benchmarks) custom Redis compatible server in Go +- Simple interface. One function `ListenAndServe` and two types `Conn` & `Command` +- Support for pipelining and telnet commands +- Works with Redis clients such as [redigo](https://github.com/garyburd/redigo), [redis-py](https://github.com/andymccurdy/redis-py), [node_redis](https://github.com/NodeRedis/node_redis), and [jedis](https://github.com/xetorthio/jedis) +- [TLS Support](#tls-example) + +Installing +---------- + +``` +go get -u github.com/tidwall/redcon +``` + +Example +------- + +Here's a full example of a Redis clone that accepts: + +- SET key value +- GET key +- DEL key +- PING +- QUIT + +You can run this example from a terminal: + +```sh +go run example/clone.go +``` + +```go +package main + +import ( + "log" + "strings" + "sync" + + "github.com/tidwall/redcon" +) + +var addr = ":6380" + +func main() { + var mu sync.RWMutex + var items = make(map[string][]byte) + go log.Printf("started server at %s", addr) + err := redcon.ListenAndServe(addr, + func(conn redcon.Conn, cmd redcon.Command) { + switch strings.ToLower(string(cmd.Args[0])) { + default: + conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'") + case "ping": + conn.WriteString("PONG") + case "quit": + conn.WriteString("OK") + conn.Close() + case "set": + if len(cmd.Args) != 3 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.Lock() + items[string(cmd.Args[1])] = cmd.Args[2] + mu.Unlock() + conn.WriteString("OK") + case "get": + if len(cmd.Args) != 2 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.RLock() + val, ok := items[string(cmd.Args[1])] + mu.RUnlock() + if !ok { + conn.WriteNull() + } else { + conn.WriteBulk(val) + } + case "del": + if len(cmd.Args) != 2 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.Lock() + _, ok := items[string(cmd.Args[1])] + delete(items, string(cmd.Args[1])) + mu.Unlock() + if !ok { + conn.WriteInt(0) + } else { + conn.WriteInt(1) + } + } + }, + func(conn redcon.Conn) bool { + // use this function to accept or deny the connection. + // log.Printf("accept: %s", conn.RemoteAddr()) + return true + }, + func(conn redcon.Conn, err error) { + // this is called when the connection has been closed + // log.Printf("closed: %s, err: %v", conn.RemoteAddr(), err) + }, + ) + if err != nil { + log.Fatal(err) + } +} +``` + +TLS Example +----------- + +Redcon has full TLS support through the `ListenAndServeTLS` function. + +The [same example](example/tls/clone.go) is also provided for serving Redcon over TLS. + +```sh +go run example/tls/clone.go +``` + +Benchmarks +---------- + +**Redis**: Single-threaded, no disk persistence. + +``` +$ redis-server --port 6379 --appendonly no +``` +``` +redis-benchmark -p 6379 -t set,get -n 10000000 -q -P 512 -c 512 +SET: 941265.12 requests per second +GET: 1189909.50 requests per second +``` + +**Redcon**: Single-threaded, no disk persistence. + +``` +$ GOMAXPROCS=1 go run example/clone.go +``` +``` +redis-benchmark -p 6380 -t set,get -n 10000000 -q -P 512 -c 512 +SET: 2018570.88 requests per second +GET: 2403846.25 requests per second +``` + +**Redcon**: Multi-threaded, no disk persistence. + +``` +$ GOMAXPROCS=0 go run example/clone.go +``` +``` +$ redis-benchmark -p 6380 -t set,get -n 10000000 -q -P 512 -c 512 +SET: 1944390.38 requests per second +GET: 3993610.25 requests per second +``` + +*Running on a MacBook Pro 15" 2.8 GHz Intel Core i7 using Go 1.7* + +Contact +------- +Josh Baker [@tidwall](http://twitter.com/tidwall) + +License +------- +Redcon source code is available under the MIT [License](/LICENSE). diff --git a/vendor/github.com/tidwall/redcon/append.go b/vendor/github.com/tidwall/redcon/append.go new file mode 100644 index 00000000..2c4ea717 --- /dev/null +++ b/vendor/github.com/tidwall/redcon/append.go @@ -0,0 +1,312 @@ +package redcon + +import ( + "strconv" + "strings" +) + +// Kind is the kind of command +type Kind int + +const ( + // Redis is returned for Redis protocol commands + Redis Kind = iota + // Tile38 is returnd for Tile38 native protocol commands + Tile38 + // Telnet is returnd for plain telnet commands + Telnet +) + +var errInvalidMessage = &errProtocol{"invalid message"} + +// ReadNextCommand reads the next command from the provided packet. It's +// possible that the packet contains multiple commands, or zero commands +// when the packet is incomplete. +// 'argsbuf' is an optional reusable buffer and it can be nil. +// 'complete' indicates that a command was read. false means no more commands. +// 'args' are the output arguments for the command. +// 'kind' is the type of command that was read. +// 'leftover' is any remaining unused bytes which belong to the next command. +// 'err' is returned when a protocol error was encountered. +func ReadNextCommand(packet []byte, argsbuf [][]byte) ( + complete bool, args [][]byte, kind Kind, leftover []byte, err error, +) { + args = argsbuf[:0] + if len(packet) > 0 { + if packet[0] != '*' { + if packet[0] == '$' { + return readTile38Command(packet, args) + } + return readTelnetCommand(packet, args) + } + // standard redis command + for s, i := 1, 1; i < len(packet); i++ { + if packet[i] == '\n' { + if packet[i-1] != '\r' { + return false, args[:0], Redis, packet, errInvalidMultiBulkLength + } + count, ok := parseInt(packet[s : i-1]) + if !ok || count < 0 { + return false, args[:0], Redis, packet, errInvalidMultiBulkLength + } + i++ + if count == 0 { + return true, args[:0], Redis, packet[i:], nil + } + nextArg: + for j := 0; j < count; j++ { + if i == len(packet) { + break + } + if packet[i] != '$' { + return false, args[:0], Redis, packet, + &errProtocol{"expected '$', got '" + + string(packet[i]) + "'"} + } + for s := i + 1; i < len(packet); i++ { + if packet[i] == '\n' { + if packet[i-1] != '\r' { + return false, args[:0], Redis, packet, errInvalidBulkLength + } + n, ok := parseInt(packet[s : i-1]) + if !ok || count <= 0 { + return false, args[:0], Redis, packet, errInvalidBulkLength + } + i++ + if len(packet)-i >= n+2 { + if packet[i+n] != '\r' || packet[i+n+1] != '\n' { + return false, args[:0], Redis, packet, errInvalidBulkLength + } + args = append(args, packet[i:i+n]) + i += n + 2 + if j == count-1 { + // done reading + return true, args, Redis, packet[i:], nil + } + continue nextArg + } + break + } + } + break + } + break + } + } + } + return false, args[:0], Redis, packet, nil +} + +func readTile38Command(packet []byte, argsbuf [][]byte) ( + complete bool, args [][]byte, kind Kind, leftover []byte, err error, +) { + for i := 1; i < len(packet); i++ { + if packet[i] == ' ' { + n, ok := parseInt(packet[1:i]) + if !ok || n < 0 { + return false, args[:0], Tile38, packet, errInvalidMessage + } + i++ + if len(packet) >= i+n+2 { + if packet[i+n] != '\r' || packet[i+n+1] != '\n' { + return false, args[:0], Tile38, packet, errInvalidMessage + } + line := packet[i : i+n] + reading: + for len(line) != 0 { + if line[0] == '{' { + // The native protocol cannot understand json boundaries so it assumes that + // a json element must be at the end of the line. + args = append(args, line) + break + } + if line[0] == '"' && line[len(line)-1] == '"' { + if len(args) > 0 && + strings.ToLower(string(args[0])) == "set" && + strings.ToLower(string(args[len(args)-1])) == "string" { + // Setting a string value that is contained inside double quotes. + // This is only because of the boundary issues of the native protocol. + args = append(args, line[1:len(line)-1]) + break + } + } + i := 0 + for ; i < len(line); i++ { + if line[i] == ' ' { + value := line[:i] + if len(value) > 0 { + args = append(args, value) + } + line = line[i+1:] + continue reading + } + } + args = append(args, line) + break + } + return true, args, Tile38, packet[i+n+2:], nil + } + break + } + } + return false, args[:0], Tile38, packet, nil +} +func readTelnetCommand(packet []byte, argsbuf [][]byte) ( + complete bool, args [][]byte, kind Kind, leftover []byte, err error, +) { + // just a plain text command + for i := 0; i < len(packet); i++ { + if packet[i] == '\n' { + var line []byte + if i > 0 && packet[i-1] == '\r' { + line = packet[:i-1] + } else { + line = packet[:i] + } + var quote bool + var quotech byte + var escape bool + outer: + for { + nline := make([]byte, 0, len(line)) + for i := 0; i < len(line); i++ { + c := line[i] + if !quote { + if c == ' ' { + if len(nline) > 0 { + args = append(args, nline) + } + line = line[i+1:] + continue outer + } + if c == '"' || c == '\'' { + if i != 0 { + return false, args[:0], Telnet, packet, errUnbalancedQuotes + } + quotech = c + quote = true + line = line[i+1:] + continue outer + } + } else { + if escape { + escape = false + switch c { + case 'n': + c = '\n' + case 'r': + c = '\r' + case 't': + c = '\t' + } + } else if c == quotech { + quote = false + quotech = 0 + args = append(args, nline) + line = line[i+1:] + if len(line) > 0 && line[0] != ' ' { + return false, args[:0], Telnet, packet, errUnbalancedQuotes + } + continue outer + } else if c == '\\' { + escape = true + continue + } + } + nline = append(nline, c) + } + if quote { + return false, args[:0], Telnet, packet, errUnbalancedQuotes + } + if len(line) > 0 { + args = append(args, line) + } + break + } + return true, args, Telnet, packet[i+1:], nil + } + } + return false, args[:0], Telnet, packet, nil +} + +// AppendUint appends a Redis protocol uint64 to the input bytes. +func AppendUint(b []byte, n uint64) []byte { + b = append(b, ':') + b = strconv.AppendUint(b, n, 10) + return append(b, '\r', '\n') +} + +// AppendInt appends a Redis protocol int64 to the input bytes. +func AppendInt(b []byte, n int64) []byte { + b = append(b, ':') + b = strconv.AppendInt(b, n, 10) + return append(b, '\r', '\n') +} + +// AppendArray appends a Redis protocol array to the input bytes. +func AppendArray(b []byte, n int) []byte { + b = append(b, '*') + b = strconv.AppendInt(b, int64(n), 10) + return append(b, '\r', '\n') +} + +// AppendBulk appends a Redis protocol bulk byte slice to the input bytes. +func AppendBulk(b []byte, bulk []byte) []byte { + b = append(b, '$') + b = strconv.AppendInt(b, int64(len(bulk)), 10) + b = append(b, '\r', '\n') + b = append(b, bulk...) + return append(b, '\r', '\n') +} + +// AppendBulkString appends a Redis protocol bulk string to the input bytes. +func AppendBulkString(b []byte, bulk string) []byte { + b = append(b, '$') + b = strconv.AppendInt(b, int64(len(bulk)), 10) + b = append(b, '\r', '\n') + b = append(b, bulk...) + return append(b, '\r', '\n') +} + +// AppendString appends a Redis protocol string to the input bytes. +func AppendString(b []byte, s string) []byte { + b = append(b, '+') + b = append(b, stripNewlines(s)...) + return append(b, '\r', '\n') +} + +// AppendError appends a Redis protocol error to the input bytes. +func AppendError(b []byte, s string) []byte { + b = append(b, '-') + b = append(b, stripNewlines(s)...) + return append(b, '\r', '\n') +} + +// AppendOK appends a Redis protocol OK to the input bytes. +func AppendOK(b []byte) []byte { + return append(b, '+', 'O', 'K', '\r', '\n') +} +func stripNewlines(s string) string { + for i := 0; i < len(s); i++ { + if s[i] == '\r' || s[i] == '\n' { + s = strings.Replace(s, "\r", " ", -1) + s = strings.Replace(s, "\n", " ", -1) + break + } + } + return s +} + +// AppendTile38 appends a Tile38 message to the input bytes. +func AppendTile38(b []byte, data []byte) []byte { + b = append(b, '$') + b = strconv.AppendInt(b, int64(len(data)), 10) + b = append(b, ' ') + b = append(b, data...) + return append(b, '\r', '\n') +} + +// AppendNull appends a Redis protocol null to the input bytes. +func AppendNull(b []byte) []byte { + return append(b, '$', '-', '1', '\r', '\n') +} diff --git a/vendor/github.com/tidwall/redcon/append_test.go b/vendor/github.com/tidwall/redcon/append_test.go new file mode 100644 index 00000000..238b1e48 --- /dev/null +++ b/vendor/github.com/tidwall/redcon/append_test.go @@ -0,0 +1,94 @@ +package redcon + +import ( + "bytes" + "math/rand" + "testing" + "time" +) + +func TestNextCommand(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + start := time.Now() + for time.Since(start) < time.Second { + // keep copy of pipeline args for final compare + var plargs [][][]byte + + // create a pipeline of random number of commands with random data. + N := rand.Int() % 10000 + var data []byte + for i := 0; i < N; i++ { + nargs := rand.Int() % 10 + data = AppendArray(data, nargs) + var args [][]byte + for j := 0; j < nargs; j++ { + arg := make([]byte, rand.Int()%100) + if _, err := rand.Read(arg); err != nil { + t.Fatal(err) + } + data = AppendBulk(data, arg) + args = append(args, arg) + } + plargs = append(plargs, args) + } + + // break data into random number of chunks + chunkn := rand.Int() % 100 + if chunkn == 0 { + chunkn = 1 + } + if len(data) < chunkn { + continue + } + var chunks [][]byte + var chunksz int + for i := 0; i < len(data); i += chunksz { + chunksz = rand.Int() % (len(data) / chunkn) + var chunk []byte + if i+chunksz < len(data) { + chunk = data[i : i+chunksz] + } else { + chunk = data[i:] + } + chunks = append(chunks, chunk) + } + + // process chunks + var rbuf []byte + var fargs [][][]byte + for _, chunk := range chunks { + var data []byte + if len(rbuf) > 0 { + data = append(rbuf, chunk...) + } else { + data = chunk + } + for { + complete, args, _, leftover, err := ReadNextCommand(data, nil) + data = leftover + if err != nil { + t.Fatal(err) + } + if !complete { + break + } + fargs = append(fargs, args) + } + rbuf = append(rbuf[:0], data...) + } + // compare final args to original + if len(plargs) != len(fargs) { + t.Fatalf("not equal size: %v != %v", len(plargs), len(fargs)) + } + for i := 0; i < len(plargs); i++ { + if len(plargs[i]) != len(fargs[i]) { + t.Fatalf("not equal size for item %v: %v != %v", i, len(plargs[i]), len(fargs[i])) + } + for j := 0; j < len(plargs[i]); j++ { + if !bytes.Equal(plargs[i][j], plargs[i][j]) { + t.Fatalf("not equal for item %v:%v: %v != %v", i, j, len(plargs[i][j]), len(fargs[i][j])) + } + } + } + } +} diff --git a/vendor/github.com/tidwall/redcon/example/clone.go b/vendor/github.com/tidwall/redcon/example/clone.go new file mode 100644 index 00000000..f32bc1c0 --- /dev/null +++ b/vendor/github.com/tidwall/redcon/example/clone.go @@ -0,0 +1,87 @@ +package main + +import ( + "log" + "strings" + "sync" + + "github.com/tidwall/redcon" +) + +var addr = ":6380" + +func main() { + var mu sync.RWMutex + var items = make(map[string][]byte) + go log.Printf("started server at %s", addr) + err := redcon.ListenAndServe(addr, + func(conn redcon.Conn, cmd redcon.Command) { + switch strings.ToLower(string(cmd.Args[0])) { + default: + conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'") + case "detach": + hconn := conn.Detach() + log.Printf("connection has been detached") + go func() { + defer hconn.Close() + hconn.WriteString("OK") + hconn.Flush() + }() + return + case "ping": + conn.WriteString("PONG") + case "quit": + conn.WriteString("OK") + conn.Close() + case "set": + if len(cmd.Args) != 3 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.Lock() + items[string(cmd.Args[1])] = cmd.Args[2] + mu.Unlock() + conn.WriteString("OK") + case "get": + if len(cmd.Args) != 2 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.RLock() + val, ok := items[string(cmd.Args[1])] + mu.RUnlock() + if !ok { + conn.WriteNull() + } else { + conn.WriteBulk(val) + } + case "del": + if len(cmd.Args) != 2 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.Lock() + _, ok := items[string(cmd.Args[1])] + delete(items, string(cmd.Args[1])) + mu.Unlock() + if !ok { + conn.WriteInt(0) + } else { + conn.WriteInt(1) + } + } + }, + func(conn redcon.Conn) bool { + // use this function to accept or deny the connection. + // log.Printf("accept: %s", conn.RemoteAddr()) + return true + }, + func(conn redcon.Conn, err error) { + // this is called when the connection has been closed + // log.Printf("closed: %s, err: %v", conn.RemoteAddr(), err) + }, + ) + if err != nil { + log.Fatal(err) + } +} diff --git a/vendor/github.com/tidwall/redcon/example/tls/clone.go b/vendor/github.com/tidwall/redcon/example/tls/clone.go new file mode 100644 index 00000000..8d1b67cf --- /dev/null +++ b/vendor/github.com/tidwall/redcon/example/tls/clone.go @@ -0,0 +1,118 @@ +package main + +import ( + "crypto/tls" + "log" + "strings" + "sync" + + "github.com/tidwall/redcon" +) + +const serverKey = `-----BEGIN EC PARAMETERS----- +BggqhkjOPQMBBw== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIHg+g2unjA5BkDtXSN9ShN7kbPlbCcqcYdDu+QeV8XWuoAoGCCqGSM49 +AwEHoUQDQgAEcZpodWh3SEs5Hh3rrEiu1LZOYSaNIWO34MgRxvqwz1FMpLxNlx0G +cSqrxhPubawptX5MSr02ft32kfOlYbaF5Q== +-----END EC PRIVATE KEY----- +` + +const serverCert = `-----BEGIN CERTIFICATE----- +MIIB+TCCAZ+gAwIBAgIJAL05LKXo6PrrMAoGCCqGSM49BAMCMFkxCzAJBgNVBAYT +AkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRn +aXRzIFB0eSBMdGQxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0xNTEyMDgxNDAxMTNa +Fw0yNTEyMDUxNDAxMTNaMFkxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 +YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMM +CWxvY2FsaG9zdDBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABHGaaHVod0hLOR4d +66xIrtS2TmEmjSFjt+DIEcb6sM9RTKS8TZcdBnEqq8YT7m2sKbV+TEq9Nn7d9pHz +pWG2heWjUDBOMB0GA1UdDgQWBBR0fqrecDJ44D/fiYJiOeBzfoqEijAfBgNVHSME +GDAWgBR0fqrecDJ44D/fiYJiOeBzfoqEijAMBgNVHRMEBTADAQH/MAoGCCqGSM49 +BAMCA0gAMEUCIEKzVMF3JqjQjuM2rX7Rx8hancI5KJhwfeKu1xbyR7XaAiEA2UT7 +1xOP035EcraRmWPe7tO0LpXgMxlh2VItpc2uc2w= +-----END CERTIFICATE----- +` + +var addr = ":6380" + +func main() { + cer, err := tls.X509KeyPair([]byte(serverCert), []byte(serverKey)) + if err != nil { + log.Fatal(err) + } + config := &tls.Config{Certificates: []tls.Certificate{cer}} + + var mu sync.RWMutex + var items = make(map[string][]byte) + + go log.Printf("started server at %s", addr) + err = redcon.ListenAndServeTLS(addr, + func(conn redcon.Conn, cmd redcon.Command) { + switch strings.ToLower(string(cmd.Args[0])) { + default: + conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'") + case "detach": + hconn := conn.Detach() + log.Printf("connection has been detached") + go func() { + defer hconn.Close() + hconn.WriteString("OK") + hconn.Flush() + }() + return + case "ping": + conn.WriteString("PONG") + case "quit": + conn.WriteString("OK") + conn.Close() + case "set": + if len(cmd.Args) != 3 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.Lock() + items[string(cmd.Args[1])] = cmd.Args[2] + mu.Unlock() + conn.WriteString("OK") + case "get": + if len(cmd.Args) != 2 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.RLock() + val, ok := items[string(cmd.Args[1])] + mu.RUnlock() + if !ok { + conn.WriteNull() + } else { + conn.WriteBulk(val) + } + case "del": + if len(cmd.Args) != 2 { + conn.WriteError("ERR wrong number of arguments for '" + string(cmd.Args[0]) + "' command") + return + } + mu.Lock() + _, ok := items[string(cmd.Args[1])] + delete(items, string(cmd.Args[1])) + mu.Unlock() + if !ok { + conn.WriteInt(0) + } else { + conn.WriteInt(1) + } + } + }, + func(conn redcon.Conn) bool { + return true + }, + func(conn redcon.Conn, err error) { + }, + config, + ) + + if err != nil { + log.Fatal(err) + } +} diff --git a/vendor/github.com/tidwall/redcon/logo.png b/vendor/github.com/tidwall/redcon/logo.png new file mode 100644 index 00000000..ee336156 Binary files /dev/null and b/vendor/github.com/tidwall/redcon/logo.png differ diff --git a/vendor/github.com/tidwall/redcon/redcon.go b/vendor/github.com/tidwall/redcon/redcon.go new file mode 100644 index 00000000..0a9b56da --- /dev/null +++ b/vendor/github.com/tidwall/redcon/redcon.go @@ -0,0 +1,861 @@ +// Package redcon implements a Redis compatible server framework +package redcon + +import ( + "bufio" + "crypto/tls" + "errors" + "io" + "net" + "sync" +) + +var ( + errUnbalancedQuotes = &errProtocol{"unbalanced quotes in request"} + errInvalidBulkLength = &errProtocol{"invalid bulk length"} + errInvalidMultiBulkLength = &errProtocol{"invalid multibulk length"} + errDetached = errors.New("detached") + errIncompleteCommand = errors.New("incomplete command") + errTooMuchData = errors.New("too much data") +) + +type errProtocol struct { + msg string +} + +func (err *errProtocol) Error() string { + return "Protocol error: " + err.msg +} + +// Conn represents a client connection +type Conn interface { + // RemoteAddr returns the remote address of the client connection. + RemoteAddr() string + // Close closes the connection. + Close() error + // WriteError writes an error to the client. + WriteError(msg string) + // WriteString writes a string to the client. + WriteString(str string) + // WriteBulk writes bulk bytes to the client. + WriteBulk(bulk []byte) + // WriteBulkString writes a bulk string to the client. + WriteBulkString(bulk string) + // WriteInt writes an integer to the client. + WriteInt(num int) + // WriteInt64 writes a 64-but signed integer to the client. + WriteInt64(num int64) + // WriteArray writes an array header. You must then write additional + // sub-responses to the client to complete the response. + // For example to write two strings: + // + // c.WriteArray(2) + // c.WriteBulk("item 1") + // c.WriteBulk("item 2") + WriteArray(count int) + // WriteNull writes a null to the client + WriteNull() + // WriteRaw writes raw data to the client. + WriteRaw(data []byte) + // Context returns a user-defined context + Context() interface{} + // SetContext sets a user-defined context + SetContext(v interface{}) + // SetReadBuffer updates the buffer read size for the connection + SetReadBuffer(bytes int) + // Detach return a connection that is detached from the server. + // Useful for operations like PubSub. + // + // dconn := conn.Detach() + // go func(){ + // defer dconn.Close() + // cmd, err := dconn.ReadCommand() + // if err != nil{ + // fmt.Printf("read failed: %v\n", err) + // return + // } + // fmt.Printf("received command: %v", cmd) + // hconn.WriteString("OK") + // if err := dconn.Flush(); err != nil{ + // fmt.Printf("write failed: %v\n", err) + // return + // } + // }() + Detach() DetachedConn + // ReadPipeline returns all commands in current pipeline, if any + // The commands are removed from the pipeline. + ReadPipeline() []Command + // PeekPipeline returns all commands in current pipeline, if any. + // The commands remain in the pipeline. + PeekPipeline() []Command + // NetConn returns the base net.Conn connection + NetConn() net.Conn +} + +// NewServer returns a new Redcon server configured on "tcp" network net. +func NewServer(addr string, + handler func(conn Conn, cmd Command), + accept func(conn Conn) bool, + closed func(conn Conn, err error), +) *Server { + return NewServerNetwork("tcp", addr, handler, accept, closed) +} + +// NewServerNetwork returns a new Redcon server. The network net must be +// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" +func NewServerNetwork( + net, laddr string, + handler func(conn Conn, cmd Command), + accept func(conn Conn) bool, + closed func(conn Conn, err error), +) *Server { + if handler == nil { + panic("handler is nil") + } + s := &Server{ + net: net, + laddr: laddr, + handler: handler, + accept: accept, + closed: closed, + conns: make(map[*conn]bool), + } + return s +} + +// NewServerNetworkTLS returns a new TLS Redcon server. The network net must be +// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" +func NewServerNetworkTLS( + net, laddr string, + handler func(conn Conn, cmd Command), + accept func(conn Conn) bool, + closed func(conn Conn, err error), + config *tls.Config, +) *TLSServer { + if handler == nil { + panic("handler is nil") + } + s := Server{ + net: net, + laddr: laddr, + handler: handler, + accept: accept, + closed: closed, + conns: make(map[*conn]bool), + } + + tls := &TLSServer{ + config: config, + Server: &s, + } + return tls +} + +// Close stops listening on the TCP address. +// Already Accepted connections will be closed. +func (s *Server) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.ln == nil { + return errors.New("not serving") + } + s.done = true + return s.ln.Close() +} + +// ListenAndServe serves incoming connections. +func (s *Server) ListenAndServe() error { + return s.ListenServeAndSignal(nil) +} + +// Close stops listening on the TCP address. +// Already Accepted connections will be closed. +func (s *TLSServer) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.ln == nil { + return errors.New("not serving") + } + s.done = true + return s.ln.Close() +} + +// ListenAndServe serves incoming connections. +func (s *TLSServer) ListenAndServe() error { + return s.ListenServeAndSignal(nil) +} + +// ListenAndServe creates a new server and binds to addr configured on "tcp" network net. +func ListenAndServe(addr string, + handler func(conn Conn, cmd Command), + accept func(conn Conn) bool, + closed func(conn Conn, err error), +) error { + return ListenAndServeNetwork("tcp", addr, handler, accept, closed) +} + +// ListenAndServeTLS creates a new TLS server and binds to addr configured on "tcp" network net. +func ListenAndServeTLS(addr string, + handler func(conn Conn, cmd Command), + accept func(conn Conn) bool, + closed func(conn Conn, err error), + config *tls.Config, +) error { + return ListenAndServeNetworkTLS("tcp", addr, handler, accept, closed, config) +} + +// ListenAndServeNetwork creates a new server and binds to addr. The network net must be +// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" +func ListenAndServeNetwork( + net, laddr string, + handler func(conn Conn, cmd Command), + accept func(conn Conn) bool, + closed func(conn Conn, err error), +) error { + return NewServerNetwork(net, laddr, handler, accept, closed).ListenAndServe() +} + +// ListenAndServeNetworkTLS creates a new TLS server and binds to addr. The network net must be +// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket" +func ListenAndServeNetworkTLS( + net, laddr string, + handler func(conn Conn, cmd Command), + accept func(conn Conn) bool, + closed func(conn Conn, err error), + config *tls.Config, +) error { + return NewServerNetworkTLS(net, laddr, handler, accept, closed, config).ListenAndServe() +} + +// ListenServeAndSignal serves incoming connections and passes nil or error +// when listening. signal can be nil. +func (s *Server) ListenServeAndSignal(signal chan error) error { + ln, err := net.Listen(s.net, s.laddr) + if err != nil { + if signal != nil { + signal <- err + } + return err + } + if signal != nil { + signal <- nil + } + return serve(s, ln) +} + +// ListenServeAndSignal serves incoming connections and passes nil or error +// when listening. signal can be nil. +func (s *TLSServer) ListenServeAndSignal(signal chan error) error { + ln, err := tls.Listen(s.net, s.laddr, s.config) + if err != nil { + if signal != nil { + signal <- err + } + return err + } + if signal != nil { + signal <- nil + } + return serve(s.Server, ln) +} + +func serve(s *Server, ln net.Listener) error { + s.mu.Lock() + s.ln = ln + s.mu.Unlock() + defer func() { + ln.Close() + func() { + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.conns { + c.Close() + } + s.conns = nil + }() + }() + for { + lnconn, err := ln.Accept() + if err != nil { + s.mu.Lock() + done := s.done + s.mu.Unlock() + if done { + return nil + } + return err + } + c := &conn{ + conn: lnconn, + addr: lnconn.RemoteAddr().String(), + wr: NewWriter(lnconn), + rd: NewReader(lnconn), + } + s.mu.Lock() + s.conns[c] = true + s.mu.Unlock() + if s.accept != nil && !s.accept(c) { + s.mu.Lock() + delete(s.conns, c) + s.mu.Unlock() + c.Close() + continue + } + go handle(s, c) + } +} + +// handle manages the server connection. +func handle(s *Server, c *conn) { + var err error + defer func() { + if err != errDetached { + // do not close the connection when a detach is detected. + c.conn.Close() + } + func() { + // remove the conn from the server + s.mu.Lock() + defer s.mu.Unlock() + delete(s.conns, c) + if s.closed != nil { + if err == io.EOF { + err = nil + } + s.closed(c, err) + } + }() + }() + + err = func() error { + // read commands and feed back to the client + for { + // read pipeline commands + cmds, err := c.rd.readCommands(nil) + if err != nil { + if err, ok := err.(*errProtocol); ok { + // All protocol errors should attempt a response to + // the client. Ignore write errors. + c.wr.WriteError("ERR " + err.Error()) + c.wr.Flush() + } + return err + } + c.cmds = cmds + for len(c.cmds) > 0 { + cmd := c.cmds[0] + if len(c.cmds) == 1 { + c.cmds = nil + } else { + c.cmds = c.cmds[1:] + } + s.handler(c, cmd) + } + if c.detached { + // client has been detached + return errDetached + } + if c.closed { + return nil + } + if err := c.wr.Flush(); err != nil { + return err + } + } + }() +} + +// conn represents a client connection +type conn struct { + conn net.Conn + wr *Writer + rd *Reader + addr string + ctx interface{} + detached bool + closed bool + cmds []Command +} + +func (c *conn) Close() error { + c.wr.Flush() + c.closed = true + return c.conn.Close() +} +func (c *conn) Context() interface{} { return c.ctx } +func (c *conn) SetContext(v interface{}) { c.ctx = v } +func (c *conn) SetReadBuffer(n int) {} +func (c *conn) WriteString(str string) { c.wr.WriteString(str) } +func (c *conn) WriteBulk(bulk []byte) { c.wr.WriteBulk(bulk) } +func (c *conn) WriteBulkString(bulk string) { c.wr.WriteBulkString(bulk) } +func (c *conn) WriteInt(num int) { c.wr.WriteInt(num) } +func (c *conn) WriteInt64(num int64) { c.wr.WriteInt64(num) } +func (c *conn) WriteError(msg string) { c.wr.WriteError(msg) } +func (c *conn) WriteArray(count int) { c.wr.WriteArray(count) } +func (c *conn) WriteNull() { c.wr.WriteNull() } +func (c *conn) WriteRaw(data []byte) { c.wr.WriteRaw(data) } +func (c *conn) RemoteAddr() string { return c.addr } +func (c *conn) ReadPipeline() []Command { + cmds := c.cmds + c.cmds = nil + return cmds +} +func (c *conn) PeekPipeline() []Command { + return c.cmds +} +func (c *conn) NetConn() net.Conn { + return c.conn +} + +// BaseWriter returns the underlying connection writer, if any +func BaseWriter(c Conn) *Writer { + if c, ok := c.(*conn); ok { + return c.wr + } + return nil +} + +// DetachedConn represents a connection that is detached from the server +type DetachedConn interface { + // Conn is the original connection + Conn + // ReadCommand reads the next client command. + ReadCommand() (Command, error) + // Flush flushes any writes to the network. + Flush() error +} + +// Detach removes the current connection from the server loop and returns +// a detached connection. This is useful for operations such as PubSub. +// The detached connection must be closed by calling Close() when done. +// All writes such as WriteString() will not be written to the client +// until Flush() is called. +func (c *conn) Detach() DetachedConn { + c.detached = true + cmds := c.cmds + c.cmds = nil + return &detachedConn{conn: c, cmds: cmds} +} + +type detachedConn struct { + *conn + cmds []Command +} + +// Flush writes and Write* calls to the client. +func (dc *detachedConn) Flush() error { + return dc.conn.wr.Flush() +} + +// ReadCommand read the next command from the client. +func (dc *detachedConn) ReadCommand() (Command, error) { + if dc.closed { + return Command{}, errors.New("closed") + } + if len(dc.cmds) > 0 { + cmd := dc.cmds[0] + if len(dc.cmds) == 1 { + dc.cmds = nil + } else { + dc.cmds = dc.cmds[1:] + } + return cmd, nil + } + cmd, err := dc.rd.ReadCommand() + if err != nil { + return Command{}, err + } + return cmd, nil +} + +// Command represent a command +type Command struct { + // Raw is a encoded RESP message. + Raw []byte + // Args is a series of arguments that make up the command. + Args [][]byte +} + +// Server defines a server for clients for managing client connections. +type Server struct { + mu sync.Mutex + net string + laddr string + handler func(conn Conn, cmd Command) + accept func(conn Conn) bool + closed func(conn Conn, err error) + conns map[*conn]bool + ln net.Listener + done bool +} + +// TLSServer defines a server for clients for managing client connections. +type TLSServer struct { + *Server + config *tls.Config +} + +// Writer allows for writing RESP messages. +type Writer struct { + w io.Writer + b []byte +} + +// NewWriter creates a new RESP writer. +func NewWriter(wr io.Writer) *Writer { + return &Writer{ + w: wr, + } +} + +// WriteNull writes a null to the client +func (w *Writer) WriteNull() { + w.b = AppendNull(w.b) +} + +// WriteArray writes an array header. You must then write additional +// sub-responses to the client to complete the response. +// For example to write two strings: +// +// c.WriteArray(2) +// c.WriteBulk("item 1") +// c.WriteBulk("item 2") +func (w *Writer) WriteArray(count int) { + w.b = AppendArray(w.b, count) +} + +// WriteBulk writes bulk bytes to the client. +func (w *Writer) WriteBulk(bulk []byte) { + w.b = AppendBulk(w.b, bulk) +} + +// WriteBulkString writes a bulk string to the client. +func (w *Writer) WriteBulkString(bulk string) { + w.b = AppendBulkString(w.b, bulk) +} + +// Buffer returns the unflushed buffer. This is a copy so changes +// to the resulting []byte will not affect the writer. +func (w *Writer) Buffer() []byte { + return append([]byte(nil), w.b...) +} + +// SetBuffer replaces the unflushed buffer with new bytes. +func (w *Writer) SetBuffer(raw []byte) { + w.b = w.b[:0] + w.b = append(w.b, raw...) +} + +// Flush writes all unflushed Write* calls to the underlying writer. +func (w *Writer) Flush() error { + if _, err := w.w.Write(w.b); err != nil { + return err + } + w.b = w.b[:0] + return nil +} + +// WriteError writes an error to the client. +func (w *Writer) WriteError(msg string) { + w.b = AppendError(w.b, msg) +} + +// WriteString writes a string to the client. +func (w *Writer) WriteString(msg string) { + w.b = AppendString(w.b, msg) +} + +// WriteInt writes an integer to the client. +func (w *Writer) WriteInt(num int) { + w.WriteInt64(int64(num)) +} + +// WriteInt64 writes a 64-bit signed integer to the client. +func (w *Writer) WriteInt64(num int64) { + w.b = AppendInt(w.b, num) +} + +// WriteRaw writes raw data to the client. +func (w *Writer) WriteRaw(data []byte) { + w.b = append(w.b, data...) +} + +// Reader represent a reader for RESP or telnet commands. +type Reader struct { + rd *bufio.Reader + buf []byte + start int + end int + cmds []Command +} + +// NewReader returns a command reader which will read RESP or telnet commands. +func NewReader(rd io.Reader) *Reader { + return &Reader{ + rd: bufio.NewReader(rd), + buf: make([]byte, 4096), + } +} + +func parseInt(b []byte) (int, bool) { + if len(b) == 1 && b[0] >= '0' && b[0] <= '9' { + return int(b[0] - '0'), true + } + var n int + var sign bool + var i int + if len(b) > 0 && b[0] == '-' { + sign = true + i++ + } + for ; i < len(b); i++ { + if b[i] < '0' || b[i] > '9' { + return 0, false + } + n = n*10 + int(b[i]-'0') + } + if sign { + n *= -1 + } + return n, true +} + +func (rd *Reader) readCommands(leftover *int) ([]Command, error) { + var cmds []Command + b := rd.buf[rd.start:rd.end] + if rd.end-rd.start == 0 && len(rd.buf) > 4096 { + rd.buf = rd.buf[:4096] + rd.start = 0 + rd.end = 0 + } + if len(b) > 0 { + // we have data, yay! + // but is this enough data for a complete command? or multiple? + next: + switch b[0] { + default: + // just a plain text command + for i := 0; i < len(b); i++ { + if b[i] == '\n' { + var line []byte + if i > 0 && b[i-1] == '\r' { + line = b[:i-1] + } else { + line = b[:i] + } + var cmd Command + var quote bool + var quotech byte + var escape bool + outer: + for { + nline := make([]byte, 0, len(line)) + for i := 0; i < len(line); i++ { + c := line[i] + if !quote { + if c == ' ' { + if len(nline) > 0 { + cmd.Args = append(cmd.Args, nline) + } + line = line[i+1:] + continue outer + } + if c == '"' || c == '\'' { + if i != 0 { + return nil, errUnbalancedQuotes + } + quotech = c + quote = true + line = line[i+1:] + continue outer + } + } else { + if escape { + escape = false + switch c { + case 'n': + c = '\n' + case 'r': + c = '\r' + case 't': + c = '\t' + } + } else if c == quotech { + quote = false + quotech = 0 + cmd.Args = append(cmd.Args, nline) + line = line[i+1:] + if len(line) > 0 && line[0] != ' ' { + return nil, errUnbalancedQuotes + } + continue outer + } else if c == '\\' { + escape = true + continue + } + } + nline = append(nline, c) + } + if quote { + return nil, errUnbalancedQuotes + } + if len(line) > 0 { + cmd.Args = append(cmd.Args, line) + } + break + } + if len(cmd.Args) > 0 { + // convert this to resp command syntax + var wr Writer + wr.WriteArray(len(cmd.Args)) + for i := range cmd.Args { + wr.WriteBulk(cmd.Args[i]) + cmd.Args[i] = append([]byte(nil), cmd.Args[i]...) + } + cmd.Raw = wr.b + cmds = append(cmds, cmd) + } + b = b[i+1:] + if len(b) > 0 { + goto next + } else { + goto done + } + } + } + case '*': + // resp formatted command + marks := make([]int, 0, 16) + outer2: + for i := 1; i < len(b); i++ { + if b[i] == '\n' { + if b[i-1] != '\r' { + return nil, errInvalidMultiBulkLength + } + count, ok := parseInt(b[1 : i-1]) + if !ok || count <= 0 { + return nil, errInvalidMultiBulkLength + } + marks = marks[:0] + for j := 0; j < count; j++ { + // read bulk length + i++ + if i < len(b) { + if b[i] != '$' { + return nil, &errProtocol{"expected '$', got '" + + string(b[i]) + "'"} + } + si := i + for ; i < len(b); i++ { + if b[i] == '\n' { + if b[i-1] != '\r' { + return nil, errInvalidBulkLength + } + size, ok := parseInt(b[si+1 : i-1]) + if !ok || size < 0 { + return nil, errInvalidBulkLength + } + if i+size+2 >= len(b) { + // not ready + break outer2 + } + if b[i+size+2] != '\n' || + b[i+size+1] != '\r' { + return nil, errInvalidBulkLength + } + i++ + marks = append(marks, i, i+size) + i += size + 1 + break + } + } + } + } + if len(marks) == count*2 { + var cmd Command + if rd.rd != nil { + // make a raw copy of the entire command when + // there's a underlying reader. + cmd.Raw = append([]byte(nil), b[:i+1]...) + } else { + // just assign the slice + cmd.Raw = b[:i+1] + } + cmd.Args = make([][]byte, len(marks)/2) + // slice up the raw command into the args based on + // the recorded marks. + for h := 0; h < len(marks); h += 2 { + cmd.Args[h/2] = cmd.Raw[marks[h]:marks[h+1]] + } + cmds = append(cmds, cmd) + b = b[i+1:] + if len(b) > 0 { + goto next + } else { + goto done + } + } + } + } + } + done: + rd.start = rd.end - len(b) + } + if leftover != nil { + *leftover = rd.end - rd.start + } + if len(cmds) > 0 { + return cmds, nil + } + if rd.rd == nil { + return nil, errIncompleteCommand + } + if rd.end == len(rd.buf) { + // at the end of the buffer. + if rd.start == rd.end { + // rewind the to the beginning + rd.start, rd.end = 0, 0 + } else { + // must grow the buffer + newbuf := make([]byte, len(rd.buf)*2) + copy(newbuf, rd.buf) + rd.buf = newbuf + } + } + n, err := rd.rd.Read(rd.buf[rd.end:]) + if err != nil { + return nil, err + } + rd.end += n + return rd.readCommands(leftover) +} + +// ReadCommand reads the next command. +func (rd *Reader) ReadCommand() (Command, error) { + if len(rd.cmds) > 0 { + cmd := rd.cmds[0] + rd.cmds = rd.cmds[1:] + return cmd, nil + } + cmds, err := rd.readCommands(nil) + if err != nil { + return Command{}, err + } + rd.cmds = cmds + return rd.ReadCommand() +} + +// Parse parses a raw RESP message and returns a command. +func Parse(raw []byte) (Command, error) { + rd := Reader{buf: raw, end: len(raw)} + var leftover int + cmds, err := rd.readCommands(&leftover) + if err != nil { + return Command{}, err + } + if leftover > 0 { + return Command{}, errTooMuchData + } + return cmds[0], nil + +} diff --git a/vendor/github.com/tidwall/redcon/redcon_test.go b/vendor/github.com/tidwall/redcon/redcon_test.go new file mode 100644 index 00000000..c93d5117 --- /dev/null +++ b/vendor/github.com/tidwall/redcon/redcon_test.go @@ -0,0 +1,556 @@ +package redcon + +import ( + "bytes" + "fmt" + "io" + "log" + "math/rand" + "net" + "os" + "strconv" + "strings" + "testing" + "time" +) + +// TestRandomCommands fills a bunch of random commands and test various +// ways that the reader may receive data. +func TestRandomCommands(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + + // build random commands. + gcmds := make([][]string, 10000) + for i := 0; i < len(gcmds); i++ { + args := make([]string, (rand.Int()%50)+1) // 1-50 args + for j := 0; j < len(args); j++ { + n := rand.Int() % 10 + if j == 0 { + n++ + } + arg := make([]byte, n) + for k := 0; k < len(arg); k++ { + arg[k] = byte(rand.Int() % 0xFF) + } + args[j] = string(arg) + } + gcmds[i] = args + } + // create a list of a buffers + var bufs []string + + // pipe valid RESP commands + for i := 0; i < len(gcmds); i++ { + args := gcmds[i] + msg := fmt.Sprintf("*%d\r\n", len(args)) + for j := 0; j < len(args); j++ { + msg += fmt.Sprintf("$%d\r\n%s\r\n", len(args[j]), args[j]) + } + bufs = append(bufs, msg) + } + bufs = append(bufs, "RESET THE INDEX\r\n") + + // pipe valid plain commands + for i := 0; i < len(gcmds); i++ { + args := gcmds[i] + var msg string + for j := 0; j < len(args); j++ { + quotes := false + var narg []byte + arg := args[j] + if len(arg) == 0 { + quotes = true + } + for k := 0; k < len(arg); k++ { + switch arg[k] { + default: + narg = append(narg, arg[k]) + case ' ': + quotes = true + narg = append(narg, arg[k]) + case '\\', '"', '*': + quotes = true + narg = append(narg, '\\', arg[k]) + case '\r': + quotes = true + narg = append(narg, '\\', 'r') + case '\n': + quotes = true + narg = append(narg, '\\', 'n') + } + } + msg += " " + if quotes { + msg += "\"" + } + msg += string(narg) + if quotes { + msg += "\"" + } + } + if msg != "" { + msg = msg[1:] + } + msg += "\r\n" + bufs = append(bufs, msg) + } + bufs = append(bufs, "RESET THE INDEX\r\n") + + // pipe valid RESP commands in broken chunks + lmsg := "" + for i := 0; i < len(gcmds); i++ { + args := gcmds[i] + msg := fmt.Sprintf("*%d\r\n", len(args)) + for j := 0; j < len(args); j++ { + msg += fmt.Sprintf("$%d\r\n%s\r\n", len(args[j]), args[j]) + } + msg = lmsg + msg + if len(msg) > 0 { + lmsg = msg[len(msg)/2:] + msg = msg[:len(msg)/2] + } + bufs = append(bufs, msg) + } + bufs = append(bufs, lmsg) + bufs = append(bufs, "RESET THE INDEX\r\n") + + // pipe valid RESP commands in large broken chunks + lmsg = "" + for i := 0; i < len(gcmds); i++ { + args := gcmds[i] + msg := fmt.Sprintf("*%d\r\n", len(args)) + for j := 0; j < len(args); j++ { + msg += fmt.Sprintf("$%d\r\n%s\r\n", len(args[j]), args[j]) + } + if len(lmsg) < 1500 { + lmsg += msg + continue + } + msg = lmsg + msg + if len(msg) > 0 { + lmsg = msg[len(msg)/2:] + msg = msg[:len(msg)/2] + } + bufs = append(bufs, msg) + } + bufs = append(bufs, lmsg) + bufs = append(bufs, "RESET THE INDEX\r\n") + + // Pipe the buffers in a background routine + rd, wr := io.Pipe() + go func() { + defer wr.Close() + for _, msg := range bufs { + io.WriteString(wr, msg) + } + }() + defer rd.Close() + cnt := 0 + idx := 0 + start := time.Now() + r := NewReader(rd) + for { + cmd, err := r.ReadCommand() + if err != nil { + if err == io.EOF { + break + } + log.Fatal(err) + } + if len(cmd.Args) == 3 && string(cmd.Args[0]) == "RESET" && + string(cmd.Args[1]) == "THE" && string(cmd.Args[2]) == "INDEX" { + if idx != len(gcmds) { + t.Fatalf("did not process all commands") + } + idx = 0 + break + } + if len(cmd.Args) != len(gcmds[idx]) { + t.Fatalf("len not equal for index %d -- %d != %d", idx, len(cmd.Args), len(gcmds[idx])) + } + for i := 0; i < len(cmd.Args); i++ { + if i == 0 { + if len(cmd.Args[i]) == len(gcmds[idx][i]) { + ok := true + for j := 0; j < len(cmd.Args[i]); j++ { + c1, c2 := cmd.Args[i][j], gcmds[idx][i][j] + if c1 >= 'A' && c1 <= 'Z' { + c1 += 32 + } + if c2 >= 'A' && c2 <= 'Z' { + c2 += 32 + } + if c1 != c2 { + ok = false + break + } + } + if ok { + continue + } + } + } else if string(cmd.Args[i]) == string(gcmds[idx][i]) { + continue + } + t.Fatalf("not equal for index %d/%d", idx, i) + } + idx++ + cnt++ + } + if false { + dur := time.Now().Sub(start) + fmt.Printf("%d commands in %s - %.0f ops/sec\n", cnt, dur, float64(cnt)/(float64(dur)/float64(time.Second))) + } +} +func testDetached(t *testing.T, conn DetachedConn) { + conn.WriteString("DETACHED") + if err := conn.Flush(); err != nil { + t.Fatal(err) + } +} +func TestServerTCP(t *testing.T) { + testServerNetwork(t, "tcp", ":12345") +} +func TestServerUnix(t *testing.T) { + os.RemoveAll("/tmp/redcon-unix.sock") + defer os.RemoveAll("/tmp/redcon-unix.sock") + testServerNetwork(t, "unix", "/tmp/redcon-unix.sock") +} + +func testServerNetwork(t *testing.T, network, laddr string) { + s := NewServerNetwork(network, laddr, + func(conn Conn, cmd Command) { + switch strings.ToLower(string(cmd.Args[0])) { + default: + conn.WriteError("ERR unknown command '" + string(cmd.Args[0]) + "'") + case "ping": + conn.WriteString("PONG") + case "quit": + conn.WriteString("OK") + conn.Close() + case "detach": + go testDetached(t, conn.Detach()) + case "int": + conn.WriteInt(100) + case "bulk": + conn.WriteBulkString("bulk") + case "bulkbytes": + conn.WriteBulk([]byte("bulkbytes")) + case "null": + conn.WriteNull() + case "err": + conn.WriteError("ERR error") + case "array": + conn.WriteArray(2) + conn.WriteInt(99) + conn.WriteString("Hi!") + } + }, + func(conn Conn) bool { + //log.Printf("accept: %s", conn.RemoteAddr()) + return true + }, + func(conn Conn, err error) { + //log.Printf("closed: %s [%v]", conn.RemoteAddr(), err) + }, + ) + if err := s.Close(); err == nil { + t.Fatalf("expected an error, should not be able to close before serving") + } + go func() { + time.Sleep(time.Second / 4) + if err := ListenAndServeNetwork(network, laddr, func(conn Conn, cmd Command) {}, nil, nil); err == nil { + t.Fatalf("expected an error, should not be able to listen on the same port") + } + time.Sleep(time.Second / 4) + + err := s.Close() + if err != nil { + t.Fatal(err) + } + err = s.Close() + if err == nil { + t.Fatalf("expected an error") + } + }() + done := make(chan bool) + signal := make(chan error) + go func() { + defer func() { + done <- true + }() + err := <-signal + if err != nil { + t.Fatal(err) + } + c, err := net.Dial(network, laddr) + if err != nil { + t.Fatal(err) + } + defer c.Close() + do := func(cmd string) (string, error) { + io.WriteString(c, cmd) + buf := make([]byte, 1024) + n, err := c.Read(buf) + if err != nil { + return "", err + } + return string(buf[:n]), nil + } + res, err := do("PING\r\n") + if err != nil { + t.Fatal(err) + } + if res != "+PONG\r\n" { + t.Fatalf("expecting '+PONG\r\n', got '%v'", res) + } + res, err = do("BULK\r\n") + if err != nil { + t.Fatal(err) + } + if res != "$4\r\nbulk\r\n" { + t.Fatalf("expecting bulk, got '%v'", res) + } + res, err = do("BULKBYTES\r\n") + if err != nil { + t.Fatal(err) + } + if res != "$9\r\nbulkbytes\r\n" { + t.Fatalf("expecting bulkbytes, got '%v'", res) + } + res, err = do("INT\r\n") + if err != nil { + t.Fatal(err) + } + if res != ":100\r\n" { + t.Fatalf("expecting int, got '%v'", res) + } + res, err = do("NULL\r\n") + if err != nil { + t.Fatal(err) + } + if res != "$-1\r\n" { + t.Fatalf("expecting nul, got '%v'", res) + } + res, err = do("ARRAY\r\n") + if err != nil { + t.Fatal(err) + } + if res != "*2\r\n:99\r\n+Hi!\r\n" { + t.Fatalf("expecting array, got '%v'", res) + } + res, err = do("ERR\r\n") + if err != nil { + t.Fatal(err) + } + if res != "-ERR error\r\n" { + t.Fatalf("expecting array, got '%v'", res) + } + res, err = do("DETACH\r\n") + if err != nil { + t.Fatal(err) + } + if res != "+DETACHED\r\n" { + t.Fatalf("expecting string, got '%v'", res) + } + }() + go func() { + err := s.ListenServeAndSignal(signal) + if err != nil { + t.Fatal(err) + } + }() + <-done +} + +func TestWriter(t *testing.T) { + buf := &bytes.Buffer{} + wr := NewWriter(buf) + wr.WriteError("ERR bad stuff") + wr.Flush() + if buf.String() != "-ERR bad stuff\r\n" { + t.Fatal("failed") + } + buf.Reset() + wr.WriteString("HELLO") + wr.Flush() + if buf.String() != "+HELLO\r\n" { + t.Fatal("failed") + } + buf.Reset() + wr.WriteInt(-1234) + wr.Flush() + if buf.String() != ":-1234\r\n" { + t.Fatal("failed") + } + buf.Reset() + wr.WriteNull() + wr.Flush() + if buf.String() != "$-1\r\n" { + t.Fatal("failed") + } + buf.Reset() + wr.WriteBulk([]byte("HELLO\r\nPLANET")) + wr.Flush() + if buf.String() != "$13\r\nHELLO\r\nPLANET\r\n" { + t.Fatal("failed") + } + buf.Reset() + wr.WriteBulkString("HELLO\r\nPLANET") + wr.Flush() + if buf.String() != "$13\r\nHELLO\r\nPLANET\r\n" { + t.Fatal("failed") + } + buf.Reset() + wr.WriteArray(3) + wr.WriteBulkString("THIS") + wr.WriteBulkString("THAT") + wr.WriteString("THE OTHER THING") + wr.Flush() + if buf.String() != "*3\r\n$4\r\nTHIS\r\n$4\r\nTHAT\r\n+THE OTHER THING\r\n" { + t.Fatal("failed") + } + buf.Reset() +} +func testMakeRawCommands(rawargs [][]string) []string { + var rawcmds []string + for i := 0; i < len(rawargs); i++ { + rawcmd := "*" + strconv.FormatUint(uint64(len(rawargs[i])), 10) + "\r\n" + for j := 0; j < len(rawargs[i]); j++ { + rawcmd += "$" + strconv.FormatUint(uint64(len(rawargs[i][j])), 10) + "\r\n" + rawcmd += rawargs[i][j] + "\r\n" + } + rawcmds = append(rawcmds, rawcmd) + } + return rawcmds +} + +func TestReaderRespRandom(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + for h := 0; h < 10000; h++ { + var rawargs [][]string + for i := 0; i < 100; i++ { + var args []string + n := int(rand.Int() % 16) + for j := 0; j < n; j++ { + arg := make([]byte, rand.Int()%512) + rand.Read(arg) + args = append(args, string(arg)) + } + } + rawcmds := testMakeRawCommands(rawargs) + data := strings.Join(rawcmds, "") + rd := NewReader(bytes.NewBufferString(data)) + for i := 0; i < len(rawcmds); i++ { + if len(rawargs[i]) == 0 { + continue + } + cmd, err := rd.ReadCommand() + if err != nil { + t.Fatal(err) + } + if string(cmd.Raw) != rawcmds[i] { + t.Fatalf("expected '%v', got '%v'", rawcmds[i], string(cmd.Raw)) + } + if len(cmd.Args) != len(rawargs[i]) { + t.Fatalf("expected '%v', got '%v'", len(rawargs[i]), len(cmd.Args)) + } + for j := 0; j < len(rawargs[i]); j++ { + if string(cmd.Args[j]) != rawargs[i][j] { + t.Fatalf("expected '%v', got '%v'", rawargs[i][j], string(cmd.Args[j])) + } + } + } + } +} + +func TestPlainReader(t *testing.T) { + rawargs := [][]string{ + {"HELLO", "WORLD"}, + {"HELLO", "WORLD"}, + {"HELLO", "PLANET"}, + {"HELLO", "JELLO"}, + {"HELLO ", "JELLO"}, + } + rawcmds := []string{ + "HELLO WORLD\n", + "HELLO WORLD\r\n", + " HELLO PLANET \r\n", + " \"HELLO\" \"JELLO\" \r\n", + " \"HELLO \" JELLO \n", + } + rawres := []string{ + "*2\r\n$5\r\nHELLO\r\n$5\r\nWORLD\r\n", + "*2\r\n$5\r\nHELLO\r\n$5\r\nWORLD\r\n", + "*2\r\n$5\r\nHELLO\r\n$6\r\nPLANET\r\n", + "*2\r\n$5\r\nHELLO\r\n$5\r\nJELLO\r\n", + "*2\r\n$6\r\nHELLO \r\n$5\r\nJELLO\r\n", + } + data := strings.Join(rawcmds, "") + rd := NewReader(bytes.NewBufferString(data)) + for i := 0; i < len(rawcmds); i++ { + if len(rawargs[i]) == 0 { + continue + } + cmd, err := rd.ReadCommand() + if err != nil { + t.Fatal(err) + } + if string(cmd.Raw) != rawres[i] { + t.Fatalf("expected '%v', got '%v'", rawres[i], string(cmd.Raw)) + } + if len(cmd.Args) != len(rawargs[i]) { + t.Fatalf("expected '%v', got '%v'", len(rawargs[i]), len(cmd.Args)) + } + for j := 0; j < len(rawargs[i]); j++ { + if string(cmd.Args[j]) != rawargs[i][j] { + t.Fatalf("expected '%v', got '%v'", rawargs[i][j], string(cmd.Args[j])) + } + } + } +} + +func TestParse(t *testing.T) { + _, err := Parse(nil) + if err != errIncompleteCommand { + t.Fatalf("expected '%v', got '%v'", errIncompleteCommand, err) + } + _, err = Parse([]byte("*1\r\n")) + if err != errIncompleteCommand { + t.Fatalf("expected '%v', got '%v'", errIncompleteCommand, err) + } + _, err = Parse([]byte("*-1\r\n")) + if err != errInvalidMultiBulkLength { + t.Fatalf("expected '%v', got '%v'", errInvalidMultiBulkLength, err) + } + _, err = Parse([]byte("*0\r\n")) + if err != errInvalidMultiBulkLength { + t.Fatalf("expected '%v', got '%v'", errInvalidMultiBulkLength, err) + } + cmd, err := Parse([]byte("*1\r\n$1\r\nA\r\n")) + if err != nil { + t.Fatal(err) + } + if string(cmd.Raw) != "*1\r\n$1\r\nA\r\n" { + t.Fatalf("expected '%v', got '%v'", "*1\r\n$1\r\nA\r\n", string(cmd.Raw)) + } + if len(cmd.Args) != 1 { + t.Fatalf("expected '%v', got '%v'", 1, len(cmd.Args)) + } + if string(cmd.Args[0]) != "A" { + t.Fatalf("expected '%v', got '%v'", "A", string(cmd.Args[0])) + } + cmd, err = Parse([]byte("A\r\n")) + if err != nil { + t.Fatal(err) + } + if string(cmd.Raw) != "*1\r\n$1\r\nA\r\n" { + t.Fatalf("expected '%v', got '%v'", "*1\r\n$1\r\nA\r\n", string(cmd.Raw)) + } + if len(cmd.Args) != 1 { + t.Fatalf("expected '%v', got '%v'", 1, len(cmd.Args)) + } + if string(cmd.Args[0]) != "A" { + t.Fatalf("expected '%v', got '%v'", "A", string(cmd.Args[0])) + } +}