"""
Python reimplementation of the bundle adjustment for the incremental mapper of
C++ with equivalent logic. As a result, one can add customized residuals on top
of the exposed ceres problem from conventional bundle adjustment.
"""

import collections
import copy

import pycolmap
from pycolmap import logging


def solve_bundle_adjustment(
    reconstruction: pycolmap.Reconstruction,
    ba_options: pycolmap.BundleAdjustmentOptions,
    ba_config: pycolmap.BundleAdjustmentConfig,
) -> pycolmap.BundleAdjustmentSummary:
    bundle_adjuster = pycolmap.create_default_bundle_adjuster(
        ba_options, ba_config, reconstruction
    )
    summary = bundle_adjuster.solve()
    # Alternatively, you can customize the existing Ceres problem or options as:
    # import pyceres  # The minimal bindings in pycolmap aren't sufficient.
    # bundle_adjuster = pycolmap.create_default_ceres_bundle_adjuster(
    #     ba_options, ba_config, reconstruction
    # )
    # solver_options = ba_options.ceres.create_solver_options(
    #     ba_config, bundle_adjuster.problem
    # )
    # summary = pyceres.SolverSummary()
    # pyceres.solve(solver_options, bundle_adjuster.problem, summary)
    return summary


def adjust_global_bundle(
    mapper: pycolmap.IncrementalMapper,
    mapper_options: pycolmap.IncrementalMapperOptions,
    ba_options: pycolmap.BundleAdjustmentOptions,
) -> bool:
    """Equivalent to mapper.adjust_global_bundle(...)"""
    reconstruction = mapper.reconstruction
    assert reconstruction is not None
    reg_frame_ids = reconstruction.reg_frame_ids()
    if len(reg_frame_ids) < 2:
        logging.fatal("At least two images must be registered for global BA")
    custom_ba_options = copy.deepcopy(ba_options)

    # Use stricter convergence criteria for first registered images
    if len(reg_frame_ids) < 10:  # kMinNumRegImagesForFastBA = 10
        custom_ba_options.ceres.solver_options.function_tolerance /= 10
        custom_ba_options.ceres.solver_options.gradient_tolerance /= 10
        custom_ba_options.ceres.solver_options.parameter_tolerance /= 10
        custom_ba_options.ceres.solver_options.max_num_iterations *= 2
        custom_ba_options.ceres.solver_options.max_linear_solver_iterations = (
            200
        )

    # Avoid degeneracies in bundle adjustment
    mapper.observation_manager.filter_observations_with_negative_depth()

    # Configure bundle adjustment
    ba_config = pycolmap.BundleAdjustmentConfig()
    for frame_id in reg_frame_ids:
        frame = reconstruction.frame(frame_id)
        for data_id in frame.data_ids:
            if data_id.sensor_id.type != pycolmap.SensorType.CAMERA:
                continue
            ba_config.add_image(data_id.id)

    # Fix the existing images, if option specified
    if mapper_options.fix_existing_frames:
        for frame_id in reg_frame_ids:
            if frame_id in mapper.existing_frame_ids:
                ba_config.set_constant_rig_from_world_pose(frame_id)

    for rig_id in mapper_options.constant_rigs:
        for sensor_id in reconstruction.rig(rig_id).non_ref_sensors:
            ba_config.set_constant_sensor_from_rig_pose(sensor_id)

    for camera_id in mapper_options.constant_cameras:
        ba_config.set_constant_cam_intrinsics(camera_id)

    # TODO: Add python support for prior positions
    # Fixing the gauge with two cameras leads to a more stable optimization
    # with fewer steps as compared to fixing three points.
    ba_config.fix_gauge(pycolmap.BundleAdjustmentGauge.TWO_CAMS_FROM_WORLD)

    # Run bundle adjustment
    summary = solve_bundle_adjustment(
        reconstruction, custom_ba_options, ba_config
    )
    logging.info("Global Bundle Adjustment")
    logging.info(summary.brief_report())
    return summary.is_solution_usable()


