1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
#!/usr/bin/env python3
# Owner(s): ["oncall: r2p"]
import filecmp
import json
import os
import shutil
import tempfile
import unittest
from unittest.mock import patch
from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler
from torch.distributed.elastic.multiprocessing.errors.handlers import get_error_handler
def raise_exception_fn():
raise RuntimeError("foobar")
class GetErrorHandlerTest(unittest.TestCase):
def test_get_error_handler(self):
self.assertTrue(isinstance(get_error_handler(), ErrorHandler))
class ErrorHandlerTest(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__)
self.test_error_file = os.path.join(self.test_dir, "error.json")
def tearDown(self):
shutil.rmtree(self.test_dir)
@patch("faulthandler.enable")
def test_initialize(self, fh_enable_mock):
ErrorHandler().initialize()
fh_enable_mock.assert_called_once()
@patch("faulthandler.enable", side_effect=RuntimeError)
def test_initialize_error(self, fh_enable_mock):
# makes sure that initialize handles errors gracefully
ErrorHandler().initialize()
fh_enable_mock.assert_called_once()
def test_record_exception(self):
with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": self.test_error_file}):
eh = ErrorHandler()
eh.initialize()
try:
raise_exception_fn()
except Exception as e:
eh.record_exception(e)
with open(self.test_error_file) as fp:
err = json.load(fp)
# error file content example:
# {
# "message": {
# "message": "RuntimeError: foobar",
# "extraInfo": {
# "py_callstack": "Traceback (most recent call last):\n <... OMITTED ...>",
# "timestamp": "1605774851"
# }
# }
self.assertIsNotNone(err["message"]["message"])
self.assertIsNotNone(err["message"]["extraInfo"]["py_callstack"])
self.assertIsNotNone(err["message"]["extraInfo"]["timestamp"])
def test_record_exception_no_error_file(self):
# make sure record does not fail when no error file is specified in env vars
with patch.dict(os.environ, {}):
eh = ErrorHandler()
eh.initialize()
try:
raise_exception_fn()
except Exception as e:
eh.record_exception(e)
def test_dump_error_file(self):
src_error_file = os.path.join(self.test_dir, "src_error.json")
eh = ErrorHandler()
with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": src_error_file}):
eh.record_exception(RuntimeError("foobar"))
with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": self.test_error_file}):
eh.dump_error_file(src_error_file)
self.assertTrue(filecmp.cmp(src_error_file, self.test_error_file))
with patch.dict(os.environ, {}):
eh.dump_error_file(src_error_file)
# just validate that dump_error_file works when
# my error file is not set
# should just log an error with src_error_file pretty printed
def test_dump_error_file_overwrite_existing(self):
dst_error_file = os.path.join(self.test_dir, "dst_error.json")
src_error_file = os.path.join(self.test_dir, "src_error.json")
eh = ErrorHandler()
with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": dst_error_file}):
eh.record_exception(RuntimeError("foo"))
with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": src_error_file}):
eh.record_exception(RuntimeError("bar"))
with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": dst_error_file}):
eh.dump_error_file(src_error_file)
self.assertTrue(filecmp.cmp(src_error_file, dst_error_file))
|