• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Libsvm decoder."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.contrib.libsvm.ops import gen_libsvm_ops
21from tensorflow.contrib.util import loader
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import sparse_tensor
24from tensorflow.python.platform import resource_loader
25
26
27_libsvm_ops_so = loader.load_op_library(
28    resource_loader.get_path_to_datafile("_libsvm_ops.so"))
29
30
31def decode_libsvm(content, num_features, dtype=None, label_dtype=None):
32  """Convert Libsvm records to a tensor of label and a tensor of feature.
33
34  Args:
35    content: A `Tensor` of type `string`. Each string is a record/row in
36      the Libsvm format.
37    num_features: The number of features.
38    dtype: The type of the output feature tensor. Default to tf.float32.
39    label_dtype: The type of the output label tensor. Default to tf.int64.
40
41  Returns:
42    features: A `SparseTensor` of the shape `[input_shape, num_features]`.
43    labels: A `Tensor` of the same shape as content.
44  """
45  labels, indices, values, shape = gen_libsvm_ops.decode_libsvm(
46      content, num_features, dtype=dtype, label_dtype=label_dtype)
47  return sparse_tensor.SparseTensor(indices, values, shape), labels
48
49
50ops.NotDifferentiable("DecodeLibSVM")
51