1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
|
import logging
import re
import socket
import OpenSSL
import time
import sys
from datetime import datetime
from datetime import tzinfo
from datetime import timedelta
from http.client import HTTPConnection
from urllib.parse import urlparse
from cryptography import x509
SEC_PER_DAY = 24 * 60 * 60
log = logging.getLogger(__name__)
class MDCertUtil(object):
# Utility class for inspecting certificates in test cases
# Uses PyOpenSSL: https://pyopenssl.org/en/stable/index.html
@classmethod
def load_server_cert(cls, host_ip, host_port, host_name, tls=None, ciphers=None):
ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
if tls is not None and tls != 1.0:
ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1)
if tls is not None and tls != 1.1:
ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1_1)
if tls is not None and tls != 1.2:
ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1_2)
if tls is not None and tls != 1.3 and hasattr(OpenSSL.SSL, "OP_NO_TLSv1_3"):
ctx.set_options(OpenSSL.SSL.OP_NO_TLSv1_3)
if ciphers is not None:
ctx.set_cipher_list(ciphers)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
connection = OpenSSL.SSL.Connection(ctx, s)
connection.connect((host_ip, int(host_port)))
connection.setblocking(1)
connection.set_tlsext_host_name(host_name.encode('utf-8'))
connection.do_handshake()
peer_cert = connection.get_peer_certificate()
return MDCertUtil(None, cert=peer_cert)
@classmethod
def parse_pem_cert(cls, text):
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, text.encode('utf-8'))
return MDCertUtil(None, cert=cert)
@classmethod
def get_plain(cls, url, timeout):
server = urlparse(url)
try_until = time.time() + timeout
while time.time() < try_until:
# noinspection PyBroadException
try:
c = HTTPConnection(server.hostname, server.port, timeout=timeout)
c.request('GET', server.path)
resp = c.getresponse()
data = resp.read()
c.close()
return data
except IOError:
log.debug("connect error:", sys.exc_info()[0])
time.sleep(.1)
except:
log.error("Unexpected error:", sys.exc_info()[0])
log.error("Unable to contact server after %d sec" % timeout)
return None
def __init__(self, cert_path, cert=None):
if cert_path is not None:
self.cert_path = cert_path
# load certificate and private key
if cert_path.startswith("http"):
cert_data = self.get_plain(cert_path, 1)
else:
cert_data = MDCertUtil._load_binary_file(cert_path)
for file_type in (OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_ASN1):
try:
self.cert = OpenSSL.crypto.load_certificate(file_type, cert_data)
except Exception as error:
self.error = error
if cert is not None:
self.cert = cert
if self.cert is None:
raise self.error
def get_issuer(self):
return self.cert.get_issuer()
def get_serial(self):
# the string representation of a serial number is not unique. Some
# add leading 0s to align with word boundaries.
return ("%lx" % (self.cert.get_serial_number())).upper()
@staticmethod
def _get_serial(cert) -> int:
if isinstance(cert, x509.Certificate):
return cert.serial_number
if isinstance(cert, MDCertUtil):
return cert.get_serial_number()
elif isinstance(cert, OpenSSL.crypto.X509):
return cert.get_serial_number()
elif isinstance(cert, str):
# assume a hex number
return int(cert, 16)
elif isinstance(cert, int):
return cert
return 0
def get_serial_number(self):
return self._get_serial(self.cert)
def same_serial_as(self, other):
return self._get_serial(self.cert) == self._get_serial(other)
def get_not_before(self):
tsp = self.cert.get_notBefore()
return self._parse_tsp(tsp)
def get_not_after(self):
tsp = self.cert.get_notAfter()
return self._parse_tsp(tsp)
def get_cn(self):
return self.cert.get_subject().CN
def get_key_length(self):
return self.cert.get_pubkey().bits()
def get_san_list(self):
text = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_TEXT, self.cert).decode("utf-8")
m = re.search(r"X509v3 Subject Alternative Name:(\s+critical)?\s*(.*)", text)
sans_list = []
if m:
sans_list = m.group(2).split(",")
def _strip_prefix(s):
return s.split(":")[1] if s.strip().startswith("DNS:") else s.strip()
return list(map(_strip_prefix, sans_list))
def get_must_staple(self):
text = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_TEXT, self.cert).decode("utf-8")
m = re.search(r"1.3.6.1.5.5.7.1.24:\s*\n\s*0....", text)
if not m:
# Newer openssl versions print this differently
m = re.search(r"TLS Feature:\s*\n\s*status_request\s*\n", text)
return m is not None
@classmethod
def validate_privkey(cls, privkey_path, passphrase=None):
privkey_data = cls._load_binary_file(privkey_path)
if passphrase:
privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, privkey_data, passphrase)
else:
privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, privkey_data)
return privkey.check()
def validate_cert_matches_priv_key(self, privkey_path):
# Verifies that the private key and cert match.
privkey_data = MDCertUtil._load_binary_file(privkey_path)
privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, privkey_data)
context = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
context.use_privatekey(privkey)
context.use_certificate(self.cert)
context.check_privatekey()
# --------- _utils_ ---------
def astr(self, s):
return s.decode('utf-8')
def _parse_tsp(self, tsp):
# timestampss returned by PyOpenSSL are bytes
# parse date and time part
s = ("%s-%s-%s %s:%s:%s" % (self.astr(tsp[0:4]), self.astr(tsp[4:6]), self.astr(tsp[6:8]),
self.astr(tsp[8:10]), self.astr(tsp[10:12]), self.astr(tsp[12:14])))
timestamp = datetime.strptime(s, '%Y-%m-%d %H:%M:%S')
# adjust timezone
tz_h, tz_m = 0, 0
m = re.match(r"([+\-]\d{2})(\d{2})", self.astr(tsp[14:]))
if m:
tz_h, tz_m = int(m.group(1)), int(m.group(2)) if tz_h > 0 else -1 * int(m.group(2))
return timestamp.replace(tzinfo=self.FixedOffset(60 * tz_h + tz_m))
@classmethod
def _load_binary_file(cls, path):
with open(path, mode="rb") as file:
return file.read()
class FixedOffset(tzinfo):
def __init__(self, offset):
self.__offset = timedelta(minutes=offset)
def utcoffset(self, dt):
return self.__offset
def tzname(self, dt):
return None
def dst(self, dt):
return timedelta(0)
|