# Copyright (c) 2025 Thomas Goirand
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
try:
    from unittest import mock
except ImportError:
    import mock
import testtools
import uuid

from osc_lib import exceptions as osc_lib_exceptions

from vmmsclient.tests import base
from vmmsclient.tests.v2.fakes import FAKE_MIGRATIONS, make_fake_migration
from vmmsclient.v2 import client


class TestClient(base.TestCase):
    """Test cases for the Client class."""

    def setUp(self):
        super(TestClient, self).setUp()
        self.client = client.Client(
            session=mock.Mock(),
            service_type='vmms',
            interface='public'
        )

    def test_add_vm(self):
        """Test adding a VM migration."""
        fake_migration = make_fake_migration()
        
        # Mock the post method to return our fake data
        with mock.patch.object(self.client, 'post') as mock_post:
            mock_response = mock.Mock()
            mock_response.json.return_value = fake_migration
            mock_post.return_value = mock_response
            
            result = self.client.add_vm(
                vm_identifier='test-vm-id-or-name'
            )
            
            self.assertEqual(fake_migration, result)
            # Check that post was called with correct arguments
            expected_data = {
                'vm_identifier': 'test-vm-id-or-name'
            }
            mock_post.assert_called_once_with('/vms', json=expected_data)

    def test_add_vm_with_schedule_time(self):
        """Test adding a VM migration with scheduled time."""
        fake_migration = make_fake_migration(
            scheduled_time='2025-10-07T22:00:00'
        )
        
        # Mock the post method to return our fake data
        with mock.patch.object(self.client, 'post') as mock_post:
            mock_response = mock.Mock()
            mock_response.json.return_value = fake_migration
            mock_post.return_value = mock_response
            
            result = self.client.add_vm(
                vm_identifier='test-vm-id-or-name',
                scheduled_time=fake_migration['scheduled_time']
            )
            
            self.assertEqual(fake_migration, result)
            # Check that post was called with correct arguments
            expected_data = {
                'vm_identifier': 'test-vm-id-or-name',
                'scheduled_time': fake_migration['scheduled_time']
            }
            mock_post.assert_called_once_with('/vms', json=expected_data)

    def test_list_vms(self):
        """Test listing VM migrations."""
        # Mock the get method to return our fake data
        with mock.patch.object(self.client, 'get') as mock_get:
            mock_response = mock.Mock()
            mock_response.json.return_value = FAKE_MIGRATIONS
            mock_get.return_value = mock_response
            
            result = self.client.list_vms()
            
            self.assertEqual(FAKE_MIGRATIONS, result)
            mock_get.assert_called_once_with('/vms')

    def test_remove_vm(self):
        """Test removing a VM migration."""
        migration_id = 'test-migration-id'
        
        # Mock the delete method to return success
        with mock.patch.object(self.client, 'delete') as mock_delete:
            mock_response = mock.Mock()
            mock_response.status_code = 200
            mock_delete.return_value = mock_response
            
            result = self.client.remove_vm(migration_id)
            
            self.assertTrue(result)
            mock_delete.assert_called_once_with('/vms/{0}'.format(migration_id))

    def test_remove_vm_not_found(self):
        """Test removing a VM migration that doesn't exist."""
        migration_id = 'non-existent-id'
        
        # Mock the delete method to return 404
        with mock.patch.object(self.client, 'delete') as mock_delete:
            mock_delete.side_effect = osc_lib_exceptions.NotFound('Migration not found (404).')

            # Expect NotFound to be raised for 404
            self.assertRaises(
                osc_lib_exceptions.NotFound,
                self.client.remove_vm,
                migration_id
            )


class TestAddVMCommand(base.TestCommand):
    """Test cases for AddVMCommand."""

    def setUp(self):
        super(TestAddVMCommand, self).setUp()
        self.cmd = client.AddVMCommand(self.app, None)

    def test_get_parser(self):
        """Test parser creation."""
        parser = self.cmd.get_parser('test')
        parsed_args = parser.parse_args([
            'test-vm-id-or-name'
        ])
        self.assertEqual('test-vm-id-or-name', parsed_args.vm_identifier)

    def test_get_parser_with_schedule_time(self):
        """Test parser creation with schedule time."""
        parser = self.cmd.get_parser('test')
        parsed_args = parser.parse_args([
            'test-vm-id-or-name',
            '--schedule-time', '2025-10-07T22:00:00'
        ])
        self.assertEqual('test-vm-id-or-name', parsed_args.vm_identifier)
        self.assertEqual('2025-10-07T22:00:00', parsed_args.schedule_time)

    def test_take_action(self):
        """Test taking action."""
        fake_migration = make_fake_migration()
        self.app.client_manager.vmms.add_vm.return_value = fake_migration
        
        parsed_args = mock.Mock()
        parsed_args.vm_identifier = 'test-vm-id-or-name'
        parsed_args.schedule_time = None
        
        columns, data = self.cmd.take_action(parsed_args)
        
        self.app.client_manager.vmms.add_vm.assert_called_once_with(
            vm_identifier='test-vm-id-or-name',
            scheduled_time=None
        )
        self.assertEqual(('id', 'vm_id', 'vm_name', 'scheduled_time', 'state',
                         'workflow_exec', 'created_at', 'updated_at'), columns)


