File: test_draw.py

package info (click to toggle)
caffe-contrib 1.0.0%2Bgit20180821.99bd997-2
  • links: PTS, VCS
  • area: contrib
  • in suites: buster
  • size: 16,244 kB
  • sloc: cpp: 61,579; python: 5,783; makefile: 586; sh: 562
file content (37 lines) | stat: -rw-r--r-- 1,114 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import unittest

from google.protobuf import text_format

import caffe.draw
from caffe.proto import caffe_pb2

def getFilenames():
    """Yields files in the source tree which are Net prototxts."""
    result = []

    root_dir = os.path.abspath(os.path.join(
        os.path.dirname(__file__), '..', '..', '..'))
    assert os.path.exists(root_dir)

    for dirname in ('models', 'examples'):
        dirname = os.path.join(root_dir, dirname)
        assert os.path.exists(dirname)
        for cwd, _, filenames in os.walk(dirname):
            for filename in filenames:
                filename = os.path.join(cwd, filename)
                if filename.endswith('.prototxt') and 'solver' not in filename:
                    yield os.path.join(dirname, filename)


class TestDraw(unittest.TestCase):
    def test_draw_net(self):
        for filename in getFilenames():
            net = caffe_pb2.NetParameter()
            with open(filename) as infile:
                text_format.Merge(infile.read(), net)
            caffe.draw.draw_net(net, 'LR')


if __name__ == "__main__":
    unittest.main()