Update aws-sdk-go package

This commit is contained in:
tidwall 2019-03-11 09:18:55 -07:00
parent e05b3dc25c
commit 3e64497a51
2170 changed files with 991682 additions and 225284 deletions

24
Gopkg.lock generated
View File

@ -10,7 +10,7 @@
version = "v1.13.0"
[[projects]]
digest = "1:11c577e4f8478a8b397adb7976f1f201b31951a5473d7e794c6423a1224dab80"
digest = "1:718dd81b0dcfc49323360bcc7b4f940c9d5f27106d17eddb25a714e9b2563d24"
name = "github.com/aws/aws-sdk-go"
packages = [
"aws",
@ -22,13 +22,19 @@
"aws/credentials",
"aws/credentials/ec2rolecreds",
"aws/credentials/endpointcreds",
"aws/credentials/processcreds",
"aws/credentials/stscreds",
"aws/csm",
"aws/defaults",
"aws/ec2metadata",
"aws/endpoints",
"aws/request",
"aws/session",
"aws/signer/v4",
"internal/ini",
"internal/sdkio",
"internal/sdkrand",
"internal/sdkuri",
"internal/shareddefaults",
"private/protocol",
"private/protocol/query",
@ -39,8 +45,8 @@
"service/sts",
]
pruneopts = ""
revision = "da415b5fa0ff3f91d4707348a8ea1be53f700c22"
version = "v1.12.6"
revision = "8b705a6dec722bcda3a9309c0924d4eca24f7c72"
version = "v1.17.14"
[[projects]]
digest = "1:56c130d885a4aacae1dd9c7b71cfe39912c7ebc1ff7d2b46083c8812996dc43b"
@ -85,14 +91,6 @@
revision = "aff15770515e3c57fc6109da73d42b0d46f7f483"
version = "v1.1.0"
[[projects]]
digest = "1:e26d0f8ccaf4087b93306d5e20e4816258116868ad6c6da0f2b4926c72fb92fa"
name = "github.com/go-ini/ini"
packages = ["."]
pruneopts = ""
revision = "20b96f641a5ea98f2f8619ff4f3e061cff4833bd"
version = "v1.28.2"
[[projects]]
branch = "master"
digest = "1:27854310d59099f8dcc61dd8af4a69f0a3597f001154b2fb4d1c41baf2e31ec1"
@ -124,11 +122,11 @@
revision = "e8fc0692a7e26a05b06517348ed466349062eb47"
[[projects]]
digest = "1:6f49eae0c1e5dab1dafafee34b207aeb7a42303105960944828c2079b92fc88e"
digest = "1:13fe471d0ed891e8544eddfeeb0471fd3c9f2015609a1c000aefdedf52a19d40"
name = "github.com/jmespath/go-jmespath"
packages = ["."]
pruneopts = ""
revision = "0b12d6b5"
revision = "c2b33e84"
[[projects]]
digest = "1:be1d3b7623f11b628068933282015f3dbcd522fa8e6b16d2931edffd42ef2c0b"

View File

@ -10,5 +10,5 @@ Please fill out the sections below to help us address your issue.
### Steps to reproduce
If you have have an runnable example, please include it.
If you have an runnable example, please include it.

View File

@ -3,7 +3,7 @@
"Pattern": "/sdk-for-go/api/",
"StripPrefix": "/sdk-for-go/api",
"Include": ["/src/github.com/aws/aws-sdk-go/aws", "/src/github.com/aws/aws-sdk-go/service"],
"Exclude": ["/src/cmd", "/src/github.com/aws/aws-sdk-go/awstesting", "/src/github.com/aws/aws-sdk-go/awsmigrate"],
"Exclude": ["/src/cmd", "/src/github.com/aws/aws-sdk-go/awstesting", "/src/github.com/aws/aws-sdk-go/awsmigrate", "/src/github.com/aws/aws-sdk-go/private"],
"IgnoredSuffixes": ["iface"]
},
"Github": {

View File

@ -2,24 +2,41 @@ language: go
sudo: required
os:
- linux
- osx
go:
- 1.5
- 1.6
- 1.7
- 1.8
- 1.9
- tip
# Use Go 1.5's vendoring experiment for 1.5 tests.
env:
- GO15VENDOREXPERIMENT=1
install:
- make get-deps
script:
- make unit-with-race-cover
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
- 1.11.x
- tip
matrix:
allow_failures:
- go: tip
allow_failures:
- go: tip
exclude:
# OSX 1.6.4 is not present in travis.
# https://github.com/travis-ci/travis-ci/issues/10309
- go: 1.6.x
os: osx
include:
- os: linux
go: 1.5.x
# Use Go 1.5's vendoring experiment for 1.5 tests.
env: GO15VENDOREXPERIMENT=1
script:
- if [ $TRAVIS_GO_VERSION == "tip" ] ||
[ $TRAVIS_GO_VERSION == "1.11.x" ] ||
[ $TRAVIS_GO_VERSION == "1.10.x" ]; then
make ci-test;
else
make unit-old-go-race-cover;
fi
branches:
only:
- master

File diff suppressed because it is too large Load Diff

4
vendor/github.com/aws/aws-sdk-go/CODE_OF_CONDUCT.md generated vendored Normal file
View File

@ -0,0 +1,4 @@
## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
opensource-codeofconduct@amazon.com with any additional questions or comments.

View File

@ -56,7 +56,7 @@ Please be aware of the following notes prior to opening a pull request:
3. Wherever possible, pull requests should contain tests as appropriate.
Bugfixes should contain tests that exercise the corrected behavior (i.e., the
test should fail without the bugfix and pass with it), and new features
test should fail without the bugfix and pass with it), and new features
should be accompanied by tests exercising the feature.
4. Pull requests that contain failing tests will not be merged until the test
@ -67,11 +67,11 @@ Please be aware of the following notes prior to opening a pull request:
5. The JSON files under the SDK's `models` folder are sourced from outside the SDK.
Such as `models/apis/ec2/2016-11-15/api.json`. We will not accept pull requests
directly on these models. If you discover an issue with the models please
create a Github [issue](issues) describing the issue.
create a [GitHub issue][issues] describing the issue.
### Testing
To run the tests locally, running the `make unit` command will `go get` the
To run the tests locally, running the `make unit` command will `go get` the
SDK's testing dependencies, and run vet, link and unit tests for the SDK.
```
@ -88,8 +88,8 @@ go test -tags codegen ./private/...
See the `Makefile` for additional testing tags that can be used in testing.
To test on multiple platform the SDK includes several DockerFiles under the
`awstesting/sandbox` folder, and associated make recipes to to execute
To test on multiple platform the SDK includes several DockerFiles under the
`awstesting/sandbox` folder, and associated make recipes to execute
unit testing within environments configured for specific Go versions.
```

View File

@ -2,19 +2,15 @@
[[projects]]
name = "github.com/go-ini/ini"
packages = ["."]
revision = "300e940a926eb277d3901b20bdfcc54928ad3642"
version = "v1.25.4"
[[projects]]
digest = "1:13fe471d0ed891e8544eddfeeb0471fd3c9f2015609a1c000aefdedf52a19d40"
name = "github.com/jmespath/go-jmespath"
packages = ["."]
revision = "0b12d6b5"
pruneopts = ""
revision = "c2b33e84"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "51a86a867df617990082dec6b868e4efe2fdb2ed0e02a3daa93cd30f962b5085"
input-imports = ["github.com/jmespath/go-jmespath"]
solver-name = "gps-cdcl"
solver-version = 1

View File

@ -37,12 +37,7 @@ ignored = [
"golang.org/x/tools/go/loader",
]
[[constraint]]
name = "github.com/go-ini/ini"
version = "1.25.4"
[[constraint]]
name = "github.com/jmespath/go-jmespath"
revision = "0b12d6b5"
revision = "c2b33e84"
#version = "0.2.2"

View File

@ -1,27 +1,30 @@
LINTIGNOREDOT='awstesting/integration.+should not use dot imports'
LINTIGNOREDOC='service/[^/]+/(api|service|waiters)\.go:.+(comment on exported|should have comment or be unexported)'
LINTIGNORECONST='service/[^/]+/(api|service|waiters)\.go:.+(type|struct field|const|func) ([^ ]+) should be ([^ ]+)'
LINTIGNORESTUTTER='service/[^/]+/(api|service)\.go:.+(and that stutters)'
LINTIGNOREINFLECT='service/[^/]+/(api|errors|service)\.go:.+(method|const) .+ should be '
LINTIGNOREINFLECTS3UPLOAD='service/s3/s3manager/upload\.go:.+struct field SSEKMSKeyId should be '
LINTIGNOREENDPOINTS='aws/endpoints/(defaults|dep_service_ids).go:.+(method|const) .+ should be '
LINTIGNOREDEPS='vendor/.+\.go'
LINTIGNOREPKGCOMMENT='service/[^/]+/doc_custom.go:.+package comment should be of the form'
UNIT_TEST_TAGS="example codegen awsinclude"
SDK_WITH_VENDOR_PKGS=$(shell go list -tags ${UNIT_TEST_TAGS} ./... | grep -v "/vendor/src")
SDK_ONLY_PKGS=$(shell go list ./... | grep -v "/vendor/")
SDK_UNIT_TEST_ONLY_PKGS=$(shell go list -tags ${UNIT_TEST_TAGS} ./... | grep -v "/vendor/")
SDK_GO_1_4=$(shell go version | grep "go1.4")
SDK_GO_1_5=$(shell go version | grep "go1.5")
SDK_GO_VERSION=$(shell go version | awk '''{print $$3}''' | tr -d '''\n''')
# SDK's Core and client packages that are compatable with Go 1.5+.
SDK_CORE_PKGS=./aws/... ./private/... ./internal/...
SDK_CLIENT_PKGS=./service/...
SDK_COMPA_PKGS=${SDK_CORE_PKGS} ${SDK_CLIENT_PKGS}
all: get-deps generate unit
# SDK additional packages that are used for development of the SDK.
SDK_EXAMPLES_PKGS=./examples/...
SDK_TESTING_PKGS=./awstesting/...
SDK_MODELS_PKGS=./models/...
SDK_ALL_PKGS=${SDK_COMPA_PKGS} ${SDK_TESTING_PKGS} ${SDK_EXAMPLES_PKGS} ${SDK_MODELS_PKGS}
all: generate unit
help:
@echo "Please use \`make <target>' where <target> is one of"
@echo " api_info to print a list of services and versions"
@echo " docs to build SDK documentation"
@echo " build to go build the SDK"
@echo " unit to run unit tests"
@echo " integration to run integration tests"
@echo " performance to run performance tests"
@ -32,49 +35,83 @@ help:
@echo " gen-test to generate protocol tests"
@echo " gen-services to generate services"
@echo " get-deps to go get the SDK dependencies"
@echo " get-deps-tests to get the SDK's test dependencies"
@echo " get-deps-verify to get the SDK's verification dependencies"
generate: gen-test gen-endpoints gen-services
###################
# Code Generation #
###################
generate: cleanup-models gen-test gen-endpoints gen-services
gen-test: gen-protocol-test
gen-test: gen-protocol-test gen-codegen-test
gen-services:
gen-codegen-test: get-deps-codegen
@echo "Generating SDK API tests"
go generate ./private/model/api/codegentest/service
gen-services: get-deps-codegen
@echo "Generating SDK clients"
go generate ./service
gen-protocol-test:
gen-protocol-test: get-deps-codegen
@echo "Generating SDK protocol tests"
go generate ./private/protocol/...
gen-endpoints:
gen-endpoints: get-deps-codegen
@echo "Generating SDK endpoints"
go generate ./models/endpoints/
build:
@echo "go build SDK and vendor packages"
@go build ${SDK_ONLY_PKGS}
cleanup-models: get-deps-codegen
@echo "Cleaning up stale model versions"
go run -tags codegen ./private/model/cli/cleanup-models/* "./models/apis/*/*/api-2.json"
unit: get-deps-tests build verify
###################
# Unit/CI Testing #
###################
unit: get-deps verify
@echo "go test SDK and vendor packages"
@go test -tags ${UNIT_TEST_TAGS} $(SDK_UNIT_TEST_ONLY_PKGS)
go test -tags ${UNIT_TEST_TAGS} ${SDK_ALL_PKGS}
unit-with-race-cover: get-deps-tests build verify
unit-with-race-cover: get-deps verify
@echo "go test SDK and vendor packages"
@go test -tags ${UNIT_TEST_TAGS} -race -cpu=1,2,4 $(SDK_UNIT_TEST_ONLY_PKGS)
go test -tags ${UNIT_TEST_TAGS} -race -cpu=1,2,4 ${SDK_ALL_PKGS}
integration: get-deps-tests integ-custom smoke-tests performance
unit-old-go-race-cover: get-deps-tests
@echo "go test SDK only packages for old Go versions"
go test -race -cpu=1,2,4 ${SDK_COMPA_PKGS}
integ-custom:
go test -tags "integration" ./awstesting/integration/customizations/...
ci-test: generate unit-with-race-cover ci-test-generate-validate
cleanup-integ:
ci-test-generate-validate:
@echo "CI test validate no generated code changes"
git add . -A
gitstatus=`git diff --cached --ignore-space-change`; \
echo "$$gitstatus"; \
if [ "$$gitstatus" != "" ] && [ "$$gitstatus" != "skipping validation" ]; then echo "$$gitstatus"; exit 1; fi
#######################
# Integration Testing #
#######################
integration: core-integ client-integ
core-integ:
@echo "Integration Testing SDK core"
AWS_REGION="" go test -count=1 -tags "integration" -v -run '^TestInteg_' ./aws/... ./private/... ./internal/... ./awstesting/...
client-integ:
@echo "Integration Testing SDK clients"
AWS_REGION="" go test -count=1 -tags "integration" -v -run '^TestInteg_' ./service/...
s3crypto-integ:
@echo "Integration Testing S3 Cyrpto utility"
AWS_REGION="" go test -count=1 -tags "s3crypto_integ integration" -v -run '^TestInteg_' ./service/s3/s3crypto
cleanup-integ-buckets:
@echo "Cleaning up SDK integraiton resources"
go run -tags "integration" ./awstesting/cmd/bucket_cleanup/main.go "aws-sdk-go-integration"
smoke-tests: get-deps-tests
gucumber -go-tags "integration" ./awstesting/integration/smoke
performance: get-deps-tests
AWS_TESTING_LOG_RESULTS=${log-detailed} AWS_TESTING_REGION=$(region) AWS_TESTING_DB_TABLE=$(table) gucumber -go-tags "integration" ./awstesting/performance
sandbox-tests: sandbox-test-go15 sandbox-test-go15-novendorexp sandbox-test-go16 sandbox-test-go17 sandbox-test-go18 sandbox-test-go19 sandbox-test-gotip
###################
# Sandbox Testing #
###################
sandbox-tests: sandbox-test-go15 sandbox-test-go16 sandbox-test-go17 sandbox-test-go18 sandbox-test-go19 sandbox-test-go110 sandbox-test-go111 sandbox-test-gotip
sandbox-build-go15:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.5 -t "aws-sdk-go-1.5" .
@ -112,12 +149,26 @@ sandbox-test-go18: sandbox-build-go18
docker run -t aws-sdk-go-1.8
sandbox-build-go19:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.8 -t "aws-sdk-go-1.9" .
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.9 -t "aws-sdk-go-1.9" .
sandbox-go19: sandbox-build-go19
docker run -i -t aws-sdk-go-1.9 bash
sandbox-test-go19: sandbox-build-go19
docker run -t aws-sdk-go-1.9
sandbox-build-go110:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.10 -t "aws-sdk-go-1.10" .
sandbox-go110: sandbox-build-go110
docker run -i -t aws-sdk-go-1.10 bash
sandbox-test-go110: sandbox-build-go110
docker run -t aws-sdk-go-1.10
sandbox-build-go111:
docker build -f ./awstesting/sandbox/Dockerfile.test.go1.11 -t "aws-sdk-go-1.11" .
sandbox-go111: sandbox-build-go111
docker run -i -t aws-sdk-go-1.11 bash
sandbox-test-go111: sandbox-build-go111
docker run -t aws-sdk-go-1.11
sandbox-build-gotip:
@echo "Run make update-aws-golang-tip, if this test fails because missing aws-golang:tip container"
docker build -f ./awstesting/sandbox/Dockerfile.test.gotip -t "aws-sdk-go-tip" .
@ -129,59 +180,56 @@ sandbox-test-gotip: sandbox-build-gotip
update-aws-golang-tip:
docker build --no-cache=true -f ./awstesting/sandbox/Dockerfile.golang-tip -t "aws-golang:tip" .
##################
# Linting/Verify #
##################
verify: get-deps-verify lint vet
lint:
@echo "go lint SDK and vendor packages"
@lint=`if [ \( -z "${SDK_GO_1_4}" \) -a \( -z "${SDK_GO_1_5}" \) ]; then golint ./...; else echo "skipping golint"; fi`; \
lint=`echo "$$lint" | grep -E -v -e ${LINTIGNOREDOT} -e ${LINTIGNOREDOC} -e ${LINTIGNORECONST} -e ${LINTIGNORESTUTTER} -e ${LINTIGNOREINFLECT} -e ${LINTIGNOREDEPS} -e ${LINTIGNOREINFLECTS3UPLOAD} -e ${LINTIGNOREPKGCOMMENT}`; \
echo "$$lint"; \
if [ "$$lint" != "" ] && [ "$$lint" != "skipping golint" ]; then exit 1; fi
SDK_BASE_FOLDERS=$(shell ls -d */ | grep -v vendor | grep -v awsmigrate)
ifneq (,$(findstring go1.4, ${SDK_GO_VERSION}))
GO_VET_CMD=echo skipping go vet, ${SDK_GO_VERSION}
else ifneq (,$(findstring go1.6, ${SDK_GO_VERSION}))
GO_VET_CMD=go tool vet --all -shadow -example=false
else
GO_VET_CMD=go tool vet --all -shadow
endif
@lint=`golint ./...`; \
dolint=`echo "$$lint" | grep -E -v -e ${LINTIGNOREDOC} -e ${LINTIGNORECONST} -e ${LINTIGNORESTUTTER} -e ${LINTIGNOREINFLECT} -e ${LINTIGNOREDEPS} -e ${LINTIGNOREINFLECTS3UPLOAD} -e ${LINTIGNOREPKGCOMMENT} -e ${LINTIGNOREENDPOINTS}`; \
echo "$$dolint"; \
if [ "$$dolint" != "" ]; then exit 1; fi
vet:
${GO_VET_CMD} ${SDK_BASE_FOLDERS}
go vet -tags "example codegen awsinclude integration" --all ${SDK_ALL_PKGS}
get-deps: get-deps-tests get-deps-verify
@echo "go get SDK dependencies"
@go get -v $(SDK_ONLY_PKGS)
################
# Dependencies #
################
get-deps: get-deps-tests get-deps-x-tests get-deps-codegen get-deps-verify
get-deps-tests:
@echo "go get SDK testing dependencies"
go get github.com/gucumber/gucumber/cmd/gucumber
go get github.com/stretchr/testify
go get github.com/smartystreets/goconvey
go get golang.org/x/net/html
get-deps-x-tests:
@echo "go get SDK testing golang.org/x dependencies"
go get golang.org/x/net/http2
get-deps-codegen: get-deps-x-tests
@echo "go get SDK codegen dependencies"
go get golang.org/x/net/html
get-deps-verify:
@echo "go get SDK verification utilities"
@if [ \( -z "${SDK_GO_1_4}" \) -a \( -z "${SDK_GO_1_5}" \) ]; then go get github.com/golang/lint/golint; else echo "skipped getting golint"; fi
go get golang.org/x/lint/golint
bench:
@echo "go bench SDK packages"
@go test -run NONE -bench . -benchmem -tags 'bench' $(SDK_ONLY_PKGS)
go test -run NONE -bench . -benchmem -tags 'bench' ${SDK_ALL_PKGS}
bench-protocol:
@echo "go bench SDK protocol marshallers"
@go test -run NONE -bench . -benchmem -tags 'bench' ./private/protocol/...
go test -run NONE -bench . -benchmem -tags 'bench' ./private/protocol/...
#############
# Utilities #
#############
docs:
@echo "generate SDK docs"
@# This env variable, DOCS, is for internal use
@if [ -z ${AWS_DOC_GEN_TOOL} ]; then\
rm -rf doc && bundle install && bundle exec yard;\
else\
$(AWS_DOC_GEN_TOOL) `pwd`;\
fi
$(AWS_DOC_GEN_TOOL) `pwd`
api_info:
@go run private/model/cli/api-info/api-info.go

View File

@ -1,3 +1,3 @@
AWS SDK for Go
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copyright 2014-2015 Stripe, Inc.

View File

