Code cleanup

This commit is contained in:
tidwall 2018-11-05 15:24:45 -07:00
parent 0cd6d164d6
commit 933f243c6c
2 changed files with 190 additions and 3 deletions

View File

@ -35,7 +35,7 @@ var (
) )
// TODO: Set to false in 2.* // TODO: Set to false in 2.*
var httpTransport bool = true var httpTransport = true
// Fire up a webhook test server by using the --webhook-http-consumer-port // Fire up a webhook test server by using the --webhook-http-consumer-port
// for example // for example
@ -47,7 +47,7 @@ var httpTransport bool = true
type hserver struct{} type hserver struct{}
func (s *hserver) Send(ctx context.Context, in *hservice.MessageRequest) (*hservice.MessageReply, error) { func (s *hserver) Send(ctx context.Context, in *hservice.MessageRequest) (*hservice.MessageReply, error) {
return &hservice.MessageReply{true}, nil return &hservice.MessageReply{Ok: true}, nil
} }
func main() { func main() {

View File

@ -36,6 +36,8 @@ import (
"github.com/tidwall/tile38/internal/log" "github.com/tidwall/tile38/internal/log"
) )
const useEvio = false
var errOOM = errors.New("OOM command not allowed when used memory > 'maxmemory'") var errOOM = errors.New("OOM command not allowed when used memory > 'maxmemory'")
const goingLive = "going live" const goingLive = "going live"
@ -270,7 +272,10 @@ func Serve(host string, port int, dir string, http bool) error {
}() }()
// Start the network server // Start the network server
return server.evioServe() if useEvio {
return server.evioServe()
}
return server.netServe()
} }
func (server *Server) isProtected() bool { func (server *Server) isProtected() bool {
@ -481,6 +486,188 @@ func (server *Server) evioServe() error {
return evio.Serve(events, fmt.Sprintf("%s:%d", server.host, server.port)) return evio.Serve(events, fmt.Sprintf("%s:%d", server.host, server.port))
} }
func (server *Server) netServe() error {
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", server.host, server.port))
if err != nil {
return err
}
defer ln.Close()
log.Infof("Ready to accept connections at %s", ln.Addr())
var clientID int64
for {
conn, err := ln.Accept()
if err != nil {
return err
}
go func(conn net.Conn) {
// open connection
// create the client
client := new(Client)
client.id = int(atomic.AddInt64(&clientID, 1))
client.opened = time.Now()
client.remoteAddr = conn.RemoteAddr().String()
// add client to server map
server.connsmu.Lock()
server.conns[client.id] = client
server.connsmu.Unlock()
server.statsTotalConns.add(1)
// set the client keep-alive, if needed
if server.config.keepAlive() > 0 {
if conn, ok := conn.(*net.TCPConn); ok {
conn.SetKeepAlive(true)
conn.SetKeepAlivePeriod(
time.Duration(server.config.keepAlive()) * time.Second,
)
}
}
log.Debugf("Opened connection: %s", client.remoteAddr)
defer func() {
// close connection
// delete from server map
server.connsmu.Lock()
delete(server.conns, client.id)
server.connsmu.Unlock()
log.Debugf("Closed connection: %s", client.remoteAddr)
conn.Close()
}()
// check if the connection is protected
if !strings.HasPrefix(client.remoteAddr, "127.0.0.1:") &&
!strings.HasPrefix(client.remoteAddr, "[::1]:") {
if server.isProtected() {
// This is a protected server. Only loopback is allowed.
conn.Write(deniedMessage)
return // close connection
}
}
packet := make([]byte, 0xFFFF)
for {
var close bool
n, err := conn.Read(packet)
if err != nil {
return
}
in := packet[:n]
// read the payload packet from the client input stream.
packet := client.in.Begin(in)
// load the pipeline reader
pr := &client.pr
rdbuf := bytes.NewBuffer(packet)
pr.rd = rdbuf
pr.wr = client
msgs, err := pr.ReadMessages()
if err != nil {
log.Error(err)
return // close connection
}
for _, msg := range msgs {
// Just closing connection if we have deprecated HTTP or WS connection,
// And --http-transport = false
if !server.http && (msg.ConnType == WebSocket ||
msg.ConnType == HTTP) {
close = true // close connection
break
}
if msg != nil && msg.Command() != "" {
if client.outputType != Null {
msg.OutputType = client.outputType
}
if msg.Command() == "quit" {
if msg.OutputType == RESP {
io.WriteString(client, "+OK\r\n")
}
close = true // close connection
break
}
// increment last used
client.mu.Lock()
client.last = time.Now()
client.mu.Unlock()
// update total command count
server.statsTotalCommands.add(1)
// handle the command
err := server.handleInputCommand(client, msg)
if err != nil {
if err.Error() == goingLive {
client.goLiveErr = err
client.goLiveMsg = msg
// detach
var rwc io.ReadWriteCloser = conn
client.conn = rwc
if len(client.out) > 0 {
client.conn.Write(client.out)
client.out = nil
}
client.in = evio.InputStream{}
client.pr.rd = rwc
client.pr.wr = rwc
log.Debugf("Detached connection: %s", client.remoteAddr)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err := server.goLive(
client.goLiveErr,
&liveConn{conn.RemoteAddr(), rwc},
&client.pr,
client.goLiveMsg,
client.goLiveMsg.ConnType == WebSocket,
)
if err != nil {
log.Error(err)
}
}()
wg.Wait()
return // close connection
}
log.Error(err)
return // close connection, NOW
}
client.outputType = msg.OutputType
} else {
client.Write([]byte("HTTP/1.1 500 Bad Request\r\nConnection: close\r\n\r\n"))
break
}
if msg.ConnType == HTTP || msg.ConnType == WebSocket {
close = true // close connection
break
}
}
packet = packet[len(packet)-rdbuf.Len():]
client.in.End(packet)
// write to client
if len(client.out) > 0 {
func() {
// prewrite
server.mu.Lock()
defer server.mu.Unlock()
server.flushAOF()
}()
conn.Write(client.out)
client.out = client.out[:0]
}
if close {
break
}
}
}(conn)
}
}
type liveConn struct { type liveConn struct {
remoteAddr net.Addr remoteAddr net.Addr
rwc io.ReadWriteCloser rwc io.ReadWriteCloser