1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
|
# fmt: off
"""A Module to safely parse/evaluate Mathematical Expressions"""
import ast
import math
import operator as op
from numpy import int64
# Sets the limit of how high the number can get to prevent DNS attacks
max_value = 1e17
# Redefine mathematical operations to prevent DNS attacks
def add(a, b):
"""Redefine add function to prevent too large numbers"""
if any(abs(n) > max_value for n in [a, b]):
raise ValueError((a, b))
return op.add(a, b)
def sub(a, b):
"""Redefine sub function to prevent too large numbers"""
if any(abs(n) > max_value for n in [a, b]):
raise ValueError((a, b))
return op.sub(a, b)
def mul(a, b):
"""Redefine mul function to prevent too large numbers"""
if a == 0.0 or b == 0.0:
pass
elif math.log10(abs(a)) + math.log10(abs(b)) > math.log10(max_value):
raise ValueError((a, b))
return op.mul(a, b)
def div(a, b):
"""Redefine div function to prevent too large numbers"""
if b == 0.0:
raise ValueError((a, b))
elif a == 0.0:
pass
elif math.log10(abs(a)) - math.log10(abs(b)) > math.log10(max_value):
raise ValueError((a, b))
return op.truediv(a, b)
def power(a, b):
"""Redefine pow function to prevent too large numbers"""
if a == 0.0:
return 0.0
elif b / math.log(max_value, abs(a)) >= 1:
raise ValueError((a, b))
return op.pow(a, b)
def exp(a):
"""Redefine exp function to prevent too large numbers"""
if a > math.log(max_value):
raise ValueError(a)
return math.exp(a)
# The list of allowed operators with defined functions they should operate on
operators = {
ast.Add: add,
ast.Sub: sub,
ast.Mult: mul,
ast.Div: div,
ast.Pow: power,
ast.USub: op.neg,
ast.Mod: op.mod,
ast.FloorDiv: op.ifloordiv
}
# Take all functions from math module as allowed functions
allowed_math_fxn = {
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"asin": math.asin,
"acos": math.acos,
"atan": math.atan,
"atan2": math.atan2,
"hypot": math.hypot,
"sinh": math.sinh,
"cosh": math.cosh,
"tanh": math.tanh,
"asinh": math.asinh,
"acosh": math.acosh,
"atanh": math.atanh,
"radians": math.radians,
"degrees": math.degrees,
"sqrt": math.sqrt,
"log": math.log,
"log10": math.log10,
"log2": math.log2,
"fmod": math.fmod,
"abs": math.fabs,
"ceil": math.ceil,
"floor": math.floor,
"round": round,
"exp": exp,
}
def get_function(node):
"""Get the function from an ast.node"""
# The function call can be to a bare function or a module.function
if isinstance(node.func, ast.Name):
return node.func.id
elif isinstance(node.func, ast.Attribute):
return node.func.attr
else:
raise TypeError("node.func is of the wrong type")
def limit(max_=None):
"""Return decorator that limits allowed returned values."""
import functools
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
ret = func(*args, **kwargs)
try:
mag = abs(ret)
except TypeError:
pass # not applicable
else:
if mag > max_:
raise ValueError(ret)
if isinstance(ret, int):
ret = int64(ret)
return ret
return wrapper
return decorator
@limit(max_=max_value)
def _eval(node):
"""Evaluate a mathematical expression string parsed by ast"""
# Allow evaluate certain types of operators
if isinstance(node, ast.Constant) and isinstance(node.value, (float, int)):
return node.value
elif isinstance(node, ast.BinOp): # <left> <operator> <right>
return operators[type(node.op)](_eval(node.left), _eval(node.right))
elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1
return operators[type(node.op)](_eval(node.operand))
elif isinstance(node, ast.Call): # using math.function
func = get_function(node)
# Evaluate all arguments
evaled_args = [_eval(arg) for arg in node.args]
return allowed_math_fxn[func](*evaled_args)
elif isinstance(node, ast.Name):
if node.id.lower() == "pi":
return math.pi
elif node.id.lower() == "e":
return math.e
elif node.id.lower() == "tau":
return math.pi * 2.0
else:
raise TypeError(
"Found a str in the expression, either param_dct/the "
"expression has a mistake in the parameter names or "
"attempting to parse non-mathematical code")
else:
raise TypeError(node)
def eval_expression(expression, param_dct={}):
"""Parse a mathematical expression,
Replaces variables with the values in param_dict and solves the expression
"""
if not isinstance(expression, str):
raise TypeError("The expression must be a string")
if len(expression) > 1e4:
raise ValueError("The expression is too long.")
expression_rep = expression.strip()
if "()" in expression_rep:
raise ValueError("Invalid operation in expression")
for key, val in param_dct.items():
expression_rep = expression_rep.replace(key, str(val))
return _eval(ast.parse(expression_rep, mode="eval").body)
|