"""
Columnwise Column Generation Functions

Authors: Antony Phillips,  Dr Stuart Mitchell  2008
"""

# Import PuLP modeler functions
from pulp import *


class Pattern:
    """
    Information on a specific pattern in the SpongeRoll Problem
    """

    cost = 1
    trimValue = 0.04
    totalRollLength = 20
    lenOpts = ["5", "7", "9"]
    numPatterns = 0

    def __init__(self, name, lengths=None):
        self.name = name
        self.lengthsdict = dict(zip(self.lenOpts, lengths))
        Pattern.numPatterns += 1

    def __str__(self):
        return self.name

    def trim(self):
        return Pattern.totalRollLength - sum(
            [int(i) * int(self.lengthsdict[i]) for i in self.lengthsdict]
        )


def createMaster():

    rollData = {  # Length Demand SalePrice
        "5": [150, 0.25],
        "7": [200, 0.33],
        "9": [300, 0.40],
    }

    (rollDemand, surplusPrice) = splitDict(rollData)

    # The variable 'prob' is created
    prob = LpProblem("MasterSpongeRollProblem", LpMinimize)

    # The variable 'obj' is created and set as the LP's objective function
    obj = LpConstraintVar("Obj")
    prob.setObjective(obj)

    # The constraints are initialised and added to prob
    constraints = {}
    for l in Pattern.lenOpts:
        constraints[l] = LpConstraintVar("Min" + str(l), LpConstraintGE, rollDemand[l])
        prob += constraints[l]

    # The surplus variables are created
    surplusVars = []
    for i in Pattern.lenOpts:
        surplusVars += [
            LpVariable(
                "Surplus " + i,
                0,
                None,
                LpContinuous,
                -surplusPrice[i] * obj - constraints[i],
            )
        ]

    return prob, obj, constraints


def addPatterns(obj, constraints, newPatterns):

    # A list called Patterns is created to contain all the Pattern class
    # objects created in this function call
    Patterns = []
    for i in newPatterns:

        # The new patterns are checked to see that their length does not exceed
        # the total roll length
        lsum = 0
        for j, k in zip(i, Pattern.lenOpts):
            lsum += j * int(k)
        if lsum > Pattern.totalRollLength:
            raise ("Length Options too large for Roll")

        # The number of rolls of each length in each new pattern is printed
        print("P" + str(Pattern.numPatterns), "=", i)

        # The patterns are instantiated as Pattern objects
        Patterns += [Pattern("P" + str(Pattern.numPatterns), i)]

    # The pattern variables are created
    pattVars = []
    for i in Patterns:
        pattVars += [
            LpVariable(
                "Pattern " + i.name,
                0,
                None,
                LpContinuous,
                (i.cost - Pattern.trimValue * i.trim()) * obj
                + lpSum([constraints[l] * i.lengthsdict[l] for l in Pattern.lenOpts]),
            )
        ]


def masterSolve(prob, relax=True):

    # Unrelaxes the Integer Constraint
    if not relax:
        for v in prob.variables():
            v.cat = LpInteger

    # The problem is solved and rounded
    prob.solve(PULP_CBC_CMD())
    prob.roundSolution()

    if relax:
        # A dictionary of dual variable values is returned
        duals = {}
        for i, name in zip(Pattern.lenOpts, ["Min5", "Min7", "Min9"]):
            duals[i] = prob.constraints[name].pi
        return duals
    else:
        # A dictionary of variable values and the objective value are returned
        varsdict = {}
        for v in prob.variables():
            varsdict[v.name] = v.varValue

        return value(prob.objective), varsdict


def subSolve(duals):

    # The variable 'prob' is created
    prob = LpProblem("SubProb", LpMinimize)

    # The problem variables are created
    vars = LpVariable.dicts("Roll Length", Pattern.lenOpts, 0, None, LpInteger)

    trim = LpVariable("Trim", 0, None, LpInteger)

    # The objective function is entered: the reduced cost of a new pattern
    prob += (Pattern.cost - Pattern.trimValue * trim) - lpSum(
        [vars[i] * duals[i] for i in Pattern.lenOpts]
    ), "Objective"

    # The conservation of length constraint is entered
    prob += (
        lpSum([vars[i] * int(i) for i in Pattern.lenOpts]) + trim
        == Pattern.totalRollLength,
        "lengthEquate",
    )

    # The problem is solved
    prob.solve()

    # The variable values are rounded
    prob.roundSolution()

    newPatterns = []
    # Check if there are more patterns which would reduce the master LP objective function further
    if value(prob.objective) < -(10 ** -5):
        varsdict = {}
        for v in prob.variables():
            varsdict[v.name] = v.varValue
        # Adds the new pattern to the newPatterns list
        newPatterns += [
            [
                int(varsdict["Roll_Length_5"]),
                int(varsdict["Roll_Length_7"]),
                int(varsdict["Roll_Length_9"]),
            ]
        ]

    return newPatterns
