1 /*
2 * Copyright 2018 Google Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 package trebuchet.extractors
18
19 import trebuchet.importers.ImportFeedback
20 import trebuchet.io.*
21 import trebuchet.util.indexOf
22 import java.util.zip.DataFormatException
23 import java.util.zip.Inflater
24 import kotlin.sequences.iterator
25
26 private const val TRACE = "TRACE:"
27
findStartnull28 private fun findStart(buffer: GenericByteBuffer): Long {
29 var start = buffer.indexOf(TRACE, 100)
30 if (start == -1L) {
31 start = 0L
32 } else {
33 start += TRACE.length
34 }
35 while (start < buffer.length &&
36 (buffer[start] == '\n'.toByte() || buffer[start] == '\r'.toByte())) {
37 start++
38 }
39 return start
40 }
41
42 private class DeflateProducer(stream: StreamingReader, val feedback: ImportFeedback)
43 : BufferProducer {
44
45 private val source = stream.source
46 private val inflater = Inflater()
47 private var closed = false
48
<lambda>null49 private val sourceIterator = iterator {
50 stream.loadIndex(stream.startIndex + 1024)
51 val offset = findStart(stream)
52 val buffIter = stream.iter(offset)
53 var avgCompressFactor = 5.0
54 while (buffIter.hasNext()) {
55 val nextBuffer = buffIter.next()
56 inflater.setInput(nextBuffer.buffer, nextBuffer.startIndex, nextBuffer.length)
57 do {
58 val remaining = inflater.remaining
59 val estSize = (remaining * avgCompressFactor * 1.2).toInt()
60 val array = ByteArray(estSize)
61 val len = inflater.inflate(array)
62 if (inflater.needsDictionary()) {
63 feedback.reportImportException(IllegalStateException(
64 "inflater needs dictionary, which isn't supported"))
65 return@iterator
66 }
67 val compressFactor = len.toDouble() / (remaining - inflater.remaining)
68 avgCompressFactor = (avgCompressFactor * 9 + compressFactor) / 10
69 yield(array.asSlice(len))
70 if (closed) return@iterator
71 } while (!inflater.needsInput())
72 inflater.end()
73 }
74 }
75
nextnull76 override fun next(): DataSlice? {
77 return if (sourceIterator.hasNext()) sourceIterator.next() else null
78 }
79
closenull80 override fun close() {
81 closed = true
82 source.close()
83 inflater.end()
84 }
85 }
86
87 class ZlibExtractor(val feedback: ImportFeedback) : Extractor {
88
extractnull89 override fun extract(stream: StreamingReader, processSubStream: (BufferProducer) -> Unit) {
90 processSubStream(DeflateProducer(stream, feedback))
91 }
92
93 object Factory : ExtractorFactory {
94 private const val SIZE_TO_CHECK = 200
95
extractorFornull96 override fun extractorFor(buffer: GenericByteBuffer, feedback: ImportFeedback): Extractor? {
97 val start = findStart(buffer)
98 val toRead = minOf((buffer.length - start).toInt(), SIZE_TO_CHECK)
99 // deflate must contain at least a 2 byte header + 4 byte checksum
100 // So if there's less than 6 bytes this either isn't deflate or
101 // there's not enough data to try an inflate anyway
102 if (toRead <= 6) {
103 return null
104 }
105 val inflate = Inflater()
106 try {
107 val tmpBuffer = ByteArray(toRead) { buffer[start + it] }
108 inflate.setInput(tmpBuffer)
109 val result = ByteArray(1024)
110 val inflated = inflate.inflate(result)
111 inflate.end()
112 if (inflated > 0) {
113 return ZlibExtractor(feedback)
114 }
115 } catch (ex: DataFormatException) {
116 // Must not be deflate format
117 } finally {
118 inflate.end()
119 }
120 return null
121 }
122 }
123 }