def iterative_global_refinement(
    mapper: pycolmap.IncrementalMapper,
    max_num_refinements: int,
    max_refinement_change: float,
    mapper_options: pycolmap.IncrementalMapperOptions,
    ba_options: pycolmap.BundleAdjustmentOptions,
    tri_options: pycolmap.IncrementalTriangulatorOptions,
    normalize_reconstruction: bool = True,
) -> bool:
    """Equivalent to mapper.iterative_global_refinement(...)"""
    reconstruction = mapper.reconstruction
    mapper.complete_and_merge_tracks(tri_options)
    num_retriangulated_observations = mapper.retriangulate(tri_options)
    logging.verbose(
        1, f"=> Retriangulated observations: {num_retriangulated_observations}"
    )
    for _ in range(max_num_refinements):
        num_observations = reconstruction.compute_num_observations()
        # mapper.adjust_global_bundle(mapper_options, ba_options)
        if not adjust_global_bundle(mapper, mapper_options, ba_options):
            return False
        if normalize_reconstruction:
            reconstruction.normalize()
        num_changed_observations = mapper.complete_and_merge_tracks(tri_options)
        num_changed_observations += mapper.filter_points(mapper_options)
        changed = (
            num_changed_observations / num_observations
            if num_observations > 0
            else 0
        )
        logging.verbose(1, f"=> Changed observations: {changed:.6f}")
        if changed < max_refinement_change:
            break
    return True


def adjust_local_bundle(
    mapper: pycolmap.IncrementalMapper,
    mapper_options: pycolmap.IncrementalMapperOptions,
    ba_options: pycolmap.BundleAdjustmentOptions,
    tri_options: pycolmap.IncrementalTriangulatorOptions,
    image_id: int,
    point3D_ids: set[int],
) -> pycolmap.LocalBundleAdjustmentReport:
    """Equivalent to mapper.adjust_local_bundle(...)"""
    reconstruction = mapper.reconstruction
    assert reconstruction is not None
    report = pycolmap.LocalBundleAdjustmentReport()

    # Find images that have most 3D points with given image in common
    local_bundle = mapper.find_local_bundle(mapper_options, image_id)
    image_ids = set()

    # Do the bundle adjustment only if there is any connected images
    if local_bundle:
        ba_config = pycolmap.BundleAdjustmentConfig()
        ba_config.fix_gauge(pycolmap.BundleAdjustmentGauge.THREE_POINTS)

        # Insert the images of all local frames.
        image = reconstruction.image(image_id)
        frame_ids = {image.frame_id}
        assert image.frame is not None
        for data_id in image.frame.image_ids:
            ba_config.add_image(data_id.id)
        for local_image_id in local_bundle:
            local_image = reconstruction.image(local_image_id)
            frame_ids.add(local_image.frame_id)
            assert local_image.frame is not None
            for data_id in local_image.frame.image_ids:
                ba_config.add_image(data_id.id)

        # Fix the existing images, if options specified
        if mapper_options.fix_existing_frames:
            for frame_id in frame_ids:
                if frame_id in mapper.existing_frame_ids:
                    ba_config.set_constant_rig_from_world_pose(frame_id)

        # Fix rig poses, if not all frames within the local bundle.
        num_frames_per_rig: dict[int, int] = collections.defaultdict(int)
        for frame_id in frame_ids:
            frame = reconstruction.frame(frame_id)
            num_frames_per_rig[frame.rig_id] += 1
        for rig_id, num_frames_local in num_frames_per_rig.items():
            if (
                rig_id in mapper_options.constant_rigs
                or num_frames_local < mapper.num_reg_frames_per_rig[rig_id]
            ):
                for sensor_id in reconstruction.rig(rig_id).non_ref_sensors:
                    ba_config.set_constant_sensor_from_rig_pose(sensor_id)

        # Fix camera intrinsics, if not all images within local bundle.
        num_images_per_camera: dict[int, int] = collections.defaultdict(int)
        for image_id in ba_config.images:
            image = reconstruction.images[image_id]
            num_images_per_camera[image.camera_id] += 1
        for camera_id, num_images_local in num_images_per_camera.items():
            if (
                camera_id in mapper_options.constant_cameras
                or num_images_local
                < mapper.num_reg_images_per_camera[camera_id]
            ):
                ba_config.set_constant_cam_intrinsics(camera_id)

        # Make sure, we refine all new and short-track 3D points, no matter if
        # they are fully contained in the local image set or not. Do not include
        # long track 3D points as they are usually already very stable and
        # adding to them to bundle adjustment and track merging/completion would
        # slow down the local bundle adjustment significantly.
        variable_point3D_ids = set()
        for point3D_id in list(point3D_ids):
            point3D = reconstruction.point3D(point3D_id)
            kMaxTrackLength = 15
            if (
                point3D.error == -1.0
            ) or point3D.track.length() <= kMaxTrackLength:
                ba_config.add_variable_point(point3D_id)
                variable_point3D_ids.add(point3D_id)

        # Adjust the local bundle
        summary = solve_bundle_adjustment(
            mapper.reconstruction, ba_options, ba_config
        )
        logging.info("Local Bundle Adjustment")
        logging.info(summary.brief_report())

        image_ids = ba_config.images
        report.num_adjusted_observations = int(summary.num_residuals / 2)
        # Merge refined tracks with other existing points
        report.num_merged_observations = mapper.triangulator.merge_tracks(
            tri_options, variable_point3D_ids
        )
        # Complete tracks that may have failed to triangulate before refinement
        # of camera pose and calibration in bundle adjustment. This may avoid
        # that some points are filtered and helps for subsequent image
        # registrations.
        report.num_completed_observations = mapper.triangulator.complete_tracks(
            tri_options, variable_point3D_ids
        )
        report.num_completed_observations += mapper.triangulator.complete_image(
            tri_options, image_id
        )

    report.num_filtered_observations = (
        mapper.observation_manager.filter_points3D_in_images(
            mapper_options.filter_max_reproj_error,
            mapper_options.filter_min_tri_angle,
            image_ids,
        )
    )
    report.num_filtered_observations += (
        mapper.observation_manager.filter_points3D(
            mapper_options.filter_max_reproj_error,
            mapper_options.filter_min_tri_angle,
            point3D_ids,
        )
    )
    return report


