1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Logging and Summary Operations.""" 17# pylint: disable=protected-access 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import pprint 23import random 24import sys 25 26import six 27 28from tensorflow.python import pywrap_tensorflow 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.ops import gen_logging_ops 34from tensorflow.python.ops import string_ops 35# go/tf-wildcard-import 36# pylint: disable=wildcard-import 37from tensorflow.python.ops.gen_logging_ops import * 38# pylint: enable=wildcard-import 39from tensorflow.python.platform import tf_logging 40from tensorflow.python.util import nest 41from tensorflow.python.util.deprecation import deprecated 42from tensorflow.python.util.tf_export import tf_export 43 44# Register printing to the cell output if we are in a Colab or Jupyter Notebook. 45try: 46 get_ipython() # Exists in an ipython env like Jupyter or Colab 47 pywrap_tensorflow.TFE_Py_EnableInteractivePythonLogging() 48except NameError: 49 pass 50 51 52# The python wrapper for Assert is in control_flow_ops, as the Assert 53# call relies on certain conditionals for its dependencies. Use 54# control_flow_ops.Assert. 55 56 57# Assert and Print are special symbols in python, so we must 58# have an upper-case version of them. 59# 60# For users with Python 3 or Python 2.7 61# with `from __future__ import print_function`, we could also allow lowercase. 62# See https://github.com/tensorflow/tensorflow/issues/18053 63 64 65# pylint: disable=invalid-name 66@deprecated("2018-08-20", "Use tf.print instead of tf.Print. Note that " 67 "tf.print returns a no-output operator that directly " 68 "prints the output. Outside of defuns or eager mode, " 69 "this operator will not be executed unless it is " 70 "directly specified in session.run or used as a " 71 "control dependency for other operators. This is " 72 "only a concern in graph mode. Below is an example " 73 "of how to ensure tf.print executes in graph mode:\n" 74 """```python 75 sess = tf.Session() 76 with sess.as_default(): 77 tensor = tf.range(10) 78 print_op = tf.print(tensor) 79 with tf.control_dependencies([print_op]): 80 out = tf.add(tensor, tensor) 81 sess.run(out) 82 ``` 83Additionally, to use tf.print in python 2.7, users must make sure to import 84the following: 85 86 `from __future__ import print_function` 87""") 88@tf_export(v1=["Print"]) 89def Print(input_, data, message=None, first_n=None, summarize=None, 90 name=None): 91 """Prints a list of tensors. 92 93 This is an identity op (behaves like `tf.identity`) with the side effect 94 of printing `data` when evaluating. 95 96 Note: This op prints to the standard error. It is not currently compatible 97 with jupyter notebook (printing to the notebook *server's* output, not into 98 the notebook). 99 100 Args: 101 input_: A tensor passed through this op. 102 data: A list of tensors to print out when op is evaluated. 103 message: A string, prefix of the error message. 104 first_n: Only log `first_n` number of times. Negative numbers log always; 105 this is the default. 106 summarize: Only print this many entries of each tensor. If None, then a 107 maximum of 3 elements are printed per input tensor. 108 name: A name for the operation (optional). 109 110 Returns: 111 A `Tensor`. Has the same type and contents as `input_`. 112 """ 113 return gen_logging_ops._print(input_, data, message, first_n, summarize, name) 114# pylint: enable=invalid-name 115 116 117def _generate_placeholder_string(x, default_placeholder="{}"): 118 """Generate and return a string that does not appear in `x`.""" 119 placeholder = default_placeholder 120 rng = random.Random(5) 121 while placeholder in x: 122 placeholder = placeholder + str(rng.randint(0, 9)) 123 return placeholder 124 125 126def _is_filepath(output_stream): 127 """Returns True if output_stream is a file path.""" 128 return isinstance(output_stream, str) and output_stream.startswith("file://") 129 130 131# Temporarily disable pylint g-doc-args error to allow giving more context 132# about what the kwargs are. 133# Because we are using arbitrary-length positional arguments, python 2 134# does not support explicitly specifying the keyword arguments in the 135# function definition. 136# pylint: disable=g-doc-args 137@tf_export("print") 138def print_v2(*inputs, **kwargs): 139 """Print the specified inputs. 140 141 Returns an operator that prints the specified inputs to a desired 142 output stream or logging level. The inputs may be dense or sparse Tensors, 143 primitive python objects, data structures that contain Tensors, and printable 144 python objects. Printed tensors will recursively show the first and last 145 `summarize` elements of each dimension. 146 147 With eager execution enabled and/or inside a `tf.contrib.eager.defun` this 148 operator will automatically execute, and users only need to call `tf.print` 149 without using the return value. When constructing graphs outside of a 150 `tf.contrib.eager.defun`, one must either include the returned op 151 in the input to `session.run`, or use the operator as a control dependency for 152 executed ops by specifying `with tf.control_dependencies([print_op])`. 153 154 @compatibility(python2) 155 In python 2.7, make sure to import the following: 156 `from __future__ import print_function` 157 @end_compatibility 158 159 Example: 160 Single-input usage: 161 ```python 162 tf.enable_eager_execution() 163 tensor = tf.range(10) 164 tf.print(tensor, output_stream=sys.stderr) 165 ``` 166 (This prints "[0 1 2 ... 7 8 9]" to sys.stderr) 167 168 Multi-input usage: 169 ```python 170 tf.enable_eager_execution() 171 tensor = tf.range(10) 172 tf.print("tensors:", tensor, {2: tensor * 2}, output_stream=sys.stdout) 173 ``` 174 (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to 175 sys.stdout) 176 177 Usage in a defun: 178 ```python 179 tf.enable_eager_execution() 180 181 @tf.contrib.eager.defun 182 def f(): 183 tensor = tf.range(10) 184 tf.print(tensor, output_stream=sys.stderr) 185 return tensor 186 187 range_tensor = f() 188 ``` 189 (This prints "[0 1 2 ... 7 8 9]" to sys.stderr) 190 191 Usage when constructing graphs: 192 ```python 193 sess = tf.Session() 194 with sess.as_default(): 195 tensor = tf.range(10) 196 print_op = tf.print("tensors:", tensor, {2: tensor * 2}, 197 output_stream=sys.stdout) 198 with tf.control_dependencies([print_op]): 199 tripled_tensor = tensor * 3 200 sess.run(tripled_tensor) 201 ``` 202 (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to 203 sys.stdout) 204 205 Note: In Jupyter notebooks and colabs, this operator prints to the notebook 206 cell outputs. It will not write to the notebook kernel's console logs. 207 208 Args: 209 *inputs: Positional arguments that are the inputs to print. Inputs in the 210 printed output will be separated by spaces. Inputs may be python 211 primitives, tensors, data structures such as dicts and lists that 212 may contain tensors (with the data structures possibly nested in 213 arbitrary ways), and printable python objects. 214 output_stream: The output stream, logging level, or file to print to. 215 Defaults to sys.stderr, but sys.stdout, tf.logging.info, 216 tf.logging.warning, and tf.logging.error are also supported. To print to 217 a file, pass a string started with "file://" followed by the file path, 218 e.g., "file:///tmp/foo.out". 219 summarize: The first and last `summarize` elements within each dimension are 220 recursively printed per Tensor. If None, then the first 3 and last 3 221 elements of each dimension are printed for each tensor. If set to -1, it 222 will print all elements of every tensor. 223 name: A name for the operation (optional). 224 225 Returns: 226 A print operator that prints the specified inputs in the specified output 227 stream or logging level. 228 229 Raises: 230 ValueError: If an unsupported output stream is specified. 231 """ 232 # Because we are using arbitrary-length positional arguments, python 2 233 # does not support explicitly specifying the keyword arguments in the 234 # function definition. So, we manually get the keyword arguments w/ default 235 # values here. 236 output_stream = kwargs.pop("output_stream", sys.stderr) 237 name = kwargs.pop("name", None) 238 summarize = kwargs.pop("summarize", 3) 239 if kwargs: 240 raise ValueError("Unrecognized keyword arguments for tf.print: %s" % kwargs) 241 format_name = None 242 if name: 243 format_name = name + "_format" 244 245 # Match the C++ string constants representing the different output streams. 246 # Keep this updated! 247 output_stream_to_constant = { 248 sys.stdout: "stdout", 249 sys.stderr: "stderr", 250 tf_logging.INFO: "log(info)", 251 tf_logging.info: "log(info)", 252 tf_logging.WARN: "log(warning)", 253 tf_logging.warning: "log(warning)", 254 tf_logging.warn: "log(warning)", 255 tf_logging.ERROR: "log(error)", 256 tf_logging.error: "log(error)", 257 } 258 259 if _is_filepath(output_stream): 260 output_stream_string = output_stream 261 else: 262 output_stream_string = output_stream_to_constant.get(output_stream) 263 if not output_stream_string: 264 raise ValueError( 265 "Unsupported output stream, logging level, or file." + 266 str(output_stream) + ". Supported streams are sys.stdout, " 267 "sys.stderr, tf.logging.info, " 268 "tf.logging.warning, tf.logging.error. " + 269 "File needs to be in the form of 'file://<filepath>'.") 270 271 # If we are only printing a single string scalar, there is no need to format 272 if (len(inputs) == 1 and tensor_util.is_tensor(inputs[0]) 273 and (not isinstance(inputs[0], sparse_tensor.SparseTensor)) 274 and (inputs[0].shape.ndims == 0)and (inputs[0].dtype == dtypes.string)): 275 formatted_string = inputs[0] 276 # Otherwise, we construct an appropriate template for the tensors we are 277 # printing, and format the template using those tensors. 278 else: 279 # For each input to this print function, we extract any nested tensors, 280 # and construct an appropriate template to format representing the 281 # printed input. 282 templates = [] 283 tensors = [] 284 tensor_free_structure = nest.map_structure( 285 lambda x: "" if tensor_util.is_tensor(x) else x, 286 inputs) 287 tensor_free_template = " ".join(pprint.pformat(x) 288 for x in tensor_free_structure) 289 placeholder = _generate_placeholder_string(tensor_free_template) 290 291 for input_ in inputs: 292 placeholders = [] 293 # Use the nest utilities to flatten & process any nested elements in this 294 # input. The placeholder for a tensor in the template should be the 295 # placeholder string, and the placeholder for a non-tensor can just be 296 # the printed value of the non-tensor itself. 297 for x in nest.flatten(input_): 298 # support sparse tensors 299 if isinstance(x, sparse_tensor.SparseTensor): 300 tensors.extend([x.indices, x.values, x.dense_shape]) 301 placeholders.append( 302 "SparseTensor(indices={}, values={}, shape={})".format( 303 placeholder, placeholder, placeholder) 304 ) 305 elif tensor_util.is_tensor(x): 306 tensors.append(x) 307 placeholders.append(placeholder) 308 else: 309 placeholders.append(x) 310 311 if isinstance(input_, six.string_types): 312 # If the current input to format/print is a normal string, that string 313 # can act as the template. 314 cur_template = input_ 315 else: 316 # We pack the placeholders into a data structure that matches the 317 # input data structure format, then format that data structure 318 # into a string template. 319 # 320 # NOTE: We must use pprint.pformat here for building the template for 321 # unordered data structures such as `dict`, because `str` doesn't 322 # guarantee orderings, while pprint prints in sorted order. pprint 323 # will match the ordering of `nest.flatten`. 324 # This even works when nest.flatten reorders OrderedDicts, because 325 # pprint is printing *after* the OrderedDicts have been reordered. 326 cur_template = pprint.pformat( 327 nest.pack_sequence_as(input_, placeholders)) 328 templates.append(cur_template) 329 330 # We join the templates for the various inputs into a single larger 331 # template. We also remove all quotes surrounding the placeholders, so that 332 # the formatted/printed output will not contain quotes around tensors. 333 # (example of where these quotes might appear: if we have added a 334 # placeholder string into a list, then pretty-formatted that list) 335 template = " ".join(templates) 336 template = template.replace("'" + placeholder + "'", placeholder) 337 formatted_string = string_ops.string_format( 338 inputs=tensors, template=template, placeholder=placeholder, 339 summarize=summarize, 340 name=format_name) 341 342 return gen_logging_ops.print_v2(formatted_string, 343 output_stream=output_stream_string, 344 name=name) 345# pylint: enable=g-doc-args 346 347 348@ops.RegisterGradient("Print") 349def _PrintGrad(op, *grad): 350 return list(grad) + [None] * (len(op.inputs) - 1) 351 352 353def _Collect(val, collections, default_collections): 354 if collections is None: 355 collections = default_collections 356 for key in collections: 357 ops.add_to_collection(key, val) 358 359 360@deprecated( 361 "2016-11-30", "Please switch to tf.summary.histogram. Note that " 362 "tf.summary.histogram uses the node name instead of the tag. " 363 "This means that TensorFlow will automatically de-duplicate summary " 364 "names based on the scope they are created in.") 365def histogram_summary(tag, values, collections=None, name=None): 366 # pylint: disable=line-too-long 367 """Outputs a `Summary` protocol buffer with a histogram. 368 369 This ops is deprecated. Please switch to tf.summary.histogram. 370 371 For an explanation of why this op was deprecated, and information on how to 372 migrate, look ['here'](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/deprecated/__init__.py) 373 374 The generated 375 [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) 376 has one summary value containing a histogram for `values`. 377 378 This op reports an `InvalidArgument` error if any value is not finite. 379 380 Args: 381 tag: A `string` `Tensor`. 0-D. Tag to use for the summary value. 382 values: A real numeric `Tensor`. Any shape. Values to use to 383 build the histogram. 384 collections: Optional list of graph collections keys. The new summary op is 385 added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. 386 name: A name for the operation (optional). 387 388 Returns: 389 A scalar `Tensor` of type `string`. The serialized `Summary` protocol 390 buffer. 391 """ 392 with ops.name_scope(name, "HistogramSummary", [tag, values]) as scope: 393 val = gen_logging_ops.histogram_summary( 394 tag=tag, values=values, name=scope) 395 _Collect(val, collections, [ops.GraphKeys.SUMMARIES]) 396 return val 397 398 399@deprecated( 400 "2016-11-30", "Please switch to tf.summary.image. Note that " 401 "tf.summary.image uses the node name instead of the tag. " 402 "This means that TensorFlow will automatically de-duplicate summary " 403 "names based on the scope they are created in. Also, the max_images " 404 "argument was renamed to max_outputs.") 405def image_summary(tag, tensor, max_images=3, collections=None, name=None): 406 # pylint: disable=line-too-long 407 """Outputs a `Summary` protocol buffer with images. 408 409 For an explanation of why this op was deprecated, and information on how to 410 migrate, look ['here'](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/deprecated/__init__.py) 411 412 The summary has up to `max_images` summary values containing images. The 413 images are built from `tensor` which must be 4-D with shape `[batch_size, 414 height, width, channels]` and where `channels` can be: 415 416 * 1: `tensor` is interpreted as Grayscale. 417 * 3: `tensor` is interpreted as RGB. 418 * 4: `tensor` is interpreted as RGBA. 419 420 The images have the same number of channels as the input tensor. For float 421 input, the values are normalized one image at a time to fit in the range 422 `[0, 255]`. `uint8` values are unchanged. The op uses two different 423 normalization algorithms: 424 425 * If the input values are all positive, they are rescaled so the largest one 426 is 255. 427 428 * If any input value is negative, the values are shifted so input value 0.0 429 is at 127. They are then rescaled so that either the smallest value is 0, 430 or the largest one is 255. 431 432 The `tag` argument is a scalar `Tensor` of type `string`. It is used to 433 build the `tag` of the summary values: 434 435 * If `max_images` is 1, the summary value tag is '*tag*/image'. 436 * If `max_images` is greater than 1, the summary value tags are 437 generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. 438 439 Args: 440 tag: A scalar `Tensor` of type `string`. Used to build the `tag` 441 of the summary values. 442 tensor: A 4-D `uint8` or `float32` `Tensor` of shape `[batch_size, height, 443 width, channels]` where `channels` is 1, 3, or 4. 444 max_images: Max number of batch elements to generate images for. 445 collections: Optional list of ops.GraphKeys. The collections to add the 446 summary to. Defaults to [ops.GraphKeys.SUMMARIES] 447 name: A name for the operation (optional). 448 449 Returns: 450 A scalar `Tensor` of type `string`. The serialized `Summary` protocol 451 buffer. 452 """ 453 with ops.name_scope(name, "ImageSummary", [tag, tensor]) as scope: 454 val = gen_logging_ops.image_summary( 455 tag=tag, tensor=tensor, max_images=max_images, name=scope) 456 _Collect(val, collections, [ops.GraphKeys.SUMMARIES]) 457 return val 458 459 460@deprecated( 461 "2016-11-30", "Please switch to tf.summary.audio. Note that " 462 "tf.summary.audio uses the node name instead of the tag. " 463 "This means that TensorFlow will automatically de-duplicate summary " 464 "names based on the scope they are created in.") 465def audio_summary(tag, 466 tensor, 467 sample_rate, 468 max_outputs=3, 469 collections=None, 470 name=None): 471 # pylint: disable=line-too-long 472 """Outputs a `Summary` protocol buffer with audio. 473 474 This op is deprecated. Please switch to tf.summary.audio. 475 For an explanation of why this op was deprecated, and information on how to 476 migrate, look ['here'](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/deprecated/__init__.py) 477 478 The summary has up to `max_outputs` summary values containing audio. The 479 audio is built from `tensor` which must be 3-D with shape `[batch_size, 480 frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are 481 assumed to be in the range of `[-1.0, 1.0]` with a sample rate of 482 `sample_rate`. 483 484 The `tag` argument is a scalar `Tensor` of type `string`. It is used to 485 build the `tag` of the summary values: 486 487 * If `max_outputs` is 1, the summary value tag is '*tag*/audio'. 488 * If `max_outputs` is greater than 1, the summary value tags are 489 generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. 490 491 Args: 492 tag: A scalar `Tensor` of type `string`. Used to build the `tag` 493 of the summary values. 494 tensor: A 3-D `float32` `Tensor` of shape `[batch_size, frames, channels]` 495 or a 2-D `float32` `Tensor` of shape `[batch_size, frames]`. 496 sample_rate: A Scalar `float32` `Tensor` indicating the sample rate of the 497 signal in hertz. 498 max_outputs: Max number of batch elements to generate audio for. 499 collections: Optional list of ops.GraphKeys. The collections to add the 500 summary to. Defaults to [ops.GraphKeys.SUMMARIES] 501 name: A name for the operation (optional). 502 503 Returns: 504 A scalar `Tensor` of type `string`. The serialized `Summary` protocol 505 buffer. 506 """ 507 with ops.name_scope(name, "AudioSummary", [tag, tensor]) as scope: 508 sample_rate = ops.convert_to_tensor(sample_rate, dtype=dtypes.float32, 509 name="sample_rate") 510 val = gen_logging_ops.audio_summary_v2( 511 tag=tag, 512 tensor=tensor, 513 max_outputs=max_outputs, 514 sample_rate=sample_rate, 515 name=scope) 516 _Collect(val, collections, [ops.GraphKeys.SUMMARIES]) 517 return val 518 519 520@deprecated("2016-11-30", "Please switch to tf.summary.merge.") 521def merge_summary(inputs, collections=None, name=None): 522 # pylint: disable=line-too-long 523 """Merges summaries. 524 525 This op is deprecated. Please switch to tf.summary.merge, which has identical 526 behavior. 527 528 This op creates a 529 [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) 530 protocol buffer that contains the union of all the values in the input 531 summaries. 532 533 When the Op is run, it reports an `InvalidArgument` error if multiple values 534 in the summaries to merge use the same tag. 535 536 Args: 537 inputs: A list of `string` `Tensor` objects containing serialized `Summary` 538 protocol buffers. 539 collections: Optional list of graph collections keys. The new summary op is 540 added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. 541 name: A name for the operation (optional). 542 543 Returns: 544 A scalar `Tensor` of type `string`. The serialized `Summary` protocol 545 buffer resulting from the merging. 546 """ 547 with ops.name_scope(name, "MergeSummary", inputs): 548 val = gen_logging_ops.merge_summary(inputs=inputs, name=name) 549 _Collect(val, collections, []) 550 return val 551 552 553@deprecated("2016-11-30", "Please switch to tf.summary.merge_all.") 554def merge_all_summaries(key=ops.GraphKeys.SUMMARIES): 555 """Merges all summaries collected in the default graph. 556 557 This op is deprecated. Please switch to tf.summary.merge_all, which has 558 identical behavior. 559 560 Args: 561 key: `GraphKey` used to collect the summaries. Defaults to 562 `GraphKeys.SUMMARIES`. 563 564 Returns: 565 If no summaries were collected, returns None. Otherwise returns a scalar 566 `Tensor` of type `string` containing the serialized `Summary` protocol 567 buffer resulting from the merging. 568 """ 569 summary_ops = ops.get_collection(key) 570 if not summary_ops: 571 return None 572 else: 573 return merge_summary(summary_ops) 574 575 576def get_summary_op(): 577 """Returns a single Summary op that would run all summaries. 578 579 Either existing one from `SUMMARY_OP` collection or merges all existing 580 summaries. 581 582 Returns: 583 If no summaries were collected, returns None. Otherwise returns a scalar 584 `Tensor` of type `string` containing the serialized `Summary` protocol 585 buffer resulting from the merging. 586 """ 587 summary_op = ops.get_collection(ops.GraphKeys.SUMMARY_OP) 588 if summary_op is not None: 589 if summary_op: 590 summary_op = summary_op[0] 591 else: 592 summary_op = None 593 if summary_op is None: 594 summary_op = merge_all_summaries() 595 if summary_op is not None: 596 ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op) 597 return summary_op 598 599 600@deprecated( 601 "2016-11-30", "Please switch to tf.summary.scalar. Note that " 602 "tf.summary.scalar uses the node name instead of the tag. " 603 "This means that TensorFlow will automatically de-duplicate summary " 604 "names based on the scope they are created in. Also, passing a " 605 "tensor or list of tags to a scalar summary op is no longer " 606 "supported.") 607def scalar_summary(tags, values, collections=None, name=None): 608 # pylint: disable=line-too-long 609 """Outputs a `Summary` protocol buffer with scalar values. 610 611 This ops is deprecated. Please switch to tf.summary.scalar. 612 For an explanation of why this op was deprecated, and information on how to 613 migrate, look ['here'](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/deprecated/__init__.py) 614 615 The input `tags` and `values` must have the same shape. The generated 616 summary has a summary value for each tag-value pair in `tags` and `values`. 617 618 Args: 619 tags: A `string` `Tensor`. Tags for the summaries. 620 values: A real numeric Tensor. Values for the summaries. 621 collections: Optional list of graph collections keys. The new summary op is 622 added to these collections. Defaults to `[GraphKeys.SUMMARIES]`. 623 name: A name for the operation (optional). 624 625 Returns: 626 A scalar `Tensor` of type `string`. The serialized `Summary` protocol 627 buffer. 628 """ 629 with ops.name_scope(name, "ScalarSummary", [tags, values]) as scope: 630 val = gen_logging_ops.scalar_summary(tags=tags, values=values, name=scope) 631 _Collect(val, collections, [ops.GraphKeys.SUMMARIES]) 632 return val 633 634ops.NotDifferentiable("HistogramSummary") 635ops.NotDifferentiable("ImageSummary") 636ops.NotDifferentiable("AudioSummary") 637ops.NotDifferentiable("AudioSummaryV2") 638ops.NotDifferentiable("MergeSummary") 639ops.NotDifferentiable("ScalarSummary") 640ops.NotDifferentiable("TensorSummary") 641ops.NotDifferentiable("TensorSummaryV2") 642ops.NotDifferentiable("Timestamp") 643