• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
<lambda>null2  * Copyright (C) 2015 Square, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.squareup.kotlinpoet
17 
18 import com.squareup.kotlinpoet.AnnotationSpec.UseSiteTarget.FILE
19 import java.io.ByteArrayInputStream
20 import java.io.File
21 import java.io.IOException
22 import java.io.InputStream
23 import java.io.OutputStreamWriter
24 import java.net.URI
25 import java.nio.charset.StandardCharsets.UTF_8
26 import java.nio.file.Files
27 import java.nio.file.Path
28 import javax.annotation.processing.Filer
29 import javax.tools.JavaFileObject
30 import javax.tools.JavaFileObject.Kind
31 import javax.tools.SimpleJavaFileObject
32 import javax.tools.StandardLocation
33 import kotlin.reflect.KClass
34 
35 /**
36  * A Kotlin file containing top level objects like classes, objects, functions, properties, and type
37  * aliases.
38  *
39  * Items are output in the following order:
40  * - Comment
41  * - Annotations
42  * - Package
43  * - Imports
44  * - Members
45  */
46 public class FileSpec private constructor(
47   builder: Builder,
48   private val tagMap: TagMap = builder.buildTagMap(),
49 ) : Taggable by tagMap {
50   public val annotations: List<AnnotationSpec> = builder.annotations.toImmutableList()
51   public val comment: CodeBlock = builder.comment.build()
52   public val packageName: String = builder.packageName
53   public val name: String = builder.name
54   public val members: List<Any> = builder.members.toList()
55   public val defaultImports: Set<String> = builder.defaultImports.toSet()
56   public val body: CodeBlock = builder.body.build()
57   public val isScript: Boolean = builder.isScript
58   private val memberImports = builder.memberImports.associateBy(Import::qualifiedName)
59   private val indent = builder.indent
60   private val extension = if (isScript) "kts" else "kt"
61 
62   @Throws(IOException::class)
63   public fun writeTo(out: Appendable) {
64     val codeWriter = CodeWriter.withCollectedImports(
65       out = out,
66       indent = indent,
67       memberImports = memberImports,
68       emitStep = { importsCollector -> emit(importsCollector, collectingImports = true) },
69     )
70     emit(codeWriter, collectingImports = false)
71     codeWriter.close()
72   }
73 
74   /** Writes this to `directory` as UTF-8 using the standard directory structure.  */
75   @Throws(IOException::class)
76   public fun writeTo(directory: Path) {
77     require(Files.notExists(directory) || Files.isDirectory(directory)) {
78       "path $directory exists but is not a directory."
79     }
80     var outputDirectory = directory
81     if (packageName.isNotEmpty()) {
82       for (packageComponent in packageName.split('.').dropLastWhile { it.isEmpty() }) {
83         outputDirectory = outputDirectory.resolve(packageComponent)
84       }
85     }
86 
87     Files.createDirectories(outputDirectory)
88 
89     val outputPath = outputDirectory.resolve("$name.$extension")
90     OutputStreamWriter(Files.newOutputStream(outputPath), UTF_8).use { writer -> writeTo(writer) }
91   }
92 
93   /** Writes this to `directory` as UTF-8 using the standard directory structure.  */
94   @Throws(IOException::class)
95   public fun writeTo(directory: File): Unit = writeTo(directory.toPath())
96 
97   /** Writes this to `filer`.  */
98   @Throws(IOException::class)
99   public fun writeTo(filer: Filer) {
100     val originatingElements = members.asSequence()
101       .filterIsInstance<OriginatingElementsHolder>()
102       .flatMap { it.originatingElements.asSequence() }
103       .toSet()
104     val filerSourceFile = filer.createResource(
105       StandardLocation.SOURCE_OUTPUT,
106       packageName,
107       "$name.$extension",
108       *originatingElements.toTypedArray(),
109     )
110     try {
111       filerSourceFile.openWriter().use { writer -> writeTo(writer) }
112     } catch (e: Exception) {
113       try {
114         filerSourceFile.delete()
115       } catch (ignored: Exception) {
116       }
117       throw e
118     }
119   }
120 
121   private fun emit(codeWriter: CodeWriter, collectingImports: Boolean) {
122     if (comment.isNotEmpty()) {
123       codeWriter.emitComment(comment)
124     }
125 
126     if (annotations.isNotEmpty()) {
127       codeWriter.emitAnnotations(annotations, inline = false)
128       codeWriter.emit("\n")
129     }
130 
131     codeWriter.pushPackage(packageName)
132 
133     val escapedPackageName = packageName.escapeSegmentsIfNecessary()
134 
135     if (escapedPackageName.isNotEmpty()) {
136       codeWriter.emitCode("package·%L\n", escapedPackageName)
137       codeWriter.emit("\n")
138     }
139 
140     // If we don't have default imports or are collecting them, we don't need to filter
141     var isDefaultImport: (String) -> Boolean = { false }
142     if (!collectingImports && defaultImports.isNotEmpty()) {
143       val defaultImports = defaultImports.map(String::escapeSegmentsIfNecessary)
144       isDefaultImport = { importName ->
145         importName.substringBeforeLast(".") in defaultImports
146       }
147     }
148     // Aliased imports should always appear at the bottom of the imports list.
149     val (aliasedImports, nonAliasedImports) = codeWriter.imports.values
150       .partition { it.alias != null }
151     val imports = nonAliasedImports.asSequence().map { it.toString() }
152       .filterNot(isDefaultImport)
153       .toSortedSet()
154       .plus(aliasedImports.map { it.toString() }.toSortedSet())
155 
156     if (imports.isNotEmpty()) {
157       for (import in imports) {
158         codeWriter.emitCode("import·%L", import)
159         codeWriter.emit("\n")
160       }
161       codeWriter.emit("\n")
162     }
163 
164     if (isScript) {
165       codeWriter.emitCode(body)
166     } else {
167       members.forEachIndexed { index, member ->
168         if (index > 0) codeWriter.emit("\n")
169         when (member) {
170           is TypeSpec -> member.emit(codeWriter, null)
171           is FunSpec -> member.emit(codeWriter, null, setOf(KModifier.PUBLIC), true)
172           is PropertySpec -> member.emit(codeWriter, setOf(KModifier.PUBLIC))
173           is TypeAliasSpec -> member.emit(codeWriter)
174           else -> throw AssertionError()
175         }
176       }
177     }
178 
179     codeWriter.popPackage()
180   }
181 
182   override fun equals(other: Any?): Boolean {
183     if (this === other) return true
184     if (other == null) return false
185     if (javaClass != other.javaClass) return false
186     return toString() == other.toString()
187   }
188 
189   override fun hashCode(): Int = toString().hashCode()
190 
191   override fun toString(): String = buildString { writeTo(this) }
192 
193   public fun toJavaFileObject(): JavaFileObject {
194     val uri = URI.create(
195       if (packageName.isEmpty()) {
196         name
197       } else {
198         packageName.replace('.', '/') + '/' + name
199       } + ".$extension",
200     )
201     return object : SimpleJavaFileObject(uri, Kind.SOURCE) {
202       private val lastModified = System.currentTimeMillis()
203       override fun getCharContent(ignoreEncodingErrors: Boolean): String {
204         return this@FileSpec.toString()
205       }
206 
207       override fun openInputStream(): InputStream {
208         return ByteArrayInputStream(getCharContent(true).toByteArray(UTF_8))
209       }
210 
211       override fun getLastModified() = lastModified
212     }
213   }
214 
215   @JvmOverloads
216   public fun toBuilder(packageName: String = this.packageName, name: String = this.name): Builder {
217     val builder = Builder(packageName, name, isScript)
218     builder.annotations.addAll(annotations)
219     builder.comment.add(comment)
220     builder.members.addAll(this.members)
221     builder.indent = indent
222     builder.memberImports.addAll(memberImports.values)
223     builder.defaultImports.addAll(defaultImports)
224     builder.tags += tagMap.tags
225     builder.body.add(body)
226     return builder
227   }
228 
229   public class Builder internal constructor(
230     public val packageName: String,
231     public val name: String,
232     public val isScript: Boolean,
233   ) : Taggable.Builder<Builder> {
234     internal val comment = CodeBlock.builder()
235     internal val memberImports = sortedSetOf<Import>()
236     internal var indent = DEFAULT_INDENT
237     override val tags: MutableMap<KClass<*>, Any> = mutableMapOf()
238 
239     public val defaultImports: MutableSet<String> = mutableSetOf()
240     public val imports: List<Import> get() = memberImports.toList()
241     public val members: MutableList<Any> = mutableListOf()
242     public val annotations: MutableList<AnnotationSpec> = mutableListOf()
243     internal val body = CodeBlock.builder()
244 
245     /**
246      * Add an annotation to the file.
247      *
248      * The annotation must either have a [`file` use-site target][AnnotationSpec.UseSiteTarget.FILE]
249      * or not have a use-site target specified (in which case it will be changed to `file`).
250      */
251     public fun addAnnotation(annotationSpec: AnnotationSpec): Builder = apply {
252       val spec = when (annotationSpec.useSiteTarget) {
253         FILE -> annotationSpec
254         null -> annotationSpec.toBuilder().useSiteTarget(FILE).build()
255         else -> error(
256           "Use-site target ${annotationSpec.useSiteTarget} not supported for file annotations.",
257         )
258       }
259       annotations += spec
260     }
261 
262     public fun addAnnotation(annotation: ClassName): Builder =
263       addAnnotation(AnnotationSpec.builder(annotation).build())
264 
265     public fun addAnnotation(annotation: Class<*>): Builder =
266       addAnnotation(annotation.asClassName())
267 
268     public fun addAnnotation(annotation: KClass<*>): Builder =
269       addAnnotation(annotation.asClassName())
270 
271     /** Adds a file-site comment. This is prefixed to the start of the file and different from [addBodyComment]. */
272     public fun addFileComment(format: String, vararg args: Any): Builder = apply {
273       comment.add(format.replace(' ', '·'), *args)
274     }
275 
276     @Deprecated(
277       "Use addFileComment() instead.",
278       ReplaceWith("addFileComment(format, args)"),
279       DeprecationLevel.ERROR,
280     )
281     public fun addComment(format: String, vararg args: Any): Builder = addFileComment(format, *args)
282 
283     public fun clearComment(): Builder = apply {
284       comment.clear()
285     }
286 
287     public fun addType(typeSpec: TypeSpec): Builder = apply {
288       if (isScript) {
289         body.add("%L", typeSpec)
290       } else {
291         members += typeSpec
292       }
293     }
294 
295     public fun addFunction(funSpec: FunSpec): Builder = apply {
296       require(!funSpec.isConstructor && !funSpec.isAccessor) {
297         "cannot add ${funSpec.name} to file $name"
298       }
299       if (isScript) {
300         body.add("%L", funSpec)
301       } else {
302         members += funSpec
303       }
304     }
305 
306     public fun addProperty(propertySpec: PropertySpec): Builder = apply {
307       if (isScript) {
308         body.add("%L", propertySpec)
309       } else {
310         members += propertySpec
311       }
312     }
313 
314     public fun addTypeAlias(typeAliasSpec: TypeAliasSpec): Builder = apply {
315       if (isScript) {
316         body.add("%L", typeAliasSpec)
317       } else {
318         members += typeAliasSpec
319       }
320     }
321 
322     public fun addImport(constant: Enum<*>): Builder = addImport(
323       (constant as java.lang.Enum<*>).declaringClass.asClassName(),
324       constant.name,
325     )
326 
327     public fun addImport(`class`: Class<*>, vararg names: String): Builder = apply {
328       require(names.isNotEmpty()) { "names array is empty" }
329       addImport(`class`.asClassName(), names.toList())
330     }
331 
332     public fun addImport(`class`: KClass<*>, vararg names: String): Builder = apply {
333       require(names.isNotEmpty()) { "names array is empty" }
334       addImport(`class`.asClassName(), names.toList())
335     }
336 
337     public fun addImport(className: ClassName, vararg names: String): Builder = apply {
338       require(names.isNotEmpty()) { "names array is empty" }
339       addImport(className, names.toList())
340     }
341 
342     public fun addImport(`class`: Class<*>, names: Iterable<String>): Builder =
343       addImport(`class`.asClassName(), names)
344 
345     public fun addImport(`class`: KClass<*>, names: Iterable<String>): Builder =
346       addImport(`class`.asClassName(), names)
347 
348     public fun addImport(className: ClassName, names: Iterable<String>): Builder = apply {
349       require("*" !in names) { "Wildcard imports are not allowed" }
350       for (name in names) {
351         memberImports += Import(className.canonicalName + "." + name)
352       }
353     }
354 
355     public fun addImport(packageName: String, vararg names: String): Builder = apply {
356       require(names.isNotEmpty()) { "names array is empty" }
357       addImport(packageName, names.toList())
358     }
359 
360     public fun addImport(packageName: String, names: Iterable<String>): Builder = apply {
361       require("*" !in names) { "Wildcard imports are not allowed" }
362       for (name in names) {
363         memberImports += if (packageName.isNotEmpty()) {
364           Import("$packageName.$name")
365         } else {
366           Import(name)
367         }
368       }
369     }
370 
371     public fun addImport(import: Import): Builder = apply {
372       memberImports += import
373     }
374 
375     public fun clearImports(): Builder = apply {
376       memberImports.clear()
377     }
378 
379     public fun addAliasedImport(`class`: Class<*>, `as`: String): Builder =
380       addAliasedImport(`class`.asClassName(), `as`)
381 
382     public fun addAliasedImport(`class`: KClass<*>, `as`: String): Builder =
383       addAliasedImport(`class`.asClassName(), `as`)
384 
385     public fun addAliasedImport(className: ClassName, `as`: String): Builder = apply {
386       memberImports += Import(className.canonicalName, `as`)
387     }
388 
389     public fun addAliasedImport(
390       className: ClassName,
391       memberName: String,
392       `as`: String,
393     ): Builder = apply {
394       memberImports += Import("${className.canonicalName}.$memberName", `as`)
395     }
396 
397     public fun addAliasedImport(memberName: MemberName, `as`: String): Builder = apply {
398       memberImports += Import(memberName.canonicalName, `as`)
399     }
400 
401     /**
402      * Adds a default import for the given [packageName].
403      *
404      * The format of this should be the qualified name of the package, e.g. `kotlin`, `java.lang`,
405      * `org.gradle.api`, etc.
406      */
407     public fun addDefaultPackageImport(packageName: String): Builder = apply {
408       defaultImports += packageName
409     }
410 
411     /**
412      * Adds Kotlin's standard default package imports as described
413      * [here](https://kotlinlang.org/docs/packages.html#default-imports).
414      */
415     public fun addKotlinDefaultImports(
416       includeJvm: Boolean = true,
417       includeJs: Boolean = true,
418     ): Builder = apply {
419       defaultImports += KOTLIN_DEFAULT_IMPORTS
420       if (includeJvm) {
421         defaultImports += KOTLIN_DEFAULT_JVM_IMPORTS
422       }
423       if (includeJs) {
424         defaultImports += KOTLIN_DEFAULT_JS_IMPORTS
425       }
426     }
427 
428     public fun indent(indent: String): Builder = apply {
429       this.indent = indent
430     }
431 
432     public fun addCode(format: String, vararg args: Any?): Builder = apply {
433       check(isScript) {
434         "addCode() is only allowed in script files"
435       }
436       body.add(format, *args)
437     }
438 
439     public fun addNamedCode(format: String, args: Map<String, *>): Builder = apply {
440       check(isScript) {
441         "addNamedCode() is only allowed in script files"
442       }
443       body.addNamed(format, args)
444     }
445 
446     public fun addCode(codeBlock: CodeBlock): Builder = apply {
447       check(isScript) {
448         "addCode() is only allowed in script files"
449       }
450       body.add(codeBlock)
451     }
452 
453     /** Adds a comment to the body of this script file in the order that it was added. */
454     public fun addBodyComment(format: String, vararg args: Any): Builder = apply {
455       check(isScript) {
456         "addBodyComment() is only allowed in script files"
457       }
458       body.add("//·${format.replace(' ', '·')}\n", *args)
459     }
460 
461     /**
462      * @param controlFlow the control flow construct and its code, such as "if (foo == 5)".
463      * Shouldn't contain braces or newline characters.
464      */
465     public fun beginControlFlow(controlFlow: String, vararg args: Any): Builder = apply {
466       check(isScript) {
467         "beginControlFlow() is only allowed in script files"
468       }
469       body.beginControlFlow(controlFlow, *args)
470     }
471 
472     /**
473      * @param controlFlow the control flow construct and its code, such as "else if (foo == 10)".
474      * Shouldn't contain braces or newline characters.
475      */
476     public fun nextControlFlow(controlFlow: String, vararg args: Any): Builder = apply {
477       check(isScript) {
478         "nextControlFlow() is only allowed in script files"
479       }
480       body.nextControlFlow(controlFlow, *args)
481     }
482 
483     public fun endControlFlow(): Builder = apply {
484       check(isScript) {
485         "endControlFlow() is only allowed in script files"
486       }
487       body.endControlFlow()
488     }
489 
490     public fun addStatement(format: String, vararg args: Any): Builder = apply {
491       check(isScript) {
492         "addStatement() is only allowed in script files"
493       }
494       body.addStatement(format, *args)
495     }
496 
497     public fun clearBody(): Builder = apply {
498       check(isScript) {
499         "clearBody() is only allowed in script files"
500       }
501       body.clear()
502     }
503 
504     public fun build(): FileSpec {
505       for (annotationSpec in annotations) {
506         if (annotationSpec.useSiteTarget != FILE) {
507           error(
508             "Use-site target ${annotationSpec.useSiteTarget} not supported for file annotations.",
509           )
510         }
511       }
512       return FileSpec(this)
513     }
514   }
515 
516   public companion object {
517     @JvmStatic public fun get(packageName: String, typeSpec: TypeSpec): FileSpec {
518       val fileName = typeSpec.name
519         ?: throw IllegalArgumentException("file name required but type has no name")
520       return builder(packageName, fileName).addType(typeSpec).build()
521     }
522 
523     @JvmStatic public fun builder(className: ClassName): Builder {
524       require(className.simpleNames.size == 1) {
525         "nested types can't be used to name a file: ${className.simpleNames.joinToString(".")}"
526       }
527       return builder(className.packageName, className.simpleName)
528     }
529 
530     @JvmStatic public fun builder(packageName: String, fileName: String): Builder =
531       Builder(packageName, fileName, isScript = false)
532 
533     @JvmStatic public fun scriptBuilder(fileName: String, packageName: String = ""): Builder =
534       Builder(packageName, fileName, isScript = true)
535   }
536 }
537 
538 internal const val DEFAULT_INDENT = "  "
539 
540 private val KOTLIN_DEFAULT_IMPORTS = setOf(
541   "kotlin",
542   "kotlin.annotation",
543   "kotlin.collections",
544   "kotlin.comparisons",
545   "kotlin.io",
546   "kotlin.ranges",
547   "kotlin.sequences",
548   "kotlin.text",
549 )
550 private val KOTLIN_DEFAULT_JVM_IMPORTS = setOf("java.lang")
551 private val KOTLIN_DEFAULT_JS_IMPORTS = setOf("kotlin.js")
552