• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for solvers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import nn_ops
27
28
29def create_operator(matrix):
30  """Creates a linear operator from a rank-2 tensor."""
31
32  linear_operator = collections.namedtuple(
33      "LinearOperator", ["shape", "dtype", "apply", "apply_adjoint"])
34
35  # TODO(rmlarsen): Handle SparseTensor.
36  shape = matrix.get_shape()
37  if shape.is_fully_defined():
38    shape = shape.as_list()
39  else:
40    shape = array_ops.shape(matrix)
41  return linear_operator(
42      shape=shape,
43      dtype=matrix.dtype,
44      apply=lambda v: math_ops.matmul(matrix, v, adjoint_a=False),
45      apply_adjoint=lambda v: math_ops.matmul(matrix, v, adjoint_a=True))
46
47
48def identity_operator(matrix):
49  """Creates a linear operator from a rank-2 identity tensor."""
50
51  linear_operator = collections.namedtuple(
52      "LinearOperator", ["shape", "dtype", "apply", "apply_adjoint"])
53  shape = matrix.get_shape()
54  if shape.is_fully_defined():
55    shape = shape.as_list()
56  else:
57    shape = array_ops.shape(matrix)
58  return linear_operator(
59      shape=shape,
60      dtype=matrix.dtype,
61      apply=lambda v: v,
62      apply_adjoint=lambda v: v)
63
64
65# TODO(rmlarsen): Measure if we should just call matmul.
66def dot(x, y):
67  return math_ops.reduce_sum(math_ops.conj(x) * y)
68
69
70# TODO(rmlarsen): Implement matrix/vector norm op in C++ in core.
71# We need 1-norm, inf-norm, and Frobenius norm.
72def l2norm_squared(v):
73  return constant_op.constant(2, dtype=v.dtype.base_dtype) * nn_ops.l2_loss(v)
74
75
76def l2norm(v):
77  return math_ops.sqrt(l2norm_squared(v))
78
79
80def l2normalize(v):
81  norm = l2norm(v)
82  return v / norm, norm
83