1 //============================================================================
2 // Name : autodiff.cpp
3 // Author :
4 // Version :
5 // Copyright : Your copyright notice
6 // Description : Hello World in C++, Ansi-style
7 //============================================================================
8
9 #include <iostream>
10 #include <sstream>
11 #include <numeric>
12 #include <boost/foreach.hpp>
13 #include "autodiff.h"
14 #include "Stack.h"
15 #include "Tape.h"
16 #include "BinaryOPNode.h"
17 #include "UaryOPNode.h"
18
19 using namespace std;
20
21 namespace AutoDiff
22 {
23
24 #if FORWARD_ENABLED
25
26 unsigned int num_var = 0;
27
hess_forward(Node * root,unsigned int nvar,double ** hess_mat)28 void hess_forward(Node* root, unsigned int nvar, double** hess_mat)
29 {
30 assert(nvar == num_var);
31 unsigned int len = (nvar+3)*nvar/2;
32 root->hess_forward(len,hess_mat);
33 }
34
35 #endif
36
37
create_param_node(double value)38 PNode* create_param_node(double value){
39 return new PNode(value);
40 }
create_var_node(double v)41 VNode* create_var_node(double v)
42 {
43 return new VNode(v);
44 }
create_binary_op_node(OPCODE code,Node * left,Node * right)45 OPNode* create_binary_op_node(OPCODE code, Node* left, Node* right)
46 {
47 return BinaryOPNode::createBinaryOpNode(code,left,right);
48 }
create_uary_op_node(OPCODE code,Node * left)49 OPNode* create_uary_op_node(OPCODE code, Node* left)
50 {
51 return UaryOPNode::createUnaryOpNode(code,left);
52 }
eval_function(Node * root)53 double eval_function(Node* root)
54 {
55 assert(SD->size()==0);
56 assert(SV->size()==0);
57 root->eval_function();
58 assert(SV->size()==1);
59 double val = SV->pop_back();
60 return val;
61 }
62
grad_reverse(Node * root,vector<Node * > & vnodes,vector<double> & grad)63 double grad_reverse(Node* root,vector<Node*>& vnodes, vector<double>& grad)
64 {
65 grad.clear();
66 BOOST_FOREACH(Node* node, vnodes)
67 {
68 assert(node->getType()==VNode_Type);
69 static_cast<VNode*>(node)->adj = NaN_Double;
70 }
71
72 assert(SD->size()==0);
73 root->grad_reverse_0();
74 assert(SV->size()==1);
75 root->grad_reverse_1_init_adj();
76 root->grad_reverse_1();
77 assert(SD->size()==0);
78 double val = SV->pop_back();
79 assert(SV->size()==0);
80 //int i=0;
81 BOOST_FOREACH(Node* node, vnodes)
82 {
83 assert(node->getType()==VNode_Type);
84 grad.push_back(static_cast<VNode*>(node)->adj);
85 static_cast<VNode*>(node)->adj = NaN_Double;
86 }
87 assert(grad.size()==vnodes.size());
88 //all nodes are VNode and adj == NaN_Double -- this reset adj for this expression tree by root
89 return val;
90 }
91
grad_reverse(Node * root,vector<Node * > & vnodes,col_compress_matrix_row & rgrad)92 double grad_reverse(Node* root, vector<Node*>& vnodes, col_compress_matrix_row& rgrad)
93 {
94 BOOST_FOREACH(Node* node, vnodes)
95 {
96 assert(node->getType()==VNode_Type);
97 static_cast<VNode*>(node)->adj = NaN_Double;
98 }
99 assert(SD->size()==0);
100 root->grad_reverse_0();
101 assert(SV->size()==1);
102 root->grad_reverse_1_init_adj();
103 root->grad_reverse_1();
104 assert(SD->size()==0);
105 double val = SV->pop_back();
106 assert(SV->size()==0);
107 unsigned int i =0;
108 BOOST_FOREACH(Node* node, vnodes)
109 {
110 assert((node)->getType()==VNode_Type);
111 double diff = static_cast<VNode*>(node)->adj;
112 if(!isnan(diff)){
113 rgrad(i) = diff;
114 static_cast<VNode*>(node)->adj = NaN_Double;
115 }
116 i++;
117 }
118 //all nodes are VNode and adj == NaN_Double -- this reset adj for this expression tree by root
119 assert(i==vnodes.size());
120 return val;
121 }
122
hess_reverse(Node * root,vector<Node * > & vnodes,vector<double> & dhess)123 double hess_reverse(Node* root,vector<Node*>& vnodes,vector<double>& dhess)
124 {
125 TT->clear();
126 II->clear();
127 assert(TT->empty());
128 assert(II->empty());
129 assert(TT->index==0);
130 assert(II->index==0);
131 dhess.clear();
132
133 // for(vector<Node*>::iterator it=nodes.begin();it!=nodes.end();it++)
134 // {
135 // assert((*it)->getType()==VNode_Type);
136 // (*it)->index = 0;
137 // } //this work complete in hess-reverse_0_init_index
138
139 assert(root->n_in_arcs == 0);
140 root->hess_reverse_0_init_n_in_arcs();
141 assert(root->n_in_arcs == 1);
142 root->hess_reverse_0();
143 double val = NaN_Double;
144 root->hess_reverse_get_x(TT->index,val);
145 // cout<<TT->toString();
146 // cout<<endl;
147 // cout<<II->toString();
148 // cout<<"======================================= hess_reverse_0"<<endl;
149 root->hess_reverse_1_init_x_bar(TT->index);
150 assert(root->n_in_arcs == 1);
151 root->hess_reverse_1(TT->index);
152 assert(root->n_in_arcs == 0);
153 assert(II->index==0);
154 // cout<<TT->toString();
155 // cout<<endl;
156 // cout<<II->toString();
157 // cout<<"======================================= hess_reverse_1"<<endl;
158
159 for(vector<Node*>::iterator it=vnodes.begin();it!=vnodes.end();it++)
160 {
161 assert((*it)->getType()==VNode_Type);
162 dhess.push_back(TT->get((*it)->index-1));
163 }
164
165 TT->clear();
166 II->clear();
167 root->hess_reverse_1_clear_index();
168 return val;
169 }
170
hess_reverse(Node * root,vector<Node * > & vnodes,col_compress_matrix_col & chess)171 double hess_reverse(Node* root,vector<Node*>& vnodes,col_compress_matrix_col& chess)
172 {
173 TT->clear();
174 II->clear();
175 assert(TT->empty());
176 assert(II->empty());
177 assert(TT->index==0);
178 assert(II->index==0);
179
180 // for(vector<Node*>::iterator it=nodes.begin();it!=nodes.end();it++)
181 // {
182 // assert((*it)->getType()==VNode_Type);
183 // (*it)->index = 0;
184 // } //this work complete in hess-reverse_0_init_index
185
186 assert(root->n_in_arcs == 0);
187 //reset node index and n_in_arcs - for the Tape location
188 root->hess_reverse_0_init_n_in_arcs();
189 assert(root->n_in_arcs == 1);
190 root->hess_reverse_0();
191 double val = NaN_Double;
192 root->hess_reverse_get_x(TT->index,val);
193 // cout<<TT->toString();
194 // cout<<endl;
195 // cout<<II->toString();
196 // cout<<"======================================= hess_reverse_0"<<endl;
197 root->hess_reverse_1_init_x_bar(TT->index);
198 assert(root->n_in_arcs == 1);
199 root->hess_reverse_1(TT->index);
200 assert(root->n_in_arcs == 0);
201 assert(II->index==0);
202 // cout<<TT->toString();
203 // cout<<endl;
204 // cout<<II->toString();
205 // cout<<"======================================= hess_reverse_1"<<endl;
206
207 unsigned int i =0;
208 BOOST_FOREACH(Node* node, vnodes)
209 {
210 assert(node->getType() == VNode_Type);
211 //node->index = 0 means this VNode is not in the tree
212 if(node->index!=0)
213 {
214 double hess = TT->get(node->index -1);
215 if(!isnan(hess))
216 {
217 chess(i) = chess(i) + hess;
218 }
219 }
220 i++;
221 }
222 assert(i==vnodes.size());
223 root->hess_reverse_1_clear_index();
224 TT->clear();
225 II->clear();
226 return val;
227 }
228
nzGrad(Node * root)229 unsigned int nzGrad(Node* root)
230 {
231 unsigned int nzgrad,total = 0;
232 boost::unordered_set<Node*> nodes;
233 root->collect_vnodes(nodes,total);
234 nzgrad = nodes.size();
235 return nzgrad;
236 }
237
238 /*
239 * number of non-zero gradient in constraint tree root that also belong to vSet
240 */
nzGrad(Node * root,boost::unordered_set<Node * > & vSet)241 unsigned int nzGrad(Node* root, boost::unordered_set<Node*>& vSet)
242 {
243 unsigned int nzgrad=0, total=0;
244 boost::unordered_set<Node*> vnodes;
245 root->collect_vnodes(vnodes,total);
246 //cout<<"nzGrad - vnodes size["<<vnodes.size()<<"] -- total node["<<total<<"]"<<endl;
247 for(boost::unordered_set<Node*>::iterator it=vnodes.begin();it!=vnodes.end();it++)
248 {
249 Node* n = *it;
250 if(vSet.find(n) != vSet.end())
251 {
252 nzgrad++;
253 }
254 }
255 return nzgrad;
256 }
257
nonlinearEdges(Node * root,EdgeSet & edges)258 void nonlinearEdges(Node* root, EdgeSet& edges)
259 {
260 root->nonlinearEdges(edges);
261 }
262
nzHess(EdgeSet & eSet,boost::unordered_set<Node * > & set1,boost::unordered_set<Node * > & set2)263 unsigned int nzHess(EdgeSet& eSet,boost::unordered_set<Node*>& set1, boost::unordered_set<Node*>& set2)
264 {
265 list<Edge>::iterator i = eSet.edges.begin();
266 for(;i!=eSet.edges.end();)
267 {
268 Edge e =*i;
269 Node* a = e.a;
270 Node* b = e.b;
271 if((set1.find(a)!=set1.end() && set2.find(b)!=set2.end())
272 ||
273 (set1.find(b)!=set1.end() && set2.find(a)!=set2.end()))
274 {
275 //e is connected between set1 and set2
276 i++;
277 }
278 else
279 {
280 i = eSet.edges.erase(i);
281 }
282 }
283 unsigned int diag=eSet.numSelfEdges();
284 unsigned int nzHess = (eSet.size())*2 - diag;
285 return nzHess;
286 }
287
nzHess(EdgeSet & edges)288 unsigned int nzHess(EdgeSet& edges)
289 {
290 unsigned int diag=edges.numSelfEdges();
291 unsigned int nzHess = (edges.size())*2 - diag;
292 return nzHess;
293 }
294
numTotalNodes(Node * root)295 unsigned int numTotalNodes(Node* root)
296 {
297 unsigned int total = 0;
298 boost::unordered_set<Node*> nodes;
299 root->collect_vnodes(nodes,total);
300 return total;
301 }
302
tree_expr(Node * root)303 string tree_expr(Node* root)
304 {
305 ostringstream oss;
306 oss<<"visiting tree == "<<endl;
307 int level = 0;
308 root->inorder_visit(level,oss);
309 return oss.str();
310 }
311
print_tree(Node * root)312 void print_tree(Node* root)
313 {
314 cout<<"visiting tree == "<<endl;
315 int level = 0;
316 root->inorder_visit(level,cout);
317 }
318
autodiff_setup()319 void autodiff_setup()
320 {
321 Stack::diff = new Stack();
322 Stack::vals = new Stack();
323 Tape<unsigned int>::indexTape = new Tape<unsigned int>();
324 Tape<double>::valueTape = new Tape<double>();
325 }
326
autodiff_cleanup()327 void autodiff_cleanup()
328 {
329 delete Stack::diff;
330 delete Stack::vals;
331 delete Tape<unsigned int>::indexTape;
332 delete Tape<double>::valueTape;
333 }
334
335 } //AutoDiff namespace end
336