from __future__ import unicode_literals

import os
import re
from itertools import chain

from pip._vendor import six

from .click import unstyle
from .logging import log
from .utils import (
    UNSAFE_PACKAGES,
    comment,
    dedup,
    format_requirement,
    get_compile_command,
    key_from_ireq,
)

MESSAGE_UNHASHED_PACKAGE = comment(
    "# WARNING: pip install will require the following package to be hashed."
    "\n# Consider using a hashable URL like "
    "https://github.com/jazzband/pip-tools/archive/SOMECOMMIT.zip"
)

MESSAGE_UNSAFE_PACKAGES_UNPINNED = comment(
    "# WARNING: The following packages were not pinned, but pip requires them to be"
    "\n# pinned when the requirements file includes hashes. "
    "Consider using the --allow-unsafe flag."
)

MESSAGE_UNSAFE_PACKAGES = comment(
    "# The following packages are considered to be unsafe in a requirements file:"
)

MESSAGE_UNINSTALLABLE = (
    "The generated requirements file may be rejected by pip install. "
    "See # WARNING lines for details."
)


strip_comes_from_line_re = re.compile(r" \(line \d+\)$")


def _comes_from_as_string(ireq):
    if isinstance(ireq.comes_from, six.string_types):
        return strip_comes_from_line_re.sub("", ireq.comes_from)
    return key_from_ireq(ireq.comes_from)


class OutputWriter(object):
    def __init__(
        self,
        src_files,
        dst_file,
        click_ctx,
        dry_run,
        emit_header,
        emit_index_url,
        emit_trusted_host,
        annotate,
        generate_hashes,
        default_index_url,
        index_urls,
        trusted_hosts,
        format_control,
        allow_unsafe,
        find_links,
        emit_find_links,
    ):
        self.src_files = src_files
        self.dst_file = dst_file
        self.click_ctx = click_ctx
        self.dry_run = dry_run
        self.emit_header = emit_header
        self.emit_index_url = emit_index_url
        self.emit_trusted_host = emit_trusted_host
        self.annotate = annotate
        self.generate_hashes = generate_hashes
        self.default_index_url = default_index_url
        self.index_urls = index_urls
        self.trusted_hosts = trusted_hosts
        self.format_control = format_control
        self.allow_unsafe = allow_unsafe
        self.find_links = find_links
        self.emit_find_links = emit_find_links

    def _sort_key(self, ireq):
        return (not ireq.editable, str(ireq.req).lower())

    def write_header(self):
        if self.emit_header:
            yield comment("#")
            yield comment("# This file is autogenerated by pip-compile")
            yield comment("# To update, run:")
            yield comment("#")
            compile_command = os.environ.get(
                "CUSTOM_COMPILE_COMMAND"
            ) or get_compile_command(self.click_ctx)
            yield comment("#    {}".format(compile_command))
            yield comment("#")

    def write_index_options(self):
        if self.emit_index_url:
            for index, index_url in enumerate(dedup(self.index_urls)):
                if index_url.rstrip("/") == self.default_index_url:
                    continue
                flag = "--index-url" if index == 0 else "--extra-index-url"
                yield "{} {}".format(flag, index_url)

    def write_trusted_hosts(self):
        if self.emit_trusted_host:
            for trusted_host in dedup(self.trusted_hosts):
                yield "--trusted-host {}".format(trusted_host)

    def write_format_controls(self):
        for nb in dedup(sorted(self.format_control.no_binary)):
            yield "--no-binary {}".format(nb)
        for ob in dedup(sorted(self.format_control.only_binary)):
            yield "--only-binary {}".format(ob)

    def write_find_links(self):
        if self.emit_find_links:
            for find_link in dedup(self.find_links):
                yield "--find-links {}".format(find_link)

    def write_flags(self):
        emitted = False
        for line in chain(
            self.write_index_options(),
            self.write_find_links(),
            self.write_trusted_hosts(),
            self.write_format_controls(),
        ):
            emitted = True
            yield line
        if emitted:
            yield ""

    def _iter_lines(self, results, unsafe_requirements=None, markers=None, hashes=None):
        # default values
        unsafe_requirements = unsafe_requirements or []
        markers = markers or {}
        hashes = hashes or {}

        # Check for unhashed or unpinned packages if at least one package does have
        # hashes, which will trigger pip install's --require-hashes mode.
        warn_uninstallable = False
        has_hashes = hashes and any(hash for hash in hashes.values())

        yielded = False

        for line in self.write_header():
            yield line
            yielded = True
        for line in self.write_flags():
            yield line
            yielded = True

        unsafe_requirements = (
            {r for r in results if r.name in UNSAFE_PACKAGES}
            if not unsafe_requirements
            else unsafe_requirements
        )
        packages = {r for r in results if r.name not in UNSAFE_PACKAGES}

        if packages:
            packages = sorted(packages, key=self._sort_key)
            for ireq in packages:
                if has_hashes and not hashes.get(ireq):
                    yield MESSAGE_UNHASHED_PACKAGE
                    warn_uninstallable = True
                line = self._format_requirement(
                    ireq, markers.get(key_from_ireq(ireq)), hashes=hashes
                )
                yield line
            yielded = True

        if unsafe_requirements:
            unsafe_requirements = sorted(unsafe_requirements, key=self._sort_key)
            yield ""
            yielded = True
            if has_hashes and not self.allow_unsafe:
                yield MESSAGE_UNSAFE_PACKAGES_UNPINNED
                warn_uninstallable = True
            else:
                yield MESSAGE_UNSAFE_PACKAGES

            for ireq in unsafe_requirements:
                ireq_key = key_from_ireq(ireq)
                if not self.allow_unsafe:
                    yield comment("# {}".format(ireq_key))
                else:
                    line = self._format_requirement(
                        ireq, marker=markers.get(ireq_key), hashes=hashes
                    )
                    yield line

        # Yield even when there's no real content, so that blank files are written
        if not yielded:
            yield ""

        if warn_uninstallable:
            log.warning(MESSAGE_UNINSTALLABLE)

    def write(self, results, unsafe_requirements, markers, hashes):

        for line in self._iter_lines(results, unsafe_requirements, markers, hashes):
            log.info(line)
            if not self.dry_run:
                self.dst_file.write(unstyle(line).encode("utf-8"))
                self.dst_file.write(os.linesep.encode("utf-8"))

    def _format_requirement(self, ireq, marker=None, hashes=None):
        ireq_hashes = (hashes if hashes is not None else {}).get(ireq)

        line = format_requirement(ireq, marker=marker, hashes=ireq_hashes)

        if not self.annotate:
            return line

        # Annotate what packages or reqs-ins this package is required by
        required_by = set()
        if hasattr(ireq, "_source_ireqs"):
            required_by |= {
                _comes_from_as_string(src_ireq)
                for src_ireq in ireq._source_ireqs
                if src_ireq.comes_from
            }
        elif ireq.comes_from:
            required_by.add(_comes_from_as_string(ireq))
        if required_by:
            required_by = sorted(required_by)
            if len(required_by) == 1:
                source = required_by[0]
                annotation = "    # via " + source
            else:
                annotation_lines = ["    # via"]
                for source in required_by:
                    annotation_lines.append("    #   " + source)
                annotation = "\n".join(annotation_lines)
            line = "{}\n{}".format(line, comment(annotation))
        return line
