1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
|
import sympy
import math
class Expression:
def __init__(self):
pass
def __add__(self, other):
return Sum([self, other])
def __radd__(self, other):
return Sum([other, self])
def __mul__(self, other):
return Product([self, other])
def __rmul__(self, other):
return Product([other, self])
class Variable(Expression):
def __init__(self, id, name, integer=True):
Expression.__init__(self)
self.id = id
self.name = name
self.integer = integer
def __iter__(self):
yield self
def __eq__(self, other):
return isinstance(other, Variable) and self.id == other.id
def __hash__(self):
return hash(self.id)
def __str__(self):
return f"{self.name}"
def sympy(self):
return sympy.Symbol(self.id, integer=self.integer)
class Constant(Expression):
def __init__(self, value):
Expression.__init__(self)
self.value = value
def __iter__(self):
yield self
def __eq__(self, other):
return isinstance(other, Constant) and self.value == other.value
def __hash__(self):
return hash(self.value)
def __str__(self):
return str(self.value)
def sympy(self):
return self.value
class Sum(Expression):
@staticmethod
def maybe(children):
if len(children) == 0:
return Constant(0)
elif len(children) == 1:
return children[0]
elif all(isinstance(c, Constant) for c in children):
return Constant(sum(c.value for c in children))
else:
return Sum(children)
def __init__(self, children):
Expression.__init__(self)
self.children = [to_term(c) for c in children]
def __iter__(self):
yield self
for child in self.children:
yield from child
def __eq__(self, other):
return isinstance(other, Sum) and self.children == other.children
def __hash__(self):
return hash(tuple(self.children))
def __str__(self):
return " + ".join(str(c) for c in self.children)
def sympy(self):
return sum([c.sympy() for c in self.children])
class Product(Expression):
@staticmethod
def maybe(children):
if len(children) == 0:
return Constant(1)
elif len(children) == 1:
return children[0]
elif all(isinstance(c, Constant) for c in children):
return Constant(math.prod(c.value for c in children))
else:
return Product(children)
def __init__(self, children):
Expression.__init__(self)
self.children = [to_term(c) for c in children]
def __iter__(self):
yield self
for child in self.children:
yield from child
def __eq__(self, other):
return isinstance(other, Product) and self.children == other.children
def __hash__(self):
return hash(tuple(self.children))
def __str__(self):
return " * ".join(str(c) for c in self.children)
def sympy(self):
return math.prod([c.sympy() for c in self.children])
def to_term(x):
if isinstance(x, int):
return Constant(x)
else:
if not isinstance(x, Expression):
raise TypeError(f"Expected Expression, got {type(x)}")
return x
class SolveException(Exception):
def __init__(self, message):
super().__init__(message)
def solve(equations):
equations = [(to_term(t1), to_term(t2)) for t1, t2 in equations]
equations = [(t1, t2) for t1, t2 in equations if t1 != t2]
equations = list(set(equations))
variables = {
v.id: v for terms in equations for term in terms for v in term if isinstance(v, Variable)
}
# Find equivalence classes of variables to speed up sympy solver #####
# Find constant definitions
constants = {} # id: constant value
for t1, t2 in equations:
if isinstance(t1, Variable) and isinstance(t2, Constant):
if constants.get(t1.id, t2.value) != t2.value:
raise SolveException(
f"Found contradictory values { {constants[t1.id], t2.value} } for "
f"expression '{t1.name}'"
)
constants[t1.id] = t2.value
elif isinstance(t1, Constant) and isinstance(t2, Variable):
if constants.get(t2.id, t1.value) != t1.value:
raise SolveException(
f"Found contradictory values { {constants[t2.id], t1.value} } for "
f"expression '{t2.name}'"
)
constants[t2.id] = t1.value
elif isinstance(t1, Constant) and isinstance(t2, Constant):
if t1.value != t2.value:
raise SolveException(
f"Found contradictory values {t1.value} != {t2.value} in input equation"
)
# Find equivalence classes of variables
classes = {v: {v} for v in variables} # id: set of equivalent ids
for t1, t2 in equations:
if isinstance(t1, Variable) and isinstance(t2, Variable):
assert t1.id in classes and t2.id in classes
set1 = classes[t1.id]
set2 = classes[t2.id]
for t_id in set2:
classes[t_id] = set1
set1.add(t_id)
# For every class: Use constant if it exists, or create single class variable
origvar_to_solvevar = {} # id: Variable or Constant
for eclass in {id(s): s for s in classes.values()}.values():
if any(n in constants for n in eclass):
# Use constant
class_constants = {constants[n] for n in eclass if n in constants}
if len(class_constants) != 1:
names = {variables[a].name for a in eclass}
if len(names) == 1:
raise SolveException(
f"Found contradictory values {class_constants} for expression "
f"'{next(iter(names))}'"
)
else:
raise SolveException(
f"Found contradictory values {class_constants} for equivalent "
f"expressions {names}"
)
v = Constant(next(iter(class_constants)))
else:
# Create new variable for class
v = Variable(
f"Class-{id(eclass)}",
f"Equivalent expressions { {variables[a].name for a in eclass} }",
)
for n in eclass:
assert n not in origvar_to_solvevar
origvar_to_solvevar[n] = v
# Apply to equations
def replace(t):
if isinstance(t, Variable) and t.id in origvar_to_solvevar:
return origvar_to_solvevar[t.id]
elif isinstance(t, Constant):
return t
elif isinstance(t, Sum):
return Sum.maybe([replace(c) for c in t.children])
elif isinstance(t, Product):
return Product.maybe([replace(c) for c in t.children])
else:
raise AssertionError()
equations2 = []
for t1o, t2o in equations:
t1 = replace(t1o)
t2 = replace(t2o)
if isinstance(t1, Constant) and isinstance(t2, Constant):
if t1.value != t2.value:
raise SolveException(
f"Found contradictory values {t1.value} != {t2.value} "
"for same equivalence class"
)
elif t1 != t2:
equations2.append((t1, t2))
equations = equations2
# Solve remaining equations using sympy #####
solutions = {}
if len(equations) > 0:
sympy_equations = [sympy.Eq(t1.sympy(), t2.sympy()) for t1, t2 in equations]
if all(eq.is_Boolean and bool(eq) for eq in sympy_equations):
solutions = {}
else:
solutions = sympy.solve(sympy_equations, set=True, manual=True)
if solutions == []:
solutions = {}
elif isinstance(solutions, tuple) and len(solutions) == 2:
variables, solutions = solutions
if len(solutions) == 0:
raise SolveException("Sympy returned no solutions")
elif len(solutions) > 1:
raise SolveException("Sympy returned multiple possible solutions")
else:
solutions = next(iter(solutions))
solutions = {
str(k): int(v) for k, v in zip(variables, solutions) if v.is_number
}
else:
raise SolveException("Sympy returned unexpected result")
# Determine values for original variables in equivalence classes
orig_solutions = {}
for k, v in origvar_to_solvevar.items():
if isinstance(v, Constant):
orig_solutions[k] = v.value
elif isinstance(v, Variable):
if v.id in solutions:
orig_solutions[k] = solutions[v.id]
else:
raise AssertionError()
return orig_solutions
|