From f9ea3f1e13d95f2125549e97997c5cb4946a04be Mon Sep 17 00:00:00 2001 From: tidwall Date: Mon, 28 Oct 2019 13:22:52 -0700 Subject: [PATCH] Fixed tile38-cli not propertly handling quotes closes #500 --- cmd/tile38-cli/internal/client/README.md | 16 - cmd/tile38-cli/internal/client/conn.go | 322 -------------------- cmd/tile38-cli/internal/client/conn_test.go | 57 ---- cmd/tile38-cli/internal/client/helper.go | 49 --- cmd/tile38-cli/internal/client/pool.go | 94 ------ cmd/tile38-cli/internal/client/pool_test.go | 67 ---- cmd/tile38-cli/main.go | 185 +++++++++-- 7 files changed, 153 insertions(+), 637 deletions(-) delete mode 100644 cmd/tile38-cli/internal/client/README.md delete mode 100644 cmd/tile38-cli/internal/client/conn.go delete mode 100644 cmd/tile38-cli/internal/client/conn_test.go delete mode 100644 cmd/tile38-cli/internal/client/helper.go delete mode 100644 cmd/tile38-cli/internal/client/pool.go delete mode 100644 cmd/tile38-cli/internal/client/pool_test.go diff --git a/cmd/tile38-cli/internal/client/README.md b/cmd/tile38-cli/internal/client/README.md deleted file mode 100644 index fdf55772..00000000 --- a/cmd/tile38-cli/internal/client/README.md +++ /dev/null @@ -1,16 +0,0 @@ -Tile38 Client -============= - -[![Build Status](https://travis-ci.org/tidwall/tile38.svg?branch=master)](https://travis-ci.org/tidwall/tile38) -[![GoDoc](https://godoc.org/github.com/tidwall/tile38/client?status.svg)](https://godoc.org/github.com/tidwall/tile38/client) - -Tile38 Client is a [Go](http://golang.org/) client for [Tile38](http://tile38.com/). - -THIS LIBRARY IS DEPRECATED -========================== - -Please use the [redigo](https://github.com/garyburd/redigo) client library instead. -If you need JSON output with Redigo then call: -``` -conn.Do("OUTPUT", "JSON") -``` diff --git a/cmd/tile38-cli/internal/client/conn.go b/cmd/tile38-cli/internal/client/conn.go deleted file mode 100644 index bf2a447e..00000000 --- a/cmd/tile38-cli/internal/client/conn.go +++ /dev/null @@ -1,322 +0,0 @@ -package client - -import ( - "bufio" - "bytes" - "crypto/sha1" - "encoding/base64" - "encoding/binary" - "errors" - "io" - "net" - "net/url" - "strconv" - "strings" - "time" -) - -// LiveJSON is the value returned when a connection goes "live". -const LiveJSON = `{"ok":true,"live":true}` - -// MaxMessageSize is maximum accepted message size -const MaxMessageSize = 0x1FFFFFFF // 536,870,911 bytes - -// Proto is the protocol value. -type Proto int - -const ( - Native Proto = 0 // native protocol - Telnet Proto = 1 // telnet protocol - HTTP Proto = 2 // http protocol - WebSocket Proto = 3 // websocket protocol -) - -// Conn represents a connection to a tile38 server. -type Conn struct { - c net.Conn - rd *bufio.Reader - pool *Pool - detached bool -} - -// Dial connects to a tile38 server. -func Dial(addr string) (*Conn, error) { - c, err := net.Dial("tcp", addr) - if err != nil { - return nil, err - } - return &Conn{c: c, rd: bufio.NewReader(c)}, nil -} - -// DialTimeout connects to a tile38 server with a timeout. -func DialTimeout(addr string, timeout time.Duration) (*Conn, error) { - c, err := net.DialTimeout("tcp", addr, timeout) - if err != nil { - return nil, err - } - return &Conn{c: c, rd: bufio.NewReader(c)}, nil -} - -// Close will close a connection. -func (conn *Conn) Close() error { - if conn.pool == nil { - if !conn.detached { - conn.Do("QUIT") - } - return conn.c.Close() - } - return conn.pool.put(conn) -} - -// SetDeadline sets the connection deadline for reads and writes. -func (conn *Conn) SetDeadline(t time.Time) error { - return conn.c.SetDeadline(t) -} - -// SetDeadline sets the connection deadline for reads. -func (conn *Conn) SetReadDeadline(t time.Time) error { - return conn.c.SetReadDeadline(t) -} - -// SetDeadline sets the connection deadline for writes. -func (conn *Conn) SetWriteDeadline(t time.Time) error { - return conn.c.SetWriteDeadline(t) -} - -// Do sends a command to the server and returns the received reply. -func (conn *Conn) Do(command string) ([]byte, error) { - if err := WriteMessage(conn.c, []byte(command)); err != nil { - conn.pool = nil - return nil, err - } - message, _, _, err := ReadMessage(conn.rd, nil) - if err != nil { - conn.pool = nil - return nil, err - } - if string(message) == LiveJSON { - conn.pool = nil // detach from pool - } - return message, nil -} - -// ReadMessage returns the next message. Used when reading live connections -func (conn *Conn) ReadMessage() (message []byte, err error) { - message, _, _, err = readMessage(conn.c, conn.rd) - if err != nil { - conn.pool = nil - return message, err - } - return message, nil -} - -// Reader returns the underlying reader. -func (conn *Conn) Reader() io.Reader { - conn.pool = nil // Remove from the pool because once the reader is called - conn.detached = true // we will assume that this connection is detached. - return conn.rd -} - -// WriteMessage write a message to an io.Writer -func WriteMessage(w io.Writer, message []byte) error { - h := []byte("$" + strconv.FormatUint(uint64(len(message)), 10) + " ") - b := make([]byte, len(h)+len(message)+2) - copy(b, h) - copy(b[len(h):], message) - b[len(b)-2] = '\r' - b[len(b)-1] = '\n' - _, err := w.Write(b) - return err -} - -// WriteHTTP writes an http message to the connection and closes the connection. -func WriteHTTP(conn net.Conn, data []byte) error { - var buf bytes.Buffer - buf.WriteString("HTTP/1.1 200 OK\r\n") - buf.WriteString("Content-Length: " + strconv.FormatInt(int64(len(data))+1, 10) + "\r\n") - buf.WriteString("Content-Type: application/json\r\n") - buf.WriteString("Connection: close\r\n") - buf.WriteString("\r\n") - buf.Write(data) - buf.WriteByte('\n') - _, err := conn.Write(buf.Bytes()) - return err -} - -// WriteWebSocket writes a websocket message. -func WriteWebSocket(conn net.Conn, data []byte) error { - var msg []byte - buf := make([]byte, 10+len(data)) - buf[0] = 129 // FIN + TEXT - if len(data) <= 125 { - buf[1] = byte(len(data)) - copy(buf[2:], data) - msg = buf[:2+len(data)] - } else if len(data) <= 0xFFFF { - buf[1] = 126 - binary.BigEndian.PutUint16(buf[2:], uint16(len(data))) - copy(buf[4:], data) - msg = buf[:4+len(data)] - } else { - buf[1] = 127 - binary.BigEndian.PutUint64(buf[2:], uint64(len(data))) - copy(buf[10:], data) - msg = buf[:10+len(data)] - } - _, err := conn.Write(msg) - return err -} - -// ReadMessage reads the next message from a bufio.Reader. -func readMessage(wr io.Writer, rd *bufio.Reader) (message []byte, proto Proto, auth string, err error) { - h, err := rd.Peek(1) - if err != nil { - return nil, proto, auth, err - } - switch h[0] { - case '$': - return readProtoMessage(rd) - } - message, proto, err = readTelnetMessage(rd) - if err != nil { - return nil, proto, auth, err - } - if len(message) > 6 && string(message[len(message)-9:len(message)-2]) == " HTTP/1" { - return readHTTPMessage(string(message), wr, rd) - } - return message, proto, auth, nil - -} - -// ReadMessage read the next message from a bufio Reader. -func ReadMessage(rd *bufio.Reader, wr io.Writer) (message []byte, proto Proto, auth string, err error) { - return readMessage(wr, rd) -} - -func readProtoMessage(rd *bufio.Reader) (message []byte, proto Proto, auth string, err error) { - b, err := rd.ReadBytes(' ') - if err != nil { - return nil, Native, auth, err - } - if len(b) > 0 && b[0] != '$' { - return nil, Native, auth, errors.New("not a proto message") - } - n, err := strconv.ParseUint(string(b[1:len(b)-1]), 10, 32) - if err != nil { - return nil, Native, auth, errors.New("invalid size") - } - if n > MaxMessageSize { - return nil, Native, auth, errors.New("message too big") - } - b = make([]byte, int(n)+2) - if _, err := io.ReadFull(rd, b); err != nil { - return nil, Native, auth, err - } - if b[len(b)-2] != '\r' || b[len(b)-1] != '\n' { - return nil, Native, auth, errors.New("expecting crlf suffix") - } - return b[:len(b)-2], Native, auth, nil -} - -func readTelnetMessage(rd *bufio.Reader) (command []byte, proto Proto, err error) { - line, err := rd.ReadBytes('\n') - if err != nil { - return nil, Telnet, err - } - if len(line) > 1 && line[len(line)-2] == '\r' { - line = line[:len(line)-2] - } else { - line = line[:len(line)-1] - } - return line, Telnet, nil -} - -func readHTTPMessage(line string, wr io.Writer, rd *bufio.Reader) (command []byte, proto Proto, auth string, err error) { - proto = HTTP - parts := strings.Split(line, " ") - if len(parts) != 3 { - err = errors.New("invalid HTTP request") - return - } - method := parts[0] - path := parts[1] - if len(path) == 0 || path[0] != '/' { - err = errors.New("invalid HTTP request") - return - } - path, err = url.QueryUnescape(path[1:]) - if err != nil { - err = errors.New("invalid HTTP request") - return - } - if method != "GET" && method != "POST" { - err = errors.New("invalid HTTP method") - return - } - contentLength := 0 - websocket := false - websocketVersion := 0 - websocketKey := "" - for { - var b []byte - b, _, err = readTelnetMessage(rd) // read a header line - if err != nil { - return - } - header := string(b) - if header == "" { - break // end of headers - } - if header[0] == 'a' || header[0] == 'A' { - if strings.HasPrefix(strings.ToLower(header), "authorization:") { - auth = strings.TrimSpace(header[len("authorization:"):]) - } - } else if header[0] == 'u' || header[0] == 'U' { - if strings.HasPrefix(strings.ToLower(header), "upgrade:") && strings.ToLower(strings.TrimSpace(header[len("upgrade:"):])) == "websocket" { - websocket = true - } - } else if header[0] == 's' || header[0] == 'S' { - if strings.HasPrefix(strings.ToLower(header), "sec-websocket-version:") { - var n uint64 - n, err = strconv.ParseUint(strings.TrimSpace(header[len("sec-websocket-version:"):]), 10, 64) - if err != nil { - return - } - websocketVersion = int(n) - } else if strings.HasPrefix(strings.ToLower(header), "sec-websocket-key:") { - websocketKey = strings.TrimSpace(header[len("sec-websocket-key:"):]) - } - } else if header[0] == 'c' || header[0] == 'C' { - if strings.HasPrefix(strings.ToLower(header), "content-length:") { - var n uint64 - n, err = strconv.ParseUint(strings.TrimSpace(header[len("content-length:"):]), 10, 64) - if err != nil { - return - } - contentLength = int(n) - } - } - } - if websocket && websocketVersion >= 13 && websocketKey != "" { - proto = WebSocket - if wr == nil { - err = errors.New("connection is nil") - return - } - sum := sha1.Sum([]byte(websocketKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) - accept := base64.StdEncoding.EncodeToString(sum[:]) - wshead := "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + "\r\n\r\n" - if _, err = wr.Write([]byte(wshead)); err != nil { - return - } - } else if contentLength > 0 { - proto = HTTP - buf := make([]byte, contentLength) - if _, err = io.ReadFull(rd, buf); err != nil { - return - } - path += string(buf) - } - command = []byte(path) - return -} diff --git a/cmd/tile38-cli/internal/client/conn_test.go b/cmd/tile38-cli/internal/client/conn_test.go deleted file mode 100644 index 2e5521eb..00000000 --- a/cmd/tile38-cli/internal/client/conn_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package client - -import ( - "fmt" - "log" - "time" -) - -func ExampleDial() { - conn, err := Dial("localhost:9851") - if err != nil { - log.Fatal(err) - } - defer conn.Close() - resp, err := conn.Do("set fleet truck1 point 33.5123 -112.2693") - if err != nil { - log.Fatal(err) - } - fmt.Println(string(resp)) -} - -func ExampleDialPool() { - pool, err := DialPool("localhost:9851") - if err != nil { - log.Fatal(err) - } - defer pool.Close() - - // We'll set a point in a background routine - go func() { - conn, err := pool.Get() // get a conn from the pool - if err != nil { - log.Fatal(err) - } - defer conn.Close() // return the conn to the pool - _, err = conn.Do("set fleet truck1 point 33.5123 -112.2693") - if err != nil { - log.Fatal(err) - } - }() - time.Sleep(time.Second / 2) // wait a moment - - // Retrieve the point we just set. - go func() { - conn, err := pool.Get() // get a conn from the pool - if err != nil { - log.Fatal(err) - } - defer conn.Close() // return the conn to the pool - resp, err := conn.Do("get fleet truck1 point") - if err != nil { - log.Fatal(err) - } - fmt.Println(string(resp)) - }() - time.Sleep(time.Second / 2) // wait a moment -} diff --git a/cmd/tile38-cli/internal/client/helper.go b/cmd/tile38-cli/internal/client/helper.go deleted file mode 100644 index c60ca465..00000000 --- a/cmd/tile38-cli/internal/client/helper.go +++ /dev/null @@ -1,49 +0,0 @@ -package client - -import ( - "encoding/json" - "errors" -) - -// Standard represents a standard tile38 message. -type Standard struct { - OK bool `json:"ok"` - Err string `json:"err"` - Elapsed string `json:"elapsed"` -} - -// ServerStats represents tile38 server statistics. -type ServerStats struct { - Standard - Stats struct { - ServerID string `json:"id"` - Following string `json:"following"` - AOFSize int `json:"aof_size"` - NumCollections int `json:"num_collections"` - InMemorySize int `json:"in_memory_size"` - NumPoints int `json:"num_points"` - NumObjects int `json:"num_objects"` - HeapSize int `json:"heap_size"` - AvgItemSize int `json:"avg_item_size"` - PointerSize int `json:"pointer_size"` - } `json:"stats"` -} - -// Server returns tile38 server statistics. -func (conn *Conn) Server() (ServerStats, error) { - var stats ServerStats - msg, err := conn.Do("server") - if err != nil { - return stats, err - } - if err := json.Unmarshal(msg, &stats); err != nil { - return stats, err - } - if !stats.OK { - if stats.Err != "" { - return stats, errors.New(stats.Err) - } - return stats, errors.New("not ok") - } - return stats, nil -} diff --git a/cmd/tile38-cli/internal/client/pool.go b/cmd/tile38-cli/internal/client/pool.go deleted file mode 100644 index 70ecc152..00000000 --- a/cmd/tile38-cli/internal/client/pool.go +++ /dev/null @@ -1,94 +0,0 @@ -package client - -import ( - "errors" - "fmt" - "math/rand" - "sync" - "time" -) - -const dialTimeout = time.Second * 3 -const pingTimeout = time.Second - -// Pool represents a pool of tile38 connections. -type Pool struct { - mu sync.Mutex - conns []*Conn - addr string - closed bool -} - -// DialPool creates a new pool with 5 initial connections to the specified tile38 server. -func DialPool(addr string) (*Pool, error) { - pool := &Pool{ - addr: addr, - } - // create some connections. 5 is a good start - var tconns []*Conn - for i := 0; i < 5; i++ { - conn, err := pool.Get() - if err != nil { - pool.Close() - return nil, fmt.Errorf("unable to fill pool: %s", err) - } - tconns = append(tconns, conn) - } - pool.conns = tconns - return pool, nil -} - -// Close releases the resources used by the pool. -func (pool *Pool) Close() error { - pool.mu.Lock() - defer pool.mu.Unlock() - if pool.closed { - return errors.New("pool closed") - } - pool.closed = true - for _, conn := range pool.conns { - conn.pool = nil - conn.Close() - } - pool.conns = nil - return nil -} - -// Get borrows a connection. When the connection closes, the application returns it to the pool. -func (pool *Pool) Get() (*Conn, error) { - pool.mu.Lock() - defer pool.mu.Unlock() - for len(pool.conns) != 0 { - i := rand.Int() % len(pool.conns) - conn := pool.conns[i] - pool.conns = append(pool.conns[:i], pool.conns[i+1:]...) - // Ping to test on borrow. - conn.SetDeadline(time.Now().Add(pingTimeout)) - if _, err := conn.Do("PING"); err != nil { - conn.pool = nil - conn.Close() - continue - } - conn.SetDeadline(time.Time{}) - return conn, nil - } - conn, err := DialTimeout(pool.addr, dialTimeout) - if err != nil { - return nil, err - } - conn.pool = pool - return conn, nil -} - -func (pool *Pool) put(conn *Conn) error { - pool.mu.Lock() - defer pool.mu.Unlock() - if pool.closed { - return errors.New("pool closed") - } - conn.SetDeadline(time.Time{}) - conn.SetReadDeadline(time.Time{}) - conn.SetWriteDeadline(time.Time{}) - pool.conns = append(pool.conns, conn) - return nil -} diff --git a/cmd/tile38-cli/internal/client/pool_test.go b/cmd/tile38-cli/internal/client/pool_test.go deleted file mode 100644 index c7bc98e6..00000000 --- a/cmd/tile38-cli/internal/client/pool_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package client - -import ( - "encoding/json" - "fmt" - "math/rand" - "strings" - "sync" - "testing" - "time" -) - -func TestPool(t *testing.T) { - rand.Seed(time.Now().UnixNano()) - pool, err := DialPool("localhost:9876") - if err != nil { - t.Fatal(err) - } - defer pool.Close() - var wg sync.WaitGroup - wg.Add(25) - for i := 0; i < 25; i++ { - go func(i int) { - defer func() { - wg.Done() - }() - conn, err := pool.Get() - if err != nil { - t.Fatal(err) - } - defer conn.Close() - msg, err := conn.Do("PING") - if err != nil { - t.Fatal(err) - } - var m map[string]interface{} - if err := json.Unmarshal([]byte(msg), &m); err != nil { - t.Fatal(err) - } - if ok1, ok2 := m["ok"].(bool); !ok1 || !ok2 { - t.Fatal("not ok") - } - if pong, ok := m["ping"].(string); !ok || pong != "pong" { - t.Fatal("not pong") - } - defer conn.Do(fmt.Sprintf("drop test:%d", i)) - msg, err = conn.Do(fmt.Sprintf("drop test:%d", i)) - if err != nil { - t.Fatal(err) - } - if !strings.HasPrefix(string(msg), `{"ok":true`) { - t.Fatal("expecting OK:TRUE response") - } - for j := 0; j < 100; j++ { - lat, lon := rand.Float64()*180-90, rand.Float64()*360-180 - msg, err = conn.Do(fmt.Sprintf("set test:%d %d point %f %f", i, j, lat, lon)) - if err != nil { - t.Fatal(err) - } - if !strings.HasPrefix(string(msg), `{"ok":true`) { - t.Fatal("expecting OK:TRUE response") - } - } - }(i) - } - wg.Wait() -} diff --git a/cmd/tile38-cli/main.go b/cmd/tile38-cli/main.go index 1d13d141..52b6111a 100644 --- a/cmd/tile38-cli/main.go +++ b/cmd/tile38-cli/main.go @@ -1,8 +1,10 @@ package main import ( + "bufio" "bytes" "encoding/json" + "errors" "fmt" "io" "net" @@ -16,7 +18,6 @@ import ( "github.com/peterh/liner" "github.com/tidwall/gjson" "github.com/tidwall/resp" - "github.com/tidwall/tile38/cmd/tile38-cli/internal/client" "github.com/tidwall/tile38/core" ) @@ -64,6 +65,7 @@ func showHelp() bool { fmt.Fprintf(os.Stdout, " --noprompt Do not display a prompt\n") fmt.Fprintf(os.Stdout, " --tty Force TTY\n") fmt.Fprintf(os.Stdout, " --resp Use RESP output formatting (default is JSON output)\n") + fmt.Fprintf(os.Stdout, " --json Use JSON output formatting (default is JSON output)\n") fmt.Fprintf(os.Stdout, " -h Server hostname (default: %s)\n", hostname) fmt.Fprintf(os.Stdout, " -p Server port (default: %d)\n", port) fmt.Fprintf(os.Stdout, "\n") @@ -113,6 +115,8 @@ func parseArgs() bool { noprompt = true case "--resp": output = "resp" + case "--json": + output = "json" case "-h": hostname = readArg(arg) case "-p": @@ -156,10 +160,10 @@ func main() { } addr := fmt.Sprintf("%s:%d", hostname, port) - var conn *client.Conn + var conn *client connDial := func() { var err error - conn, err = client.Dial(addr) + conn, err = clientDial("tcp", addr) if err != nil { if _, ok := err.(net.Error); ok { fmt.Fprintln(os.Stderr, refusedErrorString(addr)) @@ -167,15 +171,9 @@ func main() { fmt.Fprintln(os.Stderr, err.Error()) os.Exit(1) } - } - if conn != nil { - if output == "resp" { - _, err := conn.Do("output resp") - if err != nil { - fmt.Fprintln(os.Stderr, err.Error()) - os.Exit(1) - } - } + } else if _, err := conn.Do("output " + output); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) } } connDial() @@ -190,10 +188,17 @@ func main() { } else { var msg []byte for { - msg, err = conn.ReadMessage() + msg, err = conn.readLiveResp() if err != nil { break } + if !raw { + if output == "resp" { + msg = convert2termresp(msg) + } else { + msg = convert2termjson(msg) + } + } fmt.Fprintln(os.Stderr, string(msg)) } } @@ -343,36 +348,40 @@ func main() { output = "json" } } + if output == "resp" && + (strings.HasPrefix(string(msg), "*3\r\n$10\r\npsubscribe\r\n") || + strings.HasPrefix(string(msg), "*3\r\n$9\r\nsubscribe\r\n")) { + livemode = true + } + if !raw { + if output == "resp" { + msg = convert2termresp(msg) + } else { + msg = convert2termjson(msg) + } + } + + if !livemode && output == "json" { + if gjson.GetBytes(msg, "command").String() == "psubscribe" || + gjson.GetBytes(msg, "command").String() == "subscribe" || + string(msg) == liveJSON { + livemode = true + } + } mustOutput := true - - if oneCommand == "" && !jsonOK(msg) { + if oneCommand == "" && output == "json" && !jsonOK(msg) { var cerr connError if err := json.Unmarshal(msg, &cerr); err == nil { fmt.Fprintln(os.Stderr, "(error) "+cerr.Err) mustOutput = false } - } else if gjson.GetBytes(msg, "command").String() == "psubscribe" || - gjson.GetBytes(msg, "command").String() == "subscribe" || - string(msg) == client.LiveJSON { + } else if livemode { fmt.Fprintln(os.Stderr, string(msg)) - livemode = true break // break out of prompt and just feed data to screen } if mustOutput { - if output == "resp" { - if !raw { - msg = convert2termresp(msg) - } - fmt.Fprintln(os.Stdout, string(msg)) - } else { - msg = bytes.TrimSpace(msg) - if raw { - fmt.Fprintln(os.Stdout, string(msg)) - } else { - fmt.Fprintln(os.Stdout, string(msg)) - } - } + fmt.Fprintln(os.Stdout, string(msg)) } } } else if err == liner.ErrPromptAborted { @@ -401,6 +410,13 @@ func convert2termresp(msg []byte) []byte { return []byte(strings.TrimSpace(out)) } +func convert2termjson(msg []byte) []byte { + if msg[0] == '{' { + return msg + } + return bytes.TrimSpace(msg[bytes.IndexByte(msg, '\n')+1:]) +} + func convert2termrespval(v resp.Value, spaces int) string { switch v.Type() { default: @@ -495,3 +511,108 @@ func help(arg string) error { } return nil } + +const liveJSON = `{"ok":true,"live":true}` + +type client struct { + wr io.Writer + rd *bufio.Reader +} + +func clientDial(network, addr string) (*client, error) { + conn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + return &client{wr: conn, rd: bufio.NewReader(conn)}, nil +} + +func (c *client) Do(command string) ([]byte, error) { + _, err := c.wr.Write([]byte(command + "\r\n")) + if err != nil { + return nil, err + } + return c.readResp() +} + +func (c *client) readResp() ([]byte, error) { + ch, err := c.rd.Peek(1) + if err != nil { + return nil, err + } + switch ch[0] { + case ':', '+', '-', '{': + return c.readLine() + case '$': + return c.readBulk() + case '*': + return c.readArray() + default: + return nil, fmt.Errorf("invalid response character '%c", ch[0]) + } +} + +func (c *client) readArray() ([]byte, error) { + out, err := c.readLine() + if err != nil { + return nil, err + } + n, err := strconv.ParseUint(string(bytes.TrimSpace(out[1:])), 10, 64) + if err != nil { + return nil, err + } + for i := 0; i < int(n); i++ { + resp, err := c.readResp() + if err != nil { + return nil, err + } + out = append(out, resp...) + } + return out, nil +} + +func (c *client) readBulk() ([]byte, error) { + line, err := c.readLine() + if err != nil { + return nil, err + } + x, err := strconv.ParseInt(string(bytes.TrimSpace(line[1:])), 10, 64) + if err != nil { + return nil, err + } + if x < 0 { + return line, nil + } + out := make([]byte, len(line)+int(x)+2) + if _, err := io.ReadFull(c.rd, out[len(line):]); err != nil { + return nil, err + } + if !bytes.HasSuffix(out, []byte{'\r', '\n'}) { + return nil, errors.New("invalid response") + } + copy(out, line) + return out, nil +} + +func (c *client) readLine() ([]byte, error) { + line, err := c.rd.ReadBytes('\r') + if err != nil { + return nil, err + } + ch, err := c.rd.ReadByte() + if err != nil { + return nil, err + } + if ch != '\n' { + return nil, errors.New("invalid response") + } + return append(line, '\n'), nil +} + +func (c *client) Reader() io.Reader { + return c.rd +} + +func (c *client) readLiveResp() (message []byte, err error) { + return c.readResp() +}