File: onnx.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (632 lines) | stat: -rw-r--r-- 24,389 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
#include <torch/csrc/jit/passes/onnx.h>

#include <ATen/core/functional.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/symbolic.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/onnx/constant_map.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <torch/csrc/jit/passes/onnx/onnx_log.h>
#include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
#include <torch/csrc/jit/python/python_ir.h>
#include <torch/csrc/utils/pybind.h>
#include <sstream>

namespace torch::jit {

void removePrintOps(Block* block) {
  for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
       ++it) {
    for (auto b : it->blocks()) {
      removePrintOps(b);
    }
    if (it->kind() == prim::Print || it->kind() == aten::warn) {
      for (size_t i = 0; i < it->inputs().size();) {
        auto input = it->inputs().at(i);
        // only handling constants bc of potential side effects
        if (input->uses().size() == 1 &&
            input->node()->kind() == prim::Constant) {
          it->removeInput(i);
          input->node()->destroy();
        } else {
          ++i;
        }
      }
      it.destroyCurrent();
    }
  }
}

void RemovePrintOps(std::shared_ptr<Graph>& graph) {
  removePrintOps(graph->block());
  GRAPH_DUMP("After RemovePrintOps: ", graph);
}

void checkONNXCompatibility(const c10::FunctionSchema& schema) {
  // in ONNX, all inputs are tensors, no support for tensor list
  // so at most one input tensor list is supported
  bool has_tensor_list = false;
  const auto& args = schema.arguments();
  for (const auto& arg : args) {
    if (arg.name() == "_caffe2_preallocated_outputs") {
      continue;
    }
    auto type = arg.type();
    if (type->kind() == TypeKind::OptionalType) {
      type = reinterpret_cast<OptionalType*>(type.get())->getElementType();
      // recursive optional type is not supported
      TORCH_INTERNAL_ASSERT(type->kind() != TypeKind::OptionalType);
    }
    if (type->kind() == TypeKind::ListType) {
      const auto& elem_type =
          reinterpret_cast<ListType*>(type.get())->getElementType();
      if (elem_type->isSubtypeOf(*TensorType::get())) {
        TORCH_INTERNAL_ASSERT(
            !has_tensor_list,
            "ONNX export supports at most one TensorList as input.");
        has_tensor_list = true;
      }
    }
  }
}

void preprocessCaffe2Ops(Block* block) {
  for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
       ++it) {
    for (auto b : it->blocks()) {
      preprocessCaffe2Ops(b);
    }
    if (it->kind().is_caffe2()) {
      const auto& schema = it->schema();
      checkONNXCompatibility(schema);
      std::vector<Value*> origin_inputs;
      for (Value* v : it->inputs()) {
        origin_inputs.push_back(v);
      }
      it->removeAllInputs();
      const auto& args = schema.arguments();
      size_t origin_inputs_index = 0;
      for (const auto& arg : args) {
        const auto& type = arg.type();
        TORCH_INTERNAL_ASSERT(origin_inputs_index < origin_inputs.size());
        const auto& origin_input = origin_inputs[origin_inputs_index++];
        if (type->kind() == TypeKind::OptionalType &&
            origin_input->mustBeNone()) {
          continue;
        }
        if (type->isSubtypeOf(*TensorType::get())) {
          it->addInput(origin_input);
        } else if (
            type->kind() == TypeKind::BoolType ||
            type->kind() == TypeKind::IntType) {
          const auto* constant_node = origin_input->node();
          TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
          it->i_(Symbol::attr(arg.name()), constant_node->i(attr::value));
        } else if (type->kind() == TypeKind::FloatType) {
          const auto* constant_node = origin_input->node();
          TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
          it->f_(Symbol::attr(arg.name()), constant_node->f(attr::value));
        } else if (type->kind() == TypeKind::StringType) {
          const auto* constant_node = origin_input->node();
          TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
          it->s_(Symbol::attr(arg.name()), constant_node->s(attr::value));
        } else if (type->kind() == TypeKind::ListType) {
          const auto& list_node = origin_input->node();
          const auto& elem_type = type->castRaw<ListType>()->getElementType();
          TORCH_INTERNAL_ASSERT(
              list_node->kind() == prim::ListConstruct ||
              list_node->kind() == prim::Constant);
          if (elem_type->isSubtypeOf(*TensorType::get())) {
            TORCH_INTERNAL_ASSERT(list_node->kind(), prim::ListConstruct);
            const auto& tensor_list = origin_input->node()->inputs();
            for (const auto& t : tensor_list) {
              it->addInput(t);
            }
          } else if (elem_type->kind() == TypeKind::FloatType) {
            std::vector<double> values;
            if (list_node->kind() == prim::ListConstruct) {
              for (const auto* elem_input : list_node->inputs()) {
                const auto* constant_node = elem_input->node();
                TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
                values.push_back(constant_node->f(attr::value));
              }
            } else { // is a constant list
              values = list_node->fs(attr::value);
            }
            it->fs_(Symbol::attr(arg.name()), values);
          } else {
            throw std::runtime_error(
                "Unhandled scalar arg: " + arg.name() +
                ", type: " + c10::typeKindToString(elem_type->kind()));
          }
        } else {
          throw std::runtime_error(
              "Unsupported input type of arg " + arg.name() +
              " in Caffe2 operator: " + c10::typeKindToString(type->kind()));
        }
      }
    }
  }
  EliminateDeadCode(
      block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
}

