1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functions for querying registered kernels.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.framework import kernel_def_pb2 22from tensorflow.python.client import pywrap_tf_session as c_api 23from tensorflow.python.util import compat 24 25 26def get_all_registered_kernels(): 27 """Returns a KernelList proto of all registered kernels. 28 """ 29 buf = c_api.TF_GetAllRegisteredKernels() 30 data = c_api.TF_GetBuffer(buf) 31 kernel_list = kernel_def_pb2.KernelList() 32 kernel_list.ParseFromString(compat.as_bytes(data)) 33 return kernel_list 34 35 36def get_registered_kernels_for_op(name): 37 """Returns a KernelList proto of registered kernels for a given op. 38 39 Args: 40 name: A string representing the name of the op whose kernels to retrieve. 41 """ 42 buf = c_api.TF_GetRegisteredKernelsForOp(name) 43 data = c_api.TF_GetBuffer(buf) 44 kernel_list = kernel_def_pb2.KernelList() 45 kernel_list.ParseFromString(compat.as_bytes(data)) 46 return kernel_list 47