File: model_compatibility.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (456 lines) | stat: -rw-r--r-- 17,750 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
#include <ATen/core/ivalue.h>
#include <caffe2/serialize/file_adapter.h>
#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
#include <torch/csrc/jit/mobile/file_format.h>
#include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_read.h>

#include <caffe2/serialize/in_memory_adapter.h>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>

namespace c10 {
TypePtr parseType(const std::string& pythonStr);
} // namespace c10

namespace torch {
namespace jit {

using caffe2::serialize::FileAdapter;
using caffe2::serialize::IStreamAdapter;
using caffe2::serialize::PyTorchStreamReader;
using caffe2::serialize::ReadAdapterInterface;

c10::IValue readArchive(
    const std::string& archive_name,
    PyTorchStreamReader& stream_reader) {
  c10::optional<at::Device> device;
  std::shared_ptr<CompilationUnit> compilation_unit =
      std::make_shared<CompilationUnit>();

  // TODO (T90180710): Simplify type_resolver and obj_loader when getting
  // bytecode version from model
  auto type_resolver = [&](const c10::QualifiedName& qn) {
    return typeResolverMobile(qn, compilation_unit);
  };

  std::shared_ptr<mobile::CompilationUnit> mobile_compilation_unit =
      std::make_shared<mobile::CompilationUnit>();
  auto obj_loader = [&](at::StrongTypePtr type, IValue input) {
    return objLoaderMobile(type, input, *mobile_compilation_unit);
  };
  bool bytecode_tensor_in_constants_archive =
      (archive_name == "bytecode" && !isTensorInBytecodeArchive(stream_reader));
  auto ivalues = torch::jit::readArchiveAndTensors(
      archive_name,
      /*pickle_prefix=*/"",
      /*tensor_prefix=*/
      bytecode_tensor_in_constants_archive ? "constants/" : "",
      type_resolver,
      obj_loader,
      device,
      stream_reader,
      nullptr);
  return ivalues;
}

std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) {
  return std::move(*readArchive("bytecode", reader).toTuple()).elements().vec();
}

/********************** Bytecode **********************/

// Forward declare
uint64_t _get_model_bytecode_version(
    const std::vector<IValue>& bytecode_ivalues);
static uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size);

uint64_t _get_model_bytecode_version(std::istream& in) {
  auto orig_pos = in.tellg();
  in.seekg(0, in.beg);
  std::shared_ptr<char> data;
  size_t size = 0;
  std::tie(data, size) = get_stream_content(in);
  in.seekg(orig_pos, in.beg);
  return _get_model_bytecode_version_from_bytes(data.get(), size);
}

uint64_t _get_model_bytecode_version(const std::string& filename) {
  std::ifstream ifile(filename);
  return _get_model_bytecode_version(ifile);
}

uint64_t _get_model_bytecode_version(
    std::shared_ptr<ReadAdapterInterface> rai) {
  std::shared_ptr<char> data;
  size_t size = 0;
  std::tie(data, size) = get_rai_content(rai.get());
  return _get_model_bytecode_version_from_bytes(data.get(), size);
}

uint64_t _get_model_bytecode_version_zip(
    std::shared_ptr<ReadAdapterInterface> rai) {
  if (!check_zip_file(rai)) {
    TORCH_CHECK(
        false,
        "Failed to open .ptl file please ensure the model was exported for mobile");
  }
  PyTorchStreamReader reader(std::move(rai));
  auto bytecode_values = get_bytecode_ivalues(reader);
  return _get_model_bytecode_version(bytecode_values);
}

uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size) {
  TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
  auto format = getFileFormat(data);
  switch (format) {
    case FileFormat::FlatbufferFileFormat: {
      if (get_flatbuffer_bytecode_version == nullptr) {
        TORCH_CHECK(
            false,
            "Flatbuffer input file but the build hasn't enabled flatbuffer");
      } else {
        return get_flatbuffer_bytecode_version(data);
      }
    }
    case FileFormat::ZipFileFormat: {
      auto rai =
          std::make_unique<caffe2::serialize::MemoryReadAdapter>(data, size);
      auto version = _get_model_bytecode_version_zip(std::move(rai));
      return version;
    }

    default:
      TORCH_CHECK(false, "Unrecognized data format");
  }
}

