1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/core/platform/types.h" 21 22 extern "C" { 23 24 // Each entry in 'values' represents a 3-dimensional shape with dimensions 25 // [a, b, c]. The 'b' dimension of each shape is sorted into ascending order 26 // according to the results of comparisons using the provided 'less_than' 27 // function. 'values_count' must be > 0 and specifies the number of entries in 28 // 'values' and 'values_primitive_type_size_in_bytes'. The size of the primitive 29 // type of the i-th shape has exactly 'values_primitive_type_size_in_bytes[i]' 30 // bytes. 'is_stable' specifies whether the sorting should be stable. 31 // 'run_options' and 'prof_counters' are passed through to the less-than 32 // function, which expects the following arguments: 33 // - pointer to the return value buffer (char*) 34 // - xla::ExecutableRunOptions = 'run_options' (char*) 35 // - pointers to the parameter buffers (char**) 36 // - pointers to the buffer tables = nullptr for thread local functions (char**) 37 // - profile counters = 'prof_counters' (int64*) 38 extern void __xla_cpu_runtime_KeyValueSort( 39 tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, 40 char** values, tensorflow::int32 values_count, 41 tensorflow::int32* values_primitive_type_size_in_bytes, bool is_stable, 42 char* run_options, tensorflow::int64* prof_counters, 43 void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)); 44 } 45 46 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ 47