• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the definition for inception v3 classification network."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib import layers
22from tensorflow.contrib.framework.python.ops import arg_scope
23from tensorflow.contrib.layers.python.layers import initializers
24from tensorflow.contrib.layers.python.layers import layers as layers_lib
25from tensorflow.contrib.layers.python.layers import regularizers
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import init_ops
29from tensorflow.python.ops import nn_ops
30from tensorflow.python.ops import variable_scope
31
32trunc_normal = lambda stddev: init_ops.truncated_normal_initializer(0.0, stddev)
33
34
35def inception_v3_base(inputs,
36                      final_endpoint='Mixed_7c',
37                      min_depth=16,
38                      depth_multiplier=1.0,
39                      scope=None):
40  """Inception model from http://arxiv.org/abs/1512.00567.
41
42  Constructs an Inception v3 network from inputs to the given final endpoint.
43  This method can construct the network up to the final inception block
44  Mixed_7c.
45
46  Note that the names of the layers in the paper do not correspond to the names
47  of the endpoints registered by this function although they build the same
48  network.
49
50  Here is a mapping from the old_names to the new names:
51  Old name          | New name
52  =======================================
53  conv0             | Conv2d_1a_3x3
54  conv1             | Conv2d_2a_3x3
55  conv2             | Conv2d_2b_3x3
56  pool1             | MaxPool_3a_3x3
57  conv3             | Conv2d_3b_1x1
58  conv4             | Conv2d_4a_3x3
59  pool2             | MaxPool_5a_3x3
60  mixed_35x35x256a  | Mixed_5b
61  mixed_35x35x288a  | Mixed_5c
62  mixed_35x35x288b  | Mixed_5d
63  mixed_17x17x768a  | Mixed_6a
64  mixed_17x17x768b  | Mixed_6b
65  mixed_17x17x768c  | Mixed_6c
66  mixed_17x17x768d  | Mixed_6d
67  mixed_17x17x768e  | Mixed_6e
68  mixed_8x8x1280a   | Mixed_7a
69  mixed_8x8x2048a   | Mixed_7b
70  mixed_8x8x2048b   | Mixed_7c
71
72  Args:
73    inputs: a tensor of size [batch_size, height, width, channels].
74    final_endpoint: specifies the endpoint to construct the network up to. It
75      can be one of ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
76      'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3',
77      'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c',
78      'Mixed_6d', 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c'].
79    min_depth: Minimum depth value (number of channels) for all convolution ops.
80      Enforced when depth_multiplier < 1, and not an active constraint when
81      depth_multiplier >= 1.
82    depth_multiplier: Float multiplier for the depth (number of channels)
83      for all convolution ops. The value must be greater than zero. Typical
84      usage will be to set this value in (0, 1) to reduce the number of
85      parameters or computation cost of the model.
86    scope: Optional variable_scope.
87
88  Returns:
89    tensor_out: output tensor corresponding to the final_endpoint.
90    end_points: a set of activations for external use, for example summaries or
91                losses.
92
93  Raises:
94    ValueError: if final_endpoint is not set to one of the predefined values,
95                or depth_multiplier <= 0
96  """
97  # end_points will collect relevant activations for external use, for example
98  # summaries or losses.
99  end_points = {}
100
101  if depth_multiplier <= 0:
102    raise ValueError('depth_multiplier is not greater than zero.')
103  depth = lambda d: max(int(d * depth_multiplier), min_depth)
104
105  with variable_scope.variable_scope(scope, 'InceptionV3', [inputs]):
106    with arg_scope(
107        [layers.conv2d, layers_lib.max_pool2d, layers_lib.avg_pool2d],
108        stride=1,
109        padding='VALID'):
110      # 299 x 299 x 3
111      end_point = 'Conv2d_1a_3x3'
112      net = layers.conv2d(inputs, depth(32), [3, 3], stride=2, scope=end_point)
113      end_points[end_point] = net
114      if end_point == final_endpoint:
115        return net, end_points
116      # 149 x 149 x 32
117      end_point = 'Conv2d_2a_3x3'
118      net = layers.conv2d(net, depth(32), [3, 3], scope=end_point)
119      end_points[end_point] = net
120      if end_point == final_endpoint:
121        return net, end_points
122      # 147 x 147 x 32
123      end_point = 'Conv2d_2b_3x3'
124      net = layers.conv2d(
125          net, depth(64), [3, 3], padding='SAME', scope=end_point)
126      end_points[end_point] = net
127      if end_point == final_endpoint:
128        return net, end_points
129      # 147 x 147 x 64
130      end_point = 'MaxPool_3a_3x3'
131      net = layers_lib.max_pool2d(net, [3, 3], stride=2, scope=end_point)
132      end_points[end_point] = net
133      if end_point == final_endpoint:
134        return net, end_points
135      # 73 x 73 x 64
136      end_point = 'Conv2d_3b_1x1'
137      net = layers.conv2d(net, depth(80), [1, 1], scope=end_point)
138      end_points[end_point] = net
139      if end_point == final_endpoint:
140        return net, end_points
141      # 73 x 73 x 80.
142      end_point = 'Conv2d_4a_3x3'
143      net = layers.conv2d(net, depth(192), [3, 3], scope=end_point)
144      end_points[end_point] = net
145      if end_point == final_endpoint:
146        return net, end_points
147      # 71 x 71 x 192.
148      end_point = 'MaxPool_5a_3x3'
149      net = layers_lib.max_pool2d(net, [3, 3], stride=2, scope=end_point)
150      end_points[end_point] = net
151      if end_point == final_endpoint:
152        return net, end_points
153      # 35 x 35 x 192.
154
155      # Inception blocks
156    with arg_scope(
157        [layers.conv2d, layers_lib.max_pool2d, layers_lib.avg_pool2d],
158        stride=1,
159        padding='SAME'):
160      # mixed: 35 x 35 x 256.
161      end_point = 'Mixed_5b'
162      with variable_scope.variable_scope(end_point):
163        with variable_scope.variable_scope('Branch_0'):
164          branch_0 = layers.conv2d(
165              net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
166        with variable_scope.variable_scope('Branch_1'):
167          branch_1 = layers.conv2d(
168              net, depth(48), [1, 1], scope='Conv2d_0a_1x1')
169          branch_1 = layers.conv2d(
170              branch_1, depth(64), [5, 5], scope='Conv2d_0b_5x5')
171        with variable_scope.variable_scope('Branch_2'):
172          branch_2 = layers.conv2d(
173              net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
174          branch_2 = layers.conv2d(
175              branch_2, depth(96), [3, 3], scope='Conv2d_0b_3x3')
176          branch_2 = layers.conv2d(
177              branch_2, depth(96), [3, 3], scope='Conv2d_0c_3x3')
178        with variable_scope.variable_scope('Branch_3'):
179          branch_3 = layers_lib.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
180          branch_3 = layers.conv2d(
181              branch_3, depth(32), [1, 1], scope='Conv2d_0b_1x1')
182        net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
183      end_points[end_point] = net
184      if end_point == final_endpoint:
185        return net, end_points
186
187      # mixed_1: 35 x 35 x 288.
188      end_point = 'Mixed_5c'
189      with variable_scope.variable_scope(end_point):
190        with variable_scope.variable_scope('Branch_0'):
191          branch_0 = layers.conv2d(
192              net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
193        with variable_scope.variable_scope('Branch_1'):
194          branch_1 = layers.conv2d(
195              net, depth(48), [1, 1], scope='Conv2d_0b_1x1')
196          branch_1 = layers.conv2d(
197              branch_1, depth(64), [5, 5], scope='Conv_1_0c_5x5')
198        with variable_scope.variable_scope('Branch_2'):
199          branch_2 = layers.conv2d(
200              net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
201          branch_2 = layers.conv2d(
202              branch_2, depth(96), [3, 3], scope='Conv2d_0b_3x3')
203          branch_2 = layers.conv2d(
204              branch_2, depth(96), [3, 3], scope='Conv2d_0c_3x3')
205        with variable_scope.variable_scope('Branch_3'):
206          branch_3 = layers_lib.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
207          branch_3 = layers.conv2d(
208              branch_3, depth(64), [1, 1], scope='Conv2d_0b_1x1')
209        net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
210      end_points[end_point] = net
211      if end_point == final_endpoint:
212        return net, end_points
213
214      # mixed_2: 35 x 35 x 288.
215      end_point = 'Mixed_5d'
216      with variable_scope.variable_scope(end_point):
217        with variable_scope.variable_scope('Branch_0'):
218          branch_0 = layers.conv2d(
219              net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
220        with variable_scope.variable_scope('Branch_1'):
221          branch_1 = layers.conv2d(
222              net, depth(48), [1, 1], scope='Conv2d_0a_1x1')
223          branch_1 = layers.conv2d(
224              branch_1, depth(64), [5, 5], scope='Conv2d_0b_5x5')
225        with variable_scope.variable_scope('Branch_2'):
226          branch_2 = layers.conv2d(
227              net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
228          branch_2 = layers.conv2d(
229              branch_2, depth(96), [3, 3], scope='Conv2d_0b_3x3')
230          branch_2 = layers.conv2d(
231              branch_2, depth(96), [3, 3], scope='Conv2d_0c_3x3')
232        with variable_scope.variable_scope('Branch_3'):
233          branch_3 = layers_lib.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
234          branch_3 = layers.conv2d(
235              branch_3, depth(64), [1, 1], scope='Conv2d_0b_1x1')
236        net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
237      end_points[end_point] = net
238      if end_point == final_endpoint:
239        return net, end_points
240
241      # mixed_3: 17 x 17 x 768.
242      end_point = 'Mixed_6a'
243      with variable_scope.variable_scope(end_point):
244        with variable_scope.variable_scope('Branch_0'):
245          branch_0 = layers.conv2d(
246              net,
247              depth(384), [3, 3],
248              stride=2,
249              padding='VALID',
250              scope='Conv2d_1a_1x1')
251        with variable_scope.variable_scope('Branch_1'):
252          branch_1 = layers.conv2d(
253              net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
254          branch_1 = layers.conv2d(
255              branch_1, depth(96), [3, 3], scope='Conv2d_0b_3x3')
256          branch_1 = layers.conv2d(
257              branch_1,
258              depth(96), [3, 3],
259              stride=2,
260              padding='VALID',
261              scope='Conv2d_1a_1x1')
262        with variable_scope.variable_scope('Branch_2'):
263          branch_2 = layers_lib.max_pool2d(
264              net, [3, 3], stride=2, padding='VALID', scope='MaxPool_1a_3x3')
265        net = array_ops.concat([branch_0, branch_1, branch_2], 3)
266      end_points[end_point] = net
267      if end_point == final_endpoint:
268        return net, end_points
269
270      # mixed4: 17 x 17 x 768.
271      end_point = 'Mixed_6b'
272      with variable_scope.variable_scope(end_point):
273        with variable_scope.variable_scope('Branch_0'):
274          branch_0 = layers.conv2d(
275              net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
276        with variable_scope.variable_scope('Branch_1'):
277          branch_1 = layers.conv2d(
278              net, depth(128), [1, 1], scope='Conv2d_0a_1x1')
279          branch_1 = layers.conv2d(
280              branch_1, depth(128), [1, 7], scope='Conv2d_0b_1x7')
281          branch_1 = layers.conv2d(
282              branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1')
283        with variable_scope.variable_scope('Branch_2'):
284          branch_2 = layers.conv2d(
285              net, depth(128), [1, 1], scope='Conv2d_0a_1x1')
286          branch_2 = layers.conv2d(
287              branch_2, depth(128), [7, 1], scope='Conv2d_0b_7x1')
288          branch_2 = layers.conv2d(
289              branch_2, depth(128), [1, 7], scope='Conv2d_0c_1x7')
290          branch_2 = layers.conv2d(
291              branch_2, depth(128), [7, 1], scope='Conv2d_0d_7x1')
292          branch_2 = layers.conv2d(
293              branch_2, depth(192), [1, 7], scope='Conv2d_0e_1x7')
294        with variable_scope.variable_scope('Branch_3'):
295          branch_3 = layers_lib.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
296          branch_3 = layers.conv2d(
297              branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1')
298        net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
299      end_points[end_point] = net
300      if end_point == final_endpoint:
301        return net, end_points
302
303      # mixed_5: 17 x 17 x 768.
304      end_point = 'Mixed_6c'
305      with variable_scope.variable_scope(end_point):
306        with variable_scope.variable_scope('Branch_0'):
307          branch_0 = layers.conv2d(
308              net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
309        with variable_scope.variable_scope('Branch_1'):
310          branch_1 = layers.conv2d(
311              net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
312          branch_1 = layers.conv2d(
313              branch_1, depth(160), [1, 7], scope='Conv2d_0b_1x7')
314          branch_1 = layers.conv2d(
315              branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1')
316        with variable_scope.variable_scope('Branch_2'):
317          branch_2 = layers.conv2d(
318              net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
319          branch_2 = layers.conv2d(
320              branch_2, depth(160), [7, 1], scope='Conv2d_0b_7x1')
321          branch_2 = layers.conv2d(
322              branch_2, depth(160), [1, 7], scope='Conv2d_0c_1x7')
323          branch_2 = layers.conv2d(
324              branch_2, depth(160), [7, 1], scope='Conv2d_0d_7x1')
325          branch_2 = layers.conv2d(
326              branch_2, depth(192), [1, 7], scope='Conv2d_0e_1x7')
327        with variable_scope.variable_scope('Branch_3'):
328          branch_3 = layers_lib.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
329          branch_3 = layers.conv2d(
330              branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1')
331        net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
332      end_points[end_point] = net
333      if end_point == final_endpoint:
334        return net, end_points
335      # mixed_6: 17 x 17 x 768.
336      end_point = 'Mixed_6d'
337      with variable_scope.variable_scope(end_point):
338        with variable_scope.variable_scope('Branch_0'):
339          branch_0 = layers.conv2d(
340              net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
341        with variable_scope.variable_scope('Branch_1'):
342          branch_1 = layers.conv2d(
343              net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
344          branch_1 = layers.conv2d(
345              branch_1, depth(160), [1, 7], scope='Conv2d_0b_1x7')
346          branch_1 = layers.conv2d(
347              branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1')
348        with variable_scope.variable_scope('Branch_2'):
349          branch_2 = layers.conv2d(
350              net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
351          branch_2 = layers.conv2d(
352              branch_2, depth(160), [7, 1], scope='Conv2d_0b_7x1')
353          branch_2 = layers.conv2d(
354              branch_2, depth(160), [1, 7], scope='Conv2d_0c_1x7')
355          branch_2 = layers.conv2d(
356              branch_2, depth(160), [7, 1], scope='Conv2d_0d_7x1')
357          branch_2 = layers.conv2d(
358              branch_2, depth(192), [1, 7], scope='Conv2d_0e_1x7')
359        with variable_scope.variable_scope('Branch_3'):
360          branch_3 = layers_lib.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
361          branch_3 = layers.conv2d(
362              branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1')
363        net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
364      end_points[end_point] = net
365      if end_point == final_endpoint:
366        return net, end_points
367
368      # mixed_7: 17 x 17 x 768.
369      end_point = 'Mixed_6e'
370      with variable_scope.variable_scope(end_point):
371        with variable_scope.variable_scope('Branch_0'):
372          branch_0 = layers.conv2d(
373              net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
374        with variable_scope.variable_scope('Branch_1'):
375          branch_1 = layers.conv2d(
376              net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
377          branch_1 = layers.conv2d(
378              branch_1, depth(192), [1, 7], scope='Conv2d_0b_1x7')
379          branch_1 = layers.conv2d(
380              branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1')
381        with variable_scope.variable_scope('Branch_2'):
382          branch_2 = layers.conv2d(
383              net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
384          branch_2 = layers.conv2d(
385              branch_2, depth(192), [7, 1], scope='Conv2d_0b_7x1')
386          branch_2 = layers.conv2d(
387              branch_2, depth(192), [1, 7], scope='Conv2d_0c_1x7')
388          branch_2 = layers.conv2d(
389              branch_2, depth(192), [7, 1], scope='Conv2d_0d_7x1')
390          branch_2 = layers.conv2d(
391              branch_2, depth(192), [1, 7], scope='Conv2d_0e_1x7')
392        with variable_scope.variable_scope('Branch_3'):
393          branch_3 = layers_lib.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
394          branch_3 = layers.conv2d(
395              branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1')
396        net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
397      end_points[end_point] = net
398      if end_point == final_endpoint:
399        return net, end_points
400
401      # mixed_8: 8 x 8 x 1280.
402      end_point = 'Mixed_7a'
403      with variable_scope.variable_scope(end_point):
404        with variable_scope.variable_scope('Branch_0'):
405          branch_0 = layers.conv2d(
406              net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
407          branch_0 = layers.conv2d(
408              branch_0,
409              depth(320), [3, 3],
410              stride=2,
411              padding='VALID',
412              scope='Conv2d_1a_3x3')
413        with variable_scope.variable_scope('Branch_1'):
414          branch_1 = layers.conv2d(
415              net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
416          branch_1 = layers.conv2d(
417              branch_1, depth(192), [1, 7], scope='Conv2d_0b_1x7')
418          branch_1 = layers.conv2d(
419              branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1')
420          branch_1 = layers.conv2d(
421              branch_1,
422              depth(192), [3, 3],
423              stride=2,
424              padding='VALID',
425              scope='Conv2d_1a_3x3')
426        with variable_scope.variable_scope('Branch_2'):
427          branch_2 = layers_lib.max_pool2d(
428              net, [3, 3], stride=2, padding='VALID', scope='MaxPool_1a_3x3')
429        net = array_ops.concat([branch_0, branch_1, branch_2], 3)
430      end_points[end_point] = net
431      if end_point == final_endpoint:
432        return net, end_points
433      # mixed_9: 8 x 8 x 2048.
434      end_point = 'Mixed_7b'
435      with variable_scope.variable_scope(end_point):
436        with variable_scope.variable_scope('Branch_0'):
437          branch_0 = layers.conv2d(
438              net, depth(320), [1, 1], scope='Conv2d_0a_1x1')
439        with variable_scope.variable_scope('Branch_1'):
440          branch_1 = layers.conv2d(
441              net, depth(384), [1, 1], scope='Conv2d_0a_1x1')
442          branch_1 = array_ops.concat(
443              [
444                  layers.conv2d(
445                      branch_1, depth(384), [1, 3], scope='Conv2d_0b_1x3'),
446                  layers.conv2d(
447                      branch_1, depth(384), [3, 1], scope='Conv2d_0b_3x1')
448              ],
449              3)
450        with variable_scope.variable_scope('Branch_2'):
451          branch_2 = layers.conv2d(
452              net, depth(448), [1, 1], scope='Conv2d_0a_1x1')
453          branch_2 = layers.conv2d(
454              branch_2, depth(384), [3, 3], scope='Conv2d_0b_3x3')
455          branch_2 = array_ops.concat(
456              [
457                  layers.conv2d(
458                      branch_2, depth(384), [1, 3], scope='Conv2d_0c_1x3'),
459                  layers.conv2d(
460                      branch_2, depth(384), [3, 1], scope='Conv2d_0d_3x1')
461              ],
462              3)
463        with variable_scope.variable_scope('Branch_3'):
464          branch_3 = layers_lib.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
465          branch_3 = layers.conv2d(
466              branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1')
467        net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
468      end_points[end_point] = net
469      if end_point == final_endpoint:
470        return net, end_points
471
472      # mixed_10: 8 x 8 x 2048.
473      end_point = 'Mixed_7c'
474      with variable_scope.variable_scope(end_point):
475        with variable_scope.variable_scope('Branch_0'):
476          branch_0 = layers.conv2d(
477              net, depth(320), [1, 1], scope='Conv2d_0a_1x1')
478        with variable_scope.variable_scope('Branch_1'):
479          branch_1 = layers.conv2d(
480              net, depth(384), [1, 1], scope='Conv2d_0a_1x1')
481          branch_1 = array_ops.concat(
482              [
483                  layers.conv2d(
484                      branch_1, depth(384), [1, 3], scope='Conv2d_0b_1x3'),
485                  layers.conv2d(
486                      branch_1, depth(384), [3, 1], scope='Conv2d_0c_3x1')
487              ],
488              3)
489        with variable_scope.variable_scope('Branch_2'):
490          branch_2 = layers.conv2d(
491              net, depth(448), [1, 1], scope='Conv2d_0a_1x1')
492          branch_2 = layers.conv2d(
493              branch_2, depth(384), [3, 3], scope='Conv2d_0b_3x3')
494          branch_2 = array_ops.concat(
495              [
496                  layers.conv2d(
497                      branch_2, depth(384), [1, 3], scope='Conv2d_0c_1x3'),
498                  layers.conv2d(
499                      branch_2, depth(384), [3, 1], scope='Conv2d_0d_3x1')
500              ],
501              3)
502        with variable_scope.variable_scope('Branch_3'):
503          branch_3 = layers_lib.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
504          branch_3 = layers.conv2d(
505              branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1')
506        net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
507      end_points[end_point] = net
508      if end_point == final_endpoint:
509        return net, end_points
510    raise ValueError('Unknown final endpoint %s' % final_endpoint)
511
512
513def inception_v3(inputs,
514                 num_classes=1000,
515                 is_training=True,
516                 dropout_keep_prob=0.8,
517                 min_depth=16,
518                 depth_multiplier=1.0,
519                 prediction_fn=layers_lib.softmax,
520                 spatial_squeeze=True,
521                 reuse=None,
522                 scope='InceptionV3'):
523  """Inception model from http://arxiv.org/abs/1512.00567.
524
525  "Rethinking the Inception Architecture for Computer Vision"
526
527  Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens,
528  Zbigniew Wojna.
529
530  With the default arguments this method constructs the exact model defined in
531  the paper. However, one can experiment with variations of the inception_v3
532  network by changing arguments dropout_keep_prob, min_depth and
533  depth_multiplier.
534
535  The default image size used to train this network is 299x299.
536
537  Args:
538    inputs: a tensor of size [batch_size, height, width, channels].
539    num_classes: number of predicted classes.
540    is_training: whether is training or not.
541    dropout_keep_prob: the percentage of activation values that are retained.
542    min_depth: Minimum depth value (number of channels) for all convolution ops.
543      Enforced when depth_multiplier < 1, and not an active constraint when
544      depth_multiplier >= 1.
545    depth_multiplier: Float multiplier for the depth (number of channels)
546      for all convolution ops. The value must be greater than zero. Typical
547      usage will be to set this value in (0, 1) to reduce the number of
548      parameters or computation cost of the model.
549    prediction_fn: a function to get predictions out of logits.
550    spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
551      of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
552      To use this parameter, the input images must be smaller
553      than 300x300 pixels, in which case the output logit layer
554      does not contain spatial information and can be removed.
555    reuse: whether or not the network and its variables should be reused. To be
556      able to reuse 'scope' must be given.
557    scope: Optional variable_scope.
558
559  Returns:
560    logits: the pre-softmax activations, a tensor of size
561      [batch_size, num_classes]
562    end_points: a dictionary from components of the network to the corresponding
563      activation.
564
565  Raises:
566    ValueError: if 'depth_multiplier' is less than or equal to zero.
567  """
568  if depth_multiplier <= 0:
569    raise ValueError('depth_multiplier is not greater than zero.')
570  depth = lambda d: max(int(d * depth_multiplier), min_depth)
571
572  with variable_scope.variable_scope(
573      scope, 'InceptionV3', [inputs, num_classes], reuse=reuse) as scope:
574    with arg_scope(
575        [layers_lib.batch_norm, layers_lib.dropout], is_training=is_training):
576      net, end_points = inception_v3_base(
577          inputs,
578          scope=scope,
579          min_depth=min_depth,
580          depth_multiplier=depth_multiplier)
581
582      # Auxiliary Head logits
583      with arg_scope(
584          [layers.conv2d, layers_lib.max_pool2d, layers_lib.avg_pool2d],
585          stride=1,
586          padding='SAME'):
587        aux_logits = end_points['Mixed_6e']
588        with variable_scope.variable_scope('AuxLogits'):
589          aux_logits = layers_lib.avg_pool2d(
590              aux_logits, [5, 5],
591              stride=3,
592              padding='VALID',
593              scope='AvgPool_1a_5x5')
594          aux_logits = layers.conv2d(
595              aux_logits, depth(128), [1, 1], scope='Conv2d_1b_1x1')
596
597          # Shape of feature map before the final layer.
598          kernel_size = _reduced_kernel_size_for_small_input(aux_logits, [5, 5])
599          aux_logits = layers.conv2d(
600              aux_logits,
601              depth(768),
602              kernel_size,
603              weights_initializer=trunc_normal(0.01),
604              padding='VALID',
605              scope='Conv2d_2a_{}x{}'.format(*kernel_size))
606          aux_logits = layers.conv2d(
607              aux_logits,
608              num_classes, [1, 1],
609              activation_fn=None,
610              normalizer_fn=None,
611              weights_initializer=trunc_normal(0.001),
612              scope='Conv2d_2b_1x1')
613          if spatial_squeeze:
614            aux_logits = array_ops.squeeze(
615                aux_logits, [1, 2], name='SpatialSqueeze')
616          end_points['AuxLogits'] = aux_logits
617
618      # Final pooling and prediction
619      with variable_scope.variable_scope('Logits'):
620        kernel_size = _reduced_kernel_size_for_small_input(net, [8, 8])
621        net = layers_lib.avg_pool2d(
622            net,
623            kernel_size,
624            padding='VALID',
625            scope='AvgPool_1a_{}x{}'.format(*kernel_size))
626        # 1 x 1 x 2048
627        net = layers_lib.dropout(
628            net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
629        end_points['PreLogits'] = net
630        # 2048
631        logits = layers.conv2d(
632            net,
633            num_classes, [1, 1],
634            activation_fn=None,
635            normalizer_fn=None,
636            scope='Conv2d_1c_1x1')
637        if spatial_squeeze:
638          logits = array_ops.squeeze(logits, [1, 2], name='SpatialSqueeze')
639        # 1000
640      end_points['Logits'] = logits
641      end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
642  return logits, end_points
643
644
645inception_v3.default_image_size = 299
646
647
648def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
649  """Define kernel size which is automatically reduced for small input.
650
651  If the shape of the input images is unknown at graph construction time this
652  function assumes that the input images are is large enough.
653
654  Args:
655    input_tensor: input tensor of size [batch_size, height, width, channels].
656    kernel_size: desired kernel size of length 2: [kernel_height, kernel_width]
657
658  Returns:
659    a tensor with the kernel size.
660
661  TODO(jrru): Make this function work with unknown shapes. Theoretically, this
662  can be done with the code below. Problems are two-fold: (1) If the shape was
663  known, it will be lost. (2) inception.tf.contrib.slim.ops._two_element_tuple
664  cannot
665  handle tensors that define the kernel size.
666      shape = tf.shape(input_tensor)
667      return = tf.stack([tf.minimum(shape[1], kernel_size[0]),
668                        tf.minimum(shape[2], kernel_size[1])])
669
670  """
671  shape = input_tensor.get_shape().as_list()
672  if shape[1] is None or shape[2] is None:
673    kernel_size_out = kernel_size
674  else:
675    kernel_size_out = [
676        min(shape[1], kernel_size[0]), min(shape[2], kernel_size[1])
677    ]
678  return kernel_size_out
679
680
681def inception_v3_arg_scope(weight_decay=0.00004,
682                           batch_norm_var_collection='moving_vars',
683                           batch_norm_decay=0.9997,
684                           batch_norm_epsilon=0.001,
685                           updates_collections=ops.GraphKeys.UPDATE_OPS,
686                           use_fused_batchnorm=True):
687  """Defines the default InceptionV3 arg scope.
688
689  Args:
690    weight_decay: The weight decay to use for regularizing the model.
691    batch_norm_var_collection: The name of the collection for the batch norm
692      variables.
693    batch_norm_decay: Decay for batch norm moving average
694    batch_norm_epsilon: Small float added to variance to avoid division by zero
695    updates_collections: Collections for the update ops of the layer
696    use_fused_batchnorm: Enable fused batchnorm.
697
698  Returns:
699    An `arg_scope` to use for the inception v3 model.
700  """
701  batch_norm_params = {
702      # Decay for the moving averages.
703      'decay': batch_norm_decay,
704      # epsilon to prevent 0s in variance.
705      'epsilon': batch_norm_epsilon,
706      # collection containing update_ops.
707      'updates_collections': updates_collections,
708      # Use fused batch norm if possible.
709      'fused': use_fused_batchnorm,
710      # collection containing the moving mean and moving variance.
711      'variables_collections': {
712          'beta': None,
713          'gamma': None,
714          'moving_mean': [batch_norm_var_collection],
715          'moving_variance': [batch_norm_var_collection],
716      }
717  }
718
719  # Set weight_decay for weights in Conv and FC layers.
720  with arg_scope(
721      [layers.conv2d, layers_lib.fully_connected],
722      weights_regularizer=regularizers.l2_regularizer(weight_decay)):
723    with arg_scope(
724        [layers.conv2d],
725        weights_initializer=initializers.variance_scaling_initializer(),
726        activation_fn=nn_ops.relu,
727        normalizer_fn=layers_lib.batch_norm,
728        normalizer_params=batch_norm_params) as sc:
729      return sc
730