# test_bundle.py -- tests for bundle
# Copyright (C) 2020 Jelmer Vernooij <jelmer@jelmer.uk>
#
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as published by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
# License, Version 2.0.
#

"""Tests for bundle support."""

import os
import tempfile
from io import BytesIO

from dulwich.bundle import Bundle, create_bundle_from_repo, read_bundle, write_bundle
from dulwich.object_format import DEFAULT_OBJECT_FORMAT
from dulwich.objects import Blob, Commit, Tree
from dulwich.pack import PackData, write_pack_objects
from dulwich.repo import MemoryRepo

from . import TestCase


class BundleTests(TestCase):
    def setUp(self):
        super().setUp()
        self.tempdir = tempfile.mkdtemp()
        self.addCleanup(os.rmdir, self.tempdir)

    def test_bundle_repr(self) -> None:
        """Test the Bundle.__repr__ method."""
        bundle = Bundle()
        self.addCleanup(bundle.close)
        bundle.version = 3
        bundle.capabilities = {"foo": "bar"}
        bundle.prerequisites = [(b"cc" * 20, "comment")]
        bundle.references = {b"refs/heads/master": b"ab" * 20}

        # Create a simple pack data
        b = BytesIO()
        write_pack_objects(b.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        b.seek(0)
        bundle.pack_data = PackData.from_file(b, object_format=DEFAULT_OBJECT_FORMAT)
        self.addCleanup(bundle.pack_data.close)

        # Check the repr output
        rep = repr(bundle)
        self.assertIn("Bundle(version=3", rep)
        self.assertIn("capabilities={'foo': 'bar'}", rep)
        self.assertIn("prerequisites=[(", rep)
        self.assertIn("references={", rep)

    def test_bundle_equality(self) -> None:
        """Test the Bundle.__eq__ method."""
        # Create two identical bundles
        bundle1 = Bundle()
        self.addCleanup(bundle1.close)
        bundle1.version = 3
        bundle1.capabilities = {"foo": "bar"}
        bundle1.prerequisites = [(b"cc" * 20, "comment")]
        bundle1.references = {b"refs/heads/master": b"ab" * 20}

        b1 = BytesIO()
        write_pack_objects(b1.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        b1.seek(0)
        bundle1.pack_data = PackData.from_file(b1, object_format=DEFAULT_OBJECT_FORMAT)

        bundle2 = Bundle()
        self.addCleanup(bundle2.close)
        bundle2.version = 3
        bundle2.capabilities = {"foo": "bar"}
        bundle2.prerequisites = [(b"cc" * 20, "comment")]
        bundle2.references = {b"refs/heads/master": b"ab" * 20}

        b2 = BytesIO()
        write_pack_objects(b2.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        b2.seek(0)
        bundle2.pack_data = PackData.from_file(b2, object_format=DEFAULT_OBJECT_FORMAT)

        # Test equality
        self.assertEqual(bundle1, bundle2)

        # Test inequality by changing different attributes
        bundle3 = Bundle()
        self.addCleanup(bundle3.close)
        bundle3.version = 2  # Different version
        bundle3.capabilities = {"foo": "bar"}
        bundle3.prerequisites = [(b"cc" * 20, "comment")]
        bundle3.references = {b"refs/heads/master": b"ab" * 20}
        b3 = BytesIO()
        write_pack_objects(b3.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        b3.seek(0)
        bundle3.pack_data = PackData.from_file(b3, object_format=DEFAULT_OBJECT_FORMAT)
        self.assertNotEqual(bundle1, bundle3)

        bundle4 = Bundle()
        self.addCleanup(bundle4.close)
        bundle4.version = 3
        bundle4.capabilities = {"different": "value"}  # Different capabilities
        bundle4.prerequisites = [(b"cc" * 20, "comment")]
        bundle4.references = {b"refs/heads/master": b"ab" * 20}
        b4 = BytesIO()
        write_pack_objects(b4.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        b4.seek(0)
        bundle4.pack_data = PackData.from_file(b4, object_format=DEFAULT_OBJECT_FORMAT)
        self.assertNotEqual(bundle1, bundle4)

        bundle5 = Bundle()
        self.addCleanup(bundle5.close)
        bundle5.version = 3
        bundle5.capabilities = {"foo": "bar"}
        bundle5.prerequisites = [(b"dd" * 20, "different")]  # Different prerequisites
        bundle5.references = {b"refs/heads/master": b"ab" * 20}
        b5 = BytesIO()
        write_pack_objects(b5.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        b5.seek(0)
        bundle5.pack_data = PackData.from_file(b5, object_format=DEFAULT_OBJECT_FORMAT)
        self.assertNotEqual(bundle1, bundle5)

        bundle6 = Bundle()
        self.addCleanup(bundle6.close)
        bundle6.version = 3
        bundle6.capabilities = {"foo": "bar"}
        bundle6.prerequisites = [(b"cc" * 20, "comment")]
        bundle6.references = {
            b"refs/heads/different": b"ab" * 20
        }  # Different references
        b6 = BytesIO()
        write_pack_objects(b6.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        b6.seek(0)
        bundle6.pack_data = PackData.from_file(b6, object_format=DEFAULT_OBJECT_FORMAT)
        self.assertNotEqual(bundle1, bundle6)

        # Test inequality with different type
        self.assertNotEqual(bundle1, "not a bundle")

    def test_read_bundle_v2(self) -> None:
        """Test reading a v2 bundle."""
        f = BytesIO()
        f.write(b"# v2 git bundle\n")
        f.write(b"-" + b"cc" * 20 + b" prerequisite comment\n")
        f.write(b"ab" * 20 + b" refs/heads/master\n")
        f.write(b"\n")
        # Add pack data
        b = BytesIO()
        write_pack_objects(b.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        f.write(b.getvalue())
        f.seek(0)

        bundle = read_bundle(f)
        self.addCleanup(bundle.close)
        self.assertEqual(2, bundle.version)
        self.assertEqual({}, bundle.capabilities)
        self.assertEqual([(b"cc" * 20, b"prerequisite comment")], bundle.prerequisites)
        self.assertEqual({b"refs/heads/master": b"ab" * 20}, bundle.references)

    def test_read_bundle_v3(self) -> None:
        """Test reading a v3 bundle with capabilities."""
        f = BytesIO()
        f.write(b"# v3 git bundle\n")
        f.write(b"@capability1\n")
        f.write(b"@capability2=value2\n")
        f.write(b"-" + b"cc" * 20 + b" prerequisite comment\n")
        f.write(b"ab" * 20 + b" refs/heads/master\n")
        f.write(b"\n")
        # Add pack data
        b = BytesIO()
        write_pack_objects(b.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        f.write(b.getvalue())
        f.seek(0)

        bundle = read_bundle(f)
        self.addCleanup(bundle.close)
        self.assertEqual(3, bundle.version)
        self.assertEqual(
            {"capability1": None, "capability2": "value2"}, bundle.capabilities
        )
        self.assertEqual([(b"cc" * 20, b"prerequisite comment")], bundle.prerequisites)
        self.assertEqual({b"refs/heads/master": b"ab" * 20}, bundle.references)

    def test_read_bundle_invalid_format(self) -> None:
        """Test reading a bundle with invalid format."""
        f = BytesIO()
        f.write(b"invalid bundle format\n")
        f.seek(0)

        with self.assertRaises(AssertionError):
            read_bundle(f)

    def test_write_bundle_v2(self) -> None:
        """Test writing a v2 bundle."""
        bundle = Bundle()
        self.addCleanup(bundle.close)
        bundle.version = 2
        bundle.capabilities = {}
        bundle.prerequisites = [(b"cc" * 20, b"prerequisite comment")]
        bundle.references = {b"refs/heads/master": b"ab" * 20}

        # Create a simple pack data
        b = BytesIO()
        write_pack_objects(b.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        b.seek(0)
        bundle.pack_data = PackData.from_file(b, object_format=DEFAULT_OBJECT_FORMAT)

        # Write the bundle
        f = BytesIO()
        write_bundle(f, bundle)
        f.seek(0)

        # Verify the written content
        self.assertEqual(b"# v2 git bundle\n", f.readline())
        self.assertEqual(b"-" + b"cc" * 20 + b" prerequisite comment\n", f.readline())
        self.assertEqual(b"ab" * 20 + b" refs/heads/master\n", f.readline())
        self.assertEqual(b"\n", f.readline())
        # The rest is pack data which we don't validate in detail

    def test_write_bundle_v3(self) -> None:
        """Test writing a v3 bundle with capabilities."""
        bundle = Bundle()
        self.addCleanup(bundle.close)
        bundle.version = 3
        bundle.capabilities = {"capability1": None, "capability2": "value2"}
        bundle.prerequisites = [(b"cc" * 20, b"prerequisite comment")]
        bundle.references = {b"refs/heads/master": b"ab" * 20}

        # Create a simple pack data
        b = BytesIO()
        write_pack_objects(b.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        b.seek(0)
        bundle.pack_data = PackData.from_file(b, object_format=DEFAULT_OBJECT_FORMAT)

        # Write the bundle
        f = BytesIO()
        write_bundle(f, bundle)
        f.seek(0)

        # Verify the written content
        self.assertEqual(b"# v3 git bundle\n", f.readline())
        self.assertEqual(b"@capability1\n", f.readline())
        self.assertEqual(b"@capability2=value2\n", f.readline())
        self.assertEqual(b"-" + b"cc" * 20 + b" prerequisite comment\n", f.readline())
        self.assertEqual(b"ab" * 20 + b" refs/heads/master\n", f.readline())
        self.assertEqual(b"\n", f.readline())
        # The rest is pack data which we don't validate in detail

    def test_write_bundle_auto_version(self) -> None:
        """Test writing a bundle with auto-detected version."""
        # Create a bundle with no explicit version but capabilities
        bundle1 = Bundle()
        self.addCleanup(bundle1.close)
        bundle1.version = None
        bundle1.capabilities = {"capability1": "value1"}
        bundle1.prerequisites = [(b"cc" * 20, b"prerequisite comment")]
        bundle1.references = {b"refs/heads/master": b"ab" * 20}

        b1 = BytesIO()
        write_pack_objects(b1.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        b1.seek(0)
        bundle1.pack_data = PackData.from_file(b1, object_format=DEFAULT_OBJECT_FORMAT)

        f1 = BytesIO()
        write_bundle(f1, bundle1)
        f1.seek(0)
        # Should use v3 format since capabilities are present
        self.assertEqual(b"# v3 git bundle\n", f1.readline())

        # Create a bundle with no explicit version and no capabilities
        bundle2 = Bundle()
        self.addCleanup(bundle2.close)
        bundle2.version = None
        bundle2.capabilities = {}
        bundle2.prerequisites = [(b"cc" * 20, b"prerequisite comment")]
        bundle2.references = {b"refs/heads/master": b"ab" * 20}

        b2 = BytesIO()
        write_pack_objects(b2.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        b2.seek(0)
        bundle2.pack_data = PackData.from_file(b2, object_format=DEFAULT_OBJECT_FORMAT)

        f2 = BytesIO()
        write_bundle(f2, bundle2)
        f2.seek(0)
        # Should use v2 format since no capabilities are present
        self.assertEqual(b"# v2 git bundle\n", f2.readline())

    def test_write_bundle_invalid_version(self) -> None:
        """Test writing a bundle with an invalid version."""
        bundle = Bundle()
        self.addCleanup(bundle.close)
        bundle.version = 4  # Invalid version
        bundle.capabilities = {}
        bundle.prerequisites = []
        bundle.references = {}

        b = BytesIO()
        write_pack_objects(b.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        b.seek(0)
        bundle.pack_data = PackData.from_file(b, object_format=DEFAULT_OBJECT_FORMAT)

        f = BytesIO()
        with self.assertRaises(AssertionError):
            write_bundle(f, bundle)

    def test_roundtrip_bundle(self) -> None:
        origbundle = Bundle()
        self.addCleanup(origbundle.close)
        origbundle.version = 3
        origbundle.capabilities = {"foo": None}
        origbundle.references = {b"refs/heads/master": b"ab" * 20}
        origbundle.prerequisites = [(b"cc" * 20, b"comment")]
        b = BytesIO()
        write_pack_objects(b.write, [], object_format=DEFAULT_OBJECT_FORMAT)
        b.seek(0)
        origbundle.pack_data = PackData.from_file(
            b, object_format=DEFAULT_OBJECT_FORMAT
        )
        with tempfile.TemporaryDirectory() as td:
            with open(os.path.join(td, "foo"), "wb") as f:
                write_bundle(f, origbundle)

            with open(os.path.join(td, "foo"), "rb") as f:
                newbundle = read_bundle(f)
                self.addCleanup(newbundle.close)

                self.assertEqual(origbundle, newbundle)

    def test_create_bundle_from_repo(self) -> None:
        """Test creating a bundle from a repository."""
        # Create a simple repository
        repo = MemoryRepo()
        self.addCleanup(repo.close)

        # Create a blob
        blob = Blob.from_string(b"Hello world")
        repo.object_store.add_object(blob)

        # Create a tree
        tree = Tree()
        tree.add(b"hello.txt", 0o100644, blob.id)
        repo.object_store.add_object(tree)

        # Create a commit
        commit = Commit()
        commit.tree = tree.id
        commit.message = b"Initial commit"
        commit.author = commit.committer = b"Test User <test@example.com>"
        commit.commit_time = commit.author_time = 1234567890
        commit.commit_timezone = commit.author_timezone = 0
        repo.object_store.add_object(commit)

        # Add a reference
        repo.refs[b"refs/heads/master"] = commit.id

        # Create bundle from repository
        bundle = create_bundle_from_repo(repo)
        self.addCleanup(bundle.close)

        # Verify bundle contents
        self.assertEqual(bundle.references, {b"refs/heads/master": commit.id})
        self.assertEqual(bundle.prerequisites, [])
        self.assertEqual(bundle.capabilities, {})
        self.assertIsNotNone(bundle.pack_data)

        # Verify the bundle contains the right objects
        objects = list(bundle.pack_data.iter_unpacked())
        object_ids = {obj.sha().hex().encode("ascii") for obj in objects}
        self.assertIn(blob.id, object_ids)
        self.assertIn(tree.id, object_ids)
        self.assertIn(commit.id, object_ids)

    def test_create_bundle_with_prerequisites(self) -> None:
        """Test creating a bundle with prerequisites."""
        repo = MemoryRepo()

        # Create some objects
        blob = Blob.from_string(b"Hello world")
        repo.object_store.add_object(blob)

        tree = Tree()
        tree.add(b"hello.txt", 0o100644, blob.id)
        repo.object_store.add_object(tree)

        commit = Commit()
        commit.tree = tree.id
        commit.message = b"Initial commit"
        commit.author = commit.committer = b"Test User <test@example.com>"
        commit.commit_time = commit.author_time = 1234567890
        commit.commit_timezone = commit.author_timezone = 0
        repo.object_store.add_object(commit)

        repo.refs[b"refs/heads/master"] = commit.id

        # Create bundle with prerequisites
        prereq_id = b"aa" * 20  # hex string like other object ids
        bundle = create_bundle_from_repo(repo, prerequisites=[prereq_id])
        self.addCleanup(bundle.close)

        # Verify prerequisites are included
        self.assertEqual(len(bundle.prerequisites), 1)
        self.assertEqual(bundle.prerequisites[0][0], prereq_id)

    def test_create_bundle_with_specific_refs(self) -> None:
        """Test creating a bundle with specific refs."""
        repo = MemoryRepo()

        # Create objects and refs
        blob = Blob.from_string(b"Hello world")
        repo.object_store.add_object(blob)

        tree = Tree()
        tree.add(b"hello.txt", 0o100644, blob.id)
        repo.object_store.add_object(tree)

        commit = Commit()
        commit.tree = tree.id
        commit.message = b"Initial commit"
        commit.author = commit.committer = b"Test User <test@example.com>"
        commit.commit_time = commit.author_time = 1234567890
        commit.commit_timezone = commit.author_timezone = 0
        repo.object_store.add_object(commit)

        repo.refs[b"refs/heads/master"] = commit.id
        repo.refs[b"refs/heads/feature"] = commit.id

        # Create bundle with only master ref
        from dulwich.refs import Ref

        bundle = create_bundle_from_repo(repo, refs=[Ref(b"refs/heads/master")])
        self.addCleanup(bundle.close)

        # Verify only master ref is included
        self.assertEqual(len(bundle.references), 1)
        self.assertIn(b"refs/heads/master", bundle.references)
        self.assertNotIn(b"refs/heads/feature", bundle.references)

    def test_create_bundle_with_capabilities(self) -> None:
        """Test creating a bundle with capabilities."""
        repo = MemoryRepo()

        # Create minimal objects
        blob = Blob.from_string(b"Hello world")
        repo.object_store.add_object(blob)

        tree = Tree()
        tree.add(b"hello.txt", 0o100644, blob.id)
        repo.object_store.add_object(tree)

        commit = Commit()
        commit.tree = tree.id
        commit.message = b"Initial commit"
        commit.author = commit.committer = b"Test User <test@example.com>"
        commit.commit_time = commit.author_time = 1234567890
        commit.commit_timezone = commit.author_timezone = 0
        repo.object_store.add_object(commit)

        repo.refs[b"refs/heads/master"] = commit.id

        # Create bundle with capabilities
        capabilities = {"object-format": "sha1"}
        bundle = create_bundle_from_repo(repo, capabilities=capabilities, version=3)
        self.addCleanup(bundle.close)

        # Verify capabilities are included
        self.assertEqual(bundle.capabilities, capabilities)
        self.assertEqual(bundle.version, 3)

    def test_create_bundle_with_hex_bytestring_prerequisite(self) -> None:
        """Test creating a bundle with prerequisite as 40-byte hex bytestring."""
        repo = MemoryRepo()

        # Create minimal objects
        blob = Blob.from_string(b"Hello world")
        repo.object_store.add_object(blob)

        tree = Tree()
        tree.add(b"hello.txt", 0o100644, blob.id)
        repo.object_store.add_object(tree)

        commit = Commit()
        commit.tree = tree.id
        commit.message = b"Initial commit"
        commit.author = commit.committer = b"Test User <test@example.com>"
        commit.commit_time = commit.author_time = 1234567890
        commit.commit_timezone = commit.author_timezone = 0
        repo.object_store.add_object(commit)

        repo.refs[b"refs/heads/master"] = commit.id

        # Create another blob to use as prerequisite
        prereq_blob = Blob.from_string(b"prerequisite")

        # Use blob.id directly (40-byte hex bytestring)
        bundle = create_bundle_from_repo(repo, prerequisites=[prereq_blob.id])
        self.addCleanup(bundle.close)

        # Verify the prerequisite was added correctly
        self.assertEqual(len(bundle.prerequisites), 1)
        self.assertEqual(bundle.prerequisites[0][0], prereq_blob.id)

    def test_create_bundle_with_hex_bytestring_prerequisite_simple(self) -> None:
        """Test creating a bundle with prerequisite as 40-byte hex bytestring."""
        repo = MemoryRepo()

        # Create minimal objects
        blob = Blob.from_string(b"Hello world")
        repo.object_store.add_object(blob)

        tree = Tree()
        tree.add(b"hello.txt", 0o100644, blob.id)
        repo.object_store.add_object(tree)

        commit = Commit()
        commit.tree = tree.id
        commit.message = b"Initial commit"
        commit.author = commit.committer = b"Test User <test@example.com>"
        commit.commit_time = commit.author_time = 1234567890
        commit.commit_timezone = commit.author_timezone = 0
        repo.object_store.add_object(commit)

        repo.refs[b"refs/heads/master"] = commit.id

        # Use a 40-byte hex bytestring as prerequisite
        prereq_hex = b"aa" * 20

        bundle = create_bundle_from_repo(repo, prerequisites=[prereq_hex])
        self.addCleanup(bundle.close)

        # Verify the prerequisite was added correctly
        self.assertEqual(len(bundle.prerequisites), 1)
        self.assertEqual(bundle.prerequisites[0][0], prereq_hex)