uint64_t _get_model_bytecode_version(
    const std::vector<IValue>& bytecode_ivalues) {
  if (!bytecode_ivalues.empty() && bytecode_ivalues[0].isInt()) {
    int64_t model_version = bytecode_ivalues[0].toInt();
    TORCH_CHECK(
        model_version > 0,
        "Expected model bytecode version > 0 got ",
        model_version);
    return static_cast<uint64_t>(model_version);
  }
  TORCH_CHECK(false, "Failed to get bytecode version.");
}

/********************** Operator Version **********************/

uint64_t _get_model_operator_version(
    PyTorchStreamReader& reader); // Forward Declare

uint64_t _get_model_operator_version(std::istream& in) {
  std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
  return _get_model_operator_version(std::move(rai));
}

uint64_t _get_model_operator_version(const std::string& filename) {
  std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
  return _get_model_operator_version(std::move(rai));
}

uint64_t _get_model_operator_version(
    std::shared_ptr<ReadAdapterInterface> rai) {
  if (!check_zip_file(rai)) {
    TORCH_CHECK(
        false,
        "Failed to open .ptl file please ensure the model was exported for mobile");
  }
  PyTorchStreamReader reader(std::move(rai));
  return _get_model_operator_version(reader);
}

uint64_t _get_model_operator_version(PyTorchStreamReader& reader) {
  return reader.version();
}

/********************** Operators and Info **********************/

// Forward declare
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
    std::vector<IValue> bytecode_ivalues);

std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
    std::istream& in) {
  std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
  return _get_model_ops_and_info(std::move(rai));
}

std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
    const std::string& filename) {
  std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
  return _get_model_ops_and_info(std::move(rai));
}

std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
    std::shared_ptr<ReadAdapterInterface> rai) {
  if (!check_zip_file(rai)) {
    TORCH_WARN("Failed to open zip file for model ops.");
    return std::unordered_map<std::string, OperatorInfo>{};
  }
  PyTorchStreamReader reader(std::move(rai));
  auto bytecode_values = get_bytecode_ivalues(reader);
  return _get_model_ops_and_info(bytecode_values);
}

/* A function to retrieve the root (top level) operators of a model and their
 * corresponding compatibility info. These root operators can call other
 * operators within them (traced ops), and a root op can call many different
 * traced ops depending on internal code paths in the root op. These traced ops
 * are not returned by this function. Those operators are abstracted into the
 * runtime as an implementation detail (and the traced ops themselves can also
 * call other operators) making retrieving them difficult and their value from
 * this api negligible since they will differ between which runtime version the
 * model is run on. Because of this, there is a false positive this api can't
 * prevent in a compatibility usecase. All the root ops of a model are present
 * in a target runtime, but not all the traced ops are which prevents a model
 * from being able to run.
 **/
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
    std::vector<IValue> bytecode_ivalues) {
  constexpr uint64_t min_version_with_schema = 6;
  if (_get_model_bytecode_version(bytecode_ivalues) < min_version_with_schema) {
    TORCH_WARN(
        "Only models with bytecode version 6 and above contain operator schema information. Please re-export your model to generate it");
  }
  std::unordered_map<std::string, OperatorInfo> result;
  if (bytecode_ivalues.empty()) {
    TORCH_WARN("Failed to get model ops and info.");
    return result;
  }
  // loop over all the functions in the bytecode
  for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
    // descend to the operators list
    const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
    auto operators_tuple = method_tuple.at(1).toTupleRef().elements()[1];
    auto operators = operators_tuple.toTupleRef().elements()[1];
    for (auto& op_tuple : operators.toTupleRef().elements()) {
      const auto& op = op_tuple.toTupleRef().elements();

      // grab name
      std::string op_name = op.at(0).toStringRef();
      std::string op_overload_name = op.at(1).toStringRef();
      if (op_overload_name != "") {
        op_name.append(".");
        op_name.append(op_overload_name);
      }

      // grab schema size
      if (op.size() > 2) {
        result.emplace(op_name, OperatorInfo{(int)op.at(2).toInt()});
      } else { // no schema information use default
        result.emplace(op_name, OperatorInfo{});
      }
    }
  }
  return result;
}

