diff --git a/internal/server/follow.go b/internal/server/follow.go index 21c39575..a70c2204 100644 --- a/internal/server/follow.go +++ b/internal/server/follow.go @@ -192,6 +192,7 @@ func (s *Server) followStep(host string, port int, followc int) error { return errNoLongerFollowing } s.mu.Lock() + s.faofsz = 0 s.fcup = false auth := s.config.leaderAuth() s.mu.Unlock() @@ -263,6 +264,10 @@ func (s *Server) followStep(host string, port int, followc int) error { return err } + s.mu.Lock() + s.faofsz = int(aofSize) + s.mu.Unlock() + caughtUp := pos >= aofSize if caughtUp { s.mu.Lock() @@ -271,6 +276,7 @@ func (s *Server) followStep(host string, port int, followc int) error { s.mu.Unlock() log.Info("caught up") } + nullw := io.Discard for { v, telnet, _, err := conn.rd.ReadMultiBulk() @@ -290,10 +296,12 @@ func (s *Server) followStep(host string, port int, followc int) error { if err != nil { return err } + s.mu.Lock() + s.faofsz = aofsz + s.mu.Unlock() if !caughtUp { if aofsz >= int(aofSize) { caughtUp = true - s.mu.Lock() s.flushAOF(false) s.fcup = true s.fcuponce = true diff --git a/internal/server/server.go b/internal/server/server.go index 95b37112..24a3866f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -136,6 +136,7 @@ type Server struct { lstack []*commandDetails lives map[*liveBuffer]bool lcond *sync.Cond // live geofence signal + faofsz int // last reported aofsize fcup bool // follow caught up fcuponce bool // follow caught up once aofconnM map[net.Conn]io.Closer @@ -1023,7 +1024,7 @@ func (s *Server) handleInputCommand(client *Client, msg *Message) error { } case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "chans", "search", "ttl", "bounds", "server", "info", "type", "jget", - "evalro", "evalrosha", "healthz": + "evalro", "evalrosha", "healthz", "role": // read operations s.mu.RLock() @@ -1210,6 +1211,8 @@ func (s *Server) command(msg *Message, client *Client) ( res, err = s.cmdHEALTHZ(msg) case "info": res, err = s.cmdINFO(msg) + case "role": + res, err = s.cmdROLE(msg) case "scan": res, err = s.cmdScan(msg) case "nearby": diff --git a/internal/server/stats.go b/internal/server/stats.go index 0d324f02..9a0ae287 100644 --- a/internal/server/stats.go +++ b/internal/server/stats.go @@ -428,6 +428,19 @@ func (s *Server) writeInfoStats(w *bytes.Buffer) { fmt.Fprintf(w, "expired_keys:%d\r\n", s.statsExpired.Load()) // Total number of key expiration events } +func replicaIPAndPort(cc *Client) (ip string, port int) { + ip = cc.remoteAddr + i := strings.LastIndex(ip, ":") + if i != -1 { + ip = ip[:i] + if ip == "[::1]" { + ip = "localhost" + } + } + port = cc.replPort + return ip, port +} + // writeInfoReplication writes all replication data to the 'info' response func (s *Server) writeInfoReplication(w *bytes.Buffer) { if s.config.followHost() != "" { @@ -443,8 +456,9 @@ func (s *Server) writeInfoReplication(w *bytes.Buffer) { s.connsmu.RLock() for _, cc := range s.conns { if cc.replPort != 0 { + ip, port := replicaIPAndPort(cc) fmt.Fprintf(w, "slave%v:ip=%s,port=%v,state=online\r\n", i, - strings.Split(cc.remoteAddr, ":")[0], cc.replPort) + ip, port) i++ } } @@ -588,3 +602,95 @@ func respValuesSimpleMap(m map[string]interface{}) []resp.Value { } return vals } + +// ROLE +func (s *Server) cmdROLE(msg *Message) (res resp.Value, err error) { + start := time.Now() + var role string + var offset int + var ips []string + var ports []int + var offsets []int + var host string + var port int + var state string + if s.config.followHost() == "" { + role = "master" + offset = s.aofsz + s.connsmu.RLock() + for _, cc := range s.conns { + if cc.replPort != 0 { + ip, port := replicaIPAndPort(cc) + ips = append(ips, ip) + ports = append(ports, port) + offsets = append(offsets, s.aofsz) + } + } + s.connsmu.RUnlock() + } else { + role = "slave" + host = s.config.followHost() + port = s.config.followPort() + offset = int(s.faofsz) + state = "connected" + } + if msg.OutputType == JSON { + var json []byte + json = append(json, `{"ok":true,"role":{"`...) + json = append(json, `"role":`...) + json = appendJSONString(json, role) + if role == "master" { + json = append(json, `,"offset":`...) + json = strconv.AppendInt(json, int64(offset), 10) + json = append(json, `,"slaves":[`...) + for i := range ips { + if i > 0 { + json = append(json, ',') + } + json = append(json, '{') + json = append(json, `"ip":`...) + json = appendJSONString(json, ips[i]) + json = append(json, `,"port":`...) + json = appendJSONString(json, fmt.Sprint(ports[i])) + json = append(json, `,"offset":`...) + json = appendJSONString(json, fmt.Sprint(offsets[i])) + json = append(json, '}') + } + json = append(json, `]`...) + } else if role == "slave" { + json = append(json, `,"host":`...) + json = appendJSONString(json, host) + json = append(json, `,"port":`...) + json = strconv.AppendInt(json, int64(port), 10) + json = append(json, `,"state":`...) + json = appendJSONString(json, state) + json = append(json, `,"offset":`...) + json = strconv.AppendInt(json, int64(offset), 10) + } + json = append(json, `},"elapsed":`...) + json = appendJSONString(json, time.Since(start).String()) + json = append(json, '}') + return resp.StringValue(string(json)), nil + } else { + var vals []resp.Value + vals = append(vals, resp.StringValue(role)) + if role == "master" { + vals = append(vals, resp.IntegerValue(offset)) + var replicaVals []resp.Value + for i := range ips { + var vals []resp.Value + vals = append(vals, resp.StringValue(ips[i])) + vals = append(vals, resp.StringValue(fmt.Sprint(ports[i]))) + vals = append(vals, resp.StringValue(fmt.Sprint(offsets[i]))) + replicaVals = append(replicaVals, resp.ArrayValue(vals)) + } + vals = append(vals, resp.ArrayValue(replicaVals)) + } else if role == "slave" { + vals = append(vals, resp.StringValue(host)) + vals = append(vals, resp.IntegerValue(port)) + vals = append(vals, resp.StringValue(state)) + vals = append(vals, resp.IntegerValue(offset)) + } + return resp.ArrayValue(vals), nil + } +}