From 53271ebad6788fe34f55d49444b822c83292664a Mon Sep 17 00:00:00 2001 From: Lenny-Campino Hartmann Date: Tue, 7 Aug 2018 21:04:04 +0200 Subject: [PATCH 1/3] Added NATS endpoint --- pkg/endpoint/endpoint.go | 87 +++++++++++++++++++++++++++++++++++----- pkg/endpoint/nats.go | 80 ++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 10 deletions(-) create mode 100644 pkg/endpoint/nats.go diff --git a/pkg/endpoint/endpoint.go b/pkg/endpoint/endpoint.go index 6c60ec65..0b5ace8e 100644 --- a/pkg/endpoint/endpoint.go +++ b/pkg/endpoint/endpoint.go @@ -33,6 +33,8 @@ const ( AMQP = Protocol("amqp") // SQS protocol SQS = Protocol("sqs") + // NATS protocol + NATS = Protocol("nats") ) // Endpoint represents an endpoint. @@ -90,6 +92,14 @@ type Endpoint struct { CredProfile string QueueName string } + + NATS struct { + Host string + Port int + User string + Pass string + Topic string + } } // Conn is an endpoint connection @@ -165,6 +175,8 @@ func (epc *Manager) Send(endpoint, msg string) error { conn = newAMQPConn(ep) case SQS: conn = newSQSConn(ep) + case NATS: + conn = newNATSConn(ep) } epc.conns[endpoint] = conn } @@ -209,6 +221,8 @@ func parseEndpoint(s string) (Endpoint, error) { endpoint.Protocol = MQTT case strings.HasPrefix(s, "sqs:"): endpoint.Protocol = SQS + case strings.HasPrefix(s, "nats:"): + endpoint.Protocol = NATS } s = s[strings.Index(s, ":")+1:] @@ -469,7 +483,7 @@ func parseEndpoint(s string) (Endpoint, error) { // Basic AMQP connection strings in HOOKS interface // amqp://guest:guest@localhost:5672//?params=value - // or amqp://guest:guest@localhost:5672///?params=value + // or amqp://guest:guest@localhost:5672///?params=value // // Default params are: // @@ -487,15 +501,15 @@ func parseEndpoint(s string) (Endpoint, error) { endpoint.AMQP.Durable = true endpoint.AMQP.DeliveryMode = amqp.Transient - // Fix incase of namespace, e.g. example.com/namespace/queue - // but not example.com/queue/ - with an endslash. - if len(sp) > 2 && len(sp[2]) > 0 { - endpoint.AMQP.URI = endpoint.AMQP.URI + "/" + sp[1] - sp = append([]string{endpoint.AMQP.URI}, sp[2:]...) - } - - // Bind queue name with no namespace - if len(sp) > 1 { + // Fix incase of namespace, e.g. example.com/namespace/queue + // but not example.com/queue/ - with an endslash. + if len(sp) > 2 && len(sp[2]) > 0 { + endpoint.AMQP.URI = endpoint.AMQP.URI + "/" + sp[1] + sp = append([]string{endpoint.AMQP.URI}, sp[2:]...) + } + + // Bind queue name with no namespace + if len(sp) > 1 { var err error endpoint.AMQP.QueueName, err = url.QueryUnescape(sp[1]) if err != nil { @@ -549,6 +563,59 @@ func parseEndpoint(s string) (Endpoint, error) { } } + // Basic NATS connection strings in HOOKS interface + // nats://://?params=value + // + // params are: + // + // user - username + // pass - password + // when user or pass is not set then login without password is used + if endpoint.Protocol == NATS { + // Parsing connection from URL string + hp := strings.Split(s, ":") + switch len(hp) { + default: + return endpoint, errors.New("invalid SQS url") + case 2: + endpoint.NATS.Host = hp[0] + port, err := strconv.Atoi(hp[1]) + if err != nil { + endpoint.NATS.Port = 4222 // default nats port + } else { + endpoint.NATS.Port = port + } + } + + // Parsing NATS topic name + if len(sp) > 1 { + var err error + endpoint.NATS.Topic, err = url.QueryUnescape(sp[1]) + if err != nil { + return endpoint, errors.New("invalid NATS topic name") + } + } + + // Parsing additional params + if len(sqp) > 1 { + m, err := url.ParseQuery(sqp[1]) + if err != nil { + return endpoint, errors.New("invalid NATS url") + } + for key, val := range m { + if len(val) == 0 { + continue + } + switch key { + case "user": + endpoint.NATS.User = val[0] + case "pass": + endpoint.NATS.Pass = val[0] + } + } + } + } + return endpoint, nil } diff --git a/pkg/endpoint/nats.go b/pkg/endpoint/nats.go new file mode 100644 index 00000000..f63e4d05 --- /dev/null +++ b/pkg/endpoint/nats.go @@ -0,0 +1,80 @@ +package endpoint + +import ( + "fmt" + "sync" + "time" + + "github.com/nats-io/go-nats" +) + +const ( + natsExpiresAfter = time.Second * 30 +) + +// NATSConn is an endpoint connection +type NATSConn struct { + mu sync.Mutex + ep Endpoint + ex bool + t time.Time + conn *nats.Conn +} + +func newNATSConn(ep Endpoint) *NATSConn { + return &NATSConn{ + ep: ep, + t: time.Now(), + } +} + +// Expired returns true if the connection has expired +func (conn *NATSConn) Expired() bool { + conn.mu.Lock() + defer conn.mu.Unlock() + if !conn.ex { + if time.Now().Sub(conn.t) > natsExpiresAfter { + if conn.conn != nil { + conn.close() + } + conn.ex = true + } + } + return conn.ex +} + +func (conn *NATSConn) close() { + if conn.conn != nil { + conn.conn.Close() + conn.conn = nil + } +} + +// Send sends a message +func (conn *NATSConn) Send(msg string) error { + conn.mu.Lock() + defer conn.mu.Unlock() + if conn.ex { + return errExpired + } + conn.t = time.Now() + if conn.conn == nil { + addr := fmt.Sprintf("nats://%s:%d", conn.ep.NATS.Host, conn.ep.NATS.Port) + var err error + if conn.ep.NATS.User != "" && conn.ep.NATS.Pass != "" { + conn.conn, err = nats.Connect(addr, nats.UserInfo(conn.ep.NATS.User, conn.ep.NATS.Pass)) + } + conn.conn, err = nats.Connect(addr) + if err != nil { + conn.close() + return err + } + } + err := conn.conn.Publish(conn.ep.NATS.Topic, []byte(msg)) + if err != nil { + conn.close() + return err + } + + return nil +} From 4cae0404703ab9a79b2cc3209c255155530a2b71 Mon Sep 17 00:00:00 2001 From: Lenny-Campino Hartmann Date: Tue, 7 Aug 2018 21:21:35 +0200 Subject: [PATCH 2/3] Added go-nats to vendor --- vendor/github.com/nats-io/go-nats/LICENSE | 20 + vendor/github.com/nats-io/go-nats/README.md | 350 ++ vendor/github.com/nats-io/go-nats/TODO.md | 26 + vendor/github.com/nats-io/go-nats/context.go | 166 + vendor/github.com/nats-io/go-nats/enc.go | 249 ++ vendor/github.com/nats-io/go-nats/enc_test.go | 257 ++ .../nats-io/go-nats/example_test.go | 266 ++ vendor/github.com/nats-io/go-nats/nats.go | 2980 +++++++++++++++++ .../github.com/nats-io/go-nats/nats_test.go | 1177 +++++++ vendor/github.com/nats-io/go-nats/netchan.go | 100 + vendor/github.com/nats-io/go-nats/parser.go | 470 +++ .../nats-io/go-nats/staticcheck.ignore | 4 + vendor/github.com/nats-io/go-nats/timer.go | 43 + .../github.com/nats-io/go-nats/timer_test.go | 29 + vendor/vendor.json | 13 + 15 files changed, 6150 insertions(+) create mode 100644 vendor/github.com/nats-io/go-nats/LICENSE create mode 100644 vendor/github.com/nats-io/go-nats/README.md create mode 100644 vendor/github.com/nats-io/go-nats/TODO.md create mode 100644 vendor/github.com/nats-io/go-nats/context.go create mode 100644 vendor/github.com/nats-io/go-nats/enc.go create mode 100644 vendor/github.com/nats-io/go-nats/enc_test.go create mode 100644 vendor/github.com/nats-io/go-nats/example_test.go create mode 100644 vendor/github.com/nats-io/go-nats/nats.go create mode 100644 vendor/github.com/nats-io/go-nats/nats_test.go create mode 100644 vendor/github.com/nats-io/go-nats/netchan.go create mode 100644 vendor/github.com/nats-io/go-nats/parser.go create mode 100644 vendor/github.com/nats-io/go-nats/staticcheck.ignore create mode 100644 vendor/github.com/nats-io/go-nats/timer.go create mode 100644 vendor/github.com/nats-io/go-nats/timer_test.go create mode 100644 vendor/vendor.json diff --git a/vendor/github.com/nats-io/go-nats/LICENSE b/vendor/github.com/nats-io/go-nats/LICENSE new file mode 100644 index 00000000..9798d4ef --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2012-2017 Apcera Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/nats-io/go-nats/README.md b/vendor/github.com/nats-io/go-nats/README.md new file mode 100644 index 00000000..ae6868c5 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/README.md @@ -0,0 +1,350 @@ +# NATS - Go Client +A [Go](http://golang.org) client for the [NATS messaging system](https://nats.io). + +[![License MIT](https://img.shields.io/badge/License-MIT-blue.svg)](http://opensource.org/licenses/MIT) +[![Go Report Card](https://goreportcard.com/badge/github.com/nats-io/go-nats)](https://goreportcard.com/report/github.com/nats-io/go-nats) [![Build Status](https://travis-ci.org/nats-io/go-nats.svg?branch=master)](http://travis-ci.org/nats-io/go-nats) [![GoDoc](https://godoc.org/github.com/nats-io/go-nats?status.svg)](http://godoc.org/github.com/nats-io/go-nats) [![Coverage Status](https://coveralls.io/repos/nats-io/go-nats/badge.svg?branch=master)](https://coveralls.io/r/nats-io/go-nats?branch=master) + +## Installation + +```bash +# Go client +go get github.com/nats-io/go-nats + +# Server +go get github.com/nats-io/gnatsd +``` + +## Basic Usage + +```go + +nc, _ := nats.Connect(nats.DefaultURL) + +// Simple Publisher +nc.Publish("foo", []byte("Hello World")) + +// Simple Async Subscriber +nc.Subscribe("foo", func(m *nats.Msg) { + fmt.Printf("Received a message: %s\n", string(m.Data)) +}) + +// Simple Sync Subscriber +sub, err := nc.SubscribeSync("foo") +m, err := sub.NextMsg(timeout) + +// Channel Subscriber +ch := make(chan *nats.Msg, 64) +sub, err := nc.ChanSubscribe("foo", ch) +msg := <- ch + +// Unsubscribe +sub.Unsubscribe() + +// Requests +msg, err := nc.Request("help", []byte("help me"), 10*time.Millisecond) + +// Replies +nc.Subscribe("help", func(m *Msg) { + nc.Publish(m.Reply, []byte("I can help!")) +}) + +// Close connection +nc, _ := nats.Connect("nats://localhost:4222") +nc.Close(); +``` + +## Encoded Connections + +```go + +nc, _ := nats.Connect(nats.DefaultURL) +c, _ := nats.NewEncodedConn(nc, nats.JSON_ENCODER) +defer c.Close() + +// Simple Publisher +c.Publish("foo", "Hello World") + +// Simple Async Subscriber +c.Subscribe("foo", func(s string) { + fmt.Printf("Received a message: %s\n", s) +}) + +// EncodedConn can Publish any raw Go type using the registered Encoder +type person struct { + Name string + Address string + Age int +} + +// Go type Subscriber +c.Subscribe("hello", func(p *person) { + fmt.Printf("Received a person: %+v\n", p) +}) + +me := &person{Name: "derek", Age: 22, Address: "140 New Montgomery Street, San Francisco, CA"} + +// Go type Publisher +c.Publish("hello", me) + +// Unsubscribe +sub, err := c.Subscribe("foo", nil) +... +sub.Unsubscribe() + +// Requests +var response string +err := c.Request("help", "help me", &response, 10*time.Millisecond) +if err != nil { + fmt.Printf("Request failed: %v\n", err) +} + +// Replying +c.Subscribe("help", func(subj, reply string, msg string) { + c.Publish(reply, "I can help!") +}) + +// Close connection +c.Close(); +``` + +## TLS + +```go +// tls as a scheme will enable secure connections by default. This will also verify the server name. +nc, err := nats.Connect("tls://nats.demo.io:4443") + +// If you are using a self-signed certificate, you need to have a tls.Config with RootCAs setup. +// We provide a helper method to make this case easier. +nc, err = nats.Connect("tls://localhost:4443", nats.RootCAs("./configs/certs/ca.pem")) + +// If the server requires client certificate, there is an helper function for that too: +cert := nats.ClientCert("./configs/certs/client-cert.pem", "./configs/certs/client-key.pem") +nc, err = nats.Connect("tls://localhost:4443", cert) + +// You can also supply a complete tls.Config + +certFile := "./configs/certs/client-cert.pem" +keyFile := "./configs/certs/client-key.pem" +cert, err := tls.LoadX509KeyPair(certFile, keyFile) +if err != nil { + t.Fatalf("error parsing X509 certificate/key pair: %v", err) +} + +config := &tls.Config{ + ServerName: opts.Host, + Certificates: []tls.Certificate{cert}, + RootCAs: pool, + MinVersion: tls.VersionTLS12, +} + +nc, err = nats.Connect("nats://localhost:4443", nats.Secure(config)) +if err != nil { + t.Fatalf("Got an error on Connect with Secure Options: %+v\n", err) +} + +``` + +## Using Go Channels (netchan) + +```go +nc, _ := nats.Connect(nats.DefaultURL) +ec, _ := nats.NewEncodedConn(nc, nats.JSON_ENCODER) +defer ec.Close() + +type person struct { + Name string + Address string + Age int +} + +recvCh := make(chan *person) +ec.BindRecvChan("hello", recvCh) + +sendCh := make(chan *person) +ec.BindSendChan("hello", sendCh) + +me := &person{Name: "derek", Age: 22, Address: "140 New Montgomery Street"} + +// Send via Go channels +sendCh <- me + +// Receive via Go channels +who := <- recvCh +``` + +## Wildcard Subscriptions + +```go + +// "*" matches any token, at any level of the subject. +nc.Subscribe("foo.*.baz", func(m *Msg) { + fmt.Printf("Msg received on [%s] : %s\n", m.Subject, string(m.Data)); +}) + +nc.Subscribe("foo.bar.*", func(m *Msg) { + fmt.Printf("Msg received on [%s] : %s\n", m.Subject, string(m.Data)); +}) + +// ">" matches any length of the tail of a subject, and can only be the last token +// E.g. 'foo.>' will match 'foo.bar', 'foo.bar.baz', 'foo.foo.bar.bax.22' +nc.Subscribe("foo.>", func(m *Msg) { + fmt.Printf("Msg received on [%s] : %s\n", m.Subject, string(m.Data)); +}) + +// Matches all of the above +nc.Publish("foo.bar.baz", []byte("Hello World")) + +``` + +## Queue Groups + +```go +// All subscriptions with the same queue name will form a queue group. +// Each message will be delivered to only one subscriber per queue group, +// using queuing semantics. You can have as many queue groups as you wish. +// Normal subscribers will continue to work as expected. + +nc.QueueSubscribe("foo", "job_workers", func(_ *Msg) { + received += 1; +}) + +``` + +## Advanced Usage + +```go + +// Flush connection to server, returns when all messages have been processed. +nc.Flush() +fmt.Println("All clear!") + +// FlushTimeout specifies a timeout value as well. +err := nc.FlushTimeout(1*time.Second) +if err != nil { + fmt.Println("All clear!") +} else { + fmt.Println("Flushed timed out!") +} + +// Auto-unsubscribe after MAX_WANTED messages received +const MAX_WANTED = 10 +sub, err := nc.Subscribe("foo") +sub.AutoUnsubscribe(MAX_WANTED) + +// Multiple connections +nc1 := nats.Connect("nats://host1:4222") +nc2 := nats.Connect("nats://host2:4222") + +nc1.Subscribe("foo", func(m *Msg) { + fmt.Printf("Received a message: %s\n", string(m.Data)) +}) + +nc2.Publish("foo", []byte("Hello World!")); + +``` + +## Clustered Usage + +```go + +var servers = "nats://localhost:1222, nats://localhost:1223, nats://localhost:1224" + +nc, err := nats.Connect(servers) + +// Optionally set ReconnectWait and MaxReconnect attempts. +// This example means 10 seconds total per backend. +nc, err = nats.Connect(servers, nats.MaxReconnects(5), nats.ReconnectWait(2 * time.Second)) + +// Optionally disable randomization of the server pool +nc, err = nats.Connect(servers, nats.DontRandomize()) + +// Setup callbacks to be notified on disconnects, reconnects and connection closed. +nc, err = nats.Connect(servers, + nats.DisconnectHandler(func(nc *nats.Conn) { + fmt.Printf("Got disconnected!\n") + }), + nats.ReconnectHandler(func(_ *nats.Conn) { + fmt.Printf("Got reconnected to %v!\n", nc.ConnectedUrl()) + }), + nats.ClosedHandler(func(nc *nats.Conn) { + fmt.Printf("Connection closed. Reason: %q\n", nc.LastError()) + }) +) + +// When connecting to a mesh of servers with auto-discovery capabilities, +// you may need to provide a username/password or token in order to connect +// to any server in that mesh when authentication is required. +// Instead of providing the credentials in the initial URL, you will use +// new option setters: +nc, err = nats.Connect("nats://localhost:4222", nats.UserInfo("foo", "bar")) + +// For token based authentication: +nc, err = nats.Connect("nats://localhost:4222", nats.Token("S3cretT0ken")) + +// You can even pass the two at the same time in case one of the server +// in the mesh requires token instead of user name and password. +nc, err = nats.Connect("nats://localhost:4222", + nats.UserInfo("foo", "bar"), + nats.Token("S3cretT0ken")) + +// Note that if credentials are specified in the initial URLs, they take +// precedence on the credentials specfied through the options. +// For instance, in the connect call below, the client library will use +// the user "my" and password "pwd" to connect to locahost:4222, however, +// it will use username "foo" and password "bar" when (re)connecting to +// a different server URL that it got as part of the auto-discovery. +nc, err = nats.Connect("nats://my:pwd@localhost:4222", nats.UserInfo("foo", "bar")) + +``` + +## Context support (+Go 1.7) + +```go +ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) +defer cancel() + +nc, err := nats.Connect(nats.DefaultURL) + +// Request with context +msg, err := nc.RequestWithContext(ctx, "foo", []byte("bar")) + +// Synchronous subscriber with context +sub, err := nc.SubscribeSync("foo") +msg, err := sub.NextMsgWithContext(ctx) + +// Encoded Request with context +c, err := nats.NewEncodedConn(nc, nats.JSON_ENCODER) +type request struct { + Message string `json:"message"` +} +type response struct { + Code int `json:"code"` +} +req := &request{Message: "Hello"} +resp := &response{} +err := c.RequestWithContext(ctx, "foo", req, resp) +``` + +## License + +(The MIT License) + +Copyright (c) 2012-2017 Apcera Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to +deal in the Software without restriction, including without limitation the +rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +sell copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +IN THE SOFTWARE. diff --git a/vendor/github.com/nats-io/go-nats/TODO.md b/vendor/github.com/nats-io/go-nats/TODO.md new file mode 100644 index 00000000..213aaeca --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/TODO.md @@ -0,0 +1,26 @@ + +- [ ] Better constructors, options handling +- [ ] Functions for callback settings after connection created. +- [ ] Better options for subscriptions. Slow Consumer state settable, Go routines vs Inline. +- [ ] Move off of channels for subscribers, use syncPool linkedLists, etc with highwater. +- [ ] Test for valid subjects on publish and subscribe? +- [ ] SyncSubscriber and Next for EncodedConn +- [ ] Fast Publisher? +- [ ] pooling for structs used? leaky bucket? +- [ ] Timeout 0 should work as no timeout +- [x] Ping timer +- [x] Name in Connect for gnatsd +- [x] Asynchronous error handling +- [x] Parser rewrite +- [x] Reconnect +- [x] Hide Lock +- [x] Easier encoder interface +- [x] QueueSubscribeSync +- [x] Make nats specific errors prefixed with 'nats:' +- [x] API test for closed connection +- [x] TLS/SSL +- [x] Stats collection +- [x] Disconnect detection +- [x] Optimized Publish (coalescing) +- [x] Do Examples via Go style +- [x] Standardized Errors diff --git a/vendor/github.com/nats-io/go-nats/context.go b/vendor/github.com/nats-io/go-nats/context.go new file mode 100644 index 00000000..be6ada4a --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/context.go @@ -0,0 +1,166 @@ +// Copyright 2012-2017 Apcera Inc. All rights reserved. + +// +build go1.7 + +// A Go client for the NATS messaging system (https://nats.io). +package nats + +import ( + "context" + "fmt" + "reflect" +) + +// RequestWithContext takes a context, a subject and payload +// in bytes and request expecting a single response. +func (nc *Conn) RequestWithContext(ctx context.Context, subj string, data []byte) (*Msg, error) { + if ctx == nil { + return nil, ErrInvalidContext + } + if nc == nil { + return nil, ErrInvalidConnection + } + + nc.mu.Lock() + // If user wants the old style. + if nc.Opts.UseOldRequestStyle { + nc.mu.Unlock() + return nc.oldRequestWithContext(ctx, subj, data) + } + + // Do setup for the new style. + if nc.respMap == nil { + // _INBOX wildcard + nc.respSub = fmt.Sprintf("%s.*", NewInbox()) + nc.respMap = make(map[string]chan *Msg) + } + // Create literal Inbox and map to a chan msg. + mch := make(chan *Msg, RequestChanLen) + respInbox := nc.newRespInbox() + token := respToken(respInbox) + nc.respMap[token] = mch + createSub := nc.respMux == nil + ginbox := nc.respSub + nc.mu.Unlock() + + if createSub { + // Make sure scoped subscription is setup only once. + var err error + nc.respSetup.Do(func() { err = nc.createRespMux(ginbox) }) + if err != nil { + return nil, err + } + } + + err := nc.PublishRequest(subj, respInbox, data) + if err != nil { + return nil, err + } + + var ok bool + var msg *Msg + + select { + case msg, ok = <-mch: + if !ok { + return nil, ErrConnectionClosed + } + case <-ctx.Done(): + nc.mu.Lock() + delete(nc.respMap, token) + nc.mu.Unlock() + return nil, ctx.Err() + } + + return msg, nil +} + +// oldRequestWithContext utilizes inbox and subscription per request. +func (nc *Conn) oldRequestWithContext(ctx context.Context, subj string, data []byte) (*Msg, error) { + inbox := NewInbox() + ch := make(chan *Msg, RequestChanLen) + + s, err := nc.subscribe(inbox, _EMPTY_, nil, ch) + if err != nil { + return nil, err + } + s.AutoUnsubscribe(1) + defer s.Unsubscribe() + + err = nc.PublishRequest(subj, inbox, data) + if err != nil { + return nil, err + } + + return s.NextMsgWithContext(ctx) +} + +// NextMsgWithContext takes a context and returns the next message +// available to a synchronous subscriber, blocking until it is delivered +// or context gets canceled. +func (s *Subscription) NextMsgWithContext(ctx context.Context) (*Msg, error) { + if ctx == nil { + return nil, ErrInvalidContext + } + if s == nil { + return nil, ErrBadSubscription + } + + s.mu.Lock() + err := s.validateNextMsgState() + if err != nil { + s.mu.Unlock() + return nil, err + } + + // snapshot + mch := s.mch + s.mu.Unlock() + + var ok bool + var msg *Msg + + select { + case msg, ok = <-mch: + if !ok { + return nil, ErrConnectionClosed + } + err := s.processNextMsgDelivered(msg) + if err != nil { + return nil, err + } + case <-ctx.Done(): + return nil, ctx.Err() + } + + return msg, nil +} + +// RequestWithContext will create an Inbox and perform a Request +// using the provided cancellation context with the Inbox reply +// for the data v. A response will be decoded into the vPtrResponse. +func (c *EncodedConn) RequestWithContext(ctx context.Context, subject string, v interface{}, vPtr interface{}) error { + if ctx == nil { + return ErrInvalidContext + } + + b, err := c.Enc.Encode(subject, v) + if err != nil { + return err + } + m, err := c.Conn.RequestWithContext(ctx, subject, b) + if err != nil { + return err + } + if reflect.TypeOf(vPtr) == emptyMsgType { + mPtr := vPtr.(*Msg) + *mPtr = *m + } else { + err := c.Enc.Decode(m.Subject, m.Data, vPtr) + if err != nil { + return err + } + } + + return nil +} diff --git a/vendor/github.com/nats-io/go-nats/enc.go b/vendor/github.com/nats-io/go-nats/enc.go new file mode 100644 index 00000000..291b7826 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/enc.go @@ -0,0 +1,249 @@ +// Copyright 2012-2015 Apcera Inc. All rights reserved. + +package nats + +import ( + "errors" + "fmt" + "reflect" + "sync" + "time" + + // Default Encoders + . "github.com/nats-io/go-nats/encoders/builtin" +) + +// Encoder interface is for all register encoders +type Encoder interface { + Encode(subject string, v interface{}) ([]byte, error) + Decode(subject string, data []byte, vPtr interface{}) error +} + +var encMap map[string]Encoder +var encLock sync.Mutex + +// Indexe names into the Registered Encoders. +const ( + JSON_ENCODER = "json" + GOB_ENCODER = "gob" + DEFAULT_ENCODER = "default" +) + +func init() { + encMap = make(map[string]Encoder) + // Register json, gob and default encoder + RegisterEncoder(JSON_ENCODER, &JsonEncoder{}) + RegisterEncoder(GOB_ENCODER, &GobEncoder{}) + RegisterEncoder(DEFAULT_ENCODER, &DefaultEncoder{}) +} + +// EncodedConn are the preferred way to interface with NATS. They wrap a bare connection to +// a nats server and have an extendable encoder system that will encode and decode messages +// from raw Go types. +type EncodedConn struct { + Conn *Conn + Enc Encoder +} + +// NewEncodedConn will wrap an existing Connection and utilize the appropriate registered +// encoder. +func NewEncodedConn(c *Conn, encType string) (*EncodedConn, error) { + if c == nil { + return nil, errors.New("nats: Nil Connection") + } + if c.IsClosed() { + return nil, ErrConnectionClosed + } + ec := &EncodedConn{Conn: c, Enc: EncoderForType(encType)} + if ec.Enc == nil { + return nil, fmt.Errorf("No encoder registered for '%s'", encType) + } + return ec, nil +} + +// RegisterEncoder will register the encType with the given Encoder. Useful for customization. +func RegisterEncoder(encType string, enc Encoder) { + encLock.Lock() + defer encLock.Unlock() + encMap[encType] = enc +} + +// EncoderForType will return the registered Encoder for the encType. +func EncoderForType(encType string) Encoder { + encLock.Lock() + defer encLock.Unlock() + return encMap[encType] +} + +// Publish publishes the data argument to the given subject. The data argument +// will be encoded using the associated encoder. +func (c *EncodedConn) Publish(subject string, v interface{}) error { + b, err := c.Enc.Encode(subject, v) + if err != nil { + return err + } + return c.Conn.publish(subject, _EMPTY_, b) +} + +// PublishRequest will perform a Publish() expecting a response on the +// reply subject. Use Request() for automatically waiting for a response +// inline. +func (c *EncodedConn) PublishRequest(subject, reply string, v interface{}) error { + b, err := c.Enc.Encode(subject, v) + if err != nil { + return err + } + return c.Conn.publish(subject, reply, b) +} + +// Request will create an Inbox and perform a Request() call +// with the Inbox reply for the data v. A response will be +// decoded into the vPtrResponse. +func (c *EncodedConn) Request(subject string, v interface{}, vPtr interface{}, timeout time.Duration) error { + b, err := c.Enc.Encode(subject, v) + if err != nil { + return err + } + m, err := c.Conn.Request(subject, b, timeout) + if err != nil { + return err + } + if reflect.TypeOf(vPtr) == emptyMsgType { + mPtr := vPtr.(*Msg) + *mPtr = *m + } else { + err = c.Enc.Decode(m.Subject, m.Data, vPtr) + } + return err +} + +// Handler is a specific callback used for Subscribe. It is generalized to +// an interface{}, but we will discover its format and arguments at runtime +// and perform the correct callback, including de-marshaling JSON strings +// back into the appropriate struct based on the signature of the Handler. +// +// Handlers are expected to have one of four signatures. +// +// type person struct { +// Name string `json:"name,omitempty"` +// Age uint `json:"age,omitempty"` +// } +// +// handler := func(m *Msg) +// handler := func(p *person) +// handler := func(subject string, o *obj) +// handler := func(subject, reply string, o *obj) +// +// These forms allow a callback to request a raw Msg ptr, where the processing +// of the message from the wire is untouched. Process a JSON representation +// and demarshal it into the given struct, e.g. person. +// There are also variants where the callback wants either the subject, or the +// subject and the reply subject. +type Handler interface{} + +// Dissect the cb Handler's signature +func argInfo(cb Handler) (reflect.Type, int) { + cbType := reflect.TypeOf(cb) + if cbType.Kind() != reflect.Func { + panic("nats: Handler needs to be a func") + } + numArgs := cbType.NumIn() + if numArgs == 0 { + return nil, numArgs + } + return cbType.In(numArgs - 1), numArgs +} + +var emptyMsgType = reflect.TypeOf(&Msg{}) + +// Subscribe will create a subscription on the given subject and process incoming +// messages using the specified Handler. The Handler should be a func that matches +// a signature from the description of Handler from above. +func (c *EncodedConn) Subscribe(subject string, cb Handler) (*Subscription, error) { + return c.subscribe(subject, _EMPTY_, cb) +} + +// QueueSubscribe will create a queue subscription on the given subject and process +// incoming messages using the specified Handler. The Handler should be a func that +// matches a signature from the description of Handler from above. +func (c *EncodedConn) QueueSubscribe(subject, queue string, cb Handler) (*Subscription, error) { + return c.subscribe(subject, queue, cb) +} + +// Internal implementation that all public functions will use. +func (c *EncodedConn) subscribe(subject, queue string, cb Handler) (*Subscription, error) { + if cb == nil { + return nil, errors.New("nats: Handler required for EncodedConn Subscription") + } + argType, numArgs := argInfo(cb) + if argType == nil { + return nil, errors.New("nats: Handler requires at least one argument") + } + + cbValue := reflect.ValueOf(cb) + wantsRaw := (argType == emptyMsgType) + + natsCB := func(m *Msg) { + var oV []reflect.Value + if wantsRaw { + oV = []reflect.Value{reflect.ValueOf(m)} + } else { + var oPtr reflect.Value + if argType.Kind() != reflect.Ptr { + oPtr = reflect.New(argType) + } else { + oPtr = reflect.New(argType.Elem()) + } + if err := c.Enc.Decode(m.Subject, m.Data, oPtr.Interface()); err != nil { + if c.Conn.Opts.AsyncErrorCB != nil { + c.Conn.ach <- func() { + c.Conn.Opts.AsyncErrorCB(c.Conn, m.Sub, errors.New("nats: Got an error trying to unmarshal: "+err.Error())) + } + } + return + } + if argType.Kind() != reflect.Ptr { + oPtr = reflect.Indirect(oPtr) + } + + // Callback Arity + switch numArgs { + case 1: + oV = []reflect.Value{oPtr} + case 2: + subV := reflect.ValueOf(m.Subject) + oV = []reflect.Value{subV, oPtr} + case 3: + subV := reflect.ValueOf(m.Subject) + replyV := reflect.ValueOf(m.Reply) + oV = []reflect.Value{subV, replyV, oPtr} + } + + } + cbValue.Call(oV) + } + + return c.Conn.subscribe(subject, queue, natsCB, nil) +} + +// FlushTimeout allows a Flush operation to have an associated timeout. +func (c *EncodedConn) FlushTimeout(timeout time.Duration) (err error) { + return c.Conn.FlushTimeout(timeout) +} + +// Flush will perform a round trip to the server and return when it +// receives the internal reply. +func (c *EncodedConn) Flush() error { + return c.Conn.Flush() +} + +// Close will close the connection to the server. This call will release +// all blocking calls, such as Flush(), etc. +func (c *EncodedConn) Close() { + c.Conn.Close() +} + +// LastError reports the last error encountered via the Connection. +func (c *EncodedConn) LastError() error { + return c.Conn.err +} diff --git a/vendor/github.com/nats-io/go-nats/enc_test.go b/vendor/github.com/nats-io/go-nats/enc_test.go new file mode 100644 index 00000000..ada5b024 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/enc_test.go @@ -0,0 +1,257 @@ +package nats_test + +import ( + "fmt" + "testing" + "time" + + . "github.com/nats-io/go-nats" + "github.com/nats-io/go-nats/encoders/protobuf" + "github.com/nats-io/go-nats/encoders/protobuf/testdata" +) + +// Since we import above nats packages, we need to have a different +// const name than TEST_PORT that we used on the other packages. +const ENC_TEST_PORT = 8268 + +var options = Options{ + Url: fmt.Sprintf("nats://localhost:%d", ENC_TEST_PORT), + AllowReconnect: true, + MaxReconnect: 10, + ReconnectWait: 100 * time.Millisecond, + Timeout: DefaultTimeout, +} + +//////////////////////////////////////////////////////////////////////////////// +// Encoded connection tests +//////////////////////////////////////////////////////////////////////////////// + +func TestPublishErrorAfterSubscribeDecodeError(t *testing.T) { + ts := RunServerOnPort(ENC_TEST_PORT) + defer ts.Shutdown() + opts := options + nc, _ := opts.Connect() + defer nc.Close() + c, _ := NewEncodedConn(nc, JSON_ENCODER) + + //Test message type + type Message struct { + Message string + } + const testSubj = "test" + + c.Subscribe(testSubj, func(msg *Message) {}) + + //Publish invalid json to catch decode error in subscription callback + c.Publish(testSubj, `foo`) + c.Flush() + + //Next publish should be successful + if err := c.Publish(testSubj, Message{"2"}); err != nil { + t.Error("Fail to send correct json message after decode error in subscription") + } +} + +func TestPublishErrorAfterInvalidPublishMessage(t *testing.T) { + ts := RunServerOnPort(ENC_TEST_PORT) + defer ts.Shutdown() + opts := options + nc, _ := opts.Connect() + defer nc.Close() + c, _ := NewEncodedConn(nc, protobuf.PROTOBUF_ENCODER) + const testSubj = "test" + + c.Publish(testSubj, &testdata.Person{Name: "Anatolii"}) + + //Publish invalid protobuff message to catch decode error + c.Publish(testSubj, "foo") + + //Next publish with valid protobuf message should be successful + if err := c.Publish(testSubj, &testdata.Person{Name: "Anatolii"}); err != nil { + t.Error("Fail to send correct protobuf message after invalid message publishing", err) + } +} + +func TestVariousFailureConditions(t *testing.T) { + ts := RunServerOnPort(ENC_TEST_PORT) + defer ts.Shutdown() + + dch := make(chan bool) + + opts := options + opts.AsyncErrorCB = func(_ *Conn, _ *Subscription, e error) { + dch <- true + } + nc, _ := opts.Connect() + nc.Close() + + if _, err := NewEncodedConn(nil, protobuf.PROTOBUF_ENCODER); err == nil { + t.Fatal("Expected an error") + } + + if _, err := NewEncodedConn(nc, protobuf.PROTOBUF_ENCODER); err == nil || err != ErrConnectionClosed { + t.Fatalf("Wrong error: %v instead of %v", err, ErrConnectionClosed) + } + + nc, _ = opts.Connect() + defer nc.Close() + + if _, err := NewEncodedConn(nc, "foo"); err == nil { + t.Fatal("Expected an error") + } + + c, err := NewEncodedConn(nc, protobuf.PROTOBUF_ENCODER) + if err != nil { + t.Fatalf("Unable to create encoded connection: %v", err) + } + defer c.Close() + + if _, err := c.Subscribe("bar", func(subj, obj string) {}); err != nil { + t.Fatalf("Unable to create subscription: %v", err) + } + + if err := c.Publish("bar", &testdata.Person{Name: "Ivan"}); err != nil { + t.Fatalf("Unable to publish: %v", err) + } + + if err := Wait(dch); err != nil { + t.Fatal("Did not get the async error callback") + } + + if err := c.PublishRequest("foo", "bar", "foo"); err == nil { + t.Fatal("Expected an error") + } + + if err := c.Request("foo", "foo", nil, 2*time.Second); err == nil { + t.Fatal("Expected an error") + } + + nc.Close() + + if err := c.PublishRequest("foo", "bar", &testdata.Person{Name: "Ivan"}); err == nil { + t.Fatal("Expected an error") + } + + resp := &testdata.Person{} + if err := c.Request("foo", &testdata.Person{Name: "Ivan"}, resp, 2*time.Second); err == nil { + t.Fatal("Expected an error") + } + + if _, err := c.Subscribe("foo", nil); err == nil { + t.Fatal("Expected an error") + } + + if _, err := c.Subscribe("foo", func() {}); err == nil { + t.Fatal("Expected an error") + } + + func() { + defer func() { + if r := recover(); r == nil { + t.Fatal("Expected an error") + } + }() + if _, err := c.Subscribe("foo", "bar"); err == nil { + t.Fatal("Expected an error") + } + }() +} + +func TestRequest(t *testing.T) { + ts := RunServerOnPort(ENC_TEST_PORT) + defer ts.Shutdown() + + dch := make(chan bool) + + opts := options + nc, _ := opts.Connect() + defer nc.Close() + + c, err := NewEncodedConn(nc, protobuf.PROTOBUF_ENCODER) + if err != nil { + t.Fatalf("Unable to create encoded connection: %v", err) + } + defer c.Close() + + sentName := "Ivan" + recvName := "Kozlovic" + + if _, err := c.Subscribe("foo", func(_, reply string, p *testdata.Person) { + if p.Name != sentName { + t.Fatalf("Got wrong name: %v instead of %v", p.Name, sentName) + } + c.Publish(reply, &testdata.Person{Name: recvName}) + dch <- true + }); err != nil { + t.Fatalf("Unable to create subscription: %v", err) + } + if _, err := c.Subscribe("foo", func(_ string, p *testdata.Person) { + if p.Name != sentName { + t.Fatalf("Got wrong name: %v instead of %v", p.Name, sentName) + } + dch <- true + }); err != nil { + t.Fatalf("Unable to create subscription: %v", err) + } + + if err := c.Publish("foo", &testdata.Person{Name: sentName}); err != nil { + t.Fatalf("Unable to publish: %v", err) + } + + if err := Wait(dch); err != nil { + t.Fatal("Did not get message") + } + if err := Wait(dch); err != nil { + t.Fatal("Did not get message") + } + + response := &testdata.Person{} + if err := c.Request("foo", &testdata.Person{Name: sentName}, response, 2*time.Second); err != nil { + t.Fatalf("Unable to publish: %v", err) + } + if response == nil { + t.Fatal("No response received") + } else if response.Name != recvName { + t.Fatalf("Wrong response: %v instead of %v", response.Name, recvName) + } + + if err := Wait(dch); err != nil { + t.Fatal("Did not get message") + } + if err := Wait(dch); err != nil { + t.Fatal("Did not get message") + } + + c2, err := NewEncodedConn(nc, GOB_ENCODER) + if err != nil { + t.Fatalf("Unable to create encoded connection: %v", err) + } + defer c2.Close() + + if _, err := c2.QueueSubscribe("bar", "baz", func(m *Msg) { + response := &Msg{Subject: m.Reply, Data: []byte(recvName)} + c2.Conn.PublishMsg(response) + dch <- true + }); err != nil { + t.Fatalf("Unable to create subscription: %v", err) + } + + mReply := Msg{} + if err := c2.Request("bar", &Msg{Data: []byte(sentName)}, &mReply, 2*time.Second); err != nil { + t.Fatalf("Unable to send request: %v", err) + } + if string(mReply.Data) != recvName { + t.Fatalf("Wrong reply: %v instead of %v", string(mReply.Data), recvName) + } + + if err := Wait(dch); err != nil { + t.Fatal("Did not get message") + } + + if c.LastError() != nil { + t.Fatalf("Unexpected connection error: %v", c.LastError()) + } + if c2.LastError() != nil { + t.Fatalf("Unexpected connection error: %v", c2.LastError()) + } +} diff --git a/vendor/github.com/nats-io/go-nats/example_test.go b/vendor/github.com/nats-io/go-nats/example_test.go new file mode 100644 index 00000000..64a65867 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/example_test.go @@ -0,0 +1,266 @@ +package nats_test + +import ( + "fmt" + "time" + + "github.com/nats-io/go-nats" +) + +// Shows different ways to create a Conn +func ExampleConnect() { + + nc, _ := nats.Connect(nats.DefaultURL) + nc.Close() + + nc, _ = nats.Connect("nats://derek:secretpassword@demo.nats.io:4222") + nc.Close() + + nc, _ = nats.Connect("tls://derek:secretpassword@demo.nats.io:4443") + nc.Close() + + opts := nats.Options{ + AllowReconnect: true, + MaxReconnect: 10, + ReconnectWait: 5 * time.Second, + Timeout: 1 * time.Second, + } + + nc, _ = opts.Connect() + nc.Close() +} + +// This Example shows an asynchronous subscriber. +func ExampleConn_Subscribe() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + nc.Subscribe("foo", func(m *nats.Msg) { + fmt.Printf("Received a message: %s\n", string(m.Data)) + }) +} + +// This Example shows a synchronous subscriber. +func ExampleConn_SubscribeSync() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + sub, _ := nc.SubscribeSync("foo") + m, err := sub.NextMsg(1 * time.Second) + if err == nil { + fmt.Printf("Received a message: %s\n", string(m.Data)) + } else { + fmt.Println("NextMsg timed out.") + } +} + +func ExampleSubscription_NextMsg() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + sub, _ := nc.SubscribeSync("foo") + m, err := sub.NextMsg(1 * time.Second) + if err == nil { + fmt.Printf("Received a message: %s\n", string(m.Data)) + } else { + fmt.Println("NextMsg timed out.") + } +} + +func ExampleSubscription_Unsubscribe() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + sub, _ := nc.SubscribeSync("foo") + // ... + sub.Unsubscribe() +} + +func ExampleConn_Publish() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + nc.Publish("foo", []byte("Hello World!")) +} + +func ExampleConn_PublishMsg() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + msg := &nats.Msg{Subject: "foo", Reply: "bar", Data: []byte("Hello World!")} + nc.PublishMsg(msg) +} + +func ExampleConn_Flush() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + msg := &nats.Msg{Subject: "foo", Reply: "bar", Data: []byte("Hello World!")} + for i := 0; i < 1000; i++ { + nc.PublishMsg(msg) + } + err := nc.Flush() + if err == nil { + // Everything has been processed by the server for nc *Conn. + } +} + +func ExampleConn_FlushTimeout() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + msg := &nats.Msg{Subject: "foo", Reply: "bar", Data: []byte("Hello World!")} + for i := 0; i < 1000; i++ { + nc.PublishMsg(msg) + } + // Only wait for up to 1 second for Flush + err := nc.FlushTimeout(1 * time.Second) + if err == nil { + // Everything has been processed by the server for nc *Conn. + } +} + +func ExampleConn_Request() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + nc.Subscribe("foo", func(m *nats.Msg) { + nc.Publish(m.Reply, []byte("I will help you")) + }) + nc.Request("foo", []byte("help"), 50*time.Millisecond) +} + +func ExampleConn_QueueSubscribe() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + received := 0 + + nc.QueueSubscribe("foo", "worker_group", func(_ *nats.Msg) { + received++ + }) +} + +func ExampleSubscription_AutoUnsubscribe() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + received, wanted, total := 0, 10, 100 + + sub, _ := nc.Subscribe("foo", func(_ *nats.Msg) { + received++ + }) + sub.AutoUnsubscribe(wanted) + + for i := 0; i < total; i++ { + nc.Publish("foo", []byte("Hello")) + } + nc.Flush() + + fmt.Printf("Received = %d", received) +} + +func ExampleConn_Close() { + nc, _ := nats.Connect(nats.DefaultURL) + nc.Close() +} + +// Shows how to wrap a Conn into an EncodedConn +func ExampleNewEncodedConn() { + nc, _ := nats.Connect(nats.DefaultURL) + c, _ := nats.NewEncodedConn(nc, "json") + c.Close() +} + +// EncodedConn can publish virtually anything just +// by passing it in. The encoder will be used to properly +// encode the raw Go type +func ExampleEncodedConn_Publish() { + nc, _ := nats.Connect(nats.DefaultURL) + c, _ := nats.NewEncodedConn(nc, "json") + defer c.Close() + + type person struct { + Name string + Address string + Age int + } + + me := &person{Name: "derek", Age: 22, Address: "85 Second St"} + c.Publish("hello", me) +} + +// EncodedConn's subscribers will automatically decode the +// wire data into the requested Go type using the Decode() +// method of the registered Encoder. The callback signature +// can also vary to include additional data, such as subject +// and reply subjects. +func ExampleEncodedConn_Subscribe() { + nc, _ := nats.Connect(nats.DefaultURL) + c, _ := nats.NewEncodedConn(nc, "json") + defer c.Close() + + type person struct { + Name string + Address string + Age int + } + + c.Subscribe("hello", func(p *person) { + fmt.Printf("Received a person! %+v\n", p) + }) + + c.Subscribe("hello", func(subj, reply string, p *person) { + fmt.Printf("Received a person on subject %s! %+v\n", subj, p) + }) + + me := &person{Name: "derek", Age: 22, Address: "85 Second St"} + c.Publish("hello", me) +} + +// BindSendChan() allows binding of a Go channel to a nats +// subject for publish operations. The Encoder attached to the +// EncodedConn will be used for marshaling. +func ExampleEncodedConn_BindSendChan() { + nc, _ := nats.Connect(nats.DefaultURL) + c, _ := nats.NewEncodedConn(nc, "json") + defer c.Close() + + type person struct { + Name string + Address string + Age int + } + + ch := make(chan *person) + c.BindSendChan("hello", ch) + + me := &person{Name: "derek", Age: 22, Address: "85 Second St"} + ch <- me +} + +// BindRecvChan() allows binding of a Go channel to a nats +// subject for subscribe operations. The Encoder attached to the +// EncodedConn will be used for un-marshaling. +func ExampleEncodedConn_BindRecvChan() { + nc, _ := nats.Connect(nats.DefaultURL) + c, _ := nats.NewEncodedConn(nc, "json") + defer c.Close() + + type person struct { + Name string + Address string + Age int + } + + ch := make(chan *person) + c.BindRecvChan("hello", ch) + + me := &person{Name: "derek", Age: 22, Address: "85 Second St"} + c.Publish("hello", me) + + // Receive the publish directly on a channel + who := <-ch + + fmt.Printf("%v says hello!\n", who) +} diff --git a/vendor/github.com/nats-io/go-nats/nats.go b/vendor/github.com/nats-io/go-nats/nats.go new file mode 100644 index 00000000..fbb86c03 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/nats.go @@ -0,0 +1,2980 @@ +// Copyright 2012-2017 Apcera Inc. All rights reserved. + +// A Go client for the NATS messaging system (https://nats.io). +package nats + +import ( + "bufio" + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "math/rand" + "net" + "net/url" + "regexp" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "github.com/nats-io/go-nats/util" + "github.com/nats-io/nuid" +) + +// Default Constants +const ( + Version = "1.3.1" + DefaultURL = "nats://localhost:4222" + DefaultPort = 4222 + DefaultMaxReconnect = 60 + DefaultReconnectWait = 2 * time.Second + DefaultTimeout = 2 * time.Second + DefaultPingInterval = 2 * time.Minute + DefaultMaxPingOut = 2 + DefaultMaxChanLen = 8192 // 8k + DefaultReconnectBufSize = 8 * 1024 * 1024 // 8MB + RequestChanLen = 8 + LangString = "go" +) + +// STALE_CONNECTION is for detection and proper handling of stale connections. +const STALE_CONNECTION = "stale connection" + +// PERMISSIONS_ERR is for when nats server subject authorization has failed. +const PERMISSIONS_ERR = "permissions violation" + +// AUTHORIZATION_ERR is for when nats server user authorization has failed. +const AUTHORIZATION_ERR = "authorization violation" + +// Errors +var ( + ErrConnectionClosed = errors.New("nats: connection closed") + ErrSecureConnRequired = errors.New("nats: secure connection required") + ErrSecureConnWanted = errors.New("nats: secure connection not available") + ErrBadSubscription = errors.New("nats: invalid subscription") + ErrTypeSubscription = errors.New("nats: invalid subscription type") + ErrBadSubject = errors.New("nats: invalid subject") + ErrSlowConsumer = errors.New("nats: slow consumer, messages dropped") + ErrTimeout = errors.New("nats: timeout") + ErrBadTimeout = errors.New("nats: timeout invalid") + ErrAuthorization = errors.New("nats: authorization violation") + ErrNoServers = errors.New("nats: no servers available for connection") + ErrJsonParse = errors.New("nats: connect message, json parse error") + ErrChanArg = errors.New("nats: argument needs to be a channel type") + ErrMaxPayload = errors.New("nats: maximum payload exceeded") + ErrMaxMessages = errors.New("nats: maximum messages delivered") + ErrSyncSubRequired = errors.New("nats: illegal call on an async subscription") + ErrMultipleTLSConfigs = errors.New("nats: multiple tls.Configs not allowed") + ErrNoInfoReceived = errors.New("nats: protocol exception, INFO not received") + ErrReconnectBufExceeded = errors.New("nats: outbound buffer limit exceeded") + ErrInvalidConnection = errors.New("nats: invalid connection") + ErrInvalidMsg = errors.New("nats: invalid message or message nil") + ErrInvalidArg = errors.New("nats: invalid argument") + ErrInvalidContext = errors.New("nats: invalid context") + ErrStaleConnection = errors.New("nats: " + STALE_CONNECTION) +) + +// GetDefaultOptions returns default configuration options for the client. +func GetDefaultOptions() Options { + return Options{ + AllowReconnect: true, + MaxReconnect: DefaultMaxReconnect, + ReconnectWait: DefaultReconnectWait, + Timeout: DefaultTimeout, + PingInterval: DefaultPingInterval, + MaxPingsOut: DefaultMaxPingOut, + SubChanLen: DefaultMaxChanLen, + ReconnectBufSize: DefaultReconnectBufSize, + } +} + +// DEPRECATED: Use GetDefaultOptions() instead. +// DefaultOptions is not safe for use by multiple clients. +// For details see #308. +var DefaultOptions = GetDefaultOptions() + +// Status represents the state of the connection. +type Status int + +const ( + DISCONNECTED = Status(iota) + CONNECTED + CLOSED + RECONNECTING + CONNECTING +) + +// ConnHandler is used for asynchronous events such as +// disconnected and closed connections. +type ConnHandler func(*Conn) + +// ErrHandler is used to process asynchronous errors encountered +// while processing inbound messages. +type ErrHandler func(*Conn, *Subscription, error) + +// asyncCB is used to preserve order for async callbacks. +type asyncCB func() + +// Option is a function on the options for a connection. +type Option func(*Options) error + +// Options can be used to create a customized connection. +type Options struct { + + // Url represents a single NATS server url to which the client + // will be connecting. If the Servers option is also set, it + // then becomes the first server in the Servers array. + Url string + + // Servers is a configured set of servers which this client + // will use when attempting to connect. + Servers []string + + // NoRandomize configures whether we will randomize the + // server pool. + NoRandomize bool + + // Name is an optional name label which will be sent to the server + // on CONNECT to identify the client. + Name string + + // Verbose signals the server to send an OK ack for commands + // successfully processed by the server. + Verbose bool + + // Pedantic signals the server whether it should be doing further + // validation of subjects. + Pedantic bool + + // Secure enables TLS secure connections that skip server + // verification by default. NOT RECOMMENDED. + Secure bool + + // TLSConfig is a custom TLS configuration to use for secure + // transports. + TLSConfig *tls.Config + + // AllowReconnect enables reconnection logic to be used when we + // encounter a disconnect from the current server. + AllowReconnect bool + + // MaxReconnect sets the number of reconnect attempts that will be + // tried before giving up. If negative, then it will never give up + // trying to reconnect. + MaxReconnect int + + // ReconnectWait sets the time to backoff after attempting a reconnect + // to a server that we were already connected to previously. + ReconnectWait time.Duration + + // Timeout sets the timeout for a Dial operation on a connection. + Timeout time.Duration + + // FlusherTimeout is the maximum time to wait for the flusher loop + // to be able to finish writing to the underlying connection. + FlusherTimeout time.Duration + + // PingInterval is the period at which the client will be sending ping + // commands to the server, disabled if 0 or negative. + PingInterval time.Duration + + // MaxPingsOut is the maximum number of pending ping commands that can + // be awaiting a response before raising an ErrStaleConnection error. + MaxPingsOut int + + // ClosedCB sets the closed handler that is called when a client will + // no longer be connected. + ClosedCB ConnHandler + + // DisconnectedCB sets the disconnected handler that is called + // whenever the connection is disconnected. + DisconnectedCB ConnHandler + + // ReconnectedCB sets the reconnected handler called whenever + // the connection is successfully reconnected. + ReconnectedCB ConnHandler + + // DiscoveredServersCB sets the callback that is invoked whenever a new + // server has joined the cluster. + DiscoveredServersCB ConnHandler + + // AsyncErrorCB sets the async error handler (e.g. slow consumer errors) + AsyncErrorCB ErrHandler + + // ReconnectBufSize is the size of the backing bufio during reconnect. + // Once this has been exhausted publish operations will return an error. + ReconnectBufSize int + + // SubChanLen is the size of the buffered channel used between the socket + // Go routine and the message delivery for SyncSubscriptions. + // NOTE: This does not affect AsyncSubscriptions which are + // dictated by PendingLimits() + SubChanLen int + + // User sets the username to be used when connecting to the server. + User string + + // Password sets the password to be used when connecting to a server. + Password string + + // Token sets the token to be used when connecting to a server. + Token string + + // Dialer allows a custom Dialer when forming connections. + Dialer *net.Dialer + + // UseOldRequestStyle forces the old method of Requests that utilize + // a new Inbox and a new Subscription for each request. + UseOldRequestStyle bool +} + +const ( + // Scratch storage for assembling protocol headers + scratchSize = 512 + + // The size of the bufio reader/writer on top of the socket. + defaultBufSize = 32768 + + // The buffered size of the flush "kick" channel + flushChanSize = 1024 + + // Default server pool size + srvPoolSize = 4 + + // Channel size for the async callback handler. + asyncCBChanSize = 32 + + // NUID size + nuidSize = 22 +) + +// A Conn represents a bare connection to a nats-server. +// It can send and receive []byte payloads. +type Conn struct { + // Keep all members for which we use atomic at the beginning of the + // struct and make sure they are all 64bits (or use padding if necessary). + // atomic.* functions crash on 32bit machines if operand is not aligned + // at 64bit. See https://github.com/golang/go/issues/599 + Statistics + mu sync.Mutex + Opts Options + wg *sync.WaitGroup + url *url.URL + conn net.Conn + srvPool []*srv + urls map[string]struct{} // Keep track of all known URLs (used by processInfo) + bw *bufio.Writer + pending *bytes.Buffer + fch chan struct{} + info serverInfo + ssid int64 + subsMu sync.RWMutex + subs map[int64]*Subscription + ach chan asyncCB + pongs []chan struct{} + scratch [scratchSize]byte + status Status + initc bool // true if the connection is performing the initial connect + err error + ps *parseState + ptmr *time.Timer + pout int + + // New style response handler + respSub string // The wildcard subject + respMux *Subscription // A single response subscription + respMap map[string]chan *Msg // Request map for the response msg channels + respSetup sync.Once // Ensures response subscription occurs once +} + +// A Subscription represents interest in a given subject. +type Subscription struct { + mu sync.Mutex + sid int64 + + // Subject that represents this subscription. This can be different + // than the received subject inside a Msg if this is a wildcard. + Subject string + + // Optional queue group name. If present, all subscriptions with the + // same name will form a distributed queue, and each message will + // only be processed by one member of the group. + Queue string + + delivered uint64 + max uint64 + conn *Conn + mcb MsgHandler + mch chan *Msg + closed bool + sc bool + connClosed bool + + // Type of Subscription + typ SubscriptionType + + // Async linked list + pHead *Msg + pTail *Msg + pCond *sync.Cond + + // Pending stats, async subscriptions, high-speed etc. + pMsgs int + pBytes int + pMsgsMax int + pBytesMax int + pMsgsLimit int + pBytesLimit int + dropped int +} + +// Msg is a structure used by Subscribers and PublishMsg(). +type Msg struct { + Subject string + Reply string + Data []byte + Sub *Subscription + next *Msg +} + +// Tracks various stats received and sent on this connection, +// including counts for messages and bytes. +type Statistics struct { + InMsgs uint64 + OutMsgs uint64 + InBytes uint64 + OutBytes uint64 + Reconnects uint64 +} + +// Tracks individual backend servers. +type srv struct { + url *url.URL + didConnect bool + reconnects int + lastAttempt time.Time + isImplicit bool +} + +type serverInfo struct { + Id string `json:"server_id"` + Host string `json:"host"` + Port uint `json:"port"` + Version string `json:"version"` + AuthRequired bool `json:"auth_required"` + TLSRequired bool `json:"tls_required"` + MaxPayload int64 `json:"max_payload"` + ConnectURLs []string `json:"connect_urls,omitempty"` +} + +const ( + // clientProtoZero is the original client protocol from 2009. + // http://nats.io/documentation/internals/nats-protocol/ + /* clientProtoZero */ _ = iota + // clientProtoInfo signals a client can receive more then the original INFO block. + // This can be used to update clients on other cluster members, etc. + clientProtoInfo +) + +type connectInfo struct { + Verbose bool `json:"verbose"` + Pedantic bool `json:"pedantic"` + User string `json:"user,omitempty"` + Pass string `json:"pass,omitempty"` + Token string `json:"auth_token,omitempty"` + TLS bool `json:"tls_required"` + Name string `json:"name"` + Lang string `json:"lang"` + Version string `json:"version"` + Protocol int `json:"protocol"` +} + +// MsgHandler is a callback function that processes messages delivered to +// asynchronous subscribers. +type MsgHandler func(msg *Msg) + +// Connect will attempt to connect to the NATS system. +// The url can contain username/password semantics. e.g. nats://derek:pass@localhost:4222 +// Comma separated arrays are also supported, e.g. urlA, urlB. +// Options start with the defaults but can be overridden. +func Connect(url string, options ...Option) (*Conn, error) { + opts := GetDefaultOptions() + opts.Servers = processUrlString(url) + for _, opt := range options { + if err := opt(&opts); err != nil { + return nil, err + } + } + return opts.Connect() +} + +// Options that can be passed to Connect. + +// Name is an Option to set the client name. +func Name(name string) Option { + return func(o *Options) error { + o.Name = name + return nil + } +} + +// Secure is an Option to enable TLS secure connections that skip server verification by default. +// Pass a TLS Configuration for proper TLS. +func Secure(tls ...*tls.Config) Option { + return func(o *Options) error { + o.Secure = true + // Use of variadic just simplifies testing scenarios. We only take the first one. + // fixme(DLC) - Could panic if more than one. Could also do TLS option. + if len(tls) > 1 { + return ErrMultipleTLSConfigs + } + if len(tls) == 1 { + o.TLSConfig = tls[0] + } + return nil + } +} + +// RootCAs is a helper option to provide the RootCAs pool from a list of filenames. If Secure is +// not already set this will set it as well. +func RootCAs(file ...string) Option { + return func(o *Options) error { + pool := x509.NewCertPool() + for _, f := range file { + rootPEM, err := ioutil.ReadFile(f) + if err != nil || rootPEM == nil { + return fmt.Errorf("nats: error loading or parsing rootCA file: %v", err) + } + ok := pool.AppendCertsFromPEM(rootPEM) + if !ok { + return fmt.Errorf("nats: failed to parse root certificate from %q", f) + } + } + if o.TLSConfig == nil { + o.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + o.TLSConfig.RootCAs = pool + o.Secure = true + return nil + } +} + +// ClientCert is a helper option to provide the client certificate from a file. If Secure is +// not already set this will set it as well +func ClientCert(certFile, keyFile string) Option { + return func(o *Options) error { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return fmt.Errorf("nats: error loading client certificate: %v", err) + } + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return fmt.Errorf("nats: error parsing client certificate: %v", err) + } + if o.TLSConfig == nil { + o.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + o.TLSConfig.Certificates = []tls.Certificate{cert} + o.Secure = true + return nil + } +} + +// NoReconnect is an Option to turn off reconnect behavior. +func NoReconnect() Option { + return func(o *Options) error { + o.AllowReconnect = false + return nil + } +} + +// DontRandomize is an Option to turn off randomizing the server pool. +func DontRandomize() Option { + return func(o *Options) error { + o.NoRandomize = true + return nil + } +} + +// ReconnectWait is an Option to set the wait time between reconnect attempts. +func ReconnectWait(t time.Duration) Option { + return func(o *Options) error { + o.ReconnectWait = t + return nil + } +} + +// MaxReconnects is an Option to set the maximum number of reconnect attempts. +func MaxReconnects(max int) Option { + return func(o *Options) error { + o.MaxReconnect = max + return nil + } +} + +// Timeout is an Option to set the timeout for Dial on a connection. +func Timeout(t time.Duration) Option { + return func(o *Options) error { + o.Timeout = t + return nil + } +} + +// DisconnectHandler is an Option to set the disconnected handler. +func DisconnectHandler(cb ConnHandler) Option { + return func(o *Options) error { + o.DisconnectedCB = cb + return nil + } +} + +// ReconnectHandler is an Option to set the reconnected handler. +func ReconnectHandler(cb ConnHandler) Option { + return func(o *Options) error { + o.ReconnectedCB = cb + return nil + } +} + +// ClosedHandler is an Option to set the closed handler. +func ClosedHandler(cb ConnHandler) Option { + return func(o *Options) error { + o.ClosedCB = cb + return nil + } +} + +// DiscoveredServersHandler is an Option to set the new servers handler. +func DiscoveredServersHandler(cb ConnHandler) Option { + return func(o *Options) error { + o.DiscoveredServersCB = cb + return nil + } +} + +// ErrHandler is an Option to set the async error handler. +func ErrorHandler(cb ErrHandler) Option { + return func(o *Options) error { + o.AsyncErrorCB = cb + return nil + } +} + +// UserInfo is an Option to set the username and password to +// use when not included directly in the URLs. +func UserInfo(user, password string) Option { + return func(o *Options) error { + o.User = user + o.Password = password + return nil + } +} + +// Token is an Option to set the token to use when not included +// directly in the URLs. +func Token(token string) Option { + return func(o *Options) error { + o.Token = token + return nil + } +} + +// Dialer is an Option to set the dialer which will be used when +// attempting to establish a connection. +func Dialer(dialer *net.Dialer) Option { + return func(o *Options) error { + o.Dialer = dialer + return nil + } +} + +// UseOldRequestyStyle is an Option to force usage of the old Request style. +func UseOldRequestStyle() Option { + return func(o *Options) error { + o.UseOldRequestStyle = true + return nil + } +} + +// Handler processing + +// SetDisconnectHandler will set the disconnect event handler. +func (nc *Conn) SetDisconnectHandler(dcb ConnHandler) { + if nc == nil { + return + } + nc.mu.Lock() + defer nc.mu.Unlock() + nc.Opts.DisconnectedCB = dcb +} + +// SetReconnectHandler will set the reconnect event handler. +func (nc *Conn) SetReconnectHandler(rcb ConnHandler) { + if nc == nil { + return + } + nc.mu.Lock() + defer nc.mu.Unlock() + nc.Opts.ReconnectedCB = rcb +} + +// SetDiscoveredServersHandler will set the discovered servers handler. +func (nc *Conn) SetDiscoveredServersHandler(dscb ConnHandler) { + if nc == nil { + return + } + nc.mu.Lock() + defer nc.mu.Unlock() + nc.Opts.DiscoveredServersCB = dscb +} + +// SetClosedHandler will set the reconnect event handler. +func (nc *Conn) SetClosedHandler(cb ConnHandler) { + if nc == nil { + return + } + nc.mu.Lock() + defer nc.mu.Unlock() + nc.Opts.ClosedCB = cb +} + +// SetErrHandler will set the async error handler. +func (nc *Conn) SetErrorHandler(cb ErrHandler) { + if nc == nil { + return + } + nc.mu.Lock() + defer nc.mu.Unlock() + nc.Opts.AsyncErrorCB = cb +} + +// Process the url string argument to Connect. Return an array of +// urls, even if only one. +func processUrlString(url string) []string { + urls := strings.Split(url, ",") + for i, s := range urls { + urls[i] = strings.TrimSpace(s) + } + return urls +} + +// Connect will attempt to connect to a NATS server with multiple options. +func (o Options) Connect() (*Conn, error) { + nc := &Conn{Opts: o} + + // Some default options processing. + if nc.Opts.MaxPingsOut == 0 { + nc.Opts.MaxPingsOut = DefaultMaxPingOut + } + // Allow old default for channel length to work correctly. + if nc.Opts.SubChanLen == 0 { + nc.Opts.SubChanLen = DefaultMaxChanLen + } + // Default ReconnectBufSize + if nc.Opts.ReconnectBufSize == 0 { + nc.Opts.ReconnectBufSize = DefaultReconnectBufSize + } + // Ensure that Timeout is not 0 + if nc.Opts.Timeout == 0 { + nc.Opts.Timeout = DefaultTimeout + } + + // Allow custom Dialer for connecting using DialTimeout by default + if nc.Opts.Dialer == nil { + nc.Opts.Dialer = &net.Dialer{ + Timeout: nc.Opts.Timeout, + } + } + + if err := nc.setupServerPool(); err != nil { + return nil, err + } + + // Create the async callback channel. + nc.ach = make(chan asyncCB, asyncCBChanSize) + + if err := nc.connect(); err != nil { + return nil, err + } + + // Spin up the async cb dispatcher on success + go nc.asyncDispatch() + + return nc, nil +} + +const ( + _CRLF_ = "\r\n" + _EMPTY_ = "" + _SPC_ = " " + _PUB_P_ = "PUB " +) + +const ( + _OK_OP_ = "+OK" + _ERR_OP_ = "-ERR" + _PONG_OP_ = "PONG" + _INFO_OP_ = "INFO" +) + +const ( + conProto = "CONNECT %s" + _CRLF_ + pingProto = "PING" + _CRLF_ + pongProto = "PONG" + _CRLF_ + subProto = "SUB %s %s %d" + _CRLF_ + unsubProto = "UNSUB %d %s" + _CRLF_ + okProto = _OK_OP_ + _CRLF_ +) + +// Return the currently selected server +func (nc *Conn) currentServer() (int, *srv) { + for i, s := range nc.srvPool { + if s == nil { + continue + } + if s.url == nc.url { + return i, s + } + } + return -1, nil +} + +// Pop the current server and put onto the end of the list. Select head of list as long +// as number of reconnect attempts under MaxReconnect. +func (nc *Conn) selectNextServer() (*srv, error) { + i, s := nc.currentServer() + if i < 0 { + return nil, ErrNoServers + } + sp := nc.srvPool + num := len(sp) + copy(sp[i:num-1], sp[i+1:num]) + maxReconnect := nc.Opts.MaxReconnect + if maxReconnect < 0 || s.reconnects < maxReconnect { + nc.srvPool[num-1] = s + } else { + nc.srvPool = sp[0 : num-1] + } + if len(nc.srvPool) <= 0 { + nc.url = nil + return nil, ErrNoServers + } + nc.url = nc.srvPool[0].url + return nc.srvPool[0], nil +} + +// Will assign the correct server to the nc.Url +func (nc *Conn) pickServer() error { + nc.url = nil + if len(nc.srvPool) <= 0 { + return ErrNoServers + } + for _, s := range nc.srvPool { + if s != nil { + nc.url = s.url + return nil + } + } + return ErrNoServers +} + +const tlsScheme = "tls" + +// Create the server pool using the options given. +// We will place a Url option first, followed by any +// Server Options. We will randomize the server pool unless +// the NoRandomize flag is set. +func (nc *Conn) setupServerPool() error { + nc.srvPool = make([]*srv, 0, srvPoolSize) + nc.urls = make(map[string]struct{}, srvPoolSize) + + // Create srv objects from each url string in nc.Opts.Servers + // and add them to the pool + for _, urlString := range nc.Opts.Servers { + if err := nc.addURLToPool(urlString, false); err != nil { + return err + } + } + + // Randomize if allowed to + if !nc.Opts.NoRandomize { + nc.shufflePool() + } + + // Normally, if this one is set, Options.Servers should not be, + // but we always allowed that, so continue to do so. + if nc.Opts.Url != _EMPTY_ { + // Add to the end of the array + if err := nc.addURLToPool(nc.Opts.Url, false); err != nil { + return err + } + // Then swap it with first to guarantee that Options.Url is tried first. + last := len(nc.srvPool) - 1 + if last > 0 { + nc.srvPool[0], nc.srvPool[last] = nc.srvPool[last], nc.srvPool[0] + } + } else if len(nc.srvPool) <= 0 { + // Place default URL if pool is empty. + if err := nc.addURLToPool(DefaultURL, false); err != nil { + return err + } + } + + // Check for Scheme hint to move to TLS mode. + for _, srv := range nc.srvPool { + if srv.url.Scheme == tlsScheme { + // FIXME(dlc), this is for all in the pool, should be case by case. + nc.Opts.Secure = true + if nc.Opts.TLSConfig == nil { + nc.Opts.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + } + } + + return nc.pickServer() +} + +// addURLToPool adds an entry to the server pool +func (nc *Conn) addURLToPool(sURL string, implicit bool) error { + u, err := url.Parse(sURL) + if err != nil { + return err + } + s := &srv{url: u, isImplicit: implicit} + nc.srvPool = append(nc.srvPool, s) + nc.urls[u.Host] = struct{}{} + return nil +} + +// shufflePool swaps randomly elements in the server pool +func (nc *Conn) shufflePool() { + if len(nc.srvPool) <= 1 { + return + } + source := rand.NewSource(time.Now().UnixNano()) + r := rand.New(source) + for i := range nc.srvPool { + j := r.Intn(i + 1) + nc.srvPool[i], nc.srvPool[j] = nc.srvPool[j], nc.srvPool[i] + } +} + +// createConn will connect to the server and wrap the appropriate +// bufio structures. It will do the right thing when an existing +// connection is in place. +func (nc *Conn) createConn() (err error) { + if nc.Opts.Timeout < 0 { + return ErrBadTimeout + } + if _, cur := nc.currentServer(); cur == nil { + return ErrNoServers + } else { + cur.lastAttempt = time.Now() + } + + dialer := nc.Opts.Dialer + nc.conn, err = dialer.Dial("tcp", nc.url.Host) + if err != nil { + return err + } + + // No clue why, but this stalls and kills performance on Mac (Mavericks). + // https://code.google.com/p/go/issues/detail?id=6930 + //if ip, ok := nc.conn.(*net.TCPConn); ok { + // ip.SetReadBuffer(defaultBufSize) + //} + + if nc.pending != nil && nc.bw != nil { + // Move to pending buffer. + nc.bw.Flush() + } + nc.bw = bufio.NewWriterSize(nc.conn, defaultBufSize) + return nil +} + +// makeTLSConn will wrap an existing Conn using TLS +func (nc *Conn) makeTLSConn() { + // Allow the user to configure their own tls.Config structure, otherwise + // default to InsecureSkipVerify. + // TODO(dlc) - We should make the more secure version the default. + if nc.Opts.TLSConfig != nil { + tlsCopy := util.CloneTLSConfig(nc.Opts.TLSConfig) + // If its blank we will override it with the current host + if tlsCopy.ServerName == _EMPTY_ { + h, _, _ := net.SplitHostPort(nc.url.Host) + tlsCopy.ServerName = h + } + nc.conn = tls.Client(nc.conn, tlsCopy) + } else { + nc.conn = tls.Client(nc.conn, &tls.Config{InsecureSkipVerify: true}) + } + conn := nc.conn.(*tls.Conn) + conn.Handshake() + nc.bw = bufio.NewWriterSize(nc.conn, defaultBufSize) +} + +// waitForExits will wait for all socket watcher Go routines to +// be shutdown before proceeding. +func (nc *Conn) waitForExits(wg *sync.WaitGroup) { + // Kick old flusher forcefully. + select { + case nc.fch <- struct{}{}: + default: + } + + // Wait for any previous go routines. + if wg != nil { + wg.Wait() + } +} + +// spinUpGoRoutines will launch the Go routines responsible for +// reading and writing to the socket. This will be launched via a +// go routine itself to release any locks that may be held. +// We also use a WaitGroup to make sure we only start them on a +// reconnect when the previous ones have exited. +func (nc *Conn) spinUpGoRoutines() { + // Make sure everything has exited. + nc.waitForExits(nc.wg) + + // Create a new waitGroup instance for this run. + nc.wg = &sync.WaitGroup{} + // We will wait on both. + nc.wg.Add(2) + + // Spin up the readLoop and the socket flusher. + go nc.readLoop(nc.wg) + go nc.flusher(nc.wg) + + nc.mu.Lock() + if nc.Opts.PingInterval > 0 { + if nc.ptmr == nil { + nc.ptmr = time.AfterFunc(nc.Opts.PingInterval, nc.processPingTimer) + } else { + nc.ptmr.Reset(nc.Opts.PingInterval) + } + } + nc.mu.Unlock() +} + +// Report the connected server's Url +func (nc *Conn) ConnectedUrl() string { + if nc == nil { + return _EMPTY_ + } + nc.mu.Lock() + defer nc.mu.Unlock() + if nc.status != CONNECTED { + return _EMPTY_ + } + return nc.url.String() +} + +// Report the connected server's Id +func (nc *Conn) ConnectedServerId() string { + if nc == nil { + return _EMPTY_ + } + nc.mu.Lock() + defer nc.mu.Unlock() + if nc.status != CONNECTED { + return _EMPTY_ + } + return nc.info.Id +} + +// Low level setup for structs, etc +func (nc *Conn) setup() { + nc.subs = make(map[int64]*Subscription) + nc.pongs = make([]chan struct{}, 0, 8) + + nc.fch = make(chan struct{}, flushChanSize) + + // Setup scratch outbound buffer for PUB + pub := nc.scratch[:len(_PUB_P_)] + copy(pub, _PUB_P_) +} + +// Process a connected connection and initialize properly. +func (nc *Conn) processConnectInit() error { + + // Set out deadline for the whole connect process + nc.conn.SetDeadline(time.Now().Add(nc.Opts.Timeout)) + defer nc.conn.SetDeadline(time.Time{}) + + // Set our status to connecting. + nc.status = CONNECTING + + // Process the INFO protocol received from the server + err := nc.processExpectedInfo() + if err != nil { + return err + } + + // Send the CONNECT protocol along with the initial PING protocol. + // Wait for the PONG response (or any error that we get from the server). + err = nc.sendConnect() + if err != nil { + return err + } + + // Reset the number of PING sent out + nc.pout = 0 + + go nc.spinUpGoRoutines() + + return nil +} + +// Main connect function. Will connect to the nats-server +func (nc *Conn) connect() error { + var returnedErr error + + // Create actual socket connection + // For first connect we walk all servers in the pool and try + // to connect immediately. + nc.mu.Lock() + nc.initc = true + // The pool may change inside the loop iteration due to INFO protocol. + for i := 0; i < len(nc.srvPool); i++ { + nc.url = nc.srvPool[i].url + + if err := nc.createConn(); err == nil { + // This was moved out of processConnectInit() because + // that function is now invoked from doReconnect() too. + nc.setup() + + err = nc.processConnectInit() + + if err == nil { + nc.srvPool[i].didConnect = true + nc.srvPool[i].reconnects = 0 + returnedErr = nil + break + } else { + returnedErr = err + nc.mu.Unlock() + nc.close(DISCONNECTED, false) + nc.mu.Lock() + nc.url = nil + } + } else { + // Cancel out default connection refused, will trigger the + // No servers error conditional + if matched, _ := regexp.Match(`connection refused`, []byte(err.Error())); matched { + returnedErr = nil + } + } + } + nc.initc = false + defer nc.mu.Unlock() + + if returnedErr == nil && nc.status != CONNECTED { + returnedErr = ErrNoServers + } + return returnedErr +} + +// This will check to see if the connection should be +// secure. This can be dictated from either end and should +// only be called after the INIT protocol has been received. +func (nc *Conn) checkForSecure() error { + // Check to see if we need to engage TLS + o := nc.Opts + + // Check for mismatch in setups + if o.Secure && !nc.info.TLSRequired { + return ErrSecureConnWanted + } else if nc.info.TLSRequired && !o.Secure { + return ErrSecureConnRequired + } + + // Need to rewrap with bufio + if o.Secure { + nc.makeTLSConn() + } + return nil +} + +// processExpectedInfo will look for the expected first INFO message +// sent when a connection is established. The lock should be held entering. +func (nc *Conn) processExpectedInfo() error { + + c := &control{} + + // Read the protocol + err := nc.readOp(c) + if err != nil { + return err + } + + // The nats protocol should send INFO first always. + if c.op != _INFO_OP_ { + return ErrNoInfoReceived + } + + // Parse the protocol + if err := nc.processInfo(c.args); err != nil { + return err + } + + return nc.checkForSecure() +} + +// Sends a protocol control message by queuing into the bufio writer +// and kicking the flush Go routine. These writes are protected. +func (nc *Conn) sendProto(proto string) { + nc.mu.Lock() + nc.bw.WriteString(proto) + nc.kickFlusher() + nc.mu.Unlock() +} + +// Generate a connect protocol message, issuing user/password if +// applicable. The lock is assumed to be held upon entering. +func (nc *Conn) connectProto() (string, error) { + o := nc.Opts + var user, pass, token string + u := nc.url.User + if u != nil { + // if no password, assume username is authToken + if _, ok := u.Password(); !ok { + token = u.Username() + } else { + user = u.Username() + pass, _ = u.Password() + } + } else { + // Take from options (pssibly all empty strings) + user = nc.Opts.User + pass = nc.Opts.Password + token = nc.Opts.Token + } + cinfo := connectInfo{o.Verbose, o.Pedantic, + user, pass, token, + o.Secure, o.Name, LangString, Version, clientProtoInfo} + b, err := json.Marshal(cinfo) + if err != nil { + return _EMPTY_, ErrJsonParse + } + return fmt.Sprintf(conProto, b), nil +} + +// normalizeErr removes the prefix -ERR, trim spaces and remove the quotes. +func normalizeErr(line string) string { + s := strings.ToLower(strings.TrimSpace(strings.TrimPrefix(line, _ERR_OP_))) + s = strings.TrimLeft(strings.TrimRight(s, "'"), "'") + return s +} + +// Send a connect protocol message to the server, issue user/password if +// applicable. Will wait for a flush to return from the server for error +// processing. +func (nc *Conn) sendConnect() error { + + // Construct the CONNECT protocol string + cProto, err := nc.connectProto() + if err != nil { + return err + } + + // Write the protocol into the buffer + _, err = nc.bw.WriteString(cProto) + if err != nil { + return err + } + + // Add to the buffer the PING protocol + _, err = nc.bw.WriteString(pingProto) + if err != nil { + return err + } + + // Flush the buffer + err = nc.bw.Flush() + if err != nil { + return err + } + + // Now read the response from the server. + br := bufio.NewReaderSize(nc.conn, defaultBufSize) + line, err := br.ReadString('\n') + if err != nil { + return err + } + + // If opts.Verbose is set, handle +OK + if nc.Opts.Verbose && line == okProto { + // Read the rest now... + line, err = br.ReadString('\n') + if err != nil { + return err + } + } + + // We expect a PONG + if line != pongProto { + // But it could be something else, like -ERR + + // Since we no longer use ReadLine(), trim the trailing "\r\n" + line = strings.TrimRight(line, "\r\n") + + // If it's a server error... + if strings.HasPrefix(line, _ERR_OP_) { + // Remove -ERR, trim spaces and quotes, and convert to lower case. + line = normalizeErr(line) + return errors.New("nats: " + line) + } + + // Notify that we got an unexpected protocol. + return fmt.Errorf("nats: expected '%s', got '%s'", _PONG_OP_, line) + } + + // This is where we are truly connected. + nc.status = CONNECTED + + return nil +} + +// A control protocol line. +type control struct { + op, args string +} + +// Read a control line and process the intended op. +func (nc *Conn) readOp(c *control) error { + br := bufio.NewReaderSize(nc.conn, defaultBufSize) + line, err := br.ReadString('\n') + if err != nil { + return err + } + parseControl(line, c) + return nil +} + +// Parse a control line from the server. +func parseControl(line string, c *control) { + toks := strings.SplitN(line, _SPC_, 2) + if len(toks) == 1 { + c.op = strings.TrimSpace(toks[0]) + c.args = _EMPTY_ + } else if len(toks) == 2 { + c.op, c.args = strings.TrimSpace(toks[0]), strings.TrimSpace(toks[1]) + } else { + c.op = _EMPTY_ + } +} + +// flushReconnectPending will push the pending items that were +// gathered while we were in a RECONNECTING state to the socket. +func (nc *Conn) flushReconnectPendingItems() { + if nc.pending == nil { + return + } + if nc.pending.Len() > 0 { + nc.bw.Write(nc.pending.Bytes()) + } +} + +// Try to reconnect using the option parameters. +// This function assumes we are allowed to reconnect. +func (nc *Conn) doReconnect() { + // We want to make sure we have the other watchers shutdown properly + // here before we proceed past this point. + nc.mu.Lock() + wg := nc.wg + nc.mu.Unlock() + nc.waitForExits(wg) + + // FIXME(dlc) - We have an issue here if we have + // outstanding flush points (pongs) and they were not + // sent out, but are still in the pipe. + + // Hold the lock manually and release where needed below, + // can't do defer here. + nc.mu.Lock() + + // Clear any queued pongs, e.g. pending flush calls. + nc.clearPendingFlushCalls() + + // Clear any errors. + nc.err = nil + + // Perform appropriate callback if needed for a disconnect. + if nc.Opts.DisconnectedCB != nil { + nc.ach <- func() { nc.Opts.DisconnectedCB(nc) } + } + + for len(nc.srvPool) > 0 { + cur, err := nc.selectNextServer() + if err != nil { + nc.err = err + break + } + + sleepTime := int64(0) + + // Sleep appropriate amount of time before the + // connection attempt if connecting to same server + // we just got disconnected from.. + if time.Since(cur.lastAttempt) < nc.Opts.ReconnectWait { + sleepTime = int64(nc.Opts.ReconnectWait - time.Since(cur.lastAttempt)) + } + + // On Windows, createConn() will take more than a second when no + // server is running at that address. So it could be that the + // time elapsed between reconnect attempts is always > than + // the set option. Release the lock to give a chance to a parallel + // nc.Close() to break the loop. + nc.mu.Unlock() + if sleepTime <= 0 { + runtime.Gosched() + } else { + time.Sleep(time.Duration(sleepTime)) + } + nc.mu.Lock() + + // Check if we have been closed first. + if nc.isClosed() { + break + } + + // Mark that we tried a reconnect + cur.reconnects++ + + // Try to create a new connection + err = nc.createConn() + + // Not yet connected, retry... + // Continue to hold the lock + if err != nil { + nc.err = nil + continue + } + + // We are reconnected + nc.Reconnects++ + + // Process connect logic + if nc.err = nc.processConnectInit(); nc.err != nil { + nc.status = RECONNECTING + continue + } + + // Clear out server stats for the server we connected to.. + cur.didConnect = true + cur.reconnects = 0 + + // Send existing subscription state + nc.resendSubscriptions() + + // Now send off and clear pending buffer + nc.flushReconnectPendingItems() + + // Flush the buffer + nc.err = nc.bw.Flush() + if nc.err != nil { + nc.status = RECONNECTING + continue + } + + // Done with the pending buffer + nc.pending = nil + + // This is where we are truly connected. + nc.status = CONNECTED + + // Queue up the reconnect callback. + if nc.Opts.ReconnectedCB != nil { + nc.ach <- func() { nc.Opts.ReconnectedCB(nc) } + } + + // Release lock here, we will return below. + nc.mu.Unlock() + + // Make sure to flush everything + nc.Flush() + + return + } + + // Call into close.. We have no servers left.. + if nc.err == nil { + nc.err = ErrNoServers + } + nc.mu.Unlock() + nc.Close() +} + +// processOpErr handles errors from reading or parsing the protocol. +// The lock should not be held entering this function. +func (nc *Conn) processOpErr(err error) { + nc.mu.Lock() + if nc.isConnecting() || nc.isClosed() || nc.isReconnecting() { + nc.mu.Unlock() + return + } + + if nc.Opts.AllowReconnect && nc.status == CONNECTED { + // Set our new status + nc.status = RECONNECTING + if nc.ptmr != nil { + nc.ptmr.Stop() + } + if nc.conn != nil { + nc.bw.Flush() + nc.conn.Close() + nc.conn = nil + } + + // Create a new pending buffer to underpin the bufio Writer while + // we are reconnecting. + nc.pending = &bytes.Buffer{} + nc.bw = bufio.NewWriterSize(nc.pending, nc.Opts.ReconnectBufSize) + + go nc.doReconnect() + nc.mu.Unlock() + return + } + + nc.status = DISCONNECTED + nc.err = err + nc.mu.Unlock() + nc.Close() +} + +// Marker to close the channel to kick out the Go routine. +func (nc *Conn) closeAsyncFunc() asyncCB { + return func() { + nc.mu.Lock() + if nc.ach != nil { + close(nc.ach) + nc.ach = nil + } + nc.mu.Unlock() + } +} + +// asyncDispatch is responsible for calling any async callbacks +func (nc *Conn) asyncDispatch() { + // snapshot since they can change from underneath of us. + nc.mu.Lock() + ach := nc.ach + nc.mu.Unlock() + + // Loop on the channel and process async callbacks. + for { + if f, ok := <-ach; !ok { + return + } else { + f() + } + } +} + +// readLoop() will sit on the socket reading and processing the +// protocol from the server. It will dispatch appropriately based +// on the op type. +func (nc *Conn) readLoop(wg *sync.WaitGroup) { + // Release the wait group on exit + defer wg.Done() + + // Create a parseState if needed. + nc.mu.Lock() + if nc.ps == nil { + nc.ps = &parseState{} + } + nc.mu.Unlock() + + // Stack based buffer. + b := make([]byte, defaultBufSize) + + for { + // FIXME(dlc): RWLock here? + nc.mu.Lock() + sb := nc.isClosed() || nc.isReconnecting() + if sb { + nc.ps = &parseState{} + } + conn := nc.conn + nc.mu.Unlock() + + if sb || conn == nil { + break + } + + n, err := conn.Read(b) + if err != nil { + nc.processOpErr(err) + break + } + + if err := nc.parse(b[:n]); err != nil { + nc.processOpErr(err) + break + } + } + // Clear the parseState here.. + nc.mu.Lock() + nc.ps = nil + nc.mu.Unlock() +} + +// waitForMsgs waits on the conditional shared with readLoop and processMsg. +// It is used to deliver messages to asynchronous subscribers. +func (nc *Conn) waitForMsgs(s *Subscription) { + var closed bool + var delivered, max uint64 + + for { + s.mu.Lock() + if s.pHead == nil && !s.closed { + s.pCond.Wait() + } + // Pop the msg off the list + m := s.pHead + if m != nil { + s.pHead = m.next + if s.pHead == nil { + s.pTail = nil + } + s.pMsgs-- + s.pBytes -= len(m.Data) + } + mcb := s.mcb + max = s.max + closed = s.closed + if !s.closed { + s.delivered++ + delivered = s.delivered + } + s.mu.Unlock() + + if closed { + break + } + + // Deliver the message. + if m != nil && (max == 0 || delivered <= max) { + mcb(m) + } + // If we have hit the max for delivered msgs, remove sub. + if max > 0 && delivered >= max { + nc.mu.Lock() + nc.removeSub(s) + nc.mu.Unlock() + break + } + } +} + +// processMsg is called by parse and will place the msg on the +// appropriate channel/pending queue for processing. If the channel is full, +// or the pending queue is over the pending limits, the connection is +// considered a slow consumer. +func (nc *Conn) processMsg(data []byte) { + // Don't lock the connection to avoid server cutting us off if the + // flusher is holding the connection lock, trying to send to the server + // that is itself trying to send data to us. + nc.subsMu.RLock() + + // Stats + nc.InMsgs++ + nc.InBytes += uint64(len(data)) + + sub := nc.subs[nc.ps.ma.sid] + if sub == nil { + nc.subsMu.RUnlock() + return + } + + // Copy them into string + subj := string(nc.ps.ma.subject) + reply := string(nc.ps.ma.reply) + + // Doing message create outside of the sub's lock to reduce contention. + // It's possible that we end-up not using the message, but that's ok. + + // FIXME(dlc): Need to copy, should/can do COW? + msgPayload := make([]byte, len(data)) + copy(msgPayload, data) + + // FIXME(dlc): Should we recycle these containers? + m := &Msg{Data: msgPayload, Subject: subj, Reply: reply, Sub: sub} + + sub.mu.Lock() + + // Subscription internal stats (applicable only for non ChanSubscription's) + if sub.typ != ChanSubscription { + sub.pMsgs++ + if sub.pMsgs > sub.pMsgsMax { + sub.pMsgsMax = sub.pMsgs + } + sub.pBytes += len(m.Data) + if sub.pBytes > sub.pBytesMax { + sub.pBytesMax = sub.pBytes + } + + // Check for a Slow Consumer + if (sub.pMsgsLimit > 0 && sub.pMsgs > sub.pMsgsLimit) || + (sub.pBytesLimit > 0 && sub.pBytes > sub.pBytesLimit) { + goto slowConsumer + } + } + + // We have two modes of delivery. One is the channel, used by channel + // subscribers and syncSubscribers, the other is a linked list for async. + if sub.mch != nil { + select { + case sub.mch <- m: + default: + goto slowConsumer + } + } else { + // Push onto the async pList + if sub.pHead == nil { + sub.pHead = m + sub.pTail = m + sub.pCond.Signal() + } else { + sub.pTail.next = m + sub.pTail = m + } + } + + // Clear SlowConsumer status. + sub.sc = false + + sub.mu.Unlock() + nc.subsMu.RUnlock() + return + +slowConsumer: + sub.dropped++ + sc := !sub.sc + sub.sc = true + // Undo stats from above + if sub.typ != ChanSubscription { + sub.pMsgs-- + sub.pBytes -= len(m.Data) + } + sub.mu.Unlock() + nc.subsMu.RUnlock() + if sc { + // Now we need connection's lock and we may end-up in the situation + // that we were trying to avoid, except that in this case, the client + // is already experiencing client-side slow consumer situation. + nc.mu.Lock() + nc.err = ErrSlowConsumer + if nc.Opts.AsyncErrorCB != nil { + nc.ach <- func() { nc.Opts.AsyncErrorCB(nc, sub, ErrSlowConsumer) } + } + nc.mu.Unlock() + } +} + +// processPermissionsViolation is called when the server signals a subject +// permissions violation on either publish or subscribe. +func (nc *Conn) processPermissionsViolation(err string) { + nc.mu.Lock() + // create error here so we can pass it as a closure to the async cb dispatcher. + e := errors.New("nats: " + err) + nc.err = e + if nc.Opts.AsyncErrorCB != nil { + nc.ach <- func() { nc.Opts.AsyncErrorCB(nc, nil, e) } + } + nc.mu.Unlock() +} + +// processAuthorizationViolation is called when the server signals a user +// authorization violation. +func (nc *Conn) processAuthorizationViolation(err string) { + nc.mu.Lock() + nc.err = ErrAuthorization + if nc.Opts.AsyncErrorCB != nil { + nc.ach <- func() { nc.Opts.AsyncErrorCB(nc, nil, ErrAuthorization) } + } + nc.mu.Unlock() +} + +// flusher is a separate Go routine that will process flush requests for the write +// bufio. This allows coalescing of writes to the underlying socket. +func (nc *Conn) flusher(wg *sync.WaitGroup) { + // Release the wait group + defer wg.Done() + + // snapshot the bw and conn since they can change from underneath of us. + nc.mu.Lock() + bw := nc.bw + conn := nc.conn + fch := nc.fch + flusherTimeout := nc.Opts.FlusherTimeout + nc.mu.Unlock() + + if conn == nil || bw == nil { + return + } + + for { + if _, ok := <-fch; !ok { + return + } + nc.mu.Lock() + + // Check to see if we should bail out. + if !nc.isConnected() || nc.isConnecting() || bw != nc.bw || conn != nc.conn { + nc.mu.Unlock() + return + } + if bw.Buffered() > 0 { + // Allow customizing how long we should wait for a flush to be done + // to prevent unhealthy connections blocking the client for too long. + if flusherTimeout > 0 { + conn.SetWriteDeadline(time.Now().Add(flusherTimeout)) + } + + if err := bw.Flush(); err != nil { + if nc.err == nil { + nc.err = err + } + } + conn.SetWriteDeadline(time.Time{}) + } + nc.mu.Unlock() + } +} + +// processPing will send an immediate pong protocol response to the +// server. The server uses this mechanism to detect dead clients. +func (nc *Conn) processPing() { + nc.sendProto(pongProto) +} + +// processPong is used to process responses to the client's ping +// messages. We use pings for the flush mechanism as well. +func (nc *Conn) processPong() { + var ch chan struct{} + + nc.mu.Lock() + if len(nc.pongs) > 0 { + ch = nc.pongs[0] + nc.pongs = nc.pongs[1:] + } + nc.pout = 0 + nc.mu.Unlock() + if ch != nil { + ch <- struct{}{} + } +} + +// processOK is a placeholder for processing OK messages. +func (nc *Conn) processOK() { + // do nothing +} + +// processInfo is used to parse the info messages sent +// from the server. +// This function may update the server pool. +func (nc *Conn) processInfo(info string) error { + if info == _EMPTY_ { + return nil + } + if err := json.Unmarshal([]byte(info), &nc.info); err != nil { + return err + } + urls := nc.info.ConnectURLs + if len(urls) > 0 { + added := false + // If randomization is allowed, shuffle the received array, not the + // entire pool. We want to preserve the pool's order up to this point + // (this would otherwise be problematic for the (re)connect loop). + if !nc.Opts.NoRandomize { + for i := range urls { + j := rand.Intn(i + 1) + urls[i], urls[j] = urls[j], urls[i] + } + } + for _, curl := range urls { + if _, present := nc.urls[curl]; !present { + if err := nc.addURLToPool(fmt.Sprintf("nats://%s", curl), true); err != nil { + continue + } + added = true + } + } + if added && !nc.initc && nc.Opts.DiscoveredServersCB != nil { + nc.ach <- func() { nc.Opts.DiscoveredServersCB(nc) } + } + } + return nil +} + +// processAsyncInfo does the same than processInfo, but is called +// from the parser. Calls processInfo under connection's lock +// protection. +func (nc *Conn) processAsyncInfo(info []byte) { + nc.mu.Lock() + // Ignore errors, we will simply not update the server pool... + nc.processInfo(string(info)) + nc.mu.Unlock() +} + +// LastError reports the last error encountered via the connection. +// It can be used reliably within ClosedCB in order to find out reason +// why connection was closed for example. +func (nc *Conn) LastError() error { + if nc == nil { + return ErrInvalidConnection + } + nc.mu.Lock() + err := nc.err + nc.mu.Unlock() + return err +} + +// processErr processes any error messages from the server and +// sets the connection's lastError. +func (nc *Conn) processErr(e string) { + // Trim, remove quotes, convert to lower case. + e = normalizeErr(e) + + // FIXME(dlc) - process Slow Consumer signals special. + if e == STALE_CONNECTION { + nc.processOpErr(ErrStaleConnection) + } else if strings.HasPrefix(e, PERMISSIONS_ERR) { + nc.processPermissionsViolation(e) + } else if strings.HasPrefix(e, AUTHORIZATION_ERR) { + nc.processAuthorizationViolation(e) + } else { + nc.mu.Lock() + nc.err = errors.New("nats: " + e) + nc.mu.Unlock() + nc.Close() + } +} + +// kickFlusher will send a bool on a channel to kick the +// flush Go routine to flush data to the server. +func (nc *Conn) kickFlusher() { + if nc.bw != nil { + select { + case nc.fch <- struct{}{}: + default: + } + } +} + +// Publish publishes the data argument to the given subject. The data +// argument is left untouched and needs to be correctly interpreted on +// the receiver. +func (nc *Conn) Publish(subj string, data []byte) error { + return nc.publish(subj, _EMPTY_, data) +} + +// PublishMsg publishes the Msg structure, which includes the +// Subject, an optional Reply and an optional Data field. +func (nc *Conn) PublishMsg(m *Msg) error { + if m == nil { + return ErrInvalidMsg + } + return nc.publish(m.Subject, m.Reply, m.Data) +} + +// PublishRequest will perform a Publish() excpecting a response on the +// reply subject. Use Request() for automatically waiting for a response +// inline. +func (nc *Conn) PublishRequest(subj, reply string, data []byte) error { + return nc.publish(subj, reply, data) +} + +// Used for handrolled itoa +const digits = "0123456789" + +// publish is the internal function to publish messages to a nats-server. +// Sends a protocol data message by queuing into the bufio writer +// and kicking the flush go routine. These writes should be protected. +func (nc *Conn) publish(subj, reply string, data []byte) error { + if nc == nil { + return ErrInvalidConnection + } + if subj == "" { + return ErrBadSubject + } + nc.mu.Lock() + + // Proactively reject payloads over the threshold set by server. + msgSize := int64(len(data)) + if msgSize > nc.info.MaxPayload { + nc.mu.Unlock() + return ErrMaxPayload + } + + if nc.isClosed() { + nc.mu.Unlock() + return ErrConnectionClosed + } + + // Check if we are reconnecting, and if so check if + // we have exceeded our reconnect outbound buffer limits. + if nc.isReconnecting() { + // Flush to underlying buffer. + nc.bw.Flush() + // Check if we are over + if nc.pending.Len() >= nc.Opts.ReconnectBufSize { + nc.mu.Unlock() + return ErrReconnectBufExceeded + } + } + + msgh := nc.scratch[:len(_PUB_P_)] + msgh = append(msgh, subj...) + msgh = append(msgh, ' ') + if reply != "" { + msgh = append(msgh, reply...) + msgh = append(msgh, ' ') + } + + // We could be smarter here, but simple loop is ok, + // just avoid strconv in fast path + // FIXME(dlc) - Find a better way here. + // msgh = strconv.AppendInt(msgh, int64(len(data)), 10) + + var b [12]byte + var i = len(b) + if len(data) > 0 { + for l := len(data); l > 0; l /= 10 { + i -= 1 + b[i] = digits[l%10] + } + } else { + i -= 1 + b[i] = digits[0] + } + + msgh = append(msgh, b[i:]...) + msgh = append(msgh, _CRLF_...) + + _, err := nc.bw.Write(msgh) + if err == nil { + _, err = nc.bw.Write(data) + } + if err == nil { + _, err = nc.bw.WriteString(_CRLF_) + } + if err != nil { + nc.mu.Unlock() + return err + } + + nc.OutMsgs++ + nc.OutBytes += uint64(len(data)) + + if len(nc.fch) == 0 { + nc.kickFlusher() + } + nc.mu.Unlock() + return nil +} + +// respHandler is the global response handler. It will look up +// the appropriate channel based on the last token and place +// the message on the channel if possible. +func (nc *Conn) respHandler(m *Msg) { + rt := respToken(m.Subject) + + nc.mu.Lock() + // Just return if closed. + if nc.isClosed() { + nc.mu.Unlock() + return + } + + // Grab mch + mch := nc.respMap[rt] + // Delete the key regardless, one response only. + // FIXME(dlc) - should we track responses past 1 + // just statistics wise? + delete(nc.respMap, rt) + nc.mu.Unlock() + + // Don't block, let Request timeout instead, mch is + // buffered and we should delete the key before a + // second response is processed. + select { + case mch <- m: + default: + return + } +} + +// Create the response subscription we will use for all +// new style responses. This will be on an _INBOX with an +// additional terminal token. The subscription will be on +// a wildcard. Caller is responsible for ensuring this is +// only called once. +func (nc *Conn) createRespMux(respSub string) error { + s, err := nc.Subscribe(respSub, nc.respHandler) + if err != nil { + return err + } + nc.mu.Lock() + nc.respMux = s + nc.mu.Unlock() + return nil +} + +// Request will send a request payload and deliver the response message, +// or an error, including a timeout if no message was received properly. +func (nc *Conn) Request(subj string, data []byte, timeout time.Duration) (*Msg, error) { + if nc == nil { + return nil, ErrInvalidConnection + } + + nc.mu.Lock() + // If user wants the old style. + if nc.Opts.UseOldRequestStyle { + nc.mu.Unlock() + return nc.oldRequest(subj, data, timeout) + } + + // Do setup for the new style. + if nc.respMap == nil { + // _INBOX wildcard + nc.respSub = fmt.Sprintf("%s.*", NewInbox()) + nc.respMap = make(map[string]chan *Msg) + } + // Create literal Inbox and map to a chan msg. + mch := make(chan *Msg, RequestChanLen) + respInbox := nc.newRespInbox() + token := respToken(respInbox) + nc.respMap[token] = mch + createSub := nc.respMux == nil + ginbox := nc.respSub + nc.mu.Unlock() + + if createSub { + // Make sure scoped subscription is setup only once. + var err error + nc.respSetup.Do(func() { err = nc.createRespMux(ginbox) }) + if err != nil { + return nil, err + } + } + + if err := nc.PublishRequest(subj, respInbox, data); err != nil { + return nil, err + } + + t := globalTimerPool.Get(timeout) + defer globalTimerPool.Put(t) + + var ok bool + var msg *Msg + + select { + case msg, ok = <-mch: + if !ok { + return nil, ErrConnectionClosed + } + case <-t.C: + nc.mu.Lock() + delete(nc.respMap, token) + nc.mu.Unlock() + return nil, ErrTimeout + } + + return msg, nil +} + +// oldRequest will create an Inbox and perform a Request() call +// with the Inbox reply and return the first reply received. +// This is optimized for the case of multiple responses. +func (nc *Conn) oldRequest(subj string, data []byte, timeout time.Duration) (*Msg, error) { + inbox := NewInbox() + ch := make(chan *Msg, RequestChanLen) + + s, err := nc.subscribe(inbox, _EMPTY_, nil, ch) + if err != nil { + return nil, err + } + s.AutoUnsubscribe(1) + defer s.Unsubscribe() + + err = nc.PublishRequest(subj, inbox, data) + if err != nil { + return nil, err + } + return s.NextMsg(timeout) +} + +// InboxPrefix is the prefix for all inbox subjects. +const InboxPrefix = "_INBOX." +const inboxPrefixLen = len(InboxPrefix) +const respInboxPrefixLen = inboxPrefixLen + nuidSize + 1 + +// NewInbox will return an inbox string which can be used for directed replies from +// subscribers. These are guaranteed to be unique, but can be shared and subscribed +// to by others. +func NewInbox() string { + var b [inboxPrefixLen + nuidSize]byte + pres := b[:inboxPrefixLen] + copy(pres, InboxPrefix) + ns := b[inboxPrefixLen:] + copy(ns, nuid.Next()) + return string(b[:]) +} + +// Creates a new literal response subject that will trigger +// the global subscription handler. +func (nc *Conn) newRespInbox() string { + var b [inboxPrefixLen + (2 * nuidSize) + 1]byte + pres := b[:respInboxPrefixLen] + copy(pres, nc.respSub) + ns := b[respInboxPrefixLen:] + copy(ns, nuid.Next()) + return string(b[:]) +} + +// respToken will return the last token of a literal response inbox +// which we use for the message channel lookup. +func respToken(respInbox string) string { + return respInbox[respInboxPrefixLen:] +} + +// Subscribe will express interest in the given subject. The subject +// can have wildcards (partial:*, full:>). Messages will be delivered +// to the associated MsgHandler. If no MsgHandler is given, the +// subscription is a synchronous subscription and can be polled via +// Subscription.NextMsg(). +func (nc *Conn) Subscribe(subj string, cb MsgHandler) (*Subscription, error) { + return nc.subscribe(subj, _EMPTY_, cb, nil) +} + +// ChanSubscribe will place all messages received on the channel. +// You should not close the channel until sub.Unsubscribe() has been called. +func (nc *Conn) ChanSubscribe(subj string, ch chan *Msg) (*Subscription, error) { + return nc.subscribe(subj, _EMPTY_, nil, ch) +} + +// ChanQueueSubscribe will place all messages received on the channel. +// You should not close the channel until sub.Unsubscribe() has been called. +func (nc *Conn) ChanQueueSubscribe(subj, group string, ch chan *Msg) (*Subscription, error) { + return nc.subscribe(subj, group, nil, ch) +} + +// SubscribeSync is syntactic sugar for Subscribe(subject, nil). +func (nc *Conn) SubscribeSync(subj string) (*Subscription, error) { + if nc == nil { + return nil, ErrInvalidConnection + } + mch := make(chan *Msg, nc.Opts.SubChanLen) + s, e := nc.subscribe(subj, _EMPTY_, nil, mch) + if s != nil { + s.typ = SyncSubscription + } + return s, e +} + +// QueueSubscribe creates an asynchronous queue subscriber on the given subject. +// All subscribers with the same queue name will form the queue group and +// only one member of the group will be selected to receive any given +// message asynchronously. +func (nc *Conn) QueueSubscribe(subj, queue string, cb MsgHandler) (*Subscription, error) { + return nc.subscribe(subj, queue, cb, nil) +} + +// QueueSubscribeSync creates a synchronous queue subscriber on the given +// subject. All subscribers with the same queue name will form the queue +// group and only one member of the group will be selected to receive any +// given message synchronously. +func (nc *Conn) QueueSubscribeSync(subj, queue string) (*Subscription, error) { + mch := make(chan *Msg, nc.Opts.SubChanLen) + s, e := nc.subscribe(subj, queue, nil, mch) + if s != nil { + s.typ = SyncSubscription + } + return s, e +} + +// QueueSubscribeSyncWithChan is syntactic sugar for ChanQueueSubscribe(subject, group, ch). +func (nc *Conn) QueueSubscribeSyncWithChan(subj, queue string, ch chan *Msg) (*Subscription, error) { + return nc.subscribe(subj, queue, nil, ch) +} + +// subscribe is the internal subscribe function that indicates interest in a subject. +func (nc *Conn) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg) (*Subscription, error) { + if nc == nil { + return nil, ErrInvalidConnection + } + nc.mu.Lock() + // ok here, but defer is generally expensive + defer nc.mu.Unlock() + defer nc.kickFlusher() + + // Check for some error conditions. + if nc.isClosed() { + return nil, ErrConnectionClosed + } + + if cb == nil && ch == nil { + return nil, ErrBadSubscription + } + + sub := &Subscription{Subject: subj, Queue: queue, mcb: cb, conn: nc} + // Set pending limits. + sub.pMsgsLimit = DefaultSubPendingMsgsLimit + sub.pBytesLimit = DefaultSubPendingBytesLimit + + // If we have an async callback, start up a sub specific + // Go routine to deliver the messages. + if cb != nil { + sub.typ = AsyncSubscription + sub.pCond = sync.NewCond(&sub.mu) + go nc.waitForMsgs(sub) + } else { + sub.typ = ChanSubscription + sub.mch = ch + } + + nc.subsMu.Lock() + nc.ssid++ + sub.sid = nc.ssid + nc.subs[sub.sid] = sub + nc.subsMu.Unlock() + + // We will send these for all subs when we reconnect + // so that we can suppress here. + if !nc.isReconnecting() { + fmt.Fprintf(nc.bw, subProto, subj, queue, sub.sid) + } + return sub, nil +} + +// Lock for nc should be held here upon entry +func (nc *Conn) removeSub(s *Subscription) { + nc.subsMu.Lock() + delete(nc.subs, s.sid) + nc.subsMu.Unlock() + s.mu.Lock() + defer s.mu.Unlock() + // Release callers on NextMsg for SyncSubscription only + if s.mch != nil && s.typ == SyncSubscription { + close(s.mch) + } + s.mch = nil + + // Mark as invalid + s.conn = nil + s.closed = true + if s.pCond != nil { + s.pCond.Broadcast() + } +} + +// SubscriptionType is the type of the Subscription. +type SubscriptionType int + +// The different types of subscription types. +const ( + AsyncSubscription = SubscriptionType(iota) + SyncSubscription + ChanSubscription + NilSubscription +) + +// Type returns the type of Subscription. +func (s *Subscription) Type() SubscriptionType { + if s == nil { + return NilSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + return s.typ +} + +// IsValid returns a boolean indicating whether the subscription +// is still active. This will return false if the subscription has +// already been closed. +func (s *Subscription) IsValid() bool { + if s == nil { + return false + } + s.mu.Lock() + defer s.mu.Unlock() + return s.conn != nil +} + +// Unsubscribe will remove interest in the given subject. +func (s *Subscription) Unsubscribe() error { + if s == nil { + return ErrBadSubscription + } + s.mu.Lock() + conn := s.conn + s.mu.Unlock() + if conn == nil { + return ErrBadSubscription + } + return conn.unsubscribe(s, 0) +} + +// AutoUnsubscribe will issue an automatic Unsubscribe that is +// processed by the server when max messages have been received. +// This can be useful when sending a request to an unknown number +// of subscribers. Request() uses this functionality. +func (s *Subscription) AutoUnsubscribe(max int) error { + if s == nil { + return ErrBadSubscription + } + s.mu.Lock() + conn := s.conn + s.mu.Unlock() + if conn == nil { + return ErrBadSubscription + } + return conn.unsubscribe(s, max) +} + +// unsubscribe performs the low level unsubscribe to the server. +// Use Subscription.Unsubscribe() +func (nc *Conn) unsubscribe(sub *Subscription, max int) error { + nc.mu.Lock() + // ok here, but defer is expensive + defer nc.mu.Unlock() + defer nc.kickFlusher() + + if nc.isClosed() { + return ErrConnectionClosed + } + + nc.subsMu.RLock() + s := nc.subs[sub.sid] + nc.subsMu.RUnlock() + // Already unsubscribed + if s == nil { + return nil + } + + maxStr := _EMPTY_ + if max > 0 { + s.max = uint64(max) + maxStr = strconv.Itoa(max) + } else { + nc.removeSub(s) + } + // We will send these for all subs when we reconnect + // so that we can suppress here. + if !nc.isReconnecting() { + fmt.Fprintf(nc.bw, unsubProto, s.sid, maxStr) + } + return nil +} + +// NextMsg will return the next message available to a synchronous subscriber +// or block until one is available. A timeout can be used to return when no +// message has been delivered. +func (s *Subscription) NextMsg(timeout time.Duration) (*Msg, error) { + if s == nil { + return nil, ErrBadSubscription + } + + s.mu.Lock() + err := s.validateNextMsgState() + if err != nil { + s.mu.Unlock() + return nil, err + } + + // snapshot + mch := s.mch + s.mu.Unlock() + + var ok bool + var msg *Msg + + t := globalTimerPool.Get(timeout) + defer globalTimerPool.Put(t) + + select { + case msg, ok = <-mch: + if !ok { + return nil, ErrConnectionClosed + } + err := s.processNextMsgDelivered(msg) + if err != nil { + return nil, err + } + case <-t.C: + return nil, ErrTimeout + } + + return msg, nil +} + +// validateNextMsgState checks whether the subscription is in a valid +// state to call NextMsg and be delivered another message synchronously. +// This should be called while holding the lock. +func (s *Subscription) validateNextMsgState() error { + if s.connClosed { + return ErrConnectionClosed + } + if s.mch == nil { + if s.max > 0 && s.delivered >= s.max { + return ErrMaxMessages + } else if s.closed { + return ErrBadSubscription + } + } + if s.mcb != nil { + return ErrSyncSubRequired + } + if s.sc { + s.sc = false + return ErrSlowConsumer + } + + return nil +} + +// processNextMsgDelivered takes a message and applies the needed +// accounting to the stats from the subscription, returning an +// error in case we have the maximum number of messages have been +// delivered already. It should not be called while holding the lock. +func (s *Subscription) processNextMsgDelivered(msg *Msg) error { + s.mu.Lock() + nc := s.conn + max := s.max + + // Update some stats. + s.delivered++ + delivered := s.delivered + if s.typ == SyncSubscription { + s.pMsgs-- + s.pBytes -= len(msg.Data) + } + s.mu.Unlock() + + if max > 0 { + if delivered > max { + return ErrMaxMessages + } + // Remove subscription if we have reached max. + if delivered == max { + nc.mu.Lock() + nc.removeSub(s) + nc.mu.Unlock() + } + } + + return nil +} + +// Queued returns the number of queued messages in the client for this subscription. +// DEPRECATED: Use Pending() +func (s *Subscription) QueuedMsgs() (int, error) { + m, _, err := s.Pending() + return int(m), err +} + +// Pending returns the number of queued messages and queued bytes in the client for this subscription. +func (s *Subscription) Pending() (int, int, error) { + if s == nil { + return -1, -1, ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return -1, -1, ErrBadSubscription + } + if s.typ == ChanSubscription { + return -1, -1, ErrTypeSubscription + } + return s.pMsgs, s.pBytes, nil +} + +// MaxPending returns the maximum number of queued messages and queued bytes seen so far. +func (s *Subscription) MaxPending() (int, int, error) { + if s == nil { + return -1, -1, ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return -1, -1, ErrBadSubscription + } + if s.typ == ChanSubscription { + return -1, -1, ErrTypeSubscription + } + return s.pMsgsMax, s.pBytesMax, nil +} + +// ClearMaxPending resets the maximums seen so far. +func (s *Subscription) ClearMaxPending() error { + if s == nil { + return ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return ErrBadSubscription + } + if s.typ == ChanSubscription { + return ErrTypeSubscription + } + s.pMsgsMax, s.pBytesMax = 0, 0 + return nil +} + +// Pending Limits +const ( + DefaultSubPendingMsgsLimit = 65536 + DefaultSubPendingBytesLimit = 65536 * 1024 +) + +// PendingLimits returns the current limits for this subscription. +// If no error is returned, a negative value indicates that the +// given metric is not limited. +func (s *Subscription) PendingLimits() (int, int, error) { + if s == nil { + return -1, -1, ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return -1, -1, ErrBadSubscription + } + if s.typ == ChanSubscription { + return -1, -1, ErrTypeSubscription + } + return s.pMsgsLimit, s.pBytesLimit, nil +} + +// SetPendingLimits sets the limits for pending msgs and bytes for this subscription. +// Zero is not allowed. Any negative value means that the given metric is not limited. +func (s *Subscription) SetPendingLimits(msgLimit, bytesLimit int) error { + if s == nil { + return ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return ErrBadSubscription + } + if s.typ == ChanSubscription { + return ErrTypeSubscription + } + if msgLimit == 0 || bytesLimit == 0 { + return ErrInvalidArg + } + s.pMsgsLimit, s.pBytesLimit = msgLimit, bytesLimit + return nil +} + +// Delivered returns the number of delivered messages for this subscription. +func (s *Subscription) Delivered() (int64, error) { + if s == nil { + return -1, ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return -1, ErrBadSubscription + } + return int64(s.delivered), nil +} + +// Dropped returns the number of known dropped messages for this subscription. +// This will correspond to messages dropped by violations of PendingLimits. If +// the server declares the connection a SlowConsumer, this number may not be +// valid. +func (s *Subscription) Dropped() (int, error) { + if s == nil { + return -1, ErrBadSubscription + } + s.mu.Lock() + defer s.mu.Unlock() + if s.conn == nil { + return -1, ErrBadSubscription + } + return s.dropped, nil +} + +// FIXME: This is a hack +// removeFlushEntry is needed when we need to discard queued up responses +// for our pings as part of a flush call. This happens when we have a flush +// call outstanding and we call close. +func (nc *Conn) removeFlushEntry(ch chan struct{}) bool { + nc.mu.Lock() + defer nc.mu.Unlock() + if nc.pongs == nil { + return false + } + for i, c := range nc.pongs { + if c == ch { + nc.pongs[i] = nil + return true + } + } + return false +} + +// The lock must be held entering this function. +func (nc *Conn) sendPing(ch chan struct{}) { + nc.pongs = append(nc.pongs, ch) + nc.bw.WriteString(pingProto) + // Flush in place. + nc.bw.Flush() +} + +// This will fire periodically and send a client origin +// ping to the server. Will also check that we have received +// responses from the server. +func (nc *Conn) processPingTimer() { + nc.mu.Lock() + + if nc.status != CONNECTED { + nc.mu.Unlock() + return + } + + // Check for violation + nc.pout++ + if nc.pout > nc.Opts.MaxPingsOut { + nc.mu.Unlock() + nc.processOpErr(ErrStaleConnection) + return + } + + nc.sendPing(nil) + nc.ptmr.Reset(nc.Opts.PingInterval) + nc.mu.Unlock() +} + +// FlushTimeout allows a Flush operation to have an associated timeout. +func (nc *Conn) FlushTimeout(timeout time.Duration) (err error) { + if nc == nil { + return ErrInvalidConnection + } + if timeout <= 0 { + return ErrBadTimeout + } + + nc.mu.Lock() + if nc.isClosed() { + nc.mu.Unlock() + return ErrConnectionClosed + } + t := globalTimerPool.Get(timeout) + defer globalTimerPool.Put(t) + + // Create a buffered channel to prevent chan send to block + // in processPong() if this code here times out just when + // PONG was received. + ch := make(chan struct{}, 1) + nc.sendPing(ch) + nc.mu.Unlock() + + select { + case _, ok := <-ch: + if !ok { + err = ErrConnectionClosed + } else { + close(ch) + } + case <-t.C: + err = ErrTimeout + } + + if err != nil { + nc.removeFlushEntry(ch) + } + return +} + +// Flush will perform a round trip to the server and return when it +// receives the internal reply. +func (nc *Conn) Flush() error { + return nc.FlushTimeout(60 * time.Second) +} + +// Buffered will return the number of bytes buffered to be sent to the server. +// FIXME(dlc) take into account disconnected state. +func (nc *Conn) Buffered() (int, error) { + nc.mu.Lock() + defer nc.mu.Unlock() + if nc.isClosed() || nc.bw == nil { + return -1, ErrConnectionClosed + } + return nc.bw.Buffered(), nil +} + +// resendSubscriptions will send our subscription state back to the +// server. Used in reconnects +func (nc *Conn) resendSubscriptions() { + // Since we are going to send protocols to the server, we don't want to + // be holding the subsMu lock (which is used in processMsg). So copy + // the subscriptions in a temporary array. + nc.subsMu.RLock() + subs := make([]*Subscription, 0, len(nc.subs)) + for _, s := range nc.subs { + subs = append(subs, s) + } + nc.subsMu.RUnlock() + for _, s := range subs { + adjustedMax := uint64(0) + s.mu.Lock() + if s.max > 0 { + if s.delivered < s.max { + adjustedMax = s.max - s.delivered + } + + // adjustedMax could be 0 here if the number of delivered msgs + // reached the max, if so unsubscribe. + if adjustedMax == 0 { + s.mu.Unlock() + fmt.Fprintf(nc.bw, unsubProto, s.sid, _EMPTY_) + continue + } + } + s.mu.Unlock() + + fmt.Fprintf(nc.bw, subProto, s.Subject, s.Queue, s.sid) + if adjustedMax > 0 { + maxStr := strconv.Itoa(int(adjustedMax)) + fmt.Fprintf(nc.bw, unsubProto, s.sid, maxStr) + } + } +} + +// This will clear any pending flush calls and release pending calls. +// Lock is assumed to be held by the caller. +func (nc *Conn) clearPendingFlushCalls() { + // Clear any queued pongs, e.g. pending flush calls. + for _, ch := range nc.pongs { + if ch != nil { + close(ch) + } + } + nc.pongs = nil +} + +// This will clear any pending Request calls. +// Lock is assumed to be held by the caller. +func (nc *Conn) clearPendingRequestCalls() { + if nc.respMap == nil { + return + } + for key, ch := range nc.respMap { + if ch != nil { + close(ch) + delete(nc.respMap, key) + } + } +} + +// Low level close call that will do correct cleanup and set +// desired status. Also controls whether user defined callbacks +// will be triggered. The lock should not be held entering this +// function. This function will handle the locking manually. +func (nc *Conn) close(status Status, doCBs bool) { + nc.mu.Lock() + if nc.isClosed() { + nc.status = status + nc.mu.Unlock() + return + } + nc.status = CLOSED + + // Kick the Go routines so they fall out. + nc.kickFlusher() + nc.mu.Unlock() + + nc.mu.Lock() + + // Clear any queued pongs, e.g. pending flush calls. + nc.clearPendingFlushCalls() + + // Clear any queued and blocking Requests. + nc.clearPendingRequestCalls() + + if nc.ptmr != nil { + nc.ptmr.Stop() + } + + // Go ahead and make sure we have flushed the outbound + if nc.conn != nil { + nc.bw.Flush() + defer nc.conn.Close() + } + + // Close sync subscriber channels and release any + // pending NextMsg() calls. + nc.subsMu.Lock() + for _, s := range nc.subs { + s.mu.Lock() + + // Release callers on NextMsg for SyncSubscription only + if s.mch != nil && s.typ == SyncSubscription { + close(s.mch) + } + s.mch = nil + // Mark as invalid, for signaling to deliverMsgs + s.closed = true + // Mark connection closed in subscription + s.connClosed = true + // If we have an async subscription, signals it to exit + if s.typ == AsyncSubscription && s.pCond != nil { + s.pCond.Signal() + } + + s.mu.Unlock() + } + nc.subs = nil + nc.subsMu.Unlock() + + // Perform appropriate callback if needed for a disconnect. + if doCBs { + if nc.Opts.DisconnectedCB != nil && nc.conn != nil { + nc.ach <- func() { nc.Opts.DisconnectedCB(nc) } + } + if nc.Opts.ClosedCB != nil { + nc.ach <- func() { nc.Opts.ClosedCB(nc) } + } + nc.ach <- nc.closeAsyncFunc() + } + nc.status = status + nc.mu.Unlock() +} + +// Close will close the connection to the server. This call will release +// all blocking calls, such as Flush() and NextMsg() +func (nc *Conn) Close() { + nc.close(CLOSED, true) +} + +// IsClosed tests if a Conn has been closed. +func (nc *Conn) IsClosed() bool { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.isClosed() +} + +// IsReconnecting tests if a Conn is reconnecting. +func (nc *Conn) IsReconnecting() bool { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.isReconnecting() +} + +// IsConnected tests if a Conn is connected. +func (nc *Conn) IsConnected() bool { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.isConnected() +} + +// caller must lock +func (nc *Conn) getServers(implicitOnly bool) []string { + poolSize := len(nc.srvPool) + var servers = make([]string, 0) + for i := 0; i < poolSize; i++ { + if implicitOnly && !nc.srvPool[i].isImplicit { + continue + } + url := nc.srvPool[i].url + servers = append(servers, fmt.Sprintf("%s://%s", url.Scheme, url.Host)) + } + return servers +} + +// Servers returns the list of known server urls, including additional +// servers discovered after a connection has been established. If +// authentication is enabled, use UserInfo or Token when connecting with +// these urls. +func (nc *Conn) Servers() []string { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.getServers(false) +} + +// DiscoveredServers returns only the server urls that have been discovered +// after a connection has been established. If authentication is enabled, +// use UserInfo or Token when connecting with these urls. +func (nc *Conn) DiscoveredServers() []string { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.getServers(true) +} + +// Status returns the current state of the connection. +func (nc *Conn) Status() Status { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.status +} + +// Test if Conn has been closed Lock is assumed held. +func (nc *Conn) isClosed() bool { + return nc.status == CLOSED +} + +// Test if Conn is in the process of connecting +func (nc *Conn) isConnecting() bool { + return nc.status == CONNECTING +} + +// Test if Conn is being reconnected. +func (nc *Conn) isReconnecting() bool { + return nc.status == RECONNECTING +} + +// Test if Conn is connected or connecting. +func (nc *Conn) isConnected() bool { + return nc.status == CONNECTED +} + +// Stats will return a race safe copy of the Statistics section for the connection. +func (nc *Conn) Stats() Statistics { + // Stats are updated either under connection's mu or subsMu mutexes. + // Lock both to safely get them. + nc.mu.Lock() + nc.subsMu.RLock() + stats := Statistics{ + InMsgs: nc.InMsgs, + InBytes: nc.InBytes, + OutMsgs: nc.OutMsgs, + OutBytes: nc.OutBytes, + Reconnects: nc.Reconnects, + } + nc.subsMu.RUnlock() + nc.mu.Unlock() + return stats +} + +// MaxPayload returns the size limit that a message payload can have. +// This is set by the server configuration and delivered to the client +// upon connect. +func (nc *Conn) MaxPayload() int64 { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.info.MaxPayload +} + +// AuthRequired will return if the connected server requires authorization. +func (nc *Conn) AuthRequired() bool { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.info.AuthRequired +} + +// TLSRequired will return if the connected server requires TLS connections. +func (nc *Conn) TLSRequired() bool { + nc.mu.Lock() + defer nc.mu.Unlock() + return nc.info.TLSRequired +} diff --git a/vendor/github.com/nats-io/go-nats/nats_test.go b/vendor/github.com/nats-io/go-nats/nats_test.go new file mode 100644 index 00000000..cbd95632 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/nats_test.go @@ -0,0 +1,1177 @@ +package nats + +//////////////////////////////////////////////////////////////////////////////// +// Package scoped specific tests here.. +//////////////////////////////////////////////////////////////////////////////// + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + "reflect" + "runtime" + "strings" + "testing" + "time" + + "github.com/nats-io/gnatsd/server" + gnatsd "github.com/nats-io/gnatsd/test" +) + +// Dumb wait program to sync on callbacks, etc... Will timeout +func Wait(ch chan bool) error { + return WaitTime(ch, 5*time.Second) +} + +func WaitTime(ch chan bool, timeout time.Duration) error { + select { + case <-ch: + return nil + case <-time.After(timeout): + } + return errors.New("timeout") +} + +func stackFatalf(t *testing.T, f string, args ...interface{}) { + lines := make([]string, 0, 32) + msg := fmt.Sprintf(f, args...) + lines = append(lines, msg) + + // Generate the Stack of callers: Skip us and verify* frames. + for i := 2; true; i++ { + _, file, line, ok := runtime.Caller(i) + if !ok { + break + } + msg := fmt.Sprintf("%d - %s:%d", i, file, line) + lines = append(lines, msg) + } + t.Fatalf("%s", strings.Join(lines, "\n")) +} + +//////////////////////////////////////////////////////////////////////////////// +// Reconnect tests +//////////////////////////////////////////////////////////////////////////////// + +const TEST_PORT = 8368 + +var reconnectOpts = Options{ + Url: fmt.Sprintf("nats://localhost:%d", TEST_PORT), + AllowReconnect: true, + MaxReconnect: 10, + ReconnectWait: 100 * time.Millisecond, + Timeout: DefaultTimeout, +} + +func RunServerOnPort(port int) *server.Server { + opts := gnatsd.DefaultTestOptions + opts.Port = port + return RunServerWithOptions(opts) +} + +func RunServerWithOptions(opts server.Options) *server.Server { + return gnatsd.RunServer(&opts) +} + +func TestReconnectServerStats(t *testing.T) { + ts := RunServerOnPort(TEST_PORT) + + opts := reconnectOpts + nc, _ := opts.Connect() + defer nc.Close() + nc.Flush() + + ts.Shutdown() + // server is stopped here... + + ts = RunServerOnPort(TEST_PORT) + defer ts.Shutdown() + + if err := nc.FlushTimeout(5 * time.Second); err != nil { + t.Fatalf("Error on Flush: %v", err) + } + + // Make sure the server who is reconnected has the reconnects stats reset. + nc.mu.Lock() + _, cur := nc.currentServer() + nc.mu.Unlock() + + if cur.reconnects != 0 { + t.Fatalf("Current Server's reconnects should be 0 vs %d\n", cur.reconnects) + } +} + +func TestParseStateReconnectFunctionality(t *testing.T) { + ts := RunServerOnPort(TEST_PORT) + ch := make(chan bool) + + opts := reconnectOpts + dch := make(chan bool) + opts.DisconnectedCB = func(_ *Conn) { + dch <- true + } + + nc, errc := opts.Connect() + if errc != nil { + t.Fatalf("Failed to create a connection: %v\n", errc) + } + ec, errec := NewEncodedConn(nc, DEFAULT_ENCODER) + if errec != nil { + nc.Close() + t.Fatalf("Failed to create an encoded connection: %v\n", errec) + } + defer ec.Close() + + testString := "bar" + ec.Subscribe("foo", func(s string) { + if s != testString { + t.Fatal("String doesn't match") + } + ch <- true + }) + ec.Flush() + + // Got a RACE condition with Travis build. The locking below does not + // really help because the parser running in the readLoop accesses + // nc.ps without the connection lock. Sleeping may help better since + // it would make the memory write in parse.go (when processing the + // pong) further away from the modification below. + time.Sleep(1 * time.Second) + + // Simulate partialState, this needs to be cleared + nc.mu.Lock() + nc.ps.state = OP_PON + nc.mu.Unlock() + + ts.Shutdown() + // server is stopped here... + + if err := Wait(dch); err != nil { + t.Fatal("Did not get the DisconnectedCB") + } + + if err := ec.Publish("foo", testString); err != nil { + t.Fatalf("Failed to publish message: %v\n", err) + } + + ts = RunServerOnPort(TEST_PORT) + defer ts.Shutdown() + + if err := ec.FlushTimeout(5 * time.Second); err != nil { + t.Fatalf("Error on Flush: %v", err) + } + + if err := Wait(ch); err != nil { + t.Fatal("Did not receive our message") + } + + expectedReconnectCount := uint64(1) + reconnectedCount := ec.Conn.Stats().Reconnects + + if reconnectedCount != expectedReconnectCount { + t.Fatalf("Reconnect count incorrect: %d vs %d\n", + reconnectedCount, expectedReconnectCount) + } +} + +//////////////////////////////////////////////////////////////////////////////// +// ServerPool tests +//////////////////////////////////////////////////////////////////////////////// + +var testServers = []string{ + "nats://localhost:1222", + "nats://localhost:1223", + "nats://localhost:1224", + "nats://localhost:1225", + "nats://localhost:1226", + "nats://localhost:1227", + "nats://localhost:1228", +} + +func TestServersRandomize(t *testing.T) { + opts := GetDefaultOptions() + opts.Servers = testServers + nc := &Conn{Opts: opts} + if err := nc.setupServerPool(); err != nil { + t.Fatalf("Problem setting up Server Pool: %v\n", err) + } + // Build []string from srvPool + clientServers := []string{} + for _, s := range nc.srvPool { + clientServers = append(clientServers, s.url.String()) + } + // In theory this could happen.. + if reflect.DeepEqual(testServers, clientServers) { + t.Fatalf("ServerPool list not randomized\n") + } + + // Now test that we do not randomize if proper flag is set. + opts = GetDefaultOptions() + opts.Servers = testServers + opts.NoRandomize = true + nc = &Conn{Opts: opts} + if err := nc.setupServerPool(); err != nil { + t.Fatalf("Problem setting up Server Pool: %v\n", err) + } + // Build []string from srvPool + clientServers = []string{} + for _, s := range nc.srvPool { + clientServers = append(clientServers, s.url.String()) + } + if !reflect.DeepEqual(testServers, clientServers) { + t.Fatalf("ServerPool list should not be randomized\n") + } + + // Although the original intent was that if Opts.Url is + // set, Opts.Servers is not (and vice versa), the behavior + // is that Opts.Url is always first, even when randomization + // is enabled. So make sure that this is still the case. + opts = GetDefaultOptions() + opts.Url = DefaultURL + opts.Servers = testServers + nc = &Conn{Opts: opts} + if err := nc.setupServerPool(); err != nil { + t.Fatalf("Problem setting up Server Pool: %v\n", err) + } + // Build []string from srvPool + clientServers = []string{} + for _, s := range nc.srvPool { + clientServers = append(clientServers, s.url.String()) + } + // In theory this could happen.. + if reflect.DeepEqual(testServers, clientServers) { + t.Fatalf("ServerPool list not randomized\n") + } + if clientServers[0] != DefaultURL { + t.Fatalf("Options.Url should be first in the array, got %v", clientServers[0]) + } +} + +func TestSelectNextServer(t *testing.T) { + opts := GetDefaultOptions() + opts.Servers = testServers + opts.NoRandomize = true + nc := &Conn{Opts: opts} + if err := nc.setupServerPool(); err != nil { + t.Fatalf("Problem setting up Server Pool: %v\n", err) + } + if nc.url != nc.srvPool[0].url { + t.Fatalf("Wrong default selection: %v\n", nc.url) + } + + sel, err := nc.selectNextServer() + if err != nil { + t.Fatalf("Got an err: %v\n", err) + } + // Check that we are now looking at #2, and current is now last. + if len(nc.srvPool) != len(testServers) { + t.Fatalf("List is incorrect size: %d vs %d\n", len(nc.srvPool), len(testServers)) + } + if nc.url.String() != testServers[1] { + t.Fatalf("Selection incorrect: %v vs %v\n", nc.url, testServers[1]) + } + if nc.srvPool[len(nc.srvPool)-1].url.String() != testServers[0] { + t.Fatalf("Did not push old to last position\n") + } + if sel != nc.srvPool[0] { + t.Fatalf("Did not return correct server: %v vs %v\n", sel.url, nc.srvPool[0].url) + } + + // Test that we do not keep servers where we have tried to reconnect past our limit. + nc.srvPool[0].reconnects = int(opts.MaxReconnect) + if _, err := nc.selectNextServer(); err != nil { + t.Fatalf("Got an err: %v\n", err) + } + // Check that we are now looking at #3, and current is not in the list. + if len(nc.srvPool) != len(testServers)-1 { + t.Fatalf("List is incorrect size: %d vs %d\n", len(nc.srvPool), len(testServers)-1) + } + if nc.url.String() != testServers[2] { + t.Fatalf("Selection incorrect: %v vs %v\n", nc.url, testServers[2]) + } + if nc.srvPool[len(nc.srvPool)-1].url.String() == testServers[1] { + t.Fatalf("Did not throw away the last server correctly\n") + } +} + +// This will test that comma separated url strings work properly for +// the Connect() command. +func TestUrlArgument(t *testing.T) { + check := func(url string, expected []string) { + if !reflect.DeepEqual(processUrlString(url), expected) { + t.Fatalf("Got wrong response processing URL: %q, RES: %#v\n", url, processUrlString(url)) + } + } + // This is normal case + oneExpected := []string{"nats://localhost:1222"} + + check("nats://localhost:1222", oneExpected) + check("nats://localhost:1222 ", oneExpected) + check(" nats://localhost:1222", oneExpected) + check(" nats://localhost:1222 ", oneExpected) + + var multiExpected = []string{ + "nats://localhost:1222", + "nats://localhost:1223", + "nats://localhost:1224", + } + + check("nats://localhost:1222,nats://localhost:1223,nats://localhost:1224", multiExpected) + check("nats://localhost:1222, nats://localhost:1223, nats://localhost:1224", multiExpected) + check(" nats://localhost:1222, nats://localhost:1223, nats://localhost:1224 ", multiExpected) + check("nats://localhost:1222, nats://localhost:1223 ,nats://localhost:1224", multiExpected) +} + +func TestParserPing(t *testing.T) { + c := &Conn{} + fake := &bytes.Buffer{} + c.bw = bufio.NewWriterSize(fake, c.Opts.ReconnectBufSize) + + c.ps = &parseState{} + + if c.ps.state != OP_START { + t.Fatalf("Expected OP_START vs %d\n", c.ps.state) + } + ping := []byte("PING\r\n") + err := c.parse(ping[:1]) + if err != nil || c.ps.state != OP_P { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(ping[1:2]) + if err != nil || c.ps.state != OP_PI { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(ping[2:3]) + if err != nil || c.ps.state != OP_PIN { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(ping[3:4]) + if err != nil || c.ps.state != OP_PING { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(ping[4:5]) + if err != nil || c.ps.state != OP_PING { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(ping[5:6]) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(ping) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + // Should tolerate spaces + ping = []byte("PING \r") + err = c.parse(ping) + if err != nil || c.ps.state != OP_PING { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + c.ps.state = OP_START + ping = []byte("PING \r \n") + err = c.parse(ping) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } +} + +func TestParserErr(t *testing.T) { + c := &Conn{} + c.status = CLOSED + fake := &bytes.Buffer{} + c.bw = bufio.NewWriterSize(fake, c.Opts.ReconnectBufSize) + + c.ps = &parseState{} + + // This test focuses on the parser only, not how the error is + // actually processed by the upper layer. + + if c.ps.state != OP_START { + t.Fatalf("Expected OP_START vs %d\n", c.ps.state) + } + + expectedError := "'Any kind of error'" + errProto := []byte("-ERR " + expectedError + "\r\n") + err := c.parse(errProto[:1]) + if err != nil || c.ps.state != OP_MINUS { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(errProto[1:2]) + if err != nil || c.ps.state != OP_MINUS_E { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(errProto[2:3]) + if err != nil || c.ps.state != OP_MINUS_ER { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(errProto[3:4]) + if err != nil || c.ps.state != OP_MINUS_ERR { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(errProto[4:5]) + if err != nil || c.ps.state != OP_MINUS_ERR_SPC { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(errProto[5:6]) + if err != nil || c.ps.state != OP_MINUS_ERR_SPC { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + + // Check with split arg buffer + err = c.parse(errProto[6:7]) + if err != nil || c.ps.state != MINUS_ERR_ARG { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(errProto[7:10]) + if err != nil || c.ps.state != MINUS_ERR_ARG { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(errProto[10 : len(errProto)-2]) + if err != nil || c.ps.state != MINUS_ERR_ARG { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + if c.ps.argBuf == nil { + t.Fatal("ArgBuf should not be nil") + } + s := string(c.ps.argBuf) + if s != expectedError { + t.Fatalf("Expected %v, got %v", expectedError, s) + } + err = c.parse(errProto[len(errProto)-2:]) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + + // Check without split arg buffer + errProto = []byte("-ERR 'Any error'\r\n") + err = c.parse(errProto) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } +} + +func TestParserOK(t *testing.T) { + c := &Conn{} + c.ps = &parseState{} + + if c.ps.state != OP_START { + t.Fatalf("Expected OP_START vs %d\n", c.ps.state) + } + errProto := []byte("+OKay\r\n") + err := c.parse(errProto[:1]) + if err != nil || c.ps.state != OP_PLUS { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(errProto[1:2]) + if err != nil || c.ps.state != OP_PLUS_O { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(errProto[2:3]) + if err != nil || c.ps.state != OP_PLUS_OK { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(errProto[3:]) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } +} + +func TestParserShouldFail(t *testing.T) { + c := &Conn{} + c.ps = &parseState{} + + if err := c.parse([]byte(" PING")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("POO")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("Px")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("PIx")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("PINx")); err == nil { + t.Fatal("Should have received a parse error") + } + // Stop here because 'PING' protos are tolerant for anything between PING and \n + + c.ps.state = OP_START + if err := c.parse([]byte("POx")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("PONx")); err == nil { + t.Fatal("Should have received a parse error") + } + // Stop here because 'PONG' protos are tolerant for anything between PONG and \n + + c.ps.state = OP_START + if err := c.parse([]byte("ZOO")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("Mx\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("MSx\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("MSGx\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("MSG foo\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("MSG \r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("MSG foo 1\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("MSG foo bar 1\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("MSG foo bar 1 baz\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("MSG foo 1 bar baz\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("+x\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("+Ox\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("-x\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("-Ex\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("-ERx\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.ps.state = OP_START + if err := c.parse([]byte("-ERRx\r\n")); err == nil { + t.Fatal("Should have received a parse error") + } +} + +func TestParserSplitMsg(t *testing.T) { + + nc := &Conn{} + nc.ps = &parseState{} + + buf := []byte("MSG a\r\n") + err := nc.parse(buf) + if err == nil { + t.Fatal("Expected an error") + } + nc.ps = &parseState{} + + buf = []byte("MSG a b c\r\n") + err = nc.parse(buf) + if err == nil { + t.Fatal("Expected an error") + } + nc.ps = &parseState{} + + expectedCount := uint64(1) + expectedSize := uint64(3) + + buf = []byte("MSG a") + err = nc.parse(buf) + if err != nil { + t.Fatalf("Parser error: %v", err) + } + if nc.ps.argBuf == nil { + t.Fatal("Arg buffer should have been created") + } + + buf = []byte(" 1 3\r\nf") + err = nc.parse(buf) + if err != nil { + t.Fatalf("Parser error: %v", err) + } + if nc.ps.ma.size != 3 { + t.Fatalf("Wrong msg size: %d instead of 3", nc.ps.ma.size) + } + if nc.ps.ma.sid != 1 { + t.Fatalf("Wrong sid: %d instead of 1", nc.ps.ma.sid) + } + if string(nc.ps.ma.subject) != "a" { + t.Fatalf("Wrong subject: '%s' instead of 'a'", string(nc.ps.ma.subject)) + } + if nc.ps.msgBuf == nil { + t.Fatal("Msg buffer should have been created") + } + + buf = []byte("oo\r\n") + err = nc.parse(buf) + if err != nil { + t.Fatalf("Parser error: %v", err) + } + if (nc.Statistics.InMsgs != expectedCount) || (nc.Statistics.InBytes != expectedSize) { + t.Fatalf("Wrong stats: %d - %d instead of %d - %d", nc.Statistics.InMsgs, nc.Statistics.InBytes, expectedCount, expectedSize) + } + if (nc.ps.argBuf != nil) || (nc.ps.msgBuf != nil) { + t.Fatal("Buffers should be nil now") + } + + buf = []byte("MSG a 1 3\r\nfo") + err = nc.parse(buf) + if err != nil { + t.Fatalf("Parser error: %v", err) + } + if nc.ps.ma.size != 3 { + t.Fatalf("Wrong msg size: %d instead of 3", nc.ps.ma.size) + } + if nc.ps.ma.sid != 1 { + t.Fatalf("Wrong sid: %d instead of 1", nc.ps.ma.sid) + } + if string(nc.ps.ma.subject) != "a" { + t.Fatalf("Wrong subject: '%s' instead of 'a'", string(nc.ps.ma.subject)) + } + if nc.ps.argBuf == nil { + t.Fatal("Arg buffer should have been created") + } + if nc.ps.msgBuf == nil { + t.Fatal("Msg buffer should have been created") + } + + expectedCount++ + expectedSize += 3 + + buf = []byte("o\r\n") + err = nc.parse(buf) + if err != nil { + t.Fatalf("Parser error: %v", err) + } + if (nc.Statistics.InMsgs != expectedCount) || (nc.Statistics.InBytes != expectedSize) { + t.Fatalf("Wrong stats: %d - %d instead of %d - %d", nc.Statistics.InMsgs, nc.Statistics.InBytes, expectedCount, expectedSize) + } + if (nc.ps.argBuf != nil) || (nc.ps.msgBuf != nil) { + t.Fatal("Buffers should be nil now") + } + + buf = []byte("MSG a 1 6\r\nfo") + err = nc.parse(buf) + if err != nil { + t.Fatalf("Parser error: %v", err) + } + if nc.ps.ma.size != 6 { + t.Fatalf("Wrong msg size: %d instead of 3", nc.ps.ma.size) + } + if nc.ps.ma.sid != 1 { + t.Fatalf("Wrong sid: %d instead of 1", nc.ps.ma.sid) + } + if string(nc.ps.ma.subject) != "a" { + t.Fatalf("Wrong subject: '%s' instead of 'a'", string(nc.ps.ma.subject)) + } + if nc.ps.argBuf == nil { + t.Fatal("Arg buffer should have been created") + } + if nc.ps.msgBuf == nil { + t.Fatal("Msg buffer should have been created") + } + + buf = []byte("ob") + err = nc.parse(buf) + if err != nil { + t.Fatalf("Parser error: %v", err) + } + + expectedCount++ + expectedSize += 6 + + buf = []byte("ar\r\n") + err = nc.parse(buf) + if err != nil { + t.Fatalf("Parser error: %v", err) + } + if (nc.Statistics.InMsgs != expectedCount) || (nc.Statistics.InBytes != expectedSize) { + t.Fatalf("Wrong stats: %d - %d instead of %d - %d", nc.Statistics.InMsgs, nc.Statistics.InBytes, expectedCount, expectedSize) + } + if (nc.ps.argBuf != nil) || (nc.ps.msgBuf != nil) { + t.Fatal("Buffers should be nil now") + } + + // Let's have a msg that is bigger than the parser's scratch size. + // Since we prepopulate the msg with 'foo', adding 3 to the size. + msgSize := cap(nc.ps.scratch) + 100 + 3 + buf = []byte(fmt.Sprintf("MSG a 1 b %d\r\nfoo", msgSize)) + err = nc.parse(buf) + if err != nil { + t.Fatalf("Parser error: %v", err) + } + if nc.ps.ma.size != msgSize { + t.Fatalf("Wrong msg size: %d instead of %d", nc.ps.ma.size, msgSize) + } + if nc.ps.ma.sid != 1 { + t.Fatalf("Wrong sid: %d instead of 1", nc.ps.ma.sid) + } + if string(nc.ps.ma.subject) != "a" { + t.Fatalf("Wrong subject: '%s' instead of 'a'", string(nc.ps.ma.subject)) + } + if string(nc.ps.ma.reply) != "b" { + t.Fatalf("Wrong reply: '%s' instead of 'b'", string(nc.ps.ma.reply)) + } + if nc.ps.argBuf == nil { + t.Fatal("Arg buffer should have been created") + } + if nc.ps.msgBuf == nil { + t.Fatal("Msg buffer should have been created") + } + + expectedCount++ + expectedSize += uint64(msgSize) + + bufSize := msgSize - 3 + + buf = make([]byte, bufSize) + for i := 0; i < bufSize; i++ { + buf[i] = byte('a' + (i % 26)) + } + + err = nc.parse(buf) + if err != nil { + t.Fatalf("Parser error: %v", err) + } + if nc.ps.state != MSG_PAYLOAD { + t.Fatalf("Wrong state: %v instead of %v", nc.ps.state, MSG_PAYLOAD) + } + if nc.ps.ma.size != msgSize { + t.Fatalf("Wrong (ma) msg size: %d instead of %d", nc.ps.ma.size, msgSize) + } + if len(nc.ps.msgBuf) != msgSize { + t.Fatalf("Wrong msg size: %d instead of %d", len(nc.ps.msgBuf), msgSize) + } + // Check content: + if string(nc.ps.msgBuf[0:3]) != "foo" { + t.Fatalf("Wrong msg content: %s", string(nc.ps.msgBuf)) + } + for k := 3; k < nc.ps.ma.size; k++ { + if nc.ps.msgBuf[k] != byte('a'+((k-3)%26)) { + t.Fatalf("Wrong msg content: %s", string(nc.ps.msgBuf)) + } + } + + buf = []byte("\r\n") + if err := nc.parse(buf); err != nil { + t.Fatalf("Unexpected error during parsing: %v", err) + } + if (nc.Statistics.InMsgs != expectedCount) || (nc.Statistics.InBytes != expectedSize) { + t.Fatalf("Wrong stats: %d - %d instead of %d - %d", nc.Statistics.InMsgs, nc.Statistics.InBytes, expectedCount, expectedSize) + } + if (nc.ps.argBuf != nil) || (nc.ps.msgBuf != nil) { + t.Fatal("Buffers should be nil now") + } + if nc.ps.state != OP_START { + t.Fatalf("Wrong state: %v", nc.ps.state) + } +} + +func TestNormalizeError(t *testing.T) { + received := "Typical Error" + expected := strings.ToLower(received) + if s := normalizeErr("-ERR '" + received + "'"); s != expected { + t.Fatalf("Expected '%s', got '%s'", expected, s) + } + + received = "Trim Surrounding Spaces" + expected = strings.ToLower(received) + if s := normalizeErr("-ERR '" + received + "' "); s != expected { + t.Fatalf("Expected '%s', got '%s'", expected, s) + } + + received = "Trim Surrounding Spaces Without Quotes" + expected = strings.ToLower(received) + if s := normalizeErr("-ERR " + received + " "); s != expected { + t.Fatalf("Expected '%s', got '%s'", expected, s) + } + + received = "Error Without Quotes" + expected = strings.ToLower(received) + if s := normalizeErr("-ERR " + received); s != expected { + t.Fatalf("Expected '%s', got '%s'", expected, s) + } + + received = "Error With Quote Only On Left" + expected = strings.ToLower(received) + if s := normalizeErr("-ERR '" + received); s != expected { + t.Fatalf("Expected '%s', got '%s'", expected, s) + } + + received = "Error With Quote Only On Right" + expected = strings.ToLower(received) + if s := normalizeErr("-ERR " + received + "'"); s != expected { + t.Fatalf("Expected '%s', got '%s'", expected, s) + } +} + +func TestAsyncINFO(t *testing.T) { + opts := GetDefaultOptions() + c := &Conn{Opts: opts} + + c.ps = &parseState{} + + if c.ps.state != OP_START { + t.Fatalf("Expected OP_START vs %d\n", c.ps.state) + } + + info := []byte("INFO {}\r\n") + if c.ps.state != OP_START { + t.Fatalf("Expected OP_START vs %d\n", c.ps.state) + } + err := c.parse(info[:1]) + if err != nil || c.ps.state != OP_I { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(info[1:2]) + if err != nil || c.ps.state != OP_IN { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(info[2:3]) + if err != nil || c.ps.state != OP_INF { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(info[3:4]) + if err != nil || c.ps.state != OP_INFO { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(info[4:5]) + if err != nil || c.ps.state != OP_INFO_SPC { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + err = c.parse(info[5:]) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + + // All at once + err = c.parse(info) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + + // Server pool needs to be setup + c.setupServerPool() + + // Partials requiring argBuf + expectedServer := serverInfo{ + Id: "test", + Host: "localhost", + Port: 4222, + Version: "1.2.3", + AuthRequired: true, + TLSRequired: true, + MaxPayload: 2 * 1024 * 1024, + ConnectURLs: []string{"localhost:5222", "localhost:6222"}, + } + // Set NoRandomize so that the check with expectedServer info + // matches. + c.Opts.NoRandomize = true + + b, _ := json.Marshal(expectedServer) + info = []byte(fmt.Sprintf("INFO %s\r\n", b)) + if c.ps.state != OP_START { + t.Fatalf("Expected OP_START vs %d\n", c.ps.state) + } + err = c.parse(info[:9]) + if err != nil || c.ps.state != INFO_ARG || c.ps.argBuf == nil { + t.Fatalf("Unexpected: %d err: %v argBuf: %v\n", c.ps.state, err, c.ps.argBuf) + } + err = c.parse(info[9:11]) + if err != nil || c.ps.state != INFO_ARG || c.ps.argBuf == nil { + t.Fatalf("Unexpected: %d err: %v argBuf: %v\n", c.ps.state, err, c.ps.argBuf) + } + err = c.parse(info[11:]) + if err != nil || c.ps.state != OP_START || c.ps.argBuf != nil { + t.Fatalf("Unexpected: %d err: %v argBuf: %v\n", c.ps.state, err, c.ps.argBuf) + } + if !reflect.DeepEqual(c.info, expectedServer) { + t.Fatalf("Expected server info to be: %v, got: %v", expectedServer, c.info) + } + + // Good INFOs + good := []string{"INFO {}\r\n", "INFO {}\r\n", "INFO {} \r\n", "INFO { \"server_id\": \"test\" } \r\n", "INFO {\"connect_urls\":[]}\r\n"} + for _, gi := range good { + c.ps = &parseState{} + err = c.parse([]byte(gi)) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Protocol %q should be fine. Err=%v state=%v", gi, err, c.ps.state) + } + } + + // Wrong INFOs + wrong := []string{"IxNFO {}\r\n", "INxFO {}\r\n", "INFxO {}\r\n", "INFOx {}\r\n", "INFO{}\r\n", "INFO {}"} + for _, wi := range wrong { + c.ps = &parseState{} + err = c.parse([]byte(wi)) + if err == nil && c.ps.state == OP_START { + t.Fatalf("Protocol %q should have failed", wi) + } + } + + checkPool := func(inThatOrder bool, urls ...string) { + // Check both pool and urls map + if len(c.srvPool) != len(urls) { + stackFatalf(t, "Pool should have %d elements, has %d", len(urls), len(c.srvPool)) + } + if len(c.urls) != len(urls) { + stackFatalf(t, "Map should have %d elements, has %d", len(urls), len(c.urls)) + } + for i, url := range urls { + if inThatOrder { + if c.srvPool[i].url.Host != url { + stackFatalf(t, "Pool should have %q at index %q, has %q", url, i, c.srvPool[i].url.Host) + } + } else { + if _, present := c.urls[url]; !present { + stackFatalf(t, "Pool should have %q", url) + } + } + } + } + + // Now test the decoding of "connect_urls" + + // No randomize for now + c.Opts.NoRandomize = true + // Reset the pool + c.setupServerPool() + // Reinitialize the parser + c.ps = &parseState{} + + info = []byte("INFO {\"connect_urls\":[\"localhost:5222\"]}\r\n") + err = c.parse(info) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + // Pool now should contain localhost:4222 (the default URL) and localhost:5222 + checkPool(true, "localhost:4222", "localhost:5222") + + // Make sure that if client receives the same, it is not added again. + err = c.parse(info) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + // Pool should still contain localhost:4222 (the default URL) and localhost:5222 + checkPool(true, "localhost:4222", "localhost:5222") + + // Receive a new URL + info = []byte("INFO {\"connect_urls\":[\"localhost:6222\"]}\r\n") + err = c.parse(info) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + // Pool now should contain localhost:4222 (the default URL) localhost:5222 and localhost:6222 + checkPool(true, "localhost:4222", "localhost:5222", "localhost:6222") + + // Receive more than 1 URL at once + info = []byte("INFO {\"connect_urls\":[\"localhost:7222\", \"localhost:8222\"]}\r\n") + err = c.parse(info) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + // Pool now should contain localhost:4222 (the default URL) localhost:5222, localhost:6222 + // localhost:7222 and localhost:8222 + checkPool(true, "localhost:4222", "localhost:5222", "localhost:6222", "localhost:7222", "localhost:8222") + + // Test with pool randomization now. Note that with randominzation, + // the initial pool is randomize, then each array of urls that the + // client gets from the INFO protocol is randomized, but added to + // the end of the pool. + c.Opts.NoRandomize = false + c.setupServerPool() + + info = []byte("INFO {\"connect_urls\":[\"localhost:5222\"]}\r\n") + err = c.parse(info) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + // Pool now should contain localhost:4222 (the default URL) and localhost:5222 + checkPool(true, "localhost:4222", "localhost:5222") + + // Make sure that if client receives the same, it is not added again. + err = c.parse(info) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + // Pool should still contain localhost:4222 (the default URL) and localhost:5222 + checkPool(true, "localhost:4222", "localhost:5222") + + // Receive a new URL + info = []byte("INFO {\"connect_urls\":[\"localhost:6222\"]}\r\n") + err = c.parse(info) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + // Pool now should contain localhost:4222 (the default URL) localhost:5222 and localhost:6222 + checkPool(true, "localhost:4222", "localhost:5222", "localhost:6222") + + // Receive more than 1 URL at once. Add more than 2 to increase the chance of + // the array being shuffled. + info = []byte("INFO {\"connect_urls\":[\"localhost:7222\", \"localhost:8222\", " + + "\"localhost:9222\", \"localhost:10222\", \"localhost:11222\"]}\r\n") + err = c.parse(info) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + // Pool now should contain localhost:4222 (the default URL) localhost:5222, localhost:6222 + // localhost:7222, localhost:8222, localhost:9222, localhost:10222 and localhost:11222 + checkPool(false, "localhost:4222", "localhost:5222", "localhost:6222", "localhost:7222", "localhost:8222", + "localhost:9222", "localhost:10222", "localhost:11222") + + // Finally, check that (part of) the pool should be randomized. + allUrls := []string{"localhost:4222", "localhost:5222", "localhost:6222", "localhost:7222", "localhost:8222", + "localhost:9222", "localhost:10222", "localhost:11222"} + same := 0 + for i, url := range c.srvPool { + if url.url.Host == allUrls[i] { + same++ + } + } + if same == len(allUrls) { + t.Fatal("Pool does not seem to be randomized") + } + + // Check that pool may be randomized on setup, but new URLs are always + // added at end of pool. + c.Opts.NoRandomize = false + c.Opts.Servers = testServers + // Reset the pool + c.setupServerPool() + // Reinitialize the parser + c.ps = &parseState{} + // Capture the pool sequence after randomization + urlsAfterPoolSetup := make([]string, 0, len(c.srvPool)) + for _, srv := range c.srvPool { + urlsAfterPoolSetup = append(urlsAfterPoolSetup, srv.url.Host) + } + checkPoolOrderDidNotChange := func() { + for i := 0; i < len(urlsAfterPoolSetup); i++ { + if c.srvPool[i].url.Host != urlsAfterPoolSetup[i] { + stackFatalf(t, "Pool should have %q at index %q, has %q", urlsAfterPoolSetup[i], i, c.srvPool[i].url.Host) + } + } + } + // Add new urls + newURLs := []string{ + "localhost:6222", + "localhost:7222", + "localhost:8222\", \"localhost:9222", + "localhost:10222\", \"localhost:11222\", \"localhost:12222,", + } + for _, newURL := range newURLs { + info = []byte("INFO {\"connect_urls\":[\"" + newURL + "]}\r\n") + err = c.parse(info) + if err != nil || c.ps.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + // Check that pool order does not change up to the new addition(s). + checkPoolOrderDidNotChange() + } +} + +func TestConnServers(t *testing.T) { + opts := GetDefaultOptions() + c := &Conn{Opts: opts} + c.ps = &parseState{} + c.setupServerPool() + + validateURLs := func(serverUrls []string, expectedUrls ...string) { + var found bool + if len(serverUrls) != len(expectedUrls) { + stackFatalf(t, "Array should have %d elements, has %d", len(expectedUrls), len(serverUrls)) + } + + for _, ev := range expectedUrls { + found = false + for _, av := range serverUrls { + if ev == av { + found = true + break + } + } + if !found { + stackFatalf(t, "array is missing %q in %v", ev, serverUrls) + } + } + } + + // check the default url + validateURLs(c.Servers(), "nats://localhost:4222") + if len(c.DiscoveredServers()) != 0 { + t.Fatalf("Expected no discovered servers") + } + + // Add a new URL + err := c.parse([]byte("INFO {\"connect_urls\":[\"localhost:5222\"]}\r\n")) + if err != nil { + t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) + } + // Server list should now contain both the default and the new url. + validateURLs(c.Servers(), "nats://localhost:4222", "nats://localhost:5222") + // Discovered servers should only contain the new url. + validateURLs(c.DiscoveredServers(), "nats://localhost:5222") + + // verify user credentials are stripped out. + opts.Servers = []string{"nats://user:pass@localhost:4333", "nats://token@localhost:4444"} + c = &Conn{Opts: opts} + c.ps = &parseState{} + c.setupServerPool() + + validateURLs(c.Servers(), "nats://localhost:4333", "nats://localhost:4444") +} + +func TestProcessErrAuthorizationError(t *testing.T) { + ach := make(chan asyncCB, 1) + called := make(chan error, 1) + c := &Conn{ + ach: ach, + Opts: Options{ + AsyncErrorCB: func(nc *Conn, sub *Subscription, err error) { + called <- err + }, + }, + } + c.processErr("Authorization Violation") + select { + case cb := <-ach: + cb() + default: + t.Fatal("Expected callback on channel") + } + + select { + case err := <-called: + if err != ErrAuthorization { + t.Fatalf("Expected ErrAuthorization, got: %v", err) + } + default: + t.Fatal("Expected error on channel") + } +} diff --git a/vendor/github.com/nats-io/go-nats/netchan.go b/vendor/github.com/nats-io/go-nats/netchan.go new file mode 100644 index 00000000..0608fd7a --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/netchan.go @@ -0,0 +1,100 @@ +// Copyright 2013-2017 Apcera Inc. All rights reserved. + +package nats + +import ( + "errors" + "reflect" +) + +// This allows the functionality for network channels by binding send and receive Go chans +// to subjects and optionally queue groups. +// Data will be encoded and decoded via the EncodedConn and its associated encoders. + +// BindSendChan binds a channel for send operations to NATS. +func (c *EncodedConn) BindSendChan(subject string, channel interface{}) error { + chVal := reflect.ValueOf(channel) + if chVal.Kind() != reflect.Chan { + return ErrChanArg + } + go chPublish(c, chVal, subject) + return nil +} + +// Publish all values that arrive on the channel until it is closed or we +// encounter an error. +func chPublish(c *EncodedConn, chVal reflect.Value, subject string) { + for { + val, ok := chVal.Recv() + if !ok { + // Channel has most likely been closed. + return + } + if e := c.Publish(subject, val.Interface()); e != nil { + // Do this under lock. + c.Conn.mu.Lock() + defer c.Conn.mu.Unlock() + + if c.Conn.Opts.AsyncErrorCB != nil { + // FIXME(dlc) - Not sure this is the right thing to do. + // FIXME(ivan) - If the connection is not yet closed, try to schedule the callback + if c.Conn.isClosed() { + go c.Conn.Opts.AsyncErrorCB(c.Conn, nil, e) + } else { + c.Conn.ach <- func() { c.Conn.Opts.AsyncErrorCB(c.Conn, nil, e) } + } + } + return + } + } +} + +// BindRecvChan binds a channel for receive operations from NATS. +func (c *EncodedConn) BindRecvChan(subject string, channel interface{}) (*Subscription, error) { + return c.bindRecvChan(subject, _EMPTY_, channel) +} + +// BindRecvQueueChan binds a channel for queue-based receive operations from NATS. +func (c *EncodedConn) BindRecvQueueChan(subject, queue string, channel interface{}) (*Subscription, error) { + return c.bindRecvChan(subject, queue, channel) +} + +// Internal function to bind receive operations for a channel. +func (c *EncodedConn) bindRecvChan(subject, queue string, channel interface{}) (*Subscription, error) { + chVal := reflect.ValueOf(channel) + if chVal.Kind() != reflect.Chan { + return nil, ErrChanArg + } + argType := chVal.Type().Elem() + + cb := func(m *Msg) { + var oPtr reflect.Value + if argType.Kind() != reflect.Ptr { + oPtr = reflect.New(argType) + } else { + oPtr = reflect.New(argType.Elem()) + } + if err := c.Enc.Decode(m.Subject, m.Data, oPtr.Interface()); err != nil { + c.Conn.err = errors.New("nats: Got an error trying to unmarshal: " + err.Error()) + if c.Conn.Opts.AsyncErrorCB != nil { + c.Conn.ach <- func() { c.Conn.Opts.AsyncErrorCB(c.Conn, m.Sub, c.Conn.err) } + } + return + } + if argType.Kind() != reflect.Ptr { + oPtr = reflect.Indirect(oPtr) + } + // This is a bit hacky, but in this instance we may be trying to send to a closed channel. + // and the user does not know when it is safe to close the channel. + defer func() { + // If we have panicked, recover and close the subscription. + if r := recover(); r != nil { + m.Sub.Unsubscribe() + } + }() + // Actually do the send to the channel. + chVal.Send(oPtr) + } + + return c.Conn.subscribe(subject, queue, cb, nil) +} diff --git a/vendor/github.com/nats-io/go-nats/parser.go b/vendor/github.com/nats-io/go-nats/parser.go new file mode 100644 index 00000000..8359b8bc --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/parser.go @@ -0,0 +1,470 @@ +// Copyright 2012-2017 Apcera Inc. All rights reserved. + +package nats + +import ( + "fmt" +) + +type msgArg struct { + subject []byte + reply []byte + sid int64 + size int +} + +const MAX_CONTROL_LINE_SIZE = 1024 + +type parseState struct { + state int + as int + drop int + ma msgArg + argBuf []byte + msgBuf []byte + scratch [MAX_CONTROL_LINE_SIZE]byte +} + +const ( + OP_START = iota + OP_PLUS + OP_PLUS_O + OP_PLUS_OK + OP_MINUS + OP_MINUS_E + OP_MINUS_ER + OP_MINUS_ERR + OP_MINUS_ERR_SPC + MINUS_ERR_ARG + OP_M + OP_MS + OP_MSG + OP_MSG_SPC + MSG_ARG + MSG_PAYLOAD + MSG_END + OP_P + OP_PI + OP_PIN + OP_PING + OP_PO + OP_PON + OP_PONG + OP_I + OP_IN + OP_INF + OP_INFO + OP_INFO_SPC + INFO_ARG +) + +// parse is the fast protocol parser engine. +func (nc *Conn) parse(buf []byte) error { + var i int + var b byte + + // Move to loop instead of range syntax to allow jumping of i + for i = 0; i < len(buf); i++ { + b = buf[i] + + switch nc.ps.state { + case OP_START: + switch b { + case 'M', 'm': + nc.ps.state = OP_M + case 'P', 'p': + nc.ps.state = OP_P + case '+': + nc.ps.state = OP_PLUS + case '-': + nc.ps.state = OP_MINUS + case 'I', 'i': + nc.ps.state = OP_I + default: + goto parseErr + } + case OP_M: + switch b { + case 'S', 's': + nc.ps.state = OP_MS + default: + goto parseErr + } + case OP_MS: + switch b { + case 'G', 'g': + nc.ps.state = OP_MSG + default: + goto parseErr + } + case OP_MSG: + switch b { + case ' ', '\t': + nc.ps.state = OP_MSG_SPC + default: + goto parseErr + } + case OP_MSG_SPC: + switch b { + case ' ', '\t': + continue + default: + nc.ps.state = MSG_ARG + nc.ps.as = i + } + case MSG_ARG: + switch b { + case '\r': + nc.ps.drop = 1 + case '\n': + var arg []byte + if nc.ps.argBuf != nil { + arg = nc.ps.argBuf + } else { + arg = buf[nc.ps.as : i-nc.ps.drop] + } + if err := nc.processMsgArgs(arg); err != nil { + return err + } + nc.ps.drop, nc.ps.as, nc.ps.state = 0, i+1, MSG_PAYLOAD + + // jump ahead with the index. If this overruns + // what is left we fall out and process split + // buffer. + i = nc.ps.as + nc.ps.ma.size - 1 + default: + if nc.ps.argBuf != nil { + nc.ps.argBuf = append(nc.ps.argBuf, b) + } + } + case MSG_PAYLOAD: + if nc.ps.msgBuf != nil { + if len(nc.ps.msgBuf) >= nc.ps.ma.size { + nc.processMsg(nc.ps.msgBuf) + nc.ps.argBuf, nc.ps.msgBuf, nc.ps.state = nil, nil, MSG_END + } else { + // copy as much as we can to the buffer and skip ahead. + toCopy := nc.ps.ma.size - len(nc.ps.msgBuf) + avail := len(buf) - i + + if avail < toCopy { + toCopy = avail + } + + if toCopy > 0 { + start := len(nc.ps.msgBuf) + // This is needed for copy to work. + nc.ps.msgBuf = nc.ps.msgBuf[:start+toCopy] + copy(nc.ps.msgBuf[start:], buf[i:i+toCopy]) + // Update our index + i = (i + toCopy) - 1 + } else { + nc.ps.msgBuf = append(nc.ps.msgBuf, b) + } + } + } else if i-nc.ps.as >= nc.ps.ma.size { + nc.processMsg(buf[nc.ps.as:i]) + nc.ps.argBuf, nc.ps.msgBuf, nc.ps.state = nil, nil, MSG_END + } + case MSG_END: + switch b { + case '\n': + nc.ps.drop, nc.ps.as, nc.ps.state = 0, i+1, OP_START + default: + continue + } + case OP_PLUS: + switch b { + case 'O', 'o': + nc.ps.state = OP_PLUS_O + default: + goto parseErr + } + case OP_PLUS_O: + switch b { + case 'K', 'k': + nc.ps.state = OP_PLUS_OK + default: + goto parseErr + } + case OP_PLUS_OK: + switch b { + case '\n': + nc.processOK() + nc.ps.drop, nc.ps.state = 0, OP_START + } + case OP_MINUS: + switch b { + case 'E', 'e': + nc.ps.state = OP_MINUS_E + default: + goto parseErr + } + case OP_MINUS_E: + switch b { + case 'R', 'r': + nc.ps.state = OP_MINUS_ER + default: + goto parseErr + } + case OP_MINUS_ER: + switch b { + case 'R', 'r': + nc.ps.state = OP_MINUS_ERR + default: + goto parseErr + } + case OP_MINUS_ERR: + switch b { + case ' ', '\t': + nc.ps.state = OP_MINUS_ERR_SPC + default: + goto parseErr + } + case OP_MINUS_ERR_SPC: + switch b { + case ' ', '\t': + continue + default: + nc.ps.state = MINUS_ERR_ARG + nc.ps.as = i + } + case MINUS_ERR_ARG: + switch b { + case '\r': + nc.ps.drop = 1 + case '\n': + var arg []byte + if nc.ps.argBuf != nil { + arg = nc.ps.argBuf + nc.ps.argBuf = nil + } else { + arg = buf[nc.ps.as : i-nc.ps.drop] + } + nc.processErr(string(arg)) + nc.ps.drop, nc.ps.as, nc.ps.state = 0, i+1, OP_START + default: + if nc.ps.argBuf != nil { + nc.ps.argBuf = append(nc.ps.argBuf, b) + } + } + case OP_P: + switch b { + case 'I', 'i': + nc.ps.state = OP_PI + case 'O', 'o': + nc.ps.state = OP_PO + default: + goto parseErr + } + case OP_PO: + switch b { + case 'N', 'n': + nc.ps.state = OP_PON + default: + goto parseErr + } + case OP_PON: + switch b { + case 'G', 'g': + nc.ps.state = OP_PONG + default: + goto parseErr + } + case OP_PONG: + switch b { + case '\n': + nc.processPong() + nc.ps.drop, nc.ps.state = 0, OP_START + } + case OP_PI: + switch b { + case 'N', 'n': + nc.ps.state = OP_PIN + default: + goto parseErr + } + case OP_PIN: + switch b { + case 'G', 'g': + nc.ps.state = OP_PING + default: + goto parseErr + } + case OP_PING: + switch b { + case '\n': + nc.processPing() + nc.ps.drop, nc.ps.state = 0, OP_START + } + case OP_I: + switch b { + case 'N', 'n': + nc.ps.state = OP_IN + default: + goto parseErr + } + case OP_IN: + switch b { + case 'F', 'f': + nc.ps.state = OP_INF + default: + goto parseErr + } + case OP_INF: + switch b { + case 'O', 'o': + nc.ps.state = OP_INFO + default: + goto parseErr + } + case OP_INFO: + switch b { + case ' ', '\t': + nc.ps.state = OP_INFO_SPC + default: + goto parseErr + } + case OP_INFO_SPC: + switch b { + case ' ', '\t': + continue + default: + nc.ps.state = INFO_ARG + nc.ps.as = i + } + case INFO_ARG: + switch b { + case '\r': + nc.ps.drop = 1 + case '\n': + var arg []byte + if nc.ps.argBuf != nil { + arg = nc.ps.argBuf + nc.ps.argBuf = nil + } else { + arg = buf[nc.ps.as : i-nc.ps.drop] + } + nc.processAsyncInfo(arg) + nc.ps.drop, nc.ps.as, nc.ps.state = 0, i+1, OP_START + default: + if nc.ps.argBuf != nil { + nc.ps.argBuf = append(nc.ps.argBuf, b) + } + } + default: + goto parseErr + } + } + // Check for split buffer scenarios + if (nc.ps.state == MSG_ARG || nc.ps.state == MINUS_ERR_ARG || nc.ps.state == INFO_ARG) && nc.ps.argBuf == nil { + nc.ps.argBuf = nc.ps.scratch[:0] + nc.ps.argBuf = append(nc.ps.argBuf, buf[nc.ps.as:i-nc.ps.drop]...) + // FIXME, check max len + } + // Check for split msg + if nc.ps.state == MSG_PAYLOAD && nc.ps.msgBuf == nil { + // We need to clone the msgArg if it is still referencing the + // read buffer and we are not able to process the msg. + if nc.ps.argBuf == nil { + nc.cloneMsgArg() + } + + // If we will overflow the scratch buffer, just create a + // new buffer to hold the split message. + if nc.ps.ma.size > cap(nc.ps.scratch)-len(nc.ps.argBuf) { + lrem := len(buf[nc.ps.as:]) + + nc.ps.msgBuf = make([]byte, lrem, nc.ps.ma.size) + copy(nc.ps.msgBuf, buf[nc.ps.as:]) + } else { + nc.ps.msgBuf = nc.ps.scratch[len(nc.ps.argBuf):len(nc.ps.argBuf)] + nc.ps.msgBuf = append(nc.ps.msgBuf, (buf[nc.ps.as:])...) + } + } + + return nil + +parseErr: + return fmt.Errorf("nats: Parse Error [%d]: '%s'", nc.ps.state, buf[i:]) +} + +// cloneMsgArg is used when the split buffer scenario has the pubArg in the existing read buffer, but +// we need to hold onto it into the next read. +func (nc *Conn) cloneMsgArg() { + nc.ps.argBuf = nc.ps.scratch[:0] + nc.ps.argBuf = append(nc.ps.argBuf, nc.ps.ma.subject...) + nc.ps.argBuf = append(nc.ps.argBuf, nc.ps.ma.reply...) + nc.ps.ma.subject = nc.ps.argBuf[:len(nc.ps.ma.subject)] + if nc.ps.ma.reply != nil { + nc.ps.ma.reply = nc.ps.argBuf[len(nc.ps.ma.subject):] + } +} + +const argsLenMax = 4 + +func (nc *Conn) processMsgArgs(arg []byte) error { + // Unroll splitArgs to avoid runtime/heap issues + a := [argsLenMax][]byte{} + args := a[:0] + start := -1 + for i, b := range arg { + switch b { + case ' ', '\t', '\r', '\n': + if start >= 0 { + args = append(args, arg[start:i]) + start = -1 + } + default: + if start < 0 { + start = i + } + } + } + if start >= 0 { + args = append(args, arg[start:]) + } + + switch len(args) { + case 3: + nc.ps.ma.subject = args[0] + nc.ps.ma.sid = parseInt64(args[1]) + nc.ps.ma.reply = nil + nc.ps.ma.size = int(parseInt64(args[2])) + case 4: + nc.ps.ma.subject = args[0] + nc.ps.ma.sid = parseInt64(args[1]) + nc.ps.ma.reply = args[2] + nc.ps.ma.size = int(parseInt64(args[3])) + default: + return fmt.Errorf("nats: processMsgArgs Parse Error: '%s'", arg) + } + if nc.ps.ma.sid < 0 { + return fmt.Errorf("nats: processMsgArgs Bad or Missing Sid: '%s'", arg) + } + if nc.ps.ma.size < 0 { + return fmt.Errorf("nats: processMsgArgs Bad or Missing Size: '%s'", arg) + } + return nil +} + +// Ascii numbers 0-9 +const ( + ascii_0 = 48 + ascii_9 = 57 +) + +// parseInt64 expects decimal positive numbers. We +// return -1 to signal error +func parseInt64(d []byte) (n int64) { + if len(d) == 0 { + return -1 + } + for _, dec := range d { + if dec < ascii_0 || dec > ascii_9 { + return -1 + } + n = n*10 + (int64(dec) - ascii_0) + } + return n +} diff --git a/vendor/github.com/nats-io/go-nats/staticcheck.ignore b/vendor/github.com/nats-io/go-nats/staticcheck.ignore new file mode 100644 index 00000000..25bbf020 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/staticcheck.ignore @@ -0,0 +1,4 @@ +github.com/nats-io/go-nats/*_test.go:SA2002 +github.com/nats-io/go-nats/*/*_test.go:SA2002 +github.com/nats-io/go-nats/test/context_test.go:SA1012 +github.com/nats-io/go-nats/nats.go:SA6000 diff --git a/vendor/github.com/nats-io/go-nats/timer.go b/vendor/github.com/nats-io/go-nats/timer.go new file mode 100644 index 00000000..1b96fd52 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/timer.go @@ -0,0 +1,43 @@ +package nats + +import ( + "sync" + "time" +) + +// global pool of *time.Timer's. can be used by multiple goroutines concurrently. +var globalTimerPool timerPool + +// timerPool provides GC-able pooling of *time.Timer's. +// can be used by multiple goroutines concurrently. +type timerPool struct { + p sync.Pool +} + +// Get returns a timer that completes after the given duration. +func (tp *timerPool) Get(d time.Duration) *time.Timer { + if t, _ := tp.p.Get().(*time.Timer); t != nil { + t.Reset(d) + return t + } + + return time.NewTimer(d) +} + +// Put pools the given timer. +// +// There is no need to call t.Stop() before calling Put. +// +// Put will try to stop the timer before pooling. If the +// given timer already expired, Put will read the unreceived +// value if there is one. +func (tp *timerPool) Put(t *time.Timer) { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + + tp.p.Put(t) +} diff --git a/vendor/github.com/nats-io/go-nats/timer_test.go b/vendor/github.com/nats-io/go-nats/timer_test.go new file mode 100644 index 00000000..fb02a769 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/timer_test.go @@ -0,0 +1,29 @@ +package nats + +import ( + "testing" + "time" +) + +func TestTimerPool(t *testing.T) { + var tp timerPool + + for i := 0; i < 10; i++ { + tm := tp.Get(time.Millisecond * 20) + + select { + case <-tm.C: + t.Errorf("Timer already expired") + continue + default: + } + + select { + case <-tm.C: + case <-time.After(time.Millisecond * 100): + t.Errorf("Timer didn't expire in time") + } + + tp.Put(tm) + } +} diff --git a/vendor/vendor.json b/vendor/vendor.json new file mode 100644 index 00000000..84f657a3 --- /dev/null +++ b/vendor/vendor.json @@ -0,0 +1,13 @@ +{ + "comment": "", + "ignore": "", + "package": [ + { + "checksumSHA1": "nWIa0L7ux21Cb8kzB4rJHXMblpI=", + "path": "github.com/nats-io/go-nats", + "revision": "f0d9c5988d4c2a17ad466fcdffe010165c46434e", + "revisionTime": "2017-11-14T23:23:38Z" + } + ], + "rootPath": "github.com/tidwall/tile38" +} From 2f076524d4029b2c3472485fef3718dcfc2322c1 Mon Sep 17 00:00:00 2001 From: Lenny-Campino Hartmann Date: Tue, 7 Aug 2018 21:24:46 +0200 Subject: [PATCH 3/3] fixed vendor fetch --- vendor/github.com/nats-io/gnatsd/LICENSE | 201 ++ vendor/github.com/nats-io/gnatsd/conf/lex.go | 1141 ++++++++++ .../nats-io/gnatsd/conf/lex_test.go | 1224 +++++++++++ .../github.com/nats-io/gnatsd/conf/parse.go | 295 +++ .../nats-io/gnatsd/conf/parse_test.go | 275 +++ .../nats-io/gnatsd/conf/simple.conf | 6 + .../github.com/nats-io/gnatsd/logger/log.go | 152 ++ .../nats-io/gnatsd/logger/log_test.go | 185 ++ .../nats-io/gnatsd/logger/syslog.go | 123 ++ .../nats-io/gnatsd/logger/syslog_test.go | 235 ++ .../nats-io/gnatsd/logger/syslog_windows.go | 104 + .../gnatsd/logger/syslog_windows_test.go | 139 ++ .../github.com/nats-io/gnatsd/server/auth.go | 249 +++ .../nats-io/gnatsd/server/auth_test.go | 99 + .../nats-io/gnatsd/server/ciphersuites.go | 97 + .../nats-io/gnatsd/server/client.go | 1810 +++++++++++++++ .../nats-io/gnatsd/server/client_test.go | 1096 ++++++++++ .../gnatsd/server/closed_conns_test.go | 361 +++ .../github.com/nats-io/gnatsd/server/const.go | 124 ++ .../nats-io/gnatsd/server/errors.go | 51 + .../github.com/nats-io/gnatsd/server/log.go | 184 ++ .../nats-io/gnatsd/server/log_test.go | 175 ++ .../nats-io/gnatsd/server/monitor.go | 1029 +++++++++ .../gnatsd/server/monitor_sort_opts.go | 147 ++ .../nats-io/gnatsd/server/monitor_test.go | 1933 +++++++++++++++++ .../nats-io/gnatsd/server/norace_test.go | 87 + .../github.com/nats-io/gnatsd/server/opts.go | 1259 +++++++++++ .../nats-io/gnatsd/server/opts_test.go | 1037 +++++++++ .../nats-io/gnatsd/server/parser.go | 749 +++++++ .../nats-io/gnatsd/server/parser_test.go | 546 +++++ .../nats-io/gnatsd/server/ping_test.go | 44 + .../nats-io/gnatsd/server/pse/pse_darwin.go | 34 + .../nats-io/gnatsd/server/pse/pse_freebsd.go | 83 + .../nats-io/gnatsd/server/pse/pse_linux.go | 126 ++ .../nats-io/gnatsd/server/pse/pse_openbsd.go | 36 + .../nats-io/gnatsd/server/pse/pse_rumprun.go | 25 + .../nats-io/gnatsd/server/pse/pse_solaris.go | 23 + .../nats-io/gnatsd/server/pse/pse_test.go | 67 + .../nats-io/gnatsd/server/pse/pse_windows.go | 280 +++ .../gnatsd/server/pse/pse_windows_test.go | 97 + .../nats-io/gnatsd/server/reload.go | 714 ++++++ .../nats-io/gnatsd/server/reload_test.go | 1901 ++++++++++++++++ .../github.com/nats-io/gnatsd/server/ring.go | 75 + .../github.com/nats-io/gnatsd/server/route.go | 1103 ++++++++++ .../nats-io/gnatsd/server/routes_test.go | 1000 +++++++++ .../nats-io/gnatsd/server/server.go | 1420 ++++++++++++ .../nats-io/gnatsd/server/server_test.go | 700 ++++++ .../nats-io/gnatsd/server/service.go | 28 + .../nats-io/gnatsd/server/service_test.go | 53 + .../nats-io/gnatsd/server/service_windows.go | 121 ++ .../nats-io/gnatsd/server/signal.go | 158 ++ .../nats-io/gnatsd/server/signal_test.go | 314 +++ .../nats-io/gnatsd/server/signal_windows.go | 101 + .../nats-io/gnatsd/server/split_test.go | 517 +++++ .../nats-io/gnatsd/server/sublist.go | 775 +++++++ .../nats-io/gnatsd/server/sublist_test.go | 1019 +++++++++ .../github.com/nats-io/gnatsd/server/util.go | 111 + .../nats-io/gnatsd/server/util_test.go | 168 ++ .../nats-io/gnatsd/test/auth_test.go | 236 ++ .../nats-io/gnatsd/test/bench_results.txt | 79 + .../nats-io/gnatsd/test/bench_test.go | 674 ++++++ .../nats-io/gnatsd/test/client_auth_test.go | 92 + .../gnatsd/test/client_cluster_test.go | 377 ++++ .../nats-io/gnatsd/test/cluster_test.go | 491 +++++ .../nats-io/gnatsd/test/cluster_tls_test.go | 64 + .../nats-io/gnatsd/test/fanout_test.go | 128 ++ .../nats-io/gnatsd/test/gosrv_test.go | 62 + .../nats-io/gnatsd/test/maxpayload_test.go | 99 + .../nats-io/gnatsd/test/monitor_test.go | 739 +++++++ .../nats-io/gnatsd/test/opts_test.go | 43 + .../nats-io/gnatsd/test/pedantic_test.go | 109 + .../nats-io/gnatsd/test/pid_test.go | 55 + .../nats-io/gnatsd/test/ping_test.go | 189 ++ .../nats-io/gnatsd/test/port_test.go | 52 + .../nats-io/gnatsd/test/ports_test.go | 195 ++ .../nats-io/gnatsd/test/proto_test.go | 317 +++ .../gnatsd/test/route_discovery_test.go | 665 ++++++ .../nats-io/gnatsd/test/routes_test.go | 1052 +++++++++ vendor/github.com/nats-io/gnatsd/test/test.go | 378 ++++ .../nats-io/gnatsd/test/test_test.go | 72 + .../nats-io/gnatsd/test/tls_test.go | 334 +++ .../gnatsd/test/user_authorization_test.go | 101 + .../nats-io/gnatsd/test/verbose_test.go | 82 + .../nats-io/gnatsd/util/gnatsd.service | 16 + vendor/github.com/nats-io/gnatsd/util/tls.go | 25 + .../nats-io/gnatsd/util/tls_pre17.go | 47 + .../nats-io/gnatsd/util/tls_pre18.go | 49 + .../github.com/nats-io/go-nats/GOVERNANCE.md | 3 + vendor/github.com/nats-io/go-nats/LICENSE | 213 +- .../github.com/nats-io/go-nats/MAINTAINERS.md | 10 + vendor/github.com/nats-io/go-nats/README.md | 26 +- vendor/github.com/nats-io/go-nats/context.go | 13 +- vendor/github.com/nats-io/go-nats/enc.go | 17 +- vendor/github.com/nats-io/go-nats/enc_test.go | 13 + .../go-nats/encoders/builtin/default_enc.go | 117 + .../go-nats/encoders/builtin/gob_enc.go | 45 + .../go-nats/encoders/builtin/json_enc.go | 56 + .../go-nats/encoders/protobuf/protobuf_enc.go | 73 + .../encoders/protobuf/testdata/pbtest.pb.go | 40 + .../encoders/protobuf/testdata/pbtest.proto | 11 + .../nats-io/go-nats/example_test.go | 13 + vendor/github.com/nats-io/go-nats/nats.go | 537 +++-- .../github.com/nats-io/go-nats/nats_test.go | 265 ++- vendor/github.com/nats-io/go-nats/netchan.go | 17 +- vendor/github.com/nats-io/go-nats/parser.go | 13 +- vendor/github.com/nats-io/go-nats/timer.go | 13 + .../github.com/nats-io/go-nats/timer_test.go | 13 + vendor/github.com/nats-io/go-nats/util/tls.go | 27 + .../nats-io/go-nats/util/tls_go17.go | 49 + vendor/github.com/nats-io/nuid/GOVERNANCE.md | 3 + vendor/github.com/nats-io/nuid/LICENSE | 201 ++ vendor/github.com/nats-io/nuid/MAINTAINERS.md | 6 + vendor/github.com/nats-io/nuid/README.md | 47 + vendor/github.com/nats-io/nuid/nuid.go | 135 ++ vendor/github.com/nats-io/nuid/nuid_test.go | 92 + vendor/github.com/nats-io/nuid/unique_test.go | 32 + vendor/vendor.json | 72 +- 117 files changed, 34904 insertions(+), 296 deletions(-) create mode 100644 vendor/github.com/nats-io/gnatsd/LICENSE create mode 100644 vendor/github.com/nats-io/gnatsd/conf/lex.go create mode 100644 vendor/github.com/nats-io/gnatsd/conf/lex_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/conf/parse.go create mode 100644 vendor/github.com/nats-io/gnatsd/conf/parse_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/conf/simple.conf create mode 100644 vendor/github.com/nats-io/gnatsd/logger/log.go create mode 100644 vendor/github.com/nats-io/gnatsd/logger/log_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/logger/syslog.go create mode 100644 vendor/github.com/nats-io/gnatsd/logger/syslog_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/logger/syslog_windows.go create mode 100755 vendor/github.com/nats-io/gnatsd/logger/syslog_windows_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/auth.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/auth_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/ciphersuites.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/client.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/client_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/closed_conns_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/const.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/errors.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/log.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/log_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/monitor.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/monitor_sort_opts.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/monitor_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/norace_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/opts.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/opts_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/parser.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/parser_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/ping_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/pse/pse_darwin.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/pse/pse_freebsd.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/pse/pse_linux.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/pse/pse_openbsd.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/pse/pse_rumprun.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/pse/pse_solaris.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/pse/pse_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/pse/pse_windows.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/pse/pse_windows_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/reload.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/reload_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/ring.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/route.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/routes_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/server.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/server_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/service.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/service_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/service_windows.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/signal.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/signal_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/signal_windows.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/split_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/sublist.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/sublist_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/util.go create mode 100644 vendor/github.com/nats-io/gnatsd/server/util_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/auth_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/bench_results.txt create mode 100644 vendor/github.com/nats-io/gnatsd/test/bench_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/client_auth_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/client_cluster_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/cluster_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/cluster_tls_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/fanout_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/gosrv_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/maxpayload_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/monitor_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/opts_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/pedantic_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/pid_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/ping_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/port_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/ports_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/proto_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/route_discovery_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/routes_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/test_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/tls_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/user_authorization_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/test/verbose_test.go create mode 100644 vendor/github.com/nats-io/gnatsd/util/gnatsd.service create mode 100644 vendor/github.com/nats-io/gnatsd/util/tls.go create mode 100644 vendor/github.com/nats-io/gnatsd/util/tls_pre17.go create mode 100644 vendor/github.com/nats-io/gnatsd/util/tls_pre18.go create mode 100644 vendor/github.com/nats-io/go-nats/GOVERNANCE.md create mode 100644 vendor/github.com/nats-io/go-nats/MAINTAINERS.md create mode 100644 vendor/github.com/nats-io/go-nats/encoders/builtin/default_enc.go create mode 100644 vendor/github.com/nats-io/go-nats/encoders/builtin/gob_enc.go create mode 100644 vendor/github.com/nats-io/go-nats/encoders/builtin/json_enc.go create mode 100644 vendor/github.com/nats-io/go-nats/encoders/protobuf/protobuf_enc.go create mode 100644 vendor/github.com/nats-io/go-nats/encoders/protobuf/testdata/pbtest.pb.go create mode 100644 vendor/github.com/nats-io/go-nats/encoders/protobuf/testdata/pbtest.proto create mode 100644 vendor/github.com/nats-io/go-nats/util/tls.go create mode 100644 vendor/github.com/nats-io/go-nats/util/tls_go17.go create mode 100644 vendor/github.com/nats-io/nuid/GOVERNANCE.md create mode 100644 vendor/github.com/nats-io/nuid/LICENSE create mode 100644 vendor/github.com/nats-io/nuid/MAINTAINERS.md create mode 100644 vendor/github.com/nats-io/nuid/README.md create mode 100644 vendor/github.com/nats-io/nuid/nuid.go create mode 100644 vendor/github.com/nats-io/nuid/nuid_test.go create mode 100644 vendor/github.com/nats-io/nuid/unique_test.go diff --git a/vendor/github.com/nats-io/gnatsd/LICENSE b/vendor/github.com/nats-io/gnatsd/LICENSE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/nats-io/gnatsd/conf/lex.go b/vendor/github.com/nats-io/gnatsd/conf/lex.go new file mode 100644 index 00000000..f9603a99 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/conf/lex.go @@ -0,0 +1,1141 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Customized heavily from +// https://github.com/BurntSushi/toml/blob/master/lex.go, which is based on +// Rob Pike's talk: http://cuddle.googlecode.com/hg/talk/lex.html + +// The format supported is less restrictive than today's formats. +// Supports mixed Arrays [], nested Maps {}, multiple comment types (# and //) +// Also supports key value assigments using '=' or ':' or whiteSpace() +// e.g. foo = 2, foo : 2, foo 2 +// maps can be assigned with no key separator as well +// semicolons as value terminators in key/value assignments are optional +// +// see lex_test.go for more examples. + +package conf + +import ( + "encoding/hex" + "fmt" + "strings" + "unicode" + "unicode/utf8" +) + +type itemType int + +const ( + itemError itemType = iota + itemNIL // used in the parser to indicate no type + itemEOF + itemKey + itemText + itemString + itemBool + itemInteger + itemFloat + itemDatetime + itemArrayStart + itemArrayEnd + itemMapStart + itemMapEnd + itemCommentStart + itemVariable + itemInclude +) + +const ( + eof = 0 + mapStart = '{' + mapEnd = '}' + keySepEqual = '=' + keySepColon = ':' + arrayStart = '[' + arrayEnd = ']' + arrayValTerm = ',' + mapValTerm = ',' + commentHashStart = '#' + commentSlashStart = '/' + dqStringStart = '"' + dqStringEnd = '"' + sqStringStart = '\'' + sqStringEnd = '\'' + optValTerm = ';' + topOptStart = '{' + topOptValTerm = ',' + topOptTerm = '}' + blockStart = '(' + blockEnd = ')' +) + +type stateFn func(lx *lexer) stateFn + +type lexer struct { + input string + start int + pos int + width int + line int + state stateFn + items chan item + + // A stack of state functions used to maintain context. + // The idea is to reuse parts of the state machine in various places. + // For example, values can appear at the top level or within arbitrarily + // nested arrays. The last state on the stack is used after a value has + // been lexed. Similarly for comments. + stack []stateFn + + // Used for processing escapable substrings in double-quoted and raw strings + stringParts []string + stringStateFn stateFn +} + +type item struct { + typ itemType + val string + line int +} + +func (lx *lexer) nextItem() item { + for { + select { + case item := <-lx.items: + return item + default: + lx.state = lx.state(lx) + } + } +} + +func lex(input string) *lexer { + lx := &lexer{ + input: input, + state: lexTop, + line: 1, + items: make(chan item, 10), + stack: make([]stateFn, 0, 10), + stringParts: []string{}, + } + return lx +} + +func (lx *lexer) push(state stateFn) { + lx.stack = append(lx.stack, state) +} + +func (lx *lexer) pop() stateFn { + if len(lx.stack) == 0 { + return lx.errorf("BUG in lexer: no states to pop.") + } + li := len(lx.stack) - 1 + last := lx.stack[li] + lx.stack = lx.stack[0:li] + return last +} + +func (lx *lexer) emit(typ itemType) { + lx.items <- item{typ, strings.Join(lx.stringParts, "") + lx.input[lx.start:lx.pos], lx.line} + lx.start = lx.pos +} + +func (lx *lexer) emitString() { + var finalString string + if len(lx.stringParts) > 0 { + finalString = strings.Join(lx.stringParts, "") + lx.input[lx.start:lx.pos] + lx.stringParts = []string{} + } else { + finalString = lx.input[lx.start:lx.pos] + } + lx.items <- item{itemString, finalString, lx.line} + lx.start = lx.pos +} + +func (lx *lexer) addCurrentStringPart(offset int) { + lx.stringParts = append(lx.stringParts, lx.input[lx.start:lx.pos-offset]) + lx.start = lx.pos +} + +func (lx *lexer) addStringPart(s string) stateFn { + lx.stringParts = append(lx.stringParts, s) + lx.start = lx.pos + return lx.stringStateFn +} + +func (lx *lexer) hasEscapedParts() bool { + return len(lx.stringParts) > 0 +} + +func (lx *lexer) next() (r rune) { + if lx.pos >= len(lx.input) { + lx.width = 0 + return eof + } + + if lx.input[lx.pos] == '\n' { + lx.line++ + } + r, lx.width = utf8.DecodeRuneInString(lx.input[lx.pos:]) + lx.pos += lx.width + return r +} + +// ignore skips over the pending input before this point. +func (lx *lexer) ignore() { + lx.start = lx.pos +} + +// backup steps back one rune. Can be called only once per call of next. +func (lx *lexer) backup() { + lx.pos -= lx.width + if lx.pos < len(lx.input) && lx.input[lx.pos] == '\n' { + lx.line-- + } +} + +// peek returns but does not consume the next rune in the input. +func (lx *lexer) peek() rune { + r := lx.next() + lx.backup() + return r +} + +// errorf stops all lexing by emitting an error and returning `nil`. +// Note that any value that is a character is escaped if it's a special +// character (new lines, tabs, etc.). +func (lx *lexer) errorf(format string, values ...interface{}) stateFn { + for i, value := range values { + if v, ok := value.(rune); ok { + values[i] = escapeSpecial(v) + } + } + lx.items <- item{ + itemError, + fmt.Sprintf(format, values...), + lx.line, + } + return nil +} + +// lexTop consumes elements at the top level of data structure. +func lexTop(lx *lexer) stateFn { + r := lx.next() + if unicode.IsSpace(r) { + return lexSkip(lx, lexTop) + } + + switch r { + case topOptStart: + return lexSkip(lx, lexTop) + case commentHashStart: + lx.push(lexTop) + return lexCommentStart + case commentSlashStart: + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexTop) + return lexCommentStart + } + lx.backup() + fallthrough + case eof: + if lx.pos > lx.start { + return lx.errorf("Unexpected EOF.") + } + lx.emit(itemEOF) + return nil + } + + // At this point, the only valid item can be a key, so we back up + // and let the key lexer do the rest. + lx.backup() + lx.push(lexTopValueEnd) + return lexKeyStart +} + +// lexTopValueEnd is entered whenever a top-level value has been consumed. +// It must see only whitespace, and will turn back to lexTop upon a new line. +// If it sees EOF, it will quit the lexer successfully. +func lexTopValueEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case r == commentHashStart: + // a comment will read to a new line for us. + lx.push(lexTop) + return lexCommentStart + case r == commentSlashStart: + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexTop) + return lexCommentStart + } + lx.backup() + fallthrough + case isWhitespace(r): + return lexTopValueEnd + case isNL(r) || r == eof || r == optValTerm || r == topOptValTerm || r == topOptTerm: + lx.ignore() + return lexTop + } + return lx.errorf("Expected a top-level value to end with a new line, "+ + "comment or EOF, but got '%v' instead.", r) +} + +// lexKeyStart consumes a key name up until the first non-whitespace character. +// lexKeyStart will ignore whitespace. It will also eat enclosing quotes. +func lexKeyStart(lx *lexer) stateFn { + r := lx.peek() + switch { + case isKeySeparator(r): + return lx.errorf("Unexpected key separator '%v'", r) + case unicode.IsSpace(r): + lx.next() + return lexSkip(lx, lexKeyStart) + case r == dqStringStart: + lx.next() + return lexSkip(lx, lexDubQuotedKey) + case r == sqStringStart: + lx.next() + return lexSkip(lx, lexQuotedKey) + } + lx.ignore() + lx.next() + return lexKey +} + +// lexDubQuotedKey consumes the text of a key between quotes. +func lexDubQuotedKey(lx *lexer) stateFn { + r := lx.peek() + if r == dqStringEnd { + lx.emit(itemKey) + lx.next() + return lexSkip(lx, lexKeyEnd) + } + lx.next() + return lexDubQuotedKey +} + +// lexQuotedKey consumes the text of a key between quotes. +func lexQuotedKey(lx *lexer) stateFn { + r := lx.peek() + if r == sqStringEnd { + lx.emit(itemKey) + lx.next() + return lexSkip(lx, lexKeyEnd) + } + lx.next() + return lexQuotedKey +} + +// keyCheckKeyword will check for reserved keywords as the key value when the key is +// separated with a space. +func (lx *lexer) keyCheckKeyword(fallThrough, push stateFn) stateFn { + key := strings.ToLower(lx.input[lx.start:lx.pos]) + switch key { + case "include": + lx.ignore() + if push != nil { + lx.push(push) + } + return lexIncludeStart + } + lx.emit(itemKey) + return fallThrough +} + +// lexIncludeStart will consume the whitespace til the start of the value. +func lexIncludeStart(lx *lexer) stateFn { + r := lx.next() + if isWhitespace(r) { + return lexSkip(lx, lexIncludeStart) + } + lx.backup() + return lexInclude +} + +// lexIncludeQuotedString consumes the inner contents of a string. It assumes that the +// beginning '"' has already been consumed and ignored. It will not interpret any +// internal contents. +func lexIncludeQuotedString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == sqStringEnd: + lx.backup() + lx.emit(itemInclude) + lx.next() + lx.ignore() + return lx.pop() + } + return lexIncludeQuotedString +} + +// lexIncludeDubQuotedString consumes the inner contents of a string. It assumes that the +// beginning '"' has already been consumed and ignored. It will not interpret any +// internal contents. +func lexIncludeDubQuotedString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == dqStringEnd: + lx.backup() + lx.emit(itemInclude) + lx.next() + lx.ignore() + return lx.pop() + } + return lexIncludeDubQuotedString +} + +// lexIncludeString consumes the inner contents of a raw string. +func lexIncludeString(lx *lexer) stateFn { + r := lx.next() + switch { + case isNL(r) || r == eof || r == optValTerm || r == mapEnd || isWhitespace(r): + lx.backup() + lx.emit(itemInclude) + return lx.pop() + case r == sqStringEnd: + lx.backup() + lx.emit(itemInclude) + lx.next() + lx.ignore() + return lx.pop() + } + return lexIncludeString +} + +// lexInclude will consume the include value. +func lexInclude(lx *lexer) stateFn { + r := lx.next() + switch { + case r == sqStringStart: + lx.ignore() // ignore the " or ' + return lexIncludeQuotedString + case r == dqStringStart: + lx.ignore() // ignore the " or ' + return lexIncludeDubQuotedString + case r == arrayStart: + return lx.errorf("Expected include value but found start of an array") + case r == mapStart: + return lx.errorf("Expected include value but found start of a map") + case r == blockStart: + return lx.errorf("Expected include value but found start of a block") + case unicode.IsDigit(r), r == '-': + return lx.errorf("Expected include value but found start of a number") + case r == '\\': + return lx.errorf("Expected include value but found escape sequence") + case isNL(r): + return lx.errorf("Expected include value but found new line") + } + lx.backup() + return lexIncludeString +} + +// lexKey consumes the text of a key. Assumes that the first character (which +// is not whitespace) has already been consumed. +func lexKey(lx *lexer) stateFn { + r := lx.peek() + if unicode.IsSpace(r) { + // Spaces signal we could be looking at a keyword, e.g. include. + // Keywords will eat the keyword and set the appropriate return stateFn. + return lx.keyCheckKeyword(lexKeyEnd, nil) + } else if isKeySeparator(r) || r == eof { + lx.emit(itemKey) + return lexKeyEnd + } + lx.next() + return lexKey +} + +// lexKeyEnd consumes the end of a key (up to the key separator). +// Assumes that the first whitespace character after a key (or the '=' or ':' +// separator) has NOT been consumed. +func lexKeyEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case unicode.IsSpace(r): + return lexSkip(lx, lexKeyEnd) + case isKeySeparator(r): + return lexSkip(lx, lexValue) + case r == eof: + lx.emit(itemEOF) + return nil + } + // We start the value here + lx.backup() + return lexValue +} + +// lexValue starts the consumption of a value anywhere a value is expected. +// lexValue will ignore whitespace. +// After a value is lexed, the last state on the next is popped and returned. +func lexValue(lx *lexer) stateFn { + // We allow whitespace to precede a value, but NOT new lines. + // In array syntax, the array states are responsible for ignoring new lines. + r := lx.next() + if isWhitespace(r) { + return lexSkip(lx, lexValue) + } + + switch { + case r == arrayStart: + lx.ignore() + lx.emit(itemArrayStart) + return lexArrayValue + case r == mapStart: + lx.ignore() + lx.emit(itemMapStart) + return lexMapKeyStart + case r == sqStringStart: + lx.ignore() // ignore the " or ' + return lexQuotedString + case r == dqStringStart: + lx.ignore() // ignore the " or ' + lx.stringStateFn = lexDubQuotedString + return lexDubQuotedString + case r == '-': + return lexNegNumberStart + case r == blockStart: + lx.ignore() + return lexBlock + case unicode.IsDigit(r): + lx.backup() // avoid an extra state and use the same as above + return lexNumberOrDateOrIPStart + case r == '.': // special error case, be kind to users + return lx.errorf("Floats must start with a digit") + case isNL(r): + return lx.errorf("Expected value but found new line") + } + lx.backup() + lx.stringStateFn = lexString + return lexString +} + +// lexArrayValue consumes one value in an array. It assumes that '[' or ',' +// have already been consumed. All whitespace and new lines are ignored. +func lexArrayValue(lx *lexer) stateFn { + r := lx.next() + switch { + case unicode.IsSpace(r): + return lexSkip(lx, lexArrayValue) + case r == commentHashStart: + lx.push(lexArrayValue) + return lexCommentStart + case r == commentSlashStart: + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexArrayValue) + return lexCommentStart + } + lx.backup() + fallthrough + case r == arrayValTerm: + return lx.errorf("Unexpected array value terminator '%v'.", arrayValTerm) + case r == arrayEnd: + return lexArrayEnd + } + + lx.backup() + lx.push(lexArrayValueEnd) + return lexValue +} + +// lexArrayValueEnd consumes the cruft between values of an array. Namely, +// it ignores whitespace and expects either a ',' or a ']'. +func lexArrayValueEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case isWhitespace(r): + return lexSkip(lx, lexArrayValueEnd) + case r == commentHashStart: + lx.push(lexArrayValueEnd) + return lexCommentStart + case r == commentSlashStart: + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexArrayValueEnd) + return lexCommentStart + } + lx.backup() + fallthrough + case r == arrayValTerm || isNL(r): + return lexSkip(lx, lexArrayValue) // Move onto next + case r == arrayEnd: + return lexArrayEnd + } + return lx.errorf("Expected an array value terminator %q or an array "+ + "terminator %q, but got '%v' instead.", arrayValTerm, arrayEnd, r) +} + +// lexArrayEnd finishes the lexing of an array. It assumes that a ']' has +// just been consumed. +func lexArrayEnd(lx *lexer) stateFn { + lx.ignore() + lx.emit(itemArrayEnd) + return lx.pop() +} + +// lexMapKeyStart consumes a key name up until the first non-whitespace +// character. +// lexMapKeyStart will ignore whitespace. +func lexMapKeyStart(lx *lexer) stateFn { + r := lx.peek() + switch { + case isKeySeparator(r): + return lx.errorf("Unexpected key separator '%v'.", r) + case unicode.IsSpace(r): + lx.next() + return lexSkip(lx, lexMapKeyStart) + case r == mapEnd: + lx.next() + return lexSkip(lx, lexMapEnd) + case r == commentHashStart: + lx.next() + lx.push(lexMapKeyStart) + return lexCommentStart + case r == commentSlashStart: + lx.next() + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexMapKeyStart) + return lexCommentStart + } + lx.backup() + case r == sqStringStart: + lx.next() + return lexSkip(lx, lexMapQuotedKey) + case r == dqStringStart: + lx.next() + return lexSkip(lx, lexMapDubQuotedKey) + } + lx.ignore() + lx.next() + return lexMapKey +} + +// lexMapQuotedKey consumes the text of a key between quotes. +func lexMapQuotedKey(lx *lexer) stateFn { + r := lx.peek() + if r == sqStringEnd { + lx.emit(itemKey) + lx.next() + return lexSkip(lx, lexMapKeyEnd) + } + lx.next() + return lexMapQuotedKey +} + +// lexMapQuotedKey consumes the text of a key between quotes. +func lexMapDubQuotedKey(lx *lexer) stateFn { + r := lx.peek() + if r == dqStringEnd { + lx.emit(itemKey) + lx.next() + return lexSkip(lx, lexMapKeyEnd) + } + lx.next() + return lexMapDubQuotedKey +} + +// lexMapKey consumes the text of a key. Assumes that the first character (which +// is not whitespace) has already been consumed. +func lexMapKey(lx *lexer) stateFn { + r := lx.peek() + if unicode.IsSpace(r) { + // Spaces signal we could be looking at a keyword, e.g. include. + // Keywords will eat the keyword and set the appropriate return stateFn. + return lx.keyCheckKeyword(lexMapKeyEnd, lexMapValueEnd) + } else if isKeySeparator(r) { + lx.emit(itemKey) + return lexMapKeyEnd + } + lx.next() + return lexMapKey +} + +// lexMapKeyEnd consumes the end of a key (up to the key separator). +// Assumes that the first whitespace character after a key (or the '=' +// separator) has NOT been consumed. +func lexMapKeyEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case unicode.IsSpace(r): + return lexSkip(lx, lexMapKeyEnd) + case isKeySeparator(r): + return lexSkip(lx, lexMapValue) + } + // We start the value here + lx.backup() + return lexMapValue +} + +// lexMapValue consumes one value in a map. It assumes that '{' or ',' +// have already been consumed. All whitespace and new lines are ignored. +// Map values can be separated by ',' or simple NLs. +func lexMapValue(lx *lexer) stateFn { + r := lx.next() + switch { + case unicode.IsSpace(r): + return lexSkip(lx, lexMapValue) + case r == mapValTerm: + return lx.errorf("Unexpected map value terminator %q.", mapValTerm) + case r == mapEnd: + return lexSkip(lx, lexMapEnd) + } + lx.backup() + lx.push(lexMapValueEnd) + return lexValue +} + +// lexMapValueEnd consumes the cruft between values of a map. Namely, +// it ignores whitespace and expects either a ',' or a '}'. +func lexMapValueEnd(lx *lexer) stateFn { + r := lx.next() + switch { + case isWhitespace(r): + return lexSkip(lx, lexMapValueEnd) + case r == commentHashStart: + lx.push(lexMapValueEnd) + return lexCommentStart + case r == commentSlashStart: + rn := lx.next() + if rn == commentSlashStart { + lx.push(lexMapValueEnd) + return lexCommentStart + } + lx.backup() + fallthrough + case r == optValTerm || r == mapValTerm || isNL(r): + return lexSkip(lx, lexMapKeyStart) // Move onto next + case r == mapEnd: + return lexSkip(lx, lexMapEnd) + } + return lx.errorf("Expected a map value terminator %q or a map "+ + "terminator %q, but got '%v' instead.", mapValTerm, mapEnd, r) +} + +// lexMapEnd finishes the lexing of a map. It assumes that a '}' has +// just been consumed. +func lexMapEnd(lx *lexer) stateFn { + lx.ignore() + lx.emit(itemMapEnd) + return lx.pop() +} + +// Checks if the unquoted string was actually a boolean +func (lx *lexer) isBool() bool { + str := strings.ToLower(lx.input[lx.start:lx.pos]) + return str == "true" || str == "false" || + str == "on" || str == "off" || + str == "yes" || str == "no" +} + +// Check if the unquoted string is a variable reference, starting with $. +func (lx *lexer) isVariable() bool { + if lx.input[lx.start] == '$' { + lx.start += 1 + return true + } + return false +} + +// lexQuotedString consumes the inner contents of a string. It assumes that the +// beginning '"' has already been consumed and ignored. It will not interpret any +// internal contents. +func lexQuotedString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == sqStringEnd: + lx.backup() + lx.emit(itemString) + lx.next() + lx.ignore() + return lx.pop() + } + return lexQuotedString +} + +// lexDubQuotedString consumes the inner contents of a string. It assumes that the +// beginning '"' has already been consumed and ignored. It will not interpret any +// internal contents. +func lexDubQuotedString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == '\\': + lx.addCurrentStringPart(1) + return lexStringEscape + case r == dqStringEnd: + lx.backup() + lx.emitString() + lx.next() + lx.ignore() + return lx.pop() + } + return lexDubQuotedString +} + +// lexString consumes the inner contents of a raw string. +func lexString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == '\\': + lx.addCurrentStringPart(1) + return lexStringEscape + // Termination of non-quoted strings + case isNL(r) || r == eof || r == optValTerm || + r == arrayValTerm || r == arrayEnd || r == mapEnd || + isWhitespace(r): + + lx.backup() + if lx.hasEscapedParts() { + lx.emitString() + } else if lx.isBool() { + lx.emit(itemBool) + } else if lx.isVariable() { + lx.emit(itemVariable) + } else { + lx.emitString() + } + return lx.pop() + case r == sqStringEnd: + lx.backup() + lx.emitString() + lx.next() + lx.ignore() + return lx.pop() + } + return lexString +} + +// lexBlock consumes the inner contents as a string. It assumes that the +// beginning '(' has already been consumed and ignored. It will continue +// processing until it finds a ')' on a new line by itself. +func lexBlock(lx *lexer) stateFn { + r := lx.next() + switch { + case r == blockEnd: + lx.backup() + lx.backup() + + // Looking for a ')' character on a line by itself, if the previous + // character isn't a new line, then break so we keep processing the block. + if lx.next() != '\n' { + lx.next() + break + } + lx.next() + + // Make sure the next character is a new line or an eof. We want a ')' on a + // bare line by itself. + switch lx.next() { + case '\n', eof: + lx.backup() + lx.backup() + lx.emit(itemString) + lx.next() + lx.ignore() + return lx.pop() + } + lx.backup() + } + return lexBlock +} + +// lexStringEscape consumes an escaped character. It assumes that the preceding +// '\\' has already been consumed. +func lexStringEscape(lx *lexer) stateFn { + r := lx.next() + switch r { + case 'x': + return lexStringBinary + case 't': + return lx.addStringPart("\t") + case 'n': + return lx.addStringPart("\n") + case 'r': + return lx.addStringPart("\r") + case '"': + return lx.addStringPart("\"") + case '\\': + return lx.addStringPart("\\") + } + return lx.errorf("Invalid escape character '%v'. Only the following "+ + "escape characters are allowed: \\xXX, \\t, \\n, \\r, \\\", \\\\.", r) +} + +// lexStringBinary consumes two hexadecimal digits following '\x'. It assumes +// that the '\x' has already been consumed. +func lexStringBinary(lx *lexer) stateFn { + r := lx.next() + if isNL(r) { + return lx.errorf("Expected two hexadecimal digits after '\\x', but hit end of line") + } + r = lx.next() + if isNL(r) { + return lx.errorf("Expected two hexadecimal digits after '\\x', but hit end of line") + } + offset := lx.pos - 2 + byteString, err := hex.DecodeString(lx.input[offset:lx.pos]) + if err != nil { + return lx.errorf("Expected two hexadecimal digits after '\\x', but got '%s'", lx.input[offset:lx.pos]) + } + lx.addStringPart(string(byteString)) + return lx.stringStateFn +} + +// lexNumberOrDateStart consumes either a (positive) integer, a float, a datetime, or IP. +// It assumes that NO negative sign has been consumed, that is triggered above. +func lexNumberOrDateOrIPStart(lx *lexer) stateFn { + r := lx.next() + if !unicode.IsDigit(r) { + if r == '.' { + return lx.errorf("Floats must start with a digit, not '.'.") + } + return lx.errorf("Expected a digit but got '%v'.", r) + } + return lexNumberOrDateOrIP +} + +// lexNumberOrDateOrIP consumes either a (positive) integer, float, datetime or IP. +func lexNumberOrDateOrIP(lx *lexer) stateFn { + r := lx.next() + switch { + case r == '-': + if lx.pos-lx.start != 5 { + return lx.errorf("All ISO8601 dates must be in full Zulu form.") + } + return lexDateAfterYear + case unicode.IsDigit(r): + return lexNumberOrDateOrIP + case r == '.': + return lexFloatStart // Assume float at first, but could be IP + case isNumberSuffix(r): + return lexConvenientNumber + } + + lx.backup() + lx.emit(itemInteger) + return lx.pop() +} + +// lexConvenientNumber is when we have a suffix, e.g. 1k or 1Mb +func lexConvenientNumber(lx *lexer) stateFn { + r := lx.next() + switch { + case r == 'b' || r == 'B': + return lexConvenientNumber + } + lx.backup() + lx.emit(itemInteger) + return lx.pop() +} + +// lexDateAfterYear consumes a full Zulu Datetime in ISO8601 format. +// It assumes that "YYYY-" has already been consumed. +func lexDateAfterYear(lx *lexer) stateFn { + formats := []rune{ + // digits are '0'. + // everything else is direct equality. + '0', '0', '-', '0', '0', + 'T', + '0', '0', ':', '0', '0', ':', '0', '0', + 'Z', + } + for _, f := range formats { + r := lx.next() + if f == '0' { + if !unicode.IsDigit(r) { + return lx.errorf("Expected digit in ISO8601 datetime, "+ + "but found '%v' instead.", r) + } + } else if f != r { + return lx.errorf("Expected '%v' in ISO8601 datetime, "+ + "but found '%v' instead.", f, r) + } + } + lx.emit(itemDatetime) + return lx.pop() +} + +// lexNegNumberStart consumes either an integer or a float. It assumes that a +// negative sign has already been read, but that *no* digits have been consumed. +// lexNegNumberStart will move to the appropriate integer or float states. +func lexNegNumberStart(lx *lexer) stateFn { + // we MUST see a digit. Even floats have to start with a digit. + r := lx.next() + if !unicode.IsDigit(r) { + if r == '.' { + return lx.errorf("Floats must start with a digit, not '.'.") + } + return lx.errorf("Expected a digit but got '%v'.", r) + } + return lexNegNumber +} + +// lexNumber consumes a negative integer or a float after seeing the first digit. +func lexNegNumber(lx *lexer) stateFn { + r := lx.next() + switch { + case unicode.IsDigit(r): + return lexNegNumber + case r == '.': + return lexFloatStart + case isNumberSuffix(r): + return lexConvenientNumber + } + lx.backup() + lx.emit(itemInteger) + return lx.pop() +} + +// lexFloatStart starts the consumption of digits of a float after a '.'. +// Namely, at least one digit is required. +func lexFloatStart(lx *lexer) stateFn { + r := lx.next() + if !unicode.IsDigit(r) { + return lx.errorf("Floats must have a digit after the '.', but got "+ + "'%v' instead.", r) + } + return lexFloat +} + +// lexFloat consumes the digits of a float after a '.'. +// Assumes that one digit has been consumed after a '.' already. +func lexFloat(lx *lexer) stateFn { + r := lx.next() + if unicode.IsDigit(r) { + return lexFloat + } + + // Not a digit, if its another '.', need to see if we falsely assumed a float. + if r == '.' { + return lexIPAddr + } + + lx.backup() + lx.emit(itemFloat) + return lx.pop() +} + +// lexIPAddr consumes IP addrs, like 127.0.0.1:4222 +func lexIPAddr(lx *lexer) stateFn { + r := lx.next() + if unicode.IsDigit(r) || r == '.' || r == ':' || r == '-' { + return lexIPAddr + } + lx.backup() + lx.emit(itemString) + return lx.pop() +} + +// lexCommentStart begins the lexing of a comment. It will emit +// itemCommentStart and consume no characters, passing control to lexComment. +func lexCommentStart(lx *lexer) stateFn { + lx.ignore() + lx.emit(itemCommentStart) + return lexComment +} + +// lexComment lexes an entire comment. It assumes that '#' has been consumed. +// It will consume *up to* the first new line character, and pass control +// back to the last state on the stack. +func lexComment(lx *lexer) stateFn { + r := lx.peek() + if isNL(r) || r == eof { + lx.emit(itemText) + return lx.pop() + } + lx.next() + return lexComment +} + +// lexSkip ignores all slurped input and moves on to the next state. +func lexSkip(lx *lexer, nextState stateFn) stateFn { + return func(lx *lexer) stateFn { + lx.ignore() + return nextState + } +} + +// Tests to see if we have a number suffix +func isNumberSuffix(r rune) bool { + return r == 'k' || r == 'K' || r == 'm' || r == 'M' || r == 'g' || r == 'G' +} + +// Tests for both key separators +func isKeySeparator(r rune) bool { + return r == keySepEqual || r == keySepColon +} + +// isWhitespace returns true if `r` is a whitespace character according +// to the spec. +func isWhitespace(r rune) bool { + return r == '\t' || r == ' ' +} + +func isNL(r rune) bool { + return r == '\n' || r == '\r' +} + +func (itype itemType) String() string { + switch itype { + case itemError: + return "Error" + case itemNIL: + return "NIL" + case itemEOF: + return "EOF" + case itemText: + return "Text" + case itemString: + return "String" + case itemBool: + return "Bool" + case itemInteger: + return "Integer" + case itemFloat: + return "Float" + case itemDatetime: + return "DateTime" + case itemKey: + return "Key" + case itemArrayStart: + return "ArrayStart" + case itemArrayEnd: + return "ArrayEnd" + case itemMapStart: + return "MapStart" + case itemMapEnd: + return "MapEnd" + case itemCommentStart: + return "CommentStart" + case itemVariable: + return "Variable" + case itemInclude: + return "Include" + } + panic(fmt.Sprintf("BUG: Unknown type '%s'.", itype.String())) +} + +func (item item) String() string { + return fmt.Sprintf("(%s, '%s', %d)", item.typ.String(), item.val, item.line) +} + +func escapeSpecial(c rune) string { + switch c { + case '\n': + return "\\n" + } + return string(c) +} diff --git a/vendor/github.com/nats-io/gnatsd/conf/lex_test.go b/vendor/github.com/nats-io/gnatsd/conf/lex_test.go new file mode 100644 index 00000000..28cce804 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/conf/lex_test.go @@ -0,0 +1,1224 @@ +package conf + +import "testing" + +// Test to make sure we get what we expect. +func expect(t *testing.T, lx *lexer, items []item) { + for i := 0; i < len(items); i++ { + item := lx.nextItem() + _ = item.String() + if item.typ == itemEOF { + break + } + if item != items[i] { + t.Fatalf("Testing: '%s'\nExpected %q, received %q\n", + lx.input, items[i], item) + } + if item.typ == itemError { + break + } + } +} + +func TestPlainValue(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo") + expect(t, lx, expectedItems) +} + +func TestSimpleKeyStringValues(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "bar", 1}, + {itemEOF, "", 1}, + } + // Double quotes + lx := lex("foo = \"bar\"") + expect(t, lx, expectedItems) + // Single quotes + lx = lex("foo = 'bar'") + expect(t, lx, expectedItems) + // No spaces + lx = lex("foo='bar'") + expect(t, lx, expectedItems) + // NL + lx = lex("foo='bar'\r\n") + expect(t, lx, expectedItems) + lx = lex("foo=\t'bar'\t") + expect(t, lx, expectedItems) +} + +func TestComplexStringValues(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "bar\\r\\n \\t", 1}, + {itemEOF, "", 2}, + } + + lx := lex("foo = 'bar\\r\\n \\t'") + expect(t, lx, expectedItems) +} + +func TestBinaryString(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "e", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo = \\x65") + expect(t, lx, expectedItems) +} + +func TestBinaryStringLatin1(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "\xe9", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo = \\xe9") + expect(t, lx, expectedItems) +} + +func TestSimpleKeyIntegerValues(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemInteger, "123", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo = 123") + expect(t, lx, expectedItems) + lx = lex("foo=123") + expect(t, lx, expectedItems) + lx = lex("foo=123\r\n") + expect(t, lx, expectedItems) +} + +func TestSimpleKeyNegativeIntegerValues(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemInteger, "-123", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo = -123") + expect(t, lx, expectedItems) + lx = lex("foo=-123") + expect(t, lx, expectedItems) + lx = lex("foo=-123\r\n") + expect(t, lx, expectedItems) +} + +func TestConvenientIntegerValues(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemInteger, "1k", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo = 1k") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "foo", 1}, + {itemInteger, "1K", 1}, + {itemEOF, "", 1}, + } + lx = lex("foo = 1K") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "foo", 1}, + {itemInteger, "1m", 1}, + {itemEOF, "", 1}, + } + lx = lex("foo = 1m") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "foo", 1}, + {itemInteger, "1M", 1}, + {itemEOF, "", 1}, + } + lx = lex("foo = 1M") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "foo", 1}, + {itemInteger, "1g", 1}, + {itemEOF, "", 1}, + } + lx = lex("foo = 1g") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "foo", 1}, + {itemInteger, "1G", 1}, + {itemEOF, "", 1}, + } + lx = lex("foo = 1G") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "foo", 1}, + {itemInteger, "1MB", 1}, + {itemEOF, "", 1}, + } + lx = lex("foo = 1MB") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "foo", 1}, + {itemInteger, "1Gb", 1}, + {itemEOF, "", 1}, + } + lx = lex("foo = 1Gb") + expect(t, lx, expectedItems) + + // Negative versions + expectedItems = []item{ + {itemKey, "foo", 1}, + {itemInteger, "-1m", 1}, + {itemEOF, "", 1}, + } + lx = lex("foo = -1m") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "foo", 1}, + {itemInteger, "-1GB", 1}, + {itemEOF, "", 1}, + } + lx = lex("foo = -1GB ") + expect(t, lx, expectedItems) +} + +func TestSimpleKeyFloatValues(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemFloat, "22.2", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo = 22.2") + expect(t, lx, expectedItems) + lx = lex("foo=22.2") + expect(t, lx, expectedItems) + lx = lex("foo=22.2\r\n") + expect(t, lx, expectedItems) +} + +func TestBadBinaryStringEndingAfterZeroHexChars(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemError, "Expected two hexadecimal digits after '\\x', but hit end of line", 2}, + {itemEOF, "", 1}, + } + lx := lex("foo = xyz\\x\n") + expect(t, lx, expectedItems) +} + +func TestBadBinaryStringEndingAfterOneHexChar(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemError, "Expected two hexadecimal digits after '\\x', but hit end of line", 2}, + {itemEOF, "", 1}, + } + lx := lex("foo = xyz\\xF\n") + expect(t, lx, expectedItems) +} + +func TestBadBinaryStringWithZeroHexChars(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemError, "Expected two hexadecimal digits after '\\x', but got ']\"'", 1}, + {itemEOF, "", 1}, + } + lx := lex(`foo = "[\x]"`) + expect(t, lx, expectedItems) +} + +func TestBadBinaryStringWithOneHexChar(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemError, "Expected two hexadecimal digits after '\\x', but got 'e]'", 1}, + {itemEOF, "", 1}, + } + lx := lex(`foo = "[\xe]"`) + expect(t, lx, expectedItems) +} + +func TestBadFloatValues(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemError, "Floats must start with a digit", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo = .2") + expect(t, lx, expectedItems) +} + +func TestBadKey(t *testing.T) { + expectedItems := []item{ + {itemError, "Unexpected key separator ':'", 1}, + {itemEOF, "", 1}, + } + lx := lex(" :foo = 22") + expect(t, lx, expectedItems) +} + +func TestSimpleKeyBoolValues(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemBool, "true", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo = true") + expect(t, lx, expectedItems) + lx = lex("foo=true") + expect(t, lx, expectedItems) + lx = lex("foo=true\r\n") + expect(t, lx, expectedItems) +} + +func TestComments(t *testing.T) { + expectedItems := []item{ + {itemCommentStart, "", 1}, + {itemText, " This is a comment", 1}, + {itemEOF, "", 1}, + } + lx := lex("# This is a comment") + expect(t, lx, expectedItems) + lx = lex("# This is a comment\r\n") + expect(t, lx, expectedItems) + lx = lex("// This is a comment\r\n") + expect(t, lx, expectedItems) +} + +func TestTopValuesWithComments(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemInteger, "123", 1}, + {itemCommentStart, "", 1}, + {itemText, " This is a comment", 1}, + {itemEOF, "", 1}, + } + + lx := lex("foo = 123 // This is a comment") + expect(t, lx, expectedItems) + lx = lex("foo=123 # This is a comment") + expect(t, lx, expectedItems) +} + +func TestRawString(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "bar", 1}, + {itemEOF, "", 1}, + } + + lx := lex("foo = bar") + expect(t, lx, expectedItems) + + lx = lex(`foo = bar' `) //'single-quote for emacs TODO: Remove me + expect(t, lx, expectedItems) +} + +func TestDateValues(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemDatetime, "2016-05-04T18:53:41Z", 1}, + {itemEOF, "", 1}, + } + + lx := lex("foo = 2016-05-04T18:53:41Z") + expect(t, lx, expectedItems) +} + +func TestVariableValues(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemVariable, "bar", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo = $bar") + expect(t, lx, expectedItems) + lx = lex("foo =$bar") + expect(t, lx, expectedItems) + lx = lex("foo $bar") + expect(t, lx, expectedItems) +} + +func TestArrays(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemArrayStart, "", 1}, + {itemInteger, "1", 1}, + {itemInteger, "2", 1}, + {itemInteger, "3", 1}, + {itemString, "bar", 1}, + {itemArrayEnd, "", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo = [1, 2, 3, 'bar']") + expect(t, lx, expectedItems) + lx = lex("foo = [1,2,3,'bar']") + expect(t, lx, expectedItems) + lx = lex("foo = [1, 2,3,'bar']") + expect(t, lx, expectedItems) +} + +var mlArray = ` +# top level comment +foo = [ + 1, # One + 2, // Two + 3 # Three + 'bar' , + "bar" +] +` + +func TestMultilineArrays(t *testing.T) { + expectedItems := []item{ + {itemCommentStart, "", 2}, + {itemText, " top level comment", 2}, + {itemKey, "foo", 3}, + {itemArrayStart, "", 3}, + {itemInteger, "1", 4}, + {itemCommentStart, "", 4}, + {itemText, " One", 4}, + {itemInteger, "2", 5}, + {itemCommentStart, "", 5}, + {itemText, " Two", 5}, + {itemInteger, "3", 6}, + {itemCommentStart, "", 6}, + {itemText, " Three", 6}, + {itemString, "bar", 7}, + {itemString, "bar", 8}, + {itemArrayEnd, "", 9}, + {itemEOF, "", 9}, + } + lx := lex(mlArray) + expect(t, lx, expectedItems) +} + +var mlArrayNoSep = ` +# top level comment +foo = [ + 1 // foo + 2 + 3 + 'bar' + "bar" +] +` + +func TestMultilineArraysNoSep(t *testing.T) { + expectedItems := []item{ + {itemCommentStart, "", 2}, + {itemText, " top level comment", 2}, + {itemKey, "foo", 3}, + {itemArrayStart, "", 3}, + {itemInteger, "1", 4}, + {itemCommentStart, "", 4}, + {itemText, " foo", 4}, + {itemInteger, "2", 5}, + {itemInteger, "3", 6}, + {itemString, "bar", 7}, + {itemString, "bar", 8}, + {itemArrayEnd, "", 9}, + {itemEOF, "", 9}, + } + lx := lex(mlArrayNoSep) + expect(t, lx, expectedItems) +} + +func TestSimpleMap(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemMapStart, "", 1}, + {itemKey, "ip", 1}, + {itemString, "127.0.0.1", 1}, + {itemKey, "port", 1}, + {itemInteger, "4242", 1}, + {itemMapEnd, "", 1}, + {itemEOF, "", 1}, + } + + lx := lex("foo = {ip='127.0.0.1', port = 4242}") + expect(t, lx, expectedItems) +} + +var mlMap = ` +foo = { + ip = '127.0.0.1' # the IP + port= 4242 // the port +} +` + +func TestMultilineMap(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 2}, + {itemMapStart, "", 2}, + {itemKey, "ip", 3}, + {itemString, "127.0.0.1", 3}, + {itemCommentStart, "", 3}, + {itemText, " the IP", 3}, + {itemKey, "port", 4}, + {itemInteger, "4242", 4}, + {itemCommentStart, "", 4}, + {itemText, " the port", 4}, + {itemMapEnd, "", 5}, + {itemEOF, "", 5}, + } + + lx := lex(mlMap) + expect(t, lx, expectedItems) +} + +var nestedMap = ` +foo = { + host = { + ip = '127.0.0.1' + port= 4242 + } +} +` + +func TestNestedMaps(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 2}, + {itemMapStart, "", 2}, + {itemKey, "host", 3}, + {itemMapStart, "", 3}, + {itemKey, "ip", 4}, + {itemString, "127.0.0.1", 4}, + {itemKey, "port", 5}, + {itemInteger, "4242", 5}, + {itemMapEnd, "", 6}, + {itemMapEnd, "", 7}, + {itemEOF, "", 5}, + } + + lx := lex(nestedMap) + expect(t, lx, expectedItems) +} + +func TestQuotedKeys(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemInteger, "123", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo : 123") + expect(t, lx, expectedItems) + lx = lex("'foo' : 123") + expect(t, lx, expectedItems) + lx = lex("\"foo\" : 123") + expect(t, lx, expectedItems) +} + +func TestQuotedKeysWithSpace(t *testing.T) { + expectedItems := []item{ + {itemKey, " foo", 1}, + {itemInteger, "123", 1}, + {itemEOF, "", 1}, + } + lx := lex("' foo' : 123") + expect(t, lx, expectedItems) + lx = lex("\" foo\" : 123") + expect(t, lx, expectedItems) +} + +func TestColonKeySep(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemInteger, "123", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo : 123") + expect(t, lx, expectedItems) + lx = lex("foo:123") + expect(t, lx, expectedItems) + lx = lex("foo: 123") + expect(t, lx, expectedItems) + lx = lex("foo: 123\r\n") + expect(t, lx, expectedItems) +} + +func TestWhitespaceKeySep(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemInteger, "123", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo 123") + expect(t, lx, expectedItems) + lx = lex("foo 123") + expect(t, lx, expectedItems) + lx = lex("foo\t123") + expect(t, lx, expectedItems) + lx = lex("foo\t\t123\r\n") + expect(t, lx, expectedItems) +} + +var escString = ` +foo = \t +bar = \r +baz = \n +q = \" +bs = \\ +` + +func TestEscapedString(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 2}, + {itemString, "\t", 2}, + {itemKey, "bar", 3}, + {itemString, "\r", 3}, + {itemKey, "baz", 4}, + {itemString, "\n", 4}, + {itemKey, "q", 5}, + {itemString, "\"", 5}, + {itemKey, "bs", 6}, + {itemString, "\\", 6}, + {itemEOF, "", 6}, + } + lx := lex(escString) + expect(t, lx, expectedItems) +} + +func TestCompoundStringES(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "\\end", 1}, + {itemEOF, "", 2}, + } + lx := lex(`foo = "\\end"`) + expect(t, lx, expectedItems) +} + +func TestCompoundStringSE(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "start\\", 1}, + {itemEOF, "", 2}, + } + lx := lex(`foo = "start\\"`) + expect(t, lx, expectedItems) +} + +func TestCompoundStringEE(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "Eq", 1}, + {itemEOF, "", 2}, + } + lx := lex(`foo = \x45\x71`) + expect(t, lx, expectedItems) +} + +func TestCompoundStringSEE(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "startEq", 1}, + {itemEOF, "", 2}, + } + lx := lex(`foo = start\x45\x71`) + expect(t, lx, expectedItems) +} + +func TestCompoundStringSES(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "start|end", 1}, + {itemEOF, "", 2}, + } + lx := lex(`foo = start\x7Cend`) + expect(t, lx, expectedItems) +} + +func TestCompoundStringEES(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "<>end", 1}, + {itemEOF, "", 2}, + } + lx := lex(`foo = \x3c\x3eend`) + expect(t, lx, expectedItems) +} + +func TestCompoundStringESE(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "", 1}, + {itemEOF, "", 2}, + } + lx := lex(`foo = \x3cmiddle\x3E`) + expect(t, lx, expectedItems) +} + +func TestBadStringEscape(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemError, "Invalid escape character 'y'. Only the following escape characters are allowed: \\xXX, \\t, \\n, \\r, \\\", \\\\.", 1}, + {itemEOF, "", 2}, + } + lx := lex(`foo = \y`) + expect(t, lx, expectedItems) +} + +func TestNonBool(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "\\true", 1}, + {itemEOF, "", 2}, + } + lx := lex(`foo = \\true`) + expect(t, lx, expectedItems) +} + +func TestNonVariable(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "\\$var", 1}, + {itemEOF, "", 2}, + } + lx := lex(`foo = \\$var`) + expect(t, lx, expectedItems) +} + +func TestEmptyStringDQ(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "", 1}, + {itemEOF, "", 2}, + } + lx := lex(`foo = ""`) + expect(t, lx, expectedItems) +} + +func TestEmptyStringSQ(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "", 1}, + {itemEOF, "", 2}, + } + lx := lex(`foo = ''`) + expect(t, lx, expectedItems) +} + +var nestedWhitespaceMap = ` +foo { + host { + ip = '127.0.0.1' + port= 4242 + } +} +` + +func TestNestedWhitespaceMaps(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 2}, + {itemMapStart, "", 2}, + {itemKey, "host", 3}, + {itemMapStart, "", 3}, + {itemKey, "ip", 4}, + {itemString, "127.0.0.1", 4}, + {itemKey, "port", 5}, + {itemInteger, "4242", 5}, + {itemMapEnd, "", 6}, + {itemMapEnd, "", 7}, + {itemEOF, "", 5}, + } + + lx := lex(nestedWhitespaceMap) + expect(t, lx, expectedItems) +} + +var semicolons = ` +foo = 123; +bar = 'baz'; +baz = 'boo' +map { + id = 1; +} +` + +func TestOptionalSemicolons(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 2}, + {itemInteger, "123", 2}, + {itemKey, "bar", 3}, + {itemString, "baz", 3}, + {itemKey, "baz", 4}, + {itemString, "boo", 4}, + {itemKey, "map", 5}, + {itemMapStart, "", 5}, + {itemKey, "id", 6}, + {itemInteger, "1", 6}, + {itemMapEnd, "", 7}, + {itemEOF, "", 5}, + } + + lx := lex(semicolons) + expect(t, lx, expectedItems) +} + +func TestSemicolonChaining(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemString, "1", 1}, + {itemKey, "bar", 1}, + {itemFloat, "2.2", 1}, + {itemKey, "baz", 1}, + {itemBool, "true", 1}, + {itemEOF, "", 1}, + } + + lx := lex("foo='1'; bar=2.2; baz=true;") + expect(t, lx, expectedItems) +} + +var noquotes = ` +foo = 123 +bar = baz +baz=boo +map { + id:one + id2 : onetwo +} +t true +f false +tstr "true" +tkey = two +fkey = five # This should be a string +` + +func TestNonQuotedStrings(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 2}, + {itemInteger, "123", 2}, + {itemKey, "bar", 3}, + {itemString, "baz", 3}, + {itemKey, "baz", 4}, + {itemString, "boo", 4}, + {itemKey, "map", 5}, + {itemMapStart, "", 5}, + {itemKey, "id", 6}, + {itemString, "one", 6}, + {itemKey, "id2", 7}, + {itemString, "onetwo", 7}, + {itemMapEnd, "", 8}, + {itemKey, "t", 9}, + {itemBool, "true", 9}, + {itemKey, "f", 10}, + {itemBool, "false", 10}, + {itemKey, "tstr", 11}, + {itemString, "true", 11}, + {itemKey, "tkey", 12}, + {itemString, "two", 12}, + {itemKey, "fkey", 13}, + {itemString, "five", 13}, + {itemCommentStart, "", 13}, + {itemText, " This should be a string", 13}, + + {itemEOF, "", 14}, + } + lx := lex(noquotes) + expect(t, lx, expectedItems) +} + +func TestMapQuotedKeys(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemMapStart, "", 1}, + {itemKey, "bar", 1}, + {itemInteger, "4242", 1}, + {itemMapEnd, "", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo = {'bar' = 4242}") + expect(t, lx, expectedItems) + lx = lex("foo = {\"bar\" = 4242}") + expect(t, lx, expectedItems) +} + +func TestSpecialCharsMapQuotedKeys(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemMapStart, "", 1}, + {itemKey, "bar-1.2.3", 1}, + {itemMapStart, "", 1}, + {itemKey, "port", 1}, + {itemInteger, "4242", 1}, + {itemMapEnd, "", 1}, + {itemMapEnd, "", 1}, + {itemEOF, "", 1}, + } + lx := lex("foo = {'bar-1.2.3' = { port:4242 }}") + expect(t, lx, expectedItems) + lx = lex("foo = {\"bar-1.2.3\" = { port:4242 }}") + expect(t, lx, expectedItems) +} + +var mlnestedmap = ` +systems { + allinone { + description: "This is a description." + } +} +` + +func TestDoubleNestedMapsNewLines(t *testing.T) { + expectedItems := []item{ + {itemKey, "systems", 2}, + {itemMapStart, "", 2}, + {itemKey, "allinone", 3}, + {itemMapStart, "", 3}, + {itemKey, "description", 4}, + {itemString, "This is a description.", 4}, + {itemMapEnd, "", 5}, + {itemMapEnd, "", 6}, + {itemEOF, "", 7}, + } + lx := lex(mlnestedmap) + expect(t, lx, expectedItems) +} + +var blockexample = ` +numbers ( +1234567890 +) +` + +func TestBlockString(t *testing.T) { + expectedItems := []item{ + {itemKey, "numbers", 2}, + {itemString, "\n1234567890\n", 4}, + } + lx := lex(blockexample) + expect(t, lx, expectedItems) +} + +func TestBlockStringEOF(t *testing.T) { + expectedItems := []item{ + {itemKey, "numbers", 2}, + {itemString, "\n1234567890\n", 4}, + } + blockbytes := []byte(blockexample[0 : len(blockexample)-1]) + blockbytes = append(blockbytes, 0) + lx := lex(string(blockbytes)) + expect(t, lx, expectedItems) +} + +var mlblockexample = ` +numbers ( + 12(34)56 + ( + 7890 + ) +) +` + +func TestBlockStringMultiLine(t *testing.T) { + expectedItems := []item{ + {itemKey, "numbers", 2}, + {itemString, "\n 12(34)56\n (\n 7890\n )\n", 7}, + } + lx := lex(mlblockexample) + expect(t, lx, expectedItems) +} + +func TestUnquotedIPAddr(t *testing.T) { + expectedItems := []item{ + {itemKey, "listen", 1}, + {itemString, "127.0.0.1:4222", 1}, + {itemEOF, "", 1}, + } + lx := lex("listen: 127.0.0.1:4222") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "listen", 1}, + {itemString, "127.0.0.1", 1}, + {itemEOF, "", 1}, + } + lx = lex("listen: 127.0.0.1") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "listen", 1}, + {itemString, "apcera.me:80", 1}, + {itemEOF, "", 1}, + } + lx = lex("listen: apcera.me:80") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "listen", 1}, + {itemString, "nats.io:-1", 1}, + {itemEOF, "", 1}, + } + lx = lex("listen: nats.io:-1") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "listen", 1}, + {itemInteger, "-1", 1}, + {itemEOF, "", 1}, + } + lx = lex("listen: -1") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "listen", 1}, + {itemString, ":-1", 1}, + {itemEOF, "", 1}, + } + lx = lex("listen: :-1") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "listen", 1}, + {itemString, ":80", 1}, + {itemEOF, "", 1}, + } + lx = lex("listen = :80") + expect(t, lx, expectedItems) + + expectedItems = []item{ + {itemKey, "listen", 1}, + {itemArrayStart, "", 1}, + {itemString, "localhost:4222", 1}, + {itemString, "localhost:4333", 1}, + {itemArrayEnd, "", 1}, + {itemEOF, "", 1}, + } + lx = lex("listen = [localhost:4222, localhost:4333]") + expect(t, lx, expectedItems) +} + +var arrayOfMaps = ` +authorization { + users = [ + {user: alice, password: foo} + {user: bob, password: bar} + ] + timeout: 0.5 +} +` + +func TestArrayOfMaps(t *testing.T) { + expectedItems := []item{ + {itemKey, "authorization", 2}, + {itemMapStart, "", 2}, + {itemKey, "users", 3}, + {itemArrayStart, "", 3}, + {itemMapStart, "", 4}, + {itemKey, "user", 4}, + {itemString, "alice", 4}, + {itemKey, "password", 4}, + {itemString, "foo", 4}, + {itemMapEnd, "", 4}, + {itemMapStart, "", 5}, + {itemKey, "user", 5}, + {itemString, "bob", 5}, + {itemKey, "password", 5}, + {itemString, "bar", 5}, + {itemMapEnd, "", 5}, + {itemArrayEnd, "", 6}, + {itemKey, "timeout", 7}, + {itemFloat, "0.5", 7}, + {itemMapEnd, "", 8}, + {itemEOF, "", 9}, + } + lx := lex(arrayOfMaps) + expect(t, lx, expectedItems) +} + +func TestInclude(t *testing.T) { + expectedItems := []item{ + {itemInclude, "users.conf", 1}, + {itemEOF, "", 1}, + } + lx := lex("include \"users.conf\"") + expect(t, lx, expectedItems) + + lx = lex("include 'users.conf'") + expect(t, lx, expectedItems) + + lx = lex("include users.conf") + expect(t, lx, expectedItems) +} + +func TestMapInclude(t *testing.T) { + expectedItems := []item{ + {itemKey, "foo", 1}, + {itemMapStart, "", 1}, + {itemInclude, "users.conf", 1}, + {itemMapEnd, "", 1}, + {itemEOF, "", 1}, + } + + lx := lex("foo { include users.conf }") + expect(t, lx, expectedItems) + + lx = lex("foo {include users.conf}") + expect(t, lx, expectedItems) + + lx = lex("foo { include 'users.conf' }") + expect(t, lx, expectedItems) + + lx = lex("foo { include \"users.conf\"}") + expect(t, lx, expectedItems) +} + +func TestJSONCompat(t *testing.T) { + for _, test := range []struct { + name string + input string + expected []item + }{ + { + name: "should omit initial and final brackets at top level with a single item", + input: ` + { + "http_port": 8223 + } + `, + expected: []item{ + {itemKey, "http_port", 3}, + {itemInteger, "8223", 3}, + }, + }, + { + name: "should omit trailing commas at top level with two items", + input: ` + { + "http_port": 8223, + "port": 4223 + } + `, + expected: []item{ + {itemKey, "http_port", 3}, + {itemInteger, "8223", 3}, + {itemKey, "port", 4}, + {itemInteger, "4223", 4}, + }, + }, + { + name: "should omit trailing commas at top level with multiple items", + input: ` + { + "http_port": 8223, + "port": 4223, + "max_payload": "5MB", + "debug": true, + "max_control_line": 1024 + } + `, + expected: []item{ + {itemKey, "http_port", 3}, + {itemInteger, "8223", 3}, + {itemKey, "port", 4}, + {itemInteger, "4223", 4}, + {itemKey, "max_payload", 5}, + {itemString, "5MB", 5}, + {itemKey, "debug", 6}, + {itemBool, "true", 6}, + {itemKey, "max_control_line", 7}, + {itemInteger, "1024", 7}, + }, + }, + { + name: "should support JSON not prettified", + input: `{"http_port": 8224,"port": 4224} + `, + expected: []item{ + {itemKey, "http_port", 1}, + {itemInteger, "8224", 1}, + {itemKey, "port", 1}, + {itemInteger, "4224", 1}, + }, + }, + { + name: "should support JSON not prettified with final bracket after newline", + input: `{"http_port": 8225,"port": 4225 + } + `, + expected: []item{ + {itemKey, "http_port", 1}, + {itemInteger, "8225", 1}, + {itemKey, "port", 1}, + {itemInteger, "4225", 1}, + }, + }, + { + name: "should support uglified JSON with inner blocks", + input: `{"http_port": 8227,"port": 4227,"write_deadline": "1h","cluster": {"port": 6222,"routes": ["nats://127.0.0.1:4222","nats://127.0.0.1:4223","nats://127.0.0.1:4224"]}} + `, + expected: []item{ + {itemKey, "http_port", 1}, + {itemInteger, "8227", 1}, + {itemKey, "port", 1}, + {itemInteger, "4227", 1}, + {itemKey, "write_deadline", 1}, + {itemString, "1h", 1}, + {itemKey, "cluster", 1}, + {itemMapStart, "", 1}, + {itemKey, "port", 1}, + {itemInteger, "6222", 1}, + {itemKey, "routes", 1}, + {itemArrayStart, "", 1}, + {itemString, "nats://127.0.0.1:4222", 1}, + {itemString, "nats://127.0.0.1:4223", 1}, + {itemString, "nats://127.0.0.1:4224", 1}, + {itemArrayEnd, "", 1}, + {itemMapEnd, "", 1}, + }, + }, + { + name: "should support prettified JSON with inner blocks", + input: ` + { + "http_port": 8227, + "port": 4227, + "write_deadline": "1h", + "cluster": { + "port": 6222, + "routes": [ + "nats://127.0.0.1:4222", + "nats://127.0.0.1:4223", + "nats://127.0.0.1:4224" + ] + } + } + `, + expected: []item{ + {itemKey, "http_port", 3}, + {itemInteger, "8227", 3}, + {itemKey, "port", 4}, + {itemInteger, "4227", 4}, + {itemKey, "write_deadline", 5}, + {itemString, "1h", 5}, + {itemKey, "cluster", 6}, + {itemMapStart, "", 6}, + {itemKey, "port", 7}, + {itemInteger, "6222", 7}, + {itemKey, "routes", 8}, + {itemArrayStart, "", 8}, + {itemString, "nats://127.0.0.1:4222", 9}, + {itemString, "nats://127.0.0.1:4223", 10}, + {itemString, "nats://127.0.0.1:4224", 11}, + {itemArrayEnd, "", 12}, + {itemMapEnd, "", 13}, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + lx := lex(test.input) + expect(t, lx, test.expected) + }) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/conf/parse.go b/vendor/github.com/nats-io/gnatsd/conf/parse.go new file mode 100644 index 00000000..09205ae0 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/conf/parse.go @@ -0,0 +1,295 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package conf supports a configuration file format used by gnatsd. It is +// a flexible format that combines the best of traditional +// configuration formats and newer styles such as JSON and YAML. +package conf + +// The format supported is less restrictive than today's formats. +// Supports mixed Arrays [], nested Maps {}, multiple comment types (# and //) +// Also supports key value assigments using '=' or ':' or whiteSpace() +// e.g. foo = 2, foo : 2, foo 2 +// maps can be assigned with no key separator as well +// semicolons as value terminators in key/value assignments are optional +// +// see parse_test.go for more examples. + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "strings" + "time" + "unicode" +) + +type parser struct { + mapping map[string]interface{} + lx *lexer + + // The current scoped context, can be array or map + ctx interface{} + + // stack of contexts, either map or array/slice stack + ctxs []interface{} + + // Keys stack + keys []string + + // The config file path, empty by default. + fp string +} + +// Parse will return a map of keys to interface{}, although concrete types +// underly them. The values supported are string, bool, int64, float64, DateTime. +// Arrays and nested Maps are also supported. +func Parse(data string) (map[string]interface{}, error) { + p, err := parse(data, "") + if err != nil { + return nil, err + } + return p.mapping, nil +} + +// ParseFile is a helper to open file, etc. and parse the contents. +func ParseFile(fp string) (map[string]interface{}, error) { + data, err := ioutil.ReadFile(fp) + if err != nil { + return nil, fmt.Errorf("error opening config file: %v", err) + } + + p, err := parse(string(data), filepath.Dir(fp)) + if err != nil { + return nil, err + } + return p.mapping, nil +} + +func parse(data, fp string) (p *parser, err error) { + p = &parser{ + mapping: make(map[string]interface{}), + lx: lex(data), + ctxs: make([]interface{}, 0, 4), + keys: make([]string, 0, 4), + fp: fp, + } + p.pushContext(p.mapping) + + for { + it := p.next() + if it.typ == itemEOF { + break + } + if err := p.processItem(it); err != nil { + return nil, err + } + } + + return p, nil +} + +func (p *parser) next() item { + return p.lx.nextItem() +} + +func (p *parser) pushContext(ctx interface{}) { + p.ctxs = append(p.ctxs, ctx) + p.ctx = ctx +} + +func (p *parser) popContext() interface{} { + if len(p.ctxs) == 0 { + panic("BUG in parser, context stack empty") + } + li := len(p.ctxs) - 1 + last := p.ctxs[li] + p.ctxs = p.ctxs[0:li] + p.ctx = p.ctxs[len(p.ctxs)-1] + return last +} + +func (p *parser) pushKey(key string) { + p.keys = append(p.keys, key) +} + +func (p *parser) popKey() string { + if len(p.keys) == 0 { + panic("BUG in parser, keys stack empty") + } + li := len(p.keys) - 1 + last := p.keys[li] + p.keys = p.keys[0:li] + return last +} + +func (p *parser) processItem(it item) error { + switch it.typ { + case itemError: + return fmt.Errorf("Parse error on line %d: '%s'", it.line, it.val) + case itemKey: + p.pushKey(it.val) + case itemMapStart: + newCtx := make(map[string]interface{}) + p.pushContext(newCtx) + case itemMapEnd: + p.setValue(p.popContext()) + case itemString: + p.setValue(it.val) // FIXME(dlc) sanitize string? + case itemInteger: + lastDigit := 0 + for _, r := range it.val { + if !unicode.IsDigit(r) && r != '-' { + break + } + lastDigit++ + } + numStr := it.val[:lastDigit] + num, err := strconv.ParseInt(numStr, 10, 64) + if err != nil { + if e, ok := err.(*strconv.NumError); ok && + e.Err == strconv.ErrRange { + return fmt.Errorf("Integer '%s' is out of the range.", it.val) + } + return fmt.Errorf("Expected integer, but got '%s'.", it.val) + } + // Process a suffix + suffix := strings.ToLower(strings.TrimSpace(it.val[lastDigit:])) + switch suffix { + case "": + p.setValue(num) + case "k": + p.setValue(num * 1000) + case "kb": + p.setValue(num * 1024) + case "m": + p.setValue(num * 1000 * 1000) + case "mb": + p.setValue(num * 1024 * 1024) + case "g": + p.setValue(num * 1000 * 1000 * 1000) + case "gb": + p.setValue(num * 1024 * 1024 * 1024) + } + case itemFloat: + num, err := strconv.ParseFloat(it.val, 64) + if err != nil { + if e, ok := err.(*strconv.NumError); ok && + e.Err == strconv.ErrRange { + return fmt.Errorf("Float '%s' is out of the range.", it.val) + } + return fmt.Errorf("Expected float, but got '%s'.", it.val) + } + p.setValue(num) + case itemBool: + switch strings.ToLower(it.val) { + case "true", "yes", "on": + p.setValue(true) + case "false", "no", "off": + p.setValue(false) + default: + return fmt.Errorf("Expected boolean value, but got '%s'.", it.val) + } + case itemDatetime: + dt, err := time.Parse("2006-01-02T15:04:05Z", it.val) + if err != nil { + return fmt.Errorf( + "Expected Zulu formatted DateTime, but got '%s'.", it.val) + } + p.setValue(dt) + case itemArrayStart: + var array = make([]interface{}, 0) + p.pushContext(array) + case itemArrayEnd: + array := p.ctx + p.popContext() + p.setValue(array) + case itemVariable: + if value, ok := p.lookupVariable(it.val); ok { + p.setValue(value) + } else { + return fmt.Errorf("Variable reference for '%s' on line %d can not be found.", + it.val, it.line) + } + case itemInclude: + m, err := ParseFile(filepath.Join(p.fp, it.val)) + if err != nil { + return fmt.Errorf("Error parsing include file '%s', %v.", it.val, err) + } + for k, v := range m { + p.pushKey(k) + p.setValue(v) + } + } + + return nil +} + +// Used to map an environment value into a temporary map to pass to secondary Parse call. +const pkey = "pk" + +// We special case raw strings here that are bcrypt'd. This allows us not to force quoting the strings +const bcryptPrefix = "2a$" + +// lookupVariable will lookup a variable reference. It will use block scoping on keys +// it has seen before, with the top level scoping being the environment variables. We +// ignore array contexts and only process the map contexts.. +// +// Returns true for ok if it finds something, similar to map. +func (p *parser) lookupVariable(varReference string) (interface{}, bool) { + // Do special check to see if it is a raw bcrypt string. + if strings.HasPrefix(varReference, bcryptPrefix) { + return "$" + varReference, true + } + + // Loop through contexts currently on the stack. + for i := len(p.ctxs) - 1; i >= 0; i -= 1 { + ctx := p.ctxs[i] + // Process if it is a map context + if m, ok := ctx.(map[string]interface{}); ok { + if v, ok := m[varReference]; ok { + return v, ok + } + } + } + + // If we are here, we have exhausted our context maps and still not found anything. + // Parse from the environment. + if vStr, ok := os.LookupEnv(varReference); ok { + // Everything we get here will be a string value, so we need to process as a parser would. + if vmap, err := Parse(fmt.Sprintf("%s=%s", pkey, vStr)); err == nil { + v, ok := vmap[pkey] + return v, ok + } + } + return nil, false +} + +func (p *parser) setValue(val interface{}) { + // Test to see if we are on an array or a map + + // Array processing + if ctx, ok := p.ctx.([]interface{}); ok { + p.ctx = append(ctx, val) + p.ctxs[len(p.ctxs)-1] = p.ctx + } + + // Map processing + if ctx, ok := p.ctx.(map[string]interface{}); ok { + key := p.popKey() + // FIXME(dlc), make sure to error if redefining same key? + ctx[key] = val + } +} diff --git a/vendor/github.com/nats-io/gnatsd/conf/parse_test.go b/vendor/github.com/nats-io/gnatsd/conf/parse_test.go new file mode 100644 index 00000000..6992c113 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/conf/parse_test.go @@ -0,0 +1,275 @@ +package conf + +import ( + "fmt" + "os" + "reflect" + "strings" + "testing" + "time" +) + +// Test to make sure we get what we expect. + +func test(t *testing.T, data string, ex map[string]interface{}) { + m, err := Parse(data) + if err != nil { + t.Fatalf("Received err: %v\n", err) + } + if m == nil { + t.Fatal("Received nil map") + } + + if !reflect.DeepEqual(m, ex) { + t.Fatalf("Not Equal:\nReceived: '%+v'\nExpected: '%+v'\n", m, ex) + } +} + +func TestSimpleTopLevel(t *testing.T) { + ex := map[string]interface{}{ + "foo": "1", + "bar": float64(2.2), + "baz": true, + "boo": int64(22), + } + test(t, "foo='1'; bar=2.2; baz=true; boo=22", ex) +} + +func TestBools(t *testing.T) { + ex := map[string]interface{}{ + "foo": true, + } + test(t, "foo=true", ex) + test(t, "foo=TRUE", ex) + test(t, "foo=true", ex) + test(t, "foo=yes", ex) + test(t, "foo=on", ex) +} + +var varSample = ` + index = 22 + foo = $index +` + +func TestSimpleVariable(t *testing.T) { + ex := map[string]interface{}{ + "index": int64(22), + "foo": int64(22), + } + test(t, varSample, ex) +} + +var varNestedSample = ` + index = 22 + nest { + index = 11 + foo = $index + } + bar = $index +` + +func TestNestedVariable(t *testing.T) { + ex := map[string]interface{}{ + "index": int64(22), + "nest": map[string]interface{}{ + "index": int64(11), + "foo": int64(11), + }, + "bar": int64(22), + } + test(t, varNestedSample, ex) +} + +func TestMissingVariable(t *testing.T) { + _, err := Parse("foo=$index") + if err == nil { + t.Fatalf("Expected an error for a missing variable, got none") + } + if !strings.HasPrefix(err.Error(), "Variable reference") { + t.Fatalf("Wanted a variable reference err, got %q\n", err) + } +} + +func TestEnvVariable(t *testing.T) { + ex := map[string]interface{}{ + "foo": int64(22), + } + evar := "__UNIQ22__" + os.Setenv(evar, "22") + defer os.Unsetenv(evar) + test(t, fmt.Sprintf("foo = $%s", evar), ex) +} + +func TestBcryptVariable(t *testing.T) { + ex := map[string]interface{}{ + "password": "$2a$11$ooo", + } + test(t, "password: $2a$11$ooo", ex) +} + +var easynum = ` +k = 8k +kb = 4kb +m = 1m +mb = 2MB +g = 2g +gb = 22GB +` + +func TestConvenientNumbers(t *testing.T) { + ex := map[string]interface{}{ + "k": int64(8 * 1000), + "kb": int64(4 * 1024), + "m": int64(1000 * 1000), + "mb": int64(2 * 1024 * 1024), + "g": int64(2 * 1000 * 1000 * 1000), + "gb": int64(22 * 1024 * 1024 * 1024), + } + test(t, easynum, ex) +} + +var sample1 = ` +foo { + host { + ip = '127.0.0.1' + port = 4242 + } + servers = [ "a.com", "b.com", "c.com"] +} +` + +func TestSample1(t *testing.T) { + ex := map[string]interface{}{ + "foo": map[string]interface{}{ + "host": map[string]interface{}{ + "ip": "127.0.0.1", + "port": int64(4242), + }, + "servers": []interface{}{"a.com", "b.com", "c.com"}, + }, + } + test(t, sample1, ex) +} + +var cluster = ` +cluster { + port: 4244 + + authorization { + user: route_user + password: top_secret + timeout: 1 + } + + # Routes are actively solicited and connected to from this server. + # Other servers can connect to us if they supply the correct credentials + # in their routes definitions from above. + + // Test both styles of comments + + routes = [ + nats-route://foo:bar@apcera.me:4245 + nats-route://foo:bar@apcera.me:4246 + ] +} +` + +func TestSample2(t *testing.T) { + ex := map[string]interface{}{ + "cluster": map[string]interface{}{ + "port": int64(4244), + "authorization": map[string]interface{}{ + "user": "route_user", + "password": "top_secret", + "timeout": int64(1), + }, + "routes": []interface{}{ + "nats-route://foo:bar@apcera.me:4245", + "nats-route://foo:bar@apcera.me:4246", + }, + }, + } + + test(t, cluster, ex) +} + +var sample3 = ` +foo { + expr = '(true == "false")' + text = 'This is a multi-line +text block.' +} +` + +func TestSample3(t *testing.T) { + ex := map[string]interface{}{ + "foo": map[string]interface{}{ + "expr": "(true == \"false\")", + "text": "This is a multi-line\ntext block.", + }, + } + test(t, sample3, ex) +} + +var sample4 = ` + array [ + { abc: 123 } + { xyz: "word" } + ] +` + +func TestSample4(t *testing.T) { + ex := map[string]interface{}{ + "array": []interface{}{ + map[string]interface{}{"abc": int64(123)}, + map[string]interface{}{"xyz": "word"}, + }, + } + test(t, sample4, ex) +} + +var sample5 = ` + now = 2016-05-04T18:53:41Z + gmt = false + +` + +func TestSample5(t *testing.T) { + dt, _ := time.Parse("2006-01-02T15:04:05Z", "2016-05-04T18:53:41Z") + ex := map[string]interface{}{ + "now": dt, + "gmt": false, + } + test(t, sample5, ex) +} + +func TestIncludes(t *testing.T) { + ex := map[string]interface{}{ + "listen": "127.0.0.1:4222", + "authorization": map[string]interface{}{ + "ALICE_PASS": "$2a$10$UHR6GhotWhpLsKtVP0/i6.Nh9.fuY73cWjLoJjb2sKT8KISBcUW5q", + "BOB_PASS": "$2a$11$dZM98SpGeI7dCFFGSpt.JObQcix8YHml4TBUZoge9R1uxnMIln5ly", + "users": []interface{}{ + map[string]interface{}{ + "user": "alice", + "password": "$2a$10$UHR6GhotWhpLsKtVP0/i6.Nh9.fuY73cWjLoJjb2sKT8KISBcUW5q"}, + map[string]interface{}{ + "user": "bob", + "password": "$2a$11$dZM98SpGeI7dCFFGSpt.JObQcix8YHml4TBUZoge9R1uxnMIln5ly"}, + }, + "timeout": float64(0.5), + }, + } + + m, err := ParseFile("simple.conf") + if err != nil { + t.Fatalf("Received err: %v\n", err) + } + if m == nil { + t.Fatal("Received nil map") + } + + if !reflect.DeepEqual(m, ex) { + t.Fatalf("Not Equal:\nReceived: '%+v'\nExpected: '%+v'\n", m, ex) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/conf/simple.conf b/vendor/github.com/nats-io/gnatsd/conf/simple.conf new file mode 100644 index 00000000..8f75d73a --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/conf/simple.conf @@ -0,0 +1,6 @@ +listen: 127.0.0.1:4222 + +authorization { + include 'includes/users.conf' # Pull in from file + timeout: 0.5 +} diff --git a/vendor/github.com/nats-io/gnatsd/logger/log.go b/vendor/github.com/nats-io/gnatsd/logger/log.go new file mode 100644 index 00000000..132cb42a --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/logger/log.go @@ -0,0 +1,152 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//Package logger provides logging facilities for the NATS server +package logger + +import ( + "fmt" + "log" + "os" +) + +// Logger is the server logger +type Logger struct { + logger *log.Logger + debug bool + trace bool + infoLabel string + errorLabel string + fatalLabel string + debugLabel string + traceLabel string + logFile *os.File // file pointer for the file logger. +} + +// NewStdLogger creates a logger with output directed to Stderr +func NewStdLogger(time, debug, trace, colors, pid bool) *Logger { + flags := 0 + if time { + flags = log.LstdFlags | log.Lmicroseconds + } + + pre := "" + if pid { + pre = pidPrefix() + } + + l := &Logger{ + logger: log.New(os.Stderr, pre, flags), + debug: debug, + trace: trace, + } + + if colors { + setColoredLabelFormats(l) + } else { + setPlainLabelFormats(l) + } + + return l +} + +// NewFileLogger creates a logger with output directed to a file +func NewFileLogger(filename string, time, debug, trace, pid bool) *Logger { + fileflags := os.O_WRONLY | os.O_APPEND | os.O_CREATE + f, err := os.OpenFile(filename, fileflags, 0660) + if err != nil { + log.Fatalf("error opening file: %v", err) + } + + flags := 0 + if time { + flags = log.LstdFlags | log.Lmicroseconds + } + + pre := "" + if pid { + pre = pidPrefix() + } + + l := &Logger{ + logger: log.New(f, pre, flags), + debug: debug, + trace: trace, + logFile: f, + } + + setPlainLabelFormats(l) + return l +} + +// Close implements the io.Closer interface to clean up +// resources in the server's logger implementation. +// Caller must ensure threadsafety. +func (l *Logger) Close() error { + if f := l.logFile; f != nil { + l.logFile = nil + return f.Close() + } + return nil +} + +// Generate the pid prefix string +func pidPrefix() string { + return fmt.Sprintf("[%d] ", os.Getpid()) +} + +func setPlainLabelFormats(l *Logger) { + l.infoLabel = "[INF] " + l.debugLabel = "[DBG] " + l.errorLabel = "[ERR] " + l.fatalLabel = "[FTL] " + l.traceLabel = "[TRC] " +} + +func setColoredLabelFormats(l *Logger) { + colorFormat := "[\x1b[%dm%s\x1b[0m] " + l.infoLabel = fmt.Sprintf(colorFormat, 32, "INF") + l.debugLabel = fmt.Sprintf(colorFormat, 36, "DBG") + l.errorLabel = fmt.Sprintf(colorFormat, 31, "ERR") + l.fatalLabel = fmt.Sprintf(colorFormat, 31, "FTL") + l.traceLabel = fmt.Sprintf(colorFormat, 33, "TRC") +} + +// Noticef logs a notice statement +func (l *Logger) Noticef(format string, v ...interface{}) { + l.logger.Printf(l.infoLabel+format, v...) +} + +// Errorf logs an error statement +func (l *Logger) Errorf(format string, v ...interface{}) { + l.logger.Printf(l.errorLabel+format, v...) +} + +// Fatalf logs a fatal error +func (l *Logger) Fatalf(format string, v ...interface{}) { + l.logger.Fatalf(l.fatalLabel+format, v...) +} + +// Debugf logs a debug statement +func (l *Logger) Debugf(format string, v ...interface{}) { + if l.debug { + l.logger.Printf(l.debugLabel+format, v...) + } +} + +// Tracef logs a trace statement +func (l *Logger) Tracef(format string, v ...interface{}) { + if l.trace { + l.logger.Printf(l.traceLabel+format, v...) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/logger/log_test.go b/vendor/github.com/nats-io/gnatsd/logger/log_test.go new file mode 100644 index 00000000..d53728e7 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/logger/log_test.go @@ -0,0 +1,185 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logger + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "log" + "os" + "strings" + "testing" +) + +func TestStdLogger(t *testing.T) { + logger := NewStdLogger(false, false, false, false, false) + + flags := logger.logger.Flags() + if flags != 0 { + t.Fatalf("Expected %q, received %q\n", 0, flags) + } + + if logger.debug { + t.Fatalf("Expected %t, received %t\n", false, logger.debug) + } + + if logger.trace { + t.Fatalf("Expected %t, received %t\n", false, logger.trace) + } +} + +func TestStdLoggerWithDebugTraceAndTime(t *testing.T) { + logger := NewStdLogger(true, true, true, false, false) + + flags := logger.logger.Flags() + if flags != log.LstdFlags|log.Lmicroseconds { + t.Fatalf("Expected %d, received %d\n", log.LstdFlags, flags) + } + + if !logger.debug { + t.Fatalf("Expected %t, received %t\n", true, logger.debug) + } + + if !logger.trace { + t.Fatalf("Expected %t, received %t\n", true, logger.trace) + } +} + +func TestStdLoggerNotice(t *testing.T) { + expectOutput(t, func() { + logger := NewStdLogger(false, false, false, false, false) + logger.Noticef("foo") + }, "[INF] foo\n") +} + +func TestStdLoggerNoticeWithColor(t *testing.T) { + expectOutput(t, func() { + logger := NewStdLogger(false, false, false, true, false) + logger.Noticef("foo") + }, "[\x1b[32mINF\x1b[0m] foo\n") +} + +func TestStdLoggerDebug(t *testing.T) { + expectOutput(t, func() { + logger := NewStdLogger(false, true, false, false, false) + logger.Debugf("foo %s", "bar") + }, "[DBG] foo bar\n") +} + +func TestStdLoggerDebugWithOutDebug(t *testing.T) { + expectOutput(t, func() { + logger := NewStdLogger(false, false, false, false, false) + logger.Debugf("foo") + }, "") +} + +func TestStdLoggerTrace(t *testing.T) { + expectOutput(t, func() { + logger := NewStdLogger(false, false, true, false, false) + logger.Tracef("foo") + }, "[TRC] foo\n") +} + +func TestStdLoggerTraceWithOutDebug(t *testing.T) { + expectOutput(t, func() { + logger := NewStdLogger(false, false, false, false, false) + logger.Tracef("foo") + }, "") +} + +func TestFileLogger(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "_gnatsd") + if err != nil { + t.Fatal("Could not create tmp dir") + } + defer os.RemoveAll(tmpDir) + + file, err := ioutil.TempFile(tmpDir, "gnatsd:log_") + if err != nil { + t.Fatalf("Could not create the temp file: %v", err) + } + file.Close() + + logger := NewFileLogger(file.Name(), false, false, false, false) + logger.Noticef("foo") + + buf, err := ioutil.ReadFile(file.Name()) + if err != nil { + t.Fatalf("Could not read logfile: %v", err) + } + if len(buf) <= 0 { + t.Fatal("Expected a non-zero length logfile") + } + + if string(buf) != "[INF] foo\n" { + t.Fatalf("Expected '%s', received '%s'\n", "[INFO] foo", string(buf)) + } + + file, err = ioutil.TempFile(tmpDir, "gnatsd:log_") + if err != nil { + t.Fatalf("Could not create the temp file: %v", err) + } + file.Close() + + logger = NewFileLogger(file.Name(), true, true, true, true) + logger.Errorf("foo") + + buf, err = ioutil.ReadFile(file.Name()) + if err != nil { + t.Fatalf("Could not read logfile: %v", err) + } + if len(buf) <= 0 { + t.Fatal("Expected a non-zero length logfile") + } + str := string(buf) + errMsg := fmt.Sprintf("Expected '%s', received '%s'\n", "[pid] [ERR] foo", str) + pidEnd := strings.Index(str, " ") + infoStart := strings.LastIndex(str, "[ERR]") + if pidEnd == -1 || infoStart == -1 { + t.Fatalf("%v", errMsg) + } + pid := str[0:pidEnd] + if pid[0] != '[' || pid[len(pid)-1] != ']' { + t.Fatalf("%v", errMsg) + } + //TODO: Parse date. + if !strings.HasSuffix(str, "[ERR] foo\n") { + t.Fatalf("%v", errMsg) + } +} + +func expectOutput(t *testing.T, f func(), expected string) { + old := os.Stderr // keep backup of the real stdout + r, w, _ := os.Pipe() + os.Stderr = w + + f() + + outC := make(chan string) + // copy the output in a separate goroutine so printing can't block indefinitely + go func() { + var buf bytes.Buffer + io.Copy(&buf, r) + outC <- buf.String() + }() + + os.Stderr.Close() + os.Stderr = old // restoring the real stdout + out := <-outC + if out != expected { + t.Fatalf("Expected '%s', received '%s'\n", expected, out) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/logger/syslog.go b/vendor/github.com/nats-io/gnatsd/logger/syslog.go new file mode 100644 index 00000000..96d65ca6 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/logger/syslog.go @@ -0,0 +1,123 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !windows + +package logger + +import ( + "fmt" + "log" + "log/syslog" + "net/url" + "os" + "strings" +) + +// SysLogger provides a system logger facility +type SysLogger struct { + writer *syslog.Writer + debug bool + trace bool +} + +// GetSysLoggerTag generates the tag name for use in syslog statements. If +// the executable is linked, the name of the link will be used as the tag, +// otherwise, the name of the executable is used. "gnatsd" is the default +// for the NATS server. +func GetSysLoggerTag() string { + procName := os.Args[0] + if strings.ContainsRune(procName, os.PathSeparator) { + parts := strings.FieldsFunc(procName, func(c rune) bool { + return c == os.PathSeparator + }) + procName = parts[len(parts)-1] + } + return procName +} + +// NewSysLogger creates a new system logger +func NewSysLogger(debug, trace bool) *SysLogger { + w, err := syslog.New(syslog.LOG_DAEMON|syslog.LOG_NOTICE, GetSysLoggerTag()) + if err != nil { + log.Fatalf("error connecting to syslog: %q", err.Error()) + } + + return &SysLogger{ + writer: w, + debug: debug, + trace: trace, + } +} + +// NewRemoteSysLogger creates a new remote system logger +func NewRemoteSysLogger(fqn string, debug, trace bool) *SysLogger { + network, addr := getNetworkAndAddr(fqn) + w, err := syslog.Dial(network, addr, syslog.LOG_DEBUG, GetSysLoggerTag()) + if err != nil { + log.Fatalf("error connecting to syslog: %q", err.Error()) + } + + return &SysLogger{ + writer: w, + debug: debug, + trace: trace, + } +} + +func getNetworkAndAddr(fqn string) (network, addr string) { + u, err := url.Parse(fqn) + if err != nil { + log.Fatal(err) + } + + network = u.Scheme + if network == "udp" || network == "tcp" { + addr = u.Host + } else if network == "unix" { + addr = u.Path + } else { + log.Fatalf("error invalid network type: %q", u.Scheme) + } + + return +} + +// Noticef logs a notice statement +func (l *SysLogger) Noticef(format string, v ...interface{}) { + l.writer.Notice(fmt.Sprintf(format, v...)) +} + +// Fatalf logs a fatal error +func (l *SysLogger) Fatalf(format string, v ...interface{}) { + l.writer.Crit(fmt.Sprintf(format, v...)) +} + +// Errorf logs an error statement +func (l *SysLogger) Errorf(format string, v ...interface{}) { + l.writer.Err(fmt.Sprintf(format, v...)) +} + +// Debugf logs a debug statement +func (l *SysLogger) Debugf(format string, v ...interface{}) { + if l.debug { + l.writer.Debug(fmt.Sprintf(format, v...)) + } +} + +// Tracef logs a trace statement +func (l *SysLogger) Tracef(format string, v ...interface{}) { + if l.trace { + l.writer.Notice(fmt.Sprintf(format, v...)) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/logger/syslog_test.go b/vendor/github.com/nats-io/gnatsd/logger/syslog_test.go new file mode 100644 index 00000000..604ccd16 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/logger/syslog_test.go @@ -0,0 +1,235 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !windows + +package logger + +import ( + "fmt" + "log" + "net" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +var serverFQN string + +func TestSysLogger(t *testing.T) { + logger := NewSysLogger(false, false) + + if logger.debug { + t.Fatalf("Expected %t, received %t\n", false, logger.debug) + } + + if logger.trace { + t.Fatalf("Expected %t, received %t\n", false, logger.trace) + } +} + +func TestSysLoggerWithDebugAndTrace(t *testing.T) { + logger := NewSysLogger(true, true) + + if !logger.debug { + t.Fatalf("Expected %t, received %t\n", true, logger.debug) + } + + if !logger.trace { + t.Fatalf("Expected %t, received %t\n", true, logger.trace) + } +} + +func testTag(t *testing.T, exePath, expected string) { + os.Args[0] = exePath + if result := GetSysLoggerTag(); result != expected { + t.Fatalf("Expected %s, received %s", expected, result) + } +} + +func restoreArg(orig string) { + os.Args[0] = orig +} + +func TestSysLoggerTagGen(t *testing.T) { + origArg := os.Args[0] + defer restoreArg(origArg) + + testTag(t, "gnatsd", "gnatsd") + testTag(t, filepath.Join(".", "gnatsd"), "gnatsd") + testTag(t, filepath.Join("home", "bin", "gnatsd"), "gnatsd") + testTag(t, filepath.Join("..", "..", "gnatsd"), "gnatsd") + testTag(t, "gnatsd.service1", "gnatsd.service1") + testTag(t, "gnatsd_service1", "gnatsd_service1") + testTag(t, "gnatsd-service1", "gnatsd-service1") + testTag(t, "gnatsd service1", "gnatsd service1") +} + +func TestSysLoggerTag(t *testing.T) { + origArg := os.Args[0] + defer restoreArg(origArg) + + os.Args[0] = "ServerLoggerTag" + + done := make(chan string) + startServer(done) + logger := NewRemoteSysLogger(serverFQN, true, true) + logger.Noticef("foo") + + line := <-done + data := strings.Split(line, "[") + if len(data) != 2 { + t.Fatalf("Unexpected syslog line %s\n", line) + } + + if !strings.Contains(data[0], os.Args[0]) { + t.Fatalf("Expected '%s', received '%s'\n", os.Args[0], data[0]) + } +} + +func TestRemoteSysLogger(t *testing.T) { + done := make(chan string) + startServer(done) + logger := NewRemoteSysLogger(serverFQN, true, true) + + if !logger.debug { + t.Fatalf("Expected %t, received %t\n", true, logger.debug) + } + + if !logger.trace { + t.Fatalf("Expected %t, received %t\n", true, logger.trace) + } +} + +func TestRemoteSysLoggerNotice(t *testing.T) { + done := make(chan string) + startServer(done) + logger := NewRemoteSysLogger(serverFQN, true, true) + + logger.Noticef("foo %s", "bar") + expectSyslogOutput(t, <-done, "foo bar\n") +} + +func TestRemoteSysLoggerDebug(t *testing.T) { + done := make(chan string) + startServer(done) + logger := NewRemoteSysLogger(serverFQN, true, true) + + logger.Debugf("foo %s", "qux") + expectSyslogOutput(t, <-done, "foo qux\n") +} + +func TestRemoteSysLoggerDebugDisabled(t *testing.T) { + done := make(chan string) + startServer(done) + logger := NewRemoteSysLogger(serverFQN, false, false) + + logger.Debugf("foo %s", "qux") + rcvd := <-done + if rcvd != "" { + t.Fatalf("Unexpected syslog response %s\n", rcvd) + } +} + +func TestRemoteSysLoggerTrace(t *testing.T) { + done := make(chan string) + startServer(done) + logger := NewRemoteSysLogger(serverFQN, true, true) + + logger.Tracef("foo %s", "qux") + expectSyslogOutput(t, <-done, "foo qux\n") +} + +func TestRemoteSysLoggerTraceDisabled(t *testing.T) { + done := make(chan string) + startServer(done) + logger := NewRemoteSysLogger(serverFQN, true, false) + + logger.Tracef("foo %s", "qux") + rcvd := <-done + if rcvd != "" { + t.Fatalf("Unexpected syslog response %s\n", rcvd) + } +} + +func TestGetNetworkAndAddrUDP(t *testing.T) { + n, a := getNetworkAndAddr("udp://foo.com:1000") + + if n != "udp" { + t.Fatalf("Unexpected network %s\n", n) + } + + if a != "foo.com:1000" { + t.Fatalf("Unexpected addr %s\n", a) + } +} + +func TestGetNetworkAndAddrTCP(t *testing.T) { + n, a := getNetworkAndAddr("tcp://foo.com:1000") + + if n != "tcp" { + t.Fatalf("Unexpected network %s\n", n) + } + + if a != "foo.com:1000" { + t.Fatalf("Unexpected addr %s\n", a) + } +} + +func TestGetNetworkAndAddrUnix(t *testing.T) { + n, a := getNetworkAndAddr("unix:///foo.sock") + + if n != "unix" { + t.Fatalf("Unexpected network %s\n", n) + } + + if a != "/foo.sock" { + t.Fatalf("Unexpected addr %s\n", a) + } +} +func expectSyslogOutput(t *testing.T, line string, expected string) { + data := strings.Split(line, "]: ") + if len(data) != 2 { + t.Fatalf("Unexpected syslog line %s\n", line) + } + + if data[1] != expected { + t.Fatalf("Expected '%s', received '%s'\n", expected, data[1]) + } +} + +func runSyslog(c net.PacketConn, done chan<- string) { + var buf [4096]byte + var rcvd string + for { + n, _, err := c.ReadFrom(buf[:]) + if err != nil || n == 0 { + break + } + rcvd += string(buf[:n]) + } + done <- rcvd +} + +func startServer(done chan<- string) { + c, e := net.ListenPacket("udp", "127.0.0.1:0") + if e != nil { + log.Fatalf("net.ListenPacket failed udp :0 %v", e) + } + + serverFQN = fmt.Sprintf("udp://%s", c.LocalAddr().String()) + c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + go runSyslog(c, done) +} diff --git a/vendor/github.com/nats-io/gnatsd/logger/syslog_windows.go b/vendor/github.com/nats-io/gnatsd/logger/syslog_windows.go new file mode 100644 index 00000000..7e780812 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/logger/syslog_windows.go @@ -0,0 +1,104 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package logger logs to the windows event log +package logger + +import ( + "fmt" + "os" + "strings" + + "golang.org/x/sys/windows/svc/eventlog" +) + +const ( + natsEventSource = "NATS-Server" +) + +// SysLogger logs to the windows event logger +type SysLogger struct { + writer *eventlog.Log + debug bool + trace bool +} + +// NewSysLogger creates a log using the windows event logger +func NewSysLogger(debug, trace bool) *SysLogger { + if err := eventlog.InstallAsEventCreate(natsEventSource, eventlog.Info|eventlog.Error|eventlog.Warning); err != nil { + if !strings.Contains(err.Error(), "registry key already exists") { + panic(fmt.Sprintf("could not access event log: %v", err)) + } + } + + w, err := eventlog.Open(natsEventSource) + if err != nil { + panic(fmt.Sprintf("could not open event log: %v", err)) + } + + return &SysLogger{ + writer: w, + debug: debug, + trace: trace, + } +} + +// NewRemoteSysLogger creates a remote event logger +func NewRemoteSysLogger(fqn string, debug, trace bool) *SysLogger { + w, err := eventlog.OpenRemote(fqn, natsEventSource) + if err != nil { + panic(fmt.Sprintf("could not open event log: %v", err)) + } + + return &SysLogger{ + writer: w, + debug: debug, + trace: trace, + } +} + +func formatMsg(tag, format string, v ...interface{}) string { + orig := fmt.Sprintf(format, v...) + return fmt.Sprintf("pid[%d][%s]: %s", os.Getpid(), tag, orig) +} + +// Noticef logs a notice statement +func (l *SysLogger) Noticef(format string, v ...interface{}) { + l.writer.Info(1, formatMsg("NOTICE", format, v...)) +} + +// Fatalf logs a fatal error +func (l *SysLogger) Fatalf(format string, v ...interface{}) { + msg := formatMsg("FATAL", format, v...) + l.writer.Error(5, msg) + panic(msg) +} + +// Errorf logs an error statement +func (l *SysLogger) Errorf(format string, v ...interface{}) { + l.writer.Error(2, formatMsg("ERROR", format, v...)) +} + +// Debugf logs a debug statement +func (l *SysLogger) Debugf(format string, v ...interface{}) { + if l.debug { + l.writer.Info(3, formatMsg("DEBUG", format, v...)) + } +} + +// Tracef logs a trace statement +func (l *SysLogger) Tracef(format string, v ...interface{}) { + if l.trace { + l.writer.Info(4, formatMsg("TRACE", format, v...)) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/logger/syslog_windows_test.go b/vendor/github.com/nats-io/gnatsd/logger/syslog_windows_test.go new file mode 100755 index 00000000..8dc06464 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/logger/syslog_windows_test.go @@ -0,0 +1,139 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build windows + +package logger + +import ( + "os/exec" + "strings" + "testing" + + "golang.org/x/sys/windows/svc/eventlog" +) + +// Skips testing if we do not have privledges to run this test. +// This lets us skip the tests for general (non admin/system) users. +func checkPrivledges(t *testing.T) { + src := "NATS-eventlog-testsource" + defer eventlog.Remove(src) + if err := eventlog.InstallAsEventCreate(src, eventlog.Info|eventlog.Error|eventlog.Warning); err != nil { + if strings.Contains(err.Error(), "Access is denied") { + t.Skip("skipping: elevated privledges are required.") + } + // let the tests report other types of errors + } +} + +// lastLogEntryContains reads the last entry (/c:1 /rd:true) written +// to the event log by the NATS-Server source, returning true if the +// passed text was found, false otherwise. +func lastLogEntryContains(t *testing.T, text string) bool { + var output []byte + var err error + + cmd := exec.Command("wevtutil.exe", "qe", "Application", "/q:*[System[Provider[@Name='NATS-Server']]]", + "/rd:true", "/c:1") + if output, err = cmd.Output(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + return strings.Contains(string(output), text) +} + +// TestSysLogger tests event logging on windows +func TestSysLogger(t *testing.T) { + checkPrivledges(t) + logger := NewSysLogger(false, false) + if logger.debug { + t.Fatalf("Expected %t, received %t\n", false, logger.debug) + } + + if logger.trace { + t.Fatalf("Expected %t, received %t\n", false, logger.trace) + } + logger.Noticef("%s", "Noticef") + if !lastLogEntryContains(t, "[NOTICE]: Noticef") { + t.Fatalf("missing log entry") + } + + logger.Errorf("%s", "Errorf") + if !lastLogEntryContains(t, "[ERROR]: Errorf") { + t.Fatalf("missing log entry") + } + + logger.Tracef("%s", "Tracef") + if lastLogEntryContains(t, "Tracef") { + t.Fatalf("should not contain log entry") + } + + logger.Debugf("%s", "Debugf") + if lastLogEntryContains(t, "Debugf") { + t.Fatalf("should not contain log entry") + } +} + +// TestSysLoggerWithDebugAndTrace tests event logging +func TestSysLoggerWithDebugAndTrace(t *testing.T) { + checkPrivledges(t) + logger := NewSysLogger(true, true) + if !logger.debug { + t.Fatalf("Expected %t, received %t\n", true, logger.debug) + } + + if !logger.trace { + t.Fatalf("Expected %t, received %t\n", true, logger.trace) + } + + logger.Tracef("%s", "Tracef") + if !lastLogEntryContains(t, "[TRACE]: Tracef") { + t.Fatalf("missing log entry") + } + + logger.Debugf("%s", "Debugf") + if !lastLogEntryContains(t, "[DEBUG]: Debugf") { + t.Fatalf("missing log entry") + } +} + +// TestSysLoggerWithDebugAndTrace tests remote event logging +func TestRemoteSysLoggerWithDebugAndTrace(t *testing.T) { + checkPrivledges(t) + logger := NewRemoteSysLogger("", true, true) + if !logger.debug { + t.Fatalf("Expected %t, received %t\n", true, logger.debug) + } + + if !logger.trace { + t.Fatalf("Expected %t, received %t\n", true, logger.trace) + } + logger.Tracef("NATS %s", "[TRACE]: Remote Noticef") + if !lastLogEntryContains(t, "Remote Noticef") { + t.Fatalf("missing log entry") + } +} + +func TestSysLoggerFatalf(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !lastLogEntryContains(t, "[FATAL]: Fatalf") { + t.Fatalf("missing log entry") + } + } + }() + + checkPrivledges(t) + logger := NewSysLogger(true, true) + logger.Fatalf("%s", "Fatalf") + t.Fatalf("did not panic when expected to") +} diff --git a/vendor/github.com/nats-io/gnatsd/server/auth.go b/vendor/github.com/nats-io/gnatsd/server/auth.go new file mode 100644 index 00000000..25b5ab7a --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/auth.go @@ -0,0 +1,249 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "crypto/tls" + "fmt" + "strings" + + "golang.org/x/crypto/bcrypt" +) + +// Authentication is an interface for implementing authentication +type Authentication interface { + // Check if a client is authorized to connect + Check(c ClientAuthentication) bool +} + +// ClientAuthentication is an interface for client authentication +type ClientAuthentication interface { + // Get options associated with a client + GetOpts() *clientOpts + // If TLS is enabled, TLS ConnectionState, nil otherwise + GetTLSConnectionState() *tls.ConnectionState + // Optionally map a user after auth. + RegisterUser(*User) +} + +// User is for multiple accounts/users. +type User struct { + Username string `json:"user"` + Password string `json:"password"` + Permissions *Permissions `json:"permissions"` +} + +// clone performs a deep copy of the User struct, returning a new clone with +// all values copied. +func (u *User) clone() *User { + if u == nil { + return nil + } + clone := &User{} + *clone = *u + clone.Permissions = u.Permissions.clone() + return clone +} + +// Permissions are the allowed subjects on a per +// publish or subscribe basis. +type Permissions struct { + Publish []string `json:"publish"` + Subscribe []string `json:"subscribe"` +} + +// RoutePermissions are similar to user permissions +// but describe what a server can import/export from and to +// another server. +type RoutePermissions struct { + Import []string `json:"import"` + Export []string `json:"export"` +} + +// clone performs a deep copy of the Permissions struct, returning a new clone +// with all values copied. +func (p *Permissions) clone() *Permissions { + if p == nil { + return nil + } + clone := &Permissions{} + if p.Publish != nil { + clone.Publish = make([]string, len(p.Publish)) + copy(clone.Publish, p.Publish) + } + if p.Subscribe != nil { + clone.Subscribe = make([]string, len(p.Subscribe)) + copy(clone.Subscribe, p.Subscribe) + } + return clone +} + +// configureAuthorization will do any setup needed for authorization. +// Lock is assumed held. +func (s *Server) configureAuthorization() { + if s.opts == nil { + return + } + + // Snapshot server options. + opts := s.getOpts() + + // Check for multiple users first + // This just checks and sets up the user map if we have multiple users. + if opts.CustomClientAuthentication != nil { + s.info.AuthRequired = true + } else if opts.Users != nil { + s.users = make(map[string]*User) + for _, u := range opts.Users { + s.users[u.Username] = u + } + s.info.AuthRequired = true + } else if opts.Username != "" || opts.Authorization != "" { + s.info.AuthRequired = true + } else { + s.users = nil + s.info.AuthRequired = false + } +} + +// checkAuthorization will check authorization based on client type and +// return boolean indicating if client is authorized. +func (s *Server) checkAuthorization(c *client) bool { + switch c.typ { + case CLIENT: + return s.isClientAuthorized(c) + case ROUTER: + return s.isRouterAuthorized(c) + default: + return false + } +} + +// hasUsers leyt's us know if we have a users array. +func (s *Server) hasUsers() bool { + s.mu.Lock() + hu := s.users != nil + s.mu.Unlock() + return hu +} + +// isClientAuthorized will check the client against the proper authorization method and data. +// This could be token or username/password based. +func (s *Server) isClientAuthorized(c *client) bool { + // Snapshot server options. + opts := s.getOpts() + + // Check custom auth first, then multiple users, then token, then single user/pass. + if opts.CustomClientAuthentication != nil { + return opts.CustomClientAuthentication.Check(c) + } else if s.hasUsers() { + s.mu.Lock() + user, ok := s.users[c.opts.Username] + s.mu.Unlock() + + if !ok { + return false + } + ok = comparePasswords(user.Password, c.opts.Password) + // If we are authorized, register the user which will properly setup any permissions + // for pub/sub authorizations. + if ok { + c.RegisterUser(user) + } + return ok + + } else if opts.Authorization != "" { + return comparePasswords(opts.Authorization, c.opts.Authorization) + + } else if opts.Username != "" { + if opts.Username != c.opts.Username { + return false + } + return comparePasswords(opts.Password, c.opts.Password) + } + + return true +} + +// checkRouterAuth checks optional router authorization which can be nil or username/password. +func (s *Server) isRouterAuthorized(c *client) bool { + // Snapshot server options. + opts := s.getOpts() + + if s.opts.CustomRouterAuthentication != nil { + return s.opts.CustomRouterAuthentication.Check(c) + } + + if opts.Cluster.Username == "" { + return true + } + + if opts.Cluster.Username != c.opts.Username { + return false + } + if !comparePasswords(opts.Cluster.Password, c.opts.Password) { + return false + } + c.setRoutePermissions(opts.Cluster.Permissions) + return true +} + +// removeUnauthorizedSubs removes any subscriptions the client has that are no +// longer authorized, e.g. due to a config reload. +func (s *Server) removeUnauthorizedSubs(c *client) { + c.mu.Lock() + if c.perms == nil { + c.mu.Unlock() + return + } + + subs := make(map[string]*subscription, len(c.subs)) + for sid, sub := range c.subs { + subs[sid] = sub + } + c.mu.Unlock() + + for sid, sub := range subs { + if !c.canSubscribe(sub.subject) { + _ = s.sl.Remove(sub) + c.mu.Lock() + delete(c.subs, sid) + c.mu.Unlock() + c.sendErr(fmt.Sprintf("Permissions Violation for Subscription to %q (sid %s)", + sub.subject, sub.sid)) + s.Noticef("Removed sub %q for user %q - not authorized", + string(sub.subject), c.opts.Username) + } + } +} + +// Support for bcrypt stored passwords and tokens. +const bcryptPrefix = "$2a$" + +// isBcrypt checks whether the given password or token is bcrypted. +func isBcrypt(password string) bool { + return strings.HasPrefix(password, bcryptPrefix) +} + +func comparePasswords(serverPassword, clientPassword string) bool { + // Check to see if the server password is a bcrypt hash + if isBcrypt(serverPassword) { + if err := bcrypt.CompareHashAndPassword([]byte(serverPassword), []byte(clientPassword)); err != nil { + return false + } + } else if serverPassword != clientPassword { + return false + } + return true +} diff --git a/vendor/github.com/nats-io/gnatsd/server/auth_test.go b/vendor/github.com/nats-io/gnatsd/server/auth_test.go new file mode 100644 index 00000000..787f5993 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/auth_test.go @@ -0,0 +1,99 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "reflect" + "testing" +) + +func TestUserCloneNilPermissions(t *testing.T) { + user := &User{ + Username: "foo", + Password: "bar", + } + + clone := user.clone() + + if !reflect.DeepEqual(user, clone) { + t.Fatalf("Cloned Users are incorrect.\nexpected: %+v\ngot: %+v", + user, clone) + } + + clone.Password = "baz" + if reflect.DeepEqual(user, clone) { + t.Fatal("Expected Users to be different") + } +} + +func TestUserClone(t *testing.T) { + user := &User{ + Username: "foo", + Password: "bar", + Permissions: &Permissions{ + Publish: []string{"foo"}, + Subscribe: []string{"bar"}, + }, + } + + clone := user.clone() + + if !reflect.DeepEqual(user, clone) { + t.Fatalf("Cloned Users are incorrect.\nexpected: %+v\ngot: %+v", + user, clone) + } + + clone.Permissions.Subscribe = []string{"baz"} + if reflect.DeepEqual(user, clone) { + t.Fatal("Expected Users to be different") + } +} + +func TestUserClonePermissionsNoLists(t *testing.T) { + user := &User{ + Username: "foo", + Password: "bar", + Permissions: &Permissions{}, + } + + clone := user.clone() + + if clone.Permissions.Publish != nil { + t.Fatalf("Expected Publish to be nil, got: %v", clone.Permissions.Publish) + } + if clone.Permissions.Subscribe != nil { + t.Fatalf("Expected Subscribe to be nil, got: %v", clone.Permissions.Subscribe) + } +} + +func TestUserCloneNoPermissions(t *testing.T) { + user := &User{ + Username: "foo", + Password: "bar", + } + + clone := user.clone() + + if clone.Permissions != nil { + t.Fatalf("Expected Permissions to be nil, got: %v", clone.Permissions) + } +} + +func TestUserCloneNil(t *testing.T) { + user := (*User)(nil) + clone := user.clone() + if clone != nil { + t.Fatalf("Expected nil, got: %+v", clone) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/ciphersuites.go b/vendor/github.com/nats-io/gnatsd/server/ciphersuites.go new file mode 100644 index 00000000..cbc5a2ff --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/ciphersuites.go @@ -0,0 +1,97 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "crypto/tls" +) + +// Where we maintain all of the available ciphers +var cipherMap = map[string]uint16{ + "TLS_RSA_WITH_RC4_128_SHA": tls.TLS_RSA_WITH_RC4_128_SHA, + "TLS_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, + "TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, + "TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256, + "TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, + "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + "TLS_ECDHE_RSA_WITH_RC4_128_SHA": tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, + "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, +} + +var cipherMapByID = map[uint16]string{ + tls.TLS_RSA_WITH_RC4_128_SHA: "TLS_RSA_WITH_RC4_128_SHA", + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_RSA_WITH_3DES_EDE_CBC_SHA", + tls.TLS_RSA_WITH_AES_128_CBC_SHA: "TLS_RSA_WITH_AES_128_CBC_SHA", + tls.TLS_RSA_WITH_AES_128_CBC_SHA256: "TLS_RSA_WITH_AES_128_CBC_SHA256", + tls.TLS_RSA_WITH_AES_256_CBC_SHA: "TLS_RSA_WITH_AES_256_CBC_SHA", + tls.TLS_RSA_WITH_AES_256_GCM_SHA384: "TLS_RSA_WITH_AES_256_GCM_SHA384", + tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", + tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: "TLS_ECDHE_RSA_WITH_RC4_128_SHA", + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", +} + +func defaultCipherSuites() []uint16 { + return []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + } +} + +// Where we maintain available curve preferences +var curvePreferenceMap = map[string]tls.CurveID{ + "CurveP256": tls.CurveP256, + "CurveP384": tls.CurveP384, + "CurveP521": tls.CurveP521, + "X25519": tls.X25519, +} + +// reorder to default to the highest level of security. See: +// https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go +func defaultCurvePreferences() []tls.CurveID { + return []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.X25519, // faster than P256, arguably more secure + tls.CurveP256, + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/client.go b/vendor/github.com/nats-io/gnatsd/server/client.go new file mode 100644 index 00000000..1f2cdb38 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/client.go @@ -0,0 +1,1810 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "math/rand" + "net" + "sync" + "sync/atomic" + "time" +) + +// Type of client connection. +const ( + // CLIENT is an end user. + CLIENT = iota + // ROUTER is another router in the cluster. + ROUTER +) + +const ( + // Original Client protocol from 2009. + // http://nats.io/documentation/internals/nats-protocol/ + ClientProtoZero = iota + // This signals a client can receive more then the original INFO block. + // This can be used to update clients on other cluster members, etc. + ClientProtoInfo +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +const ( + // Scratch buffer size for the processMsg() calls. + msgScratchSize = 512 + msgHeadProto = "MSG " + msgHeadProtoLen = len(msgHeadProto) +) + +// For controlling dynamic buffer sizes. +const ( + startBufSize = 512 // For INFO/CONNECT block + minBufSize = 64 // Smallest to shrink to for PING/PONG + maxBufSize = 65536 // 64k + shortsToShrink = 2 +) + +// Represent client booleans with a bitmask +type clientFlag byte + +// Some client state represented as flags +const ( + connectReceived clientFlag = 1 << iota // The CONNECT proto has been received + firstPongSent // The first PONG has been sent + handshakeComplete // For TLS clients, indicate that the handshake is complete + clearConnection // Marks that clearConnection has already been called. + flushOutbound // Marks client as having a flushOutbound call in progress. +) + +// set the flag (would be equivalent to set the boolean to true) +func (cf *clientFlag) set(c clientFlag) { + *cf |= c +} + +// clear the flag (would be equivalent to set the boolean to false) +func (cf *clientFlag) clear(c clientFlag) { + *cf &= ^c +} + +// isSet returns true if the flag is set, false otherwise +func (cf clientFlag) isSet(c clientFlag) bool { + return cf&c != 0 +} + +// setIfNotSet will set the flag `c` only if that flag was not already +// set and return true to indicate that the flag has been set. Returns +// false otherwise. +func (cf *clientFlag) setIfNotSet(c clientFlag) bool { + if *cf&c == 0 { + *cf |= c + return true + } + return false +} + +// Reason client was closed. This will be passed into +// calls to clearConnection, but will only be stored +// in ConnInfo for monitoring. +type ClosedState int + +const ( + ClientClosed = ClosedState(iota + 1) + AuthenticationTimeout + AuthenticationViolation + TLSHandshakeError + SlowConsumerPendingBytes + SlowConsumerWriteDeadline + WriteError + ReadError + ParseError + StaleConnection + ProtocolViolation + BadClientProtocolVersion + WrongPort + MaxConnectionsExceeded + MaxPayloadExceeded + MaxControlLineExceeded + DuplicateRoute + RouteRemoved + ServerShutdown +) + +type client struct { + // Here first because of use of atomics, and memory alignment. + stats + mpay int64 + msubs int + mu sync.Mutex + typ int + cid uint64 + opts clientOpts + start time.Time + nc net.Conn + ncs string + out outbound + srv *Server + subs map[string]*subscription + perms *permissions + in readCache + pcd map[*client]struct{} + atmr *time.Timer + ping pinfo + msgb [msgScratchSize]byte + last time.Time + parseState + + rtt time.Duration + rttStart time.Time + + route *route + + debug bool + trace bool + echo bool + + flags clientFlag // Compact booleans into a single field. Size will be increased when needed. +} + +// Struct for PING initiation from the server. +type pinfo struct { + tmr *time.Timer + out int +} + +// outbound holds pending data for a socket. +type outbound struct { + p []byte // Primary write buffer + s []byte // Secondary for use post flush + nb net.Buffers // net.Buffers for writev IO + sz int // limit size per []byte, uses variable BufSize constants, start, min, max. + sws int // Number of short writes, used for dyanmic resizing. + pb int64 // Total pending/queued bytes. + pm int64 // Total pending/queued messages. + sg *sync.Cond // Flusher conditional for signaling. + fsp int // Flush signals that are pending from readLoop's pcd. + mp int64 // snapshot of max pending. + wdl time.Duration // Snapshot fo write deadline. + lft time.Duration // Last flush time. +} + +type permissions struct { + sub *Sublist + pub *Sublist + pcache map[string]bool +} + +const ( + maxResultCacheSize = 512 + maxPermCacheSize = 32 + pruneSize = 16 +) + +// Used in readloop to cache hot subject lookups and group statistics. +type readCache struct { + genid uint64 + results map[string]*SublistResult + prand *rand.Rand + msgs int + bytes int + subs int + rsz int // Read buffer size + srs int // Short reads, used for dynamic buffer resizing. +} + +func (c *client) String() (id string) { + return c.ncs +} + +func (c *client) GetOpts() *clientOpts { + return &c.opts +} + +// GetTLSConnectionState returns the TLS ConnectionState if TLS is enabled, nil +// otherwise. Implements the ClientAuth interface. +func (c *client) GetTLSConnectionState() *tls.ConnectionState { + tc, ok := c.nc.(*tls.Conn) + if !ok { + return nil + } + state := tc.ConnectionState() + return &state +} + +type subscription struct { + client *client + subject []byte + queue []byte + sid []byte + nm int64 + max int64 +} + +type clientOpts struct { + Echo bool `json:"echo"` + Verbose bool `json:"verbose"` + Pedantic bool `json:"pedantic"` + TLSRequired bool `json:"tls_required"` + Authorization string `json:"auth_token"` + Username string `json:"user"` + Password string `json:"pass"` + Name string `json:"name"` + Lang string `json:"lang"` + Version string `json:"version"` + Protocol int `json:"protocol"` +} + +var defaultOpts = clientOpts{Verbose: true, Pedantic: true, Echo: true} + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +// Lock should be held +func (c *client) initClient() { + s := c.srv + c.cid = atomic.AddUint64(&s.gcid, 1) + + // Outbound data structure setup + c.out.sz = startBufSize + c.out.sg = sync.NewCond(&c.mu) + opts := s.getOpts() + // Snapshots to avoid mutex access in fast paths. + c.out.wdl = opts.WriteDeadline + c.out.mp = opts.MaxPending + + c.subs = make(map[string]*subscription) + c.echo = true + + c.debug = (atomic.LoadInt32(&c.srv.logging.debug) != 0) + c.trace = (atomic.LoadInt32(&c.srv.logging.trace) != 0) + + // This is a scratch buffer used for processMsg() + // The msg header starts with "MSG ", + // in bytes that is [77 83 71 32]. + c.msgb = [msgScratchSize]byte{77, 83, 71, 32} + + // This is to track pending clients that have data to be flushed + // after we process inbound msgs from our own connection. + c.pcd = make(map[*client]struct{}) + + // snapshot the string version of the connection + conn := "-" + if ip, ok := c.nc.(*net.TCPConn); ok { + addr := ip.RemoteAddr().(*net.TCPAddr) + conn = fmt.Sprintf("%s:%d", addr.IP, addr.Port) + } + + switch c.typ { + case CLIENT: + c.ncs = fmt.Sprintf("%s - cid:%d", conn, c.cid) + case ROUTER: + c.ncs = fmt.Sprintf("%s - rid:%d", conn, c.cid) + } +} + +// RegisterUser allows auth to call back into a new client +// with the authenticated user. This is used to map any permissions +// into the client. +func (c *client) RegisterUser(user *User) { + if user.Permissions == nil { + // Reset perms to nil in case client previously had them. + c.mu.Lock() + c.perms = nil + c.mu.Unlock() + return + } + + // Process Permissions and map into client connection structures. + c.mu.Lock() + defer c.mu.Unlock() + + c.setPermissions(user.Permissions) +} + +// Initializes client.perms structure. +// Lock is held on entry. +func (c *client) setPermissions(perms *Permissions) { + // Pre-allocate all to simplify checks later. + c.perms = &permissions{} + c.perms.sub = NewSublist() + c.perms.pub = NewSublist() + c.perms.pcache = make(map[string]bool) + + // Loop over publish permissions + for _, pubSubject := range perms.Publish { + sub := &subscription{subject: []byte(pubSubject)} + c.perms.pub.Insert(sub) + } + + // Loop over subscribe permissions + for _, subSubject := range perms.Subscribe { + sub := &subscription{subject: []byte(subSubject)} + c.perms.sub.Insert(sub) + } +} + +// writeLoop is the main socket write functionality. +// Runs in its own Go routine. +func (c *client) writeLoop() { + defer c.srv.grWG.Done() + + // Used to check that we did flush from last wake up. + waitOk := true + + // Main loop. Will wait to be signaled and then will use + // buffered outbound structure for efficient writev to the underlying socket. + for { + c.mu.Lock() + if waitOk && (c.out.pb == 0 || c.out.fsp > 0) && len(c.out.nb) == 0 && !c.flags.isSet(clearConnection) { + // Wait on pending data. + c.out.sg.Wait() + } + // Flush data + waitOk = c.flushOutbound() + isClosed := c.flags.isSet(clearConnection) + c.mu.Unlock() + + if isClosed { + return + } + } +} + +// readLoop is the main socket read functionality. +// Runs in its own Go routine. +func (c *client) readLoop() { + // Grab the connection off the client, it will be cleared on a close. + // We check for that after the loop, but want to avoid a nil dereference + c.mu.Lock() + nc := c.nc + s := c.srv + c.in.rsz = startBufSize + defer s.grWG.Done() + c.mu.Unlock() + + if nc == nil { + return + } + + // Start read buffer. + + b := make([]byte, c.in.rsz) + + for { + n, err := nc.Read(b) + if err != nil { + if err == io.EOF { + c.closeConnection(ClientClosed) + } else { + c.closeConnection(ReadError) + } + return + } + + // Grab for updates for last activity. + last := time.Now() + + // Clear inbound stats cache + c.in.msgs = 0 + c.in.bytes = 0 + c.in.subs = 0 + + // Main call into parser for inbound data. This will generate callouts + // to process messages, etc. + if err := c.parse(b[:n]); err != nil { + // handled inline + if err != ErrMaxPayload && err != ErrAuthorization { + c.Errorf("%s", err.Error()) + c.closeConnection(ProtocolViolation) + } + return + } + + // Updates stats for client and server that were collected + // from parsing through the buffer. + if c.in.msgs > 0 { + atomic.AddInt64(&c.inMsgs, int64(c.in.msgs)) + atomic.AddInt64(&c.inBytes, int64(c.in.bytes)) + atomic.AddInt64(&s.inMsgs, int64(c.in.msgs)) + atomic.AddInt64(&s.inBytes, int64(c.in.bytes)) + } + + // Budget to spend in place flushing outbound data. + // Client will be checked on several fronts to see + // if applicable. Routes will never wait in place. + budget := 500 * time.Microsecond + if c.typ == ROUTER { + budget = 0 + } + + // Check pending clients for flush. + for cp := range c.pcd { + // Queue up a flush for those in the set + cp.mu.Lock() + // Update last activity for message delivery + cp.last = last + cp.out.fsp-- + if budget > 0 && cp.flushOutbound() { + budget -= cp.out.lft + } else { + cp.flushSignal() + } + cp.mu.Unlock() + delete(c.pcd, cp) + } + + // Update activity, check read buffer size. + c.mu.Lock() + nc := c.nc + + // Activity based on interest changes or data/msgs. + if c.in.msgs > 0 || c.in.subs > 0 { + c.last = last + } + + if n >= cap(b) { + c.in.srs = 0 + } else if n < cap(b)/2 { // divide by 2 b/c we want less than what we would shrink to. + c.in.srs++ + } + + // Update read buffer size as/if needed. + if n >= cap(b) && cap(b) < maxBufSize { + // Grow + c.in.rsz = cap(b) * 2 + b = make([]byte, c.in.rsz) + } else if n < cap(b) && cap(b) > minBufSize && c.in.srs > shortsToShrink { + // Shrink, for now don't accelerate, ping/pong will eventually sort it out. + c.in.rsz = cap(b) / 2 + b = make([]byte, c.in.rsz) + } + c.mu.Unlock() + + // Check to see if we got closed, e.g. slow consumer + if nc == nil { + return + } + } +} + +// collapsePtoNB will place primary onto nb buffer as needed in prep for WriteTo. +// This will return a copy on purpose. +func (c *client) collapsePtoNB() net.Buffers { + if c.out.p != nil { + p := c.out.p + c.out.p = nil + return append(c.out.nb, p) + } + return c.out.nb +} + +// This will handle the fixup needed on a partial write. +// Assume pending has been already calculated correctly. +func (c *client) handlePartialWrite(pnb net.Buffers) { + nb := c.collapsePtoNB() + // The partial needs to be first, so append nb to pnb + c.out.nb = append(pnb, nb...) +} + +// flushOutbound will flush outbound buffer to a client. +// Will return if data was attempted to be written. +// Lock must be held +func (c *client) flushOutbound() bool { + if c.flags.isSet(flushOutbound) { + return false + } + c.flags.set(flushOutbound) + defer c.flags.clear(flushOutbound) + + // Check for nothing to do. + if c.nc == nil || c.srv == nil || c.out.pb == 0 { + return true // true because no need to queue a signal. + } + + // Snapshot opts + srv := c.srv + + // Place primary on nb, assign primary to secondary, nil out nb and secondary. + nb := c.collapsePtoNB() + c.out.p, c.out.nb, c.out.s = c.out.s, nil, nil + + // For selecting primary replacement. + cnb := nb + + // In case it goes away after releasing the lock. + nc := c.nc + attempted := c.out.pb + apm := c.out.pm + + // Do NOT hold lock during actual IO + c.mu.Unlock() + + // flush here + now := time.Now() + // FIXME(dlc) - writev will do multiple IOs past 1024 on + // most platforms, need to account for that with deadline? + nc.SetWriteDeadline(now.Add(c.out.wdl)) + // Actual write to the socket. + n, err := nb.WriteTo(nc) + nc.SetWriteDeadline(time.Time{}) + lft := time.Since(now) + + // Re-acquire client lock + c.mu.Lock() + + // Update flush time statistics + c.out.lft = lft + + // Subtract from pending bytes and messages. + c.out.pb -= n + c.out.pm -= apm // FIXME(dlc) - this will not be accurate. + + // Check for partial writes + if n != attempted && n > 0 { + c.handlePartialWrite(nb) + } else if n >= int64(c.out.sz) { + c.out.sws = 0 + } + + if err != nil { + if n == 0 { + c.out.pb -= attempted + } + if ne, ok := err.(net.Error); ok && ne.Timeout() { + atomic.AddInt64(&srv.slowConsumers, 1) + c.clearConnection(SlowConsumerWriteDeadline) + c.Noticef("Slow Consumer Detected: WriteDeadline of %v Exceeded", c.out.wdl) + } else { + c.clearConnection(WriteError) + c.Debugf("Error flushing: %v", err) + } + return true + } + + // Adjust based on what we wrote plus any pending. + pt := int(n + c.out.pb) + + // Adjust sz as needed downward, keeping power of 2. + // We do this at a slower rate, hence the pt*4. + if pt < c.out.sz && c.out.sz > minBufSize { + c.out.sws++ + if c.out.sws > shortsToShrink { + c.out.sz >>= 1 + } + } + // Adjust sz as needed upward, keeping power of 2. + if pt > c.out.sz && c.out.sz < maxBufSize { + c.out.sz <<= 1 + } + + // Check to see if we can reuse buffers. + if len(cnb) > 0 { + oldp := cnb[0][:0] + if cap(oldp) >= c.out.sz { + // Replace primary or secondary if they are nil, reusing same buffer. + if c.out.p == nil { + c.out.p = oldp + } else if c.out.s == nil || cap(c.out.s) < c.out.sz { + c.out.s = oldp + } + } + } + return true +} + +// flushSignal will use server to queue the flush IO operation to a pool of flushers. +// Lock must be held. +func (c *client) flushSignal() { + c.out.sg.Signal() +} + +func (c *client) traceMsg(msg []byte) { + if !c.trace { + return + } + // FIXME(dlc), allow limits to printable payload + c.Tracef("->> MSG_PAYLOAD: [%s]", string(msg[:len(msg)-LEN_CR_LF])) +} + +func (c *client) traceInOp(op string, arg []byte) { + c.traceOp("->> %s", op, arg) +} + +func (c *client) traceOutOp(op string, arg []byte) { + c.traceOp("<<- %s", op, arg) +} + +func (c *client) traceOp(format, op string, arg []byte) { + if !c.trace { + return + } + + opa := []interface{}{} + if op != "" { + opa = append(opa, op) + } + if arg != nil { + opa = append(opa, string(arg)) + } + c.Tracef(format, opa) +} + +// Process the information messages from Clients and other Routes. +func (c *client) processInfo(arg []byte) error { + info := Info{} + if err := json.Unmarshal(arg, &info); err != nil { + return err + } + if c.typ == ROUTER { + c.processRouteInfo(&info) + } + return nil +} + +func (c *client) processErr(errStr string) { + switch c.typ { + case CLIENT: + c.Errorf("Client Error %s", errStr) + case ROUTER: + c.Errorf("Route Error %s", errStr) + } + c.closeConnection(ParseError) +} + +func (c *client) processConnect(arg []byte) error { + c.traceInOp("CONNECT", arg) + + c.mu.Lock() + // If we can't stop the timer because the callback is in progress... + if !c.clearAuthTimer() { + // wait for it to finish and handle sending the failure back to + // the client. + for c.nc != nil { + c.mu.Unlock() + time.Sleep(25 * time.Millisecond) + c.mu.Lock() + } + c.mu.Unlock() + return nil + } + c.last = time.Now() + typ := c.typ + r := c.route + srv := c.srv + // Moved unmarshalling of clients' Options under the lock. + // The client has already been added to the server map, so it is possible + // that other routines lookup the client, and access its options under + // the client's lock, so unmarshalling the options outside of the lock + // would cause data RACEs. + if err := json.Unmarshal(arg, &c.opts); err != nil { + c.mu.Unlock() + return err + } + // Indicate that the CONNECT protocol has been received, and that the + // server now knows which protocol this client supports. + c.flags.set(connectReceived) + // Capture these under lock + c.echo = c.opts.Echo + proto := c.opts.Protocol + verbose := c.opts.Verbose + lang := c.opts.Lang + c.mu.Unlock() + + if srv != nil { + // As soon as c.opts is unmarshalled and if the proto is at + // least ClientProtoInfo, we need to increment the following counter. + // This is decremented when client is removed from the server's + // clients map. + if proto >= ClientProtoInfo { + srv.mu.Lock() + srv.cproto++ + srv.mu.Unlock() + } + + // Check for Auth + if ok := srv.checkAuthorization(c); !ok { + c.authViolation() + return ErrAuthorization + } + } + + // Check client protocol request if it exists. + if typ == CLIENT && (proto < ClientProtoZero || proto > ClientProtoInfo) { + c.sendErr(ErrBadClientProtocol.Error()) + c.closeConnection(BadClientProtocolVersion) + return ErrBadClientProtocol + } else if typ == ROUTER && lang != "" { + // Way to detect clients that incorrectly connect to the route listen + // port. Client provide Lang in the CONNECT protocol while ROUTEs don't. + c.sendErr(ErrClientConnectedToRoutePort.Error()) + c.closeConnection(WrongPort) + return ErrClientConnectedToRoutePort + } + + // Grab connection name of remote route. + if typ == ROUTER && r != nil { + c.mu.Lock() + c.route.remoteID = c.opts.Name + c.mu.Unlock() + } + + if verbose { + c.sendOK() + } + return nil +} + +func (c *client) authTimeout() { + c.sendErr(ErrAuthTimeout.Error()) + c.Debugf("Authorization Timeout") + c.closeConnection(AuthenticationTimeout) +} + +func (c *client) authViolation() { + if c.srv != nil && c.srv.getOpts().Users != nil { + c.Errorf("%s - User %q", + ErrAuthorization.Error(), + c.opts.Username) + } else { + c.Errorf(ErrAuthorization.Error()) + } + c.sendErr("Authorization Violation") + c.closeConnection(AuthenticationViolation) +} + +func (c *client) maxConnExceeded() { + c.Errorf(ErrTooManyConnections.Error()) + c.sendErr(ErrTooManyConnections.Error()) + c.closeConnection(MaxConnectionsExceeded) +} + +func (c *client) maxSubsExceeded() { + c.Errorf(ErrTooManySubs.Error()) + c.sendErr(ErrTooManySubs.Error()) +} + +func (c *client) maxPayloadViolation(sz int, max int64) { + c.Errorf("%s: %d vs %d", ErrMaxPayload.Error(), sz, max) + c.sendErr("Maximum Payload Violation") + c.closeConnection(MaxPayloadExceeded) +} + +// queueOutbound queues data for client/route connections. +// Return pending length. +// Lock should be held. +func (c *client) queueOutbound(data []byte) { + // Add to pending bytes total. + c.out.pb += int64(len(data)) + + // Check for slow consumer via pending bytes limit. + // ok to return here, client is going away. + if c.out.pb > c.out.mp { + c.clearConnection(SlowConsumerPendingBytes) + atomic.AddInt64(&c.srv.slowConsumers, 1) + c.Noticef("Slow Consumer Detected: MaxPending of %d Exceeded", c.out.mp) + return + } + + if c.out.p == nil && len(data) < maxBufSize { + if c.out.sz == 0 { + c.out.sz = startBufSize + } + if c.out.s != nil && cap(c.out.s) >= c.out.sz { + c.out.p = c.out.s + c.out.s = nil + } else { + // FIXME(dlc) - make power of 2 if less than maxBufSize? + c.out.p = make([]byte, 0, c.out.sz) + } + } + // Determine if we copy or reference + available := cap(c.out.p) - len(c.out.p) + if len(data) > available { + // We can fit into existing primary, but message will fit in next one + // we allocate or utilize from the secondary. So copy what we can. + if available > 0 && len(data) < c.out.sz { + c.out.p = append(c.out.p, data[:available]...) + data = data[available:] + } + // Put the primary on the nb if it has a payload + if len(c.out.p) > 0 { + c.out.nb = append(c.out.nb, c.out.p) + c.out.p = nil + } + // Check for a big message, and if found place directly on nb + // FIXME(dlc) - do we need signaling of ownership here if we want len(data) < + if len(data) > maxBufSize { + c.out.nb = append(c.out.nb, data) + } else { + // We will copy to primary. + if c.out.p == nil { + // Grow here + if (c.out.sz << 1) <= maxBufSize { + c.out.sz <<= 1 + } + if len(data) > c.out.sz { + c.out.p = make([]byte, 0, len(data)) + } else { + if c.out.s != nil && cap(c.out.s) >= c.out.sz { // TODO(dlc) - Size mismatch? + c.out.p = c.out.s + c.out.s = nil + } else { + c.out.p = make([]byte, 0, c.out.sz) + } + } + } + c.out.p = append(c.out.p, data...) + } + } else { + c.out.p = append(c.out.p, data...) + } +} + +// Assume the lock is held upon entry. +func (c *client) sendProto(info []byte, doFlush bool) { + if c.nc == nil { + return + } + c.queueOutbound(info) + if !(doFlush && c.flushOutbound()) { + c.flushSignal() + } +} + +// Assume the lock is held upon entry. +func (c *client) sendPong() { + c.traceOutOp("PONG", nil) + c.sendProto([]byte("PONG\r\n"), true) +} + +// Assume the lock is held upon entry. +func (c *client) sendPing() { + c.rttStart = time.Now() + c.ping.out++ + c.traceOutOp("PING", nil) + c.sendProto([]byte("PING\r\n"), true) +} + +// Generates the INFO to be sent to the client with the client ID included. +// info arg will be copied since passed by value. +// Assume lock is held. +func (c *client) generateClientInfoJSON(info Info) []byte { + info.CID = c.cid + // Generate the info json + b, _ := json.Marshal(info) + pcs := [][]byte{[]byte("INFO"), b, []byte(CR_LF)} + return bytes.Join(pcs, []byte(" ")) +} + +// Assume the lock is held upon entry. +func (c *client) sendInfo(info []byte) { + c.sendProto(info, true) +} + +func (c *client) sendErr(err string) { + c.mu.Lock() + c.traceOutOp("-ERR", []byte(err)) + c.sendProto([]byte(fmt.Sprintf("-ERR '%s'\r\n", err)), true) + c.mu.Unlock() +} + +func (c *client) sendOK() { + c.mu.Lock() + c.traceOutOp("OK", nil) + // Can not autoflush this one, needs to be async. + c.sendProto([]byte("+OK\r\n"), false) + // FIXME(dlc) - ?? + c.pcd[c] = needFlush + c.mu.Unlock() +} + +func (c *client) processPing() { + c.mu.Lock() + c.traceInOp("PING", nil) + if c.nc == nil { + c.mu.Unlock() + return + } + c.sendPong() + + // The CONNECT should have been received, but make sure it + // is so before proceeding + if !c.flags.isSet(connectReceived) { + c.mu.Unlock() + return + } + // If we are here, the CONNECT has been received so we know + // if this client supports async INFO or not. + var ( + checkClusterChange bool + srv = c.srv + ) + // For older clients, just flip the firstPongSent flag if not already + // set and we are done. + if c.opts.Protocol < ClientProtoInfo || srv == nil { + c.flags.setIfNotSet(firstPongSent) + } else { + // This is a client that supports async INFO protocols. + // If this is the first PING (so firstPongSent is not set yet), + // we will need to check if there was a change in cluster topology. + checkClusterChange = !c.flags.isSet(firstPongSent) + } + c.mu.Unlock() + + if checkClusterChange { + srv.mu.Lock() + c.mu.Lock() + // Now that we are under both locks, we can flip the flag. + // This prevents sendAsyncInfoToClients() and and code here + // to send a double INFO protocol. + c.flags.set(firstPongSent) + // If there was a cluster update since this client was created, + // send an updated INFO protocol now. + if srv.lastCURLsUpdate >= c.start.UnixNano() { + c.sendInfo(c.generateClientInfoJSON(srv.copyInfo())) + } + c.mu.Unlock() + srv.mu.Unlock() + } +} + +func (c *client) processPong() { + c.traceInOp("PONG", nil) + c.mu.Lock() + c.ping.out = 0 + c.rtt = time.Since(c.rttStart) + c.mu.Unlock() +} + +func (c *client) processMsgArgs(arg []byte) error { + if c.trace { + c.traceInOp("MSG", arg) + } + + // Unroll splitArgs to avoid runtime/heap issues + a := [MAX_MSG_ARGS][]byte{} + args := a[:0] + start := -1 + for i, b := range arg { + switch b { + case ' ', '\t', '\r', '\n': + if start >= 0 { + args = append(args, arg[start:i]) + start = -1 + } + default: + if start < 0 { + start = i + } + } + } + if start >= 0 { + args = append(args, arg[start:]) + } + + switch len(args) { + case 3: + c.pa.reply = nil + c.pa.szb = args[2] + c.pa.size = parseSize(args[2]) + case 4: + c.pa.reply = args[2] + c.pa.szb = args[3] + c.pa.size = parseSize(args[3]) + default: + return fmt.Errorf("processMsgArgs Parse Error: '%s'", arg) + } + if c.pa.size < 0 { + return fmt.Errorf("processMsgArgs Bad or Missing Size: '%s'", arg) + } + + // Common ones processed after check for arg length + c.pa.subject = args[0] + c.pa.sid = args[1] + + return nil +} + +func (c *client) processPub(arg []byte) error { + if c.trace { + c.traceInOp("PUB", arg) + } + + // Unroll splitArgs to avoid runtime/heap issues + a := [MAX_PUB_ARGS][]byte{} + args := a[:0] + start := -1 + for i, b := range arg { + switch b { + case ' ', '\t': + if start >= 0 { + args = append(args, arg[start:i]) + start = -1 + } + default: + if start < 0 { + start = i + } + } + } + if start >= 0 { + args = append(args, arg[start:]) + } + + switch len(args) { + case 2: + c.pa.subject = args[0] + c.pa.reply = nil + c.pa.size = parseSize(args[1]) + c.pa.szb = args[1] + case 3: + c.pa.subject = args[0] + c.pa.reply = args[1] + c.pa.size = parseSize(args[2]) + c.pa.szb = args[2] + default: + return fmt.Errorf("processPub Parse Error: '%s'", arg) + } + if c.pa.size < 0 { + return fmt.Errorf("processPub Bad or Missing Size: '%s'", arg) + } + maxPayload := atomic.LoadInt64(&c.mpay) + if maxPayload > 0 && int64(c.pa.size) > maxPayload { + c.maxPayloadViolation(c.pa.size, maxPayload) + return ErrMaxPayload + } + + if c.opts.Pedantic && !IsValidLiteralSubject(string(c.pa.subject)) { + c.sendErr("Invalid Publish Subject") + } + return nil +} + +func splitArg(arg []byte) [][]byte { + a := [MAX_MSG_ARGS][]byte{} + args := a[:0] + start := -1 + for i, b := range arg { + switch b { + case ' ', '\t', '\r', '\n': + if start >= 0 { + args = append(args, arg[start:i]) + start = -1 + } + default: + if start < 0 { + start = i + } + } + } + if start >= 0 { + args = append(args, arg[start:]) + } + return args +} + +func (c *client) processSub(argo []byte) (err error) { + c.traceInOp("SUB", argo) + + // Indicate activity. + c.in.subs++ + + // Copy so we do not reference a potentially large buffer + arg := make([]byte, len(argo)) + copy(arg, argo) + args := splitArg(arg) + sub := &subscription{client: c} + switch len(args) { + case 2: + sub.subject = args[0] + sub.queue = nil + sub.sid = args[1] + case 3: + sub.subject = args[0] + sub.queue = args[1] + sub.sid = args[2] + default: + return fmt.Errorf("processSub Parse Error: '%s'", arg) + } + + shouldForward := false + + c.mu.Lock() + if c.nc == nil { + c.mu.Unlock() + return nil + } + + // Check permissions if applicable. + if c.typ == ROUTER { + if !c.canExport(sub.subject) { + c.mu.Unlock() + return nil + } + } else if !c.canSubscribe(sub.subject) { + c.mu.Unlock() + c.sendErr(fmt.Sprintf("Permissions Violation for Subscription to %q", sub.subject)) + c.Errorf("Subscription Violation - User %q, Subject %q, SID %s", + c.opts.Username, sub.subject, sub.sid) + return nil + } + + if c.msubs > 0 && len(c.subs) >= c.msubs { + c.mu.Unlock() + c.maxSubsExceeded() + return nil + } + + // Check if we have a maximum on the number of subscriptions. + // We can have two SUB protocols coming from a route due to some + // race conditions. We should make sure that we process only one. + sid := string(sub.sid) + if c.subs[sid] == nil { + c.subs[sid] = sub + if c.srv != nil { + err = c.srv.sl.Insert(sub) + if err != nil { + delete(c.subs, sid) + } else { + shouldForward = c.typ != ROUTER + } + } + } + c.mu.Unlock() + if err != nil { + c.sendErr("Invalid Subject") + return nil + } else if c.opts.Verbose { + c.sendOK() + } + if shouldForward { + c.srv.broadcastSubscribe(sub) + } + + return nil +} + +// canSubscribe determines if the client is authorized to subscribe to the +// given subject. Assumes caller is holding lock. +func (c *client) canSubscribe(sub []byte) bool { + if c.perms == nil { + return true + } + return len(c.perms.sub.Match(string(sub)).psubs) > 0 +} + +// Low level unsubscribe for a given client. +func (c *client) unsubscribe(sub *subscription) { + c.mu.Lock() + defer c.mu.Unlock() + if sub.max > 0 && sub.nm < sub.max { + c.Debugf( + "Deferring actual UNSUB(%s): %d max, %d received\n", + string(sub.subject), sub.max, sub.nm) + return + } + c.traceOp("<-> %s", "DELSUB", sub.sid) + + delete(c.subs, string(sub.sid)) + if c.srv != nil { + c.srv.sl.Remove(sub) + } + + // If we are a queue subscriber on a client connection and we have routes, + // we will remember the remote sid and the queue group in case a route + // tries to deliver us a message. Remote queue subscribers are directed + // so we need to know what to do to avoid unnecessary message drops + // from [auto-]unsubscribe. + if c.typ == CLIENT && c.srv != nil && len(sub.queue) > 0 { + c.srv.holdRemoteQSub(sub) + } +} + +func (c *client) processUnsub(arg []byte) error { + c.traceInOp("UNSUB", arg) + args := splitArg(arg) + var sid []byte + max := -1 + + switch len(args) { + case 1: + sid = args[0] + case 2: + sid = args[0] + max = parseSize(args[1]) + default: + return fmt.Errorf("processUnsub Parse Error: '%s'", arg) + } + + // Indicate activity. + c.in.subs += 1 + + var sub *subscription + + unsub := false + shouldForward := false + ok := false + + c.mu.Lock() + if sub, ok = c.subs[string(sid)]; ok { + if max > 0 { + sub.max = int64(max) + } else { + // Clear it here to override + sub.max = 0 + } + unsub = true + shouldForward = c.typ != ROUTER && c.srv != nil + } + c.mu.Unlock() + + if unsub { + c.unsubscribe(sub) + } + if shouldForward { + c.srv.broadcastUnSubscribe(sub) + } + if c.opts.Verbose { + c.sendOK() + } + + return nil +} + +func (c *client) msgHeader(mh []byte, sub *subscription) []byte { + mh = append(mh, sub.sid...) + mh = append(mh, ' ') + if c.pa.reply != nil { + mh = append(mh, c.pa.reply...) + mh = append(mh, ' ') + } + mh = append(mh, c.pa.szb...) + mh = append(mh, "\r\n"...) + return mh +} + +// Used to treat maps as efficient set +var needFlush = struct{}{} +var routeSeen = struct{}{} + +func (c *client) deliverMsg(sub *subscription, mh, msg []byte) bool { + if sub.client == nil { + return false + } + client := sub.client + client.mu.Lock() + + // Check echo + if c == client && !client.echo { + client.mu.Unlock() + return false + } + + srv := client.srv + + sub.nm++ + // Check if we should auto-unsubscribe. + if sub.max > 0 { + // For routing.. + shouldForward := client.typ != ROUTER && client.srv != nil + // If we are at the exact number, unsubscribe but + // still process the message in hand, otherwise + // unsubscribe and drop message on the floor. + if sub.nm == sub.max { + c.Debugf("Auto-unsubscribe limit of %d reached for sid '%s'\n", sub.max, string(sub.sid)) + // Due to defer, reverse the code order so that execution + // is consistent with other cases where we unsubscribe. + if shouldForward { + defer srv.broadcastUnSubscribe(sub) + } + defer client.unsubscribe(sub) + } else if sub.nm > sub.max { + c.Debugf("Auto-unsubscribe limit [%d] exceeded\n", sub.max) + client.mu.Unlock() + client.unsubscribe(sub) + if shouldForward { + srv.broadcastUnSubscribe(sub) + } + return false + } + } + + // Check for closed connection + if client.nc == nil { + client.mu.Unlock() + return false + } + + // Update statistics + + // The msg includes the CR_LF, so pull back out for accounting. + msgSize := int64(len(msg) - LEN_CR_LF) + + // No atomic needed since accessed under client lock. + // Monitor is reading those also under client's lock. + client.outMsgs++ + client.outBytes += msgSize + + atomic.AddInt64(&srv.outMsgs, 1) + atomic.AddInt64(&srv.outBytes, msgSize) + + // Queue to outbound buffer + client.queueOutbound(mh) + client.queueOutbound(msg) + + client.out.pm++ + + // Check outbound threshold and queue IO flush if needed. + if client.out.pm > 1 && client.out.pb > maxBufSize*2 { + client.flushSignal() + } + + if c.trace { + client.traceOutOp(string(mh[:len(mh)-LEN_CR_LF]), nil) + } + + // Increment the flush pending signals if we are setting for the first time. + if _, ok := c.pcd[client]; !ok { + client.out.fsp++ + } + client.mu.Unlock() + + // Remember for when we return to the top of the loop. + c.pcd[client] = needFlush + + return true +} + +// pruneCache will prune the cache via randomly +// deleting items. Doing so pruneSize items at a time. +func (c *client) prunePubPermsCache() { + r := 0 + for subject := range c.perms.pcache { + delete(c.perms.pcache, subject) + if r++; r > pruneSize { + break + } + } +} + +// pubAllowed checks on publish permissioning. +func (c *client) pubAllowed(subject []byte) bool { + // Disallow publish to _SYS.>, these are reserved for internals. + if len(subject) > 4 && string(subject[:5]) == "_SYS." { + return false + } + if c.perms == nil { + return true + } + + // Check if published subject is allowed if we have permissions in place. + allowed, ok := c.perms.pcache[string(subject)] + if ok { + return allowed + } + // Cache miss + r := c.perms.pub.Match(string(subject)) + allowed = len(r.psubs) != 0 + c.perms.pcache[string(subject)] = allowed + // Prune if needed. + if len(c.perms.pcache) > maxPermCacheSize { + c.prunePubPermsCache() + } + return allowed +} + +// prepMsgHeader will prepare the message header prefix +func (c *client) prepMsgHeader() []byte { + // Use the scratch buffer.. + msgh := c.msgb[:msgHeadProtoLen] + + // msg header + msgh = append(msgh, c.pa.subject...) + return append(msgh, ' ') +} + +// processMsg is called to process an inbound msg from a client. +func (c *client) processMsg(msg []byte) { + // Snapshot server. + srv := c.srv + + // Update statistics + // The msg includes the CR_LF, so pull back out for accounting. + c.in.msgs += 1 + c.in.bytes += len(msg) - LEN_CR_LF + + if c.trace { + c.traceMsg(msg) + } + + // Check pub permissions (don't do this for routes) + if c.typ == CLIENT && !c.pubAllowed(c.pa.subject) { + c.pubPermissionViolation(c.pa.subject) + return + } + + if c.opts.Verbose { + c.sendOK() + } + + // Mostly under testing scenarios. + if srv == nil { + return + } + + // Match the subscriptions. We will use our own L1 map if + // it's still valid, avoiding contention on the shared sublist. + var r *SublistResult + var ok bool + + genid := atomic.LoadUint64(&srv.sl.genid) + + if genid == c.in.genid && c.in.results != nil { + r, ok = c.in.results[string(c.pa.subject)] + } else { + // reset our L1 completely. + c.in.results = make(map[string]*SublistResult) + c.in.genid = genid + } + + if !ok { + subject := string(c.pa.subject) + r = srv.sl.Match(subject) + c.in.results[subject] = r + // Prune the results cache. Keeps us from unbounded growth. + if len(c.in.results) > maxResultCacheSize { + n := 0 + for subject := range c.in.results { + delete(c.in.results, subject) + if n++; n > pruneSize { + break + } + } + } + } + + // This is the fanout scale. + fanout := len(r.psubs) + len(r.qsubs) + + // Check for no interest, short circuit if so. + if fanout == 0 { + return + } + + if c.typ == ROUTER { + c.processRoutedMsg(r, msg) + return + } + + // Client connection processing here. + msgh := c.prepMsgHeader() + si := len(msgh) + + // Used to only send messages once across any given route. + var rmap map[string]struct{} + + // Loop over all normal subscriptions that match. + for _, sub := range r.psubs { + // Check if this is a send to a ROUTER, make sure we only send it + // once. The other side will handle the appropriate re-processing + // and fan-out. Also enforce 1-Hop semantics, so no routing to another. + if sub.client.typ == ROUTER { + // Check to see if we have already sent it here. + if rmap == nil { + rmap = make(map[string]struct{}, srv.numRoutes()) + } + sub.client.mu.Lock() + if sub.client.nc == nil || + sub.client.route == nil || + sub.client.route.remoteID == "" { + c.Debugf("Bad or Missing ROUTER Identity, not processing msg") + sub.client.mu.Unlock() + continue + } + if _, ok := rmap[sub.client.route.remoteID]; ok { + c.Debugf("Ignoring route, already processed and sent msg") + sub.client.mu.Unlock() + continue + } + rmap[sub.client.route.remoteID] = routeSeen + sub.client.mu.Unlock() + } + // Normal delivery + mh := c.msgHeader(msgh[:si], sub) + c.deliverMsg(sub, mh, msg) + } + + // Check to see if we have our own rand yet. Global rand + // has contention with lots of clients, etc. + if c.in.prand == nil { + c.in.prand = rand.New(rand.NewSource(time.Now().UnixNano())) + } + // Process queue subs + for i := 0; i < len(r.qsubs); i++ { + qsubs := r.qsubs[i] + // Find a subscription that is able to deliver this message + // starting at a random index. + startIndex := c.in.prand.Intn(len(qsubs)) + for i := 0; i < len(qsubs); i++ { + index := (startIndex + i) % len(qsubs) + sub := qsubs[index] + if sub != nil { + mh := c.msgHeader(msgh[:si], sub) + if c.deliverMsg(sub, mh, msg) { + break + } + } + } + } +} + +func (c *client) pubPermissionViolation(subject []byte) { + c.sendErr(fmt.Sprintf("Permissions Violation for Publish to %q", subject)) + c.Errorf("Publish Violation - User %q, Subject %q", c.opts.Username, subject) +} + +func (c *client) processPingTimer() { + c.mu.Lock() + defer c.mu.Unlock() + c.ping.tmr = nil + // Check if connection is still opened + if c.nc == nil { + return + } + + c.Debugf("%s Ping Timer", c.typeString()) + + // Check for violation + if c.ping.out+1 > c.srv.getOpts().MaxPingsOut { + c.Debugf("Stale Client Connection - Closing") + c.sendProto([]byte(fmt.Sprintf("-ERR '%s'\r\n", "Stale Connection")), true) + c.clearConnection(StaleConnection) + return + } + + // If we have had activity within the PingInterval no + // need to send a ping. + if delta := time.Since(c.last); delta < c.srv.getOpts().PingInterval { + c.Debugf("Delaying PING due to activity %v ago", delta.Round(time.Second)) + } else { + // Send PING + c.sendPing() + } + + // Reset to fire again. + c.setPingTimer() +} + +// Lock should be held +func (c *client) setPingTimer() { + if c.srv == nil { + return + } + d := c.srv.getOpts().PingInterval + c.ping.tmr = time.AfterFunc(d, c.processPingTimer) +} + +// Lock should be held +func (c *client) clearPingTimer() { + if c.ping.tmr == nil { + return + } + c.ping.tmr.Stop() + c.ping.tmr = nil +} + +// Lock should be held +func (c *client) setAuthTimer(d time.Duration) { + c.atmr = time.AfterFunc(d, func() { c.authTimeout() }) +} + +// Lock should be held +func (c *client) clearAuthTimer() bool { + if c.atmr == nil { + return true + } + stopped := c.atmr.Stop() + c.atmr = nil + return stopped +} + +func (c *client) isAuthTimerSet() bool { + c.mu.Lock() + isSet := c.atmr != nil + c.mu.Unlock() + return isSet +} + +// Lock should be held +func (c *client) clearConnection(reason ClosedState) { + if c.flags.isSet(clearConnection) { + return + } + c.flags.set(clearConnection) + + nc := c.nc + if nc == nil || c.srv == nil { + return + } + // Flush any pending. + c.flushOutbound() + + // Clear outbound here. + c.out.sg.Broadcast() + + // With TLS, Close() is sending an alert (that is doing a write). + // Need to set a deadline otherwise the server could block there + // if the peer is not reading from socket. + if c.flags.isSet(handshakeComplete) { + nc.SetWriteDeadline(time.Now().Add(c.out.wdl)) + } + nc.Close() + // Do this always to also kick out any IO writes. + nc.SetWriteDeadline(time.Time{}) + + // Save off the connection if its a client. + if c.typ == CLIENT && c.srv != nil { + go c.srv.saveClosedClient(c, nc, reason) + } +} + +func (c *client) typeString() string { + switch c.typ { + case CLIENT: + return "Client" + case ROUTER: + return "Router" + } + return "Unknown Type" +} + +func (c *client) closeConnection(reason ClosedState) { + c.mu.Lock() + if c.nc == nil { + c.mu.Unlock() + return + } + + c.Debugf("%s connection closed", c.typeString()) + + c.clearAuthTimer() + c.clearPingTimer() + c.clearConnection(reason) + c.nc = nil + + // Snapshot for use. + subs := make([]*subscription, 0, len(c.subs)) + for _, sub := range c.subs { + // Auto-unsubscribe subscriptions must be unsubscribed forcibly. + sub.max = 0 + subs = append(subs, sub) + } + srv := c.srv + + var ( + routeClosed bool + retryImplicit bool + connectURLs []string + ) + if c.route != nil { + routeClosed = c.route.closed + if !routeClosed { + retryImplicit = c.route.retry + } + connectURLs = c.route.connectURLs + } + + c.mu.Unlock() + + if srv != nil { + // This is a route that disconnected... + if len(connectURLs) > 0 { + // Unless disabled, possibly update the server's INFO protcol + // and send to clients that know how to handle async INFOs. + if !srv.getOpts().Cluster.NoAdvertise { + srv.removeClientConnectURLsAndSendINFOToClients(connectURLs) + } + } + + // Unregister + srv.removeClient(c) + + // Remove clients subscriptions. + srv.sl.RemoveBatch(subs) + if c.typ != ROUTER { + for _, sub := range subs { + // Forward on unsubscribes if we are not + // a router ourselves. + srv.broadcastUnSubscribe(sub) + } + } + } + + // Don't reconnect routes that are being closed. + if routeClosed { + return + } + + // Check for a solicited route. If it was, start up a reconnect unless + // we are already connected to the other end. + if c.isSolicitedRoute() || retryImplicit { + // Capture these under lock + c.mu.Lock() + rid := c.route.remoteID + rtype := c.route.routeType + rurl := c.route.url + c.mu.Unlock() + + srv.mu.Lock() + defer srv.mu.Unlock() + + // It is possible that the server is being shutdown. + // If so, don't try to reconnect + if !srv.running { + return + } + + if rid != "" && srv.remotes[rid] != nil { + c.srv.Debugf("Not attempting reconnect for solicited route, already connected to \"%s\"", rid) + return + } else if rid == srv.info.ID { + c.srv.Debugf("Detected route to self, ignoring \"%s\"", rurl) + return + } else if rtype != Implicit || retryImplicit { + c.srv.Debugf("Attempting reconnect for solicited route \"%s\"", rurl) + // Keep track of this go-routine so we can wait for it on + // server shutdown. + srv.startGoRoutine(func() { srv.reConnectToRoute(rurl, rtype) }) + } + } +} + +// If the client is a route connection, sets the `closed` flag to true +// to prevent any reconnecting attempt when c.closeConnection() is called. +func (c *client) setRouteNoReconnectOnClose() { + c.mu.Lock() + if c.route != nil { + c.route.closed = true + } + c.mu.Unlock() +} + +// Logging functionality scoped to a client or route. + +func (c *client) Errorf(format string, v ...interface{}) { + format = fmt.Sprintf("%s - %s", c, format) + c.srv.Errorf(format, v...) +} + +func (c *client) Debugf(format string, v ...interface{}) { + format = fmt.Sprintf("%s - %s", c, format) + c.srv.Debugf(format, v...) +} + +func (c *client) Noticef(format string, v ...interface{}) { + format = fmt.Sprintf("%s - %s", c, format) + c.srv.Noticef(format, v...) +} + +func (c *client) Tracef(format string, v ...interface{}) { + format = fmt.Sprintf("%s - %s", c, format) + c.srv.Tracef(format, v...) +} diff --git a/vendor/github.com/nats-io/gnatsd/server/client_test.go b/vendor/github.com/nats-io/gnatsd/server/client_test.go new file mode 100644 index 00000000..29271a34 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/client_test.go @@ -0,0 +1,1096 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "net" + "reflect" + "regexp" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "crypto/rand" + "crypto/tls" + + "github.com/nats-io/go-nats" +) + +type serverInfo struct { + Id string `json:"server_id"` + Host string `json:"host"` + Port uint `json:"port"` + Version string `json:"version"` + AuthRequired bool `json:"auth_required"` + TLSRequired bool `json:"tls_required"` + MaxPayload int64 `json:"max_payload"` +} + +func createClientAsync(ch chan *client, s *Server, cli net.Conn) { + go func() { + c := s.createClient(cli) + // Must be here to suppress +OK + c.opts.Verbose = false + ch <- c + }() +} + +var defaultServerOptions = Options{ + Trace: false, + Debug: false, + NoLog: true, + NoSigs: true, +} + +func rawSetup(serverOptions Options) (*Server, *client, *bufio.Reader, string) { + cli, srv := net.Pipe() + cr := bufio.NewReaderSize(cli, maxBufSize) + s := New(&serverOptions) + + ch := make(chan *client) + createClientAsync(ch, s, srv) + + l, _ := cr.ReadString('\n') + + // Grab client + c := <-ch + return s, c, cr, l +} + +func setUpClientWithResponse() (*client, string) { + _, c, _, l := rawSetup(defaultServerOptions) + return c, l +} + +func setupClient() (*Server, *client, *bufio.Reader) { + s, c, cr, _ := rawSetup(defaultServerOptions) + return s, c, cr +} + +func checkClientsCount(t *testing.T, s *Server, expected int) { + t.Helper() + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + if nc := s.NumClients(); nc != expected { + return fmt.Errorf("The number of expected connections was %v, got %v", expected, nc) + } + return nil + }) +} + +func TestClientCreateAndInfo(t *testing.T) { + c, l := setUpClientWithResponse() + + if c.cid != 1 { + t.Fatalf("Expected cid of 1 vs %d\n", c.cid) + } + if c.state != OP_START { + t.Fatal("Expected state to be OP_START") + } + + if !strings.HasPrefix(l, "INFO ") { + t.Fatalf("INFO response incorrect: %s\n", l) + } + // Make sure payload is proper json + var info serverInfo + err := json.Unmarshal([]byte(l[5:]), &info) + if err != nil { + t.Fatalf("Could not parse INFO json: %v\n", err) + } + // Sanity checks + if info.MaxPayload != MAX_PAYLOAD_SIZE || + info.AuthRequired || info.TLSRequired || + info.Port != DEFAULT_PORT { + t.Fatalf("INFO inconsistent: %+v\n", info) + } +} + +func TestNonTLSConnectionState(t *testing.T) { + _, c, _ := setupClient() + state := c.GetTLSConnectionState() + if state != nil { + t.Error("GetTLSConnectionState() returned non-nil") + } +} + +func TestClientConnect(t *testing.T) { + _, c, _ := setupClient() + + // Basic Connect setting flags + connectOp := []byte("CONNECT {\"verbose\":true,\"pedantic\":true,\"tls_required\":false,\"echo\":false}\r\n") + err := c.parse(connectOp) + if err != nil { + t.Fatalf("Received error: %v\n", err) + } + if c.state != OP_START { + t.Fatalf("Expected state of OP_START vs %d\n", c.state) + } + if !reflect.DeepEqual(c.opts, clientOpts{Verbose: true, Pedantic: true, Echo: false}) { + t.Fatalf("Did not parse connect options correctly: %+v\n", c.opts) + } + + // Test that we can capture user/pass + connectOp = []byte("CONNECT {\"user\":\"derek\",\"pass\":\"foo\"}\r\n") + c.opts = defaultOpts + err = c.parse(connectOp) + if err != nil { + t.Fatalf("Received error: %v\n", err) + } + if c.state != OP_START { + t.Fatalf("Expected state of OP_START vs %d\n", c.state) + } + if !reflect.DeepEqual(c.opts, clientOpts{Echo: true, Verbose: true, Pedantic: true, Username: "derek", Password: "foo"}) { + t.Fatalf("Did not parse connect options correctly: %+v\n", c.opts) + } + + // Test that we can capture client name + connectOp = []byte("CONNECT {\"user\":\"derek\",\"pass\":\"foo\",\"name\":\"router\"}\r\n") + c.opts = defaultOpts + err = c.parse(connectOp) + if err != nil { + t.Fatalf("Received error: %v\n", err) + } + if c.state != OP_START { + t.Fatalf("Expected state of OP_START vs %d\n", c.state) + } + + if !reflect.DeepEqual(c.opts, clientOpts{Echo: true, Verbose: true, Pedantic: true, Username: "derek", Password: "foo", Name: "router"}) { + t.Fatalf("Did not parse connect options correctly: %+v\n", c.opts) + } + + // Test that we correctly capture auth tokens + connectOp = []byte("CONNECT {\"auth_token\":\"YZZ222\",\"name\":\"router\"}\r\n") + c.opts = defaultOpts + err = c.parse(connectOp) + if err != nil { + t.Fatalf("Received error: %v\n", err) + } + if c.state != OP_START { + t.Fatalf("Expected state of OP_START vs %d\n", c.state) + } + + if !reflect.DeepEqual(c.opts, clientOpts{Echo: true, Verbose: true, Pedantic: true, Authorization: "YZZ222", Name: "router"}) { + t.Fatalf("Did not parse connect options correctly: %+v\n", c.opts) + } +} + +func TestClientConnectProto(t *testing.T) { + _, c, r := setupClient() + + // Basic Connect setting flags, proto should be zero (original proto) + connectOp := []byte("CONNECT {\"verbose\":true,\"pedantic\":true,\"tls_required\":false}\r\n") + err := c.parse(connectOp) + if err != nil { + t.Fatalf("Received error: %v\n", err) + } + if c.state != OP_START { + t.Fatalf("Expected state of OP_START vs %d\n", c.state) + } + if !reflect.DeepEqual(c.opts, clientOpts{Echo: true, Verbose: true, Pedantic: true, Protocol: ClientProtoZero}) { + t.Fatalf("Did not parse connect options correctly: %+v\n", c.opts) + } + + // ProtoInfo + connectOp = []byte(fmt.Sprintf("CONNECT {\"verbose\":true,\"pedantic\":true,\"tls_required\":false,\"protocol\":%d}\r\n", ClientProtoInfo)) + err = c.parse(connectOp) + if err != nil { + t.Fatalf("Received error: %v\n", err) + } + if c.state != OP_START { + t.Fatalf("Expected state of OP_START vs %d\n", c.state) + } + if !reflect.DeepEqual(c.opts, clientOpts{Echo: true, Verbose: true, Pedantic: true, Protocol: ClientProtoInfo}) { + t.Fatalf("Did not parse connect options correctly: %+v\n", c.opts) + } + if c.opts.Protocol != ClientProtoInfo { + t.Fatalf("Protocol should have been set to %v, but is set to %v", ClientProtoInfo, c.opts.Protocol) + } + + // Illegal Option + connectOp = []byte("CONNECT {\"protocol\":22}\r\n") + wg := sync.WaitGroup{} + wg.Add(1) + // The client here is using a pipe, we need to be dequeuing + // data otherwise the server would be blocked trying to send + // the error back to it. + go func() { + defer wg.Done() + for { + if _, _, err := r.ReadLine(); err != nil { + return + } + } + }() + err = c.parse(connectOp) + if err == nil { + t.Fatalf("Expected to receive an error\n") + } + if err != ErrBadClientProtocol { + t.Fatalf("Expected err of %q, got %q\n", ErrBadClientProtocol, err) + } + wg.Wait() +} + +func TestClientPing(t *testing.T) { + _, c, cr := setupClient() + + // PING + pingOp := []byte("PING\r\n") + go c.parse(pingOp) + l, err := cr.ReadString('\n') + if err != nil { + t.Fatalf("Error receiving info from server: %v\n", err) + } + if !strings.HasPrefix(l, "PONG\r\n") { + t.Fatalf("PONG response incorrect: %s\n", l) + } +} + +var msgPat = regexp.MustCompile(`\AMSG\s+([^\s]+)\s+([^\s]+)\s+(([^\s]+)[^\S\r\n]+)?(\d+)\r\n`) + +const ( + SUB_INDEX = 1 + SID_INDEX = 2 + REPLY_INDEX = 4 + LEN_INDEX = 5 +) + +func checkPayload(cr *bufio.Reader, expected []byte, t *testing.T) { + // Read in payload + d := make([]byte, len(expected)) + n, err := cr.Read(d) + if err != nil { + t.Fatalf("Error receiving msg payload from server: %v\n", err) + } + if n != len(expected) { + t.Fatalf("Did not read correct amount of bytes: %d vs %d\n", n, len(expected)) + } + if !bytes.Equal(d, expected) { + t.Fatalf("Did not read correct payload:: <%s>\n", d) + } +} + +func TestClientSimplePubSub(t *testing.T) { + _, c, cr := setupClient() + // SUB/PUB + go c.parse([]byte("SUB foo 1\r\nPUB foo 5\r\nhello\r\nPING\r\n")) + l, err := cr.ReadString('\n') + if err != nil { + t.Fatalf("Error receiving msg from server: %v\n", err) + } + matches := msgPat.FindAllStringSubmatch(l, -1)[0] + if len(matches) != 6 { + t.Fatalf("Did not get correct # matches: %d vs %d\n", len(matches), 6) + } + if matches[SUB_INDEX] != "foo" { + t.Fatalf("Did not get correct subject: '%s'\n", matches[SUB_INDEX]) + } + if matches[SID_INDEX] != "1" { + t.Fatalf("Did not get correct sid: '%s'\n", matches[SID_INDEX]) + } + if matches[LEN_INDEX] != "5" { + t.Fatalf("Did not get correct msg length: '%s'\n", matches[LEN_INDEX]) + } + checkPayload(cr, []byte("hello\r\n"), t) +} + +func TestClientPubSubNoEcho(t *testing.T) { + _, c, cr := setupClient() + // Specify no echo + connectOp := []byte("CONNECT {\"echo\":false}\r\n") + err := c.parse(connectOp) + if err != nil { + t.Fatalf("Received error: %v\n", err) + } + // SUB/PUB + go c.parse([]byte("SUB foo 1\r\nPUB foo 5\r\nhello\r\nPING\r\n")) + l, err := cr.ReadString('\n') + if err != nil { + t.Fatalf("Error receiving msg from server: %v\n", err) + } + // We should not receive anything but a PONG since we specified no echo. + if !strings.HasPrefix(l, "PONG\r\n") { + t.Fatalf("PONG response incorrect: %q\n", l) + } +} + +func TestClientSimplePubSubWithReply(t *testing.T) { + _, c, cr := setupClient() + + // SUB/PUB + go c.parse([]byte("SUB foo 1\r\nPUB foo bar 5\r\nhello\r\nPING\r\n")) + l, err := cr.ReadString('\n') + if err != nil { + t.Fatalf("Error receiving msg from server: %v\n", err) + } + matches := msgPat.FindAllStringSubmatch(l, -1)[0] + if len(matches) != 6 { + t.Fatalf("Did not get correct # matches: %d vs %d\n", len(matches), 6) + } + if matches[SUB_INDEX] != "foo" { + t.Fatalf("Did not get correct subject: '%s'\n", matches[SUB_INDEX]) + } + if matches[SID_INDEX] != "1" { + t.Fatalf("Did not get correct sid: '%s'\n", matches[SID_INDEX]) + } + if matches[REPLY_INDEX] != "bar" { + t.Fatalf("Did not get correct reply subject: '%s'\n", matches[REPLY_INDEX]) + } + if matches[LEN_INDEX] != "5" { + t.Fatalf("Did not get correct msg length: '%s'\n", matches[LEN_INDEX]) + } +} + +func TestClientNoBodyPubSubWithReply(t *testing.T) { + _, c, cr := setupClient() + + // SUB/PUB + go c.parse([]byte("SUB foo 1\r\nPUB foo bar 0\r\n\r\nPING\r\n")) + l, err := cr.ReadString('\n') + if err != nil { + t.Fatalf("Error receiving msg from server: %v\n", err) + } + matches := msgPat.FindAllStringSubmatch(l, -1)[0] + if len(matches) != 6 { + t.Fatalf("Did not get correct # matches: %d vs %d\n", len(matches), 6) + } + if matches[SUB_INDEX] != "foo" { + t.Fatalf("Did not get correct subject: '%s'\n", matches[SUB_INDEX]) + } + if matches[SID_INDEX] != "1" { + t.Fatalf("Did not get correct sid: '%s'\n", matches[SID_INDEX]) + } + if matches[REPLY_INDEX] != "bar" { + t.Fatalf("Did not get correct reply subject: '%s'\n", matches[REPLY_INDEX]) + } + if matches[LEN_INDEX] != "0" { + t.Fatalf("Did not get correct msg length: '%s'\n", matches[LEN_INDEX]) + } +} + +func (c *client) parseFlushAndClose(op []byte) { + c.parse(op) + for cp := range c.pcd { + cp.mu.Lock() + cp.flushOutbound() + cp.mu.Unlock() + } + c.nc.Close() +} + +func TestClientPubWithQueueSub(t *testing.T) { + _, c, cr := setupClient() + + num := 100 + + // Queue SUB/PUB + subs := []byte("SUB foo g1 1\r\nSUB foo g1 2\r\n") + pubs := []byte("PUB foo bar 5\r\nhello\r\n") + op := []byte{} + op = append(op, subs...) + for i := 0; i < num; i++ { + op = append(op, pubs...) + } + + go c.parseFlushAndClose(op) + + var n1, n2, received int + for ; ; received++ { + l, err := cr.ReadString('\n') + if err != nil { + break + } + matches := msgPat.FindAllStringSubmatch(l, -1)[0] + + // Count which sub + switch matches[SID_INDEX] { + case "1": + n1++ + case "2": + n2++ + } + checkPayload(cr, []byte("hello\r\n"), t) + } + if received != num { + t.Fatalf("Received wrong # of msgs: %d vs %d\n", received, num) + } + // Threshold for randomness for now + if n1 < 20 || n2 < 20 { + t.Fatalf("Received wrong # of msgs per subscriber: %d - %d\n", n1, n2) + } +} + +func TestClientPubWithQueueSubNoEcho(t *testing.T) { + opts := DefaultOptions() + s := RunServer(opts) + defer s.Shutdown() + + nc1, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc1.Close() + + // Grab the client from server and set no echo by hand. + s.mu.Lock() + lc := len(s.clients) + c := s.clients[s.gcid] + s.mu.Unlock() + + if lc != 1 { + t.Fatalf("Expected only 1 client but got %d\n", lc) + } + if c == nil { + t.Fatal("Expected to retrieve client\n") + } + c.mu.Lock() + c.echo = false + c.mu.Unlock() + + // Queue sub on nc1. + _, err = nc1.QueueSubscribe("foo", "bar", func(*nats.Msg) {}) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + nc1.Flush() + + nc2, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc2.Close() + + n := int32(0) + cb := func(m *nats.Msg) { + atomic.AddInt32(&n, 1) + } + + _, err = nc2.QueueSubscribe("foo", "bar", cb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + nc2.Flush() + + // Now publish 100 messages on nc1 which does not allow echo. + for i := 0; i < 100; i++ { + nc1.Publish("foo", []byte("Hello")) + } + nc1.Flush() + nc2.Flush() + + checkFor(t, 5*time.Second, 10*time.Millisecond, func() error { + num := atomic.LoadInt32(&n) + if num != int32(100) { + return fmt.Errorf("Expected all the msgs to be received by nc2, got %d\n", num) + } + return nil + }) +} + +func TestClientUnSub(t *testing.T) { + _, c, cr := setupClient() + + num := 1 + + // SUB/PUB + subs := []byte("SUB foo 1\r\nSUB foo 2\r\n") + unsub := []byte("UNSUB 1\r\n") + pub := []byte("PUB foo bar 5\r\nhello\r\n") + + op := []byte{} + op = append(op, subs...) + op = append(op, unsub...) + op = append(op, pub...) + + go c.parseFlushAndClose(op) + + var received int + for ; ; received++ { + l, err := cr.ReadString('\n') + if err != nil { + break + } + matches := msgPat.FindAllStringSubmatch(l, -1)[0] + if matches[SID_INDEX] != "2" { + t.Fatalf("Received msg on unsubscribed subscription!\n") + } + checkPayload(cr, []byte("hello\r\n"), t) + } + if received != num { + t.Fatalf("Received wrong # of msgs: %d vs %d\n", received, num) + } +} + +func TestClientUnSubMax(t *testing.T) { + _, c, cr := setupClient() + + num := 10 + exp := 5 + + // SUB/PUB + subs := []byte("SUB foo 1\r\n") + unsub := []byte("UNSUB 1 5\r\n") + pub := []byte("PUB foo bar 5\r\nhello\r\n") + + op := []byte{} + op = append(op, subs...) + op = append(op, unsub...) + for i := 0; i < num; i++ { + op = append(op, pub...) + } + + go c.parseFlushAndClose(op) + + var received int + for ; ; received++ { + l, err := cr.ReadString('\n') + if err != nil { + break + } + matches := msgPat.FindAllStringSubmatch(l, -1)[0] + if matches[SID_INDEX] != "1" { + t.Fatalf("Received msg on unsubscribed subscription!\n") + } + checkPayload(cr, []byte("hello\r\n"), t) + } + if received != exp { + t.Fatalf("Received wrong # of msgs: %d vs %d\n", received, exp) + } +} + +func TestClientAutoUnsubExactReceived(t *testing.T) { + _, c, _ := setupClient() + defer c.nc.Close() + + // SUB/PUB + subs := []byte("SUB foo 1\r\n") + unsub := []byte("UNSUB 1 1\r\n") + pub := []byte("PUB foo bar 2\r\nok\r\n") + + op := []byte{} + op = append(op, subs...) + op = append(op, unsub...) + op = append(op, pub...) + + ch := make(chan bool) + go func() { + c.parse(op) + ch <- true + }() + + // Wait for processing + <-ch + + // We should not have any subscriptions in place here. + if len(c.subs) != 0 { + t.Fatalf("Wrong number of subscriptions: expected 0, got %d\n", len(c.subs)) + } +} + +func TestClientUnsubAfterAutoUnsub(t *testing.T) { + _, c, _ := setupClient() + defer c.nc.Close() + + // SUB/UNSUB/UNSUB + subs := []byte("SUB foo 1\r\n") + asub := []byte("UNSUB 1 1\r\n") + unsub := []byte("UNSUB 1\r\n") + + op := []byte{} + op = append(op, subs...) + op = append(op, asub...) + op = append(op, unsub...) + + ch := make(chan bool) + go func() { + c.parse(op) + ch <- true + }() + + // Wait for processing + <-ch + + // We should not have any subscriptions in place here. + if len(c.subs) != 0 { + t.Fatalf("Wrong number of subscriptions: expected 0, got %d\n", len(c.subs)) + } +} + +func TestClientRemoveSubsOnDisconnect(t *testing.T) { + s, c, _ := setupClient() + subs := []byte("SUB foo 1\r\nSUB bar 2\r\n") + + ch := make(chan bool) + go func() { + c.parse(subs) + ch <- true + }() + <-ch + + if s.sl.Count() != 2 { + t.Fatalf("Should have 2 subscriptions, got %d\n", s.sl.Count()) + } + c.closeConnection(ClientClosed) + if s.sl.Count() != 0 { + t.Fatalf("Should have no subscriptions after close, got %d\n", s.sl.Count()) + } +} + +func TestClientDoesNotAddSubscriptionsWhenConnectionClosed(t *testing.T) { + s, c, _ := setupClient() + c.closeConnection(ClientClosed) + subs := []byte("SUB foo 1\r\nSUB bar 2\r\n") + + ch := make(chan bool) + go func() { + c.parse(subs) + ch <- true + }() + <-ch + + if s.sl.Count() != 0 { + t.Fatalf("Should have no subscriptions after close, got %d\n", s.sl.Count()) + } +} + +func TestClientMapRemoval(t *testing.T) { + s, c, _ := setupClient() + c.nc.Close() + + checkClientsCount(t, s, 0) +} + +func TestAuthorizationTimeout(t *testing.T) { + serverOptions := DefaultOptions() + serverOptions.Authorization = "my_token" + serverOptions.AuthTimeout = 0.4 + s := RunServer(serverOptions) + defer s.Shutdown() + + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", serverOptions.Host, serverOptions.Port)) + if err != nil { + t.Fatalf("Error dialing server: %v\n", err) + } + defer conn.Close() + client := bufio.NewReaderSize(conn, maxBufSize) + if _, err := client.ReadString('\n'); err != nil { + t.Fatalf("Error receiving info from server: %v\n", err) + } + time.Sleep(3 * secondsToDuration(serverOptions.AuthTimeout)) + l, err := client.ReadString('\n') + if err != nil { + t.Fatalf("Error receiving info from server: %v\n", err) + } + if !strings.Contains(l, "Authorization Timeout") { + t.Fatalf("Authorization Timeout response incorrect: %q\n", l) + } +} + +// This is from bug report #18 +func TestTwoTokenPubMatchSingleTokenSub(t *testing.T) { + _, c, cr := setupClient() + test := []byte("PUB foo.bar 5\r\nhello\r\nSUB foo 1\r\nPING\r\nPUB foo.bar 5\r\nhello\r\nPING\r\n") + go c.parse(test) + l, err := cr.ReadString('\n') + if err != nil { + t.Fatalf("Error receiving info from server: %v\n", err) + } + if !strings.HasPrefix(l, "PONG\r\n") { + t.Fatalf("PONG response incorrect: %q\n", l) + } + // Expect just a pong, no match should exist here.. + l, _ = cr.ReadString('\n') + if !strings.HasPrefix(l, "PONG\r\n") { + t.Fatalf("PONG response was expected, got: %q\n", l) + } +} + +func TestUnsubRace(t *testing.T) { + opts := DefaultOptions() + s := RunServer(opts) + defer s.Shutdown() + + url := fmt.Sprintf("nats://%s:%d", + s.getOpts().Host, + s.Addr().(*net.TCPAddr).Port, + ) + nc, err := nats.Connect(url) + if err != nil { + t.Fatalf("Error creating client to %s: %v\n", url, err) + } + defer nc.Close() + + ncp, err := nats.Connect(url) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer ncp.Close() + + sub, _ := nc.Subscribe("foo", func(m *nats.Msg) { + // Just eat it.. + }) + nc.Flush() + + var wg sync.WaitGroup + + wg.Add(1) + + go func() { + for i := 0; i < 10000; i++ { + ncp.Publish("foo", []byte("hello")) + } + wg.Done() + }() + + time.Sleep(5 * time.Millisecond) + + sub.Unsubscribe() + + wg.Wait() +} + +func TestTLSCloseClientConnection(t *testing.T) { + opts, err := ProcessConfigFile("./configs/tls.conf") + if err != nil { + t.Fatalf("Error processing config file: %v", err) + } + opts.TLSTimeout = 100 + opts.NoLog = true + opts.NoSigs = true + s := RunServer(opts) + defer s.Shutdown() + + endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + conn, err := net.DialTimeout("tcp", endpoint, 2*time.Second) + if err != nil { + t.Fatalf("Unexpected error on dial: %v", err) + } + defer conn.Close() + br := bufio.NewReaderSize(conn, 100) + if _, err := br.ReadString('\n'); err != nil { + t.Fatalf("Unexpected error reading INFO: %v", err) + } + + tlsConn := tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) + defer tlsConn.Close() + if err := tlsConn.Handshake(); err != nil { + t.Fatalf("Unexpected error during handshake: %v", err) + } + br = bufio.NewReaderSize(tlsConn, 100) + connectOp := []byte("CONNECT {\"user\":\"derek\",\"pass\":\"foo\",\"verbose\":false,\"pedantic\":false,\"tls_required\":true}\r\n") + if _, err := tlsConn.Write(connectOp); err != nil { + t.Fatalf("Unexpected error writing CONNECT: %v", err) + } + if _, err := tlsConn.Write([]byte("PING\r\n")); err != nil { + t.Fatalf("Unexpected error writing PING: %v", err) + } + if _, err := br.ReadString('\n'); err != nil { + t.Fatalf("Unexpected error reading PONG: %v", err) + } + + // Check that client is registered. + checkClientsCount(t, s, 1) + var cli *client + s.mu.Lock() + for _, c := range s.clients { + cli = c + break + } + s.mu.Unlock() + if cli == nil { + t.Fatal("Did not register client on time") + } + // Test GetTLSConnectionState + state := cli.GetTLSConnectionState() + if state == nil { + t.Error("GetTLSConnectionState() returned nil") + } + // Fill the buffer. Need to send 1 byte at a time so that we timeout here + // the nc.Close() would block due to a write that can not complete. + done := false + for !done { + cli.nc.SetWriteDeadline(time.Now().Add(time.Second)) + if _, err := cli.nc.Write([]byte("a")); err != nil { + done = true + } + cli.nc.SetWriteDeadline(time.Time{}) + } + ch := make(chan bool) + go func() { + select { + case <-ch: + return + case <-time.After(3 * time.Second): + fmt.Println("!!!! closeConnection is blocked, test will hang !!!") + return + } + }() + // Close the client + cli.closeConnection(ClientClosed) + ch <- true +} + +// This tests issue #558 +func TestWildcardCharsInLiteralSubjectWorks(t *testing.T) { + opts := DefaultOptions() + s := RunServer(opts) + defer s.Shutdown() + + nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + ch := make(chan bool, 1) + // This subject is a literal even though it contains `*` and `>`, + // they are not treated as wildcards. + subj := "foo.bar,*,>,baz" + cb := func(_ *nats.Msg) { + ch <- true + } + for i := 0; i < 2; i++ { + sub, err := nc.Subscribe(subj, cb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } + if err := nc.LastError(); err != nil { + t.Fatalf("Server reported error: %v", err) + } + if err := nc.Publish(subj, []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("Should have received the message") + } + if err := sub.Unsubscribe(); err != nil { + t.Fatalf("Error on unsubscribe: %v", err) + } + } +} + +func TestDynamicBuffers(t *testing.T) { + opts := DefaultOptions() + s := RunServer(opts) + defer s.Shutdown() + + nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + // Grab the client from server. + s.mu.Lock() + lc := len(s.clients) + c := s.clients[s.gcid] + s.mu.Unlock() + + if lc != 1 { + t.Fatalf("Expected only 1 client but got %d\n", lc) + } + if c == nil { + t.Fatal("Expected to retrieve client\n") + } + + // Create some helper functions and data structures. + done := make(chan bool) // Used to stop recording. + type maxv struct{ rsz, wsz int } // Used to hold max values. + results := make(chan maxv) + + // stopRecording stops the recording ticker and releases go routine. + stopRecording := func() maxv { + done <- true + return <-results + } + // max just grabs max values. + max := func(a, b int) int { + if a > b { + return a + } + return b + } + // Returns current value of the buffer sizes. + getBufferSizes := func() (int, int) { + c.mu.Lock() + defer c.mu.Unlock() + return c.in.rsz, c.out.sz + } + // Record the max values seen. + recordMaxBufferSizes := func() { + ticker := time.NewTicker(10 * time.Microsecond) + defer ticker.Stop() + + var m maxv + + recordMax := func() { + rsz, wsz := getBufferSizes() + m.rsz = max(m.rsz, rsz) + m.wsz = max(m.wsz, wsz) + } + + for { + select { + case <-done: + recordMax() + results <- m + return + case <-ticker.C: + recordMax() + } + } + } + // Check that the current value is what we expected. + checkBuffers := func(ers, ews int) { + t.Helper() + rsz, wsz := getBufferSizes() + if rsz != ers { + t.Fatalf("Expected read buffer of %d, but got %d\n", ers, rsz) + } + if wsz != ews { + t.Fatalf("Expected write buffer of %d, but got %d\n", ews, wsz) + } + } + + // Check that the max was as expected. + checkResults := func(m maxv, rsz, wsz int) { + t.Helper() + if rsz != m.rsz { + t.Fatalf("Expected read buffer of %d, but got %d\n", rsz, m.rsz) + } + if wsz != m.wsz { + t.Fatalf("Expected write buffer of %d, but got %d\n", wsz, m.wsz) + } + } + + // Here is where testing begins.. + + // Should be at or below the startBufSize for both. + rsz, wsz := getBufferSizes() + if rsz > startBufSize { + t.Fatalf("Expected read buffer of <= %d, but got %d\n", startBufSize, rsz) + } + if wsz > startBufSize { + t.Fatalf("Expected write buffer of <= %d, but got %d\n", startBufSize, wsz) + } + + // Send some data. + data := make([]byte, 2048) + rand.Read(data) + + go recordMaxBufferSizes() + for i := 0; i < 200; i++ { + nc.Publish("foo", data) + } + nc.Flush() + m := stopRecording() + + if m.rsz != maxBufSize && m.rsz != maxBufSize/2 { + t.Fatalf("Expected read buffer of %d or %d, but got %d\n", maxBufSize, maxBufSize/2, m.rsz) + } + if m.wsz > startBufSize { + t.Fatalf("Expected write buffer of <= %d, but got %d\n", startBufSize, m.wsz) + } + + // Create Subscription to test outbound buffer from server. + nc.Subscribe("foo", func(m *nats.Msg) { + // Just eat it.. + }) + go recordMaxBufferSizes() + + for i := 0; i < 200; i++ { + nc.Publish("foo", data) + } + nc.Flush() + + m = stopRecording() + checkResults(m, maxBufSize, maxBufSize) + + // Now test that we shrink correctly. + + // Should go to minimum for both.. + for i := 0; i < 20; i++ { + nc.Flush() + } + checkBuffers(minBufSize, minBufSize) +} + +// Similar to the routed version. Make sure we receive all of the +// messages with auto-unsubscribe enabled. +func TestQueueAutoUnsubscribe(t *testing.T) { + opts := DefaultOptions() + s := RunServer(opts) + defer s.Shutdown() + + nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + rbar := int32(0) + barCb := func(m *nats.Msg) { + atomic.AddInt32(&rbar, 1) + } + rbaz := int32(0) + bazCb := func(m *nats.Msg) { + atomic.AddInt32(&rbaz, 1) + } + + // Create 1000 subscriptions with auto-unsubscribe of 1. + // Do two groups, one bar and one baz. + for i := 0; i < 1000; i++ { + qsub, err := nc.QueueSubscribe("foo", "bar", barCb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := qsub.AutoUnsubscribe(1); err != nil { + t.Fatalf("Error on auto-unsubscribe: %v", err) + } + qsub, err = nc.QueueSubscribe("foo", "baz", bazCb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := qsub.AutoUnsubscribe(1); err != nil { + t.Fatalf("Error on auto-unsubscribe: %v", err) + } + } + nc.Flush() + + expected := int32(1000) + for i := int32(0); i < expected; i++ { + nc.Publish("foo", []byte("Don't Drop Me!")) + } + nc.Flush() + + checkFor(t, 5*time.Second, 10*time.Millisecond, func() error { + nbar := atomic.LoadInt32(&rbar) + nbaz := atomic.LoadInt32(&rbaz) + if nbar == expected && nbaz == expected { + return nil + } + return fmt.Errorf("Did not receive all %d queue messages, received %d for 'bar' and %d for 'baz'", + expected, atomic.LoadInt32(&rbar), atomic.LoadInt32(&rbaz)) + }) +} diff --git a/vendor/github.com/nats-io/gnatsd/server/closed_conns_test.go b/vendor/github.com/nats-io/gnatsd/server/closed_conns_test.go new file mode 100644 index 00000000..e3a86a64 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/closed_conns_test.go @@ -0,0 +1,361 @@ +// Copyright 2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "net" + "strings" + "testing" + "time" + + nats "github.com/nats-io/go-nats" +) + +func checkClosedConns(t *testing.T, s *Server, num int, wait time.Duration) { + t.Helper() + checkFor(t, wait, 5*time.Millisecond, func() error { + if nc := s.numClosedConns(); nc != num { + return fmt.Errorf("Closed conns expected to be %v, got %v", num, nc) + } + return nil + }) +} + +func checkTotalClosedConns(t *testing.T, s *Server, num uint64, wait time.Duration) { + t.Helper() + checkFor(t, wait, 5*time.Millisecond, func() error { + if nc := s.totalClosedConns(); nc != num { + return fmt.Errorf("Total closed conns expected to be %v, got %v", num, nc) + } + return nil + }) +} + +func TestClosedConnsAccounting(t *testing.T) { + opts := DefaultOptions() + opts.MaxClosedClients = 10 + + s := RunServer(opts) + defer s.Shutdown() + + wait := 20 * time.Millisecond + + nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + nc.Close() + + checkClosedConns(t, s, 1, wait) + + conns := s.closedClients() + if lc := len(conns); lc != 1 { + t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) + } + if conns[0].Cid != 1 { + t.Fatalf("Expected CID to be 1, got %d\n", conns[0].Cid) + } + + // Now create 21 more + for i := 0; i < 21; i++ { + nc, err = nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + nc.Close() + checkTotalClosedConns(t, s, uint64(i+2), wait) + } + + checkClosedConns(t, s, opts.MaxClosedClients, wait) + checkTotalClosedConns(t, s, 22, wait) + + conns = s.closedClients() + if lc := len(conns); lc != opts.MaxClosedClients { + t.Fatalf("len(conns) expected to be %d, got %d\n", + opts.MaxClosedClients, lc) + } + + // Set it to the start after overflow. + cid := uint64(22 - opts.MaxClosedClients) + for _, ci := range conns { + cid++ + if ci.Cid != cid { + t.Fatalf("Expected cid of %d, got %d\n", cid, ci.Cid) + } + } +} + +func TestClosedConnsSubsAccounting(t *testing.T) { + opts := DefaultOptions() + s := RunServer(opts) + defer s.Shutdown() + + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + + nc, err := nats.Connect(url) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + + // Now create some subscriptions + numSubs := 10 + for i := 0; i < numSubs; i++ { + subj := fmt.Sprintf("foo.%d", i) + nc.Subscribe(subj, func(m *nats.Msg) {}) + } + nc.Flush() + nc.Close() + + checkClosedConns(t, s, 1, 20*time.Millisecond) + conns := s.closedClients() + if lc := len(conns); lc != 1 { + t.Fatalf("len(conns) expected to be 1, got %d\n", lc) + } + ci := conns[0] + + if len(ci.subs) != numSubs { + t.Fatalf("Expected number of Subs to be %d, got %d\n", numSubs, len(ci.subs)) + } +} + +func checkReason(t *testing.T, reason string, expected ClosedState) { + if !strings.Contains(reason, expected.String()) { + t.Fatalf("Expected closed connection with `%s` state, got `%s`\n", + expected, reason) + } +} + +func TestClosedAuthorizationTimeout(t *testing.T) { + serverOptions := DefaultOptions() + serverOptions.Authorization = "my_token" + serverOptions.AuthTimeout = 0.4 + s := RunServer(serverOptions) + defer s.Shutdown() + + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", serverOptions.Host, serverOptions.Port)) + if err != nil { + t.Fatalf("Error dialing server: %v\n", err) + } + defer conn.Close() + + checkClosedConns(t, s, 1, 2*time.Second) + conns := s.closedClients() + if lc := len(conns); lc != 1 { + t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) + } + checkReason(t, conns[0].Reason, AuthenticationTimeout) +} + +func TestClosedAuthorizationViolation(t *testing.T) { + serverOptions := DefaultOptions() + serverOptions.Authorization = "my_token" + s := RunServer(serverOptions) + defer s.Shutdown() + + opts := s.getOpts() + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + + nc, err := nats.Connect(url) + if err == nil { + nc.Close() + t.Fatal("Expected failure for connection") + } + + checkClosedConns(t, s, 1, 2*time.Second) + conns := s.closedClients() + if lc := len(conns); lc != 1 { + t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) + } + checkReason(t, conns[0].Reason, AuthenticationViolation) +} + +func TestClosedUPAuthorizationViolation(t *testing.T) { + serverOptions := DefaultOptions() + serverOptions.Username = "my_user" + serverOptions.Password = "my_secret" + s := RunServer(serverOptions) + defer s.Shutdown() + + opts := s.getOpts() + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + + nc, err := nats.Connect(url) + if err == nil { + nc.Close() + t.Fatal("Expected failure for connection") + } + + url2 := fmt.Sprintf("nats://my_user:wrong_pass@%s:%d", opts.Host, opts.Port) + nc, err = nats.Connect(url2) + if err == nil { + nc.Close() + t.Fatal("Expected failure for connection") + } + + checkClosedConns(t, s, 2, 2*time.Second) + conns := s.closedClients() + if lc := len(conns); lc != 2 { + t.Fatalf("len(conns) expected to be %d, got %d\n", 2, lc) + } + checkReason(t, conns[0].Reason, AuthenticationViolation) + checkReason(t, conns[1].Reason, AuthenticationViolation) +} + +func TestClosedMaxPayload(t *testing.T) { + serverOptions := DefaultOptions() + serverOptions.MaxPayload = 100 + + s := RunServer(serverOptions) + defer s.Shutdown() + + opts := s.getOpts() + endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + + conn, err := net.DialTimeout("tcp", endpoint, time.Second) + if err != nil { + t.Fatalf("Could not make a raw connection to the server: %v", err) + } + defer conn.Close() + + // This should trigger it. + pub := fmt.Sprintf("PUB foo.bar 1024\r\n") + conn.Write([]byte(pub)) + + checkClosedConns(t, s, 1, 2*time.Second) + conns := s.closedClients() + if lc := len(conns); lc != 1 { + t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) + } + checkReason(t, conns[0].Reason, MaxPayloadExceeded) +} + +func TestClosedSlowConsumerWriteDeadline(t *testing.T) { + opts := DefaultOptions() + opts.WriteDeadline = 10 * time.Millisecond // Make very small to trip. + opts.MaxPending = 500 * 1024 * 1024 // Set high so it will not trip here. + s := RunServer(opts) + defer s.Shutdown() + + c, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", opts.Host, opts.Port), 3*time.Second) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer c.Close() + if _, err := c.Write([]byte("CONNECT {}\r\nPING\r\nSUB foo 1\r\n")); err != nil { + t.Fatalf("Error sending protocols to server: %v", err) + } + // Reduce socket buffer to increase reliability of data backing up in the server destined + // for our subscribed client. + c.(*net.TCPConn).SetReadBuffer(128) + + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + sender, err := nats.Connect(url) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer sender.Close() + + payload := make([]byte, 1024*1024) + for i := 0; i < 100; i++ { + if err := sender.Publish("foo", payload); err != nil { + t.Fatalf("Error on publish: %v", err) + } + } + + // Flush sender connection to ensure that all data has been sent. + if err := sender.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } + + // At this point server should have closed connection c. + checkClosedConns(t, s, 1, 2*time.Second) + conns := s.closedClients() + if lc := len(conns); lc != 1 { + t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) + } + checkReason(t, conns[0].Reason, SlowConsumerWriteDeadline) +} + +func TestClosedSlowConsumerPendingBytes(t *testing.T) { + opts := DefaultOptions() + opts.WriteDeadline = 30 * time.Second // Wait for long time so write deadline does not trigger slow consumer. + opts.MaxPending = 1 * 1024 * 1024 // Set to low value (1MB) to allow SC to trip. + s := RunServer(opts) + defer s.Shutdown() + + c, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", opts.Host, opts.Port), 3*time.Second) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer c.Close() + if _, err := c.Write([]byte("CONNECT {}\r\nPING\r\nSUB foo 1\r\n")); err != nil { + t.Fatalf("Error sending protocols to server: %v", err) + } + // Reduce socket buffer to increase reliability of data backing up in the server destined + // for our subscribed client. + c.(*net.TCPConn).SetReadBuffer(128) + + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + sender, err := nats.Connect(url) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer sender.Close() + + payload := make([]byte, 1024*1024) + for i := 0; i < 100; i++ { + if err := sender.Publish("foo", payload); err != nil { + t.Fatalf("Error on publish: %v", err) + } + } + + // Flush sender connection to ensure that all data has been sent. + if err := sender.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } + + // At this point server should have closed connection c. + checkClosedConns(t, s, 1, 2*time.Second) + conns := s.closedClients() + if lc := len(conns); lc != 1 { + t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) + } + checkReason(t, conns[0].Reason, SlowConsumerPendingBytes) +} + +func TestClosedTLSHandshake(t *testing.T) { + opts, err := ProcessConfigFile("./configs/tls.conf") + if err != nil { + t.Fatalf("Error processing config file: %v", err) + } + opts.TLSVerify = true + opts.NoLog = true + opts.NoSigs = true + s := RunServer(opts) + defer s.Shutdown() + + nc, err := nats.Connect(fmt.Sprintf("tls://%s:%d", opts.Host, opts.Port)) + if err == nil { + nc.Close() + t.Fatal("Expected failure for connection") + } + + checkClosedConns(t, s, 1, 2*time.Second) + conns := s.closedClients() + if lc := len(conns); lc != 1 { + t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) + } + checkReason(t, conns[0].Reason, TLSHandshakeError) +} diff --git a/vendor/github.com/nats-io/gnatsd/server/const.go b/vendor/github.com/nats-io/gnatsd/server/const.go new file mode 100644 index 00000000..b27ae5d9 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/const.go @@ -0,0 +1,124 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "time" +) + +// Command is a signal used to control a running gnatsd process. +type Command string + +// Valid Command values. +const ( + CommandStop = Command("stop") + CommandQuit = Command("quit") + CommandReopen = Command("reopen") + CommandReload = Command("reload") +) + +var ( + // gitCommit injected at build + gitCommit string +) + +const ( + // VERSION is the current version for the server. + VERSION = "1.2.0" + + // PROTO is the currently supported protocol. + // 0 was the original + // 1 maintains proto 0, adds echo abilities for CONNECT from the client. Clients + // should not send echo unless proto in INFO is >= 1. + PROTO = 1 + + // DEFAULT_PORT is the default port for client connections. + DEFAULT_PORT = 4222 + + // RANDOM_PORT is the value for port that, when supplied, will cause the + // server to listen on a randomly-chosen available port. The resolved port + // is available via the Addr() method. + RANDOM_PORT = -1 + + // DEFAULT_HOST defaults to all interfaces. + DEFAULT_HOST = "0.0.0.0" + + // MAX_CONTROL_LINE_SIZE is the maximum allowed protocol control line size. + // 1k should be plenty since payloads sans connect string are separate + MAX_CONTROL_LINE_SIZE = 1024 + + // MAX_PAYLOAD_SIZE is the maximum allowed payload size. Should be using + // something different if > 1MB payloads are needed. + MAX_PAYLOAD_SIZE = (1024 * 1024) + + // MAX_PENDING_SIZE is the maximum outbound pending bytes per client. + MAX_PENDING_SIZE = (256 * 1024 * 1024) + + // DEFAULT_MAX_CONNECTIONS is the default maximum connections allowed. + DEFAULT_MAX_CONNECTIONS = (64 * 1024) + + // TLS_TIMEOUT is the TLS wait time. + TLS_TIMEOUT = 500 * time.Millisecond + + // AUTH_TIMEOUT is the authorization wait time. + AUTH_TIMEOUT = 2 * TLS_TIMEOUT + + // DEFAULT_PING_INTERVAL is how often pings are sent to clients and routes. + DEFAULT_PING_INTERVAL = 2 * time.Minute + + // DEFAULT_PING_MAX_OUT is maximum allowed pings outstanding before disconnect. + DEFAULT_PING_MAX_OUT = 2 + + // CR_LF string + CR_LF = "\r\n" + + // LEN_CR_LF hold onto the computed size. + LEN_CR_LF = len(CR_LF) + + // DEFAULT_FLUSH_DEADLINE is the write/flush deadlines. + DEFAULT_FLUSH_DEADLINE = 2 * time.Second + + // DEFAULT_HTTP_PORT is the default monitoring port. + DEFAULT_HTTP_PORT = 8222 + + // ACCEPT_MIN_SLEEP is the minimum acceptable sleep times on temporary errors. + ACCEPT_MIN_SLEEP = 10 * time.Millisecond + + // ACCEPT_MAX_SLEEP is the maximum acceptable sleep times on temporary errors + ACCEPT_MAX_SLEEP = 1 * time.Second + + // DEFAULT_ROUTE_CONNECT Route solicitation intervals. + DEFAULT_ROUTE_CONNECT = 1 * time.Second + + // DEFAULT_ROUTE_RECONNECT Route reconnect intervals. + DEFAULT_ROUTE_RECONNECT = 1 * time.Second + + // DEFAULT_ROUTE_DIAL Route dial timeout. + DEFAULT_ROUTE_DIAL = 1 * time.Second + + // PROTO_SNIPPET_SIZE is the default size of proto to print on parse errors. + PROTO_SNIPPET_SIZE = 32 + + // MAX_MSG_ARGS Maximum possible number of arguments from MSG proto. + MAX_MSG_ARGS = 4 + + // MAX_PUB_ARGS Maximum possible number of arguments from PUB proto. + MAX_PUB_ARGS = 3 + + // DEFAULT_REMOTE_QSUBS_SWEEPER + DEFAULT_REMOTE_QSUBS_SWEEPER = 30 * time.Second + + // DEFAULT_MAX_CLOSED_CLIENTS + DEFAULT_MAX_CLOSED_CLIENTS = 10000 +) diff --git a/vendor/github.com/nats-io/gnatsd/server/errors.go b/vendor/github.com/nats-io/gnatsd/server/errors.go new file mode 100644 index 00000000..c722bfc4 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/errors.go @@ -0,0 +1,51 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import "errors" + +var ( + // ErrConnectionClosed represents an error condition on a closed connection. + ErrConnectionClosed = errors.New("Connection Closed") + + // ErrAuthorization represents an error condition on failed authorization. + ErrAuthorization = errors.New("Authorization Error") + + // ErrAuthTimeout represents an error condition on failed authorization due to timeout. + ErrAuthTimeout = errors.New("Authorization Timeout") + + // ErrMaxPayload represents an error condition when the payload is too big. + ErrMaxPayload = errors.New("Maximum Payload Exceeded") + + // ErrMaxControlLine represents an error condition when the control line is too big. + ErrMaxControlLine = errors.New("Maximum Control Line Exceeded") + + // ErrReservedPublishSubject represents an error condition when sending to a reserved subject, e.g. _SYS.> + ErrReservedPublishSubject = errors.New("Reserved Internal Subject") + + // ErrBadClientProtocol signals a client requested an invalud client protocol. + ErrBadClientProtocol = errors.New("Invalid Client Protocol") + + // ErrTooManyConnections signals a client that the maximum number of connections supported by the + // server has been reached. + ErrTooManyConnections = errors.New("Maximum Connections Exceeded") + + // ErrTooManySubs signals a client that the maximum number of subscriptions per connection + // has been reached. + ErrTooManySubs = errors.New("Maximum Subscriptions Exceeded") + + // ErrClientConnectedToRoutePort represents an error condition when a client + // attempted to connect to the route listen port. + ErrClientConnectedToRoutePort = errors.New("Attempted To Connect To Route Port") +) diff --git a/vendor/github.com/nats-io/gnatsd/server/log.go b/vendor/github.com/nats-io/gnatsd/server/log.go new file mode 100644 index 00000000..8c2be370 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/log.go @@ -0,0 +1,184 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "io" + "os" + "sync/atomic" + + srvlog "github.com/nats-io/gnatsd/logger" +) + +// Logger interface of the NATS Server +type Logger interface { + + // Log a notice statement + Noticef(format string, v ...interface{}) + + // Log a fatal error + Fatalf(format string, v ...interface{}) + + // Log an error + Errorf(format string, v ...interface{}) + + // Log a debug statement + Debugf(format string, v ...interface{}) + + // Log a trace statement + Tracef(format string, v ...interface{}) +} + +// ConfigureLogger configures and sets the logger for the server. +func (s *Server) ConfigureLogger() { + var ( + log Logger + + // Snapshot server options. + opts = s.getOpts() + ) + + syslog := opts.Syslog + if isWindowsService() && opts.LogFile == "" { + // Enable syslog if no log file is specified and we're running as a + // Windows service so that logs are written to the Windows event log. + syslog = true + } + + if opts.LogFile != "" { + log = srvlog.NewFileLogger(opts.LogFile, opts.Logtime, opts.Debug, opts.Trace, true) + } else if opts.RemoteSyslog != "" { + log = srvlog.NewRemoteSysLogger(opts.RemoteSyslog, opts.Debug, opts.Trace) + } else if syslog { + log = srvlog.NewSysLogger(opts.Debug, opts.Trace) + } else { + colors := true + // Check to see if stderr is being redirected and if so turn off color + // Also turn off colors if we're running on Windows where os.Stderr.Stat() returns an invalid handle-error + stat, err := os.Stderr.Stat() + if err != nil || (stat.Mode()&os.ModeCharDevice) == 0 { + colors = false + } + log = srvlog.NewStdLogger(opts.Logtime, opts.Debug, opts.Trace, colors, true) + } + + s.SetLogger(log, opts.Debug, opts.Trace) +} + +// SetLogger sets the logger of the server +func (s *Server) SetLogger(logger Logger, debugFlag, traceFlag bool) { + if debugFlag { + atomic.StoreInt32(&s.logging.debug, 1) + } else { + atomic.StoreInt32(&s.logging.debug, 0) + } + if traceFlag { + atomic.StoreInt32(&s.logging.trace, 1) + } else { + atomic.StoreInt32(&s.logging.trace, 0) + } + s.logging.Lock() + if s.logging.logger != nil { + // Check to see if the logger implements io.Closer. This could be a + // logger from another process embedding the NATS server or a dummy + // test logger that may not implement that interface. + if l, ok := s.logging.logger.(io.Closer); ok { + if err := l.Close(); err != nil { + s.Errorf("Error closing logger: %v", err) + } + } + } + s.logging.logger = logger + s.logging.Unlock() +} + +// If the logger is a file based logger, close and re-open the file. +// This allows for file rotation by 'mv'ing the file then signaling +// the process to trigger this function. +func (s *Server) ReOpenLogFile() { + // Check to make sure this is a file logger. + s.logging.RLock() + ll := s.logging.logger + s.logging.RUnlock() + + if ll == nil { + s.Noticef("File log re-open ignored, no logger") + return + } + + // Snapshot server options. + opts := s.getOpts() + + if opts.LogFile == "" { + s.Noticef("File log re-open ignored, not a file logger") + } else { + fileLog := srvlog.NewFileLogger(opts.LogFile, + opts.Logtime, opts.Debug, opts.Trace, true) + s.SetLogger(fileLog, opts.Debug, opts.Trace) + s.Noticef("File log re-opened") + } +} + +// Noticef logs a notice statement +func (s *Server) Noticef(format string, v ...interface{}) { + s.executeLogCall(func(logger Logger, format string, v ...interface{}) { + logger.Noticef(format, v...) + }, format, v...) +} + +// Errorf logs an error +func (s *Server) Errorf(format string, v ...interface{}) { + s.executeLogCall(func(logger Logger, format string, v ...interface{}) { + logger.Errorf(format, v...) + }, format, v...) +} + +// Fatalf logs a fatal error +func (s *Server) Fatalf(format string, v ...interface{}) { + s.executeLogCall(func(logger Logger, format string, v ...interface{}) { + logger.Fatalf(format, v...) + }, format, v...) +} + +// Debugf logs a debug statement +func (s *Server) Debugf(format string, v ...interface{}) { + if atomic.LoadInt32(&s.logging.debug) == 0 { + return + } + + s.executeLogCall(func(logger Logger, format string, v ...interface{}) { + logger.Debugf(format, v...) + }, format, v...) +} + +// Tracef logs a trace statement +func (s *Server) Tracef(format string, v ...interface{}) { + if atomic.LoadInt32(&s.logging.trace) == 0 { + return + } + + s.executeLogCall(func(logger Logger, format string, v ...interface{}) { + logger.Tracef(format, v...) + }, format, v...) +} + +func (s *Server) executeLogCall(f func(logger Logger, format string, v ...interface{}), format string, args ...interface{}) { + s.logging.RLock() + defer s.logging.RUnlock() + if s.logging.logger == nil { + return + } + + f(s.logging.logger, format, args...) +} diff --git a/vendor/github.com/nats-io/gnatsd/server/log_test.go b/vendor/github.com/nats-io/gnatsd/server/log_test.go new file mode 100644 index 00000000..78aea10f --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/log_test.go @@ -0,0 +1,175 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "io/ioutil" + "os" + "runtime" + "strings" + "sync" + "testing" + + "github.com/nats-io/gnatsd/logger" +) + +func TestSetLogger(t *testing.T) { + server := &Server{} + defer server.SetLogger(nil, false, false) + dl := &DummyLogger{} + server.SetLogger(dl, true, true) + + // We assert that the logger has change to the DummyLogger + _ = server.logging.logger.(*DummyLogger) + + if server.logging.debug != 1 { + t.Fatalf("Expected debug 1, received value %d\n", server.logging.debug) + } + + if server.logging.trace != 1 { + t.Fatalf("Expected trace 1, received value %d\n", server.logging.trace) + } + + // Check traces + expectedStr := "This is a Notice" + server.Noticef(expectedStr) + dl.checkContent(t, expectedStr) + expectedStr = "This is an Error" + server.Errorf(expectedStr) + dl.checkContent(t, expectedStr) + expectedStr = "This is a Fatal" + server.Fatalf(expectedStr) + dl.checkContent(t, expectedStr) + expectedStr = "This is a Debug" + server.Debugf(expectedStr) + dl.checkContent(t, expectedStr) + expectedStr = "This is a Trace" + server.Tracef(expectedStr) + dl.checkContent(t, expectedStr) + + // Make sure that we can reset to fal + server.SetLogger(dl, false, false) + if server.logging.debug != 0 { + t.Fatalf("Expected debug 0, got %v", server.logging.debug) + } + if server.logging.trace != 0 { + t.Fatalf("Expected trace 0, got %v", server.logging.trace) + } + // Now, Debug and Trace should not produce anything + dl.msg = "" + server.Debugf("This Debug should not be traced") + dl.checkContent(t, "") + server.Tracef("This Trace should not be traced") + dl.checkContent(t, "") +} + +type DummyLogger struct { + sync.Mutex + msg string +} + +func (l *DummyLogger) checkContent(t *testing.T, expectedStr string) { + l.Lock() + defer l.Unlock() + if l.msg != expectedStr { + stackFatalf(t, "Expected log to be: %v, got %v", expectedStr, l.msg) + } +} + +func (l *DummyLogger) Noticef(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.msg = fmt.Sprintf(format, v...) +} +func (l *DummyLogger) Errorf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.msg = fmt.Sprintf(format, v...) +} +func (l *DummyLogger) Fatalf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.msg = fmt.Sprintf(format, v...) +} +func (l *DummyLogger) Debugf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.msg = fmt.Sprintf(format, v...) +} +func (l *DummyLogger) Tracef(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + l.msg = fmt.Sprintf(format, v...) +} + +func TestReOpenLogFile(t *testing.T) { + // We can't rename the file log when still opened on Windows, so skip + if runtime.GOOS == "windows" { + t.SkipNow() + } + s := &Server{opts: &Options{}} + defer s.SetLogger(nil, false, false) + + // First check with no logger + s.SetLogger(nil, false, false) + s.ReOpenLogFile() + + // Then when LogFile is not provided. + dl := &DummyLogger{} + s.SetLogger(dl, false, false) + s.ReOpenLogFile() + dl.checkContent(t, "File log re-open ignored, not a file logger") + + // Set a File log + s.opts.LogFile = "test.log" + defer os.Remove(s.opts.LogFile) + defer os.Remove(s.opts.LogFile + ".bak") + fileLog := logger.NewFileLogger(s.opts.LogFile, s.opts.Logtime, s.opts.Debug, s.opts.Trace, true) + s.SetLogger(fileLog, false, false) + // Add some log + expectedStr := "This is a Notice" + s.Noticef(expectedStr) + // Check content of log + buf, err := ioutil.ReadFile(s.opts.LogFile) + if err != nil { + t.Fatalf("Error reading file: %v", err) + } + if !strings.Contains(string(buf), expectedStr) { + t.Fatalf("Expected log to contain: %q, got %q", expectedStr, string(buf)) + } + // Close the file and rename it + if err := os.Rename(s.opts.LogFile, s.opts.LogFile+".bak"); err != nil { + t.Fatalf("Unable to rename log file: %v", err) + } + // Now re-open LogFile + s.ReOpenLogFile() + // Content should indicate that we have re-opened the log + buf, err = ioutil.ReadFile(s.opts.LogFile) + if err != nil { + t.Fatalf("Error reading file: %v", err) + } + if strings.HasSuffix(string(buf), "File log-reopened") { + t.Fatalf("File should indicate that file log was re-opened, got: %v", string(buf)) + } + // Make sure we can append to the log + s.Noticef("New message") + buf, err = ioutil.ReadFile(s.opts.LogFile) + if err != nil { + t.Fatalf("Error reading file: %v", err) + } + if strings.HasSuffix(string(buf), "New message") { + t.Fatalf("New message was not appended after file was re-opened, got: %v", string(buf)) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/monitor.go b/vendor/github.com/nats-io/gnatsd/server/monitor.go new file mode 100644 index 00000000..550b8083 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/monitor.go @@ -0,0 +1,1029 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "net" + "net/http" + "runtime" + "sort" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/nats-io/gnatsd/server/pse" +) + +// Snapshot this +var numCores int + +func init() { + numCores = runtime.NumCPU() +} + +// Connz represents detailed information on current client connections. +type Connz struct { + ID string `json:"server_id"` + Now time.Time `json:"now"` + NumConns int `json:"num_connections"` + Total int `json:"total"` + Offset int `json:"offset"` + Limit int `json:"limit"` + Conns []*ConnInfo `json:"connections"` +} + +// ConnzOptions are the options passed to Connz() +type ConnzOptions struct { + // Sort indicates how the results will be sorted. Check SortOpt for possible values. + // Only the sort by connection ID (ByCid) is ascending, all others are descending. + Sort SortOpt `json:"sort"` + + // Username indicates if user names should be included in the results. + Username bool `json:"auth"` + + // Subscriptions indicates if subscriptions should be included in the results. + Subscriptions bool `json:"subscriptions"` + + // Offset is used for pagination. Connz() only returns connections starting at this + // offset from the global results. + Offset int `json:"offset"` + + // Limit is the maximum number of connections that should be returned by Connz(). + Limit int `json:"limit"` + + // Filter for this explicit client connection. + CID uint64 `json:"cid"` + + // Filter by connection state. + State ConnState `json:"state"` +} + +// For filtering states of connections. We will only have two, open and closed. +type ConnState int + +const ( + ConnOpen = ConnState(iota) + ConnClosed + ConnAll +) + +// ConnInfo has detailed information on a per connection basis. +type ConnInfo struct { + Cid uint64 `json:"cid"` + IP string `json:"ip"` + Port int `json:"port"` + Start time.Time `json:"start"` + LastActivity time.Time `json:"last_activity"` + Stop *time.Time `json:"stop,omitempty"` + Reason string `json:"reason,omitempty"` + RTT string `json:"rtt,omitempty"` + Uptime string `json:"uptime"` + Idle string `json:"idle"` + Pending int `json:"pending_bytes"` + InMsgs int64 `json:"in_msgs"` + OutMsgs int64 `json:"out_msgs"` + InBytes int64 `json:"in_bytes"` + OutBytes int64 `json:"out_bytes"` + NumSubs uint32 `json:"subscriptions"` + Name string `json:"name,omitempty"` + Lang string `json:"lang,omitempty"` + Version string `json:"version,omitempty"` + TLSVersion string `json:"tls_version,omitempty"` + TLSCipher string `json:"tls_cipher_suite,omitempty"` + AuthorizedUser string `json:"authorized_user,omitempty"` + Subs []string `json:"subscriptions_list,omitempty"` +} + +// DefaultConnListSize is the default size of the connection list. +const DefaultConnListSize = 1024 + +// DefaultSubListSize is the default size of the subscriptions list. +const DefaultSubListSize = 1024 + +const defaultStackBufSize = 10000 + +// Connz returns a Connz struct containing inormation about connections. +func (s *Server) Connz(opts *ConnzOptions) (*Connz, error) { + var ( + sortOpt = ByCid + auth bool + subs bool + offset int + limit = DefaultConnListSize + cid = uint64(0) + state = ConnOpen + ) + + if opts != nil { + // If no sort option given or sort is by uptime, then sort by cid + if opts.Sort == "" { + sortOpt = ByCid + } else { + sortOpt = opts.Sort + if !sortOpt.IsValid() { + return nil, fmt.Errorf("Invalid sorting option: %s", sortOpt) + } + } + auth = opts.Username + subs = opts.Subscriptions + offset = opts.Offset + if offset < 0 { + offset = 0 + } + limit = opts.Limit + if limit <= 0 { + limit = DefaultConnListSize + } + // state + state = opts.State + + // ByStop only makes sense on closed connections + if sortOpt == ByStop && state != ConnClosed { + return nil, fmt.Errorf("Sort by stop only valid on closed connections") + } + // ByReason is the same. + if sortOpt == ByReason && state != ConnClosed { + return nil, fmt.Errorf("Sort by reason only valid on closed connections") + } + + // If searching by CID + if opts.CID > 0 { + cid = opts.CID + limit = 1 + } + } + + c := &Connz{ + Offset: offset, + Limit: limit, + Now: time.Now(), + } + + // Open clients + var openClients []*client + // Hold for closed clients if requested. + var closedClients []*closedClient + + // Walk the open client list with server lock held. + s.mu.Lock() + + // copy the server id for monitoring + c.ID = s.info.ID + + // Number of total clients. The resulting ConnInfo array + // may be smaller if pagination is used. + switch state { + case ConnOpen: + c.Total = len(s.clients) + case ConnClosed: + c.Total = s.closed.len() + closedClients = s.closed.closedClients() + case ConnAll: + c.Total = len(s.clients) + s.closed.len() + closedClients = s.closed.closedClients() + } + + totalClients := c.Total + if cid > 0 { // Meaning we only want 1. + totalClients = 1 + } + if state == ConnOpen || state == ConnAll { + openClients = make([]*client, 0, totalClients) + } + + // Data structures for results. + var conns []ConnInfo // Limits allocs for actual ConnInfos. + var pconns ConnInfos + + switch state { + case ConnOpen: + conns = make([]ConnInfo, totalClients) + pconns = make(ConnInfos, totalClients) + case ConnClosed: + pconns = make(ConnInfos, totalClients) + case ConnAll: + conns = make([]ConnInfo, cap(openClients)) + pconns = make(ConnInfos, totalClients) + } + + // Search by individual CID. + if cid > 0 { + if state == ConnClosed || state == ConnAll { + copyClosed := closedClients + closedClients = nil + for _, cc := range copyClosed { + if cc.Cid == cid { + closedClients = []*closedClient{cc} + break + } + } + } else if state == ConnOpen || state == ConnAll { + client := s.clients[cid] + if client != nil { + openClients = append(openClients, client) + } + } + } else { + // Gather all open clients. + if state == ConnOpen || state == ConnAll { + for _, client := range s.clients { + openClients = append(openClients, client) + } + } + } + s.mu.Unlock() + + // Just return with empty array if nothing here. + if len(openClients) == 0 && len(closedClients) == 0 { + c.Conns = ConnInfos{} + return c, nil + } + + // Now whip through and generate ConnInfo entries + + // Open Clients + i := 0 + for _, client := range openClients { + client.mu.Lock() + ci := &conns[i] + ci.fill(client, client.nc, c.Now) + // Fill in subscription data if requested. + if subs && len(client.subs) > 0 { + ci.Subs = make([]string, 0, len(client.subs)) + for _, sub := range client.subs { + ci.Subs = append(ci.Subs, string(sub.subject)) + } + } + // Fill in user if auth requested. + if auth { + ci.AuthorizedUser = client.opts.Username + } + client.mu.Unlock() + pconns[i] = ci + i++ + } + // Closed Clients + var needCopy bool + if subs || auth { + needCopy = true + } + for _, cc := range closedClients { + // Copy if needed for any changes to the ConnInfo + if needCopy { + cx := *cc + cc = &cx + } + // Fill in subscription data if requested. + if subs && len(cc.subs) > 0 { + cc.Subs = cc.subs + } + // Fill in user if auth requested. + if auth { + cc.AuthorizedUser = cc.user + } + pconns[i] = &cc.ConnInfo + i++ + } + + switch sortOpt { + case ByCid, ByStart: + sort.Sort(byCid{pconns}) + case BySubs: + sort.Sort(sort.Reverse(bySubs{pconns})) + case ByPending: + sort.Sort(sort.Reverse(byPending{pconns})) + case ByOutMsgs: + sort.Sort(sort.Reverse(byOutMsgs{pconns})) + case ByInMsgs: + sort.Sort(sort.Reverse(byInMsgs{pconns})) + case ByOutBytes: + sort.Sort(sort.Reverse(byOutBytes{pconns})) + case ByInBytes: + sort.Sort(sort.Reverse(byInBytes{pconns})) + case ByLast: + sort.Sort(sort.Reverse(byLast{pconns})) + case ByIdle: + sort.Sort(sort.Reverse(byIdle{pconns})) + case ByUptime: + sort.Sort(byUptime{pconns, time.Now()}) + case ByStop: + sort.Sort(sort.Reverse(byStop{pconns})) + case ByReason: + sort.Sort(byReason{pconns}) + } + + minoff := c.Offset + maxoff := c.Offset + c.Limit + + maxIndex := totalClients + + // Make sure these are sane. + if minoff > maxIndex { + minoff = maxIndex + } + if maxoff > maxIndex { + maxoff = maxIndex + } + + // Now pare down to the requested size. + // TODO(dlc) - for very large number of connections we + // could save the whole list in a hash, send hash on first + // request and allow users to use has for subsequent pages. + // Low TTL, say < 1sec. + c.Conns = pconns[minoff:maxoff] + c.NumConns = len(c.Conns) + + return c, nil +} + +// Fills in the ConnInfo from the client. +// client should be locked. +func (ci *ConnInfo) fill(client *client, nc net.Conn, now time.Time) { + ci.Cid = client.cid + ci.Start = client.start + ci.LastActivity = client.last + ci.Uptime = myUptime(now.Sub(client.start)) + ci.Idle = myUptime(now.Sub(client.last)) + ci.RTT = client.getRTT() + ci.OutMsgs = client.outMsgs + ci.OutBytes = client.outBytes + ci.NumSubs = uint32(len(client.subs)) + ci.Pending = int(client.out.pb) + ci.Name = client.opts.Name + ci.Lang = client.opts.Lang + ci.Version = client.opts.Version + // inMsgs and inBytes are updated outside of the client's lock, so + // we need to use atomic here. + ci.InMsgs = atomic.LoadInt64(&client.inMsgs) + ci.InBytes = atomic.LoadInt64(&client.inBytes) + + // If the connection is gone, too bad, we won't set TLSVersion and TLSCipher. + // Exclude clients that are still doing handshake so we don't block in + // ConnectionState(). + if client.flags.isSet(handshakeComplete) && nc != nil { + conn := nc.(*tls.Conn) + cs := conn.ConnectionState() + ci.TLSVersion = tlsVersion(cs.Version) + ci.TLSCipher = tlsCipher(cs.CipherSuite) + } + + switch conn := nc.(type) { + case *net.TCPConn, *tls.Conn: + addr := conn.RemoteAddr().(*net.TCPAddr) + ci.Port = addr.Port + ci.IP = addr.IP.String() + } +} + +// Assume lock is held +func (c *client) getRTT() string { + if c.rtt == 0 { + // If a real client, go ahead and send ping now to get a value + // for RTT. For tests and telnet, etc skip. + if c.flags.isSet(connectReceived) && c.opts.Lang != "" { + c.sendPing() + } + return "" + } + var rtt time.Duration + if c.rtt > time.Microsecond && c.rtt < time.Millisecond { + rtt = c.rtt.Truncate(time.Microsecond) + } else { + rtt = c.rtt.Truncate(time.Millisecond) + } + return rtt.String() +} + +func decodeBool(w http.ResponseWriter, r *http.Request, param string) (bool, error) { + str := r.URL.Query().Get(param) + if str == "" { + return false, nil + } + val, err := strconv.ParseBool(str) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("Error decoding boolean for '%s': %v", param, err))) + return false, err + } + return val, nil +} + +func decodeUint64(w http.ResponseWriter, r *http.Request, param string) (uint64, error) { + str := r.URL.Query().Get(param) + if str == "" { + return 0, nil + } + val, err := strconv.ParseUint(str, 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("Error decoding uint64 for '%s': %v", param, err))) + return 0, err + } + return val, nil +} + +func decodeInt(w http.ResponseWriter, r *http.Request, param string) (int, error) { + str := r.URL.Query().Get(param) + if str == "" { + return 0, nil + } + val, err := strconv.Atoi(str) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf("Error decoding int for '%s': %v", param, err))) + return 0, err + } + return val, nil +} + +func decodeState(w http.ResponseWriter, r *http.Request) (ConnState, error) { + str := r.URL.Query().Get("state") + if str == "" { + return ConnOpen, nil + } + switch strings.ToLower(str) { + case "open": + return ConnOpen, nil + case "closed": + return ConnClosed, nil + case "any", "all": + return ConnAll, nil + } + // We do not understand intended state here. + w.WriteHeader(http.StatusBadRequest) + err := fmt.Errorf("Error decoding state for %s", str) + w.Write([]byte(err.Error())) + return 0, err +} + +// HandleConnz process HTTP requests for connection information. +func (s *Server) HandleConnz(w http.ResponseWriter, r *http.Request) { + sortOpt := SortOpt(r.URL.Query().Get("sort")) + auth, err := decodeBool(w, r, "auth") + if err != nil { + return + } + subs, err := decodeBool(w, r, "subs") + if err != nil { + return + } + offset, err := decodeInt(w, r, "offset") + if err != nil { + return + } + limit, err := decodeInt(w, r, "limit") + if err != nil { + return + } + cid, err := decodeUint64(w, r, "cid") + if err != nil { + return + } + state, err := decodeState(w, r) + if err != nil { + return + } + + connzOpts := &ConnzOptions{ + Sort: sortOpt, + Username: auth, + Subscriptions: subs, + Offset: offset, + Limit: limit, + CID: cid, + State: state, + } + + s.mu.Lock() + s.httpReqStats[ConnzPath]++ + s.mu.Unlock() + + c, err := s.Connz(connzOpts) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + return + } + b, err := json.MarshalIndent(c, "", " ") + if err != nil { + s.Errorf("Error marshaling response to /connz request: %v", err) + } + + // Handle response + ResponseHandler(w, r, b) +} + +// Routez represents detailed information on current client connections. +type Routez struct { + ID string `json:"server_id"` + Now time.Time `json:"now"` + NumRoutes int `json:"num_routes"` + Routes []*RouteInfo `json:"routes"` +} + +// RoutezOptions are options passed to Routez +type RoutezOptions struct { + // Subscriptions indicates that Routez will return a route's subscriptions + Subscriptions bool `json:"subscriptions"` +} + +// RouteInfo has detailed information on a per connection basis. +type RouteInfo struct { + Rid uint64 `json:"rid"` + RemoteID string `json:"remote_id"` + DidSolicit bool `json:"did_solicit"` + IsConfigured bool `json:"is_configured"` + IP string `json:"ip"` + Port int `json:"port"` + Pending int `json:"pending_size"` + InMsgs int64 `json:"in_msgs"` + OutMsgs int64 `json:"out_msgs"` + InBytes int64 `json:"in_bytes"` + OutBytes int64 `json:"out_bytes"` + NumSubs uint32 `json:"subscriptions"` + Subs []string `json:"subscriptions_list,omitempty"` +} + +// Routez returns a Routez struct containing inormation about routes. +func (s *Server) Routez(routezOpts *RoutezOptions) (*Routez, error) { + rs := &Routez{Routes: []*RouteInfo{}} + rs.Now = time.Now() + + subs := routezOpts != nil && routezOpts.Subscriptions + + // Walk the list + s.mu.Lock() + rs.NumRoutes = len(s.routes) + + // copy the server id for monitoring + rs.ID = s.info.ID + + for _, r := range s.routes { + r.mu.Lock() + ri := &RouteInfo{ + Rid: r.cid, + RemoteID: r.route.remoteID, + DidSolicit: r.route.didSolicit, + IsConfigured: r.route.routeType == Explicit, + InMsgs: atomic.LoadInt64(&r.inMsgs), + OutMsgs: r.outMsgs, + InBytes: atomic.LoadInt64(&r.inBytes), + OutBytes: r.outBytes, + NumSubs: uint32(len(r.subs)), + } + + if subs && len(r.subs) > 0 { + ri.Subs = make([]string, 0, len(r.subs)) + for _, sub := range r.subs { + ri.Subs = append(ri.Subs, string(sub.subject)) + } + } + switch conn := r.nc.(type) { + case *net.TCPConn, *tls.Conn: + addr := conn.RemoteAddr().(*net.TCPAddr) + ri.Port = addr.Port + ri.IP = addr.IP.String() + } + r.mu.Unlock() + rs.Routes = append(rs.Routes, ri) + } + s.mu.Unlock() + return rs, nil +} + +// HandleRoutez process HTTP requests for route information. +func (s *Server) HandleRoutez(w http.ResponseWriter, r *http.Request) { + subs, err := decodeBool(w, r, "subs") + if err != nil { + return + } + var opts *RoutezOptions + if subs { + opts = &RoutezOptions{Subscriptions: true} + } + + s.mu.Lock() + s.httpReqStats[RoutezPath]++ + s.mu.Unlock() + + // As of now, no error is ever returned. + rs, _ := s.Routez(opts) + b, err := json.MarshalIndent(rs, "", " ") + if err != nil { + s.Errorf("Error marshaling response to /routez request: %v", err) + } + + // Handle response + ResponseHandler(w, r, b) +} + +// Subsz represents detail information on current connections. +type Subsz struct { + *SublistStats + Total int `json:"total"` + Offset int `json:"offset"` + Limit int `json:"limit"` + Subs []SubDetail `json:"subscriptions_list,omitempty"` +} + +// SubszOptions are the options passed to Subsz. +// As of now, there are no options defined. +type SubszOptions struct { + // Offset is used for pagination. Subsz() only returns connections starting at this + // offset from the global results. + Offset int `json:"offset"` + + // Limit is the maximum number of subscriptions that should be returned by Subsz(). + Limit int `json:"limit"` + + // Subscriptions indicates if subscriptions should be included in the results. + Subscriptions bool `json:"subscriptions"` + + // Test the list against this subject. Needs to be literal since it signifies a publish subject. + // We will only return subscriptions that would match if a message was sent to this subject. + Test string `json:"test,omitempty"` +} + +type SubDetail struct { + Subject string `json:"subject"` + Queue string `json:"qgroup,omitempty"` + Sid string `json:"sid"` + Msgs int64 `json:"msgs"` + Max int64 `json:"max,omitempty"` + Cid uint64 `json:"cid"` +} + +// Subsz returns a Subsz struct containing subjects statistics +func (s *Server) Subsz(opts *SubszOptions) (*Subsz, error) { + var ( + subdetail bool + test bool + offset int + limit = DefaultSubListSize + testSub = "" + ) + + if opts != nil { + subdetail = opts.Subscriptions + offset = opts.Offset + if offset < 0 { + offset = 0 + } + limit = opts.Limit + if limit <= 0 { + limit = DefaultSubListSize + } + if opts.Test != "" { + testSub = opts.Test + test = true + if !IsValidLiteralSubject(testSub) { + return nil, fmt.Errorf("Invalid test subject, must be valid publish subject: %s", testSub) + } + } + } + + sz := &Subsz{s.sl.Stats(), 0, offset, limit, nil} + + if subdetail { + // Now add in subscription's details + var raw [4096]*subscription + subs := raw[:0] + + s.sl.localSubs(&subs) + details := make([]SubDetail, len(subs)) + i := 0 + // TODO(dlc) - may be inefficient and could just do normal match when total subs is large and filtering. + for _, sub := range subs { + // Check for filter + if test && !matchLiteral(testSub, string(sub.subject)) { + continue + } + if sub.client == nil { + continue + } + sub.client.mu.Lock() + details[i] = SubDetail{ + Subject: string(sub.subject), + Queue: string(sub.queue), + Sid: string(sub.sid), + Msgs: sub.nm, + Max: sub.max, + Cid: sub.client.cid, + } + sub.client.mu.Unlock() + i++ + } + minoff := sz.Offset + maxoff := sz.Offset + sz.Limit + + maxIndex := i + + // Make sure these are sane. + if minoff > maxIndex { + minoff = maxIndex + } + if maxoff > maxIndex { + maxoff = maxIndex + } + sz.Subs = details[minoff:maxoff] + sz.Total = len(sz.Subs) + } + + return sz, nil +} + +// HandleSubsz processes HTTP requests for subjects stats. +func (s *Server) HandleSubsz(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + s.httpReqStats[SubszPath]++ + s.mu.Unlock() + + subs, err := decodeBool(w, r, "subs") + if err != nil { + return + } + offset, err := decodeInt(w, r, "offset") + if err != nil { + return + } + limit, err := decodeInt(w, r, "limit") + if err != nil { + return + } + testSub := r.URL.Query().Get("test") + + subszOpts := &SubszOptions{ + Subscriptions: subs, + Offset: offset, + Limit: limit, + Test: testSub, + } + + st, err := s.Subsz(subszOpts) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + return + } + + var b []byte + + if len(st.Subs) == 0 { + b, err = json.MarshalIndent(st.SublistStats, "", " ") + } else { + b, err = json.MarshalIndent(st, "", " ") + } + if err != nil { + s.Errorf("Error marshaling response to /subscriptionsz request: %v", err) + } + + // Handle response + ResponseHandler(w, r, b) +} + +// HandleStacksz processes HTTP requests for getting stacks +func (s *Server) HandleStacksz(w http.ResponseWriter, r *http.Request) { + // Do not get any lock here that would prevent getting the stacks + // if we were to have a deadlock somewhere. + var defaultBuf [defaultStackBufSize]byte + size := defaultStackBufSize + buf := defaultBuf[:size] + n := 0 + for { + n = runtime.Stack(buf, true) + if n < size { + break + } + size *= 2 + buf = make([]byte, size) + } + // Handle response + ResponseHandler(w, r, buf[:n]) +} + +// Varz will output server information on the monitoring port at /varz. +type Varz struct { + *Info + *Options + Port int `json:"port"` + MaxPayload int `json:"max_payload"` + Start time.Time `json:"start"` + Now time.Time `json:"now"` + Uptime string `json:"uptime"` + Mem int64 `json:"mem"` + Cores int `json:"cores"` + CPU float64 `json:"cpu"` + Connections int `json:"connections"` + TotalConnections uint64 `json:"total_connections"` + Routes int `json:"routes"` + Remotes int `json:"remotes"` + InMsgs int64 `json:"in_msgs"` + OutMsgs int64 `json:"out_msgs"` + InBytes int64 `json:"in_bytes"` + OutBytes int64 `json:"out_bytes"` + SlowConsumers int64 `json:"slow_consumers"` + MaxPending int64 `json:"max_pending"` + WriteDeadline time.Duration `json:"write_deadline"` + Subscriptions uint32 `json:"subscriptions"` + HTTPReqStats map[string]uint64 `json:"http_req_stats"` + ConfigLoadTime time.Time `json:"config_load_time"` +} + +// VarzOptions are the options passed to Varz(). +// Currently, there are no options defined. +type VarzOptions struct{} + +func myUptime(d time.Duration) string { + // Just use total seconds for uptime, and display days / years + tsecs := d / time.Second + tmins := tsecs / 60 + thrs := tmins / 60 + tdays := thrs / 24 + tyrs := tdays / 365 + + if tyrs > 0 { + return fmt.Sprintf("%dy%dd%dh%dm%ds", tyrs, tdays%365, thrs%24, tmins%60, tsecs%60) + } + if tdays > 0 { + return fmt.Sprintf("%dd%dh%dm%ds", tdays, thrs%24, tmins%60, tsecs%60) + } + if thrs > 0 { + return fmt.Sprintf("%dh%dm%ds", thrs, tmins%60, tsecs%60) + } + if tmins > 0 { + return fmt.Sprintf("%dm%ds", tmins, tsecs%60) + } + return fmt.Sprintf("%ds", tsecs) +} + +// HandleRoot will show basic info and links to others handlers. +func (s *Server) HandleRoot(w http.ResponseWriter, r *http.Request) { + // This feels dumb to me, but is required: https://code.google.com/p/go/issues/detail?id=4799 + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + s.mu.Lock() + s.httpReqStats[RootPath]++ + s.mu.Unlock() + fmt.Fprintf(w, ` + + + + + + NATS +
+ varz
+ connz
+ routez
+ subsz
+
+ help + +`) +} + +// Varz returns a Varz struct containing the server information. +func (s *Server) Varz(varzOpts *VarzOptions) (*Varz, error) { + // Snapshot server options. + opts := s.getOpts() + + v := &Varz{Info: &s.info, Options: opts, MaxPayload: opts.MaxPayload, Start: s.start} + v.Now = time.Now() + v.Uptime = myUptime(time.Since(s.start)) + v.Port = v.Info.Port + + updateUsage(v) + + s.mu.Lock() + v.Connections = len(s.clients) + v.TotalConnections = s.totalClients + v.Routes = len(s.routes) + v.Remotes = len(s.remotes) + v.InMsgs = atomic.LoadInt64(&s.inMsgs) + v.InBytes = atomic.LoadInt64(&s.inBytes) + v.OutMsgs = atomic.LoadInt64(&s.outMsgs) + v.OutBytes = atomic.LoadInt64(&s.outBytes) + v.SlowConsumers = atomic.LoadInt64(&s.slowConsumers) + v.MaxPending = opts.MaxPending + v.WriteDeadline = opts.WriteDeadline + v.Subscriptions = s.sl.Count() + v.ConfigLoadTime = s.configTime + // Need a copy here since s.httpReqStats can change while doing + // the marshaling down below. + v.HTTPReqStats = make(map[string]uint64, len(s.httpReqStats)) + for key, val := range s.httpReqStats { + v.HTTPReqStats[key] = val + } + s.mu.Unlock() + + return v, nil +} + +// HandleVarz will process HTTP requests for server information. +func (s *Server) HandleVarz(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + s.httpReqStats[VarzPath]++ + s.mu.Unlock() + + // As of now, no error is ever returned + v, _ := s.Varz(nil) + b, err := json.MarshalIndent(v, "", " ") + if err != nil { + s.Errorf("Error marshaling response to /varz request: %v", err) + } + + // Handle response + ResponseHandler(w, r, b) +} + +// Grab RSS and PCPU +func updateUsage(v *Varz) { + var rss, vss int64 + var pcpu float64 + + pse.ProcUsage(&pcpu, &rss, &vss) + + v.Mem = rss + v.CPU = pcpu + v.Cores = numCores +} + +// ResponseHandler handles responses for monitoring routes +func ResponseHandler(w http.ResponseWriter, r *http.Request, data []byte) { + // Get callback from request + callback := r.URL.Query().Get("callback") + // If callback is not empty then + if callback != "" { + // Response for JSONP + w.Header().Set("Content-Type", "application/javascript") + fmt.Fprintf(w, "%s(%s)", callback, data) + } else { + // Otherwise JSON + w.Header().Set("Content-Type", "application/json") + w.Write(data) + } +} + +func (reason ClosedState) String() string { + switch reason { + case ClientClosed: + return "Client" + case AuthenticationTimeout: + return "Authentication Timeout" + case AuthenticationViolation: + return "Authentication Failure" + case TLSHandshakeError: + return "TLS Handshake Failure" + case SlowConsumerPendingBytes: + return "Slow Consumer (Pending Bytes)" + case SlowConsumerWriteDeadline: + return "Slow Consumer (Write Deadline)" + case WriteError: + return "Write Error" + case ReadError: + return "Read Error" + case ParseError: + return "Parse Error" + case StaleConnection: + return "Stale Connection" + case ProtocolViolation: + return "Protocol Violation" + case BadClientProtocolVersion: + return "Bad Client Protocol Version" + case WrongPort: + return "Incorrect Port" + case MaxConnectionsExceeded: + return "Maximum Connections Exceeded" + case MaxPayloadExceeded: + return "Maximum Message Payload Exceeded" + case MaxControlLineExceeded: + return "Maximum Control Line Exceeded" + case DuplicateRoute: + return "Duplicate Route" + case RouteRemoved: + return "Route Removed" + case ServerShutdown: + return "Server Shutdown" + } + return "Unknown State" +} diff --git a/vendor/github.com/nats-io/gnatsd/server/monitor_sort_opts.go b/vendor/github.com/nats-io/gnatsd/server/monitor_sort_opts.go new file mode 100644 index 00000000..926ceb26 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/monitor_sort_opts.go @@ -0,0 +1,147 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "time" +) + +// Represents a connection info list. We use pointers since it will be sorted. +type ConnInfos []*ConnInfo + +// For sorting +func (cl ConnInfos) Len() int { return len(cl) } +func (cl ConnInfos) Swap(i, j int) { cl[i], cl[j] = cl[j], cl[i] } + +// SortOpt is a helper type to sort clients +type SortOpt string + +// Possible sort options +const ( + ByCid SortOpt = "cid" // By connection ID + ByStart SortOpt = "start" // By connection start time, same as CID + BySubs SortOpt = "subs" // By number of subscriptions + ByPending SortOpt = "pending" // By amount of data in bytes waiting to be sent to client + ByOutMsgs SortOpt = "msgs_to" // By number of messages sent + ByInMsgs SortOpt = "msgs_from" // By number of messages received + ByOutBytes SortOpt = "bytes_to" // By amount of bytes sent + ByInBytes SortOpt = "bytes_from" // By amount of bytes received + ByLast SortOpt = "last" // By the last activity + ByIdle SortOpt = "idle" // By the amount of inactivity + ByUptime SortOpt = "uptime" // By the amount of time connections exist + ByStop SortOpt = "stop" // By the stop time for a closed connection + ByReason SortOpt = "reason" // By the reason for a closed connection + +) + +// Individual sort options provide the Less for sort.Interface. Len and Swap are on cList. +// CID +type byCid struct{ ConnInfos } + +func (l byCid) Less(i, j int) bool { return l.ConnInfos[i].Cid < l.ConnInfos[j].Cid } + +// Number of Subscriptions +type bySubs struct{ ConnInfos } + +func (l bySubs) Less(i, j int) bool { return l.ConnInfos[i].NumSubs < l.ConnInfos[j].NumSubs } + +// Pending Bytes +type byPending struct{ ConnInfos } + +func (l byPending) Less(i, j int) bool { return l.ConnInfos[i].Pending < l.ConnInfos[j].Pending } + +// Outbound Msgs +type byOutMsgs struct{ ConnInfos } + +func (l byOutMsgs) Less(i, j int) bool { return l.ConnInfos[i].OutMsgs < l.ConnInfos[j].OutMsgs } + +// Inbound Msgs +type byInMsgs struct{ ConnInfos } + +func (l byInMsgs) Less(i, j int) bool { return l.ConnInfos[i].InMsgs < l.ConnInfos[j].InMsgs } + +// Outbound Bytes +type byOutBytes struct{ ConnInfos } + +func (l byOutBytes) Less(i, j int) bool { return l.ConnInfos[i].OutBytes < l.ConnInfos[j].OutBytes } + +// Inbound Bytes +type byInBytes struct{ ConnInfos } + +func (l byInBytes) Less(i, j int) bool { return l.ConnInfos[i].InBytes < l.ConnInfos[j].InBytes } + +// Last Activity +type byLast struct{ ConnInfos } + +func (l byLast) Less(i, j int) bool { + return l.ConnInfos[i].LastActivity.UnixNano() < l.ConnInfos[j].LastActivity.UnixNano() +} + +// Idle time +type byIdle struct{ ConnInfos } + +func (l byIdle) Less(i, j int) bool { + ii := l.ConnInfos[i].LastActivity.Sub(l.ConnInfos[i].Start) + ij := l.ConnInfos[j].LastActivity.Sub(l.ConnInfos[j].Start) + return ii < ij +} + +// Uptime +type byUptime struct { + ConnInfos + now time.Time +} + +func (l byUptime) Less(i, j int) bool { + ci := l.ConnInfos[i] + cj := l.ConnInfos[j] + var upi, upj time.Duration + if ci.Stop == nil || ci.Stop.IsZero() { + upi = l.now.Sub(ci.Start) + } else { + upi = ci.Stop.Sub(ci.Start) + } + if cj.Stop == nil || cj.Stop.IsZero() { + upj = l.now.Sub(cj.Start) + } else { + upj = cj.Stop.Sub(cj.Start) + } + return upi < upj +} + +// Stop +type byStop struct{ ConnInfos } + +func (l byStop) Less(i, j int) bool { + ciStop := l.ConnInfos[i].Stop + cjStop := l.ConnInfos[j].Stop + return ciStop.Before(*cjStop) +} + +// Reason +type byReason struct{ ConnInfos } + +func (l byReason) Less(i, j int) bool { + return l.ConnInfos[i].Reason < l.ConnInfos[j].Reason +} + +// IsValid determines if a sort option is valid +func (s SortOpt) IsValid() bool { + switch s { + case "", ByCid, ByStart, BySubs, ByPending, ByOutMsgs, ByInMsgs, ByOutBytes, ByInBytes, ByLast, ByIdle, ByUptime, ByStop, ByReason: + return true + default: + return false + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/monitor_test.go b/vendor/github.com/nats-io/gnatsd/server/monitor_test.go new file mode 100644 index 00000000..f23bff8c --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/monitor_test.go @@ -0,0 +1,1933 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "math/rand" + "net/http" + "net/url" + "runtime" + "sort" + "strings" + "sync" + "testing" + "time" + "unicode" + + "net" + + "github.com/nats-io/go-nats" +) + +const CLIENT_PORT = -1 +const MONITOR_PORT = -1 +const CLUSTER_PORT = -1 + +func DefaultMonitorOptions() *Options { + return &Options{ + Host: "127.0.0.1", + Port: CLIENT_PORT, + HTTPHost: "127.0.0.1", + HTTPPort: MONITOR_PORT, + NoLog: true, + NoSigs: true, + } +} + +func runMonitorServer() *Server { + resetPreviousHTTPConnections() + opts := DefaultMonitorOptions() + return RunServer(opts) +} + +func runMonitorServerNoHTTPPort() *Server { + resetPreviousHTTPConnections() + opts := DefaultMonitorOptions() + opts.HTTPPort = 0 + return RunServer(opts) +} + +func resetPreviousHTTPConnections() { + http.DefaultTransport = &http.Transport{} +} + +func TestMyUptime(t *testing.T) { + // Make sure we print this stuff right. + var d time.Duration + var s string + + d = 22 * time.Second + s = myUptime(d) + if s != "22s" { + t.Fatalf("Expected `22s`, go ``%s`", s) + } + d = 4*time.Minute + d + s = myUptime(d) + if s != "4m22s" { + t.Fatalf("Expected `4m22s`, go ``%s`", s) + } + d = 4*time.Hour + d + s = myUptime(d) + if s != "4h4m22s" { + t.Fatalf("Expected `4h4m22s`, go ``%s`", s) + } + d = 32*24*time.Hour + d + s = myUptime(d) + if s != "32d4h4m22s" { + t.Fatalf("Expected `32d4h4m22s`, go ``%s`", s) + } + d = 22*365*24*time.Hour + d + s = myUptime(d) + if s != "22y32d4h4m22s" { + t.Fatalf("Expected `22y32d4h4m22s`, go ``%s`", s) + } +} + +// Make sure that we do not run the http server for monitoring unless asked. +func TestNoMonitorPort(t *testing.T) { + s := runMonitorServerNoHTTPPort() + defer s.Shutdown() + + // this test might be meaningless now that we're testing with random ports? + url := fmt.Sprintf("http://127.0.0.1:%d/", 11245) + if resp, err := http.Get(url + "varz"); err == nil { + t.Fatalf("Expected error: Got %+v\n", resp) + } + if resp, err := http.Get(url + "healthz"); err == nil { + t.Fatalf("Expected error: Got %+v\n", resp) + } + if resp, err := http.Get(url + "connz"); err == nil { + t.Fatalf("Expected error: Got %+v\n", resp) + } +} + +var ( + appJSONContent = "application/json" + appJSContent = "application/javascript" + textPlain = "text/plain; charset=utf-8" +) + +func readBodyEx(t *testing.T, url string, status int, content string) []byte { + resp, err := http.Get(url) + if err != nil { + stackFatalf(t, "Expected no error: Got %v\n", err) + } + defer resp.Body.Close() + if resp.StatusCode != status { + stackFatalf(t, "Expected a %d response, got %d\n", status, resp.StatusCode) + } + ct := resp.Header.Get("Content-Type") + if ct != content { + stackFatalf(t, "Expected %s content-type, got %s\n", content, ct) + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + stackFatalf(t, "Got an error reading the body: %v\n", err) + } + return body +} + +func readBody(t *testing.T, url string) []byte { + return readBodyEx(t, url, http.StatusOK, appJSONContent) +} + +func pollVarz(t *testing.T, s *Server, mode int, url string, opts *VarzOptions) *Varz { + if mode == 0 { + v := &Varz{} + body := readBody(t, url) + if err := json.Unmarshal(body, v); err != nil { + stackFatalf(t, "Got an error unmarshalling the body: %v\n", err) + } + return v + } + v, _ := s.Varz(opts) + return v +} + +func TestHandleVarz(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + + for mode := 0; mode < 2; mode++ { + v := pollVarz(t, s, mode, url+"varz", nil) + + // Do some sanity checks on values + if time.Since(v.Start) > 10*time.Second { + t.Fatal("Expected start time to be within 10 seconds.") + } + } + + time.Sleep(100 * time.Millisecond) + + nc := createClientConnSubscribeAndPublish(t, s) + defer nc.Close() + + for mode := 0; mode < 2; mode++ { + v := pollVarz(t, s, mode, url+"varz", nil) + + if v.Connections != 1 { + t.Fatalf("Expected Connections of 1, got %v\n", v.Connections) + } + if v.TotalConnections < 1 { + t.Fatalf("Expected Total Connections of at least 1, got %v\n", v.TotalConnections) + } + if v.InMsgs != 1 { + t.Fatalf("Expected InMsgs of 1, got %v\n", v.InMsgs) + } + if v.OutMsgs != 1 { + t.Fatalf("Expected OutMsgs of 1, got %v\n", v.OutMsgs) + } + if v.InBytes != 5 { + t.Fatalf("Expected InBytes of 5, got %v\n", v.InBytes) + } + if v.OutBytes != 5 { + t.Fatalf("Expected OutBytes of 5, got %v\n", v.OutBytes) + } + if v.Subscriptions != 0 { + t.Fatalf("Expected Subscriptions of 0, got %v\n", v.Subscriptions) + } + } + + // Test JSONP + readBodyEx(t, url+"varz?callback=callback", http.StatusOK, appJSContent) +} + +func pollConz(t *testing.T, s *Server, mode int, url string, opts *ConnzOptions) *Connz { + if mode == 0 { + body := readBody(t, url) + c := &Connz{} + if err := json.Unmarshal(body, &c); err != nil { + t.Fatalf("Got an error unmarshalling the body: %v\n", err) + } + return c + } + c, err := s.Connz(opts) + if err != nil { + stackFatalf(t, "Error on Connz(): %v", err) + } + return c +} + +func TestConnz(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + + testConnz := func(mode int) { + c := pollConz(t, s, mode, url+"connz", nil) + + // Test contents.. + if c.NumConns != 0 { + t.Fatalf("Expected 0 connections, got %d\n", c.NumConns) + } + if c.Total != 0 { + t.Fatalf("Expected 0 live connections, got %d\n", c.Total) + } + if c.Conns == nil || len(c.Conns) != 0 { + t.Fatalf("Expected 0 connections in array, got %p\n", c.Conns) + } + + // Test with connections. + nc := createClientConnSubscribeAndPublish(t, s) + defer nc.Close() + + time.Sleep(50 * time.Millisecond) + + c = pollConz(t, s, mode, url+"connz", nil) + + if c.NumConns != 1 { + t.Fatalf("Expected 1 connection, got %d\n", c.NumConns) + } + if c.Total != 1 { + t.Fatalf("Expected 1 live connection, got %d\n", c.Total) + } + if c.Conns == nil || len(c.Conns) != 1 { + t.Fatalf("Expected 1 connection in array, got %d\n", len(c.Conns)) + } + + if c.Limit != DefaultConnListSize { + t.Fatalf("Expected limit of %d, got %v\n", DefaultConnListSize, c.Limit) + } + + if c.Offset != 0 { + t.Fatalf("Expected offset of 0, got %v\n", c.Offset) + } + + // Test inside details of each connection + ci := c.Conns[0] + + if ci.Cid == 0 { + t.Fatalf("Expected non-zero cid, got %v\n", ci.Cid) + } + if ci.IP != "127.0.0.1" { + t.Fatalf("Expected \"127.0.0.1\" for IP, got %v\n", ci.IP) + } + if ci.Port == 0 { + t.Fatalf("Expected non-zero port, got %v\n", ci.Port) + } + if ci.NumSubs != 0 { + t.Fatalf("Expected num_subs of 0, got %v\n", ci.NumSubs) + } + if len(ci.Subs) != 0 { + t.Fatalf("Expected subs of 0, got %v\n", ci.Subs) + } + if ci.InMsgs != 1 { + t.Fatalf("Expected InMsgs of 1, got %v\n", ci.InMsgs) + } + if ci.OutMsgs != 1 { + t.Fatalf("Expected OutMsgs of 1, got %v\n", ci.OutMsgs) + } + if ci.InBytes != 5 { + t.Fatalf("Expected InBytes of 1, got %v\n", ci.InBytes) + } + if ci.OutBytes != 5 { + t.Fatalf("Expected OutBytes of 1, got %v\n", ci.OutBytes) + } + if ci.Start.IsZero() { + t.Fatal("Expected Start to be valid\n") + } + if ci.Uptime == "" { + t.Fatal("Expected Uptime to be valid\n") + } + if ci.LastActivity.IsZero() { + t.Fatal("Expected LastActivity to be valid\n") + } + if ci.LastActivity.UnixNano() < ci.Start.UnixNano() { + t.Fatalf("Expected LastActivity [%v] to be > Start [%v]\n", ci.LastActivity, ci.Start) + } + if ci.Idle == "" { + t.Fatal("Expected Idle to be valid\n") + } + if ci.RTT != "" { + t.Fatal("Expected RTT to NOT be set for new connection\n") + } + } + + for mode := 0; mode < 2; mode++ { + testConnz(mode) + checkClientsCount(t, s, 0) + } + + // Test JSONP + readBodyEx(t, url+"connz?callback=callback", http.StatusOK, appJSContent) +} + +func TestConnzBadParams(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + url := fmt.Sprintf("http://127.0.0.1:%d/connz?", s.MonitorAddr().Port) + readBodyEx(t, url+"auth=xxx", http.StatusBadRequest, textPlain) + readBodyEx(t, url+"subs=xxx", http.StatusBadRequest, textPlain) + readBodyEx(t, url+"offset=xxx", http.StatusBadRequest, textPlain) + readBodyEx(t, url+"limit=xxx", http.StatusBadRequest, textPlain) + readBodyEx(t, url+"state=xxx", http.StatusBadRequest, textPlain) +} + +func TestConnzWithSubs(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + nc := createClientConnSubscribeAndPublish(t, s) + defer nc.Close() + + nc.Subscribe("hello.foo", func(m *nats.Msg) {}) + ensureServerActivityRecorded(t, nc) + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz?subs=1", &ConnzOptions{Subscriptions: true}) + // Test inside details of each connection + ci := c.Conns[0] + if len(ci.Subs) != 1 || ci.Subs[0] != "hello.foo" { + t.Fatalf("Expected subs of 1, got %v\n", ci.Subs) + } + } +} + +func TestConnzWithCID(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + // The one we will request + cid := 5 + total := 10 + + // Create 10 + for i := 1; i <= total; i++ { + nc := createClientConnSubscribeAndPublish(t, s) + defer nc.Close() + if i == cid { + nc.Subscribe("hello.foo", func(m *nats.Msg) {}) + nc.Subscribe("hello.bar", func(m *nats.Msg) {}) + ensureServerActivityRecorded(t, nc) + } + } + + url := fmt.Sprintf("http://127.0.0.1:%d/connz?cid=%d", s.MonitorAddr().Port, cid) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url, &ConnzOptions{CID: uint64(cid)}) + // Test inside details of each connection + if len(c.Conns) != 1 { + t.Fatalf("Expected only one connection, but got %d\n", len(c.Conns)) + } + if c.NumConns != 1 { + t.Fatalf("Expected NumConns to be 1, but got %d\n", c.NumConns) + } + ci := c.Conns[0] + if ci.Cid != uint64(cid) { + t.Fatalf("Expected to receive connection %v, but received %v\n", cid, ci.Cid) + } + if ci.NumSubs != 2 { + t.Fatalf("Expected to receive connection with %d subs, but received %d\n", 2, ci.NumSubs) + } + // Now test a miss + badUrl := fmt.Sprintf("http://127.0.0.1:%d/connz?cid=%d", s.MonitorAddr().Port, 100) + c = pollConz(t, s, mode, badUrl, &ConnzOptions{CID: uint64(100)}) + if len(c.Conns) != 0 { + t.Fatalf("Expected no connections, got %d\n", len(c.Conns)) + } + if c.NumConns != 0 { + t.Fatalf("Expected NumConns of 0, got %d\n", c.NumConns) + } + } +} + +// Helper to map to connection name +func createConnMap(t *testing.T, cz *Connz) map[string]*ConnInfo { + cm := make(map[string]*ConnInfo) + for _, c := range cz.Conns { + cm[c.Name] = c + } + return cm +} + +func getFooAndBar(t *testing.T, cm map[string]*ConnInfo) (*ConnInfo, *ConnInfo) { + return cm["foo"], cm["bar"] +} + +func ensureServerActivityRecorded(t *testing.T, nc *nats.Conn) { + nc.Flush() + err := nc.Flush() + if err != nil { + t.Fatalf("Error flushing: %v\n", err) + } +} + +func TestConnzRTT(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + + testRTT := func(mode int) { + // Test with connections. + nc := createClientConnSubscribeAndPublish(t, s) + defer nc.Close() + + c := pollConz(t, s, mode, url+"connz", nil) + + if c.NumConns != 1 { + t.Fatalf("Expected 1 connection, got %d\n", c.NumConns) + } + + // Send a server side PING to record RTT + s.mu.Lock() + ci := c.Conns[0] + sc := s.clients[ci.Cid] + if sc == nil { + t.Fatalf("Error looking up client %v\n", ci.Cid) + } + s.mu.Unlock() + sc.mu.Lock() + sc.sendPing() + sc.mu.Unlock() + + // Wait for client to respond with PONG + time.Sleep(20 * time.Millisecond) + + // Repoll for updated information. + c = pollConz(t, s, mode, url+"connz", nil) + ci = c.Conns[0] + + rtt, err := time.ParseDuration(ci.RTT) + if err != nil { + t.Fatalf("Could not parse RTT properly, %v (ci.RTT=%v)", err, ci.RTT) + } + if rtt <= 0 { + t.Fatal("Expected RTT to be valid and non-zero\n") + } + if rtt > 20*time.Millisecond || rtt < 100*time.Nanosecond { + t.Fatalf("Invalid RTT of %s\n", ci.RTT) + } + } + + for mode := 0; mode < 2; mode++ { + testRTT(mode) + checkClientsCount(t, s, 0) + } +} + +func TestConnzLastActivity(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + url += "connz?subs=1" + opts := &ConnzOptions{Subscriptions: true} + + testActivity := func(mode int) { + ncFoo := createClientConnWithName(t, "foo", s) + defer ncFoo.Close() + + ncBar := createClientConnWithName(t, "bar", s) + defer ncBar.Close() + + // Test inside details of each connection + ciFoo, ciBar := getFooAndBar(t, createConnMap(t, pollConz(t, s, mode, url, opts))) + + // Test that LastActivity is non-zero + if ciFoo.LastActivity.IsZero() { + t.Fatalf("Expected LastActivity for connection '%s'to be valid\n", ciFoo.Name) + } + if ciBar.LastActivity.IsZero() { + t.Fatalf("Expected LastActivity for connection '%s'to be valid\n", ciBar.Name) + } + // Foo should be older than Bar + if ciFoo.LastActivity.After(ciBar.LastActivity) { + t.Fatal("Expected connection 'foo' to be older than 'bar'\n") + } + + fooLA := ciFoo.LastActivity + barLA := ciBar.LastActivity + + ensureServerActivityRecorded(t, ncFoo) + ensureServerActivityRecorded(t, ncBar) + + // Sub should trigger update. + sub, _ := ncFoo.Subscribe("hello.world", func(m *nats.Msg) {}) + ensureServerActivityRecorded(t, ncFoo) + + ciFoo, _ = getFooAndBar(t, createConnMap(t, pollConz(t, s, mode, url, opts))) + nextLA := ciFoo.LastActivity + if fooLA.Equal(nextLA) { + t.Fatalf("Subscribe should have triggered update to LastActivity %+v\n", ciFoo) + } + fooLA = nextLA + + // Publish and Message Delivery should trigger as well. So both connections + // should have updates. + ncBar.Publish("hello.world", []byte("Hello")) + + ensureServerActivityRecorded(t, ncFoo) + ensureServerActivityRecorded(t, ncBar) + + ciFoo, ciBar = getFooAndBar(t, createConnMap(t, pollConz(t, s, mode, url, opts))) + nextLA = ciBar.LastActivity + if barLA.Equal(nextLA) { + t.Fatalf("Publish should have triggered update to LastActivity\n") + } + barLA = nextLA + + // Message delivery on ncFoo should have triggered as well. + nextLA = ciFoo.LastActivity + if fooLA.Equal(nextLA) { + t.Fatalf("Message delivery should have triggered update to LastActivity\n") + } + fooLA = nextLA + + // Unsub should trigger as well + sub.Unsubscribe() + ensureServerActivityRecorded(t, ncFoo) + + ciFoo, _ = getFooAndBar(t, createConnMap(t, pollConz(t, s, mode, url, opts))) + nextLA = ciFoo.LastActivity + if fooLA.Equal(nextLA) { + t.Fatalf("Message delivery should have triggered update to LastActivity\n") + } + } + + for mode := 0; mode < 2; mode++ { + testActivity(mode) + } +} + +func TestConnzWithOffsetAndLimit(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz?offset=1&limit=1", &ConnzOptions{Offset: 1, Limit: 1}) + if c.Conns == nil || len(c.Conns) != 0 { + t.Fatalf("Expected 0 connections in array, got %p\n", c.Conns) + } + + // Test that when given negative values, 0 or default is used + c = pollConz(t, s, mode, url+"connz?offset=-1&limit=-1", &ConnzOptions{Offset: -11, Limit: -11}) + if c.Conns == nil || len(c.Conns) != 0 { + t.Fatalf("Expected 0 connections in array, got %p\n", c.Conns) + } + if c.Offset != 0 { + t.Fatalf("Expected offset to be 0, and limit to be %v, got %v and %v", + DefaultConnListSize, c.Offset, c.Limit) + } + } + + cl1 := createClientConnSubscribeAndPublish(t, s) + defer cl1.Close() + + cl2 := createClientConnSubscribeAndPublish(t, s) + defer cl2.Close() + + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz?offset=1&limit=1", &ConnzOptions{Offset: 1, Limit: 1}) + if c.Limit != 1 { + t.Fatalf("Expected limit of 1, got %v\n", c.Limit) + } + + if c.Offset != 1 { + t.Fatalf("Expected offset of 1, got %v\n", c.Offset) + } + + if len(c.Conns) != 1 { + t.Fatalf("Expected conns of 1, got %v\n", len(c.Conns)) + } + + if c.NumConns != 1 { + t.Fatalf("Expected NumConns to be 1, got %v\n", c.NumConns) + } + + if c.Total != 2 { + t.Fatalf("Expected Total to be at least 2, got %v", c.Total) + } + + c = pollConz(t, s, mode, url+"connz?offset=2&limit=1", &ConnzOptions{Offset: 2, Limit: 1}) + if c.Limit != 1 { + t.Fatalf("Expected limit of 1, got %v\n", c.Limit) + } + + if c.Offset != 2 { + t.Fatalf("Expected offset of 2, got %v\n", c.Offset) + } + + if len(c.Conns) != 0 { + t.Fatalf("Expected conns of 0, got %v\n", len(c.Conns)) + } + + if c.NumConns != 0 { + t.Fatalf("Expected NumConns to be 0, got %v\n", c.NumConns) + } + + if c.Total != 2 { + t.Fatalf("Expected Total to be 2, got %v", c.Total) + } + } +} + +func TestConnzDefaultSorted(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + clients := make([]*nats.Conn, 4) + for i := range clients { + clients[i] = createClientConnSubscribeAndPublish(t, s) + defer clients[i].Close() + } + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz", nil) + if c.Conns[0].Cid > c.Conns[1].Cid || + c.Conns[1].Cid > c.Conns[2].Cid || + c.Conns[2].Cid > c.Conns[3].Cid { + t.Fatalf("Expected conns sorted in ascending order by cid, got %v < %v\n", c.Conns[0].Cid, c.Conns[3].Cid) + } + } +} + +func TestConnzSortedByCid(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + clients := make([]*nats.Conn, 4) + for i := range clients { + clients[i] = createClientConnSubscribeAndPublish(t, s) + defer clients[i].Close() + } + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz?sort=cid", &ConnzOptions{Sort: ByCid}) + if c.Conns[0].Cid > c.Conns[1].Cid || + c.Conns[1].Cid > c.Conns[2].Cid || + c.Conns[2].Cid > c.Conns[3].Cid { + t.Fatalf("Expected conns sorted in ascending order by cid, got [%v, %v, %v, %v]\n", + c.Conns[0].Cid, c.Conns[1].Cid, c.Conns[2].Cid, c.Conns[3].Cid) + } + } +} + +func TestConnzSortedByStart(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + clients := make([]*nats.Conn, 4) + for i := range clients { + clients[i] = createClientConnSubscribeAndPublish(t, s) + defer clients[i].Close() + } + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz?sort=start", &ConnzOptions{Sort: ByStart}) + if c.Conns[0].Start.After(c.Conns[1].Start) || + c.Conns[1].Start.After(c.Conns[2].Start) || + c.Conns[2].Start.After(c.Conns[3].Start) { + t.Fatalf("Expected conns sorted in ascending order by startime, got [%v, %v, %v, %v]\n", + c.Conns[0].Start, c.Conns[1].Start, c.Conns[2].Start, c.Conns[3].Start) + } + } +} + +func TestConnzSortedByBytesAndMsgs(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + // Create a connection and make it send more messages than others + firstClient := createClientConnSubscribeAndPublish(t, s) + for i := 0; i < 100; i++ { + firstClient.Publish("foo", []byte("Hello World")) + } + defer firstClient.Close() + firstClient.Flush() + + clients := make([]*nats.Conn, 3) + for i := range clients { + clients[i] = createClientConnSubscribeAndPublish(t, s) + defer clients[i].Close() + } + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz?sort=bytes_to", &ConnzOptions{Sort: ByOutBytes}) + if c.Conns[0].OutBytes < c.Conns[1].OutBytes || + c.Conns[0].OutBytes < c.Conns[2].OutBytes || + c.Conns[0].OutBytes < c.Conns[3].OutBytes { + t.Fatalf("Expected conns sorted in descending order by bytes to, got %v < one of [%v, %v, %v]\n", + c.Conns[0].OutBytes, c.Conns[1].OutBytes, c.Conns[2].OutBytes, c.Conns[3].OutBytes) + } + + c = pollConz(t, s, mode, url+"connz?sort=msgs_to", &ConnzOptions{Sort: ByOutMsgs}) + if c.Conns[0].OutMsgs < c.Conns[1].OutMsgs || + c.Conns[0].OutMsgs < c.Conns[2].OutMsgs || + c.Conns[0].OutMsgs < c.Conns[3].OutMsgs { + t.Fatalf("Expected conns sorted in descending order by msgs from, got %v < one of [%v, %v, %v]\n", + c.Conns[0].OutMsgs, c.Conns[1].OutMsgs, c.Conns[2].OutMsgs, c.Conns[3].OutMsgs) + } + + c = pollConz(t, s, mode, url+"connz?sort=bytes_from", &ConnzOptions{Sort: ByInBytes}) + if c.Conns[0].InBytes < c.Conns[1].InBytes || + c.Conns[0].InBytes < c.Conns[2].InBytes || + c.Conns[0].InBytes < c.Conns[3].InBytes { + t.Fatalf("Expected conns sorted in descending order by bytes from, got %v < one of [%v, %v, %v]\n", + c.Conns[0].InBytes, c.Conns[1].InBytes, c.Conns[2].InBytes, c.Conns[3].InBytes) + } + + c = pollConz(t, s, mode, url+"connz?sort=msgs_from", &ConnzOptions{Sort: ByInMsgs}) + if c.Conns[0].InMsgs < c.Conns[1].InMsgs || + c.Conns[0].InMsgs < c.Conns[2].InMsgs || + c.Conns[0].InMsgs < c.Conns[3].InMsgs { + t.Fatalf("Expected conns sorted in descending order by msgs from, got %v < one of [%v, %v, %v]\n", + c.Conns[0].InMsgs, c.Conns[1].InMsgs, c.Conns[2].InMsgs, c.Conns[3].InMsgs) + } + } +} + +func TestConnzSortedByPending(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + firstClient := createClientConnSubscribeAndPublish(t, s) + firstClient.Subscribe("hello.world", func(m *nats.Msg) {}) + clients := make([]*nats.Conn, 3) + for i := range clients { + clients[i] = createClientConnSubscribeAndPublish(t, s) + defer clients[i].Close() + } + defer firstClient.Close() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz?sort=pending", &ConnzOptions{Sort: ByPending}) + if c.Conns[0].Pending < c.Conns[1].Pending || + c.Conns[0].Pending < c.Conns[2].Pending || + c.Conns[0].Pending < c.Conns[3].Pending { + t.Fatalf("Expected conns sorted in descending order by number of pending, got %v < one of [%v, %v, %v]\n", + c.Conns[0].Pending, c.Conns[1].Pending, c.Conns[2].Pending, c.Conns[3].Pending) + } + } +} + +func TestConnzSortedBySubs(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + firstClient := createClientConnSubscribeAndPublish(t, s) + firstClient.Subscribe("hello.world", func(m *nats.Msg) {}) + defer firstClient.Close() + + clients := make([]*nats.Conn, 3) + for i := range clients { + clients[i] = createClientConnSubscribeAndPublish(t, s) + defer clients[i].Close() + } + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz?sort=subs", &ConnzOptions{Sort: BySubs}) + if c.Conns[0].NumSubs < c.Conns[1].NumSubs || + c.Conns[0].NumSubs < c.Conns[2].NumSubs || + c.Conns[0].NumSubs < c.Conns[3].NumSubs { + t.Fatalf("Expected conns sorted in descending order by number of subs, got %v < one of [%v, %v, %v]\n", + c.Conns[0].NumSubs, c.Conns[1].NumSubs, c.Conns[2].NumSubs, c.Conns[3].NumSubs) + } + } +} + +func TestConnzSortedByLast(t *testing.T) { + opts := DefaultMonitorOptions() + s := RunServer(opts) + defer s.Shutdown() + + firstClient := createClientConnSubscribeAndPublish(t, s) + defer firstClient.Close() + firstClient.Subscribe("hello.world", func(m *nats.Msg) {}) + firstClient.Flush() + + clients := make([]*nats.Conn, 3) + for i := range clients { + clients[i] = createClientConnSubscribeAndPublish(t, s) + defer clients[i].Close() + clients[i].Flush() + } + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz?sort=last", &ConnzOptions{Sort: ByLast}) + if c.Conns[0].LastActivity.UnixNano() < c.Conns[1].LastActivity.UnixNano() || + c.Conns[1].LastActivity.UnixNano() < c.Conns[2].LastActivity.UnixNano() || + c.Conns[2].LastActivity.UnixNano() < c.Conns[3].LastActivity.UnixNano() { + t.Fatalf("Expected conns sorted in descending order by lastActivity, got %v < one of [%v, %v, %v]\n", + c.Conns[0].LastActivity, c.Conns[1].LastActivity, c.Conns[2].LastActivity, c.Conns[3].LastActivity) + } + } +} + +func TestConnzSortedByUptime(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + for i := 0; i < 4; i++ { + client := createClientConnSubscribeAndPublish(t, s) + defer client.Close() + // Since we check times (now-start) does not have to be big. + time.Sleep(50 * time.Millisecond) + } + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz?sort=uptime", &ConnzOptions{Sort: ByUptime}) + now := time.Now() + ups := make([]int, 4) + for i := 0; i < 4; i++ { + ups[i] = int(now.Sub(c.Conns[i].Start)) + } + if !sort.IntsAreSorted(ups) { + d := make([]time.Duration, 4) + for i := 0; i < 4; i++ { + d[i] = time.Duration(ups[i]) + } + t.Fatalf("Expected conns sorted in ascending order by uptime (now-Start), got %+v\n", d) + } + } +} + +func TestConnzSortedByUptimeClosedConn(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + for i := time.Duration(1); i <= 4; i++ { + c := createClientConnSubscribeAndPublish(t, s) + + // Grab client and asjust start time such that + client := s.getClient(uint64(i)) + if client == nil { + t.Fatalf("Could nopt retrieve client for %d\n", i) + } + client.mu.Lock() + client.start = client.start.Add(-10 * (4 - i) * time.Second) + client.mu.Unlock() + + c.Close() + } + + checkClosedConns(t, s, 4, time.Second) + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz?state=closed&sort=uptime", &ConnzOptions{State: ConnClosed, Sort: ByUptime}) + ups := make([]int, 4) + for i := 0; i < 4; i++ { + ups[i] = int(c.Conns[i].Stop.Sub(c.Conns[i].Start)) + } + if !sort.IntsAreSorted(ups) { + d := make([]time.Duration, 4) + for i := 0; i < 4; i++ { + d[i] = time.Duration(ups[i]) + } + t.Fatalf("Expected conns sorted in ascending order by uptime, got %+v\n", d) + } + } +} + +func TestConnzSortedByStopOnOpen(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + opts := s.getOpts() + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + + // 4 clients + for i := 0; i < 4; i++ { + c, err := nats.Connect(url) + if err != nil { + t.Fatalf("Could not create client: %v\n", err) + } + defer c.Close() + } + + c, err := s.Connz(&ConnzOptions{Sort: ByStop}) + if err == nil { + t.Fatalf("Expected err to be non-nil, got %+v\n", c) + } +} + +func TestConnzSortedByStopTimeClosedConn(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + opts := s.getOpts() + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + + // 4 clients + for i := 0; i < 4; i++ { + c, err := nats.Connect(url) + if err != nil { + t.Fatalf("Could not create client: %v\n", err) + } + c.Close() + } + checkClosedConns(t, s, 4, time.Second) + + //Now adjust the Stop times for these with some random values. + s.mu.Lock() + now := time.Now() + ccs := s.closed.closedClients() + for _, cc := range ccs { + newStop := now.Add(time.Duration(rand.Int()%120) * -time.Minute) + cc.Stop = &newStop + } + s.mu.Unlock() + + url = fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz?state=closed&sort=stop", &ConnzOptions{State: ConnClosed, Sort: ByStop}) + ups := make([]int, 4) + nowU := time.Now().UnixNano() + for i := 0; i < 4; i++ { + ups[i] = int(nowU - c.Conns[i].Stop.UnixNano()) + } + if !sort.IntsAreSorted(ups) { + d := make([]time.Duration, 4) + for i := 0; i < 4; i++ { + d[i] = time.Duration(ups[i]) + } + t.Fatalf("Expected conns sorted in ascending order by stop time, got %+v\n", d) + } + } +} + +func TestConnzSortedByReason(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + opts := s.getOpts() + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + + // 20 clients + for i := 0; i < 20; i++ { + c, err := nats.Connect(url) + if err != nil { + t.Fatalf("Could not create client: %v\n", err) + } + c.Close() + } + checkClosedConns(t, s, 20, time.Second) + + //Now adjust the Reasons for these with some random values. + s.mu.Lock() + ccs := s.closed.closedClients() + max := int(ServerShutdown) + for _, cc := range ccs { + cc.Reason = ClosedState(rand.Int() % max).String() + } + s.mu.Unlock() + + url = fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz?state=closed&sort=reason", &ConnzOptions{State: ConnClosed, Sort: ByReason}) + rs := make([]string, 20) + for i := 0; i < 20; i++ { + rs[i] = c.Conns[i].Reason + } + if !sort.StringsAreSorted(rs) { + t.Fatalf("Expected conns sorted in order by stop reason, got %#v\n", rs) + } + } +} + +func TestConnzSortedByReasonOnOpen(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + opts := s.getOpts() + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + + // 4 clients + for i := 0; i < 4; i++ { + c, err := nats.Connect(url) + if err != nil { + t.Fatalf("Could not create client: %v\n", err) + } + defer c.Close() + } + + c, err := s.Connz(&ConnzOptions{Sort: ByReason}) + if err == nil { + t.Fatalf("Expected err to be non-nil, got %+v\n", c) + } +} + +func TestConnzSortedByIdle(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + + testIdle := func(mode int) { + firstClient := createClientConnSubscribeAndPublish(t, s) + defer firstClient.Close() + firstClient.Subscribe("client.1", func(m *nats.Msg) {}) + firstClient.Flush() + + secondClient := createClientConnSubscribeAndPublish(t, s) + defer secondClient.Close() + + // Make it such that the second client started 10 secs ago. 10 is important since bug + // was strcmp, e.g. 1s vs 11s + var cid uint64 + switch mode { + case 0: + cid = uint64(2) + case 1: + cid = uint64(4) + } + client := s.getClient(cid) + if client == nil { + t.Fatalf("Error looking up client %v\n", 2) + } + + client.mu.Lock() + client.start = client.start.Add(-10 * time.Second) + client.last = client.start + client.mu.Unlock() + + // The Idle granularity is a whole second + time.Sleep(time.Second) + firstClient.Publish("client.1", []byte("new message")) + + c := pollConz(t, s, mode, url+"connz?sort=idle", &ConnzOptions{Sort: ByIdle}) + // Make sure we are returned 2 connections... + if len(c.Conns) != 2 { + t.Fatalf("Expected to get two connections, got %v", len(c.Conns)) + } + + // And that the Idle time is valid (even if equal to "0s") + if c.Conns[0].Idle == "" || c.Conns[1].Idle == "" { + t.Fatal("Expected Idle value to be valid") + } + + idle1, err := time.ParseDuration(c.Conns[0].Idle) + if err != nil { + t.Fatalf("Unable to parse duration %v, err=%v", c.Conns[0].Idle, err) + } + idle2, err := time.ParseDuration(c.Conns[1].Idle) + if err != nil { + t.Fatalf("Unable to parse duration %v, err=%v", c.Conns[0].Idle, err) + } + + if idle2 < idle1 { + t.Fatalf("Expected conns sorted in descending order by Idle, got %v < %v\n", + idle2, idle1) + } + } + for mode := 0; mode < 2; mode++ { + testIdle(mode) + } +} + +func TestConnzSortBadRequest(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + firstClient := createClientConnSubscribeAndPublish(t, s) + firstClient.Subscribe("hello.world", func(m *nats.Msg) {}) + clients := make([]*nats.Conn, 3) + for i := range clients { + clients[i] = createClientConnSubscribeAndPublish(t, s) + defer clients[i].Close() + } + defer firstClient.Close() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + readBodyEx(t, url+"connz?sort=foo", http.StatusBadRequest, textPlain) + + if _, err := s.Connz(&ConnzOptions{Sort: "foo"}); err == nil { + t.Fatal("Expected error, got none") + } +} + +func pollRoutez(t *testing.T, s *Server, mode int, url string, opts *RoutezOptions) *Routez { + if mode == 0 { + rz := &Routez{} + body := readBody(t, url) + if err := json.Unmarshal(body, rz); err != nil { + stackFatalf(t, "Got an error unmarshalling the body: %v\n", err) + } + return rz + } + rz, _ := s.Routez(opts) + return rz +} + +func TestConnzWithRoutes(t *testing.T) { + + opts := DefaultMonitorOptions() + opts.Cluster.Host = "127.0.0.1" + opts.Cluster.Port = CLUSTER_PORT + + s := RunServer(opts) + defer s.Shutdown() + + opts = &Options{ + Host: "127.0.0.1", + Port: -1, + Cluster: ClusterOpts{ + Host: "127.0.0.1", + Port: -1, + }, + NoLog: true, + NoSigs: true, + } + routeURL, _ := url.Parse(fmt.Sprintf("nats-route://127.0.0.1:%d", s.ClusterAddr().Port)) + opts.Routes = []*url.URL{routeURL} + + sc := RunServer(opts) + defer sc.Shutdown() + + checkClusterFormed(t, s, sc) + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + c := pollConz(t, s, mode, url+"connz", nil) + // Test contents.. + // Make sure routes don't show up under connz, but do under routez + if c.NumConns != 0 { + t.Fatalf("Expected 0 connections, got %d\n", c.NumConns) + } + if c.Conns == nil || len(c.Conns) != 0 { + t.Fatalf("Expected 0 connections in array, got %p\n", c.Conns) + } + } + + nc := createClientConnSubscribeAndPublish(t, sc) + defer nc.Close() + + nc.Subscribe("hello.bar", func(m *nats.Msg) {}) + nc.Flush() + checkExpectedSubs(t, 1, s, sc) + + // Now check routez + urls := []string{"routez", "routez?subs=1"} + for subs, urlSuffix := range urls { + for mode := 0; mode < 2; mode++ { + rz := pollRoutez(t, s, mode, url+urlSuffix, &RoutezOptions{Subscriptions: subs == 1}) + + if rz.NumRoutes != 1 { + t.Fatalf("Expected 1 route, got %d\n", rz.NumRoutes) + } + + if len(rz.Routes) != 1 { + t.Fatalf("Expected route array of 1, got %v\n", len(rz.Routes)) + } + + route := rz.Routes[0] + + if route.DidSolicit { + t.Fatalf("Expected unsolicited route, got %v\n", route.DidSolicit) + } + + // Don't ask for subs, so there should not be any + if subs == 0 { + if len(route.Subs) != 0 { + t.Fatalf("There should not be subs, got %v", len(route.Subs)) + } + } else { + if len(route.Subs) != 1 { + t.Fatalf("There should be 1 sub, got %v", len(route.Subs)) + } + } + } + } + + // Test JSONP + readBodyEx(t, url+"routez?callback=callback", http.StatusOK, appJSContent) +} + +func TestRoutezWithBadParams(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + url := fmt.Sprintf("http://127.0.0.1:%d/routez?", s.MonitorAddr().Port) + readBodyEx(t, url+"subs=xxx", http.StatusBadRequest, textPlain) +} + +func pollSubsz(t *testing.T, s *Server, mode int, url string, opts *SubszOptions) *Subsz { + if mode == 0 { + body := readBody(t, url) + sz := &Subsz{} + if err := json.Unmarshal(body, sz); err != nil { + stackFatalf(t, "Got an error unmarshalling the body: %v\n", err) + } + return sz + } + sz, _ := s.Subsz(opts) + return sz +} + +func TestSubsz(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + nc := createClientConnSubscribeAndPublish(t, s) + defer nc.Close() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + + for mode := 0; mode < 2; mode++ { + sl := pollSubsz(t, s, mode, url+"subsz", nil) + if sl.NumSubs != 0 { + t.Fatalf("Expected NumSubs of 0, got %d\n", sl.NumSubs) + } + if sl.NumInserts != 1 { + t.Fatalf("Expected NumInserts of 1, got %d\n", sl.NumInserts) + } + if sl.NumMatches != 1 { + t.Fatalf("Expected NumMatches of 1, got %d\n", sl.NumMatches) + } + } + + // Test JSONP + readBodyEx(t, url+"subsz?callback=callback", http.StatusOK, appJSContent) +} + +func TestSubszDetails(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + nc := createClientConnSubscribeAndPublish(t, s) + defer nc.Close() + + nc.Subscribe("foo.*", func(m *nats.Msg) {}) + nc.Subscribe("foo.bar", func(m *nats.Msg) {}) + nc.Subscribe("foo.foo", func(m *nats.Msg) {}) + + nc.Publish("foo.bar", []byte("Hello")) + nc.Publish("foo.baz", []byte("Hello")) + nc.Publish("foo.foo", []byte("Hello")) + + nc.Flush() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + + for mode := 0; mode < 2; mode++ { + sl := pollSubsz(t, s, mode, url+"subsz?subs=1", &SubszOptions{Subscriptions: true}) + if sl.NumSubs != 3 { + t.Fatalf("Expected NumSubs of 3, got %d\n", sl.NumSubs) + } + if sl.Total != 3 { + t.Fatalf("Expected Total of 3, got %d\n", sl.Total) + } + if len(sl.Subs) != 3 { + t.Fatalf("Expected subscription details for 3 subs, got %d\n", len(sl.Subs)) + } + } +} + +func TestSubszWithOffsetAndLimit(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + nc := createClientConnSubscribeAndPublish(t, s) + defer nc.Close() + + for i := 0; i < 200; i++ { + nc.Subscribe(fmt.Sprintf("foo.%d", i), func(m *nats.Msg) {}) + } + nc.Flush() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + sl := pollSubsz(t, s, mode, url+"subsz?subs=1&offset=10&limit=100", &SubszOptions{Subscriptions: true, Offset: 10, Limit: 100}) + if sl.NumSubs != 200 { + t.Fatalf("Expected NumSubs of 200, got %d\n", sl.NumSubs) + } + if sl.Total != 100 { + t.Fatalf("Expected Total of 100, got %d\n", sl.Total) + } + if sl.Offset != 10 { + t.Fatalf("Expected Offset of 10, got %d\n", sl.Offset) + } + if sl.Limit != 100 { + t.Fatalf("Expected Total of 100, got %d\n", sl.Limit) + } + if len(sl.Subs) != 100 { + t.Fatalf("Expected subscription details for 100 subs, got %d\n", len(sl.Subs)) + } + } +} + +func TestSubszTestPubSubject(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + nc := createClientConnSubscribeAndPublish(t, s) + defer nc.Close() + + nc.Subscribe("foo.*", func(m *nats.Msg) {}) + nc.Subscribe("foo.bar", func(m *nats.Msg) {}) + nc.Subscribe("foo.foo", func(m *nats.Msg) {}) + nc.Flush() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + sl := pollSubsz(t, s, mode, url+"subsz?subs=1&test=foo.foo", &SubszOptions{Subscriptions: true, Test: "foo.foo"}) + if sl.Total != 2 { + t.Fatalf("Expected Total of 2 match, got %d\n", sl.Total) + } + if len(sl.Subs) != 2 { + t.Fatalf("Expected subscription details for 2 matching subs, got %d\n", len(sl.Subs)) + } + sl = pollSubsz(t, s, mode, url+"subsz?subs=1&test=foo", &SubszOptions{Subscriptions: true, Test: "foo"}) + if len(sl.Subs) != 0 { + t.Fatalf("Expected no matching subs, got %d\n", len(sl.Subs)) + } + } + // Make sure we get an error with invalid test subject. + testUrl := url + "subsz?subs=1&" + readBodyEx(t, testUrl+"test=*", http.StatusBadRequest, textPlain) + readBodyEx(t, testUrl+"test=foo.*", http.StatusBadRequest, textPlain) + readBodyEx(t, testUrl+"test=foo.>", http.StatusBadRequest, textPlain) + readBodyEx(t, testUrl+"test=foo..bar", http.StatusBadRequest, textPlain) +} + +// Tests handle root +func TestHandleRoot(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + nc := createClientConnSubscribeAndPublish(t, s) + defer nc.Close() + + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port)) + if err != nil { + t.Fatalf("Expected no error: Got %v\n", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected a %d response, got %d\n", http.StatusOK, resp.StatusCode) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Expected no error reading body: Got %v\n", err) + } + for _, b := range body { + if b > unicode.MaxASCII { + t.Fatalf("Expected body to contain only ASCII characters, but got %v\n", b) + } + } + + ct := resp.Header.Get("Content-Type") + if !strings.Contains(ct, "text/html") { + t.Fatalf("Expected text/html response, got %s\n", ct) + } + defer resp.Body.Close() +} + +func TestConnzWithNamedClient(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + clientName := "test-client" + nc := createClientConnWithName(t, clientName, s) + defer nc.Close() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + // Confirm server is exposing client name in monitoring endpoint. + c := pollConz(t, s, mode, url+"connz", nil) + got := len(c.Conns) + expected := 1 + if got != expected { + t.Fatalf("Expected %d connection in array, got %d\n", expected, got) + } + + conn := c.Conns[0] + if conn.Name != clientName { + t.Fatalf("Expected client to have name %q. got %q", clientName, conn.Name) + } + } +} + +func TestConnzWithStateForClosedConns(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + numEach := 10 + // Create 10 closed, and 10 to leave open. + for i := 0; i < numEach; i++ { + nc := createClientConnSubscribeAndPublish(t, s) + nc.Subscribe("hello.closed.conns", func(m *nats.Msg) {}) + nc.Close() + nc = createClientConnSubscribeAndPublish(t, s) + nc.Subscribe("hello.open.conns", func(m *nats.Msg) {}) + defer nc.Close() + } + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + + for mode := 0; mode < 2; mode++ { + // Look at all open + c := pollConz(t, s, mode, url+"connz?state=open", &ConnzOptions{State: ConnOpen}) + if lc := len(c.Conns); lc != numEach { + t.Fatalf("Expected %d connections in array, got %d\n", numEach, lc) + } + // Look at all closed + c = pollConz(t, s, mode, url+"connz?state=closed", &ConnzOptions{State: ConnClosed}) + if lc := len(c.Conns); lc != numEach { + t.Fatalf("Expected %d connections in array, got %d\n", numEach, lc) + } + // Look at all + c = pollConz(t, s, mode, url+"connz?state=ALL", &ConnzOptions{State: ConnAll}) + if lc := len(c.Conns); lc != numEach*2 { + t.Fatalf("Expected %d connections in array, got %d\n", 2*numEach, lc) + } + // Look at CID #1, which is in closed. + c = pollConz(t, s, mode, url+"connz?cid=1&state=open", &ConnzOptions{CID: 1, State: ConnOpen}) + if lc := len(c.Conns); lc != 0 { + t.Fatalf("Expected no connections in open array, got %d\n", lc) + } + c = pollConz(t, s, mode, url+"connz?cid=1&state=closed", &ConnzOptions{CID: 1, State: ConnClosed}) + if lc := len(c.Conns); lc != 1 { + t.Fatalf("Expected a connection in closed array, got %d\n", lc) + } + c = pollConz(t, s, mode, url+"connz?cid=1&state=ALL", &ConnzOptions{CID: 1, State: ConnAll}) + if lc := len(c.Conns); lc != 1 { + t.Fatalf("Expected a connection in closed array, got %d\n", lc) + } + c = pollConz(t, s, mode, url+"connz?cid=1&state=closed&subs=true", + &ConnzOptions{CID: 1, State: ConnClosed, Subscriptions: true}) + if lc := len(c.Conns); lc != 1 { + t.Fatalf("Expected a connection in closed array, got %d\n", lc) + } + ci := c.Conns[0] + if ci.NumSubs != 1 { + t.Fatalf("Expected NumSubs to be 1, got %d\n", ci.NumSubs) + } + if len(ci.Subs) != 1 { + t.Fatalf("Expected len(ci.Subs) to be 1 also, got %d\n", len(ci.Subs)) + } + // Now ask for same thing without subs and make sure they are not returned. + c = pollConz(t, s, mode, url+"connz?cid=1&state=closed&subs=false", + &ConnzOptions{CID: 1, State: ConnClosed, Subscriptions: false}) + if lc := len(c.Conns); lc != 1 { + t.Fatalf("Expected a connection in closed array, got %d\n", lc) + } + ci = c.Conns[0] + if ci.NumSubs != 1 { + t.Fatalf("Expected NumSubs to be 1, got %d\n", ci.NumSubs) + } + if len(ci.Subs) != 0 { + t.Fatalf("Expected len(ci.Subs) to be 0 since subs=false, got %d\n", len(ci.Subs)) + } + + // CID #2 is in open + c = pollConz(t, s, mode, url+"connz?cid=2&state=open", &ConnzOptions{CID: 2, State: ConnOpen}) + if lc := len(c.Conns); lc != 1 { + t.Fatalf("Expected a connection in open array, got %d\n", lc) + } + c = pollConz(t, s, mode, url+"connz?cid=2&state=closed", &ConnzOptions{CID: 2, State: ConnClosed}) + if lc := len(c.Conns); lc != 0 { + t.Fatalf("Expected no connections in closed array, got %d\n", lc) + } + } +} + +// Make sure options for ConnInfo like subs=1, authuser, etc do not cause a race. +func TestConnzClosedConnsRace(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + // Create 100 closed connections. + for i := 0; i < 100; i++ { + nc := createClientConnSubscribeAndPublish(t, s) + nc.Close() + } + + urlWithoutSubs := fmt.Sprintf("http://127.0.0.1:%d/connz?state=closed", s.MonitorAddr().Port) + urlWithSubs := urlWithoutSubs + "&subs=true" + + checkClosedConns(t, s, 100, 2*time.Second) + + wg := &sync.WaitGroup{} + + fn := func(url string) { + deadline := time.Now().Add(1 * time.Second) + for time.Now().Before(deadline) { + c := pollConz(t, s, 0, url, nil) + if len(c.Conns) != 100 { + t.Errorf("Incorrect Results: %+v\n", c) + } + } + wg.Done() + } + + wg.Add(2) + go fn(urlWithSubs) + go fn(urlWithoutSubs) + wg.Wait() +} + +// Make sure a bad client that is disconnected right away has proper values. +func TestConnzClosedConnsBadClient(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + opts := s.getOpts() + + rc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Error on dial: %v", err) + } + rc.Close() + + checkClosedConns(t, s, 1, 2*time.Second) + + c := pollConz(t, s, 1, "", &ConnzOptions{State: ConnClosed}) + if len(c.Conns) != 1 { + t.Errorf("Incorrect Results: %+v\n", c) + } + ci := c.Conns[0] + + uptime := ci.Stop.Sub(ci.Start) + idle, err := time.ParseDuration(ci.Idle) + if err != nil { + t.Fatalf("Could not parse Idle: %v\n", err) + } + if idle > uptime { + t.Fatalf("Idle can't be larger then uptime, %v vs %v\n", idle, uptime) + } + if ci.LastActivity.IsZero() { + t.Fatalf("LastActivity should not be Zero\n") + } +} + +// Make sure a bad client that tries to connect plain to TLS has proper values. +func TestConnzClosedConnsBadTLSClient(t *testing.T) { + resetPreviousHTTPConnections() + + tc := &TLSConfigOpts{} + tc.CertFile = "configs/certs/server.pem" + tc.KeyFile = "configs/certs/key.pem" + + var err error + opts := DefaultMonitorOptions() + opts.TLSTimeout = 1.5 // 1.5 seconds + opts.TLSConfig, err = GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error creating TSL config: %v", err) + } + + s := RunServer(opts) + defer s.Shutdown() + + opts = s.getOpts() + + rc, err := net.Dial("tcp", fmt.Sprintf("%s:%d", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Error on dial: %v", err) + } + rc.Write([]byte("CONNECT {}\r\n")) + rc.Close() + + checkClosedConns(t, s, 1, 2*time.Second) + + c := pollConz(t, s, 1, "", &ConnzOptions{State: ConnClosed}) + if len(c.Conns) != 1 { + t.Errorf("Incorrect Results: %+v\n", c) + } + ci := c.Conns[0] + + uptime := ci.Stop.Sub(ci.Start) + idle, err := time.ParseDuration(ci.Idle) + if err != nil { + t.Fatalf("Could not parse Idle: %v\n", err) + } + if idle > uptime { + t.Fatalf("Idle can't be larger then uptime, %v vs %v\n", idle, uptime) + } + if ci.LastActivity.IsZero() { + t.Fatalf("LastActivity should not be Zero\n") + } +} + +// Create a connection to test ConnInfo +func createClientConnSubscribeAndPublish(t *testing.T, s *Server) *nats.Conn { + natsURL := fmt.Sprintf("nats://127.0.0.1:%d", s.Addr().(*net.TCPAddr).Port) + client := nats.DefaultOptions + client.Servers = []string{natsURL} + nc, err := client.Connect() + if err != nil { + t.Fatalf("Error creating client: %v to: %s\n", err, natsURL) + } + + ch := make(chan bool) + inbox := nats.NewInbox() + sub, err := nc.Subscribe(inbox, func(m *nats.Msg) { ch <- true }) + if err != nil { + t.Fatalf("Error subscribing to `%s`: %v\n", inbox, err) + } + nc.Publish(inbox, []byte("Hello")) + // Wait for message + <-ch + sub.Unsubscribe() + close(ch) + nc.Flush() + return nc +} + +func createClientConnWithName(t *testing.T, name string, s *Server) *nats.Conn { + natsURI := fmt.Sprintf("nats://127.0.0.1:%d", s.Addr().(*net.TCPAddr).Port) + + client := nats.DefaultOptions + client.Servers = []string{natsURI} + client.Name = name + nc, err := client.Connect() + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + return nc +} + +func TestStacksz(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + body := readBody(t, url+"stacksz") + // Check content + str := string(body) + if !strings.Contains(str, "HandleStacksz") { + t.Fatalf("Result does not seem to contain server's stacks:\n%v", str) + } +} + +func TestConcurrentMonitoring(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + url := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + // Get some endpoints. Make sure we have at least varz, + // and the more the merrier. + endpoints := []string{"varz", "varz", "varz", "connz", "connz", "subsz", "subsz", "routez", "routez"} + wg := &sync.WaitGroup{} + wg.Add(len(endpoints)) + ech := make(chan string, len(endpoints)) + + for _, e := range endpoints { + go func(endpoint string) { + defer wg.Done() + for i := 0; i < 50; i++ { + resp, err := http.Get(url + endpoint) + if err != nil { + ech <- fmt.Sprintf("Expected no error: Got %v\n", err) + return + } + if resp.StatusCode != http.StatusOK { + ech <- fmt.Sprintf("Expected a %v response, got %d\n", http.StatusOK, resp.StatusCode) + return + } + ct := resp.Header.Get("Content-Type") + if ct != "application/json" { + ech <- fmt.Sprintf("Expected application/json content-type, got %s\n", ct) + return + } + defer resp.Body.Close() + if _, err := ioutil.ReadAll(resp.Body); err != nil { + ech <- fmt.Sprintf("Got an error reading the body: %v\n", err) + return + } + resp.Body.Close() + } + }(e) + } + wg.Wait() + // Check for any errors + select { + case err := <-ech: + t.Fatal(err) + default: + } +} + +func TestMonitorHandler(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + handler := s.HTTPHandler() + if handler == nil { + t.Fatal("HTTP Handler should be set") + } + s.Shutdown() + handler = s.HTTPHandler() + if handler != nil { + t.Fatal("HTTP Handler should be nil") + } +} + +func TestMonitorRoutezRace(t *testing.T) { + resetPreviousHTTPConnections() + srvAOpts := DefaultMonitorOptions() + srvAOpts.Cluster.Port = -1 + srvA := RunServer(srvAOpts) + defer srvA.Shutdown() + + srvBOpts := nextServerOpts(srvAOpts) + srvBOpts.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", srvA.ClusterAddr().Port)) + + url := fmt.Sprintf("http://127.0.0.1:%d/", srvA.MonitorAddr().Port) + doneCh := make(chan struct{}) + go func() { + defer func() { + doneCh <- struct{}{} + }() + for i := 0; i < 10; i++ { + time.Sleep(10 * time.Millisecond) + // Reset ports + srvBOpts.Port = -1 + srvBOpts.Cluster.Port = -1 + srvB := RunServer(srvBOpts) + time.Sleep(20 * time.Millisecond) + srvB.Shutdown() + } + }() + done := false + for !done { + if resp, err := http.Get(url + "routez"); err != nil { + time.Sleep(10 * time.Millisecond) + } else { + resp.Body.Close() + } + select { + case <-doneCh: + done = true + default: + } + } +} + +func TestConnzTLSInHandshake(t *testing.T) { + resetPreviousHTTPConnections() + + tc := &TLSConfigOpts{} + tc.CertFile = "configs/certs/server.pem" + tc.KeyFile = "configs/certs/key.pem" + + var err error + opts := DefaultMonitorOptions() + opts.TLSTimeout = 1.5 // 1.5 seconds + opts.TLSConfig, err = GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error creating TSL config: %v", err) + } + + s := RunServer(opts) + defer s.Shutdown() + + // Create bare TCP connection to delay client TLS handshake + c, err := net.Dial("tcp", fmt.Sprintf("%s:%d", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Error on dial: %v", err) + } + defer c.Close() + + // Wait for the connection to be registered + checkClientsCount(t, s, 1) + + start := time.Now() + endpoint := fmt.Sprintf("http://%s:%d/connz", opts.HTTPHost, s.MonitorAddr().Port) + for mode := 0; mode < 2; mode++ { + connz := pollConz(t, s, mode, endpoint, nil) + duration := time.Since(start) + if duration >= 1500*time.Millisecond { + t.Fatalf("Looks like connz blocked on handshake, took %v", duration) + } + if len(connz.Conns) != 1 { + t.Fatalf("Expected 1 conn, got %v", len(connz.Conns)) + } + conn := connz.Conns[0] + // TLS fields should be not set + if conn.TLSVersion != "" || conn.TLSCipher != "" { + t.Fatalf("Expected TLS fields to not be set, got version:%v cipher:%v", conn.TLSVersion, conn.TLSCipher) + } + } +} + +func TestServerIDs(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + murl := fmt.Sprintf("http://127.0.0.1:%d/", s.MonitorAddr().Port) + + for mode := 0; mode < 2; mode++ { + v := pollVarz(t, s, mode, murl+"varz", nil) + if v.ID == "" { + t.Fatal("Varz ID is empty") + } + c := pollConz(t, s, mode, murl+"connz", nil) + if c.ID == "" { + t.Fatal("Connz ID is empty") + } + r := pollRoutez(t, s, mode, murl+"routez", nil) + if r.ID == "" { + t.Fatal("Routez ID is empty") + } + if v.ID != c.ID || v.ID != r.ID { + t.Fatalf("Varz ID [%s] is not equal to Connz ID [%s] or Routez ID [%s]", v.ID, c.ID, r.ID) + } + } +} + +func TestHttpStatsNoUpdatedWhenUsingServerFuncs(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + for i := 0; i < 10; i++ { + s.Varz(nil) + s.Connz(nil) + s.Routez(nil) + s.Subsz(nil) + } + + v, _ := s.Varz(nil) + endpoints := []string{VarzPath, ConnzPath, RoutezPath, SubszPath} + for _, e := range endpoints { + stats := v.HTTPReqStats[e] + if stats != 0 { + t.Fatalf("Expected HTTPReqStats for %q to be 0, got %v", e, stats) + } + } +} + +func TestClusterEmptyWhenNotDefined(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + body := readBody(t, fmt.Sprintf("http://127.0.0.1:%d/varz", s.MonitorAddr().Port)) + var v map[string]interface{} + if err := json.Unmarshal(body, &v); err != nil { + stackFatalf(t, "Got an error unmarshalling the body: %v\n", err) + } + // Cluster can empty, or be defined but that needs to be empty. + c, ok := v["cluster"] + if !ok { + return + } + if len(c.(map[string]interface{})) != 0 { + t.Fatalf("Expected an empty cluster definition, instead got %+v\n", c) + } +} + +// Benchmark our Connz generation. Don't use HTTP here, just measure server endpoint. +func Benchmark_Connz(b *testing.B) { + runtime.MemProfileRate = 0 + + s := runMonitorServerNoHTTPPort() + defer s.Shutdown() + + opts := s.getOpts() + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + + // Create 250 connections with 100 subs each. + for i := 0; i < 250; i++ { + nc, err := nats.Connect(url) + if err != nil { + b.Fatalf("Error on connection[%d] to %s: %v", i, url, err) + } + for x := 0; x < 100; x++ { + subj := fmt.Sprintf("foo.%d", x) + nc.Subscribe(subj, func(m *nats.Msg) {}) + } + nc.Flush() + defer nc.Close() + } + + b.ResetTimer() + runtime.MemProfileRate = 1 + + copts := &ConnzOptions{Subscriptions: false} + for i := 0; i < b.N; i++ { + _, err := s.Connz(copts) + if err != nil { + b.Fatalf("Error on Connz(): %v", err) + } + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/norace_test.go b/vendor/github.com/nats-io/gnatsd/server/norace_test.go new file mode 100644 index 00000000..f208be38 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/norace_test.go @@ -0,0 +1,87 @@ +// Copyright 2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !race + +package server + +import ( + "fmt" + "math/rand" + "sync/atomic" + "testing" + "time" + + "github.com/nats-io/go-nats" +) + +// IMPORTANT: Tests in this file are not executed when running with the -race flag. + +func TestAvoidSlowConsumerBigMessages(t *testing.T) { + opts := DefaultOptions() // Use defaults to make sure they avoid pending slow consumer. + s := RunServer(opts) + defer s.Shutdown() + + nc1, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc1.Close() + + nc2, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc2.Close() + + data := make([]byte, 1024*1024) // 1MB payload + rand.Read(data) + + expected := int32(500) + received := int32(0) + + done := make(chan bool) + + // Create Subscription. + nc1.Subscribe("slow.consumer", func(m *nats.Msg) { + // Just eat it so that we are not measuring + // code time, just delivery. + atomic.AddInt32(&received, 1) + if received >= expected { + done <- true + } + }) + + // Create Error handler + nc1.SetErrorHandler(func(c *nats.Conn, s *nats.Subscription, err error) { + t.Fatalf("Received an error on the subscription's connection: %v\n", err) + }) + + nc1.Flush() + + for i := 0; i < int(expected); i++ { + nc2.Publish("slow.consumer", data) + } + nc2.Flush() + + select { + case <-done: + return + case <-time.After(10 * time.Second): + r := atomic.LoadInt32(&received) + if s.NumSlowConsumers() > 0 { + t.Fatalf("Did not receive all large messages due to slow consumer status: %d of %d", r, expected) + } + t.Fatalf("Failed to receive all large messages: %d of %d\n", r, expected) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/opts.go b/vendor/github.com/nats-io/gnatsd/server/opts.go new file mode 100644 index 00000000..ad4f38a3 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/opts.go @@ -0,0 +1,1259 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "flag" + "fmt" + "io/ioutil" + "net" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/nats-io/gnatsd/conf" + "github.com/nats-io/gnatsd/util" +) + +// ClusterOpts are options for clusters. +type ClusterOpts struct { + Host string `json:"addr,omitempty"` + Port int `json:"cluster_port,omitempty"` + Username string `json:"-"` + Password string `json:"-"` + AuthTimeout float64 `json:"auth_timeout,omitempty"` + Permissions *RoutePermissions `json:"-"` + TLSTimeout float64 `json:"-"` + TLSConfig *tls.Config `json:"-"` + ListenStr string `json:"-"` + Advertise string `json:"-"` + NoAdvertise bool `json:"-"` + ConnectRetries int `json:"-"` +} + +// Options block for gnatsd server. +type Options struct { + ConfigFile string `json:"-"` + Host string `json:"addr"` + Port int `json:"port"` + ClientAdvertise string `json:"-"` + Trace bool `json:"-"` + Debug bool `json:"-"` + NoLog bool `json:"-"` + NoSigs bool `json:"-"` + Logtime bool `json:"-"` + MaxConn int `json:"max_connections"` + MaxSubs int `json:"max_subscriptions,omitempty"` + Users []*User `json:"-"` + Username string `json:"-"` + Password string `json:"-"` + Authorization string `json:"-"` + PingInterval time.Duration `json:"ping_interval"` + MaxPingsOut int `json:"ping_max"` + HTTPHost string `json:"http_host"` + HTTPPort int `json:"http_port"` + HTTPSPort int `json:"https_port"` + AuthTimeout float64 `json:"auth_timeout"` + MaxControlLine int `json:"max_control_line"` + MaxPayload int `json:"max_payload"` + MaxPending int64 `json:"max_pending"` + Cluster ClusterOpts `json:"cluster,omitempty"` + ProfPort int `json:"-"` + PidFile string `json:"-"` + PortsFileDir string `json:"-"` + LogFile string `json:"-"` + Syslog bool `json:"-"` + RemoteSyslog string `json:"-"` + Routes []*url.URL `json:"-"` + RoutesStr string `json:"-"` + TLSTimeout float64 `json:"tls_timeout"` + TLS bool `json:"-"` + TLSVerify bool `json:"-"` + TLSCert string `json:"-"` + TLSKey string `json:"-"` + TLSCaCert string `json:"-"` + TLSConfig *tls.Config `json:"-"` + WriteDeadline time.Duration `json:"-"` + RQSubsSweep time.Duration `json:"-"` + MaxClosedClients int `json:"-"` + + CustomClientAuthentication Authentication `json:"-"` + CustomRouterAuthentication Authentication `json:"-"` +} + +// Clone performs a deep copy of the Options struct, returning a new clone +// with all values copied. +func (o *Options) Clone() *Options { + if o == nil { + return nil + } + clone := &Options{} + *clone = *o + if o.Users != nil { + clone.Users = make([]*User, len(o.Users)) + for i, user := range o.Users { + clone.Users[i] = user.clone() + } + } + if o.Routes != nil { + clone.Routes = make([]*url.URL, len(o.Routes)) + for i, route := range o.Routes { + routeCopy := &url.URL{} + *routeCopy = *route + clone.Routes[i] = routeCopy + } + } + if o.TLSConfig != nil { + clone.TLSConfig = util.CloneTLSConfig(o.TLSConfig) + } + if o.Cluster.TLSConfig != nil { + clone.Cluster.TLSConfig = util.CloneTLSConfig(o.Cluster.TLSConfig) + } + return clone +} + +// Configuration file authorization section. +type authorization struct { + // Singles + user string + pass string + token string + // Multiple Users + users []*User + timeout float64 + defaultPermissions *Permissions +} + +// TLSConfigOpts holds the parsed tls config information, +// used with flag parsing +type TLSConfigOpts struct { + CertFile string + KeyFile string + CaFile string + Verify bool + Timeout float64 + Ciphers []uint16 + CurvePreferences []tls.CurveID +} + +var tlsUsage = ` +TLS configuration is specified in the tls section of a configuration file: + +e.g. + + tls { + cert_file: "./certs/server-cert.pem" + key_file: "./certs/server-key.pem" + ca_file: "./certs/ca.pem" + verify: true + + cipher_suites: [ + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" + ] + curve_preferences: [ + "CurveP256", + "CurveP384", + "CurveP521" + ] + } + +Available cipher suites include: +` + +// ProcessConfigFile processes a configuration file. +// FIXME(dlc): Hacky +func ProcessConfigFile(configFile string) (*Options, error) { + opts := &Options{} + if err := opts.ProcessConfigFile(configFile); err != nil { + return nil, err + } + return opts, nil +} + +// ProcessConfigFile updates the Options structure with options +// present in the given configuration file. +// This version is convenient if one wants to set some default +// options and then override them with what is in the config file. +// For instance, this version allows you to do something such as: +// +// opts := &Options{Debug: true} +// opts.ProcessConfigFile(myConfigFile) +// +// If the config file contains "debug: false", after this call, +// opts.Debug would really be false. It would be impossible to +// achieve that with the non receiver ProcessConfigFile() version, +// since one would not know after the call if "debug" was not present +// or was present but set to false. +func (o *Options) ProcessConfigFile(configFile string) error { + o.ConfigFile = configFile + if configFile == "" { + return nil + } + + m, err := conf.ParseFile(configFile) + if err != nil { + return err + } + + for k, v := range m { + switch strings.ToLower(k) { + case "listen": + hp, err := parseListen(v) + if err != nil { + return err + } + o.Host = hp.host + o.Port = hp.port + case "client_advertise": + o.ClientAdvertise = v.(string) + case "port": + o.Port = int(v.(int64)) + case "host", "net": + o.Host = v.(string) + case "debug": + o.Debug = v.(bool) + case "trace": + o.Trace = v.(bool) + case "logtime": + o.Logtime = v.(bool) + case "authorization": + am := v.(map[string]interface{}) + auth, err := parseAuthorization(am) + if err != nil { + return err + } + o.Username = auth.user + o.Password = auth.pass + o.Authorization = auth.token + if (auth.user != "" || auth.pass != "") && auth.token != "" { + return fmt.Errorf("Cannot have a user/pass and token") + } + o.AuthTimeout = auth.timeout + // Check for multiple users defined + if auth.users != nil { + if auth.user != "" { + return fmt.Errorf("Can not have a single user/pass and a users array") + } + if auth.token != "" { + return fmt.Errorf("Can not have a token and a users array") + } + o.Users = auth.users + } + case "http": + hp, err := parseListen(v) + if err != nil { + return err + } + o.HTTPHost = hp.host + o.HTTPPort = hp.port + case "https": + hp, err := parseListen(v) + if err != nil { + return err + } + o.HTTPHost = hp.host + o.HTTPSPort = hp.port + case "http_port", "monitor_port": + o.HTTPPort = int(v.(int64)) + case "https_port": + o.HTTPSPort = int(v.(int64)) + case "cluster": + cm := v.(map[string]interface{}) + if err := parseCluster(cm, o); err != nil { + return err + } + case "logfile", "log_file": + o.LogFile = v.(string) + case "syslog": + o.Syslog = v.(bool) + case "remote_syslog": + o.RemoteSyslog = v.(string) + case "pidfile", "pid_file": + o.PidFile = v.(string) + case "ports_file_dir": + o.PortsFileDir = v.(string) + case "prof_port": + o.ProfPort = int(v.(int64)) + case "max_control_line": + o.MaxControlLine = int(v.(int64)) + case "max_payload": + o.MaxPayload = int(v.(int64)) + case "max_pending": + o.MaxPending = v.(int64) + case "max_connections", "max_conn": + o.MaxConn = int(v.(int64)) + case "max_subscriptions", "max_subs": + o.MaxSubs = int(v.(int64)) + case "ping_interval": + o.PingInterval = time.Duration(int(v.(int64))) * time.Second + case "ping_max": + o.MaxPingsOut = int(v.(int64)) + case "tls": + tlsm := v.(map[string]interface{}) + tc, err := parseTLS(tlsm) + if err != nil { + return err + } + if o.TLSConfig, err = GenTLSConfig(tc); err != nil { + return err + } + o.TLSTimeout = tc.Timeout + case "write_deadline": + wd, ok := v.(string) + if ok { + dur, err := time.ParseDuration(wd) + if err != nil { + return fmt.Errorf("error parsing write_deadline: %v", err) + } + o.WriteDeadline = dur + } else { + // Backward compatible with old type, assume this is the + // number of seconds. + o.WriteDeadline = time.Duration(v.(int64)) * time.Second + fmt.Printf("WARNING: write_deadline should be converted to a duration\n") + } + } + } + return nil +} + +// hostPort is simple struct to hold parsed listen/addr strings. +type hostPort struct { + host string + port int +} + +// parseListen will parse listen option which is replacing host/net and port +func parseListen(v interface{}) (*hostPort, error) { + hp := &hostPort{} + switch v.(type) { + // Only a port + case int64: + hp.port = int(v.(int64)) + case string: + host, port, err := net.SplitHostPort(v.(string)) + if err != nil { + return nil, fmt.Errorf("Could not parse address string %q", v) + } + hp.port, err = strconv.Atoi(port) + if err != nil { + return nil, fmt.Errorf("Could not parse port %q", port) + } + hp.host = host + } + return hp, nil +} + +// parseCluster will parse the cluster config. +func parseCluster(cm map[string]interface{}, opts *Options) error { + for mk, mv := range cm { + switch strings.ToLower(mk) { + case "listen": + hp, err := parseListen(mv) + if err != nil { + return err + } + opts.Cluster.Host = hp.host + opts.Cluster.Port = hp.port + case "port": + opts.Cluster.Port = int(mv.(int64)) + case "host", "net": + opts.Cluster.Host = mv.(string) + case "authorization": + am := mv.(map[string]interface{}) + auth, err := parseAuthorization(am) + if err != nil { + return err + } + if auth.users != nil { + return fmt.Errorf("Cluster authorization does not allow multiple users") + } + opts.Cluster.Username = auth.user + opts.Cluster.Password = auth.pass + opts.Cluster.AuthTimeout = auth.timeout + if auth.defaultPermissions != nil { + // Import is whether or not we will send a SUB for interest to the other side. + // Export is whether or not we will accept a SUB from the remote for a given subject. + // Both only effect interest registration. + // The parsing sets Import into Publish and Export into Subscribe, convert + // accordingly. + opts.Cluster.Permissions = &RoutePermissions{ + Import: auth.defaultPermissions.Publish, + Export: auth.defaultPermissions.Subscribe, + } + } + case "routes": + ra := mv.([]interface{}) + opts.Routes = make([]*url.URL, 0, len(ra)) + for _, r := range ra { + routeURL := r.(string) + url, err := url.Parse(routeURL) + if err != nil { + return fmt.Errorf("error parsing route url [%q]", routeURL) + } + opts.Routes = append(opts.Routes, url) + } + case "tls": + tlsm := mv.(map[string]interface{}) + tc, err := parseTLS(tlsm) + if err != nil { + return err + } + if opts.Cluster.TLSConfig, err = GenTLSConfig(tc); err != nil { + return err + } + // For clusters, we will force strict verification. We also act + // as both client and server, so will mirror the rootCA to the + // clientCA pool. + opts.Cluster.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + opts.Cluster.TLSConfig.RootCAs = opts.Cluster.TLSConfig.ClientCAs + opts.Cluster.TLSTimeout = tc.Timeout + case "cluster_advertise", "advertise": + opts.Cluster.Advertise = mv.(string) + case "no_advertise": + opts.Cluster.NoAdvertise = mv.(bool) + case "connect_retries": + opts.Cluster.ConnectRetries = int(mv.(int64)) + } + } + return nil +} + +// Helper function to parse Authorization configs. +func parseAuthorization(am map[string]interface{}) (*authorization, error) { + auth := &authorization{} + for mk, mv := range am { + switch strings.ToLower(mk) { + case "user", "username": + auth.user = mv.(string) + case "pass", "password": + auth.pass = mv.(string) + case "token": + auth.token = mv.(string) + case "timeout": + at := float64(1) + switch mv.(type) { + case int64: + at = float64(mv.(int64)) + case float64: + at = mv.(float64) + } + auth.timeout = at + case "users": + users, err := parseUsers(mv) + if err != nil { + return nil, err + } + auth.users = users + case "default_permission", "default_permissions", "permissions": + pm, ok := mv.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("Expected default permissions to be a map/struct, got %+v", mv) + } + permissions, err := parseUserPermissions(pm) + if err != nil { + return nil, err + } + auth.defaultPermissions = permissions + } + + // Now check for permission defaults with multiple users, etc. + if auth.users != nil && auth.defaultPermissions != nil { + for _, user := range auth.users { + if user.Permissions == nil { + user.Permissions = auth.defaultPermissions + } + } + } + + } + return auth, nil +} + +// Helper function to parse multiple users array with optional permissions. +func parseUsers(mv interface{}) ([]*User, error) { + // Make sure we have an array + uv, ok := mv.([]interface{}) + if !ok { + return nil, fmt.Errorf("Expected users field to be an array, got %v", mv) + } + users := []*User{} + for _, u := range uv { + // Check its a map/struct + um, ok := u.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("Expected user entry to be a map/struct, got %v", u) + } + user := &User{} + for k, v := range um { + switch strings.ToLower(k) { + case "user", "username": + user.Username = v.(string) + case "pass", "password": + user.Password = v.(string) + case "permission", "permissions", "authorization": + pm, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("Expected user permissions to be a map/struct, got %+v", v) + } + permissions, err := parseUserPermissions(pm) + if err != nil { + return nil, err + } + user.Permissions = permissions + } + } + // Check to make sure we have at least username and password + if user.Username == "" || user.Password == "" { + return nil, fmt.Errorf("User entry requires a user and a password") + } + users = append(users, user) + } + return users, nil +} + +// Helper function to parse user/account permissions +func parseUserPermissions(pm map[string]interface{}) (*Permissions, error) { + p := &Permissions{} + for k, v := range pm { + switch strings.ToLower(k) { + // For routes: + // Import is Publish + // Export is Subscribe + case "pub", "publish", "import": + subjects, err := parseSubjects(v) + if err != nil { + return nil, err + } + p.Publish = subjects + case "sub", "subscribe", "export": + subjects, err := parseSubjects(v) + if err != nil { + return nil, err + } + p.Subscribe = subjects + default: + return nil, fmt.Errorf("Unknown field %s parsing permissions", k) + } + } + return p, nil +} + +// Helper function to parse subject singeltons and/or arrays +func parseSubjects(v interface{}) ([]string, error) { + var subjects []string + switch v.(type) { + case string: + subjects = append(subjects, v.(string)) + case []string: + subjects = v.([]string) + case []interface{}: + for _, i := range v.([]interface{}) { + subject, ok := i.(string) + if !ok { + return nil, fmt.Errorf("Subject in permissions array cannot be cast to string") + } + subjects = append(subjects, subject) + } + default: + return nil, fmt.Errorf("Expected subject permissions to be a subject, or array of subjects, got %T", v) + } + return checkSubjectArray(subjects) +} + +// Helper function to validate subjects, etc for account permissioning. +func checkSubjectArray(sa []string) ([]string, error) { + for _, s := range sa { + if !IsValidSubject(s) { + return nil, fmt.Errorf("Subject %q is not a valid subject", s) + } + } + return sa, nil +} + +// PrintTLSHelpAndDie prints TLS usage and exits. +func PrintTLSHelpAndDie() { + fmt.Printf("%s", tlsUsage) + for k := range cipherMap { + fmt.Printf(" %s\n", k) + } + fmt.Printf("\nAvailable curve preferences include:\n") + for k := range curvePreferenceMap { + fmt.Printf(" %s\n", k) + } + os.Exit(0) +} + +func parseCipher(cipherName string) (uint16, error) { + + cipher, exists := cipherMap[cipherName] + if !exists { + return 0, fmt.Errorf("Unrecognized cipher %s", cipherName) + } + + return cipher, nil +} + +func parseCurvePreferences(curveName string) (tls.CurveID, error) { + curve, exists := curvePreferenceMap[curveName] + if !exists { + return 0, fmt.Errorf("Unrecognized curve preference %s", curveName) + } + return curve, nil +} + +// Helper function to parse TLS configs. +func parseTLS(tlsm map[string]interface{}) (*TLSConfigOpts, error) { + tc := TLSConfigOpts{} + for mk, mv := range tlsm { + switch strings.ToLower(mk) { + case "cert_file": + certFile, ok := mv.(string) + if !ok { + return nil, fmt.Errorf("error parsing tls config, expected 'cert_file' to be filename") + } + tc.CertFile = certFile + case "key_file": + keyFile, ok := mv.(string) + if !ok { + return nil, fmt.Errorf("error parsing tls config, expected 'key_file' to be filename") + } + tc.KeyFile = keyFile + case "ca_file": + caFile, ok := mv.(string) + if !ok { + return nil, fmt.Errorf("error parsing tls config, expected 'ca_file' to be filename") + } + tc.CaFile = caFile + case "verify": + verify, ok := mv.(bool) + if !ok { + return nil, fmt.Errorf("error parsing tls config, expected 'verify' to be a boolean") + } + tc.Verify = verify + case "cipher_suites": + ra := mv.([]interface{}) + if len(ra) == 0 { + return nil, fmt.Errorf("error parsing tls config, 'cipher_suites' cannot be empty") + } + tc.Ciphers = make([]uint16, 0, len(ra)) + for _, r := range ra { + cipher, err := parseCipher(r.(string)) + if err != nil { + return nil, err + } + tc.Ciphers = append(tc.Ciphers, cipher) + } + case "curve_preferences": + ra := mv.([]interface{}) + if len(ra) == 0 { + return nil, fmt.Errorf("error parsing tls config, 'curve_preferences' cannot be empty") + } + tc.CurvePreferences = make([]tls.CurveID, 0, len(ra)) + for _, r := range ra { + cps, err := parseCurvePreferences(r.(string)) + if err != nil { + return nil, err + } + tc.CurvePreferences = append(tc.CurvePreferences, cps) + } + case "timeout": + at := float64(0) + switch mv.(type) { + case int64: + at = float64(mv.(int64)) + case float64: + at = mv.(float64) + } + tc.Timeout = at + default: + return nil, fmt.Errorf("error parsing tls config, unknown field [%q]", mk) + } + } + + // If cipher suites were not specified then use the defaults + if tc.Ciphers == nil { + tc.Ciphers = defaultCipherSuites() + } + + // If curve preferences were not specified, then use the defaults + if tc.CurvePreferences == nil { + tc.CurvePreferences = defaultCurvePreferences() + } + + return &tc, nil +} + +// GenTLSConfig loads TLS related configuration parameters. +func GenTLSConfig(tc *TLSConfigOpts) (*tls.Config, error) { + + // Now load in cert and private key + cert, err := tls.LoadX509KeyPair(tc.CertFile, tc.KeyFile) + if err != nil { + return nil, fmt.Errorf("error parsing X509 certificate/key pair: %v", err) + } + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return nil, fmt.Errorf("error parsing certificate: %v", err) + } + + // Create TLSConfig + // We will determine the cipher suites that we prefer. + config := tls.Config{ + CurvePreferences: tc.CurvePreferences, + Certificates: []tls.Certificate{cert}, + PreferServerCipherSuites: true, + MinVersion: tls.VersionTLS12, + CipherSuites: tc.Ciphers, + } + + // Require client certificates as needed + if tc.Verify { + config.ClientAuth = tls.RequireAndVerifyClientCert + } + // Add in CAs if applicable. + if tc.CaFile != "" { + rootPEM, err := ioutil.ReadFile(tc.CaFile) + if err != nil || rootPEM == nil { + return nil, err + } + pool := x509.NewCertPool() + ok := pool.AppendCertsFromPEM(rootPEM) + if !ok { + return nil, fmt.Errorf("failed to parse root ca certificate") + } + config.ClientCAs = pool + } + + return &config, nil +} + +// MergeOptions will merge two options giving preference to the flagOpts +// if the item is present. +func MergeOptions(fileOpts, flagOpts *Options) *Options { + if fileOpts == nil { + return flagOpts + } + if flagOpts == nil { + return fileOpts + } + // Merge the two, flagOpts override + opts := *fileOpts + + if flagOpts.Port != 0 { + opts.Port = flagOpts.Port + } + if flagOpts.Host != "" { + opts.Host = flagOpts.Host + } + if flagOpts.ClientAdvertise != "" { + opts.ClientAdvertise = flagOpts.ClientAdvertise + } + if flagOpts.Username != "" { + opts.Username = flagOpts.Username + } + if flagOpts.Password != "" { + opts.Password = flagOpts.Password + } + if flagOpts.Authorization != "" { + opts.Authorization = flagOpts.Authorization + } + if flagOpts.HTTPPort != 0 { + opts.HTTPPort = flagOpts.HTTPPort + } + if flagOpts.Debug { + opts.Debug = true + } + if flagOpts.Trace { + opts.Trace = true + } + if flagOpts.Logtime { + opts.Logtime = true + } + if flagOpts.LogFile != "" { + opts.LogFile = flagOpts.LogFile + } + if flagOpts.PidFile != "" { + opts.PidFile = flagOpts.PidFile + } + if flagOpts.PortsFileDir != "" { + opts.PortsFileDir = flagOpts.PortsFileDir + } + if flagOpts.ProfPort != 0 { + opts.ProfPort = flagOpts.ProfPort + } + if flagOpts.Cluster.ListenStr != "" { + opts.Cluster.ListenStr = flagOpts.Cluster.ListenStr + } + if flagOpts.Cluster.NoAdvertise { + opts.Cluster.NoAdvertise = true + } + if flagOpts.Cluster.ConnectRetries != 0 { + opts.Cluster.ConnectRetries = flagOpts.Cluster.ConnectRetries + } + if flagOpts.Cluster.Advertise != "" { + opts.Cluster.Advertise = flagOpts.Cluster.Advertise + } + if flagOpts.RoutesStr != "" { + mergeRoutes(&opts, flagOpts) + } + return &opts +} + +// RoutesFromStr parses route URLs from a string +func RoutesFromStr(routesStr string) []*url.URL { + routes := strings.Split(routesStr, ",") + if len(routes) == 0 { + return nil + } + routeUrls := []*url.URL{} + for _, r := range routes { + r = strings.TrimSpace(r) + u, _ := url.Parse(r) + routeUrls = append(routeUrls, u) + } + return routeUrls +} + +// This will merge the flag routes and override anything that was present. +func mergeRoutes(opts, flagOpts *Options) { + routeUrls := RoutesFromStr(flagOpts.RoutesStr) + if routeUrls == nil { + return + } + opts.Routes = routeUrls + opts.RoutesStr = flagOpts.RoutesStr +} + +// RemoveSelfReference removes this server from an array of routes +func RemoveSelfReference(clusterPort int, routes []*url.URL) ([]*url.URL, error) { + var cleanRoutes []*url.URL + cport := strconv.Itoa(clusterPort) + + selfIPs, err := getInterfaceIPs() + if err != nil { + return nil, err + } + for _, r := range routes { + host, port, err := net.SplitHostPort(r.Host) + if err != nil { + return nil, err + } + + ipList, err := getURLIP(host) + if err != nil { + return nil, err + } + if cport == port && isIPInList(selfIPs, ipList) { + continue + } + cleanRoutes = append(cleanRoutes, r) + } + + return cleanRoutes, nil +} + +func isIPInList(list1 []net.IP, list2 []net.IP) bool { + for _, ip1 := range list1 { + for _, ip2 := range list2 { + if ip1.Equal(ip2) { + return true + } + } + } + return false +} + +func getURLIP(ipStr string) ([]net.IP, error) { + ipList := []net.IP{} + + ip := net.ParseIP(ipStr) + if ip != nil { + ipList = append(ipList, ip) + return ipList, nil + } + + hostAddr, err := net.LookupHost(ipStr) + if err != nil { + return nil, fmt.Errorf("Error looking up host with route hostname: %v", err) + } + for _, addr := range hostAddr { + ip = net.ParseIP(addr) + if ip != nil { + ipList = append(ipList, ip) + } + } + return ipList, nil +} + +func getInterfaceIPs() ([]net.IP, error) { + var localIPs []net.IP + + interfaceAddr, err := net.InterfaceAddrs() + if err != nil { + return nil, fmt.Errorf("Error getting self referencing address: %v", err) + } + + for i := 0; i < len(interfaceAddr); i++ { + interfaceIP, _, _ := net.ParseCIDR(interfaceAddr[i].String()) + if net.ParseIP(interfaceIP.String()) != nil { + localIPs = append(localIPs, interfaceIP) + } else { + return nil, fmt.Errorf("Error parsing self referencing address: %v", err) + } + } + return localIPs, nil +} + +func processOptions(opts *Options) { + // Setup non-standard Go defaults + if opts.Host == "" { + opts.Host = DEFAULT_HOST + } + if opts.HTTPHost == "" { + // Default to same bind from server if left undefined + opts.HTTPHost = opts.Host + } + if opts.Port == 0 { + opts.Port = DEFAULT_PORT + } else if opts.Port == RANDOM_PORT { + // Choose randomly inside of net.Listen + opts.Port = 0 + } + if opts.MaxConn == 0 { + opts.MaxConn = DEFAULT_MAX_CONNECTIONS + } + if opts.PingInterval == 0 { + opts.PingInterval = DEFAULT_PING_INTERVAL + } + if opts.MaxPingsOut == 0 { + opts.MaxPingsOut = DEFAULT_PING_MAX_OUT + } + if opts.TLSTimeout == 0 { + opts.TLSTimeout = float64(TLS_TIMEOUT) / float64(time.Second) + } + if opts.AuthTimeout == 0 { + opts.AuthTimeout = float64(AUTH_TIMEOUT) / float64(time.Second) + } + if opts.Cluster.Port != 0 { + if opts.Cluster.Host == "" { + opts.Cluster.Host = DEFAULT_HOST + } + if opts.Cluster.TLSTimeout == 0 { + opts.Cluster.TLSTimeout = float64(TLS_TIMEOUT) / float64(time.Second) + } + if opts.Cluster.AuthTimeout == 0 { + opts.Cluster.AuthTimeout = float64(AUTH_TIMEOUT) / float64(time.Second) + } + } + if opts.MaxControlLine == 0 { + opts.MaxControlLine = MAX_CONTROL_LINE_SIZE + } + if opts.MaxPayload == 0 { + opts.MaxPayload = MAX_PAYLOAD_SIZE + } + if opts.MaxPending == 0 { + opts.MaxPending = MAX_PENDING_SIZE + } + if opts.WriteDeadline == time.Duration(0) { + opts.WriteDeadline = DEFAULT_FLUSH_DEADLINE + } + if opts.RQSubsSweep == time.Duration(0) { + opts.RQSubsSweep = DEFAULT_REMOTE_QSUBS_SWEEPER + } + if opts.MaxClosedClients == 0 { + opts.MaxClosedClients = DEFAULT_MAX_CLOSED_CLIENTS + } +} + +// ConfigureOptions accepts a flag set and augment it with NATS Server +// specific flags. On success, an options structure is returned configured +// based on the selected flags and/or configuration file. +// The command line options take precedence to the ones in the configuration file. +func ConfigureOptions(fs *flag.FlagSet, args []string, printVersion, printHelp, printTLSHelp func()) (*Options, error) { + opts := &Options{} + var ( + showVersion bool + showHelp bool + showTLSHelp bool + signal string + configFile string + err error + ) + + fs.BoolVar(&showHelp, "h", false, "Show this message.") + fs.BoolVar(&showHelp, "help", false, "Show this message.") + fs.IntVar(&opts.Port, "port", 0, "Port to listen on.") + fs.IntVar(&opts.Port, "p", 0, "Port to listen on.") + fs.StringVar(&opts.Host, "addr", "", "Network host to listen on.") + fs.StringVar(&opts.Host, "a", "", "Network host to listen on.") + fs.StringVar(&opts.Host, "net", "", "Network host to listen on.") + fs.StringVar(&opts.ClientAdvertise, "client_advertise", "", "Client URL to advertise to other servers.") + fs.BoolVar(&opts.Debug, "D", false, "Enable Debug logging.") + fs.BoolVar(&opts.Debug, "debug", false, "Enable Debug logging.") + fs.BoolVar(&opts.Trace, "V", false, "Enable Trace logging.") + fs.BoolVar(&opts.Trace, "trace", false, "Enable Trace logging.") + fs.Bool("DV", false, "Enable Debug and Trace logging.") + fs.BoolVar(&opts.Logtime, "T", true, "Timestamp log entries.") + fs.BoolVar(&opts.Logtime, "logtime", true, "Timestamp log entries.") + fs.StringVar(&opts.Username, "user", "", "Username required for connection.") + fs.StringVar(&opts.Password, "pass", "", "Password required for connection.") + fs.StringVar(&opts.Authorization, "auth", "", "Authorization token required for connection.") + fs.IntVar(&opts.HTTPPort, "m", 0, "HTTP Port for /varz, /connz endpoints.") + fs.IntVar(&opts.HTTPPort, "http_port", 0, "HTTP Port for /varz, /connz endpoints.") + fs.IntVar(&opts.HTTPSPort, "ms", 0, "HTTPS Port for /varz, /connz endpoints.") + fs.IntVar(&opts.HTTPSPort, "https_port", 0, "HTTPS Port for /varz, /connz endpoints.") + fs.StringVar(&configFile, "c", "", "Configuration file.") + fs.StringVar(&configFile, "config", "", "Configuration file.") + fs.StringVar(&signal, "sl", "", "Send signal to gnatsd process (stop, quit, reopen, reload)") + fs.StringVar(&signal, "signal", "", "Send signal to gnatsd process (stop, quit, reopen, reload)") + fs.StringVar(&opts.PidFile, "P", "", "File to store process pid.") + fs.StringVar(&opts.PidFile, "pid", "", "File to store process pid.") + fs.StringVar(&opts.PortsFileDir, "ports_file_dir", "", "Creates a ports file in the specified directory (_.ports)") + fs.StringVar(&opts.LogFile, "l", "", "File to store logging output.") + fs.StringVar(&opts.LogFile, "log", "", "File to store logging output.") + fs.BoolVar(&opts.Syslog, "s", false, "Enable syslog as log method.") + fs.BoolVar(&opts.Syslog, "syslog", false, "Enable syslog as log method..") + fs.StringVar(&opts.RemoteSyslog, "r", "", "Syslog server addr (udp://127.0.0.1:514).") + fs.StringVar(&opts.RemoteSyslog, "remote_syslog", "", "Syslog server addr (udp://127.0.0.1:514).") + fs.BoolVar(&showVersion, "version", false, "Print version information.") + fs.BoolVar(&showVersion, "v", false, "Print version information.") + fs.IntVar(&opts.ProfPort, "profile", 0, "Profiling HTTP port") + fs.StringVar(&opts.RoutesStr, "routes", "", "Routes to actively solicit a connection.") + fs.StringVar(&opts.Cluster.ListenStr, "cluster", "", "Cluster url from which members can solicit routes.") + fs.StringVar(&opts.Cluster.ListenStr, "cluster_listen", "", "Cluster url from which members can solicit routes.") + fs.StringVar(&opts.Cluster.Advertise, "cluster_advertise", "", "Cluster URL to advertise to other servers.") + fs.BoolVar(&opts.Cluster.NoAdvertise, "no_advertise", false, "Advertise known cluster IPs to clients.") + fs.IntVar(&opts.Cluster.ConnectRetries, "connect_retries", 0, "For implicit routes, number of connect retries") + fs.BoolVar(&showTLSHelp, "help_tls", false, "TLS help.") + fs.BoolVar(&opts.TLS, "tls", false, "Enable TLS.") + fs.BoolVar(&opts.TLSVerify, "tlsverify", false, "Enable TLS with client verification.") + fs.StringVar(&opts.TLSCert, "tlscert", "", "Server certificate file.") + fs.StringVar(&opts.TLSKey, "tlskey", "", "Private key for server certificate.") + fs.StringVar(&opts.TLSCaCert, "tlscacert", "", "Client certificate CA for verification.") + + // The flags definition above set "default" values to some of the options. + // Calling Parse() here will override the default options with any value + // specified from the command line. This is ok. We will then update the + // options with the content of the configuration file (if present), and then, + // call Parse() again to override the default+config with command line values. + // Calling Parse() before processing config file is necessary since configFile + // itself is a command line argument, and also Parse() is required in order + // to know if user wants simply to show "help" or "version", etc... + if err := fs.Parse(args); err != nil { + return nil, err + } + + if showVersion { + printVersion() + return nil, nil + } + + if showHelp { + printHelp() + return nil, nil + } + + if showTLSHelp { + printTLSHelp() + return nil, nil + } + + // Process args looking for non-flag options, + // 'version' and 'help' only for now + showVersion, showHelp, err = ProcessCommandLineArgs(fs) + if err != nil { + return nil, err + } else if showVersion { + printVersion() + return nil, nil + } else if showHelp { + printHelp() + return nil, nil + } + + // Snapshot flag options. + FlagSnapshot = opts.Clone() + + // Process signal control. + if signal != "" { + if err := processSignal(signal); err != nil { + return nil, err + } + } + + // Parse config if given + if configFile != "" { + // This will update the options with values from the config file. + if err := opts.ProcessConfigFile(configFile); err != nil { + return nil, err + } + // Call this again to override config file options with options from command line. + // Note: We don't need to check error here since if there was an error, it would + // have been caught the first time this function was called (after setting up the + // flags). + fs.Parse(args) + } + + // Special handling of some flags + var ( + flagErr error + tlsDisabled bool + tlsOverride bool + ) + fs.Visit(func(f *flag.Flag) { + // short-circuit if an error was encountered + if flagErr != nil { + return + } + if strings.HasPrefix(f.Name, "tls") { + if f.Name == "tls" { + if !opts.TLS { + // User has specified "-tls=false", we need to disable TLS + opts.TLSConfig = nil + tlsDisabled = true + tlsOverride = false + return + } + tlsOverride = true + } else if !tlsDisabled { + tlsOverride = true + } + } else { + switch f.Name { + case "DV": + // Check value to support -DV=false + boolValue, _ := strconv.ParseBool(f.Value.String()) + opts.Trace, opts.Debug = boolValue, boolValue + case "cluster", "cluster_listen": + // Override cluster config if explicitly set via flags. + flagErr = overrideCluster(opts) + case "routes": + // Keep in mind that the flag has updated opts.RoutesStr at this point. + if opts.RoutesStr == "" { + // Set routes array to nil since routes string is empty + opts.Routes = nil + return + } + routeUrls := RoutesFromStr(opts.RoutesStr) + opts.Routes = routeUrls + } + } + }) + if flagErr != nil { + return nil, flagErr + } + + // This will be true if some of the `-tls` params have been set and + // `-tls=false` has not been set. + if tlsOverride { + if err := overrideTLS(opts); err != nil { + return nil, err + } + } + + // If we don't have cluster defined in the configuration + // file and no cluster listen string override, but we do + // have a routes override, we need to report misconfiguration. + if opts.RoutesStr != "" && opts.Cluster.ListenStr == "" && opts.Cluster.Host == "" && opts.Cluster.Port == 0 { + return nil, errors.New("solicited routes require cluster capabilities, e.g. --cluster") + } + + return opts, nil +} + +// overrideTLS is called when at least "-tls=true" has been set. +func overrideTLS(opts *Options) error { + if opts.TLSCert == "" { + return errors.New("TLS Server certificate must be present and valid") + } + if opts.TLSKey == "" { + return errors.New("TLS Server private key must be present and valid") + } + + tc := TLSConfigOpts{} + tc.CertFile = opts.TLSCert + tc.KeyFile = opts.TLSKey + tc.CaFile = opts.TLSCaCert + tc.Verify = opts.TLSVerify + + var err error + opts.TLSConfig, err = GenTLSConfig(&tc) + return err +} + +// overrideCluster updates Options.Cluster if that flag "cluster" (or "cluster_listen") +// has explicitly be set in the command line. If it is set to empty string, it will +// clear the Cluster options. +func overrideCluster(opts *Options) error { + if opts.Cluster.ListenStr == "" { + // This one is enough to disable clustering. + opts.Cluster.Port = 0 + return nil + } + clusterURL, err := url.Parse(opts.Cluster.ListenStr) + if err != nil { + return err + } + h, p, err := net.SplitHostPort(clusterURL.Host) + if err != nil { + return err + } + opts.Cluster.Host = h + _, err = fmt.Sscan(p, &opts.Cluster.Port) + if err != nil { + return err + } + + if clusterURL.User != nil { + pass, hasPassword := clusterURL.User.Password() + if !hasPassword { + return errors.New("expected cluster password to be set") + } + opts.Cluster.Password = pass + + user := clusterURL.User.Username() + opts.Cluster.Username = user + } else { + // Since we override from flag and there is no user/pwd, make + // sure we clear what we may have gotten from config file. + opts.Cluster.Username = "" + opts.Cluster.Password = "" + } + + return nil +} + +func processSignal(signal string) error { + var ( + pid string + commandAndPid = strings.Split(signal, "=") + ) + if l := len(commandAndPid); l == 2 { + pid = commandAndPid[1] + } else if l > 2 { + return fmt.Errorf("invalid signal parameters: %v", commandAndPid[2:]) + } + if err := ProcessSignal(Command(commandAndPid[0]), pid); err != nil { + return err + } + os.Exit(0) + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/opts_test.go b/vendor/github.com/nats-io/gnatsd/server/opts_test.go new file mode 100644 index 00000000..9c391c5a --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/opts_test.go @@ -0,0 +1,1037 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "crypto/tls" + "flag" + "io/ioutil" + "net/url" + "os" + "reflect" + "strings" + "testing" + "time" +) + +func TestDefaultOptions(t *testing.T) { + golden := &Options{ + Host: DEFAULT_HOST, + Port: DEFAULT_PORT, + MaxConn: DEFAULT_MAX_CONNECTIONS, + HTTPHost: DEFAULT_HOST, + PingInterval: DEFAULT_PING_INTERVAL, + MaxPingsOut: DEFAULT_PING_MAX_OUT, + TLSTimeout: float64(TLS_TIMEOUT) / float64(time.Second), + AuthTimeout: float64(AUTH_TIMEOUT) / float64(time.Second), + MaxControlLine: MAX_CONTROL_LINE_SIZE, + MaxPayload: MAX_PAYLOAD_SIZE, + MaxPending: MAX_PENDING_SIZE, + WriteDeadline: DEFAULT_FLUSH_DEADLINE, + RQSubsSweep: DEFAULT_REMOTE_QSUBS_SWEEPER, + MaxClosedClients: DEFAULT_MAX_CLOSED_CLIENTS, + } + + opts := &Options{} + processOptions(opts) + + if !reflect.DeepEqual(golden, opts) { + t.Fatalf("Default Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, opts) + } +} + +func TestOptions_RandomPort(t *testing.T) { + opts := &Options{Port: RANDOM_PORT} + processOptions(opts) + + if opts.Port != 0 { + t.Fatalf("Process of options should have resolved random port to "+ + "zero.\nexpected: %d\ngot: %d\n", 0, opts.Port) + } +} + +func TestConfigFile(t *testing.T) { + golden := &Options{ + ConfigFile: "./configs/test.conf", + Host: "127.0.0.1", + Port: 4242, + Username: "derek", + Password: "porkchop", + AuthTimeout: 1.0, + Debug: false, + Trace: true, + Logtime: false, + HTTPPort: 8222, + PidFile: "/tmp/gnatsd.pid", + ProfPort: 6543, + Syslog: true, + RemoteSyslog: "udp://foo.com:33", + MaxControlLine: 2048, + MaxPayload: 65536, + MaxConn: 100, + MaxSubs: 1000, + MaxPending: 10000000, + PingInterval: 60 * time.Second, + MaxPingsOut: 3, + WriteDeadline: 3 * time.Second, + } + + opts, err := ProcessConfigFile("./configs/test.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + + if !reflect.DeepEqual(golden, opts) { + t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, opts) + } +} + +func TestTLSConfigFile(t *testing.T) { + golden := &Options{ + ConfigFile: "./configs/tls.conf", + Host: "127.0.0.1", + Port: 4443, + Username: "derek", + Password: "foo", + AuthTimeout: 1.0, + TLSTimeout: 2.0, + } + opts, err := ProcessConfigFile("./configs/tls.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + tlsConfig := opts.TLSConfig + if tlsConfig == nil { + t.Fatal("Expected opts.TLSConfig to be non-nil") + } + opts.TLSConfig = nil + if !reflect.DeepEqual(golden, opts) { + t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, opts) + } + // Now check TLSConfig a bit more closely + // CipherSuites + ciphers := defaultCipherSuites() + if !reflect.DeepEqual(tlsConfig.CipherSuites, ciphers) { + t.Fatalf("Got incorrect cipher suite list: [%+v]", tlsConfig.CipherSuites) + } + if tlsConfig.MinVersion != tls.VersionTLS12 { + t.Fatalf("Expected MinVersion of 1.2 [%v], got [%v]", tls.VersionTLS12, tlsConfig.MinVersion) + } + if !tlsConfig.PreferServerCipherSuites { + t.Fatal("Expected PreferServerCipherSuites to be true") + } + // Verify hostname is correct in certificate + if len(tlsConfig.Certificates) != 1 { + t.Fatal("Expected 1 certificate") + } + cert := tlsConfig.Certificates[0].Leaf + if err := cert.VerifyHostname("127.0.0.1"); err != nil { + t.Fatalf("Could not verify hostname in certificate: %v\n", err) + } + + // Now test adding cipher suites. + opts, err = ProcessConfigFile("./configs/tls_ciphers.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + tlsConfig = opts.TLSConfig + if tlsConfig == nil { + t.Fatal("Expected opts.TLSConfig to be non-nil") + } + + // CipherSuites listed in the config - test all of them. + ciphers = []uint16{ + tls.TLS_RSA_WITH_RC4_128_SHA, + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + } + + if !reflect.DeepEqual(tlsConfig.CipherSuites, ciphers) { + t.Fatalf("Got incorrect cipher suite list: [%+v]", tlsConfig.CipherSuites) + } + + // Test an unrecognized/bad cipher + if _, err := ProcessConfigFile("./configs/tls_bad_cipher.conf"); err == nil { + t.Fatal("Did not receive an error from a unrecognized cipher") + } + + // Test an empty cipher entry in a config file. + if _, err := ProcessConfigFile("./configs/tls_empty_cipher.conf"); err == nil { + t.Fatal("Did not receive an error from empty cipher_suites") + } + + // Test a curve preference from the config. + curves := []tls.CurveID{ + tls.CurveP256, + } + + // test on a file that will load the curve preference defaults + opts, err = ProcessConfigFile("./configs/tls_ciphers.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + + if !reflect.DeepEqual(opts.TLSConfig.CurvePreferences, defaultCurvePreferences()) { + t.Fatalf("Got incorrect curve preference list: [%+v]", tlsConfig.CurvePreferences) + } + + // Test specifying a single curve preference + opts, err = ProcessConfigFile("./configs/tls_curve_prefs.conf") + if err != nil { + t.Fatal("Did not receive an error from a unrecognized cipher.") + } + + if !reflect.DeepEqual(opts.TLSConfig.CurvePreferences, curves) { + t.Fatalf("Got incorrect cipher suite list: [%+v]", tlsConfig.CurvePreferences) + } + + // Test an unrecognized/bad curve preference + if _, err := ProcessConfigFile("./configs/tls_bad_curve_prefs.conf"); err == nil { + t.Fatal("Did not receive an error from a unrecognized curve preference") + } + // Test an empty curve preference + if _, err := ProcessConfigFile("./configs/tls_empty_curve_prefs.conf"); err == nil { + t.Fatal("Did not receive an error from empty curve preferences") + } +} + +func TestMergeOverrides(t *testing.T) { + golden := &Options{ + ConfigFile: "./configs/test.conf", + Host: "127.0.0.1", + Port: 2222, + Username: "derek", + Password: "porkchop", + AuthTimeout: 1.0, + Debug: true, + Trace: true, + Logtime: false, + HTTPPort: DEFAULT_HTTP_PORT, + PidFile: "/tmp/gnatsd.pid", + ProfPort: 6789, + Syslog: true, + RemoteSyslog: "udp://foo.com:33", + MaxControlLine: 2048, + MaxPayload: 65536, + MaxConn: 100, + MaxSubs: 1000, + MaxPending: 10000000, + PingInterval: 60 * time.Second, + MaxPingsOut: 3, + Cluster: ClusterOpts{ + NoAdvertise: true, + ConnectRetries: 2, + }, + WriteDeadline: 3 * time.Second, + } + fopts, err := ProcessConfigFile("./configs/test.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + + // Overrides via flags + opts := &Options{ + Port: 2222, + Password: "porkchop", + Debug: true, + HTTPPort: DEFAULT_HTTP_PORT, + ProfPort: 6789, + Cluster: ClusterOpts{ + NoAdvertise: true, + ConnectRetries: 2, + }, + } + merged := MergeOptions(fopts, opts) + + if !reflect.DeepEqual(golden, merged) { + t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, merged) + } +} + +func TestRemoveSelfReference(t *testing.T) { + url1, _ := url.Parse("nats-route://user:password@10.4.5.6:4223") + url2, _ := url.Parse("nats-route://user:password@127.0.0.1:4223") + url3, _ := url.Parse("nats-route://user:password@127.0.0.1:4223") + + routes := []*url.URL{url1, url2, url3} + + newroutes, err := RemoveSelfReference(4223, routes) + if err != nil { + t.Fatalf("Error during RemoveSelfReference: %v", err) + } + + if len(newroutes) != 1 { + t.Fatalf("Wrong number of routes: %d", len(newroutes)) + } + + if newroutes[0] != routes[0] { + t.Fatalf("Self reference IP address %s in Routes", routes[0]) + } +} + +func TestAllowRouteWithDifferentPort(t *testing.T) { + url1, _ := url.Parse("nats-route://user:password@127.0.0.1:4224") + routes := []*url.URL{url1} + + newroutes, err := RemoveSelfReference(4223, routes) + if err != nil { + t.Fatalf("Error during RemoveSelfReference: %v", err) + } + + if len(newroutes) != 1 { + t.Fatalf("Wrong number of routes: %d", len(newroutes)) + } +} + +func TestRouteFlagOverride(t *testing.T) { + routeFlag := "nats-route://ruser:top_secret@127.0.0.1:8246" + rurl, _ := url.Parse(routeFlag) + + golden := &Options{ + ConfigFile: "./configs/srv_a.conf", + Host: "127.0.0.1", + Port: 7222, + Cluster: ClusterOpts{ + Host: "127.0.0.1", + Port: 7244, + Username: "ruser", + Password: "top_secret", + AuthTimeout: 0.5, + }, + Routes: []*url.URL{rurl}, + RoutesStr: routeFlag, + } + + fopts, err := ProcessConfigFile("./configs/srv_a.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + + // Overrides via flags + opts := &Options{ + RoutesStr: routeFlag, + } + merged := MergeOptions(fopts, opts) + + if !reflect.DeepEqual(golden, merged) { + t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, merged) + } +} + +func TestClusterFlagsOverride(t *testing.T) { + routeFlag := "nats-route://ruser:top_secret@127.0.0.1:7246" + rurl, _ := url.Parse(routeFlag) + + // In this test, we override the cluster listen string. Note that in + // the golden options, the cluster other infos correspond to what + // is recovered from the configuration file, this explains the + // discrepency between ClusterListenStr and the rest. + // The server would then process the ClusterListenStr override and + // correctly override ClusterHost/ClustherPort/etc.. + golden := &Options{ + ConfigFile: "./configs/srv_a.conf", + Host: "127.0.0.1", + Port: 7222, + Cluster: ClusterOpts{ + Host: "127.0.0.1", + Port: 7244, + ListenStr: "nats://127.0.0.1:8224", + Username: "ruser", + Password: "top_secret", + AuthTimeout: 0.5, + }, + Routes: []*url.URL{rurl}, + } + + fopts, err := ProcessConfigFile("./configs/srv_a.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + + // Overrides via flags + opts := &Options{ + Cluster: ClusterOpts{ + ListenStr: "nats://127.0.0.1:8224", + }, + } + merged := MergeOptions(fopts, opts) + + if !reflect.DeepEqual(golden, merged) { + t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, merged) + } +} + +func TestRouteFlagOverrideWithMultiple(t *testing.T) { + routeFlag := "nats-route://ruser:top_secret@127.0.0.1:8246, nats-route://ruser:top_secret@127.0.0.1:8266" + rurls := RoutesFromStr(routeFlag) + + golden := &Options{ + ConfigFile: "./configs/srv_a.conf", + Host: "127.0.0.1", + Port: 7222, + Cluster: ClusterOpts{ + Host: "127.0.0.1", + Port: 7244, + Username: "ruser", + Password: "top_secret", + AuthTimeout: 0.5, + }, + Routes: rurls, + RoutesStr: routeFlag, + } + + fopts, err := ProcessConfigFile("./configs/srv_a.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + + // Overrides via flags + opts := &Options{ + RoutesStr: routeFlag, + } + merged := MergeOptions(fopts, opts) + + if !reflect.DeepEqual(golden, merged) { + t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, merged) + } +} + +func TestDynamicPortOnListen(t *testing.T) { + opts, err := ProcessConfigFile("./configs/listen-1.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + if opts.Port != -1 { + t.Fatalf("Received incorrect port %v, expected -1\n", opts.Port) + } + if opts.HTTPPort != -1 { + t.Fatalf("Received incorrect monitoring port %v, expected -1\n", opts.HTTPPort) + } + if opts.HTTPSPort != -1 { + t.Fatalf("Received incorrect secure monitoring port %v, expected -1\n", opts.HTTPSPort) + } +} + +func TestListenConfig(t *testing.T) { + opts, err := ProcessConfigFile("./configs/listen.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + processOptions(opts) + + // Normal clients + host := "10.0.1.22" + port := 4422 + monHost := "127.0.0.1" + if opts.Host != host { + t.Fatalf("Received incorrect host %q, expected %q\n", opts.Host, host) + } + if opts.HTTPHost != monHost { + t.Fatalf("Received incorrect host %q, expected %q\n", opts.HTTPHost, monHost) + } + if opts.Port != port { + t.Fatalf("Received incorrect port %v, expected %v\n", opts.Port, port) + } + + // Clustering + clusterHost := "127.0.0.1" + clusterPort := 4244 + + if opts.Cluster.Host != clusterHost { + t.Fatalf("Received incorrect cluster host %q, expected %q\n", opts.Cluster.Host, clusterHost) + } + if opts.Cluster.Port != clusterPort { + t.Fatalf("Received incorrect cluster port %v, expected %v\n", opts.Cluster.Port, clusterPort) + } + + // HTTP + httpHost := "127.0.0.1" + httpPort := 8422 + + if opts.HTTPHost != httpHost { + t.Fatalf("Received incorrect http host %q, expected %q\n", opts.HTTPHost, httpHost) + } + if opts.HTTPPort != httpPort { + t.Fatalf("Received incorrect http port %v, expected %v\n", opts.HTTPPort, httpPort) + } + + // HTTPS + httpsPort := 9443 + if opts.HTTPSPort != httpsPort { + t.Fatalf("Received incorrect https port %v, expected %v\n", opts.HTTPSPort, httpsPort) + } +} + +func TestListenPortOnlyConfig(t *testing.T) { + opts, err := ProcessConfigFile("./configs/listen_port.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + processOptions(opts) + + port := 8922 + + if opts.Host != DEFAULT_HOST { + t.Fatalf("Received incorrect host %q, expected %q\n", opts.Host, DEFAULT_HOST) + } + if opts.HTTPHost != DEFAULT_HOST { + t.Fatalf("Received incorrect host %q, expected %q\n", opts.Host, DEFAULT_HOST) + } + if opts.Port != port { + t.Fatalf("Received incorrect port %v, expected %v\n", opts.Port, port) + } +} + +func TestListenPortWithColonConfig(t *testing.T) { + opts, err := ProcessConfigFile("./configs/listen_port_with_colon.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + processOptions(opts) + + port := 8922 + + if opts.Host != DEFAULT_HOST { + t.Fatalf("Received incorrect host %q, expected %q\n", opts.Host, DEFAULT_HOST) + } + if opts.HTTPHost != DEFAULT_HOST { + t.Fatalf("Received incorrect host %q, expected %q\n", opts.Host, DEFAULT_HOST) + } + if opts.Port != port { + t.Fatalf("Received incorrect port %v, expected %v\n", opts.Port, port) + } +} + +func TestListenMonitoringDefault(t *testing.T) { + opts := &Options{ + Host: "10.0.1.22", + } + processOptions(opts) + + host := "10.0.1.22" + if opts.Host != host { + t.Fatalf("Received incorrect host %q, expected %q\n", opts.Host, host) + } + if opts.HTTPHost != host { + t.Fatalf("Received incorrect host %q, expected %q\n", opts.Host, host) + } + if opts.Port != DEFAULT_PORT { + t.Fatalf("Received incorrect port %v, expected %v\n", opts.Port, DEFAULT_PORT) + } +} + +func TestMultipleUsersConfig(t *testing.T) { + opts, err := ProcessConfigFile("./configs/multiple_users.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + processOptions(opts) +} + +// Test highly depends on contents of the config file listed below. Any changes to that file +// may very well break this test. +func TestAuthorizationConfig(t *testing.T) { + opts, err := ProcessConfigFile("./configs/authorization.conf") + if err != nil { + t.Fatalf("Received an error reading config file: %v\n", err) + } + processOptions(opts) + lu := len(opts.Users) + if lu != 3 { + t.Fatalf("Expected 3 users, got %d\n", lu) + } + // Build a map + mu := make(map[string]*User) + for _, u := range opts.Users { + mu[u.Username] = u + } + + // Alice + alice, ok := mu["alice"] + if !ok { + t.Fatalf("Expected to see user Alice\n") + } + // Check for permissions details + if alice.Permissions == nil { + t.Fatalf("Expected Alice's permissions to be non-nil\n") + } + if alice.Permissions.Publish == nil { + t.Fatalf("Expected Alice's publish permissions to be non-nil\n") + } + if len(alice.Permissions.Publish) != 1 { + t.Fatalf("Expected Alice's publish permissions to have 1 element, got %d\n", + len(alice.Permissions.Publish)) + } + pubPerm := alice.Permissions.Publish[0] + if pubPerm != "*" { + t.Fatalf("Expected Alice's publish permissions to be '*', got %q\n", pubPerm) + } + if alice.Permissions.Subscribe == nil { + t.Fatalf("Expected Alice's subscribe permissions to be non-nil\n") + } + if len(alice.Permissions.Subscribe) != 1 { + t.Fatalf("Expected Alice's subscribe permissions to have 1 element, got %d\n", + len(alice.Permissions.Subscribe)) + } + subPerm := alice.Permissions.Subscribe[0] + if subPerm != ">" { + t.Fatalf("Expected Alice's subscribe permissions to be '>', got %q\n", subPerm) + } + + // Bob + bob, ok := mu["bob"] + if !ok { + t.Fatalf("Expected to see user Bob\n") + } + if bob.Permissions == nil { + t.Fatalf("Expected Bob's permissions to be non-nil\n") + } + + // Susan + susan, ok := mu["susan"] + if !ok { + t.Fatalf("Expected to see user Susan\n") + } + if susan.Permissions == nil { + t.Fatalf("Expected Susan's permissions to be non-nil\n") + } + // Check susan closely since she inherited the default permissions. + if susan.Permissions == nil { + t.Fatalf("Expected Susan's permissions to be non-nil\n") + } + if susan.Permissions.Publish != nil { + t.Fatalf("Expected Susan's publish permissions to be nil\n") + } + if susan.Permissions.Subscribe == nil { + t.Fatalf("Expected Susan's subscribe permissions to be non-nil\n") + } + if len(susan.Permissions.Subscribe) != 1 { + t.Fatalf("Expected Susan's subscribe permissions to have 1 element, got %d\n", + len(susan.Permissions.Subscribe)) + } + subPerm = susan.Permissions.Subscribe[0] + if subPerm != "PUBLIC.>" { + t.Fatalf("Expected Susan's subscribe permissions to be 'PUBLIC.>', got %q\n", subPerm) + } +} + +func TestTokenWithUserPass(t *testing.T) { + confFileName := "test.conf" + defer os.Remove(confFileName) + content := ` + authorization={ + user: user + pass: password + token: $2a$11$whatever + }` + if err := ioutil.WriteFile(confFileName, []byte(content), 0666); err != nil { + t.Fatalf("Error writing config file: %v", err) + } + _, err := ProcessConfigFile(confFileName) + if err == nil { + t.Fatal("Expected error, got none") + } + if !strings.Contains(err.Error(), "token") { + t.Fatalf("Expected error related to token, got %v", err) + } +} + +func TestTokenWithUsers(t *testing.T) { + confFileName := "test.conf" + defer os.Remove(confFileName) + content := ` + authorization={ + token: $2a$11$whatever + users: [ + {user: test, password: test} + ] + }` + if err := ioutil.WriteFile(confFileName, []byte(content), 0666); err != nil { + t.Fatalf("Error writing config file: %v", err) + } + _, err := ProcessConfigFile(confFileName) + if err == nil { + t.Fatal("Expected error, got none") + } + if !strings.Contains(err.Error(), "token") { + t.Fatalf("Expected error related to token, got %v", err) + } +} + +func TestParseWriteDeadline(t *testing.T) { + confFile := "test.conf" + defer os.Remove(confFile) + if err := ioutil.WriteFile(confFile, []byte("write_deadline: \"1x\"\n"), 0666); err != nil { + t.Fatalf("Error writing config file: %v", err) + } + _, err := ProcessConfigFile(confFile) + if err == nil { + t.Fatal("Expected error, got none") + } + if !strings.Contains(err.Error(), "parsing") { + t.Fatalf("Expected error related to parsing, got %v", err) + } + os.Remove(confFile) + if err := ioutil.WriteFile(confFile, []byte("write_deadline: \"1s\"\n"), 0666); err != nil { + t.Fatalf("Error writing config file: %v", err) + } + opts, err := ProcessConfigFile(confFile) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if opts.WriteDeadline != time.Second { + t.Fatalf("Expected write_deadline to be 1s, got %v", opts.WriteDeadline) + } + os.Remove(confFile) + oldStdout := os.Stdout + _, w, _ := os.Pipe() + defer func() { + w.Close() + os.Stdout = oldStdout + }() + os.Stdout = w + if err := ioutil.WriteFile(confFile, []byte("write_deadline: 2\n"), 0666); err != nil { + t.Fatalf("Error writing config file: %v", err) + } + opts, err = ProcessConfigFile(confFile) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if opts.WriteDeadline != 2*time.Second { + t.Fatalf("Expected write_deadline to be 2s, got %v", opts.WriteDeadline) + } +} + +func TestOptionsClone(t *testing.T) { + opts := &Options{ + ConfigFile: "./configs/test.conf", + Host: "127.0.0.1", + Port: 2222, + Username: "derek", + Password: "porkchop", + AuthTimeout: 1.0, + Debug: true, + Trace: true, + Logtime: false, + HTTPPort: DEFAULT_HTTP_PORT, + PidFile: "/tmp/gnatsd.pid", + ProfPort: 6789, + Syslog: true, + RemoteSyslog: "udp://foo.com:33", + MaxControlLine: 2048, + MaxPayload: 65536, + MaxConn: 100, + PingInterval: 60 * time.Second, + MaxPingsOut: 3, + Cluster: ClusterOpts{ + NoAdvertise: true, + ConnectRetries: 2, + }, + WriteDeadline: 3 * time.Second, + Routes: []*url.URL{&url.URL{}}, + Users: []*User{&User{Username: "foo", Password: "bar"}}, + } + + clone := opts.Clone() + + if !reflect.DeepEqual(opts, clone) { + t.Fatalf("Cloned Options are incorrect.\nexpected: %+v\ngot: %+v", + clone, opts) + } + + clone.Users[0].Password = "baz" + if reflect.DeepEqual(opts, clone) { + t.Fatal("Expected Options to be different") + } +} + +func TestOptionsCloneNilLists(t *testing.T) { + opts := &Options{} + + clone := opts.Clone() + + if clone.Routes != nil { + t.Fatalf("Expected Routes to be nil, got: %v", clone.Routes) + } + if clone.Users != nil { + t.Fatalf("Expected Users to be nil, got: %v", clone.Users) + } +} + +func TestOptionsCloneNil(t *testing.T) { + opts := (*Options)(nil) + clone := opts.Clone() + if clone != nil { + t.Fatalf("Expected nil, got: %+v", clone) + } +} + +func TestEmptyConfig(t *testing.T) { + opts, err := ProcessConfigFile("") + + if err != nil { + t.Fatalf("Expected no error from empty config, got: %+v", err) + } + + if opts.ConfigFile != "" { + t.Fatalf("Expected empty config, got: %+v", opts) + } +} + +func TestMalformedListenAddress(t *testing.T) { + opts, err := ProcessConfigFile("./configs/malformed_listen_address.conf") + if err == nil { + t.Fatalf("Expected an error reading config file: got %+v\n", opts) + } +} + +func TestMalformedClusterAddress(t *testing.T) { + opts, err := ProcessConfigFile("./configs/malformed_cluster_address.conf") + if err == nil { + t.Fatalf("Expected an error reading config file: got %+v\n", opts) + } +} + +func TestOptionsProcessConfigFile(t *testing.T) { + // Create options with default values of Debug and Trace + // that are the opposite of what is in the config file. + // Set another option that is not present in the config file. + logFileName := "test.log" + opts := &Options{ + Debug: true, + Trace: false, + LogFile: logFileName, + } + configFileName := "./configs/test.conf" + if err := opts.ProcessConfigFile(configFileName); err != nil { + t.Fatalf("Error processing config file: %v", err) + } + // Verify that values are as expected + if opts.ConfigFile != configFileName { + t.Fatalf("Expected ConfigFile to be set to %q, got %v", configFileName, opts.ConfigFile) + } + if opts.Debug { + t.Fatal("Debug option should have been set to false from config file") + } + if !opts.Trace { + t.Fatal("Trace option should have been set to true from config file") + } + if opts.LogFile != logFileName { + t.Fatalf("Expected LogFile to be %q, got %q", logFileName, opts.LogFile) + } +} + +func TestConfigureOptions(t *testing.T) { + // Options.Configure() will snapshot the flags. This is used by the reload code. + // We need to set it back to nil otherwise it will impact reload tests. + defer func() { FlagSnapshot = nil }() + + ch := make(chan bool, 1) + checkPrintInvoked := func() { + ch <- true + } + usage := func() { panic("should not get there") } + var fs *flag.FlagSet + type testPrint struct { + args []string + version, help, tlsHelp func() + } + testFuncs := []testPrint{ + testPrint{[]string{"-v"}, checkPrintInvoked, usage, PrintTLSHelpAndDie}, + testPrint{[]string{"version"}, checkPrintInvoked, usage, PrintTLSHelpAndDie}, + testPrint{[]string{"-h"}, PrintServerAndExit, checkPrintInvoked, PrintTLSHelpAndDie}, + testPrint{[]string{"help"}, PrintServerAndExit, checkPrintInvoked, PrintTLSHelpAndDie}, + testPrint{[]string{"-help_tls"}, PrintServerAndExit, usage, checkPrintInvoked}, + } + for _, tf := range testFuncs { + fs = flag.NewFlagSet("test", flag.ContinueOnError) + opts, err := ConfigureOptions(fs, tf.args, tf.version, tf.help, tf.tlsHelp) + if err != nil { + t.Fatalf("Error on configure: %v", err) + } + if opts != nil { + t.Fatalf("Expected options to be nil, got %v", opts) + } + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("Should have invoked print function for args=%v", tf.args) + } + } + + // Helper function that expect parsing with given args to not produce an error. + mustNotFail := func(args []string) *Options { + fs := flag.NewFlagSet("test", flag.ContinueOnError) + opts, err := ConfigureOptions(fs, args, PrintServerAndExit, fs.Usage, PrintTLSHelpAndDie) + if err != nil { + stackFatalf(t, "Error on configure: %v", err) + } + return opts + } + + // Helper function that expect configuration to fail. + expectToFail := func(args []string, errContent ...string) { + fs := flag.NewFlagSet("test", flag.ContinueOnError) + // Silence the flagSet so that on failure nothing is printed. + // (flagSet would print error message about unknown flags, etc..) + silenceOuput := &bytes.Buffer{} + fs.SetOutput(silenceOuput) + opts, err := ConfigureOptions(fs, args, PrintServerAndExit, fs.Usage, PrintTLSHelpAndDie) + if opts != nil || err == nil { + stackFatalf(t, "Expected no option and an error, got opts=%v and err=%v", opts, err) + } + for _, testErr := range errContent { + if strings.Contains(err.Error(), testErr) { + // We got the error we wanted. + return + } + } + stackFatalf(t, "Expected errors containing any of those %v, got %v", errContent, err) + } + + // Basic test with port number + opts := mustNotFail([]string{"-p", "1234"}) + if opts.Port != 1234 { + t.Fatalf("Expected port to be 1234, got %v", opts.Port) + } + + // Should fail because of unknown parameter + expectToFail([]string{"foo"}, "command") + + // Should fail because unknown flag + expectToFail([]string{"-xxx", "foo"}, "flag") + + // Should fail because of config file missing + expectToFail([]string{"-c", "xxx.cfg"}, "file") + + // Should fail because of too many args for signal command + expectToFail([]string{"-sl", "quit=pid=foo"}, "signal") + + // Should fail because of invalid pid + // On windows, if not running with admin privileges, you would get access denied. + expectToFail([]string{"-sl", "quit=pid"}, "pid", "denied") + + // The config file set Trace to true. + opts = mustNotFail([]string{"-c", "./configs/test.conf"}) + if !opts.Trace { + t.Fatal("Trace should have been set to true") + } + + // The config file set Trace to true, but was overridden by param -V=false + opts = mustNotFail([]string{"-c", "./configs/test.conf", "-V=false"}) + if opts.Trace { + t.Fatal("Trace should have been set to false") + } + + // The config file set Trace to true, but was overridden by param -DV=false + opts = mustNotFail([]string{"-c", "./configs/test.conf", "-DV=false"}) + if opts.Debug || opts.Trace { + t.Fatal("Debug and Trace should have been set to false") + } + + // The config file set Trace to true, but was overridden by param -DV + opts = mustNotFail([]string{"-c", "./configs/test.conf", "-DV"}) + if !opts.Debug || !opts.Trace { + t.Fatal("Debug and Trace should have been set to true") + } + + // This should fail since -cluster is missing + expectedURL, _ := url.Parse("nats://127.0.0.1:6223") + expectToFail([]string{"-routes", expectedURL.String()}, "solicited routes") + + // Ensure that we can set cluster and routes from command line + opts = mustNotFail([]string{"-cluster", "nats://127.0.0.1:6222", "-routes", expectedURL.String()}) + if opts.Cluster.ListenStr != "nats://127.0.0.1:6222" { + t.Fatalf("Unexpected Cluster.ListenStr=%q", opts.Cluster.ListenStr) + } + if opts.RoutesStr != "nats://127.0.0.1:6223" || len(opts.Routes) != 1 || opts.Routes[0].String() != expectedURL.String() { + t.Fatalf("Unexpected RoutesStr: %q and Routes: %v", opts.RoutesStr, opts.Routes) + } + + // Use a config with cluster configuration and explicit route defined. + // Override with empty routes string. + opts = mustNotFail([]string{"-c", "./configs/srv_a.conf", "-routes", ""}) + if opts.RoutesStr != "" || len(opts.Routes) != 0 { + t.Fatalf("Unexpected RoutesStr: %q and Routes: %v", opts.RoutesStr, opts.Routes) + } + + // Use a config with cluster configuration and override cluster listen string + expectedURL, _ = url.Parse("nats-route://ruser:top_secret@127.0.0.1:7246") + opts = mustNotFail([]string{"-c", "./configs/srv_a.conf", "-cluster", "nats://ivan:pwd@127.0.0.1:6222"}) + if opts.Cluster.Username != "ivan" || opts.Cluster.Password != "pwd" || opts.Cluster.Port != 6222 || + len(opts.Routes) != 1 || opts.Routes[0].String() != expectedURL.String() { + t.Fatalf("Unexpected Cluster and/or Routes: %#v - %v", opts.Cluster, opts.Routes) + } + + // Disable clustering from command line + opts = mustNotFail([]string{"-c", "./configs/srv_a.conf", "-cluster", ""}) + if opts.Cluster.Port != 0 { + t.Fatalf("Unexpected Cluster: %v", opts.Cluster) + } + + // Various erros due to malformed cluster listen string. + // (adding -routes to have more than 1 set flag to check + // that Visit() stops when an error is found). + expectToFail([]string{"-cluster", ":", "-routes", ""}, "protocol") + expectToFail([]string{"-cluster", "nats://127.0.0.1", "-routes", ""}, "port") + expectToFail([]string{"-cluster", "nats://127.0.0.1:xxx", "-routes", ""}, "integer") + expectToFail([]string{"-cluster", "nats://ivan:127.0.0.1:6222", "-routes", ""}, "colons") + expectToFail([]string{"-cluster", "nats://ivan@127.0.0.1:6222", "-routes", ""}, "password") + + // Override config file's TLS configuration from command line, and completely disable TLS + opts = mustNotFail([]string{"-c", "./configs/tls.conf", "-tls=false"}) + if opts.TLSConfig != nil || opts.TLS { + t.Fatal("Expected TLS to be disabled") + } + // Override config file's TLS configuration from command line, and force TLS verification. + // However, since TLS config has to be regenerated, user need to provide -tlscert and -tlskey too. + // So this should fail. + expectToFail([]string{"-c", "./configs/tls.conf", "-tlsverify"}, "valid") + + // Now same than above, but with all valid params. + opts = mustNotFail([]string{"-c", "./configs/tls.conf", "-tlsverify", "-tlscert", "./configs/certs/server.pem", "-tlskey", "./configs/certs/key.pem"}) + if opts.TLSConfig == nil || !opts.TLSVerify { + t.Fatal("Expected TLS to be configured and force verification") + } + + // Configure TLS, but some TLS params missing + expectToFail([]string{"-tls"}, "valid") + expectToFail([]string{"-tls", "-tlscert", "./configs/certs/server.pem"}, "valid") + // One of the file does not exist + expectToFail([]string{"-tls", "-tlscert", "./configs/certs/server.pem", "-tlskey", "./configs/certs/notfound.pem"}, "file") + + // Configure TLS and check that this results in a TLSConfig option. + opts = mustNotFail([]string{"-tls", "-tlscert", "./configs/certs/server.pem", "-tlskey", "./configs/certs/key.pem"}) + if opts.TLSConfig == nil || !opts.TLS { + t.Fatal("Expected TLSConfig to be set") + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/parser.go b/vendor/github.com/nats-io/gnatsd/server/parser.go new file mode 100644 index 00000000..088894fb --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/parser.go @@ -0,0 +1,749 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" +) + +type pubArg struct { + subject []byte + reply []byte + sid []byte + szb []byte + size int +} + +type parseState struct { + state int + as int + drop int + pa pubArg + argBuf []byte + msgBuf []byte + scratch [MAX_CONTROL_LINE_SIZE]byte +} + +// Parser constants +const ( + OP_START = iota + OP_PLUS + OP_PLUS_O + OP_PLUS_OK + OP_MINUS + OP_MINUS_E + OP_MINUS_ER + OP_MINUS_ERR + OP_MINUS_ERR_SPC + MINUS_ERR_ARG + OP_C + OP_CO + OP_CON + OP_CONN + OP_CONNE + OP_CONNEC + OP_CONNECT + CONNECT_ARG + OP_P + OP_PU + OP_PUB + OP_PUB_SPC + PUB_ARG + OP_PI + OP_PIN + OP_PING + OP_PO + OP_PON + OP_PONG + MSG_PAYLOAD + MSG_END + OP_S + OP_SU + OP_SUB + OP_SUB_SPC + SUB_ARG + OP_U + OP_UN + OP_UNS + OP_UNSU + OP_UNSUB + OP_UNSUB_SPC + UNSUB_ARG + OP_M + OP_MS + OP_MSG + OP_MSG_SPC + MSG_ARG + OP_I + OP_IN + OP_INF + OP_INFO + INFO_ARG +) + +func (c *client) parse(buf []byte) error { + var i int + var b byte + + mcl := MAX_CONTROL_LINE_SIZE + if c.srv != nil && c.srv.getOpts() != nil { + mcl = c.srv.getOpts().MaxControlLine + } + + // snapshot this, and reset when we receive a + // proper CONNECT if needed. + authSet := c.isAuthTimerSet() + + // Move to loop instead of range syntax to allow jumping of i + for i = 0; i < len(buf); i++ { + b = buf[i] + + switch c.state { + case OP_START: + if b != 'C' && b != 'c' && authSet { + goto authErr + } + switch b { + case 'P', 'p': + c.state = OP_P + case 'S', 's': + c.state = OP_S + case 'U', 'u': + c.state = OP_U + case 'M', 'm': + if c.typ == CLIENT { + goto parseErr + } else { + c.state = OP_M + } + case 'C', 'c': + c.state = OP_C + case 'I', 'i': + c.state = OP_I + case '+': + c.state = OP_PLUS + case '-': + c.state = OP_MINUS + default: + goto parseErr + } + case OP_P: + switch b { + case 'U', 'u': + c.state = OP_PU + case 'I', 'i': + c.state = OP_PI + case 'O', 'o': + c.state = OP_PO + default: + goto parseErr + } + case OP_PU: + switch b { + case 'B', 'b': + c.state = OP_PUB + default: + goto parseErr + } + case OP_PUB: + switch b { + case ' ', '\t': + c.state = OP_PUB_SPC + default: + goto parseErr + } + case OP_PUB_SPC: + switch b { + case ' ', '\t': + continue + default: + c.state = PUB_ARG + c.as = i + } + case PUB_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + } else { + arg = buf[c.as : i-c.drop] + } + if err := c.processPub(arg); err != nil { + return err + } + c.drop, c.as, c.state = OP_START, i+1, MSG_PAYLOAD + // If we don't have a saved buffer then jump ahead with + // the index. If this overruns what is left we fall out + // and process split buffer. + if c.msgBuf == nil { + i = c.as + c.pa.size - LEN_CR_LF + } + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + case MSG_PAYLOAD: + if c.msgBuf != nil { + // copy as much as we can to the buffer and skip ahead. + toCopy := c.pa.size - len(c.msgBuf) + avail := len(buf) - i + if avail < toCopy { + toCopy = avail + } + if toCopy > 0 { + start := len(c.msgBuf) + // This is needed for copy to work. + c.msgBuf = c.msgBuf[:start+toCopy] + copy(c.msgBuf[start:], buf[i:i+toCopy]) + // Update our index + i = (i + toCopy) - 1 + } else { + // Fall back to append if needed. + c.msgBuf = append(c.msgBuf, b) + } + if len(c.msgBuf) >= c.pa.size { + c.state = MSG_END + } + } else if i-c.as >= c.pa.size { + c.state = MSG_END + } + case MSG_END: + switch b { + case '\n': + if c.msgBuf != nil { + c.msgBuf = append(c.msgBuf, b) + } else { + c.msgBuf = buf[c.as : i+1] + } + // strict check for proto + if len(c.msgBuf) != c.pa.size+LEN_CR_LF { + goto parseErr + } + c.processMsg(c.msgBuf) + c.argBuf, c.msgBuf = nil, nil + c.drop, c.as, c.state = 0, i+1, OP_START + default: + if c.msgBuf != nil { + c.msgBuf = append(c.msgBuf, b) + } + continue + } + case OP_S: + switch b { + case 'U', 'u': + c.state = OP_SU + default: + goto parseErr + } + case OP_SU: + switch b { + case 'B', 'b': + c.state = OP_SUB + default: + goto parseErr + } + case OP_SUB: + switch b { + case ' ', '\t': + c.state = OP_SUB_SPC + default: + goto parseErr + } + case OP_SUB_SPC: + switch b { + case ' ', '\t': + continue + default: + c.state = SUB_ARG + c.as = i + } + case SUB_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + c.argBuf = nil + } else { + arg = buf[c.as : i-c.drop] + } + if err := c.processSub(arg); err != nil { + return err + } + c.drop, c.as, c.state = 0, i+1, OP_START + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + case OP_U: + switch b { + case 'N', 'n': + c.state = OP_UN + default: + goto parseErr + } + case OP_UN: + switch b { + case 'S', 's': + c.state = OP_UNS + default: + goto parseErr + } + case OP_UNS: + switch b { + case 'U', 'u': + c.state = OP_UNSU + default: + goto parseErr + } + case OP_UNSU: + switch b { + case 'B', 'b': + c.state = OP_UNSUB + default: + goto parseErr + } + case OP_UNSUB: + switch b { + case ' ', '\t': + c.state = OP_UNSUB_SPC + default: + goto parseErr + } + case OP_UNSUB_SPC: + switch b { + case ' ', '\t': + continue + default: + c.state = UNSUB_ARG + c.as = i + } + case UNSUB_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + c.argBuf = nil + } else { + arg = buf[c.as : i-c.drop] + } + if err := c.processUnsub(arg); err != nil { + return err + } + c.drop, c.as, c.state = 0, i+1, OP_START + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + case OP_PI: + switch b { + case 'N', 'n': + c.state = OP_PIN + default: + goto parseErr + } + case OP_PIN: + switch b { + case 'G', 'g': + c.state = OP_PING + default: + goto parseErr + } + case OP_PING: + switch b { + case '\n': + c.processPing() + c.drop, c.state = 0, OP_START + } + case OP_PO: + switch b { + case 'N', 'n': + c.state = OP_PON + default: + goto parseErr + } + case OP_PON: + switch b { + case 'G', 'g': + c.state = OP_PONG + default: + goto parseErr + } + case OP_PONG: + switch b { + case '\n': + c.processPong() + c.drop, c.state = 0, OP_START + } + case OP_C: + switch b { + case 'O', 'o': + c.state = OP_CO + default: + goto parseErr + } + case OP_CO: + switch b { + case 'N', 'n': + c.state = OP_CON + default: + goto parseErr + } + case OP_CON: + switch b { + case 'N', 'n': + c.state = OP_CONN + default: + goto parseErr + } + case OP_CONN: + switch b { + case 'E', 'e': + c.state = OP_CONNE + default: + goto parseErr + } + case OP_CONNE: + switch b { + case 'C', 'c': + c.state = OP_CONNEC + default: + goto parseErr + } + case OP_CONNEC: + switch b { + case 'T', 't': + c.state = OP_CONNECT + default: + goto parseErr + } + case OP_CONNECT: + switch b { + case ' ', '\t': + continue + default: + c.state = CONNECT_ARG + c.as = i + } + case CONNECT_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + c.argBuf = nil + } else { + arg = buf[c.as : i-c.drop] + } + if err := c.processConnect(arg); err != nil { + return err + } + c.drop, c.state = 0, OP_START + // Reset notion on authSet + authSet = c.isAuthTimerSet() + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + case OP_M: + switch b { + case 'S', 's': + c.state = OP_MS + default: + goto parseErr + } + case OP_MS: + switch b { + case 'G', 'g': + c.state = OP_MSG + default: + goto parseErr + } + case OP_MSG: + switch b { + case ' ', '\t': + c.state = OP_MSG_SPC + default: + goto parseErr + } + case OP_MSG_SPC: + switch b { + case ' ', '\t': + continue + default: + c.state = MSG_ARG + c.as = i + } + case MSG_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + } else { + arg = buf[c.as : i-c.drop] + } + if err := c.processMsgArgs(arg); err != nil { + return err + } + c.drop, c.as, c.state = 0, i+1, MSG_PAYLOAD + + // jump ahead with the index. If this overruns + // what is left we fall out and process split + // buffer. + i = c.as + c.pa.size - 1 + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + case OP_I: + switch b { + case 'N', 'n': + c.state = OP_IN + default: + goto parseErr + } + case OP_IN: + switch b { + case 'F', 'f': + c.state = OP_INF + default: + goto parseErr + } + case OP_INF: + switch b { + case 'O', 'o': + c.state = OP_INFO + default: + goto parseErr + } + case OP_INFO: + switch b { + case ' ', '\t': + continue + default: + c.state = INFO_ARG + c.as = i + } + case INFO_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + c.argBuf = nil + } else { + arg = buf[c.as : i-c.drop] + } + if err := c.processInfo(arg); err != nil { + return err + } + c.drop, c.as, c.state = 0, i+1, OP_START + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + case OP_PLUS: + switch b { + case 'O', 'o': + c.state = OP_PLUS_O + default: + goto parseErr + } + case OP_PLUS_O: + switch b { + case 'K', 'k': + c.state = OP_PLUS_OK + default: + goto parseErr + } + case OP_PLUS_OK: + switch b { + case '\n': + c.drop, c.state = 0, OP_START + } + case OP_MINUS: + switch b { + case 'E', 'e': + c.state = OP_MINUS_E + default: + goto parseErr + } + case OP_MINUS_E: + switch b { + case 'R', 'r': + c.state = OP_MINUS_ER + default: + goto parseErr + } + case OP_MINUS_ER: + switch b { + case 'R', 'r': + c.state = OP_MINUS_ERR + default: + goto parseErr + } + case OP_MINUS_ERR: + switch b { + case ' ', '\t': + c.state = OP_MINUS_ERR_SPC + default: + goto parseErr + } + case OP_MINUS_ERR_SPC: + switch b { + case ' ', '\t': + continue + default: + c.state = MINUS_ERR_ARG + c.as = i + } + case MINUS_ERR_ARG: + switch b { + case '\r': + c.drop = 1 + case '\n': + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + c.argBuf = nil + } else { + arg = buf[c.as : i-c.drop] + } + c.processErr(string(arg)) + c.drop, c.as, c.state = 0, i+1, OP_START + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } + } + default: + goto parseErr + } + } + + // Check for split buffer scenarios for any ARG state. + if c.state == SUB_ARG || c.state == UNSUB_ARG || c.state == PUB_ARG || + c.state == MSG_ARG || c.state == MINUS_ERR_ARG || + c.state == CONNECT_ARG || c.state == INFO_ARG { + // Setup a holder buffer to deal with split buffer scenario. + if c.argBuf == nil { + c.argBuf = c.scratch[:0] + c.argBuf = append(c.argBuf, buf[c.as:i-c.drop]...) + } + // Check for violations of control line length here. Note that this is not + // exact at all but the performance hit is too great to be precise, and + // catching here should prevent memory exhaustion attacks. + if len(c.argBuf) > mcl { + c.sendErr("Maximum Control Line Exceeded") + c.closeConnection(MaxControlLineExceeded) + return ErrMaxControlLine + } + } + + // Check for split msg + if (c.state == MSG_PAYLOAD || c.state == MSG_END) && c.msgBuf == nil { + // We need to clone the pubArg if it is still referencing the + // read buffer and we are not able to process the msg. + if c.argBuf == nil { + // Works also for MSG_ARG, when message comes from ROUTE. + c.clonePubArg() + } + + // If we will overflow the scratch buffer, just create a + // new buffer to hold the split message. + if c.pa.size > cap(c.scratch)-len(c.argBuf) { + lrem := len(buf[c.as:]) + + // Consider it a protocol error when the remaining payload + // is larger than the reported size for PUB. It can happen + // when processing incomplete messages from rogue clients. + if lrem > c.pa.size+LEN_CR_LF { + goto parseErr + } + c.msgBuf = make([]byte, lrem, c.pa.size+LEN_CR_LF) + copy(c.msgBuf, buf[c.as:]) + } else { + c.msgBuf = c.scratch[len(c.argBuf):len(c.argBuf)] + c.msgBuf = append(c.msgBuf, (buf[c.as:])...) + } + } + + return nil + +authErr: + c.authViolation() + return ErrAuthorization + +parseErr: + c.sendErr("Unknown Protocol Operation") + snip := protoSnippet(i, buf) + err := fmt.Errorf("%s parser ERROR, state=%d, i=%d: proto='%s...'", + c.typeString(), c.state, i, snip) + return err +} + +func protoSnippet(start int, buf []byte) string { + stop := start + PROTO_SNIPPET_SIZE + bufSize := len(buf) + if start >= bufSize { + return `""` + } + if stop > bufSize { + stop = bufSize - 1 + } + return fmt.Sprintf("%q", buf[start:stop]) +} + +// clonePubArg is used when the split buffer scenario has the pubArg in the existing read buffer, but +// we need to hold onto it into the next read. +func (c *client) clonePubArg() { + c.argBuf = c.scratch[:0] + c.argBuf = append(c.argBuf, c.pa.subject...) + c.argBuf = append(c.argBuf, c.pa.reply...) + c.argBuf = append(c.argBuf, c.pa.sid...) + c.argBuf = append(c.argBuf, c.pa.szb...) + + c.pa.subject = c.argBuf[:len(c.pa.subject)] + + if c.pa.reply != nil { + c.pa.reply = c.argBuf[len(c.pa.subject) : len(c.pa.subject)+len(c.pa.reply)] + } + + if c.pa.sid != nil { + c.pa.sid = c.argBuf[len(c.pa.subject)+len(c.pa.reply) : len(c.pa.subject)+len(c.pa.reply)+len(c.pa.sid)] + } + + c.pa.szb = c.argBuf[len(c.pa.subject)+len(c.pa.reply)+len(c.pa.sid):] +} diff --git a/vendor/github.com/nats-io/gnatsd/server/parser_test.go b/vendor/github.com/nats-io/gnatsd/server/parser_test.go new file mode 100644 index 00000000..95631b08 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/parser_test.go @@ -0,0 +1,546 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "testing" +) + +func dummyClient() *client { + return &client{srv: New(&defaultServerOptions)} +} + +func dummyRouteClient() *client { + return &client{srv: New(&defaultServerOptions), typ: ROUTER} +} + +func TestParsePing(t *testing.T) { + c := dummyClient() + if c.state != OP_START { + t.Fatalf("Expected OP_START vs %d\n", c.state) + } + ping := []byte("PING\r\n") + err := c.parse(ping[:1]) + if err != nil || c.state != OP_P { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(ping[1:2]) + if err != nil || c.state != OP_PI { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(ping[2:3]) + if err != nil || c.state != OP_PIN { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(ping[3:4]) + if err != nil || c.state != OP_PING { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(ping[4:5]) + if err != nil || c.state != OP_PING { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(ping[5:6]) + if err != nil || c.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(ping) + if err != nil || c.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + // Should tolerate spaces + ping = []byte("PING \r") + err = c.parse(ping) + if err != nil || c.state != OP_PING { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + c.state = OP_START + ping = []byte("PING \r \n") + err = c.parse(ping) + if err != nil || c.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } +} + +func TestParsePong(t *testing.T) { + c := dummyClient() + if c.state != OP_START { + t.Fatalf("Expected OP_START vs %d\n", c.state) + } + pong := []byte("PONG\r\n") + err := c.parse(pong[:1]) + if err != nil || c.state != OP_P { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(pong[1:2]) + if err != nil || c.state != OP_PO { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(pong[2:3]) + if err != nil || c.state != OP_PON { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(pong[3:4]) + if err != nil || c.state != OP_PONG { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(pong[4:5]) + if err != nil || c.state != OP_PONG { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(pong[5:6]) + if err != nil || c.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + if c.ping.out != 0 { + t.Fatalf("Unexpected ping.out value: %d vs 0\n", c.ping.out) + } + err = c.parse(pong) + if err != nil || c.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + if c.ping.out != 0 { + t.Fatalf("Unexpected ping.out value: %d vs 0\n", c.ping.out) + } + // Should tolerate spaces + pong = []byte("PONG \r") + err = c.parse(pong) + if err != nil || c.state != OP_PONG { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + c.state = OP_START + pong = []byte("PONG \r \n") + err = c.parse(pong) + if err != nil || c.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + if c.ping.out != 0 { + t.Fatalf("Unexpected ping.out value: %d vs 0\n", c.ping.out) + } + + // Should be adjusting c.pout (Pings Outstanding): reset to 0 + c.state = OP_START + c.ping.out = 10 + pong = []byte("PONG\r\n") + err = c.parse(pong) + if err != nil || c.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + if c.ping.out != 0 { + t.Fatalf("Unexpected ping.out: %d vs 0\n", c.ping.out) + } +} + +func TestParseConnect(t *testing.T) { + c := dummyClient() + connect := []byte("CONNECT {\"verbose\":false,\"pedantic\":true,\"tls_required\":false}\r\n") + err := c.parse(connect) + if err != nil || c.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + // Check saved state + if c.as != 8 { + t.Fatalf("ArgStart state incorrect: 8 vs %d\n", c.as) + } +} + +func TestParseSub(t *testing.T) { + c := dummyClient() + sub := []byte("SUB foo 1\r") + err := c.parse(sub) + if err != nil || c.state != SUB_ARG { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + // Check saved state + if c.as != 4 { + t.Fatalf("ArgStart state incorrect: 4 vs %d\n", c.as) + } + if c.drop != 1 { + t.Fatalf("Drop state incorrect: 1 vs %d\n", c.as) + } + if !bytes.Equal(sub[c.as:], []byte("foo 1\r")) { + t.Fatalf("Arg state incorrect: %s\n", sub[c.as:]) + } +} + +func TestParsePub(t *testing.T) { + c := dummyClient() + + pub := []byte("PUB foo 5\r\nhello\r") + err := c.parse(pub) + if err != nil || c.state != MSG_END { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + if !bytes.Equal(c.pa.subject, []byte("foo")) { + t.Fatalf("Did not parse subject correctly: 'foo' vs '%s'\n", string(c.pa.subject)) + } + if c.pa.reply != nil { + t.Fatalf("Did not parse reply correctly: 'nil' vs '%s'\n", string(c.pa.reply)) + } + if c.pa.size != 5 { + t.Fatalf("Did not parse msg size correctly: 5 vs %d\n", c.pa.size) + } + + // Clear snapshots + c.argBuf, c.msgBuf, c.state = nil, nil, OP_START + + pub = []byte("PUB foo.bar INBOX.22 11\r\nhello world\r") + err = c.parse(pub) + if err != nil || c.state != MSG_END { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + if !bytes.Equal(c.pa.subject, []byte("foo.bar")) { + t.Fatalf("Did not parse subject correctly: 'foo' vs '%s'\n", string(c.pa.subject)) + } + if !bytes.Equal(c.pa.reply, []byte("INBOX.22")) { + t.Fatalf("Did not parse reply correctly: 'INBOX.22' vs '%s'\n", string(c.pa.reply)) + } + if c.pa.size != 11 { + t.Fatalf("Did not parse msg size correctly: 11 vs %d\n", c.pa.size) + } +} + +func TestParsePubArg(t *testing.T) { + c := dummyClient() + + for _, test := range []struct { + arg string + subject string + reply string + size int + szb string + }{ + {arg: "a 2", + subject: "a", reply: "", size: 2, szb: "2"}, + {arg: "a 222", + subject: "a", reply: "", size: 222, szb: "222"}, + {arg: "foo 22", + subject: "foo", reply: "", size: 22, szb: "22"}, + {arg: " foo 22", + subject: "foo", reply: "", size: 22, szb: "22"}, + {arg: "foo 22 ", + subject: "foo", reply: "", size: 22, szb: "22"}, + {arg: "foo 22", + subject: "foo", reply: "", size: 22, szb: "22"}, + {arg: " foo 22 ", + subject: "foo", reply: "", size: 22, szb: "22"}, + {arg: " foo 22 ", + subject: "foo", reply: "", size: 22, szb: "22"}, + {arg: "foo bar 22", + subject: "foo", reply: "bar", size: 22, szb: "22"}, + {arg: " foo bar 22", + subject: "foo", reply: "bar", size: 22, szb: "22"}, + {arg: "foo bar 22 ", + subject: "foo", reply: "bar", size: 22, szb: "22"}, + {arg: "foo bar 22", + subject: "foo", reply: "bar", size: 22, szb: "22"}, + {arg: " foo bar 22 ", + subject: "foo", reply: "bar", size: 22, szb: "22"}, + {arg: " foo bar 22 ", + subject: "foo", reply: "bar", size: 22, szb: "22"}, + {arg: " foo bar 2222 ", + subject: "foo", reply: "bar", size: 2222, szb: "2222"}, + {arg: " foo 2222 ", + subject: "foo", reply: "", size: 2222, szb: "2222"}, + {arg: "a\t2", + subject: "a", reply: "", size: 2, szb: "2"}, + {arg: "a\t222", + subject: "a", reply: "", size: 222, szb: "222"}, + {arg: "foo\t22", + subject: "foo", reply: "", size: 22, szb: "22"}, + {arg: "\tfoo\t22", + subject: "foo", reply: "", size: 22, szb: "22"}, + {arg: "foo\t22\t", + subject: "foo", reply: "", size: 22, szb: "22"}, + {arg: "foo\t\t\t22", + subject: "foo", reply: "", size: 22, szb: "22"}, + {arg: "\tfoo\t22\t", + subject: "foo", reply: "", size: 22, szb: "22"}, + {arg: "\tfoo\t\t\t22\t", + subject: "foo", reply: "", size: 22, szb: "22"}, + {arg: "foo\tbar\t22", + subject: "foo", reply: "bar", size: 22, szb: "22"}, + {arg: "\tfoo\tbar\t22", + subject: "foo", reply: "bar", size: 22, szb: "22"}, + {arg: "foo\tbar\t22\t", + subject: "foo", reply: "bar", size: 22, szb: "22"}, + {arg: "foo\t\tbar\t\t22", + subject: "foo", reply: "bar", size: 22, szb: "22"}, + {arg: "\tfoo\tbar\t22\t", + subject: "foo", reply: "bar", size: 22, szb: "22"}, + {arg: "\t \tfoo\t \t \tbar\t \t22\t \t", + subject: "foo", reply: "bar", size: 22, szb: "22"}, + {arg: "\t\tfoo\t\t\tbar\t\t2222\t\t", + subject: "foo", reply: "bar", size: 2222, szb: "2222"}, + {arg: "\t \tfoo\t \t \t\t\t2222\t \t", + subject: "foo", reply: "", size: 2222, szb: "2222"}, + } { + t.Run(test.arg, func(t *testing.T) { + if err := c.processPub([]byte(test.arg)); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if !bytes.Equal(c.pa.subject, []byte(test.subject)) { + t.Fatalf("Mismatched subject: '%s'\n", c.pa.subject) + } + if !bytes.Equal(c.pa.reply, []byte(test.reply)) { + t.Fatalf("Mismatched reply subject: '%s'\n", c.pa.reply) + } + if !bytes.Equal(c.pa.szb, []byte(test.szb)) { + t.Fatalf("Bad size buf: '%s'\n", c.pa.szb) + } + if c.pa.size != test.size { + t.Fatalf("Bad size: %d\n", c.pa.size) + } + }) + } +} + +func TestParsePubBadSize(t *testing.T) { + c := dummyClient() + // Setup localized max payload + c.mpay = 32768 + if err := c.processPub([]byte("foo 2222222222222222")); err == nil { + t.Fatalf("Expected parse error for size too large") + } +} + +func TestParseMsg(t *testing.T) { + c := dummyRouteClient() + + pub := []byte("MSG foo RSID:1:2 5\r\nhello\r") + err := c.parse(pub) + if err != nil || c.state != MSG_END { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + if !bytes.Equal(c.pa.subject, []byte("foo")) { + t.Fatalf("Did not parse subject correctly: 'foo' vs '%s'\n", c.pa.subject) + } + if c.pa.reply != nil { + t.Fatalf("Did not parse reply correctly: 'nil' vs '%s'\n", c.pa.reply) + } + if c.pa.size != 5 { + t.Fatalf("Did not parse msg size correctly: 5 vs %d\n", c.pa.size) + } + if !bytes.Equal(c.pa.sid, []byte("RSID:1:2")) { + t.Fatalf("Did not parse sid correctly: 'RSID:1:2' vs '%s'\n", c.pa.sid) + } + + // Clear snapshots + c.argBuf, c.msgBuf, c.state = nil, nil, OP_START + + pub = []byte("MSG foo.bar RSID:1:2 INBOX.22 11\r\nhello world\r") + err = c.parse(pub) + if err != nil || c.state != MSG_END { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + if !bytes.Equal(c.pa.subject, []byte("foo.bar")) { + t.Fatalf("Did not parse subject correctly: 'foo' vs '%s'\n", c.pa.subject) + } + if !bytes.Equal(c.pa.reply, []byte("INBOX.22")) { + t.Fatalf("Did not parse reply correctly: 'INBOX.22' vs '%s'\n", c.pa.reply) + } + if c.pa.size != 11 { + t.Fatalf("Did not parse msg size correctly: 11 vs %d\n", c.pa.size) + } +} + +func testMsgArg(c *client, t *testing.T) { + if !bytes.Equal(c.pa.subject, []byte("foobar")) { + t.Fatalf("Mismatched subject: '%s'\n", c.pa.subject) + } + if !bytes.Equal(c.pa.szb, []byte("22")) { + t.Fatalf("Bad size buf: '%s'\n", c.pa.szb) + } + if c.pa.size != 22 { + t.Fatalf("Bad size: %d\n", c.pa.size) + } + if !bytes.Equal(c.pa.sid, []byte("RSID:22:1")) { + t.Fatalf("Bad sid: '%s'\n", c.pa.sid) + } +} + +func TestParseMsgArg(t *testing.T) { + c := dummyClient() + if err := c.processMsgArgs([]byte("foobar RSID:22:1 22")); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + testMsgArg(c, t) + if err := c.processMsgArgs([]byte(" foobar RSID:22:1 22")); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + testMsgArg(c, t) + if err := c.processMsgArgs([]byte(" foobar RSID:22:1 22 ")); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + testMsgArg(c, t) + if err := c.processMsgArgs([]byte("foobar RSID:22:1 \t22")); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if err := c.processMsgArgs([]byte("foobar\t\tRSID:22:1\t22\r")); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + testMsgArg(c, t) +} + +func TestParseMsgSpace(t *testing.T) { + c := dummyRouteClient() + + // Ivan bug he found + if err := c.parse([]byte("MSG \r\n")); err == nil { + t.Fatalf("Expected parse error for MSG ") + } + + c = dummyClient() + + // Anything with an M from a client should parse error + if err := c.parse([]byte("M")); err == nil { + t.Fatalf("Expected parse error for M* from a client") + } +} + +func TestShouldFail(t *testing.T) { + wrongProtos := []string{ + "xxx", + "Px", "PIx", "PINx", " PING", + "POx", "PONx", + "+x", "+Ox", + "-x", "-Ex", "-ERx", "-ERRx", + "Cx", "COx", "CONx", "CONNx", "CONNEx", "CONNECx", "CONNECx", "CONNECT \r\n", + "PUx", "PUB foo\r\n", "PUB \r\n", "PUB foo bar \r\n", + "PUB foo 2\r\nok \r\n", "PUB foo 2\r\nok\r \n", + "Sx", "SUx", "SUB\r\n", "SUB \r\n", "SUB foo\r\n", + "SUB foo bar baz 22\r\n", + "Ux", "UNx", "UNSx", "UNSUx", "UNSUBx", "UNSUBUNSUB 1\r\n", "UNSUB_2\r\n", + "UNSUB_UNSUB_UNSUB 2\r\n", "UNSUB_\t2\r\n", "UNSUB\r\n", "UNSUB \r\n", + "UNSUB \t \r\n", + "Ix", "INx", "INFx", "INFO \r\n", + } + for _, proto := range wrongProtos { + c := dummyClient() + if err := c.parse([]byte(proto)); err == nil { + t.Fatalf("Should have received a parse error for: %v", proto) + } + } + + // Special case for MSG, type needs to not be client. + wrongProtos = []string{"Mx", "MSx", "MSGx", "MSG \r\n"} + for _, proto := range wrongProtos { + c := dummyClient() + c.typ = ROUTER + if err := c.parse([]byte(proto)); err == nil { + t.Fatalf("Should have received a parse error for: %v", proto) + } + } +} + +func TestProtoSnippet(t *testing.T) { + sample := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + + tests := []struct { + input int + expected string + }{ + {0, `"abcdefghijklmnopqrstuvwxyzABCDEF"`}, + {1, `"bcdefghijklmnopqrstuvwxyzABCDEFG"`}, + {2, `"cdefghijklmnopqrstuvwxyzABCDEFGH"`}, + {3, `"defghijklmnopqrstuvwxyzABCDEFGHI"`}, + {4, `"efghijklmnopqrstuvwxyzABCDEFGHIJ"`}, + {5, `"fghijklmnopqrstuvwxyzABCDEFGHIJK"`}, + {6, `"ghijklmnopqrstuvwxyzABCDEFGHIJKL"`}, + {7, `"hijklmnopqrstuvwxyzABCDEFGHIJKLM"`}, + {8, `"ijklmnopqrstuvwxyzABCDEFGHIJKLMN"`}, + {9, `"jklmnopqrstuvwxyzABCDEFGHIJKLMNO"`}, + {10, `"klmnopqrstuvwxyzABCDEFGHIJKLMNOP"`}, + {11, `"lmnopqrstuvwxyzABCDEFGHIJKLMNOPQ"`}, + {12, `"mnopqrstuvwxyzABCDEFGHIJKLMNOPQR"`}, + {13, `"nopqrstuvwxyzABCDEFGHIJKLMNOPQRS"`}, + {14, `"opqrstuvwxyzABCDEFGHIJKLMNOPQRST"`}, + {15, `"pqrstuvwxyzABCDEFGHIJKLMNOPQRSTU"`}, + {16, `"qrstuvwxyzABCDEFGHIJKLMNOPQRSTUV"`}, + {17, `"rstuvwxyzABCDEFGHIJKLMNOPQRSTUVW"`}, + {18, `"stuvwxyzABCDEFGHIJKLMNOPQRSTUVWX"`}, + {19, `"tuvwxyzABCDEFGHIJKLMNOPQRSTUVWXY"`}, + {20, `"uvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"`}, + {21, `"vwxyzABCDEFGHIJKLMNOPQRSTUVWXY"`}, + {22, `"wxyzABCDEFGHIJKLMNOPQRSTUVWXY"`}, + {23, `"xyzABCDEFGHIJKLMNOPQRSTUVWXY"`}, + {24, `"yzABCDEFGHIJKLMNOPQRSTUVWXY"`}, + {25, `"zABCDEFGHIJKLMNOPQRSTUVWXY"`}, + {26, `"ABCDEFGHIJKLMNOPQRSTUVWXY"`}, + {27, `"BCDEFGHIJKLMNOPQRSTUVWXY"`}, + {28, `"CDEFGHIJKLMNOPQRSTUVWXY"`}, + {29, `"DEFGHIJKLMNOPQRSTUVWXY"`}, + {30, `"EFGHIJKLMNOPQRSTUVWXY"`}, + {31, `"FGHIJKLMNOPQRSTUVWXY"`}, + {32, `"GHIJKLMNOPQRSTUVWXY"`}, + {33, `"HIJKLMNOPQRSTUVWXY"`}, + {34, `"IJKLMNOPQRSTUVWXY"`}, + {35, `"JKLMNOPQRSTUVWXY"`}, + {36, `"KLMNOPQRSTUVWXY"`}, + {37, `"LMNOPQRSTUVWXY"`}, + {38, `"MNOPQRSTUVWXY"`}, + {39, `"NOPQRSTUVWXY"`}, + {40, `"OPQRSTUVWXY"`}, + {41, `"PQRSTUVWXY"`}, + {42, `"QRSTUVWXY"`}, + {43, `"RSTUVWXY"`}, + {44, `"STUVWXY"`}, + {45, `"TUVWXY"`}, + {46, `"UVWXY"`}, + {47, `"VWXY"`}, + {48, `"WXY"`}, + {49, `"XY"`}, + {50, `"Y"`}, + {51, `""`}, + {52, `""`}, + {53, `""`}, + {54, `""`}, + } + + for _, tt := range tests { + got := protoSnippet(tt.input, sample) + if tt.expected != got { + t.Errorf("Expected protocol snippet to be %s when start=%d but got %s\n", tt.expected, tt.input, got) + } + } +} + +func TestParseOK(t *testing.T) { + c := dummyClient() + if c.state != OP_START { + t.Fatalf("Expected OP_START vs %d\n", c.state) + } + okProto := []byte("+OK\r\n") + err := c.parse(okProto[:1]) + if err != nil || c.state != OP_PLUS { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(okProto[1:2]) + if err != nil || c.state != OP_PLUS_O { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(okProto[2:3]) + if err != nil || c.state != OP_PLUS_OK { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(okProto[3:4]) + if err != nil || c.state != OP_PLUS_OK { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } + err = c.parse(okProto[4:5]) + if err != nil || c.state != OP_START { + t.Fatalf("Unexpected: %d : %v\n", c.state, err) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/ping_test.go b/vendor/github.com/nats-io/gnatsd/server/ping_test.go new file mode 100644 index 00000000..aa56f6b4 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/ping_test.go @@ -0,0 +1,44 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "testing" + "time" + + "github.com/nats-io/go-nats" +) + +const PING_CLIENT_PORT = 11228 + +var DefaultPingOptions = Options{ + Host: "127.0.0.1", + Port: PING_CLIENT_PORT, + NoLog: true, + NoSigs: true, + PingInterval: 5 * time.Millisecond, +} + +func TestPing(t *testing.T) { + s := RunServer(&DefaultPingOptions) + defer s.Shutdown() + + nc, err := nats.Connect(fmt.Sprintf("nats://127.0.0.1:%d", PING_CLIENT_PORT)) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc.Close() + time.Sleep(10 * time.Millisecond) +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_darwin.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_darwin.go new file mode 100644 index 00000000..b00f1e00 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_darwin.go @@ -0,0 +1,34 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pse + +import ( + "fmt" + "os" + "os/exec" +) + +// ProcUsage returns CPU usage +func ProcUsage(pcpu *float64, rss, vss *int64) error { + pidStr := fmt.Sprintf("%d", os.Getpid()) + out, err := exec.Command("ps", "o", "pcpu=,rss=,vsz=", "-p", pidStr).Output() + if err != nil { + *rss, *vss = -1, -1 + return fmt.Errorf("ps call failed:%v", err) + } + fmt.Sscanf(string(out), "%f %d %d", pcpu, rss, vss) + *rss *= 1024 // 1k blocks, want bytes. + *vss *= 1024 // 1k blocks, want bytes. + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_freebsd.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_freebsd.go new file mode 100644 index 00000000..40c52847 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_freebsd.go @@ -0,0 +1,83 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pse + +/* +#include +#include +#include +#include +#include + +long pagetok(long size) +{ + int pageshift, pagesize; + + pagesize = getpagesize(); + pageshift = 0; + + while (pagesize > 1) { + pageshift++; + pagesize >>= 1; + } + + return (size << pageshift); +} + +int getusage(double *pcpu, unsigned int *rss, unsigned int *vss) +{ + int mib[4], ret; + size_t len; + struct kinfo_proc kp; + + len = 4; + sysctlnametomib("kern.proc.pid", mib, &len); + + mib[3] = getpid(); + len = sizeof(kp); + + ret = sysctl(mib, 4, &kp, &len, NULL, 0); + if (ret != 0) { + return (errno); + } + + *rss = pagetok(kp.ki_rssize); + *vss = kp.ki_size; + *pcpu = kp.ki_pctcpu; + + return 0; +} + +*/ +import "C" + +import ( + "syscall" +) + +// This is a placeholder for now. +func ProcUsage(pcpu *float64, rss, vss *int64) error { + var r, v C.uint + var c C.double + + if ret := C.getusage(&c, &r, &v); ret != 0 { + return syscall.Errno(ret) + } + + *pcpu = float64(c) + *rss = int64(r) + *vss = int64(v) + + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_linux.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_linux.go new file mode 100644 index 00000000..9fea3e07 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_linux.go @@ -0,0 +1,126 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pse + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "sync/atomic" + "syscall" + "time" +) + +var ( + procStatFile string + ticks int64 + lastTotal int64 + lastSeconds int64 + ipcpu int64 +) + +const ( + utimePos = 13 + stimePos = 14 + startPos = 21 + vssPos = 22 + rssPos = 23 +) + +func init() { + // Avoiding to generate docker image without CGO + ticks = 100 // int64(C.sysconf(C._SC_CLK_TCK)) + procStatFile = fmt.Sprintf("/proc/%d/stat", os.Getpid()) + periodic() +} + +// Sampling function to keep pcpu relevant. +func periodic() { + contents, err := ioutil.ReadFile(procStatFile) + if err != nil { + return + } + fields := bytes.Fields(contents) + + // PCPU + pstart := parseInt64(fields[startPos]) + utime := parseInt64(fields[utimePos]) + stime := parseInt64(fields[stimePos]) + total := utime + stime + + var sysinfo syscall.Sysinfo_t + if err := syscall.Sysinfo(&sysinfo); err != nil { + return + } + + seconds := int64(sysinfo.Uptime) - (pstart / ticks) + + // Save off temps + lt := lastTotal + ls := lastSeconds + + // Update last sample + lastTotal = total + lastSeconds = seconds + + // Adjust to current time window + total -= lt + seconds -= ls + + if seconds > 0 { + atomic.StoreInt64(&ipcpu, (total*1000/ticks)/seconds) + } + + time.AfterFunc(1*time.Second, periodic) +} + +func ProcUsage(pcpu *float64, rss, vss *int64) error { + contents, err := ioutil.ReadFile(procStatFile) + if err != nil { + return err + } + fields := bytes.Fields(contents) + + // Memory + *rss = (parseInt64(fields[rssPos])) << 12 + *vss = parseInt64(fields[vssPos]) + + // PCPU + // We track this with periodic sampling, so just load and go. + *pcpu = float64(atomic.LoadInt64(&ipcpu)) / 10.0 + + return nil +} + +// Ascii numbers 0-9 +const ( + asciiZero = 48 + asciiNine = 57 +) + +// parseInt64 expects decimal positive numbers. We +// return -1 to signal error +func parseInt64(d []byte) (n int64) { + if len(d) == 0 { + return -1 + } + for _, dec := range d { + if dec < asciiZero || dec > asciiNine { + return -1 + } + n = n*10 + (int64(dec) - asciiZero) + } + return n +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_openbsd.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_openbsd.go new file mode 100644 index 00000000..260f1a7c --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_openbsd.go @@ -0,0 +1,36 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Copied from pse_darwin.go + +package pse + +import ( + "fmt" + "os" + "os/exec" +) + +// ProcUsage returns CPU usage +func ProcUsage(pcpu *float64, rss, vss *int64) error { + pidStr := fmt.Sprintf("%d", os.Getpid()) + out, err := exec.Command("ps", "o", "pcpu=,rss=,vsz=", "-p", pidStr).Output() + if err != nil { + *rss, *vss = -1, -1 + return fmt.Errorf("ps call failed:%v", err) + } + fmt.Sscanf(string(out), "%f %d %d", pcpu, rss, vss) + *rss *= 1024 // 1k blocks, want bytes. + *vss *= 1024 // 1k blocks, want bytes. + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_rumprun.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_rumprun.go new file mode 100644 index 00000000..48e80fca --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_rumprun.go @@ -0,0 +1,25 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build rumprun + +package pse + +// This is a placeholder for now. +func ProcUsage(pcpu *float64, rss, vss *int64) error { + *pcpu = 0.0 + *rss = 0 + *vss = 0 + + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_solaris.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_solaris.go new file mode 100644 index 00000000..8e40d2ed --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_solaris.go @@ -0,0 +1,23 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pse + +// This is a placeholder for now. +func ProcUsage(pcpu *float64, rss, vss *int64) error { + *pcpu = 0.0 + *rss = 0 + *vss = 0 + + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_test.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_test.go new file mode 100644 index 00000000..890f9645 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_test.go @@ -0,0 +1,67 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pse + +import ( + "fmt" + "os" + "os/exec" + "runtime" + "testing" +) + +func TestPSEmulation(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skipf("Skipping this test on Windows") + } + var rss, vss, psRss, psVss int64 + var pcpu, psPcpu float64 + + runtime.GC() + + // PS version first + pidStr := fmt.Sprintf("%d", os.Getpid()) + out, err := exec.Command("ps", "o", "pcpu=,rss=,vsz=", "-p", pidStr).Output() + if err != nil { + t.Fatalf("Failed to execute ps command: %v\n", err) + } + + fmt.Sscanf(string(out), "%f %d %d", &psPcpu, &psRss, &psVss) + psRss *= 1024 // 1k blocks, want bytes. + psVss *= 1024 // 1k blocks, want bytes. + + runtime.GC() + + // Our internal version + ProcUsage(&pcpu, &rss, &vss) + + if pcpu != psPcpu { + delta := int64(pcpu - psPcpu) + if delta < 0 { + delta = -delta + } + if delta > 30 { // 30%? + t.Fatalf("CPUs did not match close enough: %f vs %f", pcpu, psPcpu) + } + } + if rss != psRss { + delta := rss - psRss + if delta < 0 { + delta = -delta + } + if delta > 1024*1024 { // 1MB + t.Fatalf("RSSs did not match close enough: %d vs %d", rss, psRss) + } + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_windows.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_windows.go new file mode 100644 index 00000000..a8b11070 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_windows.go @@ -0,0 +1,280 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build windows + +package pse + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + "unsafe" +) + +var ( + pdh = syscall.NewLazyDLL("pdh.dll") + winPdhOpenQuery = pdh.NewProc("PdhOpenQuery") + winPdhAddCounter = pdh.NewProc("PdhAddCounterW") + winPdhCollectQueryData = pdh.NewProc("PdhCollectQueryData") + winPdhGetFormattedCounterValue = pdh.NewProc("PdhGetFormattedCounterValue") + winPdhGetFormattedCounterArray = pdh.NewProc("PdhGetFormattedCounterArrayW") +) + +// global performance counter query handle and counters +var ( + pcHandle PDH_HQUERY + pidCounter, cpuCounter, rssCounter, vssCounter PDH_HCOUNTER + prevCPU float64 + prevRss int64 + prevVss int64 + lastSampleTime time.Time + processPid int + pcQueryLock sync.Mutex + initialSample = true +) + +// maxQuerySize is the number of values to return from a query. +// It represents the maximum # of servers that can be queried +// simultaneously running on a machine. +const maxQuerySize = 512 + +// Keep static memory around to reuse; this works best for passing +// into the pdh API. +var counterResults [maxQuerySize]PDH_FMT_COUNTERVALUE_ITEM_DOUBLE + +// PDH Types +type ( + PDH_HQUERY syscall.Handle + PDH_HCOUNTER syscall.Handle +) + +// PDH constants used here +const ( + PDH_FMT_DOUBLE = 0x00000200 + PDH_INVALID_DATA = 0xC0000BC6 + PDH_MORE_DATA = 0x800007D2 +) + +// PDH_FMT_COUNTERVALUE_DOUBLE - double value +type PDH_FMT_COUNTERVALUE_DOUBLE struct { + CStatus uint32 + DoubleValue float64 +} + +// PDH_FMT_COUNTERVALUE_ITEM_DOUBLE is an array +// element of a double value +type PDH_FMT_COUNTERVALUE_ITEM_DOUBLE struct { + SzName *uint16 // pointer to a string + FmtValue PDH_FMT_COUNTERVALUE_DOUBLE +} + +func pdhAddCounter(hQuery PDH_HQUERY, szFullCounterPath string, dwUserData uintptr, phCounter *PDH_HCOUNTER) error { + ptxt, _ := syscall.UTF16PtrFromString(szFullCounterPath) + r0, _, _ := winPdhAddCounter.Call( + uintptr(hQuery), + uintptr(unsafe.Pointer(ptxt)), + dwUserData, + uintptr(unsafe.Pointer(phCounter))) + + if r0 != 0 { + return fmt.Errorf("pdhAddCounter failed. %d", r0) + } + return nil +} + +func pdhOpenQuery(datasrc *uint16, userdata uint32, query *PDH_HQUERY) error { + r0, _, _ := syscall.Syscall(winPdhOpenQuery.Addr(), 3, 0, uintptr(userdata), uintptr(unsafe.Pointer(query))) + if r0 != 0 { + return fmt.Errorf("pdhOpenQuery failed - %d", r0) + } + return nil +} + +func pdhCollectQueryData(hQuery PDH_HQUERY) error { + r0, _, _ := winPdhCollectQueryData.Call(uintptr(hQuery)) + if r0 != 0 { + return fmt.Errorf("pdhCollectQueryData failed - %d", r0) + } + return nil +} + +// pdhGetFormattedCounterArrayDouble returns the value of return code +// rather than error, to easily check return codes +func pdhGetFormattedCounterArrayDouble(hCounter PDH_HCOUNTER, lpdwBufferSize *uint32, lpdwBufferCount *uint32, itemBuffer *PDH_FMT_COUNTERVALUE_ITEM_DOUBLE) uint32 { + ret, _, _ := winPdhGetFormattedCounterArray.Call( + uintptr(hCounter), + uintptr(PDH_FMT_DOUBLE), + uintptr(unsafe.Pointer(lpdwBufferSize)), + uintptr(unsafe.Pointer(lpdwBufferCount)), + uintptr(unsafe.Pointer(itemBuffer))) + + return uint32(ret) +} + +func getCounterArrayData(counter PDH_HCOUNTER) ([]float64, error) { + var bufSize uint32 + var bufCount uint32 + + // Retrieving array data requires two calls, the first which + // requires an addressable empty buffer, and sets size fields. + // The second call returns the data. + initialBuf := make([]PDH_FMT_COUNTERVALUE_ITEM_DOUBLE, 1) + ret := pdhGetFormattedCounterArrayDouble(counter, &bufSize, &bufCount, &initialBuf[0]) + if ret == PDH_MORE_DATA { + // we'll likely never get here, but be safe. + if bufCount > maxQuerySize { + bufCount = maxQuerySize + } + ret = pdhGetFormattedCounterArrayDouble(counter, &bufSize, &bufCount, &counterResults[0]) + if ret == 0 { + rv := make([]float64, bufCount) + for i := 0; i < int(bufCount); i++ { + rv[i] = counterResults[i].FmtValue.DoubleValue + } + return rv, nil + } + } + if ret != 0 { + return nil, fmt.Errorf("getCounterArrayData failed - %d", ret) + } + + return nil, nil +} + +// getProcessImageName returns the name of the process image, as expected by +// the performance counter API. +func getProcessImageName() (name string) { + name = filepath.Base(os.Args[0]) + name = strings.TrimRight(name, ".exe") + return +} + +// initialize our counters +func initCounters() (err error) { + + processPid = os.Getpid() + // require an addressible nil pointer + var source uint16 + if err := pdhOpenQuery(&source, 0, &pcHandle); err != nil { + return err + } + + // setup the performance counters, search for all server instances + name := fmt.Sprintf("%s*", getProcessImageName()) + pidQuery := fmt.Sprintf("\\Process(%s)\\ID Process", name) + cpuQuery := fmt.Sprintf("\\Process(%s)\\%% Processor Time", name) + rssQuery := fmt.Sprintf("\\Process(%s)\\Working Set - Private", name) + vssQuery := fmt.Sprintf("\\Process(%s)\\Virtual Bytes", name) + + if err = pdhAddCounter(pcHandle, pidQuery, 0, &pidCounter); err != nil { + return err + } + if err = pdhAddCounter(pcHandle, cpuQuery, 0, &cpuCounter); err != nil { + return err + } + if err = pdhAddCounter(pcHandle, rssQuery, 0, &rssCounter); err != nil { + return err + } + if err = pdhAddCounter(pcHandle, vssQuery, 0, &vssCounter); err != nil { + return err + } + + // prime the counters by collecting once, and sleep to get somewhat + // useful information the first request. Counters for the CPU require + // at least two collect calls. + if err = pdhCollectQueryData(pcHandle); err != nil { + return err + } + time.Sleep(50) + + return nil +} + +// ProcUsage returns process CPU and memory statistics +func ProcUsage(pcpu *float64, rss, vss *int64) error { + var err error + + // For simplicity, protect the entire call. + // Most simultaneous requests will immediately return + // with cached values. + pcQueryLock.Lock() + defer pcQueryLock.Unlock() + + // First time through, initialize counters. + if initialSample { + if err = initCounters(); err != nil { + return err + } + initialSample = false + } else if time.Since(lastSampleTime) < (2 * time.Second) { + // only refresh every two seconds as to minimize impact + // on the server. + *pcpu = prevCPU + *rss = prevRss + *vss = prevVss + return nil + } + + // always save the sample time, even on errors. + defer func() { + lastSampleTime = time.Now() + }() + + // refresh the performance counter data + if err = pdhCollectQueryData(pcHandle); err != nil { + return err + } + + // retrieve the data + var pidAry, cpuAry, rssAry, vssAry []float64 + if pidAry, err = getCounterArrayData(pidCounter); err != nil { + return err + } + if cpuAry, err = getCounterArrayData(cpuCounter); err != nil { + return err + } + if rssAry, err = getCounterArrayData(rssCounter); err != nil { + return err + } + if vssAry, err = getCounterArrayData(vssCounter); err != nil { + return err + } + // find the index of the entry for this process + idx := int(-1) + for i := range pidAry { + if int(pidAry[i]) == processPid { + idx = i + break + } + } + // no pid found... + if idx < 0 { + return fmt.Errorf("could not find pid in performance counter results") + } + // assign values from the performance counters + *pcpu = cpuAry[idx] + *rss = int64(rssAry[idx]) + *vss = int64(vssAry[idx]) + + // save off cache values + prevCPU = *pcpu + prevRss = *rss + prevVss = *vss + + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/pse/pse_windows_test.go b/vendor/github.com/nats-io/gnatsd/server/pse/pse_windows_test.go new file mode 100644 index 00000000..bae1dd58 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/pse/pse_windows_test.go @@ -0,0 +1,97 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build windows + +package pse + +import ( + "fmt" + "os/exec" + "runtime" + "strconv" + "strings" + "testing" +) + +func checkValues(t *testing.T, pcpu, tPcpu float64, rss, tRss int64) { + if pcpu != tPcpu { + delta := int64(pcpu - tPcpu) + if delta < 0 { + delta = -delta + } + if delta > 30 { // 30%? + t.Fatalf("CPUs did not match close enough: %f vs %f", pcpu, tPcpu) + } + } + if rss != tRss { + delta := rss - tRss + if delta < 0 { + delta = -delta + } + if delta > 1024*1024 { // 1MB + t.Fatalf("RSSs did not match close enough: %d vs %d", rss, tRss) + } + } +} + +func TestPSEmulationWin(t *testing.T) { + var pcpu, tPcpu float64 + var rss, vss, tRss int64 + + runtime.GC() + + if err := ProcUsage(&pcpu, &rss, &vss); err != nil { + t.Fatalf("Error: %v", err) + } + + runtime.GC() + + imageName := getProcessImageName() + // query the counters using typeperf + out, err := exec.Command("typeperf.exe", + fmt.Sprintf("\\Process(%s)\\%% Processor Time", imageName), + fmt.Sprintf("\\Process(%s)\\Working Set - Private", imageName), + fmt.Sprintf("\\Process(%s)\\Virtual Bytes", imageName), + "-sc", "1").Output() + if err != nil { + t.Fatal("unable to run command", err) + } + + // parse out results - refer to comments in procUsage for detail + results := strings.Split(string(out), "\r\n") + values := strings.Split(results[2], ",") + + // parse pcpu + tPcpu, err = strconv.ParseFloat(strings.Trim(values[1], "\""), 64) + if err != nil { + t.Fatalf("Unable to parse percent cpu: %s", values[1]) + } + + // parse private bytes (rss) + fval, err := strconv.ParseFloat(strings.Trim(values[2], "\""), 64) + if err != nil { + t.Fatalf("Unable to parse private bytes: %s", values[2]) + } + tRss = int64(fval) + + checkValues(t, pcpu, tPcpu, rss, tRss) + + runtime.GC() + + // Again to test caching + if err = ProcUsage(&pcpu, &rss, &vss); err != nil { + t.Fatalf("Error: %v", err) + } + checkValues(t, pcpu, tPcpu, rss, tRss) +} diff --git a/vendor/github.com/nats-io/gnatsd/server/reload.go b/vendor/github.com/nats-io/gnatsd/server/reload.go new file mode 100644 index 00000000..8839cf99 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/reload.go @@ -0,0 +1,714 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "crypto/tls" + "errors" + "fmt" + "net/url" + "reflect" + "strings" + "sync/atomic" + "time" +) + +// FlagSnapshot captures the server options as specified by CLI flags at +// startup. This should not be modified once the server has started. +var FlagSnapshot *Options + +// option is a hot-swappable configuration setting. +type option interface { + // Apply the server option. + Apply(server *Server) + + // IsLoggingChange indicates if this option requires reloading the logger. + IsLoggingChange() bool + + // IsAuthChange indicates if this option requires reloading authorization. + IsAuthChange() bool +} + +// loggingOption is a base struct that provides default option behaviors for +// logging-related options. +type loggingOption struct{} + +func (l loggingOption) IsLoggingChange() bool { + return true +} + +func (l loggingOption) IsAuthChange() bool { + return false +} + +// traceOption implements the option interface for the `trace` setting. +type traceOption struct { + loggingOption + newValue bool +} + +// Apply is a no-op because logging will be reloaded after options are applied. +func (t *traceOption) Apply(server *Server) { + server.Noticef("Reloaded: trace = %v", t.newValue) +} + +// debugOption implements the option interface for the `debug` setting. +type debugOption struct { + loggingOption + newValue bool +} + +// Apply is a no-op because logging will be reloaded after options are applied. +func (d *debugOption) Apply(server *Server) { + server.Noticef("Reloaded: debug = %v", d.newValue) +} + +// logtimeOption implements the option interface for the `logtime` setting. +type logtimeOption struct { + loggingOption + newValue bool +} + +// Apply is a no-op because logging will be reloaded after options are applied. +func (l *logtimeOption) Apply(server *Server) { + server.Noticef("Reloaded: logtime = %v", l.newValue) +} + +// logfileOption implements the option interface for the `log_file` setting. +type logfileOption struct { + loggingOption + newValue string +} + +// Apply is a no-op because logging will be reloaded after options are applied. +func (l *logfileOption) Apply(server *Server) { + server.Noticef("Reloaded: log_file = %v", l.newValue) +} + +// syslogOption implements the option interface for the `syslog` setting. +type syslogOption struct { + loggingOption + newValue bool +} + +// Apply is a no-op because logging will be reloaded after options are applied. +func (s *syslogOption) Apply(server *Server) { + server.Noticef("Reloaded: syslog = %v", s.newValue) +} + +// remoteSyslogOption implements the option interface for the `remote_syslog` +// setting. +type remoteSyslogOption struct { + loggingOption + newValue string +} + +// Apply is a no-op because logging will be reloaded after options are applied. +func (r *remoteSyslogOption) Apply(server *Server) { + server.Noticef("Reloaded: remote_syslog = %v", r.newValue) +} + +// noopOption is a base struct that provides default no-op behaviors. +type noopOption struct{} + +func (n noopOption) IsLoggingChange() bool { + return false +} + +func (n noopOption) IsAuthChange() bool { + return false +} + +// tlsOption implements the option interface for the `tls` setting. +type tlsOption struct { + noopOption + newValue *tls.Config +} + +// Apply the tls change. +func (t *tlsOption) Apply(server *Server) { + server.mu.Lock() + tlsRequired := t.newValue != nil + server.info.TLSRequired = tlsRequired + message := "disabled" + if tlsRequired { + server.info.TLSVerify = (t.newValue.ClientAuth == tls.RequireAndVerifyClientCert) + message = "enabled" + } + server.mu.Unlock() + server.Noticef("Reloaded: tls = %s", message) +} + +// tlsTimeoutOption implements the option interface for the tls `timeout` +// setting. +type tlsTimeoutOption struct { + noopOption + newValue float64 +} + +// Apply is a no-op because the timeout will be reloaded after options are +// applied. +func (t *tlsTimeoutOption) Apply(server *Server) { + server.Noticef("Reloaded: tls timeout = %v", t.newValue) +} + +// authOption is a base struct that provides default option behaviors. +type authOption struct{} + +func (o authOption) IsLoggingChange() bool { + return false +} + +func (o authOption) IsAuthChange() bool { + return true +} + +// usernameOption implements the option interface for the `username` setting. +type usernameOption struct { + authOption +} + +// Apply is a no-op because authorization will be reloaded after options are +// applied. +func (u *usernameOption) Apply(server *Server) { + server.Noticef("Reloaded: authorization username") +} + +// passwordOption implements the option interface for the `password` setting. +type passwordOption struct { + authOption +} + +// Apply is a no-op because authorization will be reloaded after options are +// applied. +func (p *passwordOption) Apply(server *Server) { + server.Noticef("Reloaded: authorization password") +} + +// authorizationOption implements the option interface for the `token` +// authorization setting. +type authorizationOption struct { + authOption +} + +// Apply is a no-op because authorization will be reloaded after options are +// applied. +func (a *authorizationOption) Apply(server *Server) { + server.Noticef("Reloaded: authorization token") +} + +// authTimeoutOption implements the option interface for the authorization +// `timeout` setting. +type authTimeoutOption struct { + noopOption // Not authOption because this is a no-op; will be reloaded with options. + newValue float64 +} + +// Apply is a no-op because the timeout will be reloaded after options are +// applied. +func (a *authTimeoutOption) Apply(server *Server) { + server.Noticef("Reloaded: authorization timeout = %v", a.newValue) +} + +// usersOption implements the option interface for the authorization `users` +// setting. +type usersOption struct { + authOption + newValue []*User +} + +func (u *usersOption) Apply(server *Server) { + server.Noticef("Reloaded: authorization users") +} + +// clusterOption implements the option interface for the `cluster` setting. +type clusterOption struct { + authOption + newValue ClusterOpts +} + +// Apply the cluster change. +func (c *clusterOption) Apply(server *Server) { + // TODO: support enabling/disabling clustering. + server.mu.Lock() + tlsRequired := c.newValue.TLSConfig != nil + server.routeInfo.TLSRequired = tlsRequired + server.routeInfo.TLSVerify = tlsRequired + server.routeInfo.AuthRequired = c.newValue.Username != "" + if c.newValue.NoAdvertise { + server.routeInfo.ClientConnectURLs = nil + } else { + server.routeInfo.ClientConnectURLs = server.clientConnectURLs + } + server.setRouteInfoHostPortAndIP() + server.mu.Unlock() + server.Noticef("Reloaded: cluster") +} + +// routesOption implements the option interface for the cluster `routes` +// setting. +type routesOption struct { + noopOption + add []*url.URL + remove []*url.URL +} + +// Apply the route changes by adding and removing the necessary routes. +func (r *routesOption) Apply(server *Server) { + server.mu.Lock() + routes := make([]*client, len(server.routes)) + i := 0 + for _, client := range server.routes { + routes[i] = client + i++ + } + server.mu.Unlock() + + // Remove routes. + for _, remove := range r.remove { + for _, client := range routes { + if client.route.url == remove { + // Do not attempt to reconnect when route is removed. + client.setRouteNoReconnectOnClose() + client.closeConnection(RouteRemoved) + server.Noticef("Removed route %v", remove) + } + } + } + + // Add routes. + server.solicitRoutes(r.add) + + server.Noticef("Reloaded: cluster routes") +} + +// maxConnOption implements the option interface for the `max_connections` +// setting. +type maxConnOption struct { + noopOption + newValue int +} + +// Apply the max connections change by closing random connections til we are +// below the limit if necessary. +func (m *maxConnOption) Apply(server *Server) { + server.mu.Lock() + var ( + clients = make([]*client, len(server.clients)) + i = 0 + ) + // Map iteration is random, which allows us to close random connections. + for _, client := range server.clients { + clients[i] = client + i++ + } + server.mu.Unlock() + + if m.newValue > 0 && len(clients) > m.newValue { + // Close connections til we are within the limit. + var ( + numClose = len(clients) - m.newValue + closed = 0 + ) + for _, client := range clients { + client.maxConnExceeded() + closed++ + if closed >= numClose { + break + } + } + server.Noticef("Closed %d connections to fall within max_connections", closed) + } + server.Noticef("Reloaded: max_connections = %v", m.newValue) +} + +// pidFileOption implements the option interface for the `pid_file` setting. +type pidFileOption struct { + noopOption + newValue string +} + +// Apply the setting by logging the pid to the new file. +func (p *pidFileOption) Apply(server *Server) { + if p.newValue == "" { + return + } + if err := server.logPid(); err != nil { + server.Errorf("Failed to write pidfile: %v", err) + } + server.Noticef("Reloaded: pid_file = %v", p.newValue) +} + +// portsFileDirOption implements the option interface for the `portFileDir` setting. +type portsFileDirOption struct { + noopOption + oldValue string + newValue string +} + +func (p *portsFileDirOption) Apply(server *Server) { + server.deletePortsFile(p.oldValue) + server.logPorts() + server.Noticef("Reloaded: ports_file_dir = %v", p.newValue) +} + +// maxControlLineOption implements the option interface for the +// `max_control_line` setting. +type maxControlLineOption struct { + noopOption + newValue int +} + +// Apply is a no-op because the max control line will be reloaded after options +// are applied +func (m *maxControlLineOption) Apply(server *Server) { + server.Noticef("Reloaded: max_control_line = %d", m.newValue) +} + +// maxPayloadOption implements the option interface for the `max_payload` +// setting. +type maxPayloadOption struct { + noopOption + newValue int +} + +// Apply the setting by updating the server info and each client. +func (m *maxPayloadOption) Apply(server *Server) { + server.mu.Lock() + server.info.MaxPayload = m.newValue + for _, client := range server.clients { + atomic.StoreInt64(&client.mpay, int64(m.newValue)) + } + server.mu.Unlock() + server.Noticef("Reloaded: max_payload = %d", m.newValue) +} + +// pingIntervalOption implements the option interface for the `ping_interval` +// setting. +type pingIntervalOption struct { + noopOption + newValue time.Duration +} + +// Apply is a no-op because the ping interval will be reloaded after options +// are applied. +func (p *pingIntervalOption) Apply(server *Server) { + server.Noticef("Reloaded: ping_interval = %s", p.newValue) +} + +// maxPingsOutOption implements the option interface for the `ping_max` +// setting. +type maxPingsOutOption struct { + noopOption + newValue int +} + +// Apply is a no-op because the ping interval will be reloaded after options +// are applied. +func (m *maxPingsOutOption) Apply(server *Server) { + server.Noticef("Reloaded: ping_max = %d", m.newValue) +} + +// writeDeadlineOption implements the option interface for the `write_deadline` +// setting. +type writeDeadlineOption struct { + noopOption + newValue time.Duration +} + +// Apply is a no-op because the write deadline will be reloaded after options +// are applied. +func (w *writeDeadlineOption) Apply(server *Server) { + server.Noticef("Reloaded: write_deadline = %s", w.newValue) +} + +// clientAdvertiseOption implements the option interface for the `client_advertise` setting. +type clientAdvertiseOption struct { + noopOption + newValue string +} + +// Apply the setting by updating the server info and regenerate the infoJSON byte array. +func (c *clientAdvertiseOption) Apply(server *Server) { + server.mu.Lock() + server.setInfoHostPortAndGenerateJSON() + server.mu.Unlock() + server.Noticef("Reload: client_advertise = %s", c.newValue) +} + +// Reload reads the current configuration file and applies any supported +// changes. This returns an error if the server was not started with a config +// file or an option which doesn't support hot-swapping was changed. +func (s *Server) Reload() error { + s.mu.Lock() + if s.configFile == "" { + s.mu.Unlock() + return errors.New("Can only reload config when a file is provided using -c or --config") + } + newOpts, err := ProcessConfigFile(s.configFile) + if err != nil { + s.mu.Unlock() + // TODO: Dump previous good config to a .bak file? + return err + } + clientOrgPort := s.clientActualPort + clusterOrgPort := s.clusterActualPort + s.mu.Unlock() + + // Apply flags over config file settings. + newOpts = MergeOptions(newOpts, FlagSnapshot) + processOptions(newOpts) + + // processOptions sets Port to 0 if set to -1 (RANDOM port) + // If that's the case, set it to the saved value when the accept loop was + // created. + if newOpts.Port == 0 { + newOpts.Port = clientOrgPort + } + // We don't do that for cluster, so check against -1. + if newOpts.Cluster.Port == -1 { + newOpts.Cluster.Port = clusterOrgPort + } + + if err := s.reloadOptions(newOpts); err != nil { + return err + } + s.mu.Lock() + s.configTime = time.Now() + s.mu.Unlock() + return nil +} + +// reloadOptions reloads the server config with the provided options. If an +// option that doesn't support hot-swapping is changed, this returns an error. +func (s *Server) reloadOptions(newOpts *Options) error { + changed, err := s.diffOptions(newOpts) + if err != nil { + return err + } + s.setOpts(newOpts) + s.applyOptions(changed) + return nil +} + +// diffOptions returns a slice containing options which have been changed. If +// an option that doesn't support hot-swapping is changed, this returns an +// error. +func (s *Server) diffOptions(newOpts *Options) ([]option, error) { + var ( + oldConfig = reflect.ValueOf(s.getOpts()).Elem() + newConfig = reflect.ValueOf(newOpts).Elem() + diffOpts = []option{} + ) + + for i := 0; i < oldConfig.NumField(); i++ { + var ( + field = oldConfig.Type().Field(i) + oldValue = oldConfig.Field(i).Interface() + newValue = newConfig.Field(i).Interface() + changed = !reflect.DeepEqual(oldValue, newValue) + ) + if !changed { + continue + } + switch strings.ToLower(field.Name) { + case "trace": + diffOpts = append(diffOpts, &traceOption{newValue: newValue.(bool)}) + case "debug": + diffOpts = append(diffOpts, &debugOption{newValue: newValue.(bool)}) + case "logtime": + diffOpts = append(diffOpts, &logtimeOption{newValue: newValue.(bool)}) + case "logfile": + diffOpts = append(diffOpts, &logfileOption{newValue: newValue.(string)}) + case "syslog": + diffOpts = append(diffOpts, &syslogOption{newValue: newValue.(bool)}) + case "remotesyslog": + diffOpts = append(diffOpts, &remoteSyslogOption{newValue: newValue.(string)}) + case "tlsconfig": + diffOpts = append(diffOpts, &tlsOption{newValue: newValue.(*tls.Config)}) + case "tlstimeout": + diffOpts = append(diffOpts, &tlsTimeoutOption{newValue: newValue.(float64)}) + case "username": + diffOpts = append(diffOpts, &usernameOption{}) + case "password": + diffOpts = append(diffOpts, &passwordOption{}) + case "authorization": + diffOpts = append(diffOpts, &authorizationOption{}) + case "authtimeout": + diffOpts = append(diffOpts, &authTimeoutOption{newValue: newValue.(float64)}) + case "users": + diffOpts = append(diffOpts, &usersOption{newValue: newValue.([]*User)}) + case "cluster": + newClusterOpts := newValue.(ClusterOpts) + if err := validateClusterOpts(oldValue.(ClusterOpts), newClusterOpts); err != nil { + return nil, err + } + diffOpts = append(diffOpts, &clusterOption{newValue: newClusterOpts}) + case "routes": + add, remove := diffRoutes(oldValue.([]*url.URL), newValue.([]*url.URL)) + diffOpts = append(diffOpts, &routesOption{add: add, remove: remove}) + case "maxconn": + diffOpts = append(diffOpts, &maxConnOption{newValue: newValue.(int)}) + case "pidfile": + diffOpts = append(diffOpts, &pidFileOption{newValue: newValue.(string)}) + case "portsfiledir": + diffOpts = append(diffOpts, &portsFileDirOption{newValue: newValue.(string), oldValue: oldValue.(string)}) + case "maxcontrolline": + diffOpts = append(diffOpts, &maxControlLineOption{newValue: newValue.(int)}) + case "maxpayload": + diffOpts = append(diffOpts, &maxPayloadOption{newValue: newValue.(int)}) + case "pinginterval": + diffOpts = append(diffOpts, &pingIntervalOption{newValue: newValue.(time.Duration)}) + case "maxpingsout": + diffOpts = append(diffOpts, &maxPingsOutOption{newValue: newValue.(int)}) + case "writedeadline": + diffOpts = append(diffOpts, &writeDeadlineOption{newValue: newValue.(time.Duration)}) + case "clientadvertise": + cliAdv := newValue.(string) + if cliAdv != "" { + // Validate ClientAdvertise syntax + if _, _, err := parseHostPort(cliAdv, 0); err != nil { + return nil, fmt.Errorf("invalid ClientAdvertise value of %s, err=%v", cliAdv, err) + } + } + diffOpts = append(diffOpts, &clientAdvertiseOption{newValue: cliAdv}) + case "nolog", "nosigs": + // Ignore NoLog and NoSigs options since they are not parsed and only used in + // testing. + continue + case "port": + // check to see if newValue == 0 and continue if so. + if newValue == 0 { + // ignore RANDOM_PORT + continue + } + fallthrough + default: + // Bail out if attempting to reload any unsupported options. + return nil, fmt.Errorf("Config reload not supported for %s: old=%v, new=%v", + field.Name, oldValue, newValue) + } + } + + return diffOpts, nil +} + +func (s *Server) applyOptions(opts []option) { + var ( + reloadLogging = false + reloadAuth = false + ) + for _, opt := range opts { + opt.Apply(s) + if opt.IsLoggingChange() { + reloadLogging = true + } + if opt.IsAuthChange() { + reloadAuth = true + } + } + + if reloadLogging { + s.ConfigureLogger() + } + if reloadAuth { + s.reloadAuthorization() + } + + s.Noticef("Reloaded server configuration") +} + +// reloadAuthorization reconfigures the server authorization settings, +// disconnects any clients who are no longer authorized, and removes any +// unauthorized subscriptions. +func (s *Server) reloadAuthorization() { + s.mu.Lock() + s.configureAuthorization() + clients := make(map[uint64]*client, len(s.clients)) + for i, client := range s.clients { + clients[i] = client + } + routes := make(map[uint64]*client, len(s.routes)) + for i, route := range s.routes { + routes[i] = route + } + s.mu.Unlock() + + for _, client := range clients { + // Disconnect any unauthorized clients. + if !s.isClientAuthorized(client) { + client.authViolation() + continue + } + + // Remove any unauthorized subscriptions. + s.removeUnauthorizedSubs(client) + } + + for _, client := range routes { + // Disconnect any unauthorized routes. + if !s.isRouterAuthorized(client) { + client.setRouteNoReconnectOnClose() + client.authViolation() + } + } +} + +// validateClusterOpts ensures the new ClusterOpts does not change host or +// port, which do not support reload. +func validateClusterOpts(old, new ClusterOpts) error { + if old.Host != new.Host { + return fmt.Errorf("Config reload not supported for cluster host: old=%s, new=%s", + old.Host, new.Host) + } + if old.Port != new.Port { + return fmt.Errorf("Config reload not supported for cluster port: old=%d, new=%d", + old.Port, new.Port) + } + // Validate Cluster.Advertise syntax + if new.Advertise != "" { + if _, _, err := parseHostPort(new.Advertise, 0); err != nil { + return fmt.Errorf("invalid Cluster.Advertise value of %s, err=%v", new.Advertise, err) + } + } + return nil +} + +// diffRoutes diffs the old routes and the new routes and returns the ones that +// should be added and removed from the server. +func diffRoutes(old, new []*url.URL) (add, remove []*url.URL) { + // Find routes to remove. +removeLoop: + for _, oldRoute := range old { + for _, newRoute := range new { + if oldRoute == newRoute { + continue removeLoop + } + } + remove = append(remove, oldRoute) + } + + // Find routes to add. +addLoop: + for _, newRoute := range new { + for _, oldRoute := range old { + if oldRoute == newRoute { + continue addLoop + } + } + add = append(add, newRoute) + } + + return add, remove +} diff --git a/vendor/github.com/nats-io/gnatsd/server/reload_test.go b/vendor/github.com/nats-io/gnatsd/server/reload_test.go new file mode 100644 index 00000000..eda457a0 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/reload_test.go @@ -0,0 +1,1901 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net" + "os" + "path/filepath" + "reflect" + "runtime" + "strings" + "testing" + "time" + + "github.com/nats-io/go-nats" +) + +// Ensure Reload returns an error when attempting to reload a server that did +// not start with a config file. +func TestConfigReloadNoConfigFile(t *testing.T) { + server := New(&Options{NoSigs: true}) + loaded := server.ConfigTime() + if server.Reload() == nil { + t.Fatal("Expected Reload to return an error") + } + if reloaded := server.ConfigTime(); reloaded != loaded { + t.Fatalf("ConfigTime is incorrect.\nexpected: %s\ngot: %s", loaded, reloaded) + } +} + +// Ensure Reload returns an error when attempting to change an option which +// does not support reloading. +func TestConfigReloadUnsupported(t *testing.T) { + server, opts, config := newServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/test.conf") + defer os.Remove(config) + defer server.Shutdown() + + loaded := server.ConfigTime() + + golden := &Options{ + ConfigFile: config, + Host: "0.0.0.0", + Port: 2233, + AuthTimeout: 1.0, + Debug: false, + Trace: false, + Logtime: false, + MaxControlLine: 1024, + MaxPayload: 1048576, + MaxConn: 65536, + PingInterval: 2 * time.Minute, + MaxPingsOut: 2, + WriteDeadline: 2 * time.Second, + Cluster: ClusterOpts{ + Host: "127.0.0.1", + Port: -1, + }, + NoSigs: true, + } + processOptions(golden) + + if !reflect.DeepEqual(golden, server.getOpts()) { + t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, opts) + } + + // Change config file to bad config by replacing symlink. + createSymlink(t, config, "./configs/reload/reload_unsupported.conf") + + // This should fail because `cluster` host cannot be changed. + if err := server.Reload(); err == nil { + t.Fatal("Expected Reload to return an error") + } + + // Ensure config didn't change. + if !reflect.DeepEqual(golden, server.getOpts()) { + t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, opts) + } + + if reloaded := server.ConfigTime(); reloaded != loaded { + t.Fatalf("ConfigTime is incorrect.\nexpected: %s\ngot: %s", loaded, reloaded) + } +} + +// This checks that if we change an option that does not support hot-swapping +// we get an error. Using `listen` for now (test may need to be updated if +// server is changed to support change of listen spec). +func TestConfigReloadUnsupportedHotSwapping(t *testing.T) { + orgConfig := "tmp_a.conf" + newConfig := "tmp_b.conf" + defer os.Remove(orgConfig) + defer os.Remove(newConfig) + if err := ioutil.WriteFile(orgConfig, []byte("listen: 127.0.0.1:-1"), 0666); err != nil { + t.Fatalf("Error creating config file: %v", err) + } + if err := ioutil.WriteFile(newConfig, []byte("listen: 127.0.0.1:9999"), 0666); err != nil { + t.Fatalf("Error creating config file: %v", err) + } + + server, _, config := newServerWithSymlinkConfig(t, "tmp.conf", orgConfig) + defer os.Remove(config) + defer server.Shutdown() + + loaded := server.ConfigTime() + + time.Sleep(time.Millisecond) + + // Change config file with unsupported option hot-swap + createSymlink(t, config, newConfig) + + // This should fail because `listen` host cannot be changed. + if err := server.Reload(); err == nil || !strings.Contains(err.Error(), "not supported") { + t.Fatalf("Expected Reload to return a not supported error, got %v", err) + } + + if reloaded := server.ConfigTime(); reloaded != loaded { + t.Fatalf("ConfigTime is incorrect.\nexpected: %s\ngot: %s", loaded, reloaded) + } +} + +// Ensure Reload returns an error when reloading from a bad config file. +func TestConfigReloadInvalidConfig(t *testing.T) { + server, opts, config := newServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/test.conf") + defer os.Remove(config) + defer server.Shutdown() + + loaded := server.ConfigTime() + + golden := &Options{ + ConfigFile: config, + Host: "0.0.0.0", + Port: 2233, + AuthTimeout: 1.0, + Debug: false, + Trace: false, + Logtime: false, + MaxControlLine: 1024, + MaxPayload: 1048576, + MaxConn: 65536, + PingInterval: 2 * time.Minute, + MaxPingsOut: 2, + WriteDeadline: 2 * time.Second, + Cluster: ClusterOpts{ + Host: "127.0.0.1", + Port: -1, + }, + NoSigs: true, + } + processOptions(golden) + + if !reflect.DeepEqual(golden, server.getOpts()) { + t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, opts) + } + + // Change config file to bad config by replacing symlink. + createSymlink(t, config, "./configs/reload/invalid.conf") + + // This should fail because the new config should not parse. + if err := server.Reload(); err == nil { + t.Fatal("Expected Reload to return an error") + } + + // Ensure config didn't change. + if !reflect.DeepEqual(golden, server.getOpts()) { + t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, opts) + } + + if reloaded := server.ConfigTime(); reloaded != loaded { + t.Fatalf("ConfigTime is incorrect.\nexpected: %s\ngot: %s", loaded, reloaded) + } +} + +// Ensure Reload returns nil and the config is changed on success. +func TestConfigReload(t *testing.T) { + var content []byte + if runtime.GOOS != "windows" { + content = []byte(` + remote_syslog: "udp://127.0.0.1:514" # change on reload + log_file: "/tmp/gnatsd-2.log" # change on reload + `) + } + platformConf := "platform.conf" + defer os.Remove(platformConf) + if err := ioutil.WriteFile(platformConf, content, 0666); err != nil { + t.Fatalf("Unable to write config file: %v", err) + } + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/test.conf") + defer os.Remove(config) + defer server.Shutdown() + + loaded := server.ConfigTime() + + golden := &Options{ + ConfigFile: config, + Host: "0.0.0.0", + Port: 2233, + AuthTimeout: 1.0, + Debug: false, + Trace: false, + NoLog: true, + Logtime: false, + MaxControlLine: 1024, + MaxPayload: 1048576, + MaxConn: 65536, + PingInterval: 2 * time.Minute, + MaxPingsOut: 2, + WriteDeadline: 2 * time.Second, + Cluster: ClusterOpts{ + Host: "127.0.0.1", + Port: server.ClusterAddr().Port, + }, + NoSigs: true, + } + processOptions(golden) + + if !reflect.DeepEqual(golden, opts) { + t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, opts) + } + + // Change config file to new config by replacing symlink. + createSymlink(t, config, "./configs/reload/reload.conf") + + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure config changed. + updated := server.getOpts() + if !updated.Trace { + t.Fatal("Expected Trace to be true") + } + if !updated.Debug { + t.Fatal("Expected Debug to be true") + } + if !updated.Logtime { + t.Fatal("Expected Logtime to be true") + } + if !updated.Syslog { + t.Fatal("Expected Syslog to be true") + } + if runtime.GOOS != "windows" { + if updated.RemoteSyslog != "udp://127.0.0.1:514" { + t.Fatalf("RemoteSyslog is incorrect.\nexpected: udp://127.0.0.1:514\ngot: %s", updated.RemoteSyslog) + } + if updated.LogFile != "/tmp/gnatsd-2.log" { + t.Fatalf("LogFile is incorrect.\nexpected: /tmp/gnatsd-2.log\ngot: %s", updated.LogFile) + } + } + if updated.TLSConfig == nil { + t.Fatal("Expected TLSConfig to be non-nil") + } + if !server.info.TLSRequired { + t.Fatal("Expected TLSRequired to be true") + } + if !server.info.TLSVerify { + t.Fatal("Expected TLSVerify to be true") + } + if updated.Username != "tyler" { + t.Fatalf("Username is incorrect.\nexpected: tyler\ngot: %s", updated.Username) + } + if updated.Password != "T0pS3cr3t" { + t.Fatalf("Password is incorrect.\nexpected: T0pS3cr3t\ngot: %s", updated.Password) + } + if updated.AuthTimeout != 2 { + t.Fatalf("AuthTimeout is incorrect.\nexpected: 2\ngot: %f", updated.AuthTimeout) + } + if !server.info.AuthRequired { + t.Fatal("Expected AuthRequired to be true") + } + if !updated.Cluster.NoAdvertise { + t.Fatal("Expected NoAdvertise to be true") + } + if updated.PidFile != "/tmp/gnatsd.pid" { + t.Fatalf("PidFile is incorrect.\nexpected: /tmp/gnatsd.pid\ngot: %s", updated.PidFile) + } + if updated.MaxControlLine != 512 { + t.Fatalf("MaxControlLine is incorrect.\nexpected: 512\ngot: %d", updated.MaxControlLine) + } + if updated.PingInterval != 5*time.Second { + t.Fatalf("PingInterval is incorrect.\nexpected 5s\ngot: %s", updated.PingInterval) + } + if updated.MaxPingsOut != 1 { + t.Fatalf("MaxPingsOut is incorrect.\nexpected 1\ngot: %d", updated.MaxPingsOut) + } + if updated.WriteDeadline != 3*time.Second { + t.Fatalf("WriteDeadline is incorrect.\nexpected 3s\ngot: %s", updated.WriteDeadline) + } + if updated.MaxPayload != 1024 { + t.Fatalf("MaxPayload is incorrect.\nexpected 1024\ngot: %d", updated.MaxPayload) + } + + if reloaded := server.ConfigTime(); !reloaded.After(loaded) { + t.Fatalf("ConfigTime is incorrect.\nexpected greater than: %s\ngot: %s", loaded, reloaded) + } +} + +// Ensure Reload supports TLS config changes. Test this by starting a server +// with TLS enabled, connect to it to verify, reload config using a different +// key pair and client verification enabled, ensure reconnect fails, then +// ensure reconnect succeeds when the client provides a cert. +func TestConfigReloadRotateTLS(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/tls_test.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Ensure we can connect as a sanity check. + addr := fmt.Sprintf("nats://%s:%d", opts.Host, server.Addr().(*net.TCPAddr).Port) + + nc, err := nats.Connect(addr, nats.Secure()) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc.Close() + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + defer sub.Unsubscribe() + + // Rotate cert and enable client verification. + createSymlink(t, config, "./configs/reload/tls_verify_test.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure connecting fails. + if _, err := nats.Connect(addr, nats.Secure()); err == nil { + t.Fatal("Expected connect to fail") + } + + // Ensure connecting succeeds when client presents cert. + cert := nats.ClientCert("./configs/certs/cert.new.pem", "./configs/certs/key.new.pem") + conn, err := nats.Connect(addr, cert, nats.RootCAs("./configs/certs/cert.new.pem")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + conn.Close() + + // Ensure the original connection can still publish/receive. + if err := nc.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + nc.Flush() + msg, err := sub.NextMsg(2 * time.Second) + if err != nil { + t.Fatalf("Error receiving msg: %v", err) + } + if string(msg.Data) != "hello" { + t.Fatalf("Msg is incorrect.\nexpected: %+v\ngot: %+v", []byte("hello"), msg.Data) + } +} + +// Ensure Reload supports enabling TLS. Test this by starting a server without +// TLS enabled, connect to it to verify, reload config with TLS enabled, ensure +// reconnect fails, then ensure reconnect succeeds when using secure. +func TestConfigReloadEnableTLS(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/basic.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Ensure we can connect as a sanity check. + addr := fmt.Sprintf("nats://%s:%d", opts.Host, server.Addr().(*net.TCPAddr).Port) + nc, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + nc.Close() + + // Enable TLS. + createSymlink(t, config, "./configs/reload/tls_test.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure connecting fails. + if _, err := nats.Connect(addr); err == nil { + t.Fatal("Expected connect to fail") + } + + // Ensure connecting succeeds when using secure. + nc, err = nats.Connect(addr, nats.Secure()) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + nc.Close() +} + +// Ensure Reload supports disabling TLS. Test this by starting a server with +// TLS enabled, connect to it to verify, reload config with TLS disabled, +// ensure reconnect fails, then ensure reconnect succeeds when connecting +// without secure. +func TestConfigReloadDisableTLS(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/tls_test.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Ensure we can connect as a sanity check. + addr := fmt.Sprintf("nats://%s:%d", opts.Host, server.Addr().(*net.TCPAddr).Port) + nc, err := nats.Connect(addr, nats.Secure()) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + nc.Close() + + // Disable TLS. + createSymlink(t, config, "./configs/reload/basic.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure connecting fails. + if _, err := nats.Connect(addr, nats.Secure()); err == nil { + t.Fatal("Expected connect to fail") + } + + // Ensure connecting succeeds when not using secure. + nc, err = nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + nc.Close() +} + +// Ensure Reload supports single user authentication config changes. Test this +// by starting a server with authentication enabled, connect to it to verify, +// reload config using a different username/password, ensure reconnect fails, +// then ensure reconnect succeeds when using the correct credentials. +func TestConfigReloadRotateUserAuthentication(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", + "./configs/reload/single_user_authentication_1.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Ensure we can connect as a sanity check. + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(addr, nats.UserInfo("tyler", "T0pS3cr3t")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc.Close() + disconnected := make(chan struct{}) + asyncErr := make(chan error) + nc.SetErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) { + asyncErr <- err + }) + nc.SetDisconnectHandler(func(*nats.Conn) { + disconnected <- struct{}{} + }) + + // Change user credentials. + createSymlink(t, config, "./configs/reload/single_user_authentication_2.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure connecting fails. + if _, err := nats.Connect(addr, nats.UserInfo("tyler", "T0pS3cr3t")); err == nil { + t.Fatal("Expected connect to fail") + } + + // Ensure connecting succeeds when using new credentials. + conn, err := nats.Connect(addr, nats.UserInfo("derek", "passw0rd")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + conn.Close() + + // Ensure the previous connection received an authorization error. + select { + case err := <-asyncErr: + if err != nats.ErrAuthorization { + t.Fatalf("Expected ErrAuthorization, got %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("Expected authorization error") + } + + // Ensure the previous connection was disconnected. + select { + case <-disconnected: + case <-time.After(2 * time.Second): + t.Fatal("Expected connection to be disconnected") + } +} + +// Ensure Reload supports enabling single user authentication. Test this by +// starting a server with authentication disabled, connect to it to verify, +// reload config using with a username/password, ensure reconnect fails, then +// ensure reconnect succeeds when using the correct credentials. +func TestConfigReloadEnableUserAuthentication(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/basic.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Ensure we can connect as a sanity check. + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc.Close() + disconnected := make(chan struct{}) + asyncErr := make(chan error) + nc.SetErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) { + asyncErr <- err + }) + nc.SetDisconnectHandler(func(*nats.Conn) { + disconnected <- struct{}{} + }) + + // Enable authentication. + createSymlink(t, config, "./configs/reload/single_user_authentication_1.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure connecting fails. + if _, err := nats.Connect(addr); err == nil { + t.Fatal("Expected connect to fail") + } + + // Ensure connecting succeeds when using new credentials. + conn, err := nats.Connect(addr, nats.UserInfo("tyler", "T0pS3cr3t")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + conn.Close() + + // Ensure the previous connection received an authorization error. + select { + case err := <-asyncErr: + if err != nats.ErrAuthorization { + t.Fatalf("Expected ErrAuthorization, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Expected authorization error") + } + + // Ensure the previous connection was disconnected. + select { + case <-disconnected: + case <-time.After(2 * time.Second): + t.Fatal("Expected connection to be disconnected") + } +} + +// Ensure Reload supports disabling single user authentication. Test this by +// starting a server with authentication enabled, connect to it to verify, +// reload config using with authentication disabled, then ensure connecting +// with no credentials succeeds. +func TestConfigReloadDisableUserAuthentication(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", + "./configs/reload/single_user_authentication_1.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Ensure we can connect as a sanity check. + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(addr, nats.UserInfo("tyler", "T0pS3cr3t")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc.Close() + nc.SetErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) { + t.Fatalf("Client received an unexpected error: %v", err) + }) + + // Disable authentication. + createSymlink(t, config, "./configs/reload/basic.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure connecting succeeds with no credentials. + conn, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + conn.Close() +} + +// Ensure Reload supports token authentication config changes. Test this by +// starting a server with token authentication enabled, connect to it to +// verify, reload config using a different token, ensure reconnect fails, then +// ensure reconnect succeeds when using the correct token. +func TestConfigReloadRotateTokenAuthentication(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/token_authentication_1.conf") + defer os.Remove(config) + defer server.Shutdown() + + disconnected := make(chan struct{}) + asyncErr := make(chan error) + eh := func(nc *nats.Conn, sub *nats.Subscription, err error) { asyncErr <- err } + dh := func(*nats.Conn) { disconnected <- struct{}{} } + + // Ensure we can connect as a sanity check. + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(addr, nats.Token("T0pS3cr3t"), nats.ErrorHandler(eh), nats.DisconnectHandler(dh)) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc.Close() + + // Change authentication token. + createSymlink(t, config, "./configs/reload/token_authentication_2.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure connecting fails. + if _, err := nats.Connect(addr, nats.Token("T0pS3cr3t")); err == nil { + t.Fatal("Expected connect to fail") + } + + // Ensure connecting succeeds when using new credentials. + conn, err := nats.Connect(addr, nats.Token("passw0rd")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + conn.Close() + + // Ensure the previous connection received an authorization error. + select { + case err := <-asyncErr: + if err != nats.ErrAuthorization { + t.Fatalf("Expected ErrAuthorization, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Expected authorization error") + } + + // Ensure the previous connection was disconnected. + select { + case <-disconnected: + case <-time.After(2 * time.Second): + t.Fatal("Expected connection to be disconnected") + } +} + +// Ensure Reload supports enabling token authentication. Test this by starting +// a server with authentication disabled, connect to it to verify, reload +// config using with a token, ensure reconnect fails, then ensure reconnect +// succeeds when using the correct token. +func TestConfigReloadEnableTokenAuthentication(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/basic.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Ensure we can connect as a sanity check. + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc.Close() + disconnected := make(chan struct{}) + asyncErr := make(chan error) + nc.SetErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) { + asyncErr <- err + }) + nc.SetDisconnectHandler(func(*nats.Conn) { + disconnected <- struct{}{} + }) + + // Enable authentication. + createSymlink(t, config, "./configs/reload/token_authentication_1.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure connecting fails. + if _, err := nats.Connect(addr); err == nil { + t.Fatal("Expected connect to fail") + } + + // Ensure connecting succeeds when using new credentials. + conn, err := nats.Connect(addr, nats.Token("T0pS3cr3t")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + conn.Close() + + // Ensure the previous connection received an authorization error. + select { + case err := <-asyncErr: + if err != nats.ErrAuthorization { + t.Fatalf("Expected ErrAuthorization, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Expected authorization error") + } + + // Ensure the previous connection was disconnected. + select { + case <-disconnected: + case <-time.After(2 * time.Second): + t.Fatal("Expected connection to be disconnected") + } +} + +// Ensure Reload supports disabling single token authentication. Test this by +// starting a server with authentication enabled, connect to it to verify, +// reload config using with authentication disabled, then ensure connecting +// with no token succeeds. +func TestConfigReloadDisableTokenAuthentication(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/token_authentication_1.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Ensure we can connect as a sanity check. + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(addr, nats.Token("T0pS3cr3t")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc.Close() + nc.SetErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) { + t.Fatalf("Client received an unexpected error: %v", err) + }) + + // Disable authentication. + createSymlink(t, config, "./configs/reload/basic.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure connecting succeeds with no credentials. + conn, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + conn.Close() +} + +// Ensure Reload supports users authentication config changes. Test this by +// starting a server with users authentication enabled, connect to it to +// verify, reload config using a different user, ensure reconnect fails, then +// ensure reconnect succeeds when using the correct credentials. +func TestConfigReloadRotateUsersAuthentication(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/multiple_users_1.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Ensure we can connect as a sanity check. + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(addr, nats.UserInfo("alice", "foo")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc.Close() + disconnected := make(chan struct{}) + asyncErr := make(chan error) + nc.SetErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) { + asyncErr <- err + }) + nc.SetDisconnectHandler(func(*nats.Conn) { + disconnected <- struct{}{} + }) + + // These credentials won't change. + nc2, err := nats.Connect(addr, nats.UserInfo("bob", "bar")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc2.Close() + sub, err := nc2.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + defer sub.Unsubscribe() + + // Change users credentials. + createSymlink(t, config, "./configs/reload/multiple_users_2.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure connecting fails. + if _, err := nats.Connect(addr, nats.UserInfo("alice", "foo")); err == nil { + t.Fatal("Expected connect to fail") + } + + // Ensure connecting succeeds when using new credentials. + conn, err := nats.Connect(addr, nats.UserInfo("alice", "baz")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + conn.Close() + + // Ensure the previous connection received an authorization error. + select { + case err := <-asyncErr: + if err != nats.ErrAuthorization { + t.Fatalf("Expected ErrAuthorization, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Expected authorization error") + } + + // Ensure the previous connection was disconnected. + select { + case <-disconnected: + case <-time.After(2 * time.Second): + t.Fatal("Expected connection to be disconnected") + } + + // Ensure the connection using unchanged credentials can still + // publish/receive. + if err := nc2.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + nc2.Flush() + msg, err := sub.NextMsg(2 * time.Second) + if err != nil { + t.Fatalf("Error receiving msg: %v", err) + } + if string(msg.Data) != "hello" { + t.Fatalf("Msg is incorrect.\nexpected: %+v\ngot: %+v", []byte("hello"), msg.Data) + } +} + +// Ensure Reload supports enabling users authentication. Test this by starting +// a server with authentication disabled, connect to it to verify, reload +// config using with users, ensure reconnect fails, then ensure reconnect +// succeeds when using the correct credentials. +func TestConfigReloadEnableUsersAuthentication(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/basic.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Ensure we can connect as a sanity check. + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc.Close() + disconnected := make(chan struct{}) + asyncErr := make(chan error) + nc.SetErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) { + asyncErr <- err + }) + nc.SetDisconnectHandler(func(*nats.Conn) { + disconnected <- struct{}{} + }) + + // Enable authentication. + createSymlink(t, config, "./configs/reload/multiple_users_1.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure connecting fails. + if _, err := nats.Connect(addr); err == nil { + t.Fatal("Expected connect to fail") + } + + // Ensure connecting succeeds when using new credentials. + conn, err := nats.Connect(addr, nats.UserInfo("alice", "foo")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + conn.Close() + + // Ensure the previous connection received an authorization error. + select { + case err := <-asyncErr: + if err != nats.ErrAuthorization { + t.Fatalf("Expected ErrAuthorization, got %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("Expected authorization error") + } + + // Ensure the previous connection was disconnected. + select { + case <-disconnected: + case <-time.After(5 * time.Second): + t.Fatal("Expected connection to be disconnected") + } +} + +// Ensure Reload supports disabling users authentication. Test this by starting +// a server with authentication enabled, connect to it to verify, +// reload config using with authentication disabled, then ensure connecting +// with no credentials succeeds. +func TestConfigReloadDisableUsersAuthentication(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/multiple_users_1.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Ensure we can connect as a sanity check. + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(addr, nats.UserInfo("alice", "foo")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc.Close() + nc.SetErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) { + t.Fatalf("Client received an unexpected error: %v", err) + }) + + // Disable authentication. + createSymlink(t, config, "./configs/reload/basic.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure connecting succeeds with no credentials. + conn, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + conn.Close() +} + +// Ensure Reload supports changing permissions. Test this by starting a server +// with a user configured with certain permissions, test publish and subscribe, +// reload config with new permissions, ensure the previous subscription was +// closed and publishes fail, then ensure the new permissions succeed. +func TestConfigReloadChangePermissions(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/authorization_1.conf") + defer os.Remove(config) + defer server.Shutdown() + + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(addr, nats.UserInfo("bob", "bar")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc.Close() + asyncErr := make(chan error) + nc.SetErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) { + asyncErr <- err + }) + // Ensure we can publish and receive messages as a sanity check. + sub, err := nc.SubscribeSync("_INBOX.>") + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + nc.Flush() + + conn, err := nats.Connect(addr, nats.UserInfo("alice", "foo")) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer conn.Close() + + sub2, err := conn.SubscribeSync("req.foo") + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + if err := conn.Publish("_INBOX.foo", []byte("hello")); err != nil { + t.Fatalf("Error publishing message: %v", err) + } + conn.Flush() + + msg, err := sub.NextMsg(2 * time.Second) + if err != nil { + t.Fatalf("Error receiving msg: %v", err) + } + if string(msg.Data) != "hello" { + t.Fatalf("Msg is incorrect.\nexpected: %+v\ngot: %+v", []byte("hello"), msg.Data) + } + + if err := nc.Publish("req.foo", []byte("world")); err != nil { + t.Fatalf("Error publishing message: %v", err) + } + nc.Flush() + + msg, err = sub2.NextMsg(2 * time.Second) + if err != nil { + t.Fatalf("Error receiving msg: %v", err) + } + if string(msg.Data) != "world" { + t.Fatalf("Msg is incorrect.\nexpected: %+v\ngot: %+v", []byte("world"), msg.Data) + } + + // Change permissions. + createSymlink(t, config, "./configs/reload/authorization_2.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure we receive an error for the subscription that is no longer + // authorized. + select { + case err := <-asyncErr: + if !strings.Contains(err.Error(), "permissions violation for subscription to \"_inbox.>\"") { + t.Fatalf("Expected permissions violation error, got %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("Expected permissions violation error") + } + + // Ensure we receive an error when publishing to req.foo and we no longer + // receive messages on _INBOX.>. + if err := nc.Publish("req.foo", []byte("hola")); err != nil { + t.Fatalf("Error publishing message: %v", err) + } + nc.Flush() + if err := conn.Publish("_INBOX.foo", []byte("mundo")); err != nil { + t.Fatalf("Error publishing message: %v", err) + } + conn.Flush() + + select { + case err := <-asyncErr: + if !strings.Contains(err.Error(), "permissions violation for publish to \"req.foo\"") { + t.Fatalf("Expected permissions violation error, got %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("Expected permissions violation error") + } + + queued, _, err := sub2.Pending() + if err != nil { + t.Fatalf("Failed to get pending messaged: %v", err) + } + if queued != 0 { + t.Fatalf("Pending is incorrect.\nexpected: 0\ngot: %d", queued) + } + + queued, _, err = sub.Pending() + if err != nil { + t.Fatalf("Failed to get pending messaged: %v", err) + } + if queued != 0 { + t.Fatalf("Pending is incorrect.\nexpected: 0\ngot: %d", queued) + } + + // Ensure we can publish to _INBOX.foo.bar and subscribe to _INBOX.foo.>. + sub, err = nc.SubscribeSync("_INBOX.foo.>") + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + nc.Flush() + if err := nc.Publish("_INBOX.foo.bar", []byte("testing")); err != nil { + t.Fatalf("Error publishing message: %v", err) + } + nc.Flush() + msg, err = sub.NextMsg(2 * time.Second) + if err != nil { + t.Fatalf("Error receiving msg: %v", err) + } + if string(msg.Data) != "testing" { + t.Fatalf("Msg is incorrect.\nexpected: %+v\ngot: %+v", []byte("testing"), msg.Data) + } + + select { + case err := <-asyncErr: + t.Fatalf("Received unexpected error: %v", err) + default: + } +} + +// Ensure Reload returns an error when attempting to change cluster address +// host. +func TestConfigReloadClusterHostUnsupported(t *testing.T) { + server, _, config := newServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/srv_a_1.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Attempt to change cluster listen host. + createSymlink(t, config, "./configs/reload/srv_c_1.conf") + + // This should fail because cluster address cannot be changed. + if err := server.Reload(); err == nil { + t.Fatal("Expected Reload to return an error") + } +} + +// Ensure Reload returns an error when attempting to change cluster address +// port. +func TestConfigReloadClusterPortUnsupported(t *testing.T) { + server, _, config := newServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/srv_a_1.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Attempt to change cluster listen port. + createSymlink(t, config, "./configs/reload/srv_b_1.conf") + + // This should fail because cluster address cannot be changed. + if err := server.Reload(); err == nil { + t.Fatal("Expected Reload to return an error") + } +} + +// Ensure Reload supports enabling route authorization. Test this by starting +// two servers in a cluster without authorization, ensuring messages flow +// between them, then reloading with authorization and ensuring messages no +// longer flow until reloading with the correct credentials. +func TestConfigReloadEnableClusterAuthorization(t *testing.T) { + srvb, srvbOpts, srvbConfig := runServerWithSymlinkConfig(t, "tmp_b.conf", "./configs/reload/srv_b_1.conf") + defer os.Remove(srvbConfig) + defer srvb.Shutdown() + + srva, srvaOpts, srvaConfig := runServerWithSymlinkConfig(t, "tmp_a.conf", "./configs/reload/srv_a_1.conf") + defer os.Remove(srvaConfig) + defer srva.Shutdown() + + checkClusterFormed(t, srva, srvb) + + srvaAddr := fmt.Sprintf("nats://%s:%d", srvaOpts.Host, srvaOpts.Port) + srvaConn, err := nats.Connect(srvaAddr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer srvaConn.Close() + sub, err := srvaConn.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + defer sub.Unsubscribe() + if err := srvaConn.Flush(); err != nil { + t.Fatalf("Error flushing: %v", err) + } + + srvbAddr := fmt.Sprintf("nats://%s:%d", srvbOpts.Host, srvbOpts.Port) + srvbConn, err := nats.Connect(srvbAddr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer srvbConn.Close() + + if numRoutes := srvb.NumRoutes(); numRoutes != 1 { + t.Fatalf("Expected 1 route, got %d", numRoutes) + } + + // Ensure messages flow through the cluster as a sanity check. + if err := srvbConn.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + srvbConn.Flush() + msg, err := sub.NextMsg(2 * time.Second) + if err != nil { + t.Fatalf("Error receiving message: %v", err) + } + if string(msg.Data) != "hello" { + t.Fatalf("Msg is incorrect.\nexpected: %+v\ngot: %+v", []byte("hello"), msg.Data) + } + + // Enable route authorization. + createSymlink(t, srvbConfig, "./configs/reload/srv_b_2.conf") + if err := srvb.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + if numRoutes := srvb.NumRoutes(); numRoutes != 0 { + t.Fatalf("Expected 0 routes, got %d", numRoutes) + } + + // Ensure messages no longer flow through the cluster. + for i := 0; i < 5; i++ { + if err := srvbConn.Publish("foo", []byte("world")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + srvbConn.Flush() + } + if _, err := sub.NextMsg(50 * time.Millisecond); err != nats.ErrTimeout { + t.Fatalf("Expected ErrTimeout, got %v", err) + } + + // Reload Server A with correct route credentials. + createSymlink(t, srvaConfig, "./configs/reload/srv_a_2.conf") + defer os.Remove(srvaConfig) + if err := srva.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + checkClusterFormed(t, srva, srvb) + + if numRoutes := srvb.NumRoutes(); numRoutes != 1 { + t.Fatalf("Expected 1 route, got %d", numRoutes) + } + + // Ensure messages flow through the cluster now. + if err := srvbConn.Publish("foo", []byte("hola")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + srvbConn.Flush() + msg, err = sub.NextMsg(2 * time.Second) + if err != nil { + t.Fatalf("Error receiving message: %v", err) + } + if string(msg.Data) != "hola" { + t.Fatalf("Msg is incorrect.\nexpected: %+v\ngot: %+v", []byte("hola"), msg.Data) + } +} + +// Ensure Reload supports disabling route authorization. Test this by starting +// two servers in a cluster with authorization, ensuring messages flow +// between them, then reloading without authorization and ensuring messages +// still flow. +func TestConfigReloadDisableClusterAuthorization(t *testing.T) { + srvb, srvbOpts, srvbConfig := runServerWithSymlinkConfig(t, "tmp_b.conf", "./configs/reload/srv_b_2.conf") + defer os.Remove(srvbConfig) + defer srvb.Shutdown() + + srva, srvaOpts, srvaConfig := runServerWithSymlinkConfig(t, "tmp_a.conf", "./configs/reload/srv_a_2.conf") + defer os.Remove(srvaConfig) + defer srva.Shutdown() + + checkClusterFormed(t, srva, srvb) + + srvaAddr := fmt.Sprintf("nats://%s:%d", srvaOpts.Host, srvaOpts.Port) + srvaConn, err := nats.Connect(srvaAddr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer srvaConn.Close() + + sub, err := srvaConn.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + defer sub.Unsubscribe() + if err := srvaConn.Flush(); err != nil { + t.Fatalf("Error flushing: %v", err) + } + + srvbAddr := fmt.Sprintf("nats://%s:%d", srvbOpts.Host, srvbOpts.Port) + srvbConn, err := nats.Connect(srvbAddr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer srvbConn.Close() + + if numRoutes := srvb.NumRoutes(); numRoutes != 1 { + t.Fatalf("Expected 1 route, got %d", numRoutes) + } + + // Ensure messages flow through the cluster as a sanity check. + if err := srvbConn.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + srvbConn.Flush() + msg, err := sub.NextMsg(2 * time.Second) + if err != nil { + t.Fatalf("Error receiving message: %v", err) + } + if string(msg.Data) != "hello" { + t.Fatalf("Msg is incorrect.\nexpected: %+v\ngot: %+v", []byte("hello"), msg.Data) + } + + // Disable route authorization. + createSymlink(t, srvbConfig, "./configs/reload/srv_b_1.conf") + if err := srvb.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + checkClusterFormed(t, srva, srvb) + + if numRoutes := srvb.NumRoutes(); numRoutes != 1 { + t.Fatalf("Expected 1 route, got %d", numRoutes) + } + + // Ensure messages still flow through the cluster. + if err := srvbConn.Publish("foo", []byte("hola")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + srvbConn.Flush() + msg, err = sub.NextMsg(2 * time.Second) + if err != nil { + t.Fatalf("Error receiving message: %v", err) + } + if string(msg.Data) != "hola" { + t.Fatalf("Msg is incorrect.\nexpected: %+v\ngot: %+v", []byte("hola"), msg.Data) + } +} + +// Ensure Reload supports changing cluster routes. Test this by starting +// two servers in a cluster, ensuring messages flow between them, then +// reloading with a different route and ensuring messages flow through the new +// cluster. +func TestConfigReloadClusterRoutes(t *testing.T) { + srvb, srvbOpts, srvbConfig := runServerWithSymlinkConfig(t, "tmp_b.conf", "./configs/reload/srv_b_1.conf") + defer os.Remove(srvbConfig) + defer srvb.Shutdown() + + srva, srvaOpts, srvaConfig := runServerWithSymlinkConfig(t, "tmp_a.conf", "./configs/reload/srv_a_1.conf") + defer os.Remove(srvaConfig) + defer srva.Shutdown() + + checkClusterFormed(t, srva, srvb) + + srvcOpts, err := ProcessConfigFile("./configs/reload/srv_c_1.conf") + if err != nil { + t.Fatalf("Error processing config file: %v", err) + } + srvcOpts.NoLog = true + srvcOpts.NoSigs = true + + srvc := RunServer(srvcOpts) + defer srvc.Shutdown() + + srvaAddr := fmt.Sprintf("nats://%s:%d", srvaOpts.Host, srvaOpts.Port) + srvaConn, err := nats.Connect(srvaAddr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer srvaConn.Close() + + sub, err := srvaConn.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + defer sub.Unsubscribe() + if err := srvaConn.Flush(); err != nil { + t.Fatalf("Error flushing: %v", err) + } + + srvbAddr := fmt.Sprintf("nats://%s:%d", srvbOpts.Host, srvbOpts.Port) + srvbConn, err := nats.Connect(srvbAddr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer srvbConn.Close() + + if numRoutes := srvb.NumRoutes(); numRoutes != 1 { + t.Fatalf("Expected 1 route, got %d", numRoutes) + } + + // Ensure messages flow through the cluster as a sanity check. + if err := srvbConn.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + srvbConn.Flush() + msg, err := sub.NextMsg(2 * time.Second) + if err != nil { + t.Fatalf("Error receiving message: %v", err) + } + if string(msg.Data) != "hello" { + t.Fatalf("Msg is incorrect.\nexpected: %+v\ngot: %+v", []byte("hello"), msg.Data) + } + + // Reload cluster routes. + createSymlink(t, srvaConfig, "./configs/reload/srv_a_3.conf") + if err := srva.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Kill old route server. + srvbConn.Close() + srvb.Shutdown() + + checkClusterFormed(t, srva, srvc) + + srvcAddr := fmt.Sprintf("nats://%s:%d", srvcOpts.Host, srvcOpts.Port) + srvcConn, err := nats.Connect(srvcAddr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer srvcConn.Close() + + // Ensure messages flow through the new cluster. + for i := 0; i < 5; i++ { + if err := srvcConn.Publish("foo", []byte("hola")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + srvcConn.Flush() + } + msg, err = sub.NextMsg(2 * time.Second) + if err != nil { + t.Fatalf("Error receiving message: %v", err) + } + if string(msg.Data) != "hola" { + t.Fatalf("Msg is incorrect.\nexpected: %+v\ngot: %+v", []byte("hola"), msg.Data) + } +} + +// Ensure Reload supports removing a solicited route. In this case from A->B +// Test this by starting two servers in a cluster, ensuring messages flow between them. +// Then stop server B, and have server A continue to try to connect. Reload A with a config +// that removes the route and make sure it does not connect to server B when its restarted. +func TestConfigReloadClusterRemoveSolicitedRoutes(t *testing.T) { + srvb, srvbOpts := RunServerWithConfig("./configs/reload/srv_b_1.conf") + defer srvb.Shutdown() + + srva, srvaOpts, srvaConfig := runServerWithSymlinkConfig(t, "tmp_a.conf", "./configs/reload/srv_a_1.conf") + defer os.Remove(srvaConfig) + defer srva.Shutdown() + + checkClusterFormed(t, srva, srvb) + + srvaAddr := fmt.Sprintf("nats://%s:%d", srvaOpts.Host, srvaOpts.Port) + srvaConn, err := nats.Connect(srvaAddr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer srvaConn.Close() + sub, err := srvaConn.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + defer sub.Unsubscribe() + if err := srvaConn.Flush(); err != nil { + t.Fatalf("Error flushing: %v", err) + } + + srvbAddr := fmt.Sprintf("nats://%s:%d", srvbOpts.Host, srvbOpts.Port) + srvbConn, err := nats.Connect(srvbAddr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer srvbConn.Close() + + if err := srvbConn.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + srvbConn.Flush() + msg, err := sub.NextMsg(5 * time.Second) + if err != nil { + t.Fatalf("Error receiving message: %v", err) + } + if string(msg.Data) != "hello" { + t.Fatalf("Msg is incorrect.\nexpected: %+v\ngot: %+v", []byte("hello"), msg.Data) + } + + // Now stop server B. + srvb.Shutdown() + + // Wait til route is dropped. + checkNumRoutes(t, srva, 0) + + // Now change config for server A to not solicit a route to server B. + createSymlink(t, srvaConfig, "./configs/reload/srv_a_4.conf") + defer os.Remove(srvaConfig) + if err := srva.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Restart server B. + srvb, _ = RunServerWithConfig("./configs/reload/srv_b_1.conf") + defer srvb.Shutdown() + + // We should not have a cluster formed here. + numRoutes := 0 + deadline := time.Now().Add(2 * DEFAULT_ROUTE_RECONNECT) + for time.Now().Before(deadline) { + if numRoutes = srva.NumRoutes(); numRoutes != 0 { + break + } else { + time.Sleep(100 * time.Millisecond) + } + } + if numRoutes != 0 { + t.Fatalf("Expected 0 routes for server A, got %d", numRoutes) + } +} + +func reloadUpdateConfig(t *testing.T, s *Server, conf, content string) { + if err := ioutil.WriteFile(conf, []byte(content), 0666); err != nil { + stackFatalf(t, "Error creating config file: %v", err) + } + if err := s.Reload(); err != nil { + stackFatalf(t, "Error on reload: %v", err) + } +} + +func TestConfigReloadClusterAdvertise(t *testing.T) { + conf := "routeadv.conf" + if err := ioutil.WriteFile(conf, []byte(` + listen: "0.0.0.0:-1" + cluster: { + listen: "0.0.0.0:-1" + } + `), 0666); err != nil { + t.Fatalf("Error creating config file: %v", err) + } + defer os.Remove(conf) + opts, err := ProcessConfigFile(conf) + if err != nil { + t.Fatalf("Error processing config file: %v", err) + } + opts.NoLog = true + opts.NoSigs = true + s := RunServer(opts) + defer s.Shutdown() + + orgClusterPort := s.ClusterAddr().Port + + verify := func(expectedHost string, expectedPort int, expectedIP string) { + s.mu.Lock() + routeInfo := s.routeInfo + routeInfoJSON := Info{} + err = json.Unmarshal(s.routeInfoJSON[5:], &routeInfoJSON) // Skip "INFO " + s.mu.Unlock() + if err != nil { + t.Fatalf("Error on Unmarshal: %v", err) + } + if routeInfo.Host != expectedHost || routeInfo.Port != expectedPort || routeInfo.IP != expectedIP { + t.Fatalf("Expected host/port/IP to be %s:%v, %q, got %s:%d, %q", + expectedHost, expectedPort, expectedIP, routeInfo.Host, routeInfo.Port, routeInfo.IP) + } + // Check that server routeInfoJSON was updated too + if !reflect.DeepEqual(routeInfo, routeInfoJSON) { + t.Fatalf("Expected routeInfoJSON to be %+v, got %+v", routeInfo, routeInfoJSON) + } + } + + // Update config with cluster_advertise + reloadUpdateConfig(t, s, conf, ` + listen: "0.0.0.0:-1" + cluster: { + listen: "0.0.0.0:-1" + cluster_advertise: "me:1" + } + `) + verify("me", 1, "nats-route://me:1/") + + // Update config with cluster_advertise (no port specified) + reloadUpdateConfig(t, s, conf, ` + listen: "0.0.0.0:-1" + cluster: { + listen: "0.0.0.0:-1" + cluster_advertise: "me" + } + `) + verify("me", orgClusterPort, fmt.Sprintf("nats-route://me:%d/", orgClusterPort)) + + // Update config with cluster_advertise (-1 port specified) + reloadUpdateConfig(t, s, conf, ` + listen: "0.0.0.0:-1" + cluster: { + listen: "0.0.0.0:-1" + cluster_advertise: "me:-1" + } + `) + verify("me", orgClusterPort, fmt.Sprintf("nats-route://me:%d/", orgClusterPort)) + + // Update to remove cluster_advertise + reloadUpdateConfig(t, s, conf, ` + listen: "0.0.0.0:-1" + cluster: { + listen: "0.0.0.0:-1" + } + `) + verify("0.0.0.0", orgClusterPort, "") +} + +func TestConfigReloadClusterNoAdvertise(t *testing.T) { + conf := "routeadv.conf" + if err := ioutil.WriteFile(conf, []byte(` + listen: "0.0.0.0:-1" + client_advertise: "me:1" + cluster: { + listen: "0.0.0.0:-1" + } + `), 0666); err != nil { + t.Fatalf("Error creating config file: %v", err) + } + defer os.Remove(conf) + opts, err := ProcessConfigFile(conf) + if err != nil { + t.Fatalf("Error processing config file: %v", err) + } + opts.NoLog = true + opts.NoSigs = true + s := RunServer(opts) + defer s.Shutdown() + + s.mu.Lock() + ccurls := s.routeInfo.ClientConnectURLs + s.mu.Unlock() + if len(ccurls) != 1 && ccurls[0] != "me:1" { + t.Fatalf("Unexpected routeInfo.ClientConnectURLS: %v", ccurls) + } + + // Update config with no_advertise + reloadUpdateConfig(t, s, conf, ` + listen: "0.0.0.0:-1" + client_advertise: "me:1" + cluster: { + listen: "0.0.0.0:-1" + no_advertise: true + } + `) + + s.mu.Lock() + ccurls = s.routeInfo.ClientConnectURLs + s.mu.Unlock() + if len(ccurls) != 0 { + t.Fatalf("Unexpected routeInfo.ClientConnectURLS: %v", ccurls) + } + + // Update config with cluster_advertise (no port specified) + reloadUpdateConfig(t, s, conf, ` + listen: "0.0.0.0:-1" + client_advertise: "me:1" + cluster: { + listen: "0.0.0.0:-1" + } + `) + s.mu.Lock() + ccurls = s.routeInfo.ClientConnectURLs + s.mu.Unlock() + if len(ccurls) != 1 && ccurls[0] != "me:1" { + t.Fatalf("Unexpected routeInfo.ClientConnectURLS: %v", ccurls) + } +} + +func TestConfigReloadMaxSubsUnsupported(t *testing.T) { + conf := "maxsubs.conf" + if err := ioutil.WriteFile(conf, []byte(`max_subs: 1`), 0666); err != nil { + t.Fatalf("Error creating config file: %v", err) + } + defer os.Remove(conf) + opts, err := ProcessConfigFile(conf) + if err != nil { + stackFatalf(t, "Error processing config file: %v", err) + } + opts.NoLog = true + opts.NoSigs = true + s := RunServer(opts) + defer s.Shutdown() + + if err := ioutil.WriteFile(conf, []byte(`max_subs: 10`), 0666); err != nil { + t.Fatalf("Error writing config file: %v", err) + } + if err := s.Reload(); err == nil { + t.Fatal("Expected Reload to return an error") + } +} + +func TestConfigReloadClientAdvertise(t *testing.T) { + conf := "clientadv.conf" + if err := ioutil.WriteFile(conf, []byte(`listen: "0.0.0.0:-1"`), 0666); err != nil { + t.Fatalf("Error creating config file: %v", err) + } + defer os.Remove(conf) + opts, err := ProcessConfigFile(conf) + if err != nil { + stackFatalf(t, "Error processing config file: %v", err) + } + opts.NoLog = true + opts.NoSigs = true + s := RunServer(opts) + defer s.Shutdown() + + orgPort := s.Addr().(*net.TCPAddr).Port + + verify := func(expectedHost string, expectedPort int) { + s.mu.Lock() + info := s.info + s.mu.Unlock() + if info.Host != expectedHost || info.Port != expectedPort { + stackFatalf(t, "Expected host/port to be %s:%d, got %s:%d", + expectedHost, expectedPort, info.Host, info.Port) + } + } + + // Update config with ClientAdvertise (port specified) + reloadUpdateConfig(t, s, conf, ` + listen: "0.0.0.0:-1" + client_advertise: "me:1" + `) + verify("me", 1) + + // Update config with ClientAdvertise (no port specified) + reloadUpdateConfig(t, s, conf, ` + listen: "0.0.0.0:-1" + client_advertise: "me" + `) + verify("me", orgPort) + + // Update config with ClientAdvertise (-1 port specified) + reloadUpdateConfig(t, s, conf, ` + listen: "0.0.0.0:-1" + client_advertise: "me:-1" + `) + verify("me", orgPort) + + // Now remove ClientAdvertise to check that original values + // are restored. + reloadUpdateConfig(t, s, conf, `listen: "0.0.0.0:-1"`) + verify("0.0.0.0", orgPort) +} + +// Ensure Reload supports changing the max connections. Test this by starting a +// server with no max connections, connecting two clients, reloading with a +// max connections of one, and ensuring one client is disconnected. +func TestConfigReloadMaxConnections(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/basic.conf") + defer os.Remove(config) + defer server.Shutdown() + + // Make two connections. + addr := fmt.Sprintf("nats://%s:%d", opts.Host, server.Addr().(*net.TCPAddr).Port) + nc1, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc1.Close() + closed := make(chan struct{}, 1) + nc1.SetDisconnectHandler(func(*nats.Conn) { + closed <- struct{}{} + }) + nc2, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc2.Close() + nc2.SetDisconnectHandler(func(*nats.Conn) { + closed <- struct{}{} + }) + + if numClients := server.NumClients(); numClients != 2 { + t.Fatalf("Expected 2 clients, got %d", numClients) + } + + // Set max connections to one. + createSymlink(t, config, "./configs/reload/max_connections.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure one connection was closed. + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatal("Expected to be disconnected") + } + + if numClients := server.NumClients(); numClients != 1 { + t.Fatalf("Expected 1 client, got %d", numClients) + } + + // Ensure new connections fail. + _, err = nats.Connect(addr) + if err == nil { + t.Fatal("Expected error on connect") + } +} + +// Ensure reload supports changing the max payload size. Test this by starting +// a server with the default size limit, ensuring publishes work, reloading +// with a restrictive limit, and ensuring publishing an oversized message fails +// and disconnects the client. +func TestConfigReloadMaxPayload(t *testing.T) { + server, opts, config := runServerWithSymlinkConfig(t, "tmp.conf", "./configs/reload/basic.conf") + defer os.Remove(config) + defer server.Shutdown() + + addr := fmt.Sprintf("nats://%s:%d", opts.Host, server.Addr().(*net.TCPAddr).Port) + nc, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer nc.Close() + closed := make(chan struct{}) + nc.SetDisconnectHandler(func(*nats.Conn) { + closed <- struct{}{} + }) + + conn, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v", err) + } + defer conn.Close() + sub, err := conn.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + conn.Flush() + + // Ensure we can publish as a sanity check. + if err := nc.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + nc.Flush() + _, err = sub.NextMsg(2 * time.Second) + if err != nil { + t.Fatalf("Error receiving message: %v", err) + } + + // Set max payload to one. + createSymlink(t, config, "./configs/reload/max_payload.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Ensure oversized messages don't get delivered and the client is + // disconnected. + if err := nc.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error publishing: %v", err) + } + nc.Flush() + _, err = sub.NextMsg(20 * time.Millisecond) + if err != nats.ErrTimeout { + t.Fatalf("Expected ErrTimeout, got: %v", err) + } + + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatal("Expected to be disconnected") + } +} + +// Ensure reload supports rotating out files. Test this by starting +// a server with log and pid files, reloading new ones, then check that +// we can rename and delete the old log/pid files. +func TestConfigReloadRotateFiles(t *testing.T) { + opts, config := newOptionsWithSymlinkConfig(t, "tmp.conf", "./configs/reload/file_rotate.conf") + server := RunServer(opts) + defer func() { + os.Remove(config) + os.Remove("log1.txt") + os.Remove("gnatsd1.pid") + }() + defer server.Shutdown() + + // Configure the logger to enable actual logging + server.ConfigureLogger() + + // Load a config that renames the files. + createSymlink(t, config, "./configs/reload/file_rotate1.conf") + if err := server.Reload(); err != nil { + t.Fatalf("Error reloading config: %v", err) + } + + // Make sure the new files exist. + if _, err := os.Stat("log1.txt"); os.IsNotExist(err) { + t.Fatalf("Error reloading config, no new file: %v", err) + } + if _, err := os.Stat("gnatsd1.pid"); os.IsNotExist(err) { + t.Fatalf("Error reloading config, no new file: %v", err) + } + + // Check that old file can be renamed. + if err := os.Rename("log.txt", "log_old.txt"); err != nil { + t.Fatalf("Error reloading config, cannot rename file: %v", err) + } + if err := os.Rename("gnatsd.pid", "gnatsd_old.pid"); err != nil { + t.Fatalf("Error reloading config, cannot rename file: %v", err) + } + + // Check that the old files can be removed after rename. + if err := os.Remove("log_old.txt"); err != nil { + t.Fatalf("Error reloading config, cannot delete file: %v", err) + } + if err := os.Remove("gnatsd_old.pid"); err != nil { + t.Fatalf("Error reloading config, cannot delete file: %v", err) + } +} + +func runServerWithSymlinkConfig(t *testing.T, symlinkName, configName string) (*Server, *Options, string) { + t.Helper() + opts, config := newOptionsWithSymlinkConfig(t, symlinkName, configName) + opts.NoLog = true + opts.NoSigs = true + return RunServer(opts), opts, config +} + +func newServerWithSymlinkConfig(t *testing.T, symlinkName, configName string) (*Server, *Options, string) { + t.Helper() + opts, config := newOptionsWithSymlinkConfig(t, symlinkName, configName) + return New(opts), opts, config +} + +func newOptionsWithSymlinkConfig(t *testing.T, symlinkName, configName string) (*Options, string) { + t.Helper() + dir, err := os.Getwd() + if err != nil { + t.Fatalf("Error getting working directory: %v", err) + } + config := filepath.Join(dir, symlinkName) + createSymlink(t, config, configName) + opts, err := ProcessConfigFile(config) + if err != nil { + t.Fatalf("Error processing config file: %v", err) + } + opts.NoSigs = true + return opts, config +} + +func createSymlink(t *testing.T, symlinkName, fileName string) { + t.Helper() + os.Remove(symlinkName) + if err := os.Symlink(fileName, symlinkName); err != nil { + t.Fatalf("Error creating symlink: %v (ensure you have privileges)", err) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/ring.go b/vendor/github.com/nats-io/gnatsd/server/ring.go new file mode 100644 index 00000000..b9232ca9 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/ring.go @@ -0,0 +1,75 @@ +// Copyright 2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +// We wrap to hold onto optional items for /connz. +type closedClient struct { + ConnInfo + subs []string + user string +} + +// Fixed sized ringbuffer for closed connections. +type closedRingBuffer struct { + total uint64 + conns []*closedClient +} + +// Create a new ring buffer with at most max items. +func newClosedRingBuffer(max int) *closedRingBuffer { + rb := &closedRingBuffer{} + rb.conns = make([]*closedClient, max) + return rb +} + +// Adds in a new closed connection. If there is no more room, +// remove the oldest. +func (rb *closedRingBuffer) append(cc *closedClient) { + rb.conns[rb.next()] = cc + rb.total++ +} + +func (rb *closedRingBuffer) next() int { + return int(rb.total % uint64(cap(rb.conns))) +} + +func (rb *closedRingBuffer) len() int { + if rb.total > uint64(cap(rb.conns)) { + return cap(rb.conns) + } + return int(rb.total) +} + +func (rb *closedRingBuffer) totalConns() uint64 { + return rb.total +} + +// This will not be sorted. Will return a copy of the list +// which recipient can modify. If the contents of the client +// itself need to be modified, meaning swapping in any optional items, +// a copy should be made. We could introduce a new lock and hold that +// but since we return this list inside monitor which allows programatic +// access, we do not know when it would be done. +func (rb *closedRingBuffer) closedClients() []*closedClient { + dup := make([]*closedClient, rb.len()) + if rb.total <= uint64(cap(rb.conns)) { + copy(dup, rb.conns[:rb.len()]) + } else { + first := rb.next() + next := cap(rb.conns) - first + copy(dup, rb.conns[first:]) + copy(dup[next:], rb.conns[:next]) + } + return dup +} diff --git a/vendor/github.com/nats-io/gnatsd/server/route.go b/vendor/github.com/nats-io/gnatsd/server/route.go new file mode 100644 index 00000000..c12a713d --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/route.go @@ -0,0 +1,1103 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "math/rand" + "net" + "net/url" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/nats-io/gnatsd/util" +) + +// RouteType designates the router type +type RouteType int + +// Type of Route +const ( + // This route we learned from speaking to other routes. + Implicit RouteType = iota + // This route was explicitly configured. + Explicit +) + +type route struct { + remoteID string + didSolicit bool + retry bool + routeType RouteType + url *url.URL + authRequired bool + tlsRequired bool + closed bool + connectURLs []string +} + +type connectInfo struct { + Echo bool `json:"echo"` + Verbose bool `json:"verbose"` + Pedantic bool `json:"pedantic"` + User string `json:"user,omitempty"` + Pass string `json:"pass,omitempty"` + TLS bool `json:"tls_required"` + Name string `json:"name"` +} + +// Used to hold onto mappings for unsubscribed +// routed queue subscribers. +type rqsub struct { + group []byte + atime time.Time +} + +// Route protocol constants +const ( + ConProto = "CONNECT %s" + _CRLF_ + InfoProto = "INFO %s" + _CRLF_ +) + +// Clear up the timer and any map held for remote qsubs. +func (s *Server) clearRemoteQSubs() { + s.rqsMu.Lock() + defer s.rqsMu.Unlock() + if s.rqsubsTimer != nil { + s.rqsubsTimer.Stop() + s.rqsubsTimer = nil + } + s.rqsubs = nil +} + +// Check to see if we can remove any of the remote qsubs mappings +func (s *Server) purgeRemoteQSubs() { + ri := s.getOpts().RQSubsSweep + s.rqsMu.Lock() + exp := time.Now().Add(-ri) + for k, rqsub := range s.rqsubs { + if exp.After(rqsub.atime) { + delete(s.rqsubs, k) + } + } + if s.rqsubsTimer != nil { + // Reset timer. + s.rqsubsTimer = time.AfterFunc(ri, s.purgeRemoteQSubs) + } + s.rqsMu.Unlock() +} + +// Lookup a remote queue group sid. +func (s *Server) lookupRemoteQGroup(sid string) []byte { + s.rqsMu.RLock() + rqsub := s.rqsubs[sid] + s.rqsMu.RUnlock() + return rqsub.group +} + +// This will hold onto a remote queue subscriber to allow +// for mapping and handling if we get a message after the +// subscription goes away. +func (s *Server) holdRemoteQSub(sub *subscription) { + // Should not happen, but protect anyway. + if len(sub.queue) == 0 { + return + } + // Add the entry + s.rqsMu.Lock() + // Start timer if needed. + if s.rqsubsTimer == nil { + ri := s.getOpts().RQSubsSweep + s.rqsubsTimer = time.AfterFunc(ri, s.purgeRemoteQSubs) + } + // Create map if needed. + if s.rqsubs == nil { + s.rqsubs = make(map[string]rqsub) + } + group := make([]byte, len(sub.queue)) + copy(group, sub.queue) + rqsub := rqsub{group: group, atime: time.Now()} + s.rqsubs[routeSid(sub)] = rqsub + s.rqsMu.Unlock() +} + +// This is for when we receive a directed message for a queue subscriber +// that has gone away. We reroute like a new message but scope to only +// the queue subscribers that it was originally intended for. We will +// prefer local clients, but will bounce to another route if needed. +func (c *client) reRouteQMsg(r *SublistResult, msgh, msg, group []byte) { + c.Debugf("Attempting redelivery of message for absent queue subscriber on group '%q'", group) + + // We only care about qsubs here. Data structure not setup for optimized + // lookup for our specific group however. + + var qsubs []*subscription + for _, qs := range r.qsubs { + if len(qs) != 0 && bytes.Equal(group, qs[0].queue) { + qsubs = qs + break + } + } + + // If no match return. + if qsubs == nil { + c.Debugf("Redelivery failed, no queue subscribers for message on group '%q'", group) + return + } + + // We have a matched group of queue subscribers. + // We prefer a local subscriber since that was the original target. + + // Spin prand if needed. + if c.in.prand == nil { + c.in.prand = rand.New(rand.NewSource(time.Now().UnixNano())) + } + + // Hold onto a remote if we come across it to utilize in case no locals exist. + var rsub *subscription + + startIndex := c.in.prand.Intn(len(qsubs)) + for i := 0; i < len(qsubs); i++ { + index := (startIndex + i) % len(qsubs) + sub := qsubs[index] + if sub == nil { + continue + } + if rsub == nil && bytes.HasPrefix(sub.sid, []byte(QRSID)) { + rsub = sub + continue + } + mh := c.msgHeader(msgh[:], sub) + if c.deliverMsg(sub, mh, msg) { + c.Debugf("Redelivery succeeded for message on group '%q'", group) + return + } + } + // If we are here we failed to find a local, see if we snapshotted a + // remote sub, and if so deliver to that. + if rsub != nil { + mh := c.msgHeader(msgh[:], rsub) + if c.deliverMsg(rsub, mh, msg) { + c.Debugf("Re-routing message on group '%q' to remote server", group) + return + } + } + c.Debugf("Redelivery failed, no queue subscribers for message on group '%q'", group) +} + +// processRoutedMsg processes messages inbound from a route. +func (c *client) processRoutedMsg(r *SublistResult, msg []byte) { + // Snapshot server. + srv := c.srv + + msgh := c.prepMsgHeader() + si := len(msgh) + + // If we have a queue subscription, deliver direct + // since they are sent direct via L2 semantics over routes. + // If the match is a queue subscription, we will return from + // here regardless if we find a sub. + isq, sub, err := srv.routeSidQueueSubscriber(c.pa.sid) + if isq { + if err != nil { + // We got an invalid QRSID, so stop here + c.Errorf("Unable to deliver routed queue message: %v", err) + return + } + didDeliver := false + if sub != nil { + mh := c.msgHeader(msgh[:si], sub) + didDeliver = c.deliverMsg(sub, mh, msg) + } + if !didDeliver && c.srv != nil { + group := c.srv.lookupRemoteQGroup(string(c.pa.sid)) + c.reRouteQMsg(r, msgh, msg, group) + } + return + } + // Normal pub/sub message here + // Loop over all normal subscriptions that match. + for _, sub := range r.psubs { + // Check if this is a send to a ROUTER, if so we ignore to + // enforce 1-hop semantics. + if sub.client.typ == ROUTER { + continue + } + sub.client.mu.Lock() + if sub.client.nc == nil { + sub.client.mu.Unlock() + continue + } + sub.client.mu.Unlock() + + // Normal delivery + mh := c.msgHeader(msgh[:si], sub) + c.deliverMsg(sub, mh, msg) + } +} + +// Lock should be held entering here. +func (c *client) sendConnect(tlsRequired bool) { + var user, pass string + if userInfo := c.route.url.User; userInfo != nil { + user = userInfo.Username() + pass, _ = userInfo.Password() + } + cinfo := connectInfo{ + Echo: true, + Verbose: false, + Pedantic: false, + User: user, + Pass: pass, + TLS: tlsRequired, + Name: c.srv.info.ID, + } + b, err := json.Marshal(cinfo) + if err != nil { + c.Errorf("Error marshaling CONNECT to route: %v\n", err) + c.closeConnection(ProtocolViolation) + return + } + c.sendProto([]byte(fmt.Sprintf(ConProto, b)), true) +} + +// Process the info message if we are a route. +func (c *client) processRouteInfo(info *Info) { + c.mu.Lock() + // Connection can be closed at any time (by auth timeout, etc). + // Does not make sense to continue here if connection is gone. + if c.route == nil || c.nc == nil { + c.mu.Unlock() + return + } + + s := c.srv + remoteID := c.route.remoteID + + // We receive an INFO from a server that informs us about another server, + // so the info.ID in the INFO protocol does not match the ID of this route. + if remoteID != "" && remoteID != info.ID { + c.mu.Unlock() + + // Process this implicit route. We will check that it is not an explicit + // route and/or that it has not been connected already. + s.processImplicitRoute(info) + return + } + + // Need to set this for the detection of the route to self to work + // in closeConnection(). + c.route.remoteID = info.ID + + // Detect route to self. + if c.route.remoteID == s.info.ID { + c.mu.Unlock() + c.closeConnection(DuplicateRoute) + return + } + + // Copy over important information. + c.route.authRequired = info.AuthRequired + c.route.tlsRequired = info.TLSRequired + + // If we do not know this route's URL, construct one on the fly + // from the information provided. + if c.route.url == nil { + // Add in the URL from host and port + hp := net.JoinHostPort(info.Host, strconv.Itoa(info.Port)) + url, err := url.Parse(fmt.Sprintf("nats-route://%s/", hp)) + if err != nil { + c.Errorf("Error parsing URL from INFO: %v\n", err) + c.mu.Unlock() + c.closeConnection(ParseError) + return + } + c.route.url = url + } + + // Check to see if we have this remote already registered. + // This can happen when both servers have routes to each other. + c.mu.Unlock() + + if added, sendInfo := s.addRoute(c, info); added { + c.Debugf("Registering remote route %q", info.ID) + // Send our local subscriptions to this route. + s.sendLocalSubsToRoute(c) + // sendInfo will be false if the route that we just accepted + // is the only route there is. + if sendInfo { + // The incoming INFO from the route will have IP set + // if it has Cluster.Advertise. In that case, use that + // otherwise contruct it from the remote TCP address. + if info.IP == "" { + // Need to get the remote IP address. + c.mu.Lock() + switch conn := c.nc.(type) { + case *net.TCPConn, *tls.Conn: + addr := conn.RemoteAddr().(*net.TCPAddr) + info.IP = fmt.Sprintf("nats-route://%s/", net.JoinHostPort(addr.IP.String(), + strconv.Itoa(info.Port))) + default: + info.IP = c.route.url.String() + } + c.mu.Unlock() + } + // Now let the known servers know about this new route + s.forwardNewRouteInfoToKnownServers(info) + } + // Unless disabled, possibly update the server's INFO protocol + // and send to clients that know how to handle async INFOs. + if !s.getOpts().Cluster.NoAdvertise { + s.addClientConnectURLsAndSendINFOToClients(info.ClientConnectURLs) + } + } else { + c.Debugf("Detected duplicate remote route %q", info.ID) + c.closeConnection(DuplicateRoute) + } +} + +// sendAsyncInfoToClients sends an INFO protocol to all +// connected clients that accept async INFO updates. +// The server lock is held on entry. +func (s *Server) sendAsyncInfoToClients() { + // If there are no clients supporting async INFO protocols, we are done. + // Also don't send if we are shutting down... + if s.cproto == 0 || s.shutdown { + return + } + + for _, c := range s.clients { + c.mu.Lock() + // Here, we are going to send only to the clients that are fully + // registered (server has received CONNECT and first PING). For + // clients that are not at this stage, this will happen in the + // processing of the first PING (see client.processPing) + if c.opts.Protocol >= ClientProtoInfo && c.flags.isSet(firstPongSent) { + // sendInfo takes care of checking if the connection is still + // valid or not, so don't duplicate tests here. + c.sendInfo(c.generateClientInfoJSON(s.copyInfo())) + } + c.mu.Unlock() + } +} + +// This will process implicit route information received from another server. +// We will check to see if we have configured or are already connected, +// and if so we will ignore. Otherwise we will attempt to connect. +func (s *Server) processImplicitRoute(info *Info) { + remoteID := info.ID + + s.mu.Lock() + defer s.mu.Unlock() + + // Don't connect to ourself + if remoteID == s.info.ID { + return + } + // Check if this route already exists + if _, exists := s.remotes[remoteID]; exists { + return + } + // Check if we have this route as a configured route + if s.hasThisRouteConfigured(info) { + return + } + + // Initiate the connection, using info.IP instead of info.URL here... + r, err := url.Parse(info.IP) + if err != nil { + s.Errorf("Error parsing URL from INFO: %v\n", err) + return + } + + // Snapshot server options. + opts := s.getOpts() + + if info.AuthRequired { + r.User = url.UserPassword(opts.Cluster.Username, opts.Cluster.Password) + } + s.startGoRoutine(func() { s.connectToRoute(r, false) }) +} + +// hasThisRouteConfigured returns true if info.Host:info.Port is present +// in the server's opts.Routes, false otherwise. +// Server lock is assumed to be held by caller. +func (s *Server) hasThisRouteConfigured(info *Info) bool { + urlToCheckExplicit := strings.ToLower(net.JoinHostPort(info.Host, strconv.Itoa(info.Port))) + for _, ri := range s.getOpts().Routes { + if strings.ToLower(ri.Host) == urlToCheckExplicit { + return true + } + } + return false +} + +// forwardNewRouteInfoToKnownServers sends the INFO protocol of the new route +// to all routes known by this server. In turn, each server will contact this +// new route. +func (s *Server) forwardNewRouteInfoToKnownServers(info *Info) { + s.mu.Lock() + defer s.mu.Unlock() + + b, _ := json.Marshal(info) + infoJSON := []byte(fmt.Sprintf(InfoProto, b)) + + for _, r := range s.routes { + r.mu.Lock() + if r.route.remoteID != info.ID { + r.sendInfo(infoJSON) + } + r.mu.Unlock() + } +} + +// canImport is whether or not we will send a SUB for interest to the other side. +// This is for ROUTER connections only. +// Lock is held on entry. +func (c *client) canImport(subject []byte) bool { + // Use pubAllowed() since this checks Publish permissions which + // is what Import maps to. + return c.pubAllowed(subject) +} + +// canExport is whether or not we will accept a SUB from the remote for a given subject. +// This is for ROUTER connections only. +// Lock is held on entry +func (c *client) canExport(subject []byte) bool { + // Use canSubscribe() since this checks Subscribe permissions which + // is what Export maps to. + return c.canSubscribe(subject) +} + +// Initialize or reset cluster's permissions. +// This is for ROUTER connections only. +// Client lock is held on entry +func (c *client) setRoutePermissions(perms *RoutePermissions) { + // Reset if some were set + if perms == nil { + c.perms = nil + return + } + // Convert route permissions to user permissions. + // The Import permission is mapped to Publish + // and Export permission is mapped to Subscribe. + // For meaning of Import/Export, see canImport and canExport. + p := &Permissions{ + Publish: perms.Import, + Subscribe: perms.Export, + } + c.setPermissions(p) +} + +// This will send local subscription state to a new route connection. +// FIXME(dlc) - This could be a DOS or perf issue with many clients +// and large subscription space. Plus buffering in place not a good idea. +func (s *Server) sendLocalSubsToRoute(route *client) { + var raw [4096]*subscription + subs := raw[:0] + + s.sl.localSubs(&subs) + + route.mu.Lock() + for _, sub := range subs { + // Send SUB interest only if subject has a match in import permissions + if !route.canImport(sub.subject) { + continue + } + proto := fmt.Sprintf(subProto, sub.subject, sub.queue, routeSid(sub)) + route.queueOutbound([]byte(proto)) + if route.out.pb > int64(route.out.sz*2) { + route.flushSignal() + } + } + route.flushSignal() + route.mu.Unlock() + + route.Debugf("Sent local subscriptions to route") +} + +func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client { + // Snapshot server options. + opts := s.getOpts() + + didSolicit := rURL != nil + r := &route{didSolicit: didSolicit} + for _, route := range opts.Routes { + if rURL != nil && (strings.ToLower(rURL.Host) == strings.ToLower(route.Host)) { + r.routeType = Explicit + } + } + + c := &client{srv: s, nc: conn, opts: clientOpts{}, typ: ROUTER, route: r} + + // Grab server variables + s.mu.Lock() + infoJSON := s.routeInfoJSON + authRequired := s.routeInfo.AuthRequired + tlsRequired := s.routeInfo.TLSRequired + s.mu.Unlock() + + // Grab lock + c.mu.Lock() + + // Initialize + c.initClient() + + if didSolicit { + // Do this before the TLS code, otherwise, in case of failure + // and if route is explicit, it would try to reconnect to 'nil'... + r.url = rURL + + // Set permissions associated with the route user (if applicable). + // No lock needed since we are already under client lock. + c.setRoutePermissions(opts.Cluster.Permissions) + } + + // Check for TLS + if tlsRequired { + // Copy off the config to add in ServerName if we + tlsConfig := util.CloneTLSConfig(opts.Cluster.TLSConfig) + + // If we solicited, we will act like the client, otherwise the server. + if didSolicit { + c.Debugf("Starting TLS route client handshake") + // Specify the ServerName we are expecting. + host, _, _ := net.SplitHostPort(rURL.Host) + tlsConfig.ServerName = host + c.nc = tls.Client(c.nc, tlsConfig) + } else { + c.Debugf("Starting TLS route server handshake") + c.nc = tls.Server(c.nc, tlsConfig) + } + + conn := c.nc.(*tls.Conn) + + // Setup the timeout + ttl := secondsToDuration(opts.Cluster.TLSTimeout) + time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) + conn.SetReadDeadline(time.Now().Add(ttl)) + + c.mu.Unlock() + if err := conn.Handshake(); err != nil { + c.Errorf("TLS route handshake error: %v", err) + c.sendErr("Secure Connection - TLS Required") + c.closeConnection(TLSHandshakeError) + return nil + } + // Reset the read deadline + conn.SetReadDeadline(time.Time{}) + + // Re-Grab lock + c.mu.Lock() + + // Verify that the connection did not go away while we released the lock. + if c.nc == nil { + c.mu.Unlock() + return nil + } + } + + // Do final client initialization + + // Set the Ping timer + c.setPingTimer() + + // For routes, the "client" is added to s.routes only when processing + // the INFO protocol, that is much later. + // In the meantime, if the server shutsdown, there would be no reference + // to the client (connection) to be closed, leaving this readLoop + // uinterrupted, causing the Shutdown() to wait indefinitively. + // We need to store the client in a special map, under a special lock. + s.grMu.Lock() + running := s.grRunning + if running { + s.grTmpClients[c.cid] = c + } + s.grMu.Unlock() + if !running { + c.mu.Unlock() + c.setRouteNoReconnectOnClose() + c.closeConnection(ServerShutdown) + return nil + } + + // Check for Auth required state for incoming connections. + // Make sure to do this before spinning up readLoop. + if authRequired && !didSolicit { + ttl := secondsToDuration(opts.Cluster.AuthTimeout) + c.setAuthTimer(ttl) + } + + // Spin up the read loop. + s.startGoRoutine(func() { c.readLoop() }) + + // Spin up the write loop. + s.startGoRoutine(c.writeLoop) + + if tlsRequired { + c.Debugf("TLS handshake complete") + cs := c.nc.(*tls.Conn).ConnectionState() + c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite)) + } + + // Queue Connect proto if we solicited the connection. + if didSolicit { + c.Debugf("Route connect msg sent") + c.sendConnect(tlsRequired) + } + + // Send our info to the other side. + c.sendInfo(infoJSON) + + c.mu.Unlock() + + c.Noticef("Route connection created") + return c +} + +const ( + _CRLF_ = "\r\n" + _EMPTY_ = "" +) + +const ( + subProto = "SUB %s %s %s" + _CRLF_ + unsubProto = "UNSUB %s" + _CRLF_ +) + +// FIXME(dlc) - Make these reserved and reject if they come in as a sid +// from a client connection. +// Route constants +const ( + RSID = "RSID" + QRSID = "QRSID" + + QRSID_LEN = len(QRSID) +) + +// Parse the given rsid. If the protocol does not start with QRSID, +// returns false and no subscription nor error. +// If it does start with QRSID, returns true and possibly a subscription +// or an error if the QRSID protocol is malformed. +func (s *Server) routeSidQueueSubscriber(rsid []byte) (bool, *subscription, error) { + if !bytes.HasPrefix(rsid, []byte(QRSID)) { + return false, nil, nil + } + cid, sid, err := parseRouteQueueSid(rsid) + if err != nil { + return true, nil, err + } + + s.mu.Lock() + client := s.clients[cid] + s.mu.Unlock() + + if client == nil { + return true, nil, nil + } + + client.mu.Lock() + sub, ok := client.subs[string(sid)] + client.mu.Unlock() + if ok { + return true, sub, nil + } + return true, nil, nil +} + +// Creates a routable sid that can be used +// to reach remote subscriptions. +func routeSid(sub *subscription) string { + var qi string + if len(sub.queue) > 0 { + qi = "Q" + } + return fmt.Sprintf("%s%s:%d:%s", qi, RSID, sub.client.cid, sub.sid) +} + +// Parse the given `rsid` knowing that it starts with `QRSID`. +// Returns the cid and sid or an error not a valid QRSID. +func parseRouteQueueSid(rsid []byte) (uint64, []byte, error) { + var ( + cid uint64 + sid []byte + cidFound bool + sidFound bool + ) + // A valid QRSID needs to be at least QRSID:x:y + // First character here should be `:` + if len(rsid) >= QRSID_LEN+4 { + if rsid[QRSID_LEN] == ':' { + for i, count := QRSID_LEN+1, len(rsid); i < count; i++ { + switch rsid[i] { + case ':': + cid = uint64(parseInt64(rsid[QRSID_LEN+1 : i])) + cidFound = true + sid = rsid[i+1:] + } + } + if cidFound { + // We can't assume the content of sid, so as long + // as it is not len 0, we have to say it is a valid one. + if len(rsid) > 0 { + sidFound = true + } + } + } + } + if cidFound && sidFound { + return cid, sid, nil + } + return 0, nil, fmt.Errorf("invalid QRSID: %s", rsid) +} + +func (s *Server) addRoute(c *client, info *Info) (bool, bool) { + id := c.route.remoteID + sendInfo := false + + s.mu.Lock() + if !s.running { + s.mu.Unlock() + return false, false + } + remote, exists := s.remotes[id] + if !exists { + s.routes[c.cid] = c + s.remotes[id] = c + c.mu.Lock() + c.route.connectURLs = info.ClientConnectURLs + cid := c.cid + c.mu.Unlock() + + // Remove from the temporary map + s.grMu.Lock() + delete(s.grTmpClients, cid) + s.grMu.Unlock() + + // we don't need to send if the only route is the one we just accepted. + sendInfo = len(s.routes) > 1 + } + s.mu.Unlock() + + if exists { + var r *route + + c.mu.Lock() + // upgrade to solicited? + if c.route.didSolicit { + // Make a copy + rs := *c.route + r = &rs + } + c.mu.Unlock() + + remote.mu.Lock() + // r will be not nil if c.route.didSolicit was true + if r != nil { + // If we upgrade to solicited, we still want to keep the remote's + // connectURLs. So transfer those. + r.connectURLs = remote.route.connectURLs + remote.route = r + } + // This is to mitigate the issue where both sides add the route + // on the opposite connection, and therefore end-up with both + // connections being dropped. + remote.route.retry = true + remote.mu.Unlock() + } + + return !exists, sendInfo +} + +func (s *Server) broadcastInterestToRoutes(sub *subscription, proto string) { + var arg []byte + if atomic.LoadInt32(&s.logging.trace) == 1 { + arg = []byte(proto[:len(proto)-LEN_CR_LF]) + } + protoAsBytes := []byte(proto) + s.mu.Lock() + for _, route := range s.routes { + // FIXME(dlc) - Make same logic as deliverMsg + route.mu.Lock() + // The permission of this cluster applies to all routes, and each + // route will have the same `perms`, so check with the first route + // and send SUB interest only if subject has a match in import permissions. + // If there is no match, we stop here. + if !route.canImport(sub.subject) { + route.mu.Unlock() + break + } + route.sendProto(protoAsBytes, true) + route.mu.Unlock() + route.traceOutOp("", arg) + } + s.mu.Unlock() +} + +// broadcastSubscribe will forward a client subscription +// to all active routes. +func (s *Server) broadcastSubscribe(sub *subscription) { + if s.numRoutes() == 0 { + return + } + rsid := routeSid(sub) + proto := fmt.Sprintf(subProto, sub.subject, sub.queue, rsid) + s.broadcastInterestToRoutes(sub, proto) +} + +// broadcastUnSubscribe will forward a client unsubscribe +// action to all active routes. +func (s *Server) broadcastUnSubscribe(sub *subscription) { + if s.numRoutes() == 0 { + return + } + sub.client.mu.Lock() + // Max has no meaning on the other side of a route, so do not send. + hasMax := sub.max > 0 && sub.nm < sub.max + sub.client.mu.Unlock() + if hasMax { + return + } + rsid := routeSid(sub) + proto := fmt.Sprintf(unsubProto, rsid) + s.broadcastInterestToRoutes(sub, proto) +} + +func (s *Server) routeAcceptLoop(ch chan struct{}) { + defer func() { + if ch != nil { + close(ch) + } + }() + + // Snapshot server options. + opts := s.getOpts() + + // Snapshot server options. + port := opts.Cluster.Port + + if port == -1 { + port = 0 + } + + hp := net.JoinHostPort(opts.Cluster.Host, strconv.Itoa(port)) + l, e := net.Listen("tcp", hp) + if e != nil { + s.Fatalf("Error listening on router port: %d - %v", opts.Cluster.Port, e) + return + } + s.Noticef("Listening for route connections on %s", + net.JoinHostPort(opts.Cluster.Host, strconv.Itoa(l.Addr().(*net.TCPAddr).Port))) + + s.mu.Lock() + // Check for TLSConfig + tlsReq := opts.Cluster.TLSConfig != nil + info := Info{ + ID: s.info.ID, + Version: s.info.Version, + AuthRequired: false, + TLSRequired: tlsReq, + TLSVerify: tlsReq, + MaxPayload: s.info.MaxPayload, + } + // Set this if only if advertise is not disabled + if !opts.Cluster.NoAdvertise { + info.ClientConnectURLs = s.clientConnectURLs + } + // If we have selected a random port... + if port == 0 { + // Write resolved port back to options. + opts.Cluster.Port = l.Addr().(*net.TCPAddr).Port + } + // Keep track of actual listen port. This will be needed in case of + // config reload. + s.clusterActualPort = opts.Cluster.Port + // Check for Auth items + if opts.Cluster.Username != "" { + info.AuthRequired = true + } + s.routeInfo = info + // Possibly override Host/Port and set IP based on Cluster.Advertise + if err := s.setRouteInfoHostPortAndIP(); err != nil { + s.Fatalf("Error setting route INFO with Cluster.Advertise value of %s, err=%v", s.opts.Cluster.Advertise, err) + l.Close() + s.mu.Unlock() + return + } + // Setup state that can enable shutdown + s.routeListener = l + s.mu.Unlock() + + // Let them know we are up + close(ch) + ch = nil + + tmpDelay := ACCEPT_MIN_SLEEP + + for s.isRunning() { + conn, err := l.Accept() + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Temporary() { + s.Debugf("Temporary Route Accept Errorf(%v), sleeping %dms", + ne, tmpDelay/time.Millisecond) + time.Sleep(tmpDelay) + tmpDelay *= 2 + if tmpDelay > ACCEPT_MAX_SLEEP { + tmpDelay = ACCEPT_MAX_SLEEP + } + } else if s.isRunning() { + s.Noticef("Accept error: %v", err) + } + continue + } + tmpDelay = ACCEPT_MIN_SLEEP + s.startGoRoutine(func() { + s.createRoute(conn, nil) + s.grWG.Done() + }) + } + s.Debugf("Router accept loop exiting..") + s.done <- true +} + +// Similar to setInfoHostPortAndGenerateJSON, but for routeInfo. +func (s *Server) setRouteInfoHostPortAndIP() error { + if s.opts.Cluster.Advertise != "" { + advHost, advPort, err := parseHostPort(s.opts.Cluster.Advertise, s.opts.Cluster.Port) + if err != nil { + return err + } + s.routeInfo.Host = advHost + s.routeInfo.Port = advPort + s.routeInfo.IP = fmt.Sprintf("nats-route://%s/", net.JoinHostPort(advHost, strconv.Itoa(advPort))) + } else { + s.routeInfo.Host = s.opts.Cluster.Host + s.routeInfo.Port = s.opts.Cluster.Port + s.routeInfo.IP = "" + } + // (re)generate the routeInfoJSON byte array + s.generateRouteInfoJSON() + return nil +} + +// StartRouting will start the accept loop on the cluster host:port +// and will actively try to connect to listed routes. +func (s *Server) StartRouting(clientListenReady chan struct{}) { + defer s.grWG.Done() + + // Wait for the client listen port to be opened, and + // the possible ephemeral port to be selected. + <-clientListenReady + + // Spin up the accept loop + ch := make(chan struct{}) + go s.routeAcceptLoop(ch) + <-ch + + // Solicit Routes if needed. + s.solicitRoutes(s.getOpts().Routes) +} + +func (s *Server) reConnectToRoute(rURL *url.URL, rtype RouteType) { + tryForEver := rtype == Explicit + // If A connects to B, and B to A (regardless if explicit or + // implicit - due to auto-discovery), and if each server first + // registers the route on the opposite TCP connection, the + // two connections will end-up being closed. + // Add some random delay to reduce risk of repeated failures. + delay := time.Duration(rand.Intn(100)) * time.Millisecond + if tryForEver { + delay += DEFAULT_ROUTE_RECONNECT + } + time.Sleep(delay) + s.connectToRoute(rURL, tryForEver) +} + +// Checks to make sure the route is still valid. +func (s *Server) routeStillValid(rURL *url.URL) bool { + for _, ri := range s.getOpts().Routes { + if ri == rURL { + return true + } + } + return false +} + +func (s *Server) connectToRoute(rURL *url.URL, tryForEver bool) { + // Snapshot server options. + opts := s.getOpts() + + defer s.grWG.Done() + + attempts := 0 + for s.isRunning() && rURL != nil { + if tryForEver && !s.routeStillValid(rURL) { + return + } + s.Debugf("Trying to connect to route on %s", rURL.Host) + conn, err := net.DialTimeout("tcp", rURL.Host, DEFAULT_ROUTE_DIAL) + if err != nil { + s.Errorf("Error trying to connect to route: %v", err) + if !tryForEver { + if opts.Cluster.ConnectRetries <= 0 { + return + } + attempts++ + if attempts > opts.Cluster.ConnectRetries { + return + } + } + select { + case <-s.quitCh: + return + case <-time.After(DEFAULT_ROUTE_CONNECT): + continue + } + } + + if tryForEver && !s.routeStillValid(rURL) { + conn.Close() + return + } + + // We have a route connection here. + // Go ahead and create it and exit this func. + s.createRoute(conn, rURL) + return + } +} + +func (c *client) isSolicitedRoute() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.typ == ROUTER && c.route != nil && c.route.didSolicit +} + +func (s *Server) solicitRoutes(routes []*url.URL) { + for _, r := range routes { + route := r + s.startGoRoutine(func() { s.connectToRoute(route, true) }) + } +} + +func (s *Server) numRoutes() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.routes) +} diff --git a/vendor/github.com/nats-io/gnatsd/server/routes_test.go b/vendor/github.com/nats-io/gnatsd/server/routes_test.go new file mode 100644 index 00000000..de1226eb --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/routes_test.go @@ -0,0 +1,1000 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "net" + "net/url" + "reflect" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/nats-io/go-nats" +) + +func checkNumRoutes(t *testing.T, s *Server, expected int) { + t.Helper() + checkFor(t, 5*time.Second, 15*time.Millisecond, func() error { + if nr := s.NumRoutes(); nr != expected { + return fmt.Errorf("Expected %v routes, got %v", expected, nr) + } + return nil + }) +} + +func TestRouteConfig(t *testing.T) { + opts, err := ProcessConfigFile("./configs/cluster.conf") + if err != nil { + t.Fatalf("Received an error reading route config file: %v\n", err) + } + + golden := &Options{ + ConfigFile: "./configs/cluster.conf", + Host: "127.0.0.1", + Port: 4242, + Username: "derek", + Password: "porkchop", + AuthTimeout: 1.0, + Cluster: ClusterOpts{ + Host: "127.0.0.1", + Port: 4244, + Username: "route_user", + Password: "top_secret", + AuthTimeout: 1.0, + NoAdvertise: true, + ConnectRetries: 2, + }, + PidFile: "/tmp/nats_cluster_test.pid", + } + + // Setup URLs + r1, _ := url.Parse("nats-route://foo:bar@127.0.0.1:4245") + r2, _ := url.Parse("nats-route://foo:bar@127.0.0.1:4246") + + golden.Routes = []*url.URL{r1, r2} + + if !reflect.DeepEqual(golden, opts) { + t.Fatalf("Options are incorrect.\nexpected: %+v\ngot: %+v", + golden, opts) + } +} + +func TestClusterAdvertise(t *testing.T) { + lst, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Error starting listener: %v", err) + } + ch := make(chan error) + go func() { + c, err := lst.Accept() + if err != nil { + ch <- err + return + } + c.Close() + ch <- nil + }() + + optsA, _ := ProcessConfigFile("./configs/seed.conf") + optsA.NoSigs, optsA.NoLog = true, true + srvA := RunServer(optsA) + defer srvA.Shutdown() + + srvARouteURL := fmt.Sprintf("nats://%s:%d", optsA.Cluster.Host, srvA.ClusterAddr().Port) + optsB := nextServerOpts(optsA) + optsB.Routes = RoutesFromStr(srvARouteURL) + + srvB := RunServer(optsB) + defer srvB.Shutdown() + + // Wait for these 2 to connect to each other + checkClusterFormed(t, srvA, srvB) + + // Now start server C that connects to A. A should ask B to connect to C, + // based on C's URL. But since C configures a Cluster.Advertise, it will connect + // to our listener. + optsC := nextServerOpts(optsB) + optsC.Cluster.Advertise = lst.Addr().String() + optsC.ClientAdvertise = "me:1" + optsC.Routes = RoutesFromStr(srvARouteURL) + + srvC := RunServer(optsC) + defer srvC.Shutdown() + + select { + case e := <-ch: + if e != nil { + t.Fatalf("Error: %v", e) + } + case <-time.After(2 * time.Second): + t.Fatalf("Test timed out") + } +} + +func TestClusterAdvertiseErrorOnStartup(t *testing.T) { + opts := DefaultOptions() + // Set invalid address + opts.Cluster.Advertise = "addr:::123" + s := New(opts) + defer s.Shutdown() + dl := &DummyLogger{} + s.SetLogger(dl, false, false) + + // Start will keep running, so start in a go-routine. + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + s.Start() + wg.Done() + }() + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + dl.Lock() + msg := dl.msg + dl.Unlock() + if strings.Contains(msg, "Cluster.Advertise") { + return nil + } + return fmt.Errorf("Did not get expected error, got %v", msg) + }) + s.Shutdown() + wg.Wait() +} + +func TestClientAdvertise(t *testing.T) { + optsA, _ := ProcessConfigFile("./configs/seed.conf") + optsA.NoSigs, optsA.NoLog = true, true + + srvA := RunServer(optsA) + defer srvA.Shutdown() + + optsB := nextServerOpts(optsA) + optsB.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", optsA.Cluster.Host, optsA.Cluster.Port)) + optsB.ClientAdvertise = "me:1" + srvB := RunServer(optsB) + defer srvB.Shutdown() + + checkClusterFormed(t, srvA, srvB) + + nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", optsA.Host, optsA.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + checkFor(t, time.Second, 15*time.Millisecond, func() error { + ds := nc.DiscoveredServers() + if len(ds) == 1 { + if ds[0] == "nats://me:1" { + return nil + } + } + return fmt.Errorf("Did not get expected discovered servers: %v", nc.DiscoveredServers()) + }) +} + +func TestServerRoutesWithClients(t *testing.T) { + optsA, _ := ProcessConfigFile("./configs/srv_a.conf") + optsB, _ := ProcessConfigFile("./configs/srv_b.conf") + + optsA.NoSigs, optsA.NoLog = true, true + optsB.NoSigs, optsB.NoLog = true, true + + srvA := RunServer(optsA) + defer srvA.Shutdown() + + urlA := fmt.Sprintf("nats://%s:%d/", optsA.Host, optsA.Port) + urlB := fmt.Sprintf("nats://%s:%d/", optsB.Host, optsB.Port) + + nc1, err := nats.Connect(urlA) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc1.Close() + + ch := make(chan bool) + sub, _ := nc1.Subscribe("foo", func(m *nats.Msg) { ch <- true }) + nc1.QueueSubscribe("foo", "bar", func(m *nats.Msg) {}) + nc1.Publish("foo", []byte("Hello")) + // Wait for message + <-ch + sub.Unsubscribe() + + srvB := RunServer(optsB) + defer srvB.Shutdown() + + // Wait for route to form. + checkClusterFormed(t, srvA, srvB) + + nc2, err := nats.Connect(urlB) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc2.Close() + nc2.Publish("foo", []byte("Hello")) + nc2.Flush() +} + +func TestServerRoutesWithAuthAndBCrypt(t *testing.T) { + optsA, _ := ProcessConfigFile("./configs/srv_a_bcrypt.conf") + optsB, _ := ProcessConfigFile("./configs/srv_b_bcrypt.conf") + + optsA.NoSigs, optsA.NoLog = true, true + optsB.NoSigs, optsB.NoLog = true, true + + srvA := RunServer(optsA) + defer srvA.Shutdown() + + srvB := RunServer(optsB) + defer srvB.Shutdown() + + // Wait for route to form. + checkClusterFormed(t, srvA, srvB) + + urlA := fmt.Sprintf("nats://%s:%s@%s:%d/", optsA.Username, optsA.Password, optsA.Host, optsA.Port) + urlB := fmt.Sprintf("nats://%s:%s@%s:%d/", optsB.Username, optsB.Password, optsB.Host, optsB.Port) + + nc1, err := nats.Connect(urlA) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc1.Close() + + // Test that we are connected. + ch := make(chan bool) + sub, err := nc1.Subscribe("foo", func(m *nats.Msg) { ch <- true }) + if err != nil { + t.Fatalf("Error creating subscription: %v\n", err) + } + nc1.Flush() + defer sub.Unsubscribe() + + nc2, err := nats.Connect(urlB) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc2.Close() + nc2.Publish("foo", []byte("Hello")) + nc2.Flush() + + // Wait for message + select { + case <-ch: + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for message across route") + } +} + +// Helper function to check that a cluster is formed +func checkClusterFormed(t *testing.T, servers ...*Server) { + t.Helper() + expectedNumRoutes := len(servers) - 1 + checkFor(t, 10*time.Second, 100*time.Millisecond, func() error { + for _, s := range servers { + if numRoutes := s.NumRoutes(); numRoutes != expectedNumRoutes { + return fmt.Errorf("Expected %d routes for server %q, got %d", expectedNumRoutes, s.ID(), numRoutes) + } + } + return nil + }) +} + +// Helper function to generate next opts to make sure no port conflicts etc. +func nextServerOpts(opts *Options) *Options { + nopts := *opts + nopts.Port = -1 + nopts.Cluster.Port = -1 + nopts.HTTPPort = -1 + return &nopts +} + +func TestSeedSolicitWorks(t *testing.T) { + optsSeed, _ := ProcessConfigFile("./configs/seed.conf") + + optsSeed.NoSigs, optsSeed.NoLog = true, true + + srvSeed := RunServer(optsSeed) + defer srvSeed.Shutdown() + + optsA := nextServerOpts(optsSeed) + optsA.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", optsSeed.Cluster.Host, + srvSeed.ClusterAddr().Port)) + + srvA := RunServer(optsA) + defer srvA.Shutdown() + + urlA := fmt.Sprintf("nats://%s:%d/", optsA.Host, srvA.ClusterAddr().Port) + + nc1, err := nats.Connect(urlA) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc1.Close() + + // Test that we are connected. + ch := make(chan bool) + nc1.Subscribe("foo", func(m *nats.Msg) { ch <- true }) + nc1.Flush() + + optsB := nextServerOpts(optsA) + optsB.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", optsSeed.Cluster.Host, + srvSeed.ClusterAddr().Port)) + + srvB := RunServer(optsB) + defer srvB.Shutdown() + + urlB := fmt.Sprintf("nats://%s:%d/", optsB.Host, srvB.ClusterAddr().Port) + + nc2, err := nats.Connect(urlB) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc2.Close() + + checkClusterFormed(t, srvSeed, srvA, srvB) + + nc2.Publish("foo", []byte("Hello")) + + // Wait for message + select { + case <-ch: + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for message across route") + } +} + +func TestTLSSeedSolicitWorks(t *testing.T) { + optsSeed, _ := ProcessConfigFile("./configs/seed_tls.conf") + + optsSeed.NoSigs, optsSeed.NoLog = true, true + + srvSeed := RunServer(optsSeed) + defer srvSeed.Shutdown() + + seedRouteUrl := fmt.Sprintf("nats://%s:%d", optsSeed.Cluster.Host, + srvSeed.ClusterAddr().Port) + optsA := nextServerOpts(optsSeed) + optsA.Routes = RoutesFromStr(seedRouteUrl) + + srvA := RunServer(optsA) + defer srvA.Shutdown() + + urlA := fmt.Sprintf("nats://%s:%d/", optsA.Host, srvA.Addr().(*net.TCPAddr).Port) + + nc1, err := nats.Connect(urlA) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc1.Close() + + // Test that we are connected. + ch := make(chan bool) + nc1.Subscribe("foo", func(m *nats.Msg) { ch <- true }) + nc1.Flush() + + optsB := nextServerOpts(optsA) + optsB.Routes = RoutesFromStr(seedRouteUrl) + + srvB := RunServer(optsB) + defer srvB.Shutdown() + + urlB := fmt.Sprintf("nats://%s:%d/", optsB.Host, srvB.Addr().(*net.TCPAddr).Port) + + nc2, err := nats.Connect(urlB) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc2.Close() + + checkClusterFormed(t, srvSeed, srvA, srvB) + + nc2.Publish("foo", []byte("Hello")) + + // Wait for message + select { + case <-ch: + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for message across route") + } +} + +func TestChainedSolicitWorks(t *testing.T) { + optsSeed, _ := ProcessConfigFile("./configs/seed.conf") + + optsSeed.NoSigs, optsSeed.NoLog = true, true + + srvSeed := RunServer(optsSeed) + defer srvSeed.Shutdown() + + seedRouteUrl := fmt.Sprintf("nats://%s:%d", optsSeed.Cluster.Host, + srvSeed.ClusterAddr().Port) + optsA := nextServerOpts(optsSeed) + optsA.Routes = RoutesFromStr(seedRouteUrl) + + srvA := RunServer(optsA) + defer srvA.Shutdown() + + urlSeed := fmt.Sprintf("nats://%s:%d/", optsSeed.Host, srvA.Addr().(*net.TCPAddr).Port) + + nc1, err := nats.Connect(urlSeed) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc1.Close() + + // Test that we are connected. + ch := make(chan bool) + nc1.Subscribe("foo", func(m *nats.Msg) { ch <- true }) + nc1.Flush() + + optsB := nextServerOpts(optsA) + // Server B connects to A + optsB.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", optsA.Cluster.Host, + srvA.ClusterAddr().Port)) + + srvB := RunServer(optsB) + defer srvB.Shutdown() + + urlB := fmt.Sprintf("nats://%s:%d/", optsB.Host, srvB.Addr().(*net.TCPAddr).Port) + + nc2, err := nats.Connect(urlB) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc2.Close() + + checkClusterFormed(t, srvSeed, srvA, srvB) + + nc2.Publish("foo", []byte("Hello")) + + // Wait for message + select { + case <-ch: + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for message across route") + } +} + +// Helper function to check that a server (or list of servers) have the +// expected number of subscriptions. +func checkExpectedSubs(t *testing.T, expected int, servers ...*Server) { + t.Helper() + checkFor(t, 10*time.Second, 10*time.Millisecond, func() error { + for _, s := range servers { + if numSubs := int(s.NumSubscriptions()); numSubs != expected { + return fmt.Errorf("Expected %d subscriptions for server %q, got %d", expected, s.ID(), numSubs) + } + } + return nil + }) +} + +func TestTLSChainedSolicitWorks(t *testing.T) { + optsSeed, _ := ProcessConfigFile("./configs/seed_tls.conf") + + optsSeed.NoSigs, optsSeed.NoLog = true, true + + srvSeed := RunServer(optsSeed) + defer srvSeed.Shutdown() + + urlSeedRoute := fmt.Sprintf("nats://%s:%d", optsSeed.Cluster.Host, + srvSeed.ClusterAddr().Port) + optsA := nextServerOpts(optsSeed) + optsA.Routes = RoutesFromStr(urlSeedRoute) + + srvA := RunServer(optsA) + defer srvA.Shutdown() + + urlSeed := fmt.Sprintf("nats://%s:%d/", optsSeed.Host, srvSeed.Addr().(*net.TCPAddr).Port) + + nc1, err := nats.Connect(urlSeed) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc1.Close() + + // Test that we are connected. + ch := make(chan bool) + nc1.Subscribe("foo", func(m *nats.Msg) { ch <- true }) + nc1.Flush() + + optsB := nextServerOpts(optsA) + // Server B connects to A + optsB.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", optsA.Cluster.Host, + srvA.ClusterAddr().Port)) + + srvB := RunServer(optsB) + defer srvB.Shutdown() + + checkClusterFormed(t, srvSeed, srvA, srvB) + checkExpectedSubs(t, 1, srvA, srvB) + + urlB := fmt.Sprintf("nats://%s:%d/", optsB.Host, srvB.Addr().(*net.TCPAddr).Port) + + nc2, err := nats.Connect(urlB) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc2.Close() + + nc2.Publish("foo", []byte("Hello")) + + // Wait for message + select { + case <-ch: + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for message across route") + } +} + +func TestRouteTLSHandshakeError(t *testing.T) { + optsSeed, _ := ProcessConfigFile("./configs/seed_tls.conf") + optsSeed.NoLog = true + optsSeed.NoSigs = true + srvSeed := RunServer(optsSeed) + defer srvSeed.Shutdown() + + opts := DefaultOptions() + opts.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", optsSeed.Cluster.Host, optsSeed.Cluster.Port)) + + srv := RunServer(opts) + defer srv.Shutdown() + + time.Sleep(500 * time.Millisecond) + + checkNumRoutes(t, srv, 0) +} + +func TestBlockedShutdownOnRouteAcceptLoopFailure(t *testing.T) { + opts := DefaultOptions() + opts.Cluster.Host = "x.x.x.x" + opts.Cluster.Port = 7222 + + s := New(opts) + go s.Start() + // Wait a second + time.Sleep(time.Second) + ch := make(chan bool) + go func() { + s.Shutdown() + ch <- true + }() + + timeout := time.NewTimer(5 * time.Second) + select { + case <-ch: + return + case <-timeout.C: + t.Fatal("Shutdown did not complete") + } +} + +func TestRouteUseIPv6(t *testing.T) { + opts := DefaultOptions() + opts.Cluster.Host = "::" + opts.Cluster.Port = 6222 + + // I believe that there is no IPv6 support on Travis... + // Regardless, cannot have this test fail simply because IPv6 is disabled + // on the host. + hp := net.JoinHostPort(opts.Cluster.Host, strconv.Itoa(opts.Cluster.Port)) + _, err := net.ResolveTCPAddr("tcp", hp) + if err != nil { + t.Skipf("Skipping this test since there is no IPv6 support on this host: %v", err) + } + + s := RunServer(opts) + defer s.Shutdown() + + routeUp := false + timeout := time.Now().Add(5 * time.Second) + for time.Now().Before(timeout) && !routeUp { + // We know that the server is local and listening to + // all IPv6 interfaces. Try connect using IPv6 loopback. + if conn, err := net.Dial("tcp", "[::1]:6222"); err != nil { + // Travis seem to have the server actually listening to 0.0.0.0, + // so try with 127.0.0.1 + if conn, err := net.Dial("tcp", "127.0.0.1:6222"); err != nil { + time.Sleep(time.Second) + continue + } else { + conn.Close() + } + } else { + conn.Close() + } + routeUp = true + } + if !routeUp { + t.Fatal("Server failed to start route accept loop") + } +} + +func TestClientConnectToRoutePort(t *testing.T) { + opts := DefaultOptions() + + // Since client will first connect to the route listen port, set the + // cluster's Host to 127.0.0.1 so it works on Windows too, since on + // Windows, a client can't use 0.0.0.0 in a connect. + opts.Cluster.Host = "127.0.0.1" + s := RunServer(opts) + defer s.Shutdown() + + url := fmt.Sprintf("nats://%s:%d", opts.Cluster.Host, s.ClusterAddr().Port) + clientURL := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + // When connecting to the ROUTE port, the client library will receive the + // CLIENT port in the INFO protocol. This URL is added to the client's pool + // and will be tried after the initial connect failure. So all those + // nats.Connect() should succeed. + // The only reason for a failure would be if there are too many FDs in time-wait + // which would delay the creation of TCP connection. So keep the total of + // attempts rather small. + total := 10 + for i := 0; i < total; i++ { + nc, err := nats.Connect(url) + if err != nil { + t.Fatalf("Unexepected error on connect: %v", err) + } + defer nc.Close() + if nc.ConnectedUrl() != clientURL { + t.Fatalf("Expected client to be connected to %v, got %v", clientURL, nc.ConnectedUrl()) + } + } + + s.Shutdown() + // Try again with NoAdvertise and this time, the client should fail to connect. + opts.Cluster.NoAdvertise = true + s = RunServer(opts) + defer s.Shutdown() + + for i := 0; i < total; i++ { + nc, err := nats.Connect(url) + if err == nil { + nc.Close() + t.Fatal("Expected error on connect, got none") + } + } +} + +type checkDuplicateRouteLogger struct { + sync.Mutex + gotDuplicate bool +} + +func (l *checkDuplicateRouteLogger) Noticef(format string, v ...interface{}) {} +func (l *checkDuplicateRouteLogger) Errorf(format string, v ...interface{}) {} +func (l *checkDuplicateRouteLogger) Fatalf(format string, v ...interface{}) {} +func (l *checkDuplicateRouteLogger) Tracef(format string, v ...interface{}) {} +func (l *checkDuplicateRouteLogger) Debugf(format string, v ...interface{}) { + l.Lock() + defer l.Unlock() + msg := fmt.Sprintf(format, v...) + if strings.Contains(msg, "duplicate remote route") { + l.gotDuplicate = true + } +} + +func TestRoutesToEachOther(t *testing.T) { + optsA := DefaultOptions() + optsA.Cluster.Port = 7246 + optsA.Routes = RoutesFromStr("nats://127.0.0.1:7247") + + optsB := DefaultOptions() + optsB.Cluster.Port = 7247 + optsB.Routes = RoutesFromStr("nats://127.0.0.1:7246") + + srvALogger := &checkDuplicateRouteLogger{} + srvA := New(optsA) + srvA.SetLogger(srvALogger, true, false) + defer srvA.Shutdown() + + srvBLogger := &checkDuplicateRouteLogger{} + srvB := New(optsB) + srvB.SetLogger(srvBLogger, true, false) + defer srvB.Shutdown() + + go srvA.Start() + go srvB.Start() + + start := time.Now() + checkClusterFormed(t, srvA, srvB) + end := time.Now() + + srvALogger.Lock() + gotIt := srvALogger.gotDuplicate + srvALogger.Unlock() + if !gotIt { + srvBLogger.Lock() + gotIt = srvBLogger.gotDuplicate + srvBLogger.Unlock() + } + if gotIt { + dur := end.Sub(start) + // It should not take too long to have a successful connection + // between the 2 servers. + if dur > 5*time.Second { + t.Logf("Cluster formed, but took a long time: %v", dur) + } + } else { + t.Log("Was not able to get duplicate route this time!") + } +} + +func wait(ch chan bool) error { + select { + case <-ch: + return nil + case <-time.After(5 * time.Second): + } + return fmt.Errorf("timeout") +} + +func TestServerPoolUpdatedWhenRouteGoesAway(t *testing.T) { + s1Opts := DefaultOptions() + s1Opts.Host = "127.0.0.1" + s1Opts.Port = 4222 + s1Opts.Cluster.Host = "127.0.0.1" + s1Opts.Cluster.Port = 6222 + s1Opts.Routes = RoutesFromStr("nats://127.0.0.1:6223,nats://127.0.0.1:6224") + s1 := RunServer(s1Opts) + defer s1.Shutdown() + + s1Url := "nats://127.0.0.1:4222" + s2Url := "nats://127.0.0.1:4223" + s3Url := "nats://127.0.0.1:4224" + + ch := make(chan bool, 1) + chch := make(chan bool, 1) + connHandler := func(_ *nats.Conn) { + chch <- true + } + nc, err := nats.Connect(s1Url, + nats.ReconnectHandler(connHandler), + nats.DiscoveredServersHandler(func(_ *nats.Conn) { + ch <- true + })) + if err != nil { + t.Fatalf("Error on connect") + } + + s2Opts := DefaultOptions() + s2Opts.Host = "127.0.0.1" + s2Opts.Port = s1Opts.Port + 1 + s2Opts.Cluster.Host = "127.0.0.1" + s2Opts.Cluster.Port = 6223 + s2Opts.Routes = RoutesFromStr("nats://127.0.0.1:6222,nats://127.0.0.1:6224") + s2 := RunServer(s2Opts) + defer s2.Shutdown() + + // Wait to be notified + if err := wait(ch); err != nil { + t.Fatal("New server callback was not invoked") + } + + checkPool := func(expected []string) { + // Don't use discovered here, but Servers to have the full list. + // Also, there may be cases where the mesh is not formed yet, + // so try again on failure. + checkFor(t, 5*time.Second, 50*time.Millisecond, func() error { + ds := nc.Servers() + if len(ds) == len(expected) { + m := make(map[string]struct{}, len(ds)) + for _, url := range ds { + m[url] = struct{}{} + } + ok := true + for _, url := range expected { + if _, present := m[url]; !present { + ok = false + break + } + } + if ok { + return nil + } + } + return fmt.Errorf("Expected %v, got %v", expected, ds) + }) + } + // Verify that we now know about s2 + checkPool([]string{s1Url, s2Url}) + + s3Opts := DefaultOptions() + s3Opts.Host = "127.0.0.1" + s3Opts.Port = s2Opts.Port + 1 + s3Opts.Cluster.Host = "127.0.0.1" + s3Opts.Cluster.Port = 6224 + s3Opts.Routes = RoutesFromStr("nats://127.0.0.1:6222,nats://127.0.0.1:6223") + s3 := RunServer(s3Opts) + defer s3.Shutdown() + + // Wait to be notified + if err := wait(ch); err != nil { + t.Fatal("New server callback was not invoked") + } + // Verify that we now know about s3 + checkPool([]string{s1Url, s2Url, s3Url}) + + // Stop s1. Since this was passed to the Connect() call, this one should + // still be present. + s1.Shutdown() + // Wait for reconnect + if err := wait(chch); err != nil { + t.Fatal("Reconnect handler not invoked") + } + checkPool([]string{s1Url, s2Url, s3Url}) + + // Check the server we reconnected to. + reConnectedTo := nc.ConnectedUrl() + expected := []string{s1Url} + if reConnectedTo == s2Url { + s2.Shutdown() + expected = append(expected, s3Url) + } else if reConnectedTo == s3Url { + s3.Shutdown() + expected = append(expected, s2Url) + } else { + t.Fatalf("Unexpected server client has reconnected to: %v", reConnectedTo) + } + // Wait for reconnect + if err := wait(chch); err != nil { + t.Fatal("Reconnect handler not invoked") + } + // The implicit server that we just shutdown should have been removed from the pool + checkPool(expected) + nc.Close() +} + +func TestRoutedQueueAutoUnsubscribe(t *testing.T) { + optsA, _ := ProcessConfigFile("./configs/seed.conf") + optsA.NoSigs, optsA.NoLog = true, true + optsA.RQSubsSweep = 500 * time.Millisecond + srvA := RunServer(optsA) + defer srvA.Shutdown() + + srvARouteURL := fmt.Sprintf("nats://%s:%d", optsA.Cluster.Host, srvA.ClusterAddr().Port) + optsB := nextServerOpts(optsA) + optsB.Routes = RoutesFromStr(srvARouteURL) + + srvB := RunServer(optsB) + defer srvB.Shutdown() + + // Wait for these 2 to connect to each other + checkClusterFormed(t, srvA, srvB) + + // Have a client connection to each server + ncA, err := nats.Connect(fmt.Sprintf("nats://%s:%d", optsA.Host, optsA.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer ncA.Close() + + ncB, err := nats.Connect(fmt.Sprintf("nats://%s:%d", optsB.Host, optsB.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer ncB.Close() + + rbar := int32(0) + barCb := func(m *nats.Msg) { + atomic.AddInt32(&rbar, 1) + } + rbaz := int32(0) + bazCb := func(m *nats.Msg) { + atomic.AddInt32(&rbaz, 1) + } + + // Create 125 queue subs with auto-unsubscribe to each server for + // group bar and group baz. So 250 total per queue group. + cons := []*nats.Conn{ncA, ncB} + for _, c := range cons { + for i := 0; i < 125; i++ { + qsub, err := c.QueueSubscribe("foo", "bar", barCb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := qsub.AutoUnsubscribe(1); err != nil { + t.Fatalf("Error on auto-unsubscribe: %v", err) + } + qsub, err = c.QueueSubscribe("foo", "baz", bazCb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := qsub.AutoUnsubscribe(1); err != nil { + t.Fatalf("Error on auto-unsubscribe: %v", err) + } + } + c.Flush() + } + + expected := int32(250) + // Now send messages from each server + for i := int32(0); i < expected; i++ { + c := cons[i%2] + c.Publish("foo", []byte("Don't Drop Me!")) + } + for _, c := range cons { + c.Flush() + } + + checkFor(t, 10*time.Second, 100*time.Millisecond, func() error { + nbar := atomic.LoadInt32(&rbar) + nbaz := atomic.LoadInt32(&rbaz) + if nbar == expected && nbaz == expected { + time.Sleep(500 * time.Millisecond) + // Now check all mappings are gone. + srvA.rqsMu.RLock() + nrqsa := len(srvA.rqsubs) + srvA.rqsMu.RUnlock() + srvB.rqsMu.RLock() + nrqsb := len(srvB.rqsubs) + srvB.rqsMu.RUnlock() + if nrqsa != 0 || nrqsb != 0 { + return fmt.Errorf("Expected rqs mappings to have cleared, but got A:%d, B:%d", + nrqsa, nrqsb) + } + return nil + } + return fmt.Errorf("Did not receive all %d queue messages, received %d for 'bar' and %d for 'baz'", + expected, atomic.LoadInt32(&rbar), atomic.LoadInt32(&rbaz)) + }) +} + +func TestRouteFailedConnRemovedFromTmpMap(t *testing.T) { + optsA, _ := ProcessConfigFile("./configs/srv_a.conf") + optsA.NoSigs, optsA.NoLog = true, true + + optsB, _ := ProcessConfigFile("./configs/srv_b.conf") + optsB.NoSigs, optsB.NoLog = true, true + + srvA := New(optsA) + defer srvA.Shutdown() + srvB := New(optsB) + defer srvB.Shutdown() + + // Start this way to increase chance of having the two connect + // to each other at the same time. This will cause one of the + // route to be dropped. + wg := &sync.WaitGroup{} + wg.Add(2) + go func() { + srvA.Start() + wg.Done() + }() + go func() { + srvB.Start() + wg.Done() + }() + + checkClusterFormed(t, srvA, srvB) + + // Ensure that maps are empty + checkMap := func(s *Server) { + s.grMu.Lock() + l := len(s.grTmpClients) + s.grMu.Unlock() + if l != 0 { + stackFatalf(t, "grTmpClients map should be empty, got %v", l) + } + } + checkMap(srvA) + checkMap(srvB) + + srvB.Shutdown() + srvA.Shutdown() + wg.Wait() +} diff --git a/vendor/github.com/nats-io/gnatsd/server/server.go b/vendor/github.com/nats-io/gnatsd/server/server.go new file mode 100644 index 00000000..bddb9de3 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/server.go @@ -0,0 +1,1420 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "net" + "net/http" + "os" + "path" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + // Allow dynamic profiling. + _ "net/http/pprof" + + "github.com/nats-io/gnatsd/util" +) + +// Info is the information sent to clients to help them understand information +// about this server. +type Info struct { + ID string `json:"server_id"` + Version string `json:"version"` + Proto int `json:"proto"` + GitCommit string `json:"git_commit,omitempty"` + GoVersion string `json:"go"` + Host string `json:"host"` + Port int `json:"port"` + AuthRequired bool `json:"auth_required,omitempty"` + TLSRequired bool `json:"tls_required,omitempty"` + TLSVerify bool `json:"tls_verify,omitempty"` + MaxPayload int `json:"max_payload"` + IP string `json:"ip,omitempty"` + CID uint64 `json:"client_id,omitempty"` + ClientConnectURLs []string `json:"connect_urls,omitempty"` // Contains URLs a client can connect to. +} + +// Server is our main struct. +type Server struct { + gcid uint64 + stats + mu sync.Mutex + info Info + sl *Sublist + configFile string + optsMu sync.RWMutex + opts *Options + running bool + shutdown bool + listener net.Listener + clients map[uint64]*client + routes map[uint64]*client + remotes map[string]*client + users map[string]*User + totalClients uint64 + closed *closedRingBuffer + done chan bool + start time.Time + http net.Listener + httpHandler http.Handler + profiler net.Listener + httpReqStats map[string]uint64 + routeListener net.Listener + routeInfo Info + routeInfoJSON []byte + quitCh chan struct{} + + // Tracking for remote QRSID tags. + rqsMu sync.RWMutex + rqsubs map[string]rqsub + rqsubsTimer *time.Timer + + // Tracking Go routines + grMu sync.Mutex + grTmpClients map[uint64]*client + grRunning bool + grWG sync.WaitGroup // to wait on various go routines + + cproto int64 // number of clients supporting async INFO + configTime time.Time // last time config was loaded + + logging struct { + sync.RWMutex + logger Logger + trace int32 + debug int32 + } + + clientConnectURLs []string + + // Used internally for quick look-ups. + clientConnectURLsMap map[string]struct{} + + lastCURLsUpdate int64 + + // These store the real client/cluster listen ports. They are + // required during config reload to reset the Options (after + // reload) to the actual listen port values. + clientActualPort int + clusterActualPort int + + // Used by tests to check that http.Servers do + // not set any timeout. + monitoringServer *http.Server + profilingServer *http.Server +} + +// Make sure all are 64bits for atomic use +type stats struct { + inMsgs int64 + outMsgs int64 + inBytes int64 + outBytes int64 + slowConsumers int64 +} + +// New will setup a new server struct after parsing the options. +func New(opts *Options) *Server { + processOptions(opts) + + // Process TLS options, including whether we require client certificates. + tlsReq := opts.TLSConfig != nil + verify := (tlsReq && opts.TLSConfig.ClientAuth == tls.RequireAndVerifyClientCert) + + info := Info{ + ID: genID(), + Version: VERSION, + Proto: PROTO, + GitCommit: gitCommit, + GoVersion: runtime.Version(), + Host: opts.Host, + Port: opts.Port, + AuthRequired: false, + TLSRequired: tlsReq, + TLSVerify: verify, + MaxPayload: opts.MaxPayload, + } + + now := time.Now() + s := &Server{ + configFile: opts.ConfigFile, + info: info, + sl: NewSublist(), + opts: opts, + done: make(chan bool, 1), + start: now, + configTime: now, + } + + s.mu.Lock() + defer s.mu.Unlock() + + // This is normally done in the AcceptLoop, once the + // listener has been created (possibly with random port), + // but since some tests may expect the INFO to be properly + // set after New(), let's do it now. + s.setInfoHostPortAndGenerateJSON() + + // Used internally for quick look-ups. + s.clientConnectURLsMap = make(map[string]struct{}) + + // For tracking clients + s.clients = make(map[uint64]*client) + + // For tracking closed clients. + s.closed = newClosedRingBuffer(opts.MaxClosedClients) + + // For tracking connections that are not yet registered + // in s.routes, but for which readLoop has started. + s.grTmpClients = make(map[uint64]*client) + + // For tracking routes and their remote ids + s.routes = make(map[uint64]*client) + s.remotes = make(map[string]*client) + + // Used to kick out all go routines possibly waiting on server + // to shutdown. + s.quitCh = make(chan struct{}) + + // Used to setup Authorization. + s.configureAuthorization() + + // Start signal handler + s.handleSignals() + + return s +} + +func (s *Server) getOpts() *Options { + s.optsMu.RLock() + opts := s.opts + s.optsMu.RUnlock() + return opts +} + +func (s *Server) setOpts(opts *Options) { + s.optsMu.Lock() + s.opts = opts + s.optsMu.Unlock() +} + +func (s *Server) generateRouteInfoJSON() { + b, _ := json.Marshal(s.routeInfo) + pcs := [][]byte{[]byte("INFO"), b, []byte(CR_LF)} + s.routeInfoJSON = bytes.Join(pcs, []byte(" ")) +} + +// PrintAndDie is exported for access in other packages. +func PrintAndDie(msg string) { + fmt.Fprintf(os.Stderr, "%s\n", msg) + os.Exit(1) +} + +// PrintServerAndExit will print our version and exit. +func PrintServerAndExit() { + fmt.Printf("nats-server version %s\n", VERSION) + os.Exit(0) +} + +// ProcessCommandLineArgs takes the command line arguments +// validating and setting flags for handling in case any +// sub command was present. +func ProcessCommandLineArgs(cmd *flag.FlagSet) (showVersion bool, showHelp bool, err error) { + if len(cmd.Args()) > 0 { + arg := cmd.Args()[0] + switch strings.ToLower(arg) { + case "version": + return true, false, nil + case "help": + return false, true, nil + default: + return false, false, fmt.Errorf("unrecognized command: %q", arg) + } + } + + return false, false, nil +} + +// Protected check on running state +func (s *Server) isRunning() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.running +} + +func (s *Server) logPid() error { + pidStr := strconv.Itoa(os.Getpid()) + return ioutil.WriteFile(s.getOpts().PidFile, []byte(pidStr), 0660) +} + +// Start up the server, this will block. +// Start via a Go routine if needed. +func (s *Server) Start() { + s.Noticef("Starting nats-server version %s", VERSION) + s.Debugf("Go build version %s", s.info.GoVersion) + gc := gitCommit + if gc == "" { + gc = "not set" + } + s.Noticef("Git commit [%s]", gc) + + // Avoid RACE between Start() and Shutdown() + s.mu.Lock() + s.running = true + s.mu.Unlock() + + s.grMu.Lock() + s.grRunning = true + s.grMu.Unlock() + + // Snapshot server options. + opts := s.getOpts() + + // Log the pid to a file + if opts.PidFile != _EMPTY_ { + if err := s.logPid(); err != nil { + PrintAndDie(fmt.Sprintf("Could not write pidfile: %v\n", err)) + } + } + + // Start monitoring if needed + if err := s.StartMonitoring(); err != nil { + s.Fatalf("Can't start monitoring: %v", err) + return + } + + // The Routing routine needs to wait for the client listen + // port to be opened and potential ephemeral port selected. + clientListenReady := make(chan struct{}) + + // Start up routing as well if needed. + if opts.Cluster.Port != 0 { + s.startGoRoutine(func() { + s.StartRouting(clientListenReady) + }) + } + + // Pprof http endpoint for the profiler. + if opts.ProfPort != 0 { + s.StartProfiler() + } + + if opts.PortsFileDir != _EMPTY_ { + s.logPorts() + } + + // Wait for clients. + s.AcceptLoop(clientListenReady) +} + +// Shutdown will shutdown the server instance by kicking out the AcceptLoop +// and closing all associated clients. +func (s *Server) Shutdown() { + s.mu.Lock() + // Prevent issues with multiple calls. + if s.shutdown { + s.mu.Unlock() + return + } + + opts := s.getOpts() + + s.shutdown = true + s.running = false + s.grMu.Lock() + s.grRunning = false + s.grMu.Unlock() + + conns := make(map[uint64]*client) + + // Copy off the clients + for i, c := range s.clients { + conns[i] = c + } + // Copy off the connections that are not yet registered + // in s.routes, but for which the readLoop has started + s.grMu.Lock() + for i, c := range s.grTmpClients { + conns[i] = c + } + s.grMu.Unlock() + // Copy off the routes + for i, r := range s.routes { + r.setRouteNoReconnectOnClose() + conns[i] = r + } + + // Number of done channel responses we expect. + doneExpected := 0 + + // Kick client AcceptLoop() + if s.listener != nil { + doneExpected++ + s.listener.Close() + s.listener = nil + } + + // Kick route AcceptLoop() + if s.routeListener != nil { + doneExpected++ + s.routeListener.Close() + s.routeListener = nil + } + + // Kick HTTP monitoring if its running + if s.http != nil { + doneExpected++ + s.http.Close() + s.http = nil + } + + // Kick Profiling if its running + if s.profiler != nil { + doneExpected++ + s.profiler.Close() + } + + // Clear any remote qsub mappings + s.clearRemoteQSubs() + s.mu.Unlock() + + // Release go routines that wait on that channel + close(s.quitCh) + + // Close client and route connections + for _, c := range conns { + c.closeConnection(ServerShutdown) + } + + // Block until the accept loops exit + for doneExpected > 0 { + <-s.done + doneExpected-- + } + + // Wait for go routines to be done. + s.grWG.Wait() + + if opts.PortsFileDir != _EMPTY_ { + s.deletePortsFile(opts.PortsFileDir) + } +} + +// AcceptLoop is exported for easier testing. +func (s *Server) AcceptLoop(clr chan struct{}) { + // If we were to exit before the listener is setup properly, + // make sure we close the channel. + defer func() { + if clr != nil { + close(clr) + } + }() + + // Snapshot server options. + opts := s.getOpts() + + hp := net.JoinHostPort(opts.Host, strconv.Itoa(opts.Port)) + l, e := net.Listen("tcp", hp) + if e != nil { + s.Fatalf("Error listening on port: %s, %q", hp, e) + return + } + s.Noticef("Listening for client connections on %s", + net.JoinHostPort(opts.Host, strconv.Itoa(l.Addr().(*net.TCPAddr).Port))) + + // Alert of TLS enabled. + if opts.TLSConfig != nil { + s.Noticef("TLS required for client connections") + } + + s.Debugf("Server id is %s", s.info.ID) + s.Noticef("Server is ready") + + // Setup state that can enable shutdown + s.mu.Lock() + s.listener = l + + // If server was started with RANDOM_PORT (-1), opts.Port would be equal + // to 0 at the beginning this function. So we need to get the actual port + if opts.Port == 0 { + // Write resolved port back to options. + opts.Port = l.Addr().(*net.TCPAddr).Port + } + // Keep track of actual listen port. This will be needed in case of + // config reload. + s.clientActualPort = opts.Port + + // Now that port has been set (if it was set to RANDOM), set the + // server's info Host/Port with either values from Options or + // ClientAdvertise. Also generate the JSON byte array. + if err := s.setInfoHostPortAndGenerateJSON(); err != nil { + s.Fatalf("Error setting server INFO with ClientAdvertise value of %s, err=%v", s.opts.ClientAdvertise, err) + s.mu.Unlock() + return + } + // Keep track of client connect URLs. We may need them later. + s.clientConnectURLs = s.getClientConnectURLs() + s.mu.Unlock() + + // Let the caller know that we are ready + close(clr) + clr = nil + + tmpDelay := ACCEPT_MIN_SLEEP + + for s.isRunning() { + conn, err := l.Accept() + if err != nil { + if ne, ok := err.(net.Error); ok && ne.Temporary() { + s.Errorf("Temporary Client Accept Error (%v), sleeping %dms", + ne, tmpDelay/time.Millisecond) + time.Sleep(tmpDelay) + tmpDelay *= 2 + if tmpDelay > ACCEPT_MAX_SLEEP { + tmpDelay = ACCEPT_MAX_SLEEP + } + } else if s.isRunning() { + s.Errorf("Client Accept Error: %v", err) + } + continue + } + tmpDelay = ACCEPT_MIN_SLEEP + s.startGoRoutine(func() { + s.createClient(conn) + s.grWG.Done() + }) + } + s.Noticef("Server Exiting..") + s.done <- true +} + +// This function sets the server's info Host/Port based on server Options. +// Note that this function may be called during config reload, this is why +// Host/Port may be reset to original Options if the ClientAdvertise option +// is not set (since it may have previously been). +// The function then generates the server infoJSON. +func (s *Server) setInfoHostPortAndGenerateJSON() error { + // When this function is called, opts.Port is set to the actual listen + // port (if option was originally set to RANDOM), even during a config + // reload. So use of s.opts.Port is safe. + if s.opts.ClientAdvertise != "" { + h, p, err := parseHostPort(s.opts.ClientAdvertise, s.opts.Port) + if err != nil { + return err + } + s.info.Host = h + s.info.Port = p + } else { + s.info.Host = s.opts.Host + s.info.Port = s.opts.Port + } + return nil +} + +// StartProfiler is called to enable dynamic profiling. +func (s *Server) StartProfiler() { + // Snapshot server options. + opts := s.getOpts() + + port := opts.ProfPort + + // Check for Random Port + if port == -1 { + port = 0 + } + + hp := net.JoinHostPort(opts.Host, strconv.Itoa(port)) + + l, err := net.Listen("tcp", hp) + s.Noticef("profiling port: %d", l.Addr().(*net.TCPAddr).Port) + + if err != nil { + s.Fatalf("error starting profiler: %s", err) + } + + srv := &http.Server{ + Addr: hp, + Handler: http.DefaultServeMux, + MaxHeaderBytes: 1 << 20, + } + + s.mu.Lock() + s.profiler = l + s.profilingServer = srv + s.mu.Unlock() + + go func() { + // if this errors out, it's probably because the server is being shutdown + err := srv.Serve(l) + if err != nil { + s.mu.Lock() + shutdown := s.shutdown + s.mu.Unlock() + if !shutdown { + s.Fatalf("error starting profiler: %s", err) + } + } + s.done <- true + }() +} + +// StartHTTPMonitoring will enable the HTTP monitoring port. +// DEPRECATED: Should use StartMonitoring. +func (s *Server) StartHTTPMonitoring() { + s.startMonitoring(false) +} + +// StartHTTPSMonitoring will enable the HTTPS monitoring port. +// DEPRECATED: Should use StartMonitoring. +func (s *Server) StartHTTPSMonitoring() { + s.startMonitoring(true) +} + +// StartMonitoring starts the HTTP or HTTPs server if needed. +func (s *Server) StartMonitoring() error { + // Snapshot server options. + opts := s.getOpts() + + // Specifying both HTTP and HTTPS ports is a misconfiguration + if opts.HTTPPort != 0 && opts.HTTPSPort != 0 { + return fmt.Errorf("can't specify both HTTP (%v) and HTTPs (%v) ports", opts.HTTPPort, opts.HTTPSPort) + } + var err error + if opts.HTTPPort != 0 { + err = s.startMonitoring(false) + } else if opts.HTTPSPort != 0 { + if opts.TLSConfig == nil { + return fmt.Errorf("TLS cert and key required for HTTPS") + } + err = s.startMonitoring(true) + } + return err +} + +// HTTP endpoints +const ( + RootPath = "/" + VarzPath = "/varz" + ConnzPath = "/connz" + RoutezPath = "/routez" + SubszPath = "/subsz" + StackszPath = "/stacksz" +) + +// Start the monitoring server +func (s *Server) startMonitoring(secure bool) error { + // Snapshot server options. + opts := s.getOpts() + + // Used to track HTTP requests + s.httpReqStats = map[string]uint64{ + RootPath: 0, + VarzPath: 0, + ConnzPath: 0, + RoutezPath: 0, + SubszPath: 0, + } + + var ( + hp string + err error + httpListener net.Listener + port int + ) + + monitorProtocol := "http" + + if secure { + monitorProtocol += "s" + port = opts.HTTPSPort + if port == -1 { + port = 0 + } + hp = net.JoinHostPort(opts.HTTPHost, strconv.Itoa(port)) + config := util.CloneTLSConfig(opts.TLSConfig) + config.ClientAuth = tls.NoClientCert + httpListener, err = tls.Listen("tcp", hp, config) + + } else { + port = opts.HTTPPort + if port == -1 { + port = 0 + } + hp = net.JoinHostPort(opts.HTTPHost, strconv.Itoa(port)) + httpListener, err = net.Listen("tcp", hp) + } + + if err != nil { + return fmt.Errorf("can't listen to the monitor port: %v", err) + } + + s.Noticef("Starting %s monitor on %s", monitorProtocol, + net.JoinHostPort(opts.HTTPHost, strconv.Itoa(httpListener.Addr().(*net.TCPAddr).Port))) + + mux := http.NewServeMux() + + // Root + mux.HandleFunc(RootPath, s.HandleRoot) + // Varz + mux.HandleFunc(VarzPath, s.HandleVarz) + // Connz + mux.HandleFunc(ConnzPath, s.HandleConnz) + // Routez + mux.HandleFunc(RoutezPath, s.HandleRoutez) + // Subz + mux.HandleFunc(SubszPath, s.HandleSubsz) + // Subz alias for backwards compatibility + mux.HandleFunc("/subscriptionsz", s.HandleSubsz) + // Stacksz + mux.HandleFunc(StackszPath, s.HandleStacksz) + + // Do not set a WriteTimeout because it could cause cURL/browser + // to return empty response or unable to display page if the + // server needs more time to build the response. + srv := &http.Server{ + Addr: hp, + Handler: mux, + MaxHeaderBytes: 1 << 20, + } + s.mu.Lock() + s.http = httpListener + s.httpHandler = mux + s.monitoringServer = srv + s.mu.Unlock() + + go func() { + srv.Serve(httpListener) + srv.Handler = nil + s.mu.Lock() + s.httpHandler = nil + s.mu.Unlock() + s.done <- true + }() + + return nil +} + +// HTTPHandler returns the http.Handler object used to handle monitoring +// endpoints. It will return nil if the server is not configured for +// monitoring, or if the server has not been started yet (Server.Start()). +func (s *Server) HTTPHandler() http.Handler { + s.mu.Lock() + defer s.mu.Unlock() + return s.httpHandler +} + +// Perform a conditional deep copy due to reference nature of ClientConnectURLs. +// If updates are made to Info, this function should be consulted and updated. +// Assume lock is held. +func (s *Server) copyInfo() Info { + info := s.info + if info.ClientConnectURLs != nil { + info.ClientConnectURLs = make([]string, len(s.info.ClientConnectURLs)) + copy(info.ClientConnectURLs, s.info.ClientConnectURLs) + } + return info +} + +func (s *Server) createClient(conn net.Conn) *client { + // Snapshot server options. + opts := s.getOpts() + + max_pay := int64(opts.MaxPayload) + max_subs := opts.MaxSubs + now := time.Now() + + c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: max_pay, msubs: max_subs, start: now, last: now} + + // Grab JSON info string + s.mu.Lock() + info := s.copyInfo() + s.totalClients++ + s.mu.Unlock() + + // Grab lock + c.mu.Lock() + + // Initialize + c.initClient() + + c.Debugf("Client connection created") + + // Send our information. + c.sendInfo(c.generateClientInfoJSON(info)) + + // Unlock to register + c.mu.Unlock() + + // Register with the server. + s.mu.Lock() + // If server is not running, Shutdown() may have already gathered the + // list of connections to close. It won't contain this one, so we need + // to bail out now otherwise the readLoop started down there would not + // be interrupted. + if !s.running { + s.mu.Unlock() + return c + } + + // If there is a max connections specified, check that adding + // this new client would not push us over the max + if opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn { + s.mu.Unlock() + c.maxConnExceeded() + return nil + } + s.clients[c.cid] = c + s.mu.Unlock() + + // Re-Grab lock + c.mu.Lock() + + // Check for TLS + if info.TLSRequired { + c.Debugf("Starting TLS client connection handshake") + c.nc = tls.Server(c.nc, opts.TLSConfig) + conn := c.nc.(*tls.Conn) + + // Setup the timeout + ttl := secondsToDuration(opts.TLSTimeout) + time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) + conn.SetReadDeadline(time.Now().Add(ttl)) + + // Force handshake + c.mu.Unlock() + if err := conn.Handshake(); err != nil { + c.Errorf("TLS handshake error: %v", err) + c.closeConnection(TLSHandshakeError) + return nil + } + // Reset the read deadline + conn.SetReadDeadline(time.Time{}) + + // Re-Grab lock + c.mu.Lock() + + // Indicate that handshake is complete (used in monitoring) + c.flags.set(handshakeComplete) + } + + // The connection may have been closed + if c.nc == nil { + c.mu.Unlock() + return c + } + + // Check for Auth. We schedule this timer after the TLS handshake to avoid + // the race where the timer fires during the handshake and causes the + // server to write bad data to the socket. See issue #432. + if info.AuthRequired { + c.setAuthTimer(secondsToDuration(opts.AuthTimeout)) + } + + // Do final client initialization + + // Set the Ping timer + c.setPingTimer() + + // Spin up the read loop. + s.startGoRoutine(c.readLoop) + + // Spin up the write loop. + s.startGoRoutine(c.writeLoop) + + if info.TLSRequired { + c.Debugf("TLS handshake complete") + cs := c.nc.(*tls.Conn).ConnectionState() + c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite)) + } + + c.mu.Unlock() + + return c +} + +// This will save off a closed client in a ring buffer such that +// /connz can inspect. Useful for debugging, etc. +func (s *Server) saveClosedClient(c *client, nc net.Conn, reason ClosedState) { + now := time.Now() + + c.mu.Lock() + + cc := &closedClient{} + cc.fill(c, nc, now) + cc.Stop = &now + cc.Reason = reason.String() + + // Do subs, do not place by default in main ConnInfo + if len(c.subs) > 0 { + cc.subs = make([]string, 0, len(c.subs)) + for _, sub := range c.subs { + cc.subs = append(cc.subs, string(sub.subject)) + } + } + // Hold user as well. + cc.user = c.opts.Username + c.mu.Unlock() + + // Place in the ring buffer + s.mu.Lock() + s.closed.append(cc) + s.mu.Unlock() +} + +// Adds the given array of urls to the server's INFO.ClientConnectURLs +// array. The server INFO JSON is regenerated. +// Note that a check is made to ensure that given URLs are not +// already present. So the INFO JSON is regenerated only if new ULRs +// were added. +// If there was a change, an INFO protocol is sent to registered clients +// that support async INFO protocols. +func (s *Server) addClientConnectURLsAndSendINFOToClients(urls []string) { + s.updateServerINFOAndSendINFOToClients(urls, true) +} + +// Removes the given array of urls from the server's INFO.ClientConnectURLs +// array. The server INFO JSON is regenerated if needed. +// If there was a change, an INFO protocol is sent to registered clients +// that support async INFO protocols. +func (s *Server) removeClientConnectURLsAndSendINFOToClients(urls []string) { + s.updateServerINFOAndSendINFOToClients(urls, false) +} + +// Updates the server's Info object with the given array of URLs and re-generate +// the infoJSON byte array, then send an (async) INFO protocol to clients that +// support it. +func (s *Server) updateServerINFOAndSendINFOToClients(urls []string, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + + // Will be set to true if we alter the server's Info object. + wasUpdated := false + remove := !add + for _, url := range urls { + _, present := s.clientConnectURLsMap[url] + if add && !present { + s.clientConnectURLsMap[url] = struct{}{} + wasUpdated = true + } else if remove && present { + delete(s.clientConnectURLsMap, url) + wasUpdated = true + } + } + if wasUpdated { + // Recreate the info.ClientConnectURL array from the map + s.info.ClientConnectURLs = s.info.ClientConnectURLs[:0] + // Add this server client connect ULRs first... + s.info.ClientConnectURLs = append(s.info.ClientConnectURLs, s.clientConnectURLs...) + for url := range s.clientConnectURLsMap { + s.info.ClientConnectURLs = append(s.info.ClientConnectURLs, url) + } + // Update the time of this update + s.lastCURLsUpdate = time.Now().UnixNano() + // Send to all registered clients that support async INFO protocols. + s.sendAsyncInfoToClients() + } +} + +// Handle closing down a connection when the handshake has timedout. +func tlsTimeout(c *client, conn *tls.Conn) { + c.mu.Lock() + nc := c.nc + c.mu.Unlock() + // Check if already closed + if nc == nil { + return + } + cs := conn.ConnectionState() + if !cs.HandshakeComplete { + c.Errorf("TLS handshake timeout") + c.sendErr("Secure Connection - TLS Required") + c.closeConnection(TLSHandshakeError) + } +} + +// Seems silly we have to write these +func tlsVersion(ver uint16) string { + switch ver { + case tls.VersionTLS10: + return "1.0" + case tls.VersionTLS11: + return "1.1" + case tls.VersionTLS12: + return "1.2" + } + return fmt.Sprintf("Unknown [%x]", ver) +} + +// We use hex here so we don't need multiple versions +func tlsCipher(cs uint16) string { + name, present := cipherMapByID[cs] + if present { + return name + } + return fmt.Sprintf("Unknown [%x]", cs) +} + +// Remove a client or route from our internal accounting. +func (s *Server) removeClient(c *client) { + var rID string + c.mu.Lock() + cid := c.cid + typ := c.typ + r := c.route + if r != nil { + rID = r.remoteID + } + updateProtoInfoCount := false + if typ == CLIENT && c.opts.Protocol >= ClientProtoInfo { + updateProtoInfoCount = true + } + c.mu.Unlock() + + s.mu.Lock() + switch typ { + case CLIENT: + delete(s.clients, cid) + if updateProtoInfoCount { + s.cproto-- + } + case ROUTER: + delete(s.routes, cid) + if r != nil { + rc, ok := s.remotes[rID] + // Only delete it if it is us.. + if ok && c == rc { + delete(s.remotes, rID) + } + } + // Remove from temporary map in case it is there. + s.grMu.Lock() + delete(s.grTmpClients, cid) + s.grMu.Unlock() + } + s.mu.Unlock() +} + +///////////////////////////////////////////////////////////////// +// These are some helpers for accounting in functional tests. +///////////////////////////////////////////////////////////////// + +// NumRoutes will report the number of registered routes. +func (s *Server) NumRoutes() int { + s.mu.Lock() + nr := len(s.routes) + s.mu.Unlock() + return nr +} + +// NumRemotes will report number of registered remotes. +func (s *Server) NumRemotes() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.remotes) +} + +// NumClients will report the number of registered clients. +func (s *Server) NumClients() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.clients) +} + +// getClient will return the client associated with cid. +func (s *Server) getClient(cid uint64) *client { + s.mu.Lock() + defer s.mu.Unlock() + return s.clients[cid] +} + +// NumSubscriptions will report how many subscriptions are active. +func (s *Server) NumSubscriptions() uint32 { + s.mu.Lock() + subs := s.sl.Count() + s.mu.Unlock() + return subs +} + +// NumSlowConsumers will report the number of slow consumers. +func (s *Server) NumSlowConsumers() int64 { + return atomic.LoadInt64(&s.slowConsumers) +} + +// ConfigTime will report the last time the server configuration was loaded. +func (s *Server) ConfigTime() time.Time { + s.mu.Lock() + defer s.mu.Unlock() + return s.configTime +} + +// Addr will return the net.Addr object for the current listener. +func (s *Server) Addr() net.Addr { + s.mu.Lock() + defer s.mu.Unlock() + if s.listener == nil { + return nil + } + return s.listener.Addr() +} + +// MonitorAddr will return the net.Addr object for the monitoring listener. +func (s *Server) MonitorAddr() *net.TCPAddr { + s.mu.Lock() + defer s.mu.Unlock() + if s.http == nil { + return nil + } + return s.http.Addr().(*net.TCPAddr) +} + +// ClusterAddr returns the net.Addr object for the route listener. +func (s *Server) ClusterAddr() *net.TCPAddr { + s.mu.Lock() + defer s.mu.Unlock() + if s.routeListener == nil { + return nil + } + return s.routeListener.Addr().(*net.TCPAddr) +} + +// ProfilerAddr returns the net.Addr object for the route listener. +func (s *Server) ProfilerAddr() *net.TCPAddr { + s.mu.Lock() + defer s.mu.Unlock() + if s.profiler == nil { + return nil + } + return s.profiler.Addr().(*net.TCPAddr) +} + +// ReadyForConnections returns `true` if the server is ready to accept client +// and, if routing is enabled, route connections. If after the duration +// `dur` the server is still not ready, returns `false`. +func (s *Server) ReadyForConnections(dur time.Duration) bool { + // Snapshot server options. + opts := s.getOpts() + + end := time.Now().Add(dur) + for time.Now().Before(end) { + s.mu.Lock() + ok := s.listener != nil && (opts.Cluster.Port == 0 || s.routeListener != nil) + s.mu.Unlock() + if ok { + return true + } + time.Sleep(25 * time.Millisecond) + } + return false +} + +// ID returns the server's ID +func (s *Server) ID() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.info.ID +} + +func (s *Server) startGoRoutine(f func()) { + s.grMu.Lock() + if s.grRunning { + s.grWG.Add(1) + go f() + } + s.grMu.Unlock() +} + +func (s *Server) numClosedConns() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.closed.len() +} + +func (s *Server) totalClosedConns() uint64 { + s.mu.Lock() + defer s.mu.Unlock() + return s.closed.totalConns() +} + +func (s *Server) closedClients() []*closedClient { + s.mu.Lock() + defer s.mu.Unlock() + return s.closed.closedClients() +} + +// getClientConnectURLs returns suitable URLs for clients to connect to the listen +// port based on the server options' Host and Port. If the Host corresponds to +// "any" interfaces, this call returns the list of resolved IP addresses. +// If ClientAdvertise is set, returns the client advertise host and port. +// The server lock is assumed held on entry. +func (s *Server) getClientConnectURLs() []string { + // Snapshot server options. + opts := s.getOpts() + + urls := make([]string, 0, 1) + + // short circuit if client advertise is set + if opts.ClientAdvertise != "" { + // just use the info host/port. This is updated in s.New() + urls = append(urls, net.JoinHostPort(s.info.Host, strconv.Itoa(s.info.Port))) + } else { + sPort := strconv.Itoa(opts.Port) + ipAddr, err := net.ResolveIPAddr("ip", opts.Host) + // If the host is "any" (0.0.0.0 or ::), get specific IPs from available + // interfaces. + if err == nil && ipAddr.IP.IsUnspecified() { + var ip net.IP + ifaces, _ := net.Interfaces() + for _, i := range ifaces { + addrs, _ := i.Addrs() + for _, addr := range addrs { + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + // Skip non global unicast addresses + if !ip.IsGlobalUnicast() || ip.IsUnspecified() { + ip = nil + continue + } + urls = append(urls, net.JoinHostPort(ip.String(), sPort)) + } + } + } + if err != nil || len(urls) == 0 { + // We are here if s.opts.Host is not "0.0.0.0" nor "::", or if for some + // reason we could not add any URL in the loop above. + // We had a case where a Windows VM was hosed and would have err == nil + // and not add any address in the array in the loop above, and we + // ended-up returning 0.0.0.0, which is problematic for Windows clients. + // Check for 0.0.0.0 or :: specifically, and ignore if that's the case. + if opts.Host == "0.0.0.0" || opts.Host == "::" { + s.Errorf("Address %q can not be resolved properly", opts.Host) + } else { + urls = append(urls, net.JoinHostPort(opts.Host, sPort)) + } + } + } + + return urls +} + +// if the ip is not specified, attempt to resolve it +func resolveHostPorts(addr net.Listener) []string { + hostPorts := make([]string, 0) + hp := addr.Addr().(*net.TCPAddr) + port := strconv.Itoa(hp.Port) + if hp.IP.IsUnspecified() { + var ip net.IP + ifaces, _ := net.Interfaces() + for _, i := range ifaces { + addrs, _ := i.Addrs() + for _, addr := range addrs { + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + hostPorts = append(hostPorts, net.JoinHostPort(ip.String(), port)) + case *net.IPAddr: + ip = v.IP + hostPorts = append(hostPorts, net.JoinHostPort(ip.String(), port)) + default: + continue + } + } + } + } else { + hostPorts = append(hostPorts, net.JoinHostPort(hp.IP.String(), port)) + } + return hostPorts +} + +// format the address of a net.Listener with a protocol +func formatURL(protocol string, addr net.Listener) []string { + hostports := resolveHostPorts(addr) + for i, hp := range hostports { + hostports[i] = fmt.Sprintf("%s://%s", protocol, hp) + } + return hostports +} + +// Ports describes URLs that the server can be contacted in +type Ports struct { + Nats []string `json:"nats,omitempty"` + Monitoring []string `json:"monitoring,omitempty"` + Cluster []string `json:"cluster,omitempty"` + Profile []string `json:"profile,omitempty"` +} + +// Attempts to resolve all the ports. If after maxWait the ports are not +// resolved, it returns nil. Otherwise it returns a Ports struct +// describing ports where the server can be contacted +func (s *Server) PortsInfo(maxWait time.Duration) *Ports { + if s.readyForListeners(maxWait) { + opts := s.getOpts() + + s.mu.Lock() + info := s.copyInfo() + listener := s.listener + httpListener := s.http + clusterListener := s.routeListener + profileListener := s.profiler + s.mu.Unlock() + + ports := Ports{} + + if listener != nil { + natsProto := "nats" + if info.TLSRequired { + natsProto = "tls" + } + ports.Nats = formatURL(natsProto, listener) + } + + if httpListener != nil { + monProto := "http" + if opts.HTTPSPort != 0 { + monProto = "https" + } + ports.Monitoring = formatURL(monProto, httpListener) + } + + if clusterListener != nil { + clusterProto := "nats" + if opts.Cluster.TLSConfig != nil { + clusterProto = "tls" + } + ports.Cluster = formatURL(clusterProto, clusterListener) + } + + if profileListener != nil { + ports.Profile = formatURL("http", profileListener) + } + + return &ports + } + + return nil +} + +// Returns the portsFile. If a non-empty dirHint is provided, the dirHint +// path is used instead of the server option value +func (s *Server) portFile(dirHint string) string { + dirname := s.getOpts().PortsFileDir + if dirHint != "" { + dirname = dirHint + } + if dirname == _EMPTY_ { + return _EMPTY_ + } + return path.Join(dirname, fmt.Sprintf("%s_%d.ports", path.Base(os.Args[0]), os.Getpid())) +} + +// Delete the ports file. If a non-empty dirHint is provided, the dirHint +// path is used instead of the server option value +func (s *Server) deletePortsFile(hintDir string) { + portsFile := s.portFile(hintDir) + if portsFile != "" { + if err := os.Remove(portsFile); err != nil { + s.Errorf("Error cleaning up ports file %s: %v", portsFile, err) + } + } +} + +// Writes a file with a serialized Ports to the specified ports_file_dir. +// The name of the file is `exename_pid.ports`, typically gnatsd_pid.ports. +// if ports file is not set, this function has no effect +func (s *Server) logPorts() { + opts := s.getOpts() + portsFile := s.portFile(opts.PortsFileDir) + if portsFile != _EMPTY_ { + go func() { + info := s.PortsInfo(5 * time.Second) + if info == nil { + s.Errorf("Unable to resolve the ports in the specified time") + return + } + data, err := json.Marshal(info) + if err != nil { + s.Errorf("Error marshaling ports file: %v", err) + return + } + if err := ioutil.WriteFile(portsFile, data, 0666); err != nil { + s.Errorf("Error writing ports file (%s): %v", portsFile, err) + return + } + + }() + } +} + +// waits until a calculated list of listeners is resolved or a timeout +func (s *Server) readyForListeners(dur time.Duration) bool { + end := time.Now().Add(dur) + for time.Now().Before(end) { + s.mu.Lock() + listeners := s.serviceListeners() + s.mu.Unlock() + if len(listeners) == 0 { + return false + } + + ok := true + for _, l := range listeners { + if l == nil { + ok = false + break + } + } + if ok { + return true + } + select { + case <-s.quitCh: + return false + case <-time.After(25 * time.Millisecond): + // continue - unable to select from quit - we are still running + } + } + return false +} + +// returns a list of listeners that are intended for the process +// if the entry is nil, the interface is yet to be resolved +func (s *Server) serviceListeners() []net.Listener { + listeners := make([]net.Listener, 0) + opts := s.getOpts() + listeners = append(listeners, s.listener) + if opts.Cluster.Port != 0 { + listeners = append(listeners, s.routeListener) + } + if opts.HTTPPort != 0 || opts.HTTPSPort != 0 { + listeners = append(listeners, s.http) + } + if opts.ProfPort != 0 { + listeners = append(listeners, s.profiler) + } + return listeners +} diff --git a/vendor/github.com/nats-io/gnatsd/server/server_test.go b/vendor/github.com/nats-io/gnatsd/server/server_test.go new file mode 100644 index 00000000..20b8dbe2 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/server_test.go @@ -0,0 +1,700 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "flag" + "fmt" + "net" + "os" + "strings" + "testing" + "time" + + "github.com/nats-io/go-nats" +) + +func checkFor(t *testing.T, totalWait, sleepDur time.Duration, f func() error) { + t.Helper() + timeout := time.Now().Add(totalWait) + var err error + for time.Now().Before(timeout) { + err = f() + if err == nil { + return + } + time.Sleep(sleepDur) + } + if err != nil { + t.Fatal(err.Error()) + } +} + +func DefaultOptions() *Options { + return &Options{ + Host: "127.0.0.1", + Port: -1, + HTTPPort: -1, + Cluster: ClusterOpts{Port: -1}, + NoLog: true, + NoSigs: true, + Debug: true, + Trace: true, + } +} + +// New Go Routine based server +func RunServer(opts *Options) *Server { + if opts == nil { + opts = DefaultOptions() + } + s := New(opts) + + if s == nil { + panic("No NATS Server object returned.") + } + + if !opts.NoLog { + s.ConfigureLogger() + } + + // Run server in Go routine. + go s.Start() + + // Wait for accept loop(s) to be started + if !s.ReadyForConnections(10 * time.Second) { + panic("Unable to start NATS Server in Go Routine") + } + return s +} + +// LoadConfig loads a configuration from a filename +func LoadConfig(configFile string) (opts *Options) { + opts, err := ProcessConfigFile(configFile) + if err != nil { + panic(fmt.Sprintf("Error processing configuration file: %v", err)) + } + opts.NoSigs, opts.NoLog = true, true + return +} + +// RunServerWithConfig starts a new Go routine based server with a configuration file. +func RunServerWithConfig(configFile string) (srv *Server, opts *Options) { + opts = LoadConfig(configFile) + srv = RunServer(opts) + return +} + +func TestVersionMatchesTag(t *testing.T) { + tag := os.Getenv("TRAVIS_TAG") + if tag == "" { + t.SkipNow() + } + // We expect a tag of the form vX.Y.Z. If that's not the case, + // we need someone to have a look. So fail if first letter is not + // a `v` + if tag[0] != 'v' { + t.Fatalf("Expect tag to start with `v`, tag is: %s", tag) + } + // Strip the `v` from the tag for the version comparison. + if VERSION != tag[1:] { + t.Fatalf("Version (%s) does not match tag (%s)", VERSION, tag[1:]) + } +} + +func TestStartProfiler(t *testing.T) { + s := New(DefaultOptions()) + s.StartProfiler() + s.mu.Lock() + s.profiler.Close() + s.mu.Unlock() +} + +func TestStartupAndShutdown(t *testing.T) { + + opts := DefaultOptions() + + s := RunServer(opts) + defer s.Shutdown() + + if !s.isRunning() { + t.Fatal("Could not run server") + } + + // Debug stuff. + numRoutes := s.NumRoutes() + if numRoutes != 0 { + t.Fatalf("Expected numRoutes to be 0 vs %d\n", numRoutes) + } + + numRemotes := s.NumRemotes() + if numRemotes != 0 { + t.Fatalf("Expected numRemotes to be 0 vs %d\n", numRemotes) + } + + numClients := s.NumClients() + if numClients != 0 && numClients != 1 { + t.Fatalf("Expected numClients to be 1 or 0 vs %d\n", numClients) + } + + numSubscriptions := s.NumSubscriptions() + if numSubscriptions != 0 { + t.Fatalf("Expected numSubscriptions to be 0 vs %d\n", numSubscriptions) + } +} + +func TestTlsCipher(t *testing.T) { + if strings.Compare(tlsCipher(0x0005), "TLS_RSA_WITH_RC4_128_SHA") != 0 { + t.Fatalf("Invalid tls cipher") + } + if strings.Compare(tlsCipher(0x000a), "TLS_RSA_WITH_3DES_EDE_CBC_SHA") != 0 { + t.Fatalf("Invalid tls cipher") + } + if strings.Compare(tlsCipher(0x002f), "TLS_RSA_WITH_AES_128_CBC_SHA") != 0 { + t.Fatalf("Invalid tls cipher") + } + if strings.Compare(tlsCipher(0x0035), "TLS_RSA_WITH_AES_256_CBC_SHA") != 0 { + t.Fatalf("Invalid tls cipher") + } + if strings.Compare(tlsCipher(0xc007), "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA") != 0 { + t.Fatalf("Invalid tls cipher") + } + if strings.Compare(tlsCipher(0xc009), "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA") != 0 { + t.Fatalf("Invalid tls cipher") + } + if strings.Compare(tlsCipher(0xc00a), "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA") != 0 { + t.Fatalf("Invalid tls cipher") + } + if strings.Compare(tlsCipher(0xc011), "TLS_ECDHE_RSA_WITH_RC4_128_SHA") != 0 { + t.Fatalf("Invalid tls cipher") + } + if strings.Compare(tlsCipher(0xc012), "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA") != 0 { + t.Fatalf("Invalid tls cipher") + } + if strings.Compare(tlsCipher(0xc013), "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA") != 0 { + t.Fatalf("Invalid tls cipher") + } + if strings.Compare(tlsCipher(0xc014), "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA") != 0 { + t.Fatalf("IUnknownnvalid tls cipher") + } + if strings.Compare(tlsCipher(0xc02f), "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256") != 0 { + t.Fatalf("Invalid tls cipher") + } + if strings.Compare(tlsCipher(0xc02b), "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256") != 0 { + t.Fatalf("Invalid tls cipher") + } + if strings.Compare(tlsCipher(0xc030), "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384") != 0 { + t.Fatalf("Invalid tls cipher") + } + if strings.Compare(tlsCipher(0xc02c), "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384") != 0 { + t.Fatalf("Invalid tls cipher") + } + if !strings.Contains(tlsCipher(0x9999), "Unknown") { + t.Fatalf("Expected an unknown cipher.") + } +} + +func TestGetConnectURLs(t *testing.T) { + opts := DefaultOptions() + opts.Port = 4222 + + var globalIP net.IP + + checkGlobalConnectURLs := func() { + s := New(opts) + defer s.Shutdown() + + s.mu.Lock() + urls := s.getClientConnectURLs() + s.mu.Unlock() + if len(urls) == 0 { + t.Fatalf("Expected to get a list of urls, got none for listen addr: %v", opts.Host) + } + for _, u := range urls { + tcpaddr, err := net.ResolveTCPAddr("tcp", u) + if err != nil { + t.Fatalf("Error resolving: %v", err) + } + ip := tcpaddr.IP + if !ip.IsGlobalUnicast() { + t.Fatalf("IP %v is not global", ip.String()) + } + if ip.IsUnspecified() { + t.Fatalf("IP %v is unspecified", ip.String()) + } + addr := strings.TrimSuffix(u, ":4222") + if addr == opts.Host { + t.Fatalf("Returned url is not right: %v", u) + } + if globalIP == nil { + globalIP = ip + } + } + } + + listenAddrs := []string{"0.0.0.0", "::"} + for _, listenAddr := range listenAddrs { + opts.Host = listenAddr + checkGlobalConnectURLs() + } + + checkConnectURLsHasOnlyOne := func() { + s := New(opts) + defer s.Shutdown() + + s.mu.Lock() + urls := s.getClientConnectURLs() + s.mu.Unlock() + if len(urls) != 1 { + t.Fatalf("Expected one URL, got %v", urls) + } + tcpaddr, err := net.ResolveTCPAddr("tcp", urls[0]) + if err != nil { + t.Fatalf("Error resolving: %v", err) + } + ip := tcpaddr.IP + if ip.String() != opts.Host { + t.Fatalf("Expected connect URL to be %v, got %v", opts.Host, ip.String()) + } + } + + singleConnectReturned := []string{"127.0.0.1", "::1"} + if globalIP != nil { + singleConnectReturned = append(singleConnectReturned, globalIP.String()) + } + for _, listenAddr := range singleConnectReturned { + opts.Host = listenAddr + checkConnectURLsHasOnlyOne() + } +} + +func TestClientAdvertiseConnectURL(t *testing.T) { + opts := DefaultOptions() + opts.Port = 4222 + opts.ClientAdvertise = "nats.example.com" + s := New(opts) + defer s.Shutdown() + + s.mu.Lock() + urls := s.getClientConnectURLs() + s.mu.Unlock() + if len(urls) != 1 { + t.Fatalf("Expected to get one url, got none: %v with ClientAdvertise %v", + opts.Host, opts.ClientAdvertise) + } + if urls[0] != "nats.example.com:4222" { + t.Fatalf("Expected to get '%s', got: '%v'", "nats.example.com:4222", urls[0]) + } + s.Shutdown() + + opts.ClientAdvertise = "nats.example.com:7777" + s = New(opts) + s.mu.Lock() + urls = s.getClientConnectURLs() + s.mu.Unlock() + if len(urls) != 1 { + t.Fatalf("Expected to get one url, got none: %v with ClientAdvertise %v", + opts.Host, opts.ClientAdvertise) + } + if urls[0] != "nats.example.com:7777" { + t.Fatalf("Expected 'nats.example.com:7777', got: '%v'", urls[0]) + } + if s.info.Host != "nats.example.com" { + t.Fatalf("Expected host to be set to nats.example.com") + } + if s.info.Port != 7777 { + t.Fatalf("Expected port to be set to 7777") + } + s.Shutdown() + + opts = DefaultOptions() + opts.Port = 0 + opts.ClientAdvertise = "nats.example.com:7777" + s = New(opts) + if s.info.Host != "nats.example.com" && s.info.Port != 7777 { + t.Fatalf("Expected Client Advertise Host:Port to be nats.example.com:7777, got: %s:%d", + s.info.Host, s.info.Port) + } + s.Shutdown() +} + +func TestClientAdvertiseErrorOnStartup(t *testing.T) { + opts := DefaultOptions() + // Set invalid address + opts.ClientAdvertise = "addr:::123" + s := New(opts) + defer s.Shutdown() + dl := &DummyLogger{} + s.SetLogger(dl, false, false) + + // Expect this to return due to failure + s.Start() + dl.Lock() + msg := dl.msg + dl.Unlock() + if !strings.Contains(msg, "ClientAdvertise") { + t.Fatalf("Unexpected error: %v", msg) + } +} + +func TestNoDeadlockOnStartFailure(t *testing.T) { + opts := DefaultOptions() + opts.Host = "x.x.x.x" // bad host + opts.Port = 4222 + opts.HTTPHost = opts.Host + opts.Cluster.Host = "127.0.0.1" + opts.Cluster.Port = -1 + opts.ProfPort = -1 + s := New(opts) + + // This should return since it should fail to start a listener + // on x.x.x.x:4222 + s.Start() + + // We should be able to shutdown + s.Shutdown() +} + +func TestMaxConnections(t *testing.T) { + opts := DefaultOptions() + opts.MaxConn = 1 + s := RunServer(opts) + defer s.Shutdown() + + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc.Close() + + nc2, err := nats.Connect(addr) + if err == nil { + nc2.Close() + t.Fatal("Expected connection to fail") + } +} + +func TestMaxSubscriptions(t *testing.T) { + opts := DefaultOptions() + opts.MaxSubs = 10 + s := RunServer(opts) + defer s.Shutdown() + + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc.Close() + + for i := 0; i < 10; i++ { + _, err := nc.Subscribe(fmt.Sprintf("foo.%d", i), func(*nats.Msg) {}) + if err != nil { + t.Fatalf("Error subscribing: %v\n", err) + } + } + // This should cause the error. + nc.Subscribe("foo.22", func(*nats.Msg) {}) + nc.Flush() + if err := nc.LastError(); err == nil { + t.Fatal("Expected an error but got none\n") + } +} + +func TestProcessCommandLineArgs(t *testing.T) { + var host string + var port int + cmd := flag.NewFlagSet("gnatsd", flag.ExitOnError) + cmd.StringVar(&host, "a", "0.0.0.0", "Host.") + cmd.IntVar(&port, "p", 4222, "Port.") + + cmd.Parse([]string{"-a", "127.0.0.1", "-p", "9090"}) + showVersion, showHelp, err := ProcessCommandLineArgs(cmd) + if err != nil { + t.Errorf("Expected no errors, got: %s", err) + } + if showVersion || showHelp { + t.Errorf("Expected not having to handle subcommands") + } + + cmd.Parse([]string{"version"}) + showVersion, showHelp, err = ProcessCommandLineArgs(cmd) + if err != nil { + t.Errorf("Expected no errors, got: %s", err) + } + if !showVersion { + t.Errorf("Expected having to handle version command") + } + if showHelp { + t.Errorf("Expected not having to handle help command") + } + + cmd.Parse([]string{"help"}) + showVersion, showHelp, err = ProcessCommandLineArgs(cmd) + if err != nil { + t.Errorf("Expected no errors, got: %s", err) + } + if showVersion { + t.Errorf("Expected not having to handle version command") + } + if !showHelp { + t.Errorf("Expected having to handle help command") + } + + cmd.Parse([]string{"foo", "-p", "9090"}) + _, _, err = ProcessCommandLineArgs(cmd) + if err == nil { + t.Errorf("Expected an error handling the command arguments") + } +} + +func TestWriteDeadline(t *testing.T) { + opts := DefaultOptions() + opts.WriteDeadline = 30 * time.Millisecond + s := RunServer(opts) + defer s.Shutdown() + + c, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", opts.Host, opts.Port), 3*time.Second) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer c.Close() + if _, err := c.Write([]byte("CONNECT {}\r\nPING\r\nSUB foo 1\r\n")); err != nil { + t.Fatalf("Error sending protocols to server: %v", err) + } + // Reduce socket buffer to increase reliability of getting + // write deadline errors. + c.(*net.TCPConn).SetReadBuffer(4) + + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + sender, err := nats.Connect(url) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer sender.Close() + + payload := make([]byte, 1000000) + for i := 0; i < 10; i++ { + if err := sender.Publish("foo", payload); err != nil { + t.Fatalf("Error on publish: %v", err) + } + } + // Flush sender connection to ensure that all data has been sent. + if err := sender.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } + + // At this point server should have closed connection c. + + // On certain platforms, it may take more than one call before + // getting the error. + for i := 0; i < 100; i++ { + if _, err := c.Write([]byte("PUB bar 5\r\nhello\r\n")); err != nil { + // ok + return + } + } + t.Fatal("Connection should have been closed") +} + +func TestSlowConsumerPendingBytes(t *testing.T) { + opts := DefaultOptions() + opts.WriteDeadline = 30 * time.Second // Wait for long time so write deadline does not trigger slow consumer. + opts.MaxPending = 1 * 1024 * 1024 // Set to low value (1MB) to allow SC to trip. + s := RunServer(opts) + defer s.Shutdown() + + c, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", opts.Host, opts.Port), 3*time.Second) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer c.Close() + if _, err := c.Write([]byte("CONNECT {}\r\nPING\r\nSUB foo 1\r\n")); err != nil { + t.Fatalf("Error sending protocols to server: %v", err) + } + // Reduce socket buffer to increase reliability of data backing up in the server destined + // for our subscribed client. + c.(*net.TCPConn).SetReadBuffer(128) + + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + sender, err := nats.Connect(url) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer sender.Close() + + payload := make([]byte, 1024*1024) + for i := 0; i < 100; i++ { + if err := sender.Publish("foo", payload); err != nil { + t.Fatalf("Error on publish: %v", err) + } + } + + // Flush sender connection to ensure that all data has been sent. + if err := sender.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } + + // At this point server should have closed connection c. + + // On certain platforms, it may take more than one call before + // getting the error. + for i := 0; i < 100; i++ { + if _, err := c.Write([]byte("PUB bar 5\r\nhello\r\n")); err != nil { + // ok + return + } + } + t.Fatal("Connection should have been closed") +} + +func TestRandomPorts(t *testing.T) { + opts := DefaultOptions() + opts.HTTPPort = -1 + opts.Port = -1 + s := RunServer(opts) + + defer s.Shutdown() + + if s.Addr() == nil || s.Addr().(*net.TCPAddr).Port <= 0 { + t.Fatal("Should have dynamically assigned server port.") + } + + if s.Addr() == nil || s.Addr().(*net.TCPAddr).Port == 4222 { + t.Fatal("Should not have dynamically assigned default port: 4222.") + } + + if s.MonitorAddr() == nil || s.MonitorAddr().Port <= 0 { + t.Fatal("Should have dynamically assigned monitoring port.") + } + +} + +func TestNilMonitoringPort(t *testing.T) { + opts := DefaultOptions() + opts.HTTPPort = 0 + opts.HTTPSPort = 0 + s := RunServer(opts) + + defer s.Shutdown() + + if s.MonitorAddr() != nil { + t.Fatal("HttpAddr should be nil.") + } +} + +type DummyAuth struct{} + +func (d *DummyAuth) Check(c ClientAuthentication) bool { + return c.GetOpts().Username == "valid" +} + +func TestCustomClientAuthentication(t *testing.T) { + var clientAuth DummyAuth + + opts := DefaultOptions() + opts.CustomClientAuthentication = &clientAuth + + s := RunServer(opts) + + defer s.Shutdown() + + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + + nc, err := nats.Connect(addr, nats.UserInfo("valid", "")) + if err != nil { + t.Fatalf("Expected client to connect, got: %s", err) + } + nc.Close() + if _, err := nats.Connect(addr, nats.UserInfo("invalid", "")); err == nil { + t.Fatal("Expected client to fail to connect") + } +} + +func TestCustomRouterAuthentication(t *testing.T) { + opts := DefaultOptions() + opts.CustomRouterAuthentication = &DummyAuth{} + opts.Cluster.Host = "127.0.0.1" + s := RunServer(opts) + defer s.Shutdown() + clusterPort := s.ClusterAddr().Port + + opts2 := DefaultOptions() + opts2.Cluster.Host = "127.0.0.1" + opts2.Routes = RoutesFromStr(fmt.Sprintf("nats://invalid@127.0.0.1:%d", clusterPort)) + s2 := RunServer(opts2) + defer s2.Shutdown() + + // s2 will attempt to connect to s, which should reject. + // Keep in mind that s2 will try again... + time.Sleep(50 * time.Millisecond) + checkNumRoutes(t, s2, 0) + + opts3 := DefaultOptions() + opts3.Cluster.Host = "127.0.0.1" + opts3.Routes = RoutesFromStr(fmt.Sprintf("nats://valid@127.0.0.1:%d", clusterPort)) + s3 := RunServer(opts3) + defer s3.Shutdown() + checkClusterFormed(t, s, s3) + checkNumRoutes(t, s3, 1) +} + +func TestMonitoringNoTimeout(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + s.mu.Lock() + srv := s.monitoringServer + s.mu.Unlock() + + if srv == nil { + t.Fatalf("Monitoring server not set") + } + if srv.ReadTimeout != 0 { + t.Fatalf("ReadTimeout should not be set, was set to %v", srv.ReadTimeout) + } + if srv.WriteTimeout != 0 { + t.Fatalf("WriteTimeout should not be set, was set to %v", srv.WriteTimeout) + } +} + +func TestProfilingNoTimeout(t *testing.T) { + opts := DefaultOptions() + opts.ProfPort = -1 + s := RunServer(opts) + defer s.Shutdown() + + paddr := s.ProfilerAddr() + if paddr == nil { + t.Fatalf("Profiler not started") + } + pport := paddr.Port + if pport <= 0 { + t.Fatalf("Expected profiler port to be set, got %v", pport) + } + s.mu.Lock() + srv := s.profilingServer + s.mu.Unlock() + + if srv == nil { + t.Fatalf("Profiling server not set") + } + if srv.ReadTimeout != 0 { + t.Fatalf("ReadTimeout should not be set, was set to %v", srv.ReadTimeout) + } + if srv.WriteTimeout != 0 { + t.Fatalf("WriteTimeout should not be set, was set to %v", srv.WriteTimeout) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/service.go b/vendor/github.com/nats-io/gnatsd/server/service.go new file mode 100644 index 00000000..a44cbac3 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/service.go @@ -0,0 +1,28 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !windows + +package server + +// Run starts the NATS server. This wrapper function allows Windows to add a +// hook for running NATS as a service. +func Run(server *Server) error { + server.Start() + return nil +} + +// isWindowsService indicates if NATS is running as a Windows service. +func isWindowsService() bool { + return false +} diff --git a/vendor/github.com/nats-io/gnatsd/server/service_test.go b/vendor/github.com/nats-io/gnatsd/server/service_test.go new file mode 100644 index 00000000..8811c185 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/service_test.go @@ -0,0 +1,53 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !windows + +package server + +import ( + "errors" + "testing" + "time" +) + +func TestRun(t *testing.T) { + var ( + s = New(DefaultOptions()) + started = make(chan error, 1) + errC = make(chan error, 1) + ) + go func() { + errC <- Run(s) + }() + go func() { + if !s.ReadyForConnections(time.Second) { + started <- errors.New("failed to start in time") + return + } + s.Shutdown() + close(started) + }() + + select { + case err := <-errC: + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Timed out") + } + if err := <-started; err != nil { + t.Fatalf("Unexpected error: %v", err) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/service_windows.go b/vendor/github.com/nats-io/gnatsd/server/service_windows.go new file mode 100644 index 00000000..43cc2b5c --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/service_windows.go @@ -0,0 +1,121 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "os" + "time" + + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/debug" +) + +const ( + serviceName = "gnatsd" + reopenLogCode = 128 + reopenLogCmd = svc.Cmd(reopenLogCode) + acceptReopenLog = svc.Accepted(reopenLogCode) +) + +// winServiceWrapper implements the svc.Handler interface for implementing +// gnatsd as a Windows service. +type winServiceWrapper struct { + server *Server +} + +var dockerized = false + +func init() { + if v, exists := os.LookupEnv("NATS_DOCKERIZED"); exists && v == "1" { + dockerized = true + } +} + +// Execute will be called by the package code at the start of +// the service, and the service will exit once Execute completes. +// Inside Execute you must read service change requests from r and +// act accordingly. You must keep service control manager up to date +// about state of your service by writing into s as required. +// args contains service name followed by argument strings passed +// to the service. +// You can provide service exit code in exitCode return parameter, +// with 0 being "no error". You can also indicate if exit code, +// if any, is service specific or not by using svcSpecificEC +// parameter. +func (w *winServiceWrapper) Execute(args []string, changes <-chan svc.ChangeRequest, + status chan<- svc.Status) (bool, uint32) { + + status <- svc.Status{State: svc.StartPending} + go w.server.Start() + + // Wait for accept loop(s) to be started + if !w.server.ReadyForConnections(10 * time.Second) { + // Failed to start. + return false, 1 + } + + status <- svc.Status{ + State: svc.Running, + Accepts: svc.AcceptStop | svc.AcceptShutdown | svc.AcceptParamChange | acceptReopenLog, + } + +loop: + for change := range changes { + switch change.Cmd { + case svc.Interrogate: + status <- change.CurrentStatus + case svc.Stop, svc.Shutdown: + w.server.Shutdown() + break loop + case reopenLogCmd: + // File log re-open for rotating file logs. + w.server.ReOpenLogFile() + case svc.ParamChange: + if err := w.server.Reload(); err != nil { + w.server.Errorf("Failed to reload server configuration: %s", err) + } + default: + w.server.Debugf("Unexpected control request: %v", change.Cmd) + } + } + + status <- svc.Status{State: svc.StopPending} + return false, 0 +} + +// Run starts the NATS server as a Windows service. +func Run(server *Server) error { + if dockerized { + server.Start() + return nil + } + run := svc.Run + isInteractive, err := svc.IsAnInteractiveSession() + if err != nil { + return err + } + if isInteractive { + run = debug.Run + } + return run(serviceName, &winServiceWrapper{server}) +} + +// isWindowsService indicates if NATS is running as a Windows service. +func isWindowsService() bool { + if dockerized { + return false + } + isInteractive, _ := svc.IsAnInteractiveSession() + return !isInteractive +} diff --git a/vendor/github.com/nats-io/gnatsd/server/signal.go b/vendor/github.com/nats-io/gnatsd/server/signal.go new file mode 100644 index 00000000..6a432f7c --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/signal.go @@ -0,0 +1,158 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !windows + +package server + +import ( + "errors" + "fmt" + "os" + "os/exec" + "os/signal" + "strconv" + "strings" + "syscall" +) + +const processName = "gnatsd" + +// Signal Handling +func (s *Server) handleSignals() { + if s.getOpts().NoSigs { + return + } + c := make(chan os.Signal, 1) + + signal.Notify(c, syscall.SIGINT, syscall.SIGUSR1, syscall.SIGHUP) + + s.grWG.Add(1) + go func() { + defer s.grWG.Done() + for { + select { + case sig := <-c: + s.Debugf("Trapped %q signal", sig) + switch sig { + case syscall.SIGINT: + s.Noticef("Server Exiting..") + os.Exit(0) + case syscall.SIGUSR1: + // File log re-open for rotating file logs. + s.ReOpenLogFile() + case syscall.SIGHUP: + // Config reload. + if err := s.Reload(); err != nil { + s.Errorf("Failed to reload server configuration: %s", err) + } + } + case <-s.quitCh: + return + } + } + }() +} + +// ProcessSignal sends the given signal command to the given process. If pidStr +// is empty, this will send the signal to the single running instance of +// gnatsd. If multiple instances are running, it returns an error. This returns +// an error if the given process is not running or the command is invalid. +func ProcessSignal(command Command, pidStr string) error { + var pid int + if pidStr == "" { + pids, err := resolvePids() + if err != nil { + return err + } + if len(pids) == 0 { + return errors.New("no gnatsd processes running") + } + if len(pids) > 1 { + errStr := "multiple gnatsd processes running:\n" + prefix := "" + for _, p := range pids { + errStr += fmt.Sprintf("%s%d", prefix, p) + prefix = "\n" + } + return errors.New(errStr) + } + pid = pids[0] + } else { + p, err := strconv.Atoi(pidStr) + if err != nil { + return fmt.Errorf("invalid pid: %s", pidStr) + } + pid = p + } + + var err error + switch command { + case CommandStop: + err = kill(pid, syscall.SIGKILL) + case CommandQuit: + err = kill(pid, syscall.SIGINT) + case CommandReopen: + err = kill(pid, syscall.SIGUSR1) + case CommandReload: + err = kill(pid, syscall.SIGHUP) + default: + err = fmt.Errorf("unknown signal %q", command) + } + return err +} + +// resolvePids returns the pids for all running gnatsd processes. +func resolvePids() ([]int, error) { + // If pgrep isn't available, this will just bail out and the user will be + // required to specify a pid. + output, err := pgrep() + if err != nil { + switch err.(type) { + case *exec.ExitError: + // ExitError indicates non-zero exit code, meaning no processes + // found. + break + default: + return nil, errors.New("unable to resolve pid, try providing one") + } + } + var ( + myPid = os.Getpid() + pidStrs = strings.Split(string(output), "\n") + pids = make([]int, 0, len(pidStrs)) + ) + for _, pidStr := range pidStrs { + if pidStr == "" { + continue + } + pid, err := strconv.Atoi(pidStr) + if err != nil { + return nil, errors.New("unable to resolve pid, try providing one") + } + // Ignore the current process. + if pid == myPid { + continue + } + pids = append(pids, pid) + } + return pids, nil +} + +var kill = func(pid int, signal syscall.Signal) error { + return syscall.Kill(pid, signal) +} + +var pgrep = func() ([]byte, error) { + return exec.Command("pgrep", processName).Output() +} diff --git a/vendor/github.com/nats-io/gnatsd/server/signal_test.go b/vendor/github.com/nats-io/gnatsd/server/signal_test.go new file mode 100644 index 00000000..3fa27559 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/signal_test.go @@ -0,0 +1,314 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !windows + +package server + +import ( + "errors" + "fmt" + "io/ioutil" + "os" + "os/exec" + "strings" + "syscall" + "testing" + "time" + + "github.com/nats-io/gnatsd/logger" +) + +func TestSignalToReOpenLogFile(t *testing.T) { + logFile := "test.log" + defer os.Remove(logFile) + defer os.Remove(logFile + ".bak") + opts := &Options{ + Host: "127.0.0.1", + Port: -1, + NoSigs: false, + LogFile: logFile, + } + s := RunServer(opts) + defer s.SetLogger(nil, false, false) + defer s.Shutdown() + + // Set the file log + fileLog := logger.NewFileLogger(s.opts.LogFile, s.opts.Logtime, s.opts.Debug, s.opts.Trace, true) + s.SetLogger(fileLog, false, false) + + // Add a trace + expectedStr := "This is a Notice" + s.Noticef(expectedStr) + buf, err := ioutil.ReadFile(logFile) + if err != nil { + t.Fatalf("Error reading file: %v", err) + } + if !strings.Contains(string(buf), expectedStr) { + t.Fatalf("Expected log to contain %q, got %q", expectedStr, string(buf)) + } + // Rename the file + if err := os.Rename(logFile, logFile+".bak"); err != nil { + t.Fatalf("Unable to rename file: %v", err) + } + // This should cause file to be reopened. + syscall.Kill(syscall.Getpid(), syscall.SIGUSR1) + // Wait a bit for action to be performed + time.Sleep(500 * time.Millisecond) + buf, err = ioutil.ReadFile(logFile) + if err != nil { + t.Fatalf("Error reading file: %v", err) + } + expectedStr = "File log re-opened" + if !strings.Contains(string(buf), expectedStr) { + t.Fatalf("Expected log to contain %q, got %q", expectedStr, string(buf)) + } +} + +func TestSignalToReloadConfig(t *testing.T) { + opts, err := ProcessConfigFile("./configs/reload/basic.conf") + if err != nil { + t.Fatalf("Error processing config file: %v", err) + } + opts.NoLog = true + s := RunServer(opts) + defer s.Shutdown() + + // Repeat test to make sure that server services signals more than once... + for i := 0; i < 2; i++ { + loaded := s.ConfigTime() + + // Wait a bit to ensure ConfigTime changes. + time.Sleep(5 * time.Millisecond) + + // This should cause config to be reloaded. + syscall.Kill(syscall.Getpid(), syscall.SIGHUP) + // Wait a bit for action to be performed + time.Sleep(500 * time.Millisecond) + + if reloaded := s.ConfigTime(); !reloaded.After(loaded) { + t.Fatalf("ConfigTime is incorrect.\nexpected greater than: %s\ngot: %s", loaded, reloaded) + } + } +} + +func TestProcessSignalNoProcesses(t *testing.T) { + pgrepBefore := pgrep + pgrep = func() ([]byte, error) { + return nil, &exec.ExitError{} + } + defer func() { + pgrep = pgrepBefore + }() + + err := ProcessSignal(CommandStop, "") + if err == nil { + t.Fatal("Expected error") + } + expectedStr := "no gnatsd processes running" + if err.Error() != expectedStr { + t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error()) + } +} + +func TestProcessSignalMultipleProcesses(t *testing.T) { + pid := os.Getpid() + pgrepBefore := pgrep + pgrep = func() ([]byte, error) { + return []byte(fmt.Sprintf("123\n456\n%d\n", pid)), nil + } + defer func() { + pgrep = pgrepBefore + }() + + err := ProcessSignal(CommandStop, "") + if err == nil { + t.Fatal("Expected error") + } + expectedStr := "multiple gnatsd processes running:\n123\n456" + if err.Error() != expectedStr { + t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error()) + } +} + +func TestProcessSignalPgrepError(t *testing.T) { + pgrepBefore := pgrep + pgrep = func() ([]byte, error) { + return nil, errors.New("error") + } + defer func() { + pgrep = pgrepBefore + }() + + err := ProcessSignal(CommandStop, "") + if err == nil { + t.Fatal("Expected error") + } + expectedStr := "unable to resolve pid, try providing one" + if err.Error() != expectedStr { + t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error()) + } +} + +func TestProcessSignalPgrepMangled(t *testing.T) { + pgrepBefore := pgrep + pgrep = func() ([]byte, error) { + return []byte("12x"), nil + } + defer func() { + pgrep = pgrepBefore + }() + + err := ProcessSignal(CommandStop, "") + if err == nil { + t.Fatal("Expected error") + } + expectedStr := "unable to resolve pid, try providing one" + if err.Error() != expectedStr { + t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error()) + } +} + +func TestProcessSignalResolveSingleProcess(t *testing.T) { + pid := os.Getpid() + pgrepBefore := pgrep + pgrep = func() ([]byte, error) { + return []byte(fmt.Sprintf("123\n%d\n", pid)), nil + } + defer func() { + pgrep = pgrepBefore + }() + killBefore := kill + called := false + kill = func(pid int, signal syscall.Signal) error { + called = true + if pid != 123 { + t.Fatalf("pid is incorrect.\nexpected: 123\ngot: %d", pid) + } + if signal != syscall.SIGKILL { + t.Fatalf("signal is incorrect.\nexpected: killed\ngot: %v", signal) + } + return nil + } + defer func() { + kill = killBefore + }() + + if err := ProcessSignal(CommandStop, ""); err != nil { + t.Fatalf("ProcessSignal failed: %v", err) + } + + if !called { + t.Fatal("Expected kill to be called") + } +} + +func TestProcessSignalInvalidCommand(t *testing.T) { + err := ProcessSignal(Command("invalid"), "123") + if err == nil { + t.Fatal("Expected error") + } + expectedStr := "unknown signal \"invalid\"" + if err.Error() != expectedStr { + t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error()) + } +} + +func TestProcessSignalInvalidPid(t *testing.T) { + err := ProcessSignal(CommandStop, "abc") + if err == nil { + t.Fatal("Expected error") + } + expectedStr := "invalid pid: abc" + if err.Error() != expectedStr { + t.Fatalf("Error is incorrect.\nexpected: %s\ngot: %s", expectedStr, err.Error()) + } +} + +func TestProcessSignalQuitProcess(t *testing.T) { + killBefore := kill + called := false + kill = func(pid int, signal syscall.Signal) error { + called = true + if pid != 123 { + t.Fatalf("pid is incorrect.\nexpected: 123\ngot: %d", pid) + } + if signal != syscall.SIGINT { + t.Fatalf("signal is incorrect.\nexpected: interrupt\ngot: %v", signal) + } + return nil + } + defer func() { + kill = killBefore + }() + + if err := ProcessSignal(CommandQuit, "123"); err != nil { + t.Fatalf("ProcessSignal failed: %v", err) + } + + if !called { + t.Fatal("Expected kill to be called") + } +} + +func TestProcessSignalReopenProcess(t *testing.T) { + killBefore := kill + called := false + kill = func(pid int, signal syscall.Signal) error { + called = true + if pid != 123 { + t.Fatalf("pid is incorrect.\nexpected: 123\ngot: %d", pid) + } + if signal != syscall.SIGUSR1 { + t.Fatalf("signal is incorrect.\nexpected: user defined signal 1\ngot: %v", signal) + } + return nil + } + defer func() { + kill = killBefore + }() + + if err := ProcessSignal(CommandReopen, "123"); err != nil { + t.Fatalf("ProcessSignal failed: %v", err) + } + + if !called { + t.Fatal("Expected kill to be called") + } +} + +func TestProcessSignalReloadProcess(t *testing.T) { + killBefore := kill + called := false + kill = func(pid int, signal syscall.Signal) error { + called = true + if pid != 123 { + t.Fatalf("pid is incorrect.\nexpected: 123\ngot: %d", pid) + } + if signal != syscall.SIGHUP { + t.Fatalf("signal is incorrect.\nexpected: hangup\ngot: %v", signal) + } + return nil + } + defer func() { + kill = killBefore + }() + + if err := ProcessSignal(CommandReload, "123"); err != nil { + t.Fatalf("ProcessSignal failed: %v", err) + } + + if !called { + t.Fatal("Expected kill to be called") + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/signal_windows.go b/vendor/github.com/nats-io/gnatsd/server/signal_windows.go new file mode 100644 index 00000000..368077dd --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/signal_windows.go @@ -0,0 +1,101 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "os" + "os/signal" + "time" + + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" +) + +// Signal Handling +func (s *Server) handleSignals() { + if s.getOpts().NoSigs { + return + } + c := make(chan os.Signal, 1) + + signal.Notify(c, os.Interrupt) + + go func() { + for sig := range c { + s.Debugf("Trapped %q signal", sig) + s.Noticef("Server Exiting..") + os.Exit(0) + } + }() +} + +// ProcessSignal sends the given signal command to the running gnatsd service. +// If service is empty, this signals the "gnatsd" service. This returns an +// error is the given service is not running or the command is invalid. +func ProcessSignal(command Command, service string) error { + if service == "" { + service = serviceName + } + + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + + s, err := m.OpenService(service) + if err != nil { + return fmt.Errorf("could not access service: %v", err) + } + defer s.Close() + + var ( + cmd svc.Cmd + to svc.State + ) + + switch command { + case CommandStop, CommandQuit: + cmd = svc.Stop + to = svc.Stopped + case CommandReopen: + cmd = reopenLogCmd + to = svc.Running + case CommandReload: + cmd = svc.ParamChange + to = svc.Running + default: + return fmt.Errorf("unknown signal %q", command) + } + + status, err := s.Control(cmd) + if err != nil { + return fmt.Errorf("could not send control=%d: %v", cmd, err) + } + + timeout := time.Now().Add(10 * time.Second) + for status.State != to { + if timeout.Before(time.Now()) { + return fmt.Errorf("timeout waiting for service to go to state=%d", to) + } + time.Sleep(300 * time.Millisecond) + status, err = s.Query() + if err != nil { + return fmt.Errorf("could not retrieve service status: %v", err) + } + } + + return nil +} diff --git a/vendor/github.com/nats-io/gnatsd/server/split_test.go b/vendor/github.com/nats-io/gnatsd/server/split_test.go new file mode 100644 index 00000000..77dd2efb --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/split_test.go @@ -0,0 +1,517 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "bytes" + "net" + "testing" +) + +func TestSplitBufferSubOp(t *testing.T) { + cli, trash := net.Pipe() + defer cli.Close() + defer trash.Close() + + s := &Server{sl: NewSublist()} + c := &client{srv: s, subs: make(map[string]*subscription), nc: cli} + + subop := []byte("SUB foo 1\r\n") + subop1 := subop[:6] + subop2 := subop[6:] + + if err := c.parse(subop1); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != SUB_ARG { + t.Fatalf("Expected SUB_ARG state vs %d\n", c.state) + } + if err := c.parse(subop2); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != OP_START { + t.Fatalf("Expected OP_START state vs %d\n", c.state) + } + r := s.sl.Match("foo") + if r == nil || len(r.psubs) != 1 { + t.Fatalf("Did not match subscription properly: %+v\n", r) + } + sub := r.psubs[0] + if !bytes.Equal(sub.subject, []byte("foo")) { + t.Fatalf("Subject did not match expected 'foo' : '%s'\n", sub.subject) + } + if !bytes.Equal(sub.sid, []byte("1")) { + t.Fatalf("Sid did not match expected '1' : '%s'\n", sub.sid) + } + if sub.queue != nil { + t.Fatalf("Received a non-nil queue: '%s'\n", sub.queue) + } +} + +func TestSplitBufferUnsubOp(t *testing.T) { + s := &Server{sl: NewSublist()} + c := &client{srv: s, subs: make(map[string]*subscription)} + + subop := []byte("SUB foo 1024\r\n") + if err := c.parse(subop); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != OP_START { + t.Fatalf("Expected OP_START state vs %d\n", c.state) + } + + unsubop := []byte("UNSUB 1024\r\n") + unsubop1 := unsubop[:8] + unsubop2 := unsubop[8:] + + if err := c.parse(unsubop1); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != UNSUB_ARG { + t.Fatalf("Expected UNSUB_ARG state vs %d\n", c.state) + } + if err := c.parse(unsubop2); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != OP_START { + t.Fatalf("Expected OP_START state vs %d\n", c.state) + } + r := s.sl.Match("foo") + if r != nil && len(r.psubs) != 0 { + t.Fatalf("Should be no subscriptions in results: %+v\n", r) + } +} + +func TestSplitBufferPubOp(t *testing.T) { + c := &client{subs: make(map[string]*subscription)} + pub := []byte("PUB foo.bar INBOX.22 11\r\nhello world\r") + pub1 := pub[:2] + pub2 := pub[2:9] + pub3 := pub[9:15] + pub4 := pub[15:22] + pub5 := pub[22:25] + pub6 := pub[25:33] + pub7 := pub[33:] + + if err := c.parse(pub1); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != OP_PU { + t.Fatalf("Expected OP_PU state vs %d\n", c.state) + } + if err := c.parse(pub2); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != PUB_ARG { + t.Fatalf("Expected OP_PU state vs %d\n", c.state) + } + if err := c.parse(pub3); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != PUB_ARG { + t.Fatalf("Expected OP_PU state vs %d\n", c.state) + } + if err := c.parse(pub4); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != PUB_ARG { + t.Fatalf("Expected PUB_ARG state vs %d\n", c.state) + } + if err := c.parse(pub5); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != MSG_PAYLOAD { + t.Fatalf("Expected MSG_PAYLOAD state vs %d\n", c.state) + } + + // Check c.pa + if !bytes.Equal(c.pa.subject, []byte("foo.bar")) { + t.Fatalf("PUB arg subject incorrect: '%s'\n", c.pa.subject) + } + if !bytes.Equal(c.pa.reply, []byte("INBOX.22")) { + t.Fatalf("PUB arg reply subject incorrect: '%s'\n", c.pa.reply) + } + if c.pa.size != 11 { + t.Fatalf("PUB arg msg size incorrect: %d\n", c.pa.size) + } + if err := c.parse(pub6); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != MSG_PAYLOAD { + t.Fatalf("Expected MSG_PAYLOAD state vs %d\n", c.state) + } + if err := c.parse(pub7); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != MSG_END { + t.Fatalf("Expected MSG_END state vs %d\n", c.state) + } +} + +func TestSplitBufferPubOp2(t *testing.T) { + c := &client{subs: make(map[string]*subscription)} + pub := []byte("PUB foo.bar INBOX.22 11\r\nhello world\r\n") + pub1 := pub[:30] + pub2 := pub[30:] + + if err := c.parse(pub1); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != MSG_PAYLOAD { + t.Fatalf("Expected MSG_PAYLOAD state vs %d\n", c.state) + } + if err := c.parse(pub2); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != OP_START { + t.Fatalf("Expected OP_START state vs %d\n", c.state) + } +} + +func TestSplitBufferPubOp3(t *testing.T) { + c := &client{subs: make(map[string]*subscription)} + pubAll := []byte("PUB foo bar 11\r\nhello world\r\n") + pub := pubAll[:16] + + if err := c.parse(pub); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if !bytes.Equal(c.pa.subject, []byte("foo")) { + t.Fatalf("Unexpected subject: '%s' vs '%s'\n", c.pa.subject, "foo") + } + + // Simulate next read of network, make sure pub state is saved + // until msg payload has cleared. + copy(pubAll, "XXXXXXXXXXXXXXXX") + if !bytes.Equal(c.pa.subject, []byte("foo")) { + t.Fatalf("Unexpected subject: '%s' vs '%s'\n", c.pa.subject, "foo") + } + if !bytes.Equal(c.pa.reply, []byte("bar")) { + t.Fatalf("Unexpected reply: '%s' vs '%s'\n", c.pa.reply, "bar") + } + if !bytes.Equal(c.pa.szb, []byte("11")) { + t.Fatalf("Unexpected size bytes: '%s' vs '%s'\n", c.pa.szb, "11") + } +} + +func TestSplitBufferPubOp4(t *testing.T) { + c := &client{subs: make(map[string]*subscription)} + pubAll := []byte("PUB foo 11\r\nhello world\r\n") + pub := pubAll[:12] + + if err := c.parse(pub); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if !bytes.Equal(c.pa.subject, []byte("foo")) { + t.Fatalf("Unexpected subject: '%s' vs '%s'\n", c.pa.subject, "foo") + } + + // Simulate next read of network, make sure pub state is saved + // until msg payload has cleared. + copy(pubAll, "XXXXXXXXXXXX") + if !bytes.Equal(c.pa.subject, []byte("foo")) { + t.Fatalf("Unexpected subject: '%s' vs '%s'\n", c.pa.subject, "foo") + } + if !bytes.Equal(c.pa.reply, []byte("")) { + t.Fatalf("Unexpected reply: '%s' vs '%s'\n", c.pa.reply, "") + } + if !bytes.Equal(c.pa.szb, []byte("11")) { + t.Fatalf("Unexpected size bytes: '%s' vs '%s'\n", c.pa.szb, "11") + } +} + +func TestSplitBufferPubOp5(t *testing.T) { + c := &client{subs: make(map[string]*subscription)} + pubAll := []byte("PUB foo 11\r\nhello world\r\n") + + // Splits need to be on MSG_END now too, so make sure we check that. + // Split between \r and \n + pub := pubAll[:len(pubAll)-1] + + if err := c.parse(pub); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.msgBuf == nil { + t.Fatalf("msgBuf should not be nil!\n") + } + if !bytes.Equal(c.msgBuf, []byte("hello world\r")) { + t.Fatalf("c.msgBuf did not snaphot the msg") + } +} + +func TestSplitConnectArg(t *testing.T) { + c := &client{subs: make(map[string]*subscription)} + connectAll := []byte("CONNECT {\"verbose\":false,\"tls_required\":false," + + "\"user\":\"test\",\"pedantic\":true,\"pass\":\"pass\"}\r\n") + + argJSON := connectAll[8:] + + c1 := connectAll[:5] + c2 := connectAll[5:22] + c3 := connectAll[22 : len(connectAll)-2] + c4 := connectAll[len(connectAll)-2:] + + if err := c.parse(c1); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.argBuf != nil { + t.Fatalf("Unexpected argBug placeholder.\n") + } + + if err := c.parse(c2); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.argBuf == nil { + t.Fatalf("Expected argBug to not be nil.\n") + } + if !bytes.Equal(c.argBuf, argJSON[:14]) { + t.Fatalf("argBuf not correct, received %q, wanted %q\n", argJSON[:14], c.argBuf) + } + + if err := c.parse(c3); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.argBuf == nil { + t.Fatalf("Expected argBug to not be nil.\n") + } + if !bytes.Equal(c.argBuf, argJSON[:len(argJSON)-2]) { + t.Fatalf("argBuf not correct, received %q, wanted %q\n", + argJSON[:len(argJSON)-2], c.argBuf) + } + + if err := c.parse(c4); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.argBuf != nil { + t.Fatalf("Unexpected argBuf placeholder.\n") + } +} + +func TestSplitDanglingArgBuf(t *testing.T) { + s := New(&defaultServerOptions) + c := &client{srv: s, subs: make(map[string]*subscription)} + + // We test to make sure we do not dangle any argBufs after processing + // since that could lead to performance issues. + + // SUB + subop := []byte("SUB foo 1\r\n") + c.parse(subop[:6]) + c.parse(subop[6:]) + if c.argBuf != nil { + t.Fatalf("Expected c.argBuf to be nil: %q\n", c.argBuf) + } + + // UNSUB + unsubop := []byte("UNSUB 1024\r\n") + c.parse(unsubop[:8]) + c.parse(unsubop[8:]) + if c.argBuf != nil { + t.Fatalf("Expected c.argBuf to be nil: %q\n", c.argBuf) + } + + // PUB + pubop := []byte("PUB foo.bar INBOX.22 11\r\nhello world\r\n") + c.parse(pubop[:22]) + c.parse(pubop[22:25]) + if c.argBuf == nil { + t.Fatal("Expected a non-nil argBuf!") + } + c.parse(pubop[25:]) + if c.argBuf != nil { + t.Fatalf("Expected c.argBuf to be nil: %q\n", c.argBuf) + } + + // MINUS_ERR + errop := []byte("-ERR Too Long\r\n") + c.parse(errop[:8]) + c.parse(errop[8:]) + if c.argBuf != nil { + t.Fatalf("Expected c.argBuf to be nil: %q\n", c.argBuf) + } + + // CONNECT_ARG + connop := []byte("CONNECT {\"verbose\":false,\"tls_required\":false," + + "\"user\":\"test\",\"pedantic\":true,\"pass\":\"pass\"}\r\n") + c.parse(connop[:22]) + c.parse(connop[22:]) + if c.argBuf != nil { + t.Fatalf("Expected c.argBuf to be nil: %q\n", c.argBuf) + } + + // INFO_ARG + infoop := []byte("INFO {\"server_id\":\"id\"}\r\n") + c.parse(infoop[:8]) + c.parse(infoop[8:]) + if c.argBuf != nil { + t.Fatalf("Expected c.argBuf to be nil: %q\n", c.argBuf) + } + + // MSG (the client has to be a ROUTE) + c = &client{subs: make(map[string]*subscription), typ: ROUTER} + msgop := []byte("MSG foo RSID:2:1 5\r\nhello\r\n") + c.parse(msgop[:5]) + c.parse(msgop[5:10]) + if c.argBuf == nil { + t.Fatal("Expected a non-nil argBuf") + } + if string(c.argBuf) != "foo RS" { + t.Fatalf("Expected argBuf to be \"foo 1 \", got %q", string(c.argBuf)) + } + c.parse(msgop[10:]) + if c.argBuf != nil { + t.Fatalf("Expected argBuf to be nil: %q", c.argBuf) + } + if c.msgBuf != nil { + t.Fatalf("Expected msgBuf to be nil: %q", c.msgBuf) + } + + c.state = OP_START + // Parse up-to somewhere in the middle of the payload. + // Verify that we have saved the MSG_ARG info + c.parse(msgop[:23]) + if c.argBuf == nil { + t.Fatal("Expected a non-nil argBuf") + } + if string(c.pa.subject) != "foo" { + t.Fatalf("Expected subject to be \"foo\", got %q", c.pa.subject) + } + if string(c.pa.reply) != "" { + t.Fatalf("Expected reply to be \"\", got %q", c.pa.reply) + } + if string(c.pa.sid) != "RSID:2:1" { + t.Fatalf("Expected sid to \"RSID:2:1\", got %q", c.pa.sid) + } + if c.pa.size != 5 { + t.Fatalf("Expected sid to 5, got %v", c.pa.size) + } + // msg buffer should be + if c.msgBuf == nil || string(c.msgBuf) != "hel" { + t.Fatalf("Expected msgBuf to be \"hel\", got %q", c.msgBuf) + } + c.parse(msgop[23:]) + // At the end, we should have cleaned-up both arg and msg buffers. + if c.argBuf != nil { + t.Fatalf("Expected argBuf to be nil: %q", c.argBuf) + } + if c.msgBuf != nil { + t.Fatalf("Expected msgBuf to be nil: %q", c.msgBuf) + } +} + +func TestSplitMsgArg(t *testing.T) { + _, c, _ := setupClient() + // Allow parser to process MSG + c.typ = ROUTER + + b := make([]byte, 1024) + + copy(b, []byte("MSG hello.world RSID:14:8 6040\r\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")) + c.parse(b) + + copy(b, []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB\r\n")) + c.parse(b) + + wantSubject := "hello.world" + wantSid := "RSID:14:8" + wantSzb := "6040" + + if string(c.pa.subject) != wantSubject { + t.Fatalf("Incorrect subject: want %q, got %q", wantSubject, c.pa.subject) + } + + if string(c.pa.sid) != wantSid { + t.Fatalf("Incorrect sid: want %q, got %q", wantSid, c.pa.sid) + } + + if string(c.pa.szb) != wantSzb { + t.Fatalf("Incorrect szb: want %q, got %q", wantSzb, c.pa.szb) + } +} + +func TestSplitBufferMsgOp(t *testing.T) { + c := &client{subs: make(map[string]*subscription), typ: ROUTER} + msg := []byte("MSG foo.bar QRSID:15:3 _INBOX.22 11\r\nhello world\r") + msg1 := msg[:2] + msg2 := msg[2:9] + msg3 := msg[9:15] + msg4 := msg[15:22] + msg5 := msg[22:25] + msg6 := msg[25:37] + msg7 := msg[37:42] + msg8 := msg[42:] + + if err := c.parse(msg1); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != OP_MS { + t.Fatalf("Expected OP_MS state vs %d\n", c.state) + } + if err := c.parse(msg2); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != MSG_ARG { + t.Fatalf("Expected MSG_ARG state vs %d\n", c.state) + } + if err := c.parse(msg3); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != MSG_ARG { + t.Fatalf("Expected MSG_ARG state vs %d\n", c.state) + } + if err := c.parse(msg4); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != MSG_ARG { + t.Fatalf("Expected MSG_ARG state vs %d\n", c.state) + } + if err := c.parse(msg5); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != MSG_ARG { + t.Fatalf("Expected MSG_ARG state vs %d\n", c.state) + } + if err := c.parse(msg6); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != MSG_PAYLOAD { + t.Fatalf("Expected MSG_PAYLOAD state vs %d\n", c.state) + } + + // Check c.pa + if !bytes.Equal(c.pa.subject, []byte("foo.bar")) { + t.Fatalf("MSG arg subject incorrect: '%s'\n", c.pa.subject) + } + if !bytes.Equal(c.pa.sid, []byte("QRSID:15:3")) { + t.Fatalf("MSG arg sid incorrect: '%s'\n", c.pa.sid) + } + if !bytes.Equal(c.pa.reply, []byte("_INBOX.22")) { + t.Fatalf("MSG arg reply subject incorrect: '%s'\n", c.pa.reply) + } + if c.pa.size != 11 { + t.Fatalf("MSG arg msg size incorrect: %d\n", c.pa.size) + } + if err := c.parse(msg7); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != MSG_PAYLOAD { + t.Fatalf("Expected MSG_PAYLOAD state vs %d\n", c.state) + } + if err := c.parse(msg8); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.state != MSG_END { + t.Fatalf("Expected MSG_END state vs %d\n", c.state) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/server/sublist.go b/vendor/github.com/nats-io/gnatsd/server/sublist.go new file mode 100644 index 00000000..e699f2af --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/sublist.go @@ -0,0 +1,775 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package sublist is a routing mechanism to handle subject distribution +// and provides a facility to match subjects from published messages to +// interested subscribers. Subscribers can have wildcard subjects to match +// multiple published subjects. +package server + +import ( + "bytes" + "errors" + "strings" + "sync" + "sync/atomic" +) + +// Common byte variables for wildcards and token separator. +const ( + pwc = '*' + fwc = '>' + tsep = "." + btsep = '.' +) + +// Sublist related errors +var ( + ErrInvalidSubject = errors.New("sublist: Invalid Subject") + ErrNotFound = errors.New("sublist: No Matches Found") +) + +const ( + // cacheMax is used to bound limit the frontend cache + slCacheMax = 1024 + // plistMin is our lower bounds to create a fast plist for Match. + plistMin = 256 +) + +// A result structure better optimized for queue subs. +type SublistResult struct { + psubs []*subscription + qsubs [][]*subscription // don't make this a map, too expensive to iterate +} + +// A Sublist stores and efficiently retrieves subscriptions. +type Sublist struct { + sync.RWMutex + genid uint64 + matches uint64 + cacheHits uint64 + inserts uint64 + removes uint64 + cache map[string]*SublistResult + root *level + count uint32 +} + +// A node contains subscriptions and a pointer to the next level. +type node struct { + next *level + psubs map[*subscription]*subscription + qsubs map[string](map[*subscription]*subscription) + plist []*subscription +} + +// A level represents a group of nodes and special pointers to +// wildcard nodes. +type level struct { + nodes map[string]*node + pwc, fwc *node +} + +// Create a new default node. +func newNode() *node { + return &node{psubs: make(map[*subscription]*subscription)} +} + +// Create a new default level. +func newLevel() *level { + return &level{nodes: make(map[string]*node)} +} + +// New will create a default sublist +func NewSublist() *Sublist { + return &Sublist{root: newLevel(), cache: make(map[string]*SublistResult)} +} + +// Insert adds a subscription into the sublist +func (s *Sublist) Insert(sub *subscription) error { + // copy the subject since we hold this and this might be part of a large byte slice. + subject := string(sub.subject) + tsa := [32]string{} + tokens := tsa[:0] + start := 0 + for i := 0; i < len(subject); i++ { + if subject[i] == btsep { + tokens = append(tokens, subject[start:i]) + start = i + 1 + } + } + tokens = append(tokens, subject[start:]) + + s.Lock() + + sfwc := false + l := s.root + var n *node + + for _, t := range tokens { + lt := len(t) + if lt == 0 || sfwc { + s.Unlock() + return ErrInvalidSubject + } + + if lt > 1 { + n = l.nodes[t] + } else { + switch t[0] { + case pwc: + n = l.pwc + case fwc: + n = l.fwc + sfwc = true + default: + n = l.nodes[t] + } + } + if n == nil { + n = newNode() + if lt > 1 { + l.nodes[t] = n + } else { + switch t[0] { + case pwc: + l.pwc = n + case fwc: + l.fwc = n + default: + l.nodes[t] = n + } + } + } + if n.next == nil { + n.next = newLevel() + } + l = n.next + } + if sub.queue == nil { + n.psubs[sub] = sub + if n.plist != nil { + n.plist = append(n.plist, sub) + } else if len(n.psubs) > plistMin { + n.plist = make([]*subscription, 0, len(n.psubs)) + // Populate + for _, psub := range n.psubs { + n.plist = append(n.plist, psub) + } + } + } else { + if n.qsubs == nil { + n.qsubs = make(map[string]map[*subscription]*subscription) + } + qname := string(sub.queue) + // This is a queue subscription + subs, ok := n.qsubs[qname] + if !ok { + subs = make(map[*subscription]*subscription) + n.qsubs[qname] = subs + } + subs[sub] = sub + } + + s.count++ + s.inserts++ + + s.addToCache(subject, sub) + atomic.AddUint64(&s.genid, 1) + + s.Unlock() + return nil +} + +// Deep copy +func copyResult(r *SublistResult) *SublistResult { + nr := &SublistResult{} + nr.psubs = append([]*subscription(nil), r.psubs...) + for _, qr := range r.qsubs { + nqr := append([]*subscription(nil), qr...) + nr.qsubs = append(nr.qsubs, nqr) + } + return nr +} + +// addToCache will add the new entry to existing cache +// entries if needed. Assumes write lock is held. +func (s *Sublist) addToCache(subject string, sub *subscription) { + for k, r := range s.cache { + if matchLiteral(k, subject) { + // Copy since others may have a reference. + nr := copyResult(r) + if sub.queue == nil { + nr.psubs = append(nr.psubs, sub) + } else { + if i := findQSliceForSub(sub, nr.qsubs); i >= 0 { + nr.qsubs[i] = append(nr.qsubs[i], sub) + } else { + nr.qsubs = append(nr.qsubs, []*subscription{sub}) + } + } + s.cache[k] = nr + } + } +} + +// removeFromCache will remove the sub from any active cache entries. +// Assumes write lock is held. +func (s *Sublist) removeFromCache(subject string, sub *subscription) { + for k := range s.cache { + if !matchLiteral(k, subject) { + continue + } + // Since someone else may be referecing, can't modify the list + // safely, just let it re-populate. + delete(s.cache, k) + } +} + +// Match will match all entries to the literal subject. +// It will return a set of results for both normal and queue subscribers. +func (s *Sublist) Match(subject string) *SublistResult { + s.RLock() + atomic.AddUint64(&s.matches, 1) + rc, ok := s.cache[subject] + s.RUnlock() + if ok { + atomic.AddUint64(&s.cacheHits, 1) + return rc + } + + tsa := [32]string{} + tokens := tsa[:0] + start := 0 + for i := 0; i < len(subject); i++ { + if subject[i] == btsep { + tokens = append(tokens, subject[start:i]) + start = i + 1 + } + } + tokens = append(tokens, subject[start:]) + + // FIXME(dlc) - Make shared pool between sublist and client readLoop? + result := &SublistResult{} + + s.Lock() + matchLevel(s.root, tokens, result) + + // Add to our cache + s.cache[subject] = result + // Bound the number of entries to sublistMaxCache + if len(s.cache) > slCacheMax { + for k := range s.cache { + delete(s.cache, k) + break + } + } + s.Unlock() + + return result +} + +// This will add in a node's results to the total results. +func addNodeToResults(n *node, results *SublistResult) { + // Normal subscriptions + if n.plist != nil { + results.psubs = append(results.psubs, n.plist...) + } else { + for _, psub := range n.psubs { + results.psubs = append(results.psubs, psub) + } + } + // Queue subscriptions + for qname, qr := range n.qsubs { + if len(qr) == 0 { + continue + } + tsub := &subscription{subject: nil, queue: []byte(qname)} + // Need to find matching list in results + if i := findQSliceForSub(tsub, results.qsubs); i >= 0 { + for _, sub := range qr { + results.qsubs[i] = append(results.qsubs[i], sub) + } + } else { + var nqsub []*subscription + for _, sub := range qr { + nqsub = append(nqsub, sub) + } + results.qsubs = append(results.qsubs, nqsub) + } + } +} + +// We do not use a map here since we want iteration to be past when +// processing publishes in L1 on client. So we need to walk sequentially +// for now. Keep an eye on this in case we start getting large number of +// different queue subscribers for the same subject. +func findQSliceForSub(sub *subscription, qsl [][]*subscription) int { + if sub.queue == nil { + return -1 + } + for i, qr := range qsl { + if len(qr) > 0 && bytes.Equal(sub.queue, qr[0].queue) { + return i + } + } + return -1 +} + +// matchLevel is used to recursively descend into the trie. +func matchLevel(l *level, toks []string, results *SublistResult) { + var pwc, n *node + for i, t := range toks { + if l == nil { + return + } + if l.fwc != nil { + addNodeToResults(l.fwc, results) + } + if pwc = l.pwc; pwc != nil { + matchLevel(pwc.next, toks[i+1:], results) + } + n = l.nodes[t] + if n != nil { + l = n.next + } else { + l = nil + } + } + if n != nil { + addNodeToResults(n, results) + } + if pwc != nil { + addNodeToResults(pwc, results) + } +} + +// lnt is used to track descent into levels for a removal for pruning. +type lnt struct { + l *level + n *node + t string +} + +// Raw low level remove, can do batches with lock held outside. +func (s *Sublist) remove(sub *subscription, shouldLock bool) error { + subject := string(sub.subject) + tsa := [32]string{} + tokens := tsa[:0] + start := 0 + for i := 0; i < len(subject); i++ { + if subject[i] == btsep { + tokens = append(tokens, subject[start:i]) + start = i + 1 + } + } + tokens = append(tokens, subject[start:]) + + if shouldLock { + s.Lock() + defer s.Unlock() + } + + sfwc := false + l := s.root + var n *node + + // Track levels for pruning + var lnts [32]lnt + levels := lnts[:0] + + for _, t := range tokens { + lt := len(t) + if lt == 0 || sfwc { + return ErrInvalidSubject + } + if l == nil { + return ErrNotFound + } + if lt > 1 { + n = l.nodes[t] + } else { + switch t[0] { + case pwc: + n = l.pwc + case fwc: + n = l.fwc + sfwc = true + default: + n = l.nodes[t] + } + } + if n != nil { + levels = append(levels, lnt{l, n, t}) + l = n.next + } else { + l = nil + } + } + if !s.removeFromNode(n, sub) { + return ErrNotFound + } + + s.count-- + s.removes++ + + for i := len(levels) - 1; i >= 0; i-- { + l, n, t := levels[i].l, levels[i].n, levels[i].t + if n.isEmpty() { + l.pruneNode(n, t) + } + } + s.removeFromCache(subject, sub) + atomic.AddUint64(&s.genid, 1) + + return nil +} + +// Remove will remove a subscription. +func (s *Sublist) Remove(sub *subscription) error { + return s.remove(sub, true) +} + +// RemoveBatch will remove a list of subscriptions. +func (s *Sublist) RemoveBatch(subs []*subscription) error { + s.Lock() + defer s.Unlock() + + for _, sub := range subs { + if err := s.remove(sub, false); err != nil { + return err + } + } + return nil +} + +// pruneNode is used to prune an empty node from the tree. +func (l *level) pruneNode(n *node, t string) { + if n == nil { + return + } + if n == l.fwc { + l.fwc = nil + } else if n == l.pwc { + l.pwc = nil + } else { + delete(l.nodes, t) + } +} + +// isEmpty will test if the node has any entries. Used +// in pruning. +func (n *node) isEmpty() bool { + if len(n.psubs) == 0 && len(n.qsubs) == 0 { + if n.next == nil || n.next.numNodes() == 0 { + return true + } + } + return false +} + +// Return the number of nodes for the given level. +func (l *level) numNodes() int { + num := len(l.nodes) + if l.pwc != nil { + num++ + } + if l.fwc != nil { + num++ + } + return num +} + +// Remove the sub for the given node. +func (s *Sublist) removeFromNode(n *node, sub *subscription) (found bool) { + if n == nil { + return false + } + if sub.queue == nil { + _, found = n.psubs[sub] + delete(n.psubs, sub) + if found && n.plist != nil { + // This will brute force remove the plist to perform + // correct behavior. Will get repopulated on a call + //to Match as needed. + n.plist = nil + } + return found + } + + // We have a queue group subscription here + qname := string(sub.queue) + qsub := n.qsubs[qname] + _, found = qsub[sub] + delete(qsub, sub) + if len(qsub) == 0 { + delete(n.qsubs, qname) + } + return found +} + +// Count returns the number of subscriptions. +func (s *Sublist) Count() uint32 { + s.RLock() + defer s.RUnlock() + return s.count +} + +// CacheCount returns the number of result sets in the cache. +func (s *Sublist) CacheCount() int { + s.RLock() + defer s.RUnlock() + return len(s.cache) +} + +// Public stats for the sublist +type SublistStats struct { + NumSubs uint32 `json:"num_subscriptions"` + NumCache uint32 `json:"num_cache"` + NumInserts uint64 `json:"num_inserts"` + NumRemoves uint64 `json:"num_removes"` + NumMatches uint64 `json:"num_matches"` + CacheHitRate float64 `json:"cache_hit_rate"` + MaxFanout uint32 `json:"max_fanout"` + AvgFanout float64 `json:"avg_fanout"` +} + +// Stats will return a stats structure for the current state. +func (s *Sublist) Stats() *SublistStats { + s.Lock() + defer s.Unlock() + + st := &SublistStats{} + st.NumSubs = s.count + st.NumCache = uint32(len(s.cache)) + st.NumInserts = s.inserts + st.NumRemoves = s.removes + st.NumMatches = atomic.LoadUint64(&s.matches) + if st.NumMatches > 0 { + st.CacheHitRate = float64(atomic.LoadUint64(&s.cacheHits)) / float64(st.NumMatches) + } + // whip through cache for fanout stats + tot, max := 0, 0 + for _, r := range s.cache { + l := len(r.psubs) + len(r.qsubs) + tot += l + if l > max { + max = l + } + } + st.MaxFanout = uint32(max) + if tot > 0 { + st.AvgFanout = float64(tot) / float64(len(s.cache)) + } + return st +} + +// numLevels will return the maximum number of levels +// contained in the Sublist tree. +func (s *Sublist) numLevels() int { + return visitLevel(s.root, 0) +} + +// visitLevel is used to descend the Sublist tree structure +// recursively. +func visitLevel(l *level, depth int) int { + if l == nil || l.numNodes() == 0 { + return depth + } + + depth++ + maxDepth := depth + + for _, n := range l.nodes { + if n == nil { + continue + } + newDepth := visitLevel(n.next, depth) + if newDepth > maxDepth { + maxDepth = newDepth + } + } + if l.pwc != nil { + pwcDepth := visitLevel(l.pwc.next, depth) + if pwcDepth > maxDepth { + maxDepth = pwcDepth + } + } + if l.fwc != nil { + fwcDepth := visitLevel(l.fwc.next, depth) + if fwcDepth > maxDepth { + maxDepth = fwcDepth + } + } + return maxDepth +} + +// IsValidSubject returns true if a subject is valid, false otherwise +func IsValidSubject(subject string) bool { + if subject == "" { + return false + } + sfwc := false + tokens := strings.Split(subject, tsep) + for _, t := range tokens { + if len(t) == 0 || sfwc { + return false + } + if len(t) > 1 { + continue + } + switch t[0] { + case fwc: + sfwc = true + } + } + return true +} + +// IsValidLiteralSubject returns true if a subject is valid and literal (no wildcards), false otherwise +func IsValidLiteralSubject(subject string) bool { + tokens := strings.Split(subject, tsep) + for _, t := range tokens { + if len(t) == 0 { + return false + } + if len(t) > 1 { + continue + } + switch t[0] { + case pwc, fwc: + return false + } + } + return true +} + +// matchLiteral is used to test literal subjects, those that do not have any +// wildcards, with a target subject. This is used in the cache layer. +func matchLiteral(literal, subject string) bool { + li := 0 + ll := len(literal) + ls := len(subject) + for i := 0; i < ls; i++ { + if li >= ll { + return false + } + // This function has been optimized for speed. + // For instance, do not set b:=subject[i] here since + // we may bump `i` in this loop to avoid `continue` or + // skiping common test in a particular test. + // Run Benchmark_SublistMatchLiteral before making any change. + switch subject[i] { + case pwc: + // NOTE: This is not testing validity of a subject, instead ensures + // that wildcards are treated as such if they follow some basic rules, + // namely that they are a token on their own. + if i == 0 || subject[i-1] == btsep { + if i == ls-1 { + // There is no more token in the subject after this wildcard. + // Skip token in literal and expect to not find a separator. + for { + // End of literal, this is a match. + if li >= ll { + return true + } + // Presence of separator, this can't be a match. + if literal[li] == btsep { + return false + } + li++ + } + } else if subject[i+1] == btsep { + // There is another token in the subject after this wildcard. + // Skip token in literal and expect to get a separator. + for { + // We found the end of the literal before finding a separator, + // this can't be a match. + if li >= ll { + return false + } + if literal[li] == btsep { + break + } + li++ + } + // Bump `i` since we know there is a `.` following, we are + // safe. The common test below is going to check `.` with `.` + // which is good. A `continue` here is too costly. + i++ + } + } + case fwc: + // For `>` to be a wildcard, it means being the only or last character + // in the string preceded by a `.` + if (i == 0 || subject[i-1] == btsep) && i == ls-1 { + return true + } + } + if subject[i] != literal[li] { + return false + } + li++ + } + // Make sure we have processed all of the literal's chars.. + return li >= ll +} + +func addLocalSub(sub *subscription, subs *[]*subscription) { + if sub != nil && sub.client != nil && sub.client.typ == CLIENT { + *subs = append(*subs, sub) + } +} + +func (s *Sublist) addNodeToSubs(n *node, subs *[]*subscription) { + // Normal subscriptions + if n.plist != nil { + for _, sub := range n.plist { + addLocalSub(sub, subs) + } + } else { + for _, sub := range n.psubs { + addLocalSub(sub, subs) + } + } + // Queue subscriptions + for _, qr := range n.qsubs { + for _, sub := range qr { + addLocalSub(sub, subs) + } + } +} + +func (s *Sublist) collectLocalSubs(l *level, subs *[]*subscription) { + if len(l.nodes) > 0 { + for _, n := range l.nodes { + s.addNodeToSubs(n, subs) + s.collectLocalSubs(n.next, subs) + } + } + if l.pwc != nil { + s.addNodeToSubs(l.pwc, subs) + s.collectLocalSubs(l.pwc.next, subs) + } + if l.fwc != nil { + s.addNodeToSubs(l.fwc, subs) + s.collectLocalSubs(l.fwc.next, subs) + } +} + +// Return all local client subscriptions. Use the supplied slice. +func (s *Sublist) localSubs(subs *[]*subscription) { + s.RLock() + s.collectLocalSubs(s.root, subs) + s.RUnlock() +} diff --git a/vendor/github.com/nats-io/gnatsd/server/sublist_test.go b/vendor/github.com/nats-io/gnatsd/server/sublist_test.go new file mode 100644 index 00000000..a534dd4a --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/sublist_test.go @@ -0,0 +1,1019 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + "runtime" + "strings" + "sync" + "testing" + + dbg "runtime/debug" +) + +func stackFatalf(t *testing.T, f string, args ...interface{}) { + lines := make([]string, 0, 32) + msg := fmt.Sprintf(f, args...) + lines = append(lines, msg) + + // Generate the Stack of callers: Skip us and verify* frames. + for i := 2; true; i++ { + _, file, line, ok := runtime.Caller(i) + if !ok { + break + } + msg := fmt.Sprintf("%d - %s:%d", i, file, line) + lines = append(lines, msg) + } + t.Fatalf("%s", strings.Join(lines, "\n")) +} + +func verifyCount(s *Sublist, count uint32, t *testing.T) { + if s.Count() != count { + stackFatalf(t, "Count is %d, should be %d", s.Count(), count) + } +} + +func verifyLen(r []*subscription, l int, t *testing.T) { + if len(r) != l { + stackFatalf(t, "Results len is %d, should be %d", len(r), l) + } +} + +func verifyQLen(r [][]*subscription, l int, t *testing.T) { + if len(r) != l { + stackFatalf(t, "Queue Results len is %d, should be %d", len(r), l) + } +} + +func verifyNumLevels(s *Sublist, expected int, t *testing.T) { + dl := s.numLevels() + if dl != expected { + stackFatalf(t, "NumLevels is %d, should be %d", dl, expected) + } +} + +func verifyQMember(qsubs [][]*subscription, val *subscription, t *testing.T) { + verifyMember(qsubs[findQSliceForSub(val, qsubs)], val, t) +} + +func verifyMember(r []*subscription, val *subscription, t *testing.T) { + for _, v := range r { + if v == nil { + continue + } + if v == val { + return + } + } + stackFatalf(t, "Subscription (%p) for [%s : %s] not found in results", val, val.subject, val.queue) +} + +// Helpers to generate test subscriptions. +func newSub(subject string) *subscription { + return &subscription{subject: []byte(subject)} +} + +func newQSub(subject, queue string) *subscription { + if queue != "" { + return &subscription{subject: []byte(subject), queue: []byte(queue)} + } + return newSub(subject) +} + +func TestSublistInit(t *testing.T) { + s := NewSublist() + verifyCount(s, 0, t) +} + +func TestSublistInsertCount(t *testing.T) { + s := NewSublist() + s.Insert(newSub("foo")) + s.Insert(newSub("bar")) + s.Insert(newSub("foo.bar")) + verifyCount(s, 3, t) +} + +func TestSublistSimple(t *testing.T) { + s := NewSublist() + subject := "foo" + sub := newSub(subject) + s.Insert(sub) + r := s.Match(subject) + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, sub, t) +} + +func TestSublistSimpleMultiTokens(t *testing.T) { + s := NewSublist() + subject := "foo.bar.baz" + sub := newSub(subject) + s.Insert(sub) + r := s.Match(subject) + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, sub, t) +} + +func TestSublistPartialWildcard(t *testing.T) { + s := NewSublist() + lsub := newSub("a.b.c") + psub := newSub("a.*.c") + s.Insert(lsub) + s.Insert(psub) + r := s.Match("a.b.c") + verifyLen(r.psubs, 2, t) + verifyMember(r.psubs, lsub, t) + verifyMember(r.psubs, psub, t) +} + +func TestSublistPartialWildcardAtEnd(t *testing.T) { + s := NewSublist() + lsub := newSub("a.b.c") + psub := newSub("a.b.*") + s.Insert(lsub) + s.Insert(psub) + r := s.Match("a.b.c") + verifyLen(r.psubs, 2, t) + verifyMember(r.psubs, lsub, t) + verifyMember(r.psubs, psub, t) +} + +func TestSublistFullWildcard(t *testing.T) { + s := NewSublist() + lsub := newSub("a.b.c") + fsub := newSub("a.>") + s.Insert(lsub) + s.Insert(fsub) + r := s.Match("a.b.c") + verifyLen(r.psubs, 2, t) + verifyMember(r.psubs, lsub, t) + verifyMember(r.psubs, fsub, t) +} + +func TestSublistRemove(t *testing.T) { + s := NewSublist() + subject := "a.b.c.d" + sub := newSub(subject) + s.Insert(sub) + verifyCount(s, 1, t) + r := s.Match(subject) + verifyLen(r.psubs, 1, t) + s.Remove(newSub("a.b.c")) + verifyCount(s, 1, t) + s.Remove(sub) + verifyCount(s, 0, t) + r = s.Match(subject) + verifyLen(r.psubs, 0, t) +} + +func TestSublistRemoveWildcard(t *testing.T) { + s := NewSublist() + subject := "a.b.c.d" + sub := newSub(subject) + psub := newSub("a.b.*.d") + fsub := newSub("a.b.>") + s.Insert(sub) + s.Insert(psub) + s.Insert(fsub) + verifyCount(s, 3, t) + r := s.Match(subject) + verifyLen(r.psubs, 3, t) + s.Remove(sub) + verifyCount(s, 2, t) + s.Remove(fsub) + verifyCount(s, 1, t) + s.Remove(psub) + verifyCount(s, 0, t) + r = s.Match(subject) + verifyLen(r.psubs, 0, t) +} + +func TestSublistRemoveCleanup(t *testing.T) { + s := NewSublist() + literal := "a.b.c.d.e.f" + depth := len(strings.Split(literal, tsep)) + sub := newSub(literal) + verifyNumLevels(s, 0, t) + s.Insert(sub) + verifyNumLevels(s, depth, t) + s.Remove(sub) + verifyNumLevels(s, 0, t) +} + +func TestSublistRemoveCleanupWildcards(t *testing.T) { + s := NewSublist() + subject := "a.b.*.d.e.>" + depth := len(strings.Split(subject, tsep)) + sub := newSub(subject) + verifyNumLevels(s, 0, t) + s.Insert(sub) + verifyNumLevels(s, depth, t) + s.Remove(sub) + verifyNumLevels(s, 0, t) +} + +func TestSublistRemoveWithLargeSubs(t *testing.T) { + subject := "foo" + s := NewSublist() + for i := 0; i < plistMin*2; i++ { + sub := newSub(subject) + s.Insert(sub) + } + r := s.Match(subject) + verifyLen(r.psubs, plistMin*2, t) + // Remove one that is in the middle + s.Remove(r.psubs[plistMin]) + // Remove first one + s.Remove(r.psubs[0]) + // Remove last one + s.Remove(r.psubs[len(r.psubs)-1]) + // Check len again + r = s.Match(subject) + verifyLen(r.psubs, plistMin*2-3, t) +} + +func TestSublistInvalidSubjectsInsert(t *testing.T) { + s := NewSublist() + + // Insert, or subscriptions, can have wildcards, but not empty tokens, + // and can not have a FWC that is not the terminal token. + + // beginning empty token + if err := s.Insert(newSub(".foo")); err != ErrInvalidSubject { + t.Fatal("Expected invalid subject error") + } + + // trailing empty token + if err := s.Insert(newSub("foo.")); err != ErrInvalidSubject { + t.Fatal("Expected invalid subject error") + } + // empty middle token + if err := s.Insert(newSub("foo..bar")); err != ErrInvalidSubject { + t.Fatal("Expected invalid subject error") + } + // empty middle token #2 + if err := s.Insert(newSub("foo.bar..baz")); err != ErrInvalidSubject { + t.Fatal("Expected invalid subject error") + } + // fwc not terminal + if err := s.Insert(newSub("foo.>.bar")); err != ErrInvalidSubject { + t.Fatal("Expected invalid subject error") + } +} + +func TestSublistCache(t *testing.T) { + s := NewSublist() + + // Test add a remove logistics + subject := "a.b.c.d" + sub := newSub(subject) + psub := newSub("a.b.*.d") + fsub := newSub("a.b.>") + s.Insert(sub) + r := s.Match(subject) + verifyLen(r.psubs, 1, t) + s.Insert(psub) + s.Insert(fsub) + verifyCount(s, 3, t) + r = s.Match(subject) + verifyLen(r.psubs, 3, t) + s.Remove(sub) + verifyCount(s, 2, t) + s.Remove(fsub) + verifyCount(s, 1, t) + s.Remove(psub) + verifyCount(s, 0, t) + + // Check that cache is now empty + if cc := s.CacheCount(); cc != 0 { + t.Fatalf("Cache should be zero, got %d\n", cc) + } + + r = s.Match(subject) + verifyLen(r.psubs, 0, t) + + for i := 0; i < 2*slCacheMax; i++ { + s.Match(fmt.Sprintf("foo-%d\n", i)) + } + + if cc := s.CacheCount(); cc > slCacheMax { + t.Fatalf("Cache should be constrained by cacheMax, got %d for current count\n", cc) + } +} + +func TestSublistBasicQueueResults(t *testing.T) { + s := NewSublist() + + // Test some basics + subject := "foo" + sub := newSub(subject) + sub1 := newQSub(subject, "bar") + sub2 := newQSub(subject, "baz") + + s.Insert(sub1) + r := s.Match(subject) + verifyLen(r.psubs, 0, t) + verifyQLen(r.qsubs, 1, t) + verifyLen(r.qsubs[0], 1, t) + verifyQMember(r.qsubs, sub1, t) + + s.Insert(sub2) + r = s.Match(subject) + verifyLen(r.psubs, 0, t) + verifyQLen(r.qsubs, 2, t) + verifyLen(r.qsubs[0], 1, t) + verifyLen(r.qsubs[1], 1, t) + verifyQMember(r.qsubs, sub1, t) + verifyQMember(r.qsubs, sub2, t) + + s.Insert(sub) + r = s.Match(subject) + verifyLen(r.psubs, 1, t) + verifyQLen(r.qsubs, 2, t) + verifyLen(r.qsubs[0], 1, t) + verifyLen(r.qsubs[1], 1, t) + verifyQMember(r.qsubs, sub1, t) + verifyQMember(r.qsubs, sub2, t) + verifyMember(r.psubs, sub, t) + + sub3 := newQSub(subject, "bar") + sub4 := newQSub(subject, "baz") + + s.Insert(sub3) + s.Insert(sub4) + + r = s.Match(subject) + verifyLen(r.psubs, 1, t) + verifyQLen(r.qsubs, 2, t) + verifyLen(r.qsubs[0], 2, t) + verifyLen(r.qsubs[1], 2, t) + verifyQMember(r.qsubs, sub1, t) + verifyQMember(r.qsubs, sub2, t) + verifyQMember(r.qsubs, sub3, t) + verifyQMember(r.qsubs, sub4, t) + verifyMember(r.psubs, sub, t) + + // Now removal + s.Remove(sub) + + r = s.Match(subject) + verifyLen(r.psubs, 0, t) + verifyQLen(r.qsubs, 2, t) + verifyLen(r.qsubs[0], 2, t) + verifyLen(r.qsubs[1], 2, t) + verifyQMember(r.qsubs, sub1, t) + verifyQMember(r.qsubs, sub2, t) + + s.Remove(sub1) + r = s.Match(subject) + verifyLen(r.psubs, 0, t) + verifyQLen(r.qsubs, 2, t) + verifyLen(r.qsubs[findQSliceForSub(sub1, r.qsubs)], 1, t) + verifyLen(r.qsubs[findQSliceForSub(sub2, r.qsubs)], 2, t) + verifyQMember(r.qsubs, sub2, t) + verifyQMember(r.qsubs, sub3, t) + verifyQMember(r.qsubs, sub4, t) + + s.Remove(sub3) // Last one + r = s.Match(subject) + verifyLen(r.psubs, 0, t) + verifyQLen(r.qsubs, 1, t) + verifyLen(r.qsubs[0], 2, t) // this is sub2/baz now + verifyQMember(r.qsubs, sub2, t) + + s.Remove(sub2) + s.Remove(sub4) + r = s.Match(subject) + verifyLen(r.psubs, 0, t) + verifyQLen(r.qsubs, 0, t) +} + +func checkBool(b, expected bool, t *testing.T) { + if b != expected { + dbg.PrintStack() + t.Fatalf("Expected %v, but got %v\n", expected, b) + } +} + +func TestSublistValidLiteralSubjects(t *testing.T) { + checkBool(IsValidLiteralSubject("foo"), true, t) + checkBool(IsValidLiteralSubject(".foo"), false, t) + checkBool(IsValidLiteralSubject("foo."), false, t) + checkBool(IsValidLiteralSubject("foo..bar"), false, t) + checkBool(IsValidLiteralSubject("foo.bar.*"), false, t) + checkBool(IsValidLiteralSubject("foo.bar.>"), false, t) + checkBool(IsValidLiteralSubject("*"), false, t) + checkBool(IsValidLiteralSubject(">"), false, t) + // The followings have widlcards characters but are not + // considered as such because they are not individual tokens. + checkBool(IsValidLiteralSubject("foo*"), true, t) + checkBool(IsValidLiteralSubject("foo**"), true, t) + checkBool(IsValidLiteralSubject("foo.**"), true, t) + checkBool(IsValidLiteralSubject("foo*bar"), true, t) + checkBool(IsValidLiteralSubject("foo.*bar"), true, t) + checkBool(IsValidLiteralSubject("foo*.bar"), true, t) + checkBool(IsValidLiteralSubject("*bar"), true, t) + checkBool(IsValidLiteralSubject("foo>"), true, t) + checkBool(IsValidLiteralSubject("foo>>"), true, t) + checkBool(IsValidLiteralSubject("foo.>>"), true, t) + checkBool(IsValidLiteralSubject("foo>bar"), true, t) + checkBool(IsValidLiteralSubject("foo.>bar"), true, t) + checkBool(IsValidLiteralSubject("foo>.bar"), true, t) + checkBool(IsValidLiteralSubject(">bar"), true, t) +} + +func TestSublistValidlSubjects(t *testing.T) { + checkBool(IsValidSubject("."), false, t) + checkBool(IsValidSubject(".foo"), false, t) + checkBool(IsValidSubject("foo."), false, t) + checkBool(IsValidSubject("foo..bar"), false, t) + checkBool(IsValidSubject(">.bar"), false, t) + checkBool(IsValidSubject("foo.>.bar"), false, t) + checkBool(IsValidSubject("foo"), true, t) + checkBool(IsValidSubject("foo.bar.*"), true, t) + checkBool(IsValidSubject("foo.bar.>"), true, t) + checkBool(IsValidSubject("*"), true, t) + checkBool(IsValidSubject(">"), true, t) + checkBool(IsValidSubject("foo*"), true, t) + checkBool(IsValidSubject("foo**"), true, t) + checkBool(IsValidSubject("foo.**"), true, t) + checkBool(IsValidSubject("foo*bar"), true, t) + checkBool(IsValidSubject("foo.*bar"), true, t) + checkBool(IsValidSubject("foo*.bar"), true, t) + checkBool(IsValidSubject("*bar"), true, t) + checkBool(IsValidSubject("foo>"), true, t) + checkBool(IsValidSubject("foo.>>"), true, t) + checkBool(IsValidSubject("foo>bar"), true, t) + checkBool(IsValidSubject("foo.>bar"), true, t) + checkBool(IsValidSubject("foo>.bar"), true, t) + checkBool(IsValidSubject(">bar"), true, t) +} + +func TestSublistMatchLiterals(t *testing.T) { + checkBool(matchLiteral("foo", "foo"), true, t) + checkBool(matchLiteral("foo", "bar"), false, t) + checkBool(matchLiteral("foo", "*"), true, t) + checkBool(matchLiteral("foo", ">"), true, t) + checkBool(matchLiteral("foo.bar", ">"), true, t) + checkBool(matchLiteral("foo.bar", "foo.>"), true, t) + checkBool(matchLiteral("foo.bar", "bar.>"), false, t) + checkBool(matchLiteral("stats.test.22", "stats.>"), true, t) + checkBool(matchLiteral("stats.test.22", "stats.*.*"), true, t) + checkBool(matchLiteral("foo.bar", "foo"), false, t) + checkBool(matchLiteral("stats.test.foos", "stats.test.foos"), true, t) + checkBool(matchLiteral("stats.test.foos", "stats.test.foo"), false, t) + checkBool(matchLiteral("stats.test", "stats.test.*"), false, t) + checkBool(matchLiteral("stats.test.foos", "stats.*"), false, t) + checkBool(matchLiteral("stats.test.foos", "stats.*.*.foos"), false, t) + + // These are cases where wildcards characters should not be considered + // wildcards since they do not follow the rules of wildcards. + checkBool(matchLiteral("*bar", "*bar"), true, t) + checkBool(matchLiteral("foo*", "foo*"), true, t) + checkBool(matchLiteral("foo*bar", "foo*bar"), true, t) + checkBool(matchLiteral("foo.***.bar", "foo.***.bar"), true, t) + checkBool(matchLiteral(">bar", ">bar"), true, t) + checkBool(matchLiteral("foo>", "foo>"), true, t) + checkBool(matchLiteral("foo>bar", "foo>bar"), true, t) + checkBool(matchLiteral("foo.>>>.bar", "foo.>>>.bar"), true, t) +} + +func TestSublistBadSubjectOnRemove(t *testing.T) { + bad := "a.b..d" + sub := newSub(bad) + + s := NewSublist() + if err := s.Insert(sub); err != ErrInvalidSubject { + t.Fatalf("Expected ErrInvalidSubject, got %v\n", err) + } + + if err := s.Remove(sub); err != ErrInvalidSubject { + t.Fatalf("Expected ErrInvalidSubject, got %v\n", err) + } + + badfwc := "a.>.b" + if err := s.Remove(newSub(badfwc)); err != ErrInvalidSubject { + t.Fatalf("Expected ErrInvalidSubject, got %v\n", err) + } +} + +// This is from bug report #18 +func TestSublistTwoTokenPubMatchSingleTokenSub(t *testing.T) { + s := NewSublist() + sub := newSub("foo") + s.Insert(sub) + r := s.Match("foo") + verifyLen(r.psubs, 1, t) + verifyMember(r.psubs, sub, t) + r = s.Match("foo.bar") + verifyLen(r.psubs, 0, t) +} + +func TestSublistInsertWithWildcardsAsLiterals(t *testing.T) { + s := NewSublist() + subjects := []string{"foo.*-", "foo.>-"} + for _, subject := range subjects { + sub := newSub(subject) + s.Insert(sub) + // Should find no match + r := s.Match("foo.bar") + verifyLen(r.psubs, 0, t) + // Should find a match + r = s.Match(subject) + verifyLen(r.psubs, 1, t) + } +} + +func TestSublistRemoveWithWildcardsAsLiterals(t *testing.T) { + s := NewSublist() + subjects := []string{"foo.*-", "foo.>-"} + for _, subject := range subjects { + sub := newSub(subject) + s.Insert(sub) + // Should find no match + rsub := newSub("foo.bar") + s.Remove(rsub) + if c := s.Count(); c != 1 { + t.Fatalf("Expected sublist to still contain sub, got %v", c) + } + s.Remove(sub) + if c := s.Count(); c != 0 { + t.Fatalf("Expected sublist to be empty, got %v", c) + } + } +} + +func TestSublistRaceOnRemove(t *testing.T) { + s := NewSublist() + + var ( + total = 100 + subs = make(map[int]*subscription, total) // use map for randomness + ) + for i := 0; i < total; i++ { + sub := newQSub("foo", "bar") + subs[i] = sub + } + + for i := 0; i < 2; i++ { + for _, sub := range subs { + s.Insert(sub) + } + // Call Match() once or twice, to make sure we get from cache + if i == 1 { + s.Match("foo") + } + // This will be from cache when i==1 + r := s.Match("foo") + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + for _, sub := range subs { + s.Remove(sub) + } + wg.Done() + }() + for _, qsub := range r.qsubs { + for i := 0; i < len(qsub); i++ { + sub := qsub[i] + if string(sub.queue) != "bar" { + t.Fatalf("Queue name should be bar, got %s", qsub[i].queue) + } + } + } + wg.Wait() + } + + // Repeat tests with regular subs + for i := 0; i < total; i++ { + sub := newSub("foo") + subs[i] = sub + } + + for i := 0; i < 2; i++ { + for _, sub := range subs { + s.Insert(sub) + } + // Call Match() once or twice, to make sure we get from cache + if i == 1 { + s.Match("foo") + } + // This will be from cache when i==1 + r := s.Match("foo") + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + for _, sub := range subs { + s.Remove(sub) + } + wg.Done() + }() + for i := 0; i < len(r.psubs); i++ { + sub := r.psubs[i] + if string(sub.subject) != "foo" { + t.Fatalf("Subject should be foo, got %s", sub.subject) + } + } + wg.Wait() + } +} + +func TestSublistRaceOnInsert(t *testing.T) { + s := NewSublist() + + var ( + total = 100 + subs = make(map[int]*subscription, total) // use map for randomness + wg sync.WaitGroup + ) + for i := 0; i < total; i++ { + sub := newQSub("foo", "bar") + subs[i] = sub + } + wg.Add(1) + go func() { + for _, sub := range subs { + s.Insert(sub) + } + wg.Done() + }() + for i := 0; i < 1000; i++ { + r := s.Match("foo") + for _, qsubs := range r.qsubs { + for _, qsub := range qsubs { + if string(qsub.queue) != "bar" { + t.Fatalf("Expected queue name to be bar, got %v", string(qsub.queue)) + } + } + } + } + wg.Wait() + + // Repeat the test with plain subs + for i := 0; i < total; i++ { + sub := newSub("foo") + subs[i] = sub + } + wg.Add(1) + go func() { + for _, sub := range subs { + s.Insert(sub) + } + wg.Done() + }() + for i := 0; i < 1000; i++ { + r := s.Match("foo") + for _, sub := range r.psubs { + if string(sub.subject) != "foo" { + t.Fatalf("Expected subject to be foo, got %v", string(sub.subject)) + } + } + } + wg.Wait() +} + +func TestSublistRaceOnMatch(t *testing.T) { + s := NewSublist() + s.Insert(newQSub("foo.*", "workers")) + s.Insert(newQSub("foo.bar", "workers")) + s.Insert(newSub("foo.*")) + s.Insert(newSub("foo.bar")) + + wg := sync.WaitGroup{} + wg.Add(2) + errCh := make(chan error, 2) + f := func() { + defer wg.Done() + for i := 0; i < 10; i++ { + r := s.Match("foo.bar") + for _, sub := range r.psubs { + if !strings.HasPrefix(string(sub.subject), "foo.") { + errCh <- fmt.Errorf("Wrong subject: %s", sub.subject) + return + } + } + for _, qsub := range r.qsubs { + for _, sub := range qsub { + if string(sub.queue) != "workers" { + errCh <- fmt.Errorf("Wrong queue name: %s", sub.queue) + return + } + } + } + // Empty cache to maximize chance for race + s.Lock() + delete(s.cache, "foo.bar") + s.Unlock() + } + } + go f() + go f() + wg.Wait() + select { + case e := <-errCh: + t.Fatalf(e.Error()) + default: + } +} + +// -- Benchmarks Setup -- + +var subs []*subscription +var toks = []string{"apcera", "continuum", "component", "router", "api", "imgr", "jmgr", "auth"} +var sl = NewSublist() + +func init() { + subs = make([]*subscription, 0, 256*1024) + subsInit("") + for i := 0; i < len(subs); i++ { + sl.Insert(subs[i]) + } + addWildcards() +} + +func subsInit(pre string) { + var sub string + for _, t := range toks { + if len(pre) > 0 { + sub = pre + tsep + t + } else { + sub = t + } + subs = append(subs, newSub(sub)) + if len(strings.Split(sub, tsep)) < 5 { + subsInit(sub) + } + } +} + +func addWildcards() { + sl.Insert(newSub("cloud.>")) + sl.Insert(newSub("cloud.continuum.component.>")) + sl.Insert(newSub("cloud.*.*.router.*")) +} + +// -- Benchmarks Setup End -- + +func Benchmark______________________SublistInsert(b *testing.B) { + s := NewSublist() + for i, l := 0, len(subs); i < b.N; i++ { + index := i % l + s.Insert(subs[index]) + } +} + +func Benchmark____________SublistMatchSingleToken(b *testing.B) { + for i := 0; i < b.N; i++ { + sl.Match("apcera") + } +} + +func Benchmark______________SublistMatchTwoTokens(b *testing.B) { + for i := 0; i < b.N; i++ { + sl.Match("apcera.continuum") + } +} + +func Benchmark____________SublistMatchThreeTokens(b *testing.B) { + for i := 0; i < b.N; i++ { + sl.Match("apcera.continuum.component") + } +} + +func Benchmark_____________SublistMatchFourTokens(b *testing.B) { + for i := 0; i < b.N; i++ { + sl.Match("apcera.continuum.component.router") + } +} + +func Benchmark_SublistMatchFourTokensSingleResult(b *testing.B) { + for i := 0; i < b.N; i++ { + sl.Match("apcera.continuum.component.router") + } +} + +func Benchmark_SublistMatchFourTokensMultiResults(b *testing.B) { + for i := 0; i < b.N; i++ { + sl.Match("cloud.continuum.component.router") + } +} + +func Benchmark_______SublistMissOnLastTokenOfFive(b *testing.B) { + for i := 0; i < b.N; i++ { + sl.Match("apcera.continuum.component.router.ZZZZ") + } +} + +func multiRead(b *testing.B, num int) { + b.StopTimer() + var swg, fwg sync.WaitGroup + swg.Add(num) + fwg.Add(num) + s := "apcera.continuum.component.router" + for i := 0; i < num; i++ { + go func() { + swg.Done() + swg.Wait() + for i := 0; i < b.N; i++ { + sl.Match(s) + } + fwg.Done() + }() + } + swg.Wait() + b.StartTimer() + fwg.Wait() +} + +func Benchmark____________Sublist10XMultipleReads(b *testing.B) { + multiRead(b, 10) +} + +func Benchmark___________Sublist100XMultipleReads(b *testing.B) { + multiRead(b, 100) +} + +func Benchmark__________Sublist1000XMultipleReads(b *testing.B) { + multiRead(b, 1000) +} + +func Benchmark________________SublistMatchLiteral(b *testing.B) { + b.StopTimer() + cachedSubj := "foo.foo.foo.foo.foo.foo.foo.foo.foo.foo" + subjects := []string{ + "foo.foo.foo.foo.foo.foo.foo.foo.foo.foo", + "foo.foo.foo.foo.foo.foo.foo.foo.foo.>", + "foo.foo.foo.foo.foo.foo.foo.foo.>", + "foo.foo.foo.foo.foo.foo.foo.>", + "foo.foo.foo.foo.foo.foo.>", + "foo.foo.foo.foo.foo.>", + "foo.foo.foo.foo.>", + "foo.foo.foo.>", + "foo.foo.>", + "foo.>", + ">", + "foo.foo.foo.foo.foo.foo.foo.foo.foo.*", + "foo.foo.foo.foo.foo.foo.foo.foo.*.*", + "foo.foo.foo.foo.foo.foo.foo.*.*.*", + "foo.foo.foo.foo.foo.foo.*.*.*.*", + "foo.foo.foo.foo.foo.*.*.*.*.*", + "foo.foo.foo.foo.*.*.*.*.*.*", + "foo.foo.foo.*.*.*.*.*.*.*", + "foo.foo.*.*.*.*.*.*.*.*", + "foo.*.*.*.*.*.*.*.*.*", + "*.*.*.*.*.*.*.*.*.*", + } + b.StartTimer() + for i := 0; i < b.N; i++ { + for _, subject := range subjects { + if !matchLiteral(cachedSubj, subject) { + b.Fatalf("Subject %q no match with %q", cachedSubj, subject) + } + } + } +} + +func Benchmark_____SublistMatch10kSubsWithNoCache(b *testing.B) { + var nsubs = 512 + b.StopTimer() + s := NewSublist() + subject := "foo" + for i := 0; i < nsubs; i++ { + s.Insert(newSub(subject)) + } + b.StartTimer() + for i := 0; i < b.N; i++ { + r := s.Match(subject) + if len(r.psubs) != nsubs { + b.Fatalf("Results len is %d, should be %d", len(r.psubs), nsubs) + } + delete(s.cache, subject) + } +} + +func removeTest(b *testing.B, singleSubject, doBatch bool, qgroup string) { + b.StopTimer() + s := NewSublist() + subject := "foo" + + subs := make([]*subscription, 0, b.N) + for i := 0; i < b.N; i++ { + var sub *subscription + if singleSubject { + sub = newQSub(subject, qgroup) + } else { + sub = newQSub(fmt.Sprintf("%s.%d\n", subject, i), qgroup) + } + s.Insert(sub) + subs = append(subs, sub) + } + + // Actual test on Remove + b.StartTimer() + if doBatch { + s.RemoveBatch(subs) + } else { + for _, sub := range subs { + s.Remove(sub) + } + } +} + +func Benchmark__________SublistRemove1TokenSingle(b *testing.B) { + removeTest(b, true, false, "") +} + +func Benchmark___________SublistRemove1TokenBatch(b *testing.B) { + removeTest(b, true, true, "") +} + +func Benchmark_________SublistRemove2TokensSingle(b *testing.B) { + removeTest(b, false, false, "") +} + +func Benchmark__________SublistRemove2TokensBatch(b *testing.B) { + removeTest(b, false, true, "") +} + +func Benchmark________SublistRemove1TokenQGSingle(b *testing.B) { + removeTest(b, true, false, "bar") +} + +func Benchmark_________SublistRemove1TokenQGBatch(b *testing.B) { + removeTest(b, true, true, "bar") +} + +func removeMultiTest(b *testing.B, singleSubject, doBatch bool) { + b.StopTimer() + s := NewSublist() + subject := "foo" + var swg, fwg sync.WaitGroup + swg.Add(b.N) + fwg.Add(b.N) + + // We will have b.N go routines each with 1k subscriptions. + sc := 1000 + + for i := 0; i < b.N; i++ { + go func() { + subs := make([]*subscription, 0, sc) + for n := 0; n < sc; n++ { + var sub *subscription + if singleSubject { + sub = newSub(subject) + } else { + sub = newSub(fmt.Sprintf("%s.%d\n", subject, n)) + } + s.Insert(sub) + subs = append(subs, sub) + } + // Wait to start test + swg.Done() + swg.Wait() + // Actual test on Remove + if doBatch { + s.RemoveBatch(subs) + } else { + for _, sub := range subs { + s.Remove(sub) + } + } + fwg.Done() + }() + } + swg.Wait() + b.StartTimer() + fwg.Wait() +} + +// Check contention rates for remove from multiple Go routines. +// Reason for BatchRemove. +func Benchmark_________SublistRemove1kSingleMulti(b *testing.B) { + removeMultiTest(b, true, false) +} + +// Batch version +func Benchmark__________SublistRemove1kBatchMulti(b *testing.B) { + removeMultiTest(b, true, true) +} + +func Benchmark__SublistRemove1kSingle2TokensMulti(b *testing.B) { + removeMultiTest(b, false, false) +} + +// Batch version +func Benchmark___SublistRemove1kBatch2TokensMulti(b *testing.B) { + removeMultiTest(b, false, true) +} diff --git a/vendor/github.com/nats-io/gnatsd/server/util.go b/vendor/github.com/nats-io/gnatsd/server/util.go new file mode 100644 index 00000000..3a2ffe66 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/util.go @@ -0,0 +1,111 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "errors" + "fmt" + "net" + "strconv" + "strings" + "time" + + "github.com/nats-io/nuid" +) + +// Use nuid. +func genID() string { + return nuid.Next() +} + +// Ascii numbers 0-9 +const ( + asciiZero = 48 + asciiNine = 57 +) + +// parseSize expects decimal positive numbers. We +// return -1 to signal error. +func parseSize(d []byte) (n int) { + l := len(d) + if l == 0 { + return -1 + } + var ( + i int + dec byte + ) + + // Note: Use `goto` here to avoid for loop in order + // to have the function be inlined. + // See: https://github.com/golang/go/issues/14768 +loop: + dec = d[i] + if dec < asciiZero || dec > asciiNine { + return -1 + } + n = n*10 + (int(dec) - asciiZero) + + i++ + if i < l { + goto loop + } + return n +} + +// parseInt64 expects decimal positive numbers. We +// return -1 to signal error +func parseInt64(d []byte) (n int64) { + if len(d) == 0 { + return -1 + } + for _, dec := range d { + if dec < asciiZero || dec > asciiNine { + return -1 + } + n = n*10 + (int64(dec) - asciiZero) + } + return n +} + +// Helper to move from float seconds to time.Duration +func secondsToDuration(seconds float64) time.Duration { + ttl := seconds * float64(time.Second) + return time.Duration(ttl) +} + +// Parse a host/port string with a default port to use +// if none (or 0 or -1) is specified in `hostPort` string. +func parseHostPort(hostPort string, defaultPort int) (host string, port int, err error) { + if hostPort != "" { + host, sPort, err := net.SplitHostPort(hostPort) + switch err.(type) { + case *net.AddrError: + // try appending the current port + host, sPort, err = net.SplitHostPort(fmt.Sprintf("%s:%d", hostPort, defaultPort)) + } + if err != nil { + return "", -1, err + } + port, err = strconv.Atoi(strings.TrimSpace(sPort)) + if err != nil { + return "", -1, err + } + if port == 0 || port == -1 { + port = defaultPort + } + return strings.TrimSpace(host), port, nil + } + return "", -1, errors.New("No hostport specified") +} diff --git a/vendor/github.com/nats-io/gnatsd/server/util_test.go b/vendor/github.com/nats-io/gnatsd/server/util_test.go new file mode 100644 index 00000000..29748a42 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/server/util_test.go @@ -0,0 +1,168 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "math/rand" + "strconv" + "sync" + "testing" + "time" +) + +func TestParseSize(t *testing.T) { + if parseSize(nil) != -1 { + t.Fatal("Should error on nil byte slice") + } + n := []byte("12345678") + if pn := parseSize(n); pn != 12345678 { + t.Fatalf("Did not parse %q correctly, res=%d\n", n, pn) + } +} + +func TestParseSInt64(t *testing.T) { + if parseInt64(nil) != -1 { + t.Fatal("Should error on nil byte slice") + } + n := []byte("12345678") + if pn := parseInt64(n); pn != 12345678 { + t.Fatalf("Did not parse %q correctly, res=%d\n", n, pn) + } +} + +func TestParseHostPort(t *testing.T) { + check := func(hostPort string, defaultPort int, expectedHost string, expectedPort int, expectedErr bool) { + h, p, err := parseHostPort(hostPort, defaultPort) + if expectedErr { + if err == nil { + stackFatalf(t, "Expected an error, did not get one") + } + // expected error, so we are done + return + } + if !expectedErr && err != nil { + stackFatalf(t, "Unexpected error: %v", err) + } + if expectedHost != h { + stackFatalf(t, "Expected host %q, got %q", expectedHost, h) + } + if expectedPort != p { + stackFatalf(t, "Expected port %d, got %d", expectedPort, p) + } + } + check("addr:1234", 5678, "addr", 1234, false) + check(" addr:1234 ", 5678, "addr", 1234, false) + check(" addr : 1234 ", 5678, "addr", 1234, false) + check("addr", 5678, "addr", 5678, false) + check(" addr ", 5678, "addr", 5678, false) + check("addr:-1", 5678, "addr", 5678, false) + check(" addr:-1 ", 5678, "addr", 5678, false) + check(" addr : -1 ", 5678, "addr", 5678, false) + check("addr:0", 5678, "addr", 5678, false) + check(" addr:0 ", 5678, "addr", 5678, false) + check(" addr : 0 ", 5678, "addr", 5678, false) + check("addr:addr", 0, "", 0, true) + check("addr:::1234", 0, "", 0, true) + check("", 0, "", 0, true) +} + +func BenchmarkParseInt(b *testing.B) { + b.SetBytes(1) + n := "12345678" + for i := 0; i < b.N; i++ { + strconv.ParseInt(n, 10, 0) + } +} + +func BenchmarkParseSize(b *testing.B) { + b.SetBytes(1) + n := []byte("12345678") + for i := 0; i < b.N; i++ { + parseSize(n) + } +} + +func deferUnlock(mu *sync.Mutex) { + mu.Lock() + defer mu.Unlock() + // see noDeferUnlock + if false { + return + } +} + +func BenchmarkDeferMutex(b *testing.B) { + var mu sync.Mutex + b.SetBytes(1) + for i := 0; i < b.N; i++ { + deferUnlock(&mu) + } +} + +func noDeferUnlock(mu *sync.Mutex) { + mu.Lock() + // prevent staticcheck warning about empty critical section + if false { + return + } + mu.Unlock() +} + +func BenchmarkNoDeferMutex(b *testing.B) { + var mu sync.Mutex + b.SetBytes(1) + for i := 0; i < b.N; i++ { + noDeferUnlock(&mu) + } +} + +func createTestSub() *subscription { + return &subscription{ + subject: []byte("foo"), + queue: []byte("bar"), + sid: []byte("22"), + } +} + +func BenchmarkArrayRand(b *testing.B) { + b.StopTimer() + r := rand.New(rand.NewSource(time.Now().UnixNano())) + // Create an array of 10 items + subs := []*subscription{} + for i := 0; i < 10; i++ { + subs = append(subs, createTestSub()) + } + b.StartTimer() + + for i := 0; i < b.N; i++ { + index := r.Intn(len(subs)) + _ = subs[index] + } +} + +func BenchmarkMapRange(b *testing.B) { + b.StopTimer() + // Create an map of 10 items + subs := map[int]*subscription{} + for i := 0; i < 10; i++ { + subs[i] = createTestSub() + } + b.StartTimer() + + for i := 0; i < b.N; i++ { + for range subs { + break + } + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/auth_test.go b/vendor/github.com/nats-io/gnatsd/test/auth_test.go new file mode 100644 index 00000000..ed391b25 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/auth_test.go @@ -0,0 +1,236 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "encoding/json" + "fmt" + "net" + "testing" + "time" + + "github.com/nats-io/gnatsd/server" +) + +func doAuthConnect(t tLogger, c net.Conn, token, user, pass string) { + cs := fmt.Sprintf("CONNECT {\"verbose\":true,\"auth_token\":\"%s\",\"user\":\"%s\",\"pass\":\"%s\"}\r\n", token, user, pass) + sendProto(t, c, cs) +} + +func testInfoForAuth(t tLogger, infojs []byte) bool { + var sinfo server.Info + err := json.Unmarshal(infojs, &sinfo) + if err != nil { + t.Fatalf("Could not unmarshal INFO json: %v\n", err) + } + return sinfo.AuthRequired +} + +func expectAuthRequired(t tLogger, c net.Conn) { + buf := expectResult(t, c, infoRe) + infojs := infoRe.FindAllSubmatch(buf, 1)[0][1] + if !testInfoForAuth(t, infojs) { + t.Fatalf("Expected server to require authorization: '%s'", infojs) + } +} + +//////////////////////////////////////////////////////////// +// The authorization token version +//////////////////////////////////////////////////////////// + +const AUTH_PORT = 10422 +const AUTH_TOKEN = "_YZZ22_" + +func runAuthServerWithToken() *server.Server { + opts := DefaultTestOptions + opts.Port = AUTH_PORT + opts.Authorization = AUTH_TOKEN + return RunServer(&opts) +} + +func TestNoAuthClient(t *testing.T) { + s := runAuthServerWithToken() + defer s.Shutdown() + c := createClientConn(t, "127.0.0.1", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", "", "") + expectResult(t, c, errRe) +} + +func TestAuthClientBadToken(t *testing.T) { + s := runAuthServerWithToken() + defer s.Shutdown() + c := createClientConn(t, "127.0.0.1", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "ZZZ", "", "") + expectResult(t, c, errRe) +} + +func TestAuthClientNoConnect(t *testing.T) { + s := runAuthServerWithToken() + defer s.Shutdown() + c := createClientConn(t, "127.0.0.1", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + // This is timing dependent.. + time.Sleep(server.AUTH_TIMEOUT) + expectResult(t, c, errRe) +} + +func TestAuthClientGoodConnect(t *testing.T) { + s := runAuthServerWithToken() + defer s.Shutdown() + c := createClientConn(t, "127.0.0.1", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, AUTH_TOKEN, "", "") + expectResult(t, c, okRe) +} + +func TestAuthClientFailOnEverythingElse(t *testing.T) { + s := runAuthServerWithToken() + defer s.Shutdown() + c := createClientConn(t, "127.0.0.1", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + sendProto(t, c, "PUB foo 2\r\nok\r\n") + expectResult(t, c, errRe) +} + +//////////////////////////////////////////////////////////// +// The username/password version +//////////////////////////////////////////////////////////// + +const AUTH_USER = "derek" +const AUTH_PASS = "foobar" + +func runAuthServerWithUserPass() *server.Server { + opts := DefaultTestOptions + opts.Port = AUTH_PORT + opts.Username = AUTH_USER + opts.Password = AUTH_PASS + return RunServer(&opts) +} + +func TestNoUserOrPasswordClient(t *testing.T) { + s := runAuthServerWithUserPass() + defer s.Shutdown() + c := createClientConn(t, "127.0.0.1", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", "", "") + expectResult(t, c, errRe) +} + +func TestBadUserClient(t *testing.T) { + s := runAuthServerWithUserPass() + defer s.Shutdown() + c := createClientConn(t, "127.0.0.1", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", "derekzz", AUTH_PASS) + expectResult(t, c, errRe) +} + +func TestBadPasswordClient(t *testing.T) { + s := runAuthServerWithUserPass() + defer s.Shutdown() + c := createClientConn(t, "127.0.0.1", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", AUTH_USER, "ZZ") + expectResult(t, c, errRe) +} + +func TestPasswordClientGoodConnect(t *testing.T) { + s := runAuthServerWithUserPass() + defer s.Shutdown() + c := createClientConn(t, "127.0.0.1", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", AUTH_USER, AUTH_PASS) + expectResult(t, c, okRe) +} + +//////////////////////////////////////////////////////////// +// The bcrypt username/password version +//////////////////////////////////////////////////////////// + +// Generated with util/mkpasswd (Cost 4 because of cost of --race, default is 11) +const BCRYPT_AUTH_PASS = "IW@$6v(y1(t@fhPDvf!5^%" +const BCRYPT_AUTH_HASH = "$2a$04$Q.CgCP2Sl9pkcTXEZHazaeMwPaAkSHk7AI51HkyMt5iJQQyUA4qxq" + +func runAuthServerWithBcryptUserPass() *server.Server { + opts := DefaultTestOptions + opts.Port = AUTH_PORT + opts.Username = AUTH_USER + opts.Password = BCRYPT_AUTH_HASH + return RunServer(&opts) +} + +func TestBadBcryptPassword(t *testing.T) { + s := runAuthServerWithBcryptUserPass() + defer s.Shutdown() + c := createClientConn(t, "127.0.0.1", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", AUTH_USER, BCRYPT_AUTH_HASH) + expectResult(t, c, errRe) +} + +func TestGoodBcryptPassword(t *testing.T) { + s := runAuthServerWithBcryptUserPass() + defer s.Shutdown() + c := createClientConn(t, "127.0.0.1", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", AUTH_USER, BCRYPT_AUTH_PASS) + expectResult(t, c, okRe) +} + +//////////////////////////////////////////////////////////// +// The bcrypt authorization token version +//////////////////////////////////////////////////////////// + +const BCRYPT_AUTH_TOKEN = "0uhJOSr3GW7xvHvtd^K6pa" +const BCRYPT_AUTH_TOKEN_HASH = "$2a$04$u5ZClXpcjHgpfc61Ee0VKuwI1K3vTC4zq7SjphjnlHMeb1Llkb5Y6" + +func runAuthServerWithBcryptToken() *server.Server { + opts := DefaultTestOptions + opts.Port = AUTH_PORT + opts.Authorization = BCRYPT_AUTH_TOKEN_HASH + return RunServer(&opts) +} + +func TestBadBcryptToken(t *testing.T) { + s := runAuthServerWithBcryptToken() + defer s.Shutdown() + c := createClientConn(t, "127.0.0.1", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, BCRYPT_AUTH_TOKEN_HASH, "", "") + expectResult(t, c, errRe) +} + +func TestGoodBcryptToken(t *testing.T) { + s := runAuthServerWithBcryptToken() + defer s.Shutdown() + c := createClientConn(t, "127.0.0.1", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, BCRYPT_AUTH_TOKEN, "", "") + expectResult(t, c, okRe) +} diff --git a/vendor/github.com/nats-io/gnatsd/test/bench_results.txt b/vendor/github.com/nats-io/gnatsd/test/bench_results.txt new file mode 100644 index 00000000..89850ca5 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/bench_results.txt @@ -0,0 +1,79 @@ +2017 iMac Pro 3Ghz (Turbo 4Ghz) 10-Core Skylake +OSX High Sierra 10.13.2 + +=================== +Go version go1.9.2 +=================== + +Benchmark_____Pub0b_Payload-20 30000000 55.1 ns/op 199.78 MB/s +Benchmark_____Pub8b_Payload-20 30000000 55.8 ns/op 340.21 MB/s +Benchmark____Pub32b_Payload-20 20000000 63.4 ns/op 694.34 MB/s +Benchmark___Pub128B_Payload-20 20000000 79.8 ns/op 1766.47 MB/s +Benchmark___Pub256B_Payload-20 20000000 98.1 ns/op 2741.51 MB/s +Benchmark_____Pub1K_Payload-20 5000000 283 ns/op 3660.72 MB/s +Benchmark_____Pub4K_Payload-20 1000000 1395 ns/op 2945.30 MB/s +Benchmark_____Pub8K_Payload-20 500000 2846 ns/op 2882.35 MB/s +Benchmark_AuthPub0b_Payload-20 10000000 126 ns/op 86.82 MB/s +Benchmark____________PubSub-20 10000000 135 ns/op +Benchmark____PubSubTwoConns-20 10000000 136 ns/op +Benchmark____PubTwoQueueSub-20 10000000 152 ns/op +Benchmark___PubFourQueueSub-20 10000000 152 ns/op +Benchmark__PubEightQueueSub-20 10000000 152 ns/op +Benchmark___RoutedPubSub_0b-20 5000000 385 ns/op +Benchmark___RoutedPubSub_1K-20 1000000 1076 ns/op +Benchmark_RoutedPubSub_100K-20 20000 78501 ns/op + + +2015 iMac5k 4Ghz i7 Haswell +OSX El Capitan 10.11.3 + +=================== +Go version go1.6 +=================== + +Benchmark____PubNo_Payload-8 20000000 88.6 ns/op 124.11 MB/s +Benchmark____Pub8b_Payload-8 20000000 89.8 ns/op 211.63 MB/s +Benchmark___Pub32b_Payload-8 20000000 97.3 ns/op 452.20 MB/s +Benchmark__Pub256B_Payload-8 10000000 129 ns/op 2078.43 MB/s +Benchmark____Pub1K_Payload-8 5000000 216 ns/op 4791.00 MB/s +Benchmark____Pub4K_Payload-8 1000000 1123 ns/op 3657.53 MB/s +Benchmark____Pub8K_Payload-8 500000 2309 ns/op 3553.09 MB/s +Benchmark___________PubSub-8 10000000 210 ns/op +Benchmark___PubSubTwoConns-8 10000000 205 ns/op +Benchmark___PubTwoQueueSub-8 10000000 231 ns/op +Benchmark__PubFourQueueSub-8 10000000 233 ns/op +Benchmark_PubEightQueueSub-8 5000000 231 ns/op + +OSX Yosemite 10.10.5 + +=================== +Go version go1.4.2 +=================== + +Benchmark___PubNo_Payload 10000000 133 ns/op 82.44 MB/s +Benchmark___Pub8b_Payload 10000000 135 ns/op 140.27 MB/s +Benchmark__Pub32b_Payload 10000000 147 ns/op 297.56 MB/s +Benchmark_Pub256B_Payload 10000000 211 ns/op 1273.82 MB/s +Benchmark___Pub1K_Payload 3000000 447 ns/op 2321.55 MB/s +Benchmark___Pub4K_Payload 1000000 1677 ns/op 2450.43 MB/s +Benchmark___Pub8K_Payload 300000 3670 ns/op 2235.80 MB/s +Benchmark__________PubSub 5000000 263 ns/op +Benchmark__PubSubTwoConns 5000000 268 ns/op +Benchmark__PubTwoQueueSub 2000000 936 ns/op +Benchmark_PubFourQueueSub 1000000 1103 ns/op + +=================== +Go version go1.5.0 +=================== + +Benchmark___PubNo_Payload-8 10000000 122 ns/op 89.94 MB/s +Benchmark___Pub8b_Payload-8 10000000 124 ns/op 152.72 MB/s +Benchmark__Pub32b_Payload-8 10000000 135 ns/op 325.73 MB/s +Benchmark_Pub256B_Payload-8 10000000 159 ns/op 1685.78 MB/s +Benchmark___Pub1K_Payload-8 5000000 256 ns/op 4047.90 MB/s +Benchmark___Pub4K_Payload-8 1000000 1164 ns/op 3530.77 MB/s +Benchmark___Pub8K_Payload-8 500000 2444 ns/op 3357.34 MB/s +Benchmark__________PubSub-8 5000000 254 ns/op +Benchmark__PubSubTwoConns-8 5000000 245 ns/op +Benchmark__PubTwoQueueSub-8 2000000 845 ns/op +Benchmark_PubFourQueueSub-8 1000000 1004 ns/op diff --git a/vendor/github.com/nats-io/gnatsd/test/bench_test.go b/vendor/github.com/nats-io/gnatsd/test/bench_test.go new file mode 100644 index 00000000..4f90b411 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/bench_test.go @@ -0,0 +1,674 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "bufio" + "fmt" + "math/rand" + "net" + "testing" + "time" + + "github.com/nats-io/gnatsd/server" +) + +const PERF_PORT = 8422 + +// For Go routine based server. +func runBenchServer() *server.Server { + opts := DefaultTestOptions + opts.Port = PERF_PORT + return RunServer(&opts) +} + +const defaultRecBufSize = 32768 +const defaultSendBufSize = 32768 + +func flushConnection(b *testing.B, c net.Conn) { + buf := make([]byte, 32) + c.Write([]byte("PING\r\n")) + c.SetReadDeadline(time.Now().Add(5 * time.Second)) + n, err := c.Read(buf) + c.SetReadDeadline(time.Time{}) + if err != nil { + b.Fatalf("Failed read: %v\n", err) + } + if n != 6 && buf[0] != 'P' && buf[1] != 'O' { + b.Fatalf("Failed read of PONG: %s\n", buf) + } +} + +func benchPub(b *testing.B, subject, payload string) { + b.StopTimer() + s := runBenchServer() + c := createClientConn(b, "127.0.0.1", PERF_PORT) + doDefaultConnect(b, c) + bw := bufio.NewWriterSize(c, defaultSendBufSize) + sendOp := []byte(fmt.Sprintf("PUB %s %d\r\n%s\r\n", subject, len(payload), payload)) + b.SetBytes(int64(len(sendOp))) + b.StartTimer() + for i := 0; i < b.N; i++ { + bw.Write(sendOp) + } + bw.Flush() + flushConnection(b, c) + b.StopTimer() + c.Close() + s.Shutdown() +} + +var ch = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()") + +func sizedBytes(sz int) []byte { + b := make([]byte, sz) + for i := range b { + b[i] = ch[rand.Intn(len(ch))] + } + return b +} + +func sizedString(sz int) string { + return string(sizedBytes(sz)) +} + +// Publish subject for pub benchmarks. +var psub = "a" + +func Benchmark______Pub0b_Payload(b *testing.B) { + benchPub(b, psub, "") +} + +func Benchmark______Pub8b_Payload(b *testing.B) { + b.StopTimer() + s := sizedString(8) + benchPub(b, psub, s) +} + +func Benchmark_____Pub32b_Payload(b *testing.B) { + b.StopTimer() + s := sizedString(32) + benchPub(b, psub, s) +} + +func Benchmark____Pub128B_Payload(b *testing.B) { + b.StopTimer() + s := sizedString(128) + benchPub(b, psub, s) +} + +func Benchmark____Pub256B_Payload(b *testing.B) { + b.StopTimer() + s := sizedString(256) + benchPub(b, psub, s) +} + +func Benchmark______Pub1K_Payload(b *testing.B) { + b.StopTimer() + s := sizedString(1024) + benchPub(b, psub, s) +} + +func Benchmark______Pub4K_Payload(b *testing.B) { + b.StopTimer() + s := sizedString(4 * 1024) + benchPub(b, psub, s) +} + +func Benchmark______Pub8K_Payload(b *testing.B) { + b.StopTimer() + s := sizedString(8 * 1024) + benchPub(b, psub, s) +} + +func Benchmark______Pub32K_Payload(b *testing.B) { + b.StopTimer() + s := sizedString(32 * 1024) + benchPub(b, psub, s) +} + +func drainConnection(b *testing.B, c net.Conn, ch chan bool, expected int) { + buf := make([]byte, defaultRecBufSize) + bytes := 0 + + for { + c.SetReadDeadline(time.Now().Add(30 * time.Second)) + n, err := c.Read(buf) + if err != nil { + b.Errorf("Error on read: %v\n", err) + break + } + bytes += n + if bytes >= expected { + break + } + } + if bytes != expected { + b.Errorf("Did not receive all bytes: %d vs %d\n", bytes, expected) + } + ch <- true +} + +// Benchmark the authorization code path. +func Benchmark__AuthPub0b_Payload(b *testing.B) { + b.StopTimer() + + srv, opts := RunServerWithConfig("./configs/authorization.conf") + defer srv.Shutdown() + + c := createClientConn(b, opts.Host, opts.Port) + defer c.Close() + expectAuthRequired(b, c) + + cs := fmt.Sprintf("CONNECT {\"verbose\":false,\"user\":\"%s\",\"pass\":\"%s\"}\r\n", "bench", DefaultPass) + sendProto(b, c, cs) + + bw := bufio.NewWriterSize(c, defaultSendBufSize) + sendOp := []byte("PUB a 0\r\n\r\n") + b.SetBytes(int64(len(sendOp))) + b.StartTimer() + for i := 0; i < b.N; i++ { + bw.Write(sendOp) + } + bw.Flush() + flushConnection(b, c) + b.StopTimer() +} + +func Benchmark_____________PubSub(b *testing.B) { + b.StopTimer() + s := runBenchServer() + c := createClientConn(b, "127.0.0.1", PERF_PORT) + doDefaultConnect(b, c) + sendProto(b, c, "SUB foo 1\r\n") + bw := bufio.NewWriterSize(c, defaultSendBufSize) + sendOp := []byte(fmt.Sprintf("PUB foo 2\r\nok\r\n")) + ch := make(chan bool) + expected := len("MSG foo 1 2\r\nok\r\n") * b.N + go drainConnection(b, c, ch, expected) + b.StartTimer() + + for i := 0; i < b.N; i++ { + _, err := bw.Write(sendOp) + if err != nil { + b.Errorf("Received error on PUB write: %v\n", err) + } + } + err := bw.Flush() + if err != nil { + b.Errorf("Received error on FLUSH write: %v\n", err) + } + + // Wait for connection to be drained + <-ch + + b.StopTimer() + c.Close() + s.Shutdown() +} + +func Benchmark_____PubSubTwoConns(b *testing.B) { + b.StopTimer() + s := runBenchServer() + c := createClientConn(b, "127.0.0.1", PERF_PORT) + doDefaultConnect(b, c) + bw := bufio.NewWriterSize(c, defaultSendBufSize) + + c2 := createClientConn(b, "127.0.0.1", PERF_PORT) + doDefaultConnect(b, c2) + sendProto(b, c2, "SUB foo 1\r\n") + flushConnection(b, c2) + + sendOp := []byte(fmt.Sprintf("PUB foo 2\r\nok\r\n")) + ch := make(chan bool) + + expected := len("MSG foo 1 2\r\nok\r\n") * b.N + go drainConnection(b, c2, ch, expected) + + b.StartTimer() + for i := 0; i < b.N; i++ { + bw.Write(sendOp) + } + err := bw.Flush() + if err != nil { + b.Errorf("Received error on FLUSH write: %v\n", err) + } + + // Wait for connection to be drained + <-ch + + b.StopTimer() + c.Close() + c2.Close() + s.Shutdown() +} + +func Benchmark_PubSub512kTwoConns(b *testing.B) { + b.StopTimer() + s := runBenchServer() + c := createClientConn(b, "127.0.0.1", PERF_PORT) + doDefaultConnect(b, c) + bw := bufio.NewWriterSize(c, defaultSendBufSize) + + c2 := createClientConn(b, "127.0.0.1", PERF_PORT) + doDefaultConnect(b, c2) + sendProto(b, c2, "SUB foo 1\r\n") + flushConnection(b, c2) + + sz := 1024 * 512 + payload := sizedString(sz) + + sendOp := []byte(fmt.Sprintf("PUB foo %d\r\n%s\r\n", sz, payload)) + ch := make(chan bool) + + expected := len(fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", sz, payload)) * b.N + go drainConnection(b, c2, ch, expected) + + b.StartTimer() + for i := 0; i < b.N; i++ { + bw.Write(sendOp) + } + err := bw.Flush() + if err != nil { + b.Errorf("Received error on FLUSH write: %v\n", err) + } + + // Wait for connection to be drained + <-ch + + b.StopTimer() + c.Close() + c2.Close() + s.Shutdown() +} + +func Benchmark_____PubTwoQueueSub(b *testing.B) { + b.StopTimer() + s := runBenchServer() + c := createClientConn(b, "127.0.0.1", PERF_PORT) + doDefaultConnect(b, c) + sendProto(b, c, "SUB foo group1 1\r\n") + sendProto(b, c, "SUB foo group1 2\r\n") + bw := bufio.NewWriterSize(c, defaultSendBufSize) + sendOp := []byte(fmt.Sprintf("PUB foo 2\r\nok\r\n")) + ch := make(chan bool) + expected := len("MSG foo 1 2\r\nok\r\n") * b.N + go drainConnection(b, c, ch, expected) + b.StartTimer() + + for i := 0; i < b.N; i++ { + _, err := bw.Write(sendOp) + if err != nil { + b.Fatalf("Received error on PUB write: %v\n", err) + } + } + err := bw.Flush() + if err != nil { + b.Fatalf("Received error on FLUSH write: %v\n", err) + } + + // Wait for connection to be drained + <-ch + + b.StopTimer() + c.Close() + s.Shutdown() +} + +func Benchmark____PubFourQueueSub(b *testing.B) { + b.StopTimer() + s := runBenchServer() + c := createClientConn(b, "127.0.0.1", PERF_PORT) + doDefaultConnect(b, c) + sendProto(b, c, "SUB foo group1 1\r\n") + sendProto(b, c, "SUB foo group1 2\r\n") + sendProto(b, c, "SUB foo group1 3\r\n") + sendProto(b, c, "SUB foo group1 4\r\n") + bw := bufio.NewWriterSize(c, defaultSendBufSize) + sendOp := []byte(fmt.Sprintf("PUB foo 2\r\nok\r\n")) + ch := make(chan bool) + expected := len("MSG foo 1 2\r\nok\r\n") * b.N + go drainConnection(b, c, ch, expected) + b.StartTimer() + + for i := 0; i < b.N; i++ { + _, err := bw.Write(sendOp) + if err != nil { + b.Fatalf("Received error on PUB write: %v\n", err) + } + } + err := bw.Flush() + if err != nil { + b.Fatalf("Received error on FLUSH write: %v\n", err) + } + + // Wait for connection to be drained + <-ch + + b.StopTimer() + c.Close() + s.Shutdown() +} + +func Benchmark___PubEightQueueSub(b *testing.B) { + b.StopTimer() + s := runBenchServer() + c := createClientConn(b, "127.0.0.1", PERF_PORT) + doDefaultConnect(b, c) + sendProto(b, c, "SUB foo group1 1\r\n") + sendProto(b, c, "SUB foo group1 2\r\n") + sendProto(b, c, "SUB foo group1 3\r\n") + sendProto(b, c, "SUB foo group1 4\r\n") + sendProto(b, c, "SUB foo group1 5\r\n") + sendProto(b, c, "SUB foo group1 6\r\n") + sendProto(b, c, "SUB foo group1 7\r\n") + sendProto(b, c, "SUB foo group1 8\r\n") + bw := bufio.NewWriterSize(c, defaultSendBufSize) + sendOp := []byte(fmt.Sprintf("PUB foo 2\r\nok\r\n")) + ch := make(chan bool) + expected := len("MSG foo 1 2\r\nok\r\n") * b.N + go drainConnection(b, c, ch, expected) + b.StartTimer() + + for i := 0; i < b.N; i++ { + _, err := bw.Write(sendOp) + if err != nil { + b.Fatalf("Received error on PUB write: %v\n", err) + } + } + err := bw.Flush() + if err != nil { + b.Fatalf("Received error on FLUSH write: %v\n", err) + } + + // Wait for connection to be drained + <-ch + + b.StopTimer() + c.Close() + s.Shutdown() +} + +func routePubSub(b *testing.B, size int) { + b.StopTimer() + + s1, o1 := RunServerWithConfig("./configs/srv_a.conf") + defer s1.Shutdown() + s2, o2 := RunServerWithConfig("./configs/srv_b.conf") + defer s2.Shutdown() + + sub := createClientConn(b, o1.Host, o1.Port) + doDefaultConnect(b, sub) + sendProto(b, sub, "SUB foo 1\r\n") + flushConnection(b, sub) + + payload := sizedString(size) + + pub := createClientConn(b, o2.Host, o2.Port) + doDefaultConnect(b, pub) + bw := bufio.NewWriterSize(pub, defaultSendBufSize) + + ch := make(chan bool) + sendOp := []byte(fmt.Sprintf("PUB foo %d\r\n%s\r\n", len(payload), payload)) + expected := len(fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload)) * b.N + go drainConnection(b, sub, ch, expected) + b.StartTimer() + + for i := 0; i < b.N; i++ { + _, err := bw.Write(sendOp) + if err != nil { + b.Fatalf("Received error on PUB write: %v\n", err) + } + + } + err := bw.Flush() + if err != nil { + b.Errorf("Received error on FLUSH write: %v\n", err) + } + + // Wait for connection to be drained + <-ch + + b.StopTimer() + pub.Close() + sub.Close() +} + +func Benchmark____RoutedPubSub_0b(b *testing.B) { + routePubSub(b, 2) +} + +func Benchmark____RoutedPubSub_1K(b *testing.B) { + routePubSub(b, 1024) +} + +func Benchmark__RoutedPubSub_100K(b *testing.B) { + routePubSub(b, 100*1024) +} + +func routeQueue(b *testing.B, numQueueSubs, size int) { + b.StopTimer() + + s1, o1 := RunServerWithConfig("./configs/srv_a.conf") + defer s1.Shutdown() + s2, o2 := RunServerWithConfig("./configs/srv_b.conf") + defer s2.Shutdown() + + sub := createClientConn(b, o1.Host, o1.Port) + doDefaultConnect(b, sub) + for i := 0; i < numQueueSubs; i++ { + sendProto(b, sub, fmt.Sprintf("SUB foo bar %d\r\n", 100+i)) + } + flushConnection(b, sub) + + payload := sizedString(size) + + pub := createClientConn(b, o2.Host, o2.Port) + doDefaultConnect(b, pub) + bw := bufio.NewWriterSize(pub, defaultSendBufSize) + + ch := make(chan bool) + sendOp := []byte(fmt.Sprintf("PUB foo %d\r\n%s\r\n", len(payload), payload)) + expected := len(fmt.Sprintf("MSG foo 100 %d\r\n%s\r\n", len(payload), payload)) * b.N + go drainConnection(b, sub, ch, expected) + b.StartTimer() + + for i := 0; i < b.N; i++ { + _, err := bw.Write(sendOp) + if err != nil { + b.Fatalf("Received error on PUB write: %v\n", err) + } + + } + err := bw.Flush() + if err != nil { + b.Errorf("Received error on FLUSH write: %v\n", err) + } + + // Wait for connection to be drained + <-ch + + b.StopTimer() + pub.Close() + sub.Close() +} + +func Benchmark____Routed2QueueSub(b *testing.B) { + routeQueue(b, 2, 2) +} + +func Benchmark____Routed4QueueSub(b *testing.B) { + routeQueue(b, 4, 2) +} + +func Benchmark____Routed8QueueSub(b *testing.B) { + routeQueue(b, 8, 2) +} + +func Benchmark___Routed16QueueSub(b *testing.B) { + routeQueue(b, 16, 2) +} + +func doFanout(b *testing.B, numServers, numConnections, subsPerConnection int, subject, payload string) { + var s1, s2 *server.Server + var o1, o2 *server.Options + + switch numServers { + case 1: + s1, o1 = RunServerWithConfig("./configs/srv_a.conf") + defer s1.Shutdown() + s2, o2 = s1, o1 + case 2: + s1, o1 = RunServerWithConfig("./configs/srv_a.conf") + defer s1.Shutdown() + s2, o2 = RunServerWithConfig("./configs/srv_b.conf") + defer s2.Shutdown() + default: + b.Fatalf("%d servers not supported for this test\n", numServers) + } + + // To get a consistent length sid in MSG sent to us for drainConnection. + var sidFloor int + switch { + case subsPerConnection <= 100: + sidFloor = 100 + case subsPerConnection <= 1000: + sidFloor = 1000 + case subsPerConnection <= 10000: + sidFloor = 10000 + default: + b.Fatalf("Unsupported SubsPerConnection argument of %d\n", subsPerConnection) + } + + msgOp := fmt.Sprintf("MSG %s %d %d\r\n%s\r\n", subject, sidFloor, len(payload), payload) + expected := len(msgOp) * subsPerConnection * b.N + + // Client connections and subscriptions. + clients := make([]chan bool, 0, numConnections) + for i := 0; i < numConnections; i++ { + c := createClientConn(b, o2.Host, o2.Port) + doDefaultConnect(b, c) + defer c.Close() + + ch := make(chan bool) + clients = append(clients, ch) + + for s := 0; s < subsPerConnection; s++ { + subOp := fmt.Sprintf("SUB %s %d\r\n", subject, sidFloor+s) + sendProto(b, c, subOp) + } + flushConnection(b, c) + go drainConnection(b, c, ch, expected) + } + // Publish Connection + c := createClientConn(b, o1.Host, o1.Port) + doDefaultConnect(b, c) + bw := bufio.NewWriterSize(c, defaultSendBufSize) + sendOp := []byte(fmt.Sprintf("PUB %s %d\r\n%s\r\n", subject, len(payload), payload)) + flushConnection(b, c) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := bw.Write(sendOp) + if err != nil { + b.Errorf("Received error on PUB write: %v\n", err) + } + } + err := bw.Flush() + if err != nil { + b.Errorf("Received error on FLUSH write: %v\n", err) + } + + // Wait for connections to be drained + for i := 0; i < numConnections; i++ { + <-clients[i] + } + b.StopTimer() +} + +var sub = "x" +var payload = "12345678" + +func Benchmark___FanOut_512x1kx1k(b *testing.B) { + doFanout(b, 1, 1000, 1000, sub, sizedString(512)) +} + +func Benchmark__FanOut_8x1000x100(b *testing.B) { + doFanout(b, 1, 1000, 100, sub, payload) +} + +func Benchmark______FanOut_8x1x10(b *testing.B) { + doFanout(b, 1, 1, 10, sub, payload) +} + +func Benchmark_____FanOut_8x1x100(b *testing.B) { + doFanout(b, 1, 1, 100, sub, payload) +} + +func Benchmark____FanOut_8x10x100(b *testing.B) { + doFanout(b, 1, 10, 100, sub, payload) +} + +func Benchmark___FanOut_8x10x1000(b *testing.B) { + doFanout(b, 1, 10, 1000, sub, payload) +} + +func Benchmark___FanOut_8x100x100(b *testing.B) { + doFanout(b, 1, 100, 100, sub, payload) +} + +func Benchmark__FanOut_8x100x1000(b *testing.B) { + doFanout(b, 1, 100, 1000, sub, payload) +} + +func Benchmark__FanOut_8x10x10000(b *testing.B) { + doFanout(b, 1, 10, 10000, sub, payload) +} + +func Benchmark__FanOut_1kx10x1000(b *testing.B) { + doFanout(b, 1, 10, 1000, sub, sizedString(1024)) +} + +func Benchmark_____RFanOut_8x1x10(b *testing.B) { + doFanout(b, 2, 1, 10, sub, payload) +} + +func Benchmark____RFanOut_8x1x100(b *testing.B) { + doFanout(b, 2, 1, 100, sub, payload) +} + +func Benchmark___RFanOut_8x10x100(b *testing.B) { + doFanout(b, 2, 10, 100, sub, payload) +} + +func Benchmark__RFanOut_8x10x1000(b *testing.B) { + doFanout(b, 2, 10, 1000, sub, payload) +} + +func Benchmark__RFanOut_8x100x100(b *testing.B) { + doFanout(b, 2, 100, 100, sub, payload) +} + +func Benchmark_RFanOut_8x100x1000(b *testing.B) { + doFanout(b, 2, 100, 1000, sub, payload) +} + +func Benchmark_RFanOut_8x10x10000(b *testing.B) { + doFanout(b, 2, 10, 10000, sub, payload) +} + +func Benchmark_RFanOut_1kx10x1000(b *testing.B) { + doFanout(b, 2, 10, 1000, sub, sizedString(1024)) +} diff --git a/vendor/github.com/nats-io/gnatsd/test/client_auth_test.go b/vendor/github.com/nats-io/gnatsd/test/client_auth_test.go new file mode 100644 index 00000000..9255c4bd --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/client_auth_test.go @@ -0,0 +1,92 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "fmt" + "io/ioutil" + "os" + "testing" + + "github.com/nats-io/go-nats" +) + +func TestMultipleUserAuth(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/multi_user.conf") + defer srv.Shutdown() + + if opts.Users == nil { + t.Fatal("Expected a user array that is not nil") + } + if len(opts.Users) != 2 { + t.Fatal("Expected a user array that had 2 users") + } + + // Test first user + url := fmt.Sprintf("nats://%s:%s@%s:%d/", + opts.Users[0].Username, + opts.Users[0].Password, + opts.Host, opts.Port) + + nc, err := nats.Connect(url) + if err != nil { + t.Fatalf("Expected a successful connect, got %v\n", err) + } + defer nc.Close() + + if !nc.AuthRequired() { + t.Fatal("Expected auth to be required for the server") + } + + // Test second user + url = fmt.Sprintf("nats://%s:%s@%s:%d/", + opts.Users[1].Username, + opts.Users[1].Password, + opts.Host, opts.Port) + + nc, err = nats.Connect(url) + if err != nil { + t.Fatalf("Expected a successful connect, got %v\n", err) + } + defer nc.Close() +} + +// Resolves to "test" +const testToken = "$2a$05$3sSWEVA1eMCbV0hWavDjXOx.ClBjI6u1CuUdLqf22cbJjXsnzz8/." + +func TestTokenInConfig(t *testing.T) { + confFileName := "test.conf" + defer os.Remove(confFileName) + content := ` + listen: 127.0.0.1:4567 + authorization={ + token: ` + testToken + ` + timeout: 5 + }` + if err := ioutil.WriteFile(confFileName, []byte(content), 0666); err != nil { + t.Fatalf("Error writing config file: %v", err) + } + s, opts := RunServerWithConfig(confFileName) + defer s.Shutdown() + + url := fmt.Sprintf("nats://test@%s:%d/", opts.Host, opts.Port) + nc, err := nats.Connect(url) + if err != nil { + t.Fatalf("Expected a successful connect, got %v\n", err) + } + defer nc.Close() + if !nc.AuthRequired() { + t.Fatal("Expected auth to be required for the server") + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/client_cluster_test.go b/vendor/github.com/nats-io/gnatsd/test/client_cluster_test.go new file mode 100644 index 00000000..503a3e5e --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/client_cluster_test.go @@ -0,0 +1,377 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "fmt" + "math/rand" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/nats-io/go-nats" +) + +func TestServerRestartReSliceIssue(t *testing.T) { + srvA, srvB, optsA, optsB := runServers(t) + defer srvA.Shutdown() + + urlA := fmt.Sprintf("nats://%s:%d/", optsA.Host, optsA.Port) + urlB := fmt.Sprintf("nats://%s:%d/", optsB.Host, optsB.Port) + + // msg to send.. + msg := []byte("Hello World") + + servers := []string{urlA, urlB} + + opts := nats.GetDefaultOptions() + opts.Timeout = (5 * time.Second) + opts.ReconnectWait = (50 * time.Millisecond) + opts.MaxReconnect = 1000 + + numClients := 20 + + reconnects := int32(0) + reconnectsDone := make(chan bool, numClients) + opts.ReconnectedCB = func(nc *nats.Conn) { + atomic.AddInt32(&reconnects, 1) + reconnectsDone <- true + } + + clients := make([]*nats.Conn, numClients) + + // Create 20 random clients. + // Half connected to A and half to B.. + for i := 0; i < numClients; i++ { + opts.Url = servers[i%2] + nc, err := opts.Connect() + if err != nil { + t.Fatalf("Failed to create connection: %v\n", err) + } + clients[i] = nc + defer nc.Close() + + // Create 10 subscriptions each.. + for x := 0; x < 10; x++ { + subject := fmt.Sprintf("foo.%d", (rand.Int()%50)+1) + nc.Subscribe(subject, func(m *nats.Msg) { + // Just eat it.. + }) + } + // Pick one subject to send to.. + subject := fmt.Sprintf("foo.%d", (rand.Int()%50)+1) + go func() { + time.Sleep(10 * time.Millisecond) + for i := 1; i <= 100; i++ { + if err := nc.Publish(subject, msg); err != nil { + return + } + if i%10 == 0 { + time.Sleep(time.Millisecond) + } + } + }() + } + + // Wait for a short bit.. + time.Sleep(20 * time.Millisecond) + + // Restart SrvB + srvB.Shutdown() + srvB = RunServer(optsB) + defer srvB.Shutdown() + + // Check that all expected clients have reconnected + done := false + for i := 0; i < numClients/2 && !done; i++ { + select { + case <-reconnectsDone: + done = true + case <-time.After(3 * time.Second): + t.Fatalf("Expected %d reconnects, got %d\n", numClients/2, reconnects) + } + } + + // Since srvB was restarted, its defer Shutdown() was last, so will + // exectue first, which would cause clients that have reconnected to + // it to try to reconnect (causing delays on Windows). So let's + // explicitly close them here. + // NOTE: With fix of NATS GO client (reconnect loop yields to Close()), + // this change would not be required, however, it still speeeds up + // the test, from more than 7s to less than one. + for i := 0; i < numClients; i++ { + nc := clients[i] + nc.Close() + } +} + +// This will test queue subscriber semantics across a cluster in the presence +// of server restarts. +func TestServerRestartAndQueueSubs(t *testing.T) { + srvA, srvB, optsA, optsB := runServers(t) + + urlA := fmt.Sprintf("nats://%s:%d/", optsA.Host, optsA.Port) + urlB := fmt.Sprintf("nats://%s:%d/", optsB.Host, optsB.Port) + + // Client options + opts := nats.GetDefaultOptions() + opts.Timeout = (5 * time.Second) + opts.ReconnectWait = (50 * time.Millisecond) + opts.MaxReconnect = 1000 + opts.NoRandomize = true + + // Allow us to block on a reconnect completion. + reconnectsDone := make(chan bool) + opts.ReconnectedCB = func(nc *nats.Conn) { + reconnectsDone <- true + } + + // Helper to wait on a reconnect. + waitOnReconnect := func() { + var rcs int64 + for { + select { + case <-reconnectsDone: + atomic.AddInt64(&rcs, 1) + if rcs >= 2 { + return + } + case <-time.After(2 * time.Second): + t.Fatalf("Expected a reconnect, timedout!\n") + } + } + } + + // Create two clients.. + opts.Servers = []string{urlA} + nc1, err := opts.Connect() + if err != nil { + t.Fatalf("Failed to create connection for nc1: %v\n", err) + } + + opts.Servers = []string{urlB} + nc2, err := opts.Connect() + if err != nil { + t.Fatalf("Failed to create connection for nc2: %v\n", err) + } + + c1, _ := nats.NewEncodedConn(nc1, "json") + defer c1.Close() + c2, _ := nats.NewEncodedConn(nc2, "json") + defer c2.Close() + + // Flusher helper function. + flush := func() { + // Wait for processing. + c1.Flush() + c2.Flush() + // Wait for a short bit for cluster propagation. + time.Sleep(50 * time.Millisecond) + } + + // To hold queue results. + results := make(map[int]int) + var mu sync.Mutex + + // This corresponds to the subsriptions below. + const ExpectedMsgCount = 3 + + // Make sure we got what we needed, 1 msg only and all seqnos accounted for.. + checkResults := func(numSent int) { + mu.Lock() + defer mu.Unlock() + + for i := 0; i < numSent; i++ { + if results[i] != ExpectedMsgCount { + t.Fatalf("Received incorrect number of messages, [%d] vs [%d] for seq: %d\n", results[i], ExpectedMsgCount, i) + } + } + + // Auto reset results map + results = make(map[int]int) + } + + subj := "foo.bar" + qgroup := "workers" + + cb := func(seqno int) { + mu.Lock() + defer mu.Unlock() + results[seqno] = results[seqno] + 1 + } + + // Create queue subscribers + c1.QueueSubscribe(subj, qgroup, cb) + c2.QueueSubscribe(subj, qgroup, cb) + + // Do a wildcard subscription. + c1.Subscribe("foo.*", cb) + c2.Subscribe("foo.*", cb) + + // Wait for processing. + flush() + + sendAndCheckMsgs := func(numToSend int) { + for i := 0; i < numToSend; i++ { + if i%2 == 0 { + c1.Publish(subj, i) + } else { + c2.Publish(subj, i) + } + } + // Wait for processing. + flush() + // Check Results + checkResults(numToSend) + } + + //////////////////////////////////////////////////////////////////////////// + // Base Test + //////////////////////////////////////////////////////////////////////////// + + // Make sure subscriptions are propagated in the cluster + if err := checkExpectedSubs(4, srvA, srvB); err != nil { + t.Fatalf("%v", err) + } + + // Now send 10 messages, from each client.. + sendAndCheckMsgs(10) + + //////////////////////////////////////////////////////////////////////////// + // Now restart SrvA and srvB, re-run test + //////////////////////////////////////////////////////////////////////////// + + srvA.Shutdown() + srvA = RunServer(optsA) + defer srvA.Shutdown() + + srvB.Shutdown() + srvB = RunServer(optsB) + defer srvB.Shutdown() + + waitOnReconnect() + + // Make sure the cluster is reformed + checkClusterFormed(t, srvA, srvB) + + // Make sure subscriptions are propagated in the cluster + if err := checkExpectedSubs(4, srvA, srvB); err != nil { + t.Fatalf("%v", err) + } + + // Now send another 10 messages, from each client.. + sendAndCheckMsgs(10) + + // Since servers are restarted after all client's close defer calls, + // their defer Shutdown() are last, and so will be executed first, + // which would cause clients to try to reconnect on exit, causing + // delays on Windows. So let's explicitly close them here. + c1.Close() + c2.Close() +} + +// This will test request semantics across a route +func TestRequestsAcrossRoutes(t *testing.T) { + srvA, srvB, optsA, optsB := runServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() + + urlA := fmt.Sprintf("nats://%s:%d/", optsA.Host, optsA.Port) + urlB := fmt.Sprintf("nats://%s:%d/", optsB.Host, optsB.Port) + + nc1, err := nats.Connect(urlA) + if err != nil { + t.Fatalf("Failed to create connection for nc1: %v\n", err) + } + defer nc1.Close() + + nc2, err := nats.Connect(urlB) + if err != nil { + t.Fatalf("Failed to create connection for nc2: %v\n", err) + } + defer nc2.Close() + + ec2, _ := nats.NewEncodedConn(nc2, nats.JSON_ENCODER) + + response := []byte("I will help you") + + // Connect responder to srvA + nc1.Subscribe("foo-req", func(m *nats.Msg) { + nc1.Publish(m.Reply, response) + }) + // Make sure the route and the subscription are propagated. + nc1.Flush() + + var resp string + + for i := 0; i < 100; i++ { + if err := ec2.Request("foo-req", i, &resp, 100*time.Millisecond); err != nil { + t.Fatalf("Received an error on Request test [%d]: %s", i, err) + } + } +} + +// This will test request semantics across a route to queues +func TestRequestsAcrossRoutesToQueues(t *testing.T) { + srvA, srvB, optsA, optsB := runServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() + + urlA := fmt.Sprintf("nats://%s:%d/", optsA.Host, optsA.Port) + urlB := fmt.Sprintf("nats://%s:%d/", optsB.Host, optsB.Port) + + nc1, err := nats.Connect(urlA) + if err != nil { + t.Fatalf("Failed to create connection for nc1: %v\n", err) + } + defer nc1.Close() + + nc2, err := nats.Connect(urlB) + if err != nil { + t.Fatalf("Failed to create connection for nc2: %v\n", err) + } + defer nc2.Close() + + ec1, _ := nats.NewEncodedConn(nc1, nats.JSON_ENCODER) + ec2, _ := nats.NewEncodedConn(nc2, nats.JSON_ENCODER) + + response := []byte("I will help you") + + // Connect one responder to srvA + nc1.QueueSubscribe("foo-req", "booboo", func(m *nats.Msg) { + nc1.Publish(m.Reply, response) + }) + // Make sure the route and the subscription are propagated. + nc1.Flush() + + // Connect the other responder to srvB + nc2.QueueSubscribe("foo-req", "booboo", func(m *nats.Msg) { + nc2.Publish(m.Reply, response) + }) + + var resp string + + for i := 0; i < 100; i++ { + if err := ec2.Request("foo-req", i, &resp, 500*time.Millisecond); err != nil { + t.Fatalf("Received an error on Request test [%d]: %s", i, err) + } + } + + for i := 0; i < 100; i++ { + if err := ec1.Request("foo-req", i, &resp, 500*time.Millisecond); err != nil { + t.Fatalf("Received an error on Request test [%d]: %s", i, err) + } + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/cluster_test.go b/vendor/github.com/nats-io/gnatsd/test/cluster_test.go new file mode 100644 index 00000000..4bca2fc3 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/cluster_test.go @@ -0,0 +1,491 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "errors" + "fmt" + "runtime" + "testing" + "time" + + "github.com/nats-io/gnatsd/server" +) + +// Helper function to check that a cluster is formed +func checkClusterFormed(t *testing.T, servers ...*server.Server) { + t.Helper() + expectedNumRoutes := len(servers) - 1 + checkFor(t, 10*time.Second, 100*time.Millisecond, func() error { + for _, s := range servers { + if numRoutes := s.NumRoutes(); numRoutes != expectedNumRoutes { + return fmt.Errorf("Expected %d routes for server %q, got %d", expectedNumRoutes, s.ID(), numRoutes) + } + } + return nil + }) +} + +func checkNumRoutes(t *testing.T, s *server.Server, expected int) { + t.Helper() + checkFor(t, 5*time.Second, 15*time.Millisecond, func() error { + if nr := s.NumRoutes(); nr != expected { + return fmt.Errorf("Expected %v routes, got %v", expected, nr) + } + return nil + }) +} + +// Helper function to check that a server (or list of servers) have the +// expected number of subscriptions. +func checkExpectedSubs(expected int, servers ...*server.Server) error { + var err string + maxTime := time.Now().Add(10 * time.Second) + for time.Now().Before(maxTime) { + err = "" + for _, s := range servers { + if numSubs := int(s.NumSubscriptions()); numSubs != expected { + err = fmt.Sprintf("Expected %d subscriptions for server %q, got %d", expected, s.ID(), numSubs) + break + } + } + if err != "" { + time.Sleep(10 * time.Millisecond) + } else { + break + } + } + if err != "" { + return errors.New(err) + } + return nil +} + +func runServers(t *testing.T) (srvA, srvB *server.Server, optsA, optsB *server.Options) { + srvA, optsA = RunServerWithConfig("./configs/srv_a.conf") + srvB, optsB = RunServerWithConfig("./configs/srv_b.conf") + + checkClusterFormed(t, srvA, srvB) + return +} + +func TestProperServerWithRoutesShutdown(t *testing.T) { + before := runtime.NumGoroutine() + srvA, srvB, _, _ := runServers(t) + srvA.Shutdown() + srvB.Shutdown() + time.Sleep(100 * time.Millisecond) + + after := runtime.NumGoroutine() + delta := after - before + // There may be some finalizers or IO, but in general more than + // 2 as a delta represents a problem. + if delta > 2 { + t.Fatalf("Expected same number of goroutines, %d vs %d\n", before, after) + } +} + +func TestDoubleRouteConfig(t *testing.T) { + srvA, srvB, _, _ := runServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() +} + +func TestBasicClusterPubSub(t *testing.T) { + srvA, srvB, optsA, optsB := runServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() + + clientA := createClientConn(t, optsA.Host, optsA.Port) + defer clientA.Close() + + clientB := createClientConn(t, optsB.Host, optsB.Port) + defer clientB.Close() + + sendA, expectA := setupConn(t, clientA) + sendA("SUB foo 22\r\n") + sendA("PING\r\n") + expectA(pongRe) + + if err := checkExpectedSubs(1, srvA, srvB); err != nil { + t.Fatalf("%v", err) + } + + sendB, expectB := setupConn(t, clientB) + sendB("PUB foo 2\r\nok\r\n") + sendB("PING\r\n") + expectB(pongRe) + + expectMsgs := expectMsgsCommand(t, expectA) + + matches := expectMsgs(1) + checkMsg(t, matches[0], "foo", "22", "", "2", "ok") +} + +func TestClusterQueueSubs(t *testing.T) { + srvA, srvB, optsA, optsB := runServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() + + clientA := createClientConn(t, optsA.Host, optsA.Port) + defer clientA.Close() + + clientB := createClientConn(t, optsB.Host, optsB.Port) + defer clientB.Close() + + sendA, expectA := setupConn(t, clientA) + sendB, expectB := setupConn(t, clientB) + + expectMsgsA := expectMsgsCommand(t, expectA) + expectMsgsB := expectMsgsCommand(t, expectB) + + // Capture sids for checking later. + qg1SidsA := []string{"1", "2", "3"} + + // Three queue subscribers + for _, sid := range qg1SidsA { + sendA(fmt.Sprintf("SUB foo qg1 %s\r\n", sid)) + } + sendA("PING\r\n") + expectA(pongRe) + + // Make sure the subs have propagated to srvB before continuing + if err := checkExpectedSubs(len(qg1SidsA), srvB); err != nil { + t.Fatalf("%v", err) + } + + sendB("PUB foo 2\r\nok\r\n") + sendB("PING\r\n") + expectB(pongRe) + + // Make sure we get only 1. + matches := expectMsgsA(1) + checkMsg(t, matches[0], "foo", "", "", "2", "ok") + + // Capture sids for checking later. + pSids := []string{"4", "5", "6"} + + // Create 3 normal subscribers + for _, sid := range pSids { + sendA(fmt.Sprintf("SUB foo %s\r\n", sid)) + } + + // Create a FWC Subscriber + pSids = append(pSids, "7") + sendA("SUB > 7\r\n") + sendA("PING\r\n") + expectA(pongRe) + + // Make sure the subs have propagated to srvB before continuing + if err := checkExpectedSubs(len(qg1SidsA)+len(pSids), srvB); err != nil { + t.Fatalf("%v", err) + } + + // Send to B + sendB("PUB foo 2\r\nok\r\n") + sendB("PING\r\n") + expectB(pongRe) + + // Should receive 5. + matches = expectMsgsA(5) + checkForQueueSid(t, matches, qg1SidsA) + checkForPubSids(t, matches, pSids) + + // Send to A + sendA("PUB foo 2\r\nok\r\n") + + // Should receive 5. + matches = expectMsgsA(5) + checkForQueueSid(t, matches, qg1SidsA) + checkForPubSids(t, matches, pSids) + + // Now add queue subscribers to B + qg2SidsB := []string{"1", "2", "3"} + for _, sid := range qg2SidsB { + sendB(fmt.Sprintf("SUB foo qg2 %s\r\n", sid)) + } + sendB("PING\r\n") + expectB(pongRe) + + // Make sure the subs have propagated to srvA before continuing + if err := checkExpectedSubs(len(qg1SidsA)+len(pSids)+len(qg2SidsB), srvA); err != nil { + t.Fatalf("%v", err) + } + + // Send to B + sendB("PUB foo 2\r\nok\r\n") + + // Should receive 1 from B. + matches = expectMsgsB(1) + checkForQueueSid(t, matches, qg2SidsB) + + // Should receive 5 still from A. + matches = expectMsgsA(5) + checkForQueueSid(t, matches, qg1SidsA) + checkForPubSids(t, matches, pSids) + + // Now drop queue subscribers from A + for _, sid := range qg1SidsA { + sendA(fmt.Sprintf("UNSUB %s\r\n", sid)) + } + sendA("PING\r\n") + expectA(pongRe) + + // Make sure the subs have propagated to srvB before continuing + if err := checkExpectedSubs(len(pSids)+len(qg2SidsB), srvB); err != nil { + t.Fatalf("%v", err) + } + + // Send to B + sendB("PUB foo 2\r\nok\r\n") + + // Should receive 1 from B. + matches = expectMsgsB(1) + checkForQueueSid(t, matches, qg2SidsB) + + sendB("PING\r\n") + expectB(pongRe) + + // Should receive 4 now. + matches = expectMsgsA(4) + checkForPubSids(t, matches, pSids) + + // Send to A + sendA("PUB foo 2\r\nok\r\n") + + // Should receive 4 now. + matches = expectMsgsA(4) + checkForPubSids(t, matches, pSids) +} + +// Issue #22 +func TestClusterDoubleMsgs(t *testing.T) { + srvA, srvB, optsA, optsB := runServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() + + clientA1 := createClientConn(t, optsA.Host, optsA.Port) + defer clientA1.Close() + + clientA2 := createClientConn(t, optsA.Host, optsA.Port) + defer clientA2.Close() + + clientB := createClientConn(t, optsB.Host, optsB.Port) + defer clientB.Close() + + sendA1, expectA1 := setupConn(t, clientA1) + sendA2, expectA2 := setupConn(t, clientA2) + sendB, expectB := setupConn(t, clientB) + + expectMsgsA1 := expectMsgsCommand(t, expectA1) + expectMsgsA2 := expectMsgsCommand(t, expectA2) + + // Capture sids for checking later. + qg1SidsA := []string{"1", "2", "3"} + + // Three queue subscribers + for _, sid := range qg1SidsA { + sendA1(fmt.Sprintf("SUB foo qg1 %s\r\n", sid)) + } + sendA1("PING\r\n") + expectA1(pongRe) + + // Make sure the subs have propagated to srvB before continuing + if err := checkExpectedSubs(len(qg1SidsA), srvB); err != nil { + t.Fatalf("%v", err) + } + + sendB("PUB foo 2\r\nok\r\n") + sendB("PING\r\n") + expectB(pongRe) + + // Make sure we get only 1. + matches := expectMsgsA1(1) + checkMsg(t, matches[0], "foo", "", "", "2", "ok") + checkForQueueSid(t, matches, qg1SidsA) + + // Add a FWC subscriber on A2 + sendA2("SUB > 1\r\n") + sendA2("SUB foo 2\r\n") + sendA2("PING\r\n") + expectA2(pongRe) + pSids := []string{"1", "2"} + + // Make sure the subs have propagated to srvB before continuing + if err := checkExpectedSubs(len(qg1SidsA)+2, srvB); err != nil { + t.Fatalf("%v", err) + } + + sendB("PUB foo 2\r\nok\r\n") + sendB("PING\r\n") + expectB(pongRe) + + matches = expectMsgsA1(1) + checkMsg(t, matches[0], "foo", "", "", "2", "ok") + checkForQueueSid(t, matches, qg1SidsA) + + matches = expectMsgsA2(2) + checkMsg(t, matches[0], "foo", "", "", "2", "ok") + checkForPubSids(t, matches, pSids) + + // Close ClientA1 + clientA1.Close() + + sendB("PUB foo 2\r\nok\r\n") + sendB("PING\r\n") + expectB(pongRe) + + matches = expectMsgsA2(2) + checkMsg(t, matches[0], "foo", "", "", "2", "ok") + checkForPubSids(t, matches, pSids) +} + +// This will test that we drop remote sids correctly. +func TestClusterDropsRemoteSids(t *testing.T) { + srvA, srvB, optsA, _ := runServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() + + clientA := createClientConn(t, optsA.Host, optsA.Port) + defer clientA.Close() + + sendA, expectA := setupConn(t, clientA) + + // Add a subscription + sendA("SUB foo 1\r\n") + sendA("PING\r\n") + expectA(pongRe) + + // Wait for propagation. + time.Sleep(100 * time.Millisecond) + + if sc := srvA.NumSubscriptions(); sc != 1 { + t.Fatalf("Expected one subscription for srvA, got %d\n", sc) + } + if sc := srvB.NumSubscriptions(); sc != 1 { + t.Fatalf("Expected one subscription for srvB, got %d\n", sc) + } + + // Add another subscription + sendA("SUB bar 2\r\n") + sendA("PING\r\n") + expectA(pongRe) + + // Wait for propagation. + time.Sleep(100 * time.Millisecond) + + if sc := srvA.NumSubscriptions(); sc != 2 { + t.Fatalf("Expected two subscriptions for srvA, got %d\n", sc) + } + if sc := srvB.NumSubscriptions(); sc != 2 { + t.Fatalf("Expected two subscriptions for srvB, got %d\n", sc) + } + + // unsubscription + sendA("UNSUB 1\r\n") + sendA("PING\r\n") + expectA(pongRe) + + // Wait for propagation. + time.Sleep(100 * time.Millisecond) + + if sc := srvA.NumSubscriptions(); sc != 1 { + t.Fatalf("Expected one subscription for srvA, got %d\n", sc) + } + if sc := srvB.NumSubscriptions(); sc != 1 { + t.Fatalf("Expected one subscription for srvB, got %d\n", sc) + } + + // Close the client and make sure we remove subscription state. + clientA.Close() + + // Wait for propagation. + time.Sleep(100 * time.Millisecond) + if sc := srvA.NumSubscriptions(); sc != 0 { + t.Fatalf("Expected no subscriptions for srvA, got %d\n", sc) + } + if sc := srvB.NumSubscriptions(); sc != 0 { + t.Fatalf("Expected no subscriptions for srvB, got %d\n", sc) + } +} + +// This will test that we drop remote sids correctly. +func TestAutoUnsubscribePropagation(t *testing.T) { + srvA, srvB, optsA, _ := runServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() + + clientA := createClientConn(t, optsA.Host, optsA.Port) + defer clientA.Close() + + sendA, expectA := setupConn(t, clientA) + expectMsgs := expectMsgsCommand(t, expectA) + + // We will create subscriptions that will auto-unsubscribe and make sure + // we are not accumulating orphan subscriptions on the other side. + for i := 1; i <= 100; i++ { + sub := fmt.Sprintf("SUB foo %d\r\n", i) + auto := fmt.Sprintf("UNSUB %d 1\r\n", i) + sendA(sub) + sendA(auto) + // This will trip the auto-unsubscribe + sendA("PUB foo 2\r\nok\r\n") + expectMsgs(1) + } + + sendA("PING\r\n") + expectA(pongRe) + + time.Sleep(50 * time.Millisecond) + + // Make sure number of subscriptions on B is correct + if subs := srvB.NumSubscriptions(); subs != 0 { + t.Fatalf("Expected no subscriptions on remote server, got %d\n", subs) + } +} + +func TestAutoUnsubscribePropagationOnClientDisconnect(t *testing.T) { + srvA, srvB, optsA, _ := runServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() + + cluster := []*server.Server{srvA, srvB} + + clientA := createClientConn(t, optsA.Host, optsA.Port) + defer clientA.Close() + + sendA, expectA := setupConn(t, clientA) + + // No subscriptions. Ready to test. + if err := checkExpectedSubs(0, cluster...); err != nil { + t.Fatalf("%v", err) + } + + sendA("SUB foo 1\r\n") + sendA("UNSUB 1 1\r\n") + sendA("PING\r\n") + expectA(pongRe) + + // Waiting cluster subs propagation + if err := checkExpectedSubs(1, cluster...); err != nil { + t.Fatalf("%v", err) + } + + clientA.Close() + + // No subs should be on the cluster when all clients is disconnected + if err := checkExpectedSubs(0, cluster...); err != nil { + t.Fatalf("%v", err) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/cluster_tls_test.go b/vendor/github.com/nats-io/gnatsd/test/cluster_tls_test.go new file mode 100644 index 00000000..e215667f --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/cluster_tls_test.go @@ -0,0 +1,64 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "testing" + + "github.com/nats-io/gnatsd/server" +) + +func runTLSServers(t *testing.T) (srvA, srvB *server.Server, optsA, optsB *server.Options) { + srvA, optsA = RunServerWithConfig("./configs/srv_a_tls.conf") + srvB, optsB = RunServerWithConfig("./configs/srv_b_tls.conf") + checkClusterFormed(t, srvA, srvB) + return +} + +func TestTLSClusterConfig(t *testing.T) { + srvA, srvB, _, _ := runTLSServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() +} + +func TestBasicTLSClusterPubSub(t *testing.T) { + srvA, srvB, optsA, optsB := runTLSServers(t) + defer srvA.Shutdown() + defer srvB.Shutdown() + + clientA := createClientConn(t, optsA.Host, optsA.Port) + defer clientA.Close() + + clientB := createClientConn(t, optsB.Host, optsB.Port) + defer clientB.Close() + + sendA, expectA := setupConn(t, clientA) + sendA("SUB foo 22\r\n") + sendA("PING\r\n") + expectA(pongRe) + + sendB, expectB := setupConn(t, clientB) + sendB("PUB foo 2\r\nok\r\n") + sendB("PING\r\n") + expectB(pongRe) + + if err := checkExpectedSubs(1, srvA, srvB); err != nil { + t.Fatalf("%v", err) + } + + expectMsgs := expectMsgsCommand(t, expectA) + + matches := expectMsgs(1) + checkMsg(t, matches[0], "foo", "22", "", "2", "ok") +} diff --git a/vendor/github.com/nats-io/gnatsd/test/fanout_test.go b/vendor/github.com/nats-io/gnatsd/test/fanout_test.go new file mode 100644 index 00000000..c7350100 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/fanout_test.go @@ -0,0 +1,128 @@ +// Copyright 2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !race + +package test + +import ( + "fmt" + "sync" + "testing" + + "github.com/nats-io/gnatsd/server" + "github.com/nats-io/go-nats" +) + +// IMPORTANT: Tests in this file are not executed when running with the -race flag. + +// As we look to improve high fanout situations make sure we +// have a test that checks ordering for all subscriptions from a single subscriber. +func TestHighFanoutOrdering(t *testing.T) { + opts := &server.Options{Host: "127.0.0.1", Port: server.RANDOM_PORT} + + s := RunServer(opts) + defer s.Shutdown() + + url := fmt.Sprintf("nats://%s", s.Addr()) + + const ( + nconns = 100 + nsubs = 100 + npubs = 500 + ) + + // make unique + subj := nats.NewInbox() + + var wg sync.WaitGroup + wg.Add(nconns * nsubs) + + for i := 0; i < nconns; i++ { + nc, err := nats.Connect(url) + if err != nil { + t.Fatalf("Expected a successful connect on %d, got %v\n", i, err) + } + + nc.SetErrorHandler(func(c *nats.Conn, s *nats.Subscription, e error) { + t.Fatalf("Got an error %v for %+v\n", s, err) + }) + + ec, _ := nats.NewEncodedConn(nc, nats.DEFAULT_ENCODER) + + for y := 0; y < nsubs; y++ { + expected := 0 + ec.Subscribe(subj, func(n int) { + if n != expected { + t.Fatalf("Expected %d but received %d\n", expected, n) + } + expected++ + if expected >= npubs { + wg.Done() + } + }) + } + ec.Flush() + defer ec.Close() + } + + nc, _ := nats.Connect(url) + ec, _ := nats.NewEncodedConn(nc, nats.DEFAULT_ENCODER) + + for i := 0; i < npubs; i++ { + ec.Publish(subj, i) + } + defer ec.Close() + + wg.Wait() +} + +func TestRouteFormTimeWithHighSubscriptions(t *testing.T) { + srvA, optsA := RunServerWithConfig("./configs/srv_a.conf") + defer srvA.Shutdown() + + clientA := createClientConn(t, optsA.Host, optsA.Port) + defer clientA.Close() + + sendA, expectA := setupConn(t, clientA) + + // Now add lots of subscriptions. These will need to be forwarded + // to new routes when they are added. + subsTotal := 100000 + for i := 0; i < subsTotal; i++ { + subject := fmt.Sprintf("FOO.BAR.BAZ.%d", i) + sendA(fmt.Sprintf("SUB %s %d\r\n", subject, i)) + } + sendA("PING\r\n") + expectA(pongRe) + + srvB, _ := RunServerWithConfig("./configs/srv_b.conf") + defer srvB.Shutdown() + + checkClusterFormed(t, srvA, srvB) + + // Now wait for all subscriptions to be processed. + if err := checkExpectedSubs(subsTotal, srvB); err != nil { + // Make sure we are not a slow consumer + // Check for slow consumer status + if srvA.NumSlowConsumers() > 0 { + t.Fatal("Did not receive all subscriptions due to slow consumer") + } else { + t.Fatalf("%v", err) + } + } + // Just double check the slow consumer status. + if srvA.NumSlowConsumers() > 0 { + t.Fatalf("Received a slow consumer notification: %d", srvA.NumSlowConsumers()) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/gosrv_test.go b/vendor/github.com/nats-io/gnatsd/test/gosrv_test.go new file mode 100644 index 00000000..c7940cad --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/gosrv_test.go @@ -0,0 +1,62 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "net" + "runtime" + "testing" + "time" +) + +func TestSimpleGoServerShutdown(t *testing.T) { + base := runtime.NumGoroutine() + opts := DefaultTestOptions + opts.Port = -1 + s := RunServer(&opts) + s.Shutdown() + time.Sleep(100 * time.Millisecond) + delta := (runtime.NumGoroutine() - base) + if delta > 1 { + t.Fatalf("%d Go routines still exist post Shutdown()", delta) + } +} + +func TestGoServerShutdownWithClients(t *testing.T) { + base := runtime.NumGoroutine() + opts := DefaultTestOptions + opts.Port = -1 + s := RunServer(&opts) + addr := s.Addr().(*net.TCPAddr) + for i := 0; i < 50; i++ { + createClientConn(t, "127.0.0.1", addr.Port) + } + s.Shutdown() + // Wait longer for client connections + time.Sleep(1 * time.Second) + delta := (runtime.NumGoroutine() - base) + // There may be some finalizers or IO, but in general more than + // 2 as a delta represents a problem. + if delta > 2 { + t.Fatalf("%d Go routines still exist post Shutdown()", delta) + } +} + +func TestGoServerMultiShutdown(t *testing.T) { + opts := DefaultTestOptions + opts.Port = -1 + s := RunServer(&opts) + s.Shutdown() + s.Shutdown() +} diff --git a/vendor/github.com/nats-io/gnatsd/test/maxpayload_test.go b/vendor/github.com/nats-io/gnatsd/test/maxpayload_test.go new file mode 100644 index 00000000..9fe68782 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/maxpayload_test.go @@ -0,0 +1,99 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "fmt" + "net" + "runtime" + "strings" + "testing" + "time" + + "github.com/nats-io/go-nats" +) + +func TestMaxPayload(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/override.conf") + defer srv.Shutdown() + + endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(fmt.Sprintf("nats://%s/", endpoint)) + if err != nil { + t.Fatalf("Could not connect to server: %v", err) + } + defer nc.Close() + + size := 4 * 1024 * 1024 + big := sizedBytes(size) + err = nc.Publish("foo", big) + + if err != nats.ErrMaxPayload { + t.Fatalf("Expected a Max Payload error") + } + + conn, err := net.DialTimeout("tcp", endpoint, nc.Opts.Timeout) + if err != nil { + t.Fatalf("Could not make a raw connection to the server: %v", err) + } + defer conn.Close() + info := make([]byte, 512) + _, err = conn.Read(info) + if err != nil { + t.Fatalf("Expected an info message to be sent by the server: %s", err) + } + pub := fmt.Sprintf("PUB bar %d\r\n", size) + conn.Write([]byte(pub)) + if err != nil { + t.Fatalf("Could not publish event to the server: %s", err) + } + + errMsg := make([]byte, 35) + _, err = conn.Read(errMsg) + if err != nil { + t.Fatalf("Expected an error message to be sent by the server: %s", err) + } + + if !strings.Contains(string(errMsg), "Maximum Payload Violation") { + t.Errorf("Received wrong error message (%v)\n", string(errMsg)) + } + + // Client proactively omits sending the message so server + // does not close the connection. + if nc.IsClosed() { + t.Errorf("Expected connection to not be closed.") + } + + // On the other hand client which did not proactively omitted + // publishing the bytes following what is suggested by server + // in the info message has its connection closed. + _, err = conn.Write(big) + if err == nil && runtime.GOOS != "windows" { + t.Errorf("Expected error due to maximum payload transgression.") + } + + // On windows, the previous write will not fail because the connection + // is not fully closed at this stage. + if runtime.GOOS == "windows" { + // Issuing a PING and not expecting the PONG. + _, err = conn.Write([]byte("PING\r\n")) + if err == nil { + conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + _, err = conn.Read(big) + if err == nil { + t.Errorf("Expected closed connection due to maximum payload transgression.") + } + } + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/monitor_test.go b/vendor/github.com/nats-io/gnatsd/test/monitor_test.go new file mode 100644 index 00000000..a0851126 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/monitor_test.go @@ -0,0 +1,739 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/nats-io/gnatsd/server" + "github.com/nats-io/go-nats" +) + +const CLIENT_PORT = 11422 +const MONITOR_PORT = 11522 + +func runMonitorServer() *server.Server { + resetPreviousHTTPConnections() + opts := DefaultTestOptions + opts.Port = CLIENT_PORT + opts.HTTPPort = MONITOR_PORT + opts.HTTPHost = "127.0.0.1" + + return RunServer(&opts) +} + +// Runs a clustered pair of monitor servers for testing the /routez endpoint +func runMonitorServerClusteredPair(t *testing.T) (*server.Server, *server.Server) { + resetPreviousHTTPConnections() + opts := DefaultTestOptions + opts.Port = CLIENT_PORT + opts.HTTPPort = MONITOR_PORT + opts.HTTPHost = "127.0.0.1" + opts.Cluster = server.ClusterOpts{Host: "127.0.0.1", Port: 10223} + opts.Routes = server.RoutesFromStr("nats-route://127.0.0.1:10222") + + s1 := RunServer(&opts) + + opts2 := DefaultTestOptions + opts2.Port = CLIENT_PORT + 1 + opts2.HTTPPort = MONITOR_PORT + 1 + opts2.HTTPHost = "127.0.0.1" + opts2.Cluster = server.ClusterOpts{Host: "127.0.0.1", Port: 10222} + opts2.Routes = server.RoutesFromStr("nats-route://127.0.0.1:10223") + + s2 := RunServer(&opts2) + + checkClusterFormed(t, s1, s2) + + return s1, s2 +} + +func runMonitorServerNoHTTPPort() *server.Server { + resetPreviousHTTPConnections() + opts := DefaultTestOptions + opts.Port = CLIENT_PORT + opts.HTTPPort = 0 + + return RunServer(&opts) +} + +func resetPreviousHTTPConnections() { + http.DefaultTransport.(*http.Transport).CloseIdleConnections() +} + +// Make sure that we do not run the http server for monitoring unless asked. +func TestNoMonitorPort(t *testing.T) { + s := runMonitorServerNoHTTPPort() + defer s.Shutdown() + + url := fmt.Sprintf("http://127.0.0.1:%d/", MONITOR_PORT) + if resp, err := http.Get(url + "varz"); err == nil { + t.Fatalf("Expected error: Got %+v\n", resp) + } + if resp, err := http.Get(url + "healthz"); err == nil { + t.Fatalf("Expected error: Got %+v\n", resp) + } + if resp, err := http.Get(url + "connz"); err == nil { + t.Fatalf("Expected error: Got %+v\n", resp) + } +} + +// testEndpointDataRace tests a monitoring endpoint for data races by polling +// while client code acts to ensure statistics are updated. It is designed to +// run under the -race flag to catch violations. The caller must start the +// NATS server. +func testEndpointDataRace(endpoint string, t *testing.T) { + var doneWg sync.WaitGroup + + url := fmt.Sprintf("http://127.0.0.1:%d/", MONITOR_PORT) + + // Poll as fast as we can, while creating connections, publishing, + // and subscribing. + clientDone := int64(0) + doneWg.Add(1) + go func() { + for atomic.LoadInt64(&clientDone) == 0 { + resp, err := http.Get(url + endpoint) + if err != nil { + t.Errorf("Expected no error: Got %v\n", err) + } else { + resp.Body.Close() + } + } + doneWg.Done() + }() + + // create connections, subscriptions, and publish messages to + // update the monitor variables. + var conns []net.Conn + for i := 0; i < 50; i++ { + cl := createClientConnSubscribeAndPublish(t) + // keep a few connections around to test monitor variables. + if i%10 == 0 { + conns = append(conns, cl) + } else { + cl.Close() + } + } + atomic.AddInt64(&clientDone, 1) + + // wait for the endpoint polling goroutine to exit + doneWg.Wait() + + // cleanup the conns + for _, cl := range conns { + cl.Close() + } +} + +func TestEndpointDataRaces(t *testing.T) { + // setup a small cluster to test /routez + s1, s2 := runMonitorServerClusteredPair(t) + defer s1.Shutdown() + defer s2.Shutdown() + + // test all of our endpoints + testEndpointDataRace("varz", t) + testEndpointDataRace("connz", t) + testEndpointDataRace("routez", t) + testEndpointDataRace("subsz", t) + testEndpointDataRace("stacksz", t) +} + +func TestVarz(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + url := fmt.Sprintf("http://127.0.0.1:%d/", MONITOR_PORT) + resp, err := http.Get(url + "varz") + if err != nil { + t.Fatalf("Expected no error: Got %v\n", err) + } + if resp.StatusCode != 200 { + t.Fatalf("Expected a 200 response, got %d\n", resp.StatusCode) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Got an error reading the body: %v\n", err) + } + + v := server.Varz{} + if err := json.Unmarshal(body, &v); err != nil { + t.Fatalf("Got an error unmarshalling the body: %v\n", err) + } + + // Do some sanity checks on values + if time.Since(v.Start) > 10*time.Second { + t.Fatal("Expected start time to be within 10 seconds.") + } + + cl := createClientConnSubscribeAndPublish(t) + defer cl.Close() + + resp, err = http.Get(url + "varz") + if err != nil { + t.Fatalf("Expected no error: Got %v\n", err) + } + if resp.StatusCode != 200 { + t.Fatalf("Expected a 200 response, got %d\n", resp.StatusCode) + } + defer resp.Body.Close() + body, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Got an error reading the body: %v\n", err) + } + + if strings.Contains(string(body), "cluster_port") { + t.Fatal("Varz body contains cluster information when no cluster is defined.") + } + + v = server.Varz{} + if err := json.Unmarshal(body, &v); err != nil { + t.Fatalf("Got an error unmarshalling the body: %v\n", err) + } + + if v.Connections != 1 { + t.Fatalf("Expected Connections of 1, got %v\n", v.Connections) + } + if v.InMsgs != 1 { + t.Fatalf("Expected InMsgs of 1, got %v\n", v.InMsgs) + } + if v.OutMsgs != 1 { + t.Fatalf("Expected OutMsgs of 1, got %v\n", v.OutMsgs) + } + if v.InBytes != 5 { + t.Fatalf("Expected InBytes of 5, got %v\n", v.InBytes) + } + if v.OutBytes != 5 { + t.Fatalf("Expected OutBytes of 5, got %v\n", v.OutBytes) + } + if v.MaxPending != server.MAX_PENDING_SIZE { + t.Fatalf("Expected MaxPending of %d, got %v\n", + server.MAX_PENDING_SIZE, v.MaxPending) + } + if v.WriteDeadline != server.DEFAULT_FLUSH_DEADLINE { + t.Fatalf("Expected WriteDeadline of %d, got %v\n", + server.DEFAULT_FLUSH_DEADLINE, v.WriteDeadline) + } +} + +func TestConnz(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + url := fmt.Sprintf("http://127.0.0.1:%d/", MONITOR_PORT) + resp, err := http.Get(url + "connz") + if err != nil { + t.Fatalf("Expected no error: Got %v\n", err) + } + if resp.StatusCode != 200 { + t.Fatalf("Expected a 200 response, got %d\n", resp.StatusCode) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Got an error reading the body: %v\n", err) + } + + c := server.Connz{} + if err := json.Unmarshal(body, &c); err != nil { + t.Fatalf("Got an error unmarshalling the body: %v\n", err) + } + + // Test contents.. + if c.NumConns != 0 { + t.Fatalf("Expected 0 connections, got %d\n", c.NumConns) + } + if c.Total != 0 { + t.Fatalf("Expected 0 live connections, got %d\n", c.Total) + } + if c.Conns == nil || len(c.Conns) != 0 { + t.Fatalf("Expected 0 connections in array, got %p\n", c.Conns) + } + + cl := createClientConnSubscribeAndPublish(t) + defer cl.Close() + + resp, err = http.Get(url + "connz") + if err != nil { + t.Fatalf("Expected no error: Got %v\n", err) + } + if resp.StatusCode != 200 { + t.Fatalf("Expected a 200 response, got %d\n", resp.StatusCode) + } + defer resp.Body.Close() + body, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Got an error reading the body: %v\n", err) + } + if err := json.Unmarshal(body, &c); err != nil { + t.Fatalf("Got an error unmarshalling the body: %v\n", err) + } + + if c.NumConns != 1 { + t.Fatalf("Expected 1 connection, got %d\n", c.NumConns) + } + if c.Total != 1 { + t.Fatalf("Expected 1 live connection, got %d\n", c.Total) + } + if c.Conns == nil || len(c.Conns) != 1 { + t.Fatalf("Expected 1 connection in array, got %p\n", c.Conns) + } + + if c.Limit != server.DefaultConnListSize { + t.Fatalf("Expected limit of %d, got %v\n", server.DefaultConnListSize, c.Limit) + } + + if c.Offset != 0 { + t.Fatalf("Expected offset of 0, got %v\n", c.Offset) + } + + // Test inside details of each connection + ci := c.Conns[0] + + if ci.Cid == 0 { + t.Fatalf("Expected non-zero cid, got %v\n", ci.Cid) + } + if ci.IP != "127.0.0.1" { + t.Fatalf("Expected \"127.0.0.1\" for IP, got %v\n", ci.IP) + } + if ci.Port == 0 { + t.Fatalf("Expected non-zero port, got %v\n", ci.Port) + } + if ci.NumSubs != 1 { + t.Fatalf("Expected num_subs of 1, got %v\n", ci.NumSubs) + } + if len(ci.Subs) != 0 { + t.Fatalf("Expected subs of 0, got %v\n", ci.Subs) + } + if ci.InMsgs != 1 { + t.Fatalf("Expected InMsgs of 1, got %v\n", ci.InMsgs) + } + if ci.OutMsgs != 1 { + t.Fatalf("Expected OutMsgs of 1, got %v\n", ci.OutMsgs) + } + if ci.InBytes != 5 { + t.Fatalf("Expected InBytes of 1, got %v\n", ci.InBytes) + } + if ci.OutBytes != 5 { + t.Fatalf("Expected OutBytes of 1, got %v\n", ci.OutBytes) + } +} + +func TestTLSConnz(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/tls.conf") + defer srv.Shutdown() + rootCAFile := "./configs/certs/ca.pem" + clientCertFile := "./configs/certs/client-cert.pem" + clientKeyFile := "./configs/certs/client-key.pem" + + // Test with secure connection + endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + nurl := fmt.Sprintf("tls://%s:%s@%s/", opts.Username, opts.Password, endpoint) + nc, err := nats.Connect(nurl, nats.RootCAs(rootCAFile)) + if err != nil { + t.Fatalf("Got an error on Connect with Secure Options: %+v\n", err) + } + defer nc.Close() + ch := make(chan struct{}) + nc.Subscribe("foo", func(m *nats.Msg) { ch <- struct{}{} }) + nc.Publish("foo", []byte("Hello")) + + // Wait for message + <-ch + + url := fmt.Sprintf("https://127.0.0.1:%d/", opts.HTTPSPort) + tlsConfig := &tls.Config{} + caCert, err := ioutil.ReadFile(rootCAFile) + if err != nil { + t.Fatalf("Got error reading RootCA file: %s", err) + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + tlsConfig.RootCAs = caCertPool + + cert, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile) + if err != nil { + t.Fatalf("Got error reading client certificates: %s", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + transport := &http.Transport{TLSClientConfig: tlsConfig} + httpClient := &http.Client{Transport: transport} + + resp, err := httpClient.Get(url + "connz") + if err != nil { + t.Fatalf("Expected no error: Got %v\n", err) + } + if resp.StatusCode != 200 { + t.Fatalf("Expected a 200 response, got %d\n", resp.StatusCode) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + + if err != nil { + t.Fatalf("Got an error reading the body: %v\n", err) + } + c := server.Connz{} + if err := json.Unmarshal(body, &c); err != nil { + t.Fatalf("Got an error unmarshalling the body: %v\n", err) + } + + if c.NumConns != 1 { + t.Fatalf("Expected 1 connection, got %d\n", c.NumConns) + } + if c.Total != 1 { + t.Fatalf("Expected 1 live connection, got %d\n", c.Total) + } + if c.Conns == nil || len(c.Conns) != 1 { + t.Fatalf("Expected 1 connection in array, got %d\n", len(c.Conns)) + } + + // Test inside details of each connection + ci := c.Conns[0] + + if ci.Cid == 0 { + t.Fatalf("Expected non-zero cid, got %v\n", ci.Cid) + } + if ci.IP != "127.0.0.1" { + t.Fatalf("Expected \"127.0.0.1\" for IP, got %v\n", ci.IP) + } + if ci.Port == 0 { + t.Fatalf("Expected non-zero port, got %v\n", ci.Port) + } + if ci.NumSubs != 1 { + t.Fatalf("Expected num_subs of 1, got %v\n", ci.NumSubs) + } + if len(ci.Subs) != 0 { + t.Fatalf("Expected subs of 0, got %v\n", ci.Subs) + } + if ci.InMsgs != 1 { + t.Fatalf("Expected InMsgs of 1, got %v\n", ci.InMsgs) + } + if ci.OutMsgs != 1 { + t.Fatalf("Expected OutMsgs of 1, got %v\n", ci.OutMsgs) + } + if ci.InBytes != 5 { + t.Fatalf("Expected InBytes of 1, got %v\n", ci.InBytes) + } + if ci.OutBytes != 5 { + t.Fatalf("Expected OutBytes of 1, got %v\n", ci.OutBytes) + } + if ci.Start.IsZero() { + t.Fatalf("Expected Start to be valid\n") + } + if ci.Uptime == "" { + t.Fatalf("Expected Uptime to be valid\n") + } + if ci.LastActivity.IsZero() { + t.Fatalf("Expected LastActivity to be valid\n") + } + if ci.LastActivity.UnixNano() < ci.Start.UnixNano() { + t.Fatalf("Expected LastActivity [%v] to be > Start [%v]\n", ci.LastActivity, ci.Start) + } + if ci.Idle == "" { + t.Fatalf("Expected Idle to be valid\n") + } +} + +func TestConnzWithSubs(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + cl := createClientConnSubscribeAndPublish(t) + defer cl.Close() + + url := fmt.Sprintf("http://127.0.0.1:%d/", MONITOR_PORT) + resp, err := http.Get(url + "connz?subs=1") + if err != nil { + t.Fatalf("Expected no error: Got %v\n", err) + } + if resp.StatusCode != 200 { + t.Fatalf("Expected a 200 response, got %d\n", resp.StatusCode) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Got an error reading the body: %v\n", err) + } + + c := server.Connz{} + if err := json.Unmarshal(body, &c); err != nil { + t.Fatalf("Got an error unmarshalling the body: %v\n", err) + } + + // Test inside details of each connection + ci := c.Conns[0] + if len(ci.Subs) != 1 || ci.Subs[0] != "foo" { + t.Fatalf("Expected subs of 1, got %v\n", ci.Subs) + } +} + +func TestConnzWithAuth(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/multi_user.conf") + defer srv.Shutdown() + + endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + curl := fmt.Sprintf("nats://%s:%s@%s/", opts.Users[0].Username, opts.Users[0].Password, endpoint) + nc, err := nats.Connect(curl) + if err != nil { + t.Fatalf("Got an error on Connect: %+v\n", err) + } + defer nc.Close() + + ch := make(chan struct{}) + nc.Subscribe("foo", func(m *nats.Msg) { ch <- struct{}{} }) + nc.Publish("foo", []byte("Hello")) + + // Wait for message + <-ch + + url := fmt.Sprintf("http://127.0.0.1:%d/", opts.HTTPPort) + + resp, err := http.Get(url + "connz?auth=1") + if err != nil { + t.Fatalf("Expected no error: Got %v\n", err) + } + if resp.StatusCode != 200 { + t.Fatalf("Expected a 200 response, got %d\n", resp.StatusCode) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Got an error reading the body: %v\n", err) + } + + c := server.Connz{} + if err := json.Unmarshal(body, &c); err != nil { + t.Fatalf("Got an error unmarshalling the body: %v\n", err) + } + + // Test that we have authorized_user and its Alice. + ci := c.Conns[0] + if ci.AuthorizedUser != opts.Users[0].Username { + t.Fatalf("Expected authorized_user to be %q, got %q\n", + opts.Users[0].Username, ci.AuthorizedUser) + } + +} + +func TestConnzWithOffsetAndLimit(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + cl1 := createClientConnSubscribeAndPublish(t) + defer cl1.Close() + + cl2 := createClientConnSubscribeAndPublish(t) + defer cl2.Close() + + url := fmt.Sprintf("http://127.0.0.1:%d/", MONITOR_PORT) + resp, err := http.Get(url + "connz?offset=1&limit=1") + if err != nil { + t.Fatalf("Expected no error: Got %v\n", err) + } + if resp.StatusCode != 200 { + t.Fatalf("Expected a 200 response, got %d\n", resp.StatusCode) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Got an error reading the body: %v\n", err) + } + + c := server.Connz{} + if err := json.Unmarshal(body, &c); err != nil { + t.Fatalf("Got an error unmarshalling the body: %v\n", err) + } + + if c.Limit != 1 { + t.Fatalf("Expected limit of 1, got %v\n", c.Limit) + } + + if c.Offset != 1 { + t.Fatalf("Expected offset of 1, got %v\n", c.Offset) + } + + if len(c.Conns) != 1 { + t.Fatalf("Expected conns of 1, got %v\n", len(c.Conns)) + } +} + +func TestSubsz(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + cl := createClientConnSubscribeAndPublish(t) + defer cl.Close() + + url := fmt.Sprintf("http://127.0.0.1:%d/", MONITOR_PORT) + resp, err := http.Get(url + "subscriptionsz") + if err != nil { + t.Fatalf("Expected no error: Got %v\n", err) + } + if resp.StatusCode != 200 { + t.Fatalf("Expected a 200 response, got %d\n", resp.StatusCode) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Got an error reading the body: %v\n", err) + } + + su := server.Subsz{} + if err := json.Unmarshal(body, &su); err != nil { + t.Fatalf("Got an error unmarshalling the body: %v\n", err) + } + + // Do some sanity checks on values + if su.NumSubs != 1 { + t.Fatalf("Expected num_subs of 1, got %v\n", su.NumSubs) + } +} + +func TestHTTPHost(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + // Grab non-127.0.0.1 address and try to use that to connect. + // Should fail. + var ip net.IP + ifaces, _ := net.Interfaces() + for _, i := range ifaces { + addrs, _ := i.Addrs() + for _, addr := range addrs { + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + // Skip loopback/127.0.0.1 or any ipv6 for now. + if ip.IsLoopback() || ip.To4() == nil { + ip = nil + continue + } + break + } + if ip != nil { + break + } + } + if ip == nil { + t.Fatalf("Could not find non-loopback IPV4 address") + } + url := fmt.Sprintf("http://%v:%d/", ip, MONITOR_PORT) + if resp, err := http.Get(url + "varz"); err == nil { + t.Fatalf("Expected error: Got %+v\n", resp) + } +} + +// Create a connection to test ConnInfo +func createClientConnSubscribeAndPublish(t *testing.T) net.Conn { + cl := createClientConn(t, "127.0.0.1", CLIENT_PORT) + send, expect := setupConn(t, cl) + expectMsgs := expectMsgsCommand(t, expect) + + send("SUB foo 1\r\nPUB foo 5\r\nhello\r\n") + expectMsgs(1) + + return cl +} + +func TestMonitorNoTLSConfig(t *testing.T) { + opts := DefaultTestOptions + opts.Port = CLIENT_PORT + opts.HTTPHost = "127.0.0.1" + opts.HTTPSPort = MONITOR_PORT + s := server.New(&opts) + defer s.Shutdown() + // Check with manually starting the monitoring, which should return an error + if err := s.StartMonitoring(); err == nil || !strings.Contains(err.Error(), "TLS") { + t.Fatalf("Expected error about missing TLS config, got %v", err) + } + // Also check by calling Start(), which should produce a fatal error + dl := &dummyLogger{} + s.SetLogger(dl, false, false) + defer s.SetLogger(nil, false, false) + s.Start() + if !strings.Contains(dl.msg, "TLS") { + t.Fatalf("Expected error about missing TLS config, got %v", dl.msg) + } +} + +func TestMonitorErrorOnListen(t *testing.T) { + s := runMonitorServer() + defer s.Shutdown() + + opts := DefaultTestOptions + opts.Port = CLIENT_PORT + 1 + opts.HTTPHost = "127.0.0.1" + opts.HTTPPort = MONITOR_PORT + s2 := server.New(&opts) + defer s2.Shutdown() + if err := s2.StartMonitoring(); err == nil || !strings.Contains(err.Error(), "listen") { + t.Fatalf("Expected error about not able to start listener, got %v", err) + } +} + +func TestMonitorBothPortsConfigured(t *testing.T) { + opts := DefaultTestOptions + opts.Port = CLIENT_PORT + opts.HTTPHost = "127.0.0.1" + opts.HTTPPort = MONITOR_PORT + opts.HTTPSPort = MONITOR_PORT + 1 + s := server.New(&opts) + defer s.Shutdown() + if err := s.StartMonitoring(); err == nil || !strings.Contains(err.Error(), "specify both") { + t.Fatalf("Expected error about ports configured, got %v", err) + } +} + +func TestMonitorStop(t *testing.T) { + resetPreviousHTTPConnections() + opts := DefaultTestOptions + opts.Port = CLIENT_PORT + opts.HTTPHost = "127.0.0.1" + opts.HTTPPort = MONITOR_PORT + url := fmt.Sprintf("http://%v:%d/", opts.HTTPHost, MONITOR_PORT) + // Create a server instance and start only the monitoring http server. + s := server.New(&opts) + if err := s.StartMonitoring(); err != nil { + t.Fatalf("Error starting monitoring: %v", err) + } + // Make sure http server is started + resp, err := http.Get(url + "varz") + if err != nil { + t.Fatalf("Error on http request: %v", err) + } + resp.Body.Close() + // Although the server itself was not started (we did not call s.Start()), + // Shutdown() should stop the http server. + s.Shutdown() + // HTTP request should now fail + if resp, err := http.Get(url + "varz"); err == nil { + t.Fatalf("Expected error: Got %+v\n", resp) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/opts_test.go b/vendor/github.com/nats-io/gnatsd/test/opts_test.go new file mode 100644 index 00000000..0275989e --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/opts_test.go @@ -0,0 +1,43 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import "testing" + +func TestServerConfig(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/override.conf") + defer srv.Shutdown() + + c := createClientConn(t, opts.Host, opts.Port) + defer c.Close() + + sinfo := checkInfoMsg(t, c) + if sinfo.MaxPayload != opts.MaxPayload { + t.Fatalf("Expected max_payload from server, got %d vs %d", + opts.MaxPayload, sinfo.MaxPayload) + } +} + +func TestTLSConfig(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/tls.conf") + defer srv.Shutdown() + + c := createClientConn(t, opts.Host, opts.Port) + defer c.Close() + + sinfo := checkInfoMsg(t, c) + if !sinfo.TLSRequired { + t.Fatal("Expected TLSRequired to be true when configured") + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/pedantic_test.go b/vendor/github.com/nats-io/gnatsd/test/pedantic_test.go new file mode 100644 index 00000000..62585336 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/pedantic_test.go @@ -0,0 +1,109 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "testing" + + "github.com/nats-io/gnatsd/server" +) + +func runPedanticServer() *server.Server { + opts := DefaultTestOptions + + opts.NoLog = false + opts.Trace = true + + opts.Port = PROTO_TEST_PORT + return RunServer(&opts) +} + +func TestPedanticSub(t *testing.T) { + s := runPedanticServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + send := sendCommand(t, c) + expect := expectCommand(t, c) + doConnect(t, c, false, true, false) + + // Ping should still be same + send("PING\r\n") + expect(pongRe) + + // Test malformed subjects for SUB + // Sub can contain wildcards, but + // subject must still be legit. + + // Empty terminal token + send("SUB foo. 1\r\n") + expect(errRe) + + // Empty beginning token + send("SUB .foo. 1\r\n") + expect(errRe) + + // Empty middle token + send("SUB foo..bar 1\r\n") + expect(errRe) + + // Bad non-terminal FWC + send("SUB foo.>.bar 1\r\n") + buf := expect(errRe) + + // Check that itr is 'Invalid Subject' + matches := errRe.FindAllSubmatch(buf, -1) + if len(matches) != 1 { + t.Fatal("Wanted one overall match") + } + if string(matches[0][1]) != "'Invalid Subject'" { + t.Fatalf("Expected 'Invalid Subject', got %s", string(matches[0][1])) + } +} + +func TestPedanticPub(t *testing.T) { + s := runPedanticServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + send := sendCommand(t, c) + expect := expectCommand(t, c) + doConnect(t, c, false, true, false) + + // Ping should still be same + send("PING\r\n") + expect(pongRe) + + // Test malformed subjects for PUB + // PUB subjects can not have wildcards + // This will error in pedantic mode + send("PUB foo.* 2\r\nok\r\n") + expect(errRe) + + send("PUB foo.> 2\r\nok\r\n") + expect(errRe) + + send("PUB foo. 2\r\nok\r\n") + expect(errRe) + + send("PUB .foo 2\r\nok\r\n") + expect(errRe) + + send("PUB foo..* 2\r\nok\r\n") + expect(errRe) +} diff --git a/vendor/github.com/nats-io/gnatsd/test/pid_test.go b/vendor/github.com/nats-io/gnatsd/test/pid_test.go new file mode 100644 index 00000000..a4bd9f31 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/pid_test.go @@ -0,0 +1,55 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "fmt" + "io/ioutil" + "os" + "testing" +) + +func TestPidFile(t *testing.T) { + opts := DefaultTestOptions + + tmpDir, err := ioutil.TempDir("", "_gnatsd") + if err != nil { + t.Fatal("Could not create tmp dir") + } + defer os.RemoveAll(tmpDir) + + file, err := ioutil.TempFile(tmpDir, "gnatsd:pid_") + if err != nil { + t.Fatalf("Unable to create temp file: %v", err) + } + file.Close() + opts.PidFile = file.Name() + + s := RunServer(&opts) + s.Shutdown() + + buf, err := ioutil.ReadFile(opts.PidFile) + if err != nil { + t.Fatalf("Could not read pid_file: %v", err) + } + if len(buf) <= 0 { + t.Fatal("Expected a non-zero length pid_file") + } + + pid := 0 + fmt.Sscanf(string(buf), "%d", &pid) + if pid != os.Getpid() { + t.Fatalf("Expected pid to be %d, got %d\n", os.Getpid(), pid) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/ping_test.go b/vendor/github.com/nats-io/gnatsd/test/ping_test.go new file mode 100644 index 00000000..ceb14814 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/ping_test.go @@ -0,0 +1,189 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "crypto/tls" + "fmt" + "net" + "testing" + "time" + + "github.com/nats-io/gnatsd/server" +) + +const ( + PING_TEST_PORT = 9972 + PING_INTERVAL = 50 * time.Millisecond + PING_MAX = 2 +) + +func runPingServer() *server.Server { + opts := DefaultTestOptions + opts.Port = PING_TEST_PORT + opts.PingInterval = PING_INTERVAL + opts.MaxPingsOut = PING_MAX + return RunServer(&opts) +} + +func TestPingSentToTLSConnection(t *testing.T) { + opts := DefaultTestOptions + opts.Port = PING_TEST_PORT + opts.PingInterval = PING_INTERVAL + opts.MaxPingsOut = PING_MAX + opts.TLSCert = "configs/certs/server-cert.pem" + opts.TLSKey = "configs/certs/server-key.pem" + opts.TLSCaCert = "configs/certs/ca.pem" + + tc := server.TLSConfigOpts{} + tc.CertFile = opts.TLSCert + tc.KeyFile = opts.TLSKey + tc.CaFile = opts.TLSCaCert + + opts.TLSConfig, _ = server.GenTLSConfig(&tc) + opts.TLSTimeout = 5 + s := RunServer(&opts) + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PING_TEST_PORT) + defer c.Close() + + checkInfoMsg(t, c) + c = tls.Client(c, &tls.Config{InsecureSkipVerify: true}) + tlsConn := c.(*tls.Conn) + tlsConn.Handshake() + + cs := fmt.Sprintf("CONNECT {\"verbose\":%v,\"pedantic\":%v,\"tls_required\":%v}\r\n", false, false, true) + sendProto(t, c, cs) + + expect := expectCommand(t, c) + + // Expect the max to be delivered correctly.. + for i := 0; i < PING_MAX; i++ { + time.Sleep(PING_INTERVAL / 2) + expect(pingRe) + } + + // We should get an error from the server + time.Sleep(PING_INTERVAL) + expect(errRe) + + // Server should close the connection at this point.. + time.Sleep(PING_INTERVAL) + c.SetWriteDeadline(time.Now().Add(PING_INTERVAL)) + + var err error + for { + _, err = c.Write([]byte("PING\r\n")) + if err != nil { + break + } + } + c.SetWriteDeadline(time.Time{}) + + if err == nil { + t.Fatal("No error: Expected to have connection closed") + } + if ne, ok := err.(net.Error); ok && ne.Timeout() { + t.Fatal("timeout: Expected to have connection closed") + } +} + +func TestPingInterval(t *testing.T) { + s := runPingServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PING_TEST_PORT) + defer c.Close() + + doConnect(t, c, false, false, false) + + expect := expectCommand(t, c) + + // Expect the max to be delivered correctly.. + for i := 0; i < PING_MAX; i++ { + time.Sleep(PING_INTERVAL / 2) + expect(pingRe) + } + + // We should get an error from the server + time.Sleep(PING_INTERVAL) + expect(errRe) + + // Server should close the connection at this point.. + time.Sleep(PING_INTERVAL) + c.SetWriteDeadline(time.Now().Add(PING_INTERVAL)) + + var err error + for { + _, err = c.Write([]byte("PING\r\n")) + if err != nil { + break + } + } + c.SetWriteDeadline(time.Time{}) + + if err == nil { + t.Fatal("No error: Expected to have connection closed") + } + if ne, ok := err.(net.Error); ok && ne.Timeout() { + t.Fatal("timeout: Expected to have connection closed") + } +} + +func TestUnpromptedPong(t *testing.T) { + s := runPingServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PING_TEST_PORT) + defer c.Close() + + doConnect(t, c, false, false, false) + + expect := expectCommand(t, c) + + // Send lots of PONGs in a row... + for i := 0; i < 100; i++ { + c.Write([]byte("PONG\r\n")) + } + + // The server should still send the max number of PINGs and then + // close the connection. + for i := 0; i < PING_MAX; i++ { + time.Sleep(PING_INTERVAL / 2) + expect(pingRe) + } + + // We should get an error from the server + time.Sleep(PING_INTERVAL) + expect(errRe) + + // Server should close the connection at this point.. + c.SetWriteDeadline(time.Now().Add(PING_INTERVAL)) + var err error + for { + _, err = c.Write([]byte("PING\r\n")) + if err != nil { + break + } + } + c.SetWriteDeadline(time.Time{}) + + if err == nil { + t.Fatal("No error: Expected to have connection closed") + } + if ne, ok := err.(net.Error); ok && ne.Timeout() { + t.Fatal("timeout: Expected to have connection closed") + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/port_test.go b/vendor/github.com/nats-io/gnatsd/test/port_test.go new file mode 100644 index 00000000..1c750384 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/port_test.go @@ -0,0 +1,52 @@ +// Copyright 2014-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "net" + "strconv" + "testing" + + "github.com/nats-io/gnatsd/server" +) + +func TestResolveRandomPort(t *testing.T) { + opts := &server.Options{Host: "127.0.0.1", Port: server.RANDOM_PORT, NoSigs: true} + s := RunServer(opts) + defer s.Shutdown() + + addr := s.Addr() + _, port, err := net.SplitHostPort(addr.String()) + if err != nil { + t.Fatalf("Expected no error: Got %v\n", err) + } + + portNum, err := strconv.Atoi(port) + if err != nil { + t.Fatalf("Expected no error: Got %v\n", err) + } + + if portNum == server.DEFAULT_PORT { + t.Fatalf("Expected server to choose a random port\nGot: %d", server.DEFAULT_PORT) + } + + if portNum == server.RANDOM_PORT { + t.Fatalf("Expected server to choose a random port\nGot: %d", server.RANDOM_PORT) + } + + if opts.Port != portNum { + t.Fatalf("Options port (%d) should have been overridden by chosen random port (%d)", + opts.Port, portNum) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/ports_test.go b/vendor/github.com/nats-io/gnatsd/test/ports_test.go new file mode 100644 index 00000000..ec7237b4 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/ports_test.go @@ -0,0 +1,195 @@ +// Copyright 2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "os" + "path" + "strings" + "testing" + "time" + + "github.com/nats-io/gnatsd/server" +) + +// waits until a calculated list of listeners is resolved or a timeout +func waitForFile(path string, dur time.Duration) ([]byte, error) { + end := time.Now().Add(dur) + for time.Now().Before(end) { + if _, err := os.Stat(path); os.IsNotExist(err) { + time.Sleep(25 * time.Millisecond) + continue + } else { + return ioutil.ReadFile(path) + } + } + return nil, errors.New("Timeout") +} + +func portFile(dirname string) string { + return path.Join(dirname, fmt.Sprintf("%s_%d.ports", path.Base(os.Args[0]), os.Getpid())) +} + +func TestPortsFile(t *testing.T) { + portFileDir := os.TempDir() + + opts := DefaultTestOptions + opts.PortsFileDir = portFileDir + opts.Port = -1 + opts.HTTPPort = -1 + opts.ProfPort = -1 + opts.Cluster.Port = -1 + + s := RunServer(&opts) + // this for test cleanup in case we fail - will be ignored if server already shutdown + defer s.Shutdown() + + ports := s.PortsInfo(5 * time.Second) + + if ports == nil { + t.Fatal("services failed to start in 5 seconds") + } + + // the pid file should be + portsFile := portFile(portFileDir) + + if portsFile == "" { + t.Fatal("Expected a ports file") + } + + // try to read a file here - the file should be a json + buf, err := waitForFile(portsFile, 5*time.Second) + if err != nil { + t.Fatalf("Could not read ports file: %v", err) + } + if len(buf) <= 0 { + t.Fatal("Expected a non-zero length ports file") + } + + readPorts := server.Ports{} + json.Unmarshal(buf, &readPorts) + + if len(readPorts.Nats) == 0 || !strings.HasPrefix(readPorts.Nats[0], "nats://") { + t.Fatal("Expected at least one nats url") + } + + if len(readPorts.Monitoring) == 0 || !strings.HasPrefix(readPorts.Monitoring[0], "http://") { + t.Fatal("Expected at least one monitoring url") + } + + if len(readPorts.Cluster) == 0 || !strings.HasPrefix(readPorts.Cluster[0], "nats://") { + t.Fatal("Expected at least one cluster listen url") + } + + if len(readPorts.Profile) == 0 || !strings.HasPrefix(readPorts.Profile[0], "http://") { + t.Fatal("Expected at least one profile listen url") + } + + // testing cleanup + s.Shutdown() + // if we called shutdown, the cleanup code should have kicked + if _, err := os.Stat(portsFile); os.IsNotExist(err) { + // good + } else { + t.Fatalf("the port file %s was not deleted", portsFile) + } +} + +// makes a temp directory with two directories 'A' and 'B' +// the location of the ports file is changed from dir A to dir B. +func TestPortsFileReload(t *testing.T) { + // make a temp dir + tempDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatalf("Error creating temp director (%s): %v", tempDir, err) + } + defer os.RemoveAll(tempDir) + + // make child temp dir A + dirA := path.Join(tempDir, "A") + os.MkdirAll(dirA, 0777) + + // write the config file with a reference to A + config := fmt.Sprintf("ports_file_dir %s\nport -1", dirA) + confPath := path.Join(tempDir, fmt.Sprintf("%d.conf", os.Getpid())) + if err := ioutil.WriteFile(confPath, []byte(config), 0666); err != nil { + t.Fatalf("Error writing ports file (%s): %v", confPath, err) + } + + opts, err := server.ProcessConfigFile(confPath) + if err != nil { + t.Fatalf("Error processing the configuration: %v", err) + } + + s := RunServer(opts) + defer s.Shutdown() + + ports := s.PortsInfo(5 * time.Second) + if ports == nil { + t.Fatal("services failed to start in 5 seconds") + } + + // get the ports file path name + portsFileInA := portFile(dirA) + // the file should be in dirA + if !strings.HasPrefix(portsFileInA, dirA) { + t.Fatalf("expected ports file to be in [%s] but was in [%s]", dirA, portsFileInA) + } + // wait for it + buf, err := waitForFile(portsFileInA, 5*time.Second) + if err != nil { + t.Fatalf("Could not read ports file: %v", err) + } + if len(buf) <= 0 { + t.Fatal("Expected a non-zero length ports file") + } + + // change the configuration for the ports file to dirB + dirB := path.Join(tempDir, "B") + os.MkdirAll(dirB, 0777) + + config = fmt.Sprintf("ports_file_dir %s\nport -1", dirB) + if err := ioutil.WriteFile(confPath, []byte(config), 0666); err != nil { + t.Fatalf("Error writing ports file (%s): %v", confPath, err) + } + + // reload the server + if err := s.Reload(); err != nil { + t.Fatalf("error reloading server: %v", err) + } + + // wait for the new file to show up + portsFileInB := portFile(dirB) + buf, err = waitForFile(portsFileInB, 5*time.Second) + if !strings.HasPrefix(portsFileInB, dirB) { + t.Fatalf("expected ports file to be in [%s] but was in [%s]", dirB, portsFileInB) + } + if err != nil { + t.Fatalf("Could not read ports file: %v", err) + } + if len(buf) <= 0 { + t.Fatal("Expected a non-zero length ports file") + } + + // the file in dirA should have deleted + if _, err := os.Stat(portsFileInA); os.IsNotExist(err) { + // good + } else { + t.Fatalf("the port file %s was not deleted", portsFileInA) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/proto_test.go b/vendor/github.com/nats-io/gnatsd/test/proto_test.go new file mode 100644 index 00000000..200c49bf --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/proto_test.go @@ -0,0 +1,317 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "encoding/json" + "testing" + "time" + + "github.com/nats-io/gnatsd/server" +) + +const PROTO_TEST_PORT = 9922 + +func runProtoServer() *server.Server { + opts := DefaultTestOptions + opts.Port = PROTO_TEST_PORT + return RunServer(&opts) +} + +func TestProtoBasics(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + send, expect := setupConn(t, c) + expectMsgs := expectMsgsCommand(t, expect) + + // Ping + send("PING\r\n") + expect(pongRe) + + // Single Msg + send("SUB foo 1\r\nPUB foo 5\r\nhello\r\n") + matches := expectMsgs(1) + checkMsg(t, matches[0], "foo", "1", "", "5", "hello") + + // 2 Messages + send("SUB * 2\r\nPUB foo 2\r\nok\r\n") + matches = expectMsgs(2) + // Could arrive in any order + checkMsg(t, matches[0], "foo", "", "", "2", "ok") + checkMsg(t, matches[1], "foo", "", "", "2", "ok") +} + +func TestProtoErr(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + send, expect := setupConn(t, c) + + // Make sure we get an error on bad proto + send("ZZZ") + expect(errRe) +} + +func TestUnsubMax(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + send, expect := setupConn(t, c) + expectMsgs := expectMsgsCommand(t, expect) + + send("SUB foo 22\r\n") + send("UNSUB 22 2\r\n") + for i := 0; i < 100; i++ { + send("PUB foo 2\r\nok\r\n") + } + + time.Sleep(50 * time.Millisecond) + + matches := expectMsgs(2) + checkMsg(t, matches[0], "foo", "22", "", "2", "ok") + checkMsg(t, matches[1], "foo", "22", "", "2", "ok") +} + +func TestQueueSub(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + send, expect := setupConn(t, c) + expectMsgs := expectMsgsCommand(t, expect) + + sent := 100 + send("SUB foo qgroup1 22\r\n") + send("SUB foo qgroup1 32\r\n") + for i := 0; i < sent; i++ { + send("PUB foo 2\r\nok\r\n") + } + // Wait for responses + time.Sleep(250 * time.Millisecond) + + matches := expectMsgs(sent) + sids := make(map[string]int) + for _, m := range matches { + sids[string(m[sidIndex])]++ + } + if len(sids) != 2 { + t.Fatalf("Expected only 2 sids, got %d\n", len(sids)) + } + for k, c := range sids { + if c < 35 { + t.Fatalf("Expected ~50 (+-15) msgs for sid:'%s', got %d\n", k, c) + } + } +} + +func TestMultipleQueueSub(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + send, expect := setupConn(t, c) + expectMsgs := expectMsgsCommand(t, expect) + + sent := 100 + send("SUB foo g1 1\r\n") + send("SUB foo g1 2\r\n") + send("SUB foo g2 3\r\n") + send("SUB foo g2 4\r\n") + + for i := 0; i < sent; i++ { + send("PUB foo 2\r\nok\r\n") + } + // Wait for responses + time.Sleep(250 * time.Millisecond) + + matches := expectMsgs(sent * 2) + sids := make(map[string]int) + for _, m := range matches { + sids[string(m[sidIndex])]++ + } + if len(sids) != 4 { + t.Fatalf("Expected 4 sids, got %d\n", len(sids)) + } + for k, c := range sids { + if c < 35 { + t.Fatalf("Expected ~50 (+-15) msgs for '%s', got %d\n", k, c) + } + } +} + +func TestPubToArgState(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + send, expect := setupConn(t, c) + + send("PUBS foo 2\r\nok\r\n") + expect(errRe) +} + +func TestSubToArgState(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + send, expect := setupConn(t, c) + + send("SUBZZZ foo 1\r\n") + expect(errRe) +} + +// Issue #63 +func TestProtoCrash(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + send, expect := sendCommand(t, c), expectCommand(t, c) + + checkInfoMsg(t, c) + + send("CONNECT {\"verbose\":true,\"tls_required\":false,\"user\":\"test\",\"pedantic\":true,\"pass\":\"password\"}") + + time.Sleep(100 * time.Millisecond) + + send("\r\n") + expect(okRe) +} + +// Issue #136 +func TestDuplicateProtoSub(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + send, expect := setupConn(t, c) + + send("PING\r\n") + expect(pongRe) + + send("SUB foo 1\r\n") + + send("SUB foo 1\r\n") + + ns := 0 + + for i := 0; i < 5; i++ { + ns = int(s.NumSubscriptions()) + if ns == 0 { + time.Sleep(50 * time.Millisecond) + } else { + break + } + } + + if ns != 1 { + t.Fatalf("Expected 1 subscription, got %d\n", ns) + } +} + +func TestIncompletePubArg(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + send, expect := setupConn(t, c) + + size := 10000 + goodBuf := "" + for i := 0; i < size; i++ { + goodBuf += "A" + } + goodBuf += "\r\n" + + badSize := 3371 + badBuf := "" + for i := 0; i < badSize; i++ { + badBuf += "B" + } + // Message is corrupted and since we are still reading from client, + // next PUB accidentally becomes part of the payload of the + // incomplete message thus breaking the protocol. + badBuf2 := "" + for i := 0; i < size; i++ { + badBuf2 += "C" + } + badBuf2 += "\r\n" + + pub := "PUB example 10000\r\n" + send(pub + goodBuf + pub + goodBuf + pub + badBuf + pub + badBuf2) + expect(errRe) +} + +func TestControlLineMaximums(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + send, expect := setupConn(t, c) + + pubTooLong := "PUB foo " + for i := 0; i < 32; i++ { + pubTooLong += "2222222222" + } + send(pubTooLong) + expect(errRe) +} + +func TestServerInfoWithClientAdvertise(t *testing.T) { + opts := DefaultTestOptions + opts.Port = PROTO_TEST_PORT + opts.ClientAdvertise = "me:1" + s := RunServer(&opts) + defer s.Shutdown() + + c := createClientConn(t, opts.Host, PROTO_TEST_PORT) + defer c.Close() + + buf := expectResult(t, c, infoRe) + js := infoRe.FindAllSubmatch(buf, 1)[0][1] + var sinfo server.Info + err := json.Unmarshal(js, &sinfo) + if err != nil { + t.Fatalf("Could not unmarshal INFO json: %v\n", err) + } + if sinfo.Host != "me" || sinfo.Port != 1 { + t.Fatalf("Expected INFO Host:Port to be me:1, got %s:%d", sinfo.Host, sinfo.Port) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/route_discovery_test.go b/vendor/github.com/nats-io/gnatsd/test/route_discovery_test.go new file mode 100644 index 00000000..417993d3 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/route_discovery_test.go @@ -0,0 +1,665 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "bufio" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + "runtime" + "strconv" + "strings" + "testing" + "time" + + "github.com/nats-io/gnatsd/server" +) + +func runSeedServer(t *testing.T) (*server.Server, *server.Options) { + return RunServerWithConfig("./configs/seed.conf") +} + +func runAuthSeedServer(t *testing.T) (*server.Server, *server.Options) { + return RunServerWithConfig("./configs/auth_seed.conf") +} + +func TestSeedFirstRouteInfo(t *testing.T) { + s, opts := runSeedServer(t) + defer s.Shutdown() + + rc := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer rc.Close() + + _, routeExpect := setupRoute(t, rc, opts) + buf := routeExpect(infoRe) + + info := server.Info{} + if err := json.Unmarshal(buf[4:], &info); err != nil { + t.Fatalf("Could not unmarshal route info: %v", err) + } + + if info.ID != s.ID() { + t.Fatalf("Expected seed's ID %q, got %q", s.ID(), info.ID) + } +} + +func TestSeedMultipleRouteInfo(t *testing.T) { + s, opts := runSeedServer(t) + defer s.Shutdown() + + rc1 := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer rc1.Close() + + rc1ID := "2222" + rc1Port := 22 + rc1Host := "127.0.0.1" + + routeSend1, route1Expect := setupRouteEx(t, rc1, opts, rc1ID) + route1Expect(infoRe) + + // register ourselves via INFO + r1Info := server.Info{ID: rc1ID, Host: rc1Host, Port: rc1Port} + b, _ := json.Marshal(r1Info) + infoJSON := fmt.Sprintf(server.InfoProto, b) + routeSend1(infoJSON) + routeSend1("PING\r\n") + route1Expect(pongRe) + + rc2 := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer rc2.Close() + + rc2ID := "2224" + rc2Port := 24 + rc2Host := "127.0.0.1" + + routeSend2, route2Expect := setupRouteEx(t, rc2, opts, rc2ID) + + hp2 := fmt.Sprintf("nats-route://%s/", net.JoinHostPort(rc2Host, strconv.Itoa(rc2Port))) + + // register ourselves via INFO + r2Info := server.Info{ID: rc2ID, Host: rc2Host, Port: rc2Port} + b, _ = json.Marshal(r2Info) + infoJSON = fmt.Sprintf(server.InfoProto, b) + routeSend2(infoJSON) + + // Now read back the second INFO route1 should receive letting + // it know about route2 + buf := route1Expect(infoRe) + + info := server.Info{} + if err := json.Unmarshal(buf[4:], &info); err != nil { + t.Fatalf("Could not unmarshal route info: %v", err) + } + + if info.ID != rc2ID { + t.Fatalf("Expected info.ID to be %q, got %q", rc2ID, info.ID) + } + if info.IP == "" { + t.Fatalf("Expected a IP for the implicit route") + } + if info.IP != hp2 { + t.Fatalf("Expected IP Host of %s, got %s\n", hp2, info.IP) + } + + route2Expect(infoRe) + routeSend2("PING\r\n") + route2Expect(pongRe) + + // Now let's do a third. + rc3 := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer rc3.Close() + + rc3ID := "2226" + rc3Port := 26 + rc3Host := "127.0.0.1" + + routeSend3, _ := setupRouteEx(t, rc3, opts, rc3ID) + + // register ourselves via INFO + r3Info := server.Info{ID: rc3ID, Host: rc3Host, Port: rc3Port} + b, _ = json.Marshal(r3Info) + infoJSON = fmt.Sprintf(server.InfoProto, b) + routeSend3(infoJSON) + + // Now read back out the info from the seed route + buf = route1Expect(infoRe) + + info = server.Info{} + if err := json.Unmarshal(buf[4:], &info); err != nil { + t.Fatalf("Could not unmarshal route info: %v", err) + } + + if info.ID != rc3ID { + t.Fatalf("Expected info.ID to be %q, got %q", rc3ID, info.ID) + } + + // Now read back out the info from the seed route + buf = route2Expect(infoRe) + + info = server.Info{} + if err := json.Unmarshal(buf[4:], &info); err != nil { + t.Fatalf("Could not unmarshal route info: %v", err) + } + + if info.ID != rc3ID { + t.Fatalf("Expected info.ID to be %q, got %q", rc3ID, info.ID) + } +} + +func TestSeedSolicitWorks(t *testing.T) { + s1, opts := runSeedServer(t) + defer s1.Shutdown() + + // Create the routes string for others to connect to the seed. + routesStr := fmt.Sprintf("nats-route://%s:%d/", opts.Cluster.Host, opts.Cluster.Port) + + // Run Server #2 + s2Opts := nextServerOpts(opts) + s2Opts.Routes = server.RoutesFromStr(routesStr) + + s2 := RunServer(s2Opts) + defer s2.Shutdown() + + // Run Server #3 + s3Opts := nextServerOpts(s2Opts) + + s3 := RunServer(s3Opts) + defer s3.Shutdown() + + // Wait for a bit for graph to connect + time.Sleep(500 * time.Millisecond) + + // Grab Routez from monitor ports, make sure we are fully connected + url := fmt.Sprintf("http://%s:%d/", opts.Host, opts.HTTPPort) + rz := readHTTPRoutez(t, url) + ris := expectRids(t, rz, []string{s2.ID(), s3.ID()}) + if ris[s2.ID()].IsConfigured { + t.Fatalf("Expected server not to be configured\n") + } + if ris[s3.ID()].IsConfigured { + t.Fatalf("Expected server not to be configured\n") + } + + url = fmt.Sprintf("http://%s:%d/", s2Opts.Host, s2Opts.HTTPPort) + rz = readHTTPRoutez(t, url) + ris = expectRids(t, rz, []string{s1.ID(), s3.ID()}) + if !ris[s1.ID()].IsConfigured { + t.Fatalf("Expected seed server to be configured\n") + } + if ris[s3.ID()].IsConfigured { + t.Fatalf("Expected server not to be configured\n") + } + + url = fmt.Sprintf("http://%s:%d/", s3Opts.Host, s3Opts.HTTPPort) + rz = readHTTPRoutez(t, url) + ris = expectRids(t, rz, []string{s1.ID(), s2.ID()}) + if !ris[s1.ID()].IsConfigured { + t.Fatalf("Expected seed server to be configured\n") + } + if ris[s2.ID()].IsConfigured { + t.Fatalf("Expected server not to be configured\n") + } +} + +type serverInfo struct { + server *server.Server + opts *server.Options +} + +func checkConnected(t *testing.T, servers []serverInfo, current int, oneSeed bool) error { + s := servers[current] + + // Grab Routez from monitor ports, make sure we are fully connected + url := fmt.Sprintf("http://%s:%d/", s.opts.Host, s.opts.HTTPPort) + rz := readHTTPRoutez(t, url) + total := len(servers) + var ids []string + for i := 0; i < total; i++ { + if i == current { + continue + } + ids = append(ids, servers[i].server.ID()) + } + ris, err := expectRidsNoFatal(t, true, rz, ids) + if err != nil { + return err + } + for i := 0; i < total; i++ { + if i == current { + continue + } + s := servers[i] + if current == 0 || ((oneSeed && i > 0) || (!oneSeed && (i != current-1))) { + if ris[s.server.ID()].IsConfigured { + return fmt.Errorf("Expected server %s:%d not to be configured", s.opts.Host, s.opts.Port) + } + } else if oneSeed || (i == current-1) { + if !ris[s.server.ID()].IsConfigured { + return fmt.Errorf("Expected server %s:%d to be configured", s.opts.Host, s.opts.Port) + } + } + } + return nil +} + +func TestStressSeedSolicitWorks(t *testing.T) { + s1, opts := runSeedServer(t) + defer s1.Shutdown() + + // Create the routes string for others to connect to the seed. + routesStr := fmt.Sprintf("nats-route://%s:%d/", opts.Cluster.Host, opts.Cluster.Port) + + s2Opts := nextServerOpts(opts) + s2Opts.Routes = server.RoutesFromStr(routesStr) + + s3Opts := nextServerOpts(s2Opts) + s4Opts := nextServerOpts(s3Opts) + + for i := 0; i < 10; i++ { + func() { + // Run these servers manually, because we want them to start and + // connect to s1 as fast as possible. + + s2 := server.New(s2Opts) + if s2 == nil { + panic("No NATS Server object returned.") + } + defer s2.Shutdown() + go s2.Start() + + s3 := server.New(s3Opts) + if s3 == nil { + panic("No NATS Server object returned.") + } + defer s3.Shutdown() + go s3.Start() + + s4 := server.New(s4Opts) + if s4 == nil { + panic("No NATS Server object returned.") + } + defer s4.Shutdown() + go s4.Start() + + serversInfo := []serverInfo{{s1, opts}, {s2, s2Opts}, {s3, s3Opts}, {s4, s4Opts}} + + checkFor(t, 5*time.Second, 100*time.Millisecond, func() error { + for j := 0; j < len(serversInfo); j++ { + if err := checkConnected(t, serversInfo, j, true); err != nil { + return err + } + } + return nil + }) + }() + checkNumRoutes(t, s1, 0) + } +} + +func TestChainedSolicitWorks(t *testing.T) { + s1, opts := runSeedServer(t) + defer s1.Shutdown() + + // Create the routes string for others to connect to the seed. + routesStr := fmt.Sprintf("nats-route://%s:%d/", opts.Cluster.Host, opts.Cluster.Port) + + // Run Server #2 + s2Opts := nextServerOpts(opts) + s2Opts.Routes = server.RoutesFromStr(routesStr) + + s2 := RunServer(s2Opts) + defer s2.Shutdown() + + // Run Server #3 + s3Opts := nextServerOpts(s2Opts) + // We will have s3 connect to s2, not the seed. + routesStr = fmt.Sprintf("nats-route://%s:%d/", s2Opts.Cluster.Host, s2Opts.Cluster.Port) + s3Opts.Routes = server.RoutesFromStr(routesStr) + + s3 := RunServer(s3Opts) + defer s3.Shutdown() + + // Wait for a bit for graph to connect + time.Sleep(500 * time.Millisecond) + + // Grab Routez from monitor ports, make sure we are fully connected + url := fmt.Sprintf("http://%s:%d/", opts.Host, opts.HTTPPort) + rz := readHTTPRoutez(t, url) + ris := expectRids(t, rz, []string{s2.ID(), s3.ID()}) + if ris[s2.ID()].IsConfigured { + t.Fatalf("Expected server not to be configured\n") + } + if ris[s3.ID()].IsConfigured { + t.Fatalf("Expected server not to be configured\n") + } + + url = fmt.Sprintf("http://%s:%d/", s2Opts.Host, s2Opts.HTTPPort) + rz = readHTTPRoutez(t, url) + ris = expectRids(t, rz, []string{s1.ID(), s3.ID()}) + if !ris[s1.ID()].IsConfigured { + t.Fatalf("Expected seed server to be configured\n") + } + if ris[s3.ID()].IsConfigured { + t.Fatalf("Expected server not to be configured\n") + } + + url = fmt.Sprintf("http://%s:%d/", s3Opts.Host, s3Opts.HTTPPort) + rz = readHTTPRoutez(t, url) + ris = expectRids(t, rz, []string{s1.ID(), s2.ID()}) + if !ris[s2.ID()].IsConfigured { + t.Fatalf("Expected s2 server to be configured\n") + } + if ris[s1.ID()].IsConfigured { + t.Fatalf("Expected seed server not to be configured\n") + } +} + +func TestStressChainedSolicitWorks(t *testing.T) { + s1, opts := runSeedServer(t) + defer s1.Shutdown() + + // Create the routes string for s2 to connect to the seed + routesStr := fmt.Sprintf("nats-route://%s:%d/", opts.Cluster.Host, opts.Cluster.Port) + s2Opts := nextServerOpts(opts) + s2Opts.Routes = server.RoutesFromStr(routesStr) + + s3Opts := nextServerOpts(s2Opts) + // Create the routes string for s3 to connect to s2 + routesStr = fmt.Sprintf("nats-route://%s:%d/", s2Opts.Cluster.Host, s2Opts.Cluster.Port) + s3Opts.Routes = server.RoutesFromStr(routesStr) + + s4Opts := nextServerOpts(s3Opts) + // Create the routes string for s4 to connect to s3 + routesStr = fmt.Sprintf("nats-route://%s:%d/", s3Opts.Cluster.Host, s3Opts.Cluster.Port) + s4Opts.Routes = server.RoutesFromStr(routesStr) + + for i := 0; i < 10; i++ { + func() { + // Run these servers manually, because we want them to start and + // connect to s1 as fast as possible. + + s2 := server.New(s2Opts) + if s2 == nil { + panic("No NATS Server object returned.") + } + defer s2.Shutdown() + go s2.Start() + + s3 := server.New(s3Opts) + if s3 == nil { + panic("No NATS Server object returned.") + } + defer s3.Shutdown() + go s3.Start() + + s4 := server.New(s4Opts) + if s4 == nil { + panic("No NATS Server object returned.") + } + defer s4.Shutdown() + go s4.Start() + + serversInfo := []serverInfo{{s1, opts}, {s2, s2Opts}, {s3, s3Opts}, {s4, s4Opts}} + + checkFor(t, 5*time.Second, 100*time.Millisecond, func() error { + for j := 0; j < len(serversInfo); j++ { + if err := checkConnected(t, serversInfo, j, false); err != nil { + return err + } + } + return nil + }) + }() + checkNumRoutes(t, s1, 0) + } +} + +func TestAuthSeedSolicitWorks(t *testing.T) { + s1, opts := runAuthSeedServer(t) + defer s1.Shutdown() + + // Create the routes string for others to connect to the seed. + routesStr := fmt.Sprintf("nats-route://%s:%s@%s:%d/", opts.Cluster.Username, opts.Cluster.Password, opts.Cluster.Host, opts.Cluster.Port) + + // Run Server #2 + s2Opts := nextServerOpts(opts) + s2Opts.Routes = server.RoutesFromStr(routesStr) + + s2 := RunServer(s2Opts) + defer s2.Shutdown() + + // Run Server #3 + s3Opts := nextServerOpts(s2Opts) + + s3 := RunServer(s3Opts) + defer s3.Shutdown() + + // Wait for a bit for graph to connect + time.Sleep(500 * time.Millisecond) + + // Grab Routez from monitor ports, make sure we are fully connected + url := fmt.Sprintf("http://%s:%d/", opts.Host, opts.HTTPPort) + rz := readHTTPRoutez(t, url) + ris := expectRids(t, rz, []string{s2.ID(), s3.ID()}) + if ris[s2.ID()].IsConfigured { + t.Fatalf("Expected server not to be configured\n") + } + if ris[s3.ID()].IsConfigured { + t.Fatalf("Expected server not to be configured\n") + } + + url = fmt.Sprintf("http://%s:%d/", s2Opts.Host, s2Opts.HTTPPort) + rz = readHTTPRoutez(t, url) + ris = expectRids(t, rz, []string{s1.ID(), s3.ID()}) + if !ris[s1.ID()].IsConfigured { + t.Fatalf("Expected seed server to be configured\n") + } + if ris[s3.ID()].IsConfigured { + t.Fatalf("Expected server not to be configured\n") + } + + url = fmt.Sprintf("http://%s:%d/", s3Opts.Host, s3Opts.HTTPPort) + rz = readHTTPRoutez(t, url) + ris = expectRids(t, rz, []string{s1.ID(), s2.ID()}) + if !ris[s1.ID()].IsConfigured { + t.Fatalf("Expected seed server to be configured\n") + } + if ris[s2.ID()].IsConfigured { + t.Fatalf("Expected server not to be configured\n") + } +} + +// Helper to check for correct route memberships +func expectRids(t *testing.T, rz *server.Routez, rids []string) map[string]*server.RouteInfo { + ri, err := expectRidsNoFatal(t, false, rz, rids) + if err != nil { + t.Fatalf("%v", err) + } + return ri +} + +func expectRidsNoFatal(t *testing.T, direct bool, rz *server.Routez, rids []string) (map[string]*server.RouteInfo, error) { + caller := 1 + if !direct { + caller++ + } + if len(rids) != rz.NumRoutes { + _, fn, line, _ := runtime.Caller(caller) + return nil, fmt.Errorf("[%s:%d] Expecting %d routes, got %d\n", fn, line, len(rids), rz.NumRoutes) + } + set := make(map[string]bool) + for _, v := range rids { + set[v] = true + } + // Make result map for additional checking + ri := make(map[string]*server.RouteInfo) + for _, r := range rz.Routes { + if !set[r.RemoteID] { + _, fn, line, _ := runtime.Caller(caller) + return nil, fmt.Errorf("[%s:%d] Route with rid %s unexpected, expected %+v\n", fn, line, r.RemoteID, rids) + } + ri[r.RemoteID] = r + } + return ri, nil +} + +// Helper to easily grab routez info. +func readHTTPRoutez(t *testing.T, url string) *server.Routez { + resetPreviousHTTPConnections() + resp, err := http.Get(url + "routez") + if err != nil { + stackFatalf(t, "Expected no error: Got %v\n", err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + stackFatalf(t, "Expected a 200 response, got %d\n", resp.StatusCode) + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + stackFatalf(t, "Got an error reading the body: %v\n", err) + } + r := server.Routez{} + if err := json.Unmarshal(body, &r); err != nil { + stackFatalf(t, "Got an error unmarshalling the body: %v\n", err) + } + return &r +} + +func TestSeedReturnIPInInfo(t *testing.T) { + s, opts := runSeedServer(t) + defer s.Shutdown() + + rc1 := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer rc1.Close() + + rc1ID := "2222" + rc1Port := 22 + rc1Host := "127.0.0.1" + + routeSend1, route1Expect := setupRouteEx(t, rc1, opts, rc1ID) + route1Expect(infoRe) + + // register ourselves via INFO + r1Info := server.Info{ID: rc1ID, Host: rc1Host, Port: rc1Port} + b, _ := json.Marshal(r1Info) + infoJSON := fmt.Sprintf(server.InfoProto, b) + routeSend1(infoJSON) + routeSend1("PING\r\n") + route1Expect(pongRe) + + rc2 := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer rc2.Close() + + rc2ID := "2224" + rc2Port := 24 + rc2Host := "127.0.0.1" + + routeSend2, _ := setupRouteEx(t, rc2, opts, rc2ID) + + // register ourselves via INFO + r2Info := server.Info{ID: rc2ID, Host: rc2Host, Port: rc2Port} + b, _ = json.Marshal(r2Info) + infoJSON = fmt.Sprintf(server.InfoProto, b) + routeSend2(infoJSON) + + // Now read info that route1 should have received from the seed + buf := route1Expect(infoRe) + + info := server.Info{} + if err := json.Unmarshal(buf[4:], &info); err != nil { + t.Fatalf("Could not unmarshal route info: %v", err) + } + + if info.IP == "" { + t.Fatal("Expected to have IP in INFO") + } + rip, _, err := net.SplitHostPort(strings.TrimPrefix(info.IP, "nats-route://")) + if err != nil { + t.Fatalf("Error parsing url: %v", err) + } + addr, ok := rc1.RemoteAddr().(*net.TCPAddr) + if !ok { + t.Fatal("Unable to get IP address from route") + } + s1 := strings.ToLower(addr.IP.String()) + s2 := strings.ToLower(rip) + if s1 != s2 { + t.Fatalf("Expected IP %s, got %s", s1, s2) + } +} + +func TestImplicitRouteRetry(t *testing.T) { + srvSeed, optsSeed := runSeedServer(t) + defer srvSeed.Shutdown() + + optsA := nextServerOpts(optsSeed) + optsA.Routes = server.RoutesFromStr(fmt.Sprintf("nats://%s:%d", optsSeed.Cluster.Host, optsSeed.Cluster.Port)) + optsA.Cluster.ConnectRetries = 5 + srvA := RunServer(optsA) + defer srvA.Shutdown() + + optsB := nextServerOpts(optsA) + rcb := createRouteConn(t, optsSeed.Cluster.Host, optsSeed.Cluster.Port) + defer rcb.Close() + rcbID := "ServerB" + routeBSend, routeBExpect := setupRouteEx(t, rcb, optsB, rcbID) + routeBExpect(infoRe) + // register ourselves via INFO + rbInfo := server.Info{ID: rcbID, Host: optsB.Cluster.Host, Port: optsB.Cluster.Port} + b, _ := json.Marshal(rbInfo) + infoJSON := fmt.Sprintf(server.InfoProto, b) + routeBSend(infoJSON) + routeBSend("PING\r\n") + routeBExpect(pongRe) + + // srvA should try to connect. Wait to make sure that it fails. + time.Sleep(1200 * time.Millisecond) + + // Setup a fake route listen for routeB + rbListen, err := net.Listen("tcp", fmt.Sprintf("%s:%d", optsB.Cluster.Host, optsB.Cluster.Port)) + if err != nil { + t.Fatalf("Error during listen: %v", err) + } + c, err := rbListen.Accept() + if err != nil { + t.Fatalf("Error during accept: %v", err) + } + defer c.Close() + + br := bufio.NewReaderSize(c, 32768) + // Consume CONNECT and INFO + for i := 0; i < 2; i++ { + c.SetReadDeadline(time.Now().Add(2 * time.Second)) + buf, _, err := br.ReadLine() + c.SetReadDeadline(time.Time{}) + if err != nil { + t.Fatalf("Error reading: %v", err) + } + if i == 0 { + continue + } + buf = buf[len("INFO "):] + info := &server.Info{} + if err := json.Unmarshal(buf, info); err != nil { + t.Fatalf("Error during unmarshal: %v", err) + } + // Check INFO is from server A. + if info.ID != srvA.ID() { + t.Fatalf("Expected CONNECT from %v, got CONNECT from %v", srvA.ID(), info.ID) + } + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/routes_test.go b/vendor/github.com/nats-io/gnatsd/test/routes_test.go new file mode 100644 index 00000000..3328f649 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/routes_test.go @@ -0,0 +1,1052 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/nats-io/gnatsd/server" + "github.com/nats-io/go-nats" +) + +const clientProtoInfo = 1 + +func runRouteServer(t *testing.T) (*server.Server, *server.Options) { + return RunServerWithConfig("./configs/cluster.conf") +} + +func TestRouterListeningSocket(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + // Check that the cluster socket is able to be connected. + addr := fmt.Sprintf("%s:%d", opts.Cluster.Host, opts.Cluster.Port) + checkSocket(t, addr, 2*time.Second) +} + +func TestRouteGoServerShutdown(t *testing.T) { + base := runtime.NumGoroutine() + s, _ := runRouteServer(t) + s.Shutdown() + time.Sleep(50 * time.Millisecond) + delta := (runtime.NumGoroutine() - base) + if delta > 1 { + t.Fatalf("%d Go routines still exist post Shutdown()", delta) + } +} + +func TestSendRouteInfoOnConnect(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + rc := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer rc.Close() + + routeSend, routeExpect := setupRoute(t, rc, opts) + buf := routeExpect(infoRe) + + info := server.Info{} + if err := json.Unmarshal(buf[4:], &info); err != nil { + t.Fatalf("Could not unmarshal route info: %v", err) + } + + if !info.AuthRequired { + t.Fatal("Expected to see AuthRequired") + } + if info.Port != opts.Cluster.Port { + t.Fatalf("Received wrong information for port, expected %d, got %d", + info.Port, opts.Cluster.Port) + } + + // Need to send a different INFO than the one received, otherwise the server + // will detect as a "cycle" and close the connection. + info.ID = "RouteID" + b, err := json.Marshal(info) + if err != nil { + t.Fatalf("Could not marshal test route info: %v", err) + } + infoJSON := fmt.Sprintf("INFO %s\r\n", b) + routeSend(infoJSON) + routeSend("PING\r\n") + routeExpect(pongRe) +} + +func TestRouteToSelf(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + rc := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer rc.Close() + + routeSend, routeExpect := setupRouteEx(t, rc, opts, s.ID()) + buf := routeExpect(infoRe) + + info := server.Info{} + if err := json.Unmarshal(buf[4:], &info); err != nil { + t.Fatalf("Could not unmarshal route info: %v", err) + } + + if !info.AuthRequired { + t.Fatal("Expected to see AuthRequired") + } + if info.Port != opts.Cluster.Port { + t.Fatalf("Received wrong information for port, expected %d, got %d", + info.Port, opts.Cluster.Port) + } + + // Now send it back and that should be detected as a route to self and the + // connection closed. + routeSend(string(buf)) + routeSend("PING\r\n") + rc.SetReadDeadline(time.Now().Add(2 * time.Second)) + if _, err := rc.Read(buf); err == nil { + t.Fatal("Expected route connection to be closed") + } +} + +func TestSendRouteSubAndUnsub(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + c := createClientConn(t, opts.Host, opts.Port) + defer c.Close() + + send, _ := setupConn(t, c) + + // We connect to the route. + rc := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer rc.Close() + + expectAuthRequired(t, rc) + routeSend, routeExpect := setupRouteEx(t, rc, opts, "ROUTER:xyz") + routeSend("INFO {\"server_id\":\"ROUTER:xyz\"}\r\n") + + routeSend("PING\r\n") + routeExpect(pongRe) + + // Send SUB via client connection + send("SUB foo 22\r\n") + + // Make sure the SUB is broadcast via the route + buf := expectResult(t, rc, subRe) + matches := subRe.FindAllSubmatch(buf, -1) + rsid := string(matches[0][5]) + if !strings.HasPrefix(rsid, "RSID:") { + t.Fatalf("Got wrong RSID: %s\n", rsid) + } + + // Send UNSUB via client connection + send("UNSUB 22\r\n") + + // Make sure the SUB is broadcast via the route + buf = expectResult(t, rc, unsubRe) + matches = unsubRe.FindAllSubmatch(buf, -1) + rsid2 := string(matches[0][1]) + + if rsid2 != rsid { + t.Fatalf("Expected rsid's to match. %q vs %q\n", rsid, rsid2) + } + + // Explicitly shutdown the server, otherwise this test would + // cause following test to fail. + s.Shutdown() +} + +func TestSendRouteSolicit(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + // Listen for a connection from the server on the first route. + if len(opts.Routes) <= 0 { + t.Fatalf("Need an outbound solicted route for this test") + } + rURL := opts.Routes[0] + + conn := acceptRouteConn(t, rURL.Host, server.DEFAULT_ROUTE_CONNECT) + defer conn.Close() + + // We should receive a connect message right away due to auth. + buf := expectResult(t, conn, connectRe) + + // Check INFO follows. Could be inline, with first result, if not + // check follow-on buffer. + if !infoRe.Match(buf) { + expectResult(t, conn, infoRe) + } +} + +func TestRouteForwardsMsgFromClients(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + client := createClientConn(t, opts.Host, opts.Port) + defer client.Close() + + clientSend, clientExpect := setupConn(t, client) + + route := acceptRouteConn(t, opts.Routes[0].Host, server.DEFAULT_ROUTE_CONNECT) + defer route.Close() + + routeSend, routeExpect := setupRoute(t, route, opts) + expectMsgs := expectMsgsCommand(t, routeExpect) + + // Eat the CONNECT and INFO protos + buf := routeExpect(connectRe) + if !infoRe.Match(buf) { + routeExpect(infoRe) + } + + // Send SUB via route connection + routeSend("SUB foo RSID:2:22\r\n") + routeSend("PING\r\n") + routeExpect(pongRe) + + // Send PUB via client connection + clientSend("PUB foo 2\r\nok\r\n") + clientSend("PING\r\n") + clientExpect(pongRe) + + matches := expectMsgs(1) + checkMsg(t, matches[0], "foo", "RSID:2:22", "", "2", "ok") +} + +func TestRouteForwardsMsgToClients(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + client := createClientConn(t, opts.Host, opts.Port) + defer client.Close() + + clientSend, clientExpect := setupConn(t, client) + expectMsgs := expectMsgsCommand(t, clientExpect) + + route := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer route.Close() + expectAuthRequired(t, route) + routeSend, _ := setupRoute(t, route, opts) + + // Subscribe to foo + clientSend("SUB foo 1\r\n") + // Use ping roundtrip to make sure its processed. + clientSend("PING\r\n") + clientExpect(pongRe) + + // Send MSG proto via route connection + routeSend("MSG foo 1 2\r\nok\r\n") + + matches := expectMsgs(1) + checkMsg(t, matches[0], "foo", "1", "", "2", "ok") +} + +func TestRouteOneHopSemantics(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + route := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer route.Close() + + expectAuthRequired(t, route) + routeSend, _ := setupRoute(t, route, opts) + + // Express interest on this route for foo. + routeSend("SUB foo RSID:2:2\r\n") + + // Send MSG proto via route connection + routeSend("MSG foo 1 2\r\nok\r\n") + + // Make sure it does not come back! + expectNothing(t, route) +} + +func TestRouteOnlySendOnce(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + client := createClientConn(t, opts.Host, opts.Port) + defer client.Close() + + clientSend, clientExpect := setupConn(t, client) + + route := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer route.Close() + + expectAuthRequired(t, route) + routeSend, routeExpect := setupRoute(t, route, opts) + expectMsgs := expectMsgsCommand(t, routeExpect) + + // Express multiple interest on this route for foo. + routeSend("SUB foo RSID:2:1\r\n") + routeSend("SUB foo RSID:2:2\r\n") + routeSend("PING\r\n") + routeExpect(pongRe) + + // Send PUB via client connection + clientSend("PUB foo 2\r\nok\r\n") + clientSend("PING\r\n") + clientExpect(pongRe) + + expectMsgs(1) + routeSend("PING\r\n") + routeExpect(pongRe) +} + +func TestRouteQueueSemantics(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + client := createClientConn(t, opts.Host, opts.Port) + clientSend, clientExpect := setupConn(t, client) + clientExpectMsgs := expectMsgsCommand(t, clientExpect) + + defer client.Close() + + route := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer route.Close() + + expectAuthRequired(t, route) + routeSend, routeExpect := setupRouteEx(t, route, opts, "ROUTER:xyz") + routeSend("INFO {\"server_id\":\"ROUTER:xyz\"}\r\n") + expectMsgs := expectMsgsCommand(t, routeExpect) + + // Express multiple interest on this route for foo, queue group bar. + qrsid1 := "QRSID:1:1" + routeSend(fmt.Sprintf("SUB foo bar %s\r\n", qrsid1)) + qrsid2 := "QRSID:1:2" + routeSend(fmt.Sprintf("SUB foo bar %s\r\n", qrsid2)) + + // Use ping roundtrip to make sure its processed. + routeSend("PING\r\n") + routeExpect(pongRe) + + // Send PUB via client connection + clientSend("PUB foo 2\r\nok\r\n") + // Use ping roundtrip to make sure its processed. + clientSend("PING\r\n") + clientExpect(pongRe) + + // Only 1 + matches := expectMsgs(1) + checkMsg(t, matches[0], "foo", "", "", "2", "ok") + + // Add normal Interest as well to route interest. + routeSend("SUB foo RSID:1:4\r\n") + + // Use ping roundtrip to make sure its processed. + routeSend("PING\r\n") + routeExpect(pongRe) + + // Send PUB via client connection + clientSend("PUB foo 2\r\nok\r\n") + // Use ping roundtrip to make sure its processed. + clientSend("PING\r\n") + clientExpect(pongRe) + + // Should be 2 now, 1 for all normal, and one for specific queue subscriber. + matches = expectMsgs(2) + + // Expect first to be the normal subscriber, next will be the queue one. + if string(matches[0][sidIndex]) != "RSID:1:4" && + string(matches[1][sidIndex]) != "RSID:1:4" { + t.Fatalf("Did not received routed sid\n") + } + checkMsg(t, matches[0], "foo", "", "", "2", "ok") + checkMsg(t, matches[1], "foo", "", "", "2", "ok") + + // Check the rsid to verify it is one of the queue group subscribers. + var rsid string + if matches[0][sidIndex][0] == 'Q' { + rsid = string(matches[0][sidIndex]) + } else { + rsid = string(matches[1][sidIndex]) + } + if rsid != qrsid1 && rsid != qrsid2 { + t.Fatalf("Expected a queue group rsid, got %s\n", rsid) + } + + // Now create a queue subscription for the client as well as a normal one. + clientSend("SUB foo 1\r\n") + // Use ping roundtrip to make sure its processed. + clientSend("PING\r\n") + clientExpect(pongRe) + routeExpect(subRe) + + clientSend("SUB foo bar 2\r\n") + // Use ping roundtrip to make sure its processed. + clientSend("PING\r\n") + clientExpect(pongRe) + routeExpect(subRe) + + // Deliver a MSG from the route itself, make sure the client receives both. + routeSend("MSG foo RSID:1:1 2\r\nok\r\n") + // Queue group one. + routeSend("MSG foo QRSID:1:2 2\r\nok\r\n") + // Invlaid queue sid. + routeSend("MSG foo QRSID 2\r\nok\r\n") // cid and sid missing + routeSend("MSG foo QRSID:1 2\r\nok\r\n") // cid not terminated with ':' + routeSend("MSG foo QRSID:1: 2\r\nok\r\n") // cid==1 but sid missing. It needs to be at least one character. + + // Use ping roundtrip to make sure its processed. + routeSend("PING\r\n") + routeExpect(pongRe) + + // Should be 2 now, 1 for all normal, and one for specific queue subscriber. + matches = clientExpectMsgs(2) + + // Expect first to be the normal subscriber, next will be the queue one. + checkMsg(t, matches[0], "foo", "1", "", "2", "ok") + checkMsg(t, matches[1], "foo", "2", "", "2", "ok") +} + +func TestSolicitRouteReconnect(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + rURL := opts.Routes[0] + + route := acceptRouteConn(t, rURL.Host, 2*server.DEFAULT_ROUTE_CONNECT) + + // Go ahead and close the Route. + route.Close() + + // We expect to get called back.. + route = acceptRouteConn(t, rURL.Host, 2*server.DEFAULT_ROUTE_CONNECT) + route.Close() +} + +func TestMultipleRoutesSameId(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + route1 := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer route1.Close() + + expectAuthRequired(t, route1) + route1Send, _ := setupRouteEx(t, route1, opts, "ROUTE:2222") + + route2 := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer route2.Close() + + expectAuthRequired(t, route2) + route2Send, _ := setupRouteEx(t, route2, opts, "ROUTE:2222") + + // Send SUB via route connections + sub := "SUB foo RSID:2:22\r\n" + route1Send(sub) + route2Send(sub) + + // Make sure we do not get anything on a MSG send to a router. + // Send MSG proto via route connection + route1Send("MSG foo 1 2\r\nok\r\n") + + expectNothing(t, route1) + expectNothing(t, route2) + + // Setup a client + client := createClientConn(t, opts.Host, opts.Port) + clientSend, clientExpect := setupConn(t, client) + defer client.Close() + + // Send PUB via client connection + clientSend("PUB foo 2\r\nok\r\n") + clientSend("PING\r\n") + clientExpect(pongRe) + + // We should only receive on one route, not both. + // Check both manually. + route1.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + buf, _ := ioutil.ReadAll(route1) + route1.SetReadDeadline(time.Time{}) + if len(buf) <= 0 { + route2.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + buf, _ = ioutil.ReadAll(route2) + route2.SetReadDeadline(time.Time{}) + if len(buf) <= 0 { + t.Fatal("Expected to get one message on a route, received none.") + } + } + + matches := msgRe.FindAllSubmatch(buf, -1) + if len(matches) != 1 { + t.Fatalf("Expected 1 msg, got %d\n", len(matches)) + } + checkMsg(t, matches[0], "foo", "", "", "2", "ok") +} + +func TestRouteResendsLocalSubsOnReconnect(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + client := createClientConn(t, opts.Host, opts.Port) + defer client.Close() + + clientSend, clientExpect := setupConn(t, client) + + // Setup a local subscription, make sure it reaches. + clientSend("SUB foo 1\r\n") + clientSend("PING\r\n") + clientExpect(pongRe) + + route := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer route.Close() + routeSend, routeExpect := setupRouteEx(t, route, opts, "ROUTE:1234") + + // Expect to see the local sub echoed through after we send our INFO. + time.Sleep(50 * time.Millisecond) + buf := routeExpect(infoRe) + + // Generate our own INFO so we can send one to trigger the local subs. + info := server.Info{} + if err := json.Unmarshal(buf[4:], &info); err != nil { + t.Fatalf("Could not unmarshal route info: %v", err) + } + info.ID = "ROUTE:1234" + b, err := json.Marshal(info) + if err != nil { + t.Fatalf("Could not marshal test route info: %v", err) + } + infoJSON := fmt.Sprintf("INFO %s\r\n", b) + + // Trigger the send of local subs. + routeSend(infoJSON) + + routeExpect(subRe) + + // Close and then re-open + route.Close() + + route = createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + defer route.Close() + + routeSend, routeExpect = setupRouteEx(t, route, opts, "ROUTE:1234") + + routeExpect(infoRe) + + routeSend(infoJSON) + routeExpect(subRe) +} + +type ignoreLogger struct{} + +func (l *ignoreLogger) Fatalf(f string, args ...interface{}) {} +func (l *ignoreLogger) Errorf(f string, args ...interface{}) {} + +func TestRouteConnectOnShutdownRace(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + l := &ignoreLogger{} + + var wg sync.WaitGroup + + cQuit := make(chan bool, 1) + + wg.Add(1) + + go func() { + defer wg.Done() + for { + route := createRouteConn(l, opts.Cluster.Host, opts.Cluster.Port) + if route != nil { + setupRouteEx(l, route, opts, "ROUTE:1234") + route.Close() + } + select { + case <-cQuit: + return + default: + } + } + }() + + time.Sleep(5 * time.Millisecond) + s.Shutdown() + + cQuit <- true + + wg.Wait() +} + +func TestRouteSendAsyncINFOToClients(t *testing.T) { + f := func(opts *server.Options) { + s := RunServer(opts) + defer s.Shutdown() + + clientURL := net.JoinHostPort(opts.Host, strconv.Itoa(opts.Port)) + + oldClient := createClientConn(t, opts.Host, opts.Port) + defer oldClient.Close() + + oldClientSend, oldClientExpect := setupConn(t, oldClient) + oldClientSend("PING\r\n") + oldClientExpect(pongRe) + + newClient := createClientConn(t, opts.Host, opts.Port) + defer newClient.Close() + + newClientSend, newClientExpect := setupConnWithProto(t, newClient, clientProtoInfo) + newClientSend("PING\r\n") + newClientExpect(pongRe) + + // Check that even a new client does not receive an async INFO at this point + // since there is no route created yet. + expectNothing(t, newClient) + + routeID := "Server-B" + + createRoute := func() (net.Conn, sendFun, expectFun) { + rc := createRouteConn(t, opts.Cluster.Host, opts.Cluster.Port) + routeSend, routeExpect := setupRouteEx(t, rc, opts, routeID) + + buf := routeExpect(infoRe) + + info := server.Info{} + if err := json.Unmarshal(buf[4:], &info); err != nil { + stackFatalf(t, "Could not unmarshal route info: %v", err) + } + if opts.Cluster.NoAdvertise { + if len(info.ClientConnectURLs) != 0 { + stackFatalf(t, "Expected ClientConnectURLs to be empty, got %v", info.ClientConnectURLs) + } + } else { + if len(info.ClientConnectURLs) == 0 { + stackFatalf(t, "Expected a list of URLs, got none") + } + if info.ClientConnectURLs[0] != clientURL { + stackFatalf(t, "Expected ClientConnectURLs to be %q, got %q", clientURL, info.ClientConnectURLs[0]) + } + } + + return rc, routeSend, routeExpect + } + + sendRouteINFO := func(routeSend sendFun, routeExpect expectFun, urls []string) { + routeInfo := server.Info{} + routeInfo.ID = routeID + routeInfo.Host = "127.0.0.1" + routeInfo.Port = 5222 + routeInfo.ClientConnectURLs = urls + b, err := json.Marshal(routeInfo) + if err != nil { + stackFatalf(t, "Could not marshal test route info: %v", err) + } + infoJSON := fmt.Sprintf("INFO %s\r\n", b) + routeSend(infoJSON) + routeSend("PING\r\n") + routeExpect(pongRe) + } + + checkClientConnectURLS := func(urls, expected []string) { + // Order of array is not guaranteed. + ok := false + if len(urls) == len(expected) { + m := make(map[string]struct{}, len(expected)) + for _, url := range expected { + m[url] = struct{}{} + } + ok = true + for _, url := range urls { + if _, present := m[url]; !present { + ok = false + break + } + } + } + if !ok { + stackFatalf(t, "Expected ClientConnectURLs to be %v, got %v", expected, urls) + } + } + + checkINFOReceived := func(client net.Conn, clientExpect expectFun, expectedURLs []string) { + if opts.Cluster.NoAdvertise { + expectNothing(t, client) + return + } + buf := clientExpect(infoRe) + info := server.Info{} + if err := json.Unmarshal(buf[4:], &info); err != nil { + stackFatalf(t, "Could not unmarshal route info: %v", err) + } + checkClientConnectURLS(info.ClientConnectURLs, expectedURLs) + } + + // Create a route + rc, routeSend, routeExpect := createRoute() + defer rc.Close() + + // Send an INFO with single URL + routeClientConnectURLs := []string{"127.0.0.1:5222"} + sendRouteINFO(routeSend, routeExpect, routeClientConnectURLs) + + // Expect nothing for old clients + expectNothing(t, oldClient) + + // We expect to get the one from the server we connect to and the other route. + expectedURLs := []string{clientURL, routeClientConnectURLs[0]} + + // Expect new client to receive an INFO (unless disabled) + checkINFOReceived(newClient, newClientExpect, expectedURLs) + + // Disconnect the route + rc.Close() + + // Expect nothing for old clients + expectNothing(t, oldClient) + + // Expect new client to receive an INFO (unless disabled). + // The content will now have the disconnected route ClientConnectURLs + // removed from the INFO. So it should be the one from the server the + // client is connected to. + checkINFOReceived(newClient, newClientExpect, []string{clientURL}) + + // Reconnect the route. + rc, routeSend, routeExpect = createRoute() + defer rc.Close() + + // Resend the same route INFO json. The server will now send + // the INFO since the disconnected route ClientConnectURLs was + // removed in previous step. + sendRouteINFO(routeSend, routeExpect, routeClientConnectURLs) + + // Expect nothing for old clients + expectNothing(t, oldClient) + + // Expect new client to receive an INFO (unless disabled) + checkINFOReceived(newClient, newClientExpect, expectedURLs) + + // Now stop the route and restart with an additional URL + rc.Close() + + // On route disconnect, clients will receive an updated INFO + expectNothing(t, oldClient) + checkINFOReceived(newClient, newClientExpect, []string{clientURL}) + + rc, routeSend, routeExpect = createRoute() + defer rc.Close() + + // Create a client not sending the CONNECT until after route is added + clientNoConnect := createClientConn(t, opts.Host, opts.Port) + defer clientNoConnect.Close() + + // Create a client that does not send the first PING yet + clientNoPing := createClientConn(t, opts.Host, opts.Port) + defer clientNoPing.Close() + clientNoPingSend, clientNoPingExpect := setupConnWithProto(t, clientNoPing, clientProtoInfo) + + // The route now has an additional URL + routeClientConnectURLs = append(routeClientConnectURLs, "127.0.0.1:7777") + expectedURLs = append(expectedURLs, "127.0.0.1:7777") + // This causes the server to add the route and send INFO to clients + sendRouteINFO(routeSend, routeExpect, routeClientConnectURLs) + + // Expect nothing for old clients + expectNothing(t, oldClient) + + // Expect new client to receive an INFO, and verify content as expected. + checkINFOReceived(newClient, newClientExpect, expectedURLs) + + // Expect nothing yet for client that did not send the PING + expectNothing(t, clientNoPing) + + // Now send the first PING + clientNoPingSend("PING\r\n") + // Should receive PONG followed by INFO + // Receive PONG only first + pongBuf := make([]byte, len("PONG\r\n")) + clientNoPing.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := clientNoPing.Read(pongBuf) + clientNoPing.SetReadDeadline(time.Time{}) + if n <= 0 && err != nil { + t.Fatalf("Error reading from conn: %v\n", err) + } + if !pongRe.Match(pongBuf) { + t.Fatalf("Response did not match expected: \n\tReceived:'%q'\n\tExpected:'%s'\n", pongBuf, pongRe) + } + checkINFOReceived(clientNoPing, clientNoPingExpect, expectedURLs) + + // Have the client that did not send the connect do it now + clientNoConnectSend, clientNoConnectExpect := setupConnWithProto(t, clientNoConnect, clientProtoInfo) + // Send the PING + clientNoConnectSend("PING\r\n") + // Should receive PONG followed by INFO + // Receive PONG only first + clientNoConnect.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err = clientNoConnect.Read(pongBuf) + clientNoConnect.SetReadDeadline(time.Time{}) + if n <= 0 && err != nil { + t.Fatalf("Error reading from conn: %v\n", err) + } + if !pongRe.Match(pongBuf) { + t.Fatalf("Response did not match expected: \n\tReceived:'%q'\n\tExpected:'%s'\n", pongBuf, pongRe) + } + checkINFOReceived(clientNoConnect, clientNoConnectExpect, expectedURLs) + + // Create a client connection and verify content of initial INFO contains array + // (but empty if no advertise option is set) + cli := createClientConn(t, opts.Host, opts.Port) + defer cli.Close() + buf := expectResult(t, cli, infoRe) + js := infoRe.FindAllSubmatch(buf, 1)[0][1] + var sinfo server.Info + err = json.Unmarshal(js, &sinfo) + if err != nil { + t.Fatalf("Could not unmarshal INFO json: %v\n", err) + } + if opts.Cluster.NoAdvertise { + if len(sinfo.ClientConnectURLs) != 0 { + t.Fatalf("Expected ClientConnectURLs to be empty, got %v", sinfo.ClientConnectURLs) + } + } else { + checkClientConnectURLS(sinfo.ClientConnectURLs, expectedURLs) + } + + // Add a new route + routeID = "Server-C" + rc2, route2Send, route2Expect := createRoute() + defer rc2.Close() + + // Send an INFO with single URL + rc2ConnectURLs := []string{"127.0.0.1:8888"} + sendRouteINFO(route2Send, route2Expect, rc2ConnectURLs) + + // This is the combined client connect URLs array + totalConnectURLs := expectedURLs + totalConnectURLs = append(totalConnectURLs, rc2ConnectURLs...) + + // Expect nothing for old clients + expectNothing(t, oldClient) + + // Expect new client to receive an INFO (unless disabled) + checkINFOReceived(newClient, newClientExpect, totalConnectURLs) + + // Make first route disconnect + rc.Close() + + // Expect nothing for old clients + expectNothing(t, oldClient) + + // Expect new client to receive an INFO (unless disabled) + // The content should be the server client is connected to and the last route + checkINFOReceived(newClient, newClientExpect, []string{"127.0.0.1:5242", "127.0.0.1:8888"}) + } + + opts := LoadConfig("./configs/cluster.conf") + // For this test, be explicit about listen spec. + opts.Host = "127.0.0.1" + opts.Port = 5242 + + f(opts) + opts.Cluster.NoAdvertise = true + f(opts) +} + +func TestRouteBasicPermissions(t *testing.T) { + srvA, optsA := RunServerWithConfig("./configs/srv_a_perms.conf") + defer srvA.Shutdown() + + srvB, optsB := RunServerWithConfig("./configs/srv_b.conf") + defer srvB.Shutdown() + + checkClusterFormed(t, srvA, srvB) + + // Create a connection to server B + ncb, err := nats.Connect(fmt.Sprintf("nats://127.0.0.1:%d", optsB.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer ncb.Close() + ch := make(chan bool, 1) + cb := func(_ *nats.Msg) { + ch <- true + } + // Subscribe on on "bar" and "baz", which should be accepted by server A + subBbar, err := ncb.Subscribe("bar", cb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + defer subBbar.Unsubscribe() + subBbaz, err := ncb.Subscribe("baz", cb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + defer subBbaz.Unsubscribe() + ncb.Flush() + if err := checkExpectedSubs(2, srvA, srvB); err != nil { + t.Fatal(err.Error()) + } + + // Create a connection to server A + nca, err := nats.Connect(fmt.Sprintf("nats://127.0.0.1:%d", optsA.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nca.Close() + // Publish on bar and baz, messages should be received. + if err := nca.Publish("bar", []byte("hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + if err := nca.Publish("baz", []byte("hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + for i := 0; i < 2; i++ { + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("Did not get the messages") + } + } + + // From B, start a subscription on "foo", which server A should drop since + // it only exports on "bar" and "baz" + subBfoo, err := ncb.Subscribe("foo", cb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + defer subBfoo.Unsubscribe() + ncb.Flush() + // B should have now 3 subs + if err := checkExpectedSubs(3, srvB); err != nil { + t.Fatal(err.Error()) + } + // and A still 2. + if err := checkExpectedSubs(2, srvA); err != nil { + t.Fatal(err.Error()) + } + // So producing on "foo" from A should not be forwarded to B. + if err := nca.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + select { + case <-ch: + t.Fatal("Message should not have been received") + case <-time.After(100 * time.Millisecond): + } + + // Now on A, create a subscription on something that A does not import, + // like "bat". + subAbat, err := nca.Subscribe("bat", cb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + defer subAbat.Unsubscribe() + nca.Flush() + // A should have 3 subs + if err := checkExpectedSubs(3, srvA); err != nil { + t.Fatal(err.Error()) + } + // And from B, send a message on that subject and make sure it is not received. + if err := ncb.Publish("bat", []byte("hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + select { + case <-ch: + t.Fatal("Message should not have been received") + case <-time.After(100 * time.Millisecond): + } + + // Stop subscription on foo from B + subBfoo.Unsubscribe() + ncb.Flush() + // Back to 2 subs on B + if err := checkExpectedSubs(2, srvB); err != nil { + t.Fatal(err.Error()) + } + + // Create subscription on foo from A, this should be forwared to B. + subAfoo, err := nca.Subscribe("foo", cb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + defer subAfoo.Unsubscribe() + // Create another one so that test the import permissions cache + subAfoo2, err := nca.Subscribe("foo", cb) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + defer subAfoo2.Unsubscribe() + nca.Flush() + // A should have 5 subs + if err := checkExpectedSubs(5, srvA); err != nil { + t.Fatal(err.Error()) + } + // B should have 4 + if err := checkExpectedSubs(4, srvB); err != nil { + t.Fatal(err.Error()) + } + // Send a message from B and check that it is received. + if err := ncb.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + for i := 0; i < 2; i++ { + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("Did not get the message") + } + } + + // Close connection from B, and restart server B too. + // We want to make sure that + ncb.Close() + srvB.Shutdown() + + // Since B had 2 local subs, A should go from 5 to 3 + if err := checkExpectedSubs(3, srvA); err != nil { + t.Fatal(err.Error()) + } + + // Restart server B + srvB, optsB = RunServerWithConfig("./configs/srv_b.conf") + defer srvB.Shutdown() + // Check that subs from A that can be sent to B are sent. + // That would be 2 (the 2 subscriptions on foo). + if err := checkExpectedSubs(2, srvB); err != nil { + t.Fatal(err.Error()) + } + + // Connect to B and send on "foo" and make sure we receive + ncb, err = nats.Connect(fmt.Sprintf("nats://127.0.0.1:%d", optsB.Port)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer ncb.Close() + if err := ncb.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + for i := 0; i < 2; i++ { + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("Did not get the message") + } + } + + // Send on "bat" and make sure that this is not received. + if err := ncb.Publish("bat", []byte("hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + select { + case <-ch: + t.Fatal("Message should not have been received") + case <-time.After(100 * time.Millisecond): + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/test.go b/vendor/github.com/nats-io/gnatsd/test/test.go new file mode 100644 index 00000000..43684528 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/test.go @@ -0,0 +1,378 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net" + "regexp" + "runtime" + "strings" + "time" + + "github.com/nats-io/gnatsd/server" +) + +// So we can pass tests and benchmarks.. +type tLogger interface { + Fatalf(format string, args ...interface{}) + Errorf(format string, args ...interface{}) +} + +// DefaultTestOptions are default options for the unit tests. +var DefaultTestOptions = server.Options{ + Host: "127.0.0.1", + Port: 4222, + NoLog: true, + NoSigs: true, + MaxControlLine: 256, +} + +// RunDefaultServer starts a new Go routine based server using the default options +func RunDefaultServer() *server.Server { + return RunServer(&DefaultTestOptions) +} + +// RunServer starts a new Go routine based server +func RunServer(opts *server.Options) *server.Server { + if opts == nil { + opts = &DefaultTestOptions + } + s := server.New(opts) + if s == nil { + panic("No NATS Server object returned.") + } + + // Run server in Go routine. + go s.Start() + + // Wait for accept loop(s) to be started + if !s.ReadyForConnections(10 * time.Second) { + panic("Unable to start NATS Server in Go Routine") + } + return s +} + +// LoadConfig loads a configuration from a filename +func LoadConfig(configFile string) (opts *server.Options) { + opts, err := server.ProcessConfigFile(configFile) + if err != nil { + panic(fmt.Sprintf("Error processing configuration file: %v", err)) + } + opts.NoSigs, opts.NoLog = true, true + return +} + +// RunServerWithConfig starts a new Go routine based server with a configuration file. +func RunServerWithConfig(configFile string) (srv *server.Server, opts *server.Options) { + opts = LoadConfig(configFile) + srv = RunServer(opts) + return +} + +func stackFatalf(t tLogger, f string, args ...interface{}) { + lines := make([]string, 0, 32) + msg := fmt.Sprintf(f, args...) + lines = append(lines, msg) + + // Ignore ourselves + _, testFile, _, _ := runtime.Caller(0) + + // Generate the Stack of callers: + for i := 0; true; i++ { + _, file, line, ok := runtime.Caller(i) + if !ok { + break + } + if file == testFile { + continue + } + msg := fmt.Sprintf("%d - %s:%d", i, file, line) + lines = append(lines, msg) + } + + t.Fatalf("%s", strings.Join(lines, "\n")) +} + +func acceptRouteConn(t tLogger, host string, timeout time.Duration) net.Conn { + l, e := net.Listen("tcp", host) + if e != nil { + stackFatalf(t, "Error listening for route connection on %v: %v", host, e) + } + defer l.Close() + + tl := l.(*net.TCPListener) + tl.SetDeadline(time.Now().Add(timeout)) + conn, err := l.Accept() + tl.SetDeadline(time.Time{}) + + if err != nil { + stackFatalf(t, "Did not receive a route connection request: %v", err) + } + return conn +} + +func createRouteConn(t tLogger, host string, port int) net.Conn { + return createClientConn(t, host, port) +} + +func createClientConn(t tLogger, host string, port int) net.Conn { + addr := fmt.Sprintf("%s:%d", host, port) + c, err := net.DialTimeout("tcp", addr, 3*time.Second) + if err != nil { + stackFatalf(t, "Could not connect to server: %v\n", err) + } + return c +} + +func checkSocket(t tLogger, addr string, wait time.Duration) { + end := time.Now().Add(wait) + for time.Now().Before(end) { + conn, err := net.Dial("tcp", addr) + if err != nil { + // Retry after 50ms + time.Sleep(50 * time.Millisecond) + continue + } + conn.Close() + // Wait a bit to give a chance to the server to remove this + // "client" from its state, which may otherwise interfere with + // some tests. + time.Sleep(25 * time.Millisecond) + return + } + // We have failed to bind the socket in the time allowed. + t.Fatalf("Failed to connect to the socket: %q", addr) +} + +func checkInfoMsg(t tLogger, c net.Conn) server.Info { + buf := expectResult(t, c, infoRe) + js := infoRe.FindAllSubmatch(buf, 1)[0][1] + var sinfo server.Info + err := json.Unmarshal(js, &sinfo) + if err != nil { + stackFatalf(t, "Could not unmarshal INFO json: %v\n", err) + } + return sinfo +} + +func doConnect(t tLogger, c net.Conn, verbose, pedantic, ssl bool) { + checkInfoMsg(t, c) + cs := fmt.Sprintf("CONNECT {\"verbose\":%v,\"pedantic\":%v,\"tls_required\":%v}\r\n", verbose, pedantic, ssl) + sendProto(t, c, cs) +} + +func doDefaultConnect(t tLogger, c net.Conn) { + // Basic Connect + doConnect(t, c, false, false, false) +} + +const connectProto = "CONNECT {\"verbose\":false,\"user\":\"%s\",\"pass\":\"%s\",\"name\":\"%s\"}\r\n" + +func doRouteAuthConnect(t tLogger, c net.Conn, user, pass, id string) { + cs := fmt.Sprintf(connectProto, user, pass, id) + sendProto(t, c, cs) +} + +func setupRouteEx(t tLogger, c net.Conn, opts *server.Options, id string) (sendFun, expectFun) { + user := opts.Cluster.Username + pass := opts.Cluster.Password + doRouteAuthConnect(t, c, user, pass, id) + return sendCommand(t, c), expectCommand(t, c) +} + +func setupRoute(t tLogger, c net.Conn, opts *server.Options) (sendFun, expectFun) { + u := make([]byte, 16) + io.ReadFull(rand.Reader, u) + id := fmt.Sprintf("ROUTER:%s", hex.EncodeToString(u)) + return setupRouteEx(t, c, opts, id) +} + +func setupConn(t tLogger, c net.Conn) (sendFun, expectFun) { + doDefaultConnect(t, c) + return sendCommand(t, c), expectCommand(t, c) +} + +func setupConnWithProto(t tLogger, c net.Conn, proto int) (sendFun, expectFun) { + checkInfoMsg(t, c) + cs := fmt.Sprintf("CONNECT {\"verbose\":%v,\"pedantic\":%v,\"tls_required\":%v,\"protocol\":%d}\r\n", false, false, false, proto) + sendProto(t, c, cs) + return sendCommand(t, c), expectCommand(t, c) +} + +type sendFun func(string) +type expectFun func(*regexp.Regexp) []byte + +// Closure version for easier reading +func sendCommand(t tLogger, c net.Conn) sendFun { + return func(op string) { + sendProto(t, c, op) + } +} + +// Closure version for easier reading +func expectCommand(t tLogger, c net.Conn) expectFun { + return func(re *regexp.Regexp) []byte { + return expectResult(t, c, re) + } +} + +// Send the protocol command to the server. +func sendProto(t tLogger, c net.Conn, op string) { + n, err := c.Write([]byte(op)) + if err != nil { + stackFatalf(t, "Error writing command to conn: %v\n", err) + } + if n != len(op) { + stackFatalf(t, "Partial write: %d vs %d\n", n, len(op)) + } +} + +var ( + infoRe = regexp.MustCompile(`INFO\s+([^\r\n]+)\r\n`) + pingRe = regexp.MustCompile(`PING\r\n`) + pongRe = regexp.MustCompile(`PONG\r\n`) + msgRe = regexp.MustCompile(`(?:(?:MSG\s+([^\s]+)\s+([^\s]+)\s+(([^\s]+)[^\S\r\n]+)?(\d+)\s*\r\n([^\\r\\n]*?)\r\n)+?)`) + okRe = regexp.MustCompile(`\A\+OK\r\n`) + errRe = regexp.MustCompile(`\A\-ERR\s+([^\r\n]+)\r\n`) + subRe = regexp.MustCompile(`SUB\s+([^\s]+)((\s+)([^\s]+))?\s+([^\s]+)\r\n`) + unsubRe = regexp.MustCompile(`UNSUB\s+([^\s]+)(\s+(\d+))?\r\n`) + connectRe = regexp.MustCompile(`CONNECT\s+([^\r\n]+)\r\n`) +) + +const ( + subIndex = 1 + sidIndex = 2 + replyIndex = 4 + lenIndex = 5 + msgIndex = 6 +) + +// Test result from server against regexp +func expectResult(t tLogger, c net.Conn, re *regexp.Regexp) []byte { + expBuf := make([]byte, 32768) + // Wait for commands to be processed and results queued for read + c.SetReadDeadline(time.Now().Add(5 * time.Second)) + n, err := c.Read(expBuf) + c.SetReadDeadline(time.Time{}) + + if n <= 0 && err != nil { + stackFatalf(t, "Error reading from conn: %v\n", err) + } + buf := expBuf[:n] + + if !re.Match(buf) { + stackFatalf(t, "Response did not match expected: \n\tReceived:'%q'\n\tExpected:'%s'\n", buf, re) + } + return buf +} + +func expectNothing(t tLogger, c net.Conn) { + expBuf := make([]byte, 32) + c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + n, err := c.Read(expBuf) + c.SetReadDeadline(time.Time{}) + if err == nil && n > 0 { + stackFatalf(t, "Expected nothing, received: '%q'\n", expBuf[:n]) + } +} + +// This will check that we got what we expected. +func checkMsg(t tLogger, m [][]byte, subject, sid, reply, len, msg string) { + if string(m[subIndex]) != subject { + stackFatalf(t, "Did not get correct subject: expected '%s' got '%s'\n", subject, m[subIndex]) + } + if sid != "" && string(m[sidIndex]) != sid { + stackFatalf(t, "Did not get correct sid: expected '%s' got '%s'\n", sid, m[sidIndex]) + } + if string(m[replyIndex]) != reply { + stackFatalf(t, "Did not get correct reply: expected '%s' got '%s'\n", reply, m[replyIndex]) + } + if string(m[lenIndex]) != len { + stackFatalf(t, "Did not get correct msg length: expected '%s' got '%s'\n", len, m[lenIndex]) + } + if string(m[msgIndex]) != msg { + stackFatalf(t, "Did not get correct msg: expected '%s' got '%s'\n", msg, m[msgIndex]) + } +} + +// Closure for expectMsgs +func expectMsgsCommand(t tLogger, ef expectFun) func(int) [][][]byte { + return func(expected int) [][][]byte { + buf := ef(msgRe) + matches := msgRe.FindAllSubmatch(buf, -1) + if len(matches) != expected { + stackFatalf(t, "Did not get correct # msgs: %d vs %d\n", len(matches), expected) + } + return matches + } +} + +// This will check that the matches include at least one of the sids. Useful for checking +// that we received messages on a certain queue group. +func checkForQueueSid(t tLogger, matches [][][]byte, sids []string) { + seen := make(map[string]int, len(sids)) + for _, sid := range sids { + seen[sid] = 0 + } + for _, m := range matches { + sid := string(m[sidIndex]) + if _, ok := seen[sid]; ok { + seen[sid]++ + } + } + // Make sure we only see one and exactly one. + total := 0 + for _, n := range seen { + total += n + } + if total != 1 { + stackFatalf(t, "Did not get a msg for queue sids group: expected 1 got %d\n", total) + } +} + +// This will check that the matches include all of the sids. Useful for checking +// that we received messages on all subscribers. +func checkForPubSids(t tLogger, matches [][][]byte, sids []string) { + seen := make(map[string]int, len(sids)) + for _, sid := range sids { + seen[sid] = 0 + } + for _, m := range matches { + sid := string(m[sidIndex]) + if _, ok := seen[sid]; ok { + seen[sid]++ + } + } + // Make sure we only see one and exactly one for each sid. + for sid, n := range seen { + if n != 1 { + stackFatalf(t, "Did not get a msg for sid[%s]: expected 1 got %d\n", sid, n) + + } + } +} + +// Helper function to generate next opts to make sure no port conflicts etc. +func nextServerOpts(opts *server.Options) *server.Options { + nopts := opts.Clone() + nopts.Port++ + nopts.Cluster.Port++ + nopts.HTTPPort++ + return nopts +} diff --git a/vendor/github.com/nats-io/gnatsd/test/test_test.go b/vendor/github.com/nats-io/gnatsd/test/test_test.go new file mode 100644 index 00000000..ac041747 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/test_test.go @@ -0,0 +1,72 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "fmt" + "strings" + "sync" + "testing" + "time" +) + +func checkFor(t *testing.T, totalWait, sleepDur time.Duration, f func() error) { + t.Helper() + timeout := time.Now().Add(totalWait) + var err error + for time.Now().Before(timeout) { + err = f() + if err == nil { + return + } + time.Sleep(sleepDur) + } + if err != nil { + t.Fatal(err.Error()) + } +} + +type dummyLogger struct { + sync.Mutex + msg string +} + +func (d *dummyLogger) Fatalf(format string, args ...interface{}) { + d.Lock() + d.msg = fmt.Sprintf(format, args...) + d.Unlock() +} + +func (d *dummyLogger) Errorf(format string, args ...interface{}) { +} + +func (d *dummyLogger) Debugf(format string, args ...interface{}) { +} + +func (d *dummyLogger) Tracef(format string, args ...interface{}) { +} + +func (d *dummyLogger) Noticef(format string, args ...interface{}) { +} + +func TestStackFatal(t *testing.T) { + d := &dummyLogger{} + stackFatalf(d, "test stack %d", 1) + if !strings.HasPrefix(d.msg, "test stack 1") { + t.Fatalf("Unexpected start of stack: %v", d.msg) + } + if !strings.Contains(d.msg, "test_test.go") { + t.Fatalf("Unexpected stack: %v", d.msg) + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/tls_test.go b/vendor/github.com/nats-io/gnatsd/test/tls_test.go new file mode 100644 index 00000000..38a90f81 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/tls_test.go @@ -0,0 +1,334 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "bufio" + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "strings" + "sync" + "testing" + "time" + + "github.com/nats-io/gnatsd/server" + "github.com/nats-io/go-nats" +) + +func TestTLSConnection(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/tls.conf") + defer srv.Shutdown() + + endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + nurl := fmt.Sprintf("tls://%s:%s@%s/", opts.Username, opts.Password, endpoint) + nc, err := nats.Connect(nurl) + if err == nil { + nc.Close() + t.Fatalf("Expected error trying to connect to secure server") + } + + // Do simple SecureConnect + nc, err = nats.Connect(fmt.Sprintf("tls://%s/", endpoint)) + if err == nil { + nc.Close() + t.Fatalf("Expected error trying to connect to secure server with no auth") + } + + // Now do more advanced checking, verifying servername and using rootCA. + + nc, err = nats.Connect(nurl, nats.RootCAs("./configs/certs/ca.pem")) + if err != nil { + t.Fatalf("Got an error on Connect with Secure Options: %+v\n", err) + } + defer nc.Close() + + subj := "foo-tls" + sub, _ := nc.SubscribeSync(subj) + + nc.Publish(subj, []byte("We are Secure!")) + nc.Flush() + nmsgs, _ := sub.QueuedMsgs() + if nmsgs != 1 { + t.Fatalf("Expected to receive a message over the TLS connection") + } +} + +func TestTLSClientCertificate(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/tlsverify.conf") + defer srv.Shutdown() + + nurl := fmt.Sprintf("tls://%s:%d", opts.Host, opts.Port) + + _, err := nats.Connect(nurl) + if err == nil { + t.Fatalf("Expected error trying to connect to secure server without a certificate") + } + + // Load client certificate to successfully connect. + certFile := "./configs/certs/client-cert.pem" + keyFile := "./configs/certs/client-key.pem" + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + t.Fatalf("error parsing X509 certificate/key pair: %v", err) + } + + // Load in root CA for server verification + rootPEM, err := ioutil.ReadFile("./configs/certs/ca.pem") + if err != nil || rootPEM == nil { + t.Fatalf("failed to read root certificate") + } + pool := x509.NewCertPool() + ok := pool.AppendCertsFromPEM([]byte(rootPEM)) + if !ok { + t.Fatalf("failed to parse root certificate") + } + + config := &tls.Config{ + Certificates: []tls.Certificate{cert}, + ServerName: opts.Host, + RootCAs: pool, + MinVersion: tls.VersionTLS12, + } + + copts := nats.GetDefaultOptions() + copts.Url = nurl + copts.Secure = true + copts.TLSConfig = config + + nc, err := copts.Connect() + if err != nil { + t.Fatalf("Got an error on Connect with Secure Options: %+v\n", err) + } + nc.Flush() + defer nc.Close() +} + +func TestTLSVerifyClientCertificate(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/tlsverify_noca.conf") + defer srv.Shutdown() + + nurl := fmt.Sprintf("tls://%s:%d", opts.Host, opts.Port) + + // The client is configured properly, but the server has no CA + // to verify the client certificate. Connection should fail. + nc, err := nats.Connect(nurl, + nats.ClientCert("./configs/certs/client-cert.pem", "./configs/certs/client-key.pem"), + nats.RootCAs("./configs/certs/ca.pem")) + if err == nil { + nc.Close() + t.Fatal("Expected failure to connect, did not") + } +} + +func TestTLSConnectionTimeout(t *testing.T) { + opts := LoadConfig("./configs/tls.conf") + opts.TLSTimeout = 0.25 + + srv := RunServer(opts) + defer srv.Shutdown() + + // Dial with normal TCP + endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + conn, err := net.Dial("tcp", endpoint) + if err != nil { + t.Fatalf("Could not connect to %q", endpoint) + } + defer conn.Close() + + // Read deadlines + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + + // Read the INFO string. + br := bufio.NewReader(conn) + info, err := br.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read INFO - %v", err) + } + if !strings.HasPrefix(info, "INFO ") { + t.Fatalf("INFO response incorrect: %s\n", info) + } + wait := time.Duration(opts.TLSTimeout * float64(time.Second)) + time.Sleep(wait) + // Read deadlines + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + tlsErr, err := br.ReadString('\n') + if err == nil && !strings.Contains(tlsErr, "-ERR 'Secure Connection - TLS Required") { + t.Fatalf("TLS Timeout response incorrect: %q\n", tlsErr) + } +} + +// Ensure there is no race between authorization timeout and TLS handshake. +func TestTLSAuthorizationShortTimeout(t *testing.T) { + opts := LoadConfig("./configs/tls.conf") + opts.AuthTimeout = 0.001 + + srv := RunServer(opts) + defer srv.Shutdown() + + endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + nurl := fmt.Sprintf("tls://%s:%s@%s/", opts.Username, opts.Password, endpoint) + + // Expect an error here (no CA) but not a TLS oversized record error which + // indicates the authorization timeout fired too soon. + _, err := nats.Connect(nurl) + if err == nil { + t.Fatal("Expected error trying to connect to secure server") + } + if strings.Contains(err.Error(), "oversized record") { + t.Fatal("Corrupted TLS handshake:", err) + } +} + +func stressConnect(t *testing.T, wg *sync.WaitGroup, errCh chan error, url string, index int) { + defer wg.Done() + + subName := fmt.Sprintf("foo.%d", index) + + for i := 0; i < 33; i++ { + nc, err := nats.Connect(url, nats.RootCAs("./configs/certs/ca.pem")) + if err != nil { + errCh <- fmt.Errorf("Unable to create TLS connection: %v\n", err) + return + } + defer nc.Close() + + sub, err := nc.SubscribeSync(subName) + if err != nil { + errCh <- fmt.Errorf("Unable to subscribe on '%s': %v\n", subName, err) + return + } + + if err := nc.Publish(subName, []byte("secure data")); err != nil { + errCh <- fmt.Errorf("Unable to send on '%s': %v\n", subName, err) + } + + if _, err := sub.NextMsg(2 * time.Second); err != nil { + errCh <- fmt.Errorf("Unable to get next message: %v\n", err) + } + + nc.Close() + } + + errCh <- nil +} + +func TestTLSStressConnect(t *testing.T) { + opts, err := server.ProcessConfigFile("./configs/tls.conf") + if err != nil { + panic(fmt.Sprintf("Error processing configuration file: %v", err)) + } + opts.NoSigs, opts.NoLog = true, true + + // For this test, remove the authorization + opts.Username = "" + opts.Password = "" + + // Increase ssl timeout + opts.TLSTimeout = 2.0 + + srv := RunServer(opts) + defer srv.Shutdown() + + nurl := fmt.Sprintf("tls://%s:%d", opts.Host, opts.Port) + + threadCount := 3 + + errCh := make(chan error, threadCount) + + var wg sync.WaitGroup + wg.Add(threadCount) + + for i := 0; i < threadCount; i++ { + go stressConnect(t, &wg, errCh, nurl, i) + } + + wg.Wait() + + var lastError error + for i := 0; i < threadCount; i++ { + err := <-errCh + if err != nil { + lastError = err + } + } + + if lastError != nil { + t.Fatalf("%v\n", lastError) + } +} + +func TestTLSBadAuthError(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/tls.conf") + defer srv.Shutdown() + + endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + nurl := fmt.Sprintf("tls://%s:%s@%s/", opts.Username, "NOT_THE_PASSWORD", endpoint) + + _, err := nats.Connect(nurl, nats.RootCAs("./configs/certs/ca.pem")) + if err == nil { + t.Fatalf("Expected error trying to connect to secure server") + } + if err.Error() != nats.ErrAuthorization.Error() { + t.Fatalf("Excpected and auth violation, got %v\n", err) + } +} + +func TestTLSConnectionCurvePref(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/tls_curve_pref.conf") + defer srv.Shutdown() + + if len(opts.TLSConfig.CurvePreferences) != 1 { + t.Fatal("Invalid curve preference loaded.") + } + + if opts.TLSConfig.CurvePreferences[0] != tls.CurveP256 { + t.Fatalf("Invalid curve preference loaded [%v].", opts.TLSConfig.CurvePreferences[0]) + } + + endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port) + nurl := fmt.Sprintf("tls://%s:%s@%s/", opts.Username, opts.Password, endpoint) + nc, err := nats.Connect(nurl) + if err == nil { + nc.Close() + t.Fatalf("Expected error trying to connect to secure server") + } + + // Do simple SecureConnect + nc, err = nats.Connect(fmt.Sprintf("tls://%s/", endpoint)) + if err == nil { + nc.Close() + t.Fatalf("Expected error trying to connect to secure server with no auth") + } + + // Now do more advanced checking, verifying servername and using rootCA. + + nc, err = nats.Connect(nurl, nats.RootCAs("./configs/certs/ca.pem")) + if err != nil { + t.Fatalf("Got an error on Connect with Secure Options: %+v\n", err) + } + defer nc.Close() + + subj := "foo-tls" + sub, _ := nc.SubscribeSync(subj) + + nc.Publish(subj, []byte("We are Secure!")) + nc.Flush() + nmsgs, _ := sub.QueuedMsgs() + if nmsgs != 1 { + t.Fatalf("Expected to receive a message over the TLS connection") + } +} diff --git a/vendor/github.com/nats-io/gnatsd/test/user_authorization_test.go b/vendor/github.com/nats-io/gnatsd/test/user_authorization_test.go new file mode 100644 index 00000000..851c917d --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/user_authorization_test.go @@ -0,0 +1,101 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "regexp" + "testing" +) + +const DefaultPass = "foo" + +var permErrRe = regexp.MustCompile(`\A\-ERR\s+'Permissions Violation([^\r\n]+)\r\n`) + +func TestUserAuthorizationProto(t *testing.T) { + srv, opts := RunServerWithConfig("./configs/authorization.conf") + defer srv.Shutdown() + + // Alice can do anything, check a few for OK result. + c := createClientConn(t, opts.Host, opts.Port) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", "alice", DefaultPass) + expectResult(t, c, okRe) + sendProto(t, c, "PUB foo 2\r\nok\r\n") + expectResult(t, c, okRe) + sendProto(t, c, "SUB foo 1\r\n") + expectResult(t, c, okRe) + + // Check that we now reserve _SYS.> though for internal, so no clients. + sendProto(t, c, "PUB _SYS.HB 2\r\nok\r\n") + expectResult(t, c, permErrRe) + + // Check that _ is ok + sendProto(t, c, "PUB _ 2\r\nok\r\n") + expectResult(t, c, okRe) + + c.Close() + + // Bob is a requestor only, e.g. req.foo, req.bar for publish, subscribe only to INBOXes. + c = createClientConn(t, opts.Host, opts.Port) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", "bob", DefaultPass) + expectResult(t, c, okRe) + + // These should error. + sendProto(t, c, "SUB foo 1\r\n") + expectResult(t, c, permErrRe) + sendProto(t, c, "PUB foo 2\r\nok\r\n") + expectResult(t, c, permErrRe) + + // These should work ok. + sendProto(t, c, "SUB _INBOX.abcd 1\r\n") + expectResult(t, c, okRe) + sendProto(t, c, "PUB req.foo 2\r\nok\r\n") + expectResult(t, c, okRe) + sendProto(t, c, "PUB req.bar 2\r\nok\r\n") + expectResult(t, c, okRe) + c.Close() + + // Joe is a default user + c = createClientConn(t, opts.Host, opts.Port) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", "joe", DefaultPass) + expectResult(t, c, okRe) + + // These should error. + sendProto(t, c, "SUB foo.bar.* 1\r\n") + expectResult(t, c, permErrRe) + sendProto(t, c, "PUB foo.bar.baz 2\r\nok\r\n") + expectResult(t, c, permErrRe) + + // These should work ok. + sendProto(t, c, "SUB _INBOX.abcd 1\r\n") + expectResult(t, c, okRe) + sendProto(t, c, "SUB PUBLIC.abcd 1\r\n") + expectResult(t, c, okRe) + + sendProto(t, c, "PUB SANDBOX.foo 2\r\nok\r\n") + expectResult(t, c, okRe) + sendProto(t, c, "PUB SANDBOX.bar 2\r\nok\r\n") + expectResult(t, c, okRe) + + // Since only PWC, this should fail (too many tokens). + sendProto(t, c, "PUB SANDBOX.foo.bar 2\r\nok\r\n") + expectResult(t, c, permErrRe) + + c.Close() +} diff --git a/vendor/github.com/nats-io/gnatsd/test/verbose_test.go b/vendor/github.com/nats-io/gnatsd/test/verbose_test.go new file mode 100644 index 00000000..d370a2df --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/test/verbose_test.go @@ -0,0 +1,82 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "testing" +) + +func TestVerbosePing(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + doConnect(t, c, true, false, false) + + send := sendCommand(t, c) + expect := expectCommand(t, c) + + expect(okRe) + + // Ping should still be same + send("PING\r\n") + expect(pongRe) +} + +func TestVerboseConnect(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + doConnect(t, c, true, false, false) + + send := sendCommand(t, c) + expect := expectCommand(t, c) + + expect(okRe) + + // Connect + send("CONNECT {\"verbose\":true,\"pedantic\":true,\"tls_required\":false}\r\n") + expect(okRe) +} + +func TestVerbosePubSub(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "127.0.0.1", PROTO_TEST_PORT) + defer c.Close() + + doConnect(t, c, true, false, false) + send := sendCommand(t, c) + expect := expectCommand(t, c) + + expect(okRe) + + // Pub + send("PUB foo 2\r\nok\r\n") + expect(okRe) + + // Sub + send("SUB foo 1\r\n") + expect(okRe) + + // UnSub + send("UNSUB 1\r\n") + expect(okRe) +} diff --git a/vendor/github.com/nats-io/gnatsd/util/gnatsd.service b/vendor/github.com/nats-io/gnatsd/util/gnatsd.service new file mode 100644 index 00000000..259196f4 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/util/gnatsd.service @@ -0,0 +1,16 @@ +[Unit] +Description=NATS messaging server +After=network.target + +[Service] +PrivateTmp=true +Type=simple +ExecStart=/usr/sbin/gnatsd -c /etc/gnatsd.conf +ExecReload=/bin/kill -s HUP $MAINPID +ExecStop=/bin/kill -s SIGINT $MAINPID +User=gnatsd +Group=gnatsd + +[Install] +WantedBy=multi-user.target + diff --git a/vendor/github.com/nats-io/gnatsd/util/tls.go b/vendor/github.com/nats-io/gnatsd/util/tls.go new file mode 100644 index 00000000..87907eeb --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/util/tls.go @@ -0,0 +1,25 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +package util + +import ( + "crypto/tls" +) + +// CloneTLSConfig returns a copy of c. +func CloneTLSConfig(c *tls.Config) *tls.Config { + return c.Clone() +} diff --git a/vendor/github.com/nats-io/gnatsd/util/tls_pre17.go b/vendor/github.com/nats-io/gnatsd/util/tls_pre17.go new file mode 100644 index 00000000..99ea32b4 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/util/tls_pre17.go @@ -0,0 +1,47 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.5,!go1.7 + +package util + +import ( + "crypto/tls" +) + +// CloneTLSConfig returns a copy of c. Only the exported fields are copied. +// This is temporary, until this is provided by the language. +// https://go-review.googlesource.com/#/c/28075/ +func CloneTLSConfig(c *tls.Config) *tls.Config { + return &tls.Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + } +} diff --git a/vendor/github.com/nats-io/gnatsd/util/tls_pre18.go b/vendor/github.com/nats-io/gnatsd/util/tls_pre18.go new file mode 100644 index 00000000..7df47261 --- /dev/null +++ b/vendor/github.com/nats-io/gnatsd/util/tls_pre18.go @@ -0,0 +1,49 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.7,!go1.8 + +package util + +import ( + "crypto/tls" +) + +// CloneTLSConfig returns a copy of c. Only the exported fields are copied. +// This is temporary, until this is provided by the language. +// https://go-review.googlesource.com/#/c/28075/ +func CloneTLSConfig(c *tls.Config) *tls.Config { + return &tls.Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + } +} diff --git a/vendor/github.com/nats-io/go-nats/GOVERNANCE.md b/vendor/github.com/nats-io/go-nats/GOVERNANCE.md new file mode 100644 index 00000000..1d5a7be3 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/GOVERNANCE.md @@ -0,0 +1,3 @@ +# NATS Go Client Governance + +NATS Go Client (go-nats) is part of the NATS project and is subject to the [NATS Governance](https://github.com/nats-io/nats-general/blob/master/GOVERNANCE.md). \ No newline at end of file diff --git a/vendor/github.com/nats-io/go-nats/LICENSE b/vendor/github.com/nats-io/go-nats/LICENSE index 9798d4ef..261eeb9e 100644 --- a/vendor/github.com/nats-io/go-nats/LICENSE +++ b/vendor/github.com/nats-io/go-nats/LICENSE @@ -1,20 +1,201 @@ -The MIT License (MIT) + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ -Copyright (c) 2012-2017 Apcera Inc. + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -the Software, and to permit persons to whom the Software is furnished to do so, -subject to the following conditions: + 1. Definitions. -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/nats-io/go-nats/MAINTAINERS.md b/vendor/github.com/nats-io/go-nats/MAINTAINERS.md new file mode 100644 index 00000000..323faa8e --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/MAINTAINERS.md @@ -0,0 +1,10 @@ +# Maintainers + +Maintainership is on a per project basis. + +### Core-maintainers + - Derek Collison [@derekcollison](https://github.com/derekcollison) + - Ivan Kozlovic [@kozlovic](https://github.com/kozlovic) + +### Maintainers + - Waldemar Quevedo [@wallyqs](https://github.com/wallyqs) \ No newline at end of file diff --git a/vendor/github.com/nats-io/go-nats/README.md b/vendor/github.com/nats-io/go-nats/README.md index ae6868c5..f109f3d2 100644 --- a/vendor/github.com/nats-io/go-nats/README.md +++ b/vendor/github.com/nats-io/go-nats/README.md @@ -1,7 +1,8 @@ # NATS - Go Client A [Go](http://golang.org) client for the [NATS messaging system](https://nats.io). -[![License MIT](https://img.shields.io/badge/License-MIT-blue.svg)](http://opensource.org/licenses/MIT) +[![License Apache 2](https://img.shields.io/badge/License-Apache2-blue.svg)](https://www.apache.org/licenses/LICENSE-2.0) +[![FOSSA Status](https://app.fossa.io/api/projects/git%2Bgithub.com%2Fnats-io%2Fgo-nats.svg?type=shield)](https://app.fossa.io/projects/git%2Bgithub.com%2Fnats-io%2Fgo-nats?ref=badge_shield) [![Go Report Card](https://goreportcard.com/badge/github.com/nats-io/go-nats)](https://goreportcard.com/report/github.com/nats-io/go-nats) [![Build Status](https://travis-ci.org/nats-io/go-nats.svg?branch=master)](http://travis-ci.org/nats-io/go-nats) [![GoDoc](https://godoc.org/github.com/nats-io/go-nats?status.svg)](http://godoc.org/github.com/nats-io/go-nats) [![Coverage Status](https://coveralls.io/repos/nats-io/go-nats/badge.svg?branch=master)](https://coveralls.io/r/nats-io/go-nats?branch=master) ## Installation @@ -327,24 +328,7 @@ err := c.RequestWithContext(ctx, "foo", req, resp) ## License -(The MIT License) +Unless otherwise noted, the NATS source files are distributed +under the Apache Version 2.0 license found in the LICENSE file. -Copyright (c) 2012-2017 Apcera Inc. - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to -deal in the Software without restriction, including without limitation the -rights to use, copy, modify, merge, publish, distribute, sublicense, and/or -sell copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS -IN THE SOFTWARE. +[![FOSSA Status](https://app.fossa.io/api/projects/git%2Bgithub.com%2Fnats-io%2Fgo-nats.svg?type=large)](https://app.fossa.io/projects/git%2Bgithub.com%2Fnats-io%2Fgo-nats?ref=badge_large) diff --git a/vendor/github.com/nats-io/go-nats/context.go b/vendor/github.com/nats-io/go-nats/context.go index be6ada4a..4f9ec67d 100644 --- a/vendor/github.com/nats-io/go-nats/context.go +++ b/vendor/github.com/nats-io/go-nats/context.go @@ -1,4 +1,15 @@ -// Copyright 2012-2017 Apcera Inc. All rights reserved. +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. // +build go1.7 diff --git a/vendor/github.com/nats-io/go-nats/enc.go b/vendor/github.com/nats-io/go-nats/enc.go index 291b7826..d35b5b68 100644 --- a/vendor/github.com/nats-io/go-nats/enc.go +++ b/vendor/github.com/nats-io/go-nats/enc.go @@ -1,4 +1,15 @@ -// Copyright 2012-2015 Apcera Inc. All rights reserved. +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. package nats @@ -196,9 +207,9 @@ func (c *EncodedConn) subscribe(subject, queue string, cb Handler) (*Subscriptio } if err := c.Enc.Decode(m.Subject, m.Data, oPtr.Interface()); err != nil { if c.Conn.Opts.AsyncErrorCB != nil { - c.Conn.ach <- func() { + c.Conn.ach.push(func() { c.Conn.Opts.AsyncErrorCB(c.Conn, m.Sub, errors.New("nats: Got an error trying to unmarshal: "+err.Error())) - } + }) } return } diff --git a/vendor/github.com/nats-io/go-nats/enc_test.go b/vendor/github.com/nats-io/go-nats/enc_test.go index ada5b024..ca170e8b 100644 --- a/vendor/github.com/nats-io/go-nats/enc_test.go +++ b/vendor/github.com/nats-io/go-nats/enc_test.go @@ -1,3 +1,16 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package nats_test import ( diff --git a/vendor/github.com/nats-io/go-nats/encoders/builtin/default_enc.go b/vendor/github.com/nats-io/go-nats/encoders/builtin/default_enc.go new file mode 100644 index 00000000..46d918ee --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/encoders/builtin/default_enc.go @@ -0,0 +1,117 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package builtin + +import ( + "bytes" + "fmt" + "reflect" + "strconv" + "unsafe" +) + +// DefaultEncoder implementation for EncodedConn. +// This encoder will leave []byte and string untouched, but will attempt to +// turn numbers into appropriate strings that can be decoded. It will also +// propely encoded and decode bools. If will encode a struct, but if you want +// to properly handle structures you should use JsonEncoder. +type DefaultEncoder struct { + // Empty +} + +var trueB = []byte("true") +var falseB = []byte("false") +var nilB = []byte("") + +// Encode +func (je *DefaultEncoder) Encode(subject string, v interface{}) ([]byte, error) { + switch arg := v.(type) { + case string: + bytes := *(*[]byte)(unsafe.Pointer(&arg)) + return bytes, nil + case []byte: + return arg, nil + case bool: + if arg { + return trueB, nil + } else { + return falseB, nil + } + case nil: + return nilB, nil + default: + var buf bytes.Buffer + fmt.Fprintf(&buf, "%+v", arg) + return buf.Bytes(), nil + } +} + +// Decode +func (je *DefaultEncoder) Decode(subject string, data []byte, vPtr interface{}) error { + // Figure out what it's pointing to... + sData := *(*string)(unsafe.Pointer(&data)) + switch arg := vPtr.(type) { + case *string: + *arg = sData + return nil + case *[]byte: + *arg = data + return nil + case *int: + n, err := strconv.ParseInt(sData, 10, 64) + if err != nil { + return err + } + *arg = int(n) + return nil + case *int32: + n, err := strconv.ParseInt(sData, 10, 64) + if err != nil { + return err + } + *arg = int32(n) + return nil + case *int64: + n, err := strconv.ParseInt(sData, 10, 64) + if err != nil { + return err + } + *arg = int64(n) + return nil + case *float32: + n, err := strconv.ParseFloat(sData, 32) + if err != nil { + return err + } + *arg = float32(n) + return nil + case *float64: + n, err := strconv.ParseFloat(sData, 64) + if err != nil { + return err + } + *arg = float64(n) + return nil + case *bool: + b, err := strconv.ParseBool(sData) + if err != nil { + return err + } + *arg = b + return nil + default: + vt := reflect.TypeOf(arg).Elem() + return fmt.Errorf("nats: Default Encoder can't decode to type %s", vt) + } +} diff --git a/vendor/github.com/nats-io/go-nats/encoders/builtin/gob_enc.go b/vendor/github.com/nats-io/go-nats/encoders/builtin/gob_enc.go new file mode 100644 index 00000000..632bcbd3 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/encoders/builtin/gob_enc.go @@ -0,0 +1,45 @@ +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package builtin + +import ( + "bytes" + "encoding/gob" +) + +// GobEncoder is a Go specific GOB Encoder implementation for EncodedConn. +// This encoder will use the builtin encoding/gob to Marshal +// and Unmarshal most types, including structs. +type GobEncoder struct { + // Empty +} + +// FIXME(dlc) - This could probably be more efficient. + +// Encode +func (ge *GobEncoder) Encode(subject string, v interface{}) ([]byte, error) { + b := new(bytes.Buffer) + enc := gob.NewEncoder(b) + if err := enc.Encode(v); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// Decode +func (ge *GobEncoder) Decode(subject string, data []byte, vPtr interface{}) (err error) { + dec := gob.NewDecoder(bytes.NewBuffer(data)) + err = dec.Decode(vPtr) + return +} diff --git a/vendor/github.com/nats-io/go-nats/encoders/builtin/json_enc.go b/vendor/github.com/nats-io/go-nats/encoders/builtin/json_enc.go new file mode 100644 index 00000000..c9670f31 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/encoders/builtin/json_enc.go @@ -0,0 +1,56 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package builtin + +import ( + "encoding/json" + "strings" +) + +// JsonEncoder is a JSON Encoder implementation for EncodedConn. +// This encoder will use the builtin encoding/json to Marshal +// and Unmarshal most types, including structs. +type JsonEncoder struct { + // Empty +} + +// Encode +func (je *JsonEncoder) Encode(subject string, v interface{}) ([]byte, error) { + b, err := json.Marshal(v) + if err != nil { + return nil, err + } + return b, nil +} + +// Decode +func (je *JsonEncoder) Decode(subject string, data []byte, vPtr interface{}) (err error) { + switch arg := vPtr.(type) { + case *string: + // If they want a string and it is a JSON string, strip quotes + // This allows someone to send a struct but receive as a plain string + // This cast should be efficient for Go 1.3 and beyond. + str := string(data) + if strings.HasPrefix(str, `"`) && strings.HasSuffix(str, `"`) { + *arg = str[1 : len(str)-1] + } else { + *arg = str + } + case *[]byte: + *arg = data + default: + err = json.Unmarshal(data, arg) + } + return +} diff --git a/vendor/github.com/nats-io/go-nats/encoders/protobuf/protobuf_enc.go b/vendor/github.com/nats-io/go-nats/encoders/protobuf/protobuf_enc.go new file mode 100644 index 00000000..1a731056 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/encoders/protobuf/protobuf_enc.go @@ -0,0 +1,73 @@ +// Copyright 2015-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package protobuf + +import ( + "errors" + + "github.com/golang/protobuf/proto" + "github.com/nats-io/go-nats" +) + +// Additional index for registered Encoders. +const ( + PROTOBUF_ENCODER = "protobuf" +) + +func init() { + // Register protobuf encoder + nats.RegisterEncoder(PROTOBUF_ENCODER, &ProtobufEncoder{}) +} + +// ProtobufEncoder is a protobuf implementation for EncodedConn +// This encoder will use the builtin protobuf lib to Marshal +// and Unmarshal structs. +type ProtobufEncoder struct { + // Empty +} + +var ( + ErrInvalidProtoMsgEncode = errors.New("nats: Invalid protobuf proto.Message object passed to encode") + ErrInvalidProtoMsgDecode = errors.New("nats: Invalid protobuf proto.Message object passed to decode") +) + +// Encode +func (pb *ProtobufEncoder) Encode(subject string, v interface{}) ([]byte, error) { + if v == nil { + return nil, nil + } + i, found := v.(proto.Message) + if !found { + return nil, ErrInvalidProtoMsgEncode + } + + b, err := proto.Marshal(i) + if err != nil { + return nil, err + } + return b, nil +} + +// Decode +func (pb *ProtobufEncoder) Decode(subject string, data []byte, vPtr interface{}) error { + if _, ok := vPtr.(*interface{}); ok { + return nil + } + i, found := vPtr.(proto.Message) + if !found { + return ErrInvalidProtoMsgDecode + } + + return proto.Unmarshal(data, i) +} diff --git a/vendor/github.com/nats-io/go-nats/encoders/protobuf/testdata/pbtest.pb.go b/vendor/github.com/nats-io/go-nats/encoders/protobuf/testdata/pbtest.pb.go new file mode 100644 index 00000000..718e5722 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/encoders/protobuf/testdata/pbtest.pb.go @@ -0,0 +1,40 @@ +// Code generated by protoc-gen-go. +// source: pbtest.proto +// DO NOT EDIT! + +/* +Package testdata is a generated protocol buffer package. + +It is generated from these files: + pbtest.proto + +It has these top-level messages: + Person +*/ +package testdata + +import proto "github.com/golang/protobuf/proto" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal + +type Person struct { + Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Age int32 `protobuf:"varint,2,opt,name=age" json:"age,omitempty"` + Address string `protobuf:"bytes,3,opt,name=address" json:"address,omitempty"` + Children map[string]*Person `protobuf:"bytes,10,rep,name=children" json:"children,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` +} + +func (m *Person) Reset() { *m = Person{} } +func (m *Person) String() string { return proto.CompactTextString(m) } +func (*Person) ProtoMessage() {} + +func (m *Person) GetChildren() map[string]*Person { + if m != nil { + return m.Children + } + return nil +} + +func init() { +} diff --git a/vendor/github.com/nats-io/go-nats/encoders/protobuf/testdata/pbtest.proto b/vendor/github.com/nats-io/go-nats/encoders/protobuf/testdata/pbtest.proto new file mode 100644 index 00000000..010f8081 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/encoders/protobuf/testdata/pbtest.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package testdata; + +message Person { + string name = 1; + int32 age = 2; + string address = 3; + + map children = 10; +} diff --git a/vendor/github.com/nats-io/go-nats/example_test.go b/vendor/github.com/nats-io/go-nats/example_test.go index 64a65867..f0101763 100644 --- a/vendor/github.com/nats-io/go-nats/example_test.go +++ b/vendor/github.com/nats-io/go-nats/example_test.go @@ -1,3 +1,16 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package nats_test import ( diff --git a/vendor/github.com/nats-io/go-nats/nats.go b/vendor/github.com/nats-io/go-nats/nats.go index fbb86c03..b2869a27 100644 --- a/vendor/github.com/nats-io/go-nats/nats.go +++ b/vendor/github.com/nats-io/go-nats/nats.go @@ -1,4 +1,15 @@ -// Copyright 2012-2017 Apcera Inc. All rights reserved. +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. // A Go client for the NATS messaging system (https://nats.io). package nats @@ -11,15 +22,16 @@ import ( "encoding/json" "errors" "fmt" + "io" "io/ioutil" "math/rand" "net" "net/url" - "regexp" "runtime" "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/nats-io/go-nats/util" @@ -28,7 +40,7 @@ import ( // Default Constants const ( - Version = "1.3.1" + Version = "1.6.0" DefaultURL = "nats://localhost:4222" DefaultPort = 4222 DefaultMaxReconnect = 60 @@ -76,6 +88,7 @@ var ( ErrInvalidMsg = errors.New("nats: invalid message or message nil") ErrInvalidArg = errors.New("nats: invalid argument") ErrInvalidContext = errors.New("nats: invalid context") + ErrNoEchoNotSupported = errors.New("nats: no echo option not supported by this server") ErrStaleConnection = errors.New("nats: " + STALE_CONNECTION) ) @@ -118,11 +131,27 @@ type ConnHandler func(*Conn) type ErrHandler func(*Conn, *Subscription, error) // asyncCB is used to preserve order for async callbacks. -type asyncCB func() +type asyncCB struct { + f func() + next *asyncCB +} + +type asyncCallbacksHandler struct { + mu sync.Mutex + cond *sync.Cond + head *asyncCB + tail *asyncCB +} // Option is a function on the options for a connection. type Option func(*Options) error +// CustomDialer can be used to specify any dialer, not necessarily +// a *net.Dialer. +type CustomDialer interface { + Dial(network, address string) (net.Conn, error) +} + // Options can be used to create a customized connection. type Options struct { @@ -139,6 +168,11 @@ type Options struct { // server pool. NoRandomize bool + // NoEcho configures whether the server will echo back messages + // that are sent on this connection if we also have matching subscriptions. + // Note this is supported on servers >= version 1.2. Proto 1 or greater. + NoEcho bool + // Name is an optional name label which will be sent to the server // on CONNECT to identify the client. Name string @@ -225,9 +259,14 @@ type Options struct { // Token sets the token to be used when connecting to a server. Token string - // Dialer allows a custom Dialer when forming connections. + // Dialer allows a custom net.Dialer when forming connections. + // DEPRECATED: should use CustomDialer instead. Dialer *net.Dialer + // CustomDialer allows to specify a custom dialer (not necessarily + // a *net.Dialer). + CustomDialer CustomDialer + // UseOldRequestStyle forces the old method of Requests that utilize // a new Inbox and a new Subscription for each request. UseOldRequestStyle bool @@ -246,9 +285,6 @@ const ( // Default server pool size srvPoolSize = 4 - // Channel size for the async callback handler. - asyncCBChanSize = 32 - // NUID size nuidSize = 22 ) @@ -261,9 +297,11 @@ type Conn struct { // atomic.* functions crash on 32bit machines if operand is not aligned // at 64bit. See https://github.com/golang/go/issues/599 Statistics - mu sync.Mutex + mu sync.Mutex + // Opts holds the configuration of the Conn. + // Modifying the configuration of a running Conn is a race. Opts Options - wg *sync.WaitGroup + wg sync.WaitGroup url *url.URL conn net.Conn srvPool []*srv @@ -275,7 +313,7 @@ type Conn struct { ssid int64 subsMu sync.RWMutex subs map[int64]*Subscription - ach chan asyncCB + ach *asyncCallbacksHandler pongs []chan struct{} scratch [scratchSize]byte status Status @@ -340,6 +378,12 @@ type Msg struct { Data []byte Sub *Subscription next *Msg + barrier *barrierInfo +} + +type barrierInfo struct { + refs int64 + f func() } // Tracks various stats received and sent on this connection, @@ -370,6 +414,7 @@ type serverInfo struct { TLSRequired bool `json:"tls_required"` MaxPayload int64 `json:"max_payload"` ConnectURLs []string `json:"connect_urls,omitempty"` + Proto int `json:"proto,omitempty"` } const ( @@ -392,6 +437,7 @@ type connectInfo struct { Lang string `json:"lang"` Version string `json:"version"` Protocol int `json:"protocol"` + Echo bool `json:"echo"` } // MsgHandler is a callback function that processes messages delivered to @@ -501,6 +547,15 @@ func DontRandomize() Option { } } +// NoEcho is an Option to turn off messages echoing back from a server. +// Note this is supported on servers >= version 1.2. Proto 1 or greater. +func NoEcho() Option { + return func(o *Options) error { + o.NoEcho = true + return nil + } +} + // ReconnectWait is an Option to set the wait time between reconnect attempts. func ReconnectWait(t time.Duration) Option { return func(o *Options) error { @@ -517,6 +572,22 @@ func MaxReconnects(max int) Option { } } +// PingInterval is an Option to set the period for client ping commands +func PingInterval(t time.Duration) Option { + return func(o *Options) error { + o.PingInterval = t + return nil + } +} + +// ReconnectBufSize sets the buffer size of messages kept while busy reconnecting +func ReconnectBufSize(size int) Option { + return func(o *Options) error { + o.ReconnectBufSize = size + return nil + } +} + // Timeout is an Option to set the timeout for Dial on a connection. func Timeout(t time.Duration) Option { return func(o *Options) error { @@ -557,7 +628,7 @@ func DiscoveredServersHandler(cb ConnHandler) Option { } } -// ErrHandler is an Option to set the async error handler. +// ErrorHandler is an Option to set the async error handler. func ErrorHandler(cb ErrHandler) Option { return func(o *Options) error { o.AsyncErrorCB = cb @@ -586,6 +657,7 @@ func Token(token string) Option { // Dialer is an Option to set the dialer which will be used when // attempting to establish a connection. +// DEPRECATED: Should use CustomDialer instead. func Dialer(dialer *net.Dialer) Option { return func(o *Options) error { o.Dialer = dialer @@ -593,7 +665,17 @@ func Dialer(dialer *net.Dialer) Option { } } -// UseOldRequestyStyle is an Option to force usage of the old Request style. +// SetCustomDialer is an Option to set a custom dialer which will be +// used when attempting to establish a connection. If both Dialer +// and CustomDialer are specified, CustomDialer takes precedence. +func SetCustomDialer(dialer CustomDialer) Option { + return func(o *Options) error { + o.CustomDialer = dialer + return nil + } +} + +// UseOldRequestStyle is an Option to force usage of the old Request style. func UseOldRequestStyle() Option { return func(o *Options) error { o.UseOldRequestStyle = true @@ -643,7 +725,7 @@ func (nc *Conn) SetClosedHandler(cb ConnHandler) { nc.Opts.ClosedCB = cb } -// SetErrHandler will set the async error handler. +// SetErrorHandler will set the async error handler. func (nc *Conn) SetErrorHandler(cb ErrHandler) { if nc == nil { return @@ -695,15 +777,16 @@ func (o Options) Connect() (*Conn, error) { return nil, err } - // Create the async callback channel. - nc.ach = make(chan asyncCB, asyncCBChanSize) + // Create the async callback handler. + nc.ach = &asyncCallbacksHandler{} + nc.ach.cond = sync.NewCond(&nc.ach.mu) if err := nc.connect(); err != nil { return nil, err } // Spin up the async cb dispatcher on success - go nc.asyncDispatch() + go nc.ach.asyncCBDispatcher() return nc, nil } @@ -877,7 +960,13 @@ func (nc *Conn) createConn() (err error) { cur.lastAttempt = time.Now() } - dialer := nc.Opts.Dialer + // CustomDialer takes precedence. If not set, use Opts.Dialer which + // is set to a default *net.Dialer (in Connect()) if not explicitly + // set by the user. + dialer := nc.Opts.CustomDialer + if dialer == nil { + dialer = nc.Opts.Dialer + } nc.conn, err = dialer.Dial("tcp", nc.url.Host) if err != nil { return err @@ -920,7 +1009,7 @@ func (nc *Conn) makeTLSConn() { // waitForExits will wait for all socket watcher Go routines to // be shutdown before proceeding. -func (nc *Conn) waitForExits(wg *sync.WaitGroup) { +func (nc *Conn) waitForExits() { // Kick old flusher forcefully. select { case nc.fch <- struct{}{}: @@ -928,38 +1017,7 @@ func (nc *Conn) waitForExits(wg *sync.WaitGroup) { } // Wait for any previous go routines. - if wg != nil { - wg.Wait() - } -} - -// spinUpGoRoutines will launch the Go routines responsible for -// reading and writing to the socket. This will be launched via a -// go routine itself to release any locks that may be held. -// We also use a WaitGroup to make sure we only start them on a -// reconnect when the previous ones have exited. -func (nc *Conn) spinUpGoRoutines() { - // Make sure everything has exited. - nc.waitForExits(nc.wg) - - // Create a new waitGroup instance for this run. - nc.wg = &sync.WaitGroup{} - // We will wait on both. - nc.wg.Add(2) - - // Spin up the readLoop and the socket flusher. - go nc.readLoop(nc.wg) - go nc.flusher(nc.wg) - - nc.mu.Lock() - if nc.Opts.PingInterval > 0 { - if nc.ptmr == nil { - nc.ptmr = time.AfterFunc(nc.Opts.PingInterval, nc.processPingTimer) - } else { - nc.ptmr.Reset(nc.Opts.PingInterval) - } - } - nc.mu.Unlock() + nc.wg.Wait() } // Report the connected server's Url @@ -1003,7 +1061,7 @@ func (nc *Conn) setup() { // Process a connected connection and initialize properly. func (nc *Conn) processConnectInit() error { - // Set out deadline for the whole connect process + // Set our deadline for the whole connect process nc.conn.SetDeadline(time.Now().Add(nc.Opts.Timeout)) defer nc.conn.SetDeadline(time.Time{}) @@ -1026,7 +1084,19 @@ func (nc *Conn) processConnectInit() error { // Reset the number of PING sent out nc.pout = 0 - go nc.spinUpGoRoutines() + // Start or reset Timer + if nc.Opts.PingInterval > 0 { + if nc.ptmr == nil { + nc.ptmr = time.AfterFunc(nc.Opts.PingInterval, nc.processPingTimer) + } else { + nc.ptmr.Reset(nc.Opts.PingInterval) + } + } + + // Start the readLoop and flusher go routines, we will wait on both on a reconnect event. + nc.wg.Add(2) + go nc.readLoop() + go nc.flusher() return nil } @@ -1066,7 +1136,7 @@ func (nc *Conn) connect() error { } else { // Cancel out default connection refused, will trigger the // No servers error conditional - if matched, _ := regexp.Match(`connection refused`, []byte(err.Error())); matched { + if strings.Contains(err.Error(), "connection refused") { returnedErr = nil } } @@ -1150,18 +1220,25 @@ func (nc *Conn) connectProto() (string, error) { pass, _ = u.Password() } } else { - // Take from options (pssibly all empty strings) + // Take from options (possibly all empty strings) user = nc.Opts.User pass = nc.Opts.Password token = nc.Opts.Token } - cinfo := connectInfo{o.Verbose, o.Pedantic, - user, pass, token, - o.Secure, o.Name, LangString, Version, clientProtoInfo} + + cinfo := connectInfo{o.Verbose, o.Pedantic, user, pass, token, + o.Secure, o.Name, LangString, Version, clientProtoInfo, !o.NoEcho} + b, err := json.Marshal(cinfo) if err != nil { return _EMPTY_, ErrJsonParse } + + // Check if NoEcho is set and we have a server that supports it. + if o.NoEcho && nc.info.Proto < 1 { + return _EMPTY_, ErrNoEchoNotSupported + } + return fmt.Sprintf(conProto, b), nil } @@ -1201,38 +1278,40 @@ func (nc *Conn) sendConnect() error { return err } - // Now read the response from the server. - br := bufio.NewReaderSize(nc.conn, defaultBufSize) - line, err := br.ReadString('\n') + // We don't want to read more than we need here, otherwise + // we would need to transfer the excess read data to the readLoop. + // Since in normal situations we just are looking for a PONG\r\n, + // reading byte-by-byte here is ok. + proto, err := nc.readProto() if err != nil { return err } // If opts.Verbose is set, handle +OK - if nc.Opts.Verbose && line == okProto { + if nc.Opts.Verbose && proto == okProto { // Read the rest now... - line, err = br.ReadString('\n') + proto, err = nc.readProto() if err != nil { return err } } // We expect a PONG - if line != pongProto { + if proto != pongProto { // But it could be something else, like -ERR // Since we no longer use ReadLine(), trim the trailing "\r\n" - line = strings.TrimRight(line, "\r\n") + proto = strings.TrimRight(proto, "\r\n") // If it's a server error... - if strings.HasPrefix(line, _ERR_OP_) { + if strings.HasPrefix(proto, _ERR_OP_) { // Remove -ERR, trim spaces and quotes, and convert to lower case. - line = normalizeErr(line) - return errors.New("nats: " + line) + proto = normalizeErr(proto) + return errors.New("nats: " + proto) } // Notify that we got an unexpected protocol. - return fmt.Errorf("nats: expected '%s', got '%s'", _PONG_OP_, line) + return fmt.Errorf("nats: expected '%s', got '%s'", _PONG_OP_, proto) } // This is where we are truly connected. @@ -1241,6 +1320,29 @@ func (nc *Conn) sendConnect() error { return nil } +// reads a protocol one byte at a time. +func (nc *Conn) readProto() (string, error) { + var ( + _buf = [10]byte{} + buf = _buf[:0] + b = [1]byte{} + protoEnd = byte('\n') + ) + for { + if _, err := nc.conn.Read(b[:1]); err != nil { + // Do not report EOF error + if err == io.EOF { + return string(buf), nil + } + return "", err + } + buf = append(buf, b[0]) + if b[0] == protoEnd { + return string(buf), nil + } + } +} + // A control protocol line. type control struct { op, args string @@ -1281,15 +1383,20 @@ func (nc *Conn) flushReconnectPendingItems() { } } +// Stops the ping timer if set. +// Connection lock is held on entry. +func (nc *Conn) stopPingTimer() { + if nc.ptmr != nil { + nc.ptmr.Stop() + } +} + // Try to reconnect using the option parameters. // This function assumes we are allowed to reconnect. func (nc *Conn) doReconnect() { // We want to make sure we have the other watchers shutdown properly // here before we proceed past this point. - nc.mu.Lock() - wg := nc.wg - nc.mu.Unlock() - nc.waitForExits(wg) + nc.waitForExits() // FIXME(dlc) - We have an issue here if we have // outstanding flush points (pongs) and they were not @@ -1304,12 +1411,15 @@ func (nc *Conn) doReconnect() { // Clear any errors. nc.err = nil - // Perform appropriate callback if needed for a disconnect. if nc.Opts.DisconnectedCB != nil { - nc.ach <- func() { nc.Opts.DisconnectedCB(nc) } + nc.ach.push(func() { nc.Opts.DisconnectedCB(nc) }) } + // This is used to wait on go routines exit if we start them in the loop + // but an error occurs after that. + waitForGoRoutines := false + for len(nc.srvPool) > 0 { cur, err := nc.selectNextServer() if err != nil { @@ -1337,6 +1447,11 @@ func (nc *Conn) doReconnect() { } else { time.Sleep(time.Duration(sleepTime)) } + // If the readLoop, etc.. go routines were started, wait for them to complete. + if waitForGoRoutines { + nc.waitForExits() + waitForGoRoutines = false + } nc.mu.Lock() // Check if we have been closed first. @@ -1363,6 +1478,9 @@ func (nc *Conn) doReconnect() { // Process connect logic if nc.err = nc.processConnectInit(); nc.err != nil { nc.status = RECONNECTING + // Reset the buffered writer to the pending buffer + // (was set to a buffered writer on nc.conn in createConn) + nc.bw.Reset(nc.pending) continue } @@ -1380,6 +1498,14 @@ func (nc *Conn) doReconnect() { nc.err = nc.bw.Flush() if nc.err != nil { nc.status = RECONNECTING + // Reset the buffered writer to the pending buffer (bytes.Buffer). + nc.bw.Reset(nc.pending) + // Stop the ping timer (if set) + nc.stopPingTimer() + // Since processConnectInit() returned without error, the + // go routines were started, so wait for them to return + // on the next iteration (after releasing the lock). + waitForGoRoutines = true continue } @@ -1391,9 +1517,8 @@ func (nc *Conn) doReconnect() { // Queue up the reconnect callback. if nc.Opts.ReconnectedCB != nil { - nc.ach <- func() { nc.Opts.ReconnectedCB(nc) } + nc.ach.push(func() { nc.Opts.ReconnectedCB(nc) }) } - // Release lock here, we will return below. nc.mu.Unlock() @@ -1423,19 +1548,17 @@ func (nc *Conn) processOpErr(err error) { if nc.Opts.AllowReconnect && nc.status == CONNECTED { // Set our new status nc.status = RECONNECTING - if nc.ptmr != nil { - nc.ptmr.Stop() - } + // Stop ping timer if set + nc.stopPingTimer() if nc.conn != nil { nc.bw.Flush() nc.conn.Close() nc.conn = nil } - // Create a new pending buffer to underpin the bufio Writer while - // we are reconnecting. - nc.pending = &bytes.Buffer{} - nc.bw = bufio.NewWriterSize(nc.pending, nc.Opts.ReconnectBufSize) + // Create pending buffer before reconnecting. + nc.pending = new(bytes.Buffer) + nc.bw.Reset(nc.pending) go nc.doReconnect() nc.mu.Unlock() @@ -1448,41 +1571,73 @@ func (nc *Conn) processOpErr(err error) { nc.Close() } -// Marker to close the channel to kick out the Go routine. -func (nc *Conn) closeAsyncFunc() asyncCB { - return func() { - nc.mu.Lock() - if nc.ach != nil { - close(nc.ach) - nc.ach = nil +// dispatch is responsible for calling any async callbacks +func (ac *asyncCallbacksHandler) asyncCBDispatcher() { + for { + ac.mu.Lock() + // Protect for spurious wakeups. We should get out of the + // wait only if there is an element to pop from the list. + for ac.head == nil { + ac.cond.Wait() } - nc.mu.Unlock() + cur := ac.head + ac.head = cur.next + if cur == ac.tail { + ac.tail = nil + } + ac.mu.Unlock() + + // This signals that the dispatcher has been closed and all + // previous callbacks have been dispatched. + if cur.f == nil { + return + } + // Invoke callback outside of handler's lock + cur.f() } } -// asyncDispatch is responsible for calling any async callbacks -func (nc *Conn) asyncDispatch() { - // snapshot since they can change from underneath of us. - nc.mu.Lock() - ach := nc.ach - nc.mu.Unlock() +// Add the given function to the tail of the list and +// signals the dispatcher. +func (ac *asyncCallbacksHandler) push(f func()) { + ac.pushOrClose(f, false) +} - // Loop on the channel and process async callbacks. - for { - if f, ok := <-ach; !ok { - return - } else { - f() - } +// Signals that we are closing... +func (ac *asyncCallbacksHandler) close() { + ac.pushOrClose(nil, true) +} + +// Add the given function to the tail of the list and +// signals the dispatcher. +func (ac *asyncCallbacksHandler) pushOrClose(f func(), close bool) { + ac.mu.Lock() + defer ac.mu.Unlock() + // Make sure that library is not calling push with nil function, + // since this is used to notify the dispatcher that it should stop. + if !close && f == nil { + panic("pushing a nil callback") + } + cb := &asyncCB{f: f} + if ac.tail != nil { + ac.tail.next = cb + } else { + ac.head = cb + } + ac.tail = cb + if close { + ac.cond.Broadcast() + } else { + ac.cond.Signal() } } // readLoop() will sit on the socket reading and processing the // protocol from the server. It will dispatch appropriately based // on the op type. -func (nc *Conn) readLoop(wg *sync.WaitGroup) { +func (nc *Conn) readLoop() { // Release the wait group on exit - defer wg.Done() + defer nc.wg.Done() // Create a parseState if needed. nc.mu.Lock() @@ -1543,6 +1698,13 @@ func (nc *Conn) waitForMsgs(s *Subscription) { if s.pHead == nil { s.pTail = nil } + if m.barrier != nil { + s.mu.Unlock() + if atomic.AddInt64(&m.barrier.refs, -1) == 0 { + m.barrier.f() + } + continue + } s.pMsgs-- s.pBytes -= len(m.Data) } @@ -1571,6 +1733,19 @@ func (nc *Conn) waitForMsgs(s *Subscription) { break } } + // Check for barrier messages + s.mu.Lock() + for m := s.pHead; m != nil; m = s.pHead { + if m.barrier != nil { + s.mu.Unlock() + if atomic.AddInt64(&m.barrier.refs, -1) == 0 { + m.barrier.f() + } + s.mu.Lock() + } + s.pHead = m.next + } + s.mu.Unlock() } // processMsg is called by parse and will place the msg on the @@ -1672,7 +1847,7 @@ slowConsumer: nc.mu.Lock() nc.err = ErrSlowConsumer if nc.Opts.AsyncErrorCB != nil { - nc.ach <- func() { nc.Opts.AsyncErrorCB(nc, sub, ErrSlowConsumer) } + nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, sub, ErrSlowConsumer) }) } nc.mu.Unlock() } @@ -1686,7 +1861,7 @@ func (nc *Conn) processPermissionsViolation(err string) { e := errors.New("nats: " + err) nc.err = e if nc.Opts.AsyncErrorCB != nil { - nc.ach <- func() { nc.Opts.AsyncErrorCB(nc, nil, e) } + nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, nil, e) }) } nc.mu.Unlock() } @@ -1697,16 +1872,16 @@ func (nc *Conn) processAuthorizationViolation(err string) { nc.mu.Lock() nc.err = ErrAuthorization if nc.Opts.AsyncErrorCB != nil { - nc.ach <- func() { nc.Opts.AsyncErrorCB(nc, nil, ErrAuthorization) } + nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, nil, ErrAuthorization) }) } nc.mu.Unlock() } // flusher is a separate Go routine that will process flush requests for the write // bufio. This allows coalescing of writes to the underlying socket. -func (nc *Conn) flusher(wg *sync.WaitGroup) { +func (nc *Conn) flusher() { // Release the wait group - defer wg.Done() + defer nc.wg.Done() // snapshot the bw and conn since they can change from underneath of us. nc.mu.Lock() @@ -1784,32 +1959,67 @@ func (nc *Conn) processInfo(info string) error { if info == _EMPTY_ { return nil } - if err := json.Unmarshal([]byte(info), &nc.info); err != nil { + ncInfo := serverInfo{} + if err := json.Unmarshal([]byte(info), &ncInfo); err != nil { return err } + // Copy content into connection's info structure. + nc.info = ncInfo + // The array could be empty/not present on initial connect, + // if advertise is disabled on that server, or servers that + // did not include themselves in the async INFO protocol. + // If empty, do not remove the implicit servers from the pool. + if len(ncInfo.ConnectURLs) == 0 { + return nil + } + // Note about pool randomization: when the pool was first created, + // it was randomized (if allowed). We keep the order the same (removing + // implicit servers that are no longer sent to us). New URLs are sent + // to us in no specific order so don't need extra randomization. + hasNew := false + // This is what we got from the server we are connected to. urls := nc.info.ConnectURLs - if len(urls) > 0 { - added := false - // If randomization is allowed, shuffle the received array, not the - // entire pool. We want to preserve the pool's order up to this point - // (this would otherwise be problematic for the (re)connect loop). - if !nc.Opts.NoRandomize { - for i := range urls { - j := rand.Intn(i + 1) - urls[i], urls[j] = urls[j], urls[i] - } + // Transform that to a map for easy lookups + tmp := make(map[string]struct{}, len(urls)) + for _, curl := range urls { + tmp[curl] = struct{}{} + } + // Walk the pool and removed the implicit servers that are no longer in the + // given array/map + sp := nc.srvPool + for i := 0; i < len(sp); i++ { + srv := sp[i] + curl := srv.url.Host + // Check if this URL is in the INFO protocol + _, inInfo := tmp[curl] + // Remove from the temp map so that at the end we are left with only + // new (or restarted) servers that need to be added to the pool. + delete(tmp, curl) + // Keep servers that were set through Options, but also the one that + // we are currently connected to (even if it is a discovered server). + if !srv.isImplicit || srv.url == nc.url { + continue } - for _, curl := range urls { - if _, present := nc.urls[curl]; !present { - if err := nc.addURLToPool(fmt.Sprintf("nats://%s", curl), true); err != nil { - continue - } - added = true - } + if !inInfo { + // Remove from server pool. Keep current order. + copy(sp[i:], sp[i+1:]) + nc.srvPool = sp[:len(sp)-1] + sp = nc.srvPool + i-- } - if added && !nc.initc && nc.Opts.DiscoveredServersCB != nil { - nc.ach <- func() { nc.Opts.DiscoveredServersCB(nc) } + } + // If there are any left in the tmp map, these are new (or restarted) servers + // and need to be added to the pool. + for curl := range tmp { + // Before adding, check if this is a new (as in never seen) URL. + // This is used to figure out if we invoke the DiscoveredServersCB + if _, present := nc.urls[curl]; !present { + hasNew = true } + nc.addURLToPool(fmt.Sprintf("nats://%s", curl), true) + } + if hasNew && !nc.initc && nc.Opts.DiscoveredServersCB != nil { + nc.ach.push(func() { nc.Opts.DiscoveredServersCB(nc) }) } return nil } @@ -2802,9 +3012,9 @@ func (nc *Conn) close(status Status, doCBs bool) { // Clear any queued and blocking Requests. nc.clearPendingRequestCalls() - if nc.ptmr != nil { - nc.ptmr.Stop() - } + // Stop ping timer if set. + nc.stopPingTimer() + nc.ptmr = nil // Go ahead and make sure we have flushed the outbound if nc.conn != nil { @@ -2837,17 +3047,18 @@ func (nc *Conn) close(status Status, doCBs bool) { nc.subs = nil nc.subsMu.Unlock() + nc.status = status + // Perform appropriate callback if needed for a disconnect. if doCBs { if nc.Opts.DisconnectedCB != nil && nc.conn != nil { - nc.ach <- func() { nc.Opts.DisconnectedCB(nc) } + nc.ach.push(func() { nc.Opts.DisconnectedCB(nc) }) } if nc.Opts.ClosedCB != nil { - nc.ach <- func() { nc.Opts.ClosedCB(nc) } + nc.ach.push(func() { nc.Opts.ClosedCB(nc) }) } - nc.ach <- nc.closeAsyncFunc() + nc.ach.close() } - nc.status = status nc.mu.Unlock() } @@ -2978,3 +3189,51 @@ func (nc *Conn) TLSRequired() bool { defer nc.mu.Unlock() return nc.info.TLSRequired } + +// Barrier schedules the given function `f` to all registered asynchronous +// subscriptions. +// Only the last subscription to see this barrier will invoke the function. +// If no subscription is registered at the time of this call, `f()` is invoked +// right away. +// ErrConnectionClosed is returned if the connection is closed prior to +// the call. +func (nc *Conn) Barrier(f func()) error { + nc.mu.Lock() + if nc.isClosed() { + nc.mu.Unlock() + return ErrConnectionClosed + } + nc.subsMu.Lock() + // Need to figure out how many non chan subscriptions there are + numSubs := 0 + for _, sub := range nc.subs { + if sub.typ == AsyncSubscription { + numSubs++ + } + } + if numSubs == 0 { + nc.subsMu.Unlock() + nc.mu.Unlock() + f() + return nil + } + barrier := &barrierInfo{refs: int64(numSubs), f: f} + for _, sub := range nc.subs { + sub.mu.Lock() + if sub.mch == nil { + msg := &Msg{barrier: barrier} + // Push onto the async pList + if sub.pTail != nil { + sub.pTail.next = msg + } else { + sub.pHead = msg + sub.pCond.Signal() + } + sub.pTail = msg + } + sub.mu.Unlock() + } + nc.subsMu.Unlock() + nc.mu.Unlock() + return nil +} diff --git a/vendor/github.com/nats-io/go-nats/nats_test.go b/vendor/github.com/nats-io/go-nats/nats_test.go index cbd95632..680d72f0 100644 --- a/vendor/github.com/nats-io/go-nats/nats_test.go +++ b/vendor/github.com/nats-io/go-nats/nats_test.go @@ -1,3 +1,16 @@ +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package nats //////////////////////////////////////////////////////////////////////////////// @@ -10,9 +23,12 @@ import ( "encoding/json" "errors" "fmt" + "os" "reflect" "runtime" "strings" + "sync" + "sync/atomic" "testing" "time" @@ -40,7 +56,7 @@ func stackFatalf(t *testing.T, f string, args ...interface{}) { lines = append(lines, msg) // Generate the Stack of callers: Skip us and verify* frames. - for i := 2; true; i++ { + for i := 1; true; i++ { _, file, line, ok := runtime.Caller(i) if !ok { break @@ -51,6 +67,23 @@ func stackFatalf(t *testing.T, f string, args ...interface{}) { t.Fatalf("%s", strings.Join(lines, "\n")) } +func TestVersionMatchesTag(t *testing.T) { + tag := os.Getenv("TRAVIS_TAG") + if tag == "" { + t.SkipNow() + } + // We expect a tag of the form vX.Y.Z. If that's not the case, + // we need someone to have a look. So fail if first letter is not + // a `v` + if tag[0] != 'v' { + t.Fatalf("Expect tag to start with `v`, tag is: %s", tag) + } + // Strip the `v` from the tag for the version comparison. + if Version != tag[1:] { + t.Fatalf("Version (%s) does not match tag (%s)", Version, tag[1:]) + } +} + //////////////////////////////////////////////////////////////////////////////// // Reconnect tests //////////////////////////////////////////////////////////////////////////////// @@ -935,7 +968,7 @@ func TestAsyncINFO(t *testing.T) { } } - checkPool := func(inThatOrder bool, urls ...string) { + checkPool := func(urls ...string) { // Check both pool and urls map if len(c.srvPool) != len(urls) { stackFatalf(t, "Pool should have %d elements, has %d", len(urls), len(c.srvPool)) @@ -943,35 +976,27 @@ func TestAsyncINFO(t *testing.T) { if len(c.urls) != len(urls) { stackFatalf(t, "Map should have %d elements, has %d", len(urls), len(c.urls)) } - for i, url := range urls { - if inThatOrder { - if c.srvPool[i].url.Host != url { - stackFatalf(t, "Pool should have %q at index %q, has %q", url, i, c.srvPool[i].url.Host) - } - } else { - if _, present := c.urls[url]; !present { - stackFatalf(t, "Pool should have %q", url) - } + for _, url := range urls { + if _, present := c.urls[url]; !present { + stackFatalf(t, "Pool should have %q", url) } } } // Now test the decoding of "connect_urls" - // No randomize for now - c.Opts.NoRandomize = true // Reset the pool c.setupServerPool() // Reinitialize the parser c.ps = &parseState{} - info = []byte("INFO {\"connect_urls\":[\"localhost:5222\"]}\r\n") + info = []byte("INFO {\"connect_urls\":[\"localhost:4222\", \"localhost:5222\"]}\r\n") err = c.parse(info) if err != nil || c.ps.state != OP_START { t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) } // Pool now should contain localhost:4222 (the default URL) and localhost:5222 - checkPool(true, "localhost:4222", "localhost:5222") + checkPool("localhost:4222", "localhost:5222") // Make sure that if client receives the same, it is not added again. err = c.parse(info) @@ -979,84 +1004,16 @@ func TestAsyncINFO(t *testing.T) { t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) } // Pool should still contain localhost:4222 (the default URL) and localhost:5222 - checkPool(true, "localhost:4222", "localhost:5222") + checkPool("localhost:4222", "localhost:5222") // Receive a new URL - info = []byte("INFO {\"connect_urls\":[\"localhost:6222\"]}\r\n") + info = []byte("INFO {\"connect_urls\":[\"localhost:4222\", \"localhost:5222\", \"localhost:6222\"]}\r\n") err = c.parse(info) if err != nil || c.ps.state != OP_START { t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) } // Pool now should contain localhost:4222 (the default URL) localhost:5222 and localhost:6222 - checkPool(true, "localhost:4222", "localhost:5222", "localhost:6222") - - // Receive more than 1 URL at once - info = []byte("INFO {\"connect_urls\":[\"localhost:7222\", \"localhost:8222\"]}\r\n") - err = c.parse(info) - if err != nil || c.ps.state != OP_START { - t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) - } - // Pool now should contain localhost:4222 (the default URL) localhost:5222, localhost:6222 - // localhost:7222 and localhost:8222 - checkPool(true, "localhost:4222", "localhost:5222", "localhost:6222", "localhost:7222", "localhost:8222") - - // Test with pool randomization now. Note that with randominzation, - // the initial pool is randomize, then each array of urls that the - // client gets from the INFO protocol is randomized, but added to - // the end of the pool. - c.Opts.NoRandomize = false - c.setupServerPool() - - info = []byte("INFO {\"connect_urls\":[\"localhost:5222\"]}\r\n") - err = c.parse(info) - if err != nil || c.ps.state != OP_START { - t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) - } - // Pool now should contain localhost:4222 (the default URL) and localhost:5222 - checkPool(true, "localhost:4222", "localhost:5222") - - // Make sure that if client receives the same, it is not added again. - err = c.parse(info) - if err != nil || c.ps.state != OP_START { - t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) - } - // Pool should still contain localhost:4222 (the default URL) and localhost:5222 - checkPool(true, "localhost:4222", "localhost:5222") - - // Receive a new URL - info = []byte("INFO {\"connect_urls\":[\"localhost:6222\"]}\r\n") - err = c.parse(info) - if err != nil || c.ps.state != OP_START { - t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) - } - // Pool now should contain localhost:4222 (the default URL) localhost:5222 and localhost:6222 - checkPool(true, "localhost:4222", "localhost:5222", "localhost:6222") - - // Receive more than 1 URL at once. Add more than 2 to increase the chance of - // the array being shuffled. - info = []byte("INFO {\"connect_urls\":[\"localhost:7222\", \"localhost:8222\", " + - "\"localhost:9222\", \"localhost:10222\", \"localhost:11222\"]}\r\n") - err = c.parse(info) - if err != nil || c.ps.state != OP_START { - t.Fatalf("Unexpected: %d : %v\n", c.ps.state, err) - } - // Pool now should contain localhost:4222 (the default URL) localhost:5222, localhost:6222 - // localhost:7222, localhost:8222, localhost:9222, localhost:10222 and localhost:11222 - checkPool(false, "localhost:4222", "localhost:5222", "localhost:6222", "localhost:7222", "localhost:8222", - "localhost:9222", "localhost:10222", "localhost:11222") - - // Finally, check that (part of) the pool should be randomized. - allUrls := []string{"localhost:4222", "localhost:5222", "localhost:6222", "localhost:7222", "localhost:8222", - "localhost:9222", "localhost:10222", "localhost:11222"} - same := 0 - for i, url := range c.srvPool { - if url.url.Host == allUrls[i] { - same++ - } - } - if same == len(allUrls) { - t.Fatal("Pool does not seem to be randomized") - } + checkPool("localhost:4222", "localhost:5222", "localhost:6222") // Check that pool may be randomized on setup, but new URLs are always // added at end of pool. @@ -1147,31 +1104,121 @@ func TestConnServers(t *testing.T) { validateURLs(c.Servers(), "nats://localhost:4333", "nats://localhost:4444") } -func TestProcessErrAuthorizationError(t *testing.T) { - ach := make(chan asyncCB, 1) - called := make(chan error, 1) - c := &Conn{ - ach: ach, - Opts: Options{ - AsyncErrorCB: func(nc *Conn, sub *Subscription, err error) { - called <- err - }, - }, +func TestConnAsyncCBDeadlock(t *testing.T) { + s := RunServerOnPort(TEST_PORT) + defer s.Shutdown() + + ch := make(chan bool) + o := GetDefaultOptions() + o.Url = fmt.Sprintf("nats://127.0.0.1:%d", TEST_PORT) + o.ClosedCB = func(_ *Conn) { + ch <- true } - c.processErr("Authorization Violation") - select { - case cb := <-ach: - cb() - default: - t.Fatal("Expected callback on channel") + o.AsyncErrorCB = func(nc *Conn, sub *Subscription, err error) { + // do something with nc that requires locking behind the scenes + _ = nc.LastError() + } + nc, err := o.Connect() + if err != nil { + t.Fatalf("Should have connected ok: %v", err) } - select { - case err := <-called: - if err != ErrAuthorization { - t.Fatalf("Expected ErrAuthorization, got: %v", err) - } - default: - t.Fatal("Expected error on channel") + total := 300 + wg := &sync.WaitGroup{} + wg.Add(total) + for i := 0; i < total; i++ { + go func() { + // overwhelm asyncCB with errors + nc.processErr(AUTHORIZATION_ERR) + wg.Done() + }() + } + wg.Wait() + + nc.Close() + if e := Wait(ch); e != nil { + t.Fatal("Deadlock") + } +} + +func TestPingTimerLeakedOnClose(t *testing.T) { + s := RunServerOnPort(TEST_PORT) + defer s.Shutdown() + + nc, err := Connect(fmt.Sprintf("nats://127.0.0.1:%d", TEST_PORT)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + nc.Close() + // There was a bug (issue #338) that if connection + // was created and closed quickly, the pinger would + // be created from a go-routine and would cause the + // connection object to be retained until the ping + // timer fired. + // Wait a little bit and check if the timer is set. + // With the defect it would be. + time.Sleep(100 * time.Millisecond) + nc.mu.Lock() + pingTimerSet := nc.ptmr != nil + nc.mu.Unlock() + if pingTimerSet { + t.Fatal("Pinger timer should not be set") + } +} + +func TestNoEcho(t *testing.T) { + s := RunServerOnPort(TEST_PORT) + defer s.Shutdown() + + url := fmt.Sprintf("nats://127.0.0.1:%d", TEST_PORT) + + nc, err := Connect(url, NoEcho()) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + r := int32(0) + _, err = nc.Subscribe("foo", func(m *Msg) { + atomic.AddInt32(&r, 1) + }) + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + + err = nc.Publish("foo", []byte("Hello World")) + if err != nil { + t.Fatalf("Error on publish: %v", err) + } + nc.Flush() + nc.Flush() + + if nr := atomic.LoadInt32(&r); nr != 0 { + t.Fatalf("Expected no messages echoed back, received %d\n", nr) + } +} + +func TestNoEchoOldServer(t *testing.T) { + opts := GetDefaultOptions() + opts.Url = DefaultURL + opts.NoEcho = true + + nc := &Conn{Opts: opts} + if err := nc.setupServerPool(); err != nil { + t.Fatalf("Problem setting up Server Pool: %v\n", err) + } + + // Old style with no proto, meaning 0. We need Proto:1 for NoEcho support. + oldInfo := "{\"server_id\":\"22\",\"version\":\"1.1.0\",\"go\":\"go1.10.2\",\"port\":4222,\"max_payload\":1048576}" + + err := nc.processInfo(oldInfo) + if err != nil { + t.Fatalf("Error processing old style INFO: %v\n", err) + } + + // Make sure connectProto generates an error. + _, err = nc.connectProto() + if err == nil { + t.Fatalf("Expected an error but got none\n") } } diff --git a/vendor/github.com/nats-io/go-nats/netchan.go b/vendor/github.com/nats-io/go-nats/netchan.go index 0608fd7a..add3cba5 100644 --- a/vendor/github.com/nats-io/go-nats/netchan.go +++ b/vendor/github.com/nats-io/go-nats/netchan.go @@ -1,4 +1,15 @@ -// Copyright 2013-2017 Apcera Inc. All rights reserved. +// Copyright 2013-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. package nats @@ -41,7 +52,7 @@ func chPublish(c *EncodedConn, chVal reflect.Value, subject string) { if c.Conn.isClosed() { go c.Conn.Opts.AsyncErrorCB(c.Conn, nil, e) } else { - c.Conn.ach <- func() { c.Conn.Opts.AsyncErrorCB(c.Conn, nil, e) } + c.Conn.ach.push(func() { c.Conn.Opts.AsyncErrorCB(c.Conn, nil, e) }) } } return @@ -77,7 +88,7 @@ func (c *EncodedConn) bindRecvChan(subject, queue string, channel interface{}) ( if err := c.Enc.Decode(m.Subject, m.Data, oPtr.Interface()); err != nil { c.Conn.err = errors.New("nats: Got an error trying to unmarshal: " + err.Error()) if c.Conn.Opts.AsyncErrorCB != nil { - c.Conn.ach <- func() { c.Conn.Opts.AsyncErrorCB(c.Conn, m.Sub, c.Conn.err) } + c.Conn.ach.push(func() { c.Conn.Opts.AsyncErrorCB(c.Conn, m.Sub, c.Conn.err) }) } return } diff --git a/vendor/github.com/nats-io/go-nats/parser.go b/vendor/github.com/nats-io/go-nats/parser.go index 8359b8bc..a4b3ea0e 100644 --- a/vendor/github.com/nats-io/go-nats/parser.go +++ b/vendor/github.com/nats-io/go-nats/parser.go @@ -1,4 +1,15 @@ -// Copyright 2012-2017 Apcera Inc. All rights reserved. +// Copyright 2012-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. package nats diff --git a/vendor/github.com/nats-io/go-nats/timer.go b/vendor/github.com/nats-io/go-nats/timer.go index 1b96fd52..1216762d 100644 --- a/vendor/github.com/nats-io/go-nats/timer.go +++ b/vendor/github.com/nats-io/go-nats/timer.go @@ -1,3 +1,16 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package nats import ( diff --git a/vendor/github.com/nats-io/go-nats/timer_test.go b/vendor/github.com/nats-io/go-nats/timer_test.go index fb02a769..d561f967 100644 --- a/vendor/github.com/nats-io/go-nats/timer_test.go +++ b/vendor/github.com/nats-io/go-nats/timer_test.go @@ -1,3 +1,16 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package nats import ( diff --git a/vendor/github.com/nats-io/go-nats/util/tls.go b/vendor/github.com/nats-io/go-nats/util/tls.go new file mode 100644 index 00000000..53ff9aa2 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/util/tls.go @@ -0,0 +1,27 @@ +// Copyright 2017-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.8 + +package util + +import "crypto/tls" + +// CloneTLSConfig returns a copy of c. +func CloneTLSConfig(c *tls.Config) *tls.Config { + if c == nil { + return &tls.Config{} + } + + return c.Clone() +} diff --git a/vendor/github.com/nats-io/go-nats/util/tls_go17.go b/vendor/github.com/nats-io/go-nats/util/tls_go17.go new file mode 100644 index 00000000..fd646d31 --- /dev/null +++ b/vendor/github.com/nats-io/go-nats/util/tls_go17.go @@ -0,0 +1,49 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.7,!go1.8 + +package util + +import ( + "crypto/tls" +) + +// CloneTLSConfig returns a copy of c. Only the exported fields are copied. +// This is temporary, until this is provided by the language. +// https://go-review.googlesource.com/#/c/28075/ +func CloneTLSConfig(c *tls.Config) *tls.Config { + return &tls.Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + } +} diff --git a/vendor/github.com/nats-io/nuid/GOVERNANCE.md b/vendor/github.com/nats-io/nuid/GOVERNANCE.md new file mode 100644 index 00000000..01aee70d --- /dev/null +++ b/vendor/github.com/nats-io/nuid/GOVERNANCE.md @@ -0,0 +1,3 @@ +# NATS NUID Governance + +NATS NUID is part of the NATS project and is subject to the [NATS Governance](https://github.com/nats-io/nats-general/blob/master/GOVERNANCE.md). \ No newline at end of file diff --git a/vendor/github.com/nats-io/nuid/LICENSE b/vendor/github.com/nats-io/nuid/LICENSE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/vendor/github.com/nats-io/nuid/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/nats-io/nuid/MAINTAINERS.md b/vendor/github.com/nats-io/nuid/MAINTAINERS.md new file mode 100644 index 00000000..6d0ed3e3 --- /dev/null +++ b/vendor/github.com/nats-io/nuid/MAINTAINERS.md @@ -0,0 +1,6 @@ +# Maintainers + +Maintainership is on a per project basis. + +### Core-maintainers + - Derek Collison [@derekcollison](https://github.com/derekcollison) \ No newline at end of file diff --git a/vendor/github.com/nats-io/nuid/README.md b/vendor/github.com/nats-io/nuid/README.md new file mode 100644 index 00000000..16d8a735 --- /dev/null +++ b/vendor/github.com/nats-io/nuid/README.md @@ -0,0 +1,47 @@ +# NUID + +[![License Apache 2](https://img.shields.io/badge/License-Apache2-blue.svg)](https://www.apache.org/licenses/LICENSE-2.0) +[![ReportCard](http://goreportcard.com/badge/nats-io/nuid)](http://goreportcard.com/report/nats-io/nuid) +[![Build Status](https://travis-ci.org/nats-io/nuid.svg?branch=master)](http://travis-ci.org/nats-io/nuid) +[![Release](https://img.shields.io/badge/release-v1.0.0-1eb0fc.svg)](https://github.com/nats-io/nuid/releases/tag/v1.0.0) +[![GoDoc](http://godoc.org/github.com/nats-io/nuid?status.png)](http://godoc.org/github.com/nats-io/nuid) +[![Coverage Status](https://coveralls.io/repos/github/nats-io/nuid/badge.svg?branch=master)](https://coveralls.io/github/nats-io/nuid?branch=master) + +A highly performant unique identifier generator. + +## Installation + +Use the `go` command: + + $ go get github.com/nats-io/nuid + +## Basic Usage +```go + +// Utilize the global locked instance +nuid := nuid.Next() + +// Create an instance, these are not locked. +n := nuid.New() +nuid = n.Next() + +// Generate a new crypto/rand seeded prefix. +// Generally not needed, happens automatically. +n.RandomizePrefix() +``` + +## Performance +NUID needs to be very fast to generate and be truly unique, all while being entropy pool friendly. +NUID uses 12 bytes of crypto generated data (entropy draining), and 10 bytes of pseudo-random +sequential data that increments with a pseudo-random increment. + +Total length of a NUID string is 22 bytes of base 62 ascii text, so 62^22 or +2707803647802660400290261537185326956544 possibilities. + +NUID can generate identifiers as fast as 60ns, or ~16 million per second. There is an associated +benchmark you can use to test performance on your own hardware. + +## License + +Unless otherwise noted, the NATS source files are distributed +under the Apache Version 2.0 license found in the LICENSE file. diff --git a/vendor/github.com/nats-io/nuid/nuid.go b/vendor/github.com/nats-io/nuid/nuid.go new file mode 100644 index 00000000..d79e9ce1 --- /dev/null +++ b/vendor/github.com/nats-io/nuid/nuid.go @@ -0,0 +1,135 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// A unique identifier generator that is high performance, very fast, and tries to be entropy pool friendly. +package nuid + +import ( + "crypto/rand" + "fmt" + "math" + "math/big" + "sync" + "time" + + prand "math/rand" +) + +// NUID needs to be very fast to generate and truly unique, all while being entropy pool friendly. +// We will use 12 bytes of crypto generated data (entropy draining), and 10 bytes of sequential data +// that is started at a pseudo random number and increments with a pseudo-random increment. +// Total is 22 bytes of base 62 ascii text :) + +// Version of the library +const Version = "1.0.0" + +const ( + digits = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + base = 62 + preLen = 12 + seqLen = 10 + maxSeq = int64(839299365868340224) // base^seqLen == 62^10 + minInc = int64(33) + maxInc = int64(333) + totalLen = preLen + seqLen +) + +type NUID struct { + pre []byte + seq int64 + inc int64 +} + +type lockedNUID struct { + sync.Mutex + *NUID +} + +// Global NUID +var globalNUID *lockedNUID + +// Seed sequential random with crypto or math/random and current time +// and generate crypto prefix. +func init() { + r, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) + if err != nil { + prand.Seed(time.Now().UnixNano()) + } else { + prand.Seed(r.Int64()) + } + globalNUID = &lockedNUID{NUID: New()} + globalNUID.RandomizePrefix() +} + +// New will generate a new NUID and properly initialize the prefix, sequential start, and sequential increment. +func New() *NUID { + n := &NUID{ + seq: prand.Int63n(maxSeq), + inc: minInc + prand.Int63n(maxInc-minInc), + pre: make([]byte, preLen), + } + n.RandomizePrefix() + return n +} + +// Generate the next NUID string from the global locked NUID instance. +func Next() string { + globalNUID.Lock() + nuid := globalNUID.Next() + globalNUID.Unlock() + return nuid +} + +// Generate the next NUID string. +func (n *NUID) Next() string { + // Increment and capture. + n.seq += n.inc + if n.seq >= maxSeq { + n.RandomizePrefix() + n.resetSequential() + } + seq := n.seq + + // Copy prefix + var b [totalLen]byte + bs := b[:preLen] + copy(bs, n.pre) + + // copy in the seq in base36. + for i, l := len(b), seq; i > preLen; l /= base { + i -= 1 + b[i] = digits[l%base] + } + return string(b[:]) +} + +// Resets the sequential portion of the NUID. +func (n *NUID) resetSequential() { + n.seq = prand.Int63n(maxSeq) + n.inc = minInc + prand.Int63n(maxInc-minInc) +} + +// Generate a new prefix from crypto/rand. +// This call *can* drain entropy and will be called automatically when we exhaust the sequential range. +// Will panic if it gets an error from rand.Int() +func (n *NUID) RandomizePrefix() { + var cb [preLen]byte + cbs := cb[:] + if nb, err := rand.Read(cbs); nb != preLen || err != nil { + panic(fmt.Sprintf("nuid: failed generating crypto random number: %v\n", err)) + } + + for i := 0; i < preLen; i++ { + n.pre[i] = digits[int(cbs[i])%base] + } +} diff --git a/vendor/github.com/nats-io/nuid/nuid_test.go b/vendor/github.com/nats-io/nuid/nuid_test.go new file mode 100644 index 00000000..671a5539 --- /dev/null +++ b/vendor/github.com/nats-io/nuid/nuid_test.go @@ -0,0 +1,92 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nuid + +import ( + "bytes" + "testing" +) + +func TestDigits(t *testing.T) { + if len(digits) != base { + t.Fatalf("digits length does not match base modulo") + } +} + +func TestGlobalNUIDInit(t *testing.T) { + if globalNUID == nil { + t.Fatalf("Expected g to be non-nil\n") + } + if globalNUID.pre == nil || len(globalNUID.pre) != preLen { + t.Fatalf("Expected prefix to be initialized\n") + } + if globalNUID.seq == 0 { + t.Fatalf("Expected seq to be non-zero\n") + } +} + +func TestNUIDRollover(t *testing.T) { + globalNUID.seq = maxSeq + // copy + oldPre := append([]byte{}, globalNUID.pre...) + Next() + if bytes.Equal(globalNUID.pre, oldPre) { + t.Fatalf("Expected new pre, got the old one\n") + } +} + +func TestGUIDLen(t *testing.T) { + nuid := Next() + if len(nuid) != totalLen { + t.Fatalf("Expected len of %d, got %d\n", totalLen, len(nuid)) + } +} + +func TestProperPrefix(t *testing.T) { + min := byte(255) + max := byte(0) + for i := 0; i < len(digits); i++ { + if digits[i] < min { + min = digits[i] + } + if digits[i] > max { + max = digits[i] + } + } + total := 100000 + for i := 0; i < total; i++ { + n := New() + for j := 0; j < preLen; j++ { + if n.pre[j] < min || n.pre[j] > max { + t.Fatalf("Iter %d. Valid range for bytes prefix: [%d..%d]\nIncorrect prefix at pos %d: %v (%s)", + i, min, max, j, n.pre, string(n.pre)) + } + } + } +} + +func BenchmarkNUIDSpeed(b *testing.B) { + n := New() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + n.Next() + } +} + +func BenchmarkGlobalNUIDSpeed(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + Next() + } +} diff --git a/vendor/github.com/nats-io/nuid/unique_test.go b/vendor/github.com/nats-io/nuid/unique_test.go new file mode 100644 index 00000000..df979015 --- /dev/null +++ b/vendor/github.com/nats-io/nuid/unique_test.go @@ -0,0 +1,32 @@ +// Copyright 2016-2018 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !race + +package nuid + +import ( + "testing" +) + +func TestBasicUniqueness(t *testing.T) { + n := 10000000 + m := make(map[string]struct{}, n) + for i := 0; i < n; i++ { + n := Next() + if _, ok := m[n]; ok { + t.Fatalf("Duplicate NUID found: %v\n", n) + } + m[n] = struct{}{} + } +} diff --git a/vendor/vendor.json b/vendor/vendor.json index 84f657a3..a84f8d2c 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -3,10 +3,76 @@ "ignore": "", "package": [ { - "checksumSHA1": "nWIa0L7ux21Cb8kzB4rJHXMblpI=", + "checksumSHA1": "6i5WXEn6aAYiWSNCH0TsCyAhgv4=", + "path": "github.com/nats-io/gnatsd/conf", + "revision": "6608e9ac3be979dcb0614b772cc86a87b71acaa3", + "revisionTime": "2018-07-05T16:41:09Z" + }, + { + "checksumSHA1": "+BI6v9+uBu6jwqNPS0PbPvepW2Y=", + "path": "github.com/nats-io/gnatsd/logger", + "revision": "6608e9ac3be979dcb0614b772cc86a87b71acaa3", + "revisionTime": "2018-07-05T16:41:09Z" + }, + { + "checksumSHA1": "mklbylVp2WB0G8kkDmBaRFBqLqs=", + "path": "github.com/nats-io/gnatsd/server", + "revision": "6608e9ac3be979dcb0614b772cc86a87b71acaa3", + "revisionTime": "2018-07-05T16:41:09Z" + }, + { + "checksumSHA1": "P5raehO2Vgq7FKicqSdeRYoUJlo=", + "path": "github.com/nats-io/gnatsd/server/pse", + "revision": "6608e9ac3be979dcb0614b772cc86a87b71acaa3", + "revisionTime": "2018-07-05T16:41:09Z" + }, + { + "checksumSHA1": "7q/nF5pQYkUKwXkXM1H4yO7XgHU=", + "path": "github.com/nats-io/gnatsd/test", + "revision": "6608e9ac3be979dcb0614b772cc86a87b71acaa3", + "revisionTime": "2018-07-05T16:41:09Z" + }, + { + "checksumSHA1": "GfQ4LMqxsTdxY+4VjBPzejbMI6Q=", + "path": "github.com/nats-io/gnatsd/util", + "revision": "6608e9ac3be979dcb0614b772cc86a87b71acaa3", + "revisionTime": "2018-07-05T16:41:09Z" + }, + { + "checksumSHA1": "RNbZJ1hCzyADhvkfy0TwbJ6d3MM=", "path": "github.com/nats-io/go-nats", - "revision": "f0d9c5988d4c2a17ad466fcdffe010165c46434e", - "revisionTime": "2017-11-14T23:23:38Z" + "revision": "ff578ff05d4180a259da6d729af7b55285d24b78", + "revisionTime": "2018-07-25T00:55:40Z" + }, + { + "checksumSHA1": "C/TsNyFOKX2bTW3U9O3m5mQhJo8=", + "path": "github.com/nats-io/go-nats/encoders/builtin", + "revision": "ff578ff05d4180a259da6d729af7b55285d24b78", + "revisionTime": "2018-07-25T00:55:40Z" + }, + { + "checksumSHA1": "pG50W6cz2QHdf4/oW8sQQzPF7Wo=", + "path": "github.com/nats-io/go-nats/encoders/protobuf", + "revision": "ff578ff05d4180a259da6d729af7b55285d24b78", + "revisionTime": "2018-07-25T00:55:40Z" + }, + { + "checksumSHA1": "wXnHPa+satCH17fbk9XxK7MrxXw=", + "path": "github.com/nats-io/go-nats/encoders/protobuf/testdata", + "revision": "ff578ff05d4180a259da6d729af7b55285d24b78", + "revisionTime": "2018-07-25T00:55:40Z" + }, + { + "checksumSHA1": "mBZQa+wv/u+0qA1ceh6MrWBEgxs=", + "path": "github.com/nats-io/go-nats/util", + "revision": "ff578ff05d4180a259da6d729af7b55285d24b78", + "revisionTime": "2018-07-25T00:55:40Z" + }, + { + "checksumSHA1": "ufhg4R5aa96zJxjSy0JfliML97Y=", + "path": "github.com/nats-io/nuid", + "revision": "3024a71c3cbe30667286099921591e6fcc328230", + "revisionTime": "2018-07-12T04:49:59Z" } ], "rootPath": "github.com/tidwall/tile38"