1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.renderscript; 18 19 import java.util.ArrayList; 20 21 /** 22 * ScriptGroup creates a group of kernels that are executed 23 * together with one execution call as if they were a single kernel. 24 * The kernels may be connected internally or to an external allocation. 25 * The intermediate results for internal connections are not observable 26 * after the execution of the script. 27 * <p> 28 * External connections are grouped into inputs and outputs. 29 * All outputs are produced by a script kernel and placed into a 30 * user-supplied allocation. Inputs provide the input of a kernel. 31 * Inputs bound to script globals are set directly upon the script. 32 * <p> 33 * A ScriptGroup must contain at least one kernel. A ScriptGroup 34 * must contain only a single directed acyclic graph (DAG) of 35 * script kernels and connections. Attempting to create a 36 * ScriptGroup with multiple DAGs or attempting to create 37 * a cycle within a ScriptGroup will throw an exception. 38 * <p> 39 * Currently, all kernels in a ScriptGroup must be from separate 40 * Script objects. Attempting to use multiple kernels from the same 41 * Script object will result in an {@link android.renderscript.RSInvalidStateException}. 42 * 43 **/ 44 public final class ScriptGroup extends BaseObj { 45 IO mOutputs[]; 46 IO mInputs[]; 47 48 static class IO { 49 Script.KernelID mKID; 50 Allocation mAllocation; 51 IO(Script.KernelID s)52 IO(Script.KernelID s) { 53 mKID = s; 54 } 55 } 56 57 static class ConnectLine { ConnectLine(Type t, Script.KernelID from, Script.KernelID to)58 ConnectLine(Type t, Script.KernelID from, Script.KernelID to) { 59 mFrom = from; 60 mToK = to; 61 mAllocationType = t; 62 } 63 ConnectLine(Type t, Script.KernelID from, Script.FieldID to)64 ConnectLine(Type t, Script.KernelID from, Script.FieldID to) { 65 mFrom = from; 66 mToF = to; 67 mAllocationType = t; 68 } 69 70 Script.FieldID mToF; 71 Script.KernelID mToK; 72 Script.KernelID mFrom; 73 Type mAllocationType; 74 } 75 76 static class Node { 77 Script mScript; 78 ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>(); 79 ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>(); 80 ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>(); 81 int dagNumber; 82 83 Node mNext; 84 Node(Script s)85 Node(Script s) { 86 mScript = s; 87 } 88 } 89 90 ScriptGroup(long id, RenderScript rs)91 ScriptGroup(long id, RenderScript rs) { 92 super(id, rs); 93 } 94 95 /** 96 * Sets an input of the ScriptGroup. This specifies an 97 * Allocation to be used for kernels that require an input 98 * Allocation provided from outside of the ScriptGroup. 99 * 100 * @param s The ID of the kernel where the allocation should be 101 * connected. 102 * @param a The allocation to connect. 103 */ setInput(Script.KernelID s, Allocation a)104 public void setInput(Script.KernelID s, Allocation a) { 105 for (int ct=0; ct < mInputs.length; ct++) { 106 if (mInputs[ct].mKID == s) { 107 mInputs[ct].mAllocation = a; 108 mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a)); 109 return; 110 } 111 } 112 throw new RSIllegalArgumentException("Script not found"); 113 } 114 115 /** 116 * Sets an output of the ScriptGroup. This specifies an 117 * Allocation to be used for the kernels that require an output 118 * Allocation visible after the ScriptGroup is executed. 119 * 120 * @param s The ID of the kernel where the allocation should be 121 * connected. 122 * @param a The allocation to connect. 123 */ setOutput(Script.KernelID s, Allocation a)124 public void setOutput(Script.KernelID s, Allocation a) { 125 for (int ct=0; ct < mOutputs.length; ct++) { 126 if (mOutputs[ct].mKID == s) { 127 mOutputs[ct].mAllocation = a; 128 mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a)); 129 return; 130 } 131 } 132 throw new RSIllegalArgumentException("Script not found"); 133 } 134 135 /** 136 * Execute the ScriptGroup. This will run all the kernels in 137 * the ScriptGroup. No internal connection results will be visible 138 * after execution of the ScriptGroup. 139 */ execute()140 public void execute() { 141 mRS.nScriptGroupExecute(getID(mRS)); 142 } 143 144 145 /** 146 * Helper class to build a ScriptGroup. A ScriptGroup is 147 * created in two steps. 148 * <p> 149 * First, all kernels to be used by the ScriptGroup should be added. 150 * <p> 151 * Second, add connections between kernels. There are two types 152 * of connections: kernel to kernel and kernel to field. 153 * Kernel to kernel allows a kernel's output to be passed to 154 * another kernel as input. Kernel to field allows the output of 155 * one kernel to be bound as a script global. Kernel to kernel is 156 * higher performance and should be used where possible. 157 * <p> 158 * A ScriptGroup must contain a single directed acyclic graph (DAG); it 159 * cannot contain cycles. Currently, all kernels used in a ScriptGroup 160 * must come from different Script objects. Additionally, all kernels 161 * in a ScriptGroup must have at least one input, output, or internal 162 * connection. 163 * <p> 164 * Once all connections are made, a call to {@link #create} will 165 * return the ScriptGroup object. 166 * 167 */ 168 public static final class Builder { 169 private RenderScript mRS; 170 private ArrayList<Node> mNodes = new ArrayList<Node>(); 171 private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>(); 172 private int mKernelCount; 173 174 /** 175 * Create a Builder for generating a ScriptGroup. 176 * 177 * 178 * @param rs The RenderScript context. 179 */ Builder(RenderScript rs)180 public Builder(RenderScript rs) { 181 mRS = rs; 182 } 183 184 // do a DFS from original node, looking for original node 185 // any cycle that could be created must contain original node validateCycle(Node target, Node original)186 private void validateCycle(Node target, Node original) { 187 for (int ct = 0; ct < target.mOutputs.size(); ct++) { 188 final ConnectLine cl = target.mOutputs.get(ct); 189 if (cl.mToK != null) { 190 Node tn = findNode(cl.mToK.mScript); 191 if (tn.equals(original)) { 192 throw new RSInvalidStateException("Loops in group not allowed."); 193 } 194 validateCycle(tn, original); 195 } 196 if (cl.mToF != null) { 197 Node tn = findNode(cl.mToF.mScript); 198 if (tn.equals(original)) { 199 throw new RSInvalidStateException("Loops in group not allowed."); 200 } 201 validateCycle(tn, original); 202 } 203 } 204 } 205 mergeDAGs(int valueUsed, int valueKilled)206 private void mergeDAGs(int valueUsed, int valueKilled) { 207 for (int ct=0; ct < mNodes.size(); ct++) { 208 if (mNodes.get(ct).dagNumber == valueKilled) 209 mNodes.get(ct).dagNumber = valueUsed; 210 } 211 } 212 validateDAGRecurse(Node n, int dagNumber)213 private void validateDAGRecurse(Node n, int dagNumber) { 214 // combine DAGs if this node has been seen already 215 if (n.dagNumber != 0 && n.dagNumber != dagNumber) { 216 mergeDAGs(n.dagNumber, dagNumber); 217 return; 218 } 219 220 n.dagNumber = dagNumber; 221 for (int ct=0; ct < n.mOutputs.size(); ct++) { 222 final ConnectLine cl = n.mOutputs.get(ct); 223 if (cl.mToK != null) { 224 Node tn = findNode(cl.mToK.mScript); 225 validateDAGRecurse(tn, dagNumber); 226 } 227 if (cl.mToF != null) { 228 Node tn = findNode(cl.mToF.mScript); 229 validateDAGRecurse(tn, dagNumber); 230 } 231 } 232 } 233 validateDAG()234 private void validateDAG() { 235 for (int ct=0; ct < mNodes.size(); ct++) { 236 Node n = mNodes.get(ct); 237 if (n.mInputs.size() == 0) { 238 if (n.mOutputs.size() == 0 && mNodes.size() > 1) { 239 throw new RSInvalidStateException("Groups cannot contain unconnected scripts"); 240 } 241 validateDAGRecurse(n, ct+1); 242 } 243 } 244 int dagNumber = mNodes.get(0).dagNumber; 245 for (int ct=0; ct < mNodes.size(); ct++) { 246 if (mNodes.get(ct).dagNumber != dagNumber) { 247 throw new RSInvalidStateException("Multiple DAGs in group not allowed."); 248 } 249 } 250 } 251 findNode(Script s)252 private Node findNode(Script s) { 253 for (int ct=0; ct < mNodes.size(); ct++) { 254 if (s == mNodes.get(ct).mScript) { 255 return mNodes.get(ct); 256 } 257 } 258 return null; 259 } 260 findNode(Script.KernelID k)261 private Node findNode(Script.KernelID k) { 262 for (int ct=0; ct < mNodes.size(); ct++) { 263 Node n = mNodes.get(ct); 264 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) { 265 if (k == n.mKernels.get(ct2)) { 266 return n; 267 } 268 } 269 } 270 return null; 271 } 272 273 /** 274 * Adds a Kernel to the group. 275 * 276 * 277 * @param k The kernel to add. 278 * 279 * @return Builder Returns this. 280 */ addKernel(Script.KernelID k)281 public Builder addKernel(Script.KernelID k) { 282 if (mLines.size() != 0) { 283 throw new RSInvalidStateException( 284 "Kernels may not be added once connections exist."); 285 } 286 287 //android.util.Log.v("RSR", "addKernel 1 k=" + k); 288 if (findNode(k) != null) { 289 return this; 290 } 291 //android.util.Log.v("RSR", "addKernel 2 "); 292 mKernelCount++; 293 Node n = findNode(k.mScript); 294 if (n == null) { 295 //android.util.Log.v("RSR", "addKernel 3 "); 296 n = new Node(k.mScript); 297 mNodes.add(n); 298 } 299 n.mKernels.add(k); 300 return this; 301 } 302 303 /** 304 * Adds a connection to the group. 305 * 306 * 307 * @param t The type of the connection. This is used to 308 * determine the kernel launch sizes on the source side 309 * of this connection. 310 * @param from The source for the connection. 311 * @param to The destination of the connection. 312 * 313 * @return Builder Returns this 314 */ addConnection(Type t, Script.KernelID from, Script.FieldID to)315 public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) { 316 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to); 317 318 Node nf = findNode(from); 319 if (nf == null) { 320 throw new RSInvalidStateException("From script not found."); 321 } 322 323 Node nt = findNode(to.mScript); 324 if (nt == null) { 325 throw new RSInvalidStateException("To script not found."); 326 } 327 328 ConnectLine cl = new ConnectLine(t, from, to); 329 mLines.add(new ConnectLine(t, from, to)); 330 331 nf.mOutputs.add(cl); 332 nt.mInputs.add(cl); 333 334 validateCycle(nf, nf); 335 return this; 336 } 337 338 /** 339 * Adds a connection to the group. 340 * 341 * 342 * @param t The type of the connection. This is used to 343 * determine the kernel launch sizes for both sides of 344 * this connection. 345 * @param from The source for the connection. 346 * @param to The destination of the connection. 347 * 348 * @return Builder Returns this 349 */ addConnection(Type t, Script.KernelID from, Script.KernelID to)350 public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) { 351 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to); 352 353 Node nf = findNode(from); 354 if (nf == null) { 355 throw new RSInvalidStateException("From script not found."); 356 } 357 358 Node nt = findNode(to); 359 if (nt == null) { 360 throw new RSInvalidStateException("To script not found."); 361 } 362 363 ConnectLine cl = new ConnectLine(t, from, to); 364 mLines.add(new ConnectLine(t, from, to)); 365 366 nf.mOutputs.add(cl); 367 nt.mInputs.add(cl); 368 369 validateCycle(nf, nf); 370 return this; 371 } 372 373 374 375 /** 376 * Creates the Script group. 377 * 378 * 379 * @return ScriptGroup The new ScriptGroup 380 */ create()381 public ScriptGroup create() { 382 383 if (mNodes.size() == 0) { 384 throw new RSInvalidStateException("Empty script groups are not allowed"); 385 } 386 387 // reset DAG numbers in case we're building a second group 388 for (int ct=0; ct < mNodes.size(); ct++) { 389 mNodes.get(ct).dagNumber = 0; 390 } 391 validateDAG(); 392 393 ArrayList<IO> inputs = new ArrayList<IO>(); 394 ArrayList<IO> outputs = new ArrayList<IO>(); 395 396 long[] kernels = new long[mKernelCount]; 397 int idx = 0; 398 for (int ct=0; ct < mNodes.size(); ct++) { 399 Node n = mNodes.get(ct); 400 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) { 401 final Script.KernelID kid = n.mKernels.get(ct2); 402 kernels[idx++] = kid.getID(mRS); 403 404 boolean hasInput = false; 405 boolean hasOutput = false; 406 for (int ct3=0; ct3 < n.mInputs.size(); ct3++) { 407 if (n.mInputs.get(ct3).mToK == kid) { 408 hasInput = true; 409 } 410 } 411 for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) { 412 if (n.mOutputs.get(ct3).mFrom == kid) { 413 hasOutput = true; 414 } 415 } 416 if (!hasInput) { 417 inputs.add(new IO(kid)); 418 } 419 if (!hasOutput) { 420 outputs.add(new IO(kid)); 421 } 422 423 } 424 } 425 if (idx != mKernelCount) { 426 throw new RSRuntimeException("Count mismatch, should not happen."); 427 } 428 429 long[] src = new long[mLines.size()]; 430 long[] dstk = new long[mLines.size()]; 431 long[] dstf = new long[mLines.size()]; 432 long[] types = new long[mLines.size()]; 433 434 for (int ct=0; ct < mLines.size(); ct++) { 435 ConnectLine cl = mLines.get(ct); 436 src[ct] = cl.mFrom.getID(mRS); 437 if (cl.mToK != null) { 438 dstk[ct] = cl.mToK.getID(mRS); 439 } 440 if (cl.mToF != null) { 441 dstf[ct] = cl.mToF.getID(mRS); 442 } 443 types[ct] = cl.mAllocationType.getID(mRS); 444 } 445 446 long id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types); 447 if (id == 0) { 448 throw new RSRuntimeException("Object creation error, should not happen."); 449 } 450 451 ScriptGroup sg = new ScriptGroup(id, mRS); 452 sg.mOutputs = new IO[outputs.size()]; 453 for (int ct=0; ct < outputs.size(); ct++) { 454 sg.mOutputs[ct] = outputs.get(ct); 455 } 456 457 sg.mInputs = new IO[inputs.size()]; 458 for (int ct=0; ct < inputs.size(); ct++) { 459 sg.mInputs[ct] = inputs.get(ct); 460 } 461 462 return sg; 463 } 464 465 } 466 467 468 } 469 470 471