tile38/internal/server/scripts.go
tidwall 9e68703841 Update expiration logic
This commit changes the logic for managing the expiration of
objects in the database.

Before: There was a server-wide hashmap that stored the
collection key, id, and expiration timestamp for all objects
that had a TTL. The hashmap was occasionally probed at 20
random positions, looking for objects that have expired. Those
expired objects were immediately deleted, and if there was 5
or more objects deleted, then the probe happened again, with
no delay. If the number of objects was less than 5 then the
there was a 1/10th of a second delay before the next probe.

Now: Rather than a server-wide hashmap, each collection has
its own ordered priority queue that stores objects with TTLs.
Rather than probing, there is a background routine that
executes every 1/10th of a second, which pops the expired
objects from the collection queues, and deletes them.

The collection/queue method is a more stable approach than
the hashmap/probing method. With probing, we can run into
major cache misses for some cases where there is wide
TTL duration, such as in the hours or days. This may cause
the system to occasionally fall behind, leaving should-be
expired objects in memory. Using a queue, there is no
cache misses, all objects that should be expired will be
right away, regardless of the TTL durations.

Fixes #616
2021-07-12 13:37:50 -07:00

847 lines
20 KiB
Go

package server
import (
"bytes"
"context"
"crypto/sha1"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math"
"strconv"
"strings"
"sync"
"time"
"github.com/tidwall/geojson/geo"
"github.com/tidwall/resp"
"github.com/tidwall/tile38/internal/log"
lua "github.com/yuin/gopher-lua"
luajson "layeh.com/gopher-json"
)
const (
iniLuaPoolSize = 5
maxLuaPoolSize = 1000
)
var errShaNotFound = errors.New("sha not found")
var errCmdNotSupported = errors.New("command not supported in scripts")
var errNotLeader = errors.New("not the leader")
var errReadOnly = errors.New("read only")
var errCatchingUp = errors.New("catching up to leader")
var errNoLuasAvailable = errors.New("no interpreters available")
var errTimeout = errors.New("timeout")
// Go-routine-safe pool of read-to-go lua states
type lStatePool struct {
m sync.Mutex
s *Server
saved []*lua.LState
total int
}
// newPool returns a new pool of lua states
func (s *Server) newPool() *lStatePool {
pl := &lStatePool{
saved: make([]*lua.LState, iniLuaPoolSize),
s: s,
}
// Fill the pool with some ready handlers
for i := 0; i < iniLuaPoolSize; i++ {
pl.saved[i] = pl.New()
pl.total++
}
return pl
}
func (pl *lStatePool) Get() (*lua.LState, error) {
pl.m.Lock()
defer pl.m.Unlock()
n := len(pl.saved)
if n == 0 {
if pl.total >= maxLuaPoolSize {
return nil, errNoLuasAvailable
}
pl.total++
return pl.New(), nil
}
x := pl.saved[n-1]
pl.saved = pl.saved[0 : n-1]
return x, nil
}
// Prune removes some of the idle lua states from the pool
func (pl *lStatePool) Prune() {
pl.m.Lock()
n := len(pl.saved)
if n > iniLuaPoolSize {
// drop half of the idle states that is above the minimum
dropNum := (n - iniLuaPoolSize) / 2
if dropNum < 1 {
dropNum = 1
}
newSaved := make([]*lua.LState, n-dropNum)
copy(newSaved, pl.saved[dropNum:])
pl.saved = newSaved
pl.total -= dropNum
}
pl.m.Unlock()
}
func (pl *lStatePool) New() *lua.LState {
L := lua.NewState()
getArgs := func(ls *lua.LState) (evalCmd string, args []string) {
evalCmd = ls.GetGlobal("EVAL_CMD").String()
// Trying to work with unknown number of args.
// When we see empty arg we call it enough.
for i := 1; ; i++ {
if arg := ls.ToString(i); arg == "" {
break
} else {
args = append(args, arg)
}
}
return
}
call := func(ls *lua.LState) int {
evalCmd, args := getArgs(ls)
var numRet int
if res, err := pl.s.luaTile38Call(evalCmd, args[0], args[1:]...); err != nil {
ls.RaiseError("ERR %s", err.Error())
numRet = 0
} else {
ls.Push(ConvertToLua(ls, res))
numRet = 1
}
return numRet
}
pcall := func(ls *lua.LState) int {
evalCmd, args := getArgs(ls)
if res, err := pl.s.luaTile38Call(evalCmd, args[0], args[1:]...); err != nil {
ls.Push(ConvertToLua(ls, resp.ErrorValue(err)))
} else {
ls.Push(ConvertToLua(ls, res))
}
return 1
}
errorReply := func(ls *lua.LState) int {
tbl := L.CreateTable(0, 1)
tbl.RawSetString("err", lua.LString(ls.ToString(1)))
ls.Push(tbl)
return 1
}
statusReply := func(ls *lua.LState) int {
tbl := L.CreateTable(0, 1)
tbl.RawSetString("ok", lua.LString(ls.ToString(1)))
ls.Push(tbl)
return 1
}
sha1hex := func(ls *lua.LState) int {
shaSum := Sha1Sum(ls.ToString(1))
ls.Push(lua.LString(shaSum))
return 1
}
distanceTo := func(ls *lua.LState) int {
dt := geo.DistanceTo(
float64(ls.ToNumber(1)),
float64(ls.ToNumber(2)),
float64(ls.ToNumber(3)),
float64(ls.ToNumber(4)))
ls.Push(lua.LNumber(dt))
return 1
}
var exports = map[string]lua.LGFunction{
"call": call,
"pcall": pcall,
"error_reply": errorReply,
"status_reply": statusReply,
"sha1hex": sha1hex,
"distance_to": distanceTo,
}
L.SetGlobal("tile38", L.SetFuncs(L.NewTable(), exports))
// Load json
L.SetGlobal("json", L.Get(luajson.Loader(L)))
// Prohibit creating new globals in this state
lockNewGlobals := func(ls *lua.LState) int {
ls.RaiseError("attempt to create global variable '%s'", ls.ToString(2))
return 0
}
mt := L.CreateTable(0, 1)
mt.RawSetString("__newindex", L.NewFunction(lockNewGlobals))
L.SetMetatable(L.Get(lua.GlobalsIndex), mt)
return L
}
func (pl *lStatePool) Put(L *lua.LState) {
pl.m.Lock()
pl.saved = append(pl.saved, L)
pl.m.Unlock()
}
func (pl *lStatePool) Shutdown() {
pl.m.Lock()
for _, L := range pl.saved {
L.Close()
}
pl.m.Unlock()
}
// Go-routine-safe map of compiled scripts
type lScriptMap struct {
m sync.Mutex
scripts map[string]*lua.FunctionProto
}
func (sm *lScriptMap) Get(key string) (script *lua.FunctionProto, ok bool) {
sm.m.Lock()
script, ok = sm.scripts[key]
sm.m.Unlock()
return
}
func (sm *lScriptMap) Put(key string, script *lua.FunctionProto) {
sm.m.Lock()
sm.scripts[key] = script
sm.m.Unlock()
}
func (sm *lScriptMap) Flush() {
sm.m.Lock()
sm.scripts = make(map[string]*lua.FunctionProto)
sm.m.Unlock()
}
// NewScriptMap returns a new map with lua scripts
func (s *Server) newScriptMap() *lScriptMap {
return &lScriptMap{
scripts: make(map[string]*lua.FunctionProto),
}
}
// ConvertToLua converts RESP value to lua LValue
func ConvertToLua(L *lua.LState, val resp.Value) lua.LValue {
if val.IsNull() {
return lua.LFalse
}
switch val.Type() {
case resp.Integer:
return lua.LNumber(val.Integer())
case resp.BulkString:
return lua.LString(val.String())
case resp.Error:
tbl := L.CreateTable(0, 1)
tbl.RawSetString("err", lua.LString(val.String()))
return tbl
case resp.SimpleString:
tbl := L.CreateTable(0, 1)
tbl.RawSetString("ok", lua.LString(val.String()))
return tbl
case resp.Array:
tbl := L.CreateTable(len(val.Array()), 0)
for _, item := range val.Array() {
tbl.Append(ConvertToLua(L, item))
}
return tbl
}
return lua.LString("ERR: unknown RESP type: " + val.Type().String())
}
// ConvertToRESP convert lua LValue to RESP value
func ConvertToRESP(val lua.LValue) resp.Value {
switch val.Type() {
case lua.LTNil:
return resp.NullValue()
case lua.LTBool:
if val == lua.LTrue {
return resp.IntegerValue(1)
}
return resp.NullValue()
case lua.LTNumber:
float := float64(val.(lua.LNumber))
if math.IsNaN(float) || math.IsInf(float, 0) {
return resp.FloatValue(float)
}
return resp.IntegerValue(int(math.Floor(float)))
case lua.LTString:
return resp.StringValue(val.String())
case lua.LTTable:
var values []resp.Value
var specialValues []resp.Value
var cb func(lk lua.LValue, lv lua.LValue)
tbl := val.(*lua.LTable)
if tbl.Len() != 0 { // list
cb = func(lk lua.LValue, lv lua.LValue) {
values = append(values, ConvertToRESP(lv))
}
} else { // map
cb = func(lk lua.LValue, lv lua.LValue) {
if lk.Type() == lua.LTString {
lks := lk.String()
switch lks {
case "ok":
specialValues = append(specialValues, resp.SimpleStringValue(lv.String()))
case "err":
specialValues = append(specialValues, resp.ErrorValue(errors.New(lv.String())))
}
}
values = append(values, resp.ArrayValue(
[]resp.Value{ConvertToRESP(lk), ConvertToRESP(lv)}))
}
}
tbl.ForEach(cb)
if len(values) == 1 && len(specialValues) == 1 {
return specialValues[0]
}
return resp.ArrayValue(values)
}
return resp.ErrorValue(errors.New("Unsupported lua type: " + val.Type().String()))
}
// ConvertToJSON converts lua LValue to JSON string
func ConvertToJSON(val lua.LValue) string {
switch val.Type() {
case lua.LTNil:
return "null"
case lua.LTBool:
if val == lua.LTrue {
return "true"
}
return "false"
case lua.LTNumber:
return val.String()
case lua.LTString:
if b, err := json.Marshal(val.String()); err != nil {
panic(err)
} else {
return string(b)
}
case lua.LTTable:
var values []string
var cb func(lk lua.LValue, lv lua.LValue)
var start, end string
tbl := val.(*lua.LTable)
if tbl.Len() != 0 { // list
start = `[`
end = `]`
cb = func(lk lua.LValue, lv lua.LValue) {
values = append(values, ConvertToJSON(lv))
}
} else { // map
start = `{`
end = `}`
cb = func(lk lua.LValue, lv lua.LValue) {
values = append(
values, ConvertToJSON(lk)+`:`+ConvertToJSON(lv))
}
}
tbl.ForEach(cb)
return start + strings.Join(values, `,`) + end
}
return "Unsupported lua type: " + val.Type().String()
}
func luaSetRawGlobals(ls *lua.LState, tbl map[string]lua.LValue) {
gt := ls.Get(lua.GlobalsIndex).(*lua.LTable)
for key, val := range tbl {
gt.RawSetString(key, val)
}
}
// Sha1Sum returns a string with hex representation of sha1 sum of a given string
func Sha1Sum(s string) string {
h := sha1.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}
// Replace newlines with literal \n since RESP errors cannot have newlines
func makeSafeErr(err error) error {
return errors.New(strings.Replace(err.Error(), "\n", `\n`, -1))
}
// Run eval/evalro/evalna command or it's -sha variant
func (s *Server) cmdEvalUnified(scriptIsSha bool, msg *Message) (res resp.Value, err error) {
start := time.Now()
vs := msg.Args[1:]
var ok bool
var script, numkeysStr, key, arg string
if vs, script, ok = tokenval(vs); !ok || script == "" {
return NOMessage, errInvalidNumberOfArguments
}
if vs, numkeysStr, ok = tokenval(vs); !ok || numkeysStr == "" {
return NOMessage, errInvalidNumberOfArguments
}
var i, numkeys uint64
if numkeys, err = strconv.ParseUint(numkeysStr, 10, 64); err != nil {
err = errInvalidArgument(numkeysStr)
return
}
luaState, err := s.luapool.Get()
if err != nil {
return
}
luaDeadline := lua.LNil
if msg.Deadline != nil {
dlTime := msg.Deadline.GetDeadlineTime()
ctx, cancel := context.WithDeadline(context.Background(), dlTime)
defer cancel()
luaState.SetContext(ctx)
defer luaState.RemoveContext()
luaDeadline = lua.LNumber(float64(dlTime.UnixNano()) / 1e9)
}
defer s.luapool.Put(luaState)
keysTbl := luaState.CreateTable(int(numkeys), 0)
for i = 0; i < numkeys; i++ {
if vs, key, ok = tokenval(vs); !ok || key == "" {
err = errInvalidNumberOfArguments
return
}
keysTbl.Append(lua.LString(key))
}
argsTbl := luaState.CreateTable(len(vs), 0)
for len(vs) > 0 {
if vs, arg, ok = tokenval(vs); !ok || arg == "" {
err = errInvalidNumberOfArguments
return
}
argsTbl.Append(lua.LString(arg))
}
var shaSum string
if scriptIsSha {
shaSum = script
} else {
shaSum = Sha1Sum(script)
}
luaSetRawGlobals(
luaState, map[string]lua.LValue{
"KEYS": keysTbl,
"ARGV": argsTbl,
"DEADLINE": luaDeadline,
"EVAL_CMD": lua.LString(msg.Command()),
})
compiled, ok := s.luascripts.Get(shaSum)
var fn *lua.LFunction
if ok {
fn = &lua.LFunction{
IsG: false,
Env: luaState.Env,
Proto: compiled,
GFunction: nil,
Upvalues: make([]*lua.Upvalue, 0),
}
} else if scriptIsSha {
err = errShaNotFound
return
} else {
fn, err = luaState.Load(strings.NewReader(script), "f_"+shaSum)
if err != nil {
return NOMessage, makeSafeErr(err)
}
s.luascripts.Put(shaSum, fn.Proto)
}
luaState.Push(fn)
defer luaSetRawGlobals(
luaState, map[string]lua.LValue{
"KEYS": lua.LNil,
"ARGV": lua.LNil,
"DEADLINE": lua.LNil,
"EVAL_CMD": lua.LNil,
})
if err := luaState.PCall(0, 1, nil); err != nil {
if strings.Contains(err.Error(), "context deadline exceeded") {
msg.Deadline.Check()
}
log.Debugf("%v", err.Error())
return NOMessage, makeSafeErr(err)
}
ret := luaState.Get(-1) // returned value
luaState.Pop(1)
switch msg.OutputType {
case JSON:
var buf bytes.Buffer
buf.WriteString(`{"ok":true`)
buf.WriteString(`,"result":` + ConvertToJSON(ret))
buf.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}")
return resp.StringValue(buf.String()), nil
case RESP:
return ConvertToRESP(ret), nil
}
return NOMessage, nil
}
func (s *Server) cmdScriptLoad(msg *Message) (resp.Value, error) {
start := time.Now()
vs := msg.Args[1:]
var ok bool
var script string
if _, script, ok = tokenval(vs); !ok || script == "" {
return NOMessage, errInvalidNumberOfArguments
}
shaSum := Sha1Sum(script)
luaState, err := s.luapool.Get()
if err != nil {
return NOMessage, err
}
defer s.luapool.Put(luaState)
fn, err := luaState.Load(strings.NewReader(script), "f_"+shaSum)
if err != nil {
return NOMessage, makeSafeErr(err)
}
s.luascripts.Put(shaSum, fn.Proto)
switch msg.OutputType {
case JSON:
var buf bytes.Buffer
buf.WriteString(`{"ok":true`)
buf.WriteString(`,"result":"` + shaSum + `"`)
buf.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}")
return resp.StringValue(buf.String()), nil
case RESP:
return resp.StringValue(shaSum), nil
}
return NOMessage, nil
}
func (s *Server) cmdScriptExists(msg *Message) (resp.Value, error) {
start := time.Now()
vs := msg.Args[1:]
var ok bool
var shaSum string
var results []int
var ires int
for len(vs) > 0 {
if vs, shaSum, ok = tokenval(vs); !ok || shaSum == "" {
return NOMessage, errInvalidNumberOfArguments
}
_, ok = s.luascripts.Get(shaSum)
if ok {
ires = 1
} else {
ires = 0
}
results = append(results, ires)
}
switch msg.OutputType {
case JSON:
var buf bytes.Buffer
buf.WriteString(`{"ok":true`)
var resArray []string
for _, ires := range results {
resArray = append(resArray, fmt.Sprintf("%d", ires))
}
buf.WriteString(`,"result":[` + strings.Join(resArray, ",") + `]`)
buf.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}")
return resp.StringValue(buf.String()), nil
case RESP:
var resArray []resp.Value
for _, ires := range results {
resArray = append(resArray, resp.IntegerValue(ires))
}
return resp.ArrayValue(resArray), nil
}
return resp.SimpleStringValue(""), nil
}
func (s *Server) cmdScriptFlush(msg *Message) (resp.Value, error) {
start := time.Now()
s.luascripts.Flush()
switch msg.OutputType {
case JSON:
var buf bytes.Buffer
buf.WriteString(`{"ok":true`)
buf.WriteString(`,"elapsed":"` + time.Since(start).String() + "\"}")
return resp.StringValue(buf.String()), nil
case RESP:
return resp.StringValue("OK"), nil
}
return resp.SimpleStringValue(""), nil
}
func (s *Server) commandInScript(msg *Message) (
res resp.Value, d commandDetails, err error,
) {
switch msg.Command() {
default:
err = fmt.Errorf("unknown command '%s'", msg.Args[0])
case "set":
res, d, err = s.cmdSet(msg)
case "fset":
res, d, err = s.cmdFset(msg)
case "del":
res, d, err = s.cmdDel(msg)
case "pdel":
res, d, err = s.cmdPdel(msg)
case "drop":
res, d, err = s.cmdDrop(msg)
case "expire":
res, d, err = s.cmdExpire(msg)
case "rename":
res, d, err = s.cmdRename(msg, false)
case "renamenx":
res, d, err = s.cmdRename(msg, true)
case "persist":
res, d, err = s.cmdPersist(msg)
case "ttl":
res, err = s.cmdTTL(msg)
case "stats":
res, err = s.cmdStats(msg)
case "scan":
res, err = s.cmdScan(msg)
case "nearby":
res, err = s.cmdNearby(msg)
case "within":
res, err = s.cmdWithin(msg)
case "intersects":
res, err = s.cmdIntersects(msg)
case "search":
res, err = s.cmdSearch(msg)
case "bounds":
res, err = s.cmdBounds(msg)
case "get":
res, err = s.cmdGet(msg)
case "jget":
res, err = s.cmdJget(msg)
case "jset":
res, d, err = s.cmdJset(msg)
case "jdel":
res, d, err = s.cmdJdel(msg)
case "type":
res, err = s.cmdType(msg)
case "keys":
res, err = s.cmdKeys(msg)
case "test":
res, err = s.cmdTest(msg)
case "server":
res, err = s.cmdServer(msg)
}
s.sendMonitor(err, msg, nil, true)
return
}
func (s *Server) luaTile38Call(evalcmd string, cmd string, args ...string) (resp.Value, error) {
msg := &Message{}
msg.OutputType = RESP
msg.Args = append([]string{cmd}, args...)
if msg.Command() == "timeout" {
if err := rewriteTimeoutMsg(msg); err != nil {
return resp.NullValue(), err
}
}
switch msg.Command() {
case "ping", "echo", "auth", "massinsert", "shutdown", "gc",
"sethook", "pdelhook", "delhook",
"follow", "readonly", "config", "output", "client",
"aofshrink",
"script load", "script exists", "script flush",
"eval", "evalsha", "evalro", "evalrosha", "evalna", "evalnasha":
return resp.NullValue(), errCmdNotSupported
}
switch evalcmd {
case "eval", "evalsha":
return s.luaTile38AtomicRW(msg)
case "evalro", "evalrosha":
return s.luaTile38AtomicRO(msg)
case "evalna", "evalnasha":
return s.luaTile38NonAtomic(msg)
}
return resp.NullValue(), errCmdNotSupported
}
// The eval command has already got the lock. No locking on the call from within the script.
func (s *Server) luaTile38AtomicRW(msg *Message) (resp.Value, error) {
var write bool
switch msg.Command() {
default:
return resp.NullValue(), errCmdNotSupported
case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel",
"rename", "renamenx":
// write operations
write = true
if s.config.followHost() != "" {
return resp.NullValue(), errNotLeader
}
if s.config.readOnly() {
return resp.NullValue(), errReadOnly
}
case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "search",
"ttl", "bounds", "server", "info", "type", "jget", "test":
// read operations
if s.config.followHost() != "" && !s.fcuponce {
return resp.NullValue(), errCatchingUp
}
}
res, d, err := func() (res resp.Value, d commandDetails, err error) {
if msg.Deadline != nil {
if write {
res = NOMessage
err = errTimeoutOnCmd(msg.Command())
return
}
defer func() {
if msg.Deadline.Hit() {
v := recover()
if v != nil {
if s, ok := v.(string); !ok || s != "deadline" {
panic(v)
}
}
res = NOMessage
err = errTimeout
}
}()
}
return s.commandInScript(msg)
}()
if err != nil {
return resp.NullValue(), err
}
if write {
if err := s.writeAOF(msg.Args, &d); err != nil {
return resp.NullValue(), err
}
}
return res, nil
}
func (s *Server) luaTile38AtomicRO(msg *Message) (resp.Value, error) {
switch msg.Command() {
default:
return resp.NullValue(), errCmdNotSupported
case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel",
"rename", "renamenx":
// write operations
return resp.NullValue(), errReadOnly
case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "search",
"ttl", "bounds", "server", "info", "type", "jget", "test":
// read operations
if s.config.followHost() != "" && !s.fcuponce {
return resp.NullValue(), errCatchingUp
}
}
res, _, err := func() (res resp.Value, d commandDetails, err error) {
if msg.Deadline != nil {
defer func() {
if msg.Deadline.Hit() {
v := recover()
if v != nil {
if s, ok := v.(string); !ok || s != "deadline" {
panic(v)
}
}
res = NOMessage
err = errTimeout
}
}()
}
return s.commandInScript(msg)
}()
if err != nil {
return resp.NullValue(), err
}
return res, nil
}
func (s *Server) luaTile38NonAtomic(msg *Message) (resp.Value, error) {
var write bool
// choose the locking strategy
switch msg.Command() {
default:
return resp.NullValue(), errCmdNotSupported
case "set", "del", "drop", "fset", "flushdb", "expire", "persist", "jset", "pdel",
"rename", "renamenx":
// write operations
write = true
s.mu.Lock()
defer s.mu.Unlock()
if s.config.followHost() != "" {
return resp.NullValue(), errNotLeader
}
if s.config.readOnly() {
return resp.NullValue(), errReadOnly
}
case "get", "keys", "scan", "nearby", "within", "intersects", "hooks", "search",
"ttl", "bounds", "server", "info", "type", "jget", "test":
// read operations
s.mu.RLock()
defer s.mu.RUnlock()
if s.config.followHost() != "" && !s.fcuponce {
return resp.NullValue(), errCatchingUp
}
}
res, d, err := func() (res resp.Value, d commandDetails, err error) {
if msg.Deadline != nil {
if write {
res = NOMessage
err = errTimeoutOnCmd(msg.Command())
return
}
defer func() {
if msg.Deadline.Hit() {
v := recover()
if v != nil {
if s, ok := v.(string); !ok || s != "deadline" {
panic(v)
}
}
res = NOMessage
err = errTimeout
}
}()
}
return s.commandInScript(msg)
}()
if err != nil {
return resp.NullValue(), err
}
if write {
if err := s.writeAOF(msg.Args, &d); err != nil {
return resp.NullValue(), err
}
}
return res, nil
}