Source file src/crypto/x509/root_unix_test.go

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris
     6  
     7  package x509
     8  
     9  import (
    10  	"bytes"
    11  	"fmt"
    12  	"os"
    13  	"path/filepath"
    14  	"slices"
    15  	"strings"
    16  	"testing"
    17  )
    18  
    19  const (
    20  	testDirCN   = "test-dir"
    21  	testFile    = "test-file.crt"
    22  	testFileCN  = "test-file"
    23  	testMissing = "missing"
    24  )
    25  
    26  func TestEnvVars(t *testing.T) {
    27  	tmpDir := t.TempDir()
    28  	testCert, err := os.ReadFile("testdata/test-dir.crt")
    29  	if err != nil {
    30  		t.Fatalf("failed to read test cert: %s", err)
    31  	}
    32  	if err := os.WriteFile(filepath.Join(tmpDir, testFile), testCert, 0644); err != nil {
    33  		t.Fatalf("failed to write test cert: %s", err)
    34  	}
    35  
    36  	testCases := []struct {
    37  		name    string
    38  		fileEnv string
    39  		dirEnv  string
    40  		files   []string
    41  		dirs    []string
    42  		cns     []string
    43  	}{
    44  		{
    45  			// Environment variables override the default locations preventing fall through.
    46  			name:    "override-defaults",
    47  			fileEnv: testMissing,
    48  			dirEnv:  testMissing,
    49  			files:   []string{testFile},
    50  			dirs:    []string{tmpDir},
    51  			cns:     nil,
    52  		},
    53  		{
    54  			// File environment overrides default file locations.
    55  			name:    "file",
    56  			fileEnv: testFile,
    57  			dirEnv:  "",
    58  			files:   nil,
    59  			dirs:    nil,
    60  			cns:     []string{testFileCN},
    61  		},
    62  		{
    63  			// Directory environment overrides default directory locations.
    64  			name:    "dir",
    65  			fileEnv: "",
    66  			dirEnv:  tmpDir,
    67  			files:   nil,
    68  			dirs:    nil,
    69  			cns:     []string{testDirCN},
    70  		},
    71  		{
    72  			// File & directory environment overrides both default locations.
    73  			name:    "file+dir",
    74  			fileEnv: testFile,
    75  			dirEnv:  tmpDir,
    76  			files:   nil,
    77  			dirs:    nil,
    78  			cns:     []string{testFileCN, testDirCN},
    79  		},
    80  		{
    81  			// Environment variable empty / unset uses default locations.
    82  			name:    "empty-fall-through",
    83  			fileEnv: "",
    84  			dirEnv:  "",
    85  			files:   []string{testFile},
    86  			dirs:    []string{tmpDir},
    87  			cns:     []string{testFileCN, testDirCN},
    88  		},
    89  	}
    90  
    91  	// Save old settings so we can restore before the test ends.
    92  	origCertFiles, origCertDirectories := certFiles, certDirectories
    93  	origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv)
    94  	defer func() {
    95  		certFiles = origCertFiles
    96  		certDirectories = origCertDirectories
    97  		os.Setenv(certFileEnv, origFile)
    98  		os.Setenv(certDirEnv, origDir)
    99  	}()
   100  
   101  	for _, tc := range testCases {
   102  		t.Run(tc.name, func(t *testing.T) {
   103  			if err := os.Setenv(certFileEnv, tc.fileEnv); err != nil {
   104  				t.Fatalf("setenv %q failed: %v", certFileEnv, err)
   105  			}
   106  			if err := os.Setenv(certDirEnv, tc.dirEnv); err != nil {
   107  				t.Fatalf("setenv %q failed: %v", certDirEnv, err)
   108  			}
   109  
   110  			certFiles, certDirectories = tc.files, tc.dirs
   111  
   112  			r, err := loadSystemRoots()
   113  			if err != nil {
   114  				t.Fatal("unexpected failure:", err)
   115  			}
   116  
   117  			if r == nil {
   118  				t.Fatal("nil roots")
   119  			}
   120  
   121  			// Verify that the returned certs match, otherwise report where the mismatch is.
   122  			for i, cn := range tc.cns {
   123  				if i >= r.len() {
   124  					t.Errorf("missing cert %v @ %v", cn, i)
   125  				} else if r.mustCert(t, i).Subject.CommonName != cn {
   126  					fmt.Printf("%#v\n", r.mustCert(t, 0).Subject)
   127  					t.Errorf("unexpected cert common name %q, want %q", r.mustCert(t, i).Subject.CommonName, cn)
   128  				}
   129  			}
   130  			if r.len() > len(tc.cns) {
   131  				t.Errorf("got %v certs, which is more than %v wanted", r.len(), len(tc.cns))
   132  			}
   133  		})
   134  	}
   135  }
   136  
   137  // Ensure that "SSL_CERT_DIR" when used as the environment
   138  // variable delimited by colons, allows loadSystemRoots to
   139  // load all the roots from the respective directories.
   140  // See https://golang.org/issue/35325.
   141  func TestLoadSystemCertsLoadColonSeparatedDirs(t *testing.T) {
   142  	origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv)
   143  	origCertFiles := certFiles[:]
   144  
   145  	// To prevent any other certs from being loaded in
   146  	// through "SSL_CERT_FILE" or from known "certFiles",
   147  	// clear them all, and they'll be reverting on defer.
   148  	certFiles = certFiles[:0]
   149  	os.Setenv(certFileEnv, "")
   150  
   151  	defer func() {
   152  		certFiles = origCertFiles[:]
   153  		os.Setenv(certDirEnv, origDir)
   154  		os.Setenv(certFileEnv, origFile)
   155  	}()
   156  
   157  	tmpDir := t.TempDir()
   158  
   159  	rootPEMs := []string{
   160  		gtsRoot,
   161  		googleLeaf,
   162  	}
   163  
   164  	var certDirs []string
   165  	for i, certPEM := range rootPEMs {
   166  		certDir := filepath.Join(tmpDir, fmt.Sprintf("cert-%d", i))
   167  		if err := os.MkdirAll(certDir, 0755); err != nil {
   168  			t.Fatalf("Failed to create certificate dir: %v", err)
   169  		}
   170  		certOutFile := filepath.Join(certDir, "cert.crt")
   171  		if err := os.WriteFile(certOutFile, []byte(certPEM), 0655); err != nil {
   172  			t.Fatalf("Failed to write certificate to file: %v", err)
   173  		}
   174  		certDirs = append(certDirs, certDir)
   175  	}
   176  
   177  	// Sanity check: the number of certDirs should be equal to the number of roots.
   178  	if g, w := len(certDirs), len(rootPEMs); g != w {
   179  		t.Fatalf("Failed sanity check: len(certsDir)=%d is not equal to len(rootsPEMS)=%d", g, w)
   180  	}
   181  
   182  	// Now finally concatenate them with a colon.
   183  	colonConcatCertDirs := strings.Join(certDirs, ":")
   184  	os.Setenv(certDirEnv, colonConcatCertDirs)
   185  	gotPool, err := loadSystemRoots()
   186  	if err != nil {
   187  		t.Fatalf("Failed to load system roots: %v", err)
   188  	}
   189  	subjects := gotPool.Subjects()
   190  	// We expect exactly len(rootPEMs) subjects back.
   191  	if g, w := len(subjects), len(rootPEMs); g != w {
   192  		t.Fatalf("Invalid number of subjects: got %d want %d", g, w)
   193  	}
   194  
   195  	wantPool := NewCertPool()
   196  	for _, certPEM := range rootPEMs {
   197  		wantPool.AppendCertsFromPEM([]byte(certPEM))
   198  	}
   199  	strCertPool := func(p *CertPool) string {
   200  		return string(bytes.Join(p.Subjects(), []byte("\n")))
   201  	}
   202  
   203  	if !certPoolEqual(gotPool, wantPool) {
   204  		g, w := strCertPool(gotPool), strCertPool(wantPool)
   205  		t.Fatalf("Mismatched certPools\nGot:\n%s\n\nWant:\n%s", g, w)
   206  	}
   207  }
   208  
   209  func TestReadUniqueDirectoryEntries(t *testing.T) {
   210  	tmp := t.TempDir()
   211  	temp := func(base string) string { return filepath.Join(tmp, base) }
   212  	if f, err := os.Create(temp("file")); err != nil {
   213  		t.Fatal(err)
   214  	} else {
   215  		f.Close()
   216  	}
   217  	if err := os.Symlink("target-in", temp("link-in")); err != nil {
   218  		t.Fatal(err)
   219  	}
   220  	if err := os.Symlink("../target-out", temp("link-out")); err != nil {
   221  		t.Fatal(err)
   222  	}
   223  	got, err := readUniqueDirectoryEntries(tmp)
   224  	if err != nil {
   225  		t.Fatal(err)
   226  	}
   227  	gotNames := []string{}
   228  	for _, fi := range got {
   229  		gotNames = append(gotNames, fi.Name())
   230  	}
   231  	wantNames := []string{"file", "link-out"}
   232  	if !slices.Equal(gotNames, wantNames) {
   233  		t.Errorf("got %q; want %q", gotNames, wantNames)
   234  	}
   235  }
   236  

View as plain text