1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
|
import sys
import os
import io
import difflib
import tempfile
import os.path as op
from subprocess import Popen, check_output, PIPE, STDOUT, CalledProcessError
import pickle
import pickletools
from contextlib import contextmanager
from concurrent.futures import ProcessPoolExecutor
import psutil
from cloudpickle import dumps
from subprocess import TimeoutExpired
loads = pickle.loads
TIMEOUT = 60
TEST_GLOBALS = "a test value"
def make_local_function():
def g(x):
# this function checks that the globals are correctly handled and that
# the builtins are available
assert TEST_GLOBALS == "a test value"
return sum(range(10))
return g
def _make_cwd_env():
"""Helper to prepare environment for the child processes"""
cloudpickle_repo_folder = op.normpath(op.join(op.dirname(__file__), ".."))
env = os.environ.copy()
pythonpath = "{src}{sep}tests{pathsep}{src}".format(
src=cloudpickle_repo_folder, sep=os.sep, pathsep=os.pathsep
)
env["PYTHONPATH"] = pythonpath
return cloudpickle_repo_folder, env
def subprocess_pickle_string(input_data, protocol=None, timeout=TIMEOUT, add_env=None):
"""Retrieve pickle string of an object generated by a child Python process
Pickle the input data into a buffer, send it to a subprocess via
stdin, expect the subprocess to unpickle, re-pickle that data back
and send it back to the parent process via stdout for final unpickling.
>>> testutils.subprocess_pickle_string([1, 'a', None], protocol=2)
b'\x80\x02]q\x00(K\x01X\x01\x00\x00\x00aq\x01Ne.'
"""
# run then pickle_echo(protocol=protocol) in __main__:
# Protect stderr from any warning, as we will assume an error will happen
# if it is not empty. A concrete example is pytest using the imp module,
# which is deprecated in python 3.8
cmd = [sys.executable, "-W ignore", __file__, "--protocol", str(protocol)]
cwd, env = _make_cwd_env()
if add_env:
env.update(add_env)
proc = Popen(
cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd, env=env, bufsize=4096
)
pickle_string = dumps(input_data, protocol=protocol)
try:
comm_kwargs = {}
comm_kwargs["timeout"] = timeout
out, err = proc.communicate(pickle_string, **comm_kwargs)
if proc.returncode != 0 or len(err):
message = "Subprocess returned %d: " % proc.returncode
message += err.decode("utf-8")
raise RuntimeError(message)
return out
except TimeoutExpired as e:
proc.kill()
out, err = proc.communicate()
message = "\n".join([out.decode("utf-8"), err.decode("utf-8")])
raise RuntimeError(message) from e
def subprocess_pickle_echo(input_data, protocol=None, timeout=TIMEOUT, add_env=None):
"""Echo function with a child Python process
Pickle the input data into a buffer, send it to a subprocess via
stdin, expect the subprocess to unpickle, re-pickle that data back
and send it back to the parent process via stdout for final unpickling.
>>> subprocess_pickle_echo([1, 'a', None])
[1, 'a', None]
"""
out = subprocess_pickle_string(
input_data, protocol=protocol, timeout=timeout, add_env=add_env
)
return loads(out)
def _read_all_bytes(stream_in, chunk_size=4096):
all_data = b""
while True:
data = stream_in.read(chunk_size)
all_data += data
if len(data) < chunk_size:
break
return all_data
def pickle_echo(stream_in=None, stream_out=None, protocol=None):
"""Read a pickle from stdin and pickle it back to stdout"""
if stream_in is None:
stream_in = sys.stdin
if stream_out is None:
stream_out = sys.stdout
# Force the use of bytes streams under Python 3
if hasattr(stream_in, "buffer"):
stream_in = stream_in.buffer
if hasattr(stream_out, "buffer"):
stream_out = stream_out.buffer
input_bytes = _read_all_bytes(stream_in)
stream_in.close()
obj = loads(input_bytes)
repickled_bytes = dumps(obj, protocol=protocol)
stream_out.write(repickled_bytes)
stream_out.close()
def call_func(payload, protocol):
"""Remote function call that uses cloudpickle to transport everthing"""
func, args, kwargs = loads(payload)
try:
result = func(*args, **kwargs)
except BaseException as e:
result = e
return dumps(result, protocol=protocol)
class _Worker:
def __init__(self, protocol=None):
self.protocol = protocol
self.pool = ProcessPoolExecutor(max_workers=1)
self.pool.submit(id, 42).result() # start the worker process
def run(self, func, *args, **kwargs):
"""Synchronous remote function call"""
input_payload = dumps((func, args, kwargs), protocol=self.protocol)
result_payload = self.pool.submit(
call_func, input_payload, self.protocol
).result()
result = loads(result_payload)
if isinstance(result, BaseException):
raise result
return result
def memsize(self):
workers_pids = [
p.pid if hasattr(p, "pid") else p for p in list(self.pool._processes)
]
num_workers = len(workers_pids)
if num_workers == 0:
return 0
elif num_workers > 1:
raise RuntimeError("Unexpected number of workers: %d" % num_workers)
return psutil.Process(workers_pids[0]).memory_info().rss
def close(self):
self.pool.shutdown(wait=True)
@contextmanager
def subprocess_worker(protocol=None):
worker = _Worker(protocol=protocol)
yield worker
worker.close()
def assert_run_python_script(source_code, timeout=TIMEOUT):
"""Utility to help check pickleability of objects defined in __main__
The script provided in the source code should return 0 and not print
anything on stderr or stdout.
"""
fd, source_file = tempfile.mkstemp(suffix="_src_test_cloudpickle.py")
os.close(fd)
try:
with open(source_file, "wb") as f:
f.write(source_code.encode("utf-8"))
cmd = [sys.executable, "-W ignore", source_file]
cwd, env = _make_cwd_env()
kwargs = {
"cwd": cwd,
"stderr": STDOUT,
"env": env,
}
# If coverage is running, pass the config file to the subprocess
coverage_rc = os.environ.get("COVERAGE_PROCESS_START")
if coverage_rc:
kwargs["env"]["COVERAGE_PROCESS_START"] = coverage_rc
kwargs["timeout"] = timeout
try:
try:
out = check_output(cmd, **kwargs)
except CalledProcessError as e:
raise RuntimeError(
"script errored with output:\n%s" % e.output.decode("utf-8")
) from e
if out != b"":
raise AssertionError(out.decode("utf-8"))
except TimeoutExpired as e:
raise RuntimeError(
"script timeout, output so far:\n%s" % e.output.decode("utf-8")
) from e
finally:
os.unlink(source_file)
def check_deterministic_pickle(a, b):
"""Check that two pickle output are bitwise equal.
If it is not the case, print the diff between the disassembled pickle
payloads.
This helper is useful to investigate non-deterministic pickling.
"""
if a != b:
with io.StringIO() as out:
pickletools.dis(pickletools.optimize(a), out)
a_out = out.getvalue()
# Remove the 11 first characters of each line to remove the bytecode offset
# of each object, which is different on each line for very small differences,
# making the diff very hard to read.
a_out = "\n".join(ll[11:] for ll in a_out.splitlines())
with io.StringIO() as out:
pickletools.dis(pickletools.optimize(b), out)
b_out = out.getvalue()
b_out = "\n".join(ll[11:] for ll in b_out.splitlines())
assert a_out == b_out
full_diff = difflib.context_diff(
a_out.splitlines(keepends=True), b_out.splitlines(keepends=True)
)
full_diff = "".join(full_diff)
if len(full_diff) > 1500:
full_diff = full_diff[:1494] + " [...]"
raise AssertionError(
"Pickle payloads are not bitwise equal:\n"
+ full_diff
)
if __name__ == "__main__":
protocol = int(sys.argv[sys.argv.index("--protocol") + 1])
pickle_echo(protocol=protocol)
|