Allow for standard SQS URLs

Both now work:

https://sqs.us-east-1.amazonaws.com/349840735605/TestTile38Queue
sqs://us-east-1:349840735605/TestTile38Queue
This commit is contained in:
tidwall 2019-03-13 15:41:49 -07:00
parent ec57aaee1a
commit 5335aec942
2 changed files with 58 additions and 22 deletions

View File

@ -90,6 +90,7 @@ type Endpoint struct {
KeyFile string KeyFile string
} }
SQS struct { SQS struct {
PlainURL string
QueueID string QueueID string
Region string Region string
CredPath string CredPath string
@ -217,7 +218,12 @@ func parseEndpoint(s string) (Endpoint, error) {
case strings.HasPrefix(s, "http:"): case strings.HasPrefix(s, "http:"):
endpoint.Protocol = HTTP endpoint.Protocol = HTTP
case strings.HasPrefix(s, "https:"): case strings.HasPrefix(s, "https:"):
endpoint.Protocol = HTTP if probeSQS(s) {
endpoint.SQS.PlainURL = s
endpoint.Protocol = SQS
} else {
endpoint.Protocol = HTTP
}
case strings.HasPrefix(s, "disque:"): case strings.HasPrefix(s, "disque:"):
endpoint.Protocol = Disque endpoint.Protocol = Disque
case strings.HasPrefix(s, "grpc:"): case strings.HasPrefix(s, "grpc:"):
@ -469,22 +475,28 @@ func parseEndpoint(s string) (Endpoint, error) {
// credpath - path where aws credentials are located // credpath - path where aws credentials are located
// credprofile - credential profile // credprofile - credential profile
if endpoint.Protocol == SQS { if endpoint.Protocol == SQS {
// Parsing connection from URL string if endpoint.SQS.PlainURL == "" {
hp := strings.Split(s, ":") // Parsing connection from URL string
switch len(hp) { hp := strings.Split(s, ":")
default: switch len(hp) {
return endpoint, errors.New("invalid SQS url") default:
case 2: return endpoint, errors.New("invalid SQS url")
endpoint.SQS.Region = hp[0] case 2:
endpoint.SQS.QueueID = hp[1] endpoint.SQS.Region = hp[0]
} endpoint.SQS.QueueID = hp[1]
}
// Parsing SQS queue name // Parsing SQS queue name
if len(sp) > 1 { if len(sp) > 1 {
var err error var err error
endpoint.SQS.QueueName, err = url.QueryUnescape(sp[1]) endpoint.SQS.QueueName, err = url.QueryUnescape(sp[1])
if err != nil { if err != nil {
return endpoint, errors.New("invalid SQS queue name") return endpoint, errors.New("invalid SQS queue name")
}
}
// Throw error if we not provide any queue name
if endpoint.SQS.QueueName == "" {
return endpoint, errors.New("missing SQS queue name")
} }
} }
@ -512,10 +524,6 @@ func parseEndpoint(s string) (Endpoint, error) {
} }
} }
} }
// Throw error if we not provide any queue name
if endpoint.SQS.QueueName == "" {
return endpoint, errors.New("missing SQS queue name")
}
} }
// Basic AMQP connection strings in HOOKS interface // Basic AMQP connection strings in HOOKS interface

View File

@ -3,6 +3,7 @@ package endpoint
import ( import (
"errors" "errors"
"fmt" "fmt"
"strings"
"sync" "sync"
"time" "time"
@ -31,7 +32,11 @@ type SQSConn struct {
} }
func (conn *SQSConn) generateSQSURL() string { func (conn *SQSConn) generateSQSURL() string {
return "https://sqs." + conn.ep.SQS.Region + "amazonaws.com/" + conn.ep.SQS.QueueID + "/" + conn.ep.SQS.QueueName if conn.ep.SQS.PlainURL != "" {
return conn.ep.SQS.PlainURL
}
return "https://sqs." + conn.ep.SQS.Region + ".amazonaws.com/" +
conn.ep.SQS.QueueID + "/" + conn.ep.SQS.QueueName
} }
// Expired returns true if the connection has expired // Expired returns true if the connection has expired
@ -74,8 +79,14 @@ func (conn *SQSConn) Send(msg string) error {
} }
creds = credentials.NewSharedCredentials(credPath, credProfile) creds = credentials.NewSharedCredentials(credPath, credProfile)
} }
var region string
if conn.ep.SQS.Region != "" {
region = conn.ep.SQS.Region
} else {
region = sqsRegionFromPlainURL(conn.ep.SQS.PlainURL)
}
sess := session.Must(session.NewSession(&aws.Config{ sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(conn.ep.SQS.Region), Region: &region,
Credentials: creds, Credentials: creds,
MaxRetries: aws.Int(5), MaxRetries: aws.Int(5),
})) }))
@ -114,3 +125,20 @@ func newSQSConn(ep Endpoint) *SQSConn {
t: time.Now(), t: time.Now(),
} }
} }
func probeSQS(s string) bool {
// https://sqs.eu-central-1.amazonaws.com/123456789/myqueue
return strings.HasPrefix(s, "https://sqs.") &&
strings.Contains(s, ".amazonaws.com/")
}
func sqsRegionFromPlainURL(s string) string {
parts := strings.Split(s, "https://sqs.")
if len(parts) > 1 {
parts = strings.Split(parts[1], ".amazonaws.com/")
if len(parts) > 1 {
return parts[0]
}
}
return ""
}