1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
# frozen_string_literal: true
require "openssl"
require "stringio"
# Provide HMAC-based Extract-and-Expand Key Derivation Function (HKDF) for Ruby.
class HKDF
# Default hash algorithm to use for HMAC.
DEFAULT_ALGOTIHM = "SHA256"
# Default buffer size for reading source IO.
DEFAULT_READ_SIZE = 512 * 1024
# Create a new HKDF instance with then provided +source+ key material.
#
# Options:
# - +algorithm:+ hash function to use (defaults to SHA-256)
# - +info:+ optional context and application specific information
# - +salt:+ optional salt value (a non-secret random value)
# - +read_size:+ buffer size when reading from a source IO
def initialize(source, options = {})
source = StringIO.new(source) if source.is_a?(String)
algorithm = options.fetch(:algorithm, DEFAULT_ALGOTIHM)
@digest = OpenSSL::Digest.new(algorithm)
@info = options.fetch(:info, "")
salt = options[:salt]
salt = 0.chr * @digest.digest_length if salt.nil? || salt.empty?
read_size = options.fetch(:read_size, DEFAULT_READ_SIZE)
@prk = generate_prk(salt, source, read_size)
@position = 0
@blocks = [""]
end
# Returns the hash algorithm this instance was configured with.
def algorithm
@digest.name
end
# Maximum length that can be derived per the RFC.
def max_length
@max_length ||= @digest.digest_length * 255
end
# Adjust the reading position to an arbitrary offset. Will raise +RangeError+ if you attempt to seek longer than
# +#max_length+.
def seek(position)
raise RangeError, "cannot seek past #{max_length}" if position > max_length
@position = position
end
# Adjust reading position back to the beginning.
def rewind
seek(0)
end
# Read the next +length+ bytes from the stream. Will raise +RangeError+ if you attempt to read beyond +#max_length+.
def read(length)
new_position = length + @position
raise RangeError, "requested #{length} bytes, only #{max_length} available" if new_position > max_length
generate_blocks(new_position)
start = @position
@position = new_position
@blocks.join.slice(start, length)
end
# Read the next +length+ bytes from the stream and return them hex encoded. Will raise +RangeError+ if you attempt to
# read beyond +#max_length+.
def read_hex(length)
read(length).unpack1("H*")
end
# :nodoc:
def inspect
"#{to_s[0..-2]} algorithm=#{@digest.name.inspect} info=#{@info.inspect}>"
end
private
def generate_prk(salt, source, read_size)
hmac = OpenSSL::HMAC.new(salt, @digest)
while (block = source.read(read_size))
hmac.update(block)
end
hmac.digest
end
def generate_blocks(length)
start = @blocks.size
block_count = (length.to_f / @digest.digest_length).ceil
start.upto(block_count) do |n|
@blocks << OpenSSL::HMAC.digest(@digest, @prk, @blocks[n - 1] + @info + n.chr)
end
end
end
|