1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
17
18 #include "tensorflow/compiler/xla/client/lib/prng.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 #include "tensorflow/compiler/xla/shape.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/stream_executor/lib/statusor.h"
31
32 namespace xla {
33 namespace {
34
GetPhiloxStateOp(XlaOp input_state,const Shape & state_shape)35 XlaOp GetPhiloxStateOp(XlaOp input_state, const Shape& state_shape) {
36 if (state_shape.dimensions(0) >= 3) {
37 return Slice(input_state, {1}, {3}, {1});
38 }
39 return Rev(input_state, {0});
40 }
41
GetPhiloxOutputStateOp(XlaOp output_state,const Shape & state_shape)42 XlaOp GetPhiloxOutputStateOp(XlaOp output_state, const Shape& state_shape) {
43 if (state_shape.dimensions(0) < 3) {
44 output_state = Slice(output_state, {0}, {1}, {1});
45 }
46 return output_state;
47 }
48
49 } // namespace
50
InstructionMatchesPattern(HloInstruction * instruction)51 bool RngBitGeneratorExpander::InstructionMatchesPattern(
52 HloInstruction* instruction) {
53 return instruction->opcode() == HloOpcode::kRngBitGenerator;
54 }
55
GetGeneratorComputation(const Shape & data_shape,const Shape & state_shape,RandomAlgorithm algorithm,HloModule * module)56 StatusOr<HloComputation*> RngBitGeneratorExpander::GetGeneratorComputation(
57 const Shape& data_shape, const Shape& state_shape,
58 RandomAlgorithm algorithm, HloModule* module) {
59 RngGeneratorKey cache_key{data_shape, state_shape, algorithm, module};
60 auto it = computation_cache_.find(cache_key);
61 if (it != computation_cache_.end()) {
62 return it->second;
63 }
64
65 XlaBuilder builder("rng");
66 XlaOp state_param = Parameter(&builder, 0, state_shape, "state");
67 XlaOp key_op = Reshape(Slice(state_param, {0}, {1}, {1}), {});
68 RngOutput output;
69 switch (algorithm) {
70 case RandomAlgorithm::RNG_THREE_FRY:
71 output = ThreeFryBitGenerator(key_op, Slice(state_param, {1}, {2}, {1}),
72 data_shape);
73 break;
74 case RandomAlgorithm::RNG_PHILOX:
75 output = PhiloxBitGenerator(
76 key_op, GetPhiloxStateOp(state_param, state_shape), data_shape);
77 output.state = GetPhiloxOutputStateOp(output.state, state_shape);
78 break;
79 default:
80 return Unimplemented("Unsupported random algorthm: %s",
81 RandomAlgorithm_Name(algorithm));
82 }
83
84 XlaOp final_state =
85 ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0);
86 Tuple(&builder, {final_state, output.value});
87 TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
88
89 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
90 xla_computation.GetProgramShape());
91 HloModuleConfig config(program_shape);
92 TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
93 xla_computation.proto(), config));
94 HloCloneContext context(module);
95 HloComputation* new_computation =
96 module->DeepCloneComputation(new_module->entry_computation(), &context);
97 computation_cache_.emplace(cache_key, new_computation);
98 return new_computation;
99 }
100
ExpandInstruction(HloInstruction * hlo)101 StatusOr<HloInstruction*> RngBitGeneratorExpander::ExpandInstruction(
102 HloInstruction* hlo) {
103 HloRngBitGeneratorInstruction* rng = Cast<HloRngBitGeneratorInstruction>(hlo);
104 RandomAlgorithm algorithm = rng->algorithm();
105 if (algorithm == RandomAlgorithm::RNG_DEFAULT) {
106 algorithm = default_algorithm_;
107 }
108
109 HloModule* module = hlo->parent()->parent();
110 const Shape& data_shape = rng->shape().tuple_shapes(1);
111 const Shape& state_shape = rng->operand(0)->shape();
112 TF_ASSIGN_OR_RETURN(
113 HloComputation * generator_computation,
114 GetGeneratorComputation(data_shape, state_shape, algorithm, module));
115 return hlo->parent()->AddInstruction(HloInstruction::CreateCall(
116 ShapeUtil::MakeTupleShape({state_shape, data_shape}),
117 {hlo->mutable_operand(0)}, generator_computation));
118 }
119
120 } // namespace xla
121