File: test_io.py

package info (click to toggle)
caffe-contrib 1.0.0%2Bgit20180821.99bd997-2
  • links: PTS, VCS
  • area: contrib
  • in suites: buster
  • size: 16,244 kB
  • sloc: cpp: 61,579; python: 5,783; makefile: 586; sh: 562
file content (56 lines) | stat: -rw-r--r-- 1,694 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import numpy as np
import unittest

import caffe

class TestBlobProtoToArray(unittest.TestCase):

    def test_old_format(self):
        data = np.zeros((10,10))
        blob = caffe.proto.caffe_pb2.BlobProto()
        blob.data.extend(list(data.flatten()))
        shape = (1,1,10,10)
        blob.num, blob.channels, blob.height, blob.width = shape

        arr = caffe.io.blobproto_to_array(blob)
        self.assertEqual(arr.shape, shape)

    def test_new_format(self):
        data = np.zeros((10,10))
        blob = caffe.proto.caffe_pb2.BlobProto()
        blob.data.extend(list(data.flatten()))
        blob.shape.dim.extend(list(data.shape))

        arr = caffe.io.blobproto_to_array(blob)
        self.assertEqual(arr.shape, data.shape)

    def test_no_shape(self):
        data = np.zeros((10,10))
        blob = caffe.proto.caffe_pb2.BlobProto()
        blob.data.extend(list(data.flatten()))

        with self.assertRaises(ValueError):
            caffe.io.blobproto_to_array(blob)

    def test_scalar(self):
        data = np.ones((1)) * 123
        blob = caffe.proto.caffe_pb2.BlobProto()
        blob.data.extend(list(data.flatten()))

        arr = caffe.io.blobproto_to_array(blob)
        self.assertEqual(arr, 123)


class TestArrayToDatum(unittest.TestCase):

    def test_label_none_size(self):
        # Set label
        d1 = caffe.io.array_to_datum(
            np.ones((10,10,3)), label=1)
        # Don't set label
        d2 = caffe.io.array_to_datum(
            np.ones((10,10,3)))
        # Not setting the label should result in a smaller object
        self.assertGreater(
            len(d1.SerializeToString()),
            len(d2.SerializeToString()))