• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
17 
18 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
21 #include "tensorflow/core/framework/rng_alg.h"
22 
23 namespace mlir {
24 namespace TF {
25 
26 namespace {
27 // Returns int, float or complex DenseElementsAttr with scalar shape with the
28 // given element type and the integer value.
GetScalarOfType(Type ty,int64_t raw_value)29 static DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
30   RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
31   if (auto float_ty = ty.dyn_cast<FloatType>()) {
32     FloatAttr attr = FloatAttr::get(float_ty, raw_value);
33     return DenseElementsAttr::get(scalar_ty, attr);
34   }
35 
36   if (auto int_ty = ty.dyn_cast<IntegerType>()) {
37     IntegerAttr attr = IntegerAttr::get(int_ty, raw_value);
38     return DenseElementsAttr::get(scalar_ty, attr);
39   }
40 
41   if (auto complex_ty = ty.dyn_cast<ComplexType>()) {
42     Type complex_element_ty = complex_ty.getElementType();
43     if (complex_element_ty.isF32()) {
44       return DenseElementsAttr::get(
45           scalar_ty, static_cast<std::complex<float>>(raw_value));
46     } else if (complex_element_ty.isF64()) {
47       return DenseElementsAttr::get(
48           scalar_ty, static_cast<std::complex<double>>(raw_value));
49     }
50   }
51   llvm_unreachable("unsupported type");
52 }
53 
54 // Returns subtype of `resource` if present. Otherwise an unranked tensor type
55 // of `element_type` is returned.
GetResourceSubtypeOrDefault(Value resource,Type element_type)56 static Type GetResourceSubtypeOrDefault(Value resource, Type element_type) {
57   auto resource_type = resource.getType()
58                            .cast<TensorType>()
59                            .getElementType()
60                            .cast<ResourceType>();
61   if (resource_type.getSubtypes().size() == 1)
62     return resource_type.getSubtypes().front();
63 
64   return UnrankedTensorType::get(element_type);
65 }
66 
HasResourceSubtype(Value resource)67 static bool HasResourceSubtype(Value resource) {
68   return resource.getType()
69              .cast<TensorType>()
70              .getElementType()
71              .cast<ResourceType>()
72              .getSubtypes()
73              .size() == 1;
74 }
75 
GetResourceSubtype(Value resource)76 static Type GetResourceSubtype(Value resource) {
77   return resource.getType()
78       .cast<TensorType>()
79       .getElementType()
80       .cast<ResourceType>()
81       .getSubtypes()
82       .front();
83 }
84 
85 // Decompose tf.RngReadAndSkip.
86 //
87 // For Philox, the resource variable holds a tensor<3xi64> with the state:
88 //   [counter_lo, counter_hi, key]
89 //
90 //   RngReadAndSkip increments the 128 bit counter value by 256 * delta and
91 //   returns the original state value.
92 //
93 // For Threefry, the resource variable holds a tensor<2xi64> with the state:
94 //   [counter, key]
95 //
96 //   RngReadAndSkip increments the 64 bit counter value by 256 * delta and
97 //   returns a tensor<3xi64> value [counter, key, 0].
98 class DecomposeRngReadAndSkipOp : public RewritePattern {
99  public:
DecomposeRngReadAndSkipOp(MLIRContext * context)100   explicit DecomposeRngReadAndSkipOp(MLIRContext *context)
101       : RewritePattern(RngReadAndSkipOp::getOperationName(), 1, context,
102                        {
103                            AddV2Op::getOperationName(),
104                            AssignVariableOp::getOperationName(),
105                            CastOp::getOperationName(),
106                            ConstOp::getOperationName(),
107                            LessOp::getOperationName(),
108                            MulOp::getOperationName(),
109                            PadOp::getOperationName(),
110                            PackOp::getOperationName(),
111                            ReadVariableOp::getOperationName(),
112                            SelectV2Op::getOperationName(),
113                            UnpackOp::getOperationName(),
114                        }) {}
115 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const116   LogicalResult matchAndRewrite(Operation *op,
117                                 PatternRewriter &rewriter) const override {
118     auto rng_op = cast<RngReadAndSkipOp>(op);
119 
120     DenseIntElementsAttr alg_constant;
121     if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
122       return rewriter.notifyMatchFailure(
123           op, "unable to determine algorithm statically");
124     }
125 
126     if (alg_constant.getNumElements() != 1) {
127       return rewriter.notifyMatchFailure(op, "expected alg to be a scalar");
128     }
129 
130     uint64_t alg_value = ((*alg_constant.int_value_begin()).getZExtValue());
131     tensorflow::Algorithm alg;
132     if (tensorflow::RNG_ALG_PHILOX == alg_value) {
133       alg = tensorflow::RNG_ALG_PHILOX;
134     } else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
135       alg = tensorflow::RNG_ALG_THREEFRY;
136     } else {
137       return rewriter.notifyMatchFailure(op, "unsupported alg");
138     }
139 
140     Type state_element_type = rewriter.getI64Type();
141     RankedTensorType op_type = RankedTensorType::get(
142         {tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
143         state_element_type);
144     if (op_type != rng_op.getType()) {
145       return rewriter.notifyMatchFailure(op, "unexpected op type");
146     }
147 
148     if (!HasResourceSubtype(rng_op.resource())) {
149       return rewriter.notifyMatchFailure(op, "missing resource subtype");
150     }
151 
152     int counter_size = tensorflow::GetCounterSize(alg);
153     int state_size = counter_size + tensorflow::RNG_KEY_SIZE;
154     RankedTensorType res_type =
155         RankedTensorType::get({state_size}, state_element_type);
156     if (res_type != GetResourceSubtype(rng_op.resource())) {
157       return rewriter.notifyMatchFailure(op, "unexpected resource subtype");
158     }
159 
160     Location loc = op->getLoc();
161 
162     // Read the state value from the resource.
163     Value state =
164         rewriter.create<ReadVariableOp>(loc, res_type, rng_op.resource());
165 
166     // Extract the key and counter from the state.
167     RankedTensorType word_type = RankedTensorType::get({}, state_element_type);
168     auto unpacked = rewriter.create<UnpackOp>(
169         loc, SmallVector<Type, 4>(state_size, word_type), state, 0);
170     Value key = unpacked.getResult(counter_size);
171 
172     SmallVector<Value, 4> counter;
173     for (int i = 0; i < counter_size; ++i) {
174       counter.push_back(unpacked.getResult(i));
175     }
176 
177     // Set the increment to 256 * delta.
178     Type u64 = rewriter.getIntegerType(64, /*isSigned=*/false);
179     RankedTensorType u64_scalar = RankedTensorType::get({}, u64);
180     Value step_size = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 256));
181     Value increment =
182         rewriter.create<MulOp>(loc, u64_scalar, step_size, rng_op.delta());
183 
184     // Increment the counter.
185     SmallVector<Value, 4> pack_args;
186     RankedTensorType word_u64_type = RankedTensorType::get({}, u64);
187     Value zero_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 0));
188     Value one_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 1));
189     for (int i = 0; i < counter_size; ++i) {
190       Value word = counter[i];
191       Value word_u64 = rewriter.create<CastOp>(loc, word_u64_type, word);
192       Value new_word_u64 = rewriter.create<AddV2Op>(loc, word_u64, increment);
193       Value new_word = rewriter.create<CastOp>(loc, word_type, new_word_u64);
194       pack_args.push_back(new_word);
195 
196       Value overflow = rewriter.create<LessOp>(loc, new_word_u64, word_u64);
197       increment = rewriter.create<SelectV2Op>(loc, overflow, one_u64, zero_u64);
198     }
199 
200     // Save the new state value to the resource.
201     pack_args.push_back(key);
202     Value new_state = rewriter.create<PackOp>(loc, res_type, pack_args);
203     rewriter.create<AssignVariableOp>(loc, rng_op.resource(), new_state);
204 
205     // Pad the original state as necessary to fill the output shape.
206     int pad = tensorflow::RNG_MAX_COUNTER_SIZE - counter_size;
207     Type i64 = rewriter.getI64Type();
208     RankedTensorType paddings_ty = RankedTensorType::get({1, 2}, i64);
209     std::vector<int64_t> paddings_values = {0, pad};
210     Value paddings = rewriter.create<ConstOp>(
211         loc, DenseIntElementsAttr::get(paddings_ty, paddings_values));
212     Value output = rewriter.create<PadOp>(loc, op_type, state, paddings);
213 
214     rewriter.replaceOp(op, output);
215     return success();
216   }
217 };
218 
219 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
220 }  // namespace
221 
PopulateDecomposeResourceOpsPatterns(MLIRContext * context,OwningRewritePatternList * patterns)222 void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
223                                           OwningRewritePatternList *patterns) {
224   patterns->insert<DecomposeRngReadAndSkipOp>(context);
225   populateWithGenerated(*patterns);
226 }
227 
228 }  // namespace TF
229 }  // namespace mlir
230