#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit { void removePrintOps(Block* block) { for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; ++it) { for (auto b : it->blocks()) { removePrintOps(b); } if (it->kind() == prim::Print || it->kind() == aten::warn) { for (size_t i = 0; i < it->inputs().size();) { auto input = it->inputs().at(i); // only handling constants bc of potential side effects if (input->uses().size() == 1 && input->node()->kind() == prim::Constant) { it->removeInput(i); input->node()->destroy(); } else { ++i; } } it.destroyCurrent(); } } } void RemovePrintOps(std::shared_ptr& graph) { removePrintOps(graph->block()); GRAPH_DUMP("After RemovePrintOps: ", graph); } void checkONNXCompatibility(const c10::FunctionSchema& schema) { // in ONNX, all inputs are tensors, no support for tensor list // so at most one input tensor list is supported bool has_tensor_list = false; const auto& args = schema.arguments(); for (const auto& arg : args) { if (arg.name() == "_caffe2_preallocated_outputs") { continue; } auto type = arg.type(); if (type->kind() == TypeKind::OptionalType) { type = reinterpret_cast(type.get())->getElementType(); // recursive optional type is not supported TORCH_INTERNAL_ASSERT(type->kind() != TypeKind::OptionalType); } if (type->kind() == TypeKind::ListType) { const auto& elem_type = reinterpret_cast(type.get())->getElementType(); if (elem_type->isSubtypeOf(*TensorType::get())) { TORCH_INTERNAL_ASSERT( !has_tensor_list, "ONNX export supports at most one TensorList as input."); has_tensor_list = true; } } } } void preprocessCaffe2Ops(Block* block) { for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; ++it) { for (auto b : it->blocks()) { preprocessCaffe2Ops(b); } if (it->kind().is_caffe2()) { const auto& schema = it->schema(); checkONNXCompatibility(schema); std::vector origin_inputs; for (Value* v : it->inputs()) { origin_inputs.push_back(v); } it->removeAllInputs(); const auto& args = schema.arguments(); size_t origin_inputs_index = 0; for (const auto& arg : args) { const auto& type = arg.type(); TORCH_INTERNAL_ASSERT(origin_inputs_index < origin_inputs.size()); const auto& origin_input = origin_inputs[origin_inputs_index++]; if (type->kind() == TypeKind::OptionalType && origin_input->mustBeNone()) { continue; } if (type->isSubtypeOf(*TensorType::get())) { it->addInput(origin_input); } else if ( type->kind() == TypeKind::BoolType || type->kind() == TypeKind::IntType) { const auto* constant_node = origin_input->node(); TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant); it->i_(Symbol::attr(arg.name()), constant_node->i(attr::value)); } else if (type->kind() == TypeKind::FloatType) { const auto* constant_node = origin_input->node(); TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant); it->f_(Symbol::attr(arg.name()), constant_node->f(attr::value)); } else if (type->kind() == TypeKind::StringType) { const auto* constant_node = origin_input->node(); TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant); it->s_(Symbol::attr(arg.name()), constant_node->s(attr::value)); } else if (type->kind() == TypeKind::ListType) { const auto& list_node = origin_input->node(); const auto& elem_type = type->castRaw()->getElementType(); TORCH_INTERNAL_ASSERT( list_node->kind() == prim::ListConstruct || list_node->kind() == prim::Constant); if (elem_type->isSubtypeOf(*TensorType::get())) { TORCH_INTERNAL_ASSERT(list_node->kind(), prim::ListConstruct); const auto& tensor_list = origin_input->node()->inputs(); for (const auto& t : tensor_list) { it->addInput(t); } } else if (elem_type->kind() == TypeKind::FloatType) { std::vector values; if (list_node->kind() == prim::ListConstruct) { for (const auto* elem_input : list_node->inputs()) { const auto* constant_node = elem_input->node(); TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant); values.push_back(constant_node->f(attr::value)); } } else { // is a constant list values = list_node->fs(attr::value); } it->fs_(Symbol::attr(arg.name()), values); } else { throw std::runtime_error( "Unhandled scalar arg: " + arg.name() + ", type: " + c10::typeKindToString(elem_type->kind())); } } else { throw std::runtime_error( "Unsupported input type of arg " + arg.name() + " in Caffe2 operator: " + c10::typeKindToString(type->kind())); } } } } EliminateDeadCode( block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); } void PreprocessCaffe2Ops(std::shared_ptr& graph) { preprocessCaffe2Ops(graph->block()); GRAPH_DUMP("After PreprocessCaffe2Ops: ", graph); } // Transform PythonOps into Nodes that match ONNX semantics. std::shared_ptr ToONNX( std::shared_ptr& graph, ::torch::onnx::OperatorExportTypes operator_export_type) { ConstantValueMap::ClearMaps(); auto new_graph = std::make_shared(graph->current_scope()); py::dict env; // Kept identical to values in env. Used for constant-time existance check. py::set values_in_env; try { BlockToONNX( graph->block(), new_graph->block(), operator_export_type, env, values_in_env); } catch (std::runtime_error& ex) { ONNX_LOG( "ONNX graph being constructed during exception:\n", new_graph->toString()); throw; } GRAPH_DUMP("after ToONNX: ", new_graph); ConstantValueMap::ClearMaps(); return new_graph; } // BlockToONNX. // is_sub_block = true means the old_block (aten graph) is in the sub block // (e.g., if sub block), and we want to convert it into its parent block in onnx // graph. In this case, we don't register the input/output or eliminate the dead // code. py::dict BlockToONNX( Block* old_block, Block* new_block, ::torch::onnx::OperatorExportTypes operator_export_type, py::dict& env, py::set& values_in_env, bool is_sub_block) { torch::autograd::SymbolicContext ctx{}; ctx.block = new_block; GRAPH_DEBUG( "BlockToONNX: graph of old block: ", old_block->owningGraph()->toString()); // Initialize context and environment if (!is_sub_block) { for (auto input : old_block->inputs()) { auto n = ctx.block->addInput()->copyMetadata(input); auto py_n = py::cast(n); env[py::cast(input)] = py_n; values_in_env.add(py_n); } } // Determine if all inputs are static. This is used for each node to // determine whether or not to propagate shapes. if (!is_sub_block) { bool static_input_shape = AllGraphInputsStatic(ctx.block->owningGraph()); ConstantValueMap::SetAllGraphInputsStatic(static_input_shape); } // Finally, visit all nodes in the graph for (auto node : old_block->nodes()) { NodeToONNX(node, ctx.block, operator_export_type, env, values_in_env); } if (is_sub_block) { return env; } for (auto output : old_block->outputs()) { auto py_value = env[py::cast(output)]; Value* value = py_value.cast(); ctx.block->registerOutput(value); } // Run dce to clean-up unused functional and inplace ops. EliminateDeadCode( ctx.block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); return py::dict(); } bool ConstantFoldCondition(torch::jit::Value* output) { auto fold_condition = output->node()->kind() != c10::onnx::Constant && ConstantValueMap::HasValue(output->debugName()); auto reliable_value = ConstantValueMap::GetTypeReliable(output->debugName()).value_or(false); return fold_condition && reliable_value; } void NodeToONNX( Node* old_node, Block* new_block, ::torch::onnx::OperatorExportTypes operator_export_type, py::dict& env, py::set& values_in_env) { py::object onnx = py::module::import("torch.onnx"); py::object onnx_globals = py::module::import("torch.onnx._globals"); py::object onnx_registration = py::module::import("torch.onnx._internal.registration"); // Setup all the lambda helper functions. // Returns a node that n maps to in the new graph auto envFn = [&env](Value* n) -> Value* { auto py_n = py::cast(n); TORCH_CHECK(env.contains(py_n), "Dangling node reference"); auto py_value = env[py_n]; TORCH_CHECK(!py_value.is_none(), "Unused node was subsequently used"); Value* value = py_value.cast(); return value; }; // Put the new outputs in our environment map, and copy the type from the // input graph if they were not set by the symbolic. This is called only // with results of symbolic call (not for nodes that are just cloned). auto setOutputs = [&](const std::string& op_name, Node* node, const value_list& outputs) { auto old_outputs = node->outputs(); // Count all outputs, excluding Handles auto num_old_outputs = old_outputs.size(); if (outputs.size() != num_old_outputs) { std::ostringstream ss; ss << "symbolic for " << op_name << " produced an incorrect number of outputs (expected "; ss << num_old_outputs << ", but got " << outputs.size() << ")"; throw std::runtime_error(ss.str()); } // For const node, it does not need params_dict info, so set it to {}. const ParamMap empty_params_dict = {}; auto opset_version = py::cast( onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version")); for (const auto i : c10::irange(num_old_outputs)) { auto old = old_outputs[i]; if (outputs[i]) { bool exist_in_env = values_in_env.contains(py::cast(outputs[i])); // Update ONNX value debug name with ATen value debug name if existed. // Skip if ONNX value already exist in environment. // This implies the op is a noop, and the value is owned by // other node created elsewhere. if (old->hasDebugName() && !exist_in_env) { auto old_name = outputs[i]->debugName(); auto new_name = old->debugNameBase(); Value* found_value = nullptr; bool exists = false; // In this scope, we fetch debug_names as a const reference and then // construct an iterator exist_name based on it. This iterator will // be corrupted if the underlying map of debug_names changes. This // will happen as a side-effect of setDebugName. For these reasons, // we make an explicit scope for exist_name and make sure that // setDebugName is never called with this scope. { const auto& debug_names = new_block->owningGraph()->debugNames(); auto exist_name = debug_names.find(new_name); exists = exist_name != debug_names.end(); if (exists) { found_value = exist_name->second; } } outputs[i]->setDebugName(new_name); if (exists) { found_value->setDebugName(new_name); } ConstantValueMap::UpdateValueName(old_name, outputs[i]->debugName()); } // Allow symbolic() to skip specifying the type of the return node. // Unfortunately, they are on the hook for all internal nodes // (though in practice, the types are not computed.) // // If onnx shape inference is turned on, the new outputs will have // types inferred, and they will be merged with the old types. if (ConstantFoldCondition(outputs[i])) { // Create a const node if the node output value is in // ConstantValueMap. auto value = ConstantValueMap::GetValue(outputs[i]->debugName()).value(); Node* const_node = new_block->owningGraph()->create(c10::onnx::Constant); const_node->t_(attr::value, value); const_node->output()->setType(TensorType::create(value)); // Copy over source location and scope information to all nodes // created by the symbolic const_node->copyMetadata(node); new_block->appendNode(const_node); ONNXShapeTypeInference(const_node, empty_params_dict, opset_version); auto py_output = py::cast(const_node->output()); env[py::cast(old)] = py_output; values_in_env.add(py_output); } else { // An update in ConstantValueMap is also needed here, since // the user setType can be only accessed in this step, and it // should be reliable. MergeInferredTypeAndSetMap( outputs[i], old->type(), outputs[i]->type()); // non ONNX node with no type given will throw out the warnings here. UpdateReliable( outputs[i], AreInputsReliableOrStatic(outputs[i]->node()), /*no_type_warning=*/true); // For the node type that does not have ComputeConstant logic, it may // have reliable shape but its shape is not in ConstantValueMap. So we // need to update ConstantValueMap. UpdateShapeConstantIfReliable(outputs[i]); // Copy over source location and scope information to all nodes // created by the symbolic // Do not set metadata if outputs[i] is already in env. if (!exist_in_env) { outputs[i]->node()->copyMetadata(node); } auto py_output = py::cast(outputs[i]); env[py::cast(old)] = py_output; values_in_env.add(py_output); } } else { // Null output means that the ONNX op doesn't have outputs corresponding // to certain PyTorch outputs env[py::cast(old)] = py::none(); if (!old->uses().empty()) { std::ostringstream ss; ss << "symbolic for " << op_name << " returned None for the output " << i; ss << " (indicating conversion for that particular output is not supported), "; ss << "but the network uses this output later"; // TODO: Say what actually used it throw std::runtime_error(ss.str()); } } } }; // Clone the node and add it to the new graph auto cloneNode = [&](Node* node) { auto n_ = new_block->appendNode( new_block->owningGraph()->createClone(node, envFn)); for (const auto i : c10::irange(node->outputs().size())) { // n_->outputs()[i]->setType(node->outputs()[i]->type()); auto py_output = py::cast(n_->output(i)); env[py::cast(node->output(i))] = py_output; values_in_env.add(py_output); } }; // Inline the prim::PythonOp sub-block nodes and append them to the onnx graph auto inlineAutograd = [&](Node* PythonOpNode) { for (auto subblock : PythonOpNode->blocks()) { for (const auto i : c10::irange(PythonOpNode->inputs().size())) { auto py_value = env[py::cast(PythonOpNode->inputs()[i])]; env[py::cast(subblock->inputs()[i])] = py_value; values_in_env.add(py_value); } for (auto* node : subblock->nodes()) { NodeToONNX(node, new_block, operator_export_type, env, values_in_env); } for (const auto i : c10::irange(PythonOpNode->outputs().size())) { auto py_value = env[py::cast(subblock->outputs()[i])]; env[py::cast(PythonOpNode->outputs()[i])] = py_value; values_in_env.add(py_value); } } }; // Cast output of symbolic() python implementation auto processSymbolicOutput = [&](const std::string& op_name, Node* n, const py::object& raw_output) { if (raw_output.ptr() == Py_None) { cloneNode(n); return; } // Cast the outputs back to C++ and put them in the new graph std::vector outputs; try { if (py::isinstance(raw_output)) { outputs = value_list{py::cast(raw_output)}; } else { outputs = py::cast>(raw_output); } } catch (const std::exception& ex) { std::ostringstream ss; ss << "Error casting results of symbolic for " << op_name << ": expected to return list of op nodes, instead received type ''" << py::str(raw_output.get_type()) << "': " << py::str(raw_output); throw std::runtime_error(ss.str()); } setOutputs(op_name, n, outputs); }; auto callPySymbolicFunction = [&](Node* n) { // The idea is delegate as much of the actual argument massaging to // Python as possible py::tuple py_inputs(n->inputs().size()); Py_ssize_t input_nr = 0; for (auto* input : n->inputs()) { py_inputs[input_nr++] = py::cast(envFn(input)); } Graph* g = new_block->owningGraph(); WithInsertPoint insert_point_guard(new_block); WithCurrentScope scope_guard(*g, n->scope()); // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to // Python. Check #87343 for details. py::list new_nodes = py::list(); py::object raw_output = onnx.attr("_run_symbolic_function")( g->shared_from_this(), new_block, n, py_inputs, env, values_in_env, new_nodes, operator_export_type); // Find new nodes that have been created by _run_symbolic_function and // propagate metadata for (py::handle py_node : new_nodes) { Node* node = py_node.cast(); node->copyMetadata(n); } // TODO: Assert it's an ATen identifier??? // (Sometimes it's not...) processSymbolicOutput(n->kind().toUnqualString(), n, raw_output); GRAPH_DUMP("after processSymbolicOutput: ", g); }; auto callPySymbolicMethod = [&](ConcretePythonOp* op) { // Test if there is a symbolic function; bail if there is not auto pyobj = py::handle(op->pyobj.get()); auto func = op->autogradFunction(); if (func) { pyobj = func->get(); } py::object opset_version = onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version"); // NOTE(justinchuby): Call the internal registry to register the symbolic // method defined in the module. bool is_registered_op = onnx_registration.attr("registry") .attr("is_registered_op")("prim::PythonOp", opset_version) .cast(); py::bool_ is_autograd_inlining_enabled = py::cast(onnx_globals.attr("GLOBALS").attr("autograd_inlining")); if (!py::hasattr(pyobj, "symbolic") && !is_registered_op) { // Inline the subgraph within the prim::PythonOp unless // either of these conditions are satisfied // 1. The torch.autograd.Function class of this node object has `symbolic` // method defined. // 2. Custom export symbolic is registered for prim::PythonOp. if ((operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX || operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX_ATEN_FALLBACK) && (py::cast(is_autograd_inlining_enabled))) { try { inlineAutograd(op); } catch (const std::exception& ex) { TORCH_WARN( "Unable to inline PythonOp: ", op->name(), " due to the following exception\n", ex.what(), "prim::PythonOp will be exported as is and without being inlined\n", "Try exporting with the following alternatives: \n", "1) Set operator_export_type to ONNX_FALLTHROUGH mode\n", "2) Register a symbolic method for the prim::PythonOp ", op->name()); cloneNode(op); } } else { cloneNode(op); } return; } // Prepare args for Python. First one is the graph, and is followed // by regular args, with Variables replaced by corresponding nodes. Py_ssize_t input_nr = 0; py::tuple py_symbolic_args(op->cconv.size()); auto inputs = op->inputs(); auto node_it = inputs.begin(); auto scalar_it = op->scalar_args.begin(); for (auto arg_type : op->cconv) { py::object obj; if (arg_type == 'c') { TORCH_CHECK( scalar_it != op->scalar_args.end(), "expected too many scalar args"); obj = py::reinterpret_borrow( py::handle((scalar_it++)->get())); } else if (arg_type == 'd') { TORCH_CHECK(node_it != inputs.end(), "expected too many inputs"); obj = py::cast(envFn(*node_it++)); } else { throw std::runtime_error("unexpected calling convention"); } py_symbolic_args[input_nr++] = obj; } WithInsertPoint insert_point_guard(new_block); WithCurrentScope scope_guard(*new_block->owningGraph(), op->scope()); if (py::hasattr(pyobj, "symbolic")) { // Call the symbolic function // Use a little trampoline function so we can give good error messages // upon argument mismatch // Register as a custom operator // TODO: Find a more elegant way to do this without having to touch // internal Python modules. // TODO(justinchuby): Define a namespace for these Python Ops. onnx_registration.attr("registry") .attr("register")( "::" + op->name(), opset_version, pyobj.attr("symbolic"), /* custom */ true); // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to // Python. Check #87343 for details. py::object raw_output = onnx.attr("_run_symbolic_method")( new_block->owningGraph()->shared_from_this(), op->name(), pyobj.attr("symbolic"), py_symbolic_args); processSymbolicOutput(op->name(), op, raw_output); } else { TORCH_INTERNAL_ASSERT(is_registered_op); Node* n = static_cast(op); n->s_(attr::name, op->name()); // Call symbolic function // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to // Python. Check #87343 for details. py::list new_nodes = py::list(); py::object raw_output = onnx.attr("_run_symbolic_function")( new_block->owningGraph()->shared_from_this(), new_block, n, py_symbolic_args, env, values_in_env, new_nodes, operator_export_type); processSymbolicOutput(op->kind().toUnqualString(), n, raw_output); } }; auto k = old_node->kind(); if (k.is_caffe2()) { // Pass on Caffe2 operator, since we already preprocess it cloneNode(old_node); } else if (k == prim::PythonOp) { callPySymbolicMethod(static_cast(old_node)); } else { callPySymbolicFunction(old_node); } } } // namespace torch::jit