#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace { bool reportSourceLocation(size_t file_size) { if (file_size < 512ull * 1024) { return true; } const char* enable_env = std::getenv("PYTORCH_JIT_ENABLE_LARGE_SOURCE_LOCATION"); bool flag = true; if (enable_env == nullptr || std::strcmp(enable_env, "0") == 0 || std::strcmp(enable_env, "FALSE") == 0 || std::strcmp(enable_env, "false") == 0) { flag = false; } return flag; } } // namespace namespace torch::jit { using FunctionTable = std::unordered_map; using ValueTable = std::unordered_map; using TypeTable = std::unordered_map; using AttributeMap = std::unordered_map; using ListAttributeMap = std::unordered_map>; struct Refinement { Refinement(std::string identifier, TypePtr type) : identifier_(std::move(identifier)), type_(std::move(type)) {} const std::string& identifier() const { return identifier_; } TypePtr type() const { return type_; } private: std::string identifier_; TypePtr type_; }; struct RefinementSet { // When a comparison like x is None is made, we associate type refinements // with its true value and its false value. If a boolean that has refinements // associated with it is used in a conditional of an if statement, the true // and false refinements are inserted into the corresponding blocks using Refinements = std::vector; RefinementSet(Refinements true_refinements, Refinements false_refinements) : true_refinements_(std::move(true_refinements)), false_refinements_(std::move(false_refinements)) {} RefinementSet(Refinement single) : RefinementSet({std::move(single)}, {}) {} RefinementSet(Refinement single_true, Refinement single_false) : RefinementSet( Refinements({std::move(single_true)}), Refinements({std::move(single_false)})) {} RefinementSet() = default; // empty RefinementSet And(const RefinementSet& rhs) const { // if the result of an AND is true, both a & b had to be true, // so we take the union of a.true_refinements and b.true_refinements. // if the result is false, either a or b could have been false, // so we take their intersection. return RefinementSet( unionSet(true_refinements_, rhs.true_refinements_), intersectSet(false_refinements_, rhs.false_refinements_)); } RefinementSet Or(const RefinementSet& rhs) const { // if the result of an OR is true, either a & b could have been true, // so we take the intersection of a.true_refinements & b.true_refinements. // if the result is false, both a and b had to be false, // so we take their union. return RefinementSet( intersectSet(true_refinements_, rhs.true_refinements_), unionSet(false_refinements_, rhs.false_refinements_)); } RefinementSet Not() const { return RefinementSet(false_refinements_, true_refinements_); } const std::vector activeRefinements() const { return true_refinements_; } private: static bool sameVar(const Refinement& a, const Refinement& b) { return a.identifier() == b.identifier(); } static Refinements unionSet(const Refinements& a, const Refinements& b) { Refinements result = a; for (const Refinement& r : b) { auto it = std::find_if(result.begin(), result.end(), [&](const Refinement& e) { return e.identifier() == r.identifier(); }); if (it == result.end()) { result.push_back(r); } else if (*it->type() != *r.type()) { // we only keep refinements when they exactly match one // refinement type, for instance, we do not attempt to refine: // isinstance(x, float) and isinstance(x, int) result.erase(it); } } return result; } static Refinements intersectSet(const Refinements& a, const Refinements& b) { Refinements result; for (const Refinement& r : a) { auto it = std::find_if(b.begin(), b.end(), [&](const Refinement& e) { return e.identifier() == r.identifier(); }); if (it != b.end() && r.type() == it->type()) { result.push_back(r); } } return result; } Refinements true_refinements_; Refinements false_refinements_; }; struct CondValue { CondValue( Value* value, RefinementSet refinements, std::optional static_if) : value_(value), refinements_(std::move(refinements)), static_if_(static_if) {} CondValue( Graph& g, const SourceRange& loc, bool static_value, RefinementSet refinements) : value_(g.insertConstant(static_value, loc)), refinements_(std::move(refinements)), static_if_(static_value) {} Value* value() const { return value_; } const RefinementSet& refinements() const { return refinements_; } std::optional staticIf() const { return static_if_; } private: Value* value_; RefinementSet refinements_; std::optional static_if_; // certain expression cause us to emit a static if statement // this value is present if this is the case. // this is not equivalent to value_ being a constant // it is possible for value_ to be constant but for // the expression that produced it to not trigger the // static if behavior. e.g. use of a variable assigned // to a constant }; enum NoneStatus { ALWAYS, MAYBE, NEVER }; static NoneStatus canBeNone(Value* v) { if (v->node()->mustBeNone()) { return ALWAYS; } if (v->type()->kind() == OptionalType::Kind || (v->type()->kind() == UnionType::Kind && v->type()->expect()->canHoldType(*NoneType::get()))) { return MAYBE; } return NEVER; } static Value* asSimple(const SugaredValuePtr& value) { if (SimpleValue* sv = dynamic_cast(value.get())) { return sv->getValue(); } return nullptr; } static std::shared_ptr makeMagic( const std::string& name, const SugaredValuePtr& base) { return std::make_shared(name, base); } // Auxiliary data structure for desugaring variable binding into our always // explicitly scoped language as we descend down nested control structures in // the frontend (which themselves don't introduce scopes) // // The Environment keeps track of two tables, one for values which are not first // class and a type table for values which are. When a first class value // is set in the environment, we emit a prim::Store which sets the // name of the variable to appropriate type, and when a first-class value is // referenced we emit a prim::Load that generates a value of the appropriate // type. // // a = 1 // print(a) // becomes: // = prim::Store[name="a"](%a.1) // %a : int = prim::Load[name="a"]() // prim::Print(%a) struct Environment { Environment( GraphFunction& method, ResolverPtr resolver, Block* b, std::shared_ptr next = nullptr) : method(method), resolver(std::move(resolver)), b(b), next(std::move(next)) {} // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) GraphFunction& method; ResolverPtr resolver; std::unordered_map> error_messages; Block* b; std::shared_ptr next; // set type error in the lowest environment. if the variable is used after an // error has been set, then we will use the more informative error message void setVariableTypeError( const std::string& name, std::function msg) { auto runner = this; while (runner->next) { runner = runner->next.get(); } runner->error_messages[name] = std::move(msg); } // see if type error has been set for a variable std::optional findVariableTypeError(const std::string& name) { auto runner = this; while (runner->next) { runner = runner->next.get(); } auto msg = runner->error_messages.find(name); if (msg != runner->error_messages.end()) { return msg->second(); } else { return std::nullopt; } } SugaredValuePtr insertLoad(const std::string& name, const TypePtr& type) { auto g = b->owningGraph(); auto load = g->insertNode(g->createLoad(name, type)); if (meaningfulName(name)) { load->output()->setDebugName(name); } return std::make_shared(load->output()); } // note: type is not always the same as v->type(), e.g. // type: Optional[Tensor] // v->type(): Tensor void insertStore( const std::string& name, const SourceRange& loc, Value* v, TypePtr type) { auto g = b->owningGraph(); g->insertNode(g->createStore(name, v))->setSourceRange(loc); type_table[name] = std::move(type); } SugaredValuePtr findInThisFrame(const std::string& name) { auto it = value_table.find(name); if (it != value_table.end()) { return it->second; } auto it2 = type_table.find(name); if (it2 != type_table.end()) { return insertLoad(name, it2->second); } return nullptr; } SugaredValuePtr findInParentFrame(const std::string& name) { return next ? next->findInAnyFrame(name) : nullptr; } void setType(const std::string& name, TypePtr type) { type_table[name] = std::move(type); } SugaredValuePtr findInAnyFrame(const std::string& name) { for (auto runner = this; runner; runner = runner->next.get()) { if (auto r = runner->findInThisFrame(name)) { return r; } } return nullptr; } Block* block() { return b; } void setVar(const SourceRange& loc, const std::string& name, Value* value) { setSugaredVar( loc, name, std::make_shared(value), /*annotated_type=*/nullptr); } void setSugaredVar( const SourceRange& loc, const std::string& name, SugaredValuePtr value, const TypePtr& annotated_type) { Value* as_simple_value = asSimple(value); if (as_simple_value && !as_simple_value->hasDebugName() && meaningfulName(name) && // note: if the value wasn't defined in this block, we might be giving a // name only used inside this block to a value outside of this. this is // not normally helpful for debugging and causes import/export jitter. as_simple_value->node()->owningBlock() == block()) { as_simple_value->setDebugName(name); } // prevent re-assignment involving any sugared values // any reassignment like: // a = ... // while ... // a = .. // requires 'a' to be first-class in the graph since its value depends on // control flow if (auto parent = findInParentFrame(name)) { if (annotated_type) { throw( ErrorReport(loc) << "Attempting to declare and annotate the type of variable '" << name << "' but it is already defined in an outer block"); } if (!as_simple_value) { throw( ErrorReport(loc) << "Cannot re-assign '" << name << "' to a value of type " << value->kind() << " because " << name << " is not a first-class value. Only reassignments to first-class values are allowed"); } Value* simple_parent = asSimple(parent); if (!simple_parent) { throw( ErrorReport(loc) << "Cannot re-assign '" << name << "' because it has type " << value->kind() << " and " << name << " is not a first-class value. Only reassignments to first-class values are allowed"); } auto parent_type = unshapedType(simple_parent->type()); as_simple_value = tryConvertToType( loc, *b->owningGraph(), parent_type, as_simple_value, /*allow_conversions=*/true); std::stringstream why_not; if (!as_simple_value->type()->isSubtypeOfExt(*parent_type, &why_not)) { auto error = ErrorReport(loc); error << "Variable '" << name << "' previously had type " << simple_parent->type()->repr_str() << " but is now being assigned to a value of type " << as_simple_value->type()->repr_str(); // Special-cased error msg if we're trying to assign to a tensor list. if (simple_parent->type()->kind() == TypeKind::ListType && as_simple_value->type()->kind() == TypeKind::ListType) { error << "\nEmpty lists default to List[Tensor]. Add a variable " "annotation to the assignment to create an empty list " "of another type (torch.jit.annotate(List[T, []]) where T " "is the type of elements in the list for Python 2)"; } error << "\n" << why_not.str(); throw ErrorReport(error); } } if (as_simple_value) { if (annotated_type && !as_simple_value->type()->isSubtypeOf(*annotated_type)) { throw( ErrorReport(loc) << "Variable '" << name << "' is annotated with type " << annotated_type->repr_str() << " but is being assigned to a value of type " << as_simple_value->type()->repr_str()); } auto value_store_type = annotated_type ? annotated_type : as_simple_value->type(); insertStore(name, loc, as_simple_value, value_store_type); } else { value_table[name] = std::move(value); } } SugaredValuePtr getSugaredVar(const Ident& ident, bool required = true) { return getSugaredVar(ident.name(), ident.range()); } Value* getVar(const Ident& ident) { return getSugaredVar(ident)->asValue(ident.range(), method); } void throwVarNotFoundError( const std::string& ident, const SourceRange& range) { // check if this value was not emitted in an if statement because of a // type mismatch. if it was, then we print a more informative error msg if (auto msg = findVariableTypeError(ident)) { throw(ErrorReport(range) << *msg << "and was used here"); } throw(ErrorReport(range) << "undefined value " << ident); } SugaredValuePtr getSugaredVar( const std::string& ident, const SourceRange& range, bool required = true) { auto retval = findInAnyFrame(ident); if (!retval) { static std::unordered_map globals = { {"print", std::make_shared()}, {"tuple", SpecialFormValue::create(prim::TupleConstruct)}, {"float", makeMagic( "__float__", std::make_shared(FloatType::get(), aten::Float))}, {"complex", makeMagic( "__complex__", std::make_shared(ComplexType::get(), aten::Complex))}, {"int", makeMagic( "__int__", std::make_shared(IntType::get(), aten::Int))}, {"bool", makeMagic( "__bool__", std::make_shared(BoolType::get(), aten::Bool))}, {"str", makeMagic( "__str__", std::make_shared(StringType::get(), aten::str))}, {"getattr", SpecialFormValue::create(prim::GetAttr)}, {"hasattr", SpecialFormValue::create(prim::HasAttr)}, {"isinstance", SpecialFormValue::create(prim::isinstance)}, // todo(zach): remove when we can correctly export torch.full via ONNX // or we have implicit conversion that can convert numbers to tensors {"_to_tensor", std::make_shared(TensorType::get(), prim::NumToTensor)}, {"len", makeMagic( "__len__", std::make_shared(aten::len, std::nullopt))}, {"hex", makeMagic( "__hex__", std::make_shared(aten::hex, std::nullopt))}, {"oct", makeMagic( "__oct__", std::make_shared(aten::oct, std::nullopt))}, {"round", makeMagic( "__round__", std::make_shared(aten::round, std::nullopt))}, {"hash", std::make_shared(aten::hash, std::nullopt)}, {"id", std::make_shared(prim::id, std::nullopt)}, {"min", std::make_shared(prim::min, std::nullopt)}, {"max", std::make_shared(prim::max, std::nullopt)}, {"abs", std::make_shared(prim::abs, std::nullopt)}, {"all", std::make_shared(aten::all, std::nullopt)}, {"any", std::make_shared(aten::any, std::nullopt)}, {"divmod", std::make_shared(aten::divmod, std::nullopt)}, {"sum", std::make_shared(aten::sum, std::nullopt)}, {"list", SpecialFormValue::create(prim::list)}, {"dict", SpecialFormValue::create(prim::dict)}, {"ord", std::make_shared(aten::ord, std::nullopt)}, {"chr", std::make_shared(aten::chr, std::nullopt)}, {"bin", std::make_shared(aten::bin, std::nullopt)}, {"pow", std::make_shared(aten::pow, std::nullopt)}, {"range", SpecialFormValue::create(prim::range)}, {"zip", SpecialFormValue::create(prim::zip)}, {"enumerate", SpecialFormValue::create(prim::enumerate)}, {"rangelist", std::make_shared(prim::rangelist, std::nullopt)}, {"sorted", std::make_shared(aten::sorted, std::nullopt)}, // Only AssertionError is bound so that we can use it from emitAssert, // all other exceptions should be resolved at the Python level {"AssertionError", std::make_shared("AssertionError")}, }; auto it = globals.find(ident); if (it != globals.end()) { retval = it->second; } } if (!retval) { if (auto type = resolver->resolveType(ident, range)) { if (auto tuple_type = type->cast()) { retval = std::make_shared(tuple_type); } } } if (!retval) { retval = resolver->resolveValue(ident, method, range); } if (!retval) { if (auto type = resolver->resolveType(ident, range)) { if (auto class_type = type->cast()) { retval = std::make_shared(class_type); } } } if (!retval && required) { throwVarNotFoundError(ident, range); } return retval; } Value* getVar(const std::string& ident, const SourceRange& range) { return getSugaredVar(ident, range)->asValue(range, method); } void removeVar(const Ident& ident, bool check_if_removed = false) { bool removed = false; for (auto runner = this; runner; runner = runner->next.get()) { auto a = runner->value_table.erase(ident.name()); auto b = runner->type_table.erase(ident.name()); removed = a || b; } if (check_if_removed && !removed) { throwVarNotFoundError(ident.name(), ident.range()); } } std::vector definedVariables() { std::vector result; for (auto& kv : type_table) { result.push_back(kv.first); } return result; } private: TypeTable type_table; ValueTable value_table; }; template static Value* materializeConstant( T val, Graph& graph, const SourceRange& r, std::unordered_map& map) { auto existing_constant = map.find(val); if (existing_constant != map.end()) { return existing_constant->second; } WithInsertPoint guard(graph.block()->nodes().front()); auto new_constant = graph.insertConstant(val, r); map[val] = new_constant; return new_constant; } inline bool isSupportedListElementType(const TypePtr& type) { return type->isSubtypeOf(*TensorType::get()) || type->isSubtypeOf(*NumberType::get()); } // Information for each def being emitted. // Defs can be nested to support closures so we need a stack of this information // Currently records information about the functions return type. struct DefContext { TypePtr declared_return_type_; // nullptr if not annotated TypePtr merged_return_type_; // nullptr if a Return has not been seen yet }; enum class LoopStatus { NOT_IN_LOOP, IN_LOOP, IN_UNROLLED_LOOP }; struct WithLoopStatus { WithLoopStatus(LoopStatus* prev, LoopStatus new_status) : prev_ptr_(prev), prev_value_(*prev) { *prev = new_status; } ~WithLoopStatus() { *prev_ptr_ = prev_value_; } private: LoopStatus* prev_ptr_; LoopStatus prev_value_; }; struct to_ir { to_ir( const Def& def, ResolverPtr resolver_, const Self* self, GraphFunction& method) // method being constructed : method(method), graph(method.graph()), resolver(std::move(resolver_)), typeParser_(resolver), environment_stack(nullptr) { AT_ASSERT(resolver); pushFrame(graph->block(), /*starts_def=*/true); // Type annotations exclude explicitly typing the "self" parameter, so in // the case that this is a method with self we expect one fewer parameter // annotation than the number of parameters this Def takes. if (self && def.decl().params().empty()) { throw( ErrorReport(def.decl().params().range()) << "methods must have a self argument"); } method.setSchema(emitDef(def, self, graph->block())); // At this point, we might have received a graph that is compiled with // old operator schemas that might not exist in the system anymore. // Therefore, we replace such ops with its' valid upgrader. ReplaceOldOperatorsWithUpgraders(graph); // NB ORDERING: SSA conversion has to occur before // lifting of closures and forks, this way closures are converted // to SSA while part of their original graph, and closures are ready to // be inlined into forked closures ConvertToSSA(graph); // convert loops with an iter and body condition specified to // python-recognize while loops. we do this so they can be exported, // and run the pass early to avoid jitter. Like conversion to SSA, // it only needs to run once. CanonicalizeModifiedLoops(graph); // Convert Ops to a Normalized Form NormalizeOps(graph); runCleanupPasses(graph); } private: // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) GraphFunction& method; std::shared_ptr graph; ResolverPtr resolver; std::unordered_map> integral_constants; std::unordered_map> fp_constants; std::unordered_map< c10::complex, Value*, c10::hash>> complex_constants; std::unordered_set exit_blocks; ScriptTypeParser typeParser_; LoopStatus loop_status_ = LoopStatus::NOT_IN_LOOP; // Singly-linked list of environments. This top element contains a member // `next` that points to the most immediate enclosing scope's value. std::shared_ptr environment_stack; std::vector def_stack_; size_t temp_name_count_ = 0; std::string createTempName(const std::string& prefix) { return prefix + std::to_string(temp_name_count_++); } void pushFrame(Block* b, bool starts_def = false) { if (starts_def) { def_stack_.emplace_back(); } environment_stack = std::make_shared(method, resolver, b, environment_stack); } std::shared_ptr popFrame(bool ends_def = false) { auto old_frame = environment_stack; environment_stack = environment_stack->next; if (ends_def) { def_stack_.pop_back(); } return old_frame; } // If the graph might not return, add an implicit None return at the end void handleMaybeNoReturn(const Def& def, Block* block) { auto decl_ret = def_stack_.back().declared_return_type_; if (exit_blocks.count(block) == 0) { auto decl_ret = def_stack_.back().declared_return_type_; if (decl_ret && decl_ret != NoneType::get()) { throw( ErrorReport(def.range()) << "Function was not annotated as having type None, but does not " << "return along all paths"); } WithInsertPoint b(*block->nodes().end()); emitReturn(Return::create( def.range(), Expr(Compound::create(TK_NONE, def.range(), {})))); } else { // if we haven't seen any return statements, but the graph block exits // (the function always throws) then we accept the declared return type if // it exists or set it to none if (def_stack_.back().merged_return_type_ == nullptr) { def_stack_.back().merged_return_type_ = decl_ret != nullptr ? decl_ret : NoneType::get(); } } } FunctionSchema emitDef(const Def& def, const Self* self, Block* block) { auto schema = typeParser_.parseSchemaFromDef(def, bool(self)); // TODO need guards on init returning none if (schema.returns().size() == 1) { def_stack_.back().declared_return_type_ = schema.returns().at(0).type(); } std::vector arguments = emitFormalArguments(def, self, schema, block); // body auto stmts_list = def.statements(); emitStatements(stmts_list.begin(), stmts_list.end()); handleMaybeNoReturn(def, block); std::vector returns = {emitOutput(def.range(), schema, block)}; return {def.name().name(), "", std::move(arguments), std::move(returns)}; } // see [setstate type] static TypePtr getTypeForSetStateArg(const Def& def, const Self* self) { TORCH_CHECK(self, "Expected __setstate__ to have a `self` argument"); auto getstate = self->getClassType()->findMethod("__getstate__"); if (!getstate) { throw( ErrorReport(def.range()) << "`__setstate__` defined but not `__getstate__`. " << "You must have both defined on a ScriptModule " << "to customize serialization.\n" << "Did you forget to use `@torch.jit.export`?"); } getstate->ensure_defined(); return self->getClassType() ->getMethod("__getstate__") .getSchema() .returns() .at(0) .type(); } // see [setstate type] static bool shouldDeriveSetStateType( const Def& def, const FunctionSchema& schema) { const bool noTypeAnnotations = std::all_of( schema.arguments().begin(), schema.arguments().end(), [](const Argument& arg) { return arg.is_inferred_type(); }); bool shouldInfer = def.name().name() == "__setstate__" && noTypeAnnotations; if (!shouldInfer) { return false; } // Do some additional basic validation that the __setstate__ func is // well-formed TORCH_INTERNAL_ASSERT(def.name().name() == "__setstate__"); const auto numDeclParams = def.decl().params().size(); if (numDeclParams != 2) { throw( ErrorReport(def.range()) << "Expected 2 arguments for `__setstate__`, got: " << numDeclParams); } return true; } std::vector emitFormalArguments( const Def& def, const Self* self, const FunctionSchema& schema, Block* block) { std::vector arguments; // for schema // inputs auto it = def.decl().params().begin(); auto end = def.decl().params().end(); auto expected_annotation_size = def.decl().params().size(); if (self) { expected_annotation_size--; } if (schema.arguments().size() != expected_annotation_size) { throw( ErrorReport(def.decl().params().range()) << "Number of type annotations for" << " function parameters (" << schema.arguments().size() << ")" << " does not match the number of parameters on the function (" << expected_annotation_size << ")!"); } if (self) { AT_ASSERT(it != end); const auto& name = (*it).ident().name(); Value* new_input = block->addInput()->setDebugName(name); environment_stack->setSugaredVar( (*it).ident().range(), name, self->makeSugared(new_input), /*annotated_type=*/nullptr); arguments.emplace_back(name, new_input->type()); ++it; } // [setstate type] // __setstate__ is special, because if the user leaves it un-annotated we // will derive the type for `state` from the output type of __getstate__. // This is necessary so that we can allow submodules to appear in `state`. bool shouldDeriveType = shouldDeriveSetStateType(def, schema); size_t arg_annotation_idx = 0; for (; it != end; ++it) { auto& name = (*it).ident().name(); // Add the input to the graph Value* new_input = block->addInput(); if (meaningfulName(name)) { new_input->setDebugName(name); } // Record the type for the schema and set the Type on the Value* auto arg = schema.arguments().at(arg_annotation_idx++); if (shouldDeriveType) { TORCH_INTERNAL_ASSERT(schema.arguments().size() == 1); const auto& inferredStateType = getTypeForSetStateArg(def, self); arg = arg.cloneWithType(inferredStateType); } arguments.push_back(arg); new_input->setType(arguments.back().type()); // NB: set type of new_input before setVar call so the Store is // typed appropriately environment_stack->setVar((*it).ident().range(), name, new_input); } return arguments; } Argument emitOutput( const SourceRange& range, const FunctionSchema& schema, Block* block) { // handleMaybeNoReturn ensures that merged_return_type_ is always set auto ret_type = def_stack_.back().merged_return_type_; TORCH_INTERNAL_ASSERT(ret_type); // in the ConvertToSSA pass, prim::ReturnStmts are lowered so that the // correct return value is set. Until then, we have a correctly-typed // placeholder return value. This is needed so that closures & graphs // are correctly typed. auto placeholder_return = graph->insertNode(graph->createUninitialized(ret_type))->output(); block->registerOutput(placeholder_return); return Argument("", def_stack_.back().merged_return_type_); } void emitStatements(const List& statements) { return emitStatements(statements.begin(), statements.end()); } // XXX: Right now closures are not generically implemented and are only used // as an intermediate form for special tasks, like defining gradients or // forked functions. // // There are several unfinished aspects that make them unusable generally // 1. We do not have a type, ivalue, operator to represent prim::Closure, so // closure_node has type None // 2. There is no export logic for it yet, so it cannot be // exported/python_printed // 3. There is nothing preventing the assignment of already existing variables // inside the closures // the changes to those variables will just get forgotten. // 4. There is no parsing support in frontend.py, this is intentional since it // prevents people from accidentally using this feature. // // This function leaves in the graph something like: // // %2 : None = prim::Closure() // block0(): // %1 : Tensor = prim::DoSomething(%0) // -> (%1) // // A separate pass is required to erase this closure and replace it with // something actually executable (see liftClosure and inlineForkedClosure). std::shared_ptr emitClosure( const std::function& emit_body) { Node* closure_node = graph->insertNode(graph->create(prim::Closure, 1)); // it is not a real thing yet, so just say the type is None closure_node->output()->setType(NoneType::get()); Block* block = closure_node->addBlock(); WithLoopStatus loop_guard(&loop_status_, LoopStatus::NOT_IN_LOOP); { WithInsertPoint guard(block); pushFrame(block, /*starts_def=*/true); emit_body(block); popFrame(/*ends_def=*/true); } return std::make_shared(closure_node->output()); } void emitClosure(const Def& def) { // invoked once the closure block is set as the environment auto emit_body = [&](Block* closure_block) { emitDef( def, nullptr, closure_block); // ignore schema return, we just wont use it for now // since we never create a Method for the closure }; auto closure_value = emitClosure(emit_body); environment_stack->setSugaredVar( def.name().range(), def.name().name(), closure_value, /*annotated_type=*/nullptr); } void checkBreakContinue( const SourceRange& loc, const std::string& stmt_name) { if (loop_status_ == LoopStatus::NOT_IN_LOOP) { throw( ErrorReport(loc) << "SyntaxError: '" << stmt_name << "'" << " outside loop"); } else if (loop_status_ == LoopStatus::IN_UNROLLED_LOOP) { throw( ErrorReport(loc) << "Because we emit iteration over modulelists or tuples as " "unrolled loops, we do not support break or continue inside the body of these loops"); } } void emitBreak(const Break& stmt) { checkBreakContinue(stmt.range(), "break"); auto break_node = graph->create(prim::BreakStmt, {}, 0)->setSourceRange(stmt.range()); graph->insertNode(break_node); } void emitContinue(const Continue& stmt) { checkBreakContinue(stmt.range(), "continue"); auto continue_node = graph->create(prim::ContinueStmt, {}, 0)->setSourceRange(stmt.range()); graph->insertNode(continue_node); } void emitDelete(const Delete& stmt) { for (const auto& target : stmt.targets()) { if (target.kind() == TK_SUBSCRIPT) { Subscript subscript(target); const List& subscript_exprs = subscript.subscript_exprs(); if (subscript_exprs[0].kind() == TK_SLICE_EXPR) { throw( ErrorReport(target.range()) << "del statements only support deletion at a single index, " "slicing is not supported" " (see https://github.com/pytorch/pytorch/issues/31430)"); } const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1); const SourceRange& val_range = subscript.value().range(); Value* idx = emitExpr(subscript_exprs[0]); Value* val = sv->asValue(val_range, method); // If val is a class instance, this is a method call to a type-specific // implementation of del defined in a __delitem__ method. if (auto cls = val->type()->cast()) { if (!cls->findMethod("__delitem__")) { throw( ErrorReport(target.range()) << "Class does not define __delitem__"); } // Use MethodValue to call the method to handle recursion. MethodValue(val, "__delitem__") .call(stmt.range(), method, {idx}, {}, 0); } else { auto node = graph->create(aten::Delete, {val, idx}, 0) ->setSourceRange(target.range()); graph->insertNode(node); } } else if (target.kind() == TK_VAR) { Var var(target); environment_stack->removeVar(var.name(), /*check_if_removed=*/true); } else { throw( ErrorReport(target.range()) << "del statements are only supported for deleting" " list and dict items and variables"); } } } void emitReturn(const Return& stmt) { TypePtr declared_return_type = def_stack_.back().declared_return_type_; // nullptr if not annotated auto actual_return = emitExpr(stmt.expr(), declared_return_type); // result type is annotated, every return must convert to that type if (declared_return_type) { // this guard skips implicit conversion from None -> Tensor for the return // type. otherwise forgetting a return a function returning a tensor will // cause a None to be converted to a tensor. if (!(actual_return->type()->isSubtypeOf(*TensorType::get()) && actual_return->type()->isSubtypeOf(*NoneType::get()))) { actual_return = tryConvertToType( stmt.range(), *graph, declared_return_type, actual_return, /*allow_conversions=*/true); } if (!actual_return->type()->isSubtypeOf(*declared_return_type)) { throw( ErrorReport(stmt.range()) << "Return value was annotated as having type " << declared_return_type->repr_str() << " but is actually of type " << actual_return->type()->repr_str()); } } else { declared_return_type = def_stack_.back().merged_return_type_; if (!declared_return_type) { declared_return_type = actual_return->type(); } auto merged_return_type = unifyTypes(declared_return_type, actual_return->type()); if (!merged_return_type) { throw( ErrorReport(stmt.range()) << "Previous return statement returned a value of type " << declared_return_type->repr_str() << " but this return statement returns a value of type " << actual_return->type()->repr_str()); } declared_return_type = merged_return_type.value(); } AT_ASSERT(declared_return_type); def_stack_.back().merged_return_type_ = declared_return_type; // If the annotated return type is Any and the result type is not Any, // cast the result to Any to facilitate type unification between return // statements on different code paths (e.g. different branches of an if, // body and containing scope of a loop). if (declared_return_type == AnyType::get() && actual_return->type() != AnyType::get()) { actual_return = graph->insertUncheckedCast(actual_return, declared_return_type); } graph->insertNode(graph->create(prim::ReturnStmt, {actual_return}, 0)); exit_blocks.insert(environment_stack->block()); } void emitStatements( List::const_iterator begin, List::const_iterator end) { for (; begin != end; ++begin) { auto stmt = *begin; ErrorReport::CallStack::update_pending_range(stmt.range()); switch (stmt.kind()) { case TK_IF: emitIf(If(stmt)); break; case TK_WHILE: emitWhile(While(stmt)); break; case TK_FOR: emitFor(For(stmt)); break; case TK_ASSIGN: emitAssignment(Assign(stmt)); break; case TK_AUG_ASSIGN: emitAugAssignment(AugAssign(stmt)); break; case TK_EXPR_STMT: { auto expr = ExprStmt(stmt).expr(); emitSugaredExpr(expr, 0); } break; case TK_RAISE: emitRaise(Raise(stmt)); break; case TK_ASSERT: emitAssert(Assert(stmt)); break; case TK_RETURN: { emitReturn(Return(stmt)); } break; case TK_CONTINUE: { emitContinue(Continue(stmt)); } break; case TK_BREAK: { emitBreak(Break(stmt)); } break; case TK_PASS: // Emit nothing for pass break; case TK_DEF: emitClosure(Def(stmt)); break; case TK_DELETE: emitDelete(Delete(stmt)); break; case TK_WITH: emitWith(With(stmt)); break; default: throw( ErrorReport(stmt) << "Unrecognized statement kind " << kindToString(stmt.kind())); } // Found an exit statement in this block. The remaining statements aren't // reachable so we don't emit them. if (exit_blocks.count(environment_stack->block())) return; } } RefinementSet findIsNoneRefinements( const Expr& lhs, Value* lhs_value, const Expr& rhs, Value* rhs_value, int tok) { if (rhs.kind() != TK_NONE && lhs.kind() == TK_NONE) { // make 'None is var' into 'var is None' return findIsNoneRefinements(rhs, rhs_value, lhs, lhs_value, tok); } if (rhs.kind() != TK_NONE || lhs.kind() != TK_VAR) { return {}; } // statement must be var {is, is not} None const std::string& name = Var(lhs).name().name(); // While it should in theory be possible to specialize // the `x is None` to know x has type NoneType, we have previously // not done this. Unfortunately, doing this will make the type None // propagate further in all loaded models. The handling of // unwrap_optional will fail in these cases since export did // not expect that the input would be none and an unannotated None. // To enable this, we need to (1) implement a real casting operator // annotated(T, X) that stays in the graph and does the cast // and (2) only enable this OPTIONAL_NONE when loading newer // graphs because it is incompatible with older graphs. // Refinement none(name, RefinementKind::OPTIONAL_NONE); if (const auto optional_type = lhs_value->type()->cast()) { Refinement present(name, optional_type->getElementType()); if (tok == TK_IS) { return RefinementSet({}, {present}); } else { // TK_ISNOT return RefinementSet({present}, {}); } } if (const auto union_type = lhs_value->type()->cast()) { std::vector to_subtract{NoneType::get()}; std::optional remaining = union_type->subtractTypeSet(to_subtract); std::vector all_present; if (remaining) { Refinement present{name, *remaining}; all_present.push_back(std::move(present)); } if (tok == TK_IS) { return RefinementSet({}, all_present); } else { // TK_ISNOT return RefinementSet(all_present, {}); } } return RefinementSet(); } CondValue emitCondExpr(const Expr& expr) { switch (expr.kind()) { case TK_AND: case TK_OR: { auto binop = BinOp(expr); return emitShortCircuitLogical( binop.range(), binop.lhs(), binop.rhs(), expr.kind() == TK_OR); } case TK_NOT: { CondValue v = emitCondExpr(Expr(expr.tree()->trees()[0])); Value* result = emitBuiltinCall( expr.range(), *graph, aten::__not__, {v.value()}, {}); std::optional static_if; if (v.staticIf()) { static_if = !*v.staticIf(); } return CondValue(result, v.refinements().Not(), static_if); } break; case TK_IS: case TK_ISNOT: { // meta programming on AST for is/is not cases and emit branches base on auto cond_op = BinOp(expr); Value* lhs_val = emitExpr(cond_op.lhs()); Value* rhs_val = emitExpr(cond_op.rhs()); auto lhs_none = canBeNone(lhs_val); auto rhs_none = canBeNone(rhs_val); // Dispatch logic (A: ALWAYS, N: NEVER, M: MAYBE): // // AA, -> statically IS always holds, IS_NOT never holds // AN , NA-> statically IS_NOT always holds, IS never holds // MA, MM, MN, NM, NN, AM -> cannot prove anything statically bool its_is = expr.kind() == TK_IS; if (lhs_none == ALWAYS && rhs_none == ALWAYS) { return CondValue(*graph, expr.range(), its_is, {}); } else if ( (lhs_none == ALWAYS && rhs_none == NEVER) || (lhs_none == NEVER && rhs_none == ALWAYS)) { // lhs_val/rhs_val with A/M: only emit never_none_branch return CondValue(*graph, expr.range(), !its_is, {}); } else { auto kind = getNodeKind(expr.kind(), expr.get()->trees().size()); Value* cond_value = emitBuiltinCall( expr.get()->range(), *method.graph(), kind, {lhs_val, rhs_val}, {}); auto refinements = RefinementSet(findIsNoneRefinements( cond_op.lhs(), lhs_val, cond_op.rhs(), rhs_val, expr.kind())); return CondValue(cond_value, refinements, std::nullopt); } } break; default: { if (expr.kind() == TK_APPLY) { auto apply = Apply(expr); auto callee = Apply(expr).callee(); if (callee.kind() == TK_VAR) { if (Var(callee).name().name() == "isinstance") { checkApplyNumInputs(apply, 2); return emitIsInstance(apply.inputs()[0], apply.inputs()[1]); } if (Var(callee).name().name() == "hasattr") { checkApplyNumInputs(apply, 2); return emitHasAttr(apply.inputs()[0], apply.inputs()[1]); } } auto sv = emitSugaredExpr(apply.callee(), 1); auto loc = apply.callee().range(); if (auto special_form = dynamic_cast(sv.get())) { if (special_form->form() == prim::isinstance) { checkApplyNumInputs(apply, 2); return emitIsInstance(apply.inputs()[0], apply.inputs()[1]); } } } auto expr_out = emitToBool(expr.range(), emitExpr(expr)); std::optional static_if = std::nullopt; auto kind = expr_out->node()->kind(); if (kind == aten::is_scripting) { static_if = true; } else if (kind == aten::has_torch_function) { static_if = false; } // MetaCompile on boolean literals and constants if (auto maybe_ivalue = toIValue(expr_out)) { static_if = maybe_ivalue->toBool(); } return CondValue(expr_out, RefinementSet({}), static_if); } break; } } std::shared_ptr emitSingleIfBranch( Block* b, const List& branch, const RefinementSet& refinements) { pushFrame(b); WithInsertPoint guard(b); insertRefinements(branch.range(), refinements); emitStatements(branch); return popFrame(); } Node* create(Symbol kind, const SourceRange& loc, size_t n_outputs) { return graph->create(kind, n_outputs)->setSourceRange(loc); } Value* emitTernaryIf( const TernaryIf& expr, const TypePtr& type_hint = nullptr) { CondValue cond_value = emitCondExpr(expr.cond()); // If the cond expr is a static value, then we metacompile the `if` // statemement and only emit true or false branch if (cond_value.staticIf()) { if (*cond_value.staticIf()) { return emitExpr(expr.true_expr(), type_hint); } else { return emitExpr(expr.false_expr(), type_hint); } } auto true_expr = [&] { return emitExpr(expr.true_expr(), type_hint); }; auto false_expr = [&] { return emitExpr(expr.false_expr(), type_hint); }; return emitIfExpr(expr.range(), cond_value, true_expr, false_expr); } template void refineAndSetUnionTypeHintOrPopulateCandidatesVector( const TypePtr& type_hint, TypePtr* refined_type_hint_ptr, std::vector* all_candidates, const std::string& match_repr, const Expr& src, const F1& type_match, const F2& do_if_match, const F3& do_if_anytype, bool is_dict_constructor = false) { if (auto union_type_hint = (*refined_type_hint_ptr)->cast()) { // `candidate_types` holds all List types that were in the Union // annotation std::vector candidate_types; std::copy_if( union_type_hint->containedTypes().begin(), union_type_hint->containedTypes().end(), std::back_inserter(candidate_types), [&](TypePtr type_ptr) { return type_match(type_ptr); }); if (!is_dict_constructor && candidate_types.empty()) { throw( ErrorReport(src) << "Expected an Union type annotation " << "with an inner " << match_repr << " type, but got " << (*refined_type_hint_ptr)->repr_str()); } else if (candidate_types.size() == 1) { // The Union only had a single type of the container we want to // match, so we can unconditionally refine it to that type (*refined_type_hint_ptr) = candidate_types[0]; } else { // We can't refine the Union yet, since it contains multiple // types of the container we want to match, but we do at least // have a list of possiblee types (e.g. `Union[List[int], // List[str], float, str]` -> candidates={List[int], List[str]}) (*all_candidates) = std::move(candidate_types); } } else if ( auto optional_type_hint = (*refined_type_hint_ptr)->cast()) { (*refined_type_hint_ptr) = optional_type_hint->getElementType(); } // This case handles code like `dict([(x, y), (a, b)])` that would // otherwise fail the following error checks if (is_dict_constructor) { return; } // If we had any annotation that was NOT a Union that can hold more // than one type of the container we want to match if (all_candidates->empty()) { if (type_match(*refined_type_hint_ptr)) { do_if_match(); } else if ((*refined_type_hint_ptr)->kind() == AnyType::Kind) { do_if_anytype(); } else { throw( ErrorReport(src) << "Expected an annotation of type " << match_repr << " but got " << type_hint->repr_str()); } } } void refineAndSetListTypeHintFromCandidatesVector( const std::vector& all_candidates, const TypePtr& type_hint, TypePtr* refined_type_hint_ptr, const TypePtr& unified_elem_type, const Expr& src) { TypePtr greatest_elem_type = nullptr; std::for_each( all_candidates.begin(), all_candidates.end(), [&](const TypePtr& candidate) { auto candidate_elem_type = candidate->expect()->getElementType(); if (unified_elem_type->isSubtypeOf(candidate_elem_type)) { if (!greatest_elem_type) { greatest_elem_type = candidate_elem_type; } else { greatest_elem_type = *(unifyTypes(greatest_elem_type, candidate_elem_type)); } } }); if (!greatest_elem_type) { std::stringstream vector_repr; for (size_t i = 0; i < all_candidates.size(); ++i) { if (i > 0 && all_candidates.size() > 2) { vector_repr << ", "; } if (i != 0 && i == all_candidates.size() - 1) { vector_repr << " or "; } vector_repr << all_candidates[i]->repr_str(); } throw( ErrorReport(src) << "Union type annotation `" << type_hint->repr_str() << "` can hold " << vector_repr.str() << ", but none of " << "those types match the types of the given list " << "elements, which were unified to " << unified_elem_type->repr_str()); } else { (*refined_type_hint_ptr) = ListType::create(greatest_elem_type); ; } } void refineAndSetDictTypeHintFromCandidatesVector( const std::vector& all_candidates, const TypePtr& type_hint, TypePtr* refined_type_hint_ptr, const TypePtr& known_key_type, const TypePtr& known_value_type, const Expr& src) { TypePtr candidate_key_type = nullptr; TypePtr candidate_value_type = nullptr; TypePtr candidate = nullptr; for (const auto& current_candidate : all_candidates) { auto current_key_type = current_candidate->expect()->getKeyType(); auto current_value_type = current_candidate->expect()->getValueType(); if (known_key_type->isSubtypeOf(current_key_type) && known_value_type->isSubtypeOf(current_value_type)) { if (!candidate || (candidate_key_type->isSubtypeOf(current_key_type) && candidate_value_type->isSubtypeOf(current_value_type))) { candidate_key_type = current_key_type; candidate_value_type = current_value_type; candidate = current_candidate; } } } if (!candidate) { std::stringstream vector_repr; for (size_t i = 0; i < all_candidates.size(); ++i) { if (i > 0 && all_candidates.size() > 2) { vector_repr << ", "; } if (i != 0 && i == all_candidates.size() - 1) { vector_repr << " or "; } vector_repr << all_candidates[i]->repr_str(); } throw( ErrorReport(src) << "Union type annotation `" << type_hint->repr_str() << "` can hold " << vector_repr.str() << ", but none of " << "those dict types can hold the types of the given" << " keys and values, which were unified to Dict[" << known_key_type->repr_str() << ", " << known_value_type->repr_str()); } else { (*refined_type_hint_ptr) = candidate; } } Value* emitListComprehension(const ListComp& lc, const TypePtr& type_hint) { const auto loc = lc.range(); const auto targets_list = List::create(lc.range(), {lc.target()}); const auto itrs = List::create(lc.range(), {lc.iter()}); // If there is no type hint, and this is emitted over an iterable that is // unrolled and of length 0, then we emit a List of tensors Value* list_value = graph->insertNode(graph->create(prim::ListConstruct, 1)) ->output() ->setType(ListType::ofTensors()); TypePtr refined_type_hint = type_hint; std::vector all_candidates = {}; if (refined_type_hint) { auto do_if_type_match = [&]() { list_value->setType(refined_type_hint); }; auto type_match = [&](const TypePtr& t) { return t->isSubtypeOf(AnyListType::get()); }; refineAndSetUnionTypeHintOrPopulateCandidatesVector( type_hint, &refined_type_hint, &all_candidates, "List", lc, type_match, do_if_type_match, do_if_type_match); } bool seen_first_elem = false; // A list comprehension introduces its own scope Node* n = graph->insertNode(create(prim::ComprehensionScope, lc.range(), 0)); auto* comprehension_block = n->addBlock(); pushFrame(comprehension_block); WithInsertPoint guard(comprehension_block); auto emit_body = [&]() { Value* out = emitExpr(lc.elt()); // If we didn't have a type annotation, the type of the list would // be set to `Tensor`. We don't want to unify this default type // with the actual elements in the list, so let the type begin as // the first element in the list if (!seen_first_elem) { list_value->setType(ListType::create(out->type())); seen_first_elem = true; } const auto elem_type_hint = refined_type_hint && refined_type_hint->kind() == ListType::Kind ? refined_type_hint->cast()->getElementType() : nullptr; std::optional unified_elem_type = unifyTypes( list_value->type()->expect()->getElementType(), out->type(), /*default_to_union=*/true, elem_type_hint); // Case: The list comprehension generated heterogenous values, // and we don't have a type hint to suggest that this is what the // user expected if (!type_hint && (*unified_elem_type)->isUnionType()) { TORCH_WARN( "List consists of heterogeneous types, which means", " that it has been typed as containing ", (*unified_elem_type)->repr_str(), ". To use any of the " "values in this List, it will be necessary to add an " "`assert isinstance` statement before first use to trigger " "type refinement. The first non-matching element was typed", " as ", out->type()->repr_str(), ", while the elements " " before it were ", list_value->type() ->expect() ->getElementType() ->repr_str(), "\n", lc.range().str()); } // Case: We had an annotation that we were able to narrow down to // a single ListType, but the most recently generated element in // the list comprehension doesn't match that annotation if (all_candidates.empty() && refined_type_hint && !(*unified_elem_type) ->isSubtypeOf(*refined_type_hint->expectRef() .getElementType())) { throw( ErrorReport(lc) << "List type annotation `" << refined_type_hint->repr_str() << "` did not match the types of the given list elements," << " which were unified to " << (*unified_elem_type)->repr_str()); } if (!all_candidates.empty()) { // If we had a Union type annotation that could hold more than // one different type of `List` refineAndSetListTypeHintFromCandidatesVector( all_candidates, type_hint, &refined_type_hint, *unified_elem_type, lc); } else if (!refined_type_hint) { refined_type_hint = ListType::create(*unified_elem_type); } list_value->setType(refined_type_hint); out->setType(refined_type_hint->expect()->getElementType()); NamedValue self = NamedValue(loc, "self", list_value); NamedValue input = NamedValue(loc, "", out); emitBuiltinCall(loc, *graph, aten::append, {input}, {}, self); }; emitFor(targets_list, itrs, loc, emit_body); popFrame(); return list_value; } Value* emitDictComprehension(const DictComp& dc, const TypePtr& type_hint) { const auto loc = dc.range(); const auto targets_list = List::create(dc.range(), {dc.target()}); const auto itrs = List::create(dc.range(), {dc.iter()}); Value* dict_value = graph->insertNode(graph->create(prim::DictConstruct, 1))->output(); // Set the default type to be Dict[str, Tensor] dict_value->setType(DictType::create(StringType::get(), TensorType::get())); TypePtr refined_type_hint = type_hint; TypePtr annotated_union_type = type_hint && type_hint->isUnionType() ? type_hint : nullptr; std::vector all_candidates = {}; if (refined_type_hint) { auto type_match = [&](const TypePtr& t) { return t->kind() == DictType::Kind; }; auto do_if_match = [&]() { dict_value->setType(refined_type_hint); }; refineAndSetUnionTypeHintOrPopulateCandidatesVector( type_hint, &refined_type_hint, &all_candidates, "Dict", dc, type_match, do_if_match, do_if_match); } TypePtr first_generated_key_type = nullptr; TypePtr first_generated_value_type = nullptr; // A dict comprehension introduces its own scope. No variable assigned // may leak into the rest of the graph Node* n = graph->insertNode(create(prim::ComprehensionScope, dc.range(), 0)); auto* comprehension_block = n->addBlock(); pushFrame(comprehension_block); WithInsertPoint guard(comprehension_block); auto emit_body = [&]() { auto k = emitExpr(dc.key()); auto v = emitExpr(dc.value()); // If we didn't have a type annotation, the type of the dict would // be set to `(str, Tensor)`. We don't want to unify this default // type with the actual elements in the dict, so let the type // begin as the first element in the dict if (k->type()->kind() == UnionType::Kind) { throw( ErrorReport(dc) << "Dicts may only contain homogeneous keys, but the type of " << "the first generated key was " << k->type()->repr_str()); } else if ( first_generated_key_type && first_generated_key_type != k->type()) { // Values can be heterogenous, so we only need to check that the // key types are all the same throw( ErrorReport(dc) << "Dicts may only contain homogeneous keys. Expected " << "dict comprehension to generate type " << first_generated_key_type->repr_str() << ", but got " << k->type()->repr_str()); } else { dict_value->setType(DictType::create(k->type(), v->type())); first_generated_key_type = k->type(); first_generated_value_type = v->type(); } // If we had any annotation OTHER THAN a Union that can hold more // than one type of Dict if (refined_type_hint && all_candidates.empty()) { DictTypePtr dict_type_hint = refined_type_hint->expect(); std::stringstream ss; std::stringstream err; bool is_key_subtype = k->type()->isSubtypeOfExt(*dict_type_hint->getKeyType(), &ss); if (!is_key_subtype) { err << "Dict type annotation `" << dict_type_hint->repr_str() << "` did not match the " << "type of an actual key type `" << k->type()->repr_str() << "`\n" << ss.str(); } ss.str(std::string()); bool is_value_subtype = v->type()->isSubtypeOfExt(*dict_type_hint->getValueType(), &ss); if (!is_value_subtype) { err << "Dict type annotation `" << dict_type_hint->repr_str() << "` did not match the " << "type of an actual value type `" << v->type()->repr_str() << "`\n" << ss.str(); } if (!is_key_subtype || !is_value_subtype) { throw(ErrorReport(dc) << err.str()); } } const TypePtr value_type_hint = refined_type_hint && refined_type_hint->kind() == DictType::Kind ? refined_type_hint->expect()->getValueType() : nullptr; std::optional unified_value_type = unifyTypes( first_generated_value_type, v->type(), /*default_to_union=*/true, value_type_hint); if (!type_hint && (*unified_value_type)->isUnionType()) { TORCH_WARN( "Dict values consist of heterogeneous types, which means", " that they have been typed as being ", (*unified_value_type)->repr_str(), ". To use any of the " "values in this dict, it will be necessary to add an " "`assert isinstance` statement before first use to trigger " "type refinement. The first non-matching element was typed", " as ", v->type()->repr_str(), ", while the elements " " before it were ", first_generated_value_type->repr_str(), "\n", dc.range().str()); } if (type_hint) { if (type_hint->kind() == DictType::Kind) { dict_value->setType(type_hint); k->setType(type_hint->expect()->getKeyType()); v->setType(type_hint->expect()->getValueType()); } else { if (!all_candidates.empty()) { refineAndSetDictTypeHintFromCandidatesVector( all_candidates, type_hint, &refined_type_hint, k->type(), *unified_value_type, dc); } dict_value->setType(refined_type_hint); k->setType(refined_type_hint->expect()->getKeyType()); v->setType(refined_type_hint->expect()->getValueType()); } } else { dict_value->setType(DictType::create(k->type(), *unified_value_type)); } NamedValue self = NamedValue(loc, "self", dict_value); NamedValue input_k = NamedValue(loc, "", k); NamedValue input_v = NamedValue(loc, "", v); emitBuiltinCall( loc, *graph, aten::_set_item, {self, input_k, input_v}, {}); }; emitFor(targets_list, itrs, loc, emit_body); popFrame(); if (annotated_union_type) { Node* n = graph->insertNode(graph->create(prim::unchecked_cast, {dict_value})); n->output()->setType(std::move(annotated_union_type)); dict_value = n->output(); } return dict_value; } // Insert subtyping refinements void insertRefinements(const SourceRange& loc, const RefinementSet& ref) { for (const Refinement& r : ref.activeRefinements()) { Value* v = environment_stack->getVar(r.identifier(), loc); Value* new_v = graph->insertUncheckedCast(v, r.type()); environment_stack->setVar(loc, r.identifier(), new_v); } } CondValue emitShortCircuitLogical( const SourceRange& loc, const Expr& first_expr, const Expr& second_expr, bool is_or) { CondValue lhs = emitCondExpr(first_expr); // if the continue expr in the short circuit is not evaluated, // than the const expression is False if the short circuit // is an `and` and True if the short circuit is an `or`. // `False and expr` -> False, `True or expr` -> True // // inserting it as a constant makes optimization easier // if it's an OR the first expr is emitted in the true branch // and the second expr in the false branch, if it's an AND the opposite auto get_const_expr = [&] { return graph->insertConstant(is_or, loc); }; std::optional rhs; auto get_continue_expr = [&] { rhs = emitCondExpr(second_expr); return rhs->value(); }; // if this is an OR, eval second expression if first expr is False // If this is an AND, eval second expression if first expr is True Value* new_result = nullptr; std::optional refinements; std::optional static_if; if (is_or) { new_result = emitIfExpr(loc, lhs, get_const_expr, get_continue_expr); refinements = lhs.refinements().Or(rhs->refinements()); if ((lhs.staticIf() && *lhs.staticIf()) || (rhs->staticIf() && *rhs->staticIf())) { static_if = true; } else if (lhs.staticIf() && rhs->staticIf()) { static_if = *lhs.staticIf() || *rhs->staticIf(); } } else { new_result = emitIfExpr(loc, lhs, get_continue_expr, get_const_expr); refinements = lhs.refinements().And(rhs->refinements()); if (((lhs.staticIf() && !*lhs.staticIf()) || (rhs->staticIf() && !*rhs->staticIf()))) { static_if = false; } else if (lhs.staticIf() && rhs->staticIf()) { static_if = *lhs.staticIf() && *rhs->staticIf(); } } return CondValue(new_result, std::move(*refinements), static_if); } Value* emitIfExpr( const SourceRange& range, const CondValue& cond_value, const std::function& true_expr, const std::function& false_expr) { Node* n = graph->insertNode(create(prim::If, range, 0)); n->addInput(cond_value.value()); auto* true_block = n->addBlock(); auto* false_block = n->addBlock(); auto emit_if_expr = [this, &range]( Block* b, const RefinementSet& refinements, const std::function& expr_value) { pushFrame(b); WithInsertPoint guard(b); insertRefinements(range, refinements); Value* out_val = expr_value(); b->registerOutput(out_val); popFrame(); }; emit_if_expr(true_block, cond_value.refinements(), true_expr); emit_if_expr(false_block, cond_value.refinements().Not(), false_expr); auto true_type = true_block->outputs().at(0)->type(); auto false_type = false_block->outputs().at(0)->type(); auto unified = unifyTypes(true_type, false_type); if (!unified) { throw( ErrorReport(range) << "if-expression's true branch has type " << true_type->repr_str() << " but false branch has type " << false_type->repr_str()); } // Add op outputs auto expr_value = n->addOutput()->setType(*unified); // Resulting value return expr_value; } Value* emitToBool(const SourceRange& loc, Value* v) { Value* out = nullptr; try { auto bool_cast = environment_stack->getSugaredVar("bool", loc); out = asSimple(bool_cast->call(loc, method, {v}, {}, 0)); } catch (...) { throw( ErrorReport(loc) << "Could not cast value of type " << v->type()->repr_str() << " to bool"); } if (!out) { throw( ErrorReport(loc) << "Could not cast value of type " << v->type()->repr_str() << " to bool"); } // cast value not response for checking output type if (!out->type()->isSubtypeOf(*BoolType::get())) { throw( ErrorReport(loc) << "expected a bool expression for condition but found " << out->type()->repr_str()); } return out; } void emitIfElseBlocks( const SourceRange& loc, const CondValue& cond_value, const List& trueBranch, const List& falseBranch) { // this is a static if statement: that is, it contains a subset // of operators where we are willing to specialize the if statement // to be only the true or false branch when the condition is statically // known. This is used to meta-program modules, for instance, when a // submodule is absent, an is None check can be used to ensure the // accesses to the None check, which would error, are not compiled. if (cond_value.staticIf()) { if (*cond_value.staticIf()) { insertRefinements(loc, cond_value.refinements()); emitStatements(trueBranch); } else { insertRefinements(loc, cond_value.refinements().Not()); emitStatements(falseBranch); } return; } Node* n = graph->insertNode(create(prim::If, loc, 0)); n->addInput(cond_value.value()); auto* true_block = n->addBlock(); auto* false_block = n->addBlock(); // Emit both blocks once to get the union of all mutated values auto save_true = emitSingleIfBranch(true_block, trueBranch, cond_value.refinements()); auto save_false = emitSingleIfBranch( false_block, falseBranch, cond_value.refinements().Not()); bool true_exits = exit_blocks.count(true_block); bool false_exits = exit_blocks.count(false_block); if (true_exits && false_exits) { exit_blocks.insert(n->owningBlock()); } // In python, every variable assigned in an if statement escapes // the scope of the if statement (all variables are scoped to the function). // Script is a subset of python: we consider variables to be in scope // as long as there is a definition of the variable along all paths // through the if statement // ---- // if ...: // a = // else: // ... // ... = a # error, a is not defined along all paths // ---- // if ...: // a = // else: // a = // ... = a # OK, a is defined along all paths // ---- // a = ... // if ...: // a = // ... = a # OK, a is defined along all paths // if ...: // a = // else: // return // ... = a # OK, a is always defined // ordered set, because we want deterministic graph output std::set mutated_variables; // When we access either the true or false environment, // we need to set the insertion point so the prim::Load is inserted // into the right block. // if var is only defined in one branch save error in case it's used later for (auto& v : save_true->definedVariables()) { { WithInsertPoint insert(false_block); if (save_false->findInAnyFrame(v) || false_exits) { mutated_variables.insert(v); } else { if (reportSourceLocation(loc.source()->size())) { ErrorReport error(loc); environment_stack->setVariableTypeError(v, [=]() -> std::string { error << v << " is not defined in the false branch"; return error.what(); }); } else { environment_stack->setVariableTypeError(v, [=]() -> std::string { std::stringstream ss; ss << v << " is not defined in the false branch. " << "The source info is eliminated due to the source file is too large. " << "To get it back, please set PYTORCH_JIT_ENABLE_LARGE_SOURCE_LOCATION=1 " << "as env var"; return ss.str(); }); } } } } for (auto& v : save_false->definedVariables()) { { WithInsertPoint insert(true_block); if (save_true->findInAnyFrame(v) || true_exits) { mutated_variables.insert(v); } else { if (reportSourceLocation(loc.source()->size())) { ErrorReport error(loc); environment_stack->setVariableTypeError(v, [=]() -> std::string { error << v << " is not defined in the true branch"; return error.what(); }); } else { environment_stack->setVariableTypeError(v, [=]() -> std::string { std::stringstream ss; ss << v << " is not defined in the false branch. " << "The source info is eliminated due to the source file is too large. " << "To get it back, please set PYTORCH_JIT_ENABLE_LARGE_SOURCE_LOCATION=1 " << "as env var"; return ss.str(); }); } } } } // Register outputs in each block for (const auto& x : mutated_variables) { Value* tv = nullptr; Value* fv = nullptr; { WithInsertPoint insert(true_block); if (!true_exits) { tv = save_true->getVar(x, loc); } } { WithInsertPoint insert(false_block); if (!false_exits) { fv = save_false->getVar(x, loc); } } // if both branches exit don't emit any variables // if one branch exits then we allow the all variables in the other branch // to escape scope since they are well-defined if (true_exits && false_exits) { continue; } else if (true_exits) { tv = graph->createUninitialized(fv->type()) ->insertBefore(true_block->return_node()) ->output(); graph->createStore(x, tv)->insertBefore(true_block->return_node()); } else if (false_exits) { fv = graph->createUninitialized(tv->type()) ->insertBefore(false_block->return_node()) ->output(); graph->createStore(x, fv)->insertBefore(false_block->return_node()); } SugaredValuePtr maybe_sugared_x = environment_stack->findInAnyFrame(x); TypePtr full_type = nullptr; if (maybe_sugared_x) { Value* maybe_simple = asSimple(maybe_sugared_x); if (maybe_simple) { full_type = maybe_simple->type(); } } // Try to unify the types. If we found a type annotation earlier // in the environment, and if that type annotation is some form // of union, then we need to tell `unifyTypes` not to throw an // error if the branched return types we found are heterogenous bool default_to_union = full_type && (full_type->kind() == UnionType::Kind || full_type->kind() == OptionalType::Kind || full_type->kind() == NumberType::Kind); auto unified = unifyTypes( tv->type(), fv->type(), /*default_to_union=*/default_to_union); // We allow variables to be set to different types in each branch // as long as that variable is not already in scope or if that // variable does not get used later. Here, we save the error so // that the error message will be more informative in the case // that is used later. When `a` is accessed in `(a + 1)`, the // error will get printed: // if cond: // a = 1 // else: // a = tensor // b = a + 1 // if (!unified) { ErrorReport error(loc); error << "Type mismatch: " << x << " is set to type " << tv->type()->repr_str() << " in the true branch" << " and type " << fv->type()->repr_str() << " in the false branch"; if (save_true->findInParentFrame(x) || save_false->findInParentFrame(x)) { throw ErrorReport(error); } else { environment_stack->setVariableTypeError( x, [=]() -> std::string { return error.what(); }); continue; } } environment_stack->setType(x, *unified); } } CondValue emitHasAttr(const Expr& objExpr, const Expr& attrExpr) { auto obj = emitSugaredExpr(objExpr, 1); if (attrExpr.kind() != TK_STRINGLITERAL) { throw( ErrorReport(attrExpr) << "hasattr's second argument must be a string literal"); } const std::string& name = StringLiteral(attrExpr).text(); const bool hasAttr = obj->hasAttr(objExpr.range(), method, name); return CondValue(*graph, objExpr.range(), hasAttr, {}); } CondValue emitIsInstance(const Expr& obj, const Expr& classinfo) { Value* lhs_val = emitExpr(obj); std::vector lhs_types; std::vector rhs_types; std::function gather_rhs = [&](const Expr& expr) { if (expr.kind() == TK_TUPLE_LITERAL) { for (Expr e : TupleLiteral(expr).inputs()) { gather_rhs(e); } return; } TypePtr type = typeParser_.parseTypeFromExpr(expr); rhs_types.emplace_back(type); }; lhs_types.push_back(lhs_val->type()); gather_rhs(classinfo); standardizeVectorForUnion(&lhs_types); standardizeVectorForUnion(&rhs_types); RefinementSet refinement; TypePtr unified_true = nullptr; TypePtr unified_false = nullptr; std::vector isinstance_types; std::vector not_isinstance_types; std::vector true_refinements; std::vector false_refinements; bool all_lhs_subtype_some_rhs = true; // We can discard any rhs types that we know statically would be // impossible. For example, if we had: // // def fn(x: Optional[str]): // if isinstance(x, (List[str], str, int)): // ... // // then `x` would be `str` in the true branch and `None` in the // false branch, not `(List[str], str, int)` in the true branch // and `None` in the false branch for (const TypePtr& lhs_type : lhs_types) { if (lhs_type == AnyType::get()) { isinstance_types.insert( isinstance_types.end(), rhs_types.begin(), rhs_types.end()); not_isinstance_types.emplace_back(AnyType::get()); // Edge case: we can still say that all lhs types subtype some // rhs type if `lhs` is `Any` and `rhs` is `Any` if (isinstance_types.size() != 1 || isinstance_types[0] != AnyType::get()) { all_lhs_subtype_some_rhs = false; } break; } auto get_smaller_type = [&](const TypePtr& t1, const TypePtr& t2) -> TypePtr { if (t1->isSubtypeOf(*t2)) { return t1; } else if (t2->isSubtypeOf(*t1)) { return t2; } else { return nullptr; } }; TypePtr found_refinement = nullptr; for (const TypePtr& rhs_type : rhs_types) { TypePtr maybe_smaller_type = get_smaller_type(lhs_type, rhs_type); if (!maybe_smaller_type) { continue; } else if (*maybe_smaller_type == *lhs_type) { // Cover the case that we have something like // lhs = `List[str]` and rhs = `list` found_refinement = lhs_type; } else if (*maybe_smaller_type == *rhs_type) { // We want the narrowest possible type found_refinement = found_refinement ? *(unifyTypes(found_refinement, rhs_type)) : rhs_type; } } if (found_refinement) { if (*found_refinement == *lhs_type) { all_lhs_subtype_some_rhs &= true; } isinstance_types.push_back(found_refinement); } else { // If the lhs couldn't be a subtype of the rhs (or couldn't // be "refined" to itself, as in the `List[str]` and `list` // case above), then we add `lhs_type` to the false branch // refinements. This is because the type can still be itself // if the `isinstance` check is false not_isinstance_types.push_back(lhs_type); all_lhs_subtype_some_rhs = false; } } // For use with `unifyTypeList` std::stringstream nowhere; // Get a single type for the true and false branches if (!isinstance_types.empty()) { unified_true = *unifyTypeList(isinstance_types, nowhere, /*default_to_union=*/true); } if (obj.kind() == TK_VAR && unified_true) { std::string ident = Var(obj).name().name(); true_refinements = {Refinement(ident, unified_true)}; } // Get a single type for the true and false branches if (!not_isinstance_types.empty()) { unified_false = *unifyTypeList( not_isinstance_types, nowhere, /*default_to_union=*/true); } if (obj.kind() == TK_VAR && unified_false) { std::string ident = Var(obj).name().name(); false_refinements = {Refinement(ident, unified_false)}; } refinement = RefinementSet(true_refinements, false_refinements); bool is_statically_false = isinstance_types.empty(); // If the statement is statically true if (all_lhs_subtype_some_rhs) { return CondValue(*graph, obj.range(), true, std::move(refinement)); } if (is_statically_false) { return CondValue(*graph, obj.range(), false, std::move(refinement)); } // check maybe true/false at runtime, need an actual op Value* result = graph->insertNode(graph->createIsInstance(lhs_val, rhs_types)) ->output(); return CondValue(result, std::move(refinement), std::nullopt); } void emitIf(const If& stmt) { Expr cond = stmt.cond(); CondValue cond_value = emitCondExpr(cond); emitIfElseBlocks( stmt.range(), cond_value, stmt.trueBranch(), stmt.falseBranch()); } // *********************** Loop Operators ************************************ // Emits a loop operator with the form: // Loop(max_trip_count) // block0(loop_counter) { // // } // block1 { // // -> (condition) // } // For loops will have an empty loop condition block with condition set to // true. In the convert to ssa pass, the loop condition will correctly // inlined. and inputs and outputs added so that the loop conforms to the // semantics specified at // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Loop void emitLoopCommon( const SourceRange& range, const std::function& emit_body, const SugaredValuePtr& iter_val, std::optional> targets, std::optional cond) { Value* max_trip_count_val = nullptr; if (iter_val != nullptr) { max_trip_count_val = iter_val->len(range, method); } else { max_trip_count_val = materializeConstant( std::numeric_limits::max(), *graph, range, integral_constants); } Node* n = graph->insertNode(create(prim::Loop, range, 0)); auto* body_block = n->addBlock(); { Block* condition_block = n->addBlock(); pushFrame(condition_block); Value* out = nullptr; if (cond) { WithInsertPoint insert(condition_block); out = emitToBool(cond.value().range(), emitExpr(cond.value())); } else { WithInsertPoint insert(n); out = graph->insertConstant(true, range); } condition_block->registerOutput(out); popFrame(); } n->addInput(max_trip_count_val); WithLoopStatus loop_guard(&loop_status_, LoopStatus::IN_LOOP); Value* trip_count = body_block->addInput()->setType(IntType::get()); // Iteration num { pushFrame(body_block); WithInsertPoint guard(body_block); // if the FOR iters and targets are present, emit FOR target assignments if (iter_val != nullptr && targets) { Value* cur_elem = iter_val->getitem(range, method, trip_count) ->asValue(range, method); SugaredValuePtr sv = std::make_shared(cur_elem); List target_exprs = targets.value(); validateAssignLhsExpr(target_exprs, range); // if target exprs are more than 1, it means iteration unpacking on LHS // we create Tuple literal to wrap those target exprs for assignments if (target_exprs.size() > 1) { Expr tl = TupleLiteral::create(range, target_exprs); target_exprs = List::create(range, {tl}); } emitExprsAssign(target_exprs, {sv}, range, /*n_binders=*/1); } emit_body(); popFrame(); } } void emitUnrolledLoop( const SourceRange& loc, const std::function& emit_body, const SugaredValuePtr& iterable, const List& targets) { auto static_len = iterable->staticLen(); TORCH_INTERNAL_ASSERT( static_len, "Unrolled loop iter should have static length"); int64_t len = *static_len; WithLoopStatus loop_guard(&loop_status_, LoopStatus::IN_UNROLLED_LOOP); // In order to support ModuleLists which return different types, // as with an nn.Sequential which has a module that returns a Dict and then // a module which returns a Tensor, // we do not push a new environment frame because if we did all intermediary // values would have to subtype the input type. for (const auto i : c10::irange(len)) { auto index = materializeConstant(i, *method.graph(), loc, integral_constants); auto sugared_value = iterable->getitem(loc, method, index); emitExprsAssign( targets, {sugared_value}, targets.range(), /*n_binders=*/1); emit_body(); } } void emitFor( const List& targets, const List& itrs, const SourceRange& loc, const std::function& emit_body) { if (itrs.size() != 1) { throw(ErrorReport(loc) << "List of iterables is not supported currently"); } // Emit loop information for builtinFunction values like range(), zip(), // enumerate() or SimpleValue like List, Tensor, Dict, etc. SugaredValuePtr sv = emitSugaredExpr(itrs[0], 1); SugaredValuePtr iterable = sv->iter(loc, method); // We unroll the loop for iterables that contain ModuleLists so that we can // compile Heterogenous module lists. if (!iterable->shouldEmitUnrolled()) { emitLoopCommon(loc, emit_body, iterable, targets, {}); } else { emitUnrolledLoop(loc, emit_body, iterable, targets); } } void emitFor(const For& stmt) { auto emit_body = [&]() { emitStatements(stmt.body()); }; emitFor(stmt.targets(), stmt.itrs(), stmt.range(), emit_body); } void emitWhile(const While& stmt) { auto cond = stmt.cond(); auto emit_body = [&]() { emitStatements(stmt.body()); }; emitLoopCommon(stmt.range(), emit_body, nullptr, {}, cond); } void emitWith(const With& stmt) { auto targets = stmt.targets(); // Keep a stack of entered objects so they can be exited // in the right order. std::stack entered; for (const auto& target : targets) { Expr e = target.target(); auto* rhs = emitExpr(e); auto* n = graph->insertNode(graph->create(prim::Enter, {rhs})); entered.push(rhs); if (rhs->type()->kind() != TypeKind::ClassType) { throw( ErrorReport(e.range()) << "With item expression must return an object"); } auto rhsClass = rhs->type()->expect(); auto* enterMethod = rhsClass->findMethod("__enter__"); auto* exitMethod = rhsClass->findMethod("__exit__"); if (!enterMethod || !exitMethod) { throw( ErrorReport(e.range()) << "Object returned by with item expression does not define __enter__ and __exit__ methods"); } // Check the schema of __enter__. auto& enterSchema = enterMethod->getSchema(); if (enterSchema.arguments().size() != 1) { throw( ErrorReport(e.range()) << "__enter__ must have only one argument and one return value"); } // Check the schema of __exit__. auto& exitSchema = exitMethod->getSchema(); if (exitSchema.arguments().size() != 4) { throw(ErrorReport(e.range()) << "__exit__ must have four arguments"); } else { for (unsigned i = 1; i < 4; ++i) { if (exitSchema.arguments().at(i).type() != AnyType::get()) { throw( ErrorReport(e.range()) << "argument " << i << " of __exit__ must have Any type; TorchScript does not currently support passing exception type, value, or traceback to the __exit__ function."); } } } // Set the output of the enter node to be the return type of __enter__. n->output(0)->setType(enterSchema.returns().at(0).type()); // Set i = e.__enter__() so that references to i in the body of the with // will resolve correctly. if (target.var().present()) { Var i = target.var().get(); environment_stack->setVar(i.range(), i.name().name(), n->output(0)); } } emitStatements(stmt.body()); // Insert all the corresponding prim::Exit nodes. while (!entered.empty()) { auto* input = entered.top(); entered.pop(); auto* n = graph->create(prim::Exit); graph->insertNode(n); n->addInput(input); } } // Currently we do not support assigning exceptions to variables, // a = Exception("hi") // raise a // // We ignore the expression following raise void emitRaise(const Raise& raise) { auto sv = emitSugaredExpr(raise.expr(), 1); Value* error_message = nullptr; Value* qualified_class_name = nullptr; if (auto exception_instance = std::dynamic_pointer_cast(sv)) { // The typical case, an instance of the exception class was thrown: // raise RuntimeError("error") error_message = exception_instance->getValue(); qualified_class_name = exception_instance->getQualifiedClassName(); } else if ( auto exception_class = std::dynamic_pointer_cast(sv)) { // A bare exception was thrown so add an empty message. e.g. // raise RuntimeError error_message = insertConstant(*graph, "", raise.range()); } else { // The raise was not followed by an exception (i.e. it was something like // `raise "error"` instead of `raise RuntimeError("error")`) throw( ErrorReport(raise.range()) << "exceptions must derive from BaseException"); } if (!error_message->type()->isSubtypeOf(*StringType::get())) { error_message = graph->insert(aten::str, {error_message}); } graph->insert( prim::RaiseException, {error_message, qualified_class_name}, {}, raise.range()); exit_blocks.insert(environment_stack->block()); } // emit assserions as an if branch so that assertions will reuse the // message void emitAssert(const Assert& stmt) { CondValue cond_value = emitCondExpr(stmt.test()); List true_branch = List::create(stmt.range(), {}); // Create an `AssertionError("the_message")` call auto message = (stmt.msg().present()) ? stmt.msg().get() : StringLiteral::create(stmt.range(), ""); auto callee = Var::create( stmt.range(), Ident::create(stmt.range(), "AssertionError")); auto apply = Apply::create( stmt.range(), callee, List::create(stmt.range(), {message}), List::create(stmt.range(), {})); List false_branch = List::create(stmt.range(), {Raise::create(stmt.range(), apply)}); emitIfElseBlocks(stmt.range(), cond_value, true_branch, false_branch); } // Validate that the `lhs` Expr's in an assignment statement are valid. That // is: // // 1) All lhs Expr's are either Var, Tuple or Starred nodes // 2) There is at most one Starred node in the lhs Expr // 3) A Starred node can only appear when there is another non-Starred lhs // Expr. Concretely this means that `*abc = func()` is illegal. Unpacking // all outputs into a tuple is covered by `abc = func()`. bool validateAssignLhsExpr(const List& lhs, const SourceRange& r) { size_t num_normal_assign = 0; size_t num_starred = 0; for (const auto& assignee : lhs) { if (assignee.kind() == TK_VAR || assignee.kind() == TK_SUBSCRIPT || assignee.kind() == TK_TUPLE_LITERAL || assignee.kind() == '.') { num_normal_assign++; } else if (assignee.kind() == TK_STARRED) { num_starred++; } else { throw( ErrorReport(assignee) << "lhs of assignment must be a variable, " << "subscript, or starred expression"); } } if (num_starred > 1) { throw( ErrorReport(r) << "Only one starred expression is allowed on the lhs"); } if (num_starred > 0 && num_normal_assign == 0) { throw( ErrorReport(r) << "A Starred expression may only appear on the " << "lhs within the presence of another non-starred" << " expression"); } return num_starred; } // Get the appropriate builtin op for this augmented assignment // If the RHS is a tensor, return the corresponding ATen in-place op // If it's a list of scalars, then return the corresponding list augment op Symbol getAugOp(const AugAssign& stmt, const TypePtr& type) { bool use_inplace_op = type->isSubtypeOf(*TensorType::get()) || type->kind() == TypeKind::ListType; switch (stmt.aug_op()) { case '+': return use_inplace_op ? aten::add_ : aten::add; case '-': return use_inplace_op ? aten::sub_ : aten::sub; case '/': return use_inplace_op ? aten::div_ : aten::div; case '*': return use_inplace_op ? aten::mul_ : aten::mul; case '%': return use_inplace_op ? aten::fmod_ : aten::fmod; case '|': return use_inplace_op ? aten::bitwise_or : aten::__or__; case '&': return use_inplace_op ? aten::bitwise_and : aten::__and__; case '^': return use_inplace_op ? aten::bitwise_xor : aten::__xor__; case TK_LSHIFT: return use_inplace_op ? aten::__ilshift__ : aten::__lshift__; case TK_RSHIFT: return use_inplace_op ? aten::__irshift__ : aten::__rshift__; case TK_POW: return aten::pow; default: throw( ErrorReport(stmt) << "Unknown augmented assignment: " << kindToString(stmt.aug_op())); } } // Get a pair of // since the out of place method is called if the in place method is not // present std::pair getAugMagicMethod(const AugAssign& stmt) { switch (stmt.aug_op()) { case '+': return std::make_pair(std::string("__iadd__"), std::string("__add__")); case '-': return std::make_pair(std::string("__isub__"), std::string("__sub__")); case '/': return std::make_pair( std::string("__itruediv__"), std::string("__truediv__")); case '*': return std::make_pair(std::string("__imul__"), std::string("__mul__")); case '%': return std::make_pair(std::string("__imod__"), std::string("__mod__")); default: throw( ErrorReport(stmt) << "Unknown augmented assignment: " << kindToString(stmt.aug_op())); } } // Emit nodes for augmented assignments like `+=` void emitAugAssignment(const AugAssign& stmt) { switch (stmt.lhs().kind()) { case TK_VAR: { emitAugAssignmentToVar(stmt); } break; case '.': { emitAugAssignmentToSelectVar(stmt); } break; case TK_SUBSCRIPT: { emitAugAssignmentToSubscript(stmt); } break; default: throw( ErrorReport(stmt.lhs()) << "unexpected expression on " << "left-hand side of augmented assignment"); } } // This will be called when there is a class param or module buffer // mutation which make the LHS of the expr be a select expression // // Example like: // class A(Module): // def __init__(): // self.register_buffer("running_var", torch.zeros(1)) // // def forward(): // self.num_batches += 1 void emitAugAssignmentToSelectVar(const AugAssign& stmt) { const auto lhs = Select(stmt.lhs()); auto lhsSugaredVar = emitSugaredExpr(lhs.value(), 1); const auto lhsValue = lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name()) ->asValue(lhs.range(), method); auto result = emitAugAssignmentHelper(stmt, lhsValue); lhsSugaredVar->setAttr(stmt.range(), method, lhs.selector().name(), result); } void emitAugAssignmentToVar(const AugAssign& stmt) { const auto lhs = Var(stmt.lhs()); auto lhsValue = emitExpr(lhs); auto result = emitAugAssignmentHelper(stmt, lhsValue); environment_stack->setVar(lhs.range(), lhs.name().name(), result); } Value* emitAugAssignmentHelper(const AugAssign& stmt, Value* lhs) { if (lhs->type()->kind() == TypeKind::ClassType) { // Call `__iadd__` so updates happen in place on class types // https://docs.python.org/3/reference/datamodel.html#object.__iadd__ std::string in_place_method_name; std::string out_of_place_method_name; std::tie(in_place_method_name, out_of_place_method_name) = getAugMagicMethod(stmt); const auto rhs = emitExpr(stmt.rhs()); // Determine whether to use __iadd__ or __add__ (use __add__ only if // __iadd__ is not present) auto type = lhs->type()->expect(); std::string magic_method_name; if (type->findMethod(in_place_method_name)) { magic_method_name = in_place_method_name; } else if (type->findMethod(out_of_place_method_name)) { magic_method_name = out_of_place_method_name; } else { throw( ErrorReport(stmt.range()) << "Cannot emit inplace op on " << type->repr_str() << " since it does not define an " << in_place_method_name << " or " << out_of_place_method_name << " method"); } // x += y is equivalent to x = x.__iadd__(y) or x = x.__add__(y) if // __iadd__ is not present return MethodValue(lhs, magic_method_name) .call(stmt.range(), method, {rhs}, {}, 0) ->asValue(stmt.range(), method); } else { const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs())) .value(*method.graph()); return emitBuiltinCall( stmt.range(), *method.graph(), getAugOp(stmt, lhs->type()), /*args=*/{lhs, rhs}, /*kwargs=*/{}, /*self=*/std::nullopt); } } void emitAugAssignmentGeneric( const AugAssign& stmt, const Subscript& lhs, Value* sliceable) { // Get the idx to augment const auto subscriptExprs = lhs.subscript_exprs(); const TypePtr type = sliceable->type(); if (subscriptExprs.size() != 1) { throw( ErrorReport(subscriptExprs) << "Sliced expression not yet supported for " << type->repr_str() << " augmented assignment. " << "File a bug if you want this"); } TypePtr elemType = nullptr; if (const ListTypePtr listType = type->cast()) { elemType = listType->getElementType(); } else if (const DictTypePtr dictType = type->cast()) { elemType = dictType->getKeyType(); } if (elemType == nullptr) { throw( ErrorReport(lhs) << type->repr_str() << " does not support augmented assignment."); } const auto idxValue = emitExpr(subscriptExprs[0]); const auto containerArg = NamedValue(lhs.value().range(), type->str(), sliceable); const auto idxArg = NamedValue(subscriptExprs.range(), "idx", idxValue); const auto valueArg = NamedValue(stmt.rhs().range(), "value", emitExpr(stmt.rhs())); const auto getItem = graph->insert( aten::__getitem__, {containerArg, idxArg}, {}, stmt.range()); const auto augmentedItem = graph->insert( getAugOp(stmt, elemType), {getItem, valueArg}, {}, stmt.range()); graph->insert( aten::_set_item, {containerArg, idxArg, augmentedItem}, {}, stmt.range()); } void emitAugAssignmentToSubscript(const AugAssign& stmt) { // Process the base list value const auto lhs = Subscript(stmt.lhs()); const auto sliceable = emitExpr(lhs.value()); if (sliceable->type()->isSubtypeOf(*TensorType::get())) { // If it's a tensor, just fully evaluate the subscript operation and emit // an in-place assignment auto [sliced, tensorIndices] = emitIntAndSliceIndexing( lhs.range(), sliceable, lhs.subscript_exprs()); const auto slicedArg = NamedValue(stmt.lhs().range(), "self", sliced); const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs())); if (tensorIndices.empty()) { // Common case: we only tried to index with int and slices. Emit the // correct augmented assignment op to the sliced value emitBuiltinCall( stmt.range(), *method.graph(), getAugOp(stmt, sliceable->type()), {rhs}, {}, slicedArg); } else { // Special case: we tried to do "advanced indexing". Lower this expr // into `index` and `index_put_` ops with tensordices of Tensor?[] const auto indices = graph ->insertNode(graph->createList( OptionalType::ofTensor(), tensorIndices)) ->output(); const auto indexed = graph->insert(aten::index, {slicedArg, indices}, {}, stmt.range()); const auto augmented = emitBuiltinCall( stmt.range(), *method.graph(), getAugOp(stmt, sliceable->type()), {rhs}, {}, indexed); graph->insert( aten::index_put_, {slicedArg, indices, augmented}, {}, stmt.range()); } } else { emitAugAssignmentGeneric(stmt, lhs, sliceable); } } NamedValue emitValueToTensor( const NamedValue& value, const NamedValue& matchTypeOf) { // Add implicit conversion of int/float/complex/bool/number types to tensors // Used in emitSubscriptAssign to convert: // `tensor(...)[x] = 99` to `tensor(...)[x] = tensor(99)` // Mirrors the `valueToTensor` behavior in python_variable_indexing.cpp const auto kind = value.type()->kind(); if (kind == c10::TypeKind::NumberType || kind == c10::TypeKind::IntType || kind == c10::TypeKind::BoolType || kind == c10::TypeKind::FloatType || kind == c10::TypeKind::ComplexType) { auto dtype = graph->insert(prim::dtype, {matchTypeOf}, {}); auto device = graph->insert(prim::device, {matchTypeOf}, {}); auto converted = graph->insert( aten::tensor, {value}, {NamedValue("dtype", dtype), NamedValue("device", device)}); return NamedValue(value.loc(), converted); } return value; } // Emit mutating assignments like `foo[0] = bar` void emitSubscriptAssign( const SourceRange& stmtRange, const Subscript& lhs, const Expr& rhs) { emitSubscriptAssign(stmtRange, lhs, NamedValue(rhs.range(), emitExpr(rhs))); } void emitSubscriptAssign( const SourceRange& stmtRange, const Subscript& lhs, const NamedValue& rhs) { // First check the base value. auto sliceable = emitExpr(lhs.value()); // If it's a tensor, copy the RHS data into it if (sliceable->type()->isSubtypeOf(*TensorType::get())) { // Handle multi-dimensional slicing: first emit int/slice indexing // TODO: the Python equivalent code has special-cased copy_to // broadcasting to match NumPy semantics (see PR#4853). We can't // replicate that without knowing the size of the Tensor; so really that // code should be moved into the aten function auto [sliced, tensorIndices] = emitIntAndSliceIndexing( lhs.range(), sliceable, lhs.subscript_exprs()); const auto slicedArg = NamedValue(lhs.range(), sliced); // rhs must be a tensor, implicitly convert int/float/complex/bool const auto convertedRhs = emitValueToTensor(rhs, slicedArg); if (tensorIndices.empty()) { // Common case: we only tried to index with int and slices. Copy the // RHS into the resulting tensor. graph->insert(aten::copy_, {slicedArg, convertedRhs}, {}, stmtRange); } else { // Special case: we tried to do "advanced indexing" with a tensor. // Dispatch to `aten::index_put_` with tensorindices of Tensor?[] const auto indices = graph ->insertNode(graph->createList( OptionalType::ofTensor(), tensorIndices)) ->output(); graph->insert( aten::index_put_, {slicedArg, indices, convertedRhs}, {}, stmtRange); } // Otherwise, this is a list or a classtype. // Dispatch to aten::_set_item to both select and assign } else { const auto subscript = lhs.subscript_exprs(); if (subscript.size() != 1 || subscript[0].kind() == TK_SLICE_EXPR) { throw( ErrorReport(subscript) << "Sliced expression not yet supported for" << " subscripted assignment. " << "File a bug if you want this"); } if (sliceable->type()->isSubtypeOf(*AnyTupleType::get())) { throw( ErrorReport(lhs) << sliceable->type()->repr_str() << " does not support subscripted assignment"); } std::vector args; args.emplace_back(lhs.value().range(), "self", sliceable); args.emplace_back( lhs.subscript_exprs().range(), "idx", emitExpr(subscript[0])); args.push_back(rhs); makeMagic( "__setitem__", std::make_shared(aten::_set_item, std::nullopt)) ->call(stmtRange, method, args, {}, 0); } } void emitTupleAssign(const TupleLiteral& tl, const Expr& rhs) { size_t n_binders = tl.inputs().size(); bool starred_unpack = validateAssignLhsExpr(tl.inputs(), tl.range()); if (starred_unpack) n_binders--; auto output = emitSugaredExpr(rhs, n_binders); emitTupleAssign(tl, output, rhs.range(), n_binders, starred_unpack); } void emitTupleAssign( const TupleLiteral& tl, const SugaredValuePtr& rhs_output, const SourceRange& rhs_loc, size_t n_binders, bool starred_unpack) { auto outputs = rhs_output->asTuple( rhs_loc, method, starred_unpack ? std::nullopt : std::optional{n_binders}); if (outputs.size() < n_binders) { throw( ErrorReport(tl) << "need " << (starred_unpack ? "at least " : "") << n_binders << " values to unpack but found only " << outputs.size()); } if (outputs.size() > n_binders && !starred_unpack) { throw( ErrorReport(tl) << "too many values to unpack: need " << n_binders << " but found " << outputs.size()); } emitExprsAssign(tl.inputs(), outputs, rhs_loc, n_binders); } void emitExprsAssign( const List& lhs_exprs, const at::ArrayRef outputs, const SourceRange& rhs_loc, size_t n_binders) { size_t i = 0; for (auto assignee : lhs_exprs) { switch (assignee.kind()) { case TK_SUBSCRIPT: emitSubscriptAssign( rhs_loc, Subscript(assignee), NamedValue(rhs_loc, outputs.at(i)->asValue(rhs_loc, method))); i++; break; case TK_VAR: environment_stack->setSugaredVar( assignee.range(), Var(assignee).name().name(), outputs.at(i), /*annotated_type=*/nullptr); i++; break; case TK_STARRED: { auto var = Starred(assignee).expr(); if (var.kind() != TK_VAR) { throw( ErrorReport(var) << "Cannot pack a tuple into a non-variable"); } size_t n_matched = outputs.size() - n_binders; ArrayRef> outputs_ref = outputs; auto values = fmap( outputs_ref.slice(i, n_matched), [&](const std::shared_ptr& v) { return v->asValue(assignee.range(), method); }); auto tup = graph->insertNode(graph->createTuple(values))->output(); environment_stack->setVar(var.range(), Var(var).name().name(), tup); i += n_matched; } break; case TK_TUPLE_LITERAL: { // recursively emit tuple assignments on tuple literal input TupleLiteral sub_tl = TupleLiteral(assignee); size_t sub_n_binders = sub_tl.inputs().size(); bool sub_starred_unpack = validateAssignLhsExpr(sub_tl.inputs(), sub_tl.range()); if (sub_starred_unpack) sub_n_binders--; emitTupleAssign( sub_tl, outputs.at(i), rhs_loc, sub_n_binders, sub_starred_unpack); i++; } break; case '.': { emitSelectAssign(assignee, outputs.at(i), rhs_loc); i++; } break; default: throw( ErrorReport(assignee) << "unexpected expression on the left-hand side"); } } } void emitAssignment(const Assign& stmt) { if (stmt.lhs_list().size() == 1) { return emitSingleAssignment(stmt); } // multiple assign & annotated type not supported in python TORCH_INTERNAL_ASSERT(stmt.lhs_list().size() > 1 && !stmt.type().present()); // a = b = expr() // the semantics of multiple assignment is that expr() is emitted once, then // from left to right the assignments are made const auto tmp_name = createTempName("$tmp_assign_"); environment_stack->setSugaredVar( stmt.rhs().range(), tmp_name, emitSugaredExpr(stmt.rhs().get(), 1), /*annotated_type=*/nullptr); auto ident = Var::create( stmt.rhs().range(), Ident::create(stmt.rhs().range(), tmp_name)); for (auto expr : stmt.lhs_list()) { emitSingleAssignment(Assign::create( stmt.range(), List::create(expr.range(), {expr}), Maybe::create(stmt.rhs().range(), ident), Maybe::create(stmt.range()))); } } void emitSingleAssignment(const Assign& stmt) { if (!stmt.rhs().present()) { throw( ErrorReport(stmt.range()) << "For an assignment, expected an expression on the right-hand side"); } const Expr& rhs = stmt.rhs().get(); switch (stmt.lhs().kind()) { case TK_VAR: { auto v = Var(stmt.lhs()); TypePtr type = nullptr; if (stmt.type().present()) { type = typeParser_.parseTypeFromExpr(stmt.type().get()); } auto rhs_sugared_val = emitSugaredExpr(rhs, 1, type); // START BC HACK // // For old serialized quantized RNN modules, switch // quantized::linear_prepack to quantized::linear_prepack_legacy. We // changed linear_prepack to return a TorchBind class and not a // cpp_custom_type_hack tensor anymore, but the old serialized models // are tightly coupled with the type_hack version. If we still create a // Tensor here, then the quantized_lstm.legacy overload can kick in in // forward_impl(), and the module will still run correctly. if (method.qualname() == "__torch__.torch.nn.quantized.dynamic.modules.rnn.PackedParameter.__setstate__") { if (auto sv = std::dynamic_pointer_cast(rhs_sugared_val)) { Node* rhs_node = sv->getValue()->node(); if (rhs_node->kind() == Symbol::fromQualString("quantized::linear_prepack")) { std::vector inputs; for (Value* i : rhs_node->inputs()) { inputs.emplace_back(i); } Value* new_val = rhs_node->owningGraph()->insert( Symbol::fromQualString("quantized::linear_prepack_legacy"), inputs, {}, rhs_node->sourceRange()); rhs_sugared_val = std::make_shared(new_val); } } } // END BC HACK environment_stack->setSugaredVar( v.range(), v.name().name(), std::move(rhs_sugared_val), /*annotated_type=*/type); } break; case TK_TUPLE_LITERAL: emitTupleAssign(TupleLiteral(stmt.lhs()), rhs); break; case '.': emitSelectAssign(stmt); break; case TK_SUBSCRIPT: emitSubscriptAssign(stmt.range(), Subscript(stmt.lhs()), rhs); break; default: throw( ErrorReport(stmt.lhs()) << "unexpected expression on left-hand side of assignment"); } } void emitSelectAssign(const Assign& stmt) { if (!stmt.rhs().present()) { throw(ErrorReport(stmt.range()) << "Expected RHS for assignment"); } TypePtr type_hint = nullptr; if (stmt.type().present()) { type_hint = typeParser_.parseTypeFromExpr(stmt.type().get()); } const auto lhs = Select(stmt.lhs()); auto lhsObject = emitSugaredExpr(lhs.value(), 1); const auto rhsValue = emitSugaredExpr(stmt.rhs().get(), 1, type_hint) ->asValue(stmt.rhs().range(), method); lhsObject->setAttr(stmt.range(), method, lhs.selector().name(), rhsValue); } void emitSelectAssign( const Expr& lhs, const SugaredValuePtr& rhs, const SourceRange& loc) { const auto lhs_select = Select(lhs); auto lhs_sv = emitSugaredExpr(lhs_select.value(), 1); const auto rhs_value = rhs->asValue(loc, method); lhs_sv->setAttr(loc, method, lhs_select.selector().name(), rhs_value); } NodeKind getNodeKind(int kind, size_t ninputs) { switch (kind) { case '+': return aten::add; case '-': return aten::sub; case TK_UNARY_MINUS: return aten::neg; case '*': return aten::mul; case TK_POW: return aten::pow; case '@': return aten::matmul; case TK_STARRED: return prim::Starred; case '/': return aten::div; case '%': return aten::remainder; case TK_NE: return aten::ne; case TK_EQ: return aten::eq; case '<': return aten::lt; case '>': return aten::gt; case TK_LE: return aten::le; case TK_GE: return aten::ge; case TK_AND: return aten::__and__; case TK_OR: return aten::__or__; case TK_IS: return aten::__is__; case TK_ISNOT: return aten::__isnot__; case TK_NOT: return aten::__not__; case TK_FLOOR_DIV: return aten::floordiv; case TK_LSHIFT: return aten::__lshift__; case TK_RSHIFT: return aten::__rshift__; case '&': return aten::__and__; case '|': return aten::__or__; case '^': return aten::__xor__; case TK_IN: return aten::__contains__; default: throw std::runtime_error("unknown kind " + std::to_string(kind)); } } std::string getOperatorOverload(int kind, size_t ninputs) { switch (kind) { case '+': return "__add__"; case '-': return "__sub__"; case TK_UNARY_MINUS: return "__neg__"; case '~': return "__invert__"; case '*': return "__mul__"; case TK_POW: return "__pow__"; case '/': return "__truediv__"; case '%': return "__mod__"; case TK_NE: return "__ne__"; case TK_EQ: return "__eq__"; case '<': return "__lt__"; case '>': return "__gt__"; case TK_LE: return "__le__"; case TK_GE: return "__ge__"; case '&': return "__and__"; case '|': return "__or__"; case '^': return "__xor__"; case TK_IN: return "__contains__"; case TK_LSHIFT: return "__lshift__"; case TK_RSHIFT: return "__rshift__"; default: throw std::runtime_error("unknown kind " + std::to_string(kind)); } } std::vector getNamedValues( const TreeList& trees, bool maybe_unpack) { std::vector values; for (const auto& tree : trees) { if (maybe_unpack && tree->kind() == TK_STARRED) { auto starred = Starred(tree); auto entries = emitSugaredExpr(starred.expr(), 1) ->asTuple(starred.range(), method); for (const auto& entry : entries) { values.emplace_back( tree->range(), entry->asValue(starred.range(), method)); } } else { values.emplace_back(tree->range(), emitExpr(Expr(tree))); } } return values; } std::vector getNamedValues( const List& trees, bool maybe_unpack) { return getNamedValues(trees.tree()->trees(), maybe_unpack); } std::vector getValues(const TreeList& trees, bool maybe_unpack) { return toValues(*graph, getNamedValues(trees, maybe_unpack)); } std::vector getValues(const List& trees, bool maybe_unpack) { return getValues(trees.tree()->trees(), maybe_unpack); } std::vector emitAttributes(const List& attributes) { return fmap(attributes, [&](const Attribute& attr) { return NamedValue( attr.range(), attr.name().name(), emitExpr(attr.value())); }); } void checkApplyNumInputs(const Apply& apply, size_t expected_inputs) { const SourceRange& loc = apply.range(); if (apply.inputs().size() != expected_inputs) { throw( ErrorReport(loc) << Var(apply.callee()).name().name() << " expected exactly " << expected_inputs << " arguments but found " << apply.inputs().size()); } if (!apply.attributes().empty()) { throw( ErrorReport(loc) << Var(apply.callee()).name().name() << " takes no keyword arguments"); } } void checkApplyNumInputsRange( const Apply& apply, size_t min_expected_inputs, size_t max_expected_inputs) { const SourceRange& loc = apply.range(); size_t position_arg_size = apply.inputs().size(); if (position_arg_size < min_expected_inputs || position_arg_size > max_expected_inputs) { throw( ErrorReport(loc) << Var(apply.callee()).name().name() << " expected to have number of arguments between " << min_expected_inputs << " and " << max_expected_inputs << " but found " << position_arg_size); } if (!apply.attributes().empty()) { throw( ErrorReport(loc) << Var(apply.callee()).name().name() << " takes no keyword arguments"); } } std::shared_ptr emitApplyExpr( Apply& apply, size_t n_binders, const TypePtr& type_hint = nullptr) { auto sv = emitSugaredExpr(apply.callee(), 1); auto loc = apply.callee().range(); if (auto special_form = dynamic_cast(sv.get())) { return emitApplySpecialForm(special_form->form(), apply, sv, type_hint); } auto args = getNamedValues(apply.inputs(), true); auto kwargs = emitAttributes(apply.attributes()); return sv->call(loc, method, args, kwargs, n_binders); } // this function handles expressions that look like apply statements // but have special evaluation rules for the arguments. // when adding a new case, only add a special form if it cannot be expressed // using the standard SugaredValue::call function, which enforces normal // evaluation order. std::shared_ptr emitApplySpecialForm( Symbol form, Apply& apply, const std::shared_ptr& sv, const TypePtr& type_hint = nullptr) { switch (form) { case prim::fork: { auto& trees = apply.inputs().tree()->trees(); if (trees.empty()) { throw( ErrorReport(apply) << "Expected at least one argument to fork()"); } auto forked = emitSugaredExpr(Expr(trees[0]), 1); TreeList sliced_trees(trees.begin() + 1, trees.end()); auto args = getNamedValues(sliced_trees, true); auto kwargs = emitAttributes(apply.attributes()); return emitForkExpr(apply.range(), forked, args, kwargs); } case prim::awaitable: { auto tree = apply.inputs().tree(); if (!tree || tree->trees().empty()) { throw( ErrorReport(apply) << "Expected at least one argument to awaitable()"); } auto& trees = tree->trees(); auto awaited = emitSugaredExpr(Expr(trees[0]), 1); TreeList sliced_trees(trees.begin() + 1, trees.end()); auto args = getNamedValues(sliced_trees, true); auto kwargs = emitAttributes(apply.attributes()); return emitAwaitableExpr(apply.range(), awaited, args, kwargs); } case prim::annotate: { checkApplyNumInputs(apply, 2); TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]); Value* expr = tryConvertToType( apply.range(), *graph, type, emitExpr(apply.inputs()[1], type), /*allow_conversions=*/true); std::stringstream why_not; if (!expr->type()->isSubtypeOfExt(*type, &why_not)) { throw( ErrorReport(apply.inputs()) << "expected an expression of type " << type->repr_str() << " but found " << expr->type()->repr_str() << "\n" << why_not.str()); } // None is a subtype of Optional[T], but we want to remember what T is // after annotation so that variables assigned to this None will still // get the right type. To do this, we make a None constant that // has the type Optional[T] if ((type->kind() == OptionalType::Kind || (type->kind() == UnionType::Kind && type->expect()->canHoldType(*NoneType::get()))) && expr->type()->isSubtypeOf(*NoneType::get())) { Node* none = graph->createNone(); none->output()->setType(type); graph->insertNode(none); expr = none->output(); } return std::make_shared(expr); } case prim::rpc_async: case prim::rpc_sync: case prim::rpc_remote: { return emitRpcExpr(apply, form); } case prim::unchecked_cast: { checkApplyNumInputs(apply, 2); TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]); Value* v = emitExpr(apply.inputs()[1]); // avoid generating nested unchecked_casts because they are already // inserted during serialization if (v->node()->kind() != prim::unchecked_cast || *v->type() != *type) { v = graph->insertUncheckedCast(v, type); } return std::make_shared(v); } break; case prim::GetAttr: { checkApplyNumInputsRange(apply, 2, 3); auto obj = emitSugaredExpr(apply.inputs()[0], 1); auto selector = apply.inputs()[1]; if (selector.kind() != TK_STRINGLITERAL) { throw( ErrorReport(apply) << "getattr's second argument must be a string literal"); } const std::string& name = StringLiteral(selector).text(); if (apply.inputs().size() == 2) { return obj->attr(apply.range(), method, name); } else { // 3 inputs form of getattr, the third argument is the default value // to return when attribute is not found if (obj->hasAttr(apply.range(), method, name)) { return obj->attr(apply.range(), method, name); } else { // attribute not found, just default val (3rd arg) return emitSugaredExpr(apply.inputs()[2], 1); } } } break; case prim::Uninitialized: { checkApplyNumInputs(apply, 1); TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]); auto out = graph->insertNode(graph->createUninitialized(type)) ->setSourceRange(apply.range()); return std::make_shared(out->output()); } case prim::TupleConstruct: { checkApplyNumInputs(apply, 1); auto arg = emitSugaredExpr(apply.inputs()[0], 1); auto inputs = arg->asTuple(apply.range(), method); auto inp_values = fmap(inputs, [&](const SugaredValuePtr& sv) { return sv->asValue(apply.range(), method); }); return std::make_shared( graph->insertNode(graph->createTuple(inp_values))->output()); } case prim::LegacyTypedConstructor: { // see legacy_tensor_generic_ctor_new // These legacy constructors do not follow schemas that can be // typed in native_functions.yaml / JIT type signature and are handled // here. Only the two common cases are handled initially: // "new(IntArrayRef size, *, Device? device=None)", // "new(PyObject* data, *, Device? device=None)", // Note: device argument is unused in the kernel auto args = getValues(apply.inputs(), true); auto kwargs = emitAttributes(apply.attributes()); auto get_base_error_msg = [&]() { std::stringstream base_error_msg; base_error_msg << "Legacy Tensor Constructor only supports two schemas in TorchScript: \n"; base_error_msg << "'new(IntArrayRef size, *, Device? device=None)',\n"; base_error_msg << "'new(PyObject* data, *, Device? device=None)\n'"; return base_error_msg; }; if (kwargs.size() == 1 && kwargs[0].name() != "device") { throw( ErrorReport(apply) << get_base_error_msg().str() << "Got kwarg " << kwargs[0].name()); } if (kwargs.size() > 1) { throw( ErrorReport(apply) << get_base_error_msg().str() << "Got multiple kwargs\n"); } auto dtype = dynamic_cast(sv.get())->dtype(); auto dtype_ivalue = graph->insertConstant(dtype); // supporting "new(IntArrayRef size, *, Device? device=None)", through // empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? // layout=None, Device? device=None, bool? pin_memory=None, // MemoryFormat? memory_format=None) -> Tensor bool all_ints = std::all_of(args.begin(), args.end(), [](Value* v) { return v->type()->cast(); }); if (args.empty()) { // empty inputs == torch.tensor([], dtype=....) auto inp_list = graph->insertNode(graph->createList(IntType::get(), {})) ->output(); return std::make_shared(graph->insert( aten::tensor, {inp_list}, {NamedValue(apply.range(), "dtype", dtype_ivalue)})); } else if (all_ints) { auto inp_list = graph->insertNode(graph->createList(IntType::get(), args)) ->output(); return std::make_shared(graph->insert( aten::empty, {inp_list}, {NamedValue(apply.range(), "dtype", dtype_ivalue)})); } else if (args.size() == 1) { return std::make_shared(graph->insert( aten::tensor, {args[0]}, {NamedValue(apply.range(), "dtype", dtype_ivalue)})); } else { throw( ErrorReport(apply) << get_base_error_msg().str() << "Got multiple positional arguments that were not all integers"); } } case prim::isinstance: { checkApplyNumInputs(apply, 2); auto result = emitIsInstance(apply.inputs()[0], apply.inputs()[1]); return std::make_shared(result.value()); } case prim::tolist: { auto select = Select(apply.callee()); auto value = select.value(); auto operand = emitSugaredExpr(value, 1); if (!type_hint) { throw( ErrorReport(apply) << "Expected type hint for result of tolist()"); } return std::make_shared(graph->insertToList( operand->asValue(value.range(), method), type_hint)); } case prim::HasAttr: { checkApplyNumInputs(apply, 2); const auto result = emitHasAttr(apply.inputs()[0], apply.inputs()[1]); return std::make_shared(result.value()); } break; // This represents the "__new__" method on classes // because it takes a ClassValue as input. // So if we see: // Foo.__new__(Foo) // Foo is a ClassValue, calling `attr("__new__")` will return a // CreateObject special form. case prim::CreateObject: { if (apply.inputs().size() != 1) { throw(ErrorReport(apply) << "Only one argument to __new__ allowed"); } auto arg = emitSugaredExpr(apply.inputs()[0], 1); auto class_arg = dynamic_cast(arg.get()); if (!class_arg) { throw( ErrorReport(apply) << "Expected class value as argument to __new__, got " << arg->kind() << " instead"); } auto createNode = graph->insertNode(graph->createObject(class_arg->type_)); createNode->setSourceRange(apply.range()); return std::make_shared(createNode->output()); } // We construct the iterable tree here using the IterableTree // SugaredValue, The tree consists of SimpleValue, RangeValue or // IterableTree: For SimpleValues(List, Dict, etc) or RangeValue. We will // make them as tree leaves since we could get the loop information from // len() and get_item(). For IterableTree like zip(), enumerate(), we can // model them as a combination of leaves, and we emit a IterableTree value // to record the tree information case prim::range: { std::vector input_vals = getValues(apply.inputs(), /*maybe_unpack=*/true); return std::make_shared(apply.range(), method, input_vals); } case prim::enumerate: { const SourceRange& loc = apply.range(); auto inputs = apply.inputs(); auto input_size = inputs.size(); auto attributes = apply.attributes(); auto attribute_size = attributes.size(); // enumerate(x) can be rewrite as subtrees: // IterableTree(RangeValue(0, math.inf), SimpleValue(x)) Value* start_index = nullptr; if (input_size == 0) { throw( ErrorReport(loc) << "enumerate expected at least 1 arguments, got 0"); } if (input_size == 2) { start_index = emitSugaredExpr(inputs[1], 1)->asValue(loc, method); } auto arg_size = input_size + attribute_size; if (arg_size > 2) { throw( ErrorReport(loc) << "enumerate expected at most 2 arguments, got " << arg_size); } if (attribute_size == 1) { if (attributes[0].name().name() != "start") { throw( ErrorReport(loc) << "enumerate expected kwarg name 'start', got '" << attributes[0].name().name() << "'"); } start_index = emitSugaredExpr(attributes[0].value(), 1)->asValue(loc, method); } std::vector range_inputs; if (start_index != nullptr) { range_inputs.emplace_back(start_index); } Value* end = materializeConstant( std::numeric_limits::max(), *graph, loc, integral_constants); range_inputs.emplace_back(end); SugaredValuePtr expr_sv = emitSugaredExpr(inputs[0], 1); auto iterable_value = expr_sv->iter(loc, method); // range should have the same static length as the other iterable std::optional iter_static_len = iterable_value->staticLen(); SugaredValuePtr range_sv = std::make_shared( loc, method, range_inputs, iter_static_len); auto tree = std::make_shared(); tree->addChild(loc, method, range_sv); tree->addChild(loc, method, iterable_value); return tree; } case prim::zip: { // zip(x, y) can be rewrite as subtrees: // IterableTree(IterableTree(x), IterableTree(y)) auto inputs = apply.inputs(); if (inputs.empty()) { throw( ErrorReport(apply) << "zip expected at least 1 arguments, got 0"); } auto iterable_tree = std::make_shared(); for (Expr expr : inputs) { auto iterable = emitSugaredExpr(expr, 1)->iter(apply.range(), method); iterable_tree->addChild(apply.range(), method, iterable); } return iterable_tree; } case prim::list: { return emitApplySpecialFormForList(apply, type_hint); } case prim::dict: { return emitApplySpecialFormForDict(apply, type_hint); } case aten::index: { const SourceRange& loc = apply.range(); auto select = Select(apply.callee()); auto self = emitSugaredExpr(select.value(), 1)->asValue(loc, method); auto inputs = apply.inputs(); if (inputs.size() != 1) { throw( ErrorReport(apply) << "__getitem__ expected exactly 1 arguments, got " << inputs.size()); } auto input = emitSugaredExpr(apply.inputs()[0], 1)->asValue(loc, method); if (input->type()->kind() == TypeKind::TupleType) { return std::make_shared( emitIndex(loc, self, createTupleUnpack(input))); } return std::make_shared(emitIndex(loc, self, {input})); } default: TORCH_INTERNAL_ASSERT(false, "unknown special form: ", form); } } std::shared_ptr emitApplySpecialFormForList( Apply& apply, const TypePtr& type_hint = nullptr) { if (apply.inputs().empty()) { TypePtr type = type_hint ? type_hint : ListType::ofTensors(); if (!type->cast()) { throw( ErrorReport(apply.range()) << "Expected list type annotation for list(), found " << type_hint->repr_str()); } return std::make_shared( graph ->insertNode(graph->createList( type->expectRef().getElementType(), {})) ->output()); } // list(iter) desugars to [_elem for _elem in iter] checkApplyNumInputs(apply, 1); auto iter_input = emitSugaredExpr(apply.inputs()[0], 1); // aten::list builtin op is registered for List and Str input // dispatch to the builtin op to avoid perf slowdown on existing uses if (auto simple = asSimple(iter_input)) { if (simple->type()->cast() || simple->type()->cast()) { return std::make_shared(emitBuiltinCall( apply.range(), *method.graph(), aten::list, {simple}, {})); } } const std::string& iter_name = createTempName("$_iter"); environment_stack->setSugaredVar( apply.range(), iter_name, iter_input, /*annotated_type=*/nullptr); const std::string& elem_name = createTempName("$_elem"); auto ident = Var::create(apply.range(), Ident::create(apply.range(), elem_name)); auto iter = Var::create(apply.range(), Ident::create(apply.range(), iter_name)); auto lc = ListComp::create(apply.range(), ident, ident, iter); return std::make_shared(emitListComprehension(lc, type_hint)); } std::shared_ptr emitApplySpecialFormForDict( Apply& apply, const TypePtr& type_hint = nullptr) { auto check_type_assignment_error = [&](const TypePtr& key_type, const TypePtr& value_type, const TypePtr& annotated_dict_type) { std::stringstream ss; std::stringstream err; auto annotated_k_type = annotated_dict_type->expect()->getKeyType(); auto annotated_v_type = annotated_dict_type->expect()->getValueType(); const auto is_key_subtype = key_type == annotated_k_type; const auto is_value_subtype = value_type->isSubtypeOfExt(annotated_v_type, &ss); if (!is_key_subtype) { err << "Generated key type " << key_type->repr_str() << " did not match the annotated key type, which was " << annotated_k_type->repr_str() << "\n"; } if (!is_value_subtype) { err << "Generated value type " << value_type->repr_str() << " did not match the annotated value type, which was " << annotated_v_type->repr_str() << "\n" << ss.str(); } if (!is_key_subtype || !is_value_subtype) { throw(ErrorReport(apply) << err.str()); } }; auto add_kwargs = [&](Value* dc_value) { NamedValue self = NamedValue(apply.range(), "self", dc_value); for (const auto& kwarg : apply.attributes()) { auto name = StringLiteral::create(kwarg.range(), kwarg.name().name()); auto k = emitExpr(name); auto v = emitExpr(kwarg.value()); NamedValue input_k = NamedValue(kwarg.range(), "", k); NamedValue input_v = NamedValue(kwarg.range(), "", v); check_type_assignment_error(k->type(), v->type(), dc_value->type()); emitBuiltinCall( kwarg.range(), *graph, aten::_set_item, {self, input_k, input_v}, {}); } }; auto treat_as_empty_container = [&]() { // true if `dict()` if (apply.inputs().empty() && !apply.attributes().empty()) { return true; } // true if `dict({})` if (!apply.inputs().empty() && apply.inputs()[0].kind() == TK_DICT_LITERAL) { auto dict_lit = DictLiteral(apply.inputs()[0]); return dict_lit.key_inputs().empty() && dict_lit.value_inputs().empty(); } // true if `dict([])` if (!apply.inputs().empty() && apply.inputs()[0].kind() == TK_LIST_LITERAL) { auto list_lit = ListLiteral(apply.inputs()[0]); return list_lit.inputs().empty(); } return false; }; TypePtr annotated_union_type = type_hint && type_hint->isUnionType() ? type_hint : nullptr; auto add_union_cast = [&](Value* result) { Node* n = graph->insertNode(graph->create(prim::unchecked_cast, {result})); n->output()->setType(std::move(annotated_union_type)); result = n->output(); }; TypePtr refined_type_hint = type_hint; std::vector all_candidates = {}; auto type_match = [&](const TypePtr& t) { return t->kind() == DictType::Kind; }; if (type_hint && type_hint->kind() != DictType::Kind) { refineAndSetUnionTypeHintOrPopulateCandidatesVector( type_hint, &refined_type_hint, &all_candidates, "Dict", apply, type_match, [] {}, [] {}, /*is_dict_constructor=*/true); } if (!all_candidates.empty()) { throw( ErrorReport(apply) << "There are multiple candidate " << "Dict types in the Union type annotation `" << type_hint->repr_str() << "`, and full type inference is not yet supported for the " << "`dict()` constructor."); } // If possible, just cast what we have to a Dict and add the // kwargs by hand. This is not only the simplest solution; it also // hits cases like `dict(dict([1, 2, 3]))` or `dict(x)` (where `x` // is some previously-defined variable) if (!apply.inputs().empty()) { // TODO(@ansley): Fix this! We have a weird situation where the // dict constructor may be handed an internal container literal // or comprehension, in which case we'd throw an error because // the lhs type wouldn't match the rhs type (the compiler wouldn't // be able to tell that this was part of a nested expression). We // used to get around this by simply not passing `type_hint`, but // 1) that's bad, and 2) we actually need `type_hint` for // inference now that Union has been introduced. std::shared_ptr iter_input; try { iter_input = emitSugaredExpr(apply.inputs()[0], 1, type_hint); } catch (const ErrorReport&) { iter_input = emitSugaredExpr(apply.inputs()[0], 1); } if (auto simple = asSimple(iter_input)) { if (simple->type()->cast()) { auto dc_value = emitBuiltinCall( apply.range(), *method.graph(), aten::dict, {simple}, {}); add_kwargs(dc_value); if (annotated_union_type) { add_union_cast(dc_value); } return std::make_shared(dc_value); } } } // If we have a call with an empty container, or if we have a // call with kwargs only if (treat_as_empty_container()) { auto expr_list = List::create(apply.range(), {}); apply = Apply::create( apply.range(), apply.callee(), expr_list, apply.attributes()); } // If we have a completely empty call to dict() if (apply.inputs().empty() && apply.attributes().empty()) { if (!refined_type_hint) { refined_type_hint = DictType::create(StringType::get(), TensorType::get()); } else if (!all_candidates.empty()) { throw( ErrorReport(apply.range()) << "Cannot determine the type " << "of an empty dict given the Union annotation `" << type_hint->repr_str() << "`, which contains multiple " << "candidate Dict types "); } TORCH_CHECK( refined_type_hint->kind() == DictType::Kind, "Expected a type annotation " "of Dict for dict constructor dict(), got ", type_hint->str()); return std::make_shared( graph ->insertNode(graph->createDict( refined_type_hint->expect()->getKeyType(), refined_type_hint->expect()->getValueType(), {}, {})) ->output()); } // Special-case logic for if we have a dict comprehension if (!apply.inputs().empty() && apply.inputs()[0].kind() == TK_DICT_COMP) { auto dc = DictComp(apply.inputs()[0]); auto dc_value = emitDictComprehension(dc, refined_type_hint); add_kwargs(dc_value); return std::make_shared(dc_value); } // We can't feasibly register all possible key x value // combinations of new prim ops for the case that we use the // constructor with a dict literal. It makes much more sense // to transform the dict literal into a list of tuples so that // we can use the existing constructors if (!apply.inputs().empty() && apply.inputs()[0].kind() == TK_DICT_LITERAL) { auto dict_lit = DictLiteral(apply.inputs()[0]); std::vector zipped; zipped.reserve(dict_lit.key_inputs().size()); TORCH_INTERNAL_ASSERT( dict_lit.key_inputs().size() == dict_lit.value_inputs().size()); for (auto key_it = dict_lit.key_inputs().begin(), val_it = dict_lit.value_inputs().begin(); key_it != dict_lit.key_inputs().end(); ++key_it, ++val_it) { auto tuple_inputs = List::create(apply.range(), {*key_it, *val_it}); auto tuple = TupleLiteral::create(apply.range(), tuple_inputs); zipped.push_back(tuple); } auto ll_values = List::create(apply.range(), zipped); auto ll = ListLiteral::create(apply.range(), ll_values); auto expr_list = List::create(apply.range(), {ll}); // Change `apply` to a new Apply node holding a list of // tuples apply = Apply::create( apply.range(), apply.callee(), expr_list, apply.attributes()); } // If we have kwargs to include, we'll take a similar approach // to the above logic and standardize the Apply node if (!apply.attributes().empty() && (apply.inputs().empty() || apply.inputs()[0].kind() == TK_LIST_LITERAL)) { std::vector exprs; // Gather all the existing tuples in the input iterable if (!apply.inputs().empty()) { auto tuple_list = ListLiteral(apply.inputs()[0]).inputs(); for (const auto& tuple : tuple_list) { exprs.push_back(tuple); } } // Create tuples out of each kwarg and gather them as well for (const auto& attr : apply.attributes()) { auto k = StringLiteral::create(apply.range(), attr.name().name()); auto v = attr.value(); auto tuple_inputs = List::create(apply.range(), {k, v}); auto tuple = TupleLiteral::create(apply.range(), tuple_inputs); exprs.push_back(tuple); } auto expr_list = List::create(apply.range(), {exprs}); auto ll = ListLiteral::create(apply.range(), expr_list); auto new_inputs = List::create(apply.range(), {ll}); auto new_kwargs = List::create(apply.range(), {}); apply = Apply::create(apply.range(), apply.callee(), new_inputs, new_kwargs); } checkApplyNumInputs(apply, 1); auto iter_input = emitSugaredExpr(apply.inputs()[0], 1); const std::string& iter_name = createTempName("$_iter"); const std::string& key_name = createTempName("$_key"); const std::string& value_name = createTempName("$_value"); auto key = Var::create(apply.range(), Ident::create(apply.range(), key_name)); auto value = Var::create(apply.range(), Ident::create(apply.range(), value_name)); auto target = TupleLiteral::create( apply.range(), List::create(apply.range(), {key, value})); auto iter = Var::create(apply.range(), Ident::create(apply.range(), iter_name)); environment_stack->setSugaredVar( apply.range(), iter_name, iter_input, /*annotated_type=*/nullptr); auto dc = DictComp::create(apply.range(), key, value, target, iter); auto result = emitDictComprehension(dc, refined_type_hint); add_kwargs(result); if (annotated_union_type) { add_union_cast(result); } return std::make_shared(result); } Value* emitExpr(const Expr& tree, const TypePtr& type_hint = nullptr) { // Push the source range of a call in case compiling this function // triggers an error ErrorReport::CallStack::update_pending_range(tree.range()); Value* out_val = emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method); // AnyType is the only user-exposed type which we don't unify to from // its subtypes, so we add a cast for use cases like // x : Any = 1 if cond else "str" if (type_hint == AnyType::get() && out_val->type() != AnyType::get()) { out_val = graph->insertUncheckedCast(out_val, type_hint); } return out_val; } NodeKind reverseComparision(NodeKind kind) { if (kind == aten::lt) { return aten::gt; } else if (kind == aten::le) { return aten::ge; } else if (kind == aten::gt) { return aten::lt; } else if (kind == aten::ge) { return aten::le; } throw std::runtime_error( "reverseComparision: unsupported NodeKind. File a bug"); } // any expression that can produce a SugaredValue is handled here // expressions that only return a single Value* are handled in emitSimpleExpr // type_hint is set if there is a type that this value is expected to be // e.g. a : List[int] = [] // or a = torch.jit.annotate(List[int], []) // the caller is responsible for checking that the result matches type_hint // emitSugaredExpr is free to ignore it. std::shared_ptr emitSugaredExpr( const Expr& tree, size_t n_binders, const TypePtr& type_hint = nullptr) { switch (tree.kind()) { case TK_VAR: { return environment_stack->getSugaredVar(Var(tree).name()); } case '.': { auto select = Select(tree); auto sv = emitSugaredExpr(select.value(), 1); return sv->attr(select.range(), method, select.selector().name()); } case TK_APPLY: { auto apply = Apply(tree); return emitApplyExpr(apply, n_binders, type_hint); } break; case TK_SUBSCRIPT: { return emitSubscript(Subscript(tree), type_hint); } break; default: return std::make_shared(emitSimpleExpr(tree, type_hint)); } } Value* emitUnaryOp( const TreeRef& tree, const std::string& magicMethod, const c10::Symbol& opSymbol) { const auto& inputs = tree->trees(); auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false); auto val = asSimple(makeMagic( magicMethod, std::make_shared(opSymbol, std::nullopt)) ->call(tree->range(), method, named_values, {}, 0)); // if we emitted the unary op and not some other overloaded function, // then try to constantfold if (val->node()->kind() != opSymbol) { return val; } auto maybe_out_stack = runNodeIfInputsAreConstant(val->node()); if (!maybe_out_stack) { return val; } TORCH_INTERNAL_ASSERT(maybe_out_stack->size() == 1); return graph->insertConstant(maybe_out_stack->at(0), tree->range()); } /** * Emit a fork expression, of the form: * torch.jit.fork(forked, *args, **kwargs) */ std::shared_ptr emitForkExpr( SourceRange loc, const std::shared_ptr& forked, at::ArrayRef args, at::ArrayRef kwargs) { auto g = method.graph(); TypePtr out_type; auto fork_node = g->insertNode(method.graph()->create(prim::forkClosure, 1)) ->setSourceRange(loc); // We create a fork by emitting a closure and setting the closure output // into the fork input. If a closure doesn't already exist, we create one. { WithInsertPoint insert(fork_node); if (ClosureValue* sv = dynamic_cast(forked.get())) { Value* closure_output = sv->asValue(loc, method); Block* closure_block = closure_output->node()->blocks().at(0); TORCH_INTERNAL_ASSERT(closure_block->outputs().size() == 1); out_type = closure_block->outputs().at(0)->type(); fork_node->addInput(closure_output); } else { auto emit_closure_body = [&](Block* closure_block) { auto fn_sugared_output = forked->call(loc, method, args, kwargs, 1); auto fn_simple_output = fn_sugared_output->asValue(loc, method); closure_block->registerOutput(fn_simple_output); out_type = fn_simple_output->type(); }; auto closure_value = emitClosure(emit_closure_body); fork_node->addInput(closure_value->asValue(loc, method)); } } Value* node_output = fork_node->output()->setType(FutureType::create(out_type)); return std::make_shared(node_output); } std::shared_ptr emitAwaitableExpr( SourceRange loc, const std::shared_ptr& awaited, at::ArrayRef args, at::ArrayRef kwargs) { auto g = method.graph(); TypePtr out_type{}; auto await_node = g->insertNode(method.graph()->create(prim::awaitableClosure, 1)) ->setSourceRange(loc); { WithInsertPoint insert(await_node); if (auto sv = dynamic_cast(awaited.get())) { Value* closure_output = sv->asValue(loc, method); Block* closure_block = closure_output->node()->blocks().at(0); TORCH_INTERNAL_ASSERT(closure_block->outputs().size() == 1); out_type = closure_block->outputs().at(0)->type(); await_node->addInput(closure_output); } else { auto emit_closure_body = [&](Block* closure_block) { auto fn_sugared_output = awaited->call(loc, method, args, kwargs, 1); auto fn_simple_output = fn_sugared_output->asValue(loc, method); closure_block->registerOutput(fn_simple_output); out_type = fn_simple_output->type(); }; auto closure_value = emitClosure(emit_closure_body); await_node->addInput(closure_value->asValue(loc, method)); } } Value* node_output = await_node->output()->setType(AwaitType::create(out_type)); return std::make_shared(node_output); } std::shared_ptr emitRpcExpr(const Apply& apply, Symbol rpc_op) { // TODO: This is a temporary apporoach to enable calling user fucntion // through RPC in TorchScript, // Ideally, function value in JIT IR is first-class citizen and // The RPC C++ entry API can take c10::Function directly. size_t rpcMinInputs = 2; size_t rpcMaxInputs = 5; std::string op_name = rpc_op.toUnqualString(); if (apply.inputs().size() < rpcMinInputs || apply.inputs().size() > rpcMaxInputs) { throw( ErrorReport(apply) << "Possible forms of call to " << op_name << "(..) are\n" << op_name << "(dst_worker_name, user_callable, args, kwargs, timeout)\n" << op_name << "(dst_worker_name, user_callable, args, kwargs)\n" << op_name << "(dst_worker_name, user_callable, args)\n" << op_name << "(dst_worker_name, user_callable)\n" << "Now the number of arguments is " << apply.inputs().size()); } if (!apply.attributes().empty()) { throw( ErrorReport(apply) << op_name << "(dst_worker_name, user_callable, args, kwargs)" << "does not support kwargs yet"); } // TODO: Make rpc_op(..) support taking kwargs, // like rpc_async(to="worker1", func=my_func, args=(), kwargs={}) auto& input_trees = apply.inputs().tree()->trees(); Value* dst_worker_name_value = emitExpr(Expr(input_trees[0])); std::shared_ptr user_callable_sugared_value = emitSugaredExpr(Expr(input_trees[1]), 1); TORCH_CHECK( user_callable_sugared_value->kind() == "function", "user_callable should be a FunctionValue, it's now a ", user_callable_sugared_value->kind()) // NB: This should be done using `std::dynamic_pointer_cast` // and assert `user_callable_function_value != nullptr`. But somehow on // macos std::dynamic_pointer_cast always returns // `user_callable_function_value` as a `nullptr`, even if // `user_callable_sugared_value->kind() == "function"`. std::shared_ptr user_callable_function_value = std::static_pointer_cast(user_callable_sugared_value); // If `kwargs` is an empty dict, users are allowed to not pass `kwargs`. // If `args` and `kwargs` are an empty tuple and an empty dict, // respectively, users are allowed to not pass `args` and `kwargs`. TreeList args_kwargs_timeout_trees( input_trees.begin() + 2, input_trees.end()); // Get user callable. const auto& callablePtrs = user_callable_function_value->callees(); TORCH_INTERNAL_ASSERT( callablePtrs.size() == 1, "User-provided callable size should be 1. Now it's", callablePtrs.size()) Function* callablePtr = callablePtrs.at(0); const auto& functionSchema = callablePtr->getSchema(); const SourceRange& loc = apply.range(); auto graphPtr = method.graph(); // Match FunctionSchema. std::vector args; std::vector kwargs; // Get args and kwargs as `NamedValue`s. // Similar to getNamedValues(..) and emitAttributes(..). if (!args_kwargs_timeout_trees.empty()) { // Unroll args from a Var that is known to be a Tuple. auto& args_tree = args_kwargs_timeout_trees[0]; auto entry_sugared_values = emitSugaredExpr(Expr(args_tree), 1) ->asTuple(args_tree->range(), method); args.reserve(entry_sugared_values.size()); for (const auto& entrie_sugared_value : entry_sugared_values) { args.emplace_back( args_tree->range(), entrie_sugared_value->asValue(args_tree->range(), method)); } // NB: Can't do schema check on kwargs, given the RPC API is // rpc_op(to, user_callable, args, kwargs), // users can construct kwargs = {"first" + "_arg" : 1}. // Notice the key is determined at run time. // We can do it at compile time, unless one day the RPC API is // rpc_op(to, user_callable, arg_0, arg_1, kwarg_0="foo", // kwarg_1="bar") } matchSchema(functionSchema, loc, *graphPtr, args, kwargs); // Graph insert the QualifiedName as an constant input IR Value. const auto& qualname = callablePtr->qualname(); IValue userCallableQualNameIValue(qualname.qualifiedName()); Value* userCallableQualNameValue = graphPtr->insertConstant(userCallableQualNameIValue, loc); // Graph insert the corresponding RPC node to the graph. Node* rpc_node = graphPtr->insertNode(graphPtr->create(rpc_op, 1))->setSourceRange(loc); { WithInsertPoint insert(rpc_node); rpc_node->addInput(dst_worker_name_value); rpc_node->addInput(userCallableQualNameValue); for (const auto& tree : args_kwargs_timeout_trees) { rpc_node->addInput(emitExpr(Expr(tree))); } } Value* rpc_node_output = rpc_node->output(); // Set output type from FunctionSchema and corresponding rpc_op. const std::vector& returns = functionSchema.returns(); TORCH_INTERNAL_ASSERT(returns.size() == 1); TypePtr output_type = nullptr; if (rpc_op == prim::rpc_async) { // rpc_async returns FutureType of the functionSchema's return type output_type = FutureType::create(returns[0].type()); } else if (rpc_op == prim::rpc_sync) { // rpc_sync returns the functionSchema's return type output_type = returns[0].type(); } else if (rpc_op == prim::rpc_remote) { // rpc_remote returns RRefType of the functionSchema's return type output_type = RRefType::create(returns[0].type()); } else { throw( ErrorReport(apply) << rpc_op.toDisplayString() << " is not supported in TorchScript!'"); } rpc_node_output->setType(output_type); return std::make_shared(rpc_node_output); } Value* emitBinaryOp(const TreeRef& tree) { const auto& inputs = tree->trees(); auto kind = getNodeKind(tree->kind(), inputs.size()); auto overload = getOperatorOverload(tree->kind(), inputs.size()); auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false); if (tree->kind() == TK_IN) { // For `in` the arguments are in reverse order (the object being // checked is second) std::iter_swap(named_values.begin() + 0, named_values.begin() + 1); } // if this is adding two tuples, we deal with it here. // the reason is we can't specify the length of tuples // when registering custom aten::add. if (named_values[0].type()->kind() == TupleType::Kind && named_values[1].type()->kind() == TupleType::Kind && kind == aten::add) { auto first_tuple = createTupleUnpack(named_values[0].value(*graph)).vec(); auto second_tuple = createTupleUnpack(named_values[1].value(*graph)).vec(); first_tuple.insert( first_tuple.end(), second_tuple.begin(), second_tuple.end()); return graph->insertNode(graph->createTuple(first_tuple))->output(); } return asSimple( makeMagic( overload, std::make_shared(kind, std::nullopt)) ->call(tree->range(), method, named_values, {}, 0)); } Value* emitListLiteral(const ListLiteral& ll, const TypePtr& type_hint) { auto values = getValues(ll.inputs(), /*maybe_unpack=*/true); // Empty List Literals that are not assigned to variables // may match to any list type in schema matching, // but still default to List[Tensor] if assigned to a variable // or returned from a function // Restricting empty list matching to temporary values // avoids difficult to handle cases such as // a = [] // b = a // if cond: // b.append(2) // else: // a.append("hi") // This is also the same behavior that C++ allows with {} // (cannot assign to a variable typed as auto) // These nodes will be removed in a later pass after initial compilation if (values.empty() && type_hint == nullptr) { auto node = graph->insertNode(graph->create(prim::EmptyListLiteral)); node->output()->setType(ListType::ofTensors()); return node->output(); } // Determine the element type of the list. If we have a type hint // of `List[T]`, use `T`. If the list is non-empty, find the // greatest common supertype of all the list elements (defaulting to // `Any` as a catch-all supertype). Assume `[]` is `List[Tensor]` TypePtr inferred_elem_type = TensorType::get(); TypePtr refined_type_hint = type_hint; // If `type_hint` is a Union/Optional, we're going to change it to // be the type of the rhs List, so we need to store the original // UnionType for later. `nullptr` means that we don't need to emit // an `unchecked_cast` node (either because we don't have a type // hint or because the type hint wasn't a Union) TypePtr annotated_union_type = refined_type_hint && refined_type_hint->isUnionType() ? refined_type_hint : nullptr; // This is used in the case that we have a Union annotation that // contains multiple Lists std::vector all_candidates = {}; if (refined_type_hint) { auto do_if_type_match = [&]() { auto list_type_hint = refined_type_hint->cast(); inferred_elem_type = list_type_hint->getElementType(); }; auto type_match = [&](const TypePtr& t) { return t->isSubtypeOf(AnyListType::get()); }; refineAndSetUnionTypeHintOrPopulateCandidatesVector( type_hint, &refined_type_hint, &all_candidates, "List", ll, type_match, do_if_type_match, do_if_type_match); if (!all_candidates.empty() && values.empty()) { throw( ErrorReport(ll) << "Cannot assign an empty list to a " << "variable annotated to be type " << refined_type_hint->repr_str() << " because there are multiple possible List " << "type candidates in the Union annotation"); } } if (!values.empty()) { auto types = fmap(values, [](const Value* v) { return v->type(); }); std::stringstream nowhere; // never used // We don't want to use `elem_type` as the final argument to // `unifyTypeList` because there's a chance that `elem_type` is // the Tensor default const auto elem_type_hint = refined_type_hint && refined_type_hint->kind() == ListType::Kind ? refined_type_hint->cast()->getElementType() : nullptr; std::optional unified_elem_type = unifyTypeList( types, nowhere, /*default_to_union=*/true, elem_type_hint); if (!refined_type_hint && (*unified_elem_type)->kind() == UnionType::Kind) { TORCH_WARN( "List consists of heterogeneous types, which means", " that it has been typed as containing ", (*unified_elem_type)->repr_str(), ". To use any of the " "values in this List, it will be necessary to add an " "`assert isinstance` statement before first use to trigger " "type refinement.\n", ll.range().str()); } if (all_candidates.empty() && refined_type_hint && !(*unified_elem_type)->isSubtypeOf(*inferred_elem_type)) { throw( ErrorReport(ll) << "List type annotation `" << refined_type_hint->repr_str() << "` did not match the types of the given list elements," << " which were unified to " << (*unified_elem_type)->repr_str()); } if (!all_candidates.empty()) { refineAndSetListTypeHintFromCandidatesVector( all_candidates, type_hint, &refined_type_hint, *unified_elem_type, ll); inferred_elem_type = refined_type_hint->expect()->getElementType(); } // We only want to set `elem_type` if we don't have a type hint // to allow for the case that `*unified` is a subtype of // `type_hint` if (!refined_type_hint) { inferred_elem_type = *unified_elem_type; } } Node* result = graph->insertNode(graph->createList(inferred_elem_type, values)); if (annotated_union_type) { Node* n = graph->insertNode( graph->create(prim::unchecked_cast, {result->output()})); n->output()->setType(std::move(annotated_union_type)); result = n; } return result->output(); } Value* emitDictLiteral(DictLiteral dl, const TypePtr& type_hint) { auto key_trees = dl.key_inputs().tree()->trees(); auto value_trees = dl.value_inputs().tree()->trees(); AT_ASSERT(key_trees.size() == value_trees.size()); std::vector keys, values; TypePtr rhs_value_type; for (const auto i : c10::irange(key_trees.size())) { keys.push_back(emitExpr(Expr(key_trees[i]))); values.push_back(emitExpr(Expr(value_trees[i]))); if (i == 0) { rhs_value_type = values[i]->type(); } else { if (keys[i - 1]->type()->kind() != keys[i]->type()->kind()) { throw( ErrorReport(key_trees[i]) << "Dict keys must contain " << "only a single type. Expected: " << keys[i - 1]->type()->repr_str() << " but found " << keys[i]->type()->repr_str() << " instead"); } rhs_value_type = *(unifyTypes( rhs_value_type, values[i]->type(), /*default_to_union=*/true)); } } TypePtr refined_type_hint = type_hint; TypePtr annotated_union_type = type_hint && type_hint->isUnionType() ? type_hint : nullptr; std::vector all_candidates = {}; auto default_refined_type_hint_setter = [&]() { if (keys.empty()) { refined_type_hint = DictType::create(StringType::get(), TensorType::get()); } else { refined_type_hint = DictType::create(keys.at(0)->type(), rhs_value_type); if (rhs_value_type->kind() == UnionType::Kind) { TORCH_WARN( "Dict values consist of heterogeneous types, which means", " that the dict has been typed as containing ", refined_type_hint->repr_str(), ". To use any of the values in this Dict, it will be " "necessary to add an `assert isinstance` statement before " "first use to trigger type refinement.\n", dl.range().str()); } } }; if (type_hint) { auto type_match = [&](const TypePtr& t) { return t->kind() == DictType::Kind; }; refineAndSetUnionTypeHintOrPopulateCandidatesVector( type_hint, &refined_type_hint, &all_candidates, "Dict", dl, type_match, [] {}, default_refined_type_hint_setter); if (!all_candidates.empty() && values.empty()) { throw( ErrorReport(dl) << "Cannot assign an empty dict to a " << "variable annotated to be type " << type_hint->repr_str() << " because there are multiple possible Dict " << "type candidates in the Union annotation"); } } else { default_refined_type_hint_setter(); } // We must have either a) specific key/value types already, or b) a // list of possible candidates TORCH_INTERNAL_ASSERT(!all_candidates.empty() || refined_type_hint); if (!values.empty()) { if (!all_candidates.empty()) { refineAndSetDictTypeHintFromCandidatesVector( all_candidates, type_hint, &refined_type_hint, keys[0]->type(), rhs_value_type, dl); } if (refined_type_hint->expect()->getKeyType() != keys.at(0)->type()) { throw( ErrorReport(dl) << "Type annotation was inferred to be " << refined_type_hint->repr_str() << "but the type of keys given by the dict literal is " << keys.at(0)->type()->repr_str()); } if (!rhs_value_type->isSubtypeOf( refined_type_hint->expect()->getValueType())) { throw( ErrorReport(dl) << "Type annotation was inferred to be `" << refined_type_hint->repr_str() << "`, but the type of values given by the dict literal is " << rhs_value_type->repr_str()); } } Node* result = graph->insertNode(graph->createDict( refined_type_hint->expect()->getKeyType(), refined_type_hint->expect()->getValueType(), keys, values)); if (annotated_union_type) { Node* n = graph->insertNode( graph->create(prim::unchecked_cast, {result->output()})); n->output()->setType(std::move(annotated_union_type)); result = n; } return result->output(); } Value* emitSimpleExpr( const TreeRef& tree, const TypePtr& type_hint = nullptr) { switch (tree->kind()) { case TK_FLOOR_DIV: case '@': { const auto& inputs = tree->trees(); auto kind = getNodeKind(tree->kind(), inputs.size()); auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false); return emitBuiltinCall( tree->range(), *method.graph(), kind, named_values, {}); } case '%': { auto lhs = emitSugaredExpr(Expr(tree->tree(0)), 0) ->asValue(tree->tree(0)->range(), method); auto const& lhs_type = lhs->type(); if (lhs_type == StringType::get()) { auto values = getValues(tree->trees(), /*maybe_unpack=*/false); auto node = graph->create(aten::percentFormat, values, 1) ->setSourceRange(tree->range()); Value* output = graph->insertNode(node)->output(); output->setType(StringType::get()); return output; } else { return emitBinaryOp(tree); } } case TK_IN: case TK_POW: case TK_NE: case TK_EQ: case '<': case '>': case TK_LE: case TK_GE: case '*': case '/': case '+': case '-': case '&': case '|': case '^': case TK_LSHIFT: case TK_RSHIFT: return emitBinaryOp(tree); case TK_IS: case TK_ISNOT: case TK_AND: case TK_OR: case TK_NOT: { return emitCondExpr(Expr(tree)).value(); } case TK_UNARY_MINUS: { return emitUnaryOp(tree, "__neg__", aten::neg); } case '~': { return emitUnaryOp(tree, "__invert__", aten::bitwise_not); } case TK_STARRED: { throw( ErrorReport(tree) << "Unexpected starred expansion. File a bug report"); } case TK_CONST: { return emitConst(Const(tree)); } break; case TK_TRUE: { return graph->insertConstant(true, tree->range()); } break; case TK_FALSE: { return graph->insertConstant(false, tree->range()); } break; case TK_NONE: { return graph->insertConstant(IValue(), tree->range()); } break; case TK_IF_EXPR: { return emitTernaryIf(TernaryIf(tree), type_hint); } break; case TK_STRINGLITERAL: { return emitStringLiteral(StringLiteral(tree)); } break; case TK_LIST_LITERAL: { auto ll = ListLiteral(tree); return emitListLiteral(ll, type_hint); } break; case TK_TUPLE_LITERAL: { auto ll = TupleLiteral(tree); auto values = getValues(ll.inputs(), /*maybe_unpack=*/true); return graph->insertNode(graph->createTuple(values))->output(); } break; case TK_DICT_LITERAL: { auto dc = DictLiteral(tree); return emitDictLiteral(dc, type_hint); } break; case TK_LIST_COMP: { auto lc = ListComp(tree); return emitListComprehension(lc, type_hint); } break; case TK_DICT_COMP: { auto dc = DictComp(tree); return emitDictComprehension(dc, type_hint); } break; default: throw(ErrorReport(tree) << "Cannot emit expr for: " << tree); } } Value* emitConst(const Const& c) { if (c.isFloatingPoint()) return materializeConstant( c.asFloatingPoint(), *graph, c.range(), fp_constants); else if (c.isComplex()) return materializeConstant( c.asComplex(), *graph, c.range(), complex_constants); else return materializeConstant( c.asIntegral(), *graph, c.range(), integral_constants); } Value* emitStringLiteral(const StringLiteral& c) { return insertConstant(*graph, c.text(), c.range()); } // Desugars select indexing: tensor[i] -> tensor.select(dim, i) Value* emitSelect( const SourceRange& loc, Value* input, Value* dim, Value* index) { return emitBuiltinCall(loc, *graph, aten::select, {input, dim, index}, {}); } Value* emitSliceOp( const SourceRange& loc, Value* sliceable, Value* dim, Value* start, Value* end, Value* step) { std::vector args; args.reserve(5); args.emplace_back(loc, "self", sliceable); // XXX: If list slicing becomes more complicated or stops using // aten::slice, we should separate it from this function. if (dim) { AT_ASSERT(sliceable->type()->isSubtypeOf(*TensorType::get())); args.emplace_back(dim); } else { AT_ASSERT(!sliceable->type()->isSubtypeOf(*TensorType::get())); } if (sliceable->type()->cast()) { std::vector> tuple_args; // since we are only dealing with tuple slicing, we try to keep // tuple args separate for now tuple_args.reserve(3); start ? tuple_args.emplace_back(start) : tuple_args.emplace_back(std::nullopt); end ? tuple_args.emplace_back(end) : tuple_args.emplace_back(std::nullopt); step ? tuple_args.emplace_back(step) : tuple_args.emplace_back(std::nullopt); return emitTupleSlice(loc, args[0], tuple_args); } // handling cases like x[0:2]. x[0:2:] is already handled from python if (!step) { step = graph->insertConstant(1, loc); } args.emplace_back(loc, "start", start); args.emplace_back(loc, "end", end); args.emplace_back(loc, "step", step); return emitBuiltinCall(loc, *graph, aten::slice, args, {}); } // Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end, // 1) Value* emitSlice( const SourceRange& loc, Value* input, Value* dim, // Only used for tensor slicing const SliceExpr& slice) { Value* start = nullptr; Value* end = nullptr; Value* step = nullptr; if (slice.start().present()) { start = emitExpr(Expr(slice.start().get())); } if (slice.end().present()) { end = emitExpr(Expr(slice.end().get())); } if (slice.step().present()) { step = emitExpr(Expr(slice.step().get())); } return emitSliceOp(loc, input, dim, start, end, step); } Value* emitUnsqueeze(const SourceRange& loc, Value* input, Value* dim_val) { return emitBuiltinCall(loc, *graph, aten::unsqueeze, {input, dim_val}, {}); } Value* emitIndex( const SourceRange& loc, Value* input, at::ArrayRef indices) { // NB: the index of aten::index should be a type of List[Optional[Tensor]], // this is to support the case like t[:, :, 1] where : here indicates a // None/undefined tensor(optional tensor) auto* index = graph->insertNode(graph->createList(OptionalType::ofTensor(), indices)) ->output(); return emitBuiltinCall(loc, *graph, aten::index, {input, index}, {}); } // Emits multidimensional slicing with int and slice indices. // Returns: // - Value*: the input after it has been indexed by int and slice indices. // - vector: A list of tensor Value* indices that have not been // applied yet. // Should be NULL at indices where sliceable (post-slicing) isn't indexed by // a tensor. std::pair> emitIntAndSliceIndexing( const SourceRange& loc, Value* sliceable, const List& subscript_exprs) { // Overall, to handle indexing (other than Tensors), we need to handle a // couple different things. For example, for x[1:3, None, 4], each of these // different index types (slice, None, and integer) result in different // number of dimensions. Slicing doesn't change the number of dimensions, // None adds a dimension, and integer removes a dimension. As these indexing // operations are applied left to right, the actual index that it's being // applied to depends on the previous operations. Ellipses indexing throws // another wrinkle. Ellipses selects any remaining unspecified dimensions. // Thus, for indexes following an ellipses, the actual index an indexing // operation is being applied to depends on the operations to the right. // Thus, we do two passes, one from left to right up until the ellipses, and // one from right to left. std::vector tensor_indices; auto insert_value_for_dim = [&](int64_t dim) { return graph->insertConstant(dim, loc); }; std::vector dims(subscript_exprs.size()); std::vector> exprs( subscript_exprs.size(), std::nullopt); auto handle_indexing = [&](const Expr& subscript_expr, size_t expr_idx, int64_t dim, bool is_reverse = false) { dims[expr_idx] = dim; // Slice expression case, does not represent a single index. if (subscript_expr.kind() == TK_SLICE_EXPR) { if (is_reverse) { return dim - 1; } else { return dim + 1; } } // Slice object case, does not represent a single index. auto subscript_sv = emitSugaredExpr(subscript_expr, 1); if (dynamic_cast(subscript_sv.get())) { if (is_reverse) { return dim - 1; } else { return dim + 1; } } TypePtr type_hint; if (subscript_expr.kind() == TK_NONE) { type_hint = NoneType::get(); } auto index = emitExpr(subscript_expr, type_hint); // Accept list as subscript but convert it to a Tensor // since it's equivalent to indexing with Tensor. // The list can be a list literal or list variable. // Advanced indexing using list: // @torch.jit.script // def f(x): // return x[[0, 1, 5]] # or // return x[[0, 1], [0, 1]] # or // return x[[[0, 1], [0, 1]], [[0, 1], [0, 1]]] # or // ls = [0, 1] // return x[ls] // Statements above are equivalent to advanced indexing using Tensor: // @torch.jit.script // def f(x): // return x[torch.tensor([0, 1, 5])] # or // return x[torch.tensor([0, 1]), torch.tensor([0, 1])] # or // return x[torch.tensor([[0, 1], [0, 1]]), // torch.tensor([[0, 1], [0, 1]])] # or // ls = [0, 1] // return x[torch.tensor(ls)] if (index->type()->kind() == c10::TypeKind::ListType) { // Always create index tensor as LongTensor. // This is to match Pytorch eager frontend behavior which accepts // indexing with float list. index = graph->insert( aten::tensor, {index}, {NamedValue("dtype", c10::kLong)}); } exprs[expr_idx] = index; if (index->type()->isSubtypeOf(*NoneType::get())) { if (is_reverse) { return dim; } else { return dim + 1; } } else if (index->type() == IntType::get()) { if (is_reverse) { return dim - 1; } else { return dim; } } else if (index->type()->isSubtypeOf(*OptionalType::ofTensor())) { if (is_reverse) { throw( ErrorReport(loc) << "Ellipses followed by tensor indexing is currently not supported"); } else { return dim + 1; } } else { throw( ErrorReport(loc) << "Unsupported operation: indexing tensor with unsupported index type '" << index->type()->repr_str() << "'. Only ints, slices, lists and tensors are supported"); } }; size_t idx = 0; int64_t dim = 0; for (; idx < subscript_exprs.size(); idx++) { auto subscript_expr = subscript_exprs[idx]; if (subscript_expr.kind() == TK_DOTS) { break; } dim = handle_indexing(subscript_expr, idx, dim, /*is_reverse=*/false); } int64_t rdim = -1; for (size_t rev_idx = subscript_exprs.size() - 1; rev_idx > idx; rev_idx--) { auto subscript_expr = subscript_exprs[rev_idx]; if (subscript_expr.kind() == TK_DOTS) { throw( ErrorReport(loc) << "An index can only have a single ellipsis ('...')"); } rdim = handle_indexing(subscript_expr, rev_idx, rdim, /*is_reverse=*/true); } for (const auto i : c10::irange(exprs.size())) { if (!exprs[i].has_value()) { if (subscript_exprs[i].kind() == TK_SLICE_EXPR) { sliceable = emitSlice( loc, sliceable, insert_value_for_dim(dims[i]), SliceExpr(subscript_exprs[i])); continue; } if (subscript_exprs[i].kind() == TK_DOTS) { continue; } auto subscript_sv = emitSugaredExpr(subscript_exprs[i], 1); if (const auto slice_value = dynamic_cast(subscript_sv.get())) { sliceable = emitSliceOp( loc, sliceable, insert_value_for_dim(dims[i]), slice_value->start(), slice_value->stop(), slice_value->step()); } continue; } auto expr = exprs[i].value(); if (expr->type()->isSubtypeOf(*NoneType::get())) { sliceable = emitUnsqueeze(loc, sliceable, insert_value_for_dim(dims[i])); } else if (expr->type() == IntType::get()) { sliceable = emitSelect(loc, sliceable, insert_value_for_dim(dims[i]), expr); } else if (expr->type()->isSubtypeOf(*OptionalType::ofTensor())) { tensor_indices.resize(dims[i] + 1); tensor_indices[dims[i]] = expr; } else { TORCH_INTERNAL_ASSERT( false, "Trying to process index type that we don't support."); } } // at::index takes in a List[Optional[Tensor]] where some dims can be None. // create None node with optional tensor output type and pass to at::index. for (auto& index : tensor_indices) { if (index == nullptr) { index = graph->insertNode(graph->createNone())->output(); } } return std::make_pair(sliceable, tensor_indices); } // Desugars multidim slicing into slice/select/index/unsqueeze calls. // // XXX: Errors in user code are not elegantly reported. // Let's say someone were to do the following: // @torch.jit.script // def fn(x): // return x[0, 1] // fn(torch.randn(5)) // Because we desugar this into two aten::select ops, the error message // complains about aten::select failing rather than there "not being // enough dimensions to index". // // The strategy is to slice and select the tensor for int and slices first // in one pass and then apply at::index on the result of the // slicing/selecting. Call the tensor after we've applied slice / select the // `sliced`. tensor_indices should have the same size as sliced.dim(): // - tensor_indices[i] = NULL if we should not index `sliced` at dim i // - tensor_indices[i] = t if we should index `sliced` at dim i with tensor t. Value* emitMultidimSlicing( const SourceRange& loc, Value* sliceable, const List& subscript_exprs) { if (!sliceable->type()->isSubtypeOf(*TensorType::get())) { throw( ErrorReport(loc) << "Unsupported operation: attempted to use multidimensional " << "indexing on a non-tensor type"); } std::vector tensor_indices; std::tie(sliceable, tensor_indices) = emitIntAndSliceIndexing(loc, sliceable, subscript_exprs); if (tensor_indices.empty()) { // XXX: Might need to at::alias this when we support mutability return sliceable; } return emitIndex(loc, sliceable, tensor_indices); } // Desugars slice syntactic sugar tensor[begin:end] -> tensor.slice(begin, // end). Value* emitBasicSlice( const SourceRange& loc, Value* sliceable, const List& subscript_exprs) { AT_ASSERT(subscript_exprs.size() == 1); AT_ASSERT(subscript_exprs[0].kind() == TK_SLICE_EXPR); auto slice_exp = SliceExpr(subscript_exprs[0]); Value* maybe_dim = nullptr; if (sliceable->type()->isSubtypeOf(*TensorType::get())) { // If the sliceable object is a tensor, specify a default dimension maybe_dim = graph->insertConstant(0, loc); } return emitSlice(loc, sliceable, maybe_dim, slice_exp); } int64_t getAdjTupleIndex( const SourceRange& loc, const TupleTypePtr& tuple_type, int64_t input_index, bool allow_out_of_bounds) { // set index to be positive to simplify logic in runtime int64_t adj_index = input_index; int64_t tuple_len = static_cast(tuple_type->elements().size()); if (input_index < 0) { adj_index = tuple_len + input_index; } if (!allow_out_of_bounds && (adj_index >= tuple_len || adj_index < 0)) { throw( ErrorReport(loc) << "Tuple index out of range. Tuple is length " << tuple_len << " and index is " << input_index); } return adj_index; } // When a list is marked const in a module, it gets converted to a tuple. // The result is indexing into a Tuple which contains only one type // is quite common. since indexing will likely be done in a for loop, // we do not want to invoke the overhead of converting the tuple to a list // each iter. Value* emitTupleIndex( const SourceRange& loc, Value* tuple_val, Value* idx_val) { auto tuple_typ = tuple_val->type()->cast(); auto elems = tuple_typ->elements(); TypePtr output_type; if (idx_val->type() != IntType::get()) { throw(ErrorReport(loc) << "tuple index must be an integer"); } auto idx = toIValue(idx_val); if (!idx) { if (elems.empty() || !convertibleToList(tuple_typ, ListType::create(elems[0]))) { throw( ErrorReport(loc) << "Cannot index into a " << tuple_typ->repr_str() << " with a non-integer literal because we cannot resolve the output type"); } output_type = elems[0]; } else { auto adj_index = getAdjTupleIndex( loc, tuple_typ, idx->toInt(), /*allow_out_of_bounds*/ false); output_type = elems[adj_index]; } return graph ->insertNode(graph->createTupleIndex(tuple_val, idx_val, output_type)) ->output(); } int64_t getSliceInd(Value* idx_val, const SourceRange& loc) { auto ivalue = toIValue(idx_val); if (ivalue && ivalue->isInt()) { return ivalue->to(); } else { throw( ErrorReport(loc) << "tuple slice indices must be integer constants"); } } Value* emitTupleSlice( const SourceRange& loc, const NamedValue& tuple_val, const std::vector>& tuple_args) { auto tuple_type = tuple_val.value(*graph)->type()->expect(); auto tuple_len = tuple_type->elements().size(); auto beg_val = tuple_args[0]; auto end_val = tuple_args[1]; auto step = tuple_args[2]; int64_t step_size = 1; if (step) { auto val = toIValue(step->value(*graph)); TORCH_CHECK(val->isInt(), "Step size should always be an integer"); step_size = val->to(); } int64_t beg = std::numeric_limits::max(); if (beg_val) { beg = getAdjTupleIndex( loc, tuple_type, getSliceInd(beg_val->value(*graph), loc), true); } int64_t end = std::numeric_limits::max(); if (end_val) { end = getAdjTupleIndex( loc, tuple_type, getSliceInd(end_val->value(*graph), loc), true); } int64_t num_values = slice_indices_adjust( static_cast(tuple_len), &beg, &end, step_size); return graph ->insertNode(graph->createTupleSlice( tuple_val.value(*graph), beg, step_size, num_values)) ->output(); } std::shared_ptr emitSubscript( const Subscript& subscript, TypePtr type_hint = nullptr) { const SugaredValuePtr sv = emitSugaredExpr(subscript.value(), 1); const List& subscript_exprs = subscript.subscript_exprs(); const SourceRange& range = subscript.range(); const SourceRange& val_range = subscript.value().range(); if (subscript_exprs.size() != 1) { return std::make_shared(emitMultidimSlicing( range, sv->asValue(val_range, method), subscript_exprs)); } if (subscript_exprs[0].kind() == TK_SLICE_EXPR) { // TODO @wconstab refactor using Symbol instead of string compare if (sv->kind() == "module") { // Slicing isn't currently implemented for Sequential/ModuleList, // but is implemented for Tuples, so a quick workaround is to // convert to a tuple of Modules for slicing support. auto s_tuple_val = sv->asTupleValue(val_range, method)->asValue(val_range, method); const SliceExpr& slice = SliceExpr(subscript_exprs[0]); std::vector> tuple_args; tuple_args.reserve(3); if (slice.start().present()) { auto begin = NamedValue( val_range, "begin", emitExpr(Expr(slice.start().get()))); tuple_args.emplace_back(begin); } else { tuple_args.emplace_back(std::nullopt); } if (slice.end().present()) { auto end = NamedValue(val_range, "end", emitExpr(Expr(slice.end().get()))); tuple_args.emplace_back(end); } else { tuple_args.emplace_back(std::nullopt); } if (slice.step().present()) { auto step = NamedValue(val_range, "step", emitExpr(Expr(slice.step().get()))); tuple_args.emplace_back(step); } else { tuple_args.emplace_back(std::nullopt); } auto tupleSliceValue = emitTupleSlice(val_range, s_tuple_val, tuple_args); return std::make_shared(tupleSliceValue); } else { return std::make_shared(emitBasicSlice( range, sv->asValue(val_range, method), subscript_exprs)); } } else { AT_ASSERT(subscript_exprs.size() == 1); Value* sliceable = sv->asValue(val_range, method); // In case of subscript expression being a Python Slice object. auto subscript_sv = emitSugaredExpr(subscript_exprs[0], 1); if (const auto slice_value = dynamic_cast(subscript_sv.get())) { Value* dim = nullptr; // aten::slice.tensor needs an additional `dim` input. if (sliceable->type()->isSubtypeOf(*TensorType::get())) { dim = method.graph()->insertConstant(0, val_range); } Value* sliced = emitSliceOp( val_range, sliceable, dim, slice_value->start(), slice_value->stop(), slice_value->step()); return std::make_shared(sliced); } // subscript is not a slice object, then it must be convertible to // a normal value. // Desugars gather syntactic sugar foo[i] Value* idx = subscript_sv->asValue(val_range, method); if (sliceable->type()->cast()) { return std::make_shared( emitTupleIndex(range, sv->asValue(val_range, method), idx)); } else if (sliceable->type()->isSubtypeOf(*TensorType::get())) { return std::make_shared( emitMultidimSlicing(range, sliceable, subscript_exprs)); } else { return sv->getitem(range, method, idx, std::move(type_hint)); } } } }; struct FunctionResolver : public Resolver { explicit FunctionResolver( Resolver* otherResolver, const std::unordered_map& functionTable) : otherResolver_(otherResolver), functionTable_(functionTable) {} std::shared_ptr resolveValue( const std::string& name, GraphFunction& m, const SourceRange& loc) override { auto it = functionTable_.find(name); if (it != functionTable_.end()) { return std::make_shared(it->second); } return otherResolver_->resolveValue(name, m, loc); } TypePtr resolveType(const std::string& name, const SourceRange& loc) override { return otherResolver_->resolveType(name, loc); } private: Resolver* otherResolver_; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::unordered_map& functionTable_; }; CompilationUnit::CompilationUnit(const std::string& source) : CompilationUnit() { // calles the define with native resolver to generate the graph for functions define(std::nullopt, source, nativeResolver(), nullptr); } // This pair represents a pair of functions (getter and setter) obtained from // compiling a Property. struct CompilationUnit::PropertyPair : public std::pair, std::unique_ptr> { PropertyPair( std::unique_ptr getter, std::unique_ptr setter) { TORCH_INTERNAL_ASSERT(getter, "Property pair must have defined getter") this->first = std::move(getter); this->second = std::move(setter); } std::unique_ptr& getGetter() { return this->first; } std::unique_ptr& getSetter() { return this->second; } }; CompilationUnit::PropertyPair CompilationUnit::define_property( const std::optional& prefix, const Property& prop, const ResolverPtr& resolver, const Self* self, const std::unordered_map& function_table, bool shouldMangle) const { // self must be defined because properties are features of classes and // modules. TORCH_INTERNAL_ASSERT(self); // Compile the getter function. std::unique_ptr getter_fn = define( prefix, prop.getter(), resolver, self, function_table, shouldMangle); // Compile the setter function if it exists. std::unique_ptr setter_fn = nullptr; if (prop.setter().present()) { setter_fn = define( prefix, prop.setter().get(), resolver, self, function_table, shouldMangle); } // Add the property to the class type definition. self->getClassType()->addProperty( prop.name().name(), getter_fn.get(), setter_fn.get()); return PropertyPair(std::move(getter_fn), std::move(setter_fn)); } std::unique_ptr CompilationUnit::define( const std::optional& prefix, const Def& def, const ResolverPtr& resolver, const Self* self, const std::unordered_map& function_table, bool shouldMangle, CompilationUnit::FunctionType type, std::optional operator_set_version) const { TORCH_INTERNAL_ASSERT(resolver); auto _resolver = resolver; if (!self) { // if self is defined, then these are methods and do not go into the // global namespace otherwise, they get defined together so we add them to // the function table so the methods can see each other _resolver = std::make_shared(resolver.get(), function_table); } auto creator = [def, _resolver, self](GraphFunction& method) { // Store the function name so that it can be referenced if there is an error // while compiling this function std::string call_name = method.qualname().name(); if (self) { auto atoms = method.qualname().atoms(); // There should be at least a ClassName.method_name TORCH_INTERNAL_ASSERT(atoms.size() >= 2); call_name = atoms.at(atoms.size() - 2) + "." + atoms.at(atoms.size() - 1); } ErrorReport::CallStack call(call_name, def.range()); to_ir(def, _resolver, self, method); }; auto name = prefix ? QualifiedName(*prefix, def.name().name()) : QualifiedName(def.name().name()); if (shouldMangle) { // If `shouldMangle` is set, we should generate a unique name for this // function if there is already an existing one. if (find_function(name)) { name = mangle(name); } } auto graph = std::make_shared(); graph->set_op_version(operator_set_version); auto fn = std::make_unique(std::move(name), graph, creator); if (self) { // Register this as a method on `self`'s type if (type == CompilationUnit::FunctionType::Hook) { self->getClassType()->addForwardHook(fn.get()); } else if (type == CompilationUnit::FunctionType::PreHook) { self->getClassType()->addForwardPreHook(fn.get()); } else { self->getClassType()->addMethod(fn.get()); } } return fn; } std::vector CompilationUnit::define( const std::optional& prefix, const std::vector& properties, const std::vector& propResolvers, const std::vector& definitions, const std::vector& defResolvers, const Self* self, bool shouldMangle, std::optional operator_set_version) { TORCH_INTERNAL_ASSERT(definitions.size() == defResolvers.size()); TORCH_INTERNAL_ASSERT(properties.size() == propResolvers.size()); std::vector functions; std::unordered_map function_table; // Records fn in function_table, functions and with register_function. // This is done several times below, so this lambda helps avoid repeating // code. auto record_function = [&](std::unique_ptr fn) { function_table[fn->name()] = fn.get(); functions.emplace_back(fn.get()); this->register_function(std::move(fn)); }; for (const auto i : c10::irange(properties.size())) { PropertyPair property_fns = define_property( prefix, properties[i], propResolvers[i], self, function_table, shouldMangle); auto& getter_fn = property_fns.getGetter(); auto& setter_fn = property_fns.getSetter(); record_function(std::move(getter_fn)); if (setter_fn) { record_function(std::move(setter_fn)); } } for (const auto i : c10::irange(definitions.size())) { auto fn = define( prefix, definitions[i], defResolvers[i], self, function_table, shouldMangle, CompilationUnit::FunctionType::Method, operator_set_version); record_function(std::move(fn)); } // We need to compile `__init__` first, since it can determine what attributes // are available to other methods. So reorder the definitions accordingly. for (auto& kv : function_table) { if (kv.first == "__init__") { kv.second->ensure_defined(); } } for (Function* function : functions) { function->ensure_defined(); } return functions; } void CompilationUnit::define_hooks( const std::optional& prefix, const std::vector& hookDefs, const std::vector& hookResolvers, const std::vector& preHookDefs, const std::vector& preHookResolvers, const Self* self, bool shouldMangle) { TORCH_INTERNAL_ASSERT(hookDefs.size() == hookResolvers.size()); TORCH_INTERNAL_ASSERT(preHookDefs.size() == preHookResolvers.size()); std::vector functions; std::unordered_map function_table; // check hook for name collisions and redefinition auto check_collisions = [&](const Def& hook) -> Function* { auto name = prefix ? QualifiedName(*prefix, hook.name().name()).name() : QualifiedName(hook.name().name()).name(); // check if hook is already defined for this module auto found_hook = function_table.find(name); auto existing_hook = found_hook != function_table.end() ? found_hook->second : nullptr; // check if hook name is already defined on module as method if (existing_hook == nullptr) { TORCH_CHECK( self->getClassType()->findMethod(name) == nullptr && self->getClassType()->findHook(name) == nullptr, "Can't define hook: ", name, " on class: ", self->getClassType()->repr_str(), " because a method or hook with that name already exists."); } return existing_hook; }; // build_schema for checking auto build_schema = [&](const Def& hook_def, const ResolverPtr& hook_res) -> FunctionSchema { ScriptTypeParser typeParser(hook_res); FunctionSchema schema = typeParser.parseSchemaFromDef(hook_def, true /* skip_self*/); // need to add self as the first because we skipped it std::vector arguments; arguments.emplace_back( hook_def.decl().params()[0].ident().name(), self->getClassType()); arguments.insert( arguments.end(), schema.arguments().begin(), schema.arguments().end()); return schema.cloneWithArguments(arguments); }; // define hooks for (const auto i : c10::irange(hookDefs.size())) { // check to see if already defined this hook auto existing_fn = check_collisions(hookDefs[i]); if (existing_fn != nullptr) { // add it to class type again so it's called self->getClassType()->addForwardHook(existing_fn); continue; } // define hook auto fn = define( prefix, hookDefs[i], hookResolvers[i], self, function_table, shouldMangle, CompilationUnit::FunctionType::Hook); function_table[fn->name()] = fn.get(); functions.emplace_back(fn.get()); this->register_function(std::move(fn)); self->getClassType()->checkForwardHookSchema( i, build_schema(hookDefs[i], hookResolvers[i])); functions.back()->ensure_defined(); } // define pre_hooks for (const auto i : c10::irange(preHookDefs.size())) { // check to see if already defined this hook auto existing_fn = check_collisions(preHookDefs[i]); if (existing_fn != nullptr) { // add it to class type again so it's called self->getClassType()->addForwardPreHook(existing_fn); continue; } // define pre_hook auto fn = define( prefix, preHookDefs[i], preHookResolvers[i], self, function_table, shouldMangle, CompilationUnit::FunctionType::PreHook); function_table[fn->name()] = fn.get(); functions.emplace_back(fn.get()); this->register_function(std::move(fn)); self->getClassType()->checkForwardPreHookSchema( i, build_schema(preHookDefs[i], preHookResolvers[i])); functions.back()->ensure_defined(); } } std::vector CompilationUnit::define( const std::optional& prefix, const std::string& source, const ResolverPtr& resolver, const Self* self) { Parser p(std::make_shared(source, "", 1)); std::vector definitions; std::vector resolvers; while (p.lexer().cur().kind != TK_EOF) { auto def = Def(p.parseFunction(/*is_method=*/bool(self))); definitions.push_back(def); resolvers.push_back(resolver); } return define( prefix, /*properties=*/{}, /*propResolvers=*/{}, definitions, resolvers, self); } static void eraseListLiterals(std::shared_ptr& graph) { DepthFirstGraphNodeIterator it(graph); for (auto next_node = it.next(); next_node != nullptr;) { Node* node = next_node; next_node = it.next(); if (node->kind() == prim::EmptyListLiteral) { if (node->hasUses()) { TORCH_INTERNAL_ASSERT( node->output()->type()->isSubtypeOf(ListType::ofTensors())); auto li = graph->createList(TensorType::get(), {}); li->insertBefore(node); node->replaceAllUsesWith(li); } node->destroy(); } } } void runCleanupPasses(std::shared_ptr& to_clean) { liftClosures(to_clean); inlineForkedClosures(to_clean); if (getInlineEverythingMode()) { Inline(*to_clean); } // these exist temporarily in initial compilation eraseListLiterals(to_clean); // remove any uses of tuples that we inserted that are not needed LowerSimpleTuples(to_clean); // full constant propagation runs ops with mutable inputs if it can // prove that the inputs are not mutated anywhere in the graph. // if a mutating node is removed in the graph (e.g. constant prop inlined a // a constant if) then the next time constant prop is run it might be able // to run nodes it was not able to previously, and the graph may change // (jitter) So we run only constant prop w immutable types here bc // successive runs of immutable constant prop does not change the graph ConstantPropagationImmutableTypes(to_clean); // Constant Pooling pass must be after ConstantPropagation, which can create // new constants that needs to be pooled. ConstantPooling(to_clean); // For jitter CanonicalizeOutputs(to_clean); // Annotate aten::warns so that each has its unique ID. This enables us to // mimic Python behavior of only emitting each warning only once. AnnotateWarns(to_clean); } // we consider _N where N is a number, to be a non-meaningful name // and do not record it as a unique name. This allows python printing to // be able to export and import more consistently named graphs bool meaningfulName(const std::string& name) { if (name.empty()) return false; if (name[0] == '$') return false; if (name[0] != '_') return true; for (const auto i : c10::irange(1, name.size())) { if (!isdigit(name[i])) return true; } return false; } void CompilationUnit::define_interface( const c10::QualifiedName& qualifiedName, const ClassDef& classDef, ResolverPtr rcb, bool is_module) { ScriptTypeParser typeParser(std::move(rcb)); InterfaceTypePtr iface = InterfaceType::create(c10::QualifiedName(qualifiedName), is_module); for (const Stmt& stmt : classDef.body()) { if (stmt.kind() != TK_DEF) { throw( ErrorReport(stmt) << "interface declarations can only contain method definitions"); } auto method_def = Def(stmt); if (!method_def.decl().return_type().present()) { throw( ErrorReport(method_def) << "interface declarations must have a return type annotated."); } FunctionSchema schema = typeParser.parseSchemaFromDef(method_def, /* skip_self*/ true); // need to add self as the first because we skipped it std::vector arguments; arguments.emplace_back(method_def.decl().params()[0].ident().name(), iface); arguments.insert( arguments.end(), schema.arguments().begin(), schema.arguments().end()); iface->addMethod(schema.cloneWithArguments(std::move(arguments))); // we need to make sure everything but the last element is just string // literals (aka comments) unless there is "pass" in between auto stmts_size = method_def.statements().size(); for (size_t i = 0; i < stmts_size - 1; i++) { auto cur_statement = method_def.statements()[i]; if (cur_statement.kind() == TK_EXPR_STMT) { auto expr = ExprStmt(cur_statement).expr(); if (expr.kind() != TK_STRINGLITERAL) { throw( ErrorReport(method_def.range()) << "interfaces declarations should only contain a single 'pass' statement."); } } // if we see a "pass", we just stop there if (cur_statement.kind() == TK_PASS) { this->register_type(iface); return; } } if (method_def.statements()[stmts_size - 1].kind() != TK_PASS) { throw( ErrorReport(method_def.range()) << "interfaces declarations should contain 'pass' statement."); } } this->register_type(iface); } } // namespace torch::jit