#! /usr/bin/env python
# -*- coding: utf-8 -*-
#
# graph_tool -- a general graph manipulation python module
#
# Copyright (C) 2006-2025 Tiago de Paula Peixoto <tiago@skewed.de>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from .. import _prop, Graph, libcore, _get_rng, PropertyMap

import numpy as np

from .. import group_vector_property, ungroup_vector_property

from .. dl_import import dl_import
dl_import("from . import libgraph_tool_inference as libinference")

from . blockmodel import *

@entropy_state_signature
class OverlapBlockState(BlockState):
    r"""The overlapping stochastic block model state of a given graph.

    Parameters
    ----------
    g : :class:`~graph_tool.Graph`
        Graph to be modelled.
    b : :class:`~graph_tool.VertexPropertyMap` or :class:`numpy.ndarray` (optional, default: ``None``)
        Initial block labels on the vertices or half-edges. If not supplied, it
        will be randomly sampled.
        If the value passed is a vertex property map, it will be assumed to be a
        non-overlapping partition of the vertices. If it is an edge property
        map, it should contain a vector for each edge, with the block labels at
        each end point (sorted according to their vertex index, in the case of
        undirected graphs, otherwise from source to target). If the value is an
        :class:`numpy.ndarray`, it will be assumed to correspond directly to a
        partition of the list of half-edges.
    B : ``int`` (optional, default: ``None``)
        Number of blocks (or vertex groups). If not supplied it will be obtained
        from the parameter ``b``.
    recs : list of :class:`~graph_tool.EdgePropertyMap` instances (optional, default: ``[]``)
        List of real or discrete-valued edge covariates.
    rec_types : list of edge covariate types (optional, default: ``[]``)
        List of types of edge covariates. The possible types are:
        ``"real-exponential"``, ``"real-normal"``, ``"discrete-geometric"``,
        ``"discrete-poisson"`` or ``"discrete-binomial"``.
    rec_params : list of ``dict`` (optional, default: ``[]``)
        Model hyperparameters for edge covariates. This should a list of
        ``dict`` instances. See :class:`~graph_tool.inference.BlockState` for
        more details.
    clabel : :class:`~graph_tool.VertexPropertyMap` (optional, default: ``None``)
        Constraint labels on the vertices. If supplied, vertices with different
        label values will not be clustered in the same group.
    deg_corr : ``bool`` (optional, default: ``True``)
        If ``True``, the degree-corrected version of the blockmodel ensemble will
        be assumed, otherwise the traditional variant will be used.
    dense_bg : ``bool`` (optional, default: ``False``)
        If ``True`` a dense matrix is used for the block graph, otherwise a
        sparse matrix will be used.
    entropy_args: ``dict`` (optional, default: ``{}``)
        Override default arguments for :meth:`~OverlapBlockState.entropy()`
        method and releated operations.
    """

    def __init__(self, g, b=None, B=None, recs=[], rec_types=[], rec_params=[],
                 clabel=None, pclabel=None, deg_corr=True, dense_bg=False,
                 entropy_args={}, **kwargs):
        EntropyState.__init__(self, entropy_args=entropy_args)

        kwargs = kwargs.copy()

        # determine if there is a base graph, and overlapping structure
        self.base_g = kwargs.pop("base_g", None)

        # overlapping information
        node_index = kwargs.pop("node_index", None)
        node_in_degs = kwargs.pop("node_in_degs", None)
        node_out_degs = kwargs.pop("node_out_degs", None)
        half_edges = kwargs.pop("half_edges", None)
        eindex = kwargs.pop("eindex", None)

        if node_index is not None and self.base_g is None:
            raise ValueError("Must specify base graph if node_index is specified...")

        if clabel is None:
            clabel = pclabel

        if b is None:
            b = clabel

        if B is None and b is None:
            B = 1

        # create overlapping structure
        if node_index is None:
            # keep base graph
            self.base_g = g

            if len(recs) == 0:
                rec = self.base_g.new_ep("vector<double>")
            else:
                recs = [x.copy("double") for x in recs]
                rec = group_vector_property(recs)

            # substitute provided graph by its half-edge graph
            g, b, node_index, half_edges, eindex, rec = \
                                                half_edge_graph(g, b, B, rec)

            if len(recs) > 0:
                recs = ungroup_vector_property(rec, range(len(recs)))

        # create half edges set if absent
        if half_edges is None:
            half_edges = self.base_g.new_vertex_property("vector<int64_t>")
            libinference.get_nodeset_overlap(g._Graph__graph,
                                             _prop("v", g, node_index),
                                             _prop("v", self.base_g, half_edges))

        self.overlap = True
        self.node_index = node_index
        self.half_edges = half_edges
        self.eindex = eindex

        # configure the main graph and block model parameters
        self.g = g

        self.deg_corr = deg_corr

        self.is_edge_weighted = False
        self.is_vertex_weighted = False
        self.is_weighted = False

        if b is None:
            # create a random partition into B blocks.
            B = min((B, self.g.num_vertices()))
            ba = np.random.randint(0, B, self.g.num_vertices())
            ba[:B] = np.random(B)        # avoid empty blocks
            if B < self.g.num_vertices():
                np.random.shuffle(ba)
            b = g.new_vertex_property("int")
            b.fa = ba
            self.b = b
        else:
            # if a partition is available, we will incorporate it.
            # in the overlapping case
            # at this point, *b* must correspond to the partition of
            # *half-edges*
            if isinstance(b, np.ndarray):
                self.b = g.new_vertex_property("int")
                self.b.fa = b
            else:
                b = b.copy(value_type="int")
                b = g.own_property(b)
                self.b = b
            if B is None:
                B = int(self.b.fa.max()) + 1

        if self.b.fa.max() >= B:
            raise ValueError("Maximum value of b is larger or equal to B!")

        self.rec = [self.g.own_property(p) for p in recs]
        for i in range(len(self.rec)):
            if self.rec[i].value_type() != "double":
                self.rec[i] = self.rec[i].copy("double")
        self.drec = kwargs.pop("drec", None)
        if self.drec is None:
            self.drec = []
            for rec in self.rec:
                self.drec.append(self.g.new_ep("double", rec.fa ** 2))
        else:
            self.drec = [self.g.own_property(p) for p in self.drec]

        rec_types = list(rec_types)
        rec_params = list(rec_params)

        # if len(rec_params) < len(rec_types):
        #     rec_params += [{} for i in range((len(rec_types) -
        #                                       len(rec_params)))]

        if len(self.rec) > 0 and rec_types[0] != libinference.rec_type.count:
            rec_types.insert(0, libinference.rec_type.count)
            rec_params.insert(0, {})
            self.rec.insert(0, self.g.new_ep("double", 1))
            self.drec.insert(0, self.g.new_ep("double"))

        # Construct block-graph
        self.bg = get_block_graph(g, B, self.b, rec=self.rec, drec=self.drec)
        self.bg.set_fast_edge_removal()

        self.mrs = self.bg.ep["count"]
        self.wr = self.bg.vp["count"]

        self.mrp = self.bg.degree_property_map("out", weight=self.mrs)

        if g.is_directed():
            self.mrm = self.bg.degree_property_map("in", weight=self.mrs)
        else:
            self.mrm = self.mrp

        if pclabel is not None:
            if isinstance(pclabel, PropertyMap):
                self.pclabel = self.g.own_property(pclabel).copy("int")
            else:
                self.pclabel = self.g.new_vp("int")
                self.pclabel.fa = pclabel
        else:
            self.pclabel = self.g.new_vp("int")

        if clabel is not None:
            if isinstance(clabel, PropertyMap):
                self.clabel = self.g.own_property(clabel).copy("int")
            else:
                self.clabel = self.g.new_vp("int")
                self.clabel.fa = clabel
        elif self.pclabel.fa.max() > 0:
            self.clabel = self.pclabel
        else:
            self.clabel = self.g.new_vp("int")

        self.bclabel = self.get_bclabel()
        self.hclabel = self.bg.new_vp("int")

        BlockState._init_recs(self, self.rec, rec_types, rec_params)
        self.recdx = libcore.Vector_double(len(self.rec))
        self.Lrecdx = kwargs.pop("Lrecdx", None)
        if self.Lrecdx is None:
            self.Lrecdx = libcore.Vector_double(len(self.rec)+1)
            self.Lrecdx[0] = -1
        self.Lrecdx.resize(len(self.rec)+1)
        self.epsilon = kwargs.pop("epsilon", None)
        if self.epsilon is None:
            self.epsilon = libcore.Vector_double(len(self.rec))
            for i in range(len(self.rec)):
                idx = self.rec[i].a != 0
                if np.any(idx):
                    self.epsilon[i] = abs(self.rec[i].a[idx]).min() / 10

        self.dense_bg = dense_bg
        self.use_hash = not self.dense_bg

        self.bfield = self.g.new_vp("vector<double>")
        self.Bfield = Vector_double()

        self._abg = self.bg._get_any()
        self._state = libinference.make_overlap_block_state(self)

        if deg_corr:
            init_q_cache(max((self.get_E(), self.get_N())) + 1)

        self._coupled_state = None

        vweight = kwargs.pop("vweight", "unity")
        eweight = kwargs.pop("eweight", "unity")

        if vweight != "unity":
            kwargs["vweight"] = vweight
        if eweight != "unity":
            kwargs["eweight"] = eweight

        if len(kwargs) > 0:
            warnings.warn("unrecognized keyword arguments: " +
                          str(list(kwargs.keys())))

    def __repr__(self):
        return "<OverlapBlockState object with %d blocks,%s%s for graph %s, at 0x%x>" % \
            (self.get_B(), " degree corrected," if self.deg_corr else "",
             ((" with %d edge covariate%s," % (len(self.rec_types) - 1,
                                               "s" if len(self.rec_types) > 2 else ""))
              if len(self.rec_types) > 0 else ""),
             str(self.base_g), id(self))

    def __copy__(self):
        return self.copy()

    def copy(self, g=None, b=None, B=None, deg_corr=None, clabel=None,
             pclabel=None, **kwargs):
        r"""Copies the block state. The parameters override the state properties, and
         have the same meaning as in the constructor. If ``overlap=False`` an
         instance of :class:`~graph_tool.inference.BlockState` is returned. This
         is by default a shallow copy."""

        state = OverlapBlockState(self.g if g is None else g,
                                  b=self.b if b is None else b,
                                  B=(self.get_B() if b is None else None) if B is None else B,
                                  clabel=self.clabel.fa if clabel is None else clabel,
                                  pclabel=self.pclabel if pclabel is None else pclabel,
                                  deg_corr=self.deg_corr if deg_corr is None else deg_corr,
                                  recs=kwargs.pop("recs", self.rec),
                                  drec=kwargs.pop("drec", self.drec),
                                  rec_types=kwargs.pop("rec_types", self.rec_types),
                                  rec_params=kwargs.pop("rec_params",
                                                        self.rec_params),
                                  half_edges=kwargs.get("half_edges", self.half_edges),
                                  node_index=kwargs.get("node_index", self.node_index),
                                  eindex=kwargs.get("eindex", self.eindex),
                                  dense_bg=kwargs.get("dense_bg", self.dense_bg),
                                  base_g=kwargs.get("base_g", self.base_g),
                                  Lrecdx=kwargs.pop("Lrecdx", self.Lrecdx.copy()),
                                  epsilon=kwargs.pop("epsilon",
                                                     self.epsilon.copy()),
                                  **dmask(kwargs, ["half_edges", "node_index",
                                                   "eindex", "base_g", "drec",
                                                   "dense_bg"]))
        if self._coupled_state is not None:
            state._couple_state(state.get_block_state(b=state.get_bclabel(),
                                                      vweight="nonempty",
                                                      copy_bg=False,
                                                      Lrecdx=state.Lrecdx),
                                self._coupled_state[1])
        return state

    def __getstate__(self):
        state = EntropyState.__getstate__(self)
        state = dict(state,
                     g=self.g,
                     b=self.b,
                     B=self.get_B(),
                     clabel=np.array(self.clabel.fa),
                     deg_corr=self.deg_corr,
                     recs=self.rec,
                     drec=self.drec,
                     rec_types=list(self.rec_types),
                     rec_params=self.rec_params,
                     half_edges=self.half_edges,
                     node_index=self.node_index,
                     eindex=self.eindex,
                     dense_bg=self.dense_bg,
                     base_g=self.base_g)
        return state

    def __setstate__(self, state):
        self.__init__(**state)

    def get_E(self):
        r"Returns the total number of edges."
        return self.g.num_edges()

    def get_N(self):
        r"Returns the total number of nodes."
        return self.base_g.num_vertices()

    def get_B(self):
        r"Returns the total number of blocks."
        return self.bg.num_vertices()

    def get_nonempty_B(self):
        r"Returns the total number of nonempty blocks."
        return int((self.wr.a > 0).sum())

    def get_edge_blocks(self):
        r"""Returns an edge property map which contains the block labels pairs for each
        edge."""
        be = self.base_g.new_edge_property("vector<int>")
        self._state.get_be_overlap(self.base_g._Graph__graph,
                                    _prop("e", self.base_g, be))
        return be

    def get_overlap_blocks(self):
        r"""Returns the mixed membership of each vertex.

        Returns
        -------
        bv : :class:`~graph_tool.VertexPropertyMap`
           A vector-valued vertex property map containing the block memberships
           of each node.
        bc_in : :class:`~graph_tool.VertexPropertyMap`
           The labelled in-degrees of each node, i.e. how many in-edges belong
           to each group, in the same order as the ``bv`` property above.
        bc_out : :class:`~graph_tool.VertexPropertyMap`
           The labelled out-degrees of each node, i.e. how many out-edges belong
           to each group, in the same order as the ``bv`` property above.
        bc_total : :class:`~graph_tool.VertexPropertyMap`
           The labelled total degrees of each node, i.e. how many incident edges
           belong to each group, in the same order as the ``bv`` property above.

        """
        bv = self.base_g.new_vertex_property("vector<int>")
        bc_in = self.base_g.new_vertex_property("vector<int>")
        bc_out = self.base_g.new_vertex_property("vector<int>")
        bc_total = self.base_g.new_vertex_property("vector<int>")
        self._state.get_bv_overlap(self.base_g._Graph__graph,
                                    _prop("v", self.base_g, bv),
                                    _prop("v", self.base_g, bc_in),
                                    _prop("v", self.base_g, bc_out),
                                    _prop("v", self.base_g, bc_total))
        return bv, bc_in, bc_out, bc_total

    def get_nonoverlap_blocks(self):
        r"""Returns a scalar-valued vertex property map with the block mixture
        represented as a single number."""

        bv = self.get_overlap_blocks()[0]
        b = self.base_g.new_vertex_property("int")
        self._state.get_overlap_split(self.base_g._Graph__graph,
                                       _prop("v", self.base_g, bv),
                                       _prop("v", self.base_g, b))
        return b

    def get_majority_blocks(self):
        r"""Returns a scalar-valued vertex property map with the majority block
        membership of each node."""

        bv = self.get_overlap_blocks()
        bv, bc = bv[0], bv[-1]
        b = self.base_g.new_vertex_property("int")
        self._state.get_maj_overlap(self.base_g._Graph__graph,
                                     _prop("v", self.base_g, bv),
                                     _prop("v", self.base_g, bc),
                                     _prop("v", self.base_g, b))
        return b

    def get_bclabel(self, clabel=None):
        r"""Returns a :class:`~graph_tool.VertexPropertyMap` corresponding to constraint
        labels for the block graph."""

        bclabel = self.bg.new_vertex_property("int")
        reverse_map(self.b, bclabel)
        if clabel is None:
            clabel = self.clabel
        pmap(bclabel, clabel)
        return bclabel

    @copy_state_wrap
    def _entropy(self, adjacency=True, dl=True, partition_dl=True,
                 degree_dl=True, degree_dl_kind="distributed", edges_dl=True,
                 dense=False, multigraph=True, deg_entropy=True, recs=True,
                 recs_dl=True, beta_dl=1., Bfield=False, exact=True, **kwargs):
        r"""Calculate the entropy associated with the current block partition.

        Parameters
        ----------
        adjacency : ``bool`` (optional, default: ``True``)
            If ``True``, the adjacency term of the description length will be
            included.
        dl : ``bool`` (optional, default: ``True``)
            If ``True``, the description length for the parameters will be
            included.
        partition_dl : ``bool`` (optional, default: ``True``)
            If ``True``, and ``dl == True`` the partition description length
            will be included.
        degree_dl : ``bool`` (optional, default: ``True``)
            If ``True``, and ``dl == True`` the degree sequence description
            length will be included (for degree-corrected models).
        degree_dl_kind : ``str`` (optional, default: ``"distributed"``)
            This specifies the prior used for the degree sequence. It must be
            one of: ``"uniform"``, ``"distributed"`` (default) or ``"entropy"``.
        edges_dl : ``bool`` (optional, default: ``True``)
            If ``True``, and ``dl == True`` the edge matrix description length
            will be included.
        dense : ``bool`` (optional, default: ``False``)
            If ``True``, the "dense" variant of the entropy will be computed.
        multigraph : ``bool`` (optional, default: ``True``)
            If ``True``, the multigraph entropy will be used.
        deg_entropy : ``bool`` (optional, default: ``True``)
            If ``True``, the degree entropy term that is independent of the
            network partition will be included (for degree-corrected models).
        recs : ``bool`` (optional, default: ``True``)
            If ``True``, the likelihood for real or discrete-valued edge
            covariates is computed.
        recs_dl : ``bool`` (optional, default: ``True``)
            If ``True``, and ``dl == True`` the edge covariate description
            length will be included.
        beta_dl : ``double`` (optional, default: ``1.``)
            Prior inverse temperature.
        Bfield : ``bool`` (optional, default: ``False``)
            If True, the ``Bfield`` parameter passed to the construtor will be
            taken into account.
        exact : ``bool`` (optional, default: ``True``)
            If ``True``, the exact expressions will be used. Otherwise,
            Stirling's factorial approximation will be used for some terms.

        Notes
        -----

        The "entropy" of the state is minus the log-likelihood of the
        microcanonical SBM, that includes the generated graph
        :math:`\boldsymbol{A}` and the model parameters :math:`\boldsymbol{\theta}`,

        .. math::

           \mathcal{S} &= - \ln P(\boldsymbol{A},\boldsymbol{\theta}) \\
                       &= - \ln P(\boldsymbol{A}|\boldsymbol{\theta}) - \ln P(\boldsymbol{\theta}).

        This value is also called the `description length
        <https://en.wikipedia.org/wiki/Minimum_description_length>`_ of the data,
        and it corresponds to the amount of information required to describe it
        (in `nats <https://en.wikipedia.org/wiki/Nat_(unit)>`_).

        For the traditional blockmodel (``deg_corr == False``), the model
        parameters are :math:`\boldsymbol{\theta} = \{\boldsymbol{e},
        \boldsymbol{b}\}`, where :math:`\boldsymbol{e}` is the matrix of edge
        counts between blocks, and :math:`\boldsymbol{b}` is the `overlapping`
        partition of the nodes into blocks. For the degree-corrected blockmodel
        (``deg_corr == True``), we have an additional set of parameters, namely
        the `labelled` degree sequence :math:`\boldsymbol{k}`.

        The model likelihood :math:`P(\boldsymbol{A}|\theta)` is given
        analogously to the non-overlapping case, as described in
        :meth:`graph_tool.inference.BlockState.entropy`.

        If ``dl == True``, the description length :math:`\mathcal{L} = -\ln
        P(\boldsymbol{\theta})` of the model will be returned as well. The
        edge-count prior :math:`P(\boldsymbol{e})` is described in described in
        :meth:`~graph_tool.inference.BlockState.entropy`. For the
        overlapping partition :math:`P(\boldsymbol{b})`, we have

        .. math::

           -\ln P(\boldsymbol{b}) = \ln\left(\!\!{D \choose N}\!\!\right) + \sum_d \ln {\left(\!\!{{B\choose d}\choose n_d}\!\!\right)} + \ln N! - \sum_{\vec{b}}\ln n_{\vec{b}}!,

        where :math:`d \equiv |\vec{b}|_1 = \sum_rb_r` is the mixture
        size, :math:`n_d` is the number of nodes in a mixture of size :math:`d`,
        :math:`D` is the maximum value of :math:`d`, :math:`n_{\vec{b}}` is the
        number of nodes in mixture :math:`\vec{b}`.


        For the degree-corrected model we need to specify the prior
        :math:`P(\boldsymbol{k})` for the `labelled` degree sequence as well:

        .. math::

            -\ln P(\boldsymbol{k}) = \sum_r\ln\left(\!\!{m_r \choose e_r}\!\!\right) - \sum_{\vec{b}}\ln P(\boldsymbol{k}|{\vec{b}}),

        where :math:`m_r` is the number of non-empty mixtures which contain type
        :math:`r`, and :math:`P(\boldsymbol{k}|{\vec{b}})` is the likelihood of
        the labelled degree sequence inside mixture :math:`\vec{b}`. For this
        term we have three options:

        1. ``degree_dl_kind == "uniform"``

            .. math::

                P(\boldsymbol{k}|\vec{b}) = \prod_r\left(\!\!{n_{\vec{b}}\choose e^r_{\vec{b}}}\!\!\right)^{-1}.

        2. ``degree_dl_kind == "distributed"``

            .. math::

                P(\boldsymbol{k}|\vec{b}) = \prod_{\vec{b}}\frac{\prod_{\vec{k}}\eta_{\vec{k}}^{\vec{b}}!}{n_{\vec{b}}!} \prod_r q(e_{\vec{b}}^r - n_{\vec{b}}, n_{\vec{b}})

            where :math:`n^{\vec{b}}_{\vec{k}}` is the number of nodes in
            mixture :math:`\vec{b}` with labelled degree :math:`\vec{k}`, and
            :math:`q(n,m)` is the number of `partitions
            <https://en.wikipedia.org/wiki/Partition_(number_theory)>`_ of
            integer :math:`n` into at most :math:`m` parts.

        3. ``degree_dl_kind == "entropy"``

            .. math::

                P(\boldsymbol{k}|\vec{b}) = \prod_{\vec{b}}\exp\left(-n_{\vec{b}}H(\boldsymbol{k}_{\vec{b}})\right)

            where :math:`H(\boldsymbol{k}_{\vec{b}}) =
            -\sum_{\vec{k}}p_{\vec{b}}(\vec{k})\ln p_{\vec{b}}(\vec{k})` is the
            entropy of the labelled degree distribution inside mixture
            :math:`\vec{b}`.

            Note that, differently from the other two choices, this represents
            only an approximation of the description length. It is meant to be
            used only for comparison purposes, and should be avoided in practice.


        For the directed case, the above expressions are duplicated for the in-
        and out-degrees.

        """

        eargs = self._get_entropy_args(locals())

        S = self._state.entropy(eargs, kwargs.pop("propagate", False))

        kwargs.pop("test", None)
        if len(kwargs) > 0:
            raise ValueError("unrecognized keyword arguments: " +
                             str(list(kwargs.keys())))
        return S

    def _clear_egroups(self):
        self._state.clear_egroups()

    def _mcmc_sweep_dispatch(self, mcmc_state):
        dS, nattempts, nmoves = \
                    libinference.overlap_mcmc_sweep(mcmc_state, self._state,
                                                    _get_rng())
        if self.__bundled:
            ret = libinference.overlap_mcmc_bundled_sweep(mcmc_state,
                                                          self._state,
                                                          _get_rng())
            dS += ret[0]
            nattempts += ret[1]
            nmoves += ret[2]
        del self.__bundled
        return dS, nattempts, nmoves

    def _mcmc_sweep_parallel_dispatch(states, mcmc_states):
        return libinference.overlap_mcmc_sweep_parallel(mcmc_states,
                                                        [s._state for s in states],
                                                        _get_rng())

    def mcmc_sweep(self, bundled=False, **kwargs):
        r"""Perform sweeps of a Metropolis-Hastings rejection sampling MCMC to sample
        network partitions. If ``bundled == True``, the half-edges incident of
        the same node that belong to the same group are moved together. All
        remaining parameters are passed to
        :meth:`graph_tool.inference.BlockState.mcmc_sweep`."""
        self.__bundled = bundled
        return BlockState.mcmc_sweep(self, **kwargs)

    def _multiflip_mcmc_sweep_dispatch(self, mcmc_state):
        return libinference.overlap_multiflip_mcmc_sweep(mcmc_state,
                                                         self._state,
                                                         _get_rng())

    def _multiflip_mcmc_sweep_parallel_dispatch(states, mcmc_states):
        return libinference.overlap_multiflip_mcmc_sweep_parallel(mcmc_states,
                                                                  [s._state for s in states],
                                                                  _get_rng())

    def _get_bclabel(self):
        return self.bclabel

    def _multilevel_mcmc_sweep_dispatch(self, mcmc_state):
        return libinference.overlap_multilevel_mcmc_sweep(mcmc_state, self._state,
                                                          _get_rng())

    def _multilevel_mcmc_sweep_parallel_dispatch(states, mcmc_states):
        return libinference.overlap_multilevel_mcmc_sweep_parallel(mcmc_states,
                                                                   [s._state for s in states],
                                                                   _get_rng())

    def _multicanonical_sweep_dispatch(self, multicanonical_state):
        if multicanonical_state.multiflip:
            return libinference.overlap_multicanonical_sweep(multicanonical_state,
                                                             self._state,
                                                             _get_rng())
        else:
            return libinference.overlap_multicanonical_multiflip_sweep(multicanonical_state,
                                                                       self._state,
                                                                       _get_rng())

    def _exhaustive_sweep_dispatch(self, exhaustive_state, callback, hist):
        if callback is not None:
            return libinference.overlap_exhaustive_sweep(exhaustive_state,
                                                         self._state, callback)
        else:
            if hist is None:
                return libinference.overlap_exhaustive_sweep_iter(exhaustive_state,
                                                                  self._state)
            else:
                return libinference.overlap_exhaustive_dens(exhaustive_state,
                                                            self._state,
                                                            hist[0], hist[1],
                                                            hist[2])

    def _gibbs_sweep_dispatch(self, gibbs_state):
        return libinference.gibbs_overlap_sweep(gibbs_state, self._state,
                                                _get_rng())

    def _gibbs_sweep_parallel_dispatch(states, gibbs_states):
        return libinference.overlap_gibbs_sweep_parallel(gibbs_states,
                                                         [s._state for s in states],
                                                         _get_rng())

    def _merge_sweep_dispatch(self, merge_state):
        return libinference.vacate_overlap_sweep(merge_state, self._state,
                                                 _get_rng())

    def draw(self, **kwargs):
        r"""Convenience wrapper to :func:`~graph_tool.draw.graph_draw` that
        draws the state of the graph as colors on the vertices and edges."""

        bv, bc_in, bc_out, bc_total = self.get_overlap_blocks()
        if self.deg_corr:
            pie_fractions = bc_total.copy("vector<double>")
        else:
            pie_fractions = self.base_g.new_vp("vector<double>",
                                               vals=[ones(len(bv[v])) for v
                                                     in self.base_g.vertices()])

        gradient = kwargs.get("edge_gradient",
                              get_block_edge_gradient(self.base_g,
                                                      self.get_edge_blocks(),
                                                      cmap=kwargs.get("ecmap",
                                                                      None)))
        from graph_tool.draw import graph_draw
        return graph_draw(self.base_g,
                          vertex_shape=kwargs.get("vertex_shape", "pie"),
                          vertex_pie_colors=kwargs.get("vertex_pie_colors",
                                                       bv),
                          vertex_pie_fractions=kwargs.get("vertex_pie_fractions",
                                                          pie_fractions),
                          edge_gradient=gradient,
                          **dmask(kwargs, ["vertex_shape", "vertex_pie_colors",
                                           "vertex_pie_fractions",
                                           "edge_gradient"]))


