File: request.rb

package info (click to toggle)
ruby-mongo 2.21.3-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 14,764 kB
  • sloc: ruby: 108,806; makefile: 5; sh: 2
file content (284 lines) | stat: -rw-r--r-- 10,653 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# frozen_string_literal: true
# rubocop:todo all

# Copyright (C) 2020 MongoDB Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

module Net
  autoload :HTTP, 'net/http'
end

module Mongo
  module Auth
    class Aws

      # Helper class for working with AWS requests.
      #
      # The primary purpose of this class is to produce the canonical AWS
      # STS request and calculate the signed headers and signature for it.
      #
      # @api private
      class Request

        # The body of the STS GetCallerIdentity request.
        #
        # This is currently the only request that this class supports making.
        STS_REQUEST_BODY = "Action=GetCallerIdentity&Version=2011-06-15".freeze

        # The timeout, in seconds, to use for validating credentials via STS.
        VALIDATE_TIMEOUT = 10

        # Constructs the request.
        #
        # @note By overriding the time, it is possible to create reproducible
        #   requests (in other words, replay a request).
        #
        # @param [ String ] access_key_id The access key id.
        # @param [ String ] secret_access_key The secret access key.
        # @param [ String ] session_token The session token for temporary
        #   credentials.
        # @param [ String ] host The value of Host HTTP header to use.
        # @param [ String ] server_nonce The server nonce binary string.
        # @param [ Time ] time The time of the request.
        def initialize(access_key_id:, secret_access_key:, session_token: nil,
          host:, server_nonce:, time: Time.now
        )
          @access_key_id = access_key_id
          @secret_access_key = secret_access_key
          @session_token = session_token
          @host = host
          @server_nonce = server_nonce
          @time = time

          %i(access_key_id secret_access_key host server_nonce).each do |arg|
            value = instance_variable_get("@#{arg}")
            if value.nil? || value.empty?
              raise Error::InvalidServerAuthResponse, "Value for '#{arg}' is required"
            end
          end

          if host && host.length > 255
              raise Error::InvalidServerAuthHost, "Value for 'host' is too long: #{@host}"
          end
        end

        # @return [ String ] access_key_id The access key id.
        attr_reader :access_key_id

        # @return [ String ] secret_access_key The secret access key.
        attr_reader :secret_access_key

        # @return [ String ] session_token The session token for temporary
        #   credentials.
        attr_reader :session_token

        # @return [ String ] host The value of Host HTTP header to use.
        attr_reader :host

        # @return [ String ] server_nonce The server nonce binary string.
        attr_reader :server_nonce

        # @return [ Time ] time The time of the request.
        attr_reader :time

        # @return [ String ] formatted_time ISO8601-formatted time of the
        #   request, as would be used in X-Amz-Date header.
        def formatted_time
          @formatted_time ||= @time.getutc.strftime('%Y%m%dT%H%M%SZ')
        end

        # @return [ String ] formatted_date YYYYMMDD formatted date of the request.
        def formatted_date
          formatted_time[0, 8]
        end

        # @return [ String ] region The region of the host, derived from the host.
        def region
          # Common case
          if host == 'sts.amazonaws.com'
            return 'us-east-1'
          end

          if host.start_with?('.')
            raise Error::InvalidServerAuthHost, "Host begins with a period: #{host}"
          end
          if host.end_with?('.')
            raise Error::InvalidServerAuthHost, "Host ends with a period: #{host}"
          end

          parts = host.split('.')
          if parts.any? { |part| part.empty? }
            raise Error::InvalidServerAuthHost, "Host has an empty component: #{host}"
          end

          if parts.length == 1
            'us-east-1'
          else
            parts[1]
          end
        end

        # Returns the scope of the request, per the AWS signature V4 specification.
        #
        # @return [ String ] The scope.
        def scope
          "#{formatted_date}/#{region}/sts/aws4_request"
        end

        # Returns the hash containing the headers of the calculated canonical
        # request.
        #
        # @note Not all of these headers are part of the signed headers list,
        #   the keys of the hash are not necessarily ordered lexicographically,
        #   and the keys may be in any case.
        #
        # @return [ <Hash> ] headers The headers.
        def headers
          headers = {
            'content-length' => STS_REQUEST_BODY.length.to_s,
            'content-type' => 'application/x-www-form-urlencoded',
            'host' => host,
            'x-amz-date' => formatted_time,
            'x-mongodb-gs2-cb-flag' => 'n',
            'x-mongodb-server-nonce' => Base64.encode64(server_nonce).gsub("\n", ''),
          }
          if session_token
            headers['x-amz-security-token'] = session_token
          end
          headers
        end

        # Returns the hash containing the headers of the calculated canonical
        # request that should be signed, in a ready to sign form.
        #
        # The differences between #headers and this method is this method:
        #
        # - Removes any headers that are not to be signed. Per AWS
        #   specifications it should be possible to sign all headers, but
        #   MongoDB server expects only some headers to be signed and will
        #   not form the correct request if other headers are signed.
        # - Lowercases all header names.
        # - Orders the headers lexicographically in the hash.
        #
        # @return [ <Hash> ] headers The headers.
        def headers_to_sign
          headers_to_sign = {}
          headers.keys.sort_by { |k| k.downcase }.each do |key|
            write_key = key.downcase
            headers_to_sign[write_key] = headers[key]
          end
          headers_to_sign
        end

        # Returns semicolon-separated list of names of signed headers, per
        # the AWS signature V4 specification.
        #
        # @return [ String ] The signed header list.
        def signed_headers_string
          headers_to_sign.keys.join(';')
        end

        # Returns the canonical request used during calculation of AWS V4
        # signature.
        #
        # @return [ String ] The canonical request.
        def canonical_request
          headers = headers_to_sign
          serialized_headers = headers.map do |k, v|
            "#{k}:#{v}"
          end.join("\n")
          hashed_payload = Digest::SHA256.new.update(STS_REQUEST_BODY).hexdigest
          "POST\n/\n\n" +
            # There are two newlines after serialized headers because the
            # signature V4 specification treats each header as containing the
            # terminating newline, and there is an additional newline
            # separating headers from the signed header names.
            "#{serialized_headers}\n\n" +
            "#{signed_headers_string}\n" +
            hashed_payload
        end

        # Returns the calculated signature of the canonical request, per
        # the AWS signature V4 specification.
        #
        # @return [ String ] The signature.
        def signature
          hashed_canonical_request = Digest::SHA256.hexdigest(canonical_request)
          string_to_sign = "AWS4-HMAC-SHA256\n" +
            "#{formatted_time}\n" +
            "#{scope}\n" +
            hashed_canonical_request
          # All of the intermediate HMAC operations are not hex-encoded.
          mac = hmac("AWS4#{secret_access_key}", formatted_date)
          mac = hmac(mac, region)
          mac = hmac(mac, 'sts')
          signing_key = hmac(mac, 'aws4_request')
          # Only the final HMAC operation is hex-encoded.
          hmac_hex(signing_key, string_to_sign)
        end

        # Returns the value of the Authorization header, per the AWS
        # signature V4 specification.
        #
        # @return [ String ] Authorization header value.
        def authorization
          "AWS4-HMAC-SHA256 Credential=#{access_key_id}/#{scope}, SignedHeaders=#{signed_headers_string}, Signature=#{signature}"
        end

        # Validates the credentials and the constructed request components
        # by sending a real STS GetCallerIdentity request.
        #
        # @return [ Hash ] GetCallerIdentity result.
        def validate!
          sts_request = Net::HTTP::Post.new("https://#{host}").tap do |req|
            headers.each do |k, v|
              req[k] = v
            end
            req['authorization'] = authorization
            req['accept'] = 'application/json'
            req.body = STS_REQUEST_BODY
          end
          http = Net::HTTP.new(host, 443)
          http.use_ssl = true
          http.start do
            resp = Timeout.timeout(VALIDATE_TIMEOUT, Error::CredentialCheckError, 'GetCallerIdentity request timed out') do
              http.request(sts_request)
            end
            payload = JSON.parse(resp.body)
            if resp.code != '200'
              aws_code = payload.fetch('Error').fetch('Code')
              aws_message = payload.fetch('Error').fetch('Message')
              msg = "Credential check for user #{access_key_id} failed with HTTP status code #{resp.code}: #{aws_code}: #{aws_message}"
              msg += '.' unless msg.end_with?('.')
              msg += " Please check that the credentials are valid, and if they are temporary (i.e. use the session token) that the session token is provided and not expired"
              raise Error::CredentialCheckError, msg
            end
            payload.fetch('GetCallerIdentityResponse').fetch('GetCallerIdentityResult')
          end
        end

        private

        def hmac(key, data)
          OpenSSL::HMAC.digest("SHA256", key, data)
        end

        def hmac_hex(key, data)
          OpenSSL::HMAC.hexdigest("SHA256", key, data)
        end

      end
    end
  end
end