• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2017 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris
6
7package x509
8
9import (
10	"bytes"
11	"fmt"
12	"os"
13	"path/filepath"
14	"reflect"
15	"strings"
16	"testing"
17)
18
19const (
20	testDir     = "testdata"
21	testDirCN   = "test-dir"
22	testFile    = "test-file.crt"
23	testFileCN  = "test-file"
24	testMissing = "missing"
25)
26
27func TestEnvVars(t *testing.T) {
28	testCases := []struct {
29		name    string
30		fileEnv string
31		dirEnv  string
32		files   []string
33		dirs    []string
34		cns     []string
35	}{
36		{
37			// Environment variables override the default locations preventing fall through.
38			name:    "override-defaults",
39			fileEnv: testMissing,
40			dirEnv:  testMissing,
41			files:   []string{testFile},
42			dirs:    []string{testDir},
43			cns:     nil,
44		},
45		{
46			// File environment overrides default file locations.
47			name:    "file",
48			fileEnv: testFile,
49			dirEnv:  "",
50			files:   nil,
51			dirs:    nil,
52			cns:     []string{testFileCN},
53		},
54		{
55			// Directory environment overrides default directory locations.
56			name:    "dir",
57			fileEnv: "",
58			dirEnv:  testDir,
59			files:   nil,
60			dirs:    nil,
61			cns:     []string{testDirCN},
62		},
63		{
64			// File & directory environment overrides both default locations.
65			name:    "file+dir",
66			fileEnv: testFile,
67			dirEnv:  testDir,
68			files:   nil,
69			dirs:    nil,
70			cns:     []string{testFileCN, testDirCN},
71		},
72		{
73			// Environment variable empty / unset uses default locations.
74			name:    "empty-fall-through",
75			fileEnv: "",
76			dirEnv:  "",
77			files:   []string{testFile},
78			dirs:    []string{testDir},
79			cns:     []string{testFileCN, testDirCN},
80		},
81	}
82
83	// Save old settings so we can restore before the test ends.
84	origCertFiles, origCertDirectories := certFiles, certDirectories
85	origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv)
86	defer func() {
87		certFiles = origCertFiles
88		certDirectories = origCertDirectories
89		os.Setenv(certFileEnv, origFile)
90		os.Setenv(certDirEnv, origDir)
91	}()
92
93	for _, tc := range testCases {
94		t.Run(tc.name, func(t *testing.T) {
95			if err := os.Setenv(certFileEnv, tc.fileEnv); err != nil {
96				t.Fatalf("setenv %q failed: %v", certFileEnv, err)
97			}
98			if err := os.Setenv(certDirEnv, tc.dirEnv); err != nil {
99				t.Fatalf("setenv %q failed: %v", certDirEnv, err)
100			}
101
102			certFiles, certDirectories = tc.files, tc.dirs
103
104			r, err := loadSystemRoots()
105			if err != nil {
106				t.Fatal("unexpected failure:", err)
107			}
108
109			if r == nil {
110				t.Fatal("nil roots")
111			}
112
113			// Verify that the returned certs match, otherwise report where the mismatch is.
114			for i, cn := range tc.cns {
115				if i >= r.len() {
116					t.Errorf("missing cert %v @ %v", cn, i)
117				} else if r.mustCert(t, i).Subject.CommonName != cn {
118					fmt.Printf("%#v\n", r.mustCert(t, 0).Subject)
119					t.Errorf("unexpected cert common name %q, want %q", r.mustCert(t, i).Subject.CommonName, cn)
120				}
121			}
122			if r.len() > len(tc.cns) {
123				t.Errorf("got %v certs, which is more than %v wanted", r.len(), len(tc.cns))
124			}
125		})
126	}
127}
128
129// Ensure that "SSL_CERT_DIR" when used as the environment
130// variable delimited by colons, allows loadSystemRoots to
131// load all the roots from the respective directories.
132// See https://golang.org/issue/35325.
133func TestLoadSystemCertsLoadColonSeparatedDirs(t *testing.T) {
134	origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv)
135	origCertFiles := certFiles[:]
136
137	// To prevent any other certs from being loaded in
138	// through "SSL_CERT_FILE" or from known "certFiles",
139	// clear them all, and they'll be reverting on defer.
140	certFiles = certFiles[:0]
141	os.Setenv(certFileEnv, "")
142
143	defer func() {
144		certFiles = origCertFiles[:]
145		os.Setenv(certDirEnv, origDir)
146		os.Setenv(certFileEnv, origFile)
147	}()
148
149	tmpDir := t.TempDir()
150
151	rootPEMs := []string{
152		gtsRoot,
153		googleLeaf,
154		startComRoot,
155	}
156
157	var certDirs []string
158	for i, certPEM := range rootPEMs {
159		certDir := filepath.Join(tmpDir, fmt.Sprintf("cert-%d", i))
160		if err := os.MkdirAll(certDir, 0755); err != nil {
161			t.Fatalf("Failed to create certificate dir: %v", err)
162		}
163		certOutFile := filepath.Join(certDir, "cert.crt")
164		if err := os.WriteFile(certOutFile, []byte(certPEM), 0655); err != nil {
165			t.Fatalf("Failed to write certificate to file: %v", err)
166		}
167		certDirs = append(certDirs, certDir)
168	}
169
170	// Sanity check: the number of certDirs should be equal to the number of roots.
171	if g, w := len(certDirs), len(rootPEMs); g != w {
172		t.Fatalf("Failed sanity check: len(certsDir)=%d is not equal to len(rootsPEMS)=%d", g, w)
173	}
174
175	// Now finally concatenate them with a colon.
176	colonConcatCertDirs := strings.Join(certDirs, ":")
177	os.Setenv(certDirEnv, colonConcatCertDirs)
178	gotPool, err := loadSystemRoots()
179	if err != nil {
180		t.Fatalf("Failed to load system roots: %v", err)
181	}
182	subjects := gotPool.Subjects()
183	// We expect exactly len(rootPEMs) subjects back.
184	if g, w := len(subjects), len(rootPEMs); g != w {
185		t.Fatalf("Invalid number of subjects: got %d want %d", g, w)
186	}
187
188	wantPool := NewCertPool()
189	for _, certPEM := range rootPEMs {
190		wantPool.AppendCertsFromPEM([]byte(certPEM))
191	}
192	strCertPool := func(p *CertPool) string {
193		return string(bytes.Join(p.Subjects(), []byte("\n")))
194	}
195
196	if !certPoolEqual(gotPool, wantPool) {
197		g, w := strCertPool(gotPool), strCertPool(wantPool)
198		t.Fatalf("Mismatched certPools\nGot:\n%s\n\nWant:\n%s", g, w)
199	}
200}
201
202func TestReadUniqueDirectoryEntries(t *testing.T) {
203	tmp := t.TempDir()
204	temp := func(base string) string { return filepath.Join(tmp, base) }
205	if f, err := os.Create(temp("file")); err != nil {
206		t.Fatal(err)
207	} else {
208		f.Close()
209	}
210	if err := os.Symlink("target-in", temp("link-in")); err != nil {
211		t.Fatal(err)
212	}
213	if err := os.Symlink("../target-out", temp("link-out")); err != nil {
214		t.Fatal(err)
215	}
216	got, err := readUniqueDirectoryEntries(tmp)
217	if err != nil {
218		t.Fatal(err)
219	}
220	gotNames := []string{}
221	for _, fi := range got {
222		gotNames = append(gotNames, fi.Name())
223	}
224	wantNames := []string{"file", "link-out"}
225	if !reflect.DeepEqual(gotNames, wantNames) {
226		t.Errorf("got %q; want %q", gotNames, wantNames)
227	}
228}
229