• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import "go/ast"
8
9func init() {
10	addTestCases(importTests, nil)
11}
12
13var importTests = []testCase{
14	{
15		Name: "import.0",
16		Fn:   addImportFn("os"),
17		In: `package main
18
19import (
20	"os"
21)
22`,
23		Out: `package main
24
25import (
26	"os"
27)
28`,
29	},
30	{
31		Name: "import.1",
32		Fn:   addImportFn("os"),
33		In: `package main
34`,
35		Out: `package main
36
37import "os"
38`,
39	},
40	{
41		Name: "import.2",
42		Fn:   addImportFn("os"),
43		In: `package main
44
45// Comment
46import "C"
47`,
48		Out: `package main
49
50// Comment
51import "C"
52import "os"
53`,
54	},
55	{
56		Name: "import.3",
57		Fn:   addImportFn("os"),
58		In: `package main
59
60// Comment
61import "C"
62
63import (
64	"io"
65	"utf8"
66)
67`,
68		Out: `package main
69
70// Comment
71import "C"
72
73import (
74	"io"
75	"os"
76	"utf8"
77)
78`,
79	},
80	{
81		Name: "import.4",
82		Fn:   deleteImportFn("os"),
83		In: `package main
84
85import (
86	"os"
87)
88`,
89		Out: `package main
90`,
91	},
92	{
93		Name: "import.5",
94		Fn:   deleteImportFn("os"),
95		In: `package main
96
97// Comment
98import "C"
99import "os"
100`,
101		Out: `package main
102
103// Comment
104import "C"
105`,
106	},
107	{
108		Name: "import.6",
109		Fn:   deleteImportFn("os"),
110		In: `package main
111
112// Comment
113import "C"
114
115import (
116	"io"
117	"os"
118	"utf8"
119)
120`,
121		Out: `package main
122
123// Comment
124import "C"
125
126import (
127	"io"
128	"utf8"
129)
130`,
131	},
132	{
133		Name: "import.7",
134		Fn:   deleteImportFn("io"),
135		In: `package main
136
137import (
138	"io"   // a
139	"os"   // b
140	"utf8" // c
141)
142`,
143		Out: `package main
144
145import (
146	// a
147	"os"   // b
148	"utf8" // c
149)
150`,
151	},
152	{
153		Name: "import.8",
154		Fn:   deleteImportFn("os"),
155		In: `package main
156
157import (
158	"io"   // a
159	"os"   // b
160	"utf8" // c
161)
162`,
163		Out: `package main
164
165import (
166	"io" // a
167	// b
168	"utf8" // c
169)
170`,
171	},
172	{
173		Name: "import.9",
174		Fn:   deleteImportFn("utf8"),
175		In: `package main
176
177import (
178	"io"   // a
179	"os"   // b
180	"utf8" // c
181)
182`,
183		Out: `package main
184
185import (
186	"io" // a
187	"os" // b
188	// c
189)
190`,
191	},
192	{
193		Name: "import.10",
194		Fn:   deleteImportFn("io"),
195		In: `package main
196
197import (
198	"io"
199	"os"
200	"utf8"
201)
202`,
203		Out: `package main
204
205import (
206	"os"
207	"utf8"
208)
209`,
210	},
211	{
212		Name: "import.11",
213		Fn:   deleteImportFn("os"),
214		In: `package main
215
216import (
217	"io"
218	"os"
219	"utf8"
220)
221`,
222		Out: `package main
223
224import (
225	"io"
226	"utf8"
227)
228`,
229	},
230	{
231		Name: "import.12",
232		Fn:   deleteImportFn("utf8"),
233		In: `package main
234
235import (
236	"io"
237	"os"
238	"utf8"
239)
240`,
241		Out: `package main
242
243import (
244	"io"
245	"os"
246)
247`,
248	},
249	{
250		Name: "import.13",
251		Fn:   rewriteImportFn("utf8", "encoding/utf8"),
252		In: `package main
253
254import (
255	"io"
256	"os"
257	"utf8" // thanks ken
258)
259`,
260		Out: `package main
261
262import (
263	"encoding/utf8" // thanks ken
264	"io"
265	"os"
266)
267`,
268	},
269	{
270		Name: "import.14",
271		Fn:   rewriteImportFn("asn1", "encoding/asn1"),
272		In: `package main
273
274import (
275	"asn1"
276	"crypto"
277	"crypto/rsa"
278	_ "crypto/sha1"
279	"crypto/x509"
280	"crypto/x509/pkix"
281	"time"
282)
283
284var x = 1
285`,
286		Out: `package main
287
288import (
289	"crypto"
290	"crypto/rsa"
291	_ "crypto/sha1"
292	"crypto/x509"
293	"crypto/x509/pkix"
294	"encoding/asn1"
295	"time"
296)
297
298var x = 1
299`,
300	},
301	{
302		Name: "import.15",
303		Fn:   rewriteImportFn("url", "net/url"),
304		In: `package main
305
306import (
307	"bufio"
308	"net"
309	"path"
310	"url"
311)
312
313var x = 1 // comment on x, not on url
314`,
315		Out: `package main
316
317import (
318	"bufio"
319	"net"
320	"net/url"
321	"path"
322)
323
324var x = 1 // comment on x, not on url
325`,
326	},
327	{
328		Name: "import.16",
329		Fn:   rewriteImportFn("http", "net/http", "template", "text/template"),
330		In: `package main
331
332import (
333	"flag"
334	"http"
335	"log"
336	"template"
337)
338
339var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
340`,
341		Out: `package main
342
343import (
344	"flag"
345	"log"
346	"net/http"
347	"text/template"
348)
349
350var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
351`,
352	},
353	{
354		Name: "import.17",
355		Fn:   addImportFn("x/y/z", "x/a/c"),
356		In: `package main
357
358// Comment
359import "C"
360
361import (
362	"a"
363	"b"
364
365	"x/w"
366
367	"d/f"
368)
369`,
370		Out: `package main
371
372// Comment
373import "C"
374
375import (
376	"a"
377	"b"
378
379	"x/a/c"
380	"x/w"
381	"x/y/z"
382
383	"d/f"
384)
385`,
386	},
387	{
388		Name: "import.18",
389		Fn:   addDelImportFn("e", "o"),
390		In: `package main
391
392import (
393	"f"
394	"o"
395	"z"
396)
397`,
398		Out: `package main
399
400import (
401	"e"
402	"f"
403	"z"
404)
405`,
406	},
407}
408
409func addImportFn(path ...string) func(*ast.File) bool {
410	return func(f *ast.File) bool {
411		fixed := false
412		for _, p := range path {
413			if !imports(f, p) {
414				addImport(f, p)
415				fixed = true
416			}
417		}
418		return fixed
419	}
420}
421
422func deleteImportFn(path string) func(*ast.File) bool {
423	return func(f *ast.File) bool {
424		if imports(f, path) {
425			deleteImport(f, path)
426			return true
427		}
428		return false
429	}
430}
431
432func addDelImportFn(p1 string, p2 string) func(*ast.File) bool {
433	return func(f *ast.File) bool {
434		fixed := false
435		if !imports(f, p1) {
436			addImport(f, p1)
437			fixed = true
438		}
439		if imports(f, p2) {
440			deleteImport(f, p2)
441			fixed = true
442		}
443		return fixed
444	}
445}
446
447func rewriteImportFn(oldnew ...string) func(*ast.File) bool {
448	return func(f *ast.File) bool {
449		fixed := false
450		for i := 0; i < len(oldnew); i += 2 {
451			if imports(f, oldnew[i]) {
452				rewriteImport(f, oldnew[i], oldnew[i+1])
453				fixed = true
454			}
455		}
456		return fixed
457	}
458}
459