diff --git a/cmd/tile38-server/main.go b/cmd/tile38-server/main.go index 784657e6..20658aea 100644 --- a/cmd/tile38-server/main.go +++ b/cmd/tile38-server/main.go @@ -35,7 +35,7 @@ var ( ) // TODO: Set to false in 2.* -var httpTransport bool = true +var httpTransport = true // Fire up a webhook test server by using the --webhook-http-consumer-port // for example @@ -47,7 +47,7 @@ var httpTransport bool = true type hserver struct{} func (s *hserver) Send(ctx context.Context, in *hservice.MessageRequest) (*hservice.MessageReply, error) { - return &hservice.MessageReply{true}, nil + return &hservice.MessageReply{Ok: true}, nil } func main() { diff --git a/internal/server/server.go b/internal/server/server.go index fdd0cb09..7915b595 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -36,6 +36,8 @@ import ( "github.com/tidwall/tile38/internal/log" ) +const useEvio = false + var errOOM = errors.New("OOM command not allowed when used memory > 'maxmemory'") const goingLive = "going live" @@ -270,7 +272,10 @@ func Serve(host string, port int, dir string, http bool) error { }() // Start the network server - return server.evioServe() + if useEvio { + return server.evioServe() + } + return server.netServe() } func (server *Server) isProtected() bool { @@ -481,6 +486,188 @@ func (server *Server) evioServe() error { return evio.Serve(events, fmt.Sprintf("%s:%d", server.host, server.port)) } +func (server *Server) netServe() error { + ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", server.host, server.port)) + if err != nil { + return err + } + defer ln.Close() + log.Infof("Ready to accept connections at %s", ln.Addr()) + var clientID int64 + for { + conn, err := ln.Accept() + if err != nil { + return err + } + + go func(conn net.Conn) { + // open connection + // create the client + client := new(Client) + client.id = int(atomic.AddInt64(&clientID, 1)) + client.opened = time.Now() + client.remoteAddr = conn.RemoteAddr().String() + + // add client to server map + server.connsmu.Lock() + server.conns[client.id] = client + server.connsmu.Unlock() + server.statsTotalConns.add(1) + + // set the client keep-alive, if needed + if server.config.keepAlive() > 0 { + if conn, ok := conn.(*net.TCPConn); ok { + conn.SetKeepAlive(true) + conn.SetKeepAlivePeriod( + time.Duration(server.config.keepAlive()) * time.Second, + ) + } + } + log.Debugf("Opened connection: %s", client.remoteAddr) + + defer func() { + // close connection + // delete from server map + server.connsmu.Lock() + delete(server.conns, client.id) + server.connsmu.Unlock() + log.Debugf("Closed connection: %s", client.remoteAddr) + conn.Close() + }() + + // check if the connection is protected + if !strings.HasPrefix(client.remoteAddr, "127.0.0.1:") && + !strings.HasPrefix(client.remoteAddr, "[::1]:") { + if server.isProtected() { + // This is a protected server. Only loopback is allowed. + conn.Write(deniedMessage) + return // close connection + } + } + packet := make([]byte, 0xFFFF) + for { + var close bool + n, err := conn.Read(packet) + if err != nil { + return + } + in := packet[:n] + + // read the payload packet from the client input stream. + packet := client.in.Begin(in) + + // load the pipeline reader + pr := &client.pr + rdbuf := bytes.NewBuffer(packet) + pr.rd = rdbuf + pr.wr = client + + msgs, err := pr.ReadMessages() + if err != nil { + log.Error(err) + return // close connection + } + for _, msg := range msgs { + // Just closing connection if we have deprecated HTTP or WS connection, + // And --http-transport = false + if !server.http && (msg.ConnType == WebSocket || + msg.ConnType == HTTP) { + close = true // close connection + break + } + if msg != nil && msg.Command() != "" { + if client.outputType != Null { + msg.OutputType = client.outputType + } + if msg.Command() == "quit" { + if msg.OutputType == RESP { + io.WriteString(client, "+OK\r\n") + } + close = true // close connection + break + } + + // increment last used + client.mu.Lock() + client.last = time.Now() + client.mu.Unlock() + + // update total command count + server.statsTotalCommands.add(1) + + // handle the command + err := server.handleInputCommand(client, msg) + if err != nil { + if err.Error() == goingLive { + client.goLiveErr = err + client.goLiveMsg = msg + // detach + var rwc io.ReadWriteCloser = conn + client.conn = rwc + if len(client.out) > 0 { + client.conn.Write(client.out) + client.out = nil + } + client.in = evio.InputStream{} + client.pr.rd = rwc + client.pr.wr = rwc + log.Debugf("Detached connection: %s", client.remoteAddr) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + err := server.goLive( + client.goLiveErr, + &liveConn{conn.RemoteAddr(), rwc}, + &client.pr, + client.goLiveMsg, + client.goLiveMsg.ConnType == WebSocket, + ) + if err != nil { + log.Error(err) + } + }() + wg.Wait() + return // close connection + } + log.Error(err) + return // close connection, NOW + } + + client.outputType = msg.OutputType + } else { + client.Write([]byte("HTTP/1.1 500 Bad Request\r\nConnection: close\r\n\r\n")) + break + } + if msg.ConnType == HTTP || msg.ConnType == WebSocket { + close = true // close connection + break + } + } + + packet = packet[len(packet)-rdbuf.Len():] + client.in.End(packet) + + // write to client + if len(client.out) > 0 { + func() { + // prewrite + server.mu.Lock() + defer server.mu.Unlock() + server.flushAOF() + }() + conn.Write(client.out) + client.out = client.out[:0] + } + if close { + break + } + } + }(conn) + } +} + type liveConn struct { remoteAddr net.Addr rwc io.ReadWriteCloser