File: test_license.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (50 lines) | stat: -rw-r--r-- 1,825 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Owner(s): ["module: unknown"]

import glob
import io
import os
import unittest

import torch
from torch.testing._internal.common_utils import TestCase, run_tests


try:
    from third_party.build_bundled import create_bundled
except ImportError:
    create_bundled = None

license_file = 'third_party/LICENSES_BUNDLED.txt'
starting_txt = 'The Pytorch repository and source distributions bundle'
site_packages = os.path.dirname(os.path.dirname(torch.__file__))
distinfo = glob.glob(os.path.join(site_packages, 'torch-*dist-info'))

class TestLicense(TestCase):

    @unittest.skipIf(not create_bundled, "can only be run in a source tree")
    def test_license_for_wheel(self):
        current = io.StringIO()
        create_bundled('third_party', current)
        with open(license_file) as fid:
            src_tree = fid.read()
        if not src_tree == current.getvalue():
            raise AssertionError(
                f'the contents of "{license_file}" do not '
                'match the current state of the third_party files. Use '
                '"python third_party/build_bundled.py" to regenerate it')

    @unittest.skipIf(len(distinfo) == 0, "no installation in site-package to test")
    def test_distinfo_license(self):
        """If run when pytorch is installed via a wheel, the license will be in
        site-package/torch-*dist-info/LICENSE. Make sure it contains the third
        party bundle of licenses"""

        if len(distinfo) > 1:
            raise AssertionError('Found too many "torch-*dist-info" directories '
                                 f'in "{site_packages}, expected only one')
        with open(os.path.join(os.path.join(distinfo[0], 'LICENSE'))) as fid:
            txt = fid.read()
            self.assertTrue(starting_txt in txt)

if __name__ == '__main__':
    run_tests()