File: formatter.py

package info (click to toggle)
mujoco 2.2.2-3.2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 39,796 kB
  • sloc: ansic: 28,947; cpp: 28,897; cs: 14,241; python: 10,465; xml: 5,104; sh: 93; makefile: 34
file content (149 lines) | stat: -rw-r--r-- 4,756 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility for formatting AST node as Python code."""

import contextlib
import dataclasses
from typing import Any, Iterable, Mapping, Sequence

INDENT_WIDTH = 4
MAX_LINE_WIDTH = 80
SIMPLE_TYPES = frozenset([int, float, str, bool, bytes, type(None)])


def format_as_python_code(obj: Any) -> str:
  """Formats an AST node object as well-indented Python code."""
  formatter = _Formatter()
  formatter.add(obj)
  return str(formatter)


def _is_all_simple(seq: Iterable[Any]) -> bool:
  return all(type(obj) in SIMPLE_TYPES for obj in seq)


class _Formatter:
  """A helper for pretty-printing AST nodes as Python code."""

  def __init__(self):
    self._line_prefix = ''
    self._lines = []
    self._add_to_last_line = False

  @contextlib.contextmanager
  def _indent(self, width: int = INDENT_WIDTH):
    self._line_prefix += ' ' * width
    yield
    self._line_prefix = self._line_prefix[:-width]

  @contextlib.contextmanager
  def _append_at_end(self, s):
    yield
    self._lines[-1] += s

  def _add_line(self, line: str, no_break: bool = False):
    if self._add_to_last_line:
      self._lines[-1] += line
    else:
      self._lines.append(self._line_prefix + line)
    self._add_to_last_line = no_break

  def _add_dict(self, obj: Mapping[Any, Any]):
    """Adds a dict to the formatted output."""
    self._add_line('dict([')
    with self._indent():
      for k, v in obj.items():

        # Try to fit everything into a single line first.
        if _is_all_simple((k, v)):
          single_line = f'({k!r}, {v!r}),'
          if len(self._line_prefix) + len(single_line) <= MAX_LINE_WIDTH:
            self._add_line(single_line)
            continue

        self._add_line(f"('{k}',")
        with self._append_at_end('),'):
          with self._indent(1):
            self.add(v)

    self._add_line('])')

  def _add_dataclass(self, obj: Any):
    """Adds a dataclass object to the formatted output."""
    # Filter out default values.
    kv_pairs = []
    for k in dataclasses.fields(obj):
      v = getattr(obj, k.name)
      if v != k.default:
        kv_pairs.append((k, v))

    # Try to fit everything into a single line first.
    if _is_all_simple(v for _, v in kv_pairs):
      single_line = ', '.join(f'{k.name}={v!r}' for k, v in kv_pairs)
      single_line = f'{obj.__class__.__name__}({single_line})'
      if len(self._line_prefix) + len(single_line) <= MAX_LINE_WIDTH:
        self._add_line(single_line)
        return

    self._add_line(obj.__class__.__name__ + '(')
    with self._indent():
      for k, v in kv_pairs:
        self._add_line(k.name + '=', no_break=True)
        with self._append_at_end(','):
          self.add(v)
    self._add_line(')')

  def _add_sequence(self, obj: Sequence[Any]) -> None:
    """Adds a sequence to the formatted output."""
    default_str = repr(obj)
    open_token, close_token = default_str[0], default_str[-1]
    # Try to fit everything into a single line first.
    if _is_all_simple(obj):
      single_line = (
          f"{open_token}{', '.join(repr(o) for o in obj)}{close_token}")
      if close_token == ')' and len(obj) == 1:
        single_line = f'{single_line[:-1]},)'
      if len(self._line_prefix) + len(single_line) <= MAX_LINE_WIDTH:
        self._add_line(single_line)
        return

    self._add_line(open_token)
    with self._indent():
      for v in obj:
        with self._append_at_end(','):
          self.add(v)
    self._add_line(close_token)

  def add(self, obj: Any) -> None:
    """Adds an object to the formatted output."""
    if _is_all_simple((obj,)):
      self._add_line(repr(obj))
    elif dataclasses.is_dataclass(obj):
      self._add_dataclass(obj)
    elif isinstance(obj, Mapping):
      self._add_dict(obj)
    elif isinstance(obj, Sequence):
      self._add_sequence(obj)
    else:
      raise NotImplementedError

  def __str__(self):
    lines = []
    for line in self._lines:
      if len(line) > MAX_LINE_WIDTH:
        lines.append(f'{line}  # pylint: disable=line-too-long')
      else:
        lines.append(line)
    return '\n'.join(lines)