# coding: utf-8
# typed: strict
# frozen_string_literal: true

require 'digest/md5'

class PDF::Reader

  # Decrypts data using the AESV2 algorithim defined in the PDF spec. Requires
  # a decryption key, which is usually generated by PDF::Reader::StandardKeyBuilder
  #
  class AesV2SecurityHandler

    #: (String) -> void
    def initialize(key)
      @encrypt_key = key
    end

    ##7.6.2 General Encryption Algorithm
    #
    # Algorithm 1: Encryption of data using the AES-128-CBC algorithm
    #
    # version == 4 and CFM == AESV2
    #
    # used to decrypt PDF streams (buf). Input data should be in bytesizes of
    # a multiple of 16, anything else is an error. The first 16 bytes are the initialization
    # vector, so any input of exactly 16 bytes decrypts to an empty string
    #
    # buf - a string to decrypt
    # ref - a PDF::Reader::Reference for the object to decrypt
    #
    #: (String, PDF::Reader::Reference) -> String
    def decrypt( buf, ref )
      if buf.bytesize % 16 > 0
        raise PDF::Reader::MalformedPDFError.new("Ciphertext not a multiple of 16")
      elsif buf.bytesize == 16
        return ""
      else
        begin
          internal_decrypt(buf, ref)
        rescue OpenSSL::Cipher::CipherError
          # If we failed to decrypt it might be a padding error, so try again
          # and assume no padding in the ciphertext. This will "suceed" but might
          # return garbage if the key is incorrect but that's OK - well before this
          # class is used we have confirmed the user provided key is correct so if
          # this works without error we can be confident the returned plaintext is
          #  correct
          internal_decrypt(buf, ref, false)
        end
      end
    end

    private

    #: (String, PDF::Reader::Reference, ?bool) -> String
    def internal_decrypt(buf, ref, padding = true)
      objKey = @encrypt_key.dup
      (0..2).each { |e| objKey << (ref.id >> e*8 & 0xFF ) }
      (0..1).each { |e| objKey << (ref.gen >> e*8 & 0xFF ) }
      objKey << 'sAlT'  # Algorithm 1, b)
      length = objKey.length < 16 ? objKey.length : 16
      cipher = OpenSSL::Cipher.new("AES-#{length << 3}-CBC")
      cipher.decrypt
      cipher.padding = 0 unless padding
      cipher.key = Digest::MD5.digest(objKey)[0,length]
      cipher.iv = buf[0..15]
      cipher.update(buf[16..-1]) + cipher.final
    end

  end
end