/********************** Get Type Table **********************/

// Forward declare
std::unordered_set<std::string> _get_mobile_model_contained_types(
    const std::vector<IValue>& bytecode_ivalues);

std::unordered_set<std::string> _get_mobile_model_contained_types(
    std::istream& in) {
  std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
  return _get_mobile_model_contained_types(std::move(rai));
}

std::unordered_set<std::string> _get_mobile_model_contained_types(
    const std::string& filename) {
  std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
  return _get_mobile_model_contained_types(std::move(rai));
}

std::unordered_set<std::string> _get_mobile_model_contained_types(
    std::shared_ptr<ReadAdapterInterface> rai) {
  if (!check_zip_file(rai)) {
    TORCH_CHECK(
        false,
        "Failed to open .ptl file please ensure the model was exported for mobile");
  }
  PyTorchStreamReader reader(std::move(rai));
  auto bytecode_values = get_bytecode_ivalues(reader);
  return _get_mobile_model_contained_types(bytecode_values);
}

// Get deduplicate type table given bytecode, and each string is a atomic type,
// like str, Tensor and etc. For example,
// input: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
// output: {Dict, int, Tuple, Tensor}
std::unordered_set<std::string> _get_mobile_model_contained_types(
    const std::vector<IValue>& bytecode_ivalues) {
  std::unordered_set<std::string> contained_types;
  // To avoid parsing same type twice, declare $parsed_type_names_records and
  // use type name (string, ex: "Dict[int, Tuple[Tensor, Tensor, Tensor]]") as
  // the hash to record which types are parsed.
  std::unordered_set<std::string> parsed_type_names_records;
  for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
    const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
    auto type_table_tuple =
        method_tuple.at(1).toTupleRef().elements()[BYTECODE_INDEX_TYPE];
    const auto& type_table =
        type_table_tuple.toTupleRef().elements()[1].toTupleRef().elements();

    // type_table is a list of IValue, and each IValue is a string,
    // for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
    std::vector<std::string> type_name_list;
    for (const auto& type_definition : type_table) {
      std::unordered_set<std::string> type_tokens;
      std::string type_name = type_definition.toStringRef();
      type_name_list.emplace_back(type_name);
    }
    at::TypeParser parser(type_name_list);
    parser.parseList();
    contained_types = parser.getContainedTypes();
  }

  return contained_types;
}

/********************** Compatibility Checker **********************/

ModelCompatibilityInfo ModelCompatibilityInfo::get(std::istream& in) {
  std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
  return get(std::move(rai));
}

ModelCompatibilityInfo ModelCompatibilityInfo::get(
    const std::string& filename) {
  std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
  return get(std::move(rai));
}

ModelCompatibilityInfo ModelCompatibilityInfo::get(
    std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai) {
  if (!check_zip_file(rai)) {
    TORCH_CHECK(
        false, "Failed to open zip file for model compatibility information");
  }
  PyTorchStreamReader reader(std::move(rai));
  std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader);
  uint64_t model_bytecode_version =
      _get_model_bytecode_version(bytecode_values);
  auto model_info = _get_model_ops_and_info(bytecode_values);
  std::unordered_set<std::string> type_table =
      _get_mobile_model_contained_types(bytecode_values);
  uint64_t operator_version = _get_model_operator_version(reader);
  return ModelCompatibilityInfo{
      model_bytecode_version, model_info, type_table, operator_version};
}