def half_edge_graph(g, b=None, B=None, rec=None):
    r"""Generate a half-edge graph, where each half-edge is represented by a node,
    and an edge connects the half-edges like in the original graph."""

    E = g.num_edges()

    b_array = None
    if b is None:
        # if no partition is given, obtain a random one.
        ba = np.random.randint(0, B, 2 * E)
        ba[:B] = np.arange(B)        # avoid empty blocks
        if B < len(ba):
            np.random.shuffle(ba)
        b = ba

    if isinstance(b, np.ndarray):
        # if given an array, assume it corresponds to the *final* half-edge
        # partitions
        b_array = b
        b = g.new_vertex_property("int")

    if b.key_type() == "v":
        # If a vertex partition is given, we convert it into a
        # non-overlapping edge partition
        be = g.new_edge_property("vector<int>")
        b = b.copy("int")
        libinference.get_be_from_b_overlap(g._Graph__graph,
                                           _prop("e", g, be),
                                           _prop("v", g, b))
        b = be
    else:
        # If an half-edge partition is provided, we incorporate it
        b = b.copy(value_type="vector<int32_t>")

    if B is None:
        if b_array is None:
            bs, bt = ungroup_vector_property(b, [0, 1])
            B = int(max(bs.fa.max(), bt.fa.max())) + 1
        else:
            B = b_array.max() + 1

    bs, bt = ungroup_vector_property(b, [0, 1])

    if bs.fa.max() >= B or bt.fa.max() >= B or (b_array is not None and b_array.max() >= B):
        raise ValueError("Maximum value of b is larger or equal to B!")

    eg = Graph(directed=g.is_directed())
    node_index = eg.new_vertex_property("int64_t")
    half_edges = g.new_vertex_property("vector<int64_t>")
    be = eg.new_vertex_property("int")
    eindex = eg.new_edge_property("int64_t")
    erec = eg.new_edge_property("vector<double>")

    if rec is None:
        rec_ = g.new_edge_property("vector<double>")
    else:
        rec_ = g.own_property(rec)

    # create half-edge graph
    libinference.get_eg_overlap(g._Graph__graph,
                                eg._Graph__graph,
                                _prop("e", g, b),
                                _prop("v", eg, be),
                                _prop("v", eg, node_index),
                                _prop("v", g, half_edges),
                                _prop("e", eg, eindex),
                                _prop("e", g, rec_),
                                _prop("e", eg, erec))

    if b_array is not None:
        be.a = b_array

    if rec is None:
        erec = None

    return eg, be, node_index, half_edges, eindex, erec

