1/* 2Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package op 18 19import ( 20 "fmt" 21 "runtime/debug" 22 23 tf "github.com/tensorflow/tensorflow/tensorflow/go" 24) 25 26// Scope encapsulates common operation properties when building a Graph. 27// 28// A Scope object (and its derivatives, e.g., obtained from Scope.SubScope) 29// act as a builder for graphs. They allow common properties (such as 30// a name prefix) to be specified for multiple operations being added 31// to the graph. 32// 33// A Scope object and all its derivatives (e.g., obtained from Scope.SubScope) 34// are not safe for concurrent use by multiple goroutines. 35type Scope struct { 36 graph *tf.Graph 37 namemap map[string]int 38 namespace string 39 controlDependencies []*tf.Operation 40 device string 41 err *scopeErr 42} 43 44// scopeErr is used to share errors between all derivatives of a root scope. 45type scopeErr struct { 46 err error 47} 48 49// NewScope creates a Scope initialized with an empty Graph. 50func NewScope() *Scope { 51 return &Scope{graph: tf.NewGraph(), namemap: make(map[string]int), err: new(scopeErr)} 52} 53 54// NewScopeWithGraph creates a Scope initialized with the Graph thats passed in 55func NewScopeWithGraph(g *tf.Graph) *Scope { 56 return &Scope{graph: g, namemap: make(map[string]int), err: new(scopeErr)} 57} 58 59// Finalize returns the Graph on which this scope operates on and renders s 60// unusable. If there was an error during graph construction, that error is 61// returned instead. 62func (s *Scope) Finalize() (*tf.Graph, error) { 63 if err := s.Err(); err != nil { 64 return nil, err 65 } 66 s.err.err = fmt.Errorf("Scope has been finalized and is no longer usable") 67 return s.graph, nil 68} 69 70// AddOperation adds the operation to the Graph managed by s. 71// 72// If there is a name prefix associated with s (such as if s was created 73// by a call to SubScope), then this prefix will be applied to the name 74// of the operation being added. See also Graph.AddOperation. 75func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation { 76 if s.Err() != nil { 77 return nil 78 } 79 if args.Name == "" { 80 args.Name = args.Type 81 } 82 if s.namespace != "" { 83 args.Name = s.namespace + "/" + args.Name 84 } 85 args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...) 86 args.Device = s.device 87 op, err := s.graph.AddOperation(args) 88 if err != nil { 89 s.UpdateErr(args.Type, err) 90 } 91 return op 92} 93 94// SubScope returns a new Scope which will cause all operations added to the 95// graph to be namespaced with 'namespace'. If namespace collides with an 96// existing namespace within the scope, then a suffix will be added. 97func (s *Scope) SubScope(namespace string) *Scope { 98 namespace = s.uniqueName(namespace) 99 if s.namespace != "" { 100 namespace = s.namespace + "/" + namespace 101 } 102 return &Scope{ 103 graph: s.graph, 104 namemap: make(map[string]int), 105 namespace: namespace, 106 controlDependencies: s.controlDependencies, 107 device: s.device, 108 err: s.err, 109 } 110} 111 112// WithControlDependencies returns a new Scope which will cause all operations 113// added to the graph to execute only after all the provided operations have 114// executed first (in addition to any other control dependencies in s). 115func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope { 116 // Force a copy of the control dependencies into a new underlying array on 117 // every call. We cannot alias the same underlying array as `ops`, otherwise 118 // the user could modify that array after calling s.WithControlDependencies, 119 // which would be confusing. We cannot alias the same underlying array as the 120 // original `s.controlDependencies`, since Scopes form a logical tree, and 121 // other calls to s.WithControlDependencies could stomp on each other. 122 deps := make([]*tf.Operation, 0, len(s.controlDependencies)+len(ops)) 123 deps = append(deps, s.controlDependencies...) 124 deps = append(deps, ops...) 125 return &Scope{ 126 graph: s.graph, 127 namemap: s.namemap, 128 namespace: s.namespace, 129 controlDependencies: deps, 130 device: s.device, 131 err: s.err, 132 } 133} 134 135// WithDevice returns a new Scope which will cause all operations added to the 136// graph to execute on devices that match the provided device specification. 137// 138// For example, WithDevice("/device:GPU:0") will cause operations added to 139// the graph to execute on GPU #0. 140// 141// An empty string removes any device restrictions. 142func (s *Scope) WithDevice(device string) *Scope { 143 return &Scope{ 144 graph: s.graph, 145 namemap: s.namemap, 146 namespace: s.namespace, 147 controlDependencies: s.controlDependencies, 148 device: device, 149 err: s.err, 150 } 151} 152 153// Err returns the error, if any, encountered during the construction 154// of the Graph managed by s. 155// 156// Once Err returns a non-nil error, all future calls will do the same, 157// indicating that the scope should be discarded as the graph could not 158// be constructed. 159func (s *Scope) Err() error { 160 return s.err.err 161} 162 163// UpdateErr is used to notify Scope of any graph construction errors 164// while creating the operation op. 165func (s *Scope) UpdateErr(op string, err error) { 166 if s.err.err == nil { 167 s.err.err = fmt.Errorf("failed to add operation %q: %v (Stacktrace: %s)", op, err, debug.Stack()) 168 } 169} 170 171func (s *Scope) uniqueName(name string) string { 172 count := s.namemap[name] 173 s.namemap[name]++ 174 if count == 0 { 175 return name 176 } 177 return fmt.Sprint(name, "_", count) 178} 179 180func (s *Scope) opName(typ string) string { 181 if s.namespace == "" { 182 return typ 183 } 184 return s.namespace + "/" + typ 185} 186