• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package op
18
19import (
20	"fmt"
21	"runtime/debug"
22
23	tf "github.com/tensorflow/tensorflow/tensorflow/go"
24)
25
26// Scope encapsulates common operation properties when building a Graph.
27//
28// A Scope object (and its derivatives, e.g., obtained from Scope.SubScope)
29// act as a builder for graphs. They allow common properties (such as
30// a name prefix) to be specified for multiple operations being added
31// to the graph.
32//
33// A Scope object and all its derivatives (e.g., obtained from Scope.SubScope)
34// are not safe for concurrent use by multiple goroutines.
35type Scope struct {
36	graph               *tf.Graph
37	namemap             map[string]int
38	namespace           string
39	controlDependencies []*tf.Operation
40	device              string
41	err                 *scopeErr
42}
43
44// scopeErr is used to share errors between all derivatives of a root scope.
45type scopeErr struct {
46	err error
47}
48
49// NewScope creates a Scope initialized with an empty Graph.
50func NewScope() *Scope {
51	return &Scope{graph: tf.NewGraph(), namemap: make(map[string]int), err: new(scopeErr)}
52}
53
54// NewScopeWithGraph creates a Scope initialized with the Graph thats passed in
55func NewScopeWithGraph(g *tf.Graph) *Scope {
56	return &Scope{graph: g, namemap: make(map[string]int), err: new(scopeErr)}
57}
58
59// Finalize returns the Graph on which this scope operates on and renders s
60// unusable. If there was an error during graph construction, that error is
61// returned instead.
62func (s *Scope) Finalize() (*tf.Graph, error) {
63	if err := s.Err(); err != nil {
64		return nil, err
65	}
66	s.err.err = fmt.Errorf("Scope has been finalized and is no longer usable")
67	return s.graph, nil
68}
69
70// AddOperation adds the operation to the Graph managed by s.
71//
72// If there is a name prefix associated with s (such as if s was created
73// by a call to SubScope), then this prefix will be applied to the name
74// of the operation being added. See also Graph.AddOperation.
75func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation {
76	if s.Err() != nil {
77		return nil
78	}
79	if args.Name == "" {
80		args.Name = args.Type
81	}
82	if s.namespace != "" {
83		args.Name = s.namespace + "/" + args.Name
84	}
85	args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...)
86	args.Device = s.device
87	op, err := s.graph.AddOperation(args)
88	if err != nil {
89		s.UpdateErr(args.Type, err)
90	}
91	return op
92}
93
94// SubScope returns a new Scope which will cause all operations added to the
95// graph to be namespaced with 'namespace'.  If namespace collides with an
96// existing namespace within the scope, then a suffix will be added.
97func (s *Scope) SubScope(namespace string) *Scope {
98	namespace = s.uniqueName(namespace)
99	if s.namespace != "" {
100		namespace = s.namespace + "/" + namespace
101	}
102	return &Scope{
103		graph:               s.graph,
104		namemap:             make(map[string]int),
105		namespace:           namespace,
106		controlDependencies: s.controlDependencies,
107		device:              s.device,
108		err:                 s.err,
109	}
110}
111
112// WithControlDependencies returns a new Scope which will cause all operations
113// added to the graph to execute only after all the provided operations have
114// executed first (in addition to any other control dependencies in s).
115func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope {
116	// Force a copy of the control dependencies into a new underlying array on
117	// every call.  We cannot alias the same underlying array as `ops`, otherwise
118	// the user could modify that array after calling s.WithControlDependencies,
119	// which would be confusing.  We cannot alias the same underlying array as the
120	// original `s.controlDependencies`, since Scopes form a logical tree, and
121	// other calls to s.WithControlDependencies could stomp on each other.
122	deps := make([]*tf.Operation, 0, len(s.controlDependencies)+len(ops))
123	deps = append(deps, s.controlDependencies...)
124	deps = append(deps, ops...)
125	return &Scope{
126		graph:               s.graph,
127		namemap:             s.namemap,
128		namespace:           s.namespace,
129		controlDependencies: deps,
130		device:              s.device,
131		err:                 s.err,
132	}
133}
134
135// WithDevice returns a new Scope which will cause all operations added to the
136// graph to execute on devices that match the provided device specification.
137//
138// For example, WithDevice("/device:GPU:0") will cause operations added to
139// the graph to execute on GPU #0.
140//
141// An empty string removes any device restrictions.
142func (s *Scope) WithDevice(device string) *Scope {
143	return &Scope{
144		graph:               s.graph,
145		namemap:             s.namemap,
146		namespace:           s.namespace,
147		controlDependencies: s.controlDependencies,
148		device:              device,
149		err:                 s.err,
150	}
151}
152
153// Err returns the error, if any, encountered during the construction
154// of the Graph managed by s.
155//
156// Once Err returns a non-nil error, all future calls will do the same,
157// indicating that the scope should be discarded as the graph could not
158// be constructed.
159func (s *Scope) Err() error {
160	return s.err.err
161}
162
163// UpdateErr is used to notify Scope of any graph construction errors
164// while creating the operation op.
165func (s *Scope) UpdateErr(op string, err error) {
166	if s.err.err == nil {
167		s.err.err = fmt.Errorf("failed to add operation %q: %v (Stacktrace: %s)", op, err, debug.Stack())
168	}
169}
170
171func (s *Scope) uniqueName(name string) string {
172	count := s.namemap[name]
173	s.namemap[name]++
174	if count == 0 {
175		return name
176	}
177	return fmt.Sprint(name, "_", count)
178}
179
180func (s *Scope) opName(typ string) string {
181	if s.namespace == "" {
182		return typ
183	}
184	return s.namespace + "/" + typ
185}
186