1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
|
# Owner(s): ["oncall: package/deploy"]
import importlib
from io import BytesIO
from sys import version_info
from textwrap import dedent
from unittest import skipIf
import torch.nn
from torch.package import EmptyMatchError, Importer, PackageExporter, PackageImporter
from torch.package.package_exporter import PackagingError
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase
class TestDependencyAPI(PackageTestCase):
"""Dependency management API tests.
- mock()
- extern()
- deny()
"""
def test_extern(self):
buffer = BytesIO()
with PackageExporter(buffer) as he:
he.extern(["package_a.subpackage", "module_a"])
he.save_source_string("foo", "import package_a.subpackage; import module_a")
buffer.seek(0)
hi = PackageImporter(buffer)
import module_a
import package_a.subpackage
module_a_im = hi.import_module("module_a")
hi.import_module("package_a.subpackage")
package_a_im = hi.import_module("package_a")
self.assertIs(module_a, module_a_im)
self.assertIsNot(package_a, package_a_im)
self.assertIs(package_a.subpackage, package_a_im.subpackage)
def test_extern_glob(self):
buffer = BytesIO()
with PackageExporter(buffer) as he:
he.extern(["package_a.*", "module_*"])
he.save_module("package_a")
he.save_source_string(
"test_module",
dedent(
"""\
import package_a.subpackage
import module_a
"""
),
)
buffer.seek(0)
hi = PackageImporter(buffer)
import module_a
import package_a.subpackage
module_a_im = hi.import_module("module_a")
hi.import_module("package_a.subpackage")
package_a_im = hi.import_module("package_a")
self.assertIs(module_a, module_a_im)
self.assertIsNot(package_a, package_a_im)
self.assertIs(package_a.subpackage, package_a_im.subpackage)
def test_extern_glob_allow_empty(self):
"""
Test that an error is thrown when a extern glob is specified with allow_empty=True
and no matching module is required during packaging.
"""
import package_a.subpackage # noqa: F401
buffer = BytesIO()
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
with PackageExporter(buffer) as exporter:
exporter.extern(include=["package_b.*"], allow_empty=False)
exporter.save_module("package_a.subpackage")
def test_deny(self):
"""
Test marking packages as "deny" during export.
"""
buffer = BytesIO()
with self.assertRaisesRegex(PackagingError, "denied"):
with PackageExporter(buffer) as exporter:
exporter.deny(["package_a.subpackage", "module_a"])
exporter.save_source_string("foo", "import package_a.subpackage")
def test_deny_glob(self):
"""
Test marking packages as "deny" using globs instead of package names.
"""
buffer = BytesIO()
with self.assertRaises(PackagingError):
with PackageExporter(buffer) as exporter:
exporter.deny(["package_a.*", "module_*"])
exporter.save_source_string(
"test_module",
dedent(
"""\
import package_a.subpackage
import module_a
"""
),
)
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
def test_mock(self):
buffer = BytesIO()
with PackageExporter(buffer) as he:
he.mock(["package_a.subpackage", "module_a"])
# Import something that dependso n package_a.subpackage
he.save_source_string("foo", "import package_a.subpackage")
buffer.seek(0)
hi = PackageImporter(buffer)
import package_a.subpackage
_ = package_a.subpackage
import module_a
_ = module_a
m = hi.import_module("package_a.subpackage")
r = m.result
with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
r()
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
def test_mock_glob(self):
buffer = BytesIO()
with PackageExporter(buffer) as he:
he.mock(["package_a.*", "module*"])
he.save_module("package_a")
he.save_source_string(
"test_module",
dedent(
"""\
import package_a.subpackage
import module_a
"""
),
)
buffer.seek(0)
hi = PackageImporter(buffer)
import package_a.subpackage
_ = package_a.subpackage
import module_a
_ = module_a
m = hi.import_module("package_a.subpackage")
r = m.result
with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
r()
def test_mock_glob_allow_empty(self):
"""
Test that an error is thrown when a mock glob is specified with allow_empty=True
and no matching module is required during packaging.
"""
import package_a.subpackage # noqa: F401
buffer = BytesIO()
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
with PackageExporter(buffer) as exporter:
exporter.mock(include=["package_b.*"], allow_empty=False)
exporter.save_module("package_a.subpackage")
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
def test_pickle_mocked(self):
import package_a.subpackage
obj = package_a.subpackage.PackageASubpackageObject()
obj2 = package_a.PackageAObject(obj)
buffer = BytesIO()
with self.assertRaises(PackagingError):
with PackageExporter(buffer) as he:
he.mock(include="package_a.subpackage")
he.intern("**")
he.save_pickle("obj", "obj.pkl", obj2)
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
def test_pickle_mocked_all(self):
import package_a.subpackage
obj = package_a.subpackage.PackageASubpackageObject()
obj2 = package_a.PackageAObject(obj)
buffer = BytesIO()
with PackageExporter(buffer) as he:
he.intern(include="package_a.**")
he.mock("**")
he.save_pickle("obj", "obj.pkl", obj2)
def test_allow_empty_with_error(self):
"""If an error occurs during packaging, it should not be shadowed by the allow_empty error."""
buffer = BytesIO()
with self.assertRaises(ModuleNotFoundError):
with PackageExporter(buffer) as pe:
# Even though we did not extern a module that matches this
# pattern, we want to show the save_module error, not the allow_empty error.
pe.extern("foo", allow_empty=False)
pe.save_module("aodoifjodisfj") # will error
# we never get here, so technically the allow_empty check
# should raise an error. However, the error above is more
# informative to what's actually going wrong with packaging.
pe.save_source_string("bar", "import foo\n")
def test_implicit_intern(self):
"""The save_module APIs should implicitly intern the module being saved."""
import package_a # noqa: F401
buffer = BytesIO()
with PackageExporter(buffer) as he:
he.save_module("package_a")
def test_intern_error(self):
"""Failure to handle all dependencies should lead to an error."""
import package_a.subpackage
obj = package_a.subpackage.PackageASubpackageObject()
obj2 = package_a.PackageAObject(obj)
buffer = BytesIO()
with self.assertRaises(PackagingError) as e:
with PackageExporter(buffer) as he:
he.save_pickle("obj", "obj.pkl", obj2)
self.assertEqual(
str(e.exception),
dedent(
"""
* Module did not match against any action pattern. Extern, mock, or intern it.
package_a
package_a.subpackage
"""
),
)
# Interning all dependencies should work
with PackageExporter(buffer) as he:
he.intern(["package_a", "package_a.subpackage"])
he.save_pickle("obj", "obj.pkl", obj2)
@skipIf(IS_WINDOWS, "extension modules have a different file extension on windows")
def test_broken_dependency(self):
"""A unpackageable dependency should raise a PackagingError."""
def create_module(name):
spec = importlib.machinery.ModuleSpec(name, self, is_package=False) # type: ignore[arg-type]
module = importlib.util.module_from_spec(spec)
ns = module.__dict__
ns["__spec__"] = spec
ns["__loader__"] = self
ns["__file__"] = f"{name}.so"
ns["__cached__"] = None
return module
class BrokenImporter(Importer):
def __init__(self):
self.modules = {
"foo": create_module("foo"),
"bar": create_module("bar"),
}
def import_module(self, module_name):
return self.modules[module_name]
buffer = BytesIO()
with self.assertRaises(PackagingError) as e:
with PackageExporter(buffer, importer=BrokenImporter()) as exporter:
exporter.intern(["foo", "bar"])
exporter.save_source_string("my_module", "import foo; import bar")
self.assertEqual(
str(e.exception),
dedent(
"""
* Module is a C extension module. torch.package supports Python modules only.
foo
bar
"""
),
)
def test_invalid_import(self):
"""An incorrectly-formed import should raise a PackagingError."""
buffer = BytesIO()
with self.assertRaises(PackagingError) as e:
with PackageExporter(buffer) as exporter:
# This import will fail to load.
exporter.save_source_string("foo", "from ........ import lol")
self.assertEqual(
str(e.exception),
dedent(
"""
* Dependency resolution failed.
foo
Context: attempted relative import beyond top-level package
"""
),
)
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
def test_repackage_mocked_module(self):
"""Re-packaging a package that contains a mocked module should work correctly."""
buffer = BytesIO()
with PackageExporter(buffer) as exporter:
exporter.mock("package_a")
exporter.save_source_string("foo", "import package_a")
buffer.seek(0)
importer = PackageImporter(buffer)
foo = importer.import_module("foo")
# "package_a" should be mocked out.
with self.assertRaises(NotImplementedError):
foo.package_a.get_something()
# Re-package the model, but intern the previously-mocked module and mock
# everything else.
buffer2 = BytesIO()
with PackageExporter(buffer2, importer=importer) as exporter:
exporter.intern("package_a")
exporter.mock("**")
exporter.save_source_string("foo", "import package_a")
buffer2.seek(0)
importer2 = PackageImporter(buffer2)
foo2 = importer2.import_module("foo")
# "package_a" should still be mocked out.
with self.assertRaises(NotImplementedError):
foo2.package_a.get_something()
def test_externing_c_extension(self):
"""Externing c extensions modules should allow us to still access them especially those found in torch._C."""
buffer = BytesIO()
# The C extension module in question is F.gelu which comes from torch._C._nn
model = torch.nn.TransformerEncoderLayer(
d_model=64,
nhead=2,
dim_feedforward=64,
dropout=1.0,
batch_first=True,
activation="gelu",
norm_first=True,
)
with PackageExporter(buffer) as e:
e.extern("torch.**")
e.intern("**")
e.save_pickle("model", "model.pkl", model)
buffer.seek(0)
imp = PackageImporter(buffer)
imp.load_pickle("model", "model.pkl")
if __name__ == "__main__":
run_tests()
|