1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ 17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ 18 19 #include <vector> 20 21 #include "tensorflow/compiler/xla/client/xla_builder.h" 22 #include "tensorflow/compiler/xla/client/xla_computation.h" 23 #include "tensorflow/compiler/xla/xla_data.pb.h" 24 25 namespace xla { 26 27 // Creates a scalar less-than computation and returns it. The created 28 // computation has 2 * 'operand_types.size()' many parameters, where parameters 29 // 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The 30 // computation compares the first two parameters. For floating point types, a 31 // total order is created where 32 // -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN 33 XlaComputation CreateScalarLtComputation( 34 const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder); 35 36 // Creates a scalar greater-than computation and returns it. The created 37 // computation has 2 * 'operand_types.size()' many parameters, where parameters 38 // 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The 39 // computation compares the first two parameters. For floating point types, a 40 // total order is created where 41 // NaN > infinity > ... > 0 > -0 > ... > -infinity > -NaN 42 XlaComputation CreateScalarGtComputation( 43 const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder); 44 45 // Creates a scalar comparison computation and returns it. This function takes 46 // a vector of comparator functions to compare the operands where the function 47 // isn't nullopt with the specified comparator at that location. 48 XlaComputation CreateScalarComparisonComputation( 49 const std::string& name, const std::vector<PrimitiveType>& operand_types, 50 const std::vector< 51 std::optional<XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64_t>)>>& 52 comparators, 53 XlaBuilder* builder); 54 55 } // namespace xla 56 57 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_COMPARATORS_H_ 58