• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2012 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.renderscript;
18 
19 import java.util.ArrayList;
20 
21 /**
22  * ScriptGroup creates a group of kernels that are executed
23  * together with one execution call as if they were a single kernel.
24  * The kernels may be connected internally or to an external allocation.
25  * The intermediate results for internal connections are not observable
26  * after the execution of the script.
27  * <p>
28  * External connections are grouped into inputs and outputs.
29  * All outputs are produced by a script kernel and placed into a
30  * user-supplied allocation. Inputs provide the input of a kernel.
31  * Inputs bound to script globals are set directly upon the script.
32  * <p>
33  * A ScriptGroup must contain at least one kernel. A ScriptGroup
34  * must contain only a single directed acyclic graph (DAG) of
35  * script kernels and connections. Attempting to create a
36  * ScriptGroup with multiple DAGs or attempting to create
37  * a cycle within a ScriptGroup will throw an exception.
38  * <p>
39  * Currently, all kernels in a ScriptGroup must be from separate
40  * Script objects. Attempting to use multiple kernels from the same
41  * Script object will result in an {@link android.renderscript.RSInvalidStateException}.
42  *
43  **/
44 public final class ScriptGroup extends BaseObj {
45     IO mOutputs[];
46     IO mInputs[];
47 
48     static class IO {
49         Script.KernelID mKID;
50         Allocation mAllocation;
51 
IO(Script.KernelID s)52         IO(Script.KernelID s) {
53             mKID = s;
54         }
55     }
56 
57     static class ConnectLine {
ConnectLine(Type t, Script.KernelID from, Script.KernelID to)58         ConnectLine(Type t, Script.KernelID from, Script.KernelID to) {
59             mFrom = from;
60             mToK = to;
61             mAllocationType = t;
62         }
63 
ConnectLine(Type t, Script.KernelID from, Script.FieldID to)64         ConnectLine(Type t, Script.KernelID from, Script.FieldID to) {
65             mFrom = from;
66             mToF = to;
67             mAllocationType = t;
68         }
69 
70         Script.FieldID mToF;
71         Script.KernelID mToK;
72         Script.KernelID mFrom;
73         Type mAllocationType;
74     }
75 
76     static class Node {
77         Script mScript;
78         ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>();
79         ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>();
80         ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>();
81         int dagNumber;
82 
83         Node mNext;
84 
Node(Script s)85         Node(Script s) {
86             mScript = s;
87         }
88     }
89 
90 
ScriptGroup(long id, RenderScript rs)91     ScriptGroup(long id, RenderScript rs) {
92         super(id, rs);
93     }
94 
95     /**
96      * Sets an input of the ScriptGroup. This specifies an
97      * Allocation to be used for kernels that require an input
98      * Allocation provided from outside of the ScriptGroup.
99      *
100      * @param s The ID of the kernel where the allocation should be
101      *          connected.
102      * @param a The allocation to connect.
103      */
setInput(Script.KernelID s, Allocation a)104     public void setInput(Script.KernelID s, Allocation a) {
105         for (int ct=0; ct < mInputs.length; ct++) {
106             if (mInputs[ct].mKID == s) {
107                 mInputs[ct].mAllocation = a;
108                 mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a));
109                 return;
110             }
111         }
112         throw new RSIllegalArgumentException("Script not found");
113     }
114 
115     /**
116      * Sets an output of the ScriptGroup. This specifies an
117      * Allocation to be used for the kernels that require an output
118      * Allocation visible after the ScriptGroup is executed.
119      *
120      * @param s The ID of the kernel where the allocation should be
121      *          connected.
122      * @param a The allocation to connect.
123      */
setOutput(Script.KernelID s, Allocation a)124     public void setOutput(Script.KernelID s, Allocation a) {
125         for (int ct=0; ct < mOutputs.length; ct++) {
126             if (mOutputs[ct].mKID == s) {
127                 mOutputs[ct].mAllocation = a;
128                 mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a));
129                 return;
130             }
131         }
132         throw new RSIllegalArgumentException("Script not found");
133     }
134 
135     /**
136      * Execute the ScriptGroup.  This will run all the kernels in
137      * the ScriptGroup.  No internal connection results will be visible
138      * after execution of the ScriptGroup.
139      */
execute()140     public void execute() {
141         mRS.nScriptGroupExecute(getID(mRS));
142     }
143 
144 
145     /**
146      * Helper class to build a ScriptGroup. A ScriptGroup is
147      * created in two steps.
148      * <p>
149      * First, all kernels to be used by the ScriptGroup should be added.
150      * <p>
151      * Second, add connections between kernels. There are two types
152      * of connections: kernel to kernel and kernel to field.
153      * Kernel to kernel allows a kernel's output to be passed to
154      * another kernel as input. Kernel to field allows the output of
155      * one kernel to be bound as a script global. Kernel to kernel is
156      * higher performance and should be used where possible.
157      * <p>
158      * A ScriptGroup must contain a single directed acyclic graph (DAG); it
159      * cannot contain cycles. Currently, all kernels used in a ScriptGroup
160      * must come from different Script objects. Additionally, all kernels
161      * in a ScriptGroup must have at least one input, output, or internal
162      * connection.
163      * <p>
164      * Once all connections are made, a call to {@link #create} will
165      * return the ScriptGroup object.
166      *
167      */
168     public static final class Builder {
169         private RenderScript mRS;
170         private ArrayList<Node> mNodes = new ArrayList<Node>();
171         private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>();
172         private int mKernelCount;
173 
174         /**
175          * Create a Builder for generating a ScriptGroup.
176          *
177          *
178          * @param rs The RenderScript context.
179          */
Builder(RenderScript rs)180         public Builder(RenderScript rs) {
181             mRS = rs;
182         }
183 
184         // do a DFS from original node, looking for original node
185         // any cycle that could be created must contain original node
validateCycle(Node target, Node original)186         private void validateCycle(Node target, Node original) {
187             for (int ct = 0; ct < target.mOutputs.size(); ct++) {
188                 final ConnectLine cl = target.mOutputs.get(ct);
189                 if (cl.mToK != null) {
190                     Node tn = findNode(cl.mToK.mScript);
191                     if (tn.equals(original)) {
192                         throw new RSInvalidStateException("Loops in group not allowed.");
193                     }
194                     validateCycle(tn, original);
195                 }
196                 if (cl.mToF != null) {
197                     Node tn = findNode(cl.mToF.mScript);
198                     if (tn.equals(original)) {
199                         throw new RSInvalidStateException("Loops in group not allowed.");
200                     }
201                     validateCycle(tn, original);
202                 }
203             }
204         }
205 
mergeDAGs(int valueUsed, int valueKilled)206         private void mergeDAGs(int valueUsed, int valueKilled) {
207             for (int ct=0; ct < mNodes.size(); ct++) {
208                 if (mNodes.get(ct).dagNumber == valueKilled)
209                     mNodes.get(ct).dagNumber = valueUsed;
210             }
211         }
212 
validateDAGRecurse(Node n, int dagNumber)213         private void validateDAGRecurse(Node n, int dagNumber) {
214             // combine DAGs if this node has been seen already
215             if (n.dagNumber != 0 && n.dagNumber != dagNumber) {
216                 mergeDAGs(n.dagNumber, dagNumber);
217                 return;
218             }
219 
220             n.dagNumber = dagNumber;
221             for (int ct=0; ct < n.mOutputs.size(); ct++) {
222                 final ConnectLine cl = n.mOutputs.get(ct);
223                 if (cl.mToK != null) {
224                     Node tn = findNode(cl.mToK.mScript);
225                     validateDAGRecurse(tn, dagNumber);
226                 }
227                 if (cl.mToF != null) {
228                     Node tn = findNode(cl.mToF.mScript);
229                     validateDAGRecurse(tn, dagNumber);
230                 }
231             }
232         }
233 
validateDAG()234         private void validateDAG() {
235             for (int ct=0; ct < mNodes.size(); ct++) {
236                 Node n = mNodes.get(ct);
237                 if (n.mInputs.size() == 0) {
238                     if (n.mOutputs.size() == 0 && mNodes.size() > 1) {
239                         throw new RSInvalidStateException("Groups cannot contain unconnected scripts");
240                     }
241                     validateDAGRecurse(n, ct+1);
242                 }
243             }
244             int dagNumber = mNodes.get(0).dagNumber;
245             for (int ct=0; ct < mNodes.size(); ct++) {
246                 if (mNodes.get(ct).dagNumber != dagNumber) {
247                     throw new RSInvalidStateException("Multiple DAGs in group not allowed.");
248                 }
249             }
250         }
251 
findNode(Script s)252         private Node findNode(Script s) {
253             for (int ct=0; ct < mNodes.size(); ct++) {
254                 if (s == mNodes.get(ct).mScript) {
255                     return mNodes.get(ct);
256                 }
257             }
258             return null;
259         }
260 
findNode(Script.KernelID k)261         private Node findNode(Script.KernelID k) {
262             for (int ct=0; ct < mNodes.size(); ct++) {
263                 Node n = mNodes.get(ct);
264                 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
265                     if (k == n.mKernels.get(ct2)) {
266                         return n;
267                     }
268                 }
269             }
270             return null;
271         }
272 
273         /**
274          * Adds a Kernel to the group.
275          *
276          *
277          * @param k The kernel to add.
278          *
279          * @return Builder Returns this.
280          */
addKernel(Script.KernelID k)281         public Builder addKernel(Script.KernelID k) {
282             if (mLines.size() != 0) {
283                 throw new RSInvalidStateException(
284                     "Kernels may not be added once connections exist.");
285             }
286 
287             //android.util.Log.v("RSR", "addKernel 1 k=" + k);
288             if (findNode(k) != null) {
289                 return this;
290             }
291             //android.util.Log.v("RSR", "addKernel 2 ");
292             mKernelCount++;
293             Node n = findNode(k.mScript);
294             if (n == null) {
295                 //android.util.Log.v("RSR", "addKernel 3 ");
296                 n = new Node(k.mScript);
297                 mNodes.add(n);
298             }
299             n.mKernels.add(k);
300             return this;
301         }
302 
303         /**
304          * Adds a connection to the group.
305          *
306          *
307          * @param t The type of the connection. This is used to
308          *          determine the kernel launch sizes on the source side
309          *          of this connection.
310          * @param from The source for the connection.
311          * @param to The destination of the connection.
312          *
313          * @return Builder Returns this
314          */
addConnection(Type t, Script.KernelID from, Script.FieldID to)315         public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) {
316             //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
317 
318             Node nf = findNode(from);
319             if (nf == null) {
320                 throw new RSInvalidStateException("From script not found.");
321             }
322 
323             Node nt = findNode(to.mScript);
324             if (nt == null) {
325                 throw new RSInvalidStateException("To script not found.");
326             }
327 
328             ConnectLine cl = new ConnectLine(t, from, to);
329             mLines.add(new ConnectLine(t, from, to));
330 
331             nf.mOutputs.add(cl);
332             nt.mInputs.add(cl);
333 
334             validateCycle(nf, nf);
335             return this;
336         }
337 
338         /**
339          * Adds a connection to the group.
340          *
341          *
342          * @param t The type of the connection. This is used to
343          *          determine the kernel launch sizes for both sides of
344          *          this connection.
345          * @param from The source for the connection.
346          * @param to The destination of the connection.
347          *
348          * @return Builder Returns this
349          */
addConnection(Type t, Script.KernelID from, Script.KernelID to)350         public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) {
351             //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
352 
353             Node nf = findNode(from);
354             if (nf == null) {
355                 throw new RSInvalidStateException("From script not found.");
356             }
357 
358             Node nt = findNode(to);
359             if (nt == null) {
360                 throw new RSInvalidStateException("To script not found.");
361             }
362 
363             ConnectLine cl = new ConnectLine(t, from, to);
364             mLines.add(new ConnectLine(t, from, to));
365 
366             nf.mOutputs.add(cl);
367             nt.mInputs.add(cl);
368 
369             validateCycle(nf, nf);
370             return this;
371         }
372 
373 
374 
375         /**
376          * Creates the Script group.
377          *
378          *
379          * @return ScriptGroup The new ScriptGroup
380          */
create()381         public ScriptGroup create() {
382 
383             if (mNodes.size() == 0) {
384                 throw new RSInvalidStateException("Empty script groups are not allowed");
385             }
386 
387             // reset DAG numbers in case we're building a second group
388             for (int ct=0; ct < mNodes.size(); ct++) {
389                 mNodes.get(ct).dagNumber = 0;
390             }
391             validateDAG();
392 
393             ArrayList<IO> inputs = new ArrayList<IO>();
394             ArrayList<IO> outputs = new ArrayList<IO>();
395 
396             long[] kernels = new long[mKernelCount];
397             int idx = 0;
398             for (int ct=0; ct < mNodes.size(); ct++) {
399                 Node n = mNodes.get(ct);
400                 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
401                     final Script.KernelID kid = n.mKernels.get(ct2);
402                     kernels[idx++] = kid.getID(mRS);
403 
404                     boolean hasInput = false;
405                     boolean hasOutput = false;
406                     for (int ct3=0; ct3 < n.mInputs.size(); ct3++) {
407                         if (n.mInputs.get(ct3).mToK == kid) {
408                             hasInput = true;
409                         }
410                     }
411                     for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) {
412                         if (n.mOutputs.get(ct3).mFrom == kid) {
413                             hasOutput = true;
414                         }
415                     }
416                     if (!hasInput) {
417                         inputs.add(new IO(kid));
418                     }
419                     if (!hasOutput) {
420                         outputs.add(new IO(kid));
421                     }
422 
423                 }
424             }
425             if (idx != mKernelCount) {
426                 throw new RSRuntimeException("Count mismatch, should not happen.");
427             }
428 
429             long[] src = new long[mLines.size()];
430             long[] dstk = new long[mLines.size()];
431             long[] dstf = new long[mLines.size()];
432             long[] types = new long[mLines.size()];
433 
434             for (int ct=0; ct < mLines.size(); ct++) {
435                 ConnectLine cl = mLines.get(ct);
436                 src[ct] = cl.mFrom.getID(mRS);
437                 if (cl.mToK != null) {
438                     dstk[ct] = cl.mToK.getID(mRS);
439                 }
440                 if (cl.mToF != null) {
441                     dstf[ct] = cl.mToF.getID(mRS);
442                 }
443                 types[ct] = cl.mAllocationType.getID(mRS);
444             }
445 
446             long id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types);
447             if (id == 0) {
448                 throw new RSRuntimeException("Object creation error, should not happen.");
449             }
450 
451             ScriptGroup sg = new ScriptGroup(id, mRS);
452             sg.mOutputs = new IO[outputs.size()];
453             for (int ct=0; ct < outputs.size(); ct++) {
454                 sg.mOutputs[ct] = outputs.get(ct);
455             }
456 
457             sg.mInputs = new IO[inputs.size()];
458             for (int ct=0; ct < inputs.size(); ct++) {
459                 sg.mInputs[ct] = inputs.get(ct);
460             }
461 
462             return sg;
463         }
464 
465     }
466 
467 
468 }
469 
470 
471