1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/map_inliner.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_query.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/logging.h"
32
33 namespace xla {
34
35 // MapInlinerVisitor traverses the HLO computation and inlines maps.
36 class MapInlinerVisitor : public DfsHloVisitorWithDefault {
37 public:
MapInlinerVisitor(HloComputation * computation)38 explicit MapInlinerVisitor(HloComputation* computation)
39 : computation_(computation) {}
40
41 // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)42 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
43 return Status::OK();
44 }
45
46 Status HandleMap(HloInstruction* map) override;
47
48 // Runs the visitor on a computation.
49 StatusOr<bool> Run(HloComputation* computation);
50
51 private:
52 // Current HloComputation instance the MapInlinerVisitor is traversing.
53 HloComputation* computation_;
54
55 // Whether algebraic simplification has occurred.
56 bool changed_ = false;
57 };
58
Run(HloComputation * computation)59 StatusOr<bool> MapInlinerVisitor::Run(HloComputation* computation) {
60 changed_ = false;
61 computation_ = computation;
62 TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this));
63 return changed_;
64 }
65
HandleMap(HloInstruction * map)66 Status MapInlinerVisitor::HandleMap(HloInstruction* map) {
67 HloComputation* function = map->to_apply();
68 HloInstruction& root = *function->root_instruction();
69 // Only inlining functions that are simply a single operation until a better
70 // profitability model for inlining is defined.
71 if (hlo_query::AllOperandsAreParameters(root)) {
72 if (root.opcode() == HloOpcode::kFusion ||
73 root.opcode() == HloOpcode::kTrace) {
74 // Cloning not supported for these instructions.
75 return Status::OK();
76 }
77 VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function "
78 << root.ToShortString();
79 if (root.opcode() == HloOpcode::kParameter) {
80 // If the root is a parameter, then use the corresponding operand as the
81 // result of the computation.
82 TF_RETURN_IF_ERROR(
83 map->ReplaceAllUsesWith(map->operands()[root.parameter_number()]));
84 TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map));
85 } else if (root.opcode() == HloOpcode::kConstant) {
86 // If the input is a constant then the shape of the constant could be
87 // different than the map shape. Hence, a broadcast is needed, else the
88 // cloned operand with new shape and operands work.
89 //
90 // The constant is in an embedded computation and needs to be recreated
91 // as part of the computation that the broadcast is inserted into.
92 HloInstruction* constant = computation_->AddInstruction(root.Clone());
93 HloInstruction* placed_instruction = computation_->AddInstruction(
94 HloInstruction::CreateBroadcast(map->shape(), constant, {}));
95 TF_RETURN_IF_ERROR(
96 computation_->ReplaceInstruction(map, placed_instruction));
97 } else {
98 std::vector<HloInstruction*> params;
99 for (int64 o = 0; o < root.operands().size(); o++) {
100 params.push_back(map->operands()[root.operand(o)->parameter_number()]);
101 }
102 HloInstruction* placed_instruction = computation_->AddInstruction(
103 root.CloneWithNewOperands(map->shape(), params));
104 TF_RETURN_IF_ERROR(
105 computation_->ReplaceInstruction(map, placed_instruction));
106 }
107 changed_ = true;
108 return Status::OK();
109 }
110
111 return Status::OK();
112 }
113
Run(HloModule * module)114 StatusOr<bool> MapInliner::Run(HloModule* module) {
115 MapInlinerVisitor visitor(/*computation=*/nullptr);
116 bool changed = false;
117 for (HloComputation* computation : module->computations()) {
118 TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation));
119 changed |= computation_changed;
120 }
121 return changed;
122 }
123
124 } // namespace xla
125