• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""TensorFlow Ops for loss computation (deprecated).
17
18This module and all its submodules are deprecated. See
19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
20for migration instructions.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27from tensorflow.contrib.framework import deprecated
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import array_ops as array_ops_
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import nn
32from tensorflow.python.ops.losses import losses
33
34
35@deprecated('2016-12-01', 'Use `tf.losses.mean_squared_error` '
36            'and explicit logits computation.')
37def mean_squared_error_regressor(tensor_in, labels, weights, biases, name=None):
38  """Returns prediction and loss for mean squared error regression."""
39  with ops.name_scope(name, 'mean_squared_error_regressor',
40                      [tensor_in, labels]):
41    predictions = nn.xw_plus_b(tensor_in, weights, biases)
42    if len(labels.get_shape()) == 1 and len(predictions.get_shape()) == 2:
43      predictions = array_ops_.squeeze(predictions, axis=[1])
44    return predictions, losses.mean_squared_error(labels, predictions)
45
46
47@deprecated('2016-12-01', 'Use `tf.losses.softmax_cross_entropy` '
48            'and explicit logits computation.')
49def softmax_classifier(tensor_in,
50                       labels,
51                       weights,
52                       biases,
53                       class_weight=None,
54                       name=None):
55  """Returns prediction and loss for softmax classifier.
56
57  This function returns "probabilities" and a cross entropy loss. To obtain
58  predictions, use `tf.argmax` on the returned probabilities.
59
60  This function requires labels to be passed in one-hot encoding.
61
62  Args:
63    tensor_in: Input tensor, [batch_size, feature_size], features.
64    labels: Tensor, [batch_size, n_classes], one-hot labels of the output
65      classes.
66    weights: Tensor, [batch_size, feature_size], linear transformation
67      matrix.
68    biases: Tensor, [batch_size], biases.
69    class_weight: Tensor, optional, [n_classes], weight for each class.
70      If not given, all classes are supposed to have weight one.
71    name: Operation name.
72
73  Returns:
74    `tuple` of softmax predictions and loss `Tensor`s.
75  """
76  with ops.name_scope(name, 'softmax_classifier', [tensor_in, labels]):
77    logits = nn.xw_plus_b(tensor_in, weights, biases)
78    if class_weight is not None:
79      logits = math_ops.multiply(logits, class_weight)
80    return nn.softmax(logits), losses.softmax_cross_entropy(labels, logits)
81