#include #include #include #include // removed after using simple type_resolver/obj_loader #include #include #include #include // removed after using simple type_resolver/obj_loader #include #include #include #include #include #include #include #include namespace c10 { TypePtr parseType(const std::string& pythonStr); } // namespace c10 namespace torch::jit { using caffe2::serialize::FileAdapter; using caffe2::serialize::IStreamAdapter; using caffe2::serialize::PyTorchStreamReader; using caffe2::serialize::ReadAdapterInterface; c10::IValue readArchive( const std::string& archive_name, PyTorchStreamReader& stream_reader) { std::optional device; std::shared_ptr compilation_unit = std::make_shared(); // TODO (T90180710): Simplify type_resolver and obj_loader when getting // bytecode version from model auto type_resolver = [&](const c10::QualifiedName& qn) { return typeResolverMobile(qn, compilation_unit); }; std::shared_ptr mobile_compilation_unit = std::make_shared(); auto obj_loader = [&](const at::StrongTypePtr& type, const IValue& input) { return objLoaderMobile(type, input, *mobile_compilation_unit); }; bool bytecode_tensor_in_constants_archive = (archive_name == "bytecode" && !isTensorInBytecodeArchive(stream_reader)); auto ivalues = torch::jit::readArchiveAndTensors( archive_name, /*pickle_prefix=*/"", /*tensor_prefix=*/ bytecode_tensor_in_constants_archive ? "constants/" : "", type_resolver, obj_loader, device, stream_reader, nullptr); return ivalues; } std::vector get_bytecode_ivalues(PyTorchStreamReader& reader) { return std::move(*readArchive("bytecode", reader).toTuple()).elements().vec(); } /********************** Bytecode **********************/ // Forward declare uint64_t _get_model_bytecode_version( const std::vector& bytecode_ivalues); static uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size); uint64_t _get_model_bytecode_version(std::istream& in) { auto orig_pos = in.tellg(); in.seekg(0, in.beg); auto [data, size] = get_stream_content(in); in.seekg(orig_pos, in.beg); return _get_model_bytecode_version_from_bytes(data.get(), size); } uint64_t _get_model_bytecode_version(const std::string& filename) { std::ifstream ifile(filename); return _get_model_bytecode_version(ifile); } uint64_t _get_model_bytecode_version( const std::shared_ptr& rai) { auto [data, size] = get_rai_content(rai.get()); return _get_model_bytecode_version_from_bytes(data.get(), size); } static uint64_t _get_model_bytecode_version_zip( std::shared_ptr rai) { if (!check_zip_file(rai)) { TORCH_CHECK( false, "Failed to open .ptl file please ensure the model was exported for mobile"); } PyTorchStreamReader reader(std::move(rai)); auto bytecode_values = get_bytecode_ivalues(reader); return _get_model_bytecode_version(bytecode_values); } uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size) { TORCH_CHECK(data != nullptr, "Pointer to bytes is null."); TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format"); auto format = getFileFormat(data); switch (format) { case FileFormat::FlatbufferFileFormat: { return get_bytecode_version_from_bytes(data); } case FileFormat::ZipFileFormat: { auto rai = std::make_unique(data, size); auto version = _get_model_bytecode_version_zip(std::move(rai)); return version; } default: TORCH_CHECK(false, "Unrecognized data format"); } } uint64_t _get_model_bytecode_version( const std::vector& bytecode_ivalues) { if (!bytecode_ivalues.empty() && bytecode_ivalues[0].isInt()) { int64_t model_version = bytecode_ivalues[0].toInt(); TORCH_CHECK( model_version > 0, "Expected model bytecode version > 0 got ", model_version); return static_cast(model_version); } TORCH_CHECK(false, "Failed to get bytecode version."); } /********************** Operator Version **********************/ uint64_t _get_model_operator_version( PyTorchStreamReader& reader); // Forward Declare uint64_t _get_model_operator_version(std::istream& in) { std::unique_ptr rai = std::make_unique(&in); return _get_model_operator_version(std::move(rai)); } uint64_t _get_model_operator_version(const std::string& filename) { std::unique_ptr rai = std::make_unique(filename); return _get_model_operator_version(std::move(rai)); } uint64_t _get_model_operator_version( std::shared_ptr rai) { if (!check_zip_file(rai)) { TORCH_CHECK( false, "Failed to open .ptl file please ensure the model was exported for mobile"); } PyTorchStreamReader reader(std::move(rai)); return _get_model_operator_version(reader); } uint64_t _get_model_operator_version(PyTorchStreamReader& reader) { return reader.version(); } /********************** Operators and Info **********************/ // Forward declare std::unordered_map _get_model_ops_and_info( std::vector bytecode_ivalues); std::unordered_map _get_model_ops_and_info( std::istream& in) { std::unique_ptr rai = std::make_unique(&in); return _get_model_ops_and_info(std::move(rai)); } std::unordered_map _get_model_ops_and_info( const std::string& filename) { std::unique_ptr rai = std::make_unique(filename); return _get_model_ops_and_info(std::move(rai)); } std::unordered_map _get_model_ops_and_info( std::shared_ptr rai) { if (!check_zip_file(rai)) { TORCH_WARN("Failed to open zip file for model ops."); return std::unordered_map{}; } PyTorchStreamReader reader(std::move(rai)); auto bytecode_values = get_bytecode_ivalues(reader); return _get_model_ops_and_info(bytecode_values); } /* A function to retrieve the root (top level) operators of a model and their * corresponding compatibility info. These root operators can call other * operators within them (traced ops), and a root op can call many different * traced ops depending on internal code paths in the root op. These traced ops * are not returned by this function. Those operators are abstracted into the * runtime as an implementation detail (and the traced ops themselves can also * call other operators) making retrieving them difficult and their value from * this api negligible since they will differ between which runtime version the * model is run on. Because of this, there is a false positive this api can't * prevent in a compatibility usecase. All the root ops of a model are present * in a target runtime, but not all the traced ops are which prevents a model * from being able to run. **/ std::unordered_map _get_model_ops_and_info( std::vector bytecode_ivalues) { constexpr uint64_t min_version_with_schema = 6; if (_get_model_bytecode_version(bytecode_ivalues) < min_version_with_schema) { TORCH_WARN( "Only models with bytecode version 6 and above contain operator schema information. Please re-export your model to generate it"); } std::unordered_map result; if (bytecode_ivalues.empty()) { TORCH_WARN("Failed to get model ops and info."); return result; } // loop over all the functions in the bytecode for (const auto i : c10::irange(1, bytecode_ivalues.size())) { // descend to the operators list const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements(); auto operators_tuple = method_tuple.at(1).toTupleRef().elements()[1]; auto operators = operators_tuple.toTupleRef().elements()[1]; for (auto& op_tuple : operators.toTupleRef().elements()) { const auto& op = op_tuple.toTupleRef().elements(); // grab name std::string op_name = op.at(0).toStringRef(); std::string op_overload_name = op.at(1).toStringRef(); if (!op_overload_name.empty()) { op_name.append("."); op_name.append(op_overload_name); } // grab schema size if (op.size() > 2) { result.emplace(op_name, OperatorInfo{(int)op.at(2).toInt()}); } else { // no schema information use default result.emplace(op_name, OperatorInfo{}); } } } return result; } /********************** Get Type Table **********************/ // Forward declare std::unordered_set _get_mobile_model_contained_types( const std::vector& bytecode_ivalues); std::unordered_set _get_mobile_model_contained_types( std::istream& in) { std::unique_ptr rai = std::make_unique(&in); return _get_mobile_model_contained_types(std::move(rai)); } std::unordered_set _get_mobile_model_contained_types( const std::string& filename) { std::unique_ptr rai = std::make_unique(filename); return _get_mobile_model_contained_types(std::move(rai)); } std::unordered_set _get_mobile_model_contained_types( std::shared_ptr rai) { if (!check_zip_file(rai)) { TORCH_CHECK( false, "Failed to open .ptl file please ensure the model was exported for mobile"); } PyTorchStreamReader reader(std::move(rai)); auto bytecode_values = get_bytecode_ivalues(reader); return _get_mobile_model_contained_types(bytecode_values); } // Get deduplicate type table given bytecode, and each string is a atomic type, // like str, Tensor and etc. For example, // input: "Dict[int, Tuple[Tensor, Tensor, Tensor]]" // output: {Dict, int, Tuple, Tensor} std::unordered_set _get_mobile_model_contained_types( const std::vector& bytecode_ivalues) { std::unordered_set contained_types; // To avoid parsing same type twice, declare $parsed_type_names_records and // use type name (string, ex: "Dict[int, Tuple[Tensor, Tensor, Tensor]]") as // the hash to record which types are parsed. std::unordered_set parsed_type_names_records; for (const auto i : c10::irange(1, bytecode_ivalues.size())) { const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements(); auto type_table_tuple = method_tuple.at(1).toTupleRef().elements()[BYTECODE_INDEX_TYPE]; const auto& type_table = type_table_tuple.toTupleRef().elements()[1].toTupleRef().elements(); // type_table is a list of IValue, and each IValue is a string, // for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]" std::vector type_name_list; for (const auto& type_definition : type_table) { std::unordered_set type_tokens; std::string type_name = type_definition.toStringRef(); type_name_list.emplace_back(type_name); } at::TypeParser parser(type_name_list); parser.parseList(); contained_types = parser.getContainedTypes(); } return contained_types; } /********************** Compatibility Checker **********************/ ModelCompatibilityInfo ModelCompatibilityInfo::get(std::istream& in) { std::unique_ptr rai = std::make_unique(&in); return get(std::move(rai)); } ModelCompatibilityInfo ModelCompatibilityInfo::get( const std::string& filename) { std::unique_ptr rai = std::make_unique(filename); return get(std::move(rai)); } ModelCompatibilityInfo ModelCompatibilityInfo::get( std::shared_ptr rai) { if (!check_zip_file(rai)) { TORCH_CHECK( false, "Failed to open zip file for model compatibility information"); } PyTorchStreamReader reader(std::move(rai)); std::vector bytecode_values = get_bytecode_ivalues(reader); uint64_t model_bytecode_version = _get_model_bytecode_version(bytecode_values); auto model_info = _get_model_ops_and_info(bytecode_values); std::unordered_set type_table = _get_mobile_model_contained_types(bytecode_values); uint64_t operator_version = _get_model_operator_version(reader); return ModelCompatibilityInfo{ model_bytecode_version, model_info, type_table, operator_version}; } ModelCompatCheckResult is_compatible( RuntimeCompatibilityInfo runtime_info, const ModelCompatibilityInfo& model_info) { ModelCompatCheckResult result = {ModelCompatibilityStatus::OK, {}}; // Check that the models bytecode version is less than or equal to // kMaxSupportedBytecodeVersion from the runtime if (model_info.bytecode_version > runtime_info.min_max_supported_bytecode_version.second) { result.status = ModelCompatibilityStatus::ERROR; std::ostringstream s; s << "model bytecode version " << model_info.bytecode_version << "is greater than the max supported bytecode version in runtimes " << runtime_info.min_max_supported_bytecode_version.second; result.errors.emplace_back(s.str()); } else if ( model_info.bytecode_version < runtime_info.min_max_supported_bytecode_version.first) { result.status = ModelCompatibilityStatus::ERROR; std::ostringstream s; s << "model bytecode version " << model_info.bytecode_version << "is less than the minimum supported bytecode version in runtime " << runtime_info.min_max_supported_bytecode_version.first; result.errors.emplace_back(s.str()); } std::unordered_set supported_type = runtime_info.supported_types; // Check type table for (const auto& type_name : model_info.type_table) { if (supported_type.find(type_name) == supported_type.end()) { result.status = ModelCompatibilityStatus::ERROR; std::ostringstream s; s << "Primitive type: '" << type_name << "' is not supported in current runtime"; result.errors.push_back(s.str()); } } // Check operators std::unordered_map operator_info = model_info.operator_info; for (auto const& op : operator_info) { std::string op_name = op.first; OperatorInfo model_op_info = op.second; // Check if operator not present in runtime if (runtime_info.operator_info.find(op_name) == runtime_info.operator_info.end()) { result.status = ModelCompatibilityStatus::ERROR; std::ostringstream s; s << "Operator '" << op_name << "' missing from runtime (not found)"; result.errors.push_back(s.str()); } else { OperatorInfo runtime_op_info = runtime_info.operator_info.at(op_name); // If the runtime op has no schema information its a false alarm and isn't // actually useable if (!runtime_op_info.num_schema_args.has_value()) { result.status = ModelCompatibilityStatus::ERROR; std::ostringstream s; s << "Operator '" << op_name << "' missing from runtime (missing schema)"; result.errors.push_back(s.str()); } else { // Check if the model operator has schema information. If it doesn't // then the model is from a bytecode version < 6 and we are done. If the // model has more args than the runtime, then the runtime can't know // what to do so we aren't compatible. If the runtime has more args than // the model then we can just use default values and be fine. if (model_op_info.num_schema_args.has_value() && (model_op_info.num_schema_args.value() > runtime_op_info.num_schema_args.value())) { result.status = ModelCompatibilityStatus::ERROR; std::ostringstream s; s << "Operator schema for'" << op_name << "' has " << model_op_info.num_schema_args.value() << " args in model but only " << runtime_op_info.num_schema_args.value() << " in the runtime"; result.errors.push_back(s.str()); } } } } // Check Operator Versions if (model_info.operator_version < runtime_info.min_max_supported_opperator_versions.first || model_info.operator_version > runtime_info.min_max_supported_opperator_versions.second) { result.status = ModelCompatibilityStatus::ERROR; std::ostringstream s; s << "Model Operator Version " << model_info.operator_version << "is not within supported version range of the runtime " << runtime_info.min_max_supported_opperator_versions.first << " to " << runtime_info.min_max_supported_opperator_versions.second; result.errors.push_back(s.str()); } return result; } } // namespace torch::jit