class TestListVMsCommand(base.TestCommand):
    """Test cases for ListVMsCommand."""

    def setUp(self):
        super(TestListVMsCommand, self).setUp()
        self.cmd = client.ListVMsCommand(self.app, None)

    def test_take_action(self):
        """Test taking action."""
        self.app.client_manager.vmms.list_vms.return_value = FAKE_MIGRATIONS
        
        # Test without --state parameter
        parsed_args = mock.Mock()
        parsed_args.state = None
        parsed_args.long = False
        
        columns, data = self.cmd.take_action(parsed_args)
        
        self.app.client_manager.vmms.list_vms.assert_called_once_with(state=None)
        self.assertEqual(('id', 'vm_id', 'vm_name', 'scheduled_time', 'state', 
                         'workflow_exec'), columns)
        self.assertEqual(len(FAKE_MIGRATIONS), len(list(data)))

    def test_take_action_with_state_filter(self):
        """Test taking action with state filter."""
        self.app.client_manager.vmms.list_vms.return_value = FAKE_MIGRATIONS
        
        parsed_args = mock.Mock()
        parsed_args.state = 'ERROR'
        parsed_args.long = False
        
        columns, data = self.cmd.take_action(parsed_args)
        
        self.app.client_manager.vmms.list_vms.assert_called_once_with(state='ERROR')
        self.assertEqual(('id', 'vm_id', 'vm_name', 'scheduled_time', 'state', 
                         'workflow_exec'), columns)

    def test_take_action_long_format(self):
        """Test taking action with long format."""
        self.app.client_manager.vmms.list_vms.return_value = FAKE_MIGRATIONS
        
        parsed_args = mock.Mock()
        parsed_args.state = None
        parsed_args.long = True
        
        columns, data = self.cmd.take_action(parsed_args)
        
        self.app.client_manager.vmms.list_vms.assert_called_once_with(state=None)
        self.assertEqual(('id', 'vm_id', 'vm_name', 'scheduled_time', 'state', 
                         'workflow_exec', 'created_at', 'updated_at'), columns)
        self.assertEqual(len(FAKE_MIGRATIONS), len(list(data)))

    def test_get_parser(self):
        """Test parser creation."""
        parser = self.cmd.get_parser('test')
        # Test that parser accepts --state option
        parsed_args = parser.parse_args(['--state', 'ERROR'])
        self.assertEqual('ERROR', parsed_args.state)
        
        # Test that parser accepts --long option
        parsed_args = parser.parse_args(['--long'])
        self.assertTrue(parsed_args.long)


class TestRemoveVMCommand(base.TestCommand):
    """Test cases for RemoveVMCommand."""

    def setUp(self):
        super(TestRemoveVMCommand, self).setUp()
        self.cmd = client.RemoveVMCommand(self.app, None)

    def test_get_parser(self):
        """Test parser creation."""
        parser = self.cmd.get_parser('test')
        parsed_args = parser.parse_args(['test-migration-id'])
        self.assertEqual('test-migration-id', parsed_args.migration_id)

    def test_take_action_success(self):
        """Test successful removal."""
        self.app.client_manager.vmms.remove_vm.return_value = True
        
        parsed_args = mock.Mock()
        parsed_args.migration_id = 'test-migration-id'
        
        self.cmd.take_action(parsed_args)
        
        self.app.client_manager.vmms.remove_vm.assert_called_once_with(
            'test-migration-id'
        )

    def test_take_action_failure(self):
        """Test failed removal."""
        # Mock remove_vm to raise NotFound
        self.app.client_manager.vmms.remove_vm.side_effect = osc_lib_exceptions.NotFound("Migration not found")

        parsed_args = mock.Mock()
        parsed_args.migration_id = 'test-migration-id'
        
        self.assertRaises(
            osc_lib_exceptions.CommandError,
            self.cmd.take_action,
            parsed_args
        )
