File: inventory.py

package info (click to toggle)
pyinfra 0.2.2%2Bgit20161227.ec708ef-1
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 11,804 kB
  • ctags: 677
  • sloc: python: 5,944; sh: 71; makefile: 11
file content (196 lines) | stat: -rw-r--r-- 5,644 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# pyinfra
# File: pyinfra/api/inventory.py
# Desc: represents a pyinfra inventory

import six

from pyinfra import logger

from .host import Host
from .attrs import AttrData
from .exceptions import NoHostError, NoGroupError


class Inventory(object):
    '''
    Represents a collection of target hosts. Stores and provides access to group data,
    host data and default data for these hosts.

    Args:
        names_data: tuple of ``(names, data)``
        ssh_user: default SSH user
        ssh_port: default SSH port
        ssh_key: default SSH key filename
        ssh_key_password: default password for the SSH key
        ssh_password: default SSH password
        **groups: map of group names -> ``(names, data)``
    '''

    state = None

    def __init__(
        self, names_data,
        ssh_user=None, ssh_port=None, ssh_key=None,
        ssh_key_password=None, ssh_password=None, **groups
    ):
        names, data = names_data

        self.connected_hosts = set()
        self.groups = {}
        self.host_data = {}
        self.group_data = {}

        # In CLI mode these are --user, --key, etc
        override_data = {
            'ssh_user': ssh_user,
            'ssh_key': ssh_key,
            'ssh_key_password': ssh_key_password,
            'ssh_port': ssh_port,
            'ssh_password': ssh_password
        }
        # Strip None values
        override_data = {
            key: value
            for key, value in six.iteritems(override_data)
            if value is not None
        }

        self.override_data = AttrData(override_data)

        self.data = AttrData(data)

        # Build host data
        for name in names:
            if isinstance(name, tuple):
                self.host_data[name[0]] = name[1]
            else:
                self.host_data[name] = {}

        # Loop groups and build map of name -> groups
        names_to_groups = {}
        for group_name, (group_names, group_data) in six.iteritems(groups):
            self.groups[group_name] = []
            self.group_data[group_name] = AttrData(group_data)

            for name in group_names:
                # Extract any data
                if isinstance(name, tuple):
                    self.host_data.setdefault(name[0], {}).update(name[1])
                    name = name[0]

                names_to_groups.setdefault(name, []).append(group_name)

        # Now we've got host data, convert -> AttrData
        self.host_data = {
            name: AttrData(d)
            for name, d in six.iteritems(self.host_data)
        }

        # Actually make Host instances
        hosts = {}
        for name in names:
            name = name[0] if isinstance(name, tuple) else name

            # Create the Host
            host = Host(self, name, names_to_groups.get(name))
            hosts[name] = host

            # Push into any groups
            for group_name in names_to_groups.get(name, []):
                self.groups[group_name].append(host)

        self.hosts = hosts

    def __getitem__(self, key):
        '''
        Get individual hosts from the inventory by name.
        '''

        if key in self.hosts:
            return self.hosts[key]

        raise NoHostError('No such host: {0}'.format(key))

    def __getattr__(self, key):
        '''
        Get groups (lists of hosts) from the inventory by name.
        '''

        if key in self.groups:
            return self.groups[key]

        # TODO: remove at some point
        # COMPAT: this provides compatability with 0.1 where inventory group names _had_
        # to be defined in caps, but names were lowered before being added to the inventory.
        # Now group names in caps will be left as-is, so check for that too.
        elif key.upper() in self.groups:
            logger.warning(
                'Accessing groups defined in CAPS with lowercase is deprecated '
                'and will be removed in 0.3, please use the name as-is'
            )
            return self.groups[key.upper()]

        raise NoGroupError('No such group: {0}'.format(key))

    def __len__(self):
        '''
        Returns a list of all hosts, connected or not.
        '''

        return len(self.hosts)

    def __iter__(self):
        '''
        Iterates over inventory hosts. Uses active hosts only when they exist - in that
        sense can be seen as the "active" list of hosts during a deploy.
        '''

        for host in self.hosts.values():
            if not self.state or not self.state.active_hosts:
                yield host

            elif host.name in self.state.active_hosts:
                yield host

    def get_data(self):
        '''
        Get the base/all data attached to this inventory.
        '''

        return self.data

    def get_override_data(self):
        '''
        Get override data for this inventory.
        '''

        return self.override_data

    def get_host_data(self, hostname):
        '''
        Get data for a single host in this inventory.
        '''

        return self.host_data[hostname]

    def get_group_data(self, group):
        '''
        Get data for a single group in this inventory.
        '''

        return self.group_data.get(group, {})

    def get_groups_data(self, groups):
        '''
        Gets aggregated data from a list of groups. Vars are collected in order so, for
        any groups which define the same var twice, the last group's value will hold.
        '''

        data = {}

        for group in groups:
            data.update(
                self.get_group_data(group).dict()
            )

        return AttrData(data)