void PreprocessCaffe2Ops(std::shared_ptr<Graph>& graph) {
  preprocessCaffe2Ops(graph->block());
  GRAPH_DUMP("After PreprocessCaffe2Ops: ", graph);
}

// Transform PythonOps into Nodes that match ONNX semantics.
std::shared_ptr<Graph> ToONNX(
    std::shared_ptr<Graph>& graph,
    ::torch::onnx::OperatorExportTypes operator_export_type) {
  ConstantValueMap::ClearMaps();
  auto new_graph = std::make_shared<Graph>(graph->current_scope());
  py::dict env;
  // Kept identical to values in env. Used for constant-time existance check.
  py::set values_in_env;
  try {
    BlockToONNX(
        graph->block(),
        new_graph->block(),
        operator_export_type,
        env,
        values_in_env);
  } catch (std::runtime_error& ex) {
    ONNX_LOG(
        "ONNX graph being constructed during exception:\n",
        new_graph->toString());
    throw;
  }
  GRAPH_DUMP("after ToONNX: ", new_graph);
  ConstantValueMap::ClearMaps();
  return new_graph;
}

// BlockToONNX.
// is_sub_block = true means the old_block (aten graph) is in the sub block
// (e.g., if sub block), and we want to convert it into its parent block in onnx
// graph. In this case, we don't register the input/output or eliminate the dead
// code.
py::dict BlockToONNX(
    Block* old_block,
    Block* new_block,
    ::torch::onnx::OperatorExportTypes operator_export_type,
    py::dict& env,
    py::set& values_in_env,
    bool is_sub_block) {
  torch::autograd::SymbolicContext ctx{};
  ctx.block = new_block;

  GRAPH_DEBUG(
      "BlockToONNX: graph of old block: ",
      old_block->owningGraph()->toString());

  // Initialize context and environment
  if (!is_sub_block) {
    for (auto input : old_block->inputs()) {
      auto n = ctx.block->addInput()->copyMetadata(input);
      auto py_n = py::cast(n);
      env[py::cast(input)] = py_n;
      values_in_env.add(py_n);
    }
  }

  // Determine if all inputs are static. This is used for each node to
  // determine whether or not to propagate shapes.
  if (!is_sub_block) {
    bool static_input_shape = AllGraphInputsStatic(ctx.block->owningGraph());
    ConstantValueMap::SetAllGraphInputsStatic(static_input_shape);
  }

  // Finally, visit all nodes in the graph
  for (auto node : old_block->nodes()) {
    NodeToONNX(node, ctx.block, operator_export_type, env, values_in_env);
  }

  if (is_sub_block) {
    return env;
  }

  for (auto output : old_block->outputs()) {
    auto py_value = env[py::cast(output)];
    Value* value = py_value.cast<Value*>();
    ctx.block->registerOutput(value);
  }
  // Run dce to clean-up unused functional and inplace ops.
  EliminateDeadCode(
      ctx.block,
      true,
      DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);

  return py::dict();
}

bool ConstantFoldCondition(torch::jit::Value* output) {
  auto fold_condition = output->node()->kind() != c10::onnx::Constant &&
      ConstantValueMap::HasValue(output->debugName());
  auto reliable_value =
      ConstantValueMap::GetTypeReliable(output->debugName()).value_or(false);
  return fold_condition && reliable_value;
}

