1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
|
# -*- coding: utf-8 -*-
# Owner(s): ["oncall: package/deploy"]
from io import BytesIO
from sys import version_info
from textwrap import dedent
from unittest import skipIf
from torch.package import PackageExporter, PackageImporter
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase
@skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
class TestResources(PackageTestCase):
"""Tests for access APIs for packaged resources."""
def test_resource_reader(self):
"""Test compliance with the get_resource_reader importlib API."""
buffer = BytesIO()
with PackageExporter(buffer) as pe:
# Layout looks like:
# package
# ├── one/
# │ ├── a.txt
# │ ├── b.txt
# │ ├── c.txt
# │ └── three/
# │ ├── d.txt
# │ └── e.txt
# └── two/
# ├── f.txt
# └── g.txt
pe.save_text("one", "a.txt", "hello, a!")
pe.save_text("one", "b.txt", "hello, b!")
pe.save_text("one", "c.txt", "hello, c!")
pe.save_text("one.three", "d.txt", "hello, d!")
pe.save_text("one.three", "e.txt", "hello, e!")
pe.save_text("two", "f.txt", "hello, f!")
pe.save_text("two", "g.txt", "hello, g!")
buffer.seek(0)
importer = PackageImporter(buffer)
reader_one = importer.get_resource_reader("one")
with self.assertRaises(FileNotFoundError):
reader_one.resource_path("a.txt")
self.assertTrue(reader_one.is_resource("a.txt"))
self.assertEqual(reader_one.open_resource("a.txt").getbuffer(), b"hello, a!")
self.assertFalse(reader_one.is_resource("three"))
reader_one_contents = list(reader_one.contents())
self.assertSequenceEqual(
reader_one_contents, ["a.txt", "b.txt", "c.txt", "three"]
)
reader_two = importer.get_resource_reader("two")
self.assertTrue(reader_two.is_resource("f.txt"))
self.assertEqual(reader_two.open_resource("f.txt").getbuffer(), b"hello, f!")
reader_two_contents = list(reader_two.contents())
self.assertSequenceEqual(reader_two_contents, ["f.txt", "g.txt"])
reader_one_three = importer.get_resource_reader("one.three")
self.assertTrue(reader_one_three.is_resource("d.txt"))
self.assertEqual(
reader_one_three.open_resource("d.txt").getbuffer(), b"hello, d!"
)
reader_one_three_contenst = list(reader_one_three.contents())
self.assertSequenceEqual(reader_one_three_contenst, ["d.txt", "e.txt"])
self.assertIsNone(importer.get_resource_reader("nonexistent_package"))
def test_package_resource_access(self):
"""Packaged modules should be able to use the importlib.resources API to access
resources saved in the package.
"""
mod_src = dedent(
"""\
import importlib.resources
import my_cool_resources
def secret_message():
return importlib.resources.read_text(my_cool_resources, 'sekrit.txt')
"""
)
buffer = BytesIO()
with PackageExporter(buffer) as pe:
pe.save_source_string("foo.bar", mod_src)
pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")
buffer.seek(0)
importer = PackageImporter(buffer)
self.assertEqual(
importer.import_module("foo.bar").secret_message(), "my sekrit plays"
)
def test_importer_access(self):
buffer = BytesIO()
with PackageExporter(buffer) as he:
he.save_text("main", "main", "my string")
he.save_binary("main", "main_binary", "my string".encode("utf-8"))
src = dedent(
"""\
import importlib
import torch_package_importer as resources
t = resources.load_text('main', 'main')
b = resources.load_binary('main', 'main_binary')
"""
)
he.save_source_string("main", src, is_package=True)
buffer.seek(0)
hi = PackageImporter(buffer)
m = hi.import_module("main")
self.assertEqual(m.t, "my string")
self.assertEqual(m.b, "my string".encode("utf-8"))
def test_resource_access_by_path(self):
"""
Tests that packaged code can used importlib.resources.path.
"""
buffer = BytesIO()
with PackageExporter(buffer) as he:
he.save_binary("string_module", "my_string", "my string".encode("utf-8"))
src = dedent(
"""\
import importlib.resources
import string_module
with importlib.resources.path(string_module, 'my_string') as path:
with open(path, mode='r', encoding='utf-8') as f:
s = f.read()
"""
)
he.save_source_string("main", src, is_package=True)
buffer.seek(0)
hi = PackageImporter(buffer)
m = hi.import_module("main")
self.assertEqual(m.s, "my string")
if __name__ == "__main__":
run_tests()
|