def augmented_graph(g, b, node_index, eweight=None):
    r"""Generates an augmented graph from the half-edge graph ``g`` partitioned
    according to ``b``, where each half-edge belonging to a different group
    inside each node forms a new node."""

    node_map = g.new_vertex_property("int")
    br_b = libcore.Vector_int32_t()
    br_ni = libcore.Vector_int32_t()
    libinference.get_augmented_overlap(g._Graph__graph,
                                       _prop("v", g, b),
                                       _prop("v", g, node_index),
                                       _prop("v", g, node_map),
                                       br_b, br_ni)


    au, vcount, ecount = condensation_graph(g, node_map,
                                            eweight=eweight)[:3]
    anidx = au.vertex_index.copy("int")
    libinference.vector_map(anidx.a, br_ni.a)

    ab = au.vertex_index.copy("int")
    libinference.vector_map(ab.a, br_b.a)

    return au, ab, anidx, ecount, node_map

def get_block_edge_gradient(g, be, cmap=None):
    r"""Get edge gradients corresponding to the block membership at the endpoints of
    the edges given by the ``be`` edge property map.

    Parameters
    ----------
    g : :class:`~graph_tool.Graph`
        The graph.
    be : :class:`~graph_tool.EdgePropertyMap`
        Vector-valued edge property map with the block membership at each
        endpoint.
    cmap : :class:`matplotlib.colors.Colormap` (optional, default: ``default_cm``)
        Color map used to construct the gradient.

    Returns
    -------
    cp : :class:`~graph_tool.EdgePropertyMap`
       A vector-valued edge property map containing a color gradient.
    """

    if cmap is None:
        from .. draw import default_cm
        cmap = default_cm

    cp = g.new_edge_property("vector<double>")
    rg = [np.inf, -np.inf]
    for e in g.edges():
        s, t = be[e]
        rg[0] = min((s, rg[0]))
        rg[0] = min((t, rg[0]))
        rg[1] = max((s, rg[1]))
        rg[1] = max((t, rg[1]))

    for e in g.edges():
        if int(e.source()) < int(e.target()) or g.is_directed():
            s, t = be[e]
        else:
            t, s = be[e]
        cs = cmap((s - rg[0]) / max((rg[1] - rg[0], 1)))
        ct = cmap((t - rg[0]) / max((rg[1] - rg[0], 1)))
        cp[e] = [0] + list(cs) + [1] + list(ct)
    return cp
