File: network_tf2.py

package info (click to toggle)
bart 0.9.00-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 9,040 kB
  • sloc: ansic: 116,116; python: 1,329; sh: 726; makefile: 639; javascript: 589; cpp: 106
file content (89 lines) | stat: -rw-r--r-- 3,110 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import bart_tf

def real_from_complex_weights(wgh):
    import tensorflow as tf

    shp = wgh.shape

    filter_depth, filter_height, filter_width, in_channels, out_channels, tmp = shp
    size = [filter_depth, filter_height, filter_width, in_channels, out_channels, 1]
    
    rwgh=tf.slice(wgh, begin=[0,0,0,0,0,0], size = size)
    iwgh=tf.slice(wgh, begin=[0,0,0,0,0,1], size = size)

    rwgh = tf.reshape(rwgh, [filter_depth, filter_height, filter_width, in_channels, 1, out_channels, 1])
    iwgh = tf.reshape(iwgh, [filter_depth, filter_height, filter_width, in_channels, 1, out_channels, 1])

    wgh = tf.concat([tf.concat([rwgh, iwgh], 6), tf.concat([-iwgh, rwgh], 6)], 4)

    return tf.reshape(wgh, [filter_depth, filter_height, filter_width, 2 * in_channels, 2 * out_channels])


def tf2_generate_resnet(path, model):

    import tensorflow as tf
    import numpy as np

    class ComplexConv3D(tf.Module):
        def __init__(self, filters, kernel_size, dummy_dim = False):
            super().__init__()
            # filters: 64, kernel_size: 3, stride: 1
            self.filters = filters
            self.kernel_size = kernel_size
            self.is_built = False
            self.dummy_dim = dummy_dim

        def __call__(self, input):

            if not(self.is_built):
                if self.dummy_dim:
                    shp = [1] + list(self.kernel_size) + [input.shape[-2], self.filters, 2]
                else:
                    shp = list(self.kernel_size) + [input.shape[-2], self.filters, 2]

                scale = np.sqrt(1 / (np.prod(self.kernel_size) * self.filters + input.shape[-2]))
                self.conv_weight = tf.Variable(tf.random.normal(shp, stddev=scale), name='w')
                self.is_built = True

            conv = self.conv_weight
            if self.dummy_dim:
                conv = tf.reshape(conv, conv.shape[1:])
            conv = real_from_complex_weights(conv)

            shp = tf.shape(input)
            shp = tf.concat([shp[:-2], [2 *shp[-2]]], 0)

            tmp = tf.reshape(input, shp)
            tmp = tf.nn.conv3d(tmp, conv, [1] * 5, "SAME")

            shp = tf.shape(tmp)
            shp = tf.concat([shp[:-1], [shp[-1] // 2, 2]], 0)

            return tf.reshape(tmp, shp)
    
    class ResBlock(tf.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = ComplexConv3D(8, (1, 3, 3))
            self.conv2 = ComplexConv3D(8, (1, 3, 3), dummy_dim=True)
            self.conv3 = ComplexConv3D(1, (1, 3, 3))

        def __call__(self, input):

            shp = tf.shape(input)
            shp = tf.concat([shp[:-1], tf.constant([1]), shp[-1:]], 0)

            out = tf.reshape(input, shp)
            out = self.conv1(out)
            out = tf.nn.relu(out)
            out = self.conv2(out)
            out = tf.nn.relu(out)
            out = self.conv3(out)
            out = input + tf.reshape(out, tf.shape(input))
            return out

    bart_tf.tf2_export_module(ResBlock(), [32, 32, 1], path+"/"+model, trace_complex=False)

tf2_generate_resnet("./", "tf2_resnet")