• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific reduction Ops.
17 
18 #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
19 #include "tensorflow/compiler/tf2xla/type_util.h"
20 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/core/framework/kernel_def_builder.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 class SumOp : public XlaReductionOp {
31  public:
SumOp(OpKernelConstruction * ctx)32   explicit SumOp(OpKernelConstruction* ctx)
33       : XlaReductionOp(ctx,
34                        XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
InitialValue(xla::XlaBuilder * builder)35   xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
36     return xla::Zero(builder, xla_reduction_type_);
37   }
BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)38   void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
39                     const xla::XlaOp& scalar_rhs) override {
40     xla::Add(scalar_lhs, scalar_rhs);
41   }
42 };
43 
44 REGISTER_XLA_OP(Name("Sum").CompileTimeConstantInput("reduction_indices"),
45                 SumOp);
46 
47 class ProdOp : public XlaReductionOp {
48  public:
ProdOp(OpKernelConstruction * ctx)49   explicit ProdOp(OpKernelConstruction* ctx)
50       : XlaReductionOp(ctx,
51                        XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
52 
InitialValue(xla::XlaBuilder * builder)53   xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
54     return xla::One(builder, xla_reduction_type_);
55   }
56 
BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)57   void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
58                     const xla::XlaOp& scalar_rhs) override {
59     xla::Mul(scalar_lhs, scalar_rhs);
60   }
61 };
62 
63 REGISTER_XLA_OP(Name("Prod").CompileTimeConstantInput("reduction_indices"),
64                 ProdOp);
65 
66 class MinOp : public XlaReductionOp {
67  public:
MinOp(OpKernelConstruction * ctx)68   explicit MinOp(OpKernelConstruction* ctx)
69       : XlaReductionOp(ctx, ctx->input_type(0)) {}
70 
InitialValue(xla::XlaBuilder * builder)71   xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
72     return xla::MaxValue(builder, xla_reduction_type_);
73   }
74 
BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)75   void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
76                     const xla::XlaOp& scalar_rhs) override {
77     xla::Min(scalar_lhs, scalar_rhs);
78   }
79 };
80 
81 REGISTER_XLA_OP(Name("Min").CompileTimeConstantInput("reduction_indices"),
82                 MinOp);
83 
84 class MaxOp : public XlaReductionOp {
85  public:
MaxOp(OpKernelConstruction * ctx)86   explicit MaxOp(OpKernelConstruction* ctx)
87       : XlaReductionOp(ctx, ctx->input_type(0)) {}
88 
InitialValue(xla::XlaBuilder * builder)89   xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
90     return xla::MinValue(builder, xla_reduction_type_);
91   }
92 
BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)93   void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
94                     const xla::XlaOp& scalar_rhs) override {
95     xla::Max(scalar_lhs, scalar_rhs);
96   }
97 };
98 
99 REGISTER_XLA_OP(Name("Max").CompileTimeConstantInput("reduction_indices"),
100                 MaxOp);
101 
102 class MeanOp : public XlaReductionOp {
103  public:
MeanOp(OpKernelConstruction * ctx)104   explicit MeanOp(OpKernelConstruction* ctx)
105       : XlaReductionOp(ctx,
106                        XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
107 
InitialValue(xla::XlaBuilder * builder)108   xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
109     return xla::Zero(builder, xla_reduction_type_);
110   }
BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)111   void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
112                     const xla::XlaOp& scalar_rhs) override {
113     xla::Add(scalar_lhs, scalar_rhs);
114   }
115 
BuildFinalizer(xla::XlaBuilder *,const xla::XlaOp & input,const xla::XlaOp & reduce_output,const std::vector<int64> & dimensions_to_reduce)116   xla::XlaOp BuildFinalizer(
117       xla::XlaBuilder* /*builder*/, const xla::XlaOp& input,
118       const xla::XlaOp& reduce_output,
119       const std::vector<int64>& dimensions_to_reduce) override {
120     if (dimensions_to_reduce.empty()) {
121       return reduce_output;
122     }
123     auto divisor = xla::GetDimensionSize(input, dimensions_to_reduce[0]);
124     for (int i = 1; i < dimensions_to_reduce.size(); i++) {
125       auto size = xla::GetDimensionSize(input, dimensions_to_reduce[i]);
126       divisor = xla::Mul(divisor, size);
127     }
128     divisor = xla::ConvertElementType(divisor, xla_reduction_type_);
129     return XlaHelpers::ConvertElementType(reduce_output / divisor,
130                                           input_type(0));
131   }
132 };
133 
134 REGISTER_XLA_OP(Name("Mean").CompileTimeConstantInput("reduction_indices"),
135                 MeanOp);
136 
137 class AllOp : public XlaReductionOp {
138  public:
AllOp(OpKernelConstruction * ctx)139   explicit AllOp(OpKernelConstruction* ctx)
140       : XlaReductionOp(ctx, ctx->input_type(0)) {}
141 
InitialValue(xla::XlaBuilder * builder)142   xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
143     return xla::ConstantR0<bool>(builder, true);
144   }
145 
BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)146   void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
147                     const xla::XlaOp& scalar_rhs) override {
148     xla::And(scalar_lhs, scalar_rhs);
149   }
150 };
151 
152 REGISTER_XLA_OP(Name("All").CompileTimeConstantInput("reduction_indices"),
153                 AllOp);
154 
155 class AnyOp : public XlaReductionOp {
156  public:
AnyOp(OpKernelConstruction * ctx)157   explicit AnyOp(OpKernelConstruction* ctx)
158       : XlaReductionOp(ctx, ctx->input_type(0)) {}
159 
InitialValue(xla::XlaBuilder * builder)160   xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
161     return xla::ConstantR0<bool>(builder, false);
162   }
163 
BuildReducer(xla::XlaBuilder * builder,const xla::XlaOp & scalar_lhs,const xla::XlaOp & scalar_rhs)164   void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
165                     const xla::XlaOp& scalar_rhs) override {
166     xla::Or(scalar_lhs, scalar_rhs);
167   }
168 };
169 
170 REGISTER_XLA_OP(Name("Any").CompileTimeConstantInput("reduction_indices"),
171                 AnyOp);
172 
173 }  // namespace
174 }  // namespace tensorflow
175