def iterative_local_refinement(
    mapper: pycolmap.IncrementalMapper,
    max_num_refinements: int,
    max_refinement_change: float,
    mapper_options: pycolmap.IncrementalMapperOptions,
    ba_options: pycolmap.BundleAdjustmentOptions,
    tri_options: pycolmap.IncrementalTriangulatorOptions,
    image_id: int,
) -> None:
    """Equivalent to mapper.iterative_local_refinement(...)"""
    custom_ba_options = copy.deepcopy(ba_options)
    for _ in range(max_num_refinements):
        # report = mapper.adjust_local_bundle(
        #     mapper_options,
        #     custom_ba_options,
        #     tri_options,
        #     image_id,
        #     mapper.get_modified_points3D(),
        # )
        report = adjust_local_bundle(
            mapper,
            mapper_options,
            custom_ba_options,
            tri_options,
            image_id,
            mapper.get_modified_points3D(),
        )
        logging.verbose(
            1, f"=> Merged observations: {report.num_merged_observations}"
        )
        logging.verbose(
            1, f"=> Completed observations: {report.num_completed_observations}"
        )
        logging.verbose(
            1, f"=> Filtered observations: {report.num_filtered_observations}"
        )
        changed = 0.0
        if report.num_adjusted_observations > 0:
            changed = (
                report.num_merged_observations
                + report.num_completed_observations
                + report.num_filtered_observations
            ) / report.num_adjusted_observations
        logging.verbose(1, f"=> Changed observations: {changed:.6f}")
        if changed < max_refinement_change:
            break

        # Only use robust cost function for first iteration
        custom_ba_options.ceres.loss_function_type = (
            pycolmap.LossFunctionType.TRIVIAL
        )
    mapper.clear_modified_points3D()
