Searched refs:grad_state (Results 1 – 7 of 7) sorted by relevance
188 grad_state=self)202 grad_state=self)417 grad_state = self419 while pred is None and grad_state:420 pred = grad_state.history_map.get(cond_ctxt.pred.name)421 grad_state = grad_state.outer_grad_state539 for grad_state in self._map.values():540 for y in grad_state.forward_loop_exits:542 grad_state.pending_exits_count -= 1544 grad_state.unused_exits.append(y)[all …]
48 merge_grad = grad_ctxt.grad_state.switch_map.get(op)66 grad_ctxt.grad_state.switch_map[op] = merge_grad110 if grad_ctxt and grad_ctxt.grad_state:115 grad_state = grad_ctxt.grad_state116 real_pred = grad_state.history_map.get(pred.name)119 grad_ctxt = grad_state.grad_context121 history_pred = grad_state.AddForwardAccumulator(pred)126 real_pred = grad_state.AddBackpropAccumulatedValue(history_pred, pred)127 grad_state.history_map[pred.name] = real_pred160 if op_ctxt.grad_state:[all …]
311 elif (while_ctxt.grad_state and312 IsContainingContext(while_ctxt.grad_state.forward_context,321 elif (while_ctxt.grad_state and322 while_ctxt.grad_state.forward_context is328 elif (input_while_ctxt.grad_state and329 input_while_ctxt.grad_state.forward_context is while_ctxt):334 elif (input_while_ctxt.grad_state and335 input_ctxt.grad_state.forward_context.grad_state and336 input_ctxt.grad_state.forward_context.grad_state.forward_context is
757 grad_state = loop_state.GetGradState(x.op, before=False)758 grad_state.deferred_exits.append(x)759 grad_state.pending_exits_count -= 1760 if grad_state.pending_exits_count == 0:763 for y in grad_state.deferred_exits:768 grad_state.unused_exits.append(y)772 for y in grad_state.unused_exits:778 for y in grad_state.unused_exits:
700 def grad_state(self): member in ControlFlowContext891 def grad_state(self): member in CondContext893 return self.GetWhileContext().grad_state1462 grad_state=None, argument1486 self._grad_state = grad_state1613 def grad_state(self): member in WhileContext1706 if grad_ctxt.grad_state:1712 if forward_ctxt == grad_ctxt.grad_state.forward_context:1713 real_val = grad_ctxt.grad_state.GetRealValue(val)1758 if grad_ctxt.grad_state:[all …]
296 def grad_state(self): member in XLACompileContext
674 def grad_state(self): member in TPUReplicateContext2138 def grad_state(self): member in _TPUInferenceContext