File: eval.py

package info (click to toggle)
python-xarray 2026.01.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 13,676 kB
  • sloc: python: 120,278; makefile: 269
file content (138 lines) | stat: -rw-r--r-- 4,790 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Expression evaluation for Dataset.eval().

This module provides AST-based expression evaluation to support N-dimensional
arrays (N > 2), which pd.eval() doesn't support. See GitHub issue #11062.

We retain logical operator transformation ('and'/'or'/'not' to '&'/'|'/'~',
and chained comparisons) for consistency with query(), which still uses
pd.eval(). We don't migrate query() to this implementation because:
- query() typically works fine (expressions usually compare 1D coordinates)
- pd.eval() with numexpr is faster and well-tested for query's use case
"""

from __future__ import annotations

import ast
import builtins
from typing import Any

# Base namespace for eval expressions.
# We add common builtins back since we use an empty __builtins__ dict.
EVAL_BUILTINS: dict[str, Any] = {
    # Numeric/aggregation functions
    "abs": abs,
    "min": min,
    "max": max,
    "round": round,
    "len": len,
    "sum": sum,
    "pow": pow,
    "any": any,
    "all": all,
    # Type constructors
    "int": int,
    "float": float,
    "bool": bool,
    "str": str,
    "list": list,
    "tuple": tuple,
    "dict": dict,
    "set": set,
    "slice": slice,
    # Iteration helpers
    "range": range,
    "zip": zip,
    "enumerate": enumerate,
    "map": builtins.map,
    "filter": filter,
}


class LogicalOperatorTransformer(ast.NodeTransformer):
    """Transform operators for consistency with query().

    query() uses pd.eval() which transforms these operators automatically.
    We replicate that behavior here so syntax that works in query() also
    works in eval().

    Transformations:
    1. 'and'/'or'/'not' -> '&'/'|'/'~'
    2. 'a < b < c' -> '(a < b) & (b < c)'

    These constructs fail on arrays in standard Python because they call
    __bool__(), which is ambiguous for multi-element arrays.
    """

    def visit_BoolOp(self, node: ast.BoolOp) -> ast.AST:
        # Transform: a and b -> a & b, a or b -> a | b
        self.generic_visit(node)
        op: ast.BitAnd | ast.BitOr
        if isinstance(node.op, ast.And):
            op = ast.BitAnd()
        elif isinstance(node.op, ast.Or):
            op = ast.BitOr()
        else:
            return node

        # BoolOp can have multiple values: a and b and c
        # Transform to chained BinOp: (a & b) & c
        result = node.values[0]
        for value in node.values[1:]:
            result = ast.BinOp(left=result, op=op, right=value)
        return ast.fix_missing_locations(result)

    def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.AST:
        # Transform: not a -> ~a
        self.generic_visit(node)
        if isinstance(node.op, ast.Not):
            return ast.fix_missing_locations(
                ast.UnaryOp(op=ast.Invert(), operand=node.operand)
            )
        return node

    def visit_Compare(self, node: ast.Compare) -> ast.AST:
        # Transform chained comparisons: 1 < x < 5 -> (1 < x) & (x < 5)
        # Python's chained comparisons use short-circuit evaluation at runtime,
        # which calls __bool__ on intermediate results. This fails for arrays.
        # We transform to bitwise AND which works element-wise.
        self.generic_visit(node)

        if len(node.ops) == 1:
            # Simple comparison, no transformation needed
            return node

        # Build individual comparisons and chain with BitAnd
        # For: a < b < c < d
        # We need: (a < b) & (b < c) & (c < d)
        comparisons = []
        left = node.left
        for op, comparator in zip(node.ops, node.comparators, strict=True):
            comp = ast.Compare(left=left, ops=[op], comparators=[comparator])
            comparisons.append(comp)
            left = comparator

        # Chain with BitAnd: (a < b) & (b < c) & ...
        result: ast.Compare | ast.BinOp = comparisons[0]
        for comp in comparisons[1:]:
            result = ast.BinOp(left=result, op=ast.BitAnd(), right=comp)
        return ast.fix_missing_locations(result)


def validate_expression(tree: ast.AST) -> None:
    """Validate that an AST doesn't contain patterns we don't support.

    These restrictions emulate pd.eval() behavior for consistency.
    """
    for node in ast.walk(tree):
        # Block lambda expressions (pd.eval: "Only named functions are supported")
        if isinstance(node, ast.Lambda):
            raise ValueError(
                "Lambda expressions are not allowed in eval(). "
                "Use direct operations on data variables instead."
            )
        # Block private/dunder attributes (consistent with pd.eval restrictions)
        if isinstance(node, ast.Attribute) and node.attr.startswith("_"):
            raise ValueError(
                f"Access to private attributes is not allowed: '{node.attr}'"
            )