# This file is part of beets.
# Copyright 2023, Max Rumpf.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.

"""Plugin to rewrite fields based on a given query."""

import re
import shlex
from collections import defaultdict

import confuse

from beets.dbcore import AndQuery, query_from_strings
from beets.dbcore.types import MULTI_VALUE_DSV
from beets.library import Album, Item
from beets.plugins import BeetsPlugin
from beets.ui import UserError


def rewriter(field, simple_rules, advanced_rules):
    """Template field function factory.

    Create a template field function that rewrites the given field
    with the given rewriting rules.
    ``simple_rules`` must be a list of (pattern, replacement) pairs.
    ``advanced_rules`` must be a list of (query, replacement) pairs.
    """

    def fieldfunc(item):
        value = item._values_fixed[field]
        for pattern, replacement in simple_rules:
            if pattern.match(value.lower()):
                # Rewrite activated.
                return replacement
        for query, replacement in advanced_rules:
            if query.match(item):
                # Rewrite activated.
                return replacement
        # Not activated; return original value.
        return value

    return fieldfunc


class AdvancedRewritePlugin(BeetsPlugin):
    """Plugin to rewrite fields based on a given query."""

    def __init__(self):
        """Parse configuration and register template fields for rewriting."""
        super().__init__()
        self.register_listener("pluginload", self.loaded)

    def loaded(self):
        template = confuse.Sequence(
            confuse.OneOf(
                [
                    confuse.MappingValues(str),
                    {
                        "match": str,
                        "replacements": confuse.MappingValues(
                            confuse.OneOf([str, confuse.Sequence(str)]),
                        ),
                    },
                ]
            )
        )

        # Used to apply the same rewrite to the corresponding album field.
        corresponding_album_fields = {
            "artist": "albumartist",
            "artists": "albumartists",
            "artist_sort": "albumartist_sort",
            "artists_sort": "albumartists_sort",
        }

        # Gather all the rewrite rules for each field.
        class RulesContainer:
            def __init__(self):
                self.simple = []
                self.advanced = []

        rules = defaultdict(RulesContainer)
        for rule in self.config.get(template):
            if "match" not in rule:
                # Simple syntax
                if len(rule) != 1:
                    raise UserError(
                        "Simple rewrites must have only one rule, "
                        "but found multiple entries. "
                        "Did you forget to prepend a dash (-)?"
                    )
                key, value = next(iter(rule.items()))
                try:
                    fieldname, pattern = key.split(None, 1)
                except ValueError:
                    raise UserError(
                        f"Invalid simple rewrite specification {key}"
                    )
                if fieldname not in Item._fields:
                    raise UserError(
                        f"invalid field name {fieldname} in rewriter"
                    )
                self._log.debug(
                    f"adding simple rewrite '{pattern}' → '{value}' "
                    f"for field {fieldname}"
                )
                pattern = re.compile(pattern.lower())
                rules[fieldname].simple.append((pattern, value))

                # Apply the same rewrite to the corresponding album field.
                if fieldname in corresponding_album_fields:
                    album_fieldname = corresponding_album_fields[fieldname]
                    rules[album_fieldname].simple.append((pattern, value))
            else:
                # Advanced syntax
                match = rule["match"]
                replacements = rule["replacements"]
                if len(replacements) == 0:
                    raise UserError(
                        "Advanced rewrites must have at least one replacement"
                    )
                query = query_from_strings(
                    AndQuery,
                    Item,
                    prefixes={},
                    query_parts=shlex.split(match),
                )
                for fieldname, replacement in replacements.items():
                    if fieldname not in Item._fields:
                        raise UserError(
                            f"Invalid field name {fieldname} in rewriter"
                        )
                    self._log.debug(
                        f"adding advanced rewrite to '{replacement}' "
                        f"for field {fieldname}"
                    )
                    if isinstance(replacement, list):
                        if Item._fields[fieldname] is not MULTI_VALUE_DSV:
                            raise UserError(
                                f"Field {fieldname} is not a multi-valued field "
                                f"but a list was given: {', '.join(replacement)}"
                            )
                    elif isinstance(replacement, str):
                        if Item._fields[fieldname] is MULTI_VALUE_DSV:
                            replacement = [replacement]
                    else:
                        raise UserError(
                            f"Invalid type of replacement {replacement} "
                            f"for field {fieldname}"
                        )

                    rules[fieldname].advanced.append((query, replacement))

                    # Apply the same rewrite to the corresponding album field.
                    if fieldname in corresponding_album_fields:
                        album_fieldname = corresponding_album_fields[fieldname]
                        rules[album_fieldname].advanced.append(
                            (query, replacement)
                        )

        # Replace each template field with the new rewriter function.
        for fieldname, fieldrules in rules.items():
            getter = rewriter(fieldname, fieldrules.simple, fieldrules.advanced)
            self.template_fields[fieldname] = getter
            if fieldname in Album._fields:
                self.album_template_fields[fieldname] = getter
