1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
|
package awsauth
import (
"bufio"
"bytes"
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
)
type location struct {
ec2 bool
checked bool
}
var loc *location
// serviceAndRegion parsers a hostname to find out which ones it is.
// http://docs.aws.amazon.com/general/latest/gr/rande.html
func serviceAndRegion(host string) (service string, region string) {
// These are the defaults if the hostname doesn't suggest something else
region = "us-east-1"
service = "s3"
parts := strings.Split(host, ".")
if len(parts) == 4 {
// Either service.region.amazonaws.com or virtual-host.region.amazonaws.com
if parts[1] == "s3" {
service = "s3"
} else if strings.HasPrefix(parts[1], "s3-") {
region = parts[1][3:]
service = "s3"
} else {
service = parts[0]
region = parts[1]
}
} else if len(parts) == 5 {
service = parts[2]
region = parts[1]
} else {
// Either service.amazonaws.com or s3-region.amazonaws.com
if strings.HasPrefix(parts[0], "s3-") {
region = parts[0][3:]
} else {
service = parts[0]
}
}
if region == "external-1" {
region = "us-east-1"
}
return
}
// newKeys produces a set of credentials based on the environment
func newKeys() (newCredentials Credentials) {
// First use credentials from environment variables
newCredentials.AccessKeyID = os.Getenv(envAccessKeyID)
if newCredentials.AccessKeyID == "" {
newCredentials.AccessKeyID = os.Getenv(envAccessKey)
}
newCredentials.SecretAccessKey = os.Getenv(envSecretAccessKey)
if newCredentials.SecretAccessKey == "" {
newCredentials.SecretAccessKey = os.Getenv(envSecretKey)
}
newCredentials.SecurityToken = os.Getenv(envSecurityToken)
// If there is no Access Key and you are on EC2, get the key from the role
if (newCredentials.AccessKeyID == "" || newCredentials.SecretAccessKey == "") && onEC2() {
newCredentials = *getIAMRoleCredentials()
}
// If the key is expiring, get a new key
if newCredentials.expired() && onEC2() {
newCredentials = *getIAMRoleCredentials()
}
return newCredentials
}
// checkKeys gets credentials depending on if any were passed in as an argument
// or it makes new ones based on the environment.
func chooseKeys(cred []Credentials) Credentials {
if len(cred) == 0 {
return newKeys()
} else {
return cred[0]
}
}
// onEC2 checks to see if the program is running on an EC2 instance.
// It does this by looking for the EC2 metadata service.
// This caches that information in a struct so that it doesn't waste time.
func onEC2() bool {
if loc == nil {
loc = &location{}
}
if !(loc.checked) {
c, err := net.DialTimeout("tcp", "169.254.169.254:80", time.Millisecond*100)
if err != nil {
loc.ec2 = false
} else {
c.Close()
loc.ec2 = true
}
loc.checked = true
}
return loc.ec2
}
// getIAMRoleList gets a list of the roles that are available to this instance
func getIAMRoleList() []string {
var roles []string
url := "http://169.254.169.254/latest/meta-data/iam/security-credentials/"
client := &http.Client{}
request, err := http.NewRequest("GET", url, nil)
if err != nil {
return roles
}
response, err := client.Do(request)
if err != nil {
return roles
}
defer response.Body.Close()
scanner := bufio.NewScanner(response.Body)
for scanner.Scan() {
roles = append(roles, scanner.Text())
}
return roles
}
func getIAMRoleCredentials() *Credentials {
roles := getIAMRoleList()
if len(roles) < 1 {
return &Credentials{}
}
// Use the first role in the list
role := roles[0]
url := "http://169.254.169.254/latest/meta-data/iam/security-credentials/"
// Create the full URL of the role
var buffer bytes.Buffer
buffer.WriteString(url)
buffer.WriteString(role)
roleURL := buffer.String()
// Get the role
roleRequest, err := http.NewRequest("GET", roleURL, nil)
if err != nil {
return &Credentials{}
}
client := &http.Client{}
roleResponse, err := client.Do(roleRequest)
if err != nil {
return &Credentials{}
}
defer roleResponse.Body.Close()
roleBuffer := new(bytes.Buffer)
roleBuffer.ReadFrom(roleResponse.Body)
credentials := Credentials{}
err = json.Unmarshal(roleBuffer.Bytes(), &credentials)
if err != nil {
return &Credentials{}
}
return &credentials
}
func augmentRequestQuery(request *http.Request, values url.Values) *http.Request {
for key, array := range request.URL.Query() {
for _, value := range array {
values.Set(key, value)
}
}
request.URL.RawQuery = values.Encode()
return request
}
func hmacSHA256(key []byte, content string) []byte {
mac := hmac.New(sha256.New, key)
mac.Write([]byte(content))
return mac.Sum(nil)
}
func hmacSHA1(key []byte, content string) []byte {
mac := hmac.New(sha1.New, key)
mac.Write([]byte(content))
return mac.Sum(nil)
}
func hashSHA256(content []byte) string {
h := sha256.New()
h.Write(content)
return fmt.Sprintf("%x", h.Sum(nil))
}
func hashMD5(content []byte) string {
h := md5.New()
h.Write(content)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
func readAndReplaceBody(request *http.Request) []byte {
if request.Body == nil {
return []byte{}
}
payload, _ := ioutil.ReadAll(request.Body)
request.Body = ioutil.NopCloser(bytes.NewReader(payload))
return payload
}
func concat(delim string, str ...string) string {
return strings.Join(str, delim)
}
var now = func() time.Time {
return time.Now().UTC()
}
func normuri(uri string) string {
parts := strings.Split(uri, "/")
for i := range parts {
parts[i] = encodePathFrag(parts[i])
}
return strings.Join(parts, "/")
}
func encodePathFrag(s string) string {
hexCount := 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c) {
hexCount++
}
}
t := make([]byte, len(s)+2*hexCount)
j := 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c) {
t[j] = '%'
t[j+1] = "0123456789ABCDEF"[c>>4]
t[j+2] = "0123456789ABCDEF"[c&15]
j += 3
} else {
t[j] = c
j++
}
}
return string(t)
}
func shouldEscape(c byte) bool {
if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' {
return false
}
if '0' <= c && c <= '9' {
return false
}
if c == '-' || c == '_' || c == '.' || c == '~' {
return false
}
return true
}
func normquery(v url.Values) string {
queryString := v.Encode()
// Go encodes a space as '+' but Amazon requires '%20'. Luckily any '+' in the
// original query string has been percent escaped so all '+' chars that are left
// were originally spaces.
return strings.Replace(queryString, "+", "%20", -1)
}
|