• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_
18 
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/client/xla_computation.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 
25 namespace xla {
26 
27 // Creates a scalar less-than computation and returns it. The created
28 // computation has 2 * 'operand_types.size()' many parameters, where parameters
29 // 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The
30 // computation compares the first two parameters. For floating point types, a
31 // total order is created where
32 // -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN
33 XlaComputation CreateScalarLtComputation(
34     const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder);
35 
36 // Creates a scalar greater-than computation and returns it. The created
37 // computation has 2 * 'operand_types.size()' many parameters, where parameters
38 // 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The
39 // computation compares the first two parameters. For floating point types, a
40 // total order is created where
41 // NaN > infinity > ... > 0 > -0 > ... > -infinity > -NaN
42 XlaComputation CreateScalarGtComputation(
43     const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder);
44 
45 // Creates a scalar comparison computation and returns it. This function takes
46 // a vector of comparator functions to compare the operands where the function
47 // isn't nullopt with the specified comparator at that location.
48 XlaComputation CreateScalarComparisonComputation(
49     const std::string& name, const std::vector<PrimitiveType>& operand_types,
50     const std::vector<
51         std::optional<XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64_t>)>>&
52         comparators,
53     XlaBuilder* builder);
54 
55 }  // namespace xla
56 
57 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_
58