#include #include #include #include #include namespace torch::jit { namespace { static constexpr int defaultPrecision = 6; // IValue tags are intentionally private, so we need additional logic to cast // the IValue type to the specified format. void addFormattedArg( char key, const IValue& ival, std::stringstream& ss, int precision = defaultPrecision) { // TODO: Implement precision-based formatting std::stringstream tmp; switch (key) { case 'd': case 'i': TORCH_CHECK( ival.isScalar(), "%", key, " requires a number for formatting, but got ", ival.tagKind()); if (ival.isInt()) { ss << ival.toInt(); } else { ss << static_cast(ival.toDouble()); } break; case 'e': case 'E': TORCH_CHECK( ival.isScalar(), "%", key, " requires a number for formatting, but got ", ival.tagKind()); tmp << std::setprecision(precision) << std::scientific; if (key == 'E') { tmp << std::uppercase; } if (ival.isInt()) { tmp << static_cast(ival.toInt()); } else { tmp << static_cast(ival.toDouble()); } ss << tmp.str(); break; case 'f': case 'F': TORCH_CHECK( ival.isScalar(), "%", key, " requires a number for formatting, but got ", ival.tagKind()); tmp << std::setprecision(precision) << std::fixed; if (ival.isInt()) { tmp << static_cast(ival.toInt()); } else { tmp << static_cast(ival.toDouble()); } ss << tmp.str(); break; case 'c': TORCH_CHECK( ival.isInt() || (ival.isString() && ival.toStringRef().length() == 1), "%", key, " requires an int or char for formatting, but got ", ival.tagKind()); if (ival.isInt()) { ss << static_cast(ival.toInt()); } else { ss << ival.toStringRef(); } break; case 's': if (ival.isString()) { ss << ival.toStringRef(); } else { ss << ival; } break; default: TORCH_CHECK( false, "The specifier %", key, " is not supported in TorchScript format strings"); } } } // namespace void tupleUnpack(Stack& stack) { auto tuple = pop(stack).toTuple(); stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end()); } void format(Stack& stack, size_t num_inputs) { TORCH_CHECK( num_inputs > 0 && num_inputs <= stack.size(), "Invalid number of inputs for format string: ", num_inputs); // static const std::regex unsupported_options("\\{(.*?)\\}"); auto format = peek(stack, 0, num_inputs).toStringRef(); // // Temporally comment out the warning message because of // // "StdRegexIsAwful" internal Lint error, to prevent sev // // of std::regex from PT mobile. // if (std::regex_search(format, unsupported_options)) { // TORCH_WARN("Format options are not supported."); // } auto args = last(stack, num_inputs - 1); std::stringstream ss; for (size_t begin = 0, used_args = 0; true; ++used_args) { size_t loc = format.find("{}", begin); if (loc == std::string::npos) { ss << format.substr(begin); break; } ss << format.substr(begin, loc - begin); if (used_args >= args.size()) { AT_ERROR("Too few arguments for format string: ", format); } ss << args[used_args]; begin = loc + 2; } drop(stack, num_inputs); push(stack, ss.str()); } void einsum(Stack& stack, size_t num_inputs) { TORCH_CHECK( num_inputs >= 2, "einsum(): must specify the equation string and at least one operand, ", "or at least one operand and its subscripts list"); const auto args = last(stack, num_inputs); // Convert the subscript list format which is an interleaving of operand and // its subscripts list with an optional output subscripts list at the end // (see documentation for more details on this) to the equation string // format by creating the equation string from the subscripts list and // grouping the input operands into a tensorlist (List[Tensor]). std::stringstream ss; auto parse_sublist = [&ss](const c10::List& l, size_t arg_num) { for (const auto i : c10::irange(l.size())) { TORCH_CHECK( l[i] >= 0 && l[i] < 52, "einsum(): expected subscript ", i, " in argument ", arg_num, " to be within the range [0, 52), but got ", l[i]); if (l[i] < 26) { ss << static_cast(l[i] + 'A'); } else { ss << static_cast(l[i] - 26 + 'a'); } } }; // Parse subscripts for input operands for (auto i = decltype(num_inputs){1}; i < num_inputs; i += 2) { TORCH_CHECK( args[i].isIntList(), "einsum(): expected List[int] in argument ", i, ", but got ", args[i].type()->repr_str()); parse_sublist(args[i].toIntList(), i); if (i + 2 < num_inputs) { ss << ','; } } // Parse optional output subscripts (provided if #args is odd) if (num_inputs % 2 == 1) { TORCH_CHECK( args.back().isIntList(), "einsum(): expected List[int] in argument ", num_inputs - 1, ", but got ", args.back().type()->repr_str()); ss << "->"; parse_sublist(args.back().toIntList(), num_inputs - 1); } const auto equation = ss.str(); std::vector operands; // Parse input operands const auto end = num_inputs % 2 == 1 ? num_inputs - 1 : num_inputs; for (auto i = decltype(num_inputs){0}; i < end; i += 2) { TORCH_CHECK( args[i].isTensor(), "einsum(): expected Tensor in argument ", i, ", but got ", args[i].type()->repr_str()); operands.emplace_back(args[i].toTensor()); } drop(stack, num_inputs); push(stack, at::einsum(equation, operands)); } void percentFormat(Stack& stack, size_t num_inputs) { auto format_str = peek(stack, 0, num_inputs).toStringRef(); auto args = last(stack, num_inputs - 1)[0]; size_t args_size = 1; // assumed size if (args.isTuple()) { args_size = args.toTupleRef().elements().size(); } std::stringstream ss; size_t used_args = 0; size_t begin = 0; while (true) { size_t percent_idx = format_str.find('%', begin); if (percent_idx == std::string::npos) { ss << format_str.substr(begin); break; } size_t format_idx = percent_idx + 1; TORCH_CHECK( percent_idx < format_str.length() - 1, "Incomplete format specifier"); ss << format_str.substr(begin, percent_idx - begin); if (format_str.at(format_idx) == '%') { ss << '%'; begin = percent_idx + 2; // skip the `%` and the format specifier continue; } TORCH_CHECK(used_args < args_size, "Too few arguments for format string"); char key = format_str.at(format_idx); IValue arg; if (args.isTuple()) { arg = args.toTupleRef().elements()[used_args]; } else { arg = args; } addFormattedArg(key, arg, ss); begin = percent_idx + 2; ++used_args; } TORCH_CHECK(used_args == args_size, "Too many arguments for format string"); drop(stack, num_inputs); push(stack, ss.str()); } void listUnpack(Stack& stack, size_t num_outputs) { auto list = pop(stack).toList(); TORCH_CHECK( list.size() == num_outputs, "Expected ", num_outputs, " elements in a list but found ", list.size()); stack.insert(stack.end(), list.begin(), list.end()); } void tupleConstruct(Stack& stack, size_t num_inputs) { if (num_inputs > stack.size()) { TORCH_CHECK(false, "Invalid number of inputs: ", num_inputs); } switch (num_inputs) { case 0: stack.emplace_back(c10::ivalue::Tuple::create()); break; case 1: stack.back() = c10::ivalue::Tuple::create(std::move(stack.back())); break; case 2: { auto tuple = c10::ivalue::Tuple::create( std::move(stack[stack.size() - 2]), std::move(stack[stack.size() - 1])); stack.pop_back(); stack.back() = std::move(tuple); break; } case 3: { auto tuple = c10::ivalue::Tuple::create( std::move(stack[stack.size() - 3]), std::move(stack[stack.size() - 2]), std::move(stack[stack.size() - 1])); stack.pop_back(); stack.pop_back(); stack.back() = std::move(tuple); break; } default: { std::vector elems{ std::make_move_iterator(stack.end() - num_inputs), std::make_move_iterator(stack.end())}; drop(stack, num_inputs - 1); stack.back() = c10::ivalue::Tuple::create(std::move(elems)); break; } } } void namedTupleConstruct( Stack& stack, c10::TypePtr tuple_type, size_t num_inputs) { std::vector elems{ std::make_move_iterator(stack.end() - num_inputs), std::make_move_iterator(stack.end())}; drop(stack, num_inputs); push( stack, c10::ivalue::Tuple::createNamed(std::move(elems), std::move(tuple_type))); } void listConstruct( Stack& stack, const c10::Type& list_type, size_t num_inputs) { // Structuring the implementation this way allows NRVO to avoid // move-constructing vals on its way onto the stack. Moving a List // isn't free. auto makeList = [](Stack& stack, const c10::Type& list_type, size_t num_inputs) { c10::List vals(list_type.containedType(0)); vals.reserve(num_inputs); for (size_t i = stack.size() - num_inputs; i < stack.size(); ++i) { vals.push_back(std::move(stack[i])); } drop(stack, num_inputs); return vals; }; stack.emplace_back(makeList(stack, list_type, num_inputs)); } void dictConstruct( Stack& stack, const c10::Type& dict_type, size_t num_inputs) { auto vals = c10::impl::GenericDict( dict_type.containedType(0), dict_type.containedType(1)); vals.reserve(num_inputs / 2); // loop from the bottom of the stack to ensure the dictConstruct preserve // the inputs order. auto inputs = last(stack, num_inputs); for (size_t i = 0; i < num_inputs; i += 2) { auto key = inputs[i]; auto val = inputs[i + 1]; vals.insert_or_assign(std::move(key), std::move(val)); } drop(stack, num_inputs); push(stack, std::move(vals)); } void createObject( Stack& stack, const at::ClassTypePtr& type, bool as_weak_ref) { if (as_weak_ref) { c10::WeakTypePtr weak(type->compilation_unit(), type); auto userObj = c10::ivalue::Object::create( c10::WeakOrStrongTypePtr(weak), type->numAttributes()); push(stack, std::move(userObj)); } else { auto userObj = c10::ivalue::Object::create( c10::StrongTypePtr(type->compilation_unit(), type), type->numAttributes()); push(stack, std::move(userObj)); } } void isinstance(Stack& stack, at::ArrayRef types) { at::TypePtr ty = pop(stack).type(); for (const at::TypePtr& candidate : types) { if (ty->isSubtypeOf(*candidate)) { push(stack, true); return; } } push(stack, false); } void tupleSlice(Stack& stack, size_t begin, size_t end) { auto tuple = pop(stack).toTuple(); push( stack, c10::ivalue::Tuple::create( tuple->elements().asArrayRef().slice(begin, end - begin))); } void dequantize(Stack& stack) { auto iv = pop(stack); if (iv.isTuple()) { auto tuple = iv.toTuple(); const auto& elems = tuple->elements(); std::vector output_elems; output_elems.reserve(elems.size()); for (const auto& elem : elems) { if (elem.isTensor()) { output_elems.emplace_back(at::dequantize(elem.toTensor())); } else { output_elems.emplace_back(elem); } } push(stack, c10::ivalue::Tuple::create(std::move(output_elems))); } else if (iv.isTensorList()) { auto elems = iv.toTensorList(); auto output_list = c10::impl::GenericList(elems.elementType()); for (auto&& elem : elems) { output_list.emplace_back(at::dequantize(elem)); } push(stack, std::move(output_list)); } else { TORCH_CHECK( false, "Unsupported type in dequantize, only List[Tensor] and \ Tuple[Tensor or other types] are supported, got type:", toString(iv.type())); } } } // namespace torch::jit