ModelCompatCheckResult is_compatible(
    RuntimeCompatibilityInfo runtime_info,
    ModelCompatibilityInfo model_info) {
  ModelCompatCheckResult result = {ModelCompatibilityStatus::OK, {}};
  // Check that the models bytecode version is less than or equal to
  // kMaxSupportedBytecodeVersion from the runtime
  if (model_info.bytecode_version >
      runtime_info.min_max_supported_bytecode_version.second) {
    result.status = ModelCompatibilityStatus::ERROR;
    std::ostringstream s;
    s << "model bytecode version " << model_info.bytecode_version
      << "is greater than the max supported bytecode version in runtimes "
      << runtime_info.min_max_supported_bytecode_version.second;
    result.errors.emplace_back(s.str());
  } else if (
      model_info.bytecode_version <
      runtime_info.min_max_supported_bytecode_version.first) {
    result.status = ModelCompatibilityStatus::ERROR;
    std::ostringstream s;
    s << "model bytecode version " << model_info.bytecode_version
      << "is less than the minimum supported bytecode version in runtime "
      << runtime_info.min_max_supported_bytecode_version.first;
    result.errors.emplace_back(s.str());
  }

  std::unordered_set<std::string> supported_type = runtime_info.supported_types;

  // Check type table
  for (const auto& type_name : model_info.type_table) {
    if (supported_type.find(type_name) == supported_type.end()) {
      result.status = ModelCompatibilityStatus::ERROR;
      std::ostringstream s;
      s << "Primitive type: '" << type_name
        << "' is not supported in current runtime";
      result.errors.push_back(s.str());
    }
  }

  // Check operators
  std::unordered_map<std::string, OperatorInfo> operator_info =
      model_info.operator_info;
  for (auto const& op : operator_info) {
    std::string op_name = op.first;
    OperatorInfo model_op_info = op.second;

    // Check if operator not present in runtime
    if (runtime_info.operator_info.find(op_name) ==
        runtime_info.operator_info.end()) {
      result.status = ModelCompatibilityStatus::ERROR;
      std::ostringstream s;
      s << "Operator '" << op_name << "' missing from runtime (not found)";
      result.errors.push_back(s.str());
    } else {
      OperatorInfo runtime_op_info = runtime_info.operator_info.at(op_name);

      // If the runtime op has no schema information its a false alarm and isn't
      // actually useable
      if (!runtime_op_info.num_schema_args.has_value()) {
        result.status = ModelCompatibilityStatus::ERROR;
        std::ostringstream s;
        s << "Operator '" << op_name
          << "' missing from runtime (missing schema)";
        result.errors.push_back(s.str());
      } else {
        // Check if the model operator has schema information. If it doesn't
        // then the model is from a bytecode version < 6 and we are done. If the
        // model has more args than the runtime, then the runtime can't know
        // what to do so we aren't compatible. If the runtime has more args than
        // the model then we can just use default values and be fine.
        if (model_op_info.num_schema_args.has_value() &&
            (model_op_info.num_schema_args.value() >
             runtime_op_info.num_schema_args.value())) {
          result.status = ModelCompatibilityStatus::ERROR;
          std::ostringstream s;
          s << "Operator schema for'" << op_name << "' has "
            << model_op_info.num_schema_args.value()
            << " args in model but only "
            << runtime_op_info.num_schema_args.value() << " in the runtime";
          result.errors.push_back(s.str());
        }
      }
    }
  }

  // Check Operator Versions
  if (model_info.operator_version <
          runtime_info.min_max_supported_opperator_versions.first ||
      model_info.operator_version >
          runtime_info.min_max_supported_opperator_versions.second) {
    result.status = ModelCompatibilityStatus::ERROR;
    std::ostringstream s;
    s << "Model Operator Version " << model_info.operator_version
      << "is not within supported version range of the runtime "
      << runtime_info.min_max_supported_opperator_versions.first << " to "
      << runtime_info.min_max_supported_opperator_versions.second;
    result.errors.push_back(s.str());
  }

  return result;
}

} // namespace jit
} // namespace torch