File: python_dict.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (128 lines) | stat: -rw-r--r-- 3,345 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#pragma once

#include <ATen/core/Dict.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/utils/pybind.h>

namespace torch {
namespace jit {

void initScriptDictBindings(PyObject* module);

/// An iterator over the keys of ScriptDict. This is used to support
/// .keys() and iteration.
class ScriptDictKeyIterator final {
 public:
  ScriptDictKeyIterator(
      c10::impl::GenericDict::iterator iter,
      c10::impl::GenericDict::iterator end)
      : iter_(std::move(iter)), end_(std::move(end)) {}
  IValue next();

 private:
  c10::impl::GenericDict::iterator iter_;
  c10::impl::GenericDict::iterator end_;
};

/// An iterator over the key-value pairs of ScriptDict. This is used to support
/// .items().
class ScriptDictIterator final {
 public:
  ScriptDictIterator(
      c10::impl::GenericDict::iterator iter,
      c10::impl::GenericDict::iterator end)
      : iter_(std::move(iter)), end_(std::move(end)) {}
  IValue next();

 private:
  c10::impl::GenericDict::iterator iter_;
  c10::impl::GenericDict::iterator end_;
};

/// A wrapper around c10::Dict that can be exposed in Python via pybind
/// with an API identical to the Python dictionary class. This allows
/// dictionaries to have reference semantics across the Python/TorchScript
/// boundary.
class ScriptDict final {
 public:
  // Constructor.
  ScriptDict(IValue data) : dict_(AnyType::get(), AnyType::get()) {
    TORCH_INTERNAL_ASSERT(data.isGenericDict());
    dict_ = data.toGenericDict();
  }

  // Get the type of the dictionary.
  DictTypePtr type() const {
    return DictType::create(dict_.keyType(), dict_.valueType());
  }

  // Return a string representation that can be used
  // to reconstruct the instance.
  std::string repr() const {
    std::ostringstream s;
    s << '{';
    bool f = false;
    for (auto const& kv : dict_) {
      if (f) {
        s << ", ";
      }
      s << kv.key() << ": " << kv.value();
      f = true;
    }
    s << '}';
    return s.str();
  }

  // Return an iterator over the keys of the dictionary.
  ScriptDictKeyIterator iter() const {
    auto begin = dict_.begin();
    auto end = dict_.end();
    return ScriptDictKeyIterator(begin, end);
  }

  // Return an iterator over the key-value pairs of the dictionary.
  ScriptDictIterator items() const {
    auto begin = dict_.begin();
    auto end = dict_.end();
    return ScriptDictIterator(begin, end);
  }

  // Interpret the dictionary as a boolean; empty means false, non-empty means
  // true.
  bool toBool() const {
    return !(dict_.empty());
  }

  // Get the value for the given key. Throws std::out_of_range if the key does
  // not exist.
  IValue getItem(const IValue& key) {
    return dict_.at(key);
  };

  // Set the value for the given key.
  void setItem(const IValue& key, const IValue& value) {
    dict_.insert_or_assign(key, value);
  };

  // Check whether the dictionary contains the given key.
  bool contains(const IValue& key) {
    return dict_.contains(key);
  }

  // Delete the given key from the dictionary.
  bool delItem(const IValue& key) {
    return dict_.erase(key);
  }

  // Get the size of the dictionary.
  int64_t len() const {
    return dict_.size();
  }

  // A c10::Dict instance that holds the actual data.
  c10::impl::GenericDict dict_;
};

} // namespace jit
} // namespace torch