From 49111a4dce6e1cbb9c3c6efe6d60dc4fec00f5b9 Mon Sep 17 00:00:00 2001 From: Alex Roitman Date: Wed, 24 Apr 2019 17:00:52 -0700 Subject: [PATCH] Add timeouts for lua scripts --- internal/deadline/deadline.go | 6 ++++++ internal/server/scripts.go | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/internal/deadline/deadline.go b/internal/deadline/deadline.go index 6be76543..697492a3 100644 --- a/internal/deadline/deadline.go +++ b/internal/deadline/deadline.go @@ -13,6 +13,7 @@ func New(deadline time.Time) *Deadline { return &Deadline{unixNano: deadline.UnixNano()} } +// Empty deadline does nothing, just a place holder for future updates func Empty() *Deadline { return &Deadline{} } @@ -38,3 +39,8 @@ func (deadline *Deadline) Check() { func (deadline *Deadline) Hit() bool { return deadline.hit } + +// GetDeadlineTime returns the time object for the deadline, and an "empty" boolean +func (deadline *Deadline) GetDeadlineTime() (time.Time, bool) { + return time.Unix(0, deadline.unixNano), deadline.unixNano == 0 +} diff --git a/internal/server/scripts.go b/internal/server/scripts.go index 42735db5..9f87a1ac 100644 --- a/internal/server/scripts.go +++ b/internal/server/scripts.go @@ -2,6 +2,7 @@ package server import ( "bytes" + "context" "crypto/sha1" "encoding/hex" "encoding/json" @@ -391,6 +392,13 @@ func (c *Server) cmdEvalUnified(scriptIsSha bool, msg *Message) (res resp.Value, if err != nil { return } + deadline, empty := msg.Deadline.GetDeadlineTime() + if !empty { + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + luaState.SetContext(ctx) + defer luaState.RemoveContext() + } defer c.luapool.Put(luaState) keysTbl := luaState.CreateTable(int(numkeys), 0) @@ -454,6 +462,9 @@ func (c *Server) cmdEvalUnified(scriptIsSha bool, msg *Message) (res resp.Value, "EVAL_CMD": lua.LNil, }) if err := luaState.PCall(0, 1, nil); err != nil { + if strings.Contains(err.Error(), "context deadline exceeded") { + msg.Deadline.Check() + } log.Debugf("%v", err.Error()) return NOMessage, makeSafeErr(err) }