1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
|
#!/usr/bin/python3
from __future__ import print_function
import logging
import socket
import sys
import time
import xml.etree.ElementTree as ET
import libvirt
# Needed for compatibility as there is no TimeoutError in python2.
if sys.version_info[0] < 3:
TimeoutException = socket.timeout
else:
TimeoutException = TimeoutError
logging.getLogger(__name__).addHandler(logging.NullHandler())
class SnapshotStack(object):
SNAPSHOT_BASE = ("<domainsnapshot>"
" <name>{name}</name>"
" <description>"
" Full snapshot by Automatus"
" </description>"
"</domainsnapshot>")
CREATE_FLAGS = libvirt.VIR_DOMAIN_SNAPSHOT_CREATE_ATOMIC
REVERT_FLAGS = libvirt.VIR_DOMAIN_SNAPSHOT_REVERT_FORCE
def __init__(self, domain):
self.snapshot_stack = []
self.domain = domain
def create(self, snapshot_name):
logging.debug("Creating snapshot '{0}'".format(snapshot_name))
snapshot_xml = self.SNAPSHOT_BASE.format(name=snapshot_name)
snapshot = self.domain.snapshotCreateXML(snapshot_xml,
self.CREATE_FLAGS)
self.snapshot_stack.append(snapshot)
return snapshot
def revert_forced(self, snapshot):
snapshot_name = snapshot.getName()
logging.debug("Forced revert of snapshot '{0}'".format(snapshot_name))
self.domain.revertToSnapshot(snapshot,
self.REVERT_FLAGS)
snapshot.delete()
self.snapshot_stack.remove(snapshot)
logging.debug('Revert successful')
def revert(self, delete=True):
try:
snapshot = self.snapshot_stack.pop()
except IndexError:
logging.error("No snapshot in stack anymore")
else:
self.domain.revertToSnapshot(snapshot,
self.REVERT_FLAGS)
if delete:
logging.debug(("Hard revert of snapshot "
"'{0}' successful").format(snapshot.getName()))
snapshot.delete()
else:
# this is soft revert - we are keeping the snapshot for
# another use
logging.debug(("Soft revert of snapshot "
"'{0}' successful").format(snapshot.getName()))
self.snapshot_stack.append(snapshot)
def delete(self, snapshot=None):
# removing snapshot from the stack without doing a revert - use
# coupled with revert without delete
if snapshot:
self.snapshot_stack.remove(snapshot)
else:
snapshot = self.snapshot_stack.pop()
snapshot.delete()
logging.debug(("Snapshot '{0}' deleted "
"successfully").format(snapshot.getName()))
def clear(self):
logging.debug('Reverting all created snapshots in reverse order')
while self.snapshot_stack:
snapshot = self.snapshot_stack.pop()
snapshot_name = snapshot.getName()
logging.debug("Reverting of snapshot '{0}'".format(snapshot_name))
self.domain.revertToSnapshot(snapshot,
self.REVERT_FLAGS)
snapshot.delete()
logging.debug('Revert successful')
logging.info('All snapshots reverted successfully')
def connect_domain(hypervisor, domain_name):
conn = libvirt.open(hypervisor)
if conn is None:
logging.error('Failed to open connection to the hypervisor')
return None
try:
dom = conn.lookupByName(domain_name)
except libvirt.libvirtError:
logging.error("Failed to find domain '{0}'".format(domain_name))
return None
return dom
def determine_ip(domain):
GUEST_AGENT_XML = ("<channel type='unix'>"
" <source mode='bind'/>"
" <target type='virtio'"
" name='org.qemu.guest_agent.0'"
" state='connected'/>"
"</channel>")
# wait for machine until it gets to RUNNING state,
# because it isn't possible to determine IP in e.g. PAUSED state
must_end = time.time() + 120 # wait max. 2 minutes
while time.time() < must_end:
if domain.state()[0] == libvirt.VIR_DOMAIN_RUNNING:
break
time.sleep(1)
domain_xml = ET.fromstring(domain.XMLDesc())
for mac_node in domain_xml.iter('mac'):
domain_mac = mac_node.attrib['address']
break
logging.debug('Fetching IP address of the domain')
try:
ifaces = domain.interfaceAddresses(
libvirt.VIR_DOMAIN_INTERFACE_ADDRESSES_SRC_AGENT,
0)
except libvirt.libvirtError:
# guest agent is not connected properly
# let's try to reattach the guest-agent device
guest_agent_xml_string = None
domain_xml = ET.fromstring(domain.XMLDesc())
for guest_agent_node in domain_xml.iter('channel'):
if guest_agent_node.attrib['type'] == 'unix':
guest_agent_xml_string = ET.tostring(guest_agent_node, encoding='unicode')
break
if guest_agent_xml_string:
domain.detachDevice(guest_agent_xml_string)
domain.attachDevice(GUEST_AGENT_XML)
time.sleep(1)
# now it should be ok
ifaces = domain.interfaceAddresses(
libvirt.VIR_DOMAIN_INTERFACE_ADDRESSES_SRC_AGENT,
0)
# get IPv4 address of the guest
for (name, val) in ifaces.items():
if val['hwaddr'] == domain_mac and val['addrs']:
for ipaddr in val['addrs']:
if ipaddr['type'] == libvirt.VIR_IP_ADDR_TYPE_IPV4:
logging.debug('IP address is {0}'.format(ipaddr['addr']))
return ipaddr['addr']
def start_domain(domain):
if not domain.isActive():
logging.debug("Starting domain '{0}'".format(domain.name()))
domain.create()
logging.debug('Waiting 30s for domain to start')
time.sleep(30)
def reboot_domain(domain, domain_ip, ssh_port):
timeout = 300 # Timeout for domain shutdown and boot.
connection_timeout = 5 # Timeout on the socket before attempting to connect.
logging.debug("Shutting down domain '{0}'".format(domain.name()))
domain.shutdown()
# Wait until domain shuts down.
logging.debug("Waiting for domain to shutdown (max. {0}s)".format(timeout))
end_time = time.time() + timeout
while domain.isActive():
time.sleep(1)
if time.time() >= end_time:
str_err = ("Timeout reached: '{0}' domain failed to shutdown. "
"Forcing the shutdown...".format(domain.name()))
logging.warning(str_err)
domain.destroy()
logging.debug("Starting domain '{0}'".format(domain.name()))
domain.create()
# Wait until SSH (on ssh_port) starts accepting TCP connections.
logging.debug("Waiting for domain to boot (max. {0}s)".format(timeout))
end_time = time.time() + timeout
while True:
try:
ssh_socket = socket.create_connection((domain_ip, ssh_port),
timeout=connection_timeout)
except (OSError, socket.error):
time.sleep(1)
if time.time() >= end_time:
str_err = ("Timeout reached: '{0}' ({1}:{2}) domain does not "
"accept connections.".format(domain.name(), domain_ip, ssh_port))
logging.debug(str_err)
raise TimeoutException(str_err)
else:
ssh_socket.close()
break
|