1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
|
# Copyright 2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from awscli.customizations.emr import exceptions, sshutils
from awscli.testutils import mock, unittest
class TestSSHUtils(unittest.TestCase):
@mock.patch('awscli.customizations.emr.sshutils.emrutils')
def test_validate_and_find_master_dns_waits(self, emrutils):
emrutils.get_cluster_state.return_value = 'STARTING'
session = mock.Mock()
client = mock.Mock()
emrutils.get_client.return_value = client
sshutils.validate_and_find_master_dns(session, None, 'cluster-id')
# We should have:
# 1. Waiter for the cluster to be running.
client.get_waiter.assert_called_with('cluster_running')
client.get_waiter.return_value.wait.assert_called_with(
ClusterId='cluster-id'
)
# 2. Found the master public DNS
self.assertTrue(emrutils.find_master_dns.called)
@mock.patch('awscli.customizations.emr.sshutils.emrutils')
def test_cluster_in_terminated_states(self, emrutils):
emrutils.get_cluster_state.return_value = 'TERMINATED'
with self.assertRaises(exceptions.ClusterTerminatedError):
sshutils.validate_and_find_master_dns(
mock.Mock(), None, 'cluster-id'
)
@mock.patch('awscli.customizations.emr.sshutils.emrutils')
def test_ssh_scp_key_file_format(self, emrutils):
def which_side_effect(program):
if program == 'ssh' or program == 'scp':
return '/some/path'
emrutils.which.side_effect = which_side_effect
key_file1 = 'key.abc'
sshutils.validate_ssh_with_key_file(key_file1)
sshutils.validate_scp_with_key_file(key_file1)
key_file2 = 'key'
sshutils.validate_ssh_with_key_file(key_file2)
sshutils.validate_scp_with_key_file(key_file2)
|