1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions for solvers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import nn_ops 27 28 29def create_operator(matrix): 30 """Creates a linear operator from a rank-2 tensor.""" 31 32 linear_operator = collections.namedtuple( 33 "LinearOperator", ["shape", "dtype", "apply", "apply_adjoint"]) 34 35 # TODO(rmlarsen): Handle SparseTensor. 36 shape = matrix.get_shape() 37 if shape.is_fully_defined(): 38 shape = shape.as_list() 39 else: 40 shape = array_ops.shape(matrix) 41 return linear_operator( 42 shape=shape, 43 dtype=matrix.dtype, 44 apply=lambda v: math_ops.matmul(matrix, v, adjoint_a=False), 45 apply_adjoint=lambda v: math_ops.matmul(matrix, v, adjoint_a=True)) 46 47 48def identity_operator(matrix): 49 """Creates a linear operator from a rank-2 identity tensor.""" 50 51 linear_operator = collections.namedtuple( 52 "LinearOperator", ["shape", "dtype", "apply", "apply_adjoint"]) 53 shape = matrix.get_shape() 54 if shape.is_fully_defined(): 55 shape = shape.as_list() 56 else: 57 shape = array_ops.shape(matrix) 58 return linear_operator( 59 shape=shape, 60 dtype=matrix.dtype, 61 apply=lambda v: v, 62 apply_adjoint=lambda v: v) 63 64 65# TODO(rmlarsen): Measure if we should just call matmul. 66def dot(x, y): 67 return math_ops.reduce_sum(math_ops.conj(x) * y) 68 69 70# TODO(rmlarsen): Implement matrix/vector norm op in C++ in core. 71# We need 1-norm, inf-norm, and Frobenius norm. 72def l2norm_squared(v): 73 return constant_op.constant(2, dtype=v.dtype.base_dtype) * nn_ops.l2_loss(v) 74 75 76def l2norm(v): 77 return math_ops.sqrt(l2norm_squared(v)) 78 79 80def l2normalize(v): 81 norm = l2norm(v) 82 return v / norm, norm 83