1 #include <pybind11/detail/common.h>
2 #include <pybind11/pytypes.h>
3 #include <torch/csrc/jit/api/object.h>
4 #include <torch/csrc/jit/python/script_init.h>
5 #include <torch/csrc/utils/pybind.h>
6 
7 #include <caffe2/serialize/versions.h>
8 #include <torch/csrc/Device.h>
9 #include <torch/csrc/DynamicTypes.h>
10 #include <torch/csrc/jit/api/module.h>
11 #include <torch/csrc/jit/frontend/ir_emitter.h>
12 #include <torch/csrc/jit/frontend/sugared_value.h>
13 #include <torch/csrc/jit/mobile/code.h>
14 #include <torch/csrc/jit/mobile/compatibility/backport.h>
15 #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
16 #include <torch/csrc/jit/mobile/file_format.h>
17 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
18 #include <torch/csrc/jit/mobile/import.h>
19 #include <torch/csrc/jit/mobile/module.h>
20 #include <torch/csrc/jit/mobile/quantization.h>
21 #include <torch/csrc/jit/operator_upgraders/upgraders.h>
22 #include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
23 #include <torch/csrc/jit/operator_upgraders/utils.h>
24 #include <torch/csrc/jit/operator_upgraders/version_map.h>
25 #include <torch/csrc/jit/python/module_python.h>
26 #include <torch/csrc/jit/python/python_ivalue.h>
27 #include <torch/csrc/jit/python/python_sugared_value.h>
28 #include <torch/csrc/jit/serialization/export_bytecode.h>
29 #include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
30 #include <torch/csrc/jit/serialization/import.h>
31 #include <torch/csrc/jit/testing/file_check.h>
32 
33 #include <c10/util/Exception.h>
34 #include <c10/util/intrusive_ptr.h>
35 #include <c10/util/irange.h>
36 #include <torch/csrc/jit/frontend/parser.h>
37 #include <torch/csrc/jit/frontend/tracer.h>
38 #include <torch/csrc/jit/ir/constants.h>
39 #include <torch/csrc/jit/ir/graph_utils.h>
40 #include <torch/csrc/jit/ir/irparser.h>
41 #include <torch/csrc/jit/passes/inliner.h>
42 #include <torch/csrc/jit/passes/shape_analysis.h>
43 #include <torch/csrc/jit/python/pybind_utils.h>
44 #include <torch/csrc/jit/python/python_dict.h>
45 #include <torch/csrc/jit/python/python_list.h>
46 #include <torch/csrc/jit/python/python_tracer.h>
47 #include <torch/csrc/jit/runtime/graph_executor.h>
48 #include <torch/csrc/jit/runtime/instruction.h>
49 #include <torch/csrc/jit/runtime/interpreter.h>
50 #include <torch/csrc/jit/runtime/logging.h>
51 #include <torch/csrc/jit/serialization/export_bytecode.h>
52 #include <torch/csrc/jit/serialization/import_source.h>
53 #include <torch/csrc/jit/serialization/pickle.h>
54 #include <torch/csrc/jit/serialization/python_print.h>
55 #include <torch/csrc/jit/testing/hooks_for_testing.h>
56 
57 #include <torch/csrc/api/include/torch/ordered_dict.h>
58 
59 #include <ATen/ATen.h>
60 #include <ATen/core/function_schema.h>
61 #include <ATen/core/ivalue.h>
62 #include <ATen/core/qualified_name.h>
63 
64 #include <pybind11/functional.h>
65 #include <pybind11/pybind11.h>
66 #include <pybind11/stl.h>
67 #include <pybind11/stl_bind.h>
68 #include <torch/csrc/jit/mobile/train/export_data.h>
69 #include <cstddef>
70 #include <memory>
71 #include <sstream>
72 #include <string>
73 #include <tuple>
74 #include <utility>
75 #include <vector>
76 
77 #include <fmt/format.h>
78 
79 namespace torch::jit {
80 
81 using ::c10::Argument;
82 using ::c10::FunctionSchema;
83 
84 using FunctionDefaults = std::unordered_map<std::string, py::object>;
85 using ClassMethodDefaults = std::unordered_map<std::string, FunctionDefaults>;
86 
87 namespace {
88 
89 // A resolver that will inspect the outer Python scope to find `name`.
90 struct PythonResolver : public Resolver {
PythonResolvertorch::jit::__anond03ab9480111::PythonResolver91   explicit PythonResolver(ResolutionCallback rcb) : rcb_(std::move(rcb)) {}
92 
93   /**
94    * While compiling classes, the class type we're compiling will not be
95    * available in Python, since we haven't fowner_ defining the class yet. So
96    * in order to make the class type available to its own methods, we need to
97    * explicitly resolve it.
98    *
99    * @param rcb Python function to resolve a name to its Python object in the
100    *            enclosing scope
101    * @param classname The unqualified classname of the class currently being
102    *                  compiled.
103    * @param classType The class's type.
104    */
PythonResolvertorch::jit::__anond03ab9480111::PythonResolver105   explicit PythonResolver(
106       ResolutionCallback rcb,
107       std::string classname,
108       ClassTypePtr classType)
109       : rcb_(std::move(rcb)),
110         classname_(std::move(classname)),
111         classType_(std::move(classType)) {}
112 
resolveValuetorch::jit::__anond03ab9480111::PythonResolver113   std::shared_ptr<SugaredValue> resolveValue(
114       const std::string& name,
115       GraphFunction& m,
116       const SourceRange& loc) override {
117     pybind11::gil_scoped_acquire ag;
118     py::object obj = rcb_(name);
119     if (obj.is_none()) {
120       return nullptr;
121     }
122     return toSugaredValue(obj, m, loc);
123   }
124 
isNamedTupleClasstorch::jit::__anond03ab9480111::PythonResolver125   static bool isNamedTupleClass(py::object obj) {
126     auto tuple_type = reinterpret_cast<PyObject*>(&PyTuple_Type);
127     return PyObject_IsSubclass(obj.ptr(), tuple_type) &&
128         py::hasattr(obj, "_fields");
129   }
130 
resolveTypeFromObjecttorch::jit::__anond03ab9480111::PythonResolver131   TypePtr resolveTypeFromObject(const py::object& obj, const SourceRange& loc) {
132     if (py::isinstance<ScriptClass>(obj)) {
133       auto script_class = py::cast<ScriptClass>(obj);
134       return script_class.class_type_.type_;
135     }
136 
137     py::bool_ isClass = py::module::import("inspect").attr("isclass")(obj);
138     if (!py::cast<bool>(isClass)) {
139       return nullptr;
140     }
141 
142     if (isNamedTupleClass(obj)) {
143       return registerNamedTuple(obj, loc, rcb_);
144     }
145 
146     auto qualifiedName = c10::QualifiedName(
147         py::cast<std::string>(py::module::import("torch._jit_internal")
148                                   .attr("_qualified_name")(obj)));
149 
150     return get_python_cu()->get_type(qualifiedName);
151   }
152 
resolveTypetorch::jit::__anond03ab9480111::PythonResolver153   TypePtr resolveType(const std::string& name, const SourceRange& loc)
154       override {
155     if (classType_ && name == classname_) {
156       return classType_;
157     }
158     pybind11::gil_scoped_acquire ag;
159     py::object obj = rcb_(name);
160     if (obj.is_none()) {
161       return nullptr;
162     }
163 
164     auto annotation_type =
165         py::module::import("torch.jit.annotations")
166             .attr("try_ann_to_type")(obj, loc, py::cpp_function(rcb_));
167     if (!annotation_type.is_none()) {
168       return py::cast<TypePtr>(annotation_type);
169     }
170     return resolveTypeFromObject(obj, loc);
171   }
172 
173  private:
174   ResolutionCallback rcb_;
175   std::string classname_;
176   ClassTypePtr classType_;
177 };
178 
pythonResolver(const ResolutionCallback & rcb)179 std::shared_ptr<PythonResolver> pythonResolver(const ResolutionCallback& rcb) {
180   return std::make_shared<PythonResolver>(rcb);
181 }
pythonResolver(const ResolutionCallback & rcb,std::string classname,ClassTypePtr classType)182 std::shared_ptr<PythonResolver> pythonResolver(
183     const ResolutionCallback& rcb,
184     std::string classname,
185     ClassTypePtr classType) {
186   return std::make_shared<PythonResolver>(
187       rcb, std::move(classname), std::move(classType));
188 }
189 
checkOverloadDecl(const Decl & new_decl,const Decl & old_decl)190 void checkOverloadDecl(const Decl& new_decl, const Decl& old_decl) {
191   const auto& new_params = new_decl.params();
192   const auto& old_params = old_decl.params();
193 
194   // TODO. same number of parameters not strictly necessary.
195   TORCH_INTERNAL_ASSERT(
196       new_params.size() == old_params.size(),
197       "Overload must have same number of parameters\n",
198       new_decl.range(),
199       old_decl.range());
200   for (const auto i : c10::irange(new_decl.params().size())) {
201     TORCH_INTERNAL_ASSERT(
202         new_params[i].ident().name() == old_params[i].ident().name(),
203         "Overload parameters must have the same names\n",
204         new_params[i].ident(),
205         old_params[i].ident());
206   }
207 }
208 
tryCalculateDefaultParam(const Argument & arg,const py::object & def_value)209 std::optional<IValue> tryCalculateDefaultParam(
210     const Argument& arg,
211     const py::object& def_value) {
212   auto n = arg.N();
213   auto list_type = arg.type()->cast<ListType>();
214   try {
215     if (n && *n > 0 && list_type) {
216       // BroadcastingList, allow default values T for arg types List[T]
217       return toIValue(def_value, list_type->getElementType());
218     } else {
219       return toIValue(def_value, arg.type());
220     }
221   } catch (...) {
222     return std::nullopt;
223   }
224 }
225 
226 // An overloaded function may have a default that does not subtype all overloads
227 // @overload
228 // def foo(x: str)
229 // def foo(x=1)
calcOverloadedFunctionDefaults(const FunctionSchema & schema,const FunctionDefaults & defaults)230 FunctionDefaults calcOverloadedFunctionDefaults(
231     const FunctionSchema& schema,
232     const FunctionDefaults& defaults) {
233   FunctionDefaults updated_defaults;
234   for (const auto& arg : schema.arguments()) {
235     const std::string& arg_name = arg.name();
236     auto value = defaults.find(arg_name);
237     if (value == defaults.end()) {
238       continue;
239     }
240     auto maybe_ivalue = tryCalculateDefaultParam(arg, value->second);
241     if (maybe_ivalue) {
242       updated_defaults[arg_name] = value->second;
243     }
244   }
245   return updated_defaults;
246 }
247 
248 } // namespace
249 
checkMutableFunctionDefault(const py::object & def_arg)250 bool checkMutableFunctionDefault(const py::object& def_arg) {
251   if (py::isinstance<py::list>(def_arg) || py::isinstance<py::dict>(def_arg)) {
252     return true;
253   }
254   if (py::isinstance<py::tuple>(def_arg)) {
255     auto pytuple = def_arg.cast<py::tuple>();
256     for (py::handle t : pytuple) {
257       py::object obj = py::reinterpret_borrow<py::object>(t);
258       if (checkMutableFunctionDefault(obj)) {
259         return true;
260       }
261     }
262   }
263   return false;
264 }
265 
checkMutableFunctionDefault(const SourceRange & range,const Argument & arg,const py::object & def_arg)266 void checkMutableFunctionDefault(
267     const SourceRange& range,
268     const Argument& arg,
269     const py::object& def_arg) {
270   if (checkMutableFunctionDefault(def_arg) || arg.type()->cast<ClassType>()) {
271     throw(
272         ErrorReport(range)
273         << "Mutable default parameters are not supported because Python binds them to the function"
274         << " and they persist across function calls.\n As a workaround, make the default None and instantiate"
275         << " the default parameter within the body of the function. Found "
276         << def_arg.get_type() << " on parameter " << arg.name());
277   }
278 }
279 
getSchemaWithNameAndDefaults(const SourceRange & range,const FunctionSchema & schema,const std::optional<std::string> & new_name,const FunctionDefaults & default_args)280 FunctionSchema getSchemaWithNameAndDefaults(
281     const SourceRange& range,
282     const FunctionSchema& schema,
283     const std::optional<std::string>& new_name,
284     const FunctionDefaults& default_args) {
285   std::vector<Argument> new_args;
286   for (auto& arg : schema.arguments()) {
287     auto it = default_args.find(arg.name());
288     if (it != default_args.end()) {
289       checkMutableFunctionDefault(range, arg, it->second);
290       std::optional<IValue> value = tryCalculateDefaultParam(arg, it->second);
291       if (!value) {
292         ErrorReport error(range);
293         error << "Expected a default value of type " << arg.type()->repr_str()
294               << " on parameter \"" << arg.name() << "\".";
295         if (arg.is_inferred_type()) {
296           error << "Because \"" << arg.name()
297                 << "\" was not annotated with an explicit type "
298                 << "it is assumed to be type 'Tensor'.";
299         }
300         throw ErrorReport(error);
301       }
302       new_args.emplace_back(
303           arg.name(), arg.type(), arg.N(), *value, arg.kwarg_only());
304     } else {
305       new_args.push_back(arg);
306     }
307   }
308   return FunctionSchema(
309       new_name.value_or(schema.name()),
310       schema.overload_name(),
311       new_args,
312       schema.returns(),
313       schema.is_vararg(),
314       schema.is_varret());
315 }
316 
mergeDefaultsAndExtraParametersToOverloadDecl(const Decl & overload_decl,const Decl & impl_decl,const FunctionDefaults & defaults)317 static Decl mergeDefaultsAndExtraParametersToOverloadDecl(
318     const Decl& overload_decl,
319     const Decl& impl_decl,
320     const FunctionDefaults& defaults) {
321   std::vector<Param> adjusted_params;
322   const auto& overload_params = overload_decl.params();
323   const auto& impl_params = impl_decl.params();
324 
325   // following PEP specification that the following should work:
326   // @overload
327   // def mouse_event(x1: int, y1: int) -> ClickEvent: ...
328   // ...
329   // def mouse_event(x1: int, y1: int, x2: Optional[int] = None, y2:
330   // Optional[int] = None)
331   TORCH_CHECK(
332       overload_params.size() <= impl_params.size(),
333       "Overload should not have more parameters than implementation function",
334       overload_decl.range(),
335       impl_decl.range());
336 
337   for (const auto i : c10::irange(overload_params.size())) {
338     auto overload_name = overload_params[i].ident().name();
339     auto impl_name = impl_params[i].ident().name();
340     if (overload_name != impl_name) {
341       throw(
342           ErrorReport(overload_decl.range())
343           << "Overload parameters must have the same names. "
344           << "Found " << overload_name << " and " << impl_name
345           << " on argument " << i);
346     }
347     adjusted_params.push_back(overload_params[i]);
348   }
349   for (size_t i = overload_params.size(); i < impl_params.size(); ++i) {
350     if (!defaults.count(impl_params[i].ident().name())) {
351       throw(
352           ErrorReport(impl_decl.range())
353           << "Expected to find default parameter on argument"
354           << impl_params[i].ident().name()
355           << " because it is not defined on the overloaded declaration");
356     }
357     if (!impl_params[i].type().present()) {
358       throw(
359           ErrorReport(impl_decl.range())
360           << "Parameters not specified on the overloaded declaration must have a type annotation in the implementation function."
361           << " Did not find type for param " << impl_params[i].ident().name());
362     }
363     adjusted_params.push_back(impl_params[i]);
364   }
365   return Decl::create(
366       overload_decl.range(),
367       List<Param>::create(overload_decl.range(), adjusted_params),
368       overload_decl.return_type());
369 }
370 
script_compile_overloaded_function(const c10::QualifiedName & name,const Decl & overload_decl,const Def & implementation_def,const ResolutionCallback & rcb,const FunctionDefaults & implementation_defaults,const py::object & signature)371 static StrongFunctionPtr script_compile_overloaded_function(
372     const c10::QualifiedName& name,
373     const Decl& overload_decl,
374     const Def& implementation_def,
375     const ResolutionCallback& rcb,
376     const FunctionDefaults& implementation_defaults,
377     const py::object& signature) {
378   if (signature.is_none()) {
379     throw(
380         ErrorReport(overload_decl.range())
381         << "Must explicitly add type annotations to overloaded functions");
382   }
383 
384   auto adjusted_decl = mergeDefaultsAndExtraParametersToOverloadDecl(
385       overload_decl, implementation_def.decl(), implementation_defaults);
386   auto new_def = implementation_def.withDecl(adjusted_decl);
387   auto cu = get_python_cu();
388   auto defined_functions = cu->define(
389       QualifiedName(name.prefix()),
390       /*properties=*/{},
391       /*propResolvers=*/{},
392       {new_def},
393       {pythonResolver(rcb)},
394       nullptr,
395       true);
396   TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
397   auto& defined = defined_functions[0];
398   FunctionDefaults updated_defaults = calcOverloadedFunctionDefaults(
399       defined->getSchema(), implementation_defaults);
400   defined->setSchema(getSchemaWithNameAndDefaults(
401       new_def.range(),
402       defined->getSchema(),
403       new_def.name().name(),
404       updated_defaults));
405   StrongFunctionPtr ret(std::move(cu), defined);
406   didFinishEmitFunction(ret);
407   return ret;
408 }
409 
script_compile_function(const c10::QualifiedName & name,const Def & def,const FunctionDefaults & defaults,const ResolutionCallback & rcb)410 static StrongFunctionPtr script_compile_function(
411     const c10::QualifiedName& name,
412     const Def& def,
413     const FunctionDefaults& defaults,
414     const ResolutionCallback& rcb) {
415   auto cu = get_python_cu();
416   auto defined_functions = cu->define(
417       QualifiedName(name.prefix()),
418       /*properties=*/{},
419       /*propResolvers=*/{},
420       {def},
421       {pythonResolver(rcb)},
422       nullptr,
423       true);
424   TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
425   auto& defined = defined_functions[0];
426   defined->setSchema(getSchemaWithNameAndDefaults(
427       def.range(), defined->getSchema(), def.name().name(), defaults));
428   StrongFunctionPtr ret(std::move(cu), defined);
429   didFinishEmitFunction(ret);
430   return ret;
431 }
432 
433 struct VISIBILITY_HIDDEN ModuleSelf : public Self {
ModuleSelftorch::jit::ModuleSelf434   ModuleSelf(std::shared_ptr<ConcreteModuleType> concreteType)
435       : Self(), concreteType_(std::move(concreteType)) {}
436 
makeSugaredtorch::jit::ModuleSelf437   std::shared_ptr<SugaredValue> makeSugared(Value* v) const override {
438     v->setType(getClassType());
439     return std::make_shared<ModuleValue>(v, concreteType_);
440   }
441 
getClassTypetorch::jit::ModuleSelf442   ClassTypePtr getClassType() const override {
443     return concreteType_->getJitType()->expect<ClassType>();
444   }
445 
446  private:
447   std::shared_ptr<ConcreteModuleType> concreteType_;
448 };
449 
_propagate_shapes(Graph & graph,std::vector<at::Tensor> inputs,bool with_grad=false)450 static std::shared_ptr<Graph> _propagate_shapes(
451     Graph& graph,
452     std::vector<at::Tensor> inputs,
453     bool with_grad = false) {
454   Stack stack(inputs.begin(), inputs.end());
455   auto retval = graph.copy();
456   setInputTensorTypes(*retval, stack, /*complete=*/false);
457   PropagateInputShapes(retval);
458   return retval;
459 }
460 
_propagate_and_assign_input_shapes(Graph & graph,const std::vector<at::Tensor> & inputs,const std::vector<int> & param_count_list,bool with_grad=false,bool propagate=true)461 static std::shared_ptr<Graph> _propagate_and_assign_input_shapes(
462     Graph& graph,
463     const std::vector<at::Tensor>& inputs,
464     const std::vector<int>& param_count_list,
465     bool with_grad = false,
466     bool propagate = true) {
467   auto retval = graph.copy();
468   setInputTensorTypes(
469       *retval, fmap<IValue>(inputs), /*complete=*/true, param_count_list);
470   if (propagate) {
471     PropagateInputShapes(retval);
472   }
473   return retval;
474 }
475 
addFunctionToModule(Module & module,const StrongFunctionPtr & func)476 void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
477   // Make a graph with a fake self argument
478   auto graph = toGraphFunction(*func.function_).graph()->copy();
479   auto v = graph->insertInput(0, "self");
480   v->setType(module._ivalue()->type());
481   const auto name = QualifiedName(*module.type()->name(), "forward");
482   auto method =
483       module._ivalue()->compilation_unit()->create_function(name, graph);
484   module.type()->addMethod(method);
485 }
486 
487 // this is used in our test suite to check that we correctly preserved type tags
ivalue_tags_match(const Module & lhs,const Module & rhs)488 bool ivalue_tags_match(const Module& lhs, const Module& rhs) {
489   struct Work {
490     IValue a;
491     IValue b;
492   };
493   std::unordered_set<const void*> visited;
494   std::vector<Work> work = {{lhs._ivalue(), rhs._ivalue()}};
495   while (!work.empty()) {
496     Work item = work.back();
497     work.pop_back();
498     if (item.a.isPtrType()) {
499       // uncomment to debug type matching errors
500       // std::cout << "MATCHING " << /*item.a <<*/ "(" << *item.a.type() << ") "
501       //          << item.a.internalToPointer() << " " << /*item.b <<*/ " ("
502       //          << *item.b.type() << ") " << item.b.internalToPointer() <<
503       //          "\n";
504 
505       if (visited.count(item.a.internalToPointer())) {
506         continue;
507       }
508       visited.emplace(item.a.internalToPointer());
509     }
510     if (!unshapedType(item.b.type())
511              ->isSubtypeOf(unshapedType(item.b.type()))) {
512       // Since named types are saved and loaded in the test suite, we cannot
513       // expect them to be equal. We should still check their slots however.
514       if (!item.a.type()->cast<c10::NamedType>()) {
515         return false;
516       }
517     }
518     // check tags for objects that contain subobjects
519     if (item.a.isObject()) {
520       auto ao = item.a.toObject();
521       auto bo = item.b.toObject();
522       for (size_t i = 0; i < ao->slots().size(); ++i) {
523         work.emplace_back(Work{ao->slots().at(i), bo->slots().at(i)});
524       }
525     } else if (item.a.isTuple()) {
526       auto at = item.a.toTuple();
527       auto bt = item.b.toTuple();
528       for (size_t i = 0; i < at->elements().size(); ++i) {
529         work.emplace_back(Work{at->elements().at(i), bt->elements().at(i)});
530       }
531     } else if (item.a.isList()) {
532       auto al = item.a.toList();
533       auto bl = item.b.toList();
534       for (const auto i : c10::irange(al.size())) {
535         work.emplace_back(Work{al.get(i), bl.get(i)});
536       }
537     } else if (item.a.isGenericDict()) {
538       auto ad = item.a.toGenericDict();
539       auto bd = item.b.toGenericDict();
540       for (auto& item : ad) {
541         // Dictionaory keys cannot contain List/Dicts that require tags
542         // so we do not have to check them.
543         // Furthermore without ordered dicts it is expensive to find the
544         // equivalent key
545         work.emplace_back(Work{item.value(), bd.at(item.key())});
546       }
547     } else if (item.a.isFuture()) {
548       auto af = item.a.toFuture();
549       auto bf = item.b.toFuture();
550       af->wait();
551       bf->wait();
552       work.emplace_back(Work{af->value(), bf->value()});
553     }
554   }
555 
556   return true;
557 }
558 
559 // helper used to implement ._parameters, ._buffers, ._modules dicts
560 // inside of script nn.Module
561 template <typename Policy>
562 struct slot_dict_impl {
slot_dict_impltorch::jit::slot_dict_impl563   slot_dict_impl(ModulePtr module) : module_(std::move(module)) {}
containstorch::jit::slot_dict_impl564   bool contains(const std::string& name) const {
565     if (auto slot = module_->type()->findAttributeSlot(name)) {
566       if (Policy::valid(module_->type(), *slot, module_->getSlot(*slot))) {
567         return true;
568       }
569     }
570     return false;
571   }
572 
itemstorch::jit::slot_dict_impl573   std::vector<std::pair<std::string, py::object>> items() const {
574     std::vector<std::pair<std::string, py::object>> result;
575     for (size_t i = 0, N = module_->type()->numAttributes(); i < N; ++i) {
576       if (Policy::valid(module_->type(), i, module_->getSlot(i))) {
577         result.emplace_back(
578             module_->type()->getAttributeName(i),
579             toPyObject(module_->getSlot(i)));
580       }
581     }
582     return result;
583   }
584 
setattrtorch::jit::slot_dict_impl585   void setattr(const std::string& name, py::object value) {
586     const TypePtr& type = module_->type()->getAttribute(name);
587     Module(module_).setattr(name, toIValue(std::move(value), type));
588   }
589 
getattrtorch::jit::slot_dict_impl590   py::object getattr(const std::string& name) {
591     return toPyObject(Module(module_).attr(name));
592   }
593 
bindtorch::jit::slot_dict_impl594   static void bind(const py::module& m, const char* name) {
595     py::class_<slot_dict_impl<Policy>>(m, name)
596         .def(py::init(
597             [](Module& m) { return slot_dict_impl<Policy>(m._ivalue()); }))
598         .def("contains", &slot_dict_impl<Policy>::contains)
599         .def("items", &slot_dict_impl<Policy>::items)
600         .def("setattr", &slot_dict_impl<Policy>::setattr)
601         .def("getattr", &slot_dict_impl<Policy>::getattr);
602   }
603 
604  private:
605   ModulePtr module_;
606 };
607 
608 template <typename T>
debugMakeList(const T & list)609 py::list debugMakeList(const T& list) {
610   py::list result;
611   for (const auto& elem : list) {
612     result.append(py::cast(elem));
613   }
614   return result;
615 }
616 template <typename T>
debugMakeNamedList(const T & list)617 py::list debugMakeNamedList(const T& list) {
618   py::list result;
619   for (auto elem : list) {
620     result.append(py::cast(std::make_pair(elem.name, elem.value)));
621   }
622   return result;
623 }
624 template <typename T>
debugMakeSet(const T & list)625 py::set debugMakeSet(const T& list) {
626   py::set result;
627   for (const auto& elem : list) {
628     result.add(py::cast(elem));
629   }
630   return result;
631 }
632 
_jit_debug_module_iterators(Module & module)633 static py::dict _jit_debug_module_iterators(Module& module) {
634   py::dict result;
635   result["children"] = debugMakeList(module.children());
636   result["named_children"] = debugMakeNamedList(module.named_children());
637   result["modules"] = debugMakeList(module.modules());
638   result["named_modules"] = debugMakeNamedList(module.named_modules());
639 
640   result["parameters"] = debugMakeList(module.parameters(false));
641   result["named_parameters"] =
642       debugMakeNamedList(module.named_parameters(false));
643   result["parameters_r"] = debugMakeList(module.parameters(true));
644   result["named_parameters_r"] =
645       debugMakeNamedList(module.named_parameters(true));
646 
647   result["buffers"] = debugMakeList(module.buffers(false));
648   result["named_buffers"] = debugMakeNamedList(module.named_buffers(false));
649   result["buffers_r"] = debugMakeList(module.buffers(true));
650   result["named_buffers_r"] = debugMakeNamedList(module.named_buffers(true));
651 
652   result["named_attributes"] =
653       debugMakeNamedList(module.named_attributes(false));
654   result["named_attributes_r"] =
655       debugMakeNamedList(module.named_attributes(true));
656   return result;
657 }
658 
659 static constexpr std::array<const char*, 48> magic_method_names = {
660     "__lt__",      "__le__",      "__eq__",        "__ne__",
661     "__ge__",      "__gt__",      "__not__",       "__abs__",
662     "__add__",     "__and__",     "__floordiv__",  "__index__",
663     "__inv__",     "__invert__",  "__lshift__",    "__mod__",
664     "__mul__",     "__matmul__",  "__neg__",       "__or__",
665     "__pos__",     "__pow__",     "__rshift__",    "__sub__",
666     "__truediv__", "__xor__",     "__concat__",    "__contains__",
667     "__delitem__", "__getitem__", "__setitem__",   "__iadd__",
668     "__iand__",    "__iconcat__", "__ifloordiv__", "__ilshift__",
669     "__imod__",    "__imul__",    "__imatmul__",   "__ior__",
670     "__ipow__",    "__irshift__", "__isub__",      "__itruediv__",
671     "__ixor__",    "__str__",     "__len__",       "__repr__",
672 };
673 
674 struct DeepCopyMemoTable {
675   std::shared_ptr<IValue::HashIdentityIValueMap> map;
676 };
677 
pyIValueDeepcopy(const IValue & ivalue,const py::dict & memo)678 IValue pyIValueDeepcopy(const IValue& ivalue, const py::dict& memo) {
679   if (!memo.contains(py::str("__torch_script_memo_table"))) {
680     memo["__torch_script_memo_table"] =
681         DeepCopyMemoTable{std::make_shared<IValue::HashIdentityIValueMap>()};
682   }
683   auto& ivalue_memo =
684       *py::cast<DeepCopyMemoTable>(memo["__torch_script_memo_table"]).map;
685   return ivalue.deepcopy(ivalue_memo);
686 }
687 
extra_files_from_python(const py::dict & pydict)688 ExtraFilesMap extra_files_from_python(const py::dict& pydict) {
689   ExtraFilesMap r;
690   for (const auto& it : pydict) {
691     r[py::cast<std::string>(it.first)] = "";
692   }
693   return r;
694 }
695 
extra_files_to_python(const ExtraFilesMap & m,const py::dict & pydict)696 void extra_files_to_python(const ExtraFilesMap& m, const py::dict& pydict) {
697   // py::dict is pointer-like type so it gets modified despite const&
698   for (const auto& it : m) {
699     pydict[py::str(it.first)] = py::bytes(it.second);
700   }
701 }
702 
pyCompilationUnitDefine(CompilationUnit & cu,const std::string & src,const ResolutionCallback * rcb,const uint32_t _frames_up)703 void pyCompilationUnitDefine(
704     CompilationUnit& cu,
705     const std::string& src,
706     const ResolutionCallback* rcb,
707     const uint32_t _frames_up) {
708   if (rcb && *rcb) {
709     cu.define(std::nullopt, src, pythonResolver(*rcb), nullptr);
710   } else {
711     py::object py_default_rcb =
712         py::module::import("torch._jit_internal")
713             .attr("createResolutionCallbackFromFrame")(_frames_up);
714     auto default_rcb = py_default_rcb.cast<ResolutionCallback>();
715     cu.define(std::nullopt, src, pythonResolver(default_rcb), nullptr);
716   }
717 }
718 
719 // This function will copy bytes into a shared_ptr of chars aligned
720 // at kFlatbufferDataAlignmentBytes boundary (currently 16).
721 // This is required because tensors need to be aligned at 16 bytes boundary.
copyStr(const std::string & bytes)722 static std::shared_ptr<char> copyStr(const std::string& bytes) {
723   size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
724       kFlatbufferDataAlignmentBytes;
725 #ifdef _WIN32
726   std::shared_ptr<char> bytes_copy(
727       static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
728       _aligned_free);
729 #elif defined(__APPLE__)
730   void* p;
731   ::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
732   TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
733   std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
734 #else
735   std::shared_ptr<char> bytes_copy(
736       static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
737       free);
738 #endif
739   memcpy(bytes_copy.get(), bytes.data(), bytes.size());
740   return bytes_copy;
741 }
742 
initJitScriptBindings(PyObject * module)743 void initJitScriptBindings(PyObject* module) {
744   auto m = py::handle(module).cast<py::module>();
745 
746   // NOLINTNEXTLINE(bugprone-unused-raii)
747   py::class_<c10::Capsule>(m, "Capsule");
748 
749   auto object_class =
750       py::class_<Object>(m, "ScriptObject")
751           .def("_type", [](Object& o) { return o.type(); })
752           .def(
753               "_get_method",
754               [](Object& self, const std::string& name) -> Method {
755                 return self.get_method(name);
756               },
757               py::keep_alive<0, 1>())
758           .def(
759               "setattr",
760               [](Object& self, const std::string& name, py::object value) {
761                 if (self.type()->hasConstant(name)) {
762                   TORCH_CHECK(
763                       false,
764                       "Can't set constant '",
765                       name,
766                       "' which has value:",
767                       self.type()->getConstant(name));
768                 }
769                 TypePtr type = self.type()->getAttribute(name);
770                 try {
771                   auto ivalue = toIValue(std::move(value), type);
772                   self.setattr(name, ivalue);
773                 } catch (std::exception& e) {
774                   throw py::cast_error(c10::str(
775                       "Could not cast attribute '",
776                       name,
777                       "' to type ",
778                       type->repr_str(),
779                       ": ",
780                       e.what()));
781                 }
782               })
783           .def(
784               "getattr",
785               [](Object& self, const std::string& name) {
786                 try {
787                   return toPyObject(self.attr(name));
788                 } catch (const ObjectAttributeError& err) {
789                   throw AttributeError("%s", err.what());
790                 }
791               })
792           .def(
793               "__getattr__",
794               [](Object& self, const std::string& name) -> py::object {
795                 try {
796                   if (name == "__qualname__") {
797                     return py::cast(self.type()->name()->name());
798                   }
799                   if (auto method = self.find_method(name)) {
800                     return py::cast(*method);
801                   }
802                   if (self.has_property(name)) {
803                     auto prop = self.get_property(name);
804                     // wrap the Method into callable PyObject
805                     auto getter_func = py::cast(prop.getter_func);
806                     return getter_func();
807                   }
808                   return toPyObject(self.attr(name));
809                 } catch (const ObjectAttributeError& err) {
810                   throw AttributeError("%s", err.what());
811                 }
812               })
813           .def(
814               "__setattr__",
815               [](Object& self, const std::string& name, py::object value) {
816                 try {
817                   if (self.has_property(name)) {
818                     auto prop = self.get_property(name);
819                     if (!prop.setter_func.has_value()) {
820                       TORCH_CHECK(false, "can't set attribute");
821                     }
822                     // wrap the Method into callable PyObject
823                     auto setter_func = py::cast(prop.setter_func);
824                     setter_func(value);
825                     return;
826                   }
827 
828                   if (self.type()->hasConstant(name)) {
829                     TORCH_CHECK(
830                         false,
831                         "Can't set constant '",
832                         name,
833                         "' which has value:",
834                         self.type()->getConstant(name));
835                   }
836                   TypePtr type = self.type()->getAttribute(name);
837                   auto ivalue = toIValue(std::move(value), type);
838                   self.setattr(name, ivalue);
839                 } catch (const ObjectAttributeError& err) {
840                   throw AttributeError("%s", err.what());
841                 }
842               })
843           .def(
844               "hasattr",
845               [](Object& self, const std::string& name) {
846                 return self.hasattr(name);
847               })
848           .def(
849               "_has_method",
850               [](Object& self, const std::string& name) {
851                 return bool(self.find_method(name));
852               })
853           .def(
854               "_method_names",
855               [](Object& self) {
856                 return fmap(self.get_methods(), [](const Method& method) {
857                   return method.name();
858                 });
859               })
860           .def(
861               "_properties", [](Object& self) { return self.get_properties(); })
862           .def("__copy__", &Object::copy)
863           .def(
864               "__hash__",
865               [](const Object& self) {
866                 // Similar to Tensor's `__hash__`, which is `id()`.
867                 return std::hash<c10::ivalue::Object*>{}(self._ivalue().get());
868               })
869           .def(py::pickle(
870               [](const Object& self)
871                   -> std::tuple<py::object, std::string> { // __getstate__
872                 if (auto getstate_method = self.find_method("__getstate__")) {
873                   auto object_state = toPyObject((*getstate_method)(Stack{}));
874                   TORCH_INTERNAL_ASSERT(self.type()->name());
875                   return std::make_tuple(
876                       object_state, self.type()->name()->qualifiedName());
877                 }
878                 std::stringstream err;
879                 err << "Tried to serialize object ";
880                 if (auto qualname = self.type()->name()) {
881                   err << qualname->qualifiedName() << " ";
882                 }
883                 err << "which does not have a __getstate__ method defined!";
884                 throw std::runtime_error(err.str());
885               },
886               [](const std::tuple<py::object, std::string>& state_tup)
887                   -> Object {
888                 auto [state, qualname] = state_tup;
889                 auto class_type = getCustomClass(qualname);
890                 TORCH_CHECK(
891                     class_type,
892                     "Tried to deserialize class ",
893                     qualname,
894                     " which is not known to the runtime. "
895                     "If this is a custom C++ class, make "
896                     "sure the appropriate code is linked.");
897 
898                 auto self = Object(c10::ivalue::Object::create(
899                     c10::StrongTypePtr(
900                         std::shared_ptr<torch::jit::CompilationUnit>(),
901                         class_type),
902                     1));
903                 if (auto setstate_method = self.find_method("__setstate__")) {
904                   auto setstate_schema =
905                       setstate_method->function().getSchema();
906                   TORCH_INTERNAL_ASSERT(
907                       setstate_schema.arguments().size() == 2,
908                       "__setstate__ method for class ",
909                       class_type->repr_str(),
910                       " must have exactly 2 arguments!");
911                   auto state_type = setstate_schema.arguments().at(1).type();
912                   (*setstate_method)(Stack{toIValue(state, state_type)});
913                   return self;
914                 }
915                 std::stringstream err;
916                 err << "Tried to deserialize object ";
917                 if (auto qualname = class_type->name()) {
918                   err << qualname->qualifiedName() << " ";
919                 }
920                 err << "which does not have a __setstate__ method defined!";
921                 throw std::runtime_error(err.str());
922               }));
923 
924   py::class_<Object::Property>(m, "ScriptObjectProperty")
925       .def_property_readonly(
926           "name", [](const Object::Property& self) { return self.name; })
927       .def_property_readonly(
928           "getter",
929           [](const Object::Property& self) { return self.getter_func; })
930       .def_property_readonly("setter", [](const Object::Property& self) {
931         return self.setter_func;
932       });
933 
934   // Special case __str__ and __repr__ to make sure we can print Objects/Modules
935   // regardless of if the user defined __str__/__repr__
936   using MagicMethodImplType = std::function<py::object(
937       const Object& self, py::args args, py::kwargs kwargs)>;
938 
939   std::unordered_map<std::string, MagicMethodImplType> special_magic_methods;
940   special_magic_methods.emplace(
941       "__str__",
942       [](const Object& self,
943          const py::args& args,
944          const py::kwargs& kwargs) -> py::object {
945         auto method = self.find_method("__str__");
946         if (!method) {
947           return py::str("ScriptObject <" + self.type()->str() + ">");
948         }
949         return invokeScriptMethodFromPython(*method, args, kwargs);
950       });
951 
952   special_magic_methods.emplace(
953       "__repr__",
954       [](const Object& self,
955          const py::args& args,
956          const py::kwargs& kwargs) -> py::object {
957         auto method = self.find_method("__repr__");
958         if (!method) {
959           std::stringstream ss;
960           ss << std::hex << static_cast<const void*>(&self);
961           return py::str("<torch.ScriptObject object at " + ss.str() + ">");
962         }
963         return invokeScriptMethodFromPython(*method, args, kwargs);
964       });
965 
966   for (const char* mm_name : magic_method_names) {
967     if (special_magic_methods.count(mm_name)) {
968       object_class.def(mm_name, special_magic_methods[mm_name]);
969     } else {
970       object_class.def(
971           mm_name,
972           [mm_name](
973               const Object& self,
974               const py::args& args,
975               const py::kwargs& kwargs) {
976             auto method = self.find_method(mm_name);
977             if (!method) {
978               std::string msg = fmt::format(
979                   "'{}' is not implemented for {}",
980                   mm_name,
981                   self.type()->str());
982               throw c10::NotImplementedError(msg);
983             }
984             return invokeScriptMethodFromPython(*method, args, kwargs);
985           });
986     }
987   }
988 
989   // NOLINTNEXTLINE(bugprone-unused-raii)
990   py::class_<DeepCopyMemoTable>(m, "DeepCopyMemoTable");
991 
992   py::class_<UpgraderEntry>(m, "_UpgraderEntry")
993       .def(py::init<int, std::string, std::string>())
994       .def_property_readonly(
995           "bumped_at_version",
996           [](const UpgraderEntry& self) { return self.bumped_at_version; })
997       .def_property_readonly(
998           "upgrader_name",
999           [](const UpgraderEntry& self) { return self.upgrader_name; })
1000       .def_property_readonly("old_schema", [](const UpgraderEntry& self) {
1001         return self.old_schema;
1002       });
1003 
1004   py::class_<UpgraderRange>(m, "_UpgraderRange")
1005       .def(py::init<int, int>())
1006       .def_property_readonly(
1007           "min_version",
1008           [](const UpgraderRange& self) { return self.min_version; })
1009       .def_property_readonly("max_version", [](const UpgraderRange& self) {
1010         return self.max_version;
1011       });
1012 
1013   object_class.def(
1014       "__deepcopy__", [](const Object& self, const py::dict& memo) {
1015         return Object(
1016             pyIValueDeepcopy(IValue(self._ivalue()), memo).toObject());
1017       });
1018 
1019   // Used by torch.package to save ScriptModule objects in unified format.
1020   py::class_<ScriptModuleSerializer>(m, "ScriptModuleSerializer")
1021       .def(py::init<caffe2::serialize::PyTorchStreamWriter&>())
1022       .def("serialize", &ScriptModuleSerializer::serialize_unified_format)
1023       .def(
1024           "write_files",
1025           &ScriptModuleSerializer::writeFiles,
1026           py::arg("code_dir") = ".data/ts_code/code/")
1027       .def(
1028           "storage_context",
1029           &ScriptModuleSerializer::storage_context,
1030           pybind11::return_value_policy::reference_internal);
1031 
1032   // Used by torch.package to coordinate sharing of storages between eager
1033   // and ScriptModules.
1034   py::class_<
1035       SerializationStorageContext,
1036       std::shared_ptr<SerializationStorageContext>>(
1037       m, "SerializationStorageContext")
1038       .def("has_storage", &SerializationStorageContext::hasStorage)
1039       .def("get_or_add_storage", &SerializationStorageContext::getOrAddStorage);
1040 
1041   // torch.jit.ScriptModule is a subclass of this C++ object.
1042   // Methods here are prefixed with _ since they should not be
1043   // public.
1044   py::class_<Module, Object>(m, "ScriptModule")
1045       .def(py::init<std::string, std::shared_ptr<CompilationUnit>, bool>())
1046       .def(
1047           "save",
1048           [](Module& m,
1049              const std::string& filename,
1050              const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1051             m.save(filename, _extra_files);
1052           },
1053           py::arg("filename"),
1054           py::arg("_extra_files") = ExtraFilesMap())
1055       .def(
1056           "save_to_buffer",
1057           [](Module& m, const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1058             std::ostringstream buf;
1059             m.save(buf, _extra_files);
1060             return py::bytes(buf.str());
1061           },
1062           py::arg("_extra_files") = ExtraFilesMap())
1063       .def(
1064           "_save_for_mobile",
1065           [](Module& m,
1066              const std::string& filename,
1067              const ExtraFilesMap& _extra_files = ExtraFilesMap(),
1068              bool _save_mobile_debug_info = false,
1069              bool _use_flatbuffer = false) {
1070             m._save_for_mobile(
1071                 filename,
1072                 _extra_files,
1073                 _save_mobile_debug_info,
1074                 _use_flatbuffer);
1075           },
1076           py::arg("filename"),
1077           py::arg("_extra_files") = ExtraFilesMap(),
1078           py::arg("_save_mobile_debug_info") = false,
1079           py::arg("_use_flatbuffer") = false)
1080       .def(
1081           "_save_to_buffer_for_mobile",
1082           [](Module& m,
1083              const ExtraFilesMap& _extra_files = ExtraFilesMap(),
1084              bool _save_mobile_debug_info = false,
1085              bool _use_flatbuffer = false) {
1086             std::ostringstream buf;
1087             m._save_for_mobile(
1088                 buf, _extra_files, _save_mobile_debug_info, _use_flatbuffer);
1089             return py::bytes(buf.str());
1090           },
1091           py::arg("_extra_files") = ExtraFilesMap(),
1092           py::arg("_save_mobile_debug_info") = false,
1093           py::arg("_use_flatbuffer") = false)
1094       .def("_set_optimized", &Module::set_optimized)
1095       .def(
1096           "dump",
1097           &Module::dump,
1098           py::arg("code") = true,
1099           py::arg("attrs") = true,
1100           py::arg("params") = true)
1101       .def(
1102           "dump_to_str",
1103           &Module::dump_to_str,
1104           py::arg("code") = true,
1105           py::arg("attrs") = true,
1106           py::arg("params") = true)
1107       .def(
1108           "_replicate_for_data_parallel",
1109           [](Module& module) {
1110             const ModulePtr& obj = module._ivalue();
1111             auto copy = c10::ivalue::Object::create(
1112                 c10::StrongTypePtr(obj->compilation_unit(), obj->type()),
1113                 obj->slots().size());
1114             for (size_t i = 0; i < obj->slots().size(); ++i) {
1115               copy->setSlot(i, obj->getSlot(i));
1116             }
1117             return Module(std::move(copy));
1118           })
1119       .def(
1120           "get_debug_state",
1121           [](Module& self) {
1122             if (auto m = self.find_method("forward")) {
1123               return m->get_executor().getDebugState();
1124             }
1125             throw std::runtime_error(
1126                 "Attempted to call get_debug_state on a Module without a compiled forward()");
1127           })
1128       .def(
1129           "_define",
1130           [](Module& m,
1131              std::shared_ptr<ConcreteModuleType> concreteType,
1132              const std::string& script,
1133              const ResolutionCallback& rcb) {
1134             const auto self = ModuleSelf(std::move(concreteType));
1135             m._ivalue()->compilation_unit()->define(
1136                 m.type()->name(), script, pythonResolver(rcb), &self);
1137             didFinishEmitModule(m);
1138           })
1139       .def(
1140           "_register_attribute",
1141           [](Module& m,
1142              const std::string& name,
1143              const TypePtr& type,
1144              py::handle value) {
1145             m.register_attribute(name, type, toIValue(value, type));
1146           })
1147       .def(
1148           "_create_method_from_trace",
1149           [](Module& self,
1150              const std::string& name,
1151              const py::function& func,
1152              const py::tuple& input_tuple,
1153              const py::function& var_name_lookup_fn,
1154              bool strict,
1155              bool force_outplace,
1156              const std::vector<std::string>& argument_names,
1157              bool store_inputs) {
1158             // prereq: Module's buffers and parameters are unique
1159             // this was ensured in python before calling this function
1160             auto typed_inputs = toTraceableStack(input_tuple);
1161 
1162             std::shared_ptr<Graph> graph =
1163                 std::get<0>(tracer::createGraphByTracing(
1164                     func,
1165                     typed_inputs,
1166                     var_name_lookup_fn,
1167                     strict,
1168                     force_outplace,
1169                     &self,
1170                     argument_names));
1171             const auto method_name = QualifiedName(*self.type()->name(), name);
1172             auto fn = self._ivalue()->compilation_unit()->create_function(
1173                 method_name, graph);
1174             self.type()->addMethod(fn);
1175             if (store_inputs) {
1176               self.store_traced_inputs(name, typed_inputs);
1177             }
1178             didFinishEmitModule(self);
1179           },
1180           py::arg("name"),
1181           py::arg("func"),
1182           py::arg("input_tuple"),
1183           py::arg("var_name_lookup_fn"),
1184           py::arg("strict"),
1185           py::arg("force_outplace"),
1186           py::arg("argument_names") = std::vector<std::string>(),
1187           py::arg("store_inputs"))
1188       .def(
1189           "_create_method_from_trace_with_dict",
1190           [](Module& self,
1191              const std::string& name,
1192              const py::function& func,
1193              const py::dict& input_dict,
1194              const py::function& var_name_lookup_fn,
1195              bool strict,
1196              bool force_outplace,
1197              const std::vector<std::string>& argument_names,
1198              bool store_inputs) {
1199             // prereq: Module's buffers and parameters are unique
1200             // this was ensured in python before calling this function
1201             auto typed_inputs = toTraceableStack(input_dict);
1202 
1203             std::shared_ptr<Graph> graph =
1204                 std::get<0>(tracer::createGraphByTracingWithDict(
1205                     func,
1206                     input_dict,
1207                     typed_inputs,
1208                     var_name_lookup_fn,
1209                     strict,
1210                     force_outplace,
1211                     &self,
1212                     argument_names));
1213             const auto method_name = QualifiedName(*self.type()->name(), name);
1214             auto fn = self._ivalue()->compilation_unit()->create_function(
1215                 method_name, graph);
1216             if (store_inputs) {
1217               self.store_traced_inputs(name, typed_inputs);
1218             }
1219             self.type()->addMethod(fn);
1220             didFinishEmitModule(self);
1221           },
1222           py::arg("name"),
1223           py::arg("func"),
1224           py::arg("input_dict"),
1225           py::arg("var_name_lookup_fn"),
1226           py::arg("strict"),
1227           py::arg("force_outplace"),
1228           py::arg("argument_names") = std::vector<std::string>(),
1229           py::arg("store_inputs"))
1230       .def(
1231           "_get_forward_hooks",
1232           [](const Module& m) {
1233             std::vector<StrongFunctionPtr> funcs;
1234             for (auto& hook : m.type()->getForwardHooks()) {
1235               funcs.emplace_back(m.type()->compilation_unit(), hook);
1236             }
1237             return funcs;
1238           })
1239       .def(
1240           "_get_forward_pre_hooks",
1241           [](const Module& m) {
1242             std::vector<StrongFunctionPtr> funcs;
1243             for (auto& pre_hook : m.type()->getForwardPreHooks()) {
1244               funcs.emplace_back(m.type()->compilation_unit(), pre_hook);
1245             }
1246             return funcs;
1247           })
1248       .def(
1249           "_retrieve_traced_inputs",
1250           [](const Module& m) {
1251             return ScriptDict(m.retrieve_traced_inputs());
1252           })
1253       .def_property_readonly(
1254           "code",
1255           [](Module& self) {
1256             std::vector<at::IValue> constants;
1257             PrintDepsTable deps;
1258             PythonPrint pp(constants, deps);
1259             pp.printNamedType(self.type());
1260             return pp.str();
1261           })
1262       .def_property_readonly(
1263           "code_with_constants",
1264           [](Module& self) {
1265             std::vector<at::IValue> constants;
1266             PrintDepsTable deps;
1267             PythonPrint pp(constants, deps);
1268             pp.printNamedType(self.type());
1269             std::map<std::string, at::IValue> consts;
1270             int i = 0;
1271             for (auto const& constant : constants) {
1272               consts["c" + std::to_string(i)] = constant;
1273               i += 1;
1274             }
1275             return std::make_tuple(pp.str(), std::move(consts));
1276           })
1277       .def("apply", &Module::apply)
1278       .def("__copy__", &Module::copy)
1279       .def(
1280           "__hash__",
1281           [](const Module& self) {
1282             // Similar to Tensor's `__hash__`, which is `id()`.
1283             return std::hash<c10::ivalue::Object*>{}(self._ivalue().get());
1284           })
1285       .def(
1286           "__eq__",
1287           [](const Module& self, const py::object& other) {
1288             // TODO: call UDF if it exists
1289             if (!py::isinstance<Module>(other)) {
1290               return false;
1291             }
1292             return self._ivalue().get() ==
1293                 py::cast<Module>(other)._ivalue().get();
1294           })
1295       .def(
1296           "__deepcopy__",
1297           [](const Module& self, const py::dict& memo) {
1298             return Module(
1299                 pyIValueDeepcopy(IValue(self._ivalue()), memo).toObject());
1300           })
1301       .def("children", &Module::children)
1302       .def_property_readonly("qualified_name", [](const Module& self) {
1303         return self.type()->name()->qualifiedName();
1304       });
1305 
1306   py::class_<mobile::Module>(m, "LiteScriptModule")
1307       .def(py::init<
1308            c10::intrusive_ptr<c10::ivalue::Object>,
1309            std::shared_ptr<mobile::CompilationUnit>>())
1310       .def(
1311           "find_method",
1312           [](mobile::Module& m, const std::string& method_name) {
1313             auto method = m.find_method(method_name);
1314             return method != std::nullopt;
1315           },
1316           py::arg("method_name"))
1317       .def(
1318           "run_method",
1319           [](mobile::Module& m,
1320              const std::string& method_name,
1321              const py::tuple& input_tuple) {
1322             Stack stack;
1323             for (auto& input : input_tuple) {
1324               stack.push_back(toTypeInferredIValue(input));
1325             }
1326             return m.get_method(method_name)(stack);
1327           },
1328           py::arg("method_name"),
1329           py::arg("input_tuple"))
1330       .def(
1331           "forward",
1332           [](mobile::Module& m, const py::tuple& input_tuple) {
1333             Stack stack;
1334             for (auto& input : input_tuple) {
1335               stack.push_back(toTypeInferredIValue(input));
1336             }
1337             return m.get_method("forward")(stack);
1338           },
1339           py::arg("input_tuple"));
1340 
1341   slot_dict_impl<detail::ParameterPolicy>::bind(m, "ParameterDict");
1342   slot_dict_impl<detail::BufferPolicy>::bind(m, "BufferDict");
1343   slot_dict_impl<detail::ModulePolicy>::bind(m, "ModuleDict");
1344 
1345   py::class_<ErrorReport, std::shared_ptr<ErrorReport>>(m, "ErrorReport")
1346       .def(py::init<SourceRange>())
1347       .def("what", &ErrorReport::what)
1348       .def_static("call_stack", ErrorReport::current_call_stack);
1349 
1350   py::class_<CompilationUnit, std::shared_ptr<CompilationUnit>>(
1351       m, "CompilationUnit")
1352       .def(
1353           py::init([](const std::string& lang, const uint32_t _frames_up) {
1354             auto cu = std::make_shared<CompilationUnit>();
1355             if (!lang.empty()) {
1356               pyCompilationUnitDefine(*cu, lang, nullptr, _frames_up);
1357             }
1358             return cu;
1359           }),
1360           py::arg("lang") = "",
1361           py::arg("_frames_up") = 0)
1362 
1363       .def(
1364           "find_function",
1365           [](std::shared_ptr<CompilationUnit> self, const std::string& name) {
1366             auto fn = self->find_function(QualifiedName(name));
1367             if (fn) {
1368               return std::optional<StrongFunctionPtr>(
1369                   StrongFunctionPtr(std::move(self), fn));
1370             } else {
1371               return std::optional<StrongFunctionPtr>(std::nullopt);
1372             }
1373           })
1374       .def(
1375           "__getattr__",
1376           [](std::shared_ptr<CompilationUnit> self, const std::string& name) {
1377             auto fn = self->find_function(QualifiedName(name));
1378             if (fn) {
1379               return StrongFunctionPtr(std::move(self), fn);
1380             } else {
1381               throw AttributeError(
1382                   "'CompilationUnit' has no attribute '%s'", name.c_str());
1383             }
1384           })
1385       .def(
1386           "get_functions",
1387           [](const std::shared_ptr<CompilationUnit>& self) {
1388             auto raw_functions = self->get_functions();
1389             std::vector<StrongFunctionPtr> functions;
1390             functions.reserve(raw_functions.size());
1391             for (auto fn : raw_functions) {
1392               if (fn) {
1393                 functions.emplace_back(self, fn);
1394               }
1395             }
1396             return functions;
1397           })
1398       .def("set_optimized", &CompilationUnit::set_optimized)
1399       .def(
1400           "define",
1401           pyCompilationUnitDefine,
1402           py::arg("src"),
1403           py::arg("rcb") = nullptr,
1404           py::arg("_frames_up") = 0)
1405       .def(
1406           "create_function",
1407           [](std::shared_ptr<CompilationUnit>& self,
1408              const std::string& qualified_name,
1409              std::shared_ptr<Graph> graph,
1410              bool should_mangle) {
1411             Function* fn = self->create_function(
1412                 qualified_name, std::move(graph), should_mangle);
1413             return StrongFunctionPtr(std::move(self), fn);
1414           },
1415           py::arg("qualified_name"),
1416           py::arg("graph"),
1417           py::arg("should_mangle") = false)
1418       .def(
1419           "get_interface",
1420           [](const std::shared_ptr<CompilationUnit>& self,
1421              const std::string& name) { return self->get_interface(name); })
1422       .def(
1423           "get_class",
1424           [](const std::shared_ptr<CompilationUnit>& self,
1425              const std::string& name) { return self->get_class(name); })
1426       .def(
1427           "drop_all_functions",
1428           [](const std::shared_ptr<CompilationUnit>& self) {
1429             self->drop_all_functions();
1430           });
1431 
1432   py::class_<StrongFunctionPtr>(m, "ScriptFunction", py::dynamic_attr())
1433       .def(
1434           "__call__",
1435           [](py::args args, const py::kwargs& kwargs) {
1436             HANDLE_TH_ERRORS
1437             // see: [pybind11 varargs]
1438             auto strongPtr = py::cast<StrongFunctionPtr>(args[0]);
1439             Function& callee = *strongPtr.function_;
1440             py::object result = invokeScriptFunctionFromPython(
1441                 callee, tuple_slice(std::move(args), 1), kwargs);
1442             return result;
1443             END_HANDLE_TH_ERRORS_PYBIND
1444           })
1445       .def(
1446           "save",
1447           [](const StrongFunctionPtr& self,
1448              const std::string& filename,
1449              const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1450             Module module("__torch__.PlaceholderModule");
1451             // [issue 27343]
1452             // Modules have 'training' attributes by default, but due to
1453             // https://github.com/pytorch/pytorch/issues/27343, functions end
1454             // up having a training attribute when they are loaded. This adds
1455             // a fake 'training' attribute that shouldn't be used, but prevents
1456             // jitter on saving and loading. Once that issue is fixed this can
1457             // be deleted.
1458             module.register_attribute("training", BoolType::get(), true);
1459             addFunctionToModule(module, self);
1460             module.save(filename, _extra_files);
1461           },
1462           py::arg("filename"),
1463           py::arg("_extra_files") = ExtraFilesMap())
1464       .def(
1465           "save_to_buffer",
1466           [](const StrongFunctionPtr& self,
1467              const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
1468             std::ostringstream buf;
1469             Module module("__torch__.PlaceholderModule");
1470             // see [issue 27343]
1471             module.register_attribute("training", BoolType::get(), true);
1472             addFunctionToModule(module, self);
1473             module.save(buf, _extra_files);
1474             return py::bytes(buf.str());
1475           },
1476           py::arg("_extra_files") = ExtraFilesMap())
1477       .def_property_readonly(
1478           "graph",
1479           [](const StrongFunctionPtr& self) {
1480             return toGraphFunction(*self.function_).graph();
1481           })
1482       .def_property_readonly(
1483           "inlined_graph",
1484           [](const StrongFunctionPtr& self) {
1485             auto g = toGraphFunction(*self.function_).graph()->copy();
1486             Inline(*g);
1487             return g;
1488           })
1489       .def_property_readonly(
1490           "schema",
1491           [](const StrongFunctionPtr& self) {
1492             return self.function_->getSchema();
1493           })
1494       .def_property_readonly(
1495           "code",
1496           [](const StrongFunctionPtr& self) {
1497             std::vector<at::IValue> constants;
1498             PrintDepsTable deps;
1499 
1500             PythonPrint pp(constants, deps);
1501             pp.printFunction(*self.function_);
1502             return pp.str();
1503           })
1504       .def(
1505           "get_debug_state",
1506           [](const StrongFunctionPtr& self) {
1507             return toGraphFunction(*self.function_)
1508                 .get_executor()
1509                 .getDebugState();
1510           })
1511       .def(
1512           "_debug_flush_compilation_cache",
1513           [](const StrongFunctionPtr& self) {
1514             toGraphFunction(*self.function_)
1515                 .get_executor()
1516                 .debugFlushCompilationCache();
1517           })
1518       .def_property_readonly(
1519           "name",
1520           [](const StrongFunctionPtr& self) { return self.function_->name(); })
1521       .def(
1522           "_set_ignore_amp",
1523           [](StrongFunctionPtr& self, bool ignore) {
1524             auto fn = self.function_;
1525             TORCH_INTERNAL_ASSERT(fn->isGraphFunction());
1526             GraphFunction& g_fn = toGraphFunction(*fn);
1527             g_fn._set_ignore_amp(ignore);
1528           })
1529       .def_property_readonly(
1530           "qualified_name",
1531           [](const StrongFunctionPtr& self) {
1532             return self.function_->qualname().qualifiedName();
1533           })
1534       .def_property_readonly("__doc__", [](const StrongFunctionPtr& self) {
1535         return self.function_->doc_string();
1536       });
1537 
1538   py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
1539       .def(
1540           "__call__",
1541           [](py::args args, const py::kwargs& kwargs) {
1542             // see: [pybind11 varargs]
1543             HANDLE_TH_ERRORS
1544             Method& method = py::cast<Method&>(args[0]);
1545 
1546             return invokeScriptMethodFromPython(
1547                 method, tuple_slice(std::move(args), 1), kwargs);
1548             END_HANDLE_TH_ERRORS_PYBIND
1549           })
1550       .def_property_readonly("graph", &Method::graph)
1551       .def_property_readonly(
1552           "inlined_graph",
1553           [](const Method& self) {
1554             auto g = toGraphFunction(self.function()).graph()->copy();
1555             Inline(*g);
1556             return g;
1557           })
1558       .def_property_readonly(
1559           "schema", [](Method& m) { return m.function().getSchema(); })
1560       .def_property_readonly("name", &Method::name)
1561       .def_property_readonly(
1562           "code",
1563           [](Method& self) {
1564             std::vector<at::IValue> constants;
1565             PrintDepsTable deps;
1566             PythonPrint pp(constants, deps);
1567             pp.printMethod(self.function());
1568             return pp.str();
1569           })
1570       .def(
1571           "_debug_flush_compilation_cache",
1572           [](Method& self) {
1573             return self.get_executor().debugFlushCompilationCache();
1574           })
1575       .def_property_readonly(
1576           "code_with_constants",
1577           [](Method& self) {
1578             std::vector<at::IValue> constants;
1579             PrintDepsTable deps;
1580             PythonPrint pp(constants, deps);
1581             pp.printMethod(self.function());
1582             std::map<std::string, at::IValue> consts;
1583             int i = 0;
1584             for (auto const& constant : constants) {
1585               consts["c" + std::to_string(i)] = constant;
1586               i += 1;
1587             }
1588             return std::make_tuple(pp.str(), std::move(consts));
1589           })
1590       .def_property_readonly("owner", &Method::owner)
1591       .def_property_readonly("raw_owner", [](const Method& self) {
1592         return Object(self.raw_owner());
1593       });
1594   m.def("_generate_upgraders_graph", &generate_upgraders_graph);
1595   m.def(
1596       "_calculate_package_version_based_on_upgraders",
1597       &calculate_package_version_based_on_upgraders);
1598   m.def("_get_version_calculator_flag", &get_version_calculator_flag);
1599   m.def(
1600       "_compile_graph_to_code_table",
1601       [](const std::string& name, const std::shared_ptr<Graph>& graph) {
1602         CompilationOptions options;
1603         GraphFunction jitFunc(name, graph, nullptr);
1604         auto mobileFunc = convertJitFunctionToMobileFunction(jitFunc, options);
1605         return convertMobileFunctionToCodeTable(*mobileFunc, options);
1606       });
1607   m.def(
1608       "_jit_script_compile",
1609       [](const std::string& qualname,
1610          const Def& def,
1611          const ResolutionCallback& rcb,
1612          const FunctionDefaults& defaults) {
1613         C10_LOG_API_USAGE_ONCE("torch.script.compile");
1614         const auto name = c10::QualifiedName(qualname);
1615         TORCH_INTERNAL_ASSERT(name.name() == def.name().name());
1616         return script_compile_function(name, def, defaults, rcb);
1617       });
1618   m.def(
1619       "_jit_script_compile_overload",
1620       [](const std::string& qualname,
1621          const Decl& overload_decl,
1622          const Def& implementation_def,
1623          const ResolutionCallback& rcb,
1624          const FunctionDefaults& implementation_defaults,
1625          const py::object& signature) {
1626         const auto name = c10::QualifiedName(qualname);
1627         return script_compile_overloaded_function(
1628             name,
1629             overload_decl,
1630             implementation_def,
1631             rcb,
1632             implementation_defaults,
1633             signature);
1634       });
1635   m.def(
1636       "_replace_overloaded_method_decl",
1637       [](const Decl& overload_decl,
1638          const Def& implementation_def,
1639          const std::string& new_name) {
1640         checkOverloadDecl(overload_decl, implementation_def.decl());
1641         return implementation_def.withDecl(overload_decl).withName(new_name);
1642       });
1643   m.def(
1644       "_create_function_from_trace",
1645       [](const std::string& qualname,
1646          const py::function& func,
1647          const py::tuple& input_tuple,
1648          const py::function& var_name_lookup_fn,
1649          bool strict,
1650          bool force_outplace,
1651          const std::vector<std::string>& argument_names) {
1652         auto typed_inputs = toTraceableStack(input_tuple);
1653         std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing(
1654             func,
1655             typed_inputs,
1656             var_name_lookup_fn,
1657             strict,
1658             force_outplace,
1659             /*self=*/nullptr,
1660             argument_names));
1661 
1662         auto cu = get_python_cu();
1663         auto name = c10::QualifiedName(qualname);
1664         auto result = cu->create_function(
1665             std::move(name), std::move(graph), /*shouldMangle=*/true);
1666         StrongFunctionPtr ret(std::move(cu), result);
1667         didFinishEmitFunction(ret);
1668         return ret;
1669       },
1670       py::arg("name"),
1671       py::arg("func"),
1672       py::arg("input_tuple"),
1673       py::arg("var_name_lookup_fn"),
1674       py::arg("strict"),
1675       py::arg("force_outplace"),
1676       py::arg("argument_names") = std::vector<std::string>());
1677 
1678   m.def(
1679       "_create_function_from_trace_with_dict",
1680       [](const std::string& qualname,
1681          const py::function& func,
1682          const py::dict& input_dict,
1683          const py::function& var_name_lookup_fn,
1684          bool strict,
1685          bool force_outplace,
1686          const std::vector<std::string>& argument_names) {
1687         auto typed_inputs = toTraceableStack(input_dict);
1688         std::shared_ptr<Graph> graph =
1689             std::get<0>(tracer::createGraphByTracingWithDict(
1690                 func,
1691                 input_dict,
1692                 typed_inputs,
1693                 var_name_lookup_fn,
1694                 strict,
1695                 force_outplace,
1696                 /*self=*/nullptr,
1697                 argument_names));
1698 
1699         auto cu = get_python_cu();
1700         auto name = c10::QualifiedName(qualname);
1701         auto result = cu->create_function(
1702             std::move(name), std::move(graph), /*shouldMangle=*/true);
1703         StrongFunctionPtr ret(std::move(cu), result);
1704         didFinishEmitFunction(ret);
1705         return ret;
1706       },
1707       py::arg("name"),
1708       py::arg("func"),
1709       py::arg("input_dict"),
1710       py::arg("var_name_lookup_fn"),
1711       py::arg("strict"),
1712       py::arg("force_outplace"),
1713       py::arg("argument_names") = std::vector<std::string>());
1714 
1715   m.def(
1716       "_jit_script_class_compile",
1717       [](const std::string& qualifiedName,
1718          const ClassDef& classDef,
1719          const ClassMethodDefaults& defaults,
1720          const ResolutionCallback& rcb) {
1721         C10_LOG_API_USAGE_ONCE("torch.script.class");
1722         if (classDef.superclass().present()) {
1723           throw(
1724               ErrorReport(classDef.range())
1725               << "Torchscript does not support class inheritance.");
1726         }
1727         auto cu = get_python_cu();
1728         auto classname = c10::QualifiedName(qualifiedName);
1729         if (cu->get_type(classname) != nullptr) {
1730           classname = cu->mangle(classname);
1731         }
1732 
1733         auto classType = ClassType::create(
1734             classname,
1735             cu,
1736             /* is_module = */ false,
1737             /* doc_string = */ "",
1738             getUnresolvedClassAttributes(classDef));
1739         cu->register_type(classType);
1740         std::vector<ResolverPtr> methodRcbs, propRcbs;
1741         std::vector<Def> methodDefs;
1742         std::vector<Property> props;
1743 
1744         for (const auto& def : classDef.body()) {
1745           if (def.kind() != TK_DEF) {
1746             throw(
1747                 ErrorReport(def.range())
1748                 << "Currently class bodies can only contain method "
1749                    "definitions. File an issue on GitHub if you want "
1750                    "something else!");
1751           }
1752           methodDefs.emplace_back(def);
1753           methodRcbs.push_back(
1754               pythonResolver(rcb, classDef.name().name(), classType));
1755         }
1756 
1757         // Gather definitions for property getters and setters as well as
1758         // corresponding resolution callbacks.
1759         if (classDef.properties().present()) {
1760           for (const auto& prop : classDef.properties().get()) {
1761             props.emplace_back(prop);
1762             propRcbs.push_back(
1763                 pythonResolver(rcb, classDef.name().name(), classType));
1764           }
1765         }
1766 
1767         const auto self = SimpleSelf(classType);
1768         cu->define(classname, props, propRcbs, methodDefs, methodRcbs, &self);
1769 
1770         // Stitch in default arguments for methods. Properties don't need to be
1771         // considered since there is no way to invoke setters without passing in
1772         // a value.
1773         auto defs_it = methodDefs.begin();
1774         while (defs_it != methodDefs.end()) {
1775           auto def_name = (*defs_it).name().name();
1776           // If the method is not in the defaults map, assume there are
1777           // no default arguments for it.
1778           auto default_it = defaults.find(def_name);
1779           if (default_it == defaults.end()) {
1780             continue;
1781           }
1782 
1783           const auto method_name =
1784               QualifiedName(classname, (*defs_it).name().name());
1785           auto& method = cu->get_function(method_name);
1786           method.setSchema(getSchemaWithNameAndDefaults(
1787               defs_it->range(),
1788               method.getSchema(),
1789               std::nullopt,
1790               default_it->second));
1791           ++defs_it;
1792         }
1793         return classType;
1794       });
1795   m.def(
1796       "_jit_script_interface_compile",
1797       [](const std::string& qualifiedName,
1798          const ClassDef& classDef,
1799          const ResolutionCallback& rcb,
1800          bool is_module) {
1801         auto cu = get_python_cu();
1802         auto className = c10::QualifiedName(qualifiedName);
1803         if (cu->get_type(className) != nullptr) {
1804           className = cu->mangle(className);
1805         }
1806 
1807         get_python_cu()->define_interface(
1808             className, classDef, pythonResolver(rcb), is_module);
1809         return className.qualifiedName();
1810       });
1811 
1812   py::class_<torch::jit::ErrorReport::CallStack>(
1813       m, "CallStack", py::dynamic_attr())
1814       .def(py::init<const std::string&, const SourceRange&>());
1815 
1816   m.def("_parse_source_def", [](const std::string& src) {
1817     Parser p(std::make_shared<Source>(src));
1818     return Def(p.parseFunction(/*is_method=*/true));
1819   });
1820   m.def("parse_type_comment", [](const std::string& comment) {
1821     Parser p(std::make_shared<Source>(comment));
1822     return Decl(p.parseTypeComment());
1823   });
1824 
1825   m.def("_get_upgraders_map_size", &get_upgraders_map_size);
1826   m.def("_dump_upgraders_map", &dump_upgraders_map);
1827 
1828   m.def("_test_only_populate_upgraders", &test_only_populate_upgraders);
1829   m.def("_test_only_remove_upgraders", &test_only_remove_upgraders);
1830 
1831   m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
1832   m.def("_get_max_operator_version", &getMaxOperatorVersion);
1833   m.def("_get_operator_version_map", &get_operator_version_map);
1834   m.def("_get_upgraders_entry_map", &get_upgraders_entry_map);
1835   m.def("_get_upgrader_ranges", &getUpgradersRangeForOp);
1836   m.def("_test_only_add_entry_to_op_version_map", &test_only_add_entry);
1837   m.def("_test_only_remove_entry_to_op_version_map", &test_only_remove_entry);
1838   m.def(
1839       "import_ir_module",
1840       [](std::shared_ptr<CompilationUnit> cu,
1841          const std::string& filename,
1842          py::object map_location,
1843          const py::dict& extra_files,
1844          bool restore_shapes = false) {
1845         std::optional<at::Device> optional_device;
1846         if (!map_location.is_none()) {
1847           AT_ASSERT(THPDevice_Check(map_location.ptr()));
1848           optional_device =
1849               reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1850         }
1851         ExtraFilesMap extra_files_map = extra_files_from_python(extra_files);
1852         auto ret = import_ir_module(
1853             std::move(cu),
1854             filename,
1855             optional_device,
1856             extra_files_map,
1857             /*load_debug_files*/ true,
1858             restore_shapes);
1859         extra_files_to_python(extra_files_map, extra_files);
1860         return ret;
1861       });
1862   m.def(
1863       "_import_ir_module_from_package",
1864       [](std::shared_ptr<CompilationUnit> cu,
1865          std::shared_ptr<caffe2::serialize::PyTorchStreamReader> reader,
1866          std::shared_ptr<torch::jit::DeserializationStorageContext>
1867              storage_context,
1868          py::object map_location,
1869          const std::string& ts_id) {
1870         std::optional<at::Device> optional_device;
1871         if (!map_location.is_none()) {
1872           AT_ASSERT(THPDevice_Check(map_location.ptr()));
1873           optional_device =
1874               reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1875         }
1876         return import_ir_module(
1877             std::move(cu),
1878             std::move(reader),
1879             std::move(storage_context),
1880             optional_device,
1881             ts_id);
1882       });
1883   m.def(
1884       "import_ir_module_from_buffer",
1885       [](std::shared_ptr<CompilationUnit> cu,
1886          const std::string& buffer,
1887          py::object map_location,
1888          const py::dict& extra_files,
1889          bool restore_shapes = false) {
1890         std::istringstream in(buffer);
1891         std::optional<at::Device> optional_device;
1892         if (!map_location.is_none()) {
1893           AT_ASSERT(THPDevice_Check(map_location.ptr()));
1894           optional_device =
1895               reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1896         }
1897         ExtraFilesMap extra_files_map = extra_files_from_python(extra_files);
1898         auto ret = import_ir_module(
1899             std::move(cu),
1900             in,
1901             optional_device,
1902             extra_files_map,
1903             /*load_debug_files*/ true,
1904             restore_shapes);
1905         extra_files_to_python(extra_files_map, extra_files);
1906         return ret;
1907       });
1908   m.def(
1909       "_load_for_lite_interpreter",
1910       [](const std::string& filename, py::object map_location) {
1911         std::optional<at::Device> optional_device;
1912         if (!map_location.is_none()) {
1913           AT_ASSERT(THPDevice_Check(map_location.ptr()));
1914           optional_device =
1915               reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1916         }
1917         return _load_for_mobile(filename, optional_device);
1918       });
1919   m.def(
1920       "_load_for_lite_interpreter_from_buffer",
1921       [](const std::string& buffer, py::object map_location) {
1922         std::istringstream in(buffer);
1923         std::optional<at::Device> optional_device;
1924         if (!map_location.is_none()) {
1925           AT_ASSERT(THPDevice_Check(map_location.ptr()));
1926           optional_device =
1927               reinterpret_cast<THPDevice*>(map_location.ptr())->device;
1928         }
1929         return _load_for_mobile(in, optional_device);
1930       });
1931   m.def(
1932       "_backport_for_mobile",
1933       [](const std::string& filename_input,
1934          const std::string& filename_output,
1935          const int64_t version) {
1936         return _backport_for_mobile(filename_input, filename_output, version);
1937       });
1938   m.def(
1939       "_backport_for_mobile_from_buffer",
1940       [](const std::string& buffer_input,
1941          const std::string& filename_output,
1942          const int64_t version) {
1943         std::istringstream in(buffer_input);
1944         return _backport_for_mobile(in, filename_output, version);
1945       });
1946   m.def(
1947       "_backport_for_mobile_to_buffer",
1948       [](const std::string& filename_input, const int64_t version) {
1949         std::ostringstream buffer_output;
1950         bool success =
1951             _backport_for_mobile(filename_input, buffer_output, version);
1952         return success ? py::bytes(buffer_output.str()) : py::bytes("");
1953       });
1954   m.def(
1955       "_backport_for_mobile_from_buffer_to_buffer",
1956       [](const std::string& buffer_input, const int64_t version) {
1957         std::istringstream in(buffer_input);
1958         std::ostringstream buffer_output;
1959         bool success = _backport_for_mobile(in, buffer_output, version);
1960         return success ? py::bytes(buffer_output.str()) : py::bytes("");
1961       });
1962   m.def("_get_model_bytecode_version", [](const std::string& filename) {
1963     return _get_model_bytecode_version(filename);
1964   });
1965   m.def(
1966       "_get_model_extra_files",
1967       [](const std::string& filename, const py::dict& py_extra_files) {
1968         std::optional<at::Device> optional_device;
1969         ExtraFilesMap cpp_extra_files = ExtraFilesMap();
1970         _load_for_mobile(filename, optional_device, cpp_extra_files);
1971         extra_files_to_python(cpp_extra_files, py_extra_files);
1972 
1973         return py_extra_files;
1974       });
1975   m.def(
1976       "_get_model_bytecode_version_from_buffer", [](const std::string& buffer) {
1977         std::istringstream in(buffer);
1978         return _get_model_bytecode_version(in);
1979       });
1980   m.def(
1981       "_get_model_extra_files_from_buffer",
1982       [](const std::string& buffer, const py::dict& py_extra_files) {
1983         std::optional<at::Device> optional_device;
1984         ExtraFilesMap cpp_extra_files = ExtraFilesMap();
1985         std::istringstream in(buffer);
1986         _load_for_mobile(in, optional_device, cpp_extra_files);
1987         extra_files_to_python(cpp_extra_files, py_extra_files);
1988 
1989         return py_extra_files;
1990       });
1991   m.def("_get_mobile_model_contained_types", [](const std::string& filename) {
1992     return _get_mobile_model_contained_types(filename);
1993   });
1994   m.def(
1995       "_get_mobile_model_contained_types_from_buffer",
1996       [](const std::string& buffer) {
1997         std::istringstream in(buffer);
1998         return _get_mobile_model_contained_types(in);
1999       });
2000   m.def("_nn_module_to_mobile", [](const Module& module) {
2001     CompilationOptions options;
2002     return jitModuleToMobile(module, options);
2003   });
2004   py::class_<OperatorInfo>(m, "OperatorInfo")
2005       .def_readonly("num_schema_args", &OperatorInfo::num_schema_args);
2006   m.def("_get_model_ops_and_info", [](const std::string& filename) {
2007     return _get_model_ops_and_info(filename);
2008   });
2009   m.def("_get_model_ops_and_info_from_buffer", [](const std::string& buffer) {
2010     std::istringstream in(buffer);
2011     return _get_model_ops_and_info(in);
2012   });
2013   m.def("_export_operator_list", [](torch::jit::mobile::Module& sm) {
2014     return debugMakeSet(torch::jit::mobile::_export_operator_list(sm));
2015   });
2016   m.def(
2017       "_quantize_ondevice_ptq_dynamic",
2018       [](mobile::Module& m, const std::string& method_name) {
2019         mobile::quantization::PTQQuanizationHelper ptq_helper;
2020         ptq_helper.quantize_dynamic(m, method_name);
2021       });
2022 
2023   m.def("_jit_set_emit_hooks", setEmitHooks);
2024   m.def("_jit_get_emit_hooks", getEmitHooks);
2025   m.def("_jit_clear_class_registry", []() {
2026     get_python_cu()->_clear_python_cu();
2027   });
2028   m.def(
2029       "_debug_set_autodiff_subgraph_inlining",
2030       debugSetAutodiffSubgraphInlining);
2031   m.def("_debug_set_fusion_group_inlining", debugSetFusionGroupInlining);
2032   m.def("_debug_get_fusion_group_inlining", getFusionGroupInlining);
2033   m.def("_propagate_shapes", _propagate_shapes);
2034   m.def(
2035       "_propagate_and_assign_input_shapes", _propagate_and_assign_input_shapes);
2036   m.def(
2037       "_last_executed_optimized_graph",
2038       []() { return lastExecutedOptimizedGraph(); },
2039       "Retrieve the optimized graph that was run the last time the graph executor ran on this thread");
2040   m.def(
2041       "_create_function_from_graph",
2042       [](const std::string& qualname, std::shared_ptr<Graph> graph) {
2043         // TODO this should go in the global Python CU
2044         auto cu = std::make_shared<CompilationUnit>();
2045         c10::QualifiedName name(qualname);
2046         auto fn = cu->create_function(std::move(name), std::move(graph));
2047         return StrongFunctionPtr(std::move(cu), fn);
2048       });
2049   m.def("_ivalue_tags_match", ivalue_tags_match);
2050   m.def("_ivalue_debug_python_object", [](py::object py_obj) {
2051     // convert to IValue first, IValue will incref via py::object
2052     IValue pyobj_ivalue = toIValue(std::move(py_obj), PyObjectType::get());
2053     // convert back to PyObject by borrowing the reference, which also
2054     // incref, after the return of this function, IValue is out of scope
2055     // which decref, so the return value is original refcount + 1
2056     py::object ret = toPyObject(pyobj_ivalue);
2057     return ret;
2058   });
2059   m.def("_jit_debug_module_iterators", _jit_debug_module_iterators);
2060 
2061   py::class_<testing::FileCheck>(m, "FileCheck")
2062       .def(py::init<>())
2063       .def("check", &testing::FileCheck::check)
2064       .def("check_not", &testing::FileCheck::check_not)
2065       .def("check_same", &testing::FileCheck::check_same)
2066       .def("check_next", &testing::FileCheck::check_next)
2067       .def("check_count", &testing::FileCheck::check_count)
2068       .def("check_dag", &testing::FileCheck::check_dag)
2069       .def(
2070           "check_source_highlighted",
2071           &testing::FileCheck::check_source_highlighted)
2072       .def("check_regex", &testing::FileCheck::check_regex)
2073       .def(
2074           "check_count",
2075           [](testing::FileCheck& f,
2076              const std::string& str,
2077              size_t count,
2078              bool exactly) { return f.check_count(str, count, exactly); },
2079           "Check Count",
2080           py::arg("str"),
2081           py::arg("count"),
2082           py::arg("exactly") = false)
2083       .def(
2084           "run",
2085           [](testing::FileCheck& f, const std::string& str) {
2086             return f.run(str);
2087           })
2088       .def(
2089           "run", [](testing::FileCheck& f, const Graph& g) { return f.run(g); })
2090       .def(
2091           "run",
2092           [](testing::FileCheck& f,
2093              const std::string& input,
2094              const std::string& output) { return f.run(input, output); },
2095           "Run",
2096           py::arg("checks_file"),
2097           py::arg("test_file"))
2098       .def(
2099           "run",
2100           [](testing::FileCheck& f, const std::string& input, const Graph& g) {
2101             return f.run(input, g);
2102           },
2103           "Run",
2104           py::arg("checks_file"),
2105           py::arg("graph"));
2106 
2107   m.def(
2108       "_logging_set_logger",
2109       [](logging::LoggerBase* logger) { return logging::setLogger(logger); },
2110       py::return_value_policy::reference);
2111   m.def("_set_graph_executor_optimize", [](bool optimize) {
2112     setGraphExecutorOptimize(optimize);
2113   });
2114 
2115   m.def(
2116       "_get_graph_executor_optimize",
2117       [](std::optional<bool> new_setting = std::nullopt) {
2118         bool old_value = getGraphExecutorOptimize();
2119         if (new_setting) {
2120           setGraphExecutorOptimize(*new_setting);
2121         }
2122         return old_value;
2123       },
2124       py::arg("new_settings") = nullptr);
2125 
2126   m.def(
2127       "_enable_mobile_interface_call_export",
2128       &torch::jit::enableMobileInterfaceCallExport);
2129 
2130   m.def("_create_module_with_type", [](const ClassTypePtr& type) {
2131      return Module(get_python_cu(), type);
2132    }).def("_create_object_with_type", [](const ClassTypePtr& type) {
2133     return Object(get_python_cu(), type);
2134   });
2135 
2136   m.def("_export_opnames", [](Module& sm) {
2137     return debugMakeList(torch::jit::export_opnames(sm));
2138   });
2139 
2140   py::class_<
2141       ConcreteModuleTypeBuilder,
2142       std::shared_ptr<ConcreteModuleTypeBuilder>>(
2143       m, "ConcreteModuleTypeBuilder")
2144       .def(py::init<py::object>())
2145       .def(
2146           "add_constant",
2147           [](ConcreteModuleTypeBuilder& self,
2148              std::string name,
2149              py::object value) {
2150             self.addConstant(std::move(name), std::move(value));
2151           })
2152       .def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute)
2153       .def(
2154           "add_function_attribute",
2155           &ConcreteModuleTypeBuilder::addFunctionAttribute)
2156       .def(
2157           "add_builtin_function",
2158           &ConcreteModuleTypeBuilder::addBuiltinFunction)
2159       .def("add_forward_hook", &ConcreteModuleTypeBuilder::addForwardHook)
2160       .def(
2161           "add_forward_pre_hook", &ConcreteModuleTypeBuilder::addForwardPreHook)
2162       .def("add_module", &ConcreteModuleTypeBuilder::addModule)
2163       .def("add_overload", &ConcreteModuleTypeBuilder::addOverload)
2164       .def("set_poisoned", &ConcreteModuleTypeBuilder::setPoisoned)
2165       .def(
2166           "add_failed_attribute",
2167           &ConcreteModuleTypeBuilder::addFailedAttribute)
2168       .def(
2169           "add_ignored_attribute",
2170           &ConcreteModuleTypeBuilder::addIgnoredAttribute)
2171       .def(
2172           "add_ignored_attributes",
2173           [](ConcreteModuleTypeBuilder& self,
2174              const std::vector<std::string>& names) {
2175             for (auto& name : names) {
2176               self.addIgnoredAttribute(name);
2177             }
2178           })
2179       .def(
2180           "set_module_dict",
2181           [](ConcreteModuleTypeBuilder& self) {
2182             self.setIterableModuleKind(IterableModuleKind::DICT);
2183           })
2184       .def("build", &ConcreteModuleTypeBuilder::build)
2185       .def(
2186           "equals",
2187           [](const ConcreteModuleTypeBuilder& self,
2188              const ConcreteModuleTypeBuilder& other) {
2189             return self.equals(other);
2190           })
2191       .def(
2192           "set_module_list",
2193           [](ConcreteModuleTypeBuilder& self) {
2194             self.setIterableModuleKind(IterableModuleKind::LIST);
2195           })
2196       .def(
2197           "set_parameter_list",
2198           [](ConcreteModuleTypeBuilder& self) {
2199             self.setIterableModuleKind(IterableModuleKind::PARAMLIST);
2200           })
2201       .def("set_parameter_dict", [](ConcreteModuleTypeBuilder& self) {
2202         self.setIterableModuleKind(IterableModuleKind::PARAMDICT);
2203       });
2204 
2205   py::class_<ConcreteModuleType, std::shared_ptr<ConcreteModuleType>>(
2206       m, "ConcreteModuleType")
2207       .def_property_readonly("py_class", &ConcreteModuleType::getPyClass)
2208       .def_property_readonly("jit_type", &ConcreteModuleType::getJitType)
2209       .def_static("from_jit_type", &ConcreteModuleType::fromJitType)
2210       .def("get_constants", &ConcreteModuleType::getConstantsPy)
2211       .def("get_attributes", &ConcreteModuleType::getAttributesPy)
2212       .def("get_modules", &ConcreteModuleType::getModulesPy)
2213       .def("dump", &ConcreteModuleType::dump)
2214       .def("is_ignored_attribute", &ConcreteModuleType::isIgnoredAttribute)
2215       .def(
2216           "equals",
2217           [](const ConcreteModuleType& self, const ConcreteModuleType& other) {
2218             return self.equals(other);
2219           })
2220       .def(
2221           "equals",
2222           [](const ConcreteModuleType& self,
2223              const ConcreteModuleTypeBuilder& other) {
2224             return self.equals(other);
2225           })
2226       .def(
2227           "_create_methods_and_properties",
2228           [](std::shared_ptr<ConcreteModuleType> concreteType,
2229              const std::vector<Property>& properties,
2230              const std::vector<ResolutionCallback>& propertyRcbs,
2231              const std::vector<Def>& methodDefs,
2232              const std::vector<ResolutionCallback>& methodRcbs,
2233              const std::vector<FunctionDefaults>& defaults) {
2234             TORCH_INTERNAL_ASSERT(methodDefs.size() == methodRcbs.size());
2235             TORCH_INTERNAL_ASSERT(properties.size() == propertyRcbs.size());
2236 
2237             std::vector<ResolverPtr> methodResolvers, propertyResolvers;
2238             methodResolvers.reserve(methodRcbs.size());
2239             for (auto& callback : methodRcbs) {
2240               methodResolvers.push_back(pythonResolver(callback));
2241             }
2242 
2243             propertyResolvers.reserve(propertyRcbs.size());
2244             for (auto& callback : propertyRcbs) {
2245               propertyResolvers.push_back(pythonResolver(callback));
2246             }
2247 
2248             const auto& selfType =
2249                 concreteType->getJitType()->expect<ClassType>();
2250             const auto& prefix = selfType->name().value();
2251             const auto self = ModuleSelf(std::move(concreteType));
2252             auto cu = selfType->compilation_unit();
2253             cu->define(
2254                 prefix,
2255                 properties,
2256                 propertyResolvers,
2257                 methodDefs,
2258                 methodResolvers,
2259                 &self);
2260             // Stitch in default arguments for each Def if provided
2261             auto defaults_it = defaults.begin();
2262             auto defs_it = methodDefs.begin();
2263             while (defs_it != methodDefs.end()) {
2264               const auto method_name =
2265                   QualifiedName(prefix, (*defs_it).name().name());
2266               auto& method = cu->get_function(method_name);
2267               method.setSchema(getSchemaWithNameAndDefaults(
2268                   defs_it->range(),
2269                   method.getSchema(),
2270                   std::nullopt,
2271                   *defaults_it));
2272               ++defs_it;
2273               ++defaults_it;
2274             }
2275           })
2276       .def(
2277           "_create_hooks",
2278           [](std::shared_ptr<ConcreteModuleType> concreteType,
2279              const std::vector<Def>& hookDefs,
2280              const std::vector<ResolutionCallback>& hookRcbs,
2281              const std::vector<Def>& preHookDefs,
2282              const std::vector<ResolutionCallback>& preHookRcbs) {
2283             TORCH_INTERNAL_ASSERT(hookDefs.size() == hookRcbs.size());
2284             TORCH_INTERNAL_ASSERT(preHookDefs.size() == preHookRcbs.size());
2285 
2286             std::vector<ResolverPtr> hookResolvers, preHookResolvers;
2287 
2288             hookResolvers.reserve(hookRcbs.size());
2289             for (auto& callback : hookRcbs) {
2290               hookResolvers.push_back(pythonResolver(callback));
2291             }
2292 
2293             preHookResolvers.reserve(preHookRcbs.size());
2294             for (auto& callback : preHookRcbs) {
2295               preHookResolvers.push_back(pythonResolver(callback));
2296             }
2297 
2298             const auto& selfType =
2299                 concreteType->getJitType()->expect<ClassType>();
2300             const auto& prefix = selfType->name().value();
2301             const auto self = ModuleSelf(std::move(concreteType));
2302             auto cu = selfType->compilation_unit();
2303             cu->define_hooks(
2304                 prefix,
2305                 hookDefs,
2306                 hookResolvers,
2307                 preHookDefs,
2308                 preHookResolvers,
2309                 &self);
2310           });
2311 
2312   m.def(
2313       "_resolve_type",
2314       [](const std::string& name,
2315          const SourceRange& range,
2316          const ResolutionCallback& rcb) {
2317         return pythonResolver(rcb)->resolveType(name, range);
2318       });
2319   m.def(
2320       "_resolve_type_from_object",
2321       [](const py::object& obj,
2322          const SourceRange& range,
2323          const ResolutionCallback& rcb) {
2324         return pythonResolver(rcb)->resolveTypeFromObject(obj, range);
2325       });
2326 
2327   m.def(
2328       "_run_emit_module_hook", [](const Module& m) { didFinishEmitModule(m); });
2329 
2330   m.def(
2331       "_set_should_use_format_with_string_table",
2332       setShouldUseFormatWithStringTable);
2333 
2334   // NOLINTNEXTLINE(bugprone-unused-raii)
2335   py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>(
2336       m, "LoggerBase");
2337   py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType")
2338       .value("SUM", logging::LockingLogger::AggregationType::SUM)
2339       .value("AVG", logging::LockingLogger::AggregationType::AVG)
2340       .export_values();
2341   py::class_<
2342       logging::LockingLogger,
2343       logging::LoggerBase,
2344       std::shared_ptr<logging::LockingLogger>>(m, "LockingLogger")
2345       .def(py::init<>())
2346       .def("set_aggregation_type", &logging::LockingLogger::setAggregationType)
2347       .def("get_counter_val", &logging::LockingLogger::getCounterValue);
2348   py::class_<
2349       logging::NoopLogger,
2350       logging::LoggerBase,
2351       std::shared_ptr<logging::NoopLogger>>(m, "NoopLogger")
2352       .def(py::init<>());
2353   m.def("_jit_is_script_object", [](const py::object& obj) {
2354     return py::isinstance<Object>(obj);
2355   });
2356 
2357   m.def("_get_file_format", [](const std::string& path) {
2358     switch (getFileFormat(path)) {
2359       case FileFormat::FlatbufferFileFormat:
2360         return "flatbuffer";
2361       case FileFormat::ZipFileFormat:
2362         return "zipfile";
2363       default:
2364         return "invalid";
2365     }
2366   });
2367 
2368   m.def(
2369       "_save_parameters",
2370       [](const std::map<std::string, at::Tensor>& map,
2371          const std::string& filename,
2372          bool use_flatbuffer = false) {
2373         _save_parameters(map, filename, use_flatbuffer);
2374       });
2375 
2376   m.def("_load_mobile_module_from_file", [](const std::string& filename) {
2377     return torch::jit::load_mobile_module_from_file(filename);
2378   });
2379   m.def("_load_mobile_module_from_bytes", [](const std::string& bytes) {
2380     auto bytes_copy = copyStr(bytes);
2381     return torch::jit::parse_and_initialize_mobile_module(
2382         bytes_copy, bytes.size());
2383   });
2384   m.def("_load_jit_module_from_file", [](const std::string& filename) {
2385     ExtraFilesMap extra_files = ExtraFilesMap();
2386     return torch::jit::load_jit_module_from_file(filename, extra_files);
2387   });
2388   m.def("_load_jit_module_from_bytes", [](const std::string& bytes) {
2389     auto bytes_copy = copyStr(bytes);
2390     ExtraFilesMap extra_files = ExtraFilesMap();
2391     return torch::jit::parse_and_initialize_jit_module(
2392         bytes_copy, bytes.size(), extra_files);
2393   });
2394   m.def(
2395       "_save_mobile_module",
2396       [](const torch::jit::mobile::Module& module,
2397          const std::string& filename,
2398          const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2399         return torch::jit::save_mobile_module(module, filename, _extra_files);
2400       });
2401   m.def(
2402       "_save_jit_module",
2403       [](const torch::jit::Module& module,
2404          const std::string& filename,
2405          const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2406         return torch::jit::save_jit_module(module, filename, _extra_files);
2407       });
2408   m.def(
2409       "_save_mobile_module_to_bytes",
2410       [](const torch::jit::mobile::Module& module,
2411          const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2412         auto detached_buffer =
2413             torch::jit::save_mobile_module_to_bytes(module, _extra_files);
2414         return py::bytes(
2415             reinterpret_cast<char*>(detached_buffer->data()),
2416             detached_buffer->size());
2417       });
2418   m.def(
2419       "_save_jit_module_to_bytes",
2420       [](const torch::jit::Module& module,
2421          const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
2422         auto detached_buffer =
2423             torch::jit::save_jit_module_to_bytes(module, _extra_files);
2424         return py::bytes(
2425             reinterpret_cast<char*>(detached_buffer->data()),
2426             detached_buffer->size());
2427       });
2428   m.def("_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
2429     py::gil_scoped_acquire acquire;
2430     py::dict result;
2431     mobile::ModuleInfo minfo =
2432         torch::jit::get_module_info_from_flatbuffer(&flatbuffer_content[0]);
2433     result["bytecode_version"] = minfo.bytecode_version;
2434     result["operator_version"] = minfo.operator_version;
2435     result["function_names"] = minfo.function_names;
2436     result["type_names"] = minfo.type_names;
2437     result["opname_to_num_args"] = minfo.opname_to_num_args;
2438     return result;
2439   });
2440 
2441   m.def("_pickle_save", [](const IValue& v) {
2442     auto bytes = torch::jit::pickle_save(v);
2443     return py::bytes(bytes.data(), bytes.size());
2444   });
2445 
2446   m.def("_pickle_load_obj", [](const py::bytes& bytes) {
2447     // https://github.com/pybind/pybind11/issues/2517
2448     std::string buffer = bytes;
2449     return torch::jit::pickle_load_obj(buffer);
2450   });
2451 
2452   initScriptDictBindings(module);
2453   initScriptListBindings(module);
2454 }
2455 
2456 } // namespace torch::jit
2457