@ -1,33 +1,59 @@
[![API Reference](http://img.shields.io/badge/api-reference-blue.svg)](http://docs.aws.amazon.com/sdk-for-go/api) [![Join the chat at https://gitter.im/aws/aws-sdk-go](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/aws/aws-sdk-go?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Build Status](https://img.shields.io/travis/aws/aws-sdk-go.svg)](https://travis-ci.org/aws/aws-sdk-go) [![Apache V2 License](http://img.shields.io/badge/license-Apache%20V2-blue.svg)](https://github.com/aws/aws-sdk-go/blob/master/LICENSE.txt)
[![API Reference](https://img.shields.io/badge/api-reference-blue.svg)](https://docs.aws.amazon.com/sdk-for-go/api) [![Join the chat at https://gitter.im/aws/aws-sdk-go](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/aws/aws-sdk-go?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Build Status](https://img.shields.io/travis/aws/aws-sdk-go.svg)](https://travis-ci.org/aws/aws-sdk-go) [![Apache V2 License](https://img.shields.io/badge/license-Apache%20V2-blue.svg)](https://github.com/aws/aws-sdk-go/blob/master/LICENSE.txt)
# AWS SDK for Go
aws-sdk-go is the official AWS SDK for the Go programming language.
Checkout our [release notes](https://github.com/aws/aws-sdk-go/releases) for information about the latest bug fixes, updates, and features added to the SDK.
Checkout our [release notes](https://github.com/aws/aws-sdk-go/releases) for
information about the latest bug fixes, updates, and features added to the SDK.
We [announced](https://aws.amazon.com/blogs/developer/aws-sdk-for-go-2-0-developer-preview/) the Developer Preview for the [v2 AWS SDK for Go](https://github.com/aws/aws-sdk-go-v2). The v2 SDK source is available at https://github.com/aws/aws-sdk-go-v2, and add it to your project with `go get github.com/aws/aws-sdk-go-v2`. Check out the v2 SDK's [changes and updates](https://github.com/aws/aws-sdk-go-v2/blob/master/CHANGELOG.md), and let us know what you think. We want your feedback.
## Installing
If you are using Go 1.5 with the `GO15VENDOREXPERIMENT=1` vendoring flag, or 1.6 and higher you can use the following command to retrieve the SDK. The SDK's non-testing dependencies will be included and are vendored in the `vendor` folder.
Use `go get` to retrieve the SDK to add it to your `GOPATH` workspace, or
project's Go module dependencies.
go get -u github.com/aws/aws-sdk-go
go get github.com/aws/aws-sdk-go
Otherwise if your Go environment does not have vendoring support enabled, or you do not want to include the vendored SDK's dependencies you can use the following command to retrieve the SDK and its non-testing dependencies using `go get`.
To update the SDK use `go get -u` to retrieve the latest version of the SDK.
go get -u github.com/aws/aws-sdk-go/aws/...
go get -u github.com/aws/aws-sdk-go/service/...
go get -u github.com/aws/aws-sdk-go
If you're looking to retrieve just the SDK without any dependencies use the following command.
### Dependencies
go get -d github.com/aws/aws-sdk-go/
The SDK includes a `vendor` folder containing the runtime dependencies of the
SDK. The metadata of the SDK's dependencies can be found in the Go module file
`go.mod` or Dep file `Gopkg.toml`.
These two processes will still include the `vendor` folder and it should be deleted if its not going to be used by your environment.
### Go Modules
If you are using Go modules, your `go get` will default to the latest tagged
release version of the SDK. To get a specific release version of the SDK use
`@<tag>` in your `go get` command.
go get github.com/aws/aws-sdk-go@v1.15.77
To get the latest SDK repository change use `@latest`.
go get github.com/aws/aws-sdk-go@latest
### Go 1.5
If you are using Go 1.5 without vendoring enabled, (`GO15VENDOREXPERIMENT=1`),
you will need to use `...` when retrieving the SDK to get its dependencies.
go get github.com/aws/aws-sdk-go/...
This will still include the `vendor` folder. The `vendor` folder can be deleted
if not used by your environment.
rm -rf $GOPATH/src/github.com/aws/aws-sdk-go/vendor
## Getting Help
Please use these community resources for getting help. We use the GitHub issues for tracking bugs and feature requests.
Please use these community resources for getting help. We use the GitHub issues
for tracking bugs and feature requests.
* Ask a question on [StackOverflow](http://stackoverflow.com/) and tag it with the [`aws-sdk-go`](http://stackoverflow.com/questions/tagged/aws-sdk-go) tag.
* Come join the AWS SDK for Go community chat on [gitter](https://gitter.im/aws/aws-sdk-go).
@ -36,19 +62,44 @@ Please use these community resources for getting help. We use the GitHub issues
## Opening Issues
If you encounter a bug with the AWS SDK for Go we would like to hear about it. Search the [existing issues](https://github.com/aws/aws-sdk-go/issues) and see if others are also experiencing the issue before opening a new issue. Please include the version of AWS SDK for Go, Go language, and OS youre using. Please also include repro case when appropriate.
If you encounter a bug with the AWS SDK for Go we would like to hear about it.
Search the [existing issues](https://github.com/aws/aws-sdk-go/issues) and see
if others are also experiencing the issue before opening a new issue. Please
include the version of AWS SDK for Go, Go language, and OS youre using. Please
also include reproduction case when appropriate.
The GitHub issues are intended for bug reports and feature requests. For help and questions with using AWS SDK for GO please make use of the resources listed in the [Getting Help](https://github.com/aws/aws-sdk-go#getting-help) section. Keeping the list of open issues lean will help us respond in a timely manner.
The GitHub issues are intended for bug reports and feature requests. For help
and questions with using AWS SDK for GO please make use of the resources listed
in the [Getting Help](https://github.com/aws/aws-sdk-go#getting-help) section.
Keeping the list of open issues lean will help us respond in a timely manner.
## Reference Documentation
[`Getting Started Guide`](https://aws.amazon.com/sdk-for-go/) - This document is a general introduction how to configure and make requests with the SDK. If this is your first time using the SDK, this documentation and the API documentation will help you get started. This document focuses on the syntax and behavior of the SDK. The [Service Developer Guide](https://aws.amazon.com/documentation/) will help you get started using specific AWS services.
[`Getting Started Guide`](https://aws.amazon.com/sdk-for-go/) - This document
is a general introduction on how to configure and make requests with the SDK.
If this is your first time using the SDK, this documentation and the API
documentation will help you get started. This document focuses on the syntax
and behavior of the SDK. The [Service Developer
Guide](https://aws.amazon.com/documentation/) will help you get started using
specific AWS services.
[`SDK API Reference Documentation`](https://docs.aws.amazon.com/sdk-for-go/api/) - Use this document to look up all API operation input and output parameters for AWS services supported by the SDK. The API reference also includes documentation of the SDK, and examples how to using the SDK, service client API operations, and API operation require parameters.
[`SDK API Reference
Documentation`](https://docs.aws.amazon.com/sdk-for-go/api/) - Use this
document to look up all API operation input and output parameters for AWS
services supported by the SDK. The API reference also includes documentation of
the SDK, and examples how to using the SDK, service client API operations, and
API operation require parameters.
[`Service Developer Guide`](https://aws.amazon.com/documentation/) - Use this documentation to learn how to interface with an AWS service. These are great guides both, if you're getting started with a service, or looking for more information on a service. You should not need this document for coding, though in some cases, services may supply helpful samples that you might want to look out for.
[`Service Developer Guide`](https://aws.amazon.com/documentation/) - Use this
documentation to learn how to interface with AWS services. These are great
guides both, if you're getting started with a service, or looking for more
information on a service. You should not need this document for coding, though
in some cases, services may supply helpful samples that you might want to look
out for.
[`SDK Examples`](https://github.com/aws/aws-sdk-go/tree/master/example) - Included in the SDK's repo are a several hand crafted examples using the SDK features and AWS services.
[`SDK Examples`](https://github.com/aws/aws-sdk-go/tree/master/example) -
Included in the SDK's repo are several hand crafted examples using the SDK
features and AWS services.
## Overview of SDK's Packages
@ -185,7 +236,7 @@ Option's SharedConfigState parameter.
}))
```
[credentials_pkg]: ttps://docs.aws.amazon.com/sdk-for-go/api/aws/credentials
[credentials_pkg]: https://docs.aws.amazon.com/sdk-for-go/api/aws/credentials
### Configuring AWS Region
@ -305,7 +356,7 @@ documentation for the errors that could be returned.
// will leak connections.
defer result.Body.Close()
fmt.Println("Object Size:", aws.StringValue(result.ContentLength))
fmt.Println("Object Size:", aws.Int64Value(result.ContentLength))
```
### API Request Pagination and Resource Waiters
@ -332,7 +383,7 @@ take a callback function that will be called for each page of the API's response
```
Waiter helper methods provide the functionality to wait for an AWS resource
state. These methods abstract the logic needed to to check the state of an
state. These methods abstract the logic needed to check the state of an
AWS resource, and wait until that resource is in a desired state. The waiter
will block until the resource is in the state that is desired, an error occurs,
or the waiter times out. If a resource times out the error code returned will
@ -418,7 +469,9 @@ response.
}
// Ensure the context is canceled to prevent leaking.
// See context package for more information, https://golang.org/pkg/context/
defer cancelFn()
if cancelFn {
defer cancelFn()
}
// Uploads the object to S3. The Context will interrupt the request if the
// timeout expires.

View File

@ -5,11 +5,11 @@ import (
"fmt"
"io"
"io/ioutil"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
func ExampleCopy() {
@ -81,12 +81,24 @@ func TestCopy1(t *testing.T) {
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
assert.Equal(t, f2.D, f1.D)
assert.Equal(t, f2.E.B, f1.E.B)
assert.Equal(t, f2.E.D, f1.E.D)
if v1, v2 := f2.A, f1.A; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.B, f1.B; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.C, f1.C; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.D, f1.D; !v1.Equal(*v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.E.B, f1.E.B; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.E.D, f1.E.D; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
// But pointers are not!
str3 := "nothello"
@ -99,14 +111,30 @@ func TestCopy1(t *testing.T) {
*f2.E.B = int3
f2.E.c = 5
f2.E.D = 5
assert.NotEqual(t, f2.A, f1.A)
assert.NotEqual(t, f2.B, f1.B)
assert.NotEqual(t, f2.C, f1.C)
assert.NotEqual(t, f2.D, f1.D)
assert.NotEqual(t, f2.E.a, f1.E.a)
assert.NotEqual(t, f2.E.B, f1.E.B)
assert.NotEqual(t, f2.E.c, f1.E.c)
assert.NotEqual(t, f2.E.D, f1.E.D)
if v1, v2 := f2.A, f1.A; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.B, f1.B; reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.C, f1.C; reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.D, f1.D; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.a, f1.E.a; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.B, f1.E.B; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.c, f1.E.c; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.E.D, f1.E.D; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
}
func TestCopyNestedWithUnexported(t *testing.T) {
@ -125,10 +153,18 @@ func TestCopyNestedWithUnexported(t *testing.T) {
awsutil.Copy(&f2, f1)
// Values match
assert.Equal(t, f2.A, f1.A)
assert.NotEqual(t, f2.B, f1.B)
assert.NotEqual(t, f2.B.a, f1.B.a)
assert.Equal(t, f2.B.B, f2.B.B)
if v1, v2 := f2.A, f1.A; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.B, f1.B; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.B.a, f1.B.a; v1 == v2 {
t.Errorf("expected values to be not equivalent, but received %v", v1)
}
if v1, v2 := f2.B.B, f2.B.B; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyIgnoreNilMembers(t *testing.T) {
@ -139,34 +175,56 @@ func TestCopyIgnoreNilMembers(t *testing.T) {
}
f := &Foo{}
assert.Nil(t, f.A)
assert.Nil(t, f.B)
assert.Nil(t, f.C)
if v1 := f.A; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f.B; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f.C; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
var f2 Foo
awsutil.Copy(&f2, f)
assert.Nil(t, f2.A)
assert.Nil(t, f2.B)
assert.Nil(t, f2.C)
if v1 := f2.A; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f2.B; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f2.C; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
fcopy := awsutil.CopyOf(f)
f3 := fcopy.(*Foo)
assert.Nil(t, f3.A)
assert.Nil(t, f3.B)
assert.Nil(t, f3.C)
if v1 := f3.A; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f3.B; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1 := f3.C; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
}
func TestCopyPrimitive(t *testing.T) {
str := "hello"
var s string
awsutil.Copy(&s, &str)
assert.Equal(t, "hello", s)
if v1, v2 := "hello", s; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyNil(t *testing.T) {
var s string
awsutil.Copy(&s, nil)
assert.Equal(t, "", s)
if v1, v2 := "", s; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyReader(t *testing.T) {
@ -174,13 +232,21 @@ func TestCopyReader(t *testing.T) {
var r io.Reader
awsutil.Copy(&r, buf)
b, err := ioutil.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, []byte("hello world"), b)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if v1, v2 := []byte("hello world"), b; !bytes.Equal(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
// empty bytes because this is not a deep copy
b, err = ioutil.ReadAll(buf)
assert.NoError(t, err)
assert.Equal(t, []byte(""), b)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}
if v1, v2 := []byte(""), b; !bytes.Equal(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func TestCopyDifferentStructs(t *testing.T) {
@ -226,17 +292,39 @@ func TestCopyDifferentStructs(t *testing.T) {
awsutil.Copy(&f2, f1)
// Values are equal
assert.Equal(t, f2.A, f1.A)
assert.Equal(t, f2.B, f1.B)
assert.Equal(t, f2.C, f1.C)
assert.Equal(t, "unique", f1.SrcUnique)
assert.Equal(t, 1, f1.SameNameDiffType)
assert.Equal(t, 0, f2.DstUnique)
assert.Equal(t, "", f2.SameNameDiffType)
assert.Equal(t, int1, *f1.unexportedPtr)
assert.Nil(t, f2.unexportedPtr)
assert.Equal(t, int2, *f1.ExportedPtr)
assert.Equal(t, int2, *f2.ExportedPtr)
if v1, v2 := f2.A, f1.A; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.B, f1.B; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := f2.C, f1.C; !reflect.DeepEqual(v1, v2) {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := "unique", f1.SrcUnique; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := 1, f1.SameNameDiffType; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := 0, f2.DstUnique; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := "", f2.SameNameDiffType; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := int1, *f1.unexportedPtr; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1 := f2.unexportedPtr; v1 != nil {
t.Errorf("expected nil, but received %v", v1)
}
if v1, v2 := int2, *f1.ExportedPtr; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
if v1, v2 := int2, *f2.ExportedPtr; v1 != v2 {
t.Errorf("expected values to be equivalent but received %v and %v", v1, v2)
}
}
func ExampleCopyOf() {

View File

@ -5,7 +5,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
func TestDeepEqual(t *testing.T) {
@ -24,6 +23,8 @@ func TestDeepEqual(t *testing.T) {
}
for i, c := range cases {
assert.Equal(t, c.equal, awsutil.DeepEqual(c.a, c.b), "%d, a:%v b:%v, %t", i, c.a, c.b, c.equal)
if awsutil.DeepEqual(c.a, c.b) != c.equal {
t.Errorf("%d, a:%v b:%v, %t", i, c.a, c.b, c.equal)
}
}
}

View File

@ -1,10 +1,10 @@
package awsutil_test
import (
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws/awsutil"
"github.com/stretchr/testify/assert"
)
type Struct struct {
@ -50,8 +50,12 @@ func TestValueAtPathSuccess(t *testing.T) {
}
for i, c := range testCases {
v, err := awsutil.ValuesAtPath(c.data, c.path)
assert.NoError(t, err, "case %d, expected no error, %s", i, c.path)
assert.Equal(t, c.expect, v, "case %d, %s", i, c.path)
if err != nil {
t.Errorf("case %v, expected no error, %v", i, c.path)
}
if e, a := c.expect, v; !awsutil.DeepEqual(e, a) {
t.Errorf("case %v, %v", i, c.path)
}
}
}
@ -78,12 +82,18 @@ func TestValueAtPathFailure(t *testing.T) {
for i, c := range testCases {
v, err := awsutil.ValuesAtPath(c.data, c.path)
if c.errContains != "" {
assert.Contains(t, err.Error(), c.errContains, "case %d, expected error, %s", i, c.path)
if !strings.Contains(err.Error(), c.errContains) {
t.Errorf("case %v, expected error, %v", i, c.path)
}
continue
} else {
assert.NoError(t, err, "case %d, expected no error, %s", i, c.path)
if err != nil {
t.Errorf("case %v, expected no error, %v", i, c.path)
}
}
if e, a := c.expect, v; !awsutil.DeepEqual(e, a) {
t.Errorf("case %v, %v", i, c.path)
}
assert.Equal(t, c.expect, v, "case %d, %s", i, c.path)
}
}
@ -92,51 +102,81 @@ func TestSetValueAtPathSuccess(t *testing.T) {
awsutil.SetValueAtPath(&s, "C", "test1")
awsutil.SetValueAtPath(&s, "B.B.C", "test2")
awsutil.SetValueAtPath(&s, "B.D.C", "test3")
assert.Equal(t, "test1", s.C)
assert.Equal(t, "test2", s.B.B.C)
assert.Equal(t, "test3", s.B.D.C)
if e, a := "test1", s.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := "test2", s.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := "test3", s.B.D.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
awsutil.SetValueAtPath(&s, "B.*.C", "test0")
assert.Equal(t, "test0", s.B.B.C)
assert.Equal(t, "test0", s.B.D.C)
if e, a := "test0", s.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := "test0", s.B.D.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
var s2 Struct
awsutil.SetValueAtPath(&s2, "b.b.c", "test0")
assert.Equal(t, "test0", s2.B.B.C)
if e, a := "test0", s2.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
awsutil.SetValueAtPath(&s2, "A", []Struct{{}})
assert.Equal(t, []Struct{{}}, s2.A)
if e, a := []Struct{{}}, s2.A; !awsutil.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
str := "foo"
s3 := Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", str)
assert.Equal(t, "foo", s3.B.B.C)
if e, a := "foo", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s3 = Struct{B: &Struct{B: &Struct{C: str}}}
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
assert.Equal(t, "", s3.B.B.C)
if e, a := "", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s3 = Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", nil)
assert.Equal(t, "", s3.B.B.C)
if e, a := "", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s3 = Struct{}
awsutil.SetValueAtPath(&s3, "b.b.c", &str)
assert.Equal(t, "foo", s3.B.B.C)
if e, a := "foo", s3.B.B.C; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
var s4 struct{ Name *string }
awsutil.SetValueAtPath(&s4, "Name", str)
assert.Equal(t, str, *s4.Name)
if e, a := str, *s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s4 = struct{ Name *string }{}
awsutil.SetValueAtPath(&s4, "Name", nil)
assert.Equal(t, (*string)(nil), s4.Name)
if e, a := (*string)(nil), s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s4 = struct{ Name *string }{Name: &str}
awsutil.SetValueAtPath(&s4, "Name", nil)
assert.Equal(t, (*string)(nil), s4.Name)
if e, a := (*string)(nil), s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
s4 = struct{ Name *string }{}
awsutil.SetValueAtPath(&s4, "Name", &str)
assert.Equal(t, str, *s4.Name)
if e, a := str, *s4.Name; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
}

View File

@ -23,28 +23,27 @@ func stringValue(v reflect.Value, indent int, buf *bytes.Buffer) {
case reflect.Struct:
buf.WriteString("{\n")
names := []string{}
for i := 0; i < v.Type().NumField(); i++ {
name := v.Type().Field(i).Name
f := v.Field(i)
if name[0:1] == strings.ToLower(name[0:1]) {
ft := v.Type().Field(i)
fv := v.Field(i)
if ft.Name[0:1] == strings.ToLower(ft.Name[0:1]) {
continue // ignore unexported fields
}
if (f.Kind() == reflect.Ptr || f.Kind() == reflect.Slice) && f.IsNil() {
if (fv.Kind() == reflect.Ptr || fv.Kind() == reflect.Slice) && fv.IsNil() {
continue // ignore unset fields
}
names = append(names, name)
}
for i, n := range names {
val := v.FieldByName(n)
buf.WriteString(strings.Repeat(" ", indent+2))
buf.WriteString(n + ": ")
stringValue(val, indent+2, buf)
buf.WriteString(ft.Name + ": ")
if i < len(names)-1 {
buf.WriteString(",\n")
if tag := ft.Tag.Get("sensitive"); tag == "true" {
buf.WriteString("<sensitive>")
} else {
stringValue(fv, indent+2, buf)
}
buf.WriteString(",\n")
}
buf.WriteString("\n" + strings.Repeat(" ", indent) + "}")

View File

@ -0,0 +1,51 @@
// +build go1.7
package awsutil
import (
"testing"
"github.com/aws/aws-sdk-go/aws"
)
type testStruct struct {
Field1 string
Field2 *string
Field3 []byte `sensitive:"true"`
Value []*string
}
func TestStringValue(t *testing.T) {
cases := map[string]struct {
Value interface{}
Expect string
}{
"general": {
Value: testStruct{
Field1: "abc123",
Field2: aws.String("abc123"),
Field3: []byte("don't show me"),
Value: []*string{
aws.String("first"),
aws.String("second"),
},
},
Expect: `{
Field1: "abc123",
Field2: "abc123",
Field3: <sensitive>,
Value: ["first","second"],
}`,
},
}
for d, c := range cases {
t.Run(d, func(t *testing.T) {
actual := StringValue(c.Value)
if e, a := c.Expect, actual; e != a {
t.Errorf("expect:\n%v\nactual:\n%v\n", e, a)
}
})
}
}

View File

@ -15,6 +15,12 @@ type Config struct {
Endpoint string
SigningRegion string
SigningName string
// States that the signing name did not come from a modeled source but
// was derived based on other data. Used by service client constructors
// to determine if the signin name can be overridden based on metadata the
// service has.
SigningNameDerived bool
}
// ConfigProvider provides a generic way for a service client to receive
@ -85,6 +91,6 @@ func (c *Client) AddDebugHandlers() {
return
}
c.Handlers.Send.PushFrontNamed(request.NamedHandler{Name: "awssdk.client.LogRequest", Fn: logRequest})
c.Handlers.Send.PushBackNamed(request.NamedHandler{Name: "awssdk.client.LogResponse", Fn: logResponse})
c.Handlers.Send.PushFrontNamed(LogHTTPRequestHandler)
c.Handlers.Send.PushBackNamed(LogHTTPResponseHandler)
}

View File

@ -1,11 +1,11 @@
package client
import (
"math/rand"
"sync"
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkrand"
)
// DefaultRetryer implements basic retry logic using exponential backoff for
@ -30,25 +30,27 @@ func (d DefaultRetryer) MaxRetries() int {
return d.NumMaxRetries
}
var seededRand = rand.New(&lockedSource{src: rand.NewSource(time.Now().UnixNano())})
// RetryRules returns the delay duration before retrying this request again
func (d DefaultRetryer) RetryRules(r *request.Request) time.Duration {
// Set the upper limit of delay in retrying at ~five minutes
minTime := 30
throttle := d.shouldThrottle(r)
if throttle {
if delay, ok := getRetryDelay(r); ok {
return delay
}
minTime = 500
}
retryCount := r.RetryCount
if retryCount > 13 {
retryCount = 13
} else if throttle && retryCount > 8 {
if throttle && retryCount > 8 {
retryCount = 8
} else if retryCount > 13 {
retryCount = 13
}
delay := (1 << uint(retryCount)) * (seededRand.Intn(minTime) + minTime)
delay := (1 << uint(retryCount)) * (sdkrand.SeededRand.Intn(minTime) + minTime)
return time.Duration(delay) * time.Millisecond
}
@ -60,7 +62,7 @@ func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
return *r.Retryable
}
if r.HTTPResponse.StatusCode >= 500 {
if r.HTTPResponse.StatusCode >= 500 && r.HTTPResponse.StatusCode != 501 {
return true
}
return r.IsErrorRetryable() || d.shouldThrottle(r)
@ -68,29 +70,47 @@ func (d DefaultRetryer) ShouldRetry(r *request.Request) bool {
// ShouldThrottle returns true if the request should be throttled.
func (d DefaultRetryer) shouldThrottle(r *request.Request) bool {
if r.HTTPResponse.StatusCode == 502 ||
r.HTTPResponse.StatusCode == 503 ||
r.HTTPResponse.StatusCode == 504 {
return true
switch r.HTTPResponse.StatusCode {
case 429:
case 502:
case 503:
case 504:
default:
return r.IsErrorThrottle()
}
return r.IsErrorThrottle()
return true
}
// lockedSource is a thread-safe implementation of rand.Source
type lockedSource struct {
lk sync.Mutex
src rand.Source
// This will look in the Retry-After header, RFC 7231, for how long
// it will wait before attempting another request
func getRetryDelay(r *request.Request) (time.Duration, bool) {
if !canUseRetryAfterHeader(r) {
return 0, false
}
delayStr := r.HTTPResponse.Header.Get("Retry-After")
if len(delayStr) == 0 {
return 0, false
}
delay, err := strconv.Atoi(delayStr)
if err != nil {
return 0, false
}
return time.Duration(delay) * time.Second, true
}
func (r *lockedSource) Int63() (n int64) {
r.lk.Lock()
n = r.src.Int63()
r.lk.Unlock()
return
}
// Will look at the status code to see if the retry header pertains to
// the status code.
func canUseRetryAfterHeader(r *request.Request) bool {
switch r.HTTPResponse.StatusCode {
case 429:
case 503:
default:
return false
}
func (r *lockedSource) Seed(seed int64) {
r.lk.Lock()
r.src.Seed(seed)
r.lk.Unlock()
return true
}

View File

@ -0,0 +1,189 @@
package client
import (
"net/http"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestRetryThrottleStatusCodes(t *testing.T) {
cases := []struct {
expectThrottle bool
expectRetry bool
r request.Request
}{
{
false,
false,
request.Request{
HTTPResponse: &http.Response{StatusCode: 200},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 429},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 502},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 503},
},
},
{
true,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 504},
},
},
{
false,
true,
request.Request{
HTTPResponse: &http.Response{StatusCode: 500},
},
},
}
d := DefaultRetryer{NumMaxRetries: 10}
for i, c := range cases {
throttle := d.shouldThrottle(&c.r)
retry := d.ShouldRetry(&c.r)
if e, a := c.expectThrottle, throttle; e != a {
t.Errorf("%d: expected %v, but received %v", i, e, a)
}
if e, a := c.expectRetry, retry; e != a {
t.Errorf("%d: expected %v, but received %v", i, e, a)
}
}
}
func TestCanUseRetryAfter(t *testing.T) {
cases := []struct {
r request.Request
e bool
}{
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 200},
},
false,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 500},
},
false,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 429},
},
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503},
},
true,
},
}
for i, c := range cases {
a := canUseRetryAfterHeader(&c.r)
if c.e != a {
t.Errorf("%d: expected %v, but received %v", i, c.e, a)
}
}
}
func TestGetRetryDelay(t *testing.T) {
cases := []struct {
r request.Request
e time.Duration
equal bool
ok bool
}{
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 429, Header: http.Header{"Retry-After": []string{"3600"}}},
},
3600 * time.Second,
true,
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{"120"}}},
},
120 * time.Second,
true,
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{"120"}}},
},
1 * time.Second,
false,
true,
},
{
request.Request{
HTTPResponse: &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{""}}},
},
0 * time.Second,
true,
false,
},
}
for i, c := range cases {
a, ok := getRetryDelay(&c.r)
if c.ok != ok {
t.Errorf("%d: expected %v, but received %v", i, c.ok, ok)
}
if (c.e != a) == c.equal {
t.Errorf("%d: expected %v, but received %v", i, c.e, a)
}
}
}
func TestRetryDelay(t *testing.T) {
r := request.Request{}
for i := 0; i < 100; i++ {
rTemp := r
rTemp.HTTPResponse = &http.Response{StatusCode: 500, Header: http.Header{"Retry-After": []string{""}}}
rTemp.RetryCount = i
a, _ := getRetryDelay(&rTemp)
if a > 5*time.Minute {
t.Errorf("retry delay should never be greater than five minutes, received %d", a)
}
}
for i := 0; i < 100; i++ {
rTemp := r
rTemp.RetryCount = i
rTemp.HTTPResponse = &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{""}}}
a, _ := getRetryDelay(&rTemp)
if a > 5*time.Minute {
t.Errorf("retry delay should never be greater than five minutes, received %d", a)
}
}
}

View File

@ -44,22 +44,57 @@ func (reader *teeReaderCloser) Close() error {
return reader.Source.Close()
}
// LogHTTPRequestHandler is a SDK request handler to log the HTTP request sent
// to a service. Will include the HTTP request body if the LogLevel of the
// request matches LogDebugWithHTTPBody.
var LogHTTPRequestHandler = request.NamedHandler{
Name: "awssdk.client.LogRequest",
Fn: logRequest,
}
func logRequest(r *request.Request) {
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
dumpedBody, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
bodySeekable := aws.IsReaderSeekable(r.Body)
b, err := httputil.DumpRequestOut(r.HTTPRequest, logBody)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg, r.ClientInfo.ServiceName, r.Operation.Name, err))
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
if logBody {
if !bodySeekable {
r.SetReaderBody(aws.ReadSeekCloser(r.HTTPRequest.Body))
}
// Reset the request body because dumpRequest will re-wrap the r.HTTPRequest's
// Body as a NoOpCloser and will not be reset after read by the HTTP
// client reader.
r.ResetBody()
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg, r.ClientInfo.ServiceName, r.Operation.Name, string(dumpedBody)))
r.Config.Logger.Log(fmt.Sprintf(logReqMsg,
r.ClientInfo.ServiceName, r.Operation.Name, string(b)))
}
// LogHTTPRequestHeaderHandler is a SDK request handler to log the HTTP request sent
// to a service. Will only log the HTTP request's headers. The request payload
// will not be read.
var LogHTTPRequestHeaderHandler = request.NamedHandler{
Name: "awssdk.client.LogRequestHeader",
Fn: logRequestHeader,
}
func logRequestHeader(r *request.Request) {
b, err := httputil.DumpRequestOut(r.HTTPRequest, false)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logReqErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
r.Config.Logger.Log(fmt.Sprintf(logReqMsg,
r.ClientInfo.ServiceName, r.Operation.Name, string(b)))
}
const logRespMsg = `DEBUG: Response %s/%s Details:
@ -72,27 +107,50 @@ const logRespErrMsg = `DEBUG ERROR: Response %s/%s:
%s
-----------------------------------------------------`
// LogHTTPResponseHandler is a SDK request handler to log the HTTP response
// received from a service. Will include the HTTP response body if the LogLevel
// of the request matches LogDebugWithHTTPBody.
var LogHTTPResponseHandler = request.NamedHandler{
Name: "awssdk.client.LogResponse",
Fn: logResponse,
}
func logResponse(r *request.Request) {
lw := &logWriter{r.Config.Logger, bytes.NewBuffer(nil)}
r.HTTPResponse.Body = &teeReaderCloser{
Reader: io.TeeReader(r.HTTPResponse.Body, lw),
Source: r.HTTPResponse.Body,
if r.HTTPResponse == nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, "request's HTTPResponse is nil"))
return
}
logBody := r.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody)
if logBody {
r.HTTPResponse.Body = &teeReaderCloser{
Reader: io.TeeReader(r.HTTPResponse.Body, lw),
Source: r.HTTPResponse.Body,
}
}
handlerFn := func(req *request.Request) {
body, err := httputil.DumpResponse(req.HTTPResponse, false)
b, err := httputil.DumpResponse(req.HTTPResponse, false)
if err != nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg, req.ClientInfo.ServiceName, req.Operation.Name, err))
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
req.ClientInfo.ServiceName, req.Operation.Name, err))
return
}
b, err := ioutil.ReadAll(lw.buf)
if err != nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg, req.ClientInfo.ServiceName, req.Operation.Name, err))
return
}
lw.Logger.Log(fmt.Sprintf(logRespMsg, req.ClientInfo.ServiceName, req.Operation.Name, string(body)))
if req.Config.LogLevel.Matches(aws.LogDebugWithHTTPBody) {
lw.Logger.Log(fmt.Sprintf(logRespMsg,
req.ClientInfo.ServiceName, req.Operation.Name, string(b)))
if logBody {
b, err := ioutil.ReadAll(lw.buf)
if err != nil {
lw.Logger.Log(fmt.Sprintf(logRespErrMsg,
req.ClientInfo.ServiceName, req.Operation.Name, err))
return
}
lw.Logger.Log(string(b))
}
}
@ -106,3 +164,27 @@ func logResponse(r *request.Request) {
Name: handlerName, Fn: handlerFn,
})
}
// LogHTTPResponseHeaderHandler is a SDK request handler to log the HTTP
// response received from a service. Will only log the HTTP response's headers.
// The response payload will not be read.
var LogHTTPResponseHeaderHandler = request.NamedHandler{
Name: "awssdk.client.LogResponseHeader",
Fn: logResponseHeader,
}
func logResponseHeader(r *request.Request) {
if r.Config.Logger == nil {
return
}
b, err := httputil.DumpResponse(r.HTTPResponse, false)
if err != nil {
r.Config.Logger.Log(fmt.Sprintf(logRespErrMsg,
r.ClientInfo.ServiceName, r.Operation.Name, err))
return
}
r.Config.Logger.Log(fmt.Sprintf(logRespMsg,
r.ClientInfo.ServiceName, r.Operation.Name, string(b)))
}

View File

@ -2,8 +2,18 @@ package client
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/request"
)
type mockCloser struct {
@ -55,3 +65,158 @@ func TestLogWriter(t *testing.T) {
t.Errorf("Expected %q, but received %q", expected, lw.buf.String())
}
}
func TestLogRequest(t *testing.T) {
cases := []struct {
Body io.ReadSeeker
ExpectBody []byte
LogLevel aws.LogLevelType
}{
{
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("body content"))),
ExpectBody: []byte("body content"),
},
{
Body: aws.ReadSeekCloser(bytes.NewBuffer([]byte("body content"))),
LogLevel: aws.LogDebugWithHTTPBody,
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewReader([]byte("body content")),
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewReader([]byte("body content")),
LogLevel: aws.LogDebugWithHTTPBody,
ExpectBody: []byte("body content"),
},
}
for i, c := range cases {
logW := bytes.NewBuffer(nil)
req := request.New(
aws.Config{
Credentials: credentials.AnonymousCredentials,
Logger: &bufLogger{w: logW},
LogLevel: aws.LogLevel(c.LogLevel),
},
metadata.ClientInfo{
Endpoint: "https://mock-service.mock-region.amazonaws.com",
},
testHandlers(),
nil,
&request.Operation{
Name: "APIName",
HTTPMethod: "POST",
HTTPPath: "/",
},
struct{}{}, nil,
)
req.SetReaderBody(c.Body)
req.Build()
logRequest(req)
b, err := ioutil.ReadAll(req.HTTPRequest.Body)
if err != nil {
t.Fatalf("%d, expect to read SDK request Body", i)
}
if e, a := c.ExpectBody, b; !reflect.DeepEqual(e, a) {
t.Errorf("%d, expect %v body, got %v", i, e, a)
}
}
}
func TestLogResponse(t *testing.T) {
cases := []struct {
Body *bytes.Buffer
ExpectBody []byte
ReadBody bool
LogLevel aws.LogLevelType
}{
{
Body: bytes.NewBuffer([]byte("body content")),
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewBuffer([]byte("body content")),
LogLevel: aws.LogDebug,
ExpectBody: []byte("body content"),
},
{
Body: bytes.NewBuffer([]byte("body content")),
LogLevel: aws.LogDebugWithHTTPBody,
ReadBody: true,
ExpectBody: []byte("body content"),
},
}
for i, c := range cases {
var logW bytes.Buffer
req := request.New(
aws.Config{
Credentials: credentials.AnonymousCredentials,
Logger: &bufLogger{w: &logW},
LogLevel: aws.LogLevel(c.LogLevel),
},
metadata.ClientInfo{
Endpoint: "https://mock-service.mock-region.amazonaws.com",
},
testHandlers(),
nil,
&request.Operation{
Name: "APIName",
HTTPMethod: "POST",
HTTPPath: "/",
},
struct{}{}, nil,
)
req.HTTPResponse = &http.Response{
StatusCode: 200,
Status: "OK",
Header: http.Header{
"ABC": []string{"123"},
},
Body: ioutil.NopCloser(c.Body),
}
logResponse(req)
req.Handlers.Unmarshal.Run(req)
if c.ReadBody {
if e, a := len(c.ExpectBody), c.Body.Len(); e != a {
t.Errorf("%d, expect original body not to of been read", i)
}
}
if logW.Len() == 0 {
t.Errorf("%d, expect HTTP Response headers to be logged", i)
}
b, err := ioutil.ReadAll(req.HTTPResponse.Body)
if err != nil {
t.Fatalf("%d, expect to read SDK request Body", i)
}
if e, a := c.ExpectBody, b; !bytes.Equal(e, a) {
t.Errorf("%d, expect %v body, got %v", i, e, a)
}
}
}
type bufLogger struct {
w *bytes.Buffer
}
func (l *bufLogger) Log(args ...interface{}) {
fmt.Fprintln(l.w, args...)
}
func testHandlers() request.Handlers {
var handlers request.Handlers
handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler)
return handlers
}

View File

@ -3,6 +3,7 @@ package metadata
// ClientInfo wraps immutable data from the client.Client structure.
type ClientInfo struct {
ServiceName string
ServiceID string
APIVersion string
Endpoint string
SigningName string

View File

@ -18,7 +18,7 @@ const UseServiceDefaultRetries = -1
type RequestRetryer interface{}
// A Config provides service configuration for service clients. By default,
// all clients will use the defaults.DefaultConfig tructure.
// all clients will use the defaults.DefaultConfig structure.
//
// // Create Session with MaxRetry configuration to be shared by multiple
// // service clients.
@ -45,8 +45,8 @@ type Config struct {
// that overrides the default generated endpoint for a client. Set this
// to `""` to use the default generated endpoint.
//
// @note You must still provide a `Region` value when specifying an
// endpoint for a client.
// Note: You must still provide a `Region` value when specifying an
// endpoint for a client.
Endpoint *string
// The resolver to use for looking up endpoints for AWS service clients
@ -65,8 +65,8 @@ type Config struct {
// noted. A full list of regions is found in the "Regions and Endpoints"
// document.
//
// @see http://docs.aws.amazon.com/general/latest/gr/rande.html
// AWS Regions and Endpoints
// See http://docs.aws.amazon.com/general/latest/gr/rande.html for AWS
// Regions and Endpoints.
Region *string
// Set this to `true` to disable SSL when sending requests. Defaults
@ -120,9 +120,10 @@ type Config struct {
// will use virtual hosted bucket addressing when possible
// (`http://BUCKET.s3.amazonaws.com/KEY`).
//
// @note This configuration option is specific to the Amazon S3 service.
// @see http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// Amazon S3: Virtual Hosting of Buckets
// Note: This configuration option is specific to the Amazon S3 service.
//
// See http://docs.aws.amazon.com/AmazonS3/latest/dev/VirtualHosting.html
// for Amazon S3: Virtual Hosting of Buckets
S3ForcePathStyle *bool
// Set this to `true` to disable the SDK adding the `Expect: 100-Continue`
@ -151,6 +152,15 @@ type Config struct {
// with accelerate.
S3UseAccelerate *bool
// S3DisableContentMD5Validation config option is temporarily disabled,
// For S3 GetObject API calls, #1837.
//
// Set this to `true` to disable the S3 service client from automatically
// adding the ContentMD5 to S3 Object Put and Upload API calls. This option
// will also disable the SDK from performing object ContentMD5 validation
// on GetObject API calls.
S3DisableContentMD5Validation *bool
// Set this to `true` to disable the EC2Metadata client from overriding the
// default http.Client's Timeout. This is helpful if you do not want the
// EC2Metadata client to create a new http.Client. This options is only
@ -168,7 +178,7 @@ type Config struct {
//
EC2MetadataDisableTimeoutOverride *bool
// Instructs the endpiont to be generated for a service client to
// Instructs the endpoint to be generated for a service client to
// be the dual stack endpoint. The dual stack endpoint will support
// both IPv4 and IPv6 addressing.
//
@ -214,6 +224,28 @@ type Config struct {
// Key: aws.String("//foo//bar//moo"),
// })
DisableRestProtocolURICleaning *bool
// EnableEndpointDiscovery will allow for endpoint discovery on operations that
// have the definition in its model. By default, endpoint discovery is off.
//
// Example:
// sess := session.Must(session.NewSession(&aws.Config{
// EnableEndpointDiscovery: aws.Bool(true),
// }))
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput {
// Bucket: aws.String("bucketname"),
// Key: aws.String("/foo/bar/moo"),
// })
EnableEndpointDiscovery *bool
// DisableEndpointHostPrefix will disable the SDK's behavior of prefixing
// request endpoint hosts with modeled information.
//
// Disabling this feature is useful when you want to use local endpoints
// for testing that do not support the modeled host prefix pattern.
DisableEndpointHostPrefix *bool
}
// NewConfig returns a new Config pointer that can be chained with builder
@ -336,6 +368,15 @@ func (c *Config) WithS3Disable100Continue(disable bool) *Config {
func (c *Config) WithS3UseAccelerate(enable bool) *Config {
c.S3UseAccelerate = &enable
return c
}
// WithS3DisableContentMD5Validation sets a config
// S3DisableContentMD5Validation value returning a Config pointer for chaining.
func (c *Config) WithS3DisableContentMD5Validation(enable bool) *Config {
c.S3DisableContentMD5Validation = &enable
return c
}
// WithUseDualStack sets a config UseDualStack value returning a Config
@ -359,6 +400,19 @@ func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
return c
}
// WithEndpointDiscovery will set whether or not to use endpoint discovery.
func (c *Config) WithEndpointDiscovery(t bool) *Config {
c.EnableEndpointDiscovery = &t
return c
}
// WithDisableEndpointHostPrefix will set whether or not to use modeled host prefix
// when making requests.
func (c *Config) WithDisableEndpointHostPrefix(t bool) *Config {
c.DisableEndpointHostPrefix = &t
return c
}
// MergeIn merges the passed in configs into the existing config object.
func (c *Config) MergeIn(cfgs ...*Config) {
for _, other := range cfgs {
@ -435,6 +489,10 @@ func mergeInConfig(dst *Config, other *Config) {
dst.S3UseAccelerate = other.S3UseAccelerate
}
if other.S3DisableContentMD5Validation != nil {
dst.S3DisableContentMD5Validation = other.S3DisableContentMD5Validation
}
if other.UseDualStack != nil {
dst.UseDualStack = other.UseDualStack
}
@ -454,6 +512,14 @@ func mergeInConfig(dst *Config, other *Config) {
if other.EnforceShouldRetryCheck != nil {
dst.EnforceShouldRetryCheck = other.EnforceShouldRetryCheck
}
if other.EnableEndpointDiscovery != nil {
dst.EnableEndpointDiscovery = other.EnableEndpointDiscovery
}
if other.DisableEndpointHostPrefix != nil {
dst.DisableEndpointHostPrefix = other.DisableEndpointHostPrefix
}
}
// Copy will return a shallow copy of the Config object. If any additional

View File

@ -1,8 +1,8 @@
// +build !go1.9
package aws
import (
"time"
)
import "time"
// Context is an copy of the Go v1.7 stdlib's context.Context interface.
// It is represented as a SDK interface to enable you to use the "WithContext"
@ -35,37 +35,3 @@ type Context interface {
// functions.
Value(key interface{}) interface{}
}
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return backgroundCtx
}
// SleepWithContext will wait for the timer duration to expire, or the context
// is canceled. Which ever happens first. If the context is canceled the Context's
// error will be returned.
//
// Expects Context to always return a non-nil error if the Done channel is closed.
func SleepWithContext(ctx Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
break
case <-ctx.Done():
return ctx.Err()
}
return nil
}

View File

@ -1,9 +0,0 @@
// +build go1.7
package aws
import "context"
var (
backgroundCtx = context.Background()
)

11
vendor/github.com/aws/aws-sdk-go/aws/context_1_9.go generated vendored Normal file
View File

@ -0,0 +1,11 @@
// +build go1.9
package aws
import "context"
// Context is an alias of the Go stdlib's context.Context interface.
// It can be used within the SDK's API operation "WithContext" methods.
//
// See https://golang.org/pkg/context on how to use contexts.
type Context = context.Context

View File

@ -39,3 +39,18 @@ func (e *emptyCtx) String() string {
var (
backgroundCtx = new(emptyCtx)
)
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return backgroundCtx
}

View File

@ -0,0 +1,20 @@
// +build go1.7
package aws
import "context"
// BackgroundContext returns a context that will never be canceled, has no
// values, and no deadline. This context is used by the SDK to provide
// backwards compatibility with non-context API operations and functionality.
//
// Go 1.6 and before:
// This context function is equivalent to context.Background in the Go stdlib.
//
// Go 1.7 and later:
// The context returned will be the value returned by context.Background()
//
// See https://golang.org/pkg/context for more information on Contexts.
func BackgroundContext() Context {
return context.Background()
}

24
vendor/github.com/aws/aws-sdk-go/aws/context_sleep.go generated vendored Normal file
View File

@ -0,0 +1,24 @@
package aws
import (
"time"
)
// SleepWithContext will wait for the timer duration to expire, or the context
// is canceled. Which ever happens first. If the context is canceled the Context's
// error will be returned.
//
// Expects Context to always return a non-nil error if the Done channel is closed.
func SleepWithContext(ctx Context, dur time.Duration) error {
t := time.NewTimer(dur)
defer t.Stop()
select {
case <-t.C:
break
case <-ctx.Done():
return ctx.Err()
}
return nil
}

View File

@ -26,7 +26,7 @@ func TestSleepWithContext_Canceled(t *testing.T) {
ctx.Error = expectErr
close(ctx.DoneCh)
err := aws.SleepWithContext(ctx, 1*time.Millisecond)
err := aws.SleepWithContext(ctx, 10*time.Second)
if err == nil {
t.Fatalf("expect error, did not get one")
}

View File

@ -3,12 +3,10 @@ package corehandlers
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"runtime"
"strconv"
"time"
@ -36,18 +34,13 @@ var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLen
if slength := r.HTTPRequest.Header.Get("Content-Length"); slength != "" {
length, _ = strconv.ParseInt(slength, 10, 64)
} else {
switch body := r.Body.(type) {
case nil:
length = 0
case lener:
length = int64(body.Len())
case io.Seeker:
r.BodyStart, _ = body.Seek(0, 1)
end, _ := body.Seek(0, 2)
body.Seek(r.BodyStart, 0) // make sure to seek back to original location
length = end - r.BodyStart
default:
panic("Cannot get length of body, must provide `ContentLength`")
if r.Body != nil {
var err error
length, err = aws.SeekerLen(r.Body)
if err != nil {
r.Error = awserr.New(request.ErrCodeSerialization, "failed to get request body's length", err)
return
}
}
}
@ -60,13 +53,6 @@ var BuildContentLengthHandler = request.NamedHandler{Name: "core.BuildContentLen
}
}}
// SDKVersionUserAgentHandler is a request handler for adding the SDK Version to the user agent.
var SDKVersionUserAgentHandler = request.NamedHandler{
Name: "core.SDKVersionUserAgentHandler",
Fn: request.MakeAddToUserAgentHandler(aws.SDKName, aws.SDKVersion,
runtime.Version(), runtime.GOOS, runtime.GOARCH),
}
var reStatusCode = regexp.MustCompile(`^(\d{3})`)
// ValidateReqSigHandler is a request handler to ensure that the request's
@ -86,9 +72,9 @@ var ValidateReqSigHandler = request.NamedHandler{
signedTime = r.LastSignedAt
}
// 10 minutes to allow for some clock skew/delays in transmission.
// 5 minutes to allow for some clock skew/delays in transmission.
// Would be improved with aws/aws-sdk-go#423
if signedTime.Add(10 * time.Minute).After(time.Now()) {
if signedTime.Add(5 * time.Minute).After(time.Now()) {
return
}

View File

@ -1,4 +1,4 @@
// +build go1.8
// +build go1.10
package corehandlers_test

View File

@ -2,8 +2,8 @@ package corehandlers_test
import (
"fmt"
"testing"
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
@ -256,7 +256,7 @@ func TestValidateFieldMinParameter(t *testing.T) {
req := testSvc.NewRequest(&request.Operation{}, &c.in, nil)
corehandlers.ValidateParametersHandler.Fn(req)
if e, a := c.err, req.Error; !reflect.DeepEqual(e,a) {
if e, a := c.err, req.Error; !reflect.DeepEqual(e, a) {
t.Errorf("%d, expect %v, got %v", i, e, a)
}
}

View File

@ -0,0 +1,37 @@
package corehandlers
import (
"os"
"runtime"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
)
// SDKVersionUserAgentHandler is a request handler for adding the SDK Version
// to the user agent.
var SDKVersionUserAgentHandler = request.NamedHandler{
Name: "core.SDKVersionUserAgentHandler",
Fn: request.MakeAddToUserAgentHandler(aws.SDKName, aws.SDKVersion,
runtime.Version(), runtime.GOOS, runtime.GOARCH),
}
const execEnvVar = `AWS_EXECUTION_ENV`
const execEnvUAKey = `exec-env`
// AddHostExecEnvUserAgentHander is a request handler appending the SDK's
// execution environment to the user agent.
//
// If the environment variable AWS_EXECUTION_ENV is set, its value will be
// appended to the user agent string.
var AddHostExecEnvUserAgentHander = request.NamedHandler{
Name: "core.AddHostExecEnvUserAgentHander",
Fn: func(r *request.Request) {
v := os.Getenv(execEnvVar)
if len(v) == 0 {
return
}
request.AddToUserAgent(r, execEnvUAKey+"/"+v)
},
}

View File

@ -0,0 +1,40 @@
package corehandlers
import (
"net/http"
"os"
"testing"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestAddHostExecEnvUserAgentHander(t *testing.T) {
cases := []struct {
ExecEnv string
Expect string
}{
{ExecEnv: "Lambda", Expect: execEnvUAKey + "/Lambda"},
{ExecEnv: "", Expect: ""},
{ExecEnv: "someThingCool", Expect: execEnvUAKey + "/someThingCool"},
}
for i, c := range cases {
os.Clearenv()
os.Setenv(execEnvVar, c.ExecEnv)
req := &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
},
}
AddHostExecEnvUserAgentHander.Fn(req)
if err := req.Error; err != nil {
t.Fatalf("%d, expect no error, got %v", i, err)
}
if e, a := c.Expect, req.HTTPRequest.Header.Get("User-Agent"); e != a {
t.Errorf("%d, expect %v user agent, got %v", i, e, a)
}
}
}

View File

@ -9,9 +9,7 @@ var (
// providers in the ChainProvider.
//
// This has been deprecated. For verbose error messaging set
// aws.Config.CredentialsChainVerboseErrors to true
//
// @readonly
// aws.Config.CredentialsChainVerboseErrors to true.
ErrNoValidProvidersFoundInChain = awserr.New("NoCredentialProviders",
`no valid providers in chain. Deprecated.
For verbose messaging see aws.Config.CredentialsChainVerboseErrors`,

View File

@ -1,10 +1,10 @@
package credentials
import (
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
type secondStubProvider struct {
@ -45,13 +45,23 @@ func TestChainProviderWithNames(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "secondStubProvider", creds.ProviderName, "Expect provider name to match")
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if e, a := "secondStubProvider", creds.ProviderName; e != a {
t.Errorf("Expect provider name to match, %v got, %v", e, a)
}
// Also check credentials
assert.Equal(t, "AKIF", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "NOSECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
if e, a := "AKIF", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "NOSECRET", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect session token to be empty, %v", v)
}
}
@ -71,10 +81,18 @@ func TestChainProviderGet(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect session token to be empty, %v", v)
}
}
func TestChainProviderIsExpired(t *testing.T) {
@ -85,16 +103,26 @@ func TestChainProviderIsExpired(t *testing.T) {
},
}
assert.True(t, p.IsExpired(), "Expect expired to be true before any Retrieve")
if !p.IsExpired() {
t.Errorf("Expect expired to be true before any Retrieve")
}
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if p.IsExpired() {
t.Errorf("Expect not expired after retrieve")
}
stubProvider.expired = true
assert.True(t, p.IsExpired(), "Expect return of expired provider")
if !p.IsExpired() {
t.Errorf("Expect return of expired provider")
}
_, err = p.Retrieve()
assert.False(t, p.IsExpired(), "Expect not expired after retrieve")
if p.IsExpired() {
t.Errorf("Expect not expired after retrieve")
}
}
func TestChainProviderWithNoProvider(t *testing.T) {
@ -102,12 +130,13 @@ func TestChainProviderWithNoProvider(t *testing.T) {
Providers: []Provider{},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
if !p.IsExpired() {
t.Errorf("Expect expired with no providers")
}
_, err := p.Retrieve()
assert.Equal(t,
ErrNoValidProvidersFoundInChain,
err,
"Expect no providers error returned")
if e, a := ErrNoValidProvidersFoundInChain, err; e != a {
t.Errorf("Expect no providers error returned, %v, got %v", e, a)
}
}
func TestChainProviderWithNoValidProvider(t *testing.T) {
@ -122,13 +151,14 @@ func TestChainProviderWithNoValidProvider(t *testing.T) {
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
if !p.IsExpired() {
t.Errorf("Expect expired with no providers")
}
_, err := p.Retrieve()
assert.Equal(t,
ErrNoValidProvidersFoundInChain,
err,
"Expect no providers error returned")
if e, a := ErrNoValidProvidersFoundInChain, err; e != a {
t.Errorf("Expect no providers error returned, %v, got %v", e, a)
}
}
func TestChainProviderWithNoValidProviderWithVerboseEnabled(t *testing.T) {
@ -144,11 +174,13 @@ func TestChainProviderWithNoValidProviderWithVerboseEnabled(t *testing.T) {
},
}
assert.True(t, p.IsExpired(), "Expect expired with no providers")
if !p.IsExpired() {
t.Errorf("Expect expired with no providers")
}
_, err := p.Retrieve()
assert.Equal(t,
awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs),
err,
"Expect no providers error returned")
expectErr := awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
if e, a := expectErr, err; !reflect.DeepEqual(e, a) {
t.Errorf("Expect no providers error returned, %v, got %v", e, a)
}
}

View File

@ -49,6 +49,8 @@
package credentials
import (
"fmt"
"github.com/aws/aws-sdk-go/aws/awserr"
"sync"
"time"
)
@ -64,8 +66,6 @@ import (
// Credentials: credentials.AnonymousCredentials,
// })))
// // Access public S3 buckets.
//
// @readonly
var AnonymousCredentials = NewStaticCredentials("", "", "")
// A Value is the AWS credentials value for individual credential fields.
@ -99,6 +99,14 @@ type Provider interface {
IsExpired() bool
}
// An Expirer is an interface that Providers can implement to expose the expiration
// time, if known. If the Provider cannot accurately provide this info,
// it should not implement this interface.
type Expirer interface {
// The time at which the credentials are no longer valid
ExpiresAt() time.Time
}
// An ErrorProvider is a stub credentials provider that always returns an error
// this is used by the SDK when construction a known provider is not possible
// due to an error.
@ -158,13 +166,19 @@ func (e *Expiry) SetExpiration(expiration time.Time, window time.Duration) {
// IsExpired returns if the credentials are expired.
func (e *Expiry) IsExpired() bool {
if e.CurrentTime == nil {
e.CurrentTime = time.Now
curTime := e.CurrentTime
if curTime == nil {
curTime = time.Now
}
return e.expiration.Before(e.CurrentTime())
return e.expiration.Before(curTime())
}
// A Credentials provides synchronous safe retrieval of AWS credentials Value.
// ExpiresAt returns the expiration time of the credential
func (e *Expiry) ExpiresAt() time.Time {
return e.expiration
}
// A Credentials provides concurrency safe retrieval of AWS credentials Value.
// Credentials will cache the credentials value until they expire. Once the value
// expires the next Get will attempt to retrieve valid credentials.
//
@ -178,7 +192,8 @@ func (e *Expiry) IsExpired() bool {
type Credentials struct {
creds Value
forceRefresh bool
m sync.Mutex
m sync.RWMutex
provider Provider
}
@ -201,6 +216,17 @@ func NewCredentials(provider Provider) *Credentials {
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
func (c *Credentials) Get() (Value, error) {
// Check the cached credentials first with just the read lock.
c.m.RLock()
if !c.isExpired() {
creds := c.creds
c.m.RUnlock()
return creds, nil
}
c.m.RUnlock()
// Credentials are expired need to retrieve the credentials taking the full
// lock.
c.m.Lock()
defer c.m.Unlock()
@ -234,8 +260,8 @@ func (c *Credentials) Expire() {
// If the Credentials were forced to be expired with Expire() this will
// reflect that override.
func (c *Credentials) IsExpired() bool {
c.m.Lock()
defer c.m.Unlock()
c.m.RLock()
defer c.m.RUnlock()
return c.isExpired()
}
@ -244,3 +270,23 @@ func (c *Credentials) IsExpired() bool {
func (c *Credentials) isExpired() bool {
return c.forceRefresh || c.provider.IsExpired()
}
// ExpiresAt provides access to the functionality of the Expirer interface of
// the underlying Provider, if it supports that interface. Otherwise, it returns
// an error.
func (c *Credentials) ExpiresAt() (time.Time, error) {
c.m.RLock()
defer c.m.RUnlock()
expirer, ok := c.provider.(Expirer)
if !ok {
return time.Time{}, awserr.New("ProviderNotExpirer",
fmt.Sprintf("provider %s does not support ExpiresAt()", c.creds.ProviderName),
nil)
}
if c.forceRefresh {
// set expiration time to the distant past
return time.Time{}, nil
}
return expirer.ExpiresAt(), nil
}

View File

@ -0,0 +1,90 @@
// +build go1.9
package credentials
import (
"fmt"
"strconv"
"sync"
"testing"
"time"
)
func BenchmarkCredentials_Get(b *testing.B) {
stub := &stubProvider{}
cases := []int{1, 10, 100, 500, 1000, 10000}
for _, c := range cases {
b.Run(strconv.Itoa(c), func(b *testing.B) {
creds := NewCredentials(stub)
var wg sync.WaitGroup
wg.Add(c)
for i := 0; i < c; i++ {
go func() {
for j := 0; j < b.N; j++ {
v, err := creds.Get()
if err != nil {
b.Fatalf("expect no error %v, %v", v, err)
}
}
wg.Done()
}()
}
b.ResetTimer()
wg.Wait()
})
}
}
func BenchmarkCredentials_Get_Expire(b *testing.B) {
p := &blockProvider{}
expRates := []int{10000, 1000, 100}
cases := []int{1, 10, 100, 500, 1000, 10000}
for _, expRate := range expRates {
for _, c := range cases {
b.Run(fmt.Sprintf("%d-%d", expRate, c), func(b *testing.B) {
creds := NewCredentials(p)
var wg sync.WaitGroup
wg.Add(c)
for i := 0; i < c; i++ {
go func(id int) {
for j := 0; j < b.N; j++ {
v, err := creds.Get()
if err != nil {
b.Fatalf("expect no error %v, %v", v, err)
}
// periodically expire creds to cause rwlock
if id == 0 && j%expRate == 0 {
creds.Expire()
}
}
wg.Done()
}(i)
}
b.ResetTimer()
wg.Wait()
})
}
}
}
type blockProvider struct {
creds Value
expired bool
err error
}
func (s *blockProvider) Retrieve() (Value, error) {
s.expired = false
s.creds.ProviderName = "blockProvider"
time.Sleep(time.Millisecond)
return s.creds, s.err
}
func (s *blockProvider) IsExpired() bool {
return s.expired
}

View File

@ -1,10 +1,11 @@
package credentials
import (
"math/rand"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/stretchr/testify/assert"
)
type stubProvider struct {
@ -33,17 +34,27 @@ func TestCredentialsGet(t *testing.T) {
})
creds, err := c.Get()
assert.Nil(t, err, "Expected no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect session token to be empty")
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect session token to be empty, %v", v)
}
}
func TestCredentialsGetWithError(t *testing.T) {
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
_, err := c.Get()
assert.Equal(t, "provider error", err.(awserr.Error).Code(), "Expected provider error")
if e, a := "provider error", err.(awserr.Error).Code(); e != a {
t.Errorf("Expected provider error, %v got %v", e, a)
}
}
func TestCredentialsExpire(t *testing.T) {
@ -51,15 +62,31 @@ func TestCredentialsExpire(t *testing.T) {
c := NewCredentials(stub)
stub.expired = false
assert.True(t, c.IsExpired(), "Expected to start out expired")
if !c.IsExpired() {
t.Errorf("Expected to start out expired")
}
c.Expire()
assert.True(t, c.IsExpired(), "Expected to be expired")
if !c.IsExpired() {
t.Errorf("Expected to be expired")
}
c.forceRefresh = false
assert.False(t, c.IsExpired(), "Expected not to be expired")
if c.IsExpired() {
t.Errorf("Expected not to be expired")
}
stub.expired = true
assert.True(t, c.IsExpired(), "Expected to be expired")
if !c.IsExpired() {
t.Errorf("Expected to be expired")
}
}
type MockProvider struct {
Expiry
}
func (*MockProvider) Retrieve() (Value, error) {
return Value{}, nil
}
func TestCredentialsGetWithProviderName(t *testing.T) {
@ -68,6 +95,75 @@ func TestCredentialsGetWithProviderName(t *testing.T) {
c := NewCredentials(stub)
creds, err := c.Get()
assert.Nil(t, err, "Expected no error")
assert.Equal(t, creds.ProviderName, "stubProvider", "Expected provider name to match")
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if e, a := creds.ProviderName, "stubProvider"; e != a {
t.Errorf("Expected provider name to match, %v got %v", e, a)
}
}
func TestCredentialsIsExpired_Race(t *testing.T) {
creds := NewChainCredentials([]Provider{&MockProvider{}})
starter := make(chan struct{})
for i := 0; i < 10; i++ {
go func() {
<-starter
for {
creds.IsExpired()
}
}()
}
close(starter)
time.Sleep(10 * time.Second)
}
func TestCredentialsExpiresAt_NoExpirer(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
_, err := c.ExpiresAt()
if e, a := "ProviderNotExpirer", err.(awserr.Error).Code(); e != a {
t.Errorf("Expected provider error, %v got %v", e, a)
}
}
type stubProviderExpirer struct {
stubProvider
expiration time.Time
}
func (s *stubProviderExpirer) ExpiresAt() time.Time {
return s.expiration
}
func TestCredentialsExpiresAt_HasExpirer(t *testing.T) {
stub := &stubProviderExpirer{}
c := NewCredentials(stub)
// fetch initial credentials so that forceRefresh is set false
_, err := c.Get()
if err != nil {
t.Errorf("Unexpecte error: %v", err)
}
stub.expiration = time.Unix(rand.Int63(), 0)
expiration, err := c.ExpiresAt()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if stub.expiration != expiration {
t.Errorf("Expected matching expiration, %v got %v", stub.expiration, expiration)
}
c.Expire()
expiration, err = c.ExpiresAt()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !expiration.IsZero() {
t.Errorf("Expected distant past expiration, got %v", expiration)
}
}

View File

@ -4,7 +4,6 @@ import (
"bufio"
"encoding/json"
"fmt"
"path"
"strings"
"time"
@ -12,6 +11,7 @@ import (
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/internal/sdkuri"
)
// ProviderName provides a name of EC2Role provider
@ -125,7 +125,7 @@ type ec2RoleCredRespBody struct {
Message string
}
const iamSecurityCredsPath = "/iam/security-credentials"
const iamSecurityCredsPath = "iam/security-credentials/"
// requestCredList requests a list of credentials from the EC2 service.
// If there are no credentials, or there is an error making or receiving the request
@ -153,7 +153,7 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) {
// If the credentials cannot be found, or there is an error reading the response
// and error will be returned.
func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCredRespBody, error) {
resp, err := client.GetMetadata(path.Join(iamSecurityCredsPath, credsName))
resp, err := client.GetMetadata(sdkuri.PathJoin(iamSecurityCredsPath, credsName))
if err != nil {
return ec2RoleCredRespBody{},
awserr.New("EC2RoleRequestError",

View File

@ -7,8 +7,6 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
@ -34,7 +32,7 @@ const credsFailRespTmpl = `{
func initTestServer(expireOn string, failAssume bool) *httptest.Server {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/latest/meta-data/iam/security-credentials" {
if r.URL.Path == "/latest/meta-data/iam/security-credentials/" {
fmt.Fprintln(w, "RoleName")
} else if r.URL.Path == "/latest/meta-data/iam/security-credentials/RoleName" {
if failAssume {
@ -59,11 +57,19 @@ func TestEC2RoleProvider(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if e, a := "token", creds.SessionToken; e != a {
t.Errorf("Expect session token to match, %v got %v", e, a)
}
}
func TestEC2RoleProviderFailAssume(t *testing.T) {
@ -75,16 +81,30 @@ func TestEC2RoleProviderFailAssume(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Error(t, err, "Expect error")
if err == nil {
t.Errorf("Expect error")
}
e := err.(awserr.Error)
assert.Equal(t, "ErrorCode", e.Code())
assert.Equal(t, "ErrorMsg", e.Message())
assert.Nil(t, e.OrigErr())
if e, a := "ErrorCode", e.Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "ErrorMsg", e.Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := e.OrigErr(); v != nil {
t.Errorf("expect nil, got %v", v)
}
assert.Equal(t, "", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "", creds.SessionToken, "Expect session token to match")
if e, a := "", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if e, a := "", creds.SessionToken; e != a {
t.Errorf("Expect session token to match, %v got %v", e, a)
}
}
func TestEC2RoleProviderIsExpired(t *testing.T) {
@ -98,18 +118,26 @@ func TestEC2RoleProviderIsExpired(t *testing.T) {
return time.Date(2014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
if !p.IsExpired() {
t.Errorf("Expect creds to be expired before retrieve.")
}
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
if v := err; v != nil {
t.Errorf("Expect no error, %v", err)
}
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
if p.IsExpired() {
t.Errorf("Expect creds to not be expired after retrieve.")
}
p.CurrentTime = func() time.Time {
return time.Date(3014, 12, 15, 21, 26, 0, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
if !p.IsExpired() {
t.Errorf("Expect creds to be expired.")
}
}
func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
@ -124,18 +152,26 @@ func TestEC2RoleProviderExpiryWindowIsExpired(t *testing.T) {
return time.Date(2014, 12, 15, 0, 51, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve.")
if !p.IsExpired() {
t.Errorf("Expect creds to be expired before retrieve.")
}
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error, %v", err)
if v := err; v != nil {
t.Errorf("Expect no error, %v", err)
}
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve.")
if p.IsExpired() {
t.Errorf("Expect creds to not be expired after retrieve.")
}
p.CurrentTime = func() time.Time {
return time.Date(2014, 12, 16, 0, 55, 37, 0, time.UTC)
}
assert.True(t, p.IsExpired(), "Expect creds to be expired.")
if !p.IsExpired() {
t.Errorf("Expect creds to be expired.")
}
}
func BenchmarkEC3RoleProvider(b *testing.B) {

View File

@ -65,6 +65,10 @@ type Provider struct {
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// Optional authorization token value if set will be used as the value of
// the Authorization header of the endpoint credential request.
AuthorizationToken string
}
// NewProviderClient returns a credentials Provider for retrieving AWS credentials
@ -152,6 +156,9 @@ func (p *Provider) getCredentials() (*getCredentialsOutput, error) {
out := &getCredentialsOutput{}
req := p.Client.NewRequest(op, nil, out)
req.HTTPRequest.Header.Set("Accept", "application/json")
if authToken := p.AuthorizationToken; len(authToken) != 0 {
req.HTTPRequest.Header.Set("Authorization", authToken)
}
return out, req.Send()
}

View File

@ -11,14 +11,19 @@ import (
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/stretchr/testify/assert"
)
func TestRetrieveRefreshableCredentials(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/path/to/endpoint", r.URL.Path)
assert.Equal(t, "application/json", r.Header.Get("Accept"))
assert.Equal(t, "else", r.URL.Query().Get("something"))
if e, a := "/path/to/endpoint", r.URL.Path; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "application/json", r.Header.Get("Accept"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "else", r.URL.Query().Get("something"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
@ -39,18 +44,30 @@ func TestRetrieveRefreshableCredentials(t *testing.T) {
)
creds, err := client.Retrieve()
assert.NoError(t, err)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
assert.Equal(t, "AKID", creds.AccessKeyID)
assert.Equal(t, "SECRET", creds.SecretAccessKey)
assert.Equal(t, "TOKEN", creds.SessionToken)
assert.False(t, client.IsExpired())
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "TOKEN", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if client.IsExpired() {
t.Errorf("expect not expired, was")
}
client.(*endpointcreds.Provider).CurrentTime = func() time.Time {
return time.Now().Add(2 * time.Hour)
}
assert.True(t, client.IsExpired())
if !client.IsExpired() {
t.Errorf("expect expired, wasn't")
}
}
func TestRetrieveStaticCredentials(t *testing.T) {
@ -69,12 +86,22 @@ func TestRetrieveStaticCredentials(t *testing.T) {
client := endpointcreds.NewProviderClient(*unit.Session.Config, unit.Session.Handlers, server.URL)
creds, err := client.Retrieve()
assert.NoError(t, err)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
assert.Equal(t, "AKID", creds.AccessKeyID)
assert.Equal(t, "SECRET", creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.False(t, client.IsExpired())
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect no SessionToken, got %#v", v)
}
if client.IsExpired() {
t.Errorf("expect not expired, was")
}
}
func TestFailedRetrieveCredentials(t *testing.T) {
@ -94,18 +121,98 @@ func TestFailedRetrieveCredentials(t *testing.T) {
client := endpointcreds.NewProviderClient(*unit.Session.Config, unit.Session.Handlers, server.URL)
creds, err := client.Retrieve()
assert.Error(t, err)
if err == nil {
t.Errorf("expect error, got none")
}
aerr := err.(awserr.Error)
assert.Equal(t, "CredentialsEndpointError", aerr.Code())
assert.Equal(t, "failed to load credentials", aerr.Message())
if e, a := "CredentialsEndpointError", aerr.Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "failed to load credentials", aerr.Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
aerr = aerr.OrigErr().(awserr.Error)
assert.Equal(t, "Error", aerr.Code())
assert.Equal(t, "Message", aerr.Message())
if e, a := "Error", aerr.Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "Message", aerr.Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
assert.True(t, client.IsExpired())
if v := creds.AccessKeyID; len(v) != 0 {
t.Errorf("expect empty, got %#v", v)
}
if v := creds.SecretAccessKey; len(v) != 0 {
t.Errorf("expect empty, got %#v", v)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %#v", v)
}
if !client.IsExpired() {
t.Errorf("expect expired, wasn't")
}
}
func TestAuthorizationToken(t *testing.T) {
const expectAuthToken = "Basic abc123"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if e, a := "/path/to/endpoint", r.URL.Path; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "application/json", r.Header.Get("Accept"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectAuthToken, r.Header.Get("Authorization"); e != a {
t.Fatalf("expect %v, got %v", e, a)
}
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET",
"Token": "TOKEN",
"Expiration": time.Now().Add(1 * time.Hour),
})
if err != nil {
fmt.Println("failed to write out creds", err)
}
}))
client := endpointcreds.NewProviderClient(*unit.Session.Config,
unit.Session.Handlers,
server.URL+"/path/to/endpoint?something=else",
func(p *endpointcreds.Provider) {
p.AuthorizationToken = expectAuthToken
},
)
creds, err := client.Retrieve()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "TOKEN", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if client.IsExpired() {
t.Errorf("expect not expired, was")
}
client.(*endpointcreds.Provider).CurrentTime = func() time.Time {
return time.Now().Add(2 * time.Hour)
}
if !client.IsExpired() {
t.Errorf("expect expired, wasn't")
}
}

View File

@ -12,14 +12,10 @@ const EnvProviderName = "EnvProvider"
var (
// ErrAccessKeyIDNotFound is returned when the AWS Access Key ID can't be
// found in the process's environment.
//
// @readonly
ErrAccessKeyIDNotFound = awserr.New("EnvAccessKeyNotFound", "AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY not found in environment", nil)
// ErrSecretAccessKeyNotFound is returned when the AWS Secret Access Key
// can't be found in the process's environment.
//
// @readonly
ErrSecretAccessKeyNotFound = awserr.New("EnvSecretNotFound", "AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY not found in environment", nil)
)

View File

@ -1,7 +1,6 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"os"
"testing"
)
@ -14,11 +13,19 @@ func TestEnvProviderRetrieve(t *testing.T) {
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "access", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
if e, a := "access", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "token", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestEnvProviderIsExpired(t *testing.T) {
@ -29,12 +36,18 @@ func TestEnvProviderIsExpired(t *testing.T) {
e := EnvProvider{}
assert.True(t, e.IsExpired(), "Expect creds to be expired before retrieve.")
if !e.IsExpired() {
t.Errorf("Expect creds to be expired before retrieve.")
}
_, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.False(t, e.IsExpired(), "Expect creds to not be expired after retrieve.")
if e.IsExpired() {
t.Errorf("Expect creds to not be expired after retrieve.")
}
}
func TestEnvProviderNoAccessKeyID(t *testing.T) {
@ -42,8 +55,10 @@ func TestEnvProviderNoAccessKeyID(t *testing.T) {
os.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrAccessKeyIDNotFound, err, "ErrAccessKeyIDNotFound expected, but was %#v error: %#v", creds, err)
_, err := e.Retrieve()
if e, a := ErrAccessKeyIDNotFound, err; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestEnvProviderNoSecretAccessKey(t *testing.T) {
@ -51,8 +66,10 @@ func TestEnvProviderNoSecretAccessKey(t *testing.T) {
os.Setenv("AWS_ACCESS_KEY_ID", "access")
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Equal(t, ErrSecretAccessKeyNotFound, err, "ErrSecretAccessKeyNotFound expected, but was %#v error: %#v", creds, err)
_, err := e.Retrieve()
if e, a := ErrSecretAccessKeyNotFound, err; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestEnvProviderAlternateNames(t *testing.T) {
@ -62,9 +79,17 @@ func TestEnvProviderAlternateNames(t *testing.T) {
e := EnvProvider{}
creds, err := e.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "access", creds.AccessKeyID, "Expected access key ID")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expected secret access key")
assert.Empty(t, creds.SessionToken, "Expected no token")
if e, a := "access", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expected no token, %v", v)
}
}

View File

@ -24,7 +24,7 @@
//
// func() (RetrieveFn func() (key, secret, token string, err error), IsExpiredFn func() bool)
//
// Plugin Implementation Exmaple
// Plugin Implementation Example
//
// The following is an example implementation of a SDK credential provider using
// the plugin provider in this package. See the SDK's example/aws/credential/plugincreds/plugin

View File

@ -0,0 +1,425 @@
/*
Package processcreds is a credential Provider to retrieve `credential_process`
credentials.
WARNING: The following describes a method of sourcing credentials from an external
process. This can potentially be dangerous, so proceed with caution. Other
credential providers should be preferred if at all possible. If using this
option, you should make sure that the config file is as locked down as possible
using security best practices for your operating system.
You can use credentials from a `credential_process` in a variety of ways.
One way is to setup your shared config file, located in the default
location, with the `credential_process` key and the command you want to be
called. You also need to set the AWS_SDK_LOAD_CONFIG environment variable
(e.g., `export AWS_SDK_LOAD_CONFIG=1`) to use the shared config file.
[default]
credential_process = /command/to/call
Creating a new session will use the credential process to retrieve credentials.
NOTE: If there are credentials in the profile you are using, the credential
process will not be used.
// Initialize a session to load credentials.
sess, _ := session.NewSession(&aws.Config{
Region: aws.String("us-east-1")},
)
// Create S3 service client to use the credentials.
svc := s3.New(sess)
Another way to use the `credential_process` method is by using
`credentials.NewCredentials()` and providing a command to be executed to
retrieve credentials:
// Create credentials using the ProcessProvider.
creds := processcreds.NewCredentials("/path/to/command")
// Create service client value configured for credentials.
svc := s3.New(sess, &aws.Config{Credentials: creds})
You can set a non-default timeout for the `credential_process` with another
constructor, `credentials.NewCredentialsTimeout()`, providing the timeout. To
set a one minute timeout:
// Create credentials using the ProcessProvider.
creds := processcreds.NewCredentialsTimeout(
"/path/to/command",
time.Duration(500) * time.Millisecond)
If you need more control, you can set any configurable options in the
credentials using one or more option functions. For example, you can set a two
minute timeout, a credential duration of 60 minutes, and a maximum stdout
buffer size of 2k.
creds := processcreds.NewCredentials(
"/path/to/command",
func(opt *ProcessProvider) {
opt.Timeout = time.Duration(2) * time.Minute
opt.Duration = time.Duration(60) * time.Minute
opt.MaxBufSize = 2048
})
You can also use your own `exec.Cmd`:
// Create an exec.Cmd
myCommand := exec.Command("/path/to/command")
// Create credentials using your exec.Cmd and custom timeout
creds := processcreds.NewCredentialsCommand(
myCommand,
func(opt *processcreds.ProcessProvider) {
opt.Timeout = time.Duration(1) * time.Second
})
*/
package processcreds
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"runtime"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
)
const (
// ProviderName is the name this credentials provider will label any
// returned credentials Value with.
ProviderName = `ProcessProvider`
// ErrCodeProcessProviderParse error parsing process output
ErrCodeProcessProviderParse = "ProcessProviderParseError"
// ErrCodeProcessProviderVersion version error in output
ErrCodeProcessProviderVersion = "ProcessProviderVersionError"
// ErrCodeProcessProviderRequired required attribute missing in output
ErrCodeProcessProviderRequired = "ProcessProviderRequiredError"
// ErrCodeProcessProviderExecution execution of command failed
ErrCodeProcessProviderExecution = "ProcessProviderExecutionError"
// errMsgProcessProviderTimeout process took longer than allowed
errMsgProcessProviderTimeout = "credential process timed out"
// errMsgProcessProviderProcess process error
errMsgProcessProviderProcess = "error in credential_process"
// errMsgProcessProviderParse problem parsing output
errMsgProcessProviderParse = "parse failed of credential_process output"
// errMsgProcessProviderVersion version error in output
errMsgProcessProviderVersion = "wrong version in process output (not 1)"
// errMsgProcessProviderMissKey missing access key id in output
errMsgProcessProviderMissKey = "missing AccessKeyId in process output"
// errMsgProcessProviderMissSecret missing secret acess key in output
errMsgProcessProviderMissSecret = "missing SecretAccessKey in process output"
// errMsgProcessProviderPrepareCmd prepare of command failed
errMsgProcessProviderPrepareCmd = "failed to prepare command"
// errMsgProcessProviderEmptyCmd command must not be empty
errMsgProcessProviderEmptyCmd = "command must not be empty"
// errMsgProcessProviderPipe failed to initialize pipe
errMsgProcessProviderPipe = "failed to initialize pipe"
// DefaultDuration is the default amount of time in minutes that the
// credentials will be valid for.
DefaultDuration = time.Duration(15) * time.Minute
// DefaultBufSize limits buffer size from growing to an enormous
// amount due to a faulty process.
DefaultBufSize = 1024
// DefaultTimeout default limit on time a process can run.
DefaultTimeout = time.Duration(1) * time.Minute
)
// ProcessProvider satisfies the credentials.Provider interface, and is a
// client to retrieve credentials from a process.
type ProcessProvider struct {
staticCreds bool
credentials.Expiry
originalCommand []string
// Expiry duration of the credentials. Defaults to 15 minutes if not set.
Duration time.Duration
// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration
// A string representing an os command that should return a JSON with
// credential information.
command *exec.Cmd
// MaxBufSize limits memory usage from growing to an enormous
// amount due to a faulty process.
MaxBufSize int
// Timeout limits the time a process can run.
Timeout time.Duration
}
// NewCredentials returns a pointer to a new Credentials object wrapping the
// ProcessProvider. The credentials will expire every 15 minutes by default.
func NewCredentials(command string, options ...func(*ProcessProvider)) *credentials.Credentials {
p := &ProcessProvider{
command: exec.Command(command),
Duration: DefaultDuration,
Timeout: DefaultTimeout,
MaxBufSize: DefaultBufSize,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
// NewCredentialsTimeout returns a pointer to a new Credentials object with
// the specified command and timeout, and default duration and max buffer size.
func NewCredentialsTimeout(command string, timeout time.Duration) *credentials.Credentials {
p := NewCredentials(command, func(opt *ProcessProvider) {
opt.Timeout = timeout
})
return p
}
// NewCredentialsCommand returns a pointer to a new Credentials object with
// the specified command, and default timeout, duration and max buffer size.
func NewCredentialsCommand(command *exec.Cmd, options ...func(*ProcessProvider)) *credentials.Credentials {
p := &ProcessProvider{
command: command,
Duration: DefaultDuration,
Timeout: DefaultTimeout,
MaxBufSize: DefaultBufSize,
}
for _, option := range options {
option(p)
}
return credentials.NewCredentials(p)
}
type credentialProcessResponse struct {
Version int
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string
SessionToken string
Expiration *time.Time
}
// Retrieve executes the 'credential_process' and returns the credentials.
func (p *ProcessProvider) Retrieve() (credentials.Value, error) {
out, err := p.executeCredentialProcess()
if err != nil {
return credentials.Value{ProviderName: ProviderName}, err
}
// Serialize and validate response
resp := &credentialProcessResponse{}
if err = json.Unmarshal(out, resp); err != nil {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderParse,
fmt.Sprintf("%s: %s", errMsgProcessProviderParse, string(out)),
err)
}
if resp.Version != 1 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderVersion,
errMsgProcessProviderVersion,
nil)
}
if len(resp.AccessKeyID) == 0 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderRequired,
errMsgProcessProviderMissKey,
nil)
}
if len(resp.SecretAccessKey) == 0 {
return credentials.Value{ProviderName: ProviderName}, awserr.New(
ErrCodeProcessProviderRequired,
errMsgProcessProviderMissSecret,
nil)
}
// Handle expiration
p.staticCreds = resp.Expiration == nil
if resp.Expiration != nil {
p.SetExpiration(*resp.Expiration, p.ExpiryWindow)
}
return credentials.Value{
ProviderName: ProviderName,
AccessKeyID: resp.AccessKeyID,
SecretAccessKey: resp.SecretAccessKey,
SessionToken: resp.SessionToken,
}, nil
}
// IsExpired returns true if the credentials retrieved are expired, or not yet
// retrieved.
func (p *ProcessProvider) IsExpired() bool {
if p.staticCreds {
return false
}
return p.Expiry.IsExpired()
}
// prepareCommand prepares the command to be executed.
func (p *ProcessProvider) prepareCommand() error {
var cmdArgs []string
if runtime.GOOS == "windows" {
cmdArgs = []string{"cmd.exe", "/C"}
} else {
cmdArgs = []string{"sh", "-c"}
}
if len(p.originalCommand) == 0 {
p.originalCommand = make([]string, len(p.command.Args))
copy(p.originalCommand, p.command.Args)
// check for empty command because it succeeds
if len(strings.TrimSpace(p.originalCommand[0])) < 1 {
return awserr.New(
ErrCodeProcessProviderExecution,
fmt.Sprintf(
"%s: %s",
errMsgProcessProviderPrepareCmd,
errMsgProcessProviderEmptyCmd),
nil)
}
}
cmdArgs = append(cmdArgs, p.originalCommand...)
p.command = exec.Command(cmdArgs[0], cmdArgs[1:]...)
p.command.Env = os.Environ()
return nil
}
// executeCredentialProcess starts the credential process on the OS and
// returns the results or an error.
func (p *ProcessProvider) executeCredentialProcess() ([]byte, error) {
if err := p.prepareCommand(); err != nil {
return nil, err
}
// Setup the pipes
outReadPipe, outWritePipe, err := os.Pipe()
if err != nil {
return nil, awserr.New(
ErrCodeProcessProviderExecution,
errMsgProcessProviderPipe,
err)
}
p.command.Stderr = os.Stderr // display stderr on console for MFA
p.command.Stdout = outWritePipe // get creds json on process's stdout
p.command.Stdin = os.Stdin // enable stdin for MFA
output := bytes.NewBuffer(make([]byte, 0, p.MaxBufSize))
stdoutCh := make(chan error, 1)
go readInput(
io.LimitReader(outReadPipe, int64(p.MaxBufSize)),
output,
stdoutCh)
execCh := make(chan error, 1)
go executeCommand(*p.command, execCh)
finished := false
var errors []error
for !finished {
select {
case readError := <-stdoutCh:
errors = appendError(errors, readError)
finished = true
case execError := <-execCh:
err := outWritePipe.Close()
errors = appendError(errors, err)
errors = appendError(errors, execError)
if errors != nil {
return output.Bytes(), awserr.NewBatchError(
ErrCodeProcessProviderExecution,
errMsgProcessProviderProcess,
errors)
}
case <-time.After(p.Timeout):
finished = true
return output.Bytes(), awserr.NewBatchError(
ErrCodeProcessProviderExecution,
errMsgProcessProviderTimeout,
errors) // errors can be nil
}
}
out := output.Bytes()
if runtime.GOOS == "windows" {
// windows adds slashes to quotes
out = []byte(strings.Replace(string(out), `\"`, `"`, -1))
}
return out, nil
}
// appendError conveniently checks for nil before appending slice
func appendError(errors []error, err error) []error {
if err != nil {
return append(errors, err)
}
return errors
}
func executeCommand(cmd exec.Cmd, exec chan error) {
// Start the command
err := cmd.Start()
if err == nil {
err = cmd.Wait()
}
exec <- err
}
func readInput(r io.Reader, w io.Writer, read chan error) {
tee := io.TeeReader(r, w)
_, err := ioutil.ReadAll(tee)
if err == io.EOF {
err = nil
}
read <- err // will only arrive here when write end of pipe is closed
}

View File

@ -0,0 +1,584 @@
package processcreds_test
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"os/exec"
"runtime"
"strings"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/processcreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/awstesting"
)
func TestProcessProviderFromSessionCfg(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
if runtime.GOOS == "windows" {
os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini")
} else {
os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini")
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("region")},
)
if err != nil {
t.Errorf("error getting session: %v", err)
}
creds, err := sess.Config.Credentials.Get()
if err != nil {
t.Errorf("error getting credentials: %v", err)
}
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := "tokenDefault", creds.SessionToken; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestProcessProviderFromSessionWithProfileCfg(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_PROFILE", "non_expire")
if runtime.GOOS == "windows" {
os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini")
} else {
os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini")
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("region")},
)
if err != nil {
t.Errorf("error getting session: %v", err)
}
creds, err := sess.Config.Credentials.Get()
if err != nil {
t.Errorf("error getting credentials: %v", err)
}
if e, a := "nonDefaultToken", creds.SessionToken; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestProcessProviderNotFromCredProcCfg(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_SDK_LOAD_CONFIG", "1")
os.Setenv("AWS_PROFILE", "not_alone")
if runtime.GOOS == "windows" {
os.Setenv("AWS_CONFIG_FILE", "testdata\\shconfig_win.ini")
} else {
os.Setenv("AWS_CONFIG_FILE", "testdata/shconfig.ini")
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("region")},
)
if err != nil {
t.Errorf("error getting session: %v", err)
}
creds, err := sess.Config.Credentials.Get()
if err != nil {
t.Errorf("error getting credentials: %v", err)
}
if e, a := "notFromCredProcAccess", creds.AccessKeyID; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := "notFromCredProcSecret", creds.SecretAccessKey; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestProcessProviderFromSessionCrd(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
if runtime.GOOS == "windows" {
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini")
} else {
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini")
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("region")},
)
if err != nil {
t.Errorf("error getting session: %v", err)
}
creds, err := sess.Config.Credentials.Get()
if err != nil {
t.Errorf("error getting credentials: %v", err)
}
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := "tokenDefault", creds.SessionToken; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestProcessProviderFromSessionWithProfileCrd(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_PROFILE", "non_expire")
if runtime.GOOS == "windows" {
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini")
} else {
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini")
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("region")},
)
if err != nil {
t.Errorf("error getting session: %v", err)
}
creds, err := sess.Config.Credentials.Get()
if err != nil {
t.Errorf("error getting credentials: %v", err)
}
if e, a := "nonDefaultToken", creds.SessionToken; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestProcessProviderNotFromCredProcCrd(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
os.Setenv("AWS_PROFILE", "not_alone")
if runtime.GOOS == "windows" {
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata\\shcred_win.ini")
} else {
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", "testdata/shcred.ini")
}
sess, err := session.NewSession(&aws.Config{
Region: aws.String("region")},
)
if err != nil {
t.Errorf("error getting session: %v", err)
}
creds, err := sess.Config.Credentials.Get()
if err != nil {
t.Errorf("error getting credentials: %v", err)
}
if e, a := "notFromCredProcAccess", creds.AccessKeyID; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := "notFromCredProcSecret", creds.SecretAccessKey; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
func TestProcessProviderBadCommand(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
creds := processcreds.NewCredentials("/bad/process")
_, err := creds.Get()
if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err)
}
}
func TestProcessProviderMoreEmptyCommands(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
creds := processcreds.NewCredentials("")
_, err := creds.Get()
if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err)
}
}
func TestProcessProviderExpectErrors(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
creds := processcreds.NewCredentials(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "malformed.json"},
string(os.PathSeparator))))
_, err := creds.Get()
if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderParse {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderParse, err)
}
creds = processcreds.NewCredentials(
fmt.Sprintf("%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "wrongversion.json"},
string(os.PathSeparator))))
_, err = creds.Get()
if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderVersion {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderVersion, err)
}
creds = processcreds.NewCredentials(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "missingkey.json"},
string(os.PathSeparator))))
_, err = creds.Get()
if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderRequired {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderRequired, err)
}
creds = processcreds.NewCredentials(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "missingsecret.json"},
string(os.PathSeparator))))
_, err = creds.Get()
if err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderRequired {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderRequired, err)
}
}
func TestProcessProviderTimeout(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
command := "/bin/sleep 2"
if runtime.GOOS == "windows" {
// "timeout" command does not work due to pipe redirection
command = "ping -n 2 127.0.0.1>nul"
}
creds := processcreds.NewCredentialsTimeout(
command,
time.Duration(1)*time.Second)
if _, err := creds.Get(); err == nil || err.(awserr.Error).Code() != processcreds.ErrCodeProcessProviderExecution || err.(awserr.Error).Message() != "credential process timed out" {
t.Errorf("expected %v, got %v", processcreds.ErrCodeProcessProviderExecution, err)
}
}
func TestProcessProviderWithLongSessionToken(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
creds := processcreds.NewCredentials(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "longsessiontoken.json"},
string(os.PathSeparator))))
v, err := creds.Get()
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
// Text string same length as session token returned by AWS for AssumeRoleWithWebIdentity
e := "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
if a := v.SessionToken; e != a {
t.Errorf("expected %v, got %v", e, a)
}
}
type credentialTest struct {
Version int
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string
Expiration string
}
func TestProcessProviderStatic(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
// static
creds := processcreds.NewCredentials(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "static.json"},
string(os.PathSeparator))))
_, err := creds.Get()
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
if creds.IsExpired() {
t.Errorf("expected %v, got %v", "static credentials/not expired", "expired")
}
}
func TestProcessProviderNotExpired(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
// non-static, not expired
exp := &credentialTest{}
exp.Version = 1
exp.AccessKeyID = "accesskey"
exp.SecretAccessKey = "secretkey"
exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339)
b, err := json.Marshal(exp)
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
tmpFile := strings.Join(
[]string{"testdata", "tmp_expiring.json"},
string(os.PathSeparator))
if err = ioutil.WriteFile(tmpFile, b, 0644); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
defer func() {
if err = os.Remove(tmpFile); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
}()
creds := processcreds.NewCredentials(
fmt.Sprintf("%s %s", getOSCat(), tmpFile))
_, err = creds.Get()
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
if creds.IsExpired() {
t.Errorf("expected %v, got %v", "not expired", "expired")
}
}
func TestProcessProviderExpired(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
// non-static, expired
exp := &credentialTest{}
exp.Version = 1
exp.AccessKeyID = "accesskey"
exp.SecretAccessKey = "secretkey"
exp.Expiration = time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339)
b, err := json.Marshal(exp)
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
tmpFile := strings.Join(
[]string{"testdata", "tmp_expired.json"},
string(os.PathSeparator))
if err = ioutil.WriteFile(tmpFile, b, 0644); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
defer func() {
if err = os.Remove(tmpFile); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
}()
creds := processcreds.NewCredentials(
fmt.Sprintf("%s %s", getOSCat(), tmpFile))
_, err = creds.Get()
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
if !creds.IsExpired() {
t.Errorf("expected %v, got %v", "expired", "not expired")
}
}
func TestProcessProviderForceExpire(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
// non-static, not expired
// setup test credentials file
exp := &credentialTest{}
exp.Version = 1
exp.AccessKeyID = "accesskey"
exp.SecretAccessKey = "secretkey"
exp.Expiration = time.Now().Add(1 * time.Hour).UTC().Format(time.RFC3339)
b, err := json.Marshal(exp)
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
tmpFile := strings.Join(
[]string{"testdata", "tmp_force_expire.json"},
string(os.PathSeparator))
if err = ioutil.WriteFile(tmpFile, b, 0644); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
defer func() {
if err = os.Remove(tmpFile); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
}()
// get credentials from file
creds := processcreds.NewCredentials(
fmt.Sprintf("%s %s", getOSCat(), tmpFile))
if _, err = creds.Get(); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
if creds.IsExpired() {
t.Errorf("expected %v, got %v", "not expired", "expired")
}
// force expire creds
creds.Expire()
if !creds.IsExpired() {
t.Errorf("expected %v, got %v", "expired", "not expired")
}
// renew creds
if _, err = creds.Get(); err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
if creds.IsExpired() {
t.Errorf("expected %v, got %v", "not expired", "expired")
}
}
func TestProcessProviderAltConstruct(t *testing.T) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
// constructing with exec.Cmd instead of string
myCommand := exec.Command(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "static.json"},
string(os.PathSeparator))))
creds := processcreds.NewCredentialsCommand(myCommand, func(opt *processcreds.ProcessProvider) {
opt.Timeout = time.Duration(1) * time.Second
})
_, err := creds.Get()
if err != nil {
t.Errorf("expected %v, got %v", "no error", err)
}
if creds.IsExpired() {
t.Errorf("expected %v, got %v", "static credentials/not expired", "expired")
}
}
func BenchmarkProcessProvider(b *testing.B) {
oldEnv := preserveImportantStashEnv()
defer awstesting.PopEnv(oldEnv)
creds := processcreds.NewCredentials(
fmt.Sprintf(
"%s %s",
getOSCat(),
strings.Join(
[]string{"testdata", "static.json"},
string(os.PathSeparator))))
_, err := creds.Get()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := creds.Get()
if err != nil {
b.Fatal(err)
}
}
}
func preserveImportantStashEnv() []string {
envsToKeep := []string{"PATH"}
if runtime.GOOS == "windows" {
envsToKeep = append(envsToKeep, "ComSpec")
envsToKeep = append(envsToKeep, "SYSTEM32")
}
extraEnv := getEnvs(envsToKeep)
oldEnv := awstesting.StashEnv() //clear env
for key, val := range extraEnv {
os.Setenv(key, val)
}
return oldEnv
}
func getEnvs(envs []string) map[string]string {
extraEnvs := make(map[string]string)
for _, env := range envs {
if val, ok := os.LookupEnv(env); ok && len(val) > 0 {
extraEnvs[env] = val
}
}
return extraEnvs
}
func getOSCat() string {
if runtime.GOOS == "windows" {
return "type"
}
return "cat"
}

View File

@ -0,0 +1,7 @@
{
"Version": 1,
"AccessKeyId": "accessKey",
"SecretAccessKey": "secret",
"SessionToken": "tokenDefault",
"Expiration": "2000-01-01T00:00:00-00:00"
}

View File

@ -0,0 +1,8 @@
{
"Version": 1,
"AccessKeyId": "ASIAXXXXXXXXXXXXXXXX",
"Expiration": "2199-01-01T00:00:00Z",
"SecretAccessKey": "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX",
"SessionToken":
"XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
}

View File

@ -0,0 +1,2 @@
{
"Version": 1

View File

@ -0,0 +1,4 @@
{
"Version": 1,
"AccessKeyId": "accesskey"
}

View File

@ -0,0 +1,4 @@
{
"Version": 1,
"SecretAccessKey": "secretkey"
}

View File

@ -0,0 +1,6 @@
{
"Version": 1,
"AccessKeyId": "accessKey",
"SecretAccessKey": "secret",
"SessionToken": "nonDefaultToken"
}

View File

@ -0,0 +1,10 @@
[default]
credential_process = cat ./testdata/expired.json
[profile non_expire]
credential_process = cat ./testdata/nonexpire.json
[profile not_alone]
aws_access_key_id = notFromCredProcAccess
aws_secret_access_key = notFromCredProcSecret
credential_process = cat ./testdata/verybad.json

View File

@ -0,0 +1,10 @@
[default]
credential_process = type .\testdata\expired.json
[profile non_expire]
credential_process = type .\testdata\nonexpire.json
[profile not_alone]
aws_access_key_id = notFromCredProcAccess
aws_secret_access_key = notFromCredProcSecret
credential_process = type .\testdata\verybad.json

View File

@ -0,0 +1,10 @@
[default]
credential_process = cat ./testdata/expired.json
[non_expire]
credential_process = cat ./testdata/nonexpire.json
[not_alone]
aws_access_key_id = notFromCredProcAccess
aws_secret_access_key = notFromCredProcSecret
credential_process = cat ./testdata/verybad.json

View File

@ -0,0 +1,10 @@
[default]
credential_process = type .\testdata\expired.json
[non_expire]
credential_process = type .\testdata\nonexpire.json
[not_alone]
aws_access_key_id = notFromCredProcAccess
aws_secret_access_key = notFromCredProcSecret
credential_process = type .\testdata\verybad.json

View File

@ -0,0 +1,5 @@
{
"Version":1,
"AccessKeyId":"accesskey",
"SecretAccessKey":"secretkey"
}

View File

@ -0,0 +1,5 @@
{
"Version":1,
"AccessKeyId":"veryBadAccessKeyID",
"SecretAccessKey":"veryBadSecretAccessKey"
}

View File

@ -0,0 +1,3 @@
{
"Version": 2
}

View File

@ -4,9 +4,8 @@ import (
"fmt"
"os"
"github.com/go-ini/ini"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/internal/ini"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
@ -77,36 +76,37 @@ func (p *SharedCredentialsProvider) IsExpired() bool {
// The credentials retrieved from the profile will be returned or error. Error will be
// returned if it fails to read from the file, or the data is invalid.
func loadProfile(filename, profile string) (Value, error) {
config, err := ini.Load(filename)
config, err := ini.OpenFile(filename)
if err != nil {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to load shared credentials file", err)
}
iniProfile, err := config.GetSection(profile)
if err != nil {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", err)
iniProfile, ok := config.GetSection(profile)
if !ok {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsLoad", "failed to get profile", nil)
}
id, err := iniProfile.GetKey("aws_access_key_id")
if err != nil {
id := iniProfile.String("aws_access_key_id")
if len(id) == 0 {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsAccessKey",
fmt.Sprintf("shared credentials %s in %s did not contain aws_access_key_id", profile, filename),
err)
nil)
}
secret, err := iniProfile.GetKey("aws_secret_access_key")
if err != nil {
secret := iniProfile.String("aws_secret_access_key")
if len(secret) == 0 {
return Value{ProviderName: SharedCredsProviderName}, awserr.New("SharedCredsSecret",
fmt.Sprintf("shared credentials %s in %s did not contain aws_secret_access_key", profile, filename),
nil)
}
// Default to empty string if not found
token := iniProfile.Key("aws_session_token")
token := iniProfile.String("aws_session_token")
return Value{
AccessKeyID: id.String(),
SecretAccessKey: secret.String(),
SessionToken: token.String(),
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: token,
ProviderName: SharedCredsProviderName,
}, nil
}

View File

@ -6,7 +6,6 @@ import (
"testing"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
"github.com/stretchr/testify/assert"
)
func TestSharedCredentialsProvider(t *testing.T) {
@ -14,11 +13,19 @@ func TestSharedCredentialsProvider(t *testing.T) {
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "token", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSharedCredentialsProviderIsExpired(t *testing.T) {
@ -26,12 +33,18 @@ func TestSharedCredentialsProviderIsExpired(t *testing.T) {
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
assert.True(t, p.IsExpired(), "Expect creds to be expired before retrieve")
if !p.IsExpired() {
t.Errorf("Expect creds to be expired before retrieve")
}
_, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.False(t, p.IsExpired(), "Expect creds to not be expired after retrieve")
if p.IsExpired() {
t.Errorf("Expect creds to not be expired after retrieve")
}
}
func TestSharedCredentialsProviderWithAWS_SHARED_CREDENTIALS_FILE(t *testing.T) {
@ -40,25 +53,43 @@ func TestSharedCredentialsProviderWithAWS_SHARED_CREDENTIALS_FILE(t *testing.T)
p := SharedCredentialsProvider{}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "token", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSharedCredentialsProviderWithAWS_SHARED_CREDENTIALS_FILEAbsPath(t *testing.T) {
os.Clearenv()
wd, err := os.Getwd()
assert.NoError(t, err)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", filepath.Join(wd, "example.ini"))
p := SharedCredentialsProvider{}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "token", creds.SessionToken, "Expect session token to match")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "token", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSharedCredentialsProviderWithAWS_PROFILE(t *testing.T) {
@ -67,11 +98,19 @@ func TestSharedCredentialsProviderWithAWS_PROFILE(t *testing.T) {
p := SharedCredentialsProvider{Filename: "example.ini", Profile: ""}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect no token, %v", v)
}
}
func TestSharedCredentialsProviderWithoutTokenFromProfile(t *testing.T) {
@ -79,11 +118,19 @@ func TestSharedCredentialsProviderWithoutTokenFromProfile(t *testing.T) {
p := SharedCredentialsProvider{Filename: "example.ini", Profile: "no_token"}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect no token, %v", v)
}
}
func TestSharedCredentialsProviderColonInCredFile(t *testing.T) {
@ -91,11 +138,19 @@ func TestSharedCredentialsProviderColonInCredFile(t *testing.T) {
p := SharedCredentialsProvider{Filename: "example.ini", Profile: "with_colon"}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "accessKey", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "secret", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no token")
if e, a := "accessKey", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "secret", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect no token, %v", v)
}
}
func TestSharedCredentialsProvider_DefaultFilename(t *testing.T) {

View File

@ -9,8 +9,6 @@ const StaticProviderName = "StaticProvider"
var (
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
//
// @readonly
ErrStaticCredentialsEmpty = awserr.New("EmptyStaticCreds", "static credentials are empty", nil)
)

View File

@ -1,7 +1,6 @@
package credentials
import (
"github.com/stretchr/testify/assert"
"testing"
)
@ -15,10 +14,18 @@ func TestStaticProviderGet(t *testing.T) {
}
creds, err := s.Retrieve()
assert.Nil(t, err, "Expect no error")
assert.Equal(t, "AKID", creds.AccessKeyID, "Expect access key ID to match")
assert.Equal(t, "SECRET", creds.SecretAccessKey, "Expect secret access key to match")
assert.Empty(t, creds.SessionToken, "Expect no session token")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect no session token, %v", v)
}
}
func TestStaticProviderIsExpired(t *testing.T) {
@ -30,5 +37,7 @@ func TestStaticProviderIsExpired(t *testing.T) {
},
}
assert.False(t, s.IsExpired(), "Expect static credentials to never expire")
if s.IsExpired() {
t.Errorf("Expect static credentials to never expire")
}
}

View File

@ -7,7 +7,6 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
)
type stubSTS struct {
@ -38,18 +37,30 @@ func TestAssumeRoleProvider(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
if e, a := "roleARN", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSessionToken", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestAssumeRoleProvider_WithTokenCode(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Equal(t, "0123456789", *in.SerialNumber)
assert.Equal(t, "code", *in.TokenCode)
if e, a := "0123456789", *in.SerialNumber; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "code", *in.TokenCode; e != a {
t.Errorf("expect %v, got %v", e, a)
}
},
}
p := &AssumeRoleProvider{
@ -60,18 +71,30 @@ func TestAssumeRoleProvider_WithTokenCode(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
if e, a := "roleARN", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSessionToken", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestAssumeRoleProvider_WithTokenProvider(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Equal(t, "0123456789", *in.SerialNumber)
assert.Equal(t, "code", *in.TokenCode)
if e, a := "0123456789", *in.SerialNumber; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "code", *in.TokenCode; e != a {
t.Errorf("expect %v, got %v", e, a)
}
},
}
p := &AssumeRoleProvider{
@ -84,17 +107,25 @@ func TestAssumeRoleProvider_WithTokenProvider(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Nil(t, err, "Expect no error")
if err != nil {
t.Errorf("expect nil, got %v", err)
}
assert.Equal(t, "roleARN", creds.AccessKeyID, "Expect access key ID to be reflected role ARN")
assert.Equal(t, "assumedSecretAccessKey", creds.SecretAccessKey, "Expect secret access key to match")
assert.Equal(t, "assumedSessionToken", creds.SessionToken, "Expect session token to match")
if e, a := "roleARN", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSecretAccessKey", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "assumedSessionToken", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestAssumeRoleProvider_WithTokenProviderError(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Fail(t, "API request should not of been called")
t.Errorf("API request should not of been called")
},
}
p := &AssumeRoleProvider{
@ -107,17 +138,25 @@ func TestAssumeRoleProvider_WithTokenProviderError(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Error(t, err)
if err == nil {
t.Errorf("expect error")
}
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
if v := creds.AccessKeyID; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if v := creds.SecretAccessKey; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
}
func TestAssumeRoleProvider_MFAWithNoToken(t *testing.T) {
stub := &stubSTS{
TestInput: func(in *sts.AssumeRoleInput) {
assert.Fail(t, "API request should not of been called")
t.Errorf("API request should not of been called")
},
}
p := &AssumeRoleProvider{
@ -127,11 +166,19 @@ func TestAssumeRoleProvider_MFAWithNoToken(t *testing.T) {
}
creds, err := p.Retrieve()
assert.Error(t, err)
if err == nil {
t.Errorf("expect error")
}
assert.Empty(t, creds.AccessKeyID)
assert.Empty(t, creds.SecretAccessKey)
assert.Empty(t, creds.SessionToken)
if v := creds.AccessKeyID; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if v := creds.SecretAccessKey; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("expect empty, got %v", v)
}
}
func BenchmarkAssumeRoleProvider(b *testing.B) {

119
vendor/github.com/aws/aws-sdk-go/aws/crr/cache.go generated vendored Normal file
View File

@ -0,0 +1,119 @@
package crr
import (
"sync/atomic"
)
// EndpointCache is an LRU cache that holds a series of endpoints
// based on some key. The datastructure makes use of a read write
// mutex to enable asynchronous use.
type EndpointCache struct {
endpoints syncMap
endpointLimit int64
// size is used to count the number elements in the cache.
// The atomic package is used to ensure this size is accurate when
// using multiple goroutines.
size int64
}
// NewEndpointCache will return a newly initialized cache with a limit
// of endpointLimit entries.
func NewEndpointCache(endpointLimit int64) *EndpointCache {
return &EndpointCache{
endpointLimit: endpointLimit,
endpoints: newSyncMap(),
}
}
// get is a concurrent safe get operation that will retrieve an endpoint
// based on endpointKey. A boolean will also be returned to illustrate whether
// or not the endpoint had been found.
func (c *EndpointCache) get(endpointKey string) (Endpoint, bool) {
endpoint, ok := c.endpoints.Load(endpointKey)
if !ok {
return Endpoint{}, false
}
c.endpoints.Store(endpointKey, endpoint)
return endpoint.(Endpoint), true
}
// Has returns if the enpoint cache contains a valid entry for the endpoint key
// provided.
func (c *EndpointCache) Has(endpointKey string) bool {
endpoint, ok := c.get(endpointKey)
_, found := endpoint.GetValidAddress()
return ok && found
}
// Get will retrieve a weighted address based off of the endpoint key. If an endpoint
// should be retrieved, due to not existing or the current endpoint has expired
// the Discoverer object that was passed in will attempt to discover a new endpoint
// and add that to the cache.
func (c *EndpointCache) Get(d Discoverer, endpointKey string, required bool) (WeightedAddress, error) {
var err error
endpoint, ok := c.get(endpointKey)
weighted, found := endpoint.GetValidAddress()
shouldGet := !ok || !found
if required && shouldGet {
if endpoint, err = c.discover(d, endpointKey); err != nil {
return WeightedAddress{}, err
}
weighted, _ = endpoint.GetValidAddress()
} else if shouldGet {
go c.discover(d, endpointKey)
}
return weighted, nil
}
// Add is a concurrent safe operation that will allow new endpoints to be added
// to the cache. If the cache is full, the number of endpoints equal endpointLimit,
// then this will remove the oldest entry before adding the new endpoint.
func (c *EndpointCache) Add(endpoint Endpoint) {
// de-dups multiple adds of an endpoint with a pre-existing key
if iface, ok := c.endpoints.Load(endpoint.Key); ok {
e := iface.(Endpoint)
if e.Len() > 0 {
return
}
}
c.endpoints.Store(endpoint.Key, endpoint)
size := atomic.AddInt64(&c.size, 1)
if size > 0 && size > c.endpointLimit {
c.deleteRandomKey()
}
}
// deleteRandomKey will delete a random key from the cache. If
// no key was deleted false will be returned.
func (c *EndpointCache) deleteRandomKey() bool {
atomic.AddInt64(&c.size, -1)
found := false
c.endpoints.Range(func(key, value interface{}) bool {
found = true
c.endpoints.Delete(key)
return false
})
return found
}
// discover will get and store and endpoint using the Discoverer.
func (c *EndpointCache) discover(d Discoverer, endpointKey string) (Endpoint, error) {
endpoint, err := d.Discover()
if err != nil {
return Endpoint{}, err
}
endpoint.Key = endpointKey
c.Add(endpoint)
return endpoint, nil
}

452
vendor/github.com/aws/aws-sdk-go/aws/crr/cache_test.go generated vendored Normal file
View File

@ -0,0 +1,452 @@
package crr
import (
"net/url"
"reflect"
"testing"
)
func urlParse(uri string) *url.URL {
u, _ := url.Parse(uri)
return u
}
func TestCacheAdd(t *testing.T) {
cases := []struct {
limit int64
endpoints []Endpoint
validKeys map[string]Endpoint
expectedSize int
}{
{
limit: 5,
endpoints: []Endpoint{
{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
validKeys: map[string]Endpoint{
"foo": Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
"bar": Endpoint{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
"baz": Endpoint{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
"qux": Endpoint{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
"moo": Endpoint{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
expectedSize: 5,
},
{
limit: 2,
endpoints: []Endpoint{
{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
validKeys: map[string]Endpoint{
"foo": Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
"bar": Endpoint{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
"baz": Endpoint{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
"qux": Endpoint{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
"moo": Endpoint{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
expectedSize: 2,
},
}
for _, c := range cases {
cache := NewEndpointCache(c.limit)
for _, endpoint := range c.endpoints {
cache.Add(endpoint)
}
count := 0
endpoints := map[string]Endpoint{}
cache.endpoints.Range(func(key, value interface{}) bool {
count++
endpoints[key.(string)] = value.(Endpoint)
return true
})
if e, a := c.expectedSize, cache.size; int64(e) != a {
t.Errorf("expected %v, but received %v", e, a)
}
if e, a := c.expectedSize, count; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
for k, ep := range endpoints {
endpoint, ok := c.validKeys[k]
if !ok {
t.Errorf("unrecognized key %q in cache", k)
}
if e, a := endpoint, ep; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
}
}
func TestCacheGet(t *testing.T) {
cases := []struct {
addEndpoints []Endpoint
validKeys map[string]Endpoint
limit int64
}{
{
limit: 5,
addEndpoints: []Endpoint{
{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
validKeys: map[string]Endpoint{
"foo": Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
"bar": Endpoint{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
"baz": Endpoint{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
"qux": Endpoint{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
"moo": Endpoint{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
},
{
limit: 2,
addEndpoints: []Endpoint{
{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
validKeys: map[string]Endpoint{
"foo": Endpoint{
Key: "foo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://0"),
},
},
},
"bar": Endpoint{
Key: "bar",
Addresses: []WeightedAddress{
{
URL: urlParse("http://1"),
},
},
},
"baz": Endpoint{
Key: "baz",
Addresses: []WeightedAddress{
{
URL: urlParse("http://2"),
},
},
},
"qux": Endpoint{
Key: "qux",
Addresses: []WeightedAddress{
{
URL: urlParse("http://3"),
},
},
},
"moo": Endpoint{
Key: "moo",
Addresses: []WeightedAddress{
{
URL: urlParse("http://4"),
},
},
},
},
},
}
for _, c := range cases {
cache := NewEndpointCache(c.limit)
for _, endpoint := range c.addEndpoints {
cache.Add(endpoint)
}
keys := []string{}
cache.endpoints.Range(func(key, value interface{}) bool {
a := value.(Endpoint)
e, ok := c.validKeys[key.(string)]
if !ok {
t.Errorf("unrecognized key %q in cache", key.(string))
}
if !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
keys = append(keys, key.(string))
return true
})
for _, key := range keys {
a, ok := cache.get(key)
if !ok {
t.Errorf("expected key to be present: %q", key)
}
e := c.validKeys[key]
if !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
}
}

99
vendor/github.com/aws/aws-sdk-go/aws/crr/endpoint.go generated vendored Normal file
View File

@ -0,0 +1,99 @@
package crr
import (
"net/url"
"sort"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
)
// Endpoint represents an endpoint used in endpoint discovery.
type Endpoint struct {
Key string
Addresses WeightedAddresses
}
// WeightedAddresses represents a list of WeightedAddress.
type WeightedAddresses []WeightedAddress
// WeightedAddress represents an address with a given weight.
type WeightedAddress struct {
URL *url.URL
Expired time.Time
}
// HasExpired will return whether or not the endpoint has expired with
// the exception of a zero expiry meaning does not expire.
func (e WeightedAddress) HasExpired() bool {
return e.Expired.Before(time.Now())
}
// Add will add a given WeightedAddress to the address list of Endpoint.
func (e *Endpoint) Add(addr WeightedAddress) {
e.Addresses = append(e.Addresses, addr)
}
// Len returns the number of valid endpoints where valid means the endpoint
// has not expired.
func (e *Endpoint) Len() int {
validEndpoints := 0
for _, endpoint := range e.Addresses {
if endpoint.HasExpired() {
continue
}
validEndpoints++
}
return validEndpoints
}
// GetValidAddress will return a non-expired weight endpoint
func (e *Endpoint) GetValidAddress() (WeightedAddress, bool) {
for i := 0; i < len(e.Addresses); i++ {
we := e.Addresses[i]
if we.HasExpired() {
e.Addresses = append(e.Addresses[:i], e.Addresses[i+1:]...)
i--
continue
}
return we, true
}
return WeightedAddress{}, false
}
// Discoverer is an interface used to discovery which endpoint hit. This
// allows for specifics about what parameters need to be used to be contained
// in the Discoverer implementor.
type Discoverer interface {
Discover() (Endpoint, error)
}
// BuildEndpointKey will sort the keys in alphabetical order and then retrieve
// the values in that order. Those values are then concatenated together to form
// the endpoint key.
func BuildEndpointKey(params map[string]*string) string {
keys := make([]string, len(params))
i := 0
for k := range params {
keys[i] = k
i++
}
sort.Strings(keys)
values := make([]string, len(params))
for i, k := range keys {
if params[k] == nil {
continue
}
values[i] = aws.StringValue(params[k])
}
return strings.Join(values, ".")
}

29
vendor/github.com/aws/aws-sdk-go/aws/crr/sync_map.go generated vendored Normal file
View File

@ -0,0 +1,29 @@
// +build go1.9
package crr
import (
"sync"
)
type syncMap sync.Map
func newSyncMap() syncMap {
return syncMap{}
}
func (m *syncMap) Load(key interface{}) (interface{}, bool) {
return (*sync.Map)(m).Load(key)
}
func (m *syncMap) Store(key interface{}, value interface{}) {
(*sync.Map)(m).Store(key, value)
}
func (m *syncMap) Delete(key interface{}) {
(*sync.Map)(m).Delete(key)
}
func (m *syncMap) Range(f func(interface{}, interface{}) bool) {
(*sync.Map)(m).Range(f)
}

View File

@ -0,0 +1,48 @@
// +build !go1.9
package crr
import (
"sync"
)
type syncMap struct {
container map[interface{}]interface{}
lock sync.RWMutex
}
func newSyncMap() syncMap {
return syncMap{
container: map[interface{}]interface{}{},
}
}
func (m *syncMap) Load(key interface{}) (interface{}, bool) {
m.lock.RLock()
defer m.lock.RUnlock()
v, ok := m.container[key]
return v, ok
}
func (m *syncMap) Store(key interface{}, value interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
m.container[key] = value
}
func (m *syncMap) Delete(key interface{}) {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.container, key)
}
func (m *syncMap) Range(f func(interface{}, interface{}) bool) {
for k, v := range m.container {
if !f(k, v) {
return
}
}
}

View File

@ -0,0 +1,110 @@
package crr
import (
"reflect"
"testing"
)
func TestRangeDelete(t *testing.T) {
m := newSyncMap()
for i := 0; i < 10; i++ {
m.Store(i, i*10)
}
m.Range(func(key, value interface{}) bool {
m.Delete(key)
return true
})
expectedMap := map[interface{}]interface{}{}
actualMap := map[interface{}]interface{}{}
m.Range(func(key, value interface{}) bool {
actualMap[key] = value
return true
})
if e, a := len(expectedMap), len(actualMap); e != a {
t.Errorf("expected map size %d, but received %d", e, a)
}
if e, a := expectedMap, actualMap; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
func TestRangeStore(t *testing.T) {
m := newSyncMap()
for i := 0; i < 10; i++ {
m.Store(i, i*10)
}
m.Range(func(key, value interface{}) bool {
v := value.(int)
m.Store(key, v+1)
return true
})
expectedMap := map[interface{}]interface{}{
0: 1,
1: 11,
2: 21,
3: 31,
4: 41,
5: 51,
6: 61,
7: 71,
8: 81,
9: 91,
}
actualMap := map[interface{}]interface{}{}
m.Range(func(key, value interface{}) bool {
actualMap[key] = value
return true
})
if e, a := len(expectedMap), len(actualMap); e != a {
t.Errorf("expected map size %d, but received %d", e, a)
}
if e, a := expectedMap, actualMap; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}
func TestRangeGet(t *testing.T) {
m := newSyncMap()
for i := 0; i < 10; i++ {
m.Store(i, i*10)
}
m.Range(func(key, value interface{}) bool {
m.Load(key)
return true
})
expectedMap := map[interface{}]interface{}{
0: 0,
1: 10,
2: 20,
3: 30,
4: 40,
5: 50,
6: 60,
7: 70,
8: 80,
9: 90,
}
actualMap := map[interface{}]interface{}{}
m.Range(func(key, value interface{}) bool {
actualMap[key] = value
return true
})
if e, a := len(expectedMap), len(actualMap); e != a {
t.Errorf("expected map size %d, but received %d", e, a)
}
if e, a := expectedMap, actualMap; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}
}

46
vendor/github.com/aws/aws-sdk-go/aws/csm/doc.go generated vendored Normal file
View File

@ -0,0 +1,46 @@
// Package csm provides Client Side Monitoring (CSM) which enables sending metrics
// via UDP connection. Using the Start function will enable the reporting of
// metrics on a given port. If Start is called, with different parameters, again,
// a panic will occur.
//
// Pause can be called to pause any metrics publishing on a given port. Sessions
// that have had their handlers modified via InjectHandlers may still be used.
// However, the handlers will act as a no-op meaning no metrics will be published.
//
// Example:
// r, err := csm.Start("clientID", ":31000")
// if err != nil {
// panic(fmt.Errorf("failed starting CSM: %v", err))
// }
//
// sess, err := session.NewSession(&aws.Config{})
// if err != nil {
// panic(fmt.Errorf("failed loading session: %v", err))
// }
//
// r.InjectHandlers(&sess.Handlers)
//
// client := s3.New(sess)
// resp, err := client.GetObject(&s3.GetObjectInput{
// Bucket: aws.String("bucket"),
// Key: aws.String("key"),
// })
//
// // Will pause monitoring
// r.Pause()
// resp, err = client.GetObject(&s3.GetObjectInput{
// Bucket: aws.String("bucket"),
// Key: aws.String("key"),
// })
//
// // Resume monitoring
// r.Continue()
//
// Start returns a Reporter that is used to enable or disable monitoring. If
// access to the Reporter is required later, calling Get will return the Reporter
// singleton.
//
// Example:
// r := csm.Get()
// r.Continue()
package csm

67
vendor/github.com/aws/aws-sdk-go/aws/csm/enable.go generated vendored Normal file
View File

@ -0,0 +1,67 @@
package csm
import (
"fmt"
"sync"
)
var (
lock sync.Mutex
)
// Client side metric handler names
const (
APICallMetricHandlerName = "awscsm.SendAPICallMetric"
APICallAttemptMetricHandlerName = "awscsm.SendAPICallAttemptMetric"
)
// Start will start the a long running go routine to capture
// client side metrics. Calling start multiple time will only
// start the metric listener once and will panic if a different
// client ID or port is passed in.
//
// Example:
// r, err := csm.Start("clientID", "127.0.0.1:8094")
// if err != nil {
// panic(fmt.Errorf("expected no error, but received %v", err))
// }
// sess := session.NewSession()
// r.InjectHandlers(sess.Handlers)
//
// svc := s3.New(sess)
// out, err := svc.GetObject(&s3.GetObjectInput{
// Bucket: aws.String("bucket"),
// Key: aws.String("key"),
// })
func Start(clientID string, url string) (*Reporter, error) {
lock.Lock()
defer lock.Unlock()
if sender == nil {
sender = newReporter(clientID, url)
} else {
if sender.clientID != clientID {
panic(fmt.Errorf("inconsistent client IDs. %q was expected, but received %q", sender.clientID, clientID))
}
if sender.url != url {
panic(fmt.Errorf("inconsistent URLs. %q was expected, but received %q", sender.url, url))
}
}
if err := connect(url); err != nil {
sender = nil
return nil, err
}
return sender, nil
}
// Get will return a reporter if one exists, if one does not exist, nil will
// be returned.
func Get() *Reporter {
lock.Lock()
defer lock.Unlock()
return sender
}

View File

@ -0,0 +1,74 @@
package csm
import (
"encoding/json"
"fmt"
"net"
"testing"
)
func startUDPServer(done chan struct{}, fn func([]byte)) (string, error) {
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
return "", err
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return "", err
}
buf := make([]byte, 1024)
go func() {
defer conn.Close()
for {
select {
case <-done:
return
default:
}
n, _, err := conn.ReadFromUDP(buf)
fn(buf[:n])
if err != nil {
panic(err)
}
}
}()
return conn.LocalAddr().String(), nil
}
func TestDifferentParams(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("expected panic with different parameters")
}
}()
Start("clientID2", ":0")
}
var MetricsCh = make(chan map[string]interface{}, 1)
var Done = make(chan struct{})
func init() {
url, err := startUDPServer(Done, func(b []byte) {
m := map[string]interface{}{}
if err := json.Unmarshal(b, &m); err != nil {
panic(fmt.Sprintf("expected no error, but received %v", err))
}
MetricsCh <- m
})
if err != nil {
panic(err)
}
_, err = Start("clientID", url)
if err != nil {
panic(err)
}
}

View File

@ -0,0 +1,40 @@
package csm_test
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/csm"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
)
func ExampleStart() {
r, err := csm.Start("clientID", ":31000")
if err != nil {
panic(fmt.Errorf("failed starting CSM: %v", err))
}
sess, err := session.NewSession(&aws.Config{})
if err != nil {
panic(fmt.Errorf("failed loading session: %v", err))
}
r.InjectHandlers(&sess.Handlers)
client := s3.New(sess)
client.GetObject(&s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
// Pauses monitoring
r.Pause()
client.GetObject(&s3.GetObjectInput{
Bucket: aws.String("bucket"),
Key: aws.String("key"),
})
// Resume monitoring
r.Continue()
}

109
vendor/github.com/aws/aws-sdk-go/aws/csm/metric.go generated vendored Normal file
View File

@ -0,0 +1,109 @@
package csm
import (
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
)
type metricTime time.Time
func (t metricTime) MarshalJSON() ([]byte, error) {
ns := time.Duration(time.Time(t).UnixNano())
return []byte(strconv.FormatInt(int64(ns/time.Millisecond), 10)), nil
}
type metric struct {
ClientID *string `json:"ClientId,omitempty"`
API *string `json:"Api,omitempty"`
Service *string `json:"Service,omitempty"`
Timestamp *metricTime `json:"Timestamp,omitempty"`
Type *string `json:"Type,omitempty"`
Version *int `json:"Version,omitempty"`
AttemptCount *int `json:"AttemptCount,omitempty"`
Latency *int `json:"Latency,omitempty"`
Fqdn *string `json:"Fqdn,omitempty"`
UserAgent *string `json:"UserAgent,omitempty"`
AttemptLatency *int `json:"AttemptLatency,omitempty"`
SessionToken *string `json:"SessionToken,omitempty"`
Region *string `json:"Region,omitempty"`
AccessKey *string `json:"AccessKey,omitempty"`
HTTPStatusCode *int `json:"HttpStatusCode,omitempty"`
XAmzID2 *string `json:"XAmzId2,omitempty"`
XAmzRequestID *string `json:"XAmznRequestId,omitempty"`
AWSException *string `json:"AwsException,omitempty"`
AWSExceptionMessage *string `json:"AwsExceptionMessage,omitempty"`
SDKException *string `json:"SdkException,omitempty"`
SDKExceptionMessage *string `json:"SdkExceptionMessage,omitempty"`
FinalHTTPStatusCode *int `json:"FinalHttpStatusCode,omitempty"`
FinalAWSException *string `json:"FinalAwsException,omitempty"`
FinalAWSExceptionMessage *string `json:"FinalAwsExceptionMessage,omitempty"`
FinalSDKException *string `json:"FinalSdkException,omitempty"`
FinalSDKExceptionMessage *string `json:"FinalSdkExceptionMessage,omitempty"`
DestinationIP *string `json:"DestinationIp,omitempty"`
ConnectionReused *int `json:"ConnectionReused,omitempty"`
AcquireConnectionLatency *int `json:"AcquireConnectionLatency,omitempty"`
ConnectLatency *int `json:"ConnectLatency,omitempty"`
RequestLatency *int `json:"RequestLatency,omitempty"`
DNSLatency *int `json:"DnsLatency,omitempty"`
TCPLatency *int `json:"TcpLatency,omitempty"`
SSLLatency *int `json:"SslLatency,omitempty"`
MaxRetriesExceeded *int `json:"MaxRetriesExceeded,omitempty"`
}
func (m *metric) TruncateFields() {
m.ClientID = truncateString(m.ClientID, 255)
m.UserAgent = truncateString(m.UserAgent, 256)
m.AWSException = truncateString(m.AWSException, 128)
m.AWSExceptionMessage = truncateString(m.AWSExceptionMessage, 512)
m.SDKException = truncateString(m.SDKException, 128)
m.SDKExceptionMessage = truncateString(m.SDKExceptionMessage, 512)
m.FinalAWSException = truncateString(m.FinalAWSException, 128)
m.FinalAWSExceptionMessage = truncateString(m.FinalAWSExceptionMessage, 512)
m.FinalSDKException = truncateString(m.FinalSDKException, 128)
m.FinalSDKExceptionMessage = truncateString(m.FinalSDKExceptionMessage, 512)
}
func truncateString(v *string, l int) *string {
if v != nil && len(*v) > l {
nv := (*v)[:l]
return &nv
}
return v
}
func (m *metric) SetException(e metricException) {
switch te := e.(type) {
case awsException:
m.AWSException = aws.String(te.exception)
m.AWSExceptionMessage = aws.String(te.message)
case sdkException:
m.SDKException = aws.String(te.exception)
m.SDKExceptionMessage = aws.String(te.message)
}
}
func (m *metric) SetFinalException(e metricException) {
switch te := e.(type) {
case awsException:
m.FinalAWSException = aws.String(te.exception)
m.FinalAWSExceptionMessage = aws.String(te.message)
case sdkException:
m.FinalSDKException = aws.String(te.exception)
m.FinalSDKExceptionMessage = aws.String(te.message)
}
}

View File

@ -0,0 +1,54 @@
package csm
import (
"sync/atomic"
)
const (
runningEnum = iota
pausedEnum
)
var (
// MetricsChannelSize of metrics to hold in the channel
MetricsChannelSize = 100
)
type metricChan struct {
ch chan metric
paused int64
}
func newMetricChan(size int) metricChan {
return metricChan{
ch: make(chan metric, size),
}
}
func (ch *metricChan) Pause() {
atomic.StoreInt64(&ch.paused, pausedEnum)
}
func (ch *metricChan) Continue() {
atomic.StoreInt64(&ch.paused, runningEnum)
}
func (ch *metricChan) IsPaused() bool {
v := atomic.LoadInt64(&ch.paused)
return v == pausedEnum
}
// Push will push metrics to the metric channel if the channel
// is not paused
func (ch *metricChan) Push(m metric) bool {
if ch.IsPaused() {
return false
}
select {
case ch.ch <- m:
return true
default:
return false
}
}

View File

@ -0,0 +1,72 @@
package csm
import (
"testing"
)
func TestMetricChanPush(t *testing.T) {
ch := newMetricChan(5)
defer close(ch.ch)
pushed := ch.Push(metric{})
if !pushed {
t.Errorf("expected metrics to be pushed")
}
if e, a := 1, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}
func TestMetricChanPauseContinue(t *testing.T) {
ch := newMetricChan(5)
defer close(ch.ch)
ch.Pause()
if !ch.IsPaused() {
t.Errorf("expected to be paused, but did not pause properly")
}
ch.Continue()
if ch.IsPaused() {
t.Errorf("expected to be not paused, but did not continue properly")
}
pushed := ch.Push(metric{})
if !pushed {
t.Errorf("expected metrics to be pushed")
}
if e, a := 1, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}
func TestMetricChanPushWhenPaused(t *testing.T) {
ch := newMetricChan(5)
defer close(ch.ch)
ch.Pause()
pushed := ch.Push(metric{})
if pushed {
t.Errorf("expected metrics to not be pushed")
}
if e, a := 0, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}
func TestMetricChanNonBlocking(t *testing.T) {
ch := newMetricChan(0)
defer close(ch.ch)
pushed := ch.Push(metric{})
if pushed {
t.Errorf("expected metrics to be not pushed")
}
if e, a := 0, len(ch.ch); e != a {
t.Errorf("expected %d, but received %d", e, a)
}
}

View File

@ -0,0 +1,26 @@
package csm
type metricException interface {
Exception() string
Message() string
}
type requestException struct {
exception string
message string
}
func (e requestException) Exception() string {
return e.exception
}
func (e requestException) Message() string {
return e.message
}
type awsException struct {
requestException
}
type sdkException struct {
requestException
}

106
vendor/github.com/aws/aws-sdk-go/aws/csm/metric_test.go generated vendored Normal file
View File

@ -0,0 +1,106 @@
// +build go1.7
package csm
import (
"reflect"
"testing"
"github.com/aws/aws-sdk-go/aws"
)
func TestTruncateString(t *testing.T) {
cases := map[string]struct {
Val string
Len int
Expect string
}{
"no change": {
Val: "123456789", Len: 10,
Expect: "123456789",
},
"max len": {
Val: "1234567890", Len: 10,
Expect: "1234567890",
},
"too long": {
Val: "12345678901", Len: 10,
Expect: "1234567890",
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
v := c.Val
actual := truncateString(&v, c.Len)
if e, a := c.Val, v; e != a {
t.Errorf("expect input value not to change, %v, %v", e, a)
}
if e, a := c.Expect, *actual; e != a {
t.Errorf("expect %v, got %v", e, a)
}
})
}
}
func TestMetric_SetException(t *testing.T) {
cases := map[string]struct {
Exc metricException
Expect metric
Final bool
}{
"aws exc": {
Exc: awsException{
requestException{exception: "abc", message: "123"},
},
Expect: metric{
AWSException: aws.String("abc"),
AWSExceptionMessage: aws.String("123"),
},
},
"sdk exc": {
Exc: sdkException{
requestException{exception: "abc", message: "123"},
},
Expect: metric{
SDKException: aws.String("abc"),
SDKExceptionMessage: aws.String("123"),
},
},
"final aws exc": {
Exc: awsException{
requestException{exception: "abc", message: "123"},
},
Expect: metric{
FinalAWSException: aws.String("abc"),
FinalAWSExceptionMessage: aws.String("123"),
},
Final: true,
},
"final sdk exc": {
Exc: sdkException{
requestException{exception: "abc", message: "123"},
},
Expect: metric{
FinalSDKException: aws.String("abc"),
FinalSDKExceptionMessage: aws.String("123"),
},
Final: true,
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
var m metric
if c.Final {
m.SetFinalException(c.Exc)
} else {
m.SetException(c.Exc)
}
if e, a := c.Expect, m; !reflect.DeepEqual(e, a) {
t.Errorf("expect:\n%#v\nactual:\n%#v\n", e, a)
}
})
}
}

260
vendor/github.com/aws/aws-sdk-go/aws/csm/reporter.go generated vendored Normal file
View File

@ -0,0 +1,260 @@
package csm
import (
"encoding/json"
"net"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
)
const (
// DefaultPort is used when no port is specified
DefaultPort = "31000"
)
// Reporter will gather metrics of API requests made and
// send those metrics to the CSM endpoint.
type Reporter struct {
clientID string
url string
conn net.Conn
metricsCh metricChan
done chan struct{}
}
var (
sender *Reporter
)
func connect(url string) error {
const network = "udp"
if err := sender.connect(network, url); err != nil {
return err
}
if sender.done == nil {
sender.done = make(chan struct{})
go sender.start()
}
return nil
}
func newReporter(clientID, url string) *Reporter {
return &Reporter{
clientID: clientID,
url: url,
metricsCh: newMetricChan(MetricsChannelSize),
}
}
func (rep *Reporter) sendAPICallAttemptMetric(r *request.Request) {
if rep == nil {
return
}
now := time.Now()
creds, _ := r.Config.Credentials.Get()
m := metric{
ClientID: aws.String(rep.clientID),
API: aws.String(r.Operation.Name),
Service: aws.String(r.ClientInfo.ServiceID),
Timestamp: (*metricTime)(&now),
UserAgent: aws.String(r.HTTPRequest.Header.Get("User-Agent")),
Region: r.Config.Region,
Type: aws.String("ApiCallAttempt"),
Version: aws.Int(1),
XAmzRequestID: aws.String(r.RequestID),
AttemptCount: aws.Int(r.RetryCount + 1),
AttemptLatency: aws.Int(int(now.Sub(r.AttemptTime).Nanoseconds() / int64(time.Millisecond))),
AccessKey: aws.String(creds.AccessKeyID),
}
if r.HTTPResponse != nil {
m.HTTPStatusCode = aws.Int(r.HTTPResponse.StatusCode)
}
if r.Error != nil {
if awserr, ok := r.Error.(awserr.Error); ok {
m.SetException(getMetricException(awserr))
}
}
m.TruncateFields()
rep.metricsCh.Push(m)
}
func getMetricException(err awserr.Error) metricException {
msg := err.Error()
code := err.Code()
switch code {
case "RequestError",
"SerializationError",
request.CanceledErrorCode:
return sdkException{
requestException{exception: code, message: msg},
}
default:
return awsException{
requestException{exception: code, message: msg},
}
}
}
func (rep *Reporter) sendAPICallMetric(r *request.Request) {
if rep == nil {
return
}
now := time.Now()
m := metric{
ClientID: aws.String(rep.clientID),
API: aws.String(r.Operation.Name),
Service: aws.String(r.ClientInfo.ServiceID),
Timestamp: (*metricTime)(&now),
UserAgent: aws.String(r.HTTPRequest.Header.Get("User-Agent")),
Type: aws.String("ApiCall"),
AttemptCount: aws.Int(r.RetryCount + 1),
Region: r.Config.Region,
Latency: aws.Int(int(time.Now().Sub(r.Time) / time.Millisecond)),
XAmzRequestID: aws.String(r.RequestID),
MaxRetriesExceeded: aws.Int(boolIntValue(r.RetryCount >= r.MaxRetries())),
}
if r.HTTPResponse != nil {
m.FinalHTTPStatusCode = aws.Int(r.HTTPResponse.StatusCode)
}
if r.Error != nil {
if awserr, ok := r.Error.(awserr.Error); ok {
m.SetFinalException(getMetricException(awserr))
}
}
m.TruncateFields()
// TODO: Probably want to figure something out for logging dropped
// metrics
rep.metricsCh.Push(m)
}
func (rep *Reporter) connect(network, url string) error {
if rep.conn != nil {
rep.conn.Close()
}
conn, err := net.Dial(network, url)
if err != nil {
return awserr.New("UDPError", "Could not connect", err)
}
rep.conn = conn
return nil
}
func (rep *Reporter) close() {
if rep.done != nil {
close(rep.done)
}
rep.metricsCh.Pause()
}
func (rep *Reporter) start() {
defer func() {
rep.metricsCh.Pause()
}()
for {
select {
case <-rep.done:
rep.done = nil
return
case m := <-rep.metricsCh.ch:
// TODO: What to do with this error? Probably should just log
b, err := json.Marshal(m)
if err != nil {
continue
}
rep.conn.Write(b)
}
}
}
// Pause will pause the metric channel preventing any new metrics from
// being added.
func (rep *Reporter) Pause() {
lock.Lock()
defer lock.Unlock()
if rep == nil {
return
}
rep.close()
}
// Continue will reopen the metric channel and allow for monitoring
// to be resumed.
func (rep *Reporter) Continue() {
lock.Lock()
defer lock.Unlock()
if rep == nil {
return
}
if !rep.metricsCh.IsPaused() {
return
}
rep.metricsCh.Continue()
}
// InjectHandlers will will enable client side metrics and inject the proper
// handlers to handle how metrics are sent.
//
// Example:
// // Start must be called in order to inject the correct handlers
// r, err := csm.Start("clientID", "127.0.0.1:8094")
// if err != nil {
// panic(fmt.Errorf("expected no error, but received %v", err))
// }
//
// sess := session.NewSession()
// r.InjectHandlers(&sess.Handlers)
//
// // create a new service client with our client side metric session
// svc := s3.New(sess)
func (rep *Reporter) InjectHandlers(handlers *request.Handlers) {
if rep == nil {
return
}
handlers.Complete.PushFrontNamed(request.NamedHandler{
Name: APICallMetricHandlerName,
Fn: rep.sendAPICallMetric,
})
handlers.CompleteAttempt.PushFrontNamed(request.NamedHandler{
Name: APICallAttemptMetricHandlerName,
Fn: rep.sendAPICallAttemptMetric,
})
}
// boolIntValue return 1 for true and 0 for false.
func boolIntValue(b bool) int {
if b {
return 1
}
return 0
}

View File

@ -0,0 +1,72 @@
package csm
import (
"net/http"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/request"
)
func TestMaxRetriesExceeded(t *testing.T) {
md := metadata.ClientInfo{
Endpoint: "http://127.0.0.1",
}
cfg := aws.Config{
Region: aws.String("foo"),
Credentials: credentials.NewStaticCredentials("", "", ""),
}
op := &request.Operation{}
cases := []struct {
name string
httpStatusCode int
expectedMaxRetriesValue int
expectedMetrics int
}{
{
name: "max retry reached",
httpStatusCode: http.StatusBadGateway,
expectedMaxRetriesValue: 1,
},
{
name: "status ok",
httpStatusCode: http.StatusOK,
expectedMaxRetriesValue: 0,
},
}
for _, c := range cases {
r := request.New(cfg, md, defaults.Handlers(), client.DefaultRetryer{NumMaxRetries: 2}, op, nil, nil)
reporter := newReporter("", "")
r.Handlers.Send.Clear()
reporter.InjectHandlers(&r.Handlers)
r.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: c.httpStatusCode,
}
})
r.Send()
for {
m := <-reporter.metricsCh.ch
if *m.Type != "ApiCall" {
// ignore non-ApiCall metrics since MaxRetriesExceeded is only on ApiCall events
continue
}
if val := *m.MaxRetriesExceeded; val != c.expectedMaxRetriesValue {
t.Errorf("%s: expected %d, but received %d", c.name, c.expectedMaxRetriesValue, val)
}
break
}
}
}

View File

@ -0,0 +1,414 @@
// +build go1.7
package csm_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"sort"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/csm"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/private/protocol/jsonrpc"
)
func TestReportingMetrics(t *testing.T) {
sess := unit.Session.Copy(&aws.Config{
SleepDelay: func(time.Duration) {},
})
sess.Handlers.Validate.Clear()
sess.Handlers.Sign.Clear()
sess.Handlers.Send.Clear()
reporter := csm.Get()
if reporter == nil {
t.Errorf("expected non-nil reporter")
}
reporter.InjectHandlers(&sess.Handlers)
cases := map[string]struct {
Request *request.Request
ExpectMetrics []map[string]interface{}
}{
"successful request": {
Request: func() *request.Request {
md := metadata.ClientInfo{}
op := &request.Operation{Name: "OperationName"}
req := request.New(*sess.Config, md, sess.Handlers, client.DefaultRetryer{NumMaxRetries: 3}, op, nil, nil)
req.Handlers.Send.PushBack(func(r *request.Request) {
req.HTTPResponse = &http.Response{
StatusCode: 200,
Header: http.Header{},
}
})
return req
}(),
ExpectMetrics: []map[string]interface{}{
{
"Type": "ApiCallAttempt",
"HttpStatusCode": float64(200),
},
{
"Type": "ApiCall",
"FinalHttpStatusCode": float64(200),
},
},
},
"failed request, no retry": {
Request: func() *request.Request {
md := metadata.ClientInfo{}
op := &request.Operation{Name: "OperationName"}
req := request.New(*sess.Config, md, sess.Handlers, client.DefaultRetryer{NumMaxRetries: 3}, op, nil, nil)
req.Handlers.Send.PushBack(func(r *request.Request) {
req.HTTPResponse = &http.Response{
StatusCode: 400,
Header: http.Header{},
}
req.Retryable = aws.Bool(false)
req.Error = awserr.New("Error", "Message", nil)
})
return req
}(),
ExpectMetrics: []map[string]interface{}{
{
"Type": "ApiCallAttempt",
"HttpStatusCode": float64(400),
"AwsException": "Error",
"AwsExceptionMessage": "Error: Message",
},
{
"Type": "ApiCall",
"FinalHttpStatusCode": float64(400),
"FinalAwsException": "Error",
"FinalAwsExceptionMessage": "Error: Message",
"AttemptCount": float64(1),
},
},
},
"failed request, with retry": {
Request: func() *request.Request {
md := metadata.ClientInfo{}
op := &request.Operation{Name: "OperationName"}
req := request.New(*sess.Config, md, sess.Handlers, client.DefaultRetryer{NumMaxRetries: 1}, op, nil, nil)
resps := []*http.Response{
{
StatusCode: 500,
Header: http.Header{},
},
{
StatusCode: 500,
Header: http.Header{},
},
}
req.Handlers.Send.PushBack(func(r *request.Request) {
req.HTTPResponse = resps[0]
resps = resps[1:]
})
return req
}(),
ExpectMetrics: []map[string]interface{}{
{
"Type": "ApiCallAttempt",
"HttpStatusCode": float64(500),
"AwsException": "UnknownError",
"AwsExceptionMessage": "UnknownError: unknown error",
},
{
"Type": "ApiCallAttempt",
"HttpStatusCode": float64(500),
"AwsException": "UnknownError",
"AwsExceptionMessage": "UnknownError: unknown error",
},
{
"Type": "ApiCall",
"FinalHttpStatusCode": float64(500),
"FinalAwsException": "UnknownError",
"FinalAwsExceptionMessage": "UnknownError: unknown error",
"AttemptCount": float64(2),
},
},
},
"success request, with retry": {
Request: func() *request.Request {
md := metadata.ClientInfo{}
op := &request.Operation{Name: "OperationName"}
req := request.New(*sess.Config, md, sess.Handlers, client.DefaultRetryer{NumMaxRetries: 3}, op, nil, nil)
errs := []error{
awserr.New("AWSError", "aws error", nil),
awserr.New("RequestError", "sdk error", nil),
nil,
}
resps := []*http.Response{
{
StatusCode: 500,
Header: http.Header{},
},
{
StatusCode: 500,
Header: http.Header{},
},
{
StatusCode: 200,
Header: http.Header{},
},
}
req.Handlers.Send.PushBack(func(r *request.Request) {
req.HTTPResponse = resps[0]
resps = resps[1:]
req.Error = errs[0]
errs = errs[1:]
})
return req
}(),
ExpectMetrics: []map[string]interface{}{
{
"Type": "ApiCallAttempt",
"AwsException": "AWSError",
"AwsExceptionMessage": "AWSError: aws error",
"HttpStatusCode": float64(500),
},
{
"Type": "ApiCallAttempt",
"SdkException": "RequestError",
"SdkExceptionMessage": "RequestError: sdk error",
"HttpStatusCode": float64(500),
},
{
"Type": "ApiCallAttempt",
"AwsException": nil,
"AwsExceptionMessage": nil,
"SdkException": nil,
"SdkExceptionMessage": nil,
"HttpStatusCode": float64(200),
},
{
"Type": "ApiCall",
"FinalHttpStatusCode": float64(200),
"FinalAwsException": nil,
"FinalAwsExceptionMessage": nil,
"FinalSdkException": nil,
"FinalSdkExceptionMessage": nil,
"AttemptCount": float64(3),
},
},
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
ctx, cancelFn := context.WithTimeout(context.Background(), time.Second)
defer cancelFn()
c.Request.Send()
for i := 0; i < len(c.ExpectMetrics); i++ {
select {
case m := <-csm.MetricsCh:
for ek, ev := range c.ExpectMetrics[i] {
if ev == nil {
// must not be set
if _, ok := m[ek]; ok {
t.Errorf("%d, expect %v metric member, not to be set, %v", i, ek, m[ek])
}
continue
}
if _, ok := m[ek]; !ok {
t.Errorf("%d, expect %v metric member, keys: %v", i, ek, keys(m))
}
if e, a := ev, m[ek]; e != a {
t.Errorf("%d, expect %v:%v(%T), metric value, got %v(%T)", i, ek, e, e, a, a)
}
}
case <-ctx.Done():
t.Errorf("timeout waiting for metrics")
return
}
}
var extraMetrics []map[string]interface{}
Loop:
for {
select {
case m := <-csm.MetricsCh:
extraMetrics = append(extraMetrics, m)
default:
break Loop
}
}
if len(extraMetrics) != 0 {
t.Fatalf("unexpected metrics, %#v", extraMetrics)
}
})
}
}
type mockService struct {
*client.Client
}
type input struct{}
type output struct{}
func (s *mockService) Request(i input) *request.Request {
op := &request.Operation{
Name: "foo",
HTTPMethod: "POST",
HTTPPath: "/",
}
o := output{}
req := s.NewRequest(op, &i, &o)
return req
}
func BenchmarkWithCSM(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("{}")))
}))
cfg := aws.Config{
Endpoint: aws.String(server.URL),
}
sess := unit.Session.Copy(&cfg)
r := csm.Get()
r.InjectHandlers(&sess.Handlers)
c := sess.ClientConfig("id", &cfg)
svc := mockService{
client.New(
*c.Config,
metadata.ClientInfo{
ServiceName: "service",
ServiceID: "id",
SigningName: "signing",
SigningRegion: "region",
Endpoint: server.URL,
APIVersion: "0",
JSONVersion: "1.1",
TargetPrefix: "prefix",
},
c.Handlers,
),
}
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
for i := 0; i < b.N; i++ {
req := svc.Request(input{})
req.Send()
}
}
func BenchmarkWithCSMNoUDPConnection(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("{}")))
}))
cfg := aws.Config{
Endpoint: aws.String(server.URL),
}
sess := unit.Session.Copy(&cfg)
r := csm.Get()
r.Pause()
r.InjectHandlers(&sess.Handlers)
defer r.Pause()
c := sess.ClientConfig("id", &cfg)
svc := mockService{
client.New(
*c.Config,
metadata.ClientInfo{
ServiceName: "service",
ServiceID: "id",
SigningName: "signing",
SigningRegion: "region",
Endpoint: server.URL,
APIVersion: "0",
JSONVersion: "1.1",
TargetPrefix: "prefix",
},
c.Handlers,
),
}
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
for i := 0; i < b.N; i++ {
req := svc.Request(input{})
req.Send()
}
}
func BenchmarkWithoutCSM(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(fmt.Sprintf("{}")))
}))
cfg := aws.Config{
Endpoint: aws.String(server.URL),
}
sess := unit.Session.Copy(&cfg)
c := sess.ClientConfig("id", &cfg)
svc := mockService{
client.New(
*c.Config,
metadata.ClientInfo{
ServiceName: "service",
ServiceID: "id",
SigningName: "signing",
SigningRegion: "region",
Endpoint: server.URL,
APIVersion: "0",
JSONVersion: "1.1",
TargetPrefix: "prefix",
},
c.Handlers,
),
}
svc.Handlers.Sign.PushBackNamed(v4.SignRequestHandler)
svc.Handlers.Build.PushBackNamed(jsonrpc.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed(jsonrpc.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed(jsonrpc.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed(jsonrpc.UnmarshalErrorHandler)
for i := 0; i < b.N; i++ {
req := svc.Request(input{})
req.Send()
}
}
func keys(m map[string]interface{}) []string {
ks := make([]string, 0, len(m))
for k := range m {
ks = append(ks, k)
}
sort.Strings(ks)
return ks
}

View File

@ -9,6 +9,7 @@ package defaults
import (
"fmt"
"net"
"net/http"
"net/url"
"os"
@ -23,6 +24,7 @@ import (
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
// A Defaults provides a collection of default values for SDK clients.
@ -72,6 +74,7 @@ func Handlers() request.Handlers {
handlers.Validate.PushBackNamed(corehandlers.ValidateEndpointHandler)
handlers.Validate.AfterEachFn = request.HandlerListStopOnError
handlers.Build.PushBackNamed(corehandlers.SDKVersionUserAgentHandler)
handlers.Build.PushBackNamed(corehandlers.AddHostExecEnvUserAgentHander)
handlers.Build.AfterEachFn = request.HandlerListStopOnError
handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler)
handlers.Send.PushBackNamed(corehandlers.ValidateReqSigHandler)
@ -90,17 +93,28 @@ func Handlers() request.Handlers {
func CredChain(cfg *aws.Config, handlers request.Handlers) *credentials.Credentials {
return credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: aws.BoolValue(cfg.CredentialsChainVerboseErrors),
Providers: []credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
RemoteCredProvider(*cfg, handlers),
},
Providers: CredProviders(cfg, handlers),
})
}
// CredProviders returns the slice of providers used in
// the default credential chain.
//
// For applications that need to use some other provider (for example use
// different environment variables for legacy reasons) but still fall back
// on the default chain of providers. This allows that default chaint to be
// automatically updated
func CredProviders(cfg *aws.Config, handlers request.Handlers) []credentials.Provider {
return []credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
RemoteCredProvider(*cfg, handlers),
}
}
const (
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
ecsCredsProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
httpProviderAuthorizationEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN"
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
)
// RemoteCredProvider returns a credentials provider for the default remote
@ -110,22 +124,51 @@ func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.P
return localHTTPCredProvider(cfg, handlers, u)
}
if uri := os.Getenv(ecsCredsProviderEnvVar); len(uri) > 0 {
u := fmt.Sprintf("http://169.254.170.2%s", uri)
if uri := os.Getenv(shareddefaults.ECSCredsProviderEnvVar); len(uri) > 0 {
u := fmt.Sprintf("%s%s", shareddefaults.ECSContainerCredentialsURI, uri)
return httpCredProvider(cfg, handlers, u)
}
return ec2RoleProvider(cfg, handlers)
}
var lookupHostFn = net.LookupHost
func isLoopbackHost(host string) (bool, error) {
ip := net.ParseIP(host)
if ip != nil {
return ip.IsLoopback(), nil
}
// Host is not an ip, perform lookup
addrs, err := lookupHostFn(host)
if err != nil {
return false, err
}
for _, addr := range addrs {
if !net.ParseIP(addr).IsLoopback() {
return false, nil
}
}
return true, nil
}
func localHTTPCredProvider(cfg aws.Config, handlers request.Handlers, u string) credentials.Provider {
var errMsg string
parsed, err := url.Parse(u)
if err != nil {
errMsg = fmt.Sprintf("invalid URL, %v", err)
} else if host := aws.URLHostname(parsed); !(host == "localhost" || host == "127.0.0.1") {
errMsg = fmt.Sprintf("invalid host address, %q, only localhost and 127.0.0.1 are valid.", host)
} else {
host := aws.URLHostname(parsed)
if len(host) == 0 {
errMsg = "unable to parse host from local HTTP cred provider URL"
} else if isLoopback, loopbackErr := isLoopbackHost(host); loopbackErr != nil {
errMsg = fmt.Sprintf("failed to resolve host %q, %v", host, loopbackErr)
} else if !isLoopback {
errMsg = fmt.Sprintf("invalid endpoint host, %q, only loopback hosts are allowed.", host)
}
}
if len(errMsg) > 0 {
@ -145,6 +188,7 @@ func httpCredProvider(cfg aws.Config, handlers request.Handlers, u string) crede
return endpointcreds.NewProviderClient(cfg, handlers, u,
func(p *endpointcreds.Provider) {
p.ExpiryWindow = 5 * time.Minute
p.AuthorizationToken = os.Getenv(httpProviderAuthorizationEnvVar)
},
)
}

View File

@ -10,15 +10,46 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/shareddefaults"
)
func TestHTTPCredProvider(t *testing.T) {
origFn := lookupHostFn
defer func() { lookupHostFn = origFn }()
lookupHostFn = func(host string) ([]string, error) {
m := map[string]struct {
Addrs []string
Err error
}{
"localhost": {Addrs: []string{"::1", "127.0.0.1"}},
"actuallylocal": {Addrs: []string{"127.0.0.2"}},
"notlocal": {Addrs: []string{"::1", "127.0.0.1", "192.168.1.10"}},
"www.example.com": {Addrs: []string{"10.10.10.10"}},
}
h, ok := m[host]
if !ok {
t.Fatalf("unknown host in test, %v", host)
return nil, fmt.Errorf("unknown host")
}
return h.Addrs, h.Err
}
cases := []struct {
Host string
Fail bool
Host string
AuthToken string
Fail bool
}{
{"localhost", false}, {"127.0.0.1", false},
{"www.example.com", true}, {"169.254.170.2", true},
{Host: "localhost", Fail: false},
{Host: "actuallylocal", Fail: false},
{Host: "127.0.0.1", Fail: false},
{Host: "127.1.1.1", Fail: false},
{Host: "[::1]", Fail: false},
{Host: "www.example.com", Fail: true},
{Host: "169.254.170.2", Fail: true},
{Host: "localhost", Fail: false, AuthToken: "Basic abc123"},
}
defer os.Clearenv()
@ -26,6 +57,7 @@ func TestHTTPCredProvider(t *testing.T) {
for i, c := range cases {
u := fmt.Sprintf("http://%s/abc/123", c.Host)
os.Setenv(httpProviderEnvVar, u)
os.Setenv(httpProviderAuthorizationEnvVar, c.AuthToken)
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
if provider == nil {
@ -50,13 +82,16 @@ func TestHTTPCredProvider(t *testing.T) {
if e, a := u, httpProvider.Client.Endpoint; e != a {
t.Errorf("%d, expect %q endpoint, got %q", i, e, a)
}
if e, a := c.AuthToken, httpProvider.AuthorizationToken; e != a {
t.Errorf("%d, expect %q auth token, got %q", i, e, a)
}
}
}
}
func TestECSCredProvider(t *testing.T) {
defer os.Clearenv()
os.Setenv(ecsCredsProviderEnvVar, "/abc/123")
os.Setenv(shareddefaults.ECSCredsProviderEnvVar, "/abc/123")
provider := RemoteCredProvider(aws.Config{}, request.Handlers{})
if provider == nil {

View File

@ -4,12 +4,12 @@ import (
"encoding/json"
"fmt"
"net/http"
"path"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/internal/sdkuri"
)
// GetMetadata uses the path provided to request information from the EC2
@ -19,7 +19,7 @@ func (c *EC2Metadata) GetMetadata(p string) (string, error) {
op := &request.Operation{
Name: "GetMetadata",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "meta-data", p),
HTTPPath: sdkuri.PathJoin("/meta-data", p),
}
output := &metadataOutput{}
@ -35,7 +35,7 @@ func (c *EC2Metadata) GetUserData() (string, error) {
op := &request.Operation{
Name: "GetUserData",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "user-data"),
HTTPPath: "/user-data",
}
output := &metadataOutput{}
@ -56,7 +56,7 @@ func (c *EC2Metadata) GetDynamicData(p string) (string, error) {
op := &request.Operation{
Name: "GetDynamicData",
HTTPMethod: "GET",
HTTPPath: path.Join("/", "dynamic", p),
HTTPPath: sdkuri.PathJoin("/dynamic", p),
}
output := &metadataOutput{}
@ -118,6 +118,10 @@ func (c *EC2Metadata) Region() (string, error) {
return "", err
}
if len(resp) == 0 {
return "", awserr.New("EC2MetadataError", "invalid Region response", nil)
}
// returns region without the suffix. Eg: us-west-2a becomes us-west-2
return resp[:len(resp)-1], nil
}

View File

@ -167,6 +167,25 @@ func TestGetRegion(t *testing.T) {
}
}
func TestGetRegion_invalidResponse(t *testing.T) {
server := initTestServer(
"/latest/meta-data/placement/availability-zone",
"", // no data in response
)
defer server.Close()
c := ec2metadata.New(unit.Session, &aws.Config{Endpoint: aws.String(server.URL + "/latest")})
region, err := c.Region()
if err == nil {
t.Errorf("expected error, got %v", err)
}
if e, a := "", region; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestMetadataAvailable(t *testing.T) {
server := initTestServer(
"/latest/meta-data/instance-id",

View File

@ -1,5 +1,10 @@
// Package ec2metadata provides the client for making API calls to the
// EC2 Metadata service.
//
// This package's client can be disabled completely by setting the environment
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
// true instructs the SDK to disable the EC2 Metadata client. The client cannot
// be used while the environment variable is set to true, (case insensitive).
package ec2metadata
import (
@ -7,17 +12,21 @@ import (
"errors"
"io"
"net/http"
"os"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
)
// ServiceName is the name of the service.
const ServiceName = "ec2metadata"
const disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"
// A EC2Metadata is an EC2 Metadata service Client.
type EC2Metadata struct {
@ -63,6 +72,7 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
cfg,
metadata.ClientInfo{
ServiceName: ServiceName,
ServiceID: ServiceName,
Endpoint: endpoint,
APIVersion: "latest",
},
@ -75,6 +85,24 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
svc.Handlers.Validate.Clear()
svc.Handlers.Validate.PushBack(validateEndpointHandler)
// Disable the EC2 Metadata service if the environment variable is set.
// This shortcirctes the service's functionality to always fail to send
// requests.
if strings.ToLower(os.Getenv(disableServiceEnvVar)) == "true" {
svc.Handlers.Send.SwapNamed(request.NamedHandler{
Name: corehandlers.SendHandler.Name,
Fn: func(r *request.Request) {
r.HTTPResponse = &http.Response{
Header: http.Header{},
}
r.Error = awserr.New(
request.CanceledErrorCode,
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
nil)
},
})
}
// Add additional options to the service config
for _, option := range opts {
option(svc.Client)

View File

@ -3,21 +3,30 @@ package ec2metadata_test
import (
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/stretchr/testify/assert"
)
func TestClientOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session)
assert.NotEqual(t, http.DefaultClient, svc.Config.HTTPClient)
assert.Equal(t, 5*time.Second, svc.Config.HTTPClient.Timeout)
if e, a := http.DefaultClient, svc.Config.HTTPClient; e == a {
t.Errorf("expect %v, not to equal %v", e, a)
}
if e, a := 5*time.Second, svc.Config.HTTPClient.Timeout; e != a {
t.Errorf("expect %v to be %v", e, a)
}
}
func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) {
@ -28,18 +37,25 @@ func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session)
assert.Equal(t, http.DefaultClient, svc.Config.HTTPClient)
if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
t.Errorf("expect %v, got %v", e, a)
}
tr, ok := svc.Config.HTTPClient.Transport.(*http.Transport)
assert.True(t, ok)
assert.NotNil(t, tr)
assert.Nil(t, tr.Dial)
tr := svc.Config.HTTPClient.Transport.(*http.Transport)
if tr == nil {
t.Fatalf("expect transport not to be nil")
}
if tr.Dial != nil {
t.Errorf("expect dial to be nil, was not")
}
}
func TestClientDisableOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session, aws.NewConfig().WithEC2MetadataDisableTimeoutOverride(true))
assert.Equal(t, http.DefaultClient, svc.Config.HTTPClient)
if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestClientOverrideDefaultHTTPClientTimeoutRace(t *testing.T) {
@ -63,6 +79,32 @@ func TestClientOverrideDefaultHTTPClientTimeoutRaceWithTransport(t *testing.T) {
runEC2MetadataClients(t, cfg, 100)
}
func TestClientDisableIMDS(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)
os.Setenv("AWS_EC2_METADATA_DISABLED", "true")
svc := ec2metadata.New(unit.Session, &aws.Config{
LogLevel: aws.LogLevel(aws.LogDebugWithHTTPBody),
})
resp, err := svc.Region()
if err == nil {
t.Fatalf("expect error, got none")
}
if len(resp) != 0 {
t.Errorf("expect no response, got %v", resp)
}
aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %v error code, got %v", e, a)
}
if e, a := "AWS_EC2_METADATA_DISABLED", aerr.Message(); !strings.Contains(a, e) {
t.Errorf("expect %v in error message, got %v", e, a)
}
}
func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) {
var wg sync.WaitGroup
wg.Add(atOnce)
@ -70,7 +112,9 @@ func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) {
go func() {
svc := ec2metadata.New(unit.Session, cfg)
_, err := svc.Region()
assert.NoError(t, err)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
wg.Done()
}()
}

View File

@ -84,6 +84,8 @@ func decodeV3Endpoints(modelDef modelDefinition, opts DecodeModelOptions) (Resol
custAddEC2Metadata(p)
custAddS3DualStack(p)
custRmIotDataService(p)
custFixAppAutoscalingChina(p)
custFixAppAutoscalingUsGov(p)
}
return ps, nil
@ -94,7 +96,12 @@ func custAddS3DualStack(p *partition) {
return
}
s, ok := p.Services["s3"]
custAddDualstack(p, "s3")
custAddDualstack(p, "s3-control")
}
func custAddDualstack(p *partition, svcName string) {
s, ok := p.Services[svcName]
if !ok {
return
}
@ -102,7 +109,7 @@ func custAddS3DualStack(p *partition) {
s.Defaults.HasDualStack = boxedTrue
s.Defaults.DualStackHostname = "{service}.dualstack.{region}.{dnsSuffix}"
p.Services["s3"] = s
p.Services[svcName] = s
}
func custAddEC2Metadata(p *partition) {
@ -122,6 +129,54 @@ func custRmIotDataService(p *partition) {
delete(p.Services, "data.iot")
}
func custFixAppAutoscalingChina(p *partition) {
if p.ID != "aws-cn" {
return
}
const serviceName = "application-autoscaling"
s, ok := p.Services[serviceName]
if !ok {
return
}
const expectHostname = `autoscaling.{region}.amazonaws.com`
if e, a := s.Defaults.Hostname, expectHostname; e != a {
fmt.Printf("custFixAppAutoscalingChina: ignoring customization, expected %s, got %s\n", e, a)
return
}
s.Defaults.Hostname = expectHostname + ".cn"
p.Services[serviceName] = s
}
func custFixAppAutoscalingUsGov(p *partition) {
if p.ID != "aws-us-gov" {
return
}
const serviceName = "application-autoscaling"
s, ok := p.Services[serviceName]
if !ok {
return
}
if a := s.Defaults.CredentialScope.Service; a != "" {
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty credential scope service, got %s\n", a)
return
}
if a := s.Defaults.Hostname; a != "" {
fmt.Printf("custFixAppAutoscalingUsGov: ignoring customization, expected empty hostname, got %s\n", a)
return
}
s.Defaults.CredentialScope.Service = "application-autoscaling"
s.Defaults.Hostname = "autoscaling.{region}.amazonaws.com"
p.Services[serviceName] = s
}
type decodeModelError struct {
awsError
}

View File

@ -115,3 +115,110 @@ func TestDecodeModelOptionsSet(t *testing.T) {
t.Errorf("expect %v options got %v", expect, actual)
}
}
func TestCustFixAppAutoscalingChina(t *testing.T) {
const doc = `
{
"version": 3,
"partitions": [{
"defaults" : {
"hostname" : "{service}.{region}.{dnsSuffix}",
"protocols" : [ "https" ],
"signatureVersions" : [ "v4" ]
},
"dnsSuffix" : "amazonaws.com.cn",
"partition" : "aws-cn",
"partitionName" : "AWS China",
"regionRegex" : "^cn\\-\\w+\\-\\d+$",
"regions" : {
"cn-north-1" : {
"description" : "China (Beijing)"
},
"cn-northwest-1" : {
"description" : "China (Ningxia)"
}
},
"services" : {
"application-autoscaling" : {
"defaults" : {
"credentialScope" : {
"service" : "application-autoscaling"
},
"hostname" : "autoscaling.{region}.amazonaws.com",
"protocols" : [ "http", "https" ]
},
"endpoints" : {
"cn-north-1" : { },
"cn-northwest-1" : { }
}
}
}
}]
}`
resolver, err := DecodeModel(strings.NewReader(doc))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
endpoint, err := resolver.EndpointFor(
"application-autoscaling", "cn-northwest-1",
)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := `https://autoscaling.cn-northwest-1.amazonaws.com.cn`, endpoint.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestCustFixAppAutoscalingUsGov(t *testing.T) {
const doc = `
{
"version": 3,
"partitions": [{
"defaults" : {
"hostname" : "{service}.{region}.{dnsSuffix}",
"protocols" : [ "https" ],
"signatureVersions" : [ "v4" ]
},
"dnsSuffix" : "amazonaws.com",
"partition" : "aws-us-gov",
"partitionName" : "AWS GovCloud (US)",
"regionRegex" : "^us\\-gov\\-\\w+\\-\\d+$",
"regions" : {
"us-gov-east-1" : {
"description" : "AWS GovCloud (US-East)"
},
"us-gov-west-1" : {
"description" : "AWS GovCloud (US)"
}
},
"services" : {
"application-autoscaling" : {
"endpoints" : {
"us-gov-east-1" : { },
"us-gov-west-1" : { }
}
}
}
}]
}`
resolver, err := DecodeModel(strings.NewReader(doc))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
endpoint, err := resolver.EndpointFor(
"application-autoscaling", "us-gov-west-1",
)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := `https://autoscaling.us-gov-west-1.amazonaws.com`, endpoint.URL; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,141 @@
package endpoints
// Service identifiers
//
// Deprecated: Use client package's EndpointID value instead of these
// ServiceIDs. These IDs are not maintained, and are out of date.
const (
A4bServiceID = "a4b" // A4b.
AcmServiceID = "acm" // Acm.
AcmPcaServiceID = "acm-pca" // AcmPca.
ApiMediatailorServiceID = "api.mediatailor" // ApiMediatailor.
ApiPricingServiceID = "api.pricing" // ApiPricing.
ApiSagemakerServiceID = "api.sagemaker" // ApiSagemaker.
ApigatewayServiceID = "apigateway" // Apigateway.
ApplicationAutoscalingServiceID = "application-autoscaling" // ApplicationAutoscaling.
Appstream2ServiceID = "appstream2" // Appstream2.
AppsyncServiceID = "appsync" // Appsync.
AthenaServiceID = "athena" // Athena.
AutoscalingServiceID = "autoscaling" // Autoscaling.
AutoscalingPlansServiceID = "autoscaling-plans" // AutoscalingPlans.
BatchServiceID = "batch" // Batch.
BudgetsServiceID = "budgets" // Budgets.
CeServiceID = "ce" // Ce.
ChimeServiceID = "chime" // Chime.
Cloud9ServiceID = "cloud9" // Cloud9.
ClouddirectoryServiceID = "clouddirectory" // Clouddirectory.
CloudformationServiceID = "cloudformation" // Cloudformation.
CloudfrontServiceID = "cloudfront" // Cloudfront.
CloudhsmServiceID = "cloudhsm" // Cloudhsm.
Cloudhsmv2ServiceID = "cloudhsmv2" // Cloudhsmv2.
CloudsearchServiceID = "cloudsearch" // Cloudsearch.
CloudtrailServiceID = "cloudtrail" // Cloudtrail.
CodebuildServiceID = "codebuild" // Codebuild.
CodecommitServiceID = "codecommit" // Codecommit.
CodedeployServiceID = "codedeploy" // Codedeploy.
CodepipelineServiceID = "codepipeline" // Codepipeline.
CodestarServiceID = "codestar" // Codestar.
CognitoIdentityServiceID = "cognito-identity" // CognitoIdentity.
CognitoIdpServiceID = "cognito-idp" // CognitoIdp.
CognitoSyncServiceID = "cognito-sync" // CognitoSync.
ComprehendServiceID = "comprehend" // Comprehend.
ConfigServiceID = "config" // Config.
CurServiceID = "cur" // Cur.
DatapipelineServiceID = "datapipeline" // Datapipeline.
DaxServiceID = "dax" // Dax.
DevicefarmServiceID = "devicefarm" // Devicefarm.
DirectconnectServiceID = "directconnect" // Directconnect.
DiscoveryServiceID = "discovery" // Discovery.
DmsServiceID = "dms" // Dms.
DsServiceID = "ds" // Ds.
DynamodbServiceID = "dynamodb" // Dynamodb.
Ec2ServiceID = "ec2" // Ec2.
Ec2metadataServiceID = "ec2metadata" // Ec2metadata.
EcrServiceID = "ecr" // Ecr.
EcsServiceID = "ecs" // Ecs.
ElasticacheServiceID = "elasticache" // Elasticache.
ElasticbeanstalkServiceID = "elasticbeanstalk" // Elasticbeanstalk.
ElasticfilesystemServiceID = "elasticfilesystem" // Elasticfilesystem.
ElasticloadbalancingServiceID = "elasticloadbalancing" // Elasticloadbalancing.
ElasticmapreduceServiceID = "elasticmapreduce" // Elasticmapreduce.
ElastictranscoderServiceID = "elastictranscoder" // Elastictranscoder.
EmailServiceID = "email" // Email.
EntitlementMarketplaceServiceID = "entitlement.marketplace" // EntitlementMarketplace.
EsServiceID = "es" // Es.
EventsServiceID = "events" // Events.
FirehoseServiceID = "firehose" // Firehose.
FmsServiceID = "fms" // Fms.
GameliftServiceID = "gamelift" // Gamelift.
GlacierServiceID = "glacier" // Glacier.
GlueServiceID = "glue" // Glue.
GreengrassServiceID = "greengrass" // Greengrass.
GuarddutyServiceID = "guardduty" // Guardduty.
HealthServiceID = "health" // Health.
IamServiceID = "iam" // Iam.
ImportexportServiceID = "importexport" // Importexport.
InspectorServiceID = "inspector" // Inspector.
IotServiceID = "iot" // Iot.
IotanalyticsServiceID = "iotanalytics" // Iotanalytics.
KinesisServiceID = "kinesis" // Kinesis.
KinesisanalyticsServiceID = "kinesisanalytics" // Kinesisanalytics.
KinesisvideoServiceID = "kinesisvideo" // Kinesisvideo.
KmsServiceID = "kms" // Kms.
LambdaServiceID = "lambda" // Lambda.
LightsailServiceID = "lightsail" // Lightsail.
LogsServiceID = "logs" // Logs.
MachinelearningServiceID = "machinelearning" // Machinelearning.
MarketplacecommerceanalyticsServiceID = "marketplacecommerceanalytics" // Marketplacecommerceanalytics.
MediaconvertServiceID = "mediaconvert" // Mediaconvert.
MedialiveServiceID = "medialive" // Medialive.
MediapackageServiceID = "mediapackage" // Mediapackage.
MediastoreServiceID = "mediastore" // Mediastore.
MeteringMarketplaceServiceID = "metering.marketplace" // MeteringMarketplace.
MghServiceID = "mgh" // Mgh.
MobileanalyticsServiceID = "mobileanalytics" // Mobileanalytics.
ModelsLexServiceID = "models.lex" // ModelsLex.
MonitoringServiceID = "monitoring" // Monitoring.
MturkRequesterServiceID = "mturk-requester" // MturkRequester.
NeptuneServiceID = "neptune" // Neptune.
OpsworksServiceID = "opsworks" // Opsworks.
OpsworksCmServiceID = "opsworks-cm" // OpsworksCm.
OrganizationsServiceID = "organizations" // Organizations.
PinpointServiceID = "pinpoint" // Pinpoint.
PollyServiceID = "polly" // Polly.
RdsServiceID = "rds" // Rds.
RedshiftServiceID = "redshift" // Redshift.
RekognitionServiceID = "rekognition" // Rekognition.
ResourceGroupsServiceID = "resource-groups" // ResourceGroups.
Route53ServiceID = "route53" // Route53.
Route53domainsServiceID = "route53domains" // Route53domains.
RuntimeLexServiceID = "runtime.lex" // RuntimeLex.
RuntimeSagemakerServiceID = "runtime.sagemaker" // RuntimeSagemaker.
S3ServiceID = "s3" // S3.
S3ControlServiceID = "s3-control" // S3Control.
SagemakerServiceID = "api.sagemaker" // Sagemaker.
SdbServiceID = "sdb" // Sdb.
SecretsmanagerServiceID = "secretsmanager" // Secretsmanager.
ServerlessrepoServiceID = "serverlessrepo" // Serverlessrepo.
ServicecatalogServiceID = "servicecatalog" // Servicecatalog.
ServicediscoveryServiceID = "servicediscovery" // Servicediscovery.
ShieldServiceID = "shield" // Shield.
SmsServiceID = "sms" // Sms.
SnowballServiceID = "snowball" // Snowball.
SnsServiceID = "sns" // Sns.
SqsServiceID = "sqs" // Sqs.
SsmServiceID = "ssm" // Ssm.
StatesServiceID = "states" // States.
StoragegatewayServiceID = "storagegateway" // Storagegateway.
StreamsDynamodbServiceID = "streams.dynamodb" // StreamsDynamodb.
StsServiceID = "sts" // Sts.
SupportServiceID = "support" // Support.
SwfServiceID = "swf" // Swf.
TaggingServiceID = "tagging" // Tagging.
TransferServiceID = "transfer" // Transfer.
TranslateServiceID = "translate" // Translate.
WafServiceID = "waf" // Waf.
WafRegionalServiceID = "waf-regional" // WafRegional.
WorkdocsServiceID = "workdocs" // Workdocs.
WorkmailServiceID = "workmail" // Workmail.
WorkspacesServiceID = "workspaces" // Workspaces.
XrayServiceID = "xray" // Xray.
)

View File

@ -35,7 +35,7 @@ type Options struct {
//
// If resolving an endpoint on the partition list the provided region will
// be used to determine which partition's domain name pattern to the service
// endpoint ID with. If both the service and region are unkonwn and resolving
// endpoint ID with. If both the service and region are unknown and resolving
// the endpoint on partition list an UnknownEndpointError error will be returned.
//
// If resolving and endpoint on a partition specific resolver that partition's
@ -206,10 +206,11 @@ func (p Partition) EndpointFor(service, region string, opts ...func(*Options)) (
// enumerating over the regions in a partition.
func (p Partition) Regions() map[string]Region {
rs := map[string]Region{}
for id := range p.p.Regions {
for id, r := range p.p.Regions {
rs[id] = Region{
id: id,
p: p.p,
id: id,
desc: r.Description,
p: p.p,
}
}
@ -240,6 +241,10 @@ type Region struct {
// ID returns the region's identifier.
func (r Region) ID() string { return r.id }
// Description returns the region's description. The region description
// is free text, it can be empty, and it may change between SDK releases.
func (r Region) Description() string { return r.desc }
// ResolveEndpoint resolves an endpoint from the context of the region given
// a service. See Partition.EndpointFor for usage and errors that can be returned.
func (r Region) ResolveEndpoint(service string, opts ...func(*Options)) (ResolvedEndpoint, error) {
@ -284,10 +289,11 @@ func (s Service) ResolveEndpoint(region string, opts ...func(*Options)) (Resolve
func (s Service) Regions() map[string]Region {
rs := map[string]Region{}
for id := range s.p.Services[s.id].Endpoints {
if _, ok := s.p.Regions[id]; ok {
if r, ok := s.p.Regions[id]; ok {
rs[id] = Region{
id: id,
p: s.p,
id: id,
desc: r.Description,
p: s.p,
}
}
}
@ -347,6 +353,10 @@ type ResolvedEndpoint struct {
// The service name that should be used for signing requests.
SigningName string
// States that the signing name for this endpoint was derived from metadata
// passed in, but was not explicitly modeled.
SigningNameDerived bool
// The signing method that should be used for signing requests.
SigningMethod string
}

View File

@ -65,6 +65,10 @@ func TestEnumRegionServices(t *testing.T) {
t.Errorf("expect %q region ID, got %q", e, a)
}
if a, e := r.Description(), "region description"; a != e {
t.Errorf("expect %q region Description, got %q", e, a)
}
ss := r.Services()
if a, e := len(ss), 1; a != e {
t.Errorf("expect %d services for us-east-1, got %d", e, a)
@ -291,6 +295,9 @@ func TestRegionsForService(t *testing.T) {
if _, ok := expect[id]; !ok {
t.Errorf("expect %s region to be found", id)
}
if a, e := r.Description(), expect[id].desc; a != e {
t.Errorf("expect %q region Description, got %q", e, a)
}
}
}

View File

@ -226,16 +226,20 @@ func (e endpoint) resolve(service, region, dnsSuffix string, defs []endpoint, op
if len(signingRegion) == 0 {
signingRegion = region
}
signingName := e.CredentialScope.Service
var signingNameDerived bool
if len(signingName) == 0 {
signingName = service
signingNameDerived = true
}
return ResolvedEndpoint{
URL: u,
SigningRegion: signingRegion,
SigningName: signingName,
SigningMethod: getByPriority(e.SignatureVersions, signerPriority, defaultSigner),
URL: u,
SigningRegion: signingRegion,
SigningName: signingName,
SigningNameDerived: signingNameDerived,
SigningMethod: getByPriority(e.SignatureVersions, signerPriority, defaultSigner),
}
}

View File

@ -16,6 +16,10 @@ import (
type CodeGenOptions struct {
// Options for how the model will be decoded.
DecodeModelOptions DecodeModelOptions
// Disables code generation of the service endpoint prefix IDs defined in
// the model.
DisableGenerateServiceIDs bool
}
// Set combines all of the option functions together
@ -39,8 +43,16 @@ func CodeGenModel(modelFile io.Reader, outFile io.Writer, optFns ...func(*CodeGe
return err
}
v := struct {
Resolver
CodeGenOptions
}{
Resolver: resolver,
CodeGenOptions: opts,
}
tmpl := template.Must(template.New("tmpl").Funcs(funcMap).Parse(v3Tmpl))
if err := tmpl.ExecuteTemplate(outFile, "defaults", resolver); err != nil {
if err := tmpl.ExecuteTemplate(outFile, "defaults", v); err != nil {
return fmt.Errorf("failed to execute template, %v", err)
}
@ -166,15 +178,17 @@ import (
"regexp"
)
{{ template "partition consts" . }}
{{ template "partition consts" $.Resolver }}
{{ range $_, $partition := . }}
{{ range $_, $partition := $.Resolver }}
{{ template "partition region consts" $partition }}
{{ end }}
{{ template "service consts" . }}
{{ if not $.DisableGenerateServiceIDs -}}
{{ template "service consts" $.Resolver }}
{{- end }}
{{ template "endpoint resolvers" . }}
{{ template "endpoint resolvers" $.Resolver }}
{{- end }}
{{ define "partition consts" }}

Some files were not shown because too many files have changed in this diff Show More