File: ecdsa.rb

package info (click to toggle)
ruby-ssh-data 1.3.0-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 184 kB
  • sloc: ruby: 1,483; makefile: 4
file content (155 lines) | stat: -rw-r--r-- 4,586 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
module SSHData
  module PublicKey
    class ECDSA < Base
      attr_reader :curve, :public_key_bytes, :openssl

      NISTP256 = "nistp256"
      NISTP384 = "nistp384"
      NISTP521 = "nistp521"

      OPENSSL_CURVE_NAME_FOR_CURVE = {
        NISTP256 => "prime256v1",
        NISTP384 => "secp384r1",
        NISTP521 => "secp521r1",
      }

      CURVE_FOR_OPENSSL_CURVE_NAME = {
        "prime256v1" => NISTP256,
        "secp384r1"  => NISTP384,
        "secp521r1"  => NISTP521,
      }

      DIGEST_FOR_CURVE = {
        NISTP256 => OpenSSL::Digest::SHA256,
        NISTP384 => OpenSSL::Digest::SHA384,
        NISTP521 => OpenSSL::Digest::SHA512,
      }

      # Convert an SSH encoded ECDSA signature to DER encoding for verification with
      # OpenSSL.
      #
      # sig - A binary String signature from an SSH packet.
      #
      # Returns a binary String signature, as expected by OpenSSL.
      def self.openssl_signature(sig)
        r, rlen = Encoding.decode_mpint(sig, 0)
        s, slen = Encoding.decode_mpint(sig, rlen)

        if rlen + slen != sig.bytesize
          raise DecodeError, "unexpected trailing data"
        end

        OpenSSL::ASN1::Sequence.new([
          OpenSSL::ASN1::Integer.new(r),
          OpenSSL::ASN1::Integer.new(s)
        ]).to_der
      end

      # Convert an DER encoded ECDSA signature, as generated by OpenSSL to SSH
      # encoding.
      #
      # sig - A binary String signature, as generated by OpenSSL.
      #
      # Returns a binary String signature, as found in an SSH packet.
      def self.ssh_signature(sig)
        a1 = OpenSSL::ASN1.decode(sig)
        if a1.tag_class != :UNIVERSAL || a1.tag != OpenSSL::ASN1::SEQUENCE || a1.value.count != 2
          raise DecodeError, "bad asn1 signature"
        end

        r, s = a1.value
        if r.tag_class != :UNIVERSAL || r.tag != OpenSSL::ASN1::INTEGER || s.tag_class != :UNIVERSAL || s.tag != OpenSSL::ASN1::INTEGER
          raise DecodeError, "bad asn1 signature"
        end

        [Encoding.encode_mpint(r.value), Encoding.encode_mpint(s.value)].join
      end

      def self.check_algorithm!(algo, curve)
        unless [ALGO_ECDSA256, ALGO_ECDSA384, ALGO_ECDSA521].include?(algo)
          raise DecodeError, "bad algorithm: #{algo.inspect}"
        end

        unless algo == "ecdsa-sha2-#{curve}"
          raise DecodeError, "bad curve: #{curve.inspect}"
        end
      end

      def initialize(algo:, curve:, public_key:)
        self.class.check_algorithm!(algo, curve)

        @curve = curve
        @public_key_bytes = public_key

        @openssl = begin
          OpenSSL::PKey::EC.new(asn1.to_der)
        rescue ArgumentError
          raise DecodeError, "bad key data"
        end

        super(algo: algo)
      end

      # Verify an SSH signature.
      #
      # signed_data - The String message that the signature was calculated over.
      # signature   - The binarty String signature with SSH encoding.
      #
      # Returns boolean.
      def verify(signed_data, signature)
        sig_algo, ssh_sig, _ = Encoding.decode_signature(signature)
        if sig_algo != "ecdsa-sha2-#{curve}"
          raise DecodeError, "bad signature algorithm: #{sig_algo.inspect}"
        end

        openssl_sig = self.class.openssl_signature(ssh_sig)
        digest = DIGEST_FOR_CURVE[curve]

        openssl.verify(digest.new, openssl_sig, signed_data)
      end

      # RFC4253 binary encoding of the public key.
      #
      # Returns a binary String.
      def rfc4253
        Encoding.encode_fields(
          [:string, algo],
          [:string, curve],
          [:string, public_key_bytes],
        )
      end

      # Is this public key equal to another public key?
      #
      # other - Another SSHData::PublicKey::Base instance to compare with.
      #
      # Returns boolean.
      def ==(other)
        super && other.curve == curve && other.public_key_bytes == public_key_bytes
      end

      # The digest algorithm to use with this key's curve.
      #
      # Returns an OpenSSL::Digest.
      def digest
        DIGEST_FOR_CURVE[curve]
      end

      private

      def asn1
        unless name = OPENSSL_CURVE_NAME_FOR_CURVE[curve]
          raise DecodeError, "unknown curve: #{curve.inspect}"
        end

        OpenSSL::ASN1::Sequence.new([
          OpenSSL::ASN1::Sequence.new([
            OpenSSL::ASN1::ObjectId.new("id-ecPublicKey"),
            OpenSSL::ASN1::ObjectId.new(name),
          ]),
          OpenSSL::ASN1::BitString.new(public_key_bytes),
        ])
      end
    end
  end
end