#include #include namespace torch::jit { // [value refinement algorithm] // When a comparison like `cond = len(x) == 4` or `cond = len(x) != 4` is made, // `cond` value carries information (refinements) about the len of `x`. // When `cond` is used as the conditional of an if statement, the information // it carries for its true value can be inserted into the true block // and the same for its false value. // For something like `y = len(x) if len(x) == 1 else 1`, in the true branch // we can replace len(x) with 1 because the true refinements from `len(x) == 1` // will be present in the true block. // Additionally, we can optimize something like: // if len(x) != 4: // raise Exception(...) // return len(x) // Because the true block always throws, whatever refinements exist in the false // block become present in the owning block of the if node. We can also merge // refinements carried by two different booleans across an if node join by // taking the intersections of their refinements. // if cond: // z = len(x) == 4 and len(y) == 5 // else: // z = len(x) == 4 // Here, z's true value will refine the len(x) to 4, but not len(y). // If the code was written as: // if cond: // z = len(x) == 4 and len(y) == 5 // else: // z = False // // Then z's true value would refine x and y, because if z is true it had to have // come from the true block. Code that is written with `and` or `or` will // desugar to something similar. Additionally, any True refinements that were // present on `cond` can also be associated with the if node True output value. // The intersection of the refinements is the Value* which are in both // refinements and are refined to the same length // in an example like: // if cond: // x = len(a) == 4 and len(b) == 5 // else: // x = len(a) == 4 // For the x output of the node we take the intersection between // the refinements stored on each block output, which will result // in only the refinement of len(a) == 4 ListRefinement intersectRefinements( const ListRefinement& ref1, const ListRefinement& ref2) { ListRefinement out; for (const auto& pair : ref1) { auto val2 = ref2.find(pair.first); if (val2 != ref2.end() && val2->second == pair.second) { out[pair.first] = pair.second; } } return out; } // To union, just take all refinements from both inputs. We do not need to worry // about len refinements disagreeing because a path like `if len(x) == 4 and // len(x) == 5` will never be taken // in an example like: // if len(a) == 5: // x = len(b) == 4 // else: // x = False // For the output x Value, if is true then the refinements present in the true // block must also be true, so we take the union of `len(a) == 5` and len(b) == // 4` and assign them to true refinements of the output x value. This is a very // common pattern in desugaring of `and` or `or` boolean expressions ListRefinement unionRefinements( const ListRefinement& ref1, const ListRefinement& ref2) { ListRefinement out = ref1; out.insert(ref2.begin(), ref2.end()); return out; } void joinIfRefinements( Node* if_node, std::unordered_set& throwing_blocks, ListRefinement& curr_block_refinements, ListRefinement& true_block_refinements, ListRefinement& false_block_refinements, std::unordered_map& boolean_value_refinements) { IfView if_n(if_node); Block* b = if_node->owningBlock(); bool true_block_throws = throwing_blocks.count(if_n.thenBlock()); bool false_block_throws = throwing_blocks.count(if_n.elseBlock()); // if one block throws, the refinements for the other block // become present in the current block, and all bool outputs // of the if node take their refinements from non throwing block // output if (true_block_throws || false_block_throws) { if (true_block_throws && false_block_throws) { throwing_blocks.insert(b); return; } if (true_block_throws) { curr_block_refinements.insert( false_block_refinements.begin(), false_block_refinements.end()); } else { curr_block_refinements.insert( true_block_refinements.begin(), true_block_refinements.end()); } Block* non_throwing_block = true_block_throws ? if_node->blocks().at(1) : if_node->blocks().at(0); for (const auto i : c10::irange(if_n.outputs().size())) { if (boolean_value_refinements.count( non_throwing_block->outputs().at(i))) { boolean_value_refinements[if_node->outputs().at(i)] = boolean_value_refinements[non_throwing_block->outputs().at(i)]; } } return; } for (const auto i : c10::irange(if_n.outputs().size())) { if (!(if_n.outputs().at(i)->type() == BoolType::get())) { return; } Value* true_v = if_n.thenOutputs().at(i); Value* false_v = if_n.elseOutputs().at(i); if (!boolean_value_refinements.count(true_v) && !boolean_value_refinements.count(false_v) && !constant_as(true_v) && !constant_as(false_v)) { return; } // if either block has a constant bool output, e.g. `true` on the // true block, then for the `false` value we can take the false // refinements present on the false block and from the other block // output value bc if the output is false it had to have come from the // false block. if len(a) == 5: // x = len(b) == 4 // else: // x = False // if x is true, then we know both len(a) == 5 and len(b) == 4 // // if neither block has a constant bool value, we just take the // intersection of the refinements from boolean outputs. // if cond: // x = len(a) == 4 and len(b) == 5 // else: // x = len(a) == 4 // here, we know if x is true, then len(a) == 4, but not len(b) // == 5, because that refinement is not present in the true block. // TODO: could also take intersection of refinements present in // both blocks, but it's not a real use case. // boolean_value_refinements[value] is safe to access because // BooleanRefinementMapping has a default constructor BooleanRefinementMapping out; if (auto maybe_bool = constant_as(true_v)) { if (*maybe_bool) { out = BooleanRefinementMapping::FalseRefinements(unionRefinements( boolean_value_refinements[false_v].false_refine(), false_block_refinements)); } else { out = BooleanRefinementMapping::TrueRefinements(unionRefinements( boolean_value_refinements[false_v].true_refine(), false_block_refinements)); } } else if (auto maybe_bool = constant_as(false_v)) { if (*maybe_bool) { out = BooleanRefinementMapping::FalseRefinements(unionRefinements( boolean_value_refinements[true_v].false_refine(), true_block_refinements)); } else { out = BooleanRefinementMapping::TrueRefinements(unionRefinements( boolean_value_refinements[true_v].true_refine(), true_block_refinements)); } } else if ( boolean_value_refinements.count(true_v) && boolean_value_refinements.count(false_v)) { out = boolean_value_refinements[true_v].intersectBooleanRefinementMapping( boolean_value_refinements[false_v]); } boolean_value_refinements[if_n.outputs().at(i)] = out; } } bool handleCommonRefinentOperators( Node* n, std::unordered_set& throwing_blocks, std::unordered_map& info) { if (n->kind() == prim::RaiseException) { throwing_blocks.insert(n->owningBlock()); return true; } if (n->kind() == aten::__not__ && n->inputs().at(0)->type()->cast()) { // __not__(inp) -> reverse refinements if (info.count(n->input())) { auto& input_ref = info[n->input()]; info[n->output()] = BooleanRefinementMapping( input_ref.false_refine(), input_ref.true_refine()); } return true; } if (n->matches("aten::eq(bool a, bool b) -> bool") || (n->matches("aten::ne(bool a, bool b) -> bool"))) { for (size_t const_index : {0, 1}) { if (n->input(const_index)->node()->kind() != prim::Constant) { continue; } auto const_input = constant_as(n->input(const_index)).value(); auto non_const_input = n->input(1 - const_index); if (!info.count(non_const_input)) { continue; } // value == False / value != True -> equivalent to __not__ value // value == True / value != False -> equivalent to value auto& input_ref = info[non_const_input]; if ((!const_input && n->kind() == aten::eq) || (const_input && n->kind() == aten::ne)) { info[n->output()] = BooleanRefinementMapping( input_ref.false_refine(), input_ref.true_refine()); } else { info[n->output()] = BooleanRefinementMapping( input_ref.true_refine(), input_ref.false_refine()); } } return true; } return false; } } // namespace torch::jit