File: aws_sigv4.rb

package info (click to toggle)
ruby-httpx 1.7.2-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,816 kB
  • sloc: ruby: 12,209; makefile: 4
file content (239 lines) | stat: -rw-r--r-- 8,140 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# frozen_string_literal: true

module HTTPX
  module Plugins
    #
    # This plugin adds AWS Sigv4 authentication.
    #
    # https://docs.aws.amazon.com/IAM/latest/UserGuide/signing-elements.html
    #
    # https://gitlab.com/os85/httpx/wikis/AWS-SigV4
    #
    module AWSSigV4
      Credentials = Struct.new(:username, :password, :security_token)

      # Signs requests using the AWS sigv4 signing.
      class Signer
        def initialize(
          service:,
          region:,
          credentials: nil,
          username: nil,
          password: nil,
          security_token: nil,
          provider_prefix: "aws",
          header_provider_field: "amz",
          unsigned_headers: [],
          apply_checksum_header: true,
          algorithm: "SHA256"
        )
          @credentials = credentials || Credentials.new(username, password, security_token)
          @service = service
          @region = region

          @unsigned_headers = Set.new(unsigned_headers.map(&:downcase))
          @unsigned_headers << "authorization"
          @unsigned_headers << "x-amzn-trace-id"
          @unsigned_headers << "expect"

          @apply_checksum_header = apply_checksum_header
          @provider_prefix = provider_prefix
          @header_provider_field = header_provider_field

          @algorithm = algorithm
        end

        def sign!(request)
          lower_provider_prefix = "#{@provider_prefix}4"
          upper_provider_prefix = lower_provider_prefix.upcase

          downcased_algorithm = @algorithm.downcase

          datetime = (request.headers["x-#{@header_provider_field}-date"] ||= Time.now.utc.strftime("%Y%m%dT%H%M%SZ"))
          date = datetime[0, 8]

          content_hashed = request.headers["x-#{@header_provider_field}-content-#{downcased_algorithm}"] || hexdigest(request.body)

          request.headers["x-#{@header_provider_field}-content-#{downcased_algorithm}"] ||= content_hashed if @apply_checksum_header
          request.headers["x-#{@header_provider_field}-security-token"] ||= @credentials.security_token if @credentials.security_token

          signature_headers = request.headers.each.reject do |k, _|
            @unsigned_headers.include?(k)
          end
          # aws sigv4 needs to declare the host, regardless of protocol version
          signature_headers << ["host", request.authority] unless request.headers.key?("host")
          signature_headers.sort_by!(&:first)

          signed_headers = signature_headers.map(&:first).join(";")

          canonical_headers = signature_headers.map do |k, v|
            # eliminate whitespace between value fields, unless it's a quoted value
            "#{k}:#{v.start_with?("\"") && v.end_with?("\"") ? v : v.gsub(/\s+/, " ").strip}\n"
          end.join

          # canonical request
          creq = "#{request.verb}" \
                 "\n#{request.canonical_path}" \
                 "\n#{request.canonical_query}" \
                 "\n#{canonical_headers}" \
                 "\n#{signed_headers}" \
                 "\n#{content_hashed}"

          credential_scope = "#{date}" \
                             "/#{@region}" \
                             "/#{@service}" \
                             "/#{lower_provider_prefix}_request"

          algo_line = "#{upper_provider_prefix}-HMAC-#{@algorithm}"
          # string to sign
          sts = "#{algo_line}" \
                "\n#{datetime}" \
                "\n#{credential_scope}" \
                "\n#{OpenSSL::Digest.new(@algorithm).hexdigest(creq)}"

          # signature
          k_date = hmac("#{upper_provider_prefix}#{@credentials.password}", date)
          k_region = hmac(k_date, @region)
          k_service = hmac(k_region, @service)
          k_credentials = hmac(k_service, "#{lower_provider_prefix}_request")
          sig = hexhmac(k_credentials, sts)

          credential = "#{@credentials.username}/#{credential_scope}"
          # apply signature
          request.headers["authorization"] =
            "#{algo_line} " \
            "Credential=#{credential}, " \
            "SignedHeaders=#{signed_headers}, " \
            "Signature=#{sig}"
        end

        private

        def hexdigest(value)
          digest = OpenSSL::Digest.new(@algorithm)

          if value.respond_to?(:read)
            if value.respond_to?(:to_path)
              # files, pathnames
              digest.file(value.to_path).hexdigest
            else
              # gzipped request bodies
              raise Error, "request body must be rewindable" unless value.respond_to?(:rewind)

              buffer = Tempfile.new("httpx", encoding: Encoding::BINARY, mode: File::RDWR)
              begin
                IO.copy_stream(value, buffer)
                buffer.flush

                digest.file(buffer.to_path).hexdigest
              ensure
                value.rewind
                buffer.close
                buffer.unlink
              end
            end
          else
            # error on endless generators
            raise Error, "hexdigest for endless enumerators is not supported" if value.unbounded_body?

            mb_buffer = value.each.with_object("".b) do |chunk, b|
              b << chunk
              break if b.bytesize >= 1024 * 1024
            end

            digest.hexdigest(mb_buffer)
          end
        end

        def hmac(key, value)
          OpenSSL::HMAC.digest(OpenSSL::Digest.new(@algorithm), key, value)
        end

        def hexhmac(key, value)
          OpenSSL::HMAC.hexdigest(OpenSSL::Digest.new(@algorithm), key, value)
        end
      end

      class << self
        def load_dependencies(*)
          require "set"
          require "digest/sha2"
          require "cgi/escape"
        end

        def configure(klass)
          klass.plugin(:expect)
        end
      end

      # adds support for the following options:
      #
      # :sigv4_signer :: instance of HTTPX::Plugins::AWSSigV4 used to sign requests.
      module OptionsMethods
        private

        def option_sigv4_signer(value)
          value.is_a?(Signer) ? value : Signer.new(value)
        end
      end

      module InstanceMethods
        def aws_sigv4_authentication(**options)
          with(sigv4_signer: Signer.new(**options))
        end

        def build_request(*)
          request = super

          return request if request.headers.key?("authorization")

          signer = request.options.sigv4_signer

          return request unless signer

          signer.sign!(request)

          request
        end
      end

      module RequestMethods
        def canonical_path
          path = uri.path.dup
          path << "/" if path.empty?
          path.gsub(%r{[^/]+}) { |part| CGI.escape(part.encode("UTF-8")).gsub("+", "%20").gsub("%7E", "~") }
        end

        def canonical_query
          params = query.split("&")
          # params = params.map { |p| p.match(/=/) ? p : p + '=' }
          # From: https://docs.aws.amazon.com/IAM/latest/UserGuide/create-signed-request.html#create-canonical-request
          # Sort the parameter names by character code point in ascending order.
          # Parameters with duplicate names should be sorted by value.
          #
          # Default sort <=> in JRuby will swap members
          # occasionally when <=> is 0 (considered still sorted), but this
          # causes our normalized query string to not match the sent querystring.
          # When names match, we then sort by their values.  When values also
          # match then we sort by their original order
          params.each.with_index.sort do |a, b|
            a, a_offset = a
            b, b_offset = b
            a_name, a_value = a.split("=", 2)
            b_name, b_value = b.split("=", 2)
            if a_name == b_name
              if a_value == b_value
                a_offset <=> b_offset
              else
                a_value <=> b_value
              end
            else
              a_name <=> b_name
            end
          end.map(&:first).join("&")
        end
      end
    end
    register_plugin :aws_sigv4, AWSSigV4
  end
end