Description: trict error checking in DER decoding of integers and sequences
 This fixes CVE-2019-14853 and CVE-2019-14859.
From: Hubert Kario <hkario@redhat.com>
Origin: https://github.com/warner/python-ecdsa/pull/124
Last-Update: 2019-10-21


--- a/ecdsa/__init__.py
+++ b/ecdsa/__init__.py
@@ -1,9 +1,12 @@
 __all__ = ["curves", "der", "ecdsa", "ellipticcurve", "keys", "numbertheory",
            "test_pyecdsa", "util", "six"]
-from .keys import SigningKey, VerifyingKey, BadSignatureError, BadDigestError
+from .keys import SigningKey, VerifyingKey, BadSignatureError, BadDigestError,\
+        MalformedPointError
 from .curves import NIST192p, NIST224p, NIST256p, NIST384p, NIST521p, SECP256k1
+from .der import UnexpectedDER
 
 _hush_pyflakes = [SigningKey, VerifyingKey, BadSignatureError, BadDigestError,
+                  MalformedPointError, UnexpectedDER,
                   NIST192p, NIST224p, NIST256p, NIST384p, NIST521p, SECP256k1]
 del _hush_pyflakes
 
--- a/ecdsa/der.py
+++ b/ecdsa/der.py
@@ -60,10 +60,15 @@
     return tag, body, rest
 
 def remove_sequence(string):
+    if not string:
+        raise UnexpectedDER("Empty string does not encode a sequence")
     if not string.startswith(b("\x30")):
-        n = string[0] if isinstance(string[0], integer_types) else ord(string[0])
-        raise UnexpectedDER("wanted sequence (0x30), got 0x%02x" % n)
+        n = string[0] if isinstance(string[0], integer_types) else \
+                ord(string[0])
+        raise UnexpectedDER("wanted type 'sequence' (0x30), got 0x%02x" % n)
     length, lengthlength = read_length(string[1:])
+    if length > len(string) - 1 - lengthlength:
+        raise UnexpectedDER("Length longer than the provided buffer")
     endseq = 1+lengthlength+length
     return string[1+lengthlength:endseq], string[endseq:]
 
@@ -96,14 +101,33 @@
     return tuple(numbers), rest
 
 def remove_integer(string):
+    if not string:
+        raise UnexpectedDER("Empty string is an invalid encoding of an "
+                            "integer")
     if not string.startswith(b("\x02")):
-        n = string[0] if isinstance(string[0], integer_types) else ord(string[0])
-        raise UnexpectedDER("wanted integer (0x02), got 0x%02x" % n)
+        n = string[0] if isinstance(string[0], integer_types) \
+                else ord(string[0])
+        raise UnexpectedDER("wanted type 'integer' (0x02), got 0x%02x" % n)
     length, llen = read_length(string[1:])
+    if length > len(string) - 1 - llen:
+        raise UnexpectedDER("Length longer than provided buffer")
+    if length == 0:
+        raise UnexpectedDER("0-byte long encoding of integer")
     numberbytes = string[1+llen:1+llen+length]
     rest = string[1+llen+length:]
-    nbytes = numberbytes[0] if isinstance(numberbytes[0], integer_types) else ord(numberbytes[0])
-    assert nbytes < 0x80 # can't support negative numbers yet
+    msb = numberbytes[0] if isinstance(numberbytes[0], integer_types) \
+            else ord(numberbytes[0])
+    if not msb < 0x80:
+        raise UnexpectedDER("Negative integers are not supported")
+    # check if the encoding is the minimal one (DER requirement)
+    if length > 1 and not msb:
+        # leading zero byte is allowed if the integer would have been
+        # considered a negative number otherwise
+        smsb = numberbytes[1] if isinstance(numberbytes[1], integer_types) \
+                else ord(numberbytes[1])
+        if smsb < 0x80:
+            raise UnexpectedDER("Invalid encoding of integer, unnecessary "
+                                "zero padding bytes")
     return int(binascii.hexlify(numberbytes), 16), rest
 
 def read_number(string):
