diff --git a/controller/endpoint/disque.go b/controller/endpoint/disque.go index 169177a9..404a6f6f 100644 --- a/controller/endpoint/disque.go +++ b/controller/endpoint/disque.go @@ -57,7 +57,7 @@ func (conn *DisqueEndpointConn) Send(msg string) error { conn.mu.Lock() defer conn.mu.Unlock() if conn.ex { - return errors.New("expired") + return errExpired } conn.t = time.Now() if conn.conn == nil { diff --git a/controller/endpoint/endpoint.go b/controller/endpoint/endpoint.go index 3b2ed27d..b0075bc1 100644 --- a/controller/endpoint/endpoint.go +++ b/controller/endpoint/endpoint.go @@ -9,6 +9,8 @@ import ( "time" ) +var errExpired = errors.New("expired") + // EndpointProtocol is the type of protocol that the endpoint represents. type EndpointProtocol string @@ -36,9 +38,9 @@ type Endpoint struct { } } Redis struct { - Host string - Port int - Channel string + Host string + Port int + Channel string } } @@ -82,30 +84,42 @@ func (epc *EndpointManager) Validate(url string) error { } func (epc *EndpointManager) Send(endpoint, val string) error { - epc.mu.Lock() - conn, ok := epc.conns[endpoint] - if !ok || conn.Expired() { - ep, err := parseEndpoint(endpoint) + for { + epc.mu.Lock() + conn, ok := epc.conns[endpoint] + if !ok || conn.Expired() { + ep, err := parseEndpoint(endpoint) + if err != nil { + epc.mu.Unlock() + return err + } + switch ep.Protocol { + default: + return errors.New("invalid protocol") + case HTTP: + conn = newHTTPEndpointConn(ep) + case Disque: + conn = newDisqueEndpointConn(ep) + case GRPC: + conn = newGRPCEndpointConn(ep) + case Redis: + conn = newRedisEndpointConn(ep) + } + epc.conns[endpoint] = conn + } + epc.mu.Unlock() + err := conn.Send(val) if err != nil { - epc.mu.Unlock() + if err == errExpired { + // it's possible that the connection has expired in-between + // the last conn.Expired() check and now. If so, we should + // just try the send again. + continue + } return err } - switch ep.Protocol { - default: - return errors.New("invalid protocol") - case HTTP: - conn = newHTTPEndpointConn(ep) - case Disque: - conn = newDisqueEndpointConn(ep) - case GRPC: - conn = newGRPCEndpointConn(ep) - case Redis: - conn = newRedisEndpointConn(ep) - } - epc.conns[endpoint] = conn + return nil } - epc.mu.Unlock() - return conn.Send(val) } func parseEndpoint(s string) (Endpoint, error) { diff --git a/controller/endpoint/grpc.go b/controller/endpoint/grpc.go index a2d19d53..821d586c 100644 --- a/controller/endpoint/grpc.go +++ b/controller/endpoint/grpc.go @@ -57,7 +57,7 @@ func (conn *GRPCEndpointConn) Send(msg string) error { conn.mu.Lock() defer conn.mu.Unlock() if conn.ex { - return errors.New("expired") + return errExpired } conn.t = time.Now() if conn.conn == nil { diff --git a/controller/endpoint/http.go b/controller/endpoint/http.go index 34c3cd53..07cdc677 100644 --- a/controller/endpoint/http.go +++ b/controller/endpoint/http.go @@ -2,7 +2,6 @@ package endpoint import ( "bytes" - "errors" "fmt" "io" "io/ioutil" @@ -48,7 +47,7 @@ func (conn *HTTPEndpointConn) Send(msg string) error { conn.mu.Lock() defer conn.mu.Unlock() if conn.ex { - return errors.New("expired") + return errExpired } conn.t = time.Now() if conn.client == nil { diff --git a/controller/endpoint/redis.go b/controller/endpoint/redis.go index 1343ca1d..b1440f99 100644 --- a/controller/endpoint/redis.go +++ b/controller/endpoint/redis.go @@ -56,7 +56,7 @@ func (conn *RedisEndpointConn) Send(msg string) error { defer conn.mu.Unlock() if conn.ex { - return errors.New("expired") + return errExpired } conn.t = time.Now()