1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific reduction Ops. 17 18 #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" 19 #include "tensorflow/compiler/tf2xla/type_util.h" 20 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22 #include "tensorflow/compiler/xla/client/lib/constants.h" 23 #include "tensorflow/compiler/xla/client/xla_builder.h" 24 #include "tensorflow/compiler/xla/literal.h" 25 #include "tensorflow/core/framework/kernel_def_builder.h" 26 27 namespace tensorflow { 28 namespace { 29 30 class SumOp : public XlaReductionOp { 31 public: SumOp(OpKernelConstruction * ctx)32 explicit SumOp(OpKernelConstruction* ctx) 33 : XlaReductionOp(ctx, 34 XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} InitialValue(xla::XlaBuilder * builder)35 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { 36 return xla::Zero(builder, xla_reduction_type_); 37 } BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)38 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, 39 const xla::XlaOp& scalar_rhs) override { 40 xla::Add(scalar_lhs, scalar_rhs); 41 } 42 }; 43 44 REGISTER_XLA_OP(Name("Sum").CompileTimeConstantInput("reduction_indices"), 45 SumOp); 46 47 class ProdOp : public XlaReductionOp { 48 public: ProdOp(OpKernelConstruction * ctx)49 explicit ProdOp(OpKernelConstruction* ctx) 50 : XlaReductionOp(ctx, 51 XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} 52 InitialValue(xla::XlaBuilder * builder)53 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { 54 return xla::One(builder, xla_reduction_type_); 55 } 56 BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)57 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, 58 const xla::XlaOp& scalar_rhs) override { 59 xla::Mul(scalar_lhs, scalar_rhs); 60 } 61 }; 62 63 REGISTER_XLA_OP(Name("Prod").CompileTimeConstantInput("reduction_indices"), 64 ProdOp); 65 66 class MinOp : public XlaReductionOp { 67 public: MinOp(OpKernelConstruction * ctx)68 explicit MinOp(OpKernelConstruction* ctx) 69 : XlaReductionOp(ctx, ctx->input_type(0)) {} 70 InitialValue(xla::XlaBuilder * builder)71 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { 72 return xla::MaxValue(builder, xla_reduction_type_); 73 } 74 BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)75 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, 76 const xla::XlaOp& scalar_rhs) override { 77 xla::Min(scalar_lhs, scalar_rhs); 78 } 79 }; 80 81 REGISTER_XLA_OP(Name("Min").CompileTimeConstantInput("reduction_indices"), 82 MinOp); 83 84 class MaxOp : public XlaReductionOp { 85 public: MaxOp(OpKernelConstruction * ctx)86 explicit MaxOp(OpKernelConstruction* ctx) 87 : XlaReductionOp(ctx, ctx->input_type(0)) {} 88 InitialValue(xla::XlaBuilder * builder)89 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { 90 return xla::MinValue(builder, xla_reduction_type_); 91 } 92 BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)93 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, 94 const xla::XlaOp& scalar_rhs) override { 95 xla::Max(scalar_lhs, scalar_rhs); 96 } 97 }; 98 99 REGISTER_XLA_OP(Name("Max").CompileTimeConstantInput("reduction_indices"), 100 MaxOp); 101 102 class MeanOp : public XlaReductionOp { 103 public: MeanOp(OpKernelConstruction * ctx)104 explicit MeanOp(OpKernelConstruction* ctx) 105 : XlaReductionOp(ctx, 106 XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} 107 InitialValue(xla::XlaBuilder * builder)108 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { 109 return xla::Zero(builder, xla_reduction_type_); 110 } BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)111 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, 112 const xla::XlaOp& scalar_rhs) override { 113 xla::Add(scalar_lhs, scalar_rhs); 114 } 115 BuildFinalizer(xla::XlaBuilder *,const xla::XlaOp & input,const xla::XlaOp & reduce_output,const std::vector<int64> & dimensions_to_reduce)116 xla::XlaOp BuildFinalizer( 117 xla::XlaBuilder* /*builder*/, const xla::XlaOp& input, 118 const xla::XlaOp& reduce_output, 119 const std::vector<int64>& dimensions_to_reduce) override { 120 if (dimensions_to_reduce.empty()) { 121 return reduce_output; 122 } 123 auto divisor = xla::GetDimensionSize(input, dimensions_to_reduce[0]); 124 for (int i = 1; i < dimensions_to_reduce.size(); i++) { 125 auto size = xla::GetDimensionSize(input, dimensions_to_reduce[i]); 126 divisor = xla::Mul(divisor, size); 127 } 128 divisor = xla::ConvertElementType(divisor, xla_reduction_type_); 129 return XlaHelpers::ConvertElementType(reduce_output / divisor, 130 input_type(0)); 131 } 132 }; 133 134 REGISTER_XLA_OP(Name("Mean").CompileTimeConstantInput("reduction_indices"), 135 MeanOp); 136 137 class AllOp : public XlaReductionOp { 138 public: AllOp(OpKernelConstruction * ctx)139 explicit AllOp(OpKernelConstruction* ctx) 140 : XlaReductionOp(ctx, ctx->input_type(0)) {} 141 InitialValue(xla::XlaBuilder * builder)142 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { 143 return xla::ConstantR0<bool>(builder, true); 144 } 145 BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)146 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, 147 const xla::XlaOp& scalar_rhs) override { 148 xla::And(scalar_lhs, scalar_rhs); 149 } 150 }; 151 152 REGISTER_XLA_OP(Name("All").CompileTimeConstantInput("reduction_indices"), 153 AllOp); 154 155 class AnyOp : public XlaReductionOp { 156 public: AnyOp(OpKernelConstruction * ctx)157 explicit AnyOp(OpKernelConstruction* ctx) 158 : XlaReductionOp(ctx, ctx->input_type(0)) {} 159 InitialValue(xla::XlaBuilder * builder)160 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { 161 return xla::ConstantR0<bool>(builder, false); 162 } 163 BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)164 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, 165 const xla::XlaOp& scalar_rhs) override { 166 xla::Or(scalar_lhs, scalar_rhs); 167 } 168 }; 169 170 REGISTER_XLA_OP(Name("Any").CompileTimeConstantInput("reduction_indices"), 171 AnyOp); 172 173 } // namespace 174 } // namespace tensorflow 175