@@ -133,6 +157,8 @@
     return int2byte(0x80|llen) + s
 
 def read_length(string):
+    if not string:
+        raise UnexpectedDER("Empty string can't encode valid length value")
     num = string[0] if isinstance(string[0], integer_types) else ord(string[0])
     if not (num & 0x80):
         # short form
@@ -140,8 +166,14 @@
     # else long-form: b0&0x7f is number of additional base256 length bytes,
     # big-endian
     llen = num & 0x7f
+    if not llen:
+        raise UnexpectedDER("Invalid length encoding, length of length is 0")
     if llen > len(string)-1:
-        raise UnexpectedDER("ran out of length bytes")
+        raise UnexpectedDER("Length of length longer than provided buffer")
+    # verify that the encoding is minimal possible (DER requirement)
+    msb = string[1] if isinstance(string[1], integer_types) else ord(string[1])
+    if not msb or llen == 1 and msb < 0x80:
+        raise UnexpectedDER("Not minimal encoding of length")
     return int(binascii.hexlify(string[1:1+llen]), 16), 1+llen
 
 def remove_bitstring(string):
--- a/ecdsa/keys.py
+++ b/ecdsa/keys.py
@@ -3,10 +3,11 @@
 from . import ecdsa
 from . import der
 from . import rfc6979
+from . import ellipticcurve
 from .curves import NIST192p, find_curve
 from .util import string_to_number, number_to_string, randrange
 from .util import sigencode_string, sigdecode_string
-from .util import oid_ecPublicKey, encoded_oid_ecPublicKey
+from .util import oid_ecPublicKey, encoded_oid_ecPublicKey, MalformedSignature
 from six import PY3, b
 from hashlib import sha1
 
@@ -15,6 +16,11 @@
 class BadDigestError(Exception):
     pass
 
+
+class MalformedPointError(AssertionError):
+    pass
+
+
 class VerifyingKey:
     def __init__(self, _error__please_use_generate=None):
         if not _error__please_use_generate:
@@ -33,17 +39,21 @@
     def from_string(klass, string, curve=NIST192p, hashfunc=sha1,
                     validate_point=True):
         order = curve.order
-        assert len(string) == curve.verifying_key_length, \
-               (len(string), curve.verifying_key_length)
+        if len(string) != curve.verifying_key_length:
+            raise MalformedPointError(
+                "Malformed encoding of public point. Expected string {0} bytes"
+                " long, received {1} bytes long string".format(
+                    curve.verifying_key_length, len(string)))
         xs = string[:curve.baselen]
         ys = string[curve.baselen:]
-        assert len(xs) == curve.baselen, (len(xs), curve.baselen)
-        assert len(ys) == curve.baselen, (len(ys), curve.baselen)
+        if len(xs) != curve.baselen:
+            raise MalformedPointError("Unexpected length of encoded x")
+        if len(ys) != curve.baselen:
+            raise MalformedPointError("Unexpected length of encoded y")
         x = string_to_number(xs)
         y = string_to_number(ys)
-        if validate_point:
-            assert ecdsa.point_is_valid(curve.generator, x, y)
-        from . import ellipticcurve
+        if validate_point and not ecdsa.point_is_valid(curve.generator, x, y):
+            raise MalformedPointError("Point does not lie on the curve")
         point = ellipticcurve.Point(curve.curve, x, y, order)
         return klass.from_public_point(point, curve, hashfunc)
 
@@ -65,13 +75,18 @@
         if empty != b(""):
             raise der.UnexpectedDER("trailing junk after DER pubkey objects: %s" %
                                     binascii.hexlify(empty))
-        assert oid_pk == oid_ecPublicKey, (oid_pk, oid_ecPublicKey)
+        if oid_pk != oid_ecPublicKey:
+            raise der.UnexpectedDER(
+                "Unexpected OID in encoding, received {0}, expected {1}"
+                .format(oid_pk, oid_ecPublicKey))
         curve = find_curve(oid_curve)
         point_str, empty = der.remove_bitstring(point_str_bitstring)
         if empty != b(""):
             raise der.UnexpectedDER("trailing junk after pubkey pointstring: %s" %
                                     binascii.hexlify(empty))
