1// Copyright 2024 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package os_test 6 7import ( 8 "bytes" 9 "errors" 10 "io" 11 "math/rand/v2" 12 "net" 13 "os" 14 "runtime" 15 "sync" 16 "testing" 17 18 "golang.org/x/net/nettest" 19) 20 21// Exercise sendfile/splice fast paths with a moderately large file. 22// 23// https://go.dev/issue/70000 24 25func TestLargeCopyViaNetwork(t *testing.T) { 26 const size = 10 * 1024 * 1024 27 dir := t.TempDir() 28 29 src, err := os.Create(dir + "/src") 30 if err != nil { 31 t.Fatal(err) 32 } 33 defer src.Close() 34 if _, err := io.CopyN(src, newRandReader(), size); err != nil { 35 t.Fatal(err) 36 } 37 if _, err := src.Seek(0, 0); err != nil { 38 t.Fatal(err) 39 } 40 41 dst, err := os.Create(dir + "/dst") 42 if err != nil { 43 t.Fatal(err) 44 } 45 defer dst.Close() 46 47 client, server := createSocketPair(t, "tcp") 48 var wg sync.WaitGroup 49 wg.Add(2) 50 go func() { 51 defer wg.Done() 52 if n, err := io.Copy(dst, server); n != size || err != nil { 53 t.Errorf("copy to destination = %v, %v; want %v, nil", n, err, size) 54 } 55 }() 56 go func() { 57 defer wg.Done() 58 defer client.Close() 59 if n, err := io.Copy(client, src); n != size || err != nil { 60 t.Errorf("copy from source = %v, %v; want %v, nil", n, err, size) 61 } 62 }() 63 wg.Wait() 64 65 if _, err := dst.Seek(0, 0); err != nil { 66 t.Fatal(err) 67 } 68 if err := compareReaders(dst, io.LimitReader(newRandReader(), size)); err != nil { 69 t.Fatal(err) 70 } 71} 72 73func compareReaders(a, b io.Reader) error { 74 bufa := make([]byte, 4096) 75 bufb := make([]byte, 4096) 76 for { 77 na, erra := io.ReadFull(a, bufa) 78 if erra != nil && erra != io.EOF { 79 return erra 80 } 81 nb, errb := io.ReadFull(b, bufb) 82 if errb != nil && errb != io.EOF { 83 return errb 84 } 85 if !bytes.Equal(bufa[:na], bufb[:nb]) { 86 return errors.New("contents mismatch") 87 } 88 if erra == io.EOF && errb == io.EOF { 89 break 90 } 91 } 92 return nil 93} 94 95type randReader struct { 96 rand *rand.Rand 97} 98 99func newRandReader() *randReader { 100 return &randReader{rand.New(rand.NewPCG(0, 0))} 101} 102 103func (r *randReader) Read(p []byte) (int, error) { 104 var v uint64 105 var n int 106 for i := range p { 107 if n == 0 { 108 v = r.rand.Uint64() 109 n = 8 110 } 111 p[i] = byte(v & 0xff) 112 v >>= 8 113 n-- 114 } 115 return len(p), nil 116} 117 118func createSocketPair(t *testing.T, proto string) (client, server net.Conn) { 119 t.Helper() 120 if !nettest.TestableNetwork(proto) { 121 t.Skipf("%s does not support %q", runtime.GOOS, proto) 122 } 123 124 ln, err := nettest.NewLocalListener(proto) 125 if err != nil { 126 t.Fatalf("NewLocalListener error: %v", err) 127 } 128 t.Cleanup(func() { 129 if ln != nil { 130 ln.Close() 131 } 132 if client != nil { 133 client.Close() 134 } 135 if server != nil { 136 server.Close() 137 } 138 }) 139 ch := make(chan struct{}) 140 go func() { 141 var err error 142 server, err = ln.Accept() 143 if err != nil { 144 t.Errorf("Accept new connection error: %v", err) 145 } 146 ch <- struct{}{} 147 }() 148 client, err = net.Dial(proto, ln.Addr().String()) 149 <-ch 150 if err != nil { 151 t.Fatalf("Dial new connection error: %v", err) 152 } 153 return client, server 154} 155