void NodeToONNX(
    Node* old_node,
    Block* new_block,
    ::torch::onnx::OperatorExportTypes operator_export_type,
    py::dict& env,
    py::set& values_in_env) {
  py::object onnx = py::module::import("torch.onnx");
  py::object onnx_globals = py::module::import("torch.onnx._globals");
  py::object onnx_registration =
      py::module::import("torch.onnx._internal.registration");

  // Setup all the lambda helper functions.

  // Returns a node that n maps to in the new graph
  auto envFn = [&env](Value* n) -> Value* {
    auto py_n = py::cast(n);
    TORCH_CHECK(env.contains(py_n), "Dangling node reference");
    auto py_value = env[py_n];
    TORCH_CHECK(!py_value.is_none(), "Unused node was subsequently used");
    Value* value = py_value.cast<Value*>();
    return value;
  };

  // Put the new outputs in our environment map, and copy the type from the
  // input graph if they were not set by the symbolic. This is called only
  // with results of symbolic call (not for nodes that are just cloned).
  auto setOutputs = [&](const std::string& op_name,
                        Node* node,
                        const value_list& outputs) {
    auto old_outputs = node->outputs();
    // Count all outputs, excluding Handles
    auto num_old_outputs = old_outputs.size();
    if (outputs.size() != num_old_outputs) {
      std::ostringstream ss;
      ss << "symbolic for " << op_name
         << " produced an incorrect number of outputs (expected ";
      ss << num_old_outputs << ", but got " << outputs.size() << ")";
      throw std::runtime_error(ss.str());
    }
    // For const node, it does not need params_dict info, so set it to {}.
    const ParamMap empty_params_dict = {};
    auto opset_version = py::cast<int>(
        onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version"));
    for (const auto i : c10::irange(num_old_outputs)) {
      auto old = old_outputs[i];
      if (outputs[i]) {
        bool exist_in_env = values_in_env.contains(py::cast(outputs[i]));
        // Update ONNX value debug name with ATen value debug name if existed.
        // Skip if ONNX value already exist in environment.
        // This implies the op is a noop, and the value is owned by
        // other node created elsewhere.
        if (old->hasDebugName() && !exist_in_env) {
          auto old_name = outputs[i]->debugName();
          auto new_name = old->debugNameBase();
          Value* found_value = nullptr;
          bool exists = false;
          // In this scope, we fetch debug_names as a const reference and then
          // construct an iterator exist_name based on it. This iterator will
          // be corrupted if the underlying map of debug_names changes. This
          // will happen as a side-effect of setDebugName. For these reasons,
          // we make an explicit scope for exist_name and make sure that
          // setDebugName is never called with this scope.
          {
            const auto& debug_names = new_block->owningGraph()->debugNames();
            auto exist_name = debug_names.find(new_name);
            exists = exist_name != debug_names.end();
            if (exists) {
              found_value = exist_name->second;
            }
          }
          outputs[i]->setDebugName(new_name);
          if (exists) {
            found_value->setDebugName(new_name);
          }
          ConstantValueMap::UpdateValueName(old_name, outputs[i]->debugName());
        }
        // Allow symbolic() to skip specifying the type of the return node.
        // Unfortunately, they are on the hook for all internal nodes
        // (though in practice, the types are not computed.)
        //
        // If onnx shape inference is turned on, the new outputs will have
        // types inferred, and they will be merged with the old types.
        if (ConstantFoldCondition(outputs[i])) {
          // Create a const node if the node output value is in
          // ConstantValueMap.
          auto value =
              ConstantValueMap::GetValue(outputs[i]->debugName()).value();
          Node* const_node =
              new_block->owningGraph()->create(c10::onnx::Constant);
          const_node->t_(attr::value, value);
          const_node->output()->setType(TensorType::create(value));

          // Copy over source location and scope information to all nodes
          // created by the symbolic
          const_node->copyMetadata(node);
          new_block->appendNode(const_node);
          ONNXShapeTypeInference(const_node, empty_params_dict, opset_version);
          auto py_output = py::cast(const_node->output());
          env[py::cast(old)] = py_output;
          values_in_env.add(py_output);
        } else {
          // An update in ConstantValueMap is also needed here, since
          // the user setType can be only accessed in this step, and it
          // should be reliable.
          MergeInferredTypeAndSetMap(
              outputs[i], old->type(), outputs[i]->type());
          // non ONNX node with no type given will throw out the warnings here.
          UpdateReliable(
              outputs[i],
              AreInputsReliableOrStatic(outputs[i]->node()),
              /*no_type_warning=*/true);
          // For the node type that does not have ComputeConstant logic, it may
          // have reliable shape but its shape is not in ConstantValueMap. So we
          // need to update ConstantValueMap.
          UpdateShapeConstantIfReliable(outputs[i]);

          // Copy over source location and scope information to all nodes
          // created by the symbolic
          // Do not set metadata if outputs[i] is already in env.
          if (!exist_in_env) {
            outputs[i]->node()->copyMetadata(node);
          }
          auto py_output = py::cast(outputs[i]);
          env[py::cast(old)] = py_output;
          values_in_env.add(py_output);
        }
      } else {
        // Null output means that the ONNX op doesn't have outputs corresponding
        // to certain PyTorch outputs
        env[py::cast(old)] = py::none();
        if (!old->uses().empty()) {
          std::ostringstream ss;
          ss << "symbolic for " << op_name << " returned None for the output "
             << i;
          ss << " (indicating conversion for that particular output is not supported), ";
          ss << "but the network uses this output later";
          // TODO: Say what actually used it
          throw std::runtime_error(ss.str());
        }
      }
    }
  };

  // Clone the node and add it to the new graph
  auto cloneNode = [&](Node* node) {
    auto n_ = new_block->appendNode(
        new_block->owningGraph()->createClone(node, envFn));
    for (const auto i : c10::irange(node->outputs().size())) {
      // n_->outputs()[i]->setType(node->outputs()[i]->type());
      auto py_output = py::cast(n_->output(i));
      env[py::cast(node->output(i))] = py_output;
      values_in_env.add(py_output);
    }
  };

  // Inline the prim::PythonOp sub-block nodes and append them to the onnx graph
  auto inlineAutograd = [&](Node* PythonOpNode) {
    for (auto subblock : PythonOpNode->blocks()) {
      for (const auto i : c10::irange(PythonOpNode->inputs().size())) {
        auto py_value = env[py::cast(PythonOpNode->inputs()[i])];
        env[py::cast(subblock->inputs()[i])] = py_value;
        values_in_env.add(py_value);
      }
      for (auto* node : subblock->nodes()) {
        NodeToONNX(node, new_block, operator_export_type, env, values_in_env);
      }
      for (const auto i : c10::irange(PythonOpNode->outputs().size())) {
        auto py_value = env[py::cast(subblock->outputs()[i])];
        env[py::cast(PythonOpNode->outputs()[i])] = py_value;
        values_in_env.add(py_value);
      }
    }
  };

  // Cast output of symbolic() python implementation
  auto processSymbolicOutput = [&](const std::string& op_name,
                                   Node* n,
                                   const py::object& raw_output) {
    if (raw_output.ptr() == Py_None) {
      cloneNode(n);
      return;
    }
    // Cast the outputs back to C++ and put them in the new graph
    std::vector<Value*> outputs;
    try {
      if (py::isinstance<Value>(raw_output)) {
        outputs = value_list{py::cast<Value*>(raw_output)};
      } else {
        outputs = py::cast<std::vector<Value*>>(raw_output);
      }
    } catch (const std::exception& ex) {
      std::ostringstream ss;
      ss << "Error casting results of symbolic for " << op_name
         << ": expected to return list of op nodes, instead received type ''"
         << py::str(raw_output.get_type()) << "': " << py::str(raw_output);
      throw std::runtime_error(ss.str());
    }

    setOutputs(op_name, n, outputs);
  };

  auto callPySymbolicFunction = [&](Node* n) {
    // The idea is delegate as much of the actual argument massaging to
    // Python as possible

    py::tuple py_inputs(n->inputs().size());
    Py_ssize_t input_nr = 0;
    for (auto* input : n->inputs()) {
      py_inputs[input_nr++] = py::cast(envFn(input));
    }

    Graph* g = new_block->owningGraph();

    WithInsertPoint insert_point_guard(new_block);
    WithCurrentScope scope_guard(*g, n->scope());

    // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
    // Python. Check #87343 for details.
    py::list new_nodes = py::list();
    py::object raw_output = onnx.attr("_run_symbolic_function")(
        g->shared_from_this(),
        new_block,
        n,
        py_inputs,
        env,
        values_in_env,
        new_nodes,
        operator_export_type);

    // Find new nodes that have been created by _run_symbolic_function and
    // propagate metadata
    for (py::handle py_node : new_nodes) {
      Node* node = py_node.cast<Node*>();
      node->copyMetadata(n);
    }

    // TODO: Assert it's an ATen identifier???
    // (Sometimes it's not...)
    processSymbolicOutput(n->kind().toUnqualString(), n, raw_output);
    GRAPH_DUMP("after processSymbolicOutput: ", g);
  };

  auto callPySymbolicMethod = [&](ConcretePythonOp* op) {
    // Test if there is a symbolic function; bail if there is not
    auto pyobj = py::handle(op->pyobj.get());
    auto func = op->autogradFunction();
    if (func) {
      pyobj = func->get();
    }

    py::object opset_version =
        onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version");
    // NOTE(justinchuby): Call the internal registry to register the symbolic
    // method defined in the module.
    bool is_registered_op =
        onnx_registration.attr("registry")
            .attr("is_registered_op")("prim::PythonOp", opset_version)
            .cast<bool>();
    py::bool_ is_autograd_inlining_enabled =
        py::cast<bool>(onnx_globals.attr("GLOBALS").attr("autograd_inlining"));
    if (!py::hasattr(pyobj, "symbolic") && !is_registered_op) {
      // Inline the subgraph within the prim::PythonOp unless
      // either of these conditions are satisfied
      // 1. The torch.autograd.Function class of this node object has `symbolic`
      // method defined.
      // 2. Custom export symbolic is registered for prim::PythonOp.
      if ((operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX ||
           operator_export_type ==
               ::torch::onnx::OperatorExportTypes::ONNX_ATEN_FALLBACK) &&
          (py::cast<bool>(is_autograd_inlining_enabled))) {
        try {
          inlineAutograd(op);
        } catch (const std::exception& ex) {
          TORCH_WARN(
              "Unable to inline PythonOp: ",
              op->name(),
              " due to the following exception\n",
              ex.what(),
              "prim::PythonOp will be exported as is and without being inlined\n",
              "Try exporting with the following alternatives: \n",
              "1) Set operator_export_type to ONNX_FALLTHROUGH mode\n",
              "2) Register a symbolic method for the prim::PythonOp ",
              op->name());
          cloneNode(op);
        }
      } else {
        cloneNode(op);
      }
      return;
    }

    // Prepare args for Python. First one is the graph, and is followed
    // by regular args, with Variables replaced by corresponding nodes.
    Py_ssize_t input_nr = 0;
    py::tuple py_symbolic_args(op->cconv.size());
    auto inputs = op->inputs();
    auto node_it = inputs.begin();
    auto scalar_it = op->scalar_args.begin();
    for (auto arg_type : op->cconv) {
      py::object obj;
      if (arg_type == 'c') {
        TORCH_CHECK(
            scalar_it != op->scalar_args.end(),
            "expected too many scalar args");
        obj = py::reinterpret_borrow<py::object>(
            py::handle((scalar_it++)->get()));
      } else if (arg_type == 'd') {
        TORCH_CHECK(node_it != inputs.end(), "expected too many inputs");
        obj = py::cast(envFn(*node_it++));
      } else {
        throw std::runtime_error("unexpected calling convention");
      }
      py_symbolic_args[input_nr++] = obj;
    }

    WithInsertPoint insert_point_guard(new_block);
    WithCurrentScope scope_guard(*new_block->owningGraph(), op->scope());

    if (py::hasattr(pyobj, "symbolic")) {
      // Call the symbolic function
      // Use a little trampoline function so we can give good error messages
      // upon argument mismatch
      // Register as a custom operator
      // TODO: Find a more elegant way to do this without having to touch
      // internal Python modules.
      // TODO(justinchuby): Define a namespace for these Python Ops.
      onnx_registration.attr("registry")
          .attr("register")(
              "::" + op->name(),
              opset_version,
              pyobj.attr("symbolic"),
              /* custom */ true);

      // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
      // Python. Check #87343 for details.
      py::object raw_output = onnx.attr("_run_symbolic_method")(
          new_block->owningGraph()->shared_from_this(),
          op->name(),
          pyobj.attr("symbolic"),
          py_symbolic_args);

      processSymbolicOutput(op->name(), op, raw_output);
    } else {
      TORCH_INTERNAL_ASSERT(is_registered_op);
      Node* n = static_cast<Node*>(op);
      n->s_(attr::name, op->name());
      // Call symbolic function
      // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
      // Python. Check #87343 for details.
      py::list new_nodes = py::list();
      py::object raw_output = onnx.attr("_run_symbolic_function")(
          new_block->owningGraph()->shared_from_this(),
          new_block,
          n,
          py_symbolic_args,
          env,
          values_in_env,
          new_nodes,
          operator_export_type);

      processSymbolicOutput(op->kind().toUnqualString(), n, raw_output);
    }
  };

  auto k = old_node->kind();
  if (k.is_caffe2()) {
    // Pass on Caffe2 operator, since we already preprocess it
    cloneNode(old_node);
  } else if (k == prim::PythonOp) {
    callPySymbolicMethod(static_cast<ConcretePythonOp*>(old_node));
  } else {
    callPySymbolicFunction(old_node);
  }
}

} // namespace torch::jit