Add timeouts for lua scripts

This commit is contained in:
Alex Roitman 2019-04-24 17:00:52 -07:00
parent 878f9dedb2
commit 49111a4dce
2 changed files with 17 additions and 0 deletions

View File

@ -13,6 +13,7 @@ func New(deadline time.Time) *Deadline {
return &Deadline{unixNano: deadline.UnixNano()} return &Deadline{unixNano: deadline.UnixNano()}
} }
// Empty deadline does nothing, just a place holder for future updates
func Empty() *Deadline { func Empty() *Deadline {
return &Deadline{} return &Deadline{}
} }
@ -38,3 +39,8 @@ func (deadline *Deadline) Check() {
func (deadline *Deadline) Hit() bool { func (deadline *Deadline) Hit() bool {
return deadline.hit return deadline.hit
} }
// GetDeadlineTime returns the time object for the deadline, and an "empty" boolean
func (deadline *Deadline) GetDeadlineTime() (time.Time, bool) {
return time.Unix(0, deadline.unixNano), deadline.unixNano == 0
}

View File

@ -2,6 +2,7 @@ package server
import ( import (
"bytes" "bytes"
"context"
"crypto/sha1" "crypto/sha1"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
@ -391,6 +392,13 @@ func (c *Server) cmdEvalUnified(scriptIsSha bool, msg *Message) (res resp.Value,
if err != nil { if err != nil {
return return
} }
deadline, empty := msg.Deadline.GetDeadlineTime()
if !empty {
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
luaState.SetContext(ctx)
defer luaState.RemoveContext()
}
defer c.luapool.Put(luaState) defer c.luapool.Put(luaState)
keysTbl := luaState.CreateTable(int(numkeys), 0) keysTbl := luaState.CreateTable(int(numkeys), 0)
@ -454,6 +462,9 @@ func (c *Server) cmdEvalUnified(scriptIsSha bool, msg *Message) (res resp.Value,
"EVAL_CMD": lua.LNil, "EVAL_CMD": lua.LNil,
}) })
if err := luaState.PCall(0, 1, nil); err != nil { if err := luaState.PCall(0, 1, nil); err != nil {
if strings.Contains(err.Error(), "context deadline exceeded") {
msg.Deadline.Check()
}
log.Debugf("%v", err.Error()) log.Debugf("%v", err.Error())
return NOMessage, makeSafeErr(err) return NOMessage, makeSafeErr(err)
} }