wip config

This commit is contained in:
Josh Baker 2017-09-29 18:11:05 -07:00
parent d6936636c2
commit d8f11354df
7 changed files with 357 additions and 178 deletions

View File

@ -166,7 +166,7 @@ func (c *Controller) writeAOF(value resp.Value, d *commandDetailsT) error {
if !d.updated { if !d.updated {
return nil // just ignore writes if the command did not update return nil // just ignore writes if the command did not update
} }
if c.config.FollowHost == "" { if c.config.followHost() == "" {
// process hooks, for leader only // process hooks, for leader only
if d.parent { if d.parent {
// process children only // process children only

View File

@ -7,8 +7,10 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/tidwall/gjson"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"github.com/tidwall/tile38/controller/glob" "github.com/tidwall/tile38/controller/glob"
"github.com/tidwall/tile38/controller/server" "github.com/tidwall/tile38/controller/server"
@ -20,6 +22,12 @@ const (
) )
const ( const (
FollowHost = "follow_host"
FollowPort = "follow_port"
FollowID = "follow_id"
FollowPos = "follow_pos"
ServerID = "server_id"
ReadOnly = "read_only"
RequirePass = "requirepass" RequirePass = "requirepass"
LeaderAuth = "leaderauth" LeaderAuth = "leaderauth"
ProtectedMode = "protected-mode" ProtectedMode = "protected-mode"
@ -32,82 +40,309 @@ var validProperties = []string{RequirePass, LeaderAuth, ProtectedMode, MaxMemory
// Config is a tile38 config // Config is a tile38 config
type Config struct { type Config struct {
FollowHost string `json:"follow_host,omitempty"` path string
FollowPort int `json:"follow_port,omitempty"`
FollowID string `json:"follow_id,omitempty"`
FollowPos int `json:"follow_pos,omitempty"`
ServerID string `json:"server_id,omitempty"`
ReadOnly bool `json:"read_only,omitempty"`
// Properties mu sync.RWMutex
RequirePassP string `json:"requirepass,omitempty"`
RequirePass string `json:"-"` _followHost string
LeaderAuthP string `json:"leaderauth,omitempty"` _followPort int64
LeaderAuth string `json:"-"` _followID string
ProtectedModeP string `json:"protected-mode,omitempty"` _followPos int64
ProtectedMode string `json:"-"` _serverID string
MaxMemoryP string `json:"maxmemory,omitempty"` _readOnly bool
MaxMemory int `json:"-"`
AutoGCP string `json:"autogc,omitempty"` _requirePassP string
AutoGC uint64 `json:"-"` _requirePass string
KeepAliveP string `json:"keepalive,omitempty"` _leaderAuthP string
KeepAlive int `json:"-"` _leaderAuth string
_protectedModeP string
_protectedMode string
_maxMemoryP string
_maxMemory int64
_autoGCP string
_autoGC uint64
_keepAliveP string
_keepAlive int64
} }
func (c *Controller) loadConfig() error { func loadConfig(path string) (*Config, error) {
data, err := ioutil.ReadFile(c.dir + "/config") data, err := ioutil.ReadFile(path)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
return c.initConfig() config := &Config{path: path, _serverID: randomKey(16)}
config.write(true)
return config, nil
} }
return err return nil, err
} }
err = json.Unmarshal(data, &c.config) json := string(data)
if err != nil { config := &Config{
return err path: path,
_followHost: gjson.Get(json, FollowHost).String(),
_followPort: gjson.Get(json, FollowPort).Int(),
_followID: gjson.Get(json, FollowID).String(),
_followPos: gjson.Get(json, FollowPos).Int(),
_serverID: gjson.Get(json, ServerID).String(),
_readOnly: gjson.Get(json, ReadOnly).Bool(),
_requirePassP: gjson.Get(json, RequirePass).String(),
_leaderAuthP: gjson.Get(json, LeaderAuth).String(),
_protectedModeP: gjson.Get(json, ProtectedMode).String(),
_maxMemoryP: gjson.Get(json, MaxMemory).String(),
_autoGCP: gjson.Get(json, AutoGC).String(),
_keepAliveP: gjson.Get(json, KeepAlive).String(),
} }
// load properties // load properties
if err := c.setConfigProperty(RequirePass, c.config.RequirePassP, true); err != nil { if err := config.setProperty(RequirePass, config._requirePassP, true); err != nil {
return err return nil, err
} }
if err := c.setConfigProperty(LeaderAuth, c.config.LeaderAuthP, true); err != nil { if err := config.setProperty(LeaderAuth, config._leaderAuthP, true); err != nil {
return err return nil, err
} }
if err := c.setConfigProperty(ProtectedMode, c.config.ProtectedModeP, true); err != nil { if err := config.setProperty(ProtectedMode, config._protectedModeP, true); err != nil {
return err println(2)
return nil, err
} }
if err := c.setConfigProperty(MaxMemory, c.config.MaxMemoryP, true); err != nil { if err := config.setProperty(MaxMemory, config._maxMemoryP, true); err != nil {
return err return nil, err
} }
if err := c.setConfigProperty(AutoGC, c.config.AutoGCP, true); err != nil { if err := config.setProperty(AutoGC, config._autoGCP, true); err != nil {
return err return nil, err
} }
if err := c.setConfigProperty(KeepAlive, c.config.KeepAliveP, true); err != nil { if err := config.setProperty(KeepAlive, config._keepAliveP, true); err != nil {
return err return nil, err
} }
return nil config.write(false)
return config, nil
} }
func parseMemSize(s string) (bytes int, ok bool) { func (config *Config) write(writeProperties bool) {
config.mu.Lock()
defer config.mu.Unlock()
if writeProperties {
// save properties
config._requirePassP = config._requirePass
config._leaderAuthP = config._leaderAuth
if config._protectedMode == defaultProtectedMode {
config._protectedModeP = ""
} else {
config._protectedModeP = config._protectedMode
}
config._maxMemoryP = formatMemSize(config._maxMemory)
if config._autoGC == 0 {
config._autoGCP = ""
} else {
config._autoGCP = strconv.FormatUint(config._autoGC, 10)
}
if config._keepAlive == defaultKeepAlive {
config._keepAliveP = ""
} else {
config._keepAliveP = strconv.FormatUint(uint64(config._keepAlive), 10)
}
}
m := make(map[string]interface{})
if config._followHost != "" {
m[FollowHost] = config._followHost
}
if config._followPort != 0 {
m[FollowPort] = config._followPort
}
if config._followID != "" {
m[FollowID] = config._followID
}
if config._followPos != 0 {
m[FollowPos] = config._followPos
}
if config._serverID != "" {
m[ServerID] = config._serverID
}
if config._readOnly {
m[ReadOnly] = config._readOnly
}
if config._requirePassP != "" {
m[RequirePass] = config._requirePassP
}
if config._leaderAuthP != "" {
m[LeaderAuth] = config._leaderAuthP
}
if config._protectedModeP != "" {
m[ProtectedMode] = config._protectedModeP
}
if config._maxMemoryP != "" {
m[MaxMemory] = config._maxMemoryP
}
if config._autoGCP != "" {
m[AutoGC] = config._autoGCP
}
if config._keepAliveP != "" {
m[KeepAlive] = config._keepAliveP
}
data, err := json.MarshalIndent(m, "", "\t")
if err != nil {
panic(err)
}
data = append(data, '\n')
err = ioutil.WriteFile(config.path, data, 0600)
if err != nil {
panic(err)
}
}
func (config *Config) followHost() string {
config.mu.RLock()
v := config._followHost
config.mu.RUnlock()
return v
}
func (config *Config) followPort() int {
config.mu.RLock()
v := config._followPort
config.mu.RUnlock()
return int(v)
}
func (config *Config) followID() string {
config.mu.RLock()
v := config._followID
config.mu.RUnlock()
return v
}
func (config *Config) followPos() int64 {
config.mu.RLock()
v := config._followPos
config.mu.RUnlock()
return v
}
func (config *Config) serverID() string {
config.mu.RLock()
v := config._serverID
config.mu.RUnlock()
return v
}
func (config *Config) readOnly() bool {
config.mu.RLock()
v := config._readOnly
config.mu.RUnlock()
return v
}
func (config *Config) requirePass() string {
config.mu.RLock()
v := config._requirePass
config.mu.RUnlock()
return v
}
func (config *Config) leaderAuth() string {
config.mu.RLock()
v := config._leaderAuth
config.mu.RUnlock()
return v
}
func (config *Config) protectedMode() string {
config.mu.RLock()
v := config._protectedMode
config.mu.RUnlock()
return v
}
func (config *Config) maxMemory() int {
config.mu.RLock()
v := config._maxMemory
config.mu.RUnlock()
return int(v)
}
func (config *Config) autoGC() uint64 {
config.mu.RLock()
v := config._autoGC
config.mu.RUnlock()
return v
}
func (config *Config) keepAlive() int64 {
config.mu.RLock()
v := config._keepAlive
config.mu.RUnlock()
return v
}
func (config *Config) setFollowHost(v string) {
config.mu.Lock()
config._followHost = v
config.mu.Unlock()
}
func (config *Config) setFollowPort(v int) {
config.mu.Lock()
config._followPort = int64(v)
config.mu.Unlock()
}
func (config *Config) setFollowID(v string) {
config.mu.Lock()
config._followID = v
config.mu.Unlock()
}
func (config *Config) setFollowPos(v int64) {
config.mu.Lock()
config._followPos = v
config.mu.Unlock()
}
func (config *Config) setServerID(v string) {
config.mu.Lock()
config._serverID = v
config.mu.Unlock()
}
func (config *Config) setReadOnly(v bool) {
config.mu.Lock()
config._readOnly = v
config.mu.Unlock()
}
func (config *Config) setRequirePass(v string) {
config.mu.Lock()
config._requirePass = v
config.mu.Unlock()
}
func (config *Config) setLeaderAuth(v string) {
config.mu.Lock()
config._leaderAuth = v
config.mu.Unlock()
}
func (config *Config) setProtectedMode(v string) {
config.mu.Lock()
config._protectedMode = v
config.mu.Unlock()
}
func (config *Config) setMaxMemory(v int) {
config.mu.Lock()
config._maxMemory = int64(v)
config.mu.Unlock()
}
func (config *Config) setAutoGC(v uint64) {
config.mu.Lock()
config._autoGC = v
config.mu.Unlock()
}
func (config *Config) setKeepAlive(v int64) {
config.mu.Lock()
config._keepAlive = v
config.mu.Unlock()
}
func parseMemSize(s string) (bytes int64, ok bool) {
if s == "" { if s == "" {
return 0, true return 0, true
} }
s = strings.ToLower(s) s = strings.ToLower(s)
var n uint64 var n uint64
var sz int var sz int64
var err error var err error
if strings.HasSuffix(s, "gb") { if strings.HasSuffix(s, "gb") {
n, err = strconv.ParseUint(s[:len(s)-2], 10, 64) n, err = strconv.ParseUint(s[:len(s)-2], 10, 64)
sz = int(n * 1024 * 1024 * 1024) sz = int64(n * 1024 * 1024 * 1024)
} else if strings.HasSuffix(s, "mb") { } else if strings.HasSuffix(s, "mb") {
n, err = strconv.ParseUint(s[:len(s)-2], 10, 64) n, err = strconv.ParseUint(s[:len(s)-2], 10, 64)
sz = int(n * 1024 * 1024) sz = int64(n * 1024 * 1024)
} else if strings.HasSuffix(s, "kb") { } else if strings.HasSuffix(s, "kb") {
n, err = strconv.ParseUint(s[:len(s)-2], 10, 64) n, err = strconv.ParseUint(s[:len(s)-2], 10, 64)
sz = int(n * 1024) sz = int64(n * 1024)
} else { } else {
n, err = strconv.ParseUint(s, 10, 64) n, err = strconv.ParseUint(s, 10, 64)
sz = int(n) sz = int64(n)
} }
if err != nil { if err != nil {
return 0, false return 0, false
@ -115,72 +350,74 @@ func parseMemSize(s string) (bytes int, ok bool) {
return sz, true return sz, true
} }
func formatMemSize(sz int) string { func formatMemSize(sz int64) string {
if sz <= 0 { if sz <= 0 {
return "" return ""
} }
if sz < 1024 { if sz < 1024 {
return strconv.FormatInt(int64(sz), 10) return strconv.FormatInt(sz, 10)
} }
sz /= 1024 sz /= 1024
if sz < 1024 { if sz < 1024 {
return strconv.FormatInt(int64(sz), 10) + "kb" return strconv.FormatInt(sz, 10) + "kb"
} }
sz /= 1024 sz /= 1024
if sz < 1024 { if sz < 1024 {
return strconv.FormatInt(int64(sz), 10) + "mb" return strconv.FormatInt(sz, 10) + "mb"
} }
sz /= 1024 sz /= 1024
return strconv.FormatInt(int64(sz), 10) + "gb" return strconv.FormatInt(sz, 10) + "gb"
} }
func (c *Controller) setConfigProperty(name, value string, fromLoad bool) error { func (config *Config) setProperty(name, value string, fromLoad bool) error {
config.mu.Lock()
defer config.mu.Unlock()
var invalid bool var invalid bool
switch name { switch name {
default: default:
return fmt.Errorf("Unsupported CONFIG parameter: %s", name) return fmt.Errorf("Unsupported CONFIG parameter: %s", name)
case RequirePass: case RequirePass:
c.config.RequirePass = value config._requirePass = value
case LeaderAuth: case LeaderAuth:
c.config.LeaderAuth = value config._leaderAuth = value
case AutoGC: case AutoGC:
if value == "" { if value == "" {
c.config.AutoGC = 0 config._autoGC = 0
} else { } else {
gc, err := strconv.ParseUint(value, 10, 64) gc, err := strconv.ParseUint(value, 10, 64)
if err != nil { if err != nil {
return err return err
} }
c.config.AutoGC = gc config._autoGC = gc
} }
case MaxMemory: case MaxMemory:
sz, ok := parseMemSize(value) sz, ok := parseMemSize(value)
if !ok { if !ok {
return fmt.Errorf("Invalid argument '%s' for CONFIG SET '%s'", value, name) return fmt.Errorf("Invalid argument '%s' for CONFIG SET '%s'", value, name)
} }
c.config.MaxMemory = sz config._maxMemory = sz
case ProtectedMode: case ProtectedMode:
switch strings.ToLower(value) { switch strings.ToLower(value) {
case "": case "":
if fromLoad { if fromLoad {
c.config.ProtectedMode = defaultProtectedMode config._protectedMode = defaultProtectedMode
} else { } else {
invalid = true invalid = true
} }
case "yes", "no": case "yes", "no":
c.config.ProtectedMode = strings.ToLower(value) config._protectedMode = strings.ToLower(value)
default: default:
invalid = true invalid = true
} }
case KeepAlive: case KeepAlive:
if value == "" { if value == "" {
c.config.KeepAlive = defaultKeepAlive config._keepAlive = defaultKeepAlive
} else { } else {
keepalive, err := strconv.ParseUint(value, 10, 64) keepalive, err := strconv.ParseUint(value, 10, 64)
if err != nil { if err != nil {
invalid = true invalid = true
} else { } else {
c.config.KeepAlive = int(keepalive) config._keepAlive = int64(keepalive)
} }
} }
} }
@ -191,83 +428,38 @@ func (c *Controller) setConfigProperty(name, value string, fromLoad bool) error
return nil return nil
} }
func (c *Controller) getConfigProperties(pattern string) map[string]interface{} { func (config *Config) getProperties(pattern string) map[string]interface{} {
m := make(map[string]interface{}) m := make(map[string]interface{})
for _, name := range validProperties { for _, name := range validProperties {
matched, _ := glob.Match(pattern, name) matched, _ := glob.Match(pattern, name)
if matched { if matched {
m[name] = c.getConfigProperty(name) m[name] = config.getProperty(name)
} }
} }
return m return m
} }
func (c *Controller) getConfigProperty(name string) string {
func (config *Config) getProperty(name string) string {
config.mu.RLock()
defer config.mu.RUnlock()
switch name { switch name {
default: default:
return "" return ""
case AutoGC: case AutoGC:
return strconv.FormatUint(c.config.AutoGC, 10) return strconv.FormatUint(config._autoGC, 10)
case RequirePass: case RequirePass:
return c.config.RequirePass return config._requirePass
case LeaderAuth: case LeaderAuth:
return c.config.LeaderAuth return config._leaderAuth
case ProtectedMode: case ProtectedMode:
return c.config.ProtectedMode return config._protectedMode
case MaxMemory: case MaxMemory:
return formatMemSize(c.config.MaxMemory) return formatMemSize(config._maxMemory)
case KeepAlive: case KeepAlive:
return strconv.FormatUint(uint64(c.config.KeepAlive), 10) return strconv.FormatUint(uint64(config._keepAlive), 10)
} }
} }
func (c *Controller) initConfig() error {
c.config = Config{ServerID: randomKey(16)}
return c.writeConfig(true)
}
func (c *Controller) writeConfig(writeProperties bool) error {
var err error
bak := c.config
defer func() {
if err != nil {
// revert changes
c.config = bak
}
}()
if writeProperties {
// save properties
c.config.RequirePassP = c.config.RequirePass
c.config.LeaderAuthP = c.config.LeaderAuth
if c.config.ProtectedMode == defaultProtectedMode {
c.config.ProtectedModeP = ""
} else {
c.config.ProtectedModeP = c.config.ProtectedMode
}
c.config.MaxMemoryP = formatMemSize(c.config.MaxMemory)
if c.config.AutoGC == 0 {
c.config.AutoGCP = ""
} else {
c.config.AutoGCP = strconv.FormatUint(c.config.AutoGC, 10)
}
if c.config.KeepAlive == defaultKeepAlive {
c.config.KeepAliveP = ""
} else {
c.config.KeepAliveP = strconv.FormatUint(uint64(c.config.KeepAlive), 10)
}
}
var data []byte
data, err = json.MarshalIndent(c.config, "", "\t")
if err != nil {
return err
}
data = append(data, '\n')
err = ioutil.WriteFile(c.dir+"/config", data, 0600)
if err != nil {
return err
}
return nil
}
func (c *Controller) cmdConfigGet(msg *server.Message) (res string, err error) { func (c *Controller) cmdConfigGet(msg *server.Message) (res string, err error) {
start := time.Now() start := time.Now()
vs := msg.Values[1:] vs := msg.Values[1:]
@ -279,7 +471,7 @@ func (c *Controller) cmdConfigGet(msg *server.Message) (res string, err error) {
if len(vs) != 0 { if len(vs) != 0 {
return "", errInvalidNumberOfArguments return "", errInvalidNumberOfArguments
} }
m := c.getConfigProperties(name) m := c.config.getProperties(name)
switch msg.OutputType { switch msg.OutputType {
case server.JSON: case server.JSON:
data, err := json.Marshal(m) data, err := json.Marshal(m)
@ -314,7 +506,7 @@ func (c *Controller) cmdConfigSet(msg *server.Message) (res string, err error) {
if len(vs) != 0 { if len(vs) != 0 {
return "", errInvalidNumberOfArguments return "", errInvalidNumberOfArguments
} }
if err := c.setConfigProperty(name, value, false); err != nil { if err := c.config.setProperty(name, value, false); err != nil {
return "", err return "", err
} }
return server.OKMessage(msg, start), nil return server.OKMessage(msg, start), nil
@ -325,8 +517,6 @@ func (c *Controller) cmdConfigRewrite(msg *server.Message) (res string, err erro
if len(vs) != 0 { if len(vs) != 0 {
return "", errInvalidNumberOfArguments return "", errInvalidNumberOfArguments
} }
if err := c.writeConfig(true); err != nil { c.config.write(true)
return "", err
}
return server.OKMessage(msg, start), nil return server.OKMessage(msg, start), nil
} }

View File

@ -9,6 +9,7 @@ import (
"net" "net"
"os" "os"
"path" "path"
"path/filepath"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
@ -76,7 +77,7 @@ type Controller struct {
cols *btree.BTree cols *btree.BTree
aofsz int aofsz int
dir string dir string
config Config config *Config
followc uint64 // counter increases when follow property changes followc uint64 // counter increases when follow property changes
follows map[*bytes.Buffer]bool follows map[*bytes.Buffer]bool
fcond *sync.Cond fcond *sync.Cond
@ -138,7 +139,9 @@ func ListenAndServeEx(host string, port int, dir string, ln *net.Listener, http
if err := os.MkdirAll(dir, 0700); err != nil { if err := os.MkdirAll(dir, 0700); err != nil {
return err return err
} }
if err := c.loadConfig(); err != nil { var err error
c.config, err = loadConfig(filepath.Join(dir, "config"))
if err != nil {
return err return err
} }
// load the queue before the aof // load the queue before the aof
@ -179,8 +182,8 @@ func ListenAndServeEx(host string, port int, dir string, ln *net.Listener, http
} }
c.mu.Lock() c.mu.Lock()
c.fillExpiresList() c.fillExpiresList()
if c.config.FollowHost != "" { if c.config.followHost() != "" {
go c.follow(c.config.FollowHost, c.config.FollowPort, c.followc) go c.follow(c.config.followHost(), c.config.followPort(), c.followc)
} }
c.mu.Unlock() c.mu.Unlock()
defer func() { defer func() {
@ -225,17 +228,17 @@ func ListenAndServeEx(host string, port int, dir string, ln *net.Listener, http
return false return false
} }
c.mu.RLock() c.mu.RLock()
is := c.config.ProtectedMode != "no" && c.config.RequirePass == "" is := c.config.protectedMode() != "no" && c.config.requirePass() == ""
c.mu.RUnlock() c.mu.RUnlock()
return is return is
} }
var clientId uint64
var clientId uint64
opened := func(conn *server.Conn) { opened := func(conn *server.Conn) {
c.mu.Lock() c.mu.Lock()
if c.config.KeepAlive > 0 { if c.config.keepAlive() > 0 {
err := conn.SetKeepAlive( err := conn.SetKeepAlive(
time.Duration(c.config.KeepAlive) * time.Second) time.Duration(c.config.keepAlive()) * time.Second)
if err != nil { if err != nil {
log.Warnf("could not set keepalive for connection: %v", log.Warnf("could not set keepalive for connection: %v",
conn.RemoteAddr().String()) conn.RemoteAddr().String())
@ -271,14 +274,13 @@ func (c *Controller) watchGC() {
return return
} }
autoGC := c.config.AutoGC
c.mu.RUnlock() c.mu.RUnlock()
if autoGC == 0 { if c.config.autoGC() == 0 {
continue continue
} }
if time.Now().Sub(s) < time.Second*time.Duration(autoGC) { if time.Now().Sub(s) < time.Second*time.Duration(c.config.autoGC()) {
continue continue
} }
@ -310,10 +312,9 @@ func (c *Controller) watchMemory() {
c.mu.RUnlock() c.mu.RUnlock()
return return
} }
maxmem := c.config.MaxMemory
oom := c.outOfMemory oom := c.outOfMemory
c.mu.RUnlock() c.mu.RUnlock()
if maxmem == 0 { if c.config.maxMemory() == 0 {
if oom { if oom {
c.mu.Lock() c.mu.Lock()
c.outOfMemory = false c.outOfMemory = false
@ -326,7 +327,7 @@ func (c *Controller) watchMemory() {
} }
runtime.ReadMemStats(&mem) runtime.ReadMemStats(&mem)
c.mu.Lock() c.mu.Lock()
c.outOfMemory = int(mem.HeapAlloc) > maxmem c.outOfMemory = int(mem.HeapAlloc) > c.config.maxMemory()
c.mu.Unlock() c.mu.Unlock()
}() }()
} }
@ -443,10 +444,7 @@ func (c *Controller) handleInputCommand(conn *server.Conn, msg *server.Message,
var write bool var write bool
if !conn.Authenticated || msg.Command == "auth" { if !conn.Authenticated || msg.Command == "auth" {
c.mu.RLock() if c.config.requirePass() != "" {
requirePass := c.config.RequirePass
c.mu.RUnlock()
if requirePass != "" {
password := "" password := ""
// This better be an AUTH command or the Message should contain an Auth // This better be an AUTH command or the Message should contain an Auth
if msg.Command != "auth" && msg.Auth == "" { if msg.Command != "auth" && msg.Auth == "" {
@ -460,7 +458,7 @@ func (c *Controller) handleInputCommand(conn *server.Conn, msg *server.Message,
password = msg.Values[1].String() password = msg.Values[1].String()
} }
} }
if requirePass != strings.TrimSpace(password) { if c.config.requirePass() != strings.TrimSpace(password) {
return writeErr(errors.New("invalid password")) return writeErr(errors.New("invalid password"))
} }
conn.Authenticated = true conn.Authenticated = true
@ -482,10 +480,10 @@ func (c *Controller) handleInputCommand(conn *server.Conn, msg *server.Message,
write = true write = true
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.config.FollowHost != "" { if c.config.followHost() != "" {
return writeErr(errors.New("not the leader")) return writeErr(errors.New("not the leader"))
} }
if c.config.ReadOnly { if c.config.readOnly() {
return writeErr(errors.New("read only")) return writeErr(errors.New("read only"))
} }
case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "search", case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "search",
@ -493,7 +491,7 @@ func (c *Controller) handleInputCommand(conn *server.Conn, msg *server.Message,
// read operations // read operations
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
if c.config.FollowHost != "" && !c.fcuponce { if c.config.followHost() != "" && !c.fcuponce {
return writeErr(errors.New("catching up to leader")) return writeErr(errors.New("catching up to leader"))
} }
case "follow", "readonly", "config": case "follow", "readonly", "config":

View File

@ -716,7 +716,7 @@ func (c *Controller) parseSetArgs(vs []resp.Value) (
} }
func (c *Controller) cmdSet(msg *server.Message) (res string, d commandDetailsT, err error) { func (c *Controller) cmdSet(msg *server.Message) (res string, d commandDetailsT, err error) {
if c.config.MaxMemory > 0 && c.outOfMemory { if c.config.maxMemory() > 0 && c.outOfMemory {
err = errOOM err = errOOM
return return
} }

View File

@ -36,19 +36,18 @@ func (c *Controller) cmdFollow(msg *server.Message) (res string, err error) {
host = strings.ToLower(host) host = strings.ToLower(host)
sport = strings.ToLower(sport) sport = strings.ToLower(sport)
var update bool var update bool
pconfig := c.config
if host == "no" && sport == "one" { if host == "no" && sport == "one" {
update = c.config.FollowHost != "" || c.config.FollowPort != 0 update = c.config.followHost() != "" || c.config.followPort() != 0
c.config.FollowHost = "" c.config.setFollowHost("")
c.config.FollowPort = 0 c.config.setFollowPort(0)
} else { } else {
n, err := strconv.ParseUint(sport, 10, 64) n, err := strconv.ParseUint(sport, 10, 64)
if err != nil { if err != nil {
return "", errInvalidArgument(sport) return "", errInvalidArgument(sport)
} }
port := int(n) port := int(n)
update = c.config.FollowHost != host || c.config.FollowPort != port update = c.config.followHost() != host || c.config.followPort() != port
auth := c.config.LeaderAuth auth := c.config.leaderAuth()
if update { if update {
c.mu.Unlock() c.mu.Unlock()
conn, err := DialTimeout(fmt.Sprintf("%s:%d", host, port), time.Second*2) conn, err := DialTimeout(fmt.Sprintf("%s:%d", host, port), time.Second*2)
@ -71,7 +70,7 @@ func (c *Controller) cmdFollow(msg *server.Message) (res string, err error) {
c.mu.Lock() c.mu.Lock()
return "", fmt.Errorf("cannot follow: invalid id") return "", fmt.Errorf("cannot follow: invalid id")
} }
if m["id"] == c.config.ServerID { if m["id"] == c.config.serverID() {
c.mu.Lock() c.mu.Lock()
return "", fmt.Errorf("cannot follow self") return "", fmt.Errorf("cannot follow self")
} }
@ -81,18 +80,15 @@ func (c *Controller) cmdFollow(msg *server.Message) (res string, err error) {
} }
c.mu.Lock() c.mu.Lock()
} }
c.config.FollowHost = host c.config.setFollowHost(host)
c.config.FollowPort = port c.config.setFollowPort(port)
}
if err := c.writeConfig(false); err != nil {
c.config = pconfig // revert
return "", err
} }
c.config.write(false)
if update { if update {
c.followc++ c.followc++
if c.config.FollowHost != "" { if c.config.followHost() != "" {
log.Infof("following new host '%s' '%s'.", host, sport) log.Infof("following new host '%s' '%s'.", host, sport)
go c.follow(c.config.FollowHost, c.config.FollowPort, c.followc) go c.follow(c.config.followHost(), c.config.followPort(), c.followc)
} else { } else {
log.Infof("following no one") log.Infof("following no one")
} }
@ -159,7 +155,7 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
return errNoLongerFollowing return errNoLongerFollowing
} }
c.fcup = false c.fcup = false
auth := c.config.LeaderAuth auth := c.config.leaderAuth()
c.mu.Unlock() c.mu.Unlock()
addr := fmt.Sprintf("%s:%d", host, port) addr := fmt.Sprintf("%s:%d", host, port)
@ -182,7 +178,7 @@ func (c *Controller) followStep(host string, port int, followc uint64) error {
if m["id"] == "" { if m["id"] == "" {
return fmt.Errorf("cannot follow: invalid id") return fmt.Errorf("cannot follow: invalid id")
} }
if m["id"] == c.config.ServerID { if m["id"] == c.config.serverID() {
return fmt.Errorf("cannot follow self") return fmt.Errorf("cannot follow self")
} }
if m["following"] != "" { if m["following"] != "" {

View File

@ -20,29 +20,24 @@ func (c *Controller) cmdReadOnly(msg *server.Message) (res string, err error) {
return "", errInvalidNumberOfArguments return "", errInvalidNumberOfArguments
} }
update := false update := false
backup := c.config
switch strings.ToLower(arg) { switch strings.ToLower(arg) {
default: default:
return "", errInvalidArgument(arg) return "", errInvalidArgument(arg)
case "yes": case "yes":
if !c.config.ReadOnly { if !c.config.readOnly() {
update = true update = true
c.config.ReadOnly = true c.config.setReadOnly(true)
log.Info("read only") log.Info("read only")
} }
case "no": case "no":
if c.config.ReadOnly { if c.config.readOnly() {
update = true update = true
c.config.ReadOnly = false c.config.setReadOnly(false)
log.Info("read write") log.Info("read write")
} }
} }
if update { if update {
err := c.writeConfig(false) c.config.write(false)
if err != nil {
c.config = backup
return "", err
}
} }
return server.OKMessage(msg, start), nil return server.OKMessage(msg, start), nil
} }

View File

@ -76,9 +76,9 @@ func (c *Controller) cmdServer(msg *server.Message) (res string, err error) {
return "", errInvalidNumberOfArguments return "", errInvalidNumberOfArguments
} }
m := make(map[string]interface{}) m := make(map[string]interface{})
m["id"] = c.config.ServerID m["id"] = c.config.serverID()
if c.config.FollowHost != "" { if c.config.followHost() != "" {
m["following"] = fmt.Sprintf("%s:%d", c.config.FollowHost, c.config.FollowPort) m["following"] = fmt.Sprintf("%s:%d", c.config.followHost(), c.config.followPort())
m["caught_up"] = c.fcup m["caught_up"] = c.fcup
m["caught_up_once"] = c.fcuponce m["caught_up_once"] = c.fcuponce
} }
@ -116,10 +116,10 @@ func (c *Controller) cmdServer(msg *server.Message) (res string, err error) {
m["mem_alloc"] = mem.Alloc m["mem_alloc"] = mem.Alloc
m["heap_size"] = mem.HeapAlloc m["heap_size"] = mem.HeapAlloc
m["heap_released"] = mem.HeapReleased m["heap_released"] = mem.HeapReleased
m["max_heap_size"] = c.config.MaxMemory m["max_heap_size"] = c.config.maxMemory()
m["avg_item_size"] = avgsz m["avg_item_size"] = avgsz
m["pointer_size"] = (32 << uintptr(uint64(^uintptr(0))>>63)) / 8 m["pointer_size"] = (32 << uintptr(uint64(^uintptr(0))>>63)) / 8
m["read_only"] = c.config.ReadOnly m["read_only"] = c.config.readOnly()
switch msg.OutputType { switch msg.OutputType {
case server.JSON: case server.JSON: