File: deep_equals.py

package info (click to toggle)
python-libcst 1.4.0-1.2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 5,928 kB
  • sloc: python: 76,235; makefile: 10; sh: 2
file content (56 lines) | stat: -rw-r--r-- 1,719 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Provides the implementation of `CSTNode.deep_equals`.
"""

from dataclasses import fields
from typing import Sequence

from libcst._nodes.base import CSTNode


def deep_equals(a: object, b: object) -> bool:
    if isinstance(a, CSTNode) and isinstance(b, CSTNode):
        return _deep_equals_cst_node(a, b)
    elif (
        isinstance(a, Sequence)
        and not isinstance(a, (str, bytes))
        and isinstance(b, Sequence)
        and not isinstance(b, (str, bytes))
    ):
        return _deep_equals_sequence(a, b)
    else:
        return a == b


def _deep_equals_sequence(a: Sequence[object], b: Sequence[object]) -> bool:
    """
    A helper function for `CSTNode.deep_equals`.

    Normalizes and compares sequences. Because we only ever expose `Sequence[]`
    types, and not `List[]`, `Tuple[]`, or `Iterable[]` values, all sequences should
    be treated as equal if they have the same values.
    """
    if a is b:  # short-circuit
        return True
    if len(a) != len(b):
        return False
    return all(deep_equals(a_el, b_el) for (a_el, b_el) in zip(a, b))


def _deep_equals_cst_node(a: "CSTNode", b: "CSTNode") -> bool:
    if type(a) is not type(b):
        return False
    if a is b:  # short-circuit
        return True
    # Ignore metadata and other hidden fields
    for field in (f for f in fields(a) if f.compare is True):
        a_value = getattr(a, field.name)
        b_value = getattr(b, field.name)
        if not deep_equals(a_value, b_value):
            return False
    return True