-        assert point_str.startswith(b("\x00\x04"))
+        if not point_str.startswith(b("\x00\x04")):
+            raise der.UnexpectedDER(
+                    "Unsupported or invalid encoding of pubcli key")
         return klass.from_string(point_str[2:], curve)
 
     def to_string(self):
@@ -106,11 +121,14 @@
                                  "for your digest (%d)" % (self.curve.name,
                                                            8*len(digest)))
         number = string_to_number(digest)
-        r, s = sigdecode(signature, self.pubkey.order)
+        try:
+            r, s = sigdecode(signature, self.pubkey.order)
+        except (der.UnexpectedDER, MalformedSignature) as e:
+            raise BadSignatureError("Malformed formatting of signature", e)
         sig = ecdsa.Signature(r, s)
         if self.pubkey.verifies(number, sig):
             return True
-        raise BadSignatureError
+        raise BadSignatureError("Signature verification failed")
 
 class SigningKey:
     def __init__(self, _error__please_use_generate=None):
@@ -134,7 +152,10 @@
         self.default_hashfunc = hashfunc
         self.baselen = curve.baselen
         n = curve.order
-        assert 1 <= secexp < n
+        if not 1 <= secexp < n:
+            raise MalformedPointError(
+                "Invalid value for secexp, expected integer between 1 and {0}"
+                .format(n))
         pubkey_point = curve.generator*secexp
         pubkey = ecdsa.Public_key(curve.generator, pubkey_point)
         pubkey.order = n
@@ -146,7 +167,10 @@
 
     @classmethod
     def from_string(klass, string, curve=NIST192p, hashfunc=sha1):
-        assert len(string) == curve.baselen, (len(string), curve.baselen)
+        if len(string) != curve.baselen:
+            raise MalformedPointError(
+                "Invalid length of private key, received {0}, expected {1}"
+                .format(len(string), curve.baselen))
         secexp = string_to_number(string)
         return klass.from_secret_exponent(secexp, curve, hashfunc)
 
--- a/ecdsa/util.py
+++ b/ecdsa/util.py
@@ -216,18 +216,38 @@
     return sigencode_der(r, s, order)
 
 
+class MalformedSignature(Exception):
+    pass
+
+
 def sigdecode_string(signature, order):
     l = orderlen(order)
-    assert len(signature) == 2*l, (len(signature), 2*l)
+    if not len(signature) == 2 * l:
+        raise MalformedSignature(
+                "Invalid length of signature, expected {0} bytes long, "
+                "provided string is {1} bytes long"
+                .format(2 * l, len(signature)))
     r = string_to_number_fixedlen(signature[:l], order)
     s = string_to_number_fixedlen(signature[l:], order)
     return r, s
 
 def sigdecode_strings(rs_strings, order):
+    if not len(rs_strings) == 2:
+        raise MalformedSignature(
+                "Invalid number of strings provided: {0}, expected 2"
+                .format(len(rs_strings)))
     (r_str, s_str) = rs_strings
     l = orderlen(order)
-    assert len(r_str) == l, (len(r_str), l)
-    assert len(s_str) == l, (len(s_str), l)
+    if not len(r_str) == l:
+        raise MalformedSignature(
+                "Invalid length of first string ('r' parameter), "
+                "expected {0} bytes long, provided string is {1} bytes long"
+                .format(l, len(r_str)))
+    if not len(s_str) == l:
+        raise MalformedSignature(
+                "Invalid length of second string ('s' parameter), "
+                "expected {0} bytes long, provided string is {1} bytes long"
+                .format(l, len(s_str)))
     r = string_to_number_fixedlen(r_str, order)
     s = string_to_number_fixedlen(s_str, order)
     return r, s
