• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/map_inliner.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_query.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/logging.h"
32 
33 namespace xla {
34 
35 // MapInlinerVisitor traverses the HLO computation and inlines maps.
36 class MapInlinerVisitor : public DfsHloVisitorWithDefault {
37  public:
MapInlinerVisitor(HloComputation * computation)38   explicit MapInlinerVisitor(HloComputation* computation)
39       : computation_(computation) {}
40 
41   // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)42   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
43     return Status::OK();
44   }
45 
46   Status HandleMap(HloInstruction* map) override;
47 
48   // Runs the visitor on a computation.
49   StatusOr<bool> Run(HloComputation* computation);
50 
51  private:
52   // Current HloComputation instance the MapInlinerVisitor is traversing.
53   HloComputation* computation_;
54 
55   // Whether algebraic simplification has occurred.
56   bool changed_ = false;
57 };
58 
Run(HloComputation * computation)59 StatusOr<bool> MapInlinerVisitor::Run(HloComputation* computation) {
60   changed_ = false;
61   computation_ = computation;
62   TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this));
63   return changed_;
64 }
65 
HandleMap(HloInstruction * map)66 Status MapInlinerVisitor::HandleMap(HloInstruction* map) {
67   HloComputation* function = map->to_apply();
68   HloInstruction& root = *function->root_instruction();
69   // Only inlining functions that are simply a single operation until a better
70   // profitability model for inlining is defined.
71   if (hlo_query::AllOperandsAreParameters(root)) {
72     if (root.opcode() == HloOpcode::kFusion ||
73         root.opcode() == HloOpcode::kTrace) {
74       // Cloning not supported for these instructions.
75       return Status::OK();
76     }
77     VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function "
78              << root.ToShortString();
79     if (root.opcode() == HloOpcode::kParameter) {
80       // If the root is a parameter, then use the corresponding operand as the
81       // result of the computation.
82       TF_RETURN_IF_ERROR(
83           map->ReplaceAllUsesWith(map->operands()[root.parameter_number()]));
84       TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map));
85     } else if (root.opcode() == HloOpcode::kConstant) {
86       // If the input is a constant then the shape of the constant could be
87       // different than the map shape. Hence, a broadcast is needed, else the
88       // cloned operand with new shape and operands work.
89       //
90       // The constant is in an embedded computation and needs to be recreated
91       // as part of the computation that the broadcast is inserted into.
92       HloInstruction* constant = computation_->AddInstruction(root.Clone());
93       HloInstruction* placed_instruction = computation_->AddInstruction(
94           HloInstruction::CreateBroadcast(map->shape(), constant, {}));
95       TF_RETURN_IF_ERROR(
96           computation_->ReplaceInstruction(map, placed_instruction));
97     } else {
98       std::vector<HloInstruction*> params;
99       for (int64 o = 0; o < root.operands().size(); o++) {
100         params.push_back(map->operands()[root.operand(o)->parameter_number()]);
101       }
102       HloInstruction* placed_instruction = computation_->AddInstruction(
103           root.CloneWithNewOperands(map->shape(), params));
104       TF_RETURN_IF_ERROR(
105           computation_->ReplaceInstruction(map, placed_instruction));
106     }
107     changed_ = true;
108     return Status::OK();
109   }
110 
111   return Status::OK();
112 }
113 
Run(HloModule * module)114 StatusOr<bool> MapInliner::Run(HloModule* module) {
115   MapInlinerVisitor visitor(/*computation=*/nullptr);
116   bool changed = false;
117   for (HloComputation* computation : module->computations()) {
118     TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation));
119     changed |= computation_changed;
120   }
121   return changed;
122 }
123 
124 }  // namespace xla
125