• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //============================================================================
2 // Name        : autodiff.cpp
3 // Author      :
4 // Version     :
5 // Copyright   : Your copyright notice
6 // Description : Hello World in C++, Ansi-style
7 //============================================================================
8 
9 #include <iostream>
10 #include <sstream>
11 #include <numeric>
12 #include <boost/foreach.hpp>
13 #include "autodiff.h"
14 #include "Stack.h"
15 #include "Tape.h"
16 #include "BinaryOPNode.h"
17 #include "UaryOPNode.h"
18 
19 using namespace std;
20 
21 namespace AutoDiff
22 {
23 
24 #if FORWARD_ENABLED
25 
26 unsigned int num_var = 0;
27 
hess_forward(Node * root,unsigned int nvar,double ** hess_mat)28 void hess_forward(Node* root, unsigned int nvar, double** hess_mat)
29 {
30 	assert(nvar == num_var);
31 	unsigned int len = (nvar+3)*nvar/2;
32 	root->hess_forward(len,hess_mat);
33 }
34 
35 #endif
36 
37 
create_param_node(double value)38 PNode* create_param_node(double value){
39 	return new PNode(value);
40 }
create_var_node(double v)41 VNode* create_var_node(double v)
42 {
43 	return new VNode(v);
44 }
create_binary_op_node(OPCODE code,Node * left,Node * right)45 OPNode* create_binary_op_node(OPCODE code, Node* left, Node* right)
46 {
47 	return BinaryOPNode::createBinaryOpNode(code,left,right);
48 }
create_uary_op_node(OPCODE code,Node * left)49 OPNode* create_uary_op_node(OPCODE code, Node* left)
50 {
51 	return UaryOPNode::createUnaryOpNode(code,left);
52 }
eval_function(Node * root)53 double eval_function(Node* root)
54 {
55 	assert(SD->size()==0);
56 	assert(SV->size()==0);
57 	root->eval_function();
58 	assert(SV->size()==1);
59 	double val = SV->pop_back();
60 	return val;
61 }
62 
grad_reverse(Node * root,vector<Node * > & vnodes,vector<double> & grad)63 double grad_reverse(Node* root,vector<Node*>& vnodes, vector<double>& grad)
64 {
65 	grad.clear();
66 	BOOST_FOREACH(Node* node, vnodes)
67 	{
68 		assert(node->getType()==VNode_Type);
69 		static_cast<VNode*>(node)->adj = NaN_Double;
70 	}
71 
72 	assert(SD->size()==0);
73 	root->grad_reverse_0();
74 	assert(SV->size()==1);
75 	root->grad_reverse_1_init_adj();
76 	root->grad_reverse_1();
77 	assert(SD->size()==0);
78 	double val = SV->pop_back();
79 	assert(SV->size()==0);
80 	//int i=0;
81 	BOOST_FOREACH(Node* node, vnodes)
82 	{
83 		assert(node->getType()==VNode_Type);
84 		grad.push_back(static_cast<VNode*>(node)->adj);
85 		static_cast<VNode*>(node)->adj = NaN_Double;
86 	}
87 	assert(grad.size()==vnodes.size());
88 	//all nodes are VNode and adj == NaN_Double -- this reset adj for this expression tree by root
89 	return val;
90 }
91 
grad_reverse(Node * root,vector<Node * > & vnodes,col_compress_matrix_row & rgrad)92 double grad_reverse(Node* root, vector<Node*>& vnodes, col_compress_matrix_row& rgrad)
93 {
94 	BOOST_FOREACH(Node* node, vnodes)
95 	{
96 		assert(node->getType()==VNode_Type);
97 		static_cast<VNode*>(node)->adj = NaN_Double;
98 	}
99 	assert(SD->size()==0);
100 	root->grad_reverse_0();
101 	assert(SV->size()==1);
102 	root->grad_reverse_1_init_adj();
103 	root->grad_reverse_1();
104 	assert(SD->size()==0);
105 	double val = SV->pop_back();
106 	assert(SV->size()==0);
107 	unsigned int i =0;
108 	BOOST_FOREACH(Node* node, vnodes)
109 	{
110 		assert((node)->getType()==VNode_Type);
111 		double diff = static_cast<VNode*>(node)->adj;
112 		if(!isnan(diff)){
113 			rgrad(i) = diff;
114 			static_cast<VNode*>(node)->adj = NaN_Double;
115 		}
116 		i++;
117 	}
118 	//all nodes are VNode and adj == NaN_Double -- this reset adj for this expression tree by root
119 	assert(i==vnodes.size());
120 	return val;
121 }
122 
hess_reverse(Node * root,vector<Node * > & vnodes,vector<double> & dhess)123 double hess_reverse(Node* root,vector<Node*>& vnodes,vector<double>& dhess)
124 {
125 	TT->clear();
126 	II->clear();
127 	assert(TT->empty());
128 	assert(II->empty());
129 	assert(TT->index==0);
130 	assert(II->index==0);
131 	dhess.clear();
132 
133 //	for(vector<Node*>::iterator it=nodes.begin();it!=nodes.end();it++)
134 //	{
135 //		assert((*it)->getType()==VNode_Type);
136 //		(*it)->index = 0;
137 //	} //this work complete in hess-reverse_0_init_index
138 
139 	assert(root->n_in_arcs == 0);
140 	root->hess_reverse_0_init_n_in_arcs();
141 	assert(root->n_in_arcs == 1);
142 	root->hess_reverse_0();
143 	double val = NaN_Double;
144 	root->hess_reverse_get_x(TT->index,val);
145 //	cout<<TT->toString();
146 //	cout<<endl;
147 //	cout<<II->toString();
148 //	cout<<"======================================= hess_reverse_0"<<endl;
149 	root->hess_reverse_1_init_x_bar(TT->index);
150 	assert(root->n_in_arcs == 1);
151 	root->hess_reverse_1(TT->index);
152 	assert(root->n_in_arcs == 0);
153 	assert(II->index==0);
154 //	cout<<TT->toString();
155 //	cout<<endl;
156 //	cout<<II->toString();
157 //	cout<<"======================================= hess_reverse_1"<<endl;
158 
159 	for(vector<Node*>::iterator it=vnodes.begin();it!=vnodes.end();it++)
160 	{
161 		assert((*it)->getType()==VNode_Type);
162 		dhess.push_back(TT->get((*it)->index-1));
163 	}
164 
165 	TT->clear();
166 	II->clear();
167 	root->hess_reverse_1_clear_index();
168 	return val;
169 }
170 
hess_reverse(Node * root,vector<Node * > & vnodes,col_compress_matrix_col & chess)171 double hess_reverse(Node* root,vector<Node*>& vnodes,col_compress_matrix_col& chess)
172 {
173 	TT->clear();
174 	II->clear();
175 	assert(TT->empty());
176 	assert(II->empty());
177 	assert(TT->index==0);
178 	assert(II->index==0);
179 
180 //	for(vector<Node*>::iterator it=nodes.begin();it!=nodes.end();it++)
181 //	{
182 //		assert((*it)->getType()==VNode_Type);
183 //		(*it)->index = 0;
184 //	} //this work complete in hess-reverse_0_init_index
185 
186 	assert(root->n_in_arcs == 0);
187 	//reset node index and n_in_arcs - for the Tape location
188 	root->hess_reverse_0_init_n_in_arcs();
189 	assert(root->n_in_arcs == 1);
190 	root->hess_reverse_0();
191 	double val = NaN_Double;
192 	root->hess_reverse_get_x(TT->index,val);
193 //	cout<<TT->toString();
194 //	cout<<endl;
195 //	cout<<II->toString();
196 //	cout<<"======================================= hess_reverse_0"<<endl;
197 	root->hess_reverse_1_init_x_bar(TT->index);
198 	assert(root->n_in_arcs == 1);
199 	root->hess_reverse_1(TT->index);
200 	assert(root->n_in_arcs == 0);
201 	assert(II->index==0);
202 //	cout<<TT->toString();
203 //	cout<<endl;
204 //	cout<<II->toString();
205 //	cout<<"======================================= hess_reverse_1"<<endl;
206 
207 	unsigned int i =0;
208 	BOOST_FOREACH(Node* node, vnodes)
209 	{
210 		assert(node->getType() == VNode_Type);
211 		//node->index = 0 means this VNode is not in the tree
212 		if(node->index!=0)
213 		{
214 			double hess = TT->get(node->index -1);
215 			if(!isnan(hess))
216 			{
217 				chess(i) = chess(i) + hess;
218 			}
219 		}
220 		i++;
221 	}
222 	assert(i==vnodes.size());
223 	root->hess_reverse_1_clear_index();
224 	TT->clear();
225 	II->clear();
226 	return val;
227 }
228 
nzGrad(Node * root)229 unsigned int nzGrad(Node* root)
230 {
231 	unsigned int nzgrad,total = 0;
232 	boost::unordered_set<Node*> nodes;
233 	root->collect_vnodes(nodes,total);
234 	nzgrad = nodes.size();
235 	return nzgrad;
236 }
237 
238 /*
239  * number of non-zero gradient in constraint tree root that also belong to vSet
240  */
nzGrad(Node * root,boost::unordered_set<Node * > & vSet)241 unsigned int nzGrad(Node* root, boost::unordered_set<Node*>& vSet)
242 {
243 	unsigned int nzgrad=0, total=0;
244 	boost::unordered_set<Node*> vnodes;
245 	root->collect_vnodes(vnodes,total);
246 	//cout<<"nzGrad - vnodes size["<<vnodes.size()<<"] -- total node["<<total<<"]"<<endl;
247 	for(boost::unordered_set<Node*>::iterator it=vnodes.begin();it!=vnodes.end();it++)
248 	{
249 		Node* n = *it;
250 		if(vSet.find(n) != vSet.end())
251 		{
252 			nzgrad++;
253 		}
254 	}
255 	return nzgrad;
256 }
257 
nonlinearEdges(Node * root,EdgeSet & edges)258 void nonlinearEdges(Node* root, EdgeSet& edges)
259 {
260 	root->nonlinearEdges(edges);
261 }
262 
nzHess(EdgeSet & eSet,boost::unordered_set<Node * > & set1,boost::unordered_set<Node * > & set2)263 unsigned int nzHess(EdgeSet& eSet,boost::unordered_set<Node*>& set1, boost::unordered_set<Node*>& set2)
264 {
265 	list<Edge>::iterator i = eSet.edges.begin();
266 	for(;i!=eSet.edges.end();)
267 	{
268 		Edge e =*i;
269 		Node* a = e.a;
270 		Node* b = e.b;
271 		if((set1.find(a)!=set1.end() && set2.find(b)!=set2.end())
272 			||
273 			(set1.find(b)!=set1.end() && set2.find(a)!=set2.end()))
274 		{
275 			//e is connected between set1 and set2
276 			i++;
277 		}
278 		else
279 		{
280 			i = eSet.edges.erase(i);
281 		}
282 	}
283 	unsigned int diag=eSet.numSelfEdges();
284 	unsigned int nzHess = (eSet.size())*2 - diag;
285 	return nzHess;
286 }
287 
nzHess(EdgeSet & edges)288 unsigned int nzHess(EdgeSet& edges)
289 {
290 	unsigned int diag=edges.numSelfEdges();
291 	unsigned int nzHess = (edges.size())*2 - diag;
292 	return nzHess;
293 }
294 
numTotalNodes(Node * root)295 unsigned int numTotalNodes(Node* root)
296 {
297 	unsigned int total = 0;
298 	boost::unordered_set<Node*> nodes;
299 	root->collect_vnodes(nodes,total);
300 	return total;
301 }
302 
tree_expr(Node * root)303 string tree_expr(Node* root)
304 {
305 	ostringstream oss;
306 	oss<<"visiting tree == "<<endl;
307 	int level = 0;
308 	root->inorder_visit(level,oss);
309 	return oss.str();
310 }
311 
print_tree(Node * root)312 void print_tree(Node* root)
313 {
314 	cout<<"visiting tree == "<<endl;
315 	int level = 0;
316 	root->inorder_visit(level,cout);
317 }
318 
autodiff_setup()319 void autodiff_setup()
320 {
321 	Stack::diff = new Stack();
322 	Stack::vals = new Stack();
323 	Tape<unsigned int>::indexTape = new Tape<unsigned int>();
324 	Tape<double>::valueTape = new Tape<double>();
325 }
326 
autodiff_cleanup()327 void autodiff_cleanup()
328 {
329 	delete Stack::diff;
330 	delete Stack::vals;
331 	delete Tape<unsigned int>::indexTape;
332 	delete Tape<double>::valueTape;
333 }
334 
335 } //AutoDiff namespace end
336