• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
17 
18 #include <queue>
19 
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/compiler/xla/types.h"
25 #include "tensorflow/compiler/xla/util.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace xla {
32 
TupleSimplifier(bool exclude_entry_computation)33 TupleSimplifier::TupleSimplifier(bool exclude_entry_computation)
34     : exclude_entry_computation_(exclude_entry_computation) {}
35 
RemoveWholeTuple(HloInstruction * tuple)36 StatusOr<bool> TupleSimplifier::RemoveWholeTuple(HloInstruction* tuple) {
37   bool changed = false;
38   HloInstruction* top_tuple = nullptr;
39   bool can_simplify = true;
40   for (int64 operand_number = 0; operand_number < tuple->operand_count();
41        ++operand_number) {
42     HloInstruction* operand = tuple->mutable_operand(operand_number);
43     if (operand->opcode() != HloOpcode::kGetTupleElement ||
44         operand->tuple_index() != operand_number) {
45       can_simplify = false;
46       break;
47     }
48     if (top_tuple == nullptr) {
49       top_tuple = operand->mutable_operand(0);
50       if (!ShapeUtil::Compatible(top_tuple->shape(), tuple->shape())) {
51         can_simplify = false;
52         break;
53       }
54     } else if (top_tuple != operand->operand(0)) {
55       can_simplify = false;
56       break;
57     }
58   }
59   if (can_simplify && top_tuple != nullptr) {
60     changed = true;
61     TF_RETURN_IF_ERROR(tuple->parent()->ReplaceInstruction(tuple, top_tuple));
62   }
63   return changed;
64 }
65 
Run(HloModule * module)66 StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
67   // Initially add all GTE and Tuple instructions to the worklist.
68   bool changed = false;
69   for (auto* computation : module->computations()) {
70     if (exclude_entry_computation_ &&
71         computation == module->entry_computation()) {
72       continue;
73     }
74     for (auto* instruction : computation->MakeInstructionPostOrder()) {
75       if (instruction->opcode() == HloOpcode::kTuple) {
76         TF_ASSIGN_OR_RETURN(changed, RemoveWholeTuple(instruction));
77       } else {
78         auto ancestor = instruction->LatestNonGteAncestorAndIndex();
79         if (ancestor.first == instruction) {
80           continue;
81         }
82         // If possible replace a chain of GTE with the operation which produces
83         // the element. For example, replace uses of GTE with below with just
84         // 'Op' (assuming 'Op' is at the index of the GTE instruction):
85         //
86         //     ...  Op ...
87         //       \  |   /
88         //        Tuple
89         //          |
90         //         GTE
91         //         ...
92         //          |
93         //         GTE
94         //          |
95         //         GTE
96         //
97         // Note that this deletes the Tuple instruction altogether. In addition,
98         // if only a subset of tuple's elements are used, this transform
99         // optimizes them one at a time, and after the last use is optimized,
100         // the Tuple will also be deleted.
101         if (ShapeUtil::Compatible(ancestor.first->shape(),
102                                   instruction->shape())) {
103           changed = true;
104           TF_RETURN_IF_ERROR(
105               computation->ReplaceInstruction(instruction, ancestor.first));
106         } else if (ancestor.first->opcode() == HloOpcode::kTuple) {
107           changed = true;
108           TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
109               instruction,
110               ancestor.first->mutable_operand(ancestor.second[0])));
111         }
112       }
113     }
114   }
115   return changed;
116 }
117 
118 }  // namespace xla
119