1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Logging and Summary Operations.""" 16# pylint: disable=protected-access 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections as py_collections 22import os 23import pprint 24import random 25import sys 26 27from absl import logging 28import six 29 30from tensorflow.python import pywrap_tfe 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import sparse_tensor 34from tensorflow.python.framework import tensor_util 35from tensorflow.python.ops import gen_logging_ops 36from tensorflow.python.ops import string_ops 37# go/tf-wildcard-import 38# pylint: disable=wildcard-import 39from tensorflow.python.ops.gen_logging_ops import * 40# pylint: enable=wildcard-import 41from tensorflow.python.platform import tf_logging 42from tensorflow.python.util import dispatch 43from tensorflow.python.util import nest 44from tensorflow.python.util.deprecation import deprecated 45from tensorflow.python.util.tf_export import tf_export 46 47# Register printing to the cell output if we are in a Colab or Jupyter Notebook. 48try: 49 get_ipython() # Exists in an ipython env like Jupyter or Colab 50 pywrap_tfe.TFE_Py_EnableInteractivePythonLogging() 51except NameError: 52 pass 53 54# The python wrapper for Assert is in control_flow_ops, as the Assert 55# call relies on certain conditionals for its dependencies. Use 56# control_flow_ops.Assert. 57 58# Assert and Print are special symbols in Python 2, so we must 59# have an upper-case version of them. When support for it is dropped, 60# we can allow lowercase. 61# See https://github.com/tensorflow/tensorflow/issues/18053 62 63 64# pylint: disable=invalid-name 65@deprecated("2018-08-20", "Use tf.print instead of tf.Print. Note that " 66 "tf.print returns a no-output operator that directly " 67 "prints the output. Outside of defuns or eager mode, " 68 "this operator will not be executed unless it is " 69 "directly specified in session.run or used as a " 70 "control dependency for other operators. This is " 71 "only a concern in graph mode. Below is an example " 72 "of how to ensure tf.print executes in graph mode:\n") 73@tf_export(v1=["Print"]) 74@dispatch.add_dispatch_support 75def Print(input_, data, message=None, first_n=None, summarize=None, name=None): 76 """Prints a list of tensors. 77 78 This is an identity op (behaves like `tf.identity`) with the side effect 79 of printing `data` when evaluating. 80 81 Note: This op prints to the standard error. It is not currently compatible 82 with jupyter notebook (printing to the notebook *server's* output, not into 83 the notebook). 84 85 Args: 86 input_: A tensor passed through this op. 87 data: A list of tensors to print out when op is evaluated. 88 message: A string, prefix of the error message. 89 first_n: Only log `first_n` number of times. Negative numbers log always; 90 this is the default. 91 summarize: Only print this many entries of each tensor. If None, then a 92 maximum of 3 elements are printed per input tensor. 93 name: A name for the operation (optional). 94 95 Returns: 96 A `Tensor`. Has the same type and contents as `input_`. 97 98 ```python 99 sess = tf.compat.v1.Session() 100 with sess.as_default(): 101 tensor = tf.range(10) 102 print_op = tf.print(tensor) 103 with tf.control_dependencies([print_op]): 104 out = tf.add(tensor, tensor) 105 sess.run(out) 106 ``` 107 """ 108 return gen_logging_ops._print(input_, data, message, first_n, summarize, name) 109 110 111# pylint: enable=invalid-name 112 113 114def _generate_placeholder_string(x, default_placeholder="{}"): 115 """Generate and return a string that does not appear in `x`.""" 116 placeholder = default_placeholder 117 rng = random.Random(5) 118 while placeholder in x: 119 placeholder = placeholder + str(rng.randint(0, 9)) 120 return placeholder 121 122 123def _is_filepath(output_stream): 124 """Returns True if output_stream is a file path.""" 125 return isinstance(output_stream, str) and output_stream.startswith("file://") 126 127 128# Temporarily disable pylint g-doc-args error to allow giving more context 129# about what the kwargs are. 130# Because we are using arbitrary-length positional arguments, python 2 131# does not support explicitly specifying the keyword arguments in the 132# function definition. 133# pylint: disable=g-doc-args 134@tf_export("print") 135@dispatch.add_dispatch_support 136def print_v2(*inputs, **kwargs): 137 """Print the specified inputs. 138 139 A TensorFlow operator that prints the specified inputs to a desired 140 output stream or logging level. The inputs may be dense or sparse Tensors, 141 primitive python objects, data structures that contain tensors, and printable 142 Python objects. Printed tensors will recursively show the first and last 143 elements of each dimension to summarize. 144 145 Example: 146 Single-input usage: 147 148 ```python 149 tensor = tf.range(10) 150 tf.print(tensor, output_stream=sys.stderr) 151 ``` 152 153 (This prints "[0 1 2 ... 7 8 9]" to sys.stderr) 154 155 Multi-input usage: 156 157 ```python 158 tensor = tf.range(10) 159 tf.print("tensors:", tensor, {2: tensor * 2}, output_stream=sys.stdout) 160 ``` 161 162 (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to 163 sys.stdout) 164 165 Changing the input separator: 166 ```python 167 tensor_a = tf.range(2) 168 tensor_b = tensor_a * 2 169 tf.print(tensor_a, tensor_b, output_stream=sys.stderr, sep=',') 170 ``` 171 172 (This prints "[0 1],[0 2]" to sys.stderr) 173 174 Usage in a `tf.function`: 175 176 ```python 177 @tf.function 178 def f(): 179 tensor = tf.range(10) 180 tf.print(tensor, output_stream=sys.stderr) 181 return tensor 182 183 range_tensor = f() 184 ``` 185 186 (This prints "[0 1 2 ... 7 8 9]" to sys.stderr) 187 188 @compatibility(TF 1.x Graphs and Sessions) 189 In graphs manually created outside of `tf.function`, this method returns 190 the created TF operator that prints the data. To make sure the 191 operator runs, users need to pass the produced op to 192 `tf.compat.v1.Session`'s run method, or to use the op as a control 193 dependency for executed ops by specifying 194 `with tf.compat.v1.control_dependencies([print_op])`. 195 @end_compatibility 196 197 Compatibility usage in TF 1.x graphs: 198 199 ```python 200 sess = tf.compat.v1.Session() 201 with sess.as_default(): 202 tensor = tf.range(10) 203 print_op = tf.print("tensors:", tensor, {2: tensor * 2}, 204 output_stream=sys.stdout) 205 with tf.control_dependencies([print_op]): 206 tripled_tensor = tensor * 3 207 sess.run(tripled_tensor) 208 ``` 209 210 (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to 211 sys.stdout) 212 213 Note: In Jupyter notebooks and colabs, `tf.print` prints to the notebook 214 cell outputs. It will not write to the notebook kernel's console logs. 215 216 Args: 217 *inputs: Positional arguments that are the inputs to print. Inputs in the 218 printed output will be separated by spaces. Inputs may be python 219 primitives, tensors, data structures such as dicts and lists that may 220 contain tensors (with the data structures possibly nested in arbitrary 221 ways), and printable python objects. 222 output_stream: The output stream, logging level, or file to print to. 223 Defaults to sys.stderr, but sys.stdout, tf.compat.v1.logging.info, 224 tf.compat.v1.logging.warning, tf.compat.v1.logging.error, 225 absl.logging.info, absl.logging.warning and absl.logging.error are also 226 supported. To print to a file, pass a string started with "file://" 227 followed by the file path, e.g., "file:///tmp/foo.out". 228 summarize: The first and last `summarize` elements within each dimension are 229 recursively printed per Tensor. If None, then the first 3 and last 3 230 elements of each dimension are printed for each tensor. If set to -1, it 231 will print all elements of every tensor. 232 sep: The string to use to separate the inputs. Defaults to " ". 233 end: End character that is appended at the end the printed string. 234 Defaults to the newline character. 235 name: A name for the operation (optional). 236 237 Returns: 238 None when executing eagerly. During graph tracing this returns 239 a TF operator that prints the specified inputs in the specified output 240 stream or logging level. This operator will be automatically executed 241 except inside of `tf.compat.v1` graphs and sessions. 242 243 Raises: 244 ValueError: If an unsupported output stream is specified. 245 """ 246 # Because we are using arbitrary-length positional arguments, python 2 247 # does not support explicitly specifying the keyword arguments in the 248 # function definition. So, we manually get the keyword arguments w/ default 249 # values here. 250 output_stream = kwargs.pop("output_stream", sys.stderr) 251 name = kwargs.pop("name", None) 252 summarize = kwargs.pop("summarize", 3) 253 sep = kwargs.pop("sep", " ") 254 end = kwargs.pop("end", os.linesep) 255 if kwargs: 256 raise ValueError("Unrecognized keyword arguments for tf.print: %s" % kwargs) 257 format_name = None 258 if name: 259 format_name = name + "_format" 260 261 # Match the C++ string constants representing the different output streams. 262 # Keep this updated! 263 output_stream_to_constant = { 264 sys.stdout: "stdout", 265 sys.stderr: "stderr", 266 tf_logging.INFO: "log(info)", 267 tf_logging.info: "log(info)", 268 tf_logging.WARN: "log(warning)", 269 tf_logging.warning: "log(warning)", 270 tf_logging.warn: "log(warning)", 271 tf_logging.ERROR: "log(error)", 272 tf_logging.error: "log(error)", 273 logging.INFO: "log(info)", 274 logging.info: "log(info)", 275 logging.INFO: "log(info)", 276 logging.WARNING: "log(warning)", 277 logging.WARN: "log(warning)", 278 logging.warning: "log(warning)", 279 logging.warn: "log(warning)", 280 logging.ERROR: "log(error)", 281 logging.error: "log(error)", 282 } 283 284 if _is_filepath(output_stream): 285 output_stream_string = output_stream 286 else: 287 output_stream_string = output_stream_to_constant.get(output_stream) 288 if not output_stream_string: 289 raise ValueError("Unsupported output stream, logging level, or file." + 290 str(output_stream) + 291 ". Supported streams are sys.stdout, " 292 "sys.stderr, tf.logging.info, " 293 "tf.logging.warning, tf.logging.error. " + 294 "File needs to be in the form of 'file://<filepath>'.") 295 296 # If we are only printing a single string scalar, there is no need to format 297 if (len(inputs) == 1 and tensor_util.is_tf_type(inputs[0]) and 298 (not isinstance(inputs[0], sparse_tensor.SparseTensor)) and 299 (inputs[0].shape.ndims == 0) and (inputs[0].dtype == dtypes.string)): 300 formatted_string = inputs[0] 301 # Otherwise, we construct an appropriate template for the tensors we are 302 # printing, and format the template using those tensors. 303 else: 304 # For each input to this print function, we extract any nested tensors, 305 # and construct an appropriate template to format representing the 306 # printed input. 307 templates = [] 308 tensors = [] 309 # If an input to the print function is of type `OrderedDict`, sort its 310 # elements by the keys for consistency with the ordering of `nest.flatten`. 311 # This is not needed for `dict` types because `pprint.pformat()` takes care 312 # of printing the template in a sorted fashion. 313 inputs_ordered_dicts_sorted = [] 314 for input_ in inputs: 315 if isinstance(input_, py_collections.OrderedDict): 316 inputs_ordered_dicts_sorted.append( 317 py_collections.OrderedDict(sorted(input_.items()))) 318 else: 319 inputs_ordered_dicts_sorted.append(input_) 320 tensor_free_structure = nest.map_structure( 321 lambda x: "" if tensor_util.is_tf_type(x) else x, 322 inputs_ordered_dicts_sorted) 323 324 tensor_free_template = " ".join( 325 pprint.pformat(x) for x in tensor_free_structure) 326 placeholder = _generate_placeholder_string(tensor_free_template) 327 328 for input_ in inputs: 329 placeholders = [] 330 # Use the nest utilities to flatten & process any nested elements in this 331 # input. The placeholder for a tensor in the template should be the 332 # placeholder string, and the placeholder for a non-tensor can just be 333 # the printed value of the non-tensor itself. 334 for x in nest.flatten(input_): 335 # support sparse tensors 336 if isinstance(x, sparse_tensor.SparseTensor): 337 tensors.extend([x.indices, x.values, x.dense_shape]) 338 placeholders.append( 339 "SparseTensor(indices={}, values={}, shape={})".format( 340 placeholder, placeholder, placeholder)) 341 elif tensor_util.is_tf_type(x): 342 tensors.append(x) 343 placeholders.append(placeholder) 344 else: 345 placeholders.append(x) 346 347 if isinstance(input_, six.string_types): 348 # If the current input to format/print is a normal string, that string 349 # can act as the template. 350 cur_template = input_ 351 else: 352 # We pack the placeholders into a data structure that matches the 353 # input data structure format, then format that data structure 354 # into a string template. 355 # 356 # NOTE: We must use pprint.pformat here for building the template for 357 # unordered data structures such as `dict`, because `str` doesn't 358 # guarantee orderings, while pprint prints in sorted order. pprint 359 # will match the ordering of `nest.flatten`. 360 # This even works when nest.flatten reorders OrderedDicts, because 361 # pprint is printing *after* the OrderedDicts have been reordered. 362 cur_template = pprint.pformat( 363 nest.pack_sequence_as(input_, placeholders)) 364 templates.append(cur_template) 365 366 # We join the templates for the various inputs into a single larger 367 # template. We also remove all quotes surrounding the placeholders, so that 368 # the formatted/printed output will not contain quotes around tensors. 369 # (example of where these quotes might appear: if we have added a 370 # placeholder string into a list, then pretty-formatted that list) 371 template = sep.join(templates) 372 template = template.replace("'" + placeholder + "'", placeholder) 373 formatted_string = string_ops.string_format( 374 inputs=tensors, 375 template=template, 376 placeholder=placeholder, 377 summarize=summarize, 378 name=format_name) 379 380 return gen_logging_ops.print_v2( 381 formatted_string, output_stream=output_stream_string, name=name, end=end) 382 383# pylint: enable=g-doc-args 384 385 386@ops.RegisterGradient("Print") 387def _PrintGrad(op, *grad): 388 return list(grad) + [None] * (len(op.inputs) - 1) 389 390 391def _Collect(val, collections, default_collections): 392 if collections is None: 393 collections = default_collections 394 for key in collections: 395 ops.add_to_collection(key, val) 396 397 398@deprecated( 399 "2016-11-30", "Please switch to tf.summary.histogram. Note that " 400 "tf.summary.histogram uses the node name instead of the tag. " 401 "This means that TensorFlow will automatically de-duplicate summary " 402 "names based on the scope they are created in.") 403def histogram_summary(tag, values, collections=None, name=None): 404 # pylint: disable=line-too-long 405 """Outputs a `Summary` protocol buffer with a histogram. 406 407 This ops is deprecated. Please switch to tf.summary.histogram. 408 409 For an explanation of why this op was deprecated, and information on how to 410 migrate, look 411 ['here'](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/deprecated/__init__.py) 412 413 The generated 414 [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) 415 has one summary value containing a histogram for `values`. 416 417 This op reports an `InvalidArgument` error if any value is not finite. 418 419 Args: 420 tag: A `string` `Tensor`. 0-D. Tag to use for the summary value. 421 values: A real numeric `Tensor`. Any shape. Values to use to build the 422 histogram. 423 collections: Optional list of graph collections keys. The new summary op is 424 added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. 425 name: A name for the operation (optional). 426 427 Returns: 428 A scalar `Tensor` of type `string`. The serialized `Summary` protocol 429 buffer. 430 """ 431 with ops.name_scope(name, "HistogramSummary", [tag, values]) as scope: 432 val = gen_logging_ops.histogram_summary(tag=tag, values=values, name=scope) 433 _Collect(val, collections, [ops.GraphKeys.SUMMARIES]) 434 return val 435 436 437@deprecated( 438 "2016-11-30", "Please switch to tf.summary.image. Note that " 439 "tf.summary.image uses the node name instead of the tag. " 440 "This means that TensorFlow will automatically de-duplicate summary " 441 "names based on the scope they are created in. Also, the max_images " 442 "argument was renamed to max_outputs.") 443def image_summary(tag, tensor, max_images=3, collections=None, name=None): 444 # pylint: disable=line-too-long 445 """Outputs a `Summary` protocol buffer with images. 446 447 For an explanation of why this op was deprecated, and information on how to 448 migrate, look 449 ['here'](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/deprecated/__init__.py) 450 451 The summary has up to `max_images` summary values containing images. The 452 images are built from `tensor` which must be 4-D with shape `[batch_size, 453 height, width, channels]` and where `channels` can be: 454 455 * 1: `tensor` is interpreted as Grayscale. 456 * 3: `tensor` is interpreted as RGB. 457 * 4: `tensor` is interpreted as RGBA. 458 459 The images have the same number of channels as the input tensor. For float 460 input, the values are normalized one image at a time to fit in the range 461 `[0, 255]`. `uint8` values are unchanged. The op uses two different 462 normalization algorithms: 463 464 * If the input values are all positive, they are rescaled so the largest one 465 is 255. 466 467 * If any input value is negative, the values are shifted so input value 0.0 468 is at 127. They are then rescaled so that either the smallest value is 0, 469 or the largest one is 255. 470 471 The `tag` argument is a scalar `Tensor` of type `string`. It is used to 472 build the `tag` of the summary values: 473 474 * If `max_images` is 1, the summary value tag is '*tag*/image'. 475 * If `max_images` is greater than 1, the summary value tags are 476 generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. 477 478 Args: 479 tag: A scalar `Tensor` of type `string`. Used to build the `tag` of the 480 summary values. 481 tensor: A 4-D `uint8` or `float32` `Tensor` of shape `[batch_size, height, 482 width, channels]` where `channels` is 1, 3, or 4. 483 max_images: Max number of batch elements to generate images for. 484 collections: Optional list of ops.GraphKeys. The collections to add the 485 summary to. Defaults to [ops.GraphKeys.SUMMARIES] 486 name: A name for the operation (optional). 487 488 Returns: 489 A scalar `Tensor` of type `string`. The serialized `Summary` protocol 490 buffer. 491 """ 492 with ops.name_scope(name, "ImageSummary", [tag, tensor]) as scope: 493 val = gen_logging_ops.image_summary( 494 tag=tag, tensor=tensor, max_images=max_images, name=scope) 495 _Collect(val, collections, [ops.GraphKeys.SUMMARIES]) 496 return val 497 498 499@deprecated( 500 "2016-11-30", "Please switch to tf.summary.audio. Note that " 501 "tf.summary.audio uses the node name instead of the tag. " 502 "This means that TensorFlow will automatically de-duplicate summary " 503 "names based on the scope they are created in.") 504def audio_summary(tag, 505 tensor, 506 sample_rate, 507 max_outputs=3, 508 collections=None, 509 name=None): 510 # pylint: disable=line-too-long 511 """Outputs a `Summary` protocol buffer with audio. 512 513 This op is deprecated. Please switch to tf.summary.audio. 514 For an explanation of why this op was deprecated, and information on how to 515 migrate, look 516 ['here'](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/deprecated/__init__.py) 517 518 The summary has up to `max_outputs` summary values containing audio. The 519 audio is built from `tensor` which must be 3-D with shape `[batch_size, 520 frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are 521 assumed to be in the range of `[-1.0, 1.0]` with a sample rate of 522 `sample_rate`. 523 524 The `tag` argument is a scalar `Tensor` of type `string`. It is used to 525 build the `tag` of the summary values: 526 527 * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. 528 * If `max_outputs` is greater than 1, the summary value tags are 529 generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. 530 531 Args: 532 tag: A scalar `Tensor` of type `string`. Used to build the `tag` of the 533 summary values. 534 tensor: A 3-D `float32` `Tensor` of shape `[batch_size, frames, channels]` 535 or a 2-D `float32` `Tensor` of shape `[batch_size, frames]`. 536 sample_rate: A Scalar `float32` `Tensor` indicating the sample rate of the 537 signal in hertz. 538 max_outputs: Max number of batch elements to generate audio for. 539 collections: Optional list of ops.GraphKeys. The collections to add the 540 summary to. Defaults to [ops.GraphKeys.SUMMARIES] 541 name: A name for the operation (optional). 542 543 Returns: 544 A scalar `Tensor` of type `string`. The serialized `Summary` protocol 545 buffer. 546 """ 547 with ops.name_scope(name, "AudioSummary", [tag, tensor]) as scope: 548 sample_rate = ops.convert_to_tensor( 549 sample_rate, dtype=dtypes.float32, name="sample_rate") 550 val = gen_logging_ops.audio_summary_v2( 551 tag=tag, 552 tensor=tensor, 553 max_outputs=max_outputs, 554 sample_rate=sample_rate, 555 name=scope) 556 _Collect(val, collections, [ops.GraphKeys.SUMMARIES]) 557 return val 558 559 560@deprecated("2016-11-30", "Please switch to tf.summary.merge.") 561def merge_summary(inputs, collections=None, name=None): 562 # pylint: disable=line-too-long 563 """Merges summaries. 564 565 This op is deprecated. Please switch to tf.compat.v1.summary.merge, which has 566 identical 567 behavior. 568 569 This op creates a 570 [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) 571 protocol buffer that contains the union of all the values in the input 572 summaries. 573 574 When the Op is run, it reports an `InvalidArgument` error if multiple values 575 in the summaries to merge use the same tag. 576 577 Args: 578 inputs: A list of `string` `Tensor` objects containing serialized `Summary` 579 protocol buffers. 580 collections: Optional list of graph collections keys. The new summary op is 581 added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. 582 name: A name for the operation (optional). 583 584 Returns: 585 A scalar `Tensor` of type `string`. The serialized `Summary` protocol 586 buffer resulting from the merging. 587 """ 588 with ops.name_scope(name, "MergeSummary", inputs): 589 val = gen_logging_ops.merge_summary(inputs=inputs, name=name) 590 _Collect(val, collections, []) 591 return val 592 593 594@deprecated("2016-11-30", "Please switch to tf.summary.merge_all.") 595def merge_all_summaries(key=ops.GraphKeys.SUMMARIES): 596 """Merges all summaries collected in the default graph. 597 598 This op is deprecated. Please switch to tf.compat.v1.summary.merge_all, which 599 has 600 identical behavior. 601 602 Args: 603 key: `GraphKey` used to collect the summaries. Defaults to 604 `GraphKeys.SUMMARIES`. 605 606 Returns: 607 If no summaries were collected, returns None. Otherwise returns a scalar 608 `Tensor` of type `string` containing the serialized `Summary` protocol 609 buffer resulting from the merging. 610 """ 611 summary_ops = ops.get_collection(key) 612 if not summary_ops: 613 return None 614 else: 615 return merge_summary(summary_ops) 616 617 618def get_summary_op(): 619 """Returns a single Summary op that would run all summaries. 620 621 Either existing one from `SUMMARY_OP` collection or merges all existing 622 summaries. 623 624 Returns: 625 If no summaries were collected, returns None. Otherwise returns a scalar 626 `Tensor` of type `string` containing the serialized `Summary` protocol 627 buffer resulting from the merging. 628 """ 629 summary_op = ops.get_collection(ops.GraphKeys.SUMMARY_OP) 630 if summary_op is not None: 631 if summary_op: 632 summary_op = summary_op[0] 633 else: 634 summary_op = None 635 if summary_op is None: 636 summary_op = merge_all_summaries() 637 if summary_op is not None: 638 ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op) 639 return summary_op 640 641 642@deprecated( 643 "2016-11-30", "Please switch to tf.summary.scalar. Note that " 644 "tf.summary.scalar uses the node name instead of the tag. " 645 "This means that TensorFlow will automatically de-duplicate summary " 646 "names based on the scope they are created in. Also, passing a " 647 "tensor or list of tags to a scalar summary op is no longer " 648 "supported.") 649def scalar_summary(tags, values, collections=None, name=None): 650 # pylint: disable=line-too-long 651 """Outputs a `Summary` protocol buffer with scalar values. 652 653 This ops is deprecated. Please switch to tf.summary.scalar. 654 For an explanation of why this op was deprecated, and information on how to 655 migrate, look 656 ['here'](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/deprecated/__init__.py) 657 658 The input `tags` and `values` must have the same shape. The generated 659 summary has a summary value for each tag-value pair in `tags` and `values`. 660 661 Args: 662 tags: A `string` `Tensor`. Tags for the summaries. 663 values: A real numeric Tensor. Values for the summaries. 664 collections: Optional list of graph collections keys. The new summary op is 665 added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. 666 name: A name for the operation (optional). 667 668 Returns: 669 A scalar `Tensor` of type `string`. The serialized `Summary` protocol 670 buffer. 671 """ 672 with ops.name_scope(name, "ScalarSummary", [tags, values]) as scope: 673 val = gen_logging_ops.scalar_summary(tags=tags, values=values, name=scope) 674 _Collect(val, collections, [ops.GraphKeys.SUMMARIES]) 675 return val 676 677 678ops.NotDifferentiable("HistogramSummary") 679ops.NotDifferentiable("ImageSummary") 680ops.NotDifferentiable("AudioSummary") 681ops.NotDifferentiable("AudioSummaryV2") 682ops.NotDifferentiable("MergeSummary") 683ops.NotDifferentiable("ScalarSummary") 684ops.NotDifferentiable("TensorSummary") 685ops.NotDifferentiable("TensorSummaryV2") 686ops.NotDifferentiable("Timestamp") 687