File: test_net_spec.py

package info (click to toggle)
caffe-contrib 1.0.0%2Bgit20180821.99bd997-2
  • links: PTS, VCS
  • area: contrib
  • in suites: buster
  • size: 16,244 kB
  • sloc: cpp: 61,579; python: 5,783; makefile: 586; sh: 562
file content (89 lines) | stat: -rw-r--r-- 3,756 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import unittest
import tempfile
import caffe
from caffe import layers as L
from caffe import params as P

def lenet(batch_size):
    n = caffe.NetSpec()
    n.data, n.label = L.DummyData(shape=[dict(dim=[batch_size, 1, 28, 28]),
                                         dict(dim=[batch_size, 1, 1, 1])],
                                  transform_param=dict(scale=1./255), ntop=2)
    n.conv1 = L.Convolution(n.data, kernel_size=5, num_output=20,
        weight_filler=dict(type='xavier'))
    n.pool1 = L.Pooling(n.conv1, kernel_size=2, stride=2, pool=P.Pooling.MAX)
    n.conv2 = L.Convolution(n.pool1, kernel_size=5, num_output=50,
        weight_filler=dict(type='xavier'))
    n.pool2 = L.Pooling(n.conv2, kernel_size=2, stride=2, pool=P.Pooling.MAX)
    n.ip1 = L.InnerProduct(n.pool2, num_output=500,
        weight_filler=dict(type='xavier'))
    n.relu1 = L.ReLU(n.ip1, in_place=True)
    n.ip2 = L.InnerProduct(n.relu1, num_output=10,
        weight_filler=dict(type='xavier'))
    n.loss = L.SoftmaxWithLoss(n.ip2, n.label)
    return n.to_proto()

def anon_lenet(batch_size):
    data, label = L.DummyData(shape=[dict(dim=[batch_size, 1, 28, 28]),
                                     dict(dim=[batch_size, 1, 1, 1])],
                              transform_param=dict(scale=1./255), ntop=2)
    conv1 = L.Convolution(data, kernel_size=5, num_output=20,
        weight_filler=dict(type='xavier'))
    pool1 = L.Pooling(conv1, kernel_size=2, stride=2, pool=P.Pooling.MAX)
    conv2 = L.Convolution(pool1, kernel_size=5, num_output=50,
        weight_filler=dict(type='xavier'))
    pool2 = L.Pooling(conv2, kernel_size=2, stride=2, pool=P.Pooling.MAX)
    ip1 = L.InnerProduct(pool2, num_output=500,
        weight_filler=dict(type='xavier'))
    relu1 = L.ReLU(ip1, in_place=True)
    ip2 = L.InnerProduct(relu1, num_output=10,
        weight_filler=dict(type='xavier'))
    loss = L.SoftmaxWithLoss(ip2, label)
    return loss.to_proto()

def silent_net():
    n = caffe.NetSpec()
    n.data, n.data2 = L.DummyData(shape=dict(dim=3), ntop=2)
    n.silence_data = L.Silence(n.data, ntop=0)
    n.silence_data2 = L.Silence(n.data2, ntop=0)
    return n.to_proto()

class TestNetSpec(unittest.TestCase):
    def load_net(self, net_proto):
        f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
        f.write(str(net_proto))
        f.close()
        return caffe.Net(f.name, caffe.TEST)

    def test_lenet(self):
        """Construct and build the Caffe version of LeNet."""

        net_proto = lenet(50)
        # check that relu is in-place
        self.assertEqual(net_proto.layer[6].bottom,
                net_proto.layer[6].top)
        net = self.load_net(net_proto)
        # check that all layers are present
        self.assertEqual(len(net.layers), 9)

        # now the check the version with automatically-generated layer names
        net_proto = anon_lenet(50)
        self.assertEqual(net_proto.layer[6].bottom,
                net_proto.layer[6].top)
        net = self.load_net(net_proto)
        self.assertEqual(len(net.layers), 9)

    def test_zero_tops(self):
        """Test net construction for top-less layers."""

        net_proto = silent_net()
        net = self.load_net(net_proto)
        self.assertEqual(len(net.forward()), 0)

    def test_type_error(self):
        """Test that a TypeError is raised when a Function input isn't a Top."""
        data = L.DummyData(ntop=2)  # data is a 2-tuple of Tops
        r = r"^Silence input 0 is not a Top \(type is <(type|class) 'tuple'>\)$"
        with self.assertRaisesRegexp(TypeError, r):
            L.Silence(data, ntop=0)  # should raise: data is a tuple, not a Top
        L.Silence(*data, ntop=0)  # shouldn't raise: each elt of data is a Top