diff --git a/cmd/tile38-cli/main.go b/cmd/tile38-cli/main.go index 9faeb6ff..2823547e 100644 --- a/cmd/tile38-cli/main.go +++ b/cmd/tile38-cli/main.go @@ -276,6 +276,13 @@ func main() { } }() for { + if conn == nil { + connDial() + if conn == nil { + continue + } + } + var command string var err error if oneCommand == "" { @@ -299,11 +306,14 @@ func main() { if conn != nil { _, err := conn.Do("pInG") if err != nil { - if err != io.EOF { + if err != io.EOF && !strings.Contains(err.Error(), "broken pipe") { fmt.Fprintln(os.Stderr, err.Error()) - return + } else { + fmt.Fprintln(os.Stderr, refusedErrorString(addr)) } - fmt.Fprintln(os.Stderr, refusedErrorString(addr)) + conn.wr.Close() + conn = nil + continue } } } else { @@ -335,11 +345,10 @@ func main() { if err != nil { if err != io.EOF { fmt.Fprintln(os.Stderr, err.Error()) - } else { - conn = nil - goto tryAgain } - return + conn.wr.Close() + conn = nil + goto tryAgain } switch strings.ToLower(command) { case "output resp": @@ -518,7 +527,7 @@ func help(arg string) error { const liveJSON = `{"ok":true,"live":true}` type client struct { - wr io.Writer + wr net.Conn rd *bufio.Reader } @@ -531,7 +540,7 @@ func clientDial(network, addr string) (*client, error) { } func (c *client) Do(command string) ([]byte, error) { - _, err := c.wr.Write([]byte(command + "\r\n")) + _, err := c.wr.Write(plainToCompat(command)) if err != nil { return nil, err } @@ -619,3 +628,51 @@ func (c *client) Reader() io.Reader { func (c *client) readLiveResp() (message []byte, err error) { return c.readResp() } + +// plainToCompat converts a plain message like "SET fleet truck1 ..." into a +// Tile38 compatible blob. +func plainToCompat(message string) []byte { + var args []string + // search for the beginning of the first argument + for i := 0; i < len(message); i++ { + if message[i] != ' ' { + // first argument found + if message[i] == '"' || message[i] == '\'' { + // using a string caps + s := i + cap := message[i] + for ; i < len(message); i++ { + if message[i] == cap { + if message[i-1] == '\\' { + continue + } + if i == len(message)-1 || message[i+1] == ' ' { + args = append(args, message[s:i+1]) + i++ + break + } + } + } + } else { + // using plain string, terminated by a space + s := i + var quotes bool + for ; i < len(message); i++ { + if message[i] == '"' || message[i] == '\'' { + quotes = true + } + if i == len(message)-1 || message[i+1] == ' ' { + arg := message[s : i+1] + if quotes { + arg = strconv.Quote(arg) + } + args = append(args, arg) + i++ + break + } + } + } + } + } + return []byte(strings.Join(args, " ") + "\r\n") +} diff --git a/internal/server/server.go b/internal/server/server.go index c3c097ad..f3939e50 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -348,6 +348,9 @@ func (server *Server) netServe() error { conn.Close() }() + var lastConnType Type + var lastOutputType Type + // check if the connection is protected if !strings.HasPrefix(client.remoteAddr, "127.0.0.1:") && !strings.HasPrefix(client.remoteAddr, "[::1]:") { @@ -376,10 +379,6 @@ func (server *Server) netServe() error { pr.wr = client msgs, err := pr.ReadMessages() - if err != nil { - log.Error(err) - return // close connection - } for _, msg := range msgs { // Just closing connection if we have deprecated HTTP or WS connection, // And --http-transport = false @@ -457,6 +456,8 @@ func (server *Server) netServe() error { close = true // close connection break } + lastOutputType = msg.OutputType + lastConnType = msg.ConnType } packet = packet[len(packet)-rdbuf.Len():] @@ -475,11 +476,26 @@ func (server *Server) netServe() error { } conn.Write(client.out) client.out = nil - } if close { break } + if err != nil { + log.Error(err) + if lastConnType == RESP { + var value resp.Value + switch lastOutputType { + case JSON: + value = resp.StringValue(`{"ok":false,"err":` + + jsonString(err.Error()) + "}") + case RESP: + value = resp.ErrorValue(err) + } + bytes, _ := value.MarshalRESP() + conn.Write(bytes) + } + break // close connection + } } }(conn) } @@ -1351,8 +1367,10 @@ moreData: } for len(data) > 0 { msg := &Message{} - complete, args, kind, leftover, err := readNextCommand(data, nil, msg, rd.wr) - if err != nil { + complete, args, kind, leftover, err2 := + readNextCommand(data, nil, msg, rd.wr) + if err2 != nil { + err = err2 break } if !complete { @@ -1387,10 +1405,7 @@ moreData: } else if len(rd.buf) > 0 { rd.buf = rd.buf[:0] } - if err != nil && len(msgs) == 0 { - return nil, err - } - return msgs, nil + return msgs, err } func readNativeMessageLine(line []byte) (*Message, error) {