Source file
src/net/http/transport_test.go
1
2
3
4
5
6
7
8
9
10 package http_test
11
12 import (
13 "bufio"
14 "bytes"
15 "compress/gzip"
16 "context"
17 "crypto/rand"
18 "crypto/tls"
19 "crypto/x509"
20 "encoding/binary"
21 "errors"
22 "fmt"
23 "go/token"
24 "internal/nettrace"
25 "io"
26 "log"
27 mrand "math/rand"
28 "net"
29 "net/http"
30 . "net/http"
31 "net/http/httptest"
32 "net/http/httptrace"
33 "net/http/httputil"
34 "net/http/internal/testcert"
35 "net/textproto"
36 "net/url"
37 "os"
38 "reflect"
39 "runtime"
40 "slices"
41 "strconv"
42 "strings"
43 "sync"
44 "sync/atomic"
45 "testing"
46 "testing/iotest"
47 "testing/synctest"
48 "time"
49
50 "golang.org/x/net/http/httpguts"
51 )
52
53
54
55
56
57 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
58 if r.FormValue("close") == "true" {
59 w.Header().Set("Connection", "close")
60 }
61 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
62 w.Write([]byte(r.RemoteAddr))
63
64
65
66 if c, ok := ResponseWriterConnForTesting(w); ok {
67 fmt.Fprintf(w, ", %T %p", c, c)
68 }
69 })
70
71
72 type testCloseConn struct {
73 net.Conn
74 set *testConnSet
75 }
76
77 func (c *testCloseConn) Close() error {
78 c.set.remove(c)
79 return c.Conn.Close()
80 }
81
82
83
84 type testConnSet struct {
85 t *testing.T
86 mu sync.Mutex
87 closed map[net.Conn]bool
88 list []net.Conn
89 }
90
91 func (tcs *testConnSet) insert(c net.Conn) {
92 tcs.mu.Lock()
93 defer tcs.mu.Unlock()
94 tcs.closed[c] = false
95 tcs.list = append(tcs.list, c)
96 }
97
98 func (tcs *testConnSet) remove(c net.Conn) {
99 tcs.mu.Lock()
100 defer tcs.mu.Unlock()
101 tcs.closed[c] = true
102 }
103
104
105 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
106 connSet := &testConnSet{
107 t: t,
108 closed: make(map[net.Conn]bool),
109 }
110 dial := func(n, addr string) (net.Conn, error) {
111 c, err := net.Dial(n, addr)
112 if err != nil {
113 return nil, err
114 }
115 tc := &testCloseConn{c, connSet}
116 connSet.insert(tc)
117 return tc, nil
118 }
119 return connSet, dial
120 }
121
122 func (tcs *testConnSet) check(t *testing.T) {
123 tcs.mu.Lock()
124 defer tcs.mu.Unlock()
125 for i := 4; i >= 0; i-- {
126 for i, c := range tcs.list {
127 if tcs.closed[c] {
128 continue
129 }
130 if i != 0 {
131
132
133 tcs.mu.Unlock()
134 time.Sleep(50 * time.Millisecond)
135 tcs.mu.Lock()
136 continue
137 }
138 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
139 }
140 }
141 }
142
143 func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
144 func testReuseRequest(t *testing.T, mode testMode) {
145 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
146 w.Write([]byte("{}"))
147 })).ts
148
149 c := ts.Client()
150 req, _ := NewRequest("GET", ts.URL, nil)
151 res, err := c.Do(req)
152 if err != nil {
153 t.Fatal(err)
154 }
155 err = res.Body.Close()
156 if err != nil {
157 t.Fatal(err)
158 }
159
160 res, err = c.Do(req)
161 if err != nil {
162 t.Fatal(err)
163 }
164 err = res.Body.Close()
165 if err != nil {
166 t.Fatal(err)
167 }
168 }
169
170
171
172 func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
173 func testTransportKeepAlives(t *testing.T, mode testMode) {
174 ts := newClientServerTest(t, mode, hostPortHandler).ts
175
176 c := ts.Client()
177 for _, disableKeepAlive := range []bool{false, true} {
178 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
179 fetch := func(n int) string {
180 res, err := c.Get(ts.URL)
181 if err != nil {
182 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
183 }
184 body, err := io.ReadAll(res.Body)
185 if err != nil {
186 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
187 }
188 return string(body)
189 }
190
191 body1 := fetch(1)
192 body2 := fetch(2)
193
194 bodiesDiffer := body1 != body2
195 if bodiesDiffer != disableKeepAlive {
196 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
197 disableKeepAlive, bodiesDiffer, body1, body2)
198 }
199 }
200 }
201
202 func TestTransportConnectionCloseOnResponse(t *testing.T) {
203 run(t, testTransportConnectionCloseOnResponse)
204 }
205 func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
206 ts := newClientServerTest(t, mode, hostPortHandler).ts
207
208 connSet, testDial := makeTestDial(t)
209
210 c := ts.Client()
211 tr := c.Transport.(*Transport)
212 tr.Dial = testDial
213
214 for _, connectionClose := range []bool{false, true} {
215 fetch := func(n int) string {
216 req := new(Request)
217 var err error
218 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
219 if err != nil {
220 t.Fatalf("URL parse error: %v", err)
221 }
222 req.Method = "GET"
223 req.Proto = "HTTP/1.1"
224 req.ProtoMajor = 1
225 req.ProtoMinor = 1
226
227 res, err := c.Do(req)
228 if err != nil {
229 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
230 }
231 defer res.Body.Close()
232 body, err := io.ReadAll(res.Body)
233 if err != nil {
234 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
235 }
236 return string(body)
237 }
238
239 body1 := fetch(1)
240 body2 := fetch(2)
241 bodiesDiffer := body1 != body2
242 if bodiesDiffer != connectionClose {
243 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
244 connectionClose, bodiesDiffer, body1, body2)
245 }
246
247 tr.CloseIdleConnections()
248 }
249
250 connSet.check(t)
251 }
252
253
254
255
256
257
258
259 func TestTransportConnectionCloseOnRequest(t *testing.T) {
260 run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
261 }
262 func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
263 ts := newClientServerTest(t, mode, hostPortHandler).ts
264
265 connSet, testDial := makeTestDial(t)
266
267 c := ts.Client()
268 tr := c.Transport.(*Transport)
269 tr.Dial = testDial
270 for _, reqClose := range []bool{false, true} {
271 fetch := func(n int) string {
272 req := new(Request)
273 var err error
274 req.URL, err = url.Parse(ts.URL)
275 if err != nil {
276 t.Fatalf("URL parse error: %v", err)
277 }
278 req.Method = "GET"
279 req.Proto = "HTTP/1.1"
280 req.ProtoMajor = 1
281 req.ProtoMinor = 1
282 req.Close = reqClose
283
284 res, err := c.Do(req)
285 if err != nil {
286 t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
287 }
288 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
289 t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
290 reqClose, got, !reqClose)
291 }
292 body, err := io.ReadAll(res.Body)
293 if err != nil {
294 t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
295 }
296 return string(body)
297 }
298
299 body1 := fetch(1)
300 body2 := fetch(2)
301
302 got := 1
303 if body1 != body2 {
304 got++
305 }
306 want := 1
307 if reqClose {
308 want = 2
309 }
310 if got != want {
311 t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
312 reqClose, got, want, body1, body2)
313 }
314
315 tr.CloseIdleConnections()
316 }
317
318 connSet.check(t)
319 }
320
321
322
323
324 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
325 run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
326 }
327 func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
328 ts := newClientServerTest(t, mode, hostPortHandler).ts
329
330 c := ts.Client()
331 c.Transport.(*Transport).DisableKeepAlives = true
332
333 res, err := c.Get(ts.URL)
334 if err != nil {
335 t.Fatal(err)
336 }
337 res.Body.Close()
338 if res.Header.Get("X-Saw-Close") != "true" {
339 t.Errorf("handler didn't see Connection: close ")
340 }
341 }
342
343
344
345 func TestTransportRespectRequestWantsClose(t *testing.T) {
346 run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
347 }
348 func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
349 tests := []struct {
350 disableKeepAlives bool
351 close bool
352 }{
353 {disableKeepAlives: false, close: false},
354 {disableKeepAlives: false, close: true},
355 {disableKeepAlives: true, close: false},
356 {disableKeepAlives: true, close: true},
357 }
358
359 for _, tc := range tests {
360 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
361 func(t *testing.T) {
362 ts := newClientServerTest(t, mode, hostPortHandler).ts
363
364 c := ts.Client()
365 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
366 req, err := NewRequest("GET", ts.URL, nil)
367 if err != nil {
368 t.Fatal(err)
369 }
370 count := 0
371 trace := &httptrace.ClientTrace{
372 WroteHeaderField: func(key string, field []string) {
373 if key != "Connection" {
374 return
375 }
376 if httpguts.HeaderValuesContainsToken(field, "close") {
377 count += 1
378 }
379 },
380 }
381 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
382 req.Close = tc.close
383 res, err := c.Do(req)
384 if err != nil {
385 t.Fatal(err)
386 }
387 defer res.Body.Close()
388 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
389 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
390 }
391 })
392 }
393
394 }
395
396 func TestTransportIdleCacheKeys(t *testing.T) {
397 run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
398 }
399 func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
400 ts := newClientServerTest(t, mode, hostPortHandler).ts
401 c := ts.Client()
402 tr := c.Transport.(*Transport)
403
404 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
405 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
406 }
407
408 resp, err := c.Get(ts.URL)
409 if err != nil {
410 t.Error(err)
411 }
412 io.ReadAll(resp.Body)
413
414 keys := tr.IdleConnKeysForTesting()
415 if e, g := 1, len(keys); e != g {
416 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
417 }
418
419 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
420 t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
421 }
422
423 tr.CloseIdleConnections()
424 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
425 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
426 }
427 }
428
429
430
431 func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
432 func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
433 const msg = "foobar"
434
435 var addrSeen map[string]int
436 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
437 addrSeen[r.RemoteAddr]++
438 if r.URL.Path == "/chunked/" {
439 w.WriteHeader(200)
440 w.(Flusher).Flush()
441 } else {
442 w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
443 w.WriteHeader(200)
444 }
445 w.Write([]byte(msg))
446 })).ts
447
448 for pi, path := range []string{"/content-length/", "/chunked/"} {
449 wantLen := []int{len(msg), -1}[pi]
450 addrSeen = make(map[string]int)
451 for i := 0; i < 3; i++ {
452 res, err := ts.Client().Get(ts.URL + path)
453 if err != nil {
454 t.Errorf("Get %s: %v", path, err)
455 continue
456 }
457
458
459
460
461
462 defer res.Body.Close()
463
464 if res.ContentLength != int64(wantLen) {
465 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
466 }
467 got, err := io.ReadAll(res.Body)
468 if string(got) != msg || err != nil {
469 t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
470 }
471 }
472 if len(addrSeen) != 1 {
473 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
474 }
475 }
476 }
477
478
479
480
481
482 func TestTransportNotReadToEndConnectionReuse(t *testing.T) {
483 run(t, testTransportNotReadToEndConnectionReuse, []testMode{http1Mode, https1Mode})
484 }
485 func testTransportNotReadToEndConnectionReuse(t *testing.T, mode testMode) {
486 tests := []struct {
487 name string
488 bodyLen int
489 contentLenKnown bool
490 headRequest bool
491 timeBetweenReqs time.Duration
492 responseTime time.Duration
493 wantReuse bool
494 }{
495 {
496 name: "unconsumed body within drain limit",
497 bodyLen: 200 * 1024,
498 timeBetweenReqs: http.MaxPostCloseReadTime,
499 wantReuse: true,
500 },
501 {
502 name: "unconsumed body within drain limit with known length",
503 bodyLen: 200 * 1024,
504 contentLenKnown: true,
505 timeBetweenReqs: http.MaxPostCloseReadTime,
506 wantReuse: true,
507 },
508 {
509 name: "unconsumed body larger than drain limit",
510 bodyLen: 500 * 1024,
511 timeBetweenReqs: http.MaxPostCloseReadTime,
512 wantReuse: false,
513 },
514 {
515 name: "unconsumed body larger than drain limit with known length",
516 bodyLen: 500 * 1024,
517 contentLenKnown: true,
518 timeBetweenReqs: http.MaxPostCloseReadTime,
519 wantReuse: false,
520 },
521 {
522 name: "new requests start before drain for old requests are finished",
523 bodyLen: 200 * 1024,
524 timeBetweenReqs: 0,
525 responseTime: time.Minute,
526 wantReuse: false,
527 },
528 {
529
530
531 name: "unconsumed body larger than drain limit for HEAD request",
532 bodyLen: 500 * 1024,
533 headRequest: true,
534 wantReuse: true,
535 },
536 }
537
538 for _, tc := range tests {
539 subtest := func(t *testing.T) {
540 addrSeen := make(map[string]int)
541 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
542 addrSeen[r.RemoteAddr]++
543 time.Sleep(tc.responseTime)
544 if tc.contentLenKnown {
545 w.Header().Add("Content-Length", strconv.Itoa(tc.bodyLen))
546 }
547 w.Write(slices.Repeat([]byte("a"), tc.bodyLen))
548 }), optFakeNet).ts
549
550 var wg sync.WaitGroup
551 for range 10 {
552 wg.Go(func() {
553 method := http.MethodGet
554 if tc.headRequest {
555 method = http.MethodHead
556 }
557 ctx, cancel := context.WithCancel(context.Background())
558 req, err := http.NewRequestWithContext(ctx, method, ts.URL, nil)
559 if err != nil {
560 log.Fatal(err)
561 }
562 resp, err := ts.Client().Do(req)
563 if err != nil {
564 t.Fatal(err)
565 }
566 if resp.StatusCode != http.StatusOK {
567 t.Errorf("expected HTTP 200, got: %v", resp.StatusCode)
568 }
569 resp.Body.Close()
570
571
572 cancel()
573 if n, err := resp.Body.Read([]byte{}); n != 0 || err == nil {
574 t.Errorf("read after body has been closed should not succeed, but read %v byte with %v error", n, err)
575 }
576 })
577 time.Sleep(tc.timeBetweenReqs)
578 synctest.Wait()
579 }
580 wg.Wait()
581 if (len(addrSeen) == 1) != tc.wantReuse {
582 t.Errorf("want connection reuse to be %v, but %v connections were created", tc.wantReuse, len(addrSeen))
583 }
584 }
585 t.Run(tc.name, func(t *testing.T) {
586 synctest.Test(t, subtest)
587 })
588 }
589 }
590
591 func TestTransportMaxPerHostIdleConns(t *testing.T) {
592 run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
593 }
594 func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
595 stop := make(chan struct{})
596 defer close(stop)
597
598 resch := make(chan string)
599 gotReq := make(chan bool)
600 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
601 gotReq <- true
602 var msg string
603 select {
604 case <-stop:
605 return
606 case msg = <-resch:
607 }
608 _, err := w.Write([]byte(msg))
609 if err != nil {
610 t.Errorf("Write: %v", err)
611 return
612 }
613 })).ts
614
615 c := ts.Client()
616 tr := c.Transport.(*Transport)
617 maxIdleConnsPerHost := 2
618 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
619
620
621
622 donech := make(chan bool)
623 doReq := func() {
624 defer func() {
625 select {
626 case <-stop:
627 return
628 case donech <- t.Failed():
629 }
630 }()
631 resp, err := c.Get(ts.URL)
632 if err != nil {
633 t.Error(err)
634 return
635 }
636 if _, err := io.ReadAll(resp.Body); err != nil {
637 t.Errorf("ReadAll: %v", err)
638 return
639 }
640 }
641 go doReq()
642 <-gotReq
643 go doReq()
644 <-gotReq
645 go doReq()
646 <-gotReq
647
648 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
649 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
650 }
651
652 resch <- "res1"
653 <-donech
654 keys := tr.IdleConnKeysForTesting()
655 if e, g := 1, len(keys); e != g {
656 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
657 }
658 addr := ts.Listener.Addr().String()
659 cacheKey := "|http|" + addr
660 if keys[0] != cacheKey {
661 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
662 }
663 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
664 t.Errorf("after first response, expected %d idle conns; got %d", e, g)
665 }
666
667 resch <- "res2"
668 <-donech
669 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
670 t.Errorf("after second response, idle conns = %d; want %d", g, w)
671 }
672
673 resch <- "res3"
674 <-donech
675 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
676 t.Errorf("after third response, idle conns = %d; want %d", g, w)
677 }
678 }
679
680 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
681 run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
682 }
683 func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
684 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
685 _, err := w.Write([]byte("foo"))
686 if err != nil {
687 t.Fatalf("Write: %v", err)
688 }
689 })).ts
690 c := ts.Client()
691 tr := c.Transport.(*Transport)
692 dialStarted := make(chan struct{})
693 stallDial := make(chan struct{})
694 tr.Dial = func(network, addr string) (net.Conn, error) {
695 dialStarted <- struct{}{}
696 <-stallDial
697 return net.Dial(network, addr)
698 }
699
700 tr.DisableKeepAlives = true
701 tr.MaxConnsPerHost = 1
702
703 preDial := make(chan struct{})
704 reqComplete := make(chan struct{})
705 doReq := func(reqId string) {
706 req, _ := NewRequest("GET", ts.URL, nil)
707 trace := &httptrace.ClientTrace{
708 GetConn: func(hostPort string) {
709 preDial <- struct{}{}
710 },
711 }
712 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
713 resp, err := tr.RoundTrip(req)
714 if err != nil {
715 t.Errorf("unexpected error for request %s: %v", reqId, err)
716 }
717 _, err = io.ReadAll(resp.Body)
718 if err != nil {
719 t.Errorf("unexpected error for request %s: %v", reqId, err)
720 }
721 reqComplete <- struct{}{}
722 }
723
724 go doReq("req1")
725 <-preDial
726 <-dialStarted
727
728
729 go doReq("req2")
730 <-preDial
731 select {
732 case <-dialStarted:
733 t.Error("req2 dial started while req1 dial in progress")
734 return
735 default:
736 }
737
738
739 stallDial <- struct{}{}
740 <-reqComplete
741
742
743 <-dialStarted
744 stallDial <- struct{}{}
745 <-reqComplete
746 }
747
748 func TestTransportMaxConnsPerHost(t *testing.T) {
749 run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
750 }
751 func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
752 CondSkipHTTP2(t)
753
754 h := HandlerFunc(func(w ResponseWriter, r *Request) {
755 _, err := w.Write([]byte("foo"))
756 if err != nil {
757 t.Fatalf("Write: %v", err)
758 }
759 })
760
761 ts := newClientServerTest(t, mode, h).ts
762 c := ts.Client()
763 tr := c.Transport.(*Transport)
764 tr.MaxConnsPerHost = 1
765
766 mu := sync.Mutex{}
767 var conns []net.Conn
768 var dialCnt, gotConnCnt, tlsHandshakeCnt int32
769 tr.Dial = func(network, addr string) (net.Conn, error) {
770 atomic.AddInt32(&dialCnt, 1)
771 c, err := net.Dial(network, addr)
772 mu.Lock()
773 defer mu.Unlock()
774 conns = append(conns, c)
775 return c, err
776 }
777
778 doReq := func() {
779 trace := &httptrace.ClientTrace{
780 GotConn: func(connInfo httptrace.GotConnInfo) {
781 if !connInfo.Reused {
782 atomic.AddInt32(&gotConnCnt, 1)
783 }
784 },
785 TLSHandshakeStart: func() {
786 atomic.AddInt32(&tlsHandshakeCnt, 1)
787 },
788 }
789 req, _ := NewRequest("GET", ts.URL, nil)
790 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
791
792 resp, err := c.Do(req)
793 if err != nil {
794 t.Fatalf("request failed: %v", err)
795 }
796 defer resp.Body.Close()
797 _, err = io.ReadAll(resp.Body)
798 if err != nil {
799 t.Fatalf("read body failed: %v", err)
800 }
801 }
802
803 wg := sync.WaitGroup{}
804 for i := 0; i < 10; i++ {
805 wg.Add(1)
806 go func() {
807 defer wg.Done()
808 doReq()
809 }()
810 }
811 wg.Wait()
812
813 expected := int32(tr.MaxConnsPerHost)
814 if dialCnt != expected {
815 t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
816 }
817 if gotConnCnt != expected {
818 t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
819 }
820 if ts.TLS != nil && tlsHandshakeCnt != expected {
821 t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
822 }
823
824 if t.Failed() {
825 t.FailNow()
826 }
827
828 mu.Lock()
829 for _, c := range conns {
830 c.Close()
831 }
832 conns = nil
833 mu.Unlock()
834 tr.CloseIdleConnections()
835
836 doReq()
837 expected++
838 if dialCnt != expected {
839 t.Errorf("round 2: too many dials: %d", dialCnt)
840 }
841 if gotConnCnt != expected {
842 t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
843 }
844 if ts.TLS != nil && tlsHandshakeCnt != expected {
845 t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
846 }
847 }
848
849 func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
850 run(t, testTransportMaxConnsPerHostDialCancellation,
851 testNotParallel,
852 []testMode{http1Mode, https1Mode, http2Mode},
853 )
854 }
855
856 func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
857 CondSkipHTTP2(t)
858
859 h := HandlerFunc(func(w ResponseWriter, r *Request) {
860 _, err := w.Write([]byte("foo"))
861 if err != nil {
862 t.Fatalf("Write: %v", err)
863 }
864 })
865
866 cst := newClientServerTest(t, mode, h)
867 defer cst.close()
868 ts := cst.ts
869 c := ts.Client()
870 tr := c.Transport.(*Transport)
871 tr.MaxConnsPerHost = 1
872
873
874 ctx, cancel := context.WithCancel(context.Background())
875 defer cancel()
876 SetPendingDialHooks(cancel, nil)
877 defer SetPendingDialHooks(nil, nil)
878
879 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
880 _, err := c.Do(req)
881 if !errors.Is(err, context.Canceled) {
882 t.Errorf("expected error %v, got %v", context.Canceled, err)
883 }
884
885
886 SetPendingDialHooks(nil, nil)
887 req, _ = NewRequest("GET", ts.URL, nil)
888 resp, err := c.Do(req)
889 if err != nil {
890 t.Fatalf("request failed: %v", err)
891 }
892 defer resp.Body.Close()
893 _, err = io.ReadAll(resp.Body)
894 if err != nil {
895 t.Fatalf("read body failed: %v", err)
896 }
897 }
898
899 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
900 run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
901 }
902 func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
903 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
904 io.WriteString(w, r.RemoteAddr)
905 })).ts
906
907 c := ts.Client()
908 tr := c.Transport.(*Transport)
909
910 doReq := func(name string) {
911
912
913 res, err := c.Post(ts.URL, "", nil)
914 if err != nil {
915 t.Fatalf("%s: %v", name, err)
916 }
917 if res.StatusCode != 200 {
918 t.Fatalf("%s: %v", name, res.Status)
919 }
920 defer res.Body.Close()
921 slurp, err := io.ReadAll(res.Body)
922 if err != nil {
923 t.Fatalf("%s: %v", name, err)
924 }
925 t.Logf("%s: ok (%q)", name, slurp)
926 }
927
928 doReq("first")
929 keys1 := tr.IdleConnKeysForTesting()
930
931 ts.CloseClientConnections()
932
933 var keys2 []string
934 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
935 keys2 = tr.IdleConnKeysForTesting()
936 if len(keys2) != 0 {
937 if d > 0 {
938 t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
939 }
940 return false
941 }
942 return true
943 })
944
945 doReq("second")
946 }
947
948
949
950 func TestTransportServerClosingUnexpectedly(t *testing.T) {
951 run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
952 }
953 func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
954 ts := newClientServerTest(t, mode, hostPortHandler).ts
955 c := ts.Client()
956
957 fetch := func(n, retries int) string {
958 condFatalf := func(format string, arg ...any) {
959 if retries <= 0 {
960 t.Fatalf(format, arg...)
961 }
962 t.Logf("retrying shortly after expected error: "+format, arg...)
963 time.Sleep(time.Second / time.Duration(retries))
964 }
965 for retries >= 0 {
966 retries--
967 res, err := c.Get(ts.URL)
968 if err != nil {
969 condFatalf("error in req #%d, GET: %v", n, err)
970 continue
971 }
972 body, err := io.ReadAll(res.Body)
973 if err != nil {
974 condFatalf("error in req #%d, ReadAll: %v", n, err)
975 continue
976 }
977 res.Body.Close()
978 return string(body)
979 }
980 panic("unreachable")
981 }
982
983 body1 := fetch(1, 0)
984 body2 := fetch(2, 0)
985
986
987
988
989
990
991
992
993 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
994
995 body3 := fetch(3, 5)
996
997 if body1 != body2 {
998 t.Errorf("expected body1 and body2 to be equal")
999 }
1000 if body2 == body3 {
1001 t.Errorf("expected body2 and body3 to be different")
1002 }
1003 }
1004
1005
1006
1007 func TestStressSurpriseServerCloses(t *testing.T) {
1008 run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
1009 }
1010 func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
1011 if testing.Short() {
1012 t.Skip("skipping test in short mode")
1013 }
1014 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1015 w.Header().Set("Content-Length", "5")
1016 w.Header().Set("Content-Type", "text/plain")
1017 w.Write([]byte("Hello"))
1018 w.(Flusher).Flush()
1019 conn, buf, _ := w.(Hijacker).Hijack()
1020 buf.Flush()
1021 conn.Close()
1022 })).ts
1023 c := ts.Client()
1024
1025
1026
1027
1028
1029
1030
1031 const (
1032 numClients = 20
1033 reqsPerClient = 25
1034 )
1035 var wg sync.WaitGroup
1036 wg.Add(numClients * reqsPerClient)
1037 for i := 0; i < numClients; i++ {
1038 go func() {
1039 for i := 0; i < reqsPerClient; i++ {
1040 res, err := c.Get(ts.URL)
1041 if err == nil {
1042
1043
1044
1045
1046
1047
1048 res.Body.Close()
1049 }
1050 wg.Done()
1051 }
1052 }()
1053 }
1054
1055
1056 wg.Wait()
1057 }
1058
1059
1060
1061 func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
1062 func testTransportHeadResponses(t *testing.T, mode testMode) {
1063 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1064 if r.Method != "HEAD" {
1065 panic("expected HEAD; got " + r.Method)
1066 }
1067 w.Header().Set("Content-Length", "123")
1068 w.WriteHeader(200)
1069 })).ts
1070 c := ts.Client()
1071
1072 for i := 0; i < 2; i++ {
1073 res, err := c.Head(ts.URL)
1074 if err != nil {
1075 t.Errorf("error on loop %d: %v", i, err)
1076 continue
1077 }
1078 if e, g := "123", res.Header.Get("Content-Length"); e != g {
1079 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
1080 }
1081 if e, g := int64(123), res.ContentLength; e != g {
1082 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
1083 }
1084 if all, err := io.ReadAll(res.Body); err != nil {
1085 t.Errorf("loop %d: Body ReadAll: %v", i, err)
1086 } else if len(all) != 0 {
1087 t.Errorf("Bogus body %q", all)
1088 }
1089 }
1090 }
1091
1092
1093
1094 func TestTransportHeadChunkedResponse(t *testing.T) {
1095 run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
1096 }
1097 func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
1098 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1099 if r.Method != "HEAD" {
1100 panic("expected HEAD; got " + r.Method)
1101 }
1102 w.Header().Set("Transfer-Encoding", "chunked")
1103 w.Header().Set("x-client-ipport", r.RemoteAddr)
1104 w.WriteHeader(200)
1105 })).ts
1106 c := ts.Client()
1107
1108
1109
1110 didRead := make(chan bool)
1111 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
1112 defer SetReadLoopBeforeNextReadHook(nil)
1113
1114 res1, err := c.Head(ts.URL)
1115 <-didRead
1116
1117 if err != nil {
1118 t.Fatalf("request 1 error: %v", err)
1119 }
1120
1121 res2, err := c.Head(ts.URL)
1122 <-didRead
1123
1124 if err != nil {
1125 t.Fatalf("request 2 error: %v", err)
1126 }
1127 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
1128 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
1129 }
1130 }
1131
1132 var roundTripTests = []struct {
1133 accept string
1134 expectAccept string
1135 compressed bool
1136 }{
1137
1138 {"", "gzip", false},
1139
1140 {"foo", "foo", false},
1141
1142 {"gzip", "gzip", true},
1143 }
1144
1145
1146 func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
1147 func testRoundTripGzip(t *testing.T, mode testMode) {
1148 const responseBody = "test response body"
1149 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1150 accept := req.Header.Get("Accept-Encoding")
1151 if expect := req.FormValue("expect_accept"); accept != expect {
1152 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
1153 req.FormValue("testnum"), accept, expect)
1154 }
1155 if accept == "gzip" {
1156 rw.Header().Set("Content-Encoding", "gzip")
1157 gz := gzip.NewWriter(rw)
1158 gz.Write([]byte(responseBody))
1159 gz.Close()
1160 } else {
1161 rw.Header().Set("Content-Encoding", accept)
1162 rw.Write([]byte(responseBody))
1163 }
1164 })).ts
1165 tr := ts.Client().Transport.(*Transport)
1166
1167 for i, test := range roundTripTests {
1168
1169 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1170 if test.accept != "" {
1171 req.Header.Set("Accept-Encoding", test.accept)
1172 }
1173 res, err := tr.RoundTrip(req)
1174 if err != nil {
1175 t.Errorf("%d. RoundTrip: %v", i, err)
1176 continue
1177 }
1178 var body []byte
1179 if test.compressed {
1180 var r *gzip.Reader
1181 r, err = gzip.NewReader(res.Body)
1182 if err != nil {
1183 t.Errorf("%d. gzip NewReader: %v", i, err)
1184 continue
1185 }
1186 body, err = io.ReadAll(r)
1187 res.Body.Close()
1188 } else {
1189 body, err = io.ReadAll(res.Body)
1190 }
1191 if err != nil {
1192 t.Errorf("%d. Error: %q", i, err)
1193 continue
1194 }
1195 if g, e := string(body), responseBody; g != e {
1196 t.Errorf("%d. body = %q; want %q", i, g, e)
1197 }
1198 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1199 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1200 }
1201 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1202 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1203 }
1204 }
1205
1206 }
1207
1208 func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
1209 func testTransportGzip(t *testing.T, mode testMode) {
1210 if mode == http2Mode {
1211 t.Skip("https://go.dev/issue/56020")
1212 }
1213 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1214 const nRandBytes = 1024 * 1024
1215 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1216 if req.Method == "HEAD" {
1217 if g := req.Header.Get("Accept-Encoding"); g != "" {
1218 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1219 }
1220 return
1221 }
1222 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1223 t.Errorf("Accept-Encoding = %q, want %q", g, e)
1224 }
1225 rw.Header().Set("Content-Encoding", "gzip")
1226
1227 var w io.Writer = rw
1228 var buf bytes.Buffer
1229 if req.FormValue("chunked") == "0" {
1230 w = &buf
1231 defer io.Copy(rw, &buf)
1232 defer func() {
1233 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1234 }()
1235 }
1236 gz := gzip.NewWriter(w)
1237 gz.Write([]byte(testString))
1238 if req.FormValue("body") == "large" {
1239 io.CopyN(gz, rand.Reader, nRandBytes)
1240 }
1241 gz.Close()
1242 })).ts
1243 c := ts.Client()
1244
1245 for _, chunked := range []string{"1", "0"} {
1246
1247 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1248 if err != nil {
1249 t.Fatalf("large get: %v", err)
1250 }
1251 buf := make([]byte, len(testString))
1252 n, err := io.ReadFull(res.Body, buf)
1253 if err != nil {
1254 t.Fatalf("partial read of large response: size=%d, %v", n, err)
1255 }
1256 if e, g := testString, string(buf); e != g {
1257 t.Errorf("partial read got %q, expected %q", g, e)
1258 }
1259 res.Body.Close()
1260
1261 n, err = res.Body.Read(buf)
1262 if n != 0 || err == nil {
1263 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1264 }
1265
1266
1267 res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1268 if err != nil {
1269 t.Fatal(err)
1270 }
1271 body, err := io.ReadAll(res.Body)
1272 if err != nil {
1273 t.Fatal(err)
1274 }
1275 if g, e := string(body), testString; g != e {
1276 t.Fatalf("body = %q; want %q", g, e)
1277 }
1278 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1279 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1280 }
1281
1282
1283 n, err = res.Body.Read(buf)
1284 if n != 0 || err == nil {
1285 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1286 }
1287 res.Body.Close()
1288 n, err = res.Body.Read(buf)
1289 if n != 0 || err == nil {
1290 t.Errorf("expected Read error after Close; got %d, %v", n, err)
1291 }
1292 }
1293
1294
1295 res, err := c.Head(ts.URL)
1296 if err != nil {
1297 t.Fatalf("Head: %v", err)
1298 }
1299 if res.StatusCode != 200 {
1300 t.Errorf("Head status=%d; want=200", res.StatusCode)
1301 }
1302 }
1303
1304
1305
1306 type transport100ContinueTest struct {
1307 t *testing.T
1308
1309 reqdone chan struct{}
1310 resp *Response
1311 respErr error
1312
1313 conn net.Conn
1314 reader *bufio.Reader
1315 }
1316
1317 const transport100ContinueTestBody = "request body"
1318
1319
1320
1321 func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
1322 ln := newLocalListener(t)
1323 defer ln.Close()
1324
1325 test := &transport100ContinueTest{
1326 t: t,
1327 reqdone: make(chan struct{}),
1328 }
1329
1330 tr := &Transport{
1331 ExpectContinueTimeout: timeout,
1332 }
1333 go func() {
1334 defer close(test.reqdone)
1335 body := strings.NewReader(transport100ContinueTestBody)
1336 req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
1337 req.Header.Set("Expect", "100-continue")
1338 req.ContentLength = int64(len(transport100ContinueTestBody))
1339 test.resp, test.respErr = tr.RoundTrip(req)
1340 test.resp.Body.Close()
1341 }()
1342
1343 c, err := ln.Accept()
1344 if err != nil {
1345 t.Fatalf("Accept: %v", err)
1346 }
1347 t.Cleanup(func() {
1348 c.Close()
1349 })
1350 br := bufio.NewReader(c)
1351 _, err = ReadRequest(br)
1352 if err != nil {
1353 t.Fatalf("ReadRequest: %v", err)
1354 }
1355 test.conn = c
1356 test.reader = br
1357 t.Cleanup(func() {
1358 <-test.reqdone
1359 tr.CloseIdleConnections()
1360 got, _ := io.ReadAll(test.reader)
1361 if len(got) > 0 {
1362 t.Fatalf("Transport sent unexpected bytes: %q", got)
1363 }
1364 })
1365
1366 return test
1367 }
1368
1369
1370 func (test *transport100ContinueTest) respond(lines ...string) {
1371 for _, line := range lines {
1372 if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
1373 test.t.Fatalf("Write: %v", err)
1374 }
1375 }
1376 if _, err := test.conn.Write([]byte("\r\n")); err != nil {
1377 test.t.Fatalf("Write: %v", err)
1378 }
1379 }
1380
1381
1382 func (test *transport100ContinueTest) wantBodySent() {
1383 got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
1384 if err != nil {
1385 test.t.Fatalf("unexpected error reading body: %v", err)
1386 }
1387 if got, want := string(got), transport100ContinueTestBody; got != want {
1388 test.t.Fatalf("unexpected body: got %q, want %q", got, want)
1389 }
1390 }
1391
1392
1393 func (test *transport100ContinueTest) wantRequestDone(want int) {
1394 <-test.reqdone
1395 if test.respErr != nil {
1396 test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
1397 }
1398 if got := test.resp.StatusCode; got != want {
1399 test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
1400 }
1401 }
1402
1403 func TestTransportExpect100ContinueSent(t *testing.T) {
1404 test := newTransport100ContinueTest(t, 1*time.Hour)
1405
1406 test.respond("HTTP/1.1 100 Continue")
1407 test.wantBodySent()
1408 test.respond("HTTP/1.1 200", "Content-Length: 0")
1409 test.wantRequestDone(200)
1410 }
1411
1412 func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
1413 test := newTransport100ContinueTest(t, 1*time.Hour)
1414
1415 test.respond("HTTP/1.1 200", "Content-Length: 0")
1416 test.wantBodySent()
1417 test.wantRequestDone(200)
1418 }
1419
1420 func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
1421 test := newTransport100ContinueTest(t, 1*time.Hour)
1422
1423 test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
1424 test.wantRequestDone(200)
1425 }
1426
1427 func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
1428 test := newTransport100ContinueTest(t, 1*time.Hour)
1429
1430 test.respond("HTTP/1.1 500", "Content-Length: 0")
1431 test.wantBodySent()
1432 test.wantRequestDone(500)
1433 }
1434
1435 func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
1436 test := newTransport100ContinueTest(t, 5*time.Millisecond)
1437 test.wantBodySent()
1438 test.respond("HTTP/1.1 200", "Content-Length: 0")
1439 test.wantRequestDone(200)
1440 }
1441
1442 func TestSOCKS5Proxy(t *testing.T) {
1443 run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
1444 }
1445 func testSOCKS5Proxy(t *testing.T, mode testMode) {
1446 ch := make(chan string, 1)
1447 l := newLocalListener(t)
1448 defer l.Close()
1449 defer close(ch)
1450 proxy := func(t *testing.T) {
1451 s, err := l.Accept()
1452 if err != nil {
1453 t.Errorf("socks5 proxy Accept(): %v", err)
1454 return
1455 }
1456 defer s.Close()
1457 var buf [22]byte
1458 if _, err := io.ReadFull(s, buf[:3]); err != nil {
1459 t.Errorf("socks5 proxy initial read: %v", err)
1460 return
1461 }
1462 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1463 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1464 return
1465 }
1466 if _, err := s.Write([]byte{5, 0}); err != nil {
1467 t.Errorf("socks5 proxy initial write: %v", err)
1468 return
1469 }
1470 if _, err := io.ReadFull(s, buf[:4]); err != nil {
1471 t.Errorf("socks5 proxy second read: %v", err)
1472 return
1473 }
1474 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1475 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1476 return
1477 }
1478 var ipLen int
1479 switch buf[3] {
1480 case 1:
1481 ipLen = net.IPv4len
1482 case 4:
1483 ipLen = net.IPv6len
1484 default:
1485 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1486 return
1487 }
1488 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1489 t.Errorf("socks5 proxy address read: %v", err)
1490 return
1491 }
1492 ip := net.IP(buf[4 : ipLen+4])
1493 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1494 copy(buf[:3], []byte{5, 0, 0})
1495 if _, err := s.Write(buf[:ipLen+6]); err != nil {
1496 t.Errorf("socks5 proxy connect write: %v", err)
1497 return
1498 }
1499 ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1500
1501
1502 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1503 targetConn, err := net.Dial("tcp", targetHost)
1504 if err != nil {
1505 t.Errorf("net.Dial failed")
1506 return
1507 }
1508 go io.Copy(targetConn, s)
1509 io.Copy(s, targetConn)
1510 targetConn.Close()
1511 }
1512
1513 pu, err := url.Parse("socks5://" + l.Addr().String())
1514 if err != nil {
1515 t.Fatal(err)
1516 }
1517
1518 sentinelHeader := "X-Sentinel"
1519 sentinelValue := "12345"
1520 h := HandlerFunc(func(w ResponseWriter, r *Request) {
1521 w.Header().Set(sentinelHeader, sentinelValue)
1522 })
1523 for _, useTLS := range []bool{false, true} {
1524 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1525 ts := newClientServerTest(t, mode, h).ts
1526 go proxy(t)
1527 c := ts.Client()
1528 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1529 r, err := c.Head(ts.URL)
1530 if err != nil {
1531 t.Fatal(err)
1532 }
1533 if r.Header.Get(sentinelHeader) != sentinelValue {
1534 t.Errorf("Failed to retrieve sentinel value")
1535 }
1536 got := <-ch
1537 ts.Close()
1538 tsu, err := url.Parse(ts.URL)
1539 if err != nil {
1540 t.Fatal(err)
1541 }
1542 want := "proxy for " + tsu.Host
1543 if got != want {
1544 t.Errorf("got %q, want %q", got, want)
1545 }
1546 })
1547 }
1548 }
1549
1550 func TestTransportProxy(t *testing.T) {
1551 defer afterTest(t)
1552 testCases := []struct{ siteMode, proxyMode testMode }{
1553 {http1Mode, http1Mode},
1554 {http1Mode, https1Mode},
1555 {https1Mode, http1Mode},
1556 {https1Mode, https1Mode},
1557 }
1558 for _, testCase := range testCases {
1559 siteMode := testCase.siteMode
1560 proxyMode := testCase.proxyMode
1561 t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
1562 siteCh := make(chan *Request, 1)
1563 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1564 siteCh <- r
1565 })
1566 proxyCh := make(chan *Request, 1)
1567 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1568 proxyCh <- r
1569
1570 if r.Method == "CONNECT" {
1571 hijacker, ok := w.(Hijacker)
1572 if !ok {
1573 t.Errorf("hijack not allowed")
1574 return
1575 }
1576 clientConn, _, err := hijacker.Hijack()
1577 if err != nil {
1578 t.Errorf("hijacking failed")
1579 return
1580 }
1581 res := &Response{
1582 StatusCode: StatusOK,
1583 Proto: "HTTP/1.1",
1584 ProtoMajor: 1,
1585 ProtoMinor: 1,
1586 Header: make(Header),
1587 }
1588
1589 targetConn, err := net.Dial("tcp", r.URL.Host)
1590 if err != nil {
1591 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1592 return
1593 }
1594
1595 if err := res.Write(clientConn); err != nil {
1596 t.Errorf("Writing 200 OK failed: %v", err)
1597 return
1598 }
1599
1600 go io.Copy(targetConn, clientConn)
1601 go func() {
1602 io.Copy(clientConn, targetConn)
1603 targetConn.Close()
1604 }()
1605 }
1606 })
1607 ts := newClientServerTest(t, siteMode, h1).ts
1608 proxy := newClientServerTest(t, proxyMode, h2).ts
1609
1610 pu, err := url.Parse(proxy.URL)
1611 if err != nil {
1612 t.Fatal(err)
1613 }
1614
1615
1616
1617
1618 c := proxy.Client()
1619 if siteMode == https1Mode {
1620 c = ts.Client()
1621 }
1622
1623 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1624 if _, err := c.Head(ts.URL); err != nil {
1625 t.Error(err)
1626 }
1627 got := <-proxyCh
1628 c.Transport.(*Transport).CloseIdleConnections()
1629 ts.Close()
1630 proxy.Close()
1631 if siteMode == https1Mode {
1632
1633 if got.Method != "CONNECT" {
1634 t.Errorf("Wrong method for secure proxying: %q", got.Method)
1635 }
1636 gotHost := got.URL.Host
1637 pu, err := url.Parse(ts.URL)
1638 if err != nil {
1639 t.Fatal("Invalid site URL")
1640 }
1641 if wantHost := pu.Host; gotHost != wantHost {
1642 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1643 }
1644
1645
1646 next := <-siteCh
1647 if next.Method != "HEAD" {
1648 t.Errorf("Wrong method at destination: %s", next.Method)
1649 }
1650 if nextURL := next.URL.String(); nextURL != "/" {
1651 t.Errorf("Wrong URL at destination: %s", nextURL)
1652 }
1653 } else {
1654 if got.Method != "HEAD" {
1655 t.Errorf("Wrong method for destination: %q", got.Method)
1656 }
1657 gotURL := got.URL.String()
1658 wantURL := ts.URL + "/"
1659 if gotURL != wantURL {
1660 t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1661 }
1662 }
1663 })
1664 }
1665 }
1666
1667
1668
1669
1670 func TestProxyWithInfiniteHeader(t *testing.T) {
1671 defer afterTest(t)
1672
1673 ln := newLocalListener(t)
1674 defer ln.Close()
1675 cancelc := make(chan struct{})
1676 defer close(cancelc)
1677
1678
1679
1680 go func() {
1681 c, err := ln.Accept()
1682 if err != nil {
1683 t.Errorf("Accept: %v", err)
1684 return
1685 }
1686 defer c.Close()
1687
1688 br := bufio.NewReader(c)
1689 cr, err := ReadRequest(br)
1690 if err != nil {
1691 t.Errorf("proxy server failed to read CONNECT request")
1692 return
1693 }
1694 if cr.Method != "CONNECT" {
1695 t.Errorf("unexpected method %q", cr.Method)
1696 return
1697 }
1698
1699
1700 for {
1701
1702
1703
1704 runtime.Gosched()
1705 select {
1706 case <-cancelc:
1707 return
1708 default:
1709 c.Write([]byte("infinite stream of bytes"))
1710 }
1711 }
1712 }()
1713
1714 c := &Client{
1715 Transport: &Transport{
1716 Proxy: func(*Request) (*url.URL, error) {
1717 return url.Parse("http://" + ln.Addr().String())
1718 },
1719
1720 MaxResponseHeaderBytes: 1024,
1721 },
1722 }
1723 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1724 if err != nil {
1725 t.Fatal(err)
1726 }
1727 _, err = c.Do(req)
1728 if err == nil {
1729 t.Errorf("unexpected Get success")
1730 }
1731 }
1732
1733 func TestOnProxyConnectResponse(t *testing.T) {
1734
1735 var tcases = []struct {
1736 proxyStatusCode int
1737 err error
1738 }{
1739 {
1740 StatusOK,
1741 nil,
1742 },
1743 {
1744 StatusForbidden,
1745 errors.New("403"),
1746 },
1747 }
1748 for _, tcase := range tcases {
1749 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1750
1751 })
1752
1753 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1754
1755 if r.Method == "CONNECT" {
1756 if tcase.proxyStatusCode != StatusOK {
1757 w.WriteHeader(tcase.proxyStatusCode)
1758 return
1759 }
1760 hijacker, ok := w.(Hijacker)
1761 if !ok {
1762 t.Errorf("hijack not allowed")
1763 return
1764 }
1765 clientConn, _, err := hijacker.Hijack()
1766 if err != nil {
1767 t.Errorf("hijacking failed")
1768 return
1769 }
1770 res := &Response{
1771 StatusCode: StatusOK,
1772 Proto: "HTTP/1.1",
1773 ProtoMajor: 1,
1774 ProtoMinor: 1,
1775 Header: make(Header),
1776 }
1777
1778 targetConn, err := net.Dial("tcp", r.URL.Host)
1779 if err != nil {
1780 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1781 return
1782 }
1783
1784 if err := res.Write(clientConn); err != nil {
1785 t.Errorf("Writing 200 OK failed: %v", err)
1786 return
1787 }
1788
1789 go io.Copy(targetConn, clientConn)
1790 go func() {
1791 io.Copy(clientConn, targetConn)
1792 targetConn.Close()
1793 }()
1794 }
1795 })
1796 ts := newClientServerTest(t, https1Mode, h1).ts
1797 proxy := newClientServerTest(t, https1Mode, h2).ts
1798
1799 pu, err := url.Parse(proxy.URL)
1800 if err != nil {
1801 t.Fatal(err)
1802 }
1803
1804 c := proxy.Client()
1805
1806 var (
1807 dials atomic.Int32
1808 closes atomic.Int32
1809 )
1810 c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1811 conn, err := net.Dial(network, addr)
1812 if err != nil {
1813 return nil, err
1814 }
1815 dials.Add(1)
1816 return noteCloseConn{
1817 Conn: conn,
1818 closeFunc: func() {
1819 closes.Add(1)
1820 },
1821 }, nil
1822 }
1823
1824 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1825 c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1826 if proxyURL.String() != pu.String() {
1827 t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1828 }
1829
1830 if "https://"+connectReq.URL.String() != ts.URL {
1831 t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1832 }
1833 return tcase.err
1834 }
1835 wantCloses := int32(0)
1836 if _, err := c.Head(ts.URL); err != nil {
1837 wantCloses = 1
1838 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1839 t.Errorf("got %v, want %v", err, tcase.err)
1840 }
1841 } else {
1842 if tcase.err != nil {
1843 t.Errorf("got %v, want nil", err)
1844 }
1845 }
1846 if got, want := dials.Load(), int32(1); got != want {
1847 t.Errorf("got %v dials, want %v", got, want)
1848 }
1849
1850 if got, want := closes.Load(), wantCloses; got != want {
1851 t.Errorf("got %v closes, want %v", got, want)
1852 }
1853 }
1854 }
1855
1856
1857
1858 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1859 cancelc := make(chan struct{})
1860 SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
1861 ctx, cancel := context.WithCancel(ctx)
1862 go func() {
1863 select {
1864 case <-cancelc:
1865 case <-ctx.Done():
1866 }
1867 cancel()
1868 }()
1869 return ctx, cancel
1870 })
1871
1872 defer afterTest(t)
1873
1874 ln := newLocalListener(t)
1875 defer ln.Close()
1876 listenerDone := make(chan struct{})
1877 go func() {
1878 defer close(listenerDone)
1879 c, err := ln.Accept()
1880 if err != nil {
1881 t.Errorf("Accept: %v", err)
1882 return
1883 }
1884 defer c.Close()
1885
1886 br := bufio.NewReader(c)
1887 cr, err := ReadRequest(br)
1888 if err != nil {
1889 t.Errorf("proxy server failed to read CONNECT request")
1890 return
1891 }
1892 if cr.Method != "CONNECT" {
1893 t.Errorf("unexpected method %q", cr.Method)
1894 return
1895 }
1896
1897
1898
1899
1900 close(cancelc)
1901 var buf [1]byte
1902 _, err = br.Read(buf[:])
1903 if err != io.EOF {
1904 t.Errorf("proxy server Read err = %v; want EOF", err)
1905 }
1906 return
1907 }()
1908
1909 c := &Client{
1910 Transport: &Transport{
1911 Proxy: func(*Request) (*url.URL, error) {
1912 return url.Parse("http://" + ln.Addr().String())
1913 },
1914 },
1915 }
1916 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1917 if err != nil {
1918 t.Fatal(err)
1919 }
1920 _, err = c.Do(req)
1921 if err == nil {
1922 t.Errorf("unexpected Get success")
1923 }
1924
1925
1926
1927
1928 <-listenerDone
1929 }
1930
1931
1932 func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1933 defer afterTest(t)
1934
1935 var errDial = errors.New("some dial error")
1936
1937 tr := &Transport{
1938 Proxy: func(*Request) (*url.URL, error) {
1939 return url.Parse("http://proxy.fake.tld/")
1940 },
1941 Dial: func(string, string) (net.Conn, error) {
1942 return nil, errDial
1943 },
1944 }
1945 defer tr.CloseIdleConnections()
1946
1947 c := &Client{Transport: tr}
1948 req, _ := NewRequest("GET", "http://fake.tld", nil)
1949 res, err := c.Do(req)
1950 if err == nil {
1951 res.Body.Close()
1952 t.Fatal("wanted a non-nil error")
1953 }
1954
1955 uerr, ok := err.(*url.Error)
1956 if !ok {
1957 t.Fatalf("got %T, want *url.Error", err)
1958 }
1959 oe, ok := uerr.Err.(*net.OpError)
1960 if !ok {
1961 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
1962 }
1963 want := &net.OpError{
1964 Op: "proxyconnect",
1965 Net: "tcp",
1966 Err: errDial,
1967 }
1968 if !reflect.DeepEqual(oe, want) {
1969 t.Errorf("Got error %#v; want %#v", oe, want)
1970 }
1971 }
1972
1973
1974
1975
1976
1977 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1978 run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
1979 }
1980 func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
1981 proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
1982 defer proxy.Close()
1983 c := proxy.Client()
1984
1985 tr := c.Transport.(*Transport)
1986 tr.Proxy = func(*Request) (*url.URL, error) {
1987 u, _ := url.Parse(proxy.URL)
1988 u.User = url.UserPassword("aladdin", "opensesame")
1989 return u, nil
1990 }
1991 h := tr.ProxyConnectHeader
1992 if h == nil {
1993 h = make(Header)
1994 }
1995 tr.ProxyConnectHeader = h.Clone()
1996
1997 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1998 if err != nil {
1999 t.Fatal(err)
2000 }
2001 _, err = c.Do(req)
2002 if err == nil {
2003 t.Errorf("unexpected Get success")
2004 }
2005
2006 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
2007 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
2008 }
2009 }
2010
2011
2012
2013
2014
2015 func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
2016 func testTransportGzipRecursive(t *testing.T, mode testMode) {
2017 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2018 w.Header().Set("Content-Encoding", "gzip")
2019 w.Write(rgz)
2020 })).ts
2021
2022 c := ts.Client()
2023 res, err := c.Get(ts.URL)
2024 if err != nil {
2025 t.Fatal(err)
2026 }
2027 body, err := io.ReadAll(res.Body)
2028 if err != nil {
2029 t.Fatal(err)
2030 }
2031 if !bytes.Equal(body, rgz) {
2032 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
2033 body, rgz)
2034 }
2035 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
2036 t.Fatalf("Content-Encoding = %q; want %q", g, e)
2037 }
2038 }
2039
2040
2041
2042 func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
2043 func testTransportGzipShort(t *testing.T, mode testMode) {
2044 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2045 w.Header().Set("Content-Encoding", "gzip")
2046 w.Write([]byte{0x1f, 0x8b})
2047 })).ts
2048
2049 c := ts.Client()
2050 res, err := c.Get(ts.URL)
2051 if err != nil {
2052 t.Fatal(err)
2053 }
2054 defer res.Body.Close()
2055 _, err = io.ReadAll(res.Body)
2056 if err == nil {
2057 t.Fatal("Expect an error from reading a body.")
2058 }
2059 if err != io.ErrUnexpectedEOF {
2060 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
2061 }
2062 }
2063
2064
2065 func waitNumGoroutine(nmax int) int {
2066 nfinal := runtime.NumGoroutine()
2067 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
2068 time.Sleep(50 * time.Millisecond)
2069 runtime.GC()
2070 nfinal = runtime.NumGoroutine()
2071 }
2072 return nfinal
2073 }
2074
2075
2076 func TestTransportPersistConnLeak(t *testing.T) {
2077 run(t, testTransportPersistConnLeak, testNotParallel)
2078 }
2079 func testTransportPersistConnLeak(t *testing.T, mode testMode) {
2080 if mode == http2Mode {
2081 t.Skip("flaky in HTTP/2")
2082 }
2083
2084
2085 const numReq = 25
2086 gotReqCh := make(chan bool, numReq)
2087 unblockCh := make(chan bool, numReq)
2088 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2089 gotReqCh <- true
2090 <-unblockCh
2091 w.Header().Set("Content-Length", "0")
2092 w.WriteHeader(204)
2093 })).ts
2094 c := ts.Client()
2095 tr := c.Transport.(*Transport)
2096
2097 n0 := runtime.NumGoroutine()
2098
2099 didReqCh := make(chan bool, numReq)
2100 failed := make(chan bool, numReq)
2101 for i := 0; i < numReq; i++ {
2102 go func() {
2103 res, err := c.Get(ts.URL)
2104 didReqCh <- true
2105 if err != nil {
2106 t.Logf("client fetch error: %v", err)
2107 failed <- true
2108 return
2109 }
2110 res.Body.Close()
2111 }()
2112 }
2113
2114
2115 for i := 0; i < numReq; i++ {
2116 select {
2117 case <-gotReqCh:
2118
2119 case <-failed:
2120
2121
2122 }
2123 }
2124
2125 nhigh := runtime.NumGoroutine()
2126
2127
2128 close(unblockCh)
2129
2130
2131 for i := 0; i < numReq; i++ {
2132 <-didReqCh
2133 }
2134
2135 tr.CloseIdleConnections()
2136 nfinal := waitNumGoroutine(n0 + 5)
2137
2138 growth := nfinal - n0
2139
2140
2141
2142 if int(growth) > 5 {
2143 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2144 t.Error("too many new goroutines")
2145 }
2146 }
2147
2148
2149
2150 func TestTransportPersistConnLeakShortBody(t *testing.T) {
2151 run(t, testTransportPersistConnLeakShortBody, testNotParallel)
2152 }
2153 func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
2154 if mode == http2Mode {
2155 t.Skip("flaky in HTTP/2")
2156 }
2157
2158
2159 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2160 })).ts
2161 c := ts.Client()
2162 tr := c.Transport.(*Transport)
2163
2164 n0 := runtime.NumGoroutine()
2165 body := []byte("Hello")
2166 for i := 0; i < 20; i++ {
2167 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2168 if err != nil {
2169 t.Fatal(err)
2170 }
2171 req.ContentLength = int64(len(body) - 2)
2172 _, err = c.Do(req)
2173 if err == nil {
2174 t.Fatal("Expect an error from writing too long of a body.")
2175 }
2176 }
2177 nhigh := runtime.NumGoroutine()
2178 tr.CloseIdleConnections()
2179 nfinal := waitNumGoroutine(n0 + 5)
2180
2181 growth := nfinal - n0
2182
2183
2184
2185 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2186 if int(growth) > 5 {
2187 t.Error("too many new goroutines")
2188 }
2189 }
2190
2191
2192 type countedConn struct {
2193 net.Conn
2194 }
2195
2196
2197 type countingDialer struct {
2198 dialer net.Dialer
2199 mu sync.Mutex
2200 total, live int64
2201 }
2202
2203 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
2204 conn, err := d.dialer.DialContext(ctx, network, address)
2205 if err != nil {
2206 return nil, err
2207 }
2208
2209 counted := new(countedConn)
2210 counted.Conn = conn
2211
2212 d.mu.Lock()
2213 defer d.mu.Unlock()
2214 d.total++
2215 d.live++
2216
2217 runtime.AddCleanup(counted, func(dd *countingDialer) { dd.decrement(nil) }, d)
2218 return counted, nil
2219 }
2220
2221 func (d *countingDialer) decrement(*countedConn) {
2222 d.mu.Lock()
2223 defer d.mu.Unlock()
2224 d.live--
2225 }
2226
2227 func (d *countingDialer) Read() (total, live int64) {
2228 d.mu.Lock()
2229 defer d.mu.Unlock()
2230 return d.total, d.live
2231 }
2232
2233 func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
2234 run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
2235 }
2236 func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
2237 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2238
2239 conn, _, err := w.(Hijacker).Hijack()
2240 if err != nil {
2241 t.Errorf("Hijack failed unexpectedly: %v", err)
2242 return
2243 }
2244 conn.Close()
2245 })).ts
2246
2247 var d countingDialer
2248 c := ts.Client()
2249 c.Transport.(*Transport).DialContext = d.DialContext
2250
2251 body := []byte("Hello")
2252 for i := 0; ; i++ {
2253 total, live := d.Read()
2254 if live < total {
2255 break
2256 }
2257 if i >= 1<<12 {
2258 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
2259 }
2260
2261 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2262 if err != nil {
2263 t.Fatal(err)
2264 }
2265 _, err = c.Do(req)
2266 if err == nil {
2267 t.Fatal("expected broken connection")
2268 }
2269
2270 runtime.GC()
2271 }
2272 }
2273
2274 type countedContext struct {
2275 context.Context
2276 }
2277
2278 type contextCounter struct {
2279 mu sync.Mutex
2280 live int64
2281 }
2282
2283 func (cc *contextCounter) Track(ctx context.Context) context.Context {
2284 counted := new(countedContext)
2285 counted.Context = ctx
2286 cc.mu.Lock()
2287 defer cc.mu.Unlock()
2288 cc.live++
2289 runtime.AddCleanup(counted, func(c *contextCounter) { cc.decrement(nil) }, cc)
2290 return counted
2291 }
2292
2293 func (cc *contextCounter) decrement(*countedContext) {
2294 cc.mu.Lock()
2295 defer cc.mu.Unlock()
2296 cc.live--
2297 }
2298
2299 func (cc *contextCounter) Read() (live int64) {
2300 cc.mu.Lock()
2301 defer cc.mu.Unlock()
2302 return cc.live
2303 }
2304
2305 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
2306 run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
2307 }
2308 func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
2309 if mode == http2Mode {
2310 t.Skip("https://go.dev/issue/56021")
2311 }
2312
2313 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2314 runtime.Gosched()
2315 w.WriteHeader(StatusOK)
2316 })).ts
2317
2318 c := ts.Client()
2319 c.Transport.(*Transport).MaxConnsPerHost = 1
2320
2321 ctx := context.Background()
2322 body := []byte("Hello")
2323 doPosts := func(cc *contextCounter) {
2324 var wg sync.WaitGroup
2325 for n := 64; n > 0; n-- {
2326 wg.Add(1)
2327 go func() {
2328 defer wg.Done()
2329
2330 ctx := cc.Track(ctx)
2331 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2332 if err != nil {
2333 t.Error(err)
2334 }
2335
2336 _, err = c.Do(req.WithContext(ctx))
2337 if err != nil {
2338 t.Errorf("Do failed with error: %v", err)
2339 }
2340 }()
2341 }
2342 wg.Wait()
2343 }
2344
2345 var initialCC contextCounter
2346 doPosts(&initialCC)
2347
2348
2349
2350
2351 var flushCC contextCounter
2352 for i := 0; ; i++ {
2353 live := initialCC.Read()
2354 if live == 0 {
2355 break
2356 }
2357 if i >= 100 {
2358 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
2359 }
2360 doPosts(&flushCC)
2361 runtime.GC()
2362 }
2363 }
2364
2365
2366 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
2367 func testTransportIdleConnCrash(t *testing.T, mode testMode) {
2368 var tr *Transport
2369
2370 unblockCh := make(chan bool, 1)
2371 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2372 <-unblockCh
2373 tr.CloseIdleConnections()
2374 })).ts
2375 c := ts.Client()
2376 tr = c.Transport.(*Transport)
2377
2378 didreq := make(chan bool)
2379 go func() {
2380 res, err := c.Get(ts.URL)
2381 if err != nil {
2382 t.Error(err)
2383 } else {
2384 res.Body.Close()
2385 }
2386 didreq <- true
2387 }()
2388 unblockCh <- true
2389 <-didreq
2390 }
2391
2392
2393
2394
2395
2396 func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
2397 func testIssue3644(t *testing.T, mode testMode) {
2398 const numFoos = 5000
2399 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2400 w.Header().Set("Connection", "close")
2401 for i := 0; i < numFoos; i++ {
2402 w.Write([]byte("foo "))
2403 }
2404 })).ts
2405 c := ts.Client()
2406 res, err := c.Get(ts.URL)
2407 if err != nil {
2408 t.Fatal(err)
2409 }
2410 defer res.Body.Close()
2411 bs, err := io.ReadAll(res.Body)
2412 if err != nil {
2413 t.Fatal(err)
2414 }
2415 if len(bs) != numFoos*len("foo ") {
2416 t.Errorf("unexpected response length")
2417 }
2418 }
2419
2420
2421
2422 func TestIssue3595(t *testing.T) {
2423
2424 run(t, testIssue3595, testNotParallel)
2425 }
2426 func testIssue3595(t *testing.T, mode testMode) {
2427 runTimeSensitiveTest(t, []time.Duration{
2428 1 * time.Millisecond,
2429 5 * time.Millisecond,
2430 10 * time.Millisecond,
2431 50 * time.Millisecond,
2432 100 * time.Millisecond,
2433 500 * time.Millisecond,
2434 time.Second,
2435 5 * time.Second,
2436 }, func(t *testing.T, timeout time.Duration) error {
2437 SetRSTAvoidanceDelay(t, timeout)
2438 t.Logf("set RST avoidance delay to %v", timeout)
2439
2440 const deniedMsg = "sorry, denied."
2441 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2442 Error(w, deniedMsg, StatusUnauthorized)
2443 }))
2444
2445
2446 defer cst.close()
2447 ts := cst.ts
2448 c := ts.Client()
2449
2450 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2451 if err != nil {
2452 return fmt.Errorf("Post: %v", err)
2453 }
2454 got, err := io.ReadAll(res.Body)
2455 if err != nil {
2456 return fmt.Errorf("Body ReadAll: %v", err)
2457 }
2458 t.Logf("server response:\n%s", got)
2459 if !strings.Contains(string(got), deniedMsg) {
2460
2461
2462 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2463 }
2464 return nil
2465 })
2466 }
2467
2468
2469
2470 func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
2471 func testChunkedNoContent(t *testing.T, mode testMode) {
2472 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2473 w.WriteHeader(StatusNoContent)
2474 })).ts
2475
2476 c := ts.Client()
2477 for _, closeBody := range []bool{true, false} {
2478 const n = 4
2479 for i := 1; i <= n; i++ {
2480 res, err := c.Get(ts.URL)
2481 if err != nil {
2482 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2483 } else {
2484 if closeBody {
2485 res.Body.Close()
2486 }
2487 }
2488 }
2489 }
2490 }
2491
2492 func TestTransportConcurrency(t *testing.T) {
2493 run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
2494 }
2495 func testTransportConcurrency(t *testing.T, mode testMode) {
2496
2497 maxProcs, numReqs := 16, 500
2498 if testing.Short() {
2499 maxProcs, numReqs = 4, 50
2500 }
2501 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2502 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2503 fmt.Fprintf(w, "%v", r.FormValue("echo"))
2504 })).ts
2505
2506 var wg sync.WaitGroup
2507 wg.Add(numReqs)
2508
2509
2510
2511
2512
2513
2514
2515 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2516 defer SetPendingDialHooks(nil, nil)
2517
2518 c := ts.Client()
2519 reqs := make(chan string)
2520 defer close(reqs)
2521
2522 for i := 0; i < maxProcs*2; i++ {
2523 go func() {
2524 for req := range reqs {
2525 res, err := c.Get(ts.URL + "/?echo=" + req)
2526 if err != nil {
2527 if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
2528
2529
2530 t.Logf("error on req %s: %v", req, err)
2531 t.Logf("(see https://go.dev/issue/52168)")
2532 } else {
2533 t.Errorf("error on req %s: %v", req, err)
2534 }
2535 wg.Done()
2536 continue
2537 }
2538 all, err := io.ReadAll(res.Body)
2539 if err != nil {
2540 t.Errorf("read error on req %s: %v", req, err)
2541 } else if string(all) != req {
2542 t.Errorf("body of req %s = %q; want %q", req, all, req)
2543 }
2544 res.Body.Close()
2545 wg.Done()
2546 }
2547 }()
2548 }
2549 for i := 0; i < numReqs; i++ {
2550 reqs <- fmt.Sprintf("request-%d", i)
2551 }
2552 wg.Wait()
2553 }
2554
2555 func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
2556 func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
2557 mux := NewServeMux()
2558 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2559 io.Copy(w, neverEnding('a'))
2560 })
2561 ts := newClientServerTest(t, mode, mux).ts
2562
2563 connc := make(chan net.Conn, 1)
2564 c := ts.Client()
2565 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2566 conn, err := net.Dial(n, addr)
2567 if err != nil {
2568 return nil, err
2569 }
2570 select {
2571 case connc <- conn:
2572 default:
2573 }
2574 return conn, nil
2575 }
2576
2577 res, err := c.Get(ts.URL + "/get")
2578 if err != nil {
2579 t.Fatalf("Error issuing GET: %v", err)
2580 }
2581 defer res.Body.Close()
2582
2583 conn := <-connc
2584 conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
2585 _, err = io.Copy(io.Discard, res.Body)
2586 if err == nil {
2587 t.Errorf("Unexpected successful copy")
2588 }
2589 }
2590
2591 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2592 run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
2593 }
2594 func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
2595 const debug = false
2596 mux := NewServeMux()
2597 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2598 io.Copy(w, neverEnding('a'))
2599 })
2600 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2601 defer r.Body.Close()
2602 io.Copy(io.Discard, r.Body)
2603 })
2604 ts := newClientServerTest(t, mode, mux).ts
2605 timeout := 100 * time.Millisecond
2606
2607 c := ts.Client()
2608 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2609 conn, err := net.Dial(n, addr)
2610 if err != nil {
2611 return nil, err
2612 }
2613 conn.SetDeadline(time.Now().Add(timeout))
2614 if debug {
2615 conn = NewLoggingConn("client", conn)
2616 }
2617 return conn, nil
2618 }
2619
2620 getFailed := false
2621 nRuns := 5
2622 if testing.Short() {
2623 nRuns = 1
2624 }
2625 for i := 0; i < nRuns; i++ {
2626 if debug {
2627 println("run", i+1, "of", nRuns)
2628 }
2629 sres, err := c.Get(ts.URL + "/get")
2630 if err != nil {
2631 if !getFailed {
2632
2633 getFailed = true
2634 t.Logf("increasing timeout")
2635 i--
2636 timeout *= 10
2637 continue
2638 }
2639 t.Errorf("Error issuing GET: %v", err)
2640 break
2641 }
2642 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2643 _, err = c.Do(req)
2644 if err == nil {
2645 sres.Body.Close()
2646 t.Errorf("Unexpected successful PUT")
2647 break
2648 }
2649 sres.Body.Close()
2650 }
2651 if debug {
2652 println("tests complete; waiting for handlers to finish")
2653 }
2654 ts.Close()
2655 }
2656
2657 func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
2658 func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
2659 if testing.Short() {
2660 t.Skip("skipping timeout test in -short mode")
2661 }
2662
2663 timeout := 2 * time.Millisecond
2664 retry := true
2665 for retry && !t.Failed() {
2666 var srvWG sync.WaitGroup
2667 inHandler := make(chan bool, 1)
2668 mux := NewServeMux()
2669 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2670 inHandler <- true
2671 srvWG.Done()
2672 })
2673 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2674 inHandler <- true
2675 <-r.Context().Done()
2676 srvWG.Done()
2677 })
2678 ts := newClientServerTest(t, mode, mux).ts
2679
2680 c := ts.Client()
2681 c.Transport.(*Transport).ResponseHeaderTimeout = timeout
2682
2683 retry = false
2684 srvWG.Add(3)
2685 tests := []struct {
2686 path string
2687 wantTimeout bool
2688 }{
2689 {path: "/fast"},
2690 {path: "/slow", wantTimeout: true},
2691 {path: "/fast"},
2692 }
2693 for i, tt := range tests {
2694 req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2695 req = req.WithT(t)
2696 res, err := c.Do(req)
2697 <-inHandler
2698 if err != nil {
2699 uerr, ok := err.(*url.Error)
2700 if !ok {
2701 t.Errorf("error is not a url.Error; got: %#v", err)
2702 continue
2703 }
2704 nerr, ok := uerr.Err.(net.Error)
2705 if !ok {
2706 t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2707 continue
2708 }
2709 if !nerr.Timeout() {
2710 t.Errorf("want timeout error; got: %q", nerr)
2711 continue
2712 }
2713 if !tt.wantTimeout {
2714 if !retry {
2715
2716 t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
2717 timeout *= 2
2718 retry = true
2719 }
2720 }
2721 if !strings.Contains(err.Error(), "timeout awaiting response headers") {
2722 t.Errorf("%d. unexpected error: %v", i, err)
2723 }
2724 continue
2725 }
2726 if tt.wantTimeout {
2727 t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
2728 continue
2729 }
2730 if res.StatusCode != 200 {
2731 t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
2732 }
2733 }
2734
2735 srvWG.Wait()
2736 ts.Close()
2737 }
2738 }
2739
2740
2741 type cancelTest struct {
2742 mode testMode
2743 newReq func(req *Request) *Request
2744 cancel func(tr *Transport, req *Request)
2745 checkErr func(when string, err error)
2746 }
2747
2748
2749 func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2750 t.Run("TransportCancel", func(t *testing.T) {
2751 f(t, cancelTest{
2752 mode: mode,
2753 newReq: func(req *Request) *Request {
2754 return req
2755 },
2756 cancel: func(tr *Transport, req *Request) {
2757 tr.CancelRequest(req)
2758 },
2759 checkErr: func(when string, err error) {
2760 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2761 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2762 }
2763 },
2764 })
2765 })
2766 }
2767
2768
2769 func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2770 cancelc := make(chan struct{})
2771 cancelOnce := sync.OnceFunc(func() { close(cancelc) })
2772 f(t, cancelTest{
2773 mode: mode,
2774 newReq: func(req *Request) *Request {
2775 req.Cancel = cancelc
2776 return req
2777 },
2778 cancel: func(tr *Transport, req *Request) {
2779 cancelOnce()
2780 },
2781 checkErr: func(when string, err error) {
2782 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2783 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2784 }
2785 },
2786 })
2787 }
2788
2789
2790 func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2791 ctx, cancel := context.WithCancel(context.Background())
2792 f(t, cancelTest{
2793 mode: mode,
2794 newReq: func(req *Request) *Request {
2795 return req.WithContext(ctx)
2796 },
2797 cancel: func(tr *Transport, req *Request) {
2798 cancel()
2799 },
2800 checkErr: func(when string, err error) {
2801 if !errors.Is(err, context.Canceled) {
2802 t.Errorf("%v error = %v, want context.Canceled", when, err)
2803 }
2804 },
2805 })
2806 }
2807
2808 func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
2809 run(t, func(t *testing.T, mode testMode) {
2810 if mode == http1Mode {
2811 t.Run("TransportCancel", func(t *testing.T) {
2812 runCancelTestTransport(t, mode, f)
2813 })
2814 }
2815 t.Run("RequestCancel", func(t *testing.T) {
2816 runCancelTestChannel(t, mode, f)
2817 })
2818 t.Run("ContextCancel", func(t *testing.T) {
2819 runCancelTestContext(t, mode, f)
2820 })
2821 }, opts...)
2822 }
2823
2824 func TestTransportCancelRequest(t *testing.T) {
2825 runCancelTest(t, testTransportCancelRequest)
2826 }
2827 func testTransportCancelRequest(t *testing.T, test cancelTest) {
2828 if testing.Short() {
2829 t.Skip("skipping test in -short mode")
2830 }
2831
2832 const msg = "Hello"
2833 unblockc := make(chan bool)
2834 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2835 io.WriteString(w, msg)
2836 w.(Flusher).Flush()
2837 <-unblockc
2838 })).ts
2839 defer close(unblockc)
2840
2841 c := ts.Client()
2842 tr := c.Transport.(*Transport)
2843
2844 req, _ := NewRequest("GET", ts.URL, nil)
2845 req = test.newReq(req)
2846 res, err := c.Do(req)
2847 if err != nil {
2848 t.Fatal(err)
2849 }
2850 body := make([]byte, len(msg))
2851 n, _ := io.ReadFull(res.Body, body)
2852 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2853 t.Errorf("Body = %q; want %q", body[:n], msg)
2854 }
2855 test.cancel(tr, req)
2856
2857 tail, err := io.ReadAll(res.Body)
2858 res.Body.Close()
2859 test.checkErr("Body.Read", err)
2860 if len(tail) > 0 {
2861 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2862 }
2863
2864
2865
2866 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2867 n := tr.NumPendingRequestsForTesting()
2868 if n > 0 {
2869 if d > 0 {
2870 t.Logf("pending requests = %d after %v (want 0)", n, d)
2871 }
2872 return false
2873 }
2874 return true
2875 })
2876 }
2877
2878 func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
2879 if testing.Short() {
2880 t.Skip("skipping test in -short mode")
2881 }
2882 unblockc := make(chan bool)
2883 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2884 <-unblockc
2885 })).ts
2886 defer close(unblockc)
2887
2888 c := ts.Client()
2889 tr := c.Transport.(*Transport)
2890
2891 donec := make(chan bool)
2892 req, _ := NewRequest("GET", ts.URL, body)
2893 req = test.newReq(req)
2894 go func() {
2895 defer close(donec)
2896 c.Do(req)
2897 }()
2898
2899 unblockc <- true
2900 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2901 test.cancel(tr, req)
2902 select {
2903 case <-donec:
2904 return true
2905 default:
2906 if d > 0 {
2907 t.Logf("Do of canceled request has not returned after %v", d)
2908 }
2909 return false
2910 }
2911 })
2912 }
2913
2914 func TestTransportCancelRequestInDo(t *testing.T) {
2915 runCancelTest(t, func(t *testing.T, test cancelTest) {
2916 testTransportCancelRequestInDo(t, test, nil)
2917 })
2918 }
2919
2920 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2921 runCancelTest(t, func(t *testing.T, test cancelTest) {
2922 testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
2923 })
2924 }
2925
2926 func TestTransportCancelRequestInDial(t *testing.T) {
2927 runCancelTest(t, testTransportCancelRequestInDial)
2928 }
2929 func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
2930 defer afterTest(t)
2931 if testing.Short() {
2932 t.Skip("skipping test in -short mode")
2933 }
2934 var logbuf strings.Builder
2935 eventLog := log.New(&logbuf, "", 0)
2936
2937 unblockDial := make(chan bool)
2938 defer close(unblockDial)
2939
2940 inDial := make(chan bool)
2941 tr := &Transport{
2942 Dial: func(network, addr string) (net.Conn, error) {
2943 eventLog.Println("dial: blocking")
2944 if !<-inDial {
2945 return nil, errors.New("main Test goroutine exited")
2946 }
2947 <-unblockDial
2948 return nil, errors.New("nope")
2949 },
2950 }
2951 cl := &Client{Transport: tr}
2952 gotres := make(chan bool)
2953 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2954 req = test.newReq(req)
2955 go func() {
2956 _, err := cl.Do(req)
2957 eventLog.Printf("Get error = %v", err != nil)
2958 test.checkErr("Get", err)
2959 gotres <- true
2960 }()
2961
2962 inDial <- true
2963
2964 eventLog.Printf("canceling")
2965 test.cancel(tr, req)
2966 test.cancel(tr, req)
2967
2968 if d, ok := t.Deadline(); ok {
2969
2970
2971 timeout := time.Until(d) * 19 / 20
2972 timer := time.AfterFunc(timeout, func() {
2973 panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
2974 })
2975 defer timer.Stop()
2976 }
2977 <-gotres
2978
2979 got := logbuf.String()
2980 want := `dial: blocking
2981 canceling
2982 Get error = true
2983 `
2984 if got != want {
2985 t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
2986 }
2987 }
2988
2989
2990 func TestTransportCancelRequestWithBody(t *testing.T) {
2991 runCancelTest(t, testTransportCancelRequestWithBody)
2992 }
2993 func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
2994 if testing.Short() {
2995 t.Skip("skipping test in -short mode")
2996 }
2997
2998 const msg = "Hello"
2999 unblockc := make(chan struct{})
3000 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3001 io.WriteString(w, msg)
3002 w.(Flusher).Flush()
3003 <-unblockc
3004 })).ts
3005 defer close(unblockc)
3006
3007 c := ts.Client()
3008 tr := c.Transport.(*Transport)
3009
3010 req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
3011 req = test.newReq(req)
3012
3013 res, err := c.Do(req)
3014 if err != nil {
3015 t.Fatal(err)
3016 }
3017 body := make([]byte, len(msg))
3018 n, _ := io.ReadFull(res.Body, body)
3019 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
3020 t.Errorf("Body = %q; want %q", body[:n], msg)
3021 }
3022 test.cancel(tr, req)
3023
3024 tail, err := io.ReadAll(res.Body)
3025 res.Body.Close()
3026 test.checkErr("Body.Read", err)
3027 if len(tail) > 0 {
3028 t.Errorf("Spurious bytes from Body.Read: %q", tail)
3029 }
3030
3031
3032
3033 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
3034 n := tr.NumPendingRequestsForTesting()
3035 if n > 0 {
3036 if d > 0 {
3037 t.Logf("pending requests = %d after %v (want 0)", n, d)
3038 }
3039 return false
3040 }
3041 return true
3042 })
3043 }
3044
3045 func TestTransportCancelRequestBeforeDo(t *testing.T) {
3046
3047 run(t, func(t *testing.T, mode testMode) {
3048 t.Run("RequestCancel", func(t *testing.T) {
3049 runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
3050 })
3051 t.Run("ContextCancel", func(t *testing.T) {
3052 runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
3053 })
3054 })
3055 }
3056 func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
3057 unblockc := make(chan bool)
3058 cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3059 <-unblockc
3060 }))
3061 defer close(unblockc)
3062
3063 c := cst.ts.Client()
3064
3065 req, _ := NewRequest("GET", cst.ts.URL, nil)
3066 req = test.newReq(req)
3067 test.cancel(cst.tr, req)
3068
3069 _, err := c.Do(req)
3070 test.checkErr("Do", err)
3071 }
3072
3073
3074 func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
3075 runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
3076 }
3077 func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
3078 defer afterTest(t)
3079
3080 serverConnCh := make(chan net.Conn, 1)
3081 tr := &Transport{
3082 Dial: func(network, addr string) (net.Conn, error) {
3083 cc, sc := net.Pipe()
3084 serverConnCh <- sc
3085 return cc, nil
3086 },
3087 }
3088 defer tr.CloseIdleConnections()
3089 errc := make(chan error, 1)
3090 req, _ := NewRequest("GET", "http://example.com/", nil)
3091 req = test.newReq(req)
3092 go func() {
3093 _, err := tr.RoundTrip(req)
3094 errc <- err
3095 }()
3096
3097 sc := <-serverConnCh
3098 verb := make([]byte, 3)
3099 if _, err := io.ReadFull(sc, verb); err != nil {
3100 t.Errorf("Error reading HTTP verb from server: %v", err)
3101 }
3102 if string(verb) != "GET" {
3103 t.Errorf("server received %q; want GET", verb)
3104 }
3105 defer sc.Close()
3106
3107 test.cancel(tr, req)
3108
3109 err := <-errc
3110 if err == nil {
3111 t.Fatalf("unexpected success from RoundTrip")
3112 }
3113 test.checkErr("RoundTrip", err)
3114 }
3115
3116
3117
3118
3119 func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
3120 func testTransportCloseResponseBody(t *testing.T, mode testMode) {
3121 writeErr := make(chan error, 1)
3122 msg := []byte("young\n")
3123 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3124 for {
3125 _, err := w.Write(msg)
3126 if err != nil {
3127 writeErr <- err
3128 return
3129 }
3130 w.(Flusher).Flush()
3131 }
3132 })).ts
3133
3134 c := ts.Client()
3135 tr := c.Transport.(*Transport)
3136
3137 req, _ := NewRequest("GET", ts.URL, nil)
3138 defer tr.CancelRequest(req)
3139
3140 res, err := c.Do(req)
3141 if err != nil {
3142 t.Fatal(err)
3143 }
3144
3145 const repeats = 3
3146 buf := make([]byte, len(msg)*repeats)
3147 want := bytes.Repeat(msg, repeats)
3148
3149 _, err = io.ReadFull(res.Body, buf)
3150 if err != nil {
3151 t.Fatal(err)
3152 }
3153 if !bytes.Equal(buf, want) {
3154 t.Fatalf("read %q; want %q", buf, want)
3155 }
3156
3157 if err := res.Body.Close(); err != nil {
3158 t.Errorf("Close = %v", err)
3159 }
3160
3161 if err := <-writeErr; err == nil {
3162 t.Errorf("expected non-nil write error")
3163 }
3164 }
3165
3166 type fooProto struct{}
3167
3168 func (fooProto) RoundTrip(req *Request) (*Response, error) {
3169 res := &Response{
3170 Status: "200 OK",
3171 StatusCode: 200,
3172 Header: make(Header),
3173 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
3174 }
3175 return res, nil
3176 }
3177
3178 func TestTransportAltProto(t *testing.T) {
3179 defer afterTest(t)
3180 tr := &Transport{}
3181 c := &Client{Transport: tr}
3182 tr.RegisterProtocol("foo", fooProto{})
3183 res, err := c.Get("foo://bar.com/path")
3184 if err != nil {
3185 t.Fatal(err)
3186 }
3187 bodyb, err := io.ReadAll(res.Body)
3188 if err != nil {
3189 t.Fatal(err)
3190 }
3191 body := string(bodyb)
3192 if e := "You wanted foo://bar.com/path"; body != e {
3193 t.Errorf("got response %q, want %q", body, e)
3194 }
3195 }
3196
3197 func TestTransportNoHost(t *testing.T) {
3198 defer afterTest(t)
3199 tr := &Transport{}
3200 _, err := tr.RoundTrip(&Request{
3201 Header: make(Header),
3202 URL: &url.URL{
3203 Scheme: "http",
3204 },
3205 })
3206 want := "http: no Host in request URL"
3207 if got := fmt.Sprint(err); got != want {
3208 t.Errorf("error = %v; want %q", err, want)
3209 }
3210 }
3211
3212
3213 func TestTransportEmptyMethod(t *testing.T) {
3214 req, _ := NewRequest("GET", "http://foo.com/", nil)
3215 req.Method = ""
3216 got, err := httputil.DumpRequestOut(req, false)
3217 if err != nil {
3218 t.Fatal(err)
3219 }
3220 if !strings.Contains(string(got), "GET ") {
3221 t.Fatalf("expected substring 'GET '; got: %s", got)
3222 }
3223 }
3224
3225 func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
3226 func testTransportSocketLateBinding(t *testing.T, mode testMode) {
3227 mux := NewServeMux()
3228 fooGate := make(chan bool, 1)
3229 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
3230 w.Header().Set("foo-ipport", r.RemoteAddr)
3231 w.(Flusher).Flush()
3232 <-fooGate
3233 })
3234 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
3235 w.Header().Set("bar-ipport", r.RemoteAddr)
3236 })
3237 ts := newClientServerTest(t, mode, mux).ts
3238
3239 dialGate := make(chan bool, 1)
3240 dialing := make(chan bool)
3241 c := ts.Client()
3242 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
3243 for {
3244 select {
3245 case ok := <-dialGate:
3246 if !ok {
3247 return nil, errors.New("manually closed")
3248 }
3249 return net.Dial(n, addr)
3250 case dialing <- true:
3251 }
3252 }
3253 }
3254 defer close(dialGate)
3255
3256 dialGate <- true
3257 fooRes, err := c.Get(ts.URL + "/foo")
3258 if err != nil {
3259 t.Fatal(err)
3260 }
3261 fooAddr := fooRes.Header.Get("foo-ipport")
3262 if fooAddr == "" {
3263 t.Fatal("No addr on /foo request")
3264 }
3265
3266 fooDone := make(chan struct{})
3267 go func() {
3268
3269
3270
3271
3272 if mode == http2Mode {
3273
3274
3275
3276
3277 select {
3278 case <-dialing:
3279 t.Errorf("unexpected second Dial in HTTP/2 mode")
3280 case <-time.After(10 * time.Millisecond):
3281 }
3282 } else {
3283 <-dialing
3284 }
3285 fooGate <- true
3286 io.Copy(io.Discard, fooRes.Body)
3287 fooRes.Body.Close()
3288 close(fooDone)
3289 }()
3290 defer func() {
3291 <-fooDone
3292 }()
3293
3294 barRes, err := c.Get(ts.URL + "/bar")
3295 if err != nil {
3296 t.Fatal(err)
3297 }
3298 barAddr := barRes.Header.Get("bar-ipport")
3299 if barAddr != fooAddr {
3300 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
3301 }
3302 barRes.Body.Close()
3303 }
3304
3305
3306 func TestTransportReading100Continue(t *testing.T) {
3307 defer afterTest(t)
3308
3309 const numReqs = 5
3310 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
3311 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
3312
3313 send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
3314 defer w.Close()
3315 defer r.Close()
3316 br := bufio.NewReader(r)
3317 n := 0
3318 for {
3319 n++
3320 req, err := ReadRequest(br)
3321 if err == io.EOF {
3322 return
3323 }
3324 if err != nil {
3325 t.Error(err)
3326 return
3327 }
3328 slurp, err := io.ReadAll(req.Body)
3329 if err != nil {
3330 t.Errorf("Server request body slurp: %v", err)
3331 return
3332 }
3333 id := req.Header.Get("Request-Id")
3334 resCode := req.Header.Get("X-Want-Response-Code")
3335 if resCode == "" {
3336 resCode = "100 Continue"
3337 if string(slurp) != reqBody(n) {
3338 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
3339 }
3340 }
3341 body := fmt.Sprintf("Response number %d", n)
3342 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
3343 Date: Thu, 28 Feb 2013 17:55:41 GMT
3344
3345 HTTP/1.1 200 OK
3346 Content-Type: text/html
3347 Echo-Request-Id: %s
3348 Content-Length: %d
3349
3350 %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
3351 w.Write(v)
3352 if id == reqID(numReqs) {
3353 return
3354 }
3355 }
3356
3357 }
3358
3359 tr := &Transport{
3360 Dial: func(n, addr string) (net.Conn, error) {
3361 sr, sw := io.Pipe()
3362 cr, cw := io.Pipe()
3363 conn := &rwTestConn{
3364 Reader: cr,
3365 Writer: sw,
3366 closeFunc: func() error {
3367 sw.Close()
3368 cw.Close()
3369 return nil
3370 },
3371 }
3372 go send100Response(cw, sr)
3373 return conn, nil
3374 },
3375 DisableKeepAlives: false,
3376 }
3377 defer tr.CloseIdleConnections()
3378 c := &Client{Transport: tr}
3379
3380 testResponse := func(req *Request, name string, wantCode int) {
3381 t.Helper()
3382 res, err := c.Do(req)
3383 if err != nil {
3384 t.Fatalf("%s: Do: %v", name, err)
3385 }
3386 if res.StatusCode != wantCode {
3387 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
3388 }
3389 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
3390 t.Errorf("%s: response id %q != request id %q", name, idBack, id)
3391 }
3392 _, err = io.ReadAll(res.Body)
3393 if err != nil {
3394 t.Fatalf("%s: Slurp error: %v", name, err)
3395 }
3396 }
3397
3398
3399 for i := 1; i <= numReqs; i++ {
3400 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
3401 req.Header.Set("Request-Id", reqID(i))
3402 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
3403 }
3404 }
3405
3406
3407
3408 func TestTransportIgnore1xxResponses(t *testing.T) {
3409 run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
3410 }
3411 func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
3412 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3413 conn, buf, _ := w.(Hijacker).Hijack()
3414 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
3415 buf.Flush()
3416 conn.Close()
3417 }))
3418 cst.tr.DisableKeepAlives = true
3419
3420 var got strings.Builder
3421
3422 req, _ := NewRequest("GET", cst.ts.URL, nil)
3423 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3424 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3425 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
3426 return nil
3427 },
3428 }))
3429 res, err := cst.c.Do(req)
3430 if err != nil {
3431 t.Fatal(err)
3432 }
3433 defer res.Body.Close()
3434
3435 res.Write(&got)
3436 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
3437 if got.String() != want {
3438 t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
3439 }
3440 }
3441
3442 func TestTransportLimits1xxResponses(t *testing.T) { run(t, testTransportLimits1xxResponses) }
3443 func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
3444 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3445 w.Header().Add("X-Header", strings.Repeat("a", 100))
3446 for i := 0; i < 10; i++ {
3447 w.WriteHeader(123)
3448 }
3449 w.WriteHeader(204)
3450 }))
3451 cst.tr.DisableKeepAlives = true
3452 cst.tr.MaxResponseHeaderBytes = 1000
3453
3454 res, err := cst.c.Get(cst.ts.URL)
3455 if err == nil {
3456 res.Body.Close()
3457 t.Fatalf("RoundTrip succeeded; want error")
3458 }
3459 for _, want := range []string{
3460 "response headers exceeded",
3461 "too many 1xx",
3462 "header list too large",
3463 } {
3464 if strings.Contains(err.Error(), want) {
3465 return
3466 }
3467 }
3468 t.Errorf(`got error %q; want "response headers exceeded" or "too many 1xx"`, err)
3469 }
3470
3471 func TestTransportDoesNotLimitDelivered1xxResponses(t *testing.T) {
3472 run(t, testTransportDoesNotLimitDelivered1xxResponses)
3473 }
3474 func testTransportDoesNotLimitDelivered1xxResponses(t *testing.T, mode testMode) {
3475 if mode == http2Mode {
3476 t.Skip("skip until x/net/http2 updated")
3477 }
3478 const num1xx = 10
3479 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3480 w.Header().Add("X-Header", strings.Repeat("a", 100))
3481 for i := 0; i < 10; i++ {
3482 w.WriteHeader(123)
3483 }
3484 w.WriteHeader(204)
3485 }))
3486 cst.tr.DisableKeepAlives = true
3487 cst.tr.MaxResponseHeaderBytes = 1000
3488
3489 got1xx := 0
3490 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3491 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3492 got1xx++
3493 return nil
3494 },
3495 })
3496 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
3497 res, err := cst.c.Do(req)
3498 if err != nil {
3499 t.Fatal(err)
3500 }
3501 res.Body.Close()
3502 if got1xx != num1xx {
3503 t.Errorf("Got %v 1xx responses, want %x", got1xx, num1xx)
3504 }
3505 }
3506
3507
3508
3509 func TestTransportTreat101Terminal(t *testing.T) {
3510 run(t, testTransportTreat101Terminal, []testMode{http1Mode})
3511 }
3512 func testTransportTreat101Terminal(t *testing.T, mode testMode) {
3513 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3514 conn, buf, _ := w.(Hijacker).Hijack()
3515 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
3516 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3517 buf.Flush()
3518 conn.Close()
3519 }))
3520 res, err := cst.c.Get(cst.ts.URL)
3521 if err != nil {
3522 t.Fatal(err)
3523 }
3524 defer res.Body.Close()
3525 if res.StatusCode != StatusSwitchingProtocols {
3526 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
3527 }
3528 }
3529
3530 type proxyFromEnvTest struct {
3531 req string
3532
3533 env string
3534 httpsenv string
3535 noenv string
3536 reqmeth string
3537
3538 want string
3539 wanterr error
3540 }
3541
3542 func (t proxyFromEnvTest) String() string {
3543 var buf strings.Builder
3544 space := func() {
3545 if buf.Len() > 0 {
3546 buf.WriteByte(' ')
3547 }
3548 }
3549 if t.env != "" {
3550 fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3551 }
3552 if t.httpsenv != "" {
3553 space()
3554 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3555 }
3556 if t.noenv != "" {
3557 space()
3558 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3559 }
3560 if t.reqmeth != "" {
3561 space()
3562 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3563 }
3564 req := "http://example.com"
3565 if t.req != "" {
3566 req = t.req
3567 }
3568 space()
3569 fmt.Fprintf(&buf, "req=%q", req)
3570 return strings.TrimSpace(buf.String())
3571 }
3572
3573 var proxyFromEnvTests = []proxyFromEnvTest{
3574 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3575 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3576 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3577 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3578 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3579 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3580 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3581 {env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"},
3582
3583
3584 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3585
3586 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3587 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3588
3589
3590
3591 {env: "http://10.1.2.3:8080", reqmeth: "POST",
3592 want: "<nil>",
3593 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3594
3595 {want: "<nil>"},
3596
3597 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3598 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3599 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3600 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3601 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3602 }
3603
3604 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3605 t.Helper()
3606 reqURL := tt.req
3607 if reqURL == "" {
3608 reqURL = "http://example.com"
3609 }
3610 req, _ := NewRequest("GET", reqURL, nil)
3611 url, err := proxyForRequest(req)
3612 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3613 t.Errorf("%v: got error = %q, want %q", tt, g, e)
3614 return
3615 }
3616 if got := fmt.Sprintf("%s", url); got != tt.want {
3617 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3618 }
3619 }
3620
3621 func TestProxyFromEnvironment(t *testing.T) {
3622 ResetProxyEnv()
3623 defer ResetProxyEnv()
3624 for _, tt := range proxyFromEnvTests {
3625 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3626 os.Setenv("HTTP_PROXY", tt.env)
3627 os.Setenv("HTTPS_PROXY", tt.httpsenv)
3628 os.Setenv("NO_PROXY", tt.noenv)
3629 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3630 ResetCachedEnvironment()
3631 return ProxyFromEnvironment(req)
3632 })
3633 }
3634 }
3635
3636 func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3637 ResetProxyEnv()
3638 defer ResetProxyEnv()
3639 for _, tt := range proxyFromEnvTests {
3640 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3641 os.Setenv("http_proxy", tt.env)
3642 os.Setenv("https_proxy", tt.httpsenv)
3643 os.Setenv("no_proxy", tt.noenv)
3644 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3645 ResetCachedEnvironment()
3646 return ProxyFromEnvironment(req)
3647 })
3648 }
3649 }
3650
3651 func TestIdleConnChannelLeak(t *testing.T) {
3652 run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
3653 }
3654 func testIdleConnChannelLeak(t *testing.T, mode testMode) {
3655
3656 var mu sync.Mutex
3657 var n int
3658
3659 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3660 mu.Lock()
3661 n++
3662 mu.Unlock()
3663 })).ts
3664
3665 const nReqs = 5
3666 didRead := make(chan bool, nReqs)
3667 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3668 defer SetReadLoopBeforeNextReadHook(nil)
3669
3670 c := ts.Client()
3671 tr := c.Transport.(*Transport)
3672 tr.Dial = func(netw, addr string) (net.Conn, error) {
3673 return net.Dial(netw, ts.Listener.Addr().String())
3674 }
3675
3676
3677 for _, disableKeep := range []bool{true, false} {
3678 tr.DisableKeepAlives = disableKeep
3679 for i := 0; i < nReqs; i++ {
3680 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3681 if err != nil {
3682 t.Fatal(err)
3683 }
3684
3685
3686
3687
3688
3689 }
3690
3691
3692
3693
3694
3695
3696
3697 for i := 0; i < nReqs; i++ {
3698 <-didRead
3699 }
3700
3701 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3702 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3703 }
3704 }
3705 }
3706
3707
3708
3709
3710 func TestTransportClosesRequestBody(t *testing.T) {
3711 run(t, testTransportClosesRequestBody, []testMode{http1Mode})
3712 }
3713 func testTransportClosesRequestBody(t *testing.T, mode testMode) {
3714 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3715 io.Copy(io.Discard, r.Body)
3716 })).ts
3717
3718 c := ts.Client()
3719
3720 closes := 0
3721
3722 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3723 if err != nil {
3724 t.Fatal(err)
3725 }
3726 res.Body.Close()
3727 if closes != 1 {
3728 t.Errorf("closes = %d; want 1", closes)
3729 }
3730 }
3731
3732 func TestTransportTLSHandshakeTimeout(t *testing.T) {
3733 defer afterTest(t)
3734 if testing.Short() {
3735 t.Skip("skipping in short mode")
3736 }
3737 ln := newLocalListener(t)
3738 defer ln.Close()
3739 testdonec := make(chan struct{})
3740 defer close(testdonec)
3741
3742 go func() {
3743 c, err := ln.Accept()
3744 if err != nil {
3745 t.Error(err)
3746 return
3747 }
3748 <-testdonec
3749 c.Close()
3750 }()
3751
3752 tr := &Transport{
3753 Dial: func(_, _ string) (net.Conn, error) {
3754 return net.Dial("tcp", ln.Addr().String())
3755 },
3756 TLSHandshakeTimeout: 250 * time.Millisecond,
3757 }
3758 cl := &Client{Transport: tr}
3759 _, err := cl.Get("https://dummy.tld/")
3760 if err == nil {
3761 t.Error("expected error")
3762 return
3763 }
3764 ue, ok := err.(*url.Error)
3765 if !ok {
3766 t.Errorf("expected url.Error; got %#v", err)
3767 return
3768 }
3769 ne, ok := ue.Err.(net.Error)
3770 if !ok {
3771 t.Errorf("expected net.Error; got %#v", err)
3772 return
3773 }
3774 if !ne.Timeout() {
3775 t.Errorf("expected timeout error; got %v", err)
3776 }
3777 if !strings.Contains(err.Error(), "handshake timeout") {
3778 t.Errorf("expected 'handshake timeout' in error; got %v", err)
3779 }
3780 }
3781
3782
3783 func TestTLSServerClosesConnection(t *testing.T) {
3784 run(t, testTLSServerClosesConnection, []testMode{https1Mode})
3785 }
3786 func testTLSServerClosesConnection(t *testing.T, mode testMode) {
3787 closedc := make(chan bool, 1)
3788 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3789 if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3790 conn, _, _ := w.(Hijacker).Hijack()
3791 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3792 conn.Close()
3793 closedc <- true
3794 return
3795 }
3796 fmt.Fprintf(w, "hello")
3797 })).ts
3798
3799 c := ts.Client()
3800 tr := c.Transport.(*Transport)
3801
3802 var nSuccess = 0
3803 var errs []error
3804 const trials = 20
3805 for i := 0; i < trials; i++ {
3806 tr.CloseIdleConnections()
3807 res, err := c.Get(ts.URL + "/keep-alive-then-die")
3808 if err != nil {
3809 t.Fatal(err)
3810 }
3811 <-closedc
3812 slurp, err := io.ReadAll(res.Body)
3813 if err != nil {
3814 t.Fatal(err)
3815 }
3816 if string(slurp) != "foo" {
3817 t.Errorf("Got %q, want foo", slurp)
3818 }
3819
3820
3821
3822 res, err = c.Get(ts.URL + "/")
3823 if err != nil {
3824 errs = append(errs, err)
3825 continue
3826 }
3827 slurp, err = io.ReadAll(res.Body)
3828 if err != nil {
3829 errs = append(errs, err)
3830 continue
3831 }
3832 nSuccess++
3833 }
3834 if nSuccess > 0 {
3835 t.Logf("successes = %d of %d", nSuccess, trials)
3836 } else {
3837 t.Errorf("All runs failed:")
3838 }
3839 for _, err := range errs {
3840 t.Logf(" err: %v", err)
3841 }
3842 }
3843
3844
3845
3846
3847 type byteFromChanReader chan byte
3848
3849 func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3850 if len(p) == 0 {
3851 return
3852 }
3853 b, ok := <-c
3854 if !ok {
3855 return 0, io.EOF
3856 }
3857 p[0] = b
3858 return 1, nil
3859 }
3860
3861
3862
3863
3864
3865
3866
3867 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3868 run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
3869 }
3870 func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
3871 defer func(d time.Duration) {
3872 *MaxWriteWaitBeforeConnReuse = d
3873 }(*MaxWriteWaitBeforeConnReuse)
3874 *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
3875 var sconn struct {
3876 sync.Mutex
3877 c net.Conn
3878 }
3879 var getOkay bool
3880 var willCopy sync.WaitGroup
3881 closeConn := func() {
3882 sconn.Lock()
3883 defer sconn.Unlock()
3884 if sconn.c != nil {
3885 sconn.c.Close()
3886 sconn.c = nil
3887 if !getOkay {
3888 t.Logf("Closed server connection")
3889 }
3890 }
3891 }
3892 defer func() {
3893 closeConn()
3894 willCopy.Wait()
3895 }()
3896
3897 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3898 if r.Method == "GET" {
3899 io.WriteString(w, "bar")
3900 return
3901 }
3902 conn, _, _ := w.(Hijacker).Hijack()
3903 sconn.Lock()
3904 sconn.c = conn
3905 sconn.Unlock()
3906
3907 willCopy.Add(1)
3908 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3909 go func() {
3910 io.Copy(io.Discard, conn)
3911 willCopy.Done()
3912 }()
3913 })).ts
3914 c := ts.Client()
3915
3916 const bodySize = 256 << 10
3917 finalBit := make(byteFromChanReader, 1)
3918 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3919 req.ContentLength = bodySize
3920 res, err := c.Do(req)
3921 if err := wantBody(res, err, "foo"); err != nil {
3922 t.Errorf("POST response: %v", err)
3923 }
3924
3925 res, err = c.Get(ts.URL)
3926 if err := wantBody(res, err, "bar"); err != nil {
3927 t.Errorf("GET response: %v", err)
3928 return
3929 }
3930 getOkay = true
3931 finalBit <- 'x'
3932 close(finalBit)
3933 }
3934
3935
3936
3937 func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
3938 func testTransportIssue10457(t *testing.T, mode testMode) {
3939 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3940
3941
3942
3943
3944
3945 conn, _, _ := w.(Hijacker).Hijack()
3946 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n"))
3947 conn.Close()
3948 })).ts
3949 c := ts.Client()
3950
3951 res, err := c.Get(ts.URL)
3952 if err != nil {
3953 t.Fatalf("Get: %v", err)
3954 }
3955 defer res.Body.Close()
3956
3957
3958
3959
3960 if got, want := res.Header.Get("Foo"), "Bar"; got != want {
3961 t.Errorf("Foo header = %q; want %q", got, want)
3962 }
3963 }
3964
3965 type closerFunc func() error
3966
3967 func (f closerFunc) Close() error { return f() }
3968
3969 type writerFuncConn struct {
3970 net.Conn
3971 write func(p []byte) (n int, err error)
3972 }
3973
3974 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
3975
3976
3977
3978
3979
3980
3981
3982
3983
3984
3985
3986
3987
3988 func TestRetryRequestsOnError(t *testing.T) {
3989 run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
3990 }
3991 func testRetryRequestsOnError(t *testing.T, mode testMode) {
3992 newRequest := func(method, urlStr string, body io.Reader) *Request {
3993 req, err := NewRequest(method, urlStr, body)
3994 if err != nil {
3995 t.Fatal(err)
3996 }
3997 return req
3998 }
3999
4000 testCases := []struct {
4001 name string
4002 failureN int
4003 failureErr error
4004
4005
4006
4007 req func() *Request
4008 reqString string
4009 }{
4010 {
4011 name: "IdempotentNoBodySomeWritten",
4012
4013
4014 failureN: 1,
4015
4016 failureErr: ExportErrServerClosedIdle,
4017 req: func() *Request {
4018 return newRequest("GET", "http://fake.golang", nil)
4019 },
4020 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
4021 },
4022 {
4023 name: "IdempotentGetBodySomeWritten",
4024
4025
4026 failureN: 1,
4027
4028 failureErr: ExportErrServerClosedIdle,
4029 req: func() *Request {
4030 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
4031 },
4032 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
4033 },
4034 {
4035 name: "NothingWrittenNoBody",
4036
4037
4038 failureN: 0,
4039 failureErr: errors.New("second write fails"),
4040 req: func() *Request {
4041 return newRequest("DELETE", "http://fake.golang", nil)
4042 },
4043 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
4044 },
4045 {
4046 name: "NothingWrittenGetBody",
4047
4048
4049 failureN: 0,
4050 failureErr: errors.New("second write fails"),
4051
4052
4053 req: func() *Request {
4054 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
4055 },
4056 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
4057 },
4058 }
4059
4060 for _, tc := range testCases {
4061 t.Run(tc.name, func(t *testing.T) {
4062 var (
4063 mu sync.Mutex
4064 logbuf strings.Builder
4065 )
4066 logf := func(format string, args ...any) {
4067 mu.Lock()
4068 defer mu.Unlock()
4069 fmt.Fprintf(&logbuf, format, args...)
4070 logbuf.WriteByte('\n')
4071 }
4072
4073 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4074 logf("Handler")
4075 w.Header().Set("X-Status", "ok")
4076 })).ts
4077
4078 var writeNumAtomic int32
4079 c := ts.Client()
4080 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
4081 logf("Dial")
4082 c, err := net.Dial(network, ts.Listener.Addr().String())
4083 if err != nil {
4084 logf("Dial error: %v", err)
4085 return nil, err
4086 }
4087 return &writerFuncConn{
4088 Conn: c,
4089 write: func(p []byte) (n int, err error) {
4090 if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
4091 logf("intentional write failure")
4092 return tc.failureN, tc.failureErr
4093 }
4094 logf("Write(%q)", p)
4095 return c.Write(p)
4096 },
4097 }, nil
4098 }
4099
4100 SetRoundTripRetried(func() {
4101 logf("Retried.")
4102 })
4103 defer SetRoundTripRetried(nil)
4104
4105 for i := 0; i < 3; i++ {
4106 t0 := time.Now()
4107 req := tc.req()
4108 res, err := c.Do(req)
4109 if err != nil {
4110 if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
4111 mu.Lock()
4112 got := logbuf.String()
4113 mu.Unlock()
4114 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
4115 }
4116 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
4117 }
4118 res.Body.Close()
4119 if res.Request != req {
4120 t.Errorf("Response.Request != original request; want identical Request")
4121 }
4122 }
4123
4124 mu.Lock()
4125 got := logbuf.String()
4126 mu.Unlock()
4127 want := fmt.Sprintf(`Dial
4128 Write("%s")
4129 Handler
4130 intentional write failure
4131 Retried.
4132 Dial
4133 Write("%s")
4134 Handler
4135 Write("%s")
4136 Handler
4137 `, tc.reqString, tc.reqString, tc.reqString)
4138 if got != want {
4139 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
4140 }
4141 })
4142 }
4143 }
4144
4145
4146 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
4147 func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
4148 readBody := make(chan error, 1)
4149 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4150 _, err := io.ReadAll(r.Body)
4151 readBody <- err
4152 })).ts
4153 c := ts.Client()
4154 fakeErr := errors.New("fake error")
4155 didClose := make(chan bool, 1)
4156 req, _ := NewRequest("POST", ts.URL, struct {
4157 io.Reader
4158 io.Closer
4159 }{
4160 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
4161 closerFunc(func() error {
4162 select {
4163 case didClose <- true:
4164 default:
4165 }
4166 return nil
4167 }),
4168 })
4169 res, err := c.Do(req)
4170 if res != nil {
4171 defer res.Body.Close()
4172 }
4173 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
4174 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
4175 }
4176 if err := <-readBody; err == nil {
4177 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
4178 }
4179 select {
4180 case <-didClose:
4181 default:
4182 t.Errorf("didn't see Body.Close")
4183 }
4184 }
4185
4186 func TestTransportDialTLS(t *testing.T) {
4187 run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
4188 }
4189 func testTransportDialTLS(t *testing.T, mode testMode) {
4190 var mu sync.Mutex
4191 var gotReq, didDial bool
4192
4193 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4194 mu.Lock()
4195 gotReq = true
4196 mu.Unlock()
4197 })).ts
4198 c := ts.Client()
4199 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
4200 mu.Lock()
4201 didDial = true
4202 mu.Unlock()
4203 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4204 if err != nil {
4205 return nil, err
4206 }
4207 return c, c.Handshake()
4208 }
4209
4210 res, err := c.Get(ts.URL)
4211 if err != nil {
4212 t.Fatal(err)
4213 }
4214 res.Body.Close()
4215 mu.Lock()
4216 if !gotReq {
4217 t.Error("didn't get request")
4218 }
4219 if !didDial {
4220 t.Error("didn't use dial hook")
4221 }
4222 }
4223
4224 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
4225 func testTransportDialContext(t *testing.T, mode testMode) {
4226 ctxKey := "some-key"
4227 ctxValue := "some-value"
4228 var (
4229 mu sync.Mutex
4230 gotReq bool
4231 gotCtxValue any
4232 )
4233
4234 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4235 mu.Lock()
4236 gotReq = true
4237 mu.Unlock()
4238 })).ts
4239 c := ts.Client()
4240 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4241 mu.Lock()
4242 gotCtxValue = ctx.Value(ctxKey)
4243 mu.Unlock()
4244 return net.Dial(netw, addr)
4245 }
4246
4247 req, err := NewRequest("GET", ts.URL, nil)
4248 if err != nil {
4249 t.Fatal(err)
4250 }
4251 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4252 res, err := c.Do(req.WithContext(ctx))
4253 if err != nil {
4254 t.Fatal(err)
4255 }
4256 res.Body.Close()
4257 mu.Lock()
4258 if !gotReq {
4259 t.Error("didn't get request")
4260 }
4261 if got, want := gotCtxValue, ctxValue; got != want {
4262 t.Errorf("got context with value %v, want %v", got, want)
4263 }
4264 }
4265
4266 func TestTransportDialTLSContext(t *testing.T) {
4267 run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
4268 }
4269 func testTransportDialTLSContext(t *testing.T, mode testMode) {
4270 ctxKey := "some-key"
4271 ctxValue := "some-value"
4272 var (
4273 mu sync.Mutex
4274 gotReq bool
4275 gotCtxValue any
4276 )
4277
4278 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4279 mu.Lock()
4280 gotReq = true
4281 mu.Unlock()
4282 })).ts
4283 c := ts.Client()
4284 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4285 mu.Lock()
4286 gotCtxValue = ctx.Value(ctxKey)
4287 mu.Unlock()
4288 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4289 if err != nil {
4290 return nil, err
4291 }
4292 return c, c.HandshakeContext(ctx)
4293 }
4294
4295 req, err := NewRequest("GET", ts.URL, nil)
4296 if err != nil {
4297 t.Fatal(err)
4298 }
4299 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4300 res, err := c.Do(req.WithContext(ctx))
4301 if err != nil {
4302 t.Fatal(err)
4303 }
4304 res.Body.Close()
4305 mu.Lock()
4306 if !gotReq {
4307 t.Error("didn't get request")
4308 }
4309 if got, want := gotCtxValue, ctxValue; got != want {
4310 t.Errorf("got context with value %v, want %v", got, want)
4311 }
4312 }
4313
4314
4315
4316 func TestRoundTripReturnsProxyError(t *testing.T) {
4317 badProxy := func(*Request) (*url.URL, error) {
4318 return nil, errors.New("errorMessage")
4319 }
4320
4321 tr := &Transport{Proxy: badProxy}
4322
4323 req, _ := NewRequest("GET", "http://example.com", nil)
4324
4325 _, err := tr.RoundTrip(req)
4326
4327 if err == nil {
4328 t.Error("Expected proxy error to be returned by RoundTrip")
4329 }
4330 }
4331
4332
4333 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
4334 tr := &Transport{}
4335 wantIdle := func(when string, n int) bool {
4336 got := tr.IdleConnCountForTesting("http", "example.com")
4337 if got == n {
4338 return true
4339 }
4340 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4341 return false
4342 }
4343 wantIdle("start", 0)
4344 if !tr.PutIdleTestConn("http", "example.com") {
4345 t.Fatal("put failed")
4346 }
4347 if !tr.PutIdleTestConn("http", "example.com") {
4348 t.Fatal("second put failed")
4349 }
4350 wantIdle("after put", 2)
4351 tr.CloseIdleConnections()
4352 if !tr.IsIdleForTesting() {
4353 t.Error("should be idle after CloseIdleConnections")
4354 }
4355 wantIdle("after close idle", 0)
4356 if tr.PutIdleTestConn("http", "example.com") {
4357 t.Fatal("put didn't fail")
4358 }
4359 wantIdle("after second put", 0)
4360
4361 tr.QueueForIdleConnForTesting()
4362 if tr.IsIdleForTesting() {
4363 t.Error("shouldn't be idle after QueueForIdleConnForTesting")
4364 }
4365 if !tr.PutIdleTestConn("http", "example.com") {
4366 t.Fatal("after re-activation")
4367 }
4368 wantIdle("after final put", 1)
4369 }
4370
4371
4372
4373 func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
4374 tr := &Transport{}
4375 wantIdle := func(when string, n int) bool {
4376 got := tr.IdleConnCountForTesting("https", "example.com:443")
4377 if got == n {
4378 return true
4379 }
4380 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4381 return false
4382 }
4383 wantIdle("start", 0)
4384 alt := funcRoundTripper(func() {})
4385 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
4386 t.Fatal("put failed")
4387 }
4388 wantIdle("after put", 1)
4389 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
4390 GotConn: func(httptrace.GotConnInfo) {
4391
4392 t.Error("GotConn called")
4393 },
4394 })
4395 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
4396 _, err := tr.RoundTrip(req)
4397 if err != errFakeRoundTrip {
4398 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
4399 }
4400 wantIdle("after round trip", 1)
4401 }
4402
4403
4404
4405
4406
4407
4408 func TestTransportIdleConnRacesRequest(t *testing.T) {
4409
4410
4411 runSynctest(t, testTransportIdleConnRacesRequest, []testMode{http1Mode, http2UnencryptedMode})
4412 }
4413 func testTransportIdleConnRacesRequest(t *testing.T, mode testMode) {
4414 if mode == http2UnencryptedMode {
4415 t.Skip("remove skip when #70515 is fixed")
4416 }
4417 timeout := 1 * time.Millisecond
4418 trFunc := func(tr *Transport) {
4419 tr.IdleConnTimeout = timeout
4420 }
4421 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4422 }), trFunc, optFakeNet)
4423 cst.li.trackConns = true
4424
4425
4426
4427
4428
4429 dialc := make(chan struct{})
4430 cst.li.onDial = func() {
4431 <-dialc
4432 }
4433 closec := make(chan struct{})
4434 cst.li.onClose = func(*fakeNetConn) {
4435 <-closec
4436 }
4437 ctx, cancel := context.WithCancel(context.Background())
4438 req1c := make(chan error)
4439 go func() {
4440 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
4441 resp, err := cst.c.Do(req)
4442 if err == nil {
4443 resp.Body.Close()
4444 }
4445 req1c <- err
4446 }()
4447
4448 synctest.Wait()
4449
4450 cancel()
4451 synctest.Wait()
4452 if err := <-req1c; err == nil {
4453 t.Fatal("expected request to fail, but it succeeded")
4454 }
4455
4456 close(dialc)
4457
4458
4459
4460
4461
4462
4463
4464
4465 synctest.Wait()
4466 time.Sleep(timeout)
4467 synctest.Wait()
4468
4469 req2c := make(chan error)
4470 go func() {
4471 resp, err := cst.c.Get(cst.ts.URL)
4472 if err == nil {
4473 resp.Body.Close()
4474 }
4475 req2c <- err
4476 }()
4477
4478
4479 close(closec)
4480 if err := <-req2c; err != nil {
4481 t.Fatalf("Get: %v", err)
4482 }
4483 }
4484
4485 func TestTransportRemovesConnsAfterIdle(t *testing.T) {
4486 runSynctest(t, testTransportRemovesConnsAfterIdle)
4487 }
4488 func testTransportRemovesConnsAfterIdle(t *testing.T, mode testMode) {
4489 if testing.Short() {
4490 t.Skip("skipping in short mode")
4491 }
4492
4493 timeout := 1 * time.Second
4494 trFunc := func(tr *Transport) {
4495 tr.MaxConnsPerHost = 1
4496 tr.MaxIdleConnsPerHost = 1
4497 tr.IdleConnTimeout = timeout
4498 }
4499 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4500 w.Header().Set("X-Addr", r.RemoteAddr)
4501 }), trFunc, optFakeNet)
4502
4503
4504
4505 makeRequest := func() string {
4506 resp, err := cst.c.Get(cst.ts.URL)
4507 if err != nil {
4508 t.Fatalf("got error: %s", err)
4509 }
4510 resp.Body.Close()
4511 return resp.Header.Get("X-Addr")
4512 }
4513
4514 addr1 := makeRequest()
4515
4516 time.Sleep(timeout / 2)
4517 synctest.Wait()
4518 addr2 := makeRequest()
4519 if addr1 != addr2 {
4520 t.Fatalf("two requests made within IdleConnTimeout should have used the same conn, but used %v, %v", addr1, addr2)
4521 }
4522
4523 time.Sleep(timeout)
4524 synctest.Wait()
4525 addr3 := makeRequest()
4526 if addr1 == addr3 {
4527 t.Fatalf("two requests made more than IdleConnTimeout apart should have used different conns, but used %v, %v", addr1, addr3)
4528 }
4529 }
4530
4531 func TestTransportRemovesConnsAfterBroken(t *testing.T) {
4532 runSynctest(t, testTransportRemovesConnsAfterBroken)
4533 }
4534 func testTransportRemovesConnsAfterBroken(t *testing.T, mode testMode) {
4535 if testing.Short() {
4536 t.Skip("skipping in short mode")
4537 }
4538
4539 trFunc := func(tr *Transport) {
4540 tr.MaxConnsPerHost = 1
4541 tr.MaxIdleConnsPerHost = 1
4542 }
4543 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4544 w.Header().Set("X-Addr", r.RemoteAddr)
4545 }), trFunc, optFakeNet)
4546 cst.li.trackConns = true
4547
4548
4549
4550 makeRequest := func() string {
4551 resp, err := cst.c.Get(cst.ts.URL)
4552 if err != nil {
4553 t.Fatalf("got error: %s", err)
4554 }
4555 resp.Body.Close()
4556 return resp.Header.Get("X-Addr")
4557 }
4558
4559 addr1 := makeRequest()
4560 addr2 := makeRequest()
4561 if addr1 != addr2 {
4562 t.Fatalf("successive requests should have used the same conn, but used %v, %v", addr1, addr2)
4563 }
4564
4565
4566 synctest.Wait()
4567 cst.li.conns[0].peer.Close()
4568 synctest.Wait()
4569 addr3 := makeRequest()
4570 if addr1 == addr3 {
4571 t.Fatalf("successive requests made with conn broken between should have used different conns, but used %v, %v", addr1, addr3)
4572 }
4573 }
4574
4575
4576
4577
4578
4579 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
4580 func testTransportRangeAndGzip(t *testing.T, mode testMode) {
4581 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4582 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
4583 t.Error("Transport advertised gzip support in the Accept header")
4584 }
4585 if r.Header.Get("Range") == "" {
4586 t.Error("no Range in request")
4587 }
4588 })).ts
4589 c := ts.Client()
4590
4591 req, _ := NewRequest("GET", ts.URL, nil)
4592 req.Header.Set("Range", "bytes=7-11")
4593 res, err := c.Do(req)
4594 if err != nil {
4595 t.Fatal(err)
4596 }
4597 res.Body.Close()
4598 }
4599
4600
4601 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
4602 func testTransportResponseCancelRace(t *testing.T, mode testMode) {
4603 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4604
4605 var b [1024]byte
4606 w.Write(b[:])
4607 })).ts
4608 tr := ts.Client().Transport.(*Transport)
4609
4610 req, err := NewRequest("GET", ts.URL, nil)
4611 if err != nil {
4612 t.Fatal(err)
4613 }
4614 res, err := tr.RoundTrip(req)
4615 if err != nil {
4616 t.Fatal(err)
4617 }
4618
4619
4620
4621 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4622 t.Fatal(err)
4623 }
4624
4625 req2, err := NewRequest("GET", ts.URL, nil)
4626 if err != nil {
4627 t.Fatal(err)
4628 }
4629 tr.CancelRequest(req)
4630 res, err = tr.RoundTrip(req2)
4631 if err != nil {
4632 t.Fatal(err)
4633 }
4634 res.Body.Close()
4635 }
4636
4637
4638 func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
4639 run(t, testTransportContentEncodingCaseInsensitive)
4640 }
4641 func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
4642 for _, ce := range []string{"gzip", "GZIP"} {
4643 t.Run(ce, func(t *testing.T) {
4644 const encodedString = "Hello Gopher"
4645 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4646 w.Header().Set("Content-Encoding", ce)
4647 gz := gzip.NewWriter(w)
4648 gz.Write([]byte(encodedString))
4649 gz.Close()
4650 })).ts
4651
4652 res, err := ts.Client().Get(ts.URL)
4653 if err != nil {
4654 t.Fatal(err)
4655 }
4656
4657 body, err := io.ReadAll(res.Body)
4658 res.Body.Close()
4659 if err != nil {
4660 t.Fatal(err)
4661 }
4662
4663 if string(body) != encodedString {
4664 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
4665 }
4666 })
4667 }
4668 }
4669
4670
4671 func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
4672 run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
4673 }
4674 func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
4675 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
4676 func(tr *Transport) {
4677 tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
4678
4679 return &funcConn{
4680 read: func([]byte) (int, error) {
4681 return 0, errors.New("error")
4682 },
4683 write: func([]byte) (int, error) {
4684 return 0, errors.New("error")
4685 },
4686 }, nil
4687 }
4688 },
4689 ).ts
4690
4691
4692
4693
4694
4695 SetEnterRoundTripHook(func() {
4696 time.Sleep(1 * time.Millisecond)
4697 })
4698 defer SetEnterRoundTripHook(nil)
4699 var closes int
4700 _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
4701 if err == nil {
4702 t.Fatalf("expected request to fail, but it did not")
4703 }
4704 if closes != 1 {
4705 t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
4706 }
4707 }
4708
4709
4710
4711
4712 type logWritesConn struct {
4713 net.Conn
4714
4715 w io.Writer
4716
4717 rch <-chan io.Reader
4718 r io.Reader
4719
4720 mu sync.Mutex
4721 writes []string
4722 }
4723
4724 func (c *logWritesConn) Write(p []byte) (n int, err error) {
4725 c.mu.Lock()
4726 defer c.mu.Unlock()
4727 c.writes = append(c.writes, string(p))
4728 return c.w.Write(p)
4729 }
4730
4731 func (c *logWritesConn) Read(p []byte) (n int, err error) {
4732 if c.r == nil {
4733 c.r = <-c.rch
4734 }
4735 return c.r.Read(p)
4736 }
4737
4738 func (c *logWritesConn) Close() error { return nil }
4739
4740
4741 func TestTransportFlushesBodyChunks(t *testing.T) {
4742 defer afterTest(t)
4743 resBody := make(chan io.Reader, 1)
4744 connr, connw := io.Pipe()
4745 lw := &logWritesConn{
4746 rch: resBody,
4747 w: connw,
4748 }
4749 tr := &Transport{
4750 Dial: func(network, addr string) (net.Conn, error) {
4751 return lw, nil
4752 },
4753 }
4754 bodyr, bodyw := io.Pipe()
4755 go func() {
4756 defer bodyw.Close()
4757 for i := 0; i < 3; i++ {
4758 fmt.Fprintf(bodyw, "num%d\n", i)
4759 }
4760 }()
4761 resc := make(chan *Response)
4762 go func() {
4763 req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4764 req.Header.Set("User-Agent", "x")
4765 res, err := tr.RoundTrip(req)
4766 if err != nil {
4767 t.Errorf("RoundTrip: %v", err)
4768 close(resc)
4769 return
4770 }
4771 resc <- res
4772
4773 }()
4774
4775 req, err := ReadRequest(bufio.NewReader(connr))
4776 if err != nil {
4777 t.Fatal(err)
4778 }
4779 io.Copy(io.Discard, req.Body)
4780
4781
4782 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4783 res, ok := <-resc
4784 if !ok {
4785 return
4786 }
4787 defer res.Body.Close()
4788
4789 want := []string{
4790 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4791 "5\r\nnum0\n\r\n",
4792 "5\r\nnum1\n\r\n",
4793 "5\r\nnum2\n\r\n",
4794 "0\r\n\r\n",
4795 }
4796 if !slices.Equal(lw.writes, want) {
4797 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4798 }
4799 }
4800
4801
4802 func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
4803 func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
4804 gotReq := make(chan struct{})
4805 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4806 close(gotReq)
4807 }))
4808
4809 pr, pw := io.Pipe()
4810 req, err := NewRequest("POST", cst.ts.URL, pr)
4811 if err != nil {
4812 t.Fatal(err)
4813 }
4814 gotRes := make(chan struct{})
4815 go func() {
4816 defer close(gotRes)
4817 res, err := cst.tr.RoundTrip(req)
4818 if err != nil {
4819 t.Error(err)
4820 return
4821 }
4822 res.Body.Close()
4823 }()
4824
4825 <-gotReq
4826 pw.Close()
4827 <-gotRes
4828 }
4829
4830 type wgReadCloser struct {
4831 io.Reader
4832 wg *sync.WaitGroup
4833 closed bool
4834 }
4835
4836 func (c *wgReadCloser) Close() error {
4837 if c.closed {
4838 return net.ErrClosed
4839 }
4840 c.closed = true
4841 c.wg.Done()
4842 return nil
4843 }
4844
4845
4846 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4847
4848 run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
4849 }
4850 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
4851 if testing.Short() {
4852 t.Skip("skipping in short mode")
4853 }
4854
4855 runTimeSensitiveTest(t, []time.Duration{
4856 1 * time.Millisecond,
4857 5 * time.Millisecond,
4858 10 * time.Millisecond,
4859 50 * time.Millisecond,
4860 100 * time.Millisecond,
4861 500 * time.Millisecond,
4862 time.Second,
4863 5 * time.Second,
4864 }, func(t *testing.T, timeout time.Duration) error {
4865 SetRSTAvoidanceDelay(t, timeout)
4866 t.Logf("set RST avoidance delay to %v", timeout)
4867
4868 const contentLengthLimit = 1024 * 1024
4869 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4870 if r.ContentLength >= contentLengthLimit {
4871 w.WriteHeader(StatusBadRequest)
4872 r.Body.Close()
4873 return
4874 }
4875 w.WriteHeader(StatusOK)
4876 }))
4877
4878
4879 defer cst.close()
4880 ts := cst.ts
4881 c := ts.Client()
4882
4883 count := 100
4884
4885 bigBody := strings.Repeat("a", contentLengthLimit*2)
4886 var wg sync.WaitGroup
4887 defer wg.Wait()
4888 getBody := func() (io.ReadCloser, error) {
4889 wg.Add(1)
4890 body := &wgReadCloser{
4891 Reader: strings.NewReader(bigBody),
4892 wg: &wg,
4893 }
4894 return body, nil
4895 }
4896
4897 for i := 0; i < count; i++ {
4898 reqBody, _ := getBody()
4899 req, err := NewRequest("PUT", ts.URL, reqBody)
4900 if err != nil {
4901 reqBody.Close()
4902 t.Fatal(err)
4903 }
4904 req.ContentLength = int64(len(bigBody))
4905 req.GetBody = getBody
4906
4907 resp, err := c.Do(req)
4908 if err != nil {
4909 return fmt.Errorf("Do %d: %v", i, err)
4910 } else {
4911 resp.Body.Close()
4912 if resp.StatusCode != 400 {
4913 t.Errorf("Expected status code 400, got %v", resp.Status)
4914 }
4915 }
4916 }
4917 return nil
4918 })
4919 }
4920
4921 func TestTransportAutomaticHTTP2(t *testing.T) {
4922 testTransportAutoHTTP(t, &Transport{}, true)
4923 }
4924
4925 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4926 testTransportAutoHTTP(t, &Transport{
4927 ForceAttemptHTTP2: true,
4928 TLSClientConfig: new(tls.Config),
4929 }, true)
4930 }
4931
4932
4933 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4934 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4935 }
4936
4937 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4938 testTransportAutoHTTP(t, &Transport{
4939 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4940 }, false)
4941 }
4942
4943 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4944 testTransportAutoHTTP(t, &Transport{
4945 TLSClientConfig: new(tls.Config),
4946 }, false)
4947 }
4948
4949 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
4950 testTransportAutoHTTP(t, &Transport{
4951 ExpectContinueTimeout: 1 * time.Second,
4952 }, true)
4953 }
4954
4955 func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
4956 var d net.Dialer
4957 testTransportAutoHTTP(t, &Transport{
4958 Dial: d.Dial,
4959 }, false)
4960 }
4961
4962 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
4963 var d net.Dialer
4964 testTransportAutoHTTP(t, &Transport{
4965 DialContext: d.DialContext,
4966 }, false)
4967 }
4968
4969 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
4970 testTransportAutoHTTP(t, &Transport{
4971 DialTLS: func(network, addr string) (net.Conn, error) {
4972 panic("unused")
4973 },
4974 }, false)
4975 }
4976
4977 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
4978 CondSkipHTTP2(t)
4979 _, err := tr.RoundTrip(new(Request))
4980 if err == nil {
4981 t.Error("expected error from RoundTrip")
4982 }
4983 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
4984 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
4985 }
4986 }
4987
4988
4989
4990
4991
4992
4993
4994
4995 func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
4996 run(t, testTransportReuseConnEmptyResponseBody)
4997 }
4998 func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
4999 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5000 w.Header().Set("X-Addr", r.RemoteAddr)
5001
5002 }))
5003 n := 100
5004 if testing.Short() {
5005 n = 10
5006 }
5007 var firstAddr string
5008 for i := 0; i < n; i++ {
5009 res, err := cst.c.Get(cst.ts.URL)
5010 if err != nil {
5011 log.Fatal(err)
5012 }
5013 addr := res.Header.Get("X-Addr")
5014 if i == 0 {
5015 firstAddr = addr
5016 } else if addr != firstAddr {
5017 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
5018 }
5019 res.Body.Close()
5020 }
5021 }
5022
5023
5024 func TestNoCrashReturningTransportAltConn(t *testing.T) {
5025 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
5026 if err != nil {
5027 t.Fatal(err)
5028 }
5029 ln := newLocalListener(t)
5030 defer ln.Close()
5031
5032 var wg sync.WaitGroup
5033 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
5034 defer SetPendingDialHooks(nil, nil)
5035
5036 testDone := make(chan struct{})
5037 defer close(testDone)
5038 go func() {
5039 tln := tls.NewListener(ln, &tls.Config{
5040 NextProtos: []string{"foo"},
5041 Certificates: []tls.Certificate{cert},
5042 })
5043 sc, err := tln.Accept()
5044 if err != nil {
5045 t.Error(err)
5046 return
5047 }
5048 if err := sc.(*tls.Conn).Handshake(); err != nil {
5049 t.Error(err)
5050 return
5051 }
5052 <-testDone
5053 sc.Close()
5054 }()
5055
5056 addr := ln.Addr().String()
5057
5058 req, _ := NewRequest("GET", "https://fake.tld/", nil)
5059 cancel := make(chan struct{})
5060 req.Cancel = cancel
5061
5062 doReturned := make(chan bool, 1)
5063 madeRoundTripper := make(chan bool, 1)
5064
5065 tr := &Transport{
5066 DisableKeepAlives: true,
5067 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
5068 "foo": func(authority string, c *tls.Conn) RoundTripper {
5069 madeRoundTripper <- true
5070 return funcRoundTripper(func() {
5071 t.Error("foo RoundTripper should not be called")
5072 })
5073 },
5074 },
5075 Dial: func(_, _ string) (net.Conn, error) {
5076 panic("shouldn't be called")
5077 },
5078 DialTLS: func(_, _ string) (net.Conn, error) {
5079 tc, err := tls.Dial("tcp", addr, &tls.Config{
5080 InsecureSkipVerify: true,
5081 NextProtos: []string{"foo"},
5082 })
5083 if err != nil {
5084 return nil, err
5085 }
5086 if err := tc.Handshake(); err != nil {
5087 return nil, err
5088 }
5089 close(cancel)
5090 <-doReturned
5091 return tc, nil
5092 },
5093 }
5094 c := &Client{Transport: tr}
5095
5096 _, err = c.Do(req)
5097 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
5098 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
5099 }
5100
5101 doReturned <- true
5102 <-madeRoundTripper
5103 wg.Wait()
5104 }
5105
5106 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
5107 run(t, func(t *testing.T, mode testMode) {
5108 testTransportReuseConnection_Gzip(t, mode, true)
5109 })
5110 }
5111
5112 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
5113 run(t, func(t *testing.T, mode testMode) {
5114 testTransportReuseConnection_Gzip(t, mode, false)
5115 })
5116 }
5117
5118
5119 func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
5120 addr := make(chan string, 2)
5121 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5122 addr <- r.RemoteAddr
5123 w.Header().Set("Content-Encoding", "gzip")
5124 if chunked {
5125 w.(Flusher).Flush()
5126 }
5127 w.Write(rgz)
5128 })).ts
5129 c := ts.Client()
5130
5131 trace := &httptrace.ClientTrace{
5132 GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
5133 GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
5134 PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
5135 ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
5136 ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
5137 }
5138 ctx := httptrace.WithClientTrace(context.Background(), trace)
5139
5140 for i := 0; i < 2; i++ {
5141 req, _ := NewRequest("GET", ts.URL, nil)
5142 req = req.WithContext(ctx)
5143 res, err := c.Do(req)
5144 if err != nil {
5145 t.Fatal(err)
5146 }
5147 buf := make([]byte, len(rgz))
5148 if n, err := io.ReadFull(res.Body, buf); err != nil {
5149 t.Errorf("%d. ReadFull = %v, %v", i, n, err)
5150 }
5151
5152
5153
5154 }
5155 a1, a2 := <-addr, <-addr
5156 if a1 != a2 {
5157 t.Fatalf("didn't reuse connection")
5158 }
5159 }
5160
5161 func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
5162 func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
5163 if mode == http2Mode {
5164 t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
5165 }
5166 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5167 if r.URL.Path == "/long" {
5168 w.Header().Set("Long", strings.Repeat("a", 1<<20))
5169 }
5170 })).ts
5171 c := ts.Client()
5172 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
5173
5174 if res, err := c.Get(ts.URL); err != nil {
5175 t.Fatal(err)
5176 } else {
5177 res.Body.Close()
5178 }
5179
5180 res, err := c.Get(ts.URL + "/long")
5181 if err == nil {
5182 defer res.Body.Close()
5183 var n int64
5184 for k, vv := range res.Header {
5185 for _, v := range vv {
5186 n += int64(len(k)) + int64(len(v))
5187 }
5188 }
5189 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
5190 }
5191 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
5192 t.Errorf("got error: %v; want %q", err, want)
5193 }
5194 }
5195
5196 func TestTransportEventTrace(t *testing.T) {
5197 run(t, func(t *testing.T, mode testMode) {
5198 testTransportEventTrace(t, mode, false)
5199 }, testNotParallel)
5200 }
5201
5202
5203 func TestTransportEventTrace_NoHooks(t *testing.T) {
5204 run(t, func(t *testing.T, mode testMode) {
5205 testTransportEventTrace(t, mode, true)
5206 }, testNotParallel)
5207 }
5208
5209 func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
5210 const resBody = "some body"
5211 gotWroteReqEvent := make(chan struct{}, 500)
5212 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5213 if r.Method == "GET" {
5214
5215 return
5216 }
5217 if _, err := io.ReadAll(r.Body); err != nil {
5218 t.Error(err)
5219 }
5220 if !noHooks {
5221 <-gotWroteReqEvent
5222 }
5223 io.WriteString(w, resBody)
5224 }), func(tr *Transport) {
5225 if tr.TLSClientConfig != nil {
5226 tr.TLSClientConfig.InsecureSkipVerify = true
5227 }
5228 })
5229 defer cst.close()
5230
5231 cst.tr.ExpectContinueTimeout = 1 * time.Second
5232
5233 var mu sync.Mutex
5234 var buf strings.Builder
5235 logf := func(format string, args ...any) {
5236 mu.Lock()
5237 defer mu.Unlock()
5238 fmt.Fprintf(&buf, format, args...)
5239 buf.WriteByte('\n')
5240 }
5241
5242 addrStr := cst.ts.Listener.Addr().String()
5243 ip, port, err := net.SplitHostPort(addrStr)
5244 if err != nil {
5245 t.Fatal(err)
5246 }
5247
5248
5249 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5250 if host != "dns-is-faked.golang" {
5251 t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
5252 return nil, nil
5253 }
5254 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5255 })
5256
5257 body := "some body"
5258 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
5259 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
5260 trace := &httptrace.ClientTrace{
5261 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
5262 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
5263 GotFirstResponseByte: func() { logf("first response byte") },
5264 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
5265 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
5266 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
5267 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
5268 ConnectDone: func(network, addr string, err error) {
5269 if err != nil {
5270 t.Errorf("ConnectDone: %v", err)
5271 }
5272 logf("ConnectDone: connected to %s %s = %v", network, addr, err)
5273 },
5274 WroteHeaderField: func(key string, value []string) {
5275 logf("WroteHeaderField: %s: %v", key, value)
5276 },
5277 WroteHeaders: func() {
5278 logf("WroteHeaders")
5279 },
5280 Wait100Continue: func() { logf("Wait100Continue") },
5281 Got100Continue: func() { logf("Got100Continue") },
5282 WroteRequest: func(e httptrace.WroteRequestInfo) {
5283 logf("WroteRequest: %+v", e)
5284 gotWroteReqEvent <- struct{}{}
5285 },
5286 }
5287 if mode == http2Mode {
5288 trace.TLSHandshakeStart = func() { logf("tls handshake start") }
5289 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
5290 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
5291 }
5292 }
5293 if noHooks {
5294
5295 *trace = httptrace.ClientTrace{}
5296 }
5297 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5298
5299 req.Header.Set("Expect", "100-continue")
5300 res, err := cst.c.Do(req)
5301 if err != nil {
5302 t.Fatal(err)
5303 }
5304 logf("got roundtrip.response")
5305 slurp, err := io.ReadAll(res.Body)
5306 if err != nil {
5307 t.Fatal(err)
5308 }
5309 logf("consumed body")
5310 if string(slurp) != resBody || res.StatusCode != 200 {
5311 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
5312 }
5313 res.Body.Close()
5314
5315 if noHooks {
5316
5317
5318
5319 return
5320 }
5321
5322 mu.Lock()
5323 got := buf.String()
5324 mu.Unlock()
5325
5326 wantOnce := func(sub string) {
5327 if strings.Count(got, sub) != 1 {
5328 t.Errorf("expected substring %q exactly once in output.", sub)
5329 }
5330 }
5331 wantOnceOrMore := func(sub string) {
5332 if strings.Count(got, sub) == 0 {
5333 t.Errorf("expected substring %q at least once in output.", sub)
5334 }
5335 }
5336 wantOnce("Getting conn for dns-is-faked.golang:" + port)
5337 wantOnce("DNS start: {Host:dns-is-faked.golang}")
5338 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
5339 wantOnce("got conn: {")
5340 wantOnceOrMore("Connecting to tcp " + addrStr)
5341 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
5342 wantOnce("Reused:false WasIdle:false IdleTime:0s")
5343 wantOnce("first response byte")
5344 if mode == http2Mode {
5345 wantOnce("tls handshake start")
5346 wantOnce("tls handshake done")
5347 } else {
5348 wantOnce("PutIdleConn = <nil>")
5349 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
5350
5351
5352 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
5353 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
5354 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
5355 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
5356 }
5357 wantOnce("WroteHeaders")
5358 wantOnce("Wait100Continue")
5359 wantOnce("Got100Continue")
5360 wantOnce("WroteRequest: {Err:<nil>}")
5361 if strings.Contains(got, " to udp ") {
5362 t.Errorf("should not see UDP (DNS) connections")
5363 }
5364 if t.Failed() {
5365 t.Errorf("Output:\n%s", got)
5366 }
5367
5368
5369 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
5370 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5371 res, err = cst.c.Do(req)
5372 if err != nil {
5373 t.Fatal(err)
5374 }
5375 if res.StatusCode != 200 {
5376 t.Fatal(res.Status)
5377 }
5378 res.Body.Close()
5379
5380 mu.Lock()
5381 got = buf.String()
5382 mu.Unlock()
5383
5384 sub := "Getting conn for dns-is-faked.golang:"
5385 if gotn, want := strings.Count(got, sub), 2; gotn != want {
5386 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
5387 }
5388
5389 }
5390
5391 func TestTransportEventTraceTLSVerify(t *testing.T) {
5392 run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
5393 }
5394 func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
5395 var mu sync.Mutex
5396 var buf strings.Builder
5397 logf := func(format string, args ...any) {
5398 mu.Lock()
5399 defer mu.Unlock()
5400 fmt.Fprintf(&buf, format, args...)
5401 buf.WriteByte('\n')
5402 }
5403
5404 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5405 t.Error("Unexpected request")
5406 }), func(ts *httptest.Server) {
5407 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
5408 logf("%s", p)
5409 return len(p), nil
5410 }), "", 0)
5411 }).ts
5412
5413 certpool := x509.NewCertPool()
5414 certpool.AddCert(ts.Certificate())
5415
5416 c := &Client{Transport: &Transport{
5417 TLSClientConfig: &tls.Config{
5418 ServerName: "dns-is-faked.golang",
5419 RootCAs: certpool,
5420 },
5421 }}
5422
5423 trace := &httptrace.ClientTrace{
5424 TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
5425 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5426 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
5427 },
5428 }
5429
5430 req, _ := NewRequest("GET", ts.URL, nil)
5431 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5432 _, err := c.Do(req)
5433 if err == nil {
5434 t.Error("Expected request to fail TLS verification")
5435 }
5436
5437 mu.Lock()
5438 got := buf.String()
5439 mu.Unlock()
5440
5441 wantOnce := func(sub string) {
5442 if strings.Count(got, sub) != 1 {
5443 t.Errorf("expected substring %q exactly once in output.", sub)
5444 }
5445 }
5446
5447 wantOnce("TLSHandshakeStart")
5448 wantOnce("TLSHandshakeDone")
5449 wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
5450
5451 if t.Failed() {
5452 t.Errorf("Output:\n%s", got)
5453 }
5454 }
5455
5456 var isDNSHijacked = sync.OnceValue(func() bool {
5457 addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
5458 return len(addrs) != 0
5459 })
5460
5461 func skipIfDNSHijacked(t *testing.T) {
5462
5463
5464
5465 if isDNSHijacked() {
5466 t.Skip("skipping; test requires non-hijacking DNS server")
5467 }
5468 }
5469
5470 func TestTransportEventTraceRealDNS(t *testing.T) {
5471 skipIfDNSHijacked(t)
5472 defer afterTest(t)
5473 tr := &Transport{}
5474 defer tr.CloseIdleConnections()
5475 c := &Client{Transport: tr}
5476
5477 var mu sync.Mutex
5478 var buf strings.Builder
5479 logf := func(format string, args ...any) {
5480 mu.Lock()
5481 defer mu.Unlock()
5482 fmt.Fprintf(&buf, format, args...)
5483 buf.WriteByte('\n')
5484 }
5485
5486 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
5487 trace := &httptrace.ClientTrace{
5488 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
5489 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
5490 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
5491 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
5492 }
5493 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5494
5495 resp, err := c.Do(req)
5496 if err == nil {
5497 resp.Body.Close()
5498 t.Fatal("expected error during DNS lookup")
5499 }
5500
5501 mu.Lock()
5502 got := buf.String()
5503 mu.Unlock()
5504
5505 wantSub := func(sub string) {
5506 if !strings.Contains(got, sub) {
5507 t.Errorf("expected substring %q in output.", sub)
5508 }
5509 }
5510 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
5511 wantSub("DNSDone: {Addrs:[] Err:")
5512 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
5513 t.Errorf("should not see Connect events")
5514 }
5515 if t.Failed() {
5516 t.Errorf("Output:\n%s", got)
5517 }
5518 }
5519
5520
5521 func TestTransportRejectsAlphaPort(t *testing.T) {
5522 res, err := Get("http://dummy.tld:123foo/bar")
5523 if err == nil {
5524 res.Body.Close()
5525 t.Fatal("unexpected success")
5526 }
5527 ue, ok := err.(*url.Error)
5528 if !ok {
5529 t.Fatalf("got %#v; want *url.Error", err)
5530 }
5531 got := ue.Err.Error()
5532 want := `invalid port ":123foo" after host`
5533 if got != want {
5534 t.Errorf("got error %q; want %q", got, want)
5535 }
5536 }
5537
5538
5539
5540 func TestTLSHandshakeTrace(t *testing.T) {
5541 run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
5542 }
5543 func testTLSHandshakeTrace(t *testing.T, mode testMode) {
5544 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
5545
5546 var mu sync.Mutex
5547 var start, done bool
5548 trace := &httptrace.ClientTrace{
5549 TLSHandshakeStart: func() {
5550 mu.Lock()
5551 defer mu.Unlock()
5552 start = true
5553 },
5554 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5555 mu.Lock()
5556 defer mu.Unlock()
5557 done = true
5558 if err != nil {
5559 t.Fatal("Expected error to be nil but was:", err)
5560 }
5561 },
5562 }
5563
5564 c := ts.Client()
5565 req, err := NewRequest("GET", ts.URL, nil)
5566 if err != nil {
5567 t.Fatal("Unable to construct test request:", err)
5568 }
5569 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
5570
5571 r, err := c.Do(req)
5572 if err != nil {
5573 t.Fatal("Unexpected error making request:", err)
5574 }
5575 r.Body.Close()
5576 mu.Lock()
5577 defer mu.Unlock()
5578 if !start {
5579 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
5580 }
5581 if !done {
5582 t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
5583 }
5584 }
5585
5586 func TestTransportMaxIdleConns(t *testing.T) {
5587 run(t, testTransportMaxIdleConns, []testMode{http1Mode})
5588 }
5589 func testTransportMaxIdleConns(t *testing.T, mode testMode) {
5590 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5591
5592 })).ts
5593 c := ts.Client()
5594 tr := c.Transport.(*Transport)
5595 tr.MaxIdleConns = 4
5596
5597 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
5598 if err != nil {
5599 t.Fatal(err)
5600 }
5601 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
5602 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5603 })
5604
5605 hitHost := func(n int) {
5606 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
5607 req = req.WithContext(ctx)
5608 res, err := c.Do(req)
5609 if err != nil {
5610 t.Fatal(err)
5611 }
5612 res.Body.Close()
5613 }
5614 for i := 0; i < 4; i++ {
5615 hitHost(i)
5616 }
5617 want := []string{
5618 "|http|host-0.dns-is-faked.golang:" + port,
5619 "|http|host-1.dns-is-faked.golang:" + port,
5620 "|http|host-2.dns-is-faked.golang:" + port,
5621 "|http|host-3.dns-is-faked.golang:" + port,
5622 }
5623 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5624 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
5625 }
5626
5627
5628 hitHost(4)
5629 want = []string{
5630 "|http|host-1.dns-is-faked.golang:" + port,
5631 "|http|host-2.dns-is-faked.golang:" + port,
5632 "|http|host-3.dns-is-faked.golang:" + port,
5633 "|http|host-4.dns-is-faked.golang:" + port,
5634 }
5635 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5636 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
5637 }
5638 }
5639
5640 func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
5641 func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
5642 if testing.Short() {
5643 t.Skip("skipping in short mode")
5644 }
5645
5646 timeout := 1 * time.Millisecond
5647 timeoutLoop:
5648 for {
5649 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5650
5651 }))
5652 tr := cst.tr
5653 tr.IdleConnTimeout = timeout
5654 defer tr.CloseIdleConnections()
5655 c := &Client{Transport: tr}
5656
5657 idleConns := func() []string {
5658 return tr.IdleConnStrsForTesting()
5659 }
5660
5661 var conn string
5662 doReq := func(n int) (timeoutOk bool) {
5663 req, _ := NewRequest("GET", cst.ts.URL, nil)
5664 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5665 PutIdleConn: func(err error) {
5666 if err != nil {
5667 t.Errorf("failed to keep idle conn: %v", err)
5668 }
5669 },
5670 }))
5671 res, err := c.Do(req)
5672 if err != nil {
5673 if strings.Contains(err.Error(), "use of closed network connection") {
5674 t.Logf("req %v: connection closed prematurely", n)
5675 return false
5676 }
5677 }
5678 if err == nil {
5679 res.Body.Close()
5680 }
5681 conns := idleConns()
5682 if len(conns) != 1 {
5683 if len(conns) == 0 {
5684 t.Logf("req %v: no idle conns", n)
5685 return false
5686 }
5687 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
5688 }
5689 if conn == "" {
5690 conn = conns[0]
5691 }
5692 if conn != conns[0] {
5693 t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
5694 return false
5695 }
5696 return true
5697 }
5698 for i := 0; i < 3; i++ {
5699 if !doReq(i) {
5700 t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
5701 timeout *= 2
5702 cst.close()
5703 continue timeoutLoop
5704 }
5705 time.Sleep(timeout / 2)
5706 }
5707
5708 waitCondition(t, timeout/2, func(d time.Duration) bool {
5709 if got := idleConns(); len(got) != 0 {
5710 if d >= timeout*3/2 {
5711 t.Logf("after %v, idle conns = %q", d, got)
5712 }
5713 return false
5714 }
5715 return true
5716 })
5717 break
5718 }
5719 }
5720
5721
5722
5723
5724
5725
5726
5727
5728
5729
5730
5731
5732 func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
5733 func testIdleConnH2Crash(t *testing.T, mode testMode) {
5734 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5735
5736 }))
5737
5738 ctx, cancel := context.WithCancel(context.Background())
5739 defer cancel()
5740
5741 sawDoErr := make(chan bool, 1)
5742 testDone := make(chan struct{})
5743 defer close(testDone)
5744
5745 cst.tr.IdleConnTimeout = 5 * time.Millisecond
5746 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
5747 c, err := tls.Dial(network, addr, &tls.Config{
5748 InsecureSkipVerify: true,
5749 NextProtos: []string{"h2"},
5750 })
5751 if err != nil {
5752 t.Error(err)
5753 return nil, err
5754 }
5755 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
5756 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
5757 c.Close()
5758 return nil, errors.New("bogus")
5759 }
5760
5761 cancel()
5762
5763 select {
5764 case <-sawDoErr:
5765 case <-testDone:
5766 }
5767 return c, nil
5768 }
5769
5770 req, _ := NewRequest("GET", cst.ts.URL, nil)
5771 req = req.WithContext(ctx)
5772 res, err := cst.c.Do(req)
5773 if err == nil {
5774 res.Body.Close()
5775 t.Fatal("unexpected success")
5776 }
5777 sawDoErr <- true
5778
5779
5780 time.Sleep(cst.tr.IdleConnTimeout * 10)
5781 }
5782
5783 type funcConn struct {
5784 net.Conn
5785 read func([]byte) (int, error)
5786 write func([]byte) (int, error)
5787 }
5788
5789 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
5790 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5791 func (c funcConn) Close() error { return nil }
5792
5793
5794
5795 func TestTransportReturnsPeekError(t *testing.T) {
5796 errValue := errors.New("specific error value")
5797
5798 wrote := make(chan struct{})
5799 wroteOnce := sync.OnceFunc(func() { close(wrote) })
5800
5801 tr := &Transport{
5802 Dial: func(network, addr string) (net.Conn, error) {
5803 c := funcConn{
5804 read: func([]byte) (int, error) {
5805 <-wrote
5806 return 0, errValue
5807 },
5808 write: func(p []byte) (int, error) {
5809 wroteOnce()
5810 return len(p), nil
5811 },
5812 }
5813 return c, nil
5814 },
5815 }
5816 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5817 if err != errValue {
5818 t.Errorf("error = %#v; want %v", err, errValue)
5819 }
5820 }
5821
5822
5823 func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
5824 func testTransportIDNA(t *testing.T, mode testMode) {
5825 const uniDomain = "гофер.го"
5826 const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5827
5828 var port string
5829 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5830 want := punyDomain + ":" + port
5831 if r.Host != want {
5832 t.Errorf("Host header = %q; want %q", r.Host, want)
5833 }
5834 if mode == http2Mode {
5835 if r.TLS == nil {
5836 t.Errorf("r.TLS == nil")
5837 } else if r.TLS.ServerName != punyDomain {
5838 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5839 }
5840 }
5841 w.Header().Set("Hit-Handler", "1")
5842 }), func(tr *Transport) {
5843 if tr.TLSClientConfig != nil {
5844 tr.TLSClientConfig.InsecureSkipVerify = true
5845 }
5846 })
5847
5848 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5849 if err != nil {
5850 t.Fatal(err)
5851 }
5852
5853
5854 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5855 if host != punyDomain {
5856 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5857 return nil, nil
5858 }
5859 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5860 })
5861
5862 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5863 trace := &httptrace.ClientTrace{
5864 GetConn: func(hostPort string) {
5865 want := net.JoinHostPort(punyDomain, port)
5866 if hostPort != want {
5867 t.Errorf("getting conn for %q; want %q", hostPort, want)
5868 }
5869 },
5870 DNSStart: func(e httptrace.DNSStartInfo) {
5871 if e.Host != punyDomain {
5872 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5873 }
5874 },
5875 }
5876 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5877
5878 res, err := cst.tr.RoundTrip(req)
5879 if err != nil {
5880 t.Fatal(err)
5881 }
5882 defer res.Body.Close()
5883 if res.Header.Get("Hit-Handler") != "1" {
5884 out, err := httputil.DumpResponse(res, true)
5885 if err != nil {
5886 t.Fatal(err)
5887 }
5888 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5889 }
5890 }
5891
5892
5893 func TestTransportProxyConnectHeader(t *testing.T) {
5894 run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
5895 }
5896 func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
5897 reqc := make(chan *Request, 1)
5898 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5899 if r.Method != "CONNECT" {
5900 t.Errorf("method = %q; want CONNECT", r.Method)
5901 }
5902 reqc <- r
5903 c, _, err := w.(Hijacker).Hijack()
5904 if err != nil {
5905 t.Errorf("Hijack: %v", err)
5906 return
5907 }
5908 c.Close()
5909 })).ts
5910
5911 c := ts.Client()
5912 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5913 return url.Parse(ts.URL)
5914 }
5915 c.Transport.(*Transport).ProxyConnectHeader = Header{
5916 "User-Agent": {"foo"},
5917 "Other": {"bar"},
5918 }
5919
5920 res, err := c.Get("https://dummy.tld/")
5921 if err == nil {
5922 res.Body.Close()
5923 t.Errorf("unexpected success")
5924 }
5925
5926 r := <-reqc
5927 if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5928 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5929 }
5930 if got, want := r.Header.Get("Other"), "bar"; got != want {
5931 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5932 }
5933 }
5934
5935 func TestTransportProxyGetConnectHeader(t *testing.T) {
5936 run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
5937 }
5938 func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
5939 reqc := make(chan *Request, 1)
5940 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5941 if r.Method != "CONNECT" {
5942 t.Errorf("method = %q; want CONNECT", r.Method)
5943 }
5944 reqc <- r
5945 c, _, err := w.(Hijacker).Hijack()
5946 if err != nil {
5947 t.Errorf("Hijack: %v", err)
5948 return
5949 }
5950 c.Close()
5951 })).ts
5952
5953 c := ts.Client()
5954 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5955 return url.Parse(ts.URL)
5956 }
5957
5958 c.Transport.(*Transport).ProxyConnectHeader = Header{
5959 "User-Agent": {"foo"},
5960 "Other": {"bar"},
5961 }
5962 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
5963 return Header{
5964 "User-Agent": {"foo2"},
5965 "Other": {"bar2"},
5966 }, nil
5967 }
5968
5969 res, err := c.Get("https://dummy.tld/")
5970 if err == nil {
5971 res.Body.Close()
5972 t.Errorf("unexpected success")
5973 }
5974
5975 r := <-reqc
5976 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
5977 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5978 }
5979 if got, want := r.Header.Get("Other"), "bar2"; got != want {
5980 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5981 }
5982 }
5983
5984 var errFakeRoundTrip = errors.New("fake roundtrip")
5985
5986 type funcRoundTripper func()
5987
5988 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
5989 fn()
5990 return nil, errFakeRoundTrip
5991 }
5992
5993 func wantBody(res *Response, err error, want string) error {
5994 if err != nil {
5995 return err
5996 }
5997 slurp, err := io.ReadAll(res.Body)
5998 if err != nil {
5999 return fmt.Errorf("error reading body: %v", err)
6000 }
6001 if string(slurp) != want {
6002 return fmt.Errorf("body = %q; want %q", slurp, want)
6003 }
6004 if err := res.Body.Close(); err != nil {
6005 return fmt.Errorf("body Close = %v", err)
6006 }
6007 return nil
6008 }
6009
6010 func newLocalListener(t *testing.T) net.Listener {
6011 ln, err := net.Listen("tcp", "127.0.0.1:0")
6012 if err != nil {
6013 ln, err = net.Listen("tcp6", "[::1]:0")
6014 }
6015 if err != nil {
6016 t.Fatal(err)
6017 }
6018 return ln
6019 }
6020
6021 type countCloseReader struct {
6022 n *int
6023 io.Reader
6024 }
6025
6026 func (cr countCloseReader) Close() error {
6027 (*cr.n)++
6028 return nil
6029 }
6030
6031
6032 var rgz = []byte{
6033 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
6034 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
6035 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
6036 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
6037 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
6038 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
6039 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
6040 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
6041 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
6042 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
6043 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
6044 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
6045 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
6046 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
6047 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
6048 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
6049 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
6050 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
6051 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
6052 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
6053 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
6054 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
6055 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
6056 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
6057 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
6058 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
6059 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
6060 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
6061 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
6062 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
6063 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
6064 0x00, 0x00,
6065 }
6066
6067
6068
6069 func TestMissingStatusNoPanic(t *testing.T) {
6070 t.Parallel()
6071
6072 const want = "unknown status code"
6073
6074 ln := newLocalListener(t)
6075 addr := ln.Addr().String()
6076 done := make(chan bool)
6077 fullAddrURL := fmt.Sprintf("http://%s", addr)
6078 raw := "HTTP/1.1 400\r\n" +
6079 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6080 "Content-Type: text/html; charset=utf-8\r\n" +
6081 "Content-Length: 10\r\n" +
6082 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
6083 "Vary: Accept-Encoding\r\n\r\n" +
6084 "Aloha Olaa"
6085
6086 go func() {
6087 defer close(done)
6088
6089 conn, _ := ln.Accept()
6090 if conn != nil {
6091 io.WriteString(conn, raw)
6092 io.ReadAll(conn)
6093 conn.Close()
6094 }
6095 }()
6096
6097 proxyURL, err := url.Parse(fullAddrURL)
6098 if err != nil {
6099 t.Fatalf("proxyURL: %v", err)
6100 }
6101
6102 tr := &Transport{Proxy: ProxyURL(proxyURL)}
6103
6104 req, _ := NewRequest("GET", "https://golang.org/", nil)
6105 res, err, panicked := doFetchCheckPanic(tr, req)
6106 if panicked {
6107 t.Error("panicked, expecting an error")
6108 }
6109 if res != nil && res.Body != nil {
6110 io.Copy(io.Discard, res.Body)
6111 res.Body.Close()
6112 }
6113
6114 if err == nil || !strings.Contains(err.Error(), want) {
6115 t.Errorf("got=%v want=%q", err, want)
6116 }
6117
6118 ln.Close()
6119 <-done
6120 }
6121
6122 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
6123 defer func() {
6124 if r := recover(); r != nil {
6125 panicked = true
6126 }
6127 }()
6128 res, err = tr.RoundTrip(req)
6129 return
6130 }
6131
6132
6133
6134 func TestNoBodyOnChunked304Response(t *testing.T) {
6135 run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
6136 }
6137 func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
6138 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6139 conn, buf, _ := w.(Hijacker).Hijack()
6140 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
6141 buf.Flush()
6142 conn.Close()
6143 }))
6144
6145
6146
6147
6148
6149 cst.tr.DisableKeepAlives = true
6150
6151 res, err := cst.c.Get(cst.ts.URL)
6152 if err != nil {
6153 t.Fatal(err)
6154 }
6155
6156 if res.Body != NoBody {
6157 t.Errorf("Unexpected body on 304 response")
6158 }
6159 }
6160
6161 type funcWriter func([]byte) (int, error)
6162
6163 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
6164
6165 type doneContext struct {
6166 context.Context
6167 err error
6168 }
6169
6170 func (doneContext) Done() <-chan struct{} {
6171 c := make(chan struct{})
6172 close(c)
6173 return c
6174 }
6175
6176 func (d doneContext) Err() error { return d.err }
6177
6178
6179 func TestTransportCheckContextDoneEarly(t *testing.T) {
6180 tr := &Transport{}
6181 req, _ := NewRequest("GET", "http://fake.example/", nil)
6182 wantErr := errors.New("some error")
6183 req = req.WithContext(doneContext{context.Background(), wantErr})
6184 _, err := tr.RoundTrip(req)
6185 if err != wantErr {
6186 t.Errorf("error = %v; want %v", err, wantErr)
6187 }
6188 }
6189
6190
6191
6192
6193
6194
6195 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
6196 run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
6197 }
6198 func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
6199 timeout := 1 * time.Millisecond
6200 for {
6201 inHandler := make(chan bool)
6202 cancelHandler := make(chan struct{})
6203 handlerDone := make(chan bool)
6204 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6205 <-r.Context().Done()
6206
6207 select {
6208 case <-cancelHandler:
6209 return
6210 case inHandler <- true:
6211 }
6212 defer func() { handlerDone <- true }()
6213
6214
6215 conn, _, err := w.(Hijacker).Hijack()
6216 if err != nil {
6217 t.Error(err)
6218 return
6219 }
6220 n, err := conn.Read([]byte{0})
6221 if n != 0 || err != io.EOF {
6222 t.Errorf("unexpected Read result: %v, %v", n, err)
6223 }
6224 conn.Close()
6225 }))
6226
6227 cst.c.Timeout = timeout
6228
6229 _, err := cst.c.Get(cst.ts.URL)
6230 if err == nil {
6231 close(cancelHandler)
6232 t.Fatal("unexpected Get success")
6233 }
6234
6235 tooSlow := time.NewTimer(timeout * 10)
6236 select {
6237 case <-tooSlow.C:
6238
6239
6240
6241 t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
6242 close(cancelHandler)
6243 cst.close()
6244 timeout *= 2
6245 continue
6246 case <-inHandler:
6247 tooSlow.Stop()
6248 <-handlerDone
6249 }
6250 break
6251 }
6252 }
6253
6254
6255
6256
6257
6258
6259 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
6260 run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
6261 }
6262 func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
6263 inHandler := make(chan bool)
6264 cancelHandler := make(chan struct{})
6265 handlerDone := make(chan bool)
6266 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6267 w.Header().Set("Content-Length", "100")
6268 w.(Flusher).Flush()
6269
6270 select {
6271 case <-cancelHandler:
6272 return
6273 case inHandler <- true:
6274 }
6275 defer func() { handlerDone <- true }()
6276
6277 conn, _, err := w.(Hijacker).Hijack()
6278 if err != nil {
6279 t.Error(err)
6280 return
6281 }
6282 conn.Write([]byte("foo"))
6283
6284 n, err := conn.Read([]byte{0})
6285
6286
6287
6288
6289
6290 if n != 0 || err == nil {
6291 t.Errorf("unexpected Read result: %v, %v", n, err)
6292 }
6293 conn.Close()
6294 }))
6295
6296
6297
6298
6299
6300 cst.c.Timeout = 24 * time.Hour
6301 req, _ := NewRequest("GET", cst.ts.URL, nil)
6302 cancelReq := make(chan struct{})
6303 req.Cancel = cancelReq
6304
6305 res, err := cst.c.Do(req)
6306 if err != nil {
6307 close(cancelHandler)
6308 t.Fatalf("Get error: %v", err)
6309 }
6310
6311
6312
6313
6314 close(cancelReq)
6315 got, err := io.ReadAll(res.Body)
6316 if err == nil {
6317 t.Errorf("unexpected success; read %q, nil", got)
6318 }
6319
6320
6321 <-inHandler
6322 <-handlerDone
6323 }
6324
6325 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
6326 run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
6327 }
6328 func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
6329 done := make(chan struct{})
6330 defer close(done)
6331 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6332 conn, _, err := w.(Hijacker).Hijack()
6333 if err != nil {
6334 t.Error(err)
6335 return
6336 }
6337 defer conn.Close()
6338 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
6339 bs := bufio.NewScanner(conn)
6340 bs.Scan()
6341 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
6342 <-done
6343 }))
6344
6345 req, _ := NewRequest("GET", cst.ts.URL, nil)
6346 req.Header.Set("Upgrade", "foo")
6347 req.Header.Set("Connection", "upgrade")
6348 res, err := cst.c.Do(req)
6349 if err != nil {
6350 t.Fatal(err)
6351 }
6352 if res.StatusCode != 101 {
6353 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
6354 }
6355 rwc, ok := res.Body.(io.ReadWriteCloser)
6356 if !ok {
6357 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
6358 }
6359 defer rwc.Close()
6360 bs := bufio.NewScanner(rwc)
6361 if !bs.Scan() {
6362 t.Fatalf("expected readable input")
6363 }
6364 if got, want := bs.Text(), "Some buffered data"; got != want {
6365 t.Errorf("read %q; want %q", got, want)
6366 }
6367 io.WriteString(rwc, "echo\n")
6368 if !bs.Scan() {
6369 t.Fatalf("expected another line")
6370 }
6371 if got, want := bs.Text(), "ECHO"; got != want {
6372 t.Errorf("read %q; want %q", got, want)
6373 }
6374 }
6375
6376 func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
6377 func testTransportCONNECTBidi(t *testing.T, mode testMode) {
6378 const target = "backend:443"
6379 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6380 if r.Method != "CONNECT" {
6381 t.Errorf("unexpected method %q", r.Method)
6382 w.WriteHeader(500)
6383 return
6384 }
6385 if r.RequestURI != target {
6386 t.Errorf("unexpected CONNECT target %q", r.RequestURI)
6387 w.WriteHeader(500)
6388 return
6389 }
6390 nc, brw, err := w.(Hijacker).Hijack()
6391 if err != nil {
6392 t.Error(err)
6393 return
6394 }
6395 defer nc.Close()
6396 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
6397
6398 for {
6399 line, err := brw.ReadString('\n')
6400 if err != nil {
6401 if err != io.EOF {
6402 t.Error(err)
6403 }
6404 return
6405 }
6406 io.WriteString(brw, strings.ToUpper(line))
6407 brw.Flush()
6408 }
6409 }))
6410 pr, pw := io.Pipe()
6411 defer pw.Close()
6412 req, err := NewRequest("CONNECT", cst.ts.URL, pr)
6413 if err != nil {
6414 t.Fatal(err)
6415 }
6416 req.URL.Opaque = target
6417 res, err := cst.c.Do(req)
6418 if err != nil {
6419 t.Fatal(err)
6420 }
6421 defer res.Body.Close()
6422 if res.StatusCode != 200 {
6423 t.Fatalf("status code = %d; want 200", res.StatusCode)
6424 }
6425 br := bufio.NewReader(res.Body)
6426 for _, str := range []string{"foo", "bar", "baz"} {
6427 fmt.Fprintf(pw, "%s\n", str)
6428 got, err := br.ReadString('\n')
6429 if err != nil {
6430 t.Fatal(err)
6431 }
6432 got = strings.TrimSpace(got)
6433 want := strings.ToUpper(str)
6434 if got != want {
6435 t.Fatalf("got %q; want %q", got, want)
6436 }
6437 }
6438 }
6439
6440 func TestTransportRequestReplayable(t *testing.T) {
6441 someBody := io.NopCloser(strings.NewReader(""))
6442 tests := []struct {
6443 name string
6444 req *Request
6445 want bool
6446 }{
6447 {
6448 name: "GET",
6449 req: &Request{Method: "GET"},
6450 want: true,
6451 },
6452 {
6453 name: "GET_http.NoBody",
6454 req: &Request{Method: "GET", Body: NoBody},
6455 want: true,
6456 },
6457 {
6458 name: "GET_body",
6459 req: &Request{Method: "GET", Body: someBody},
6460 want: false,
6461 },
6462 {
6463 name: "POST",
6464 req: &Request{Method: "POST"},
6465 want: false,
6466 },
6467 {
6468 name: "POST_idempotency-key",
6469 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
6470 want: true,
6471 },
6472 {
6473 name: "POST_x-idempotency-key",
6474 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
6475 want: true,
6476 },
6477 {
6478 name: "POST_body",
6479 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
6480 want: false,
6481 },
6482 }
6483 for _, tt := range tests {
6484 t.Run(tt.name, func(t *testing.T) {
6485 got := tt.req.ExportIsReplayable()
6486 if got != tt.want {
6487 t.Errorf("replyable = %v; want %v", got, tt.want)
6488 }
6489 })
6490 }
6491 }
6492
6493
6494
6495 type testMockTCPConn struct {
6496 *net.TCPConn
6497
6498 ReadFromCalled bool
6499 }
6500
6501 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
6502 c.ReadFromCalled = true
6503 return c.TCPConn.ReadFrom(r)
6504 }
6505
6506 func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
6507 func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
6508 nBytes := int64(1 << 10)
6509 newFileFunc := func() (r io.Reader, done func(), err error) {
6510 f, err := os.CreateTemp("", "net-http-newfilefunc")
6511 if err != nil {
6512 return nil, nil, err
6513 }
6514
6515
6516 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
6517 return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
6518 }
6519 if _, err := f.Seek(0, 0); err != nil {
6520 return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
6521 }
6522
6523 done = func() {
6524 f.Close()
6525 os.Remove(f.Name())
6526 }
6527
6528 return f, done, nil
6529 }
6530
6531 newBufferFunc := func() (io.Reader, func(), error) {
6532 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
6533 }
6534
6535 cases := []struct {
6536 name string
6537 readerFunc func() (io.Reader, func(), error)
6538 contentLength int64
6539 expectedReadFrom bool
6540 }{
6541 {
6542 name: "file, length",
6543 readerFunc: newFileFunc,
6544 contentLength: nBytes,
6545 expectedReadFrom: true,
6546 },
6547 {
6548 name: "file, no length",
6549 readerFunc: newFileFunc,
6550 },
6551 {
6552 name: "file, negative length",
6553 readerFunc: newFileFunc,
6554 contentLength: -1,
6555 },
6556 {
6557 name: "buffer",
6558 contentLength: nBytes,
6559 readerFunc: newBufferFunc,
6560 },
6561 {
6562 name: "buffer, no length",
6563 readerFunc: newBufferFunc,
6564 },
6565 {
6566 name: "buffer, length -1",
6567 contentLength: -1,
6568 readerFunc: newBufferFunc,
6569 },
6570 }
6571
6572 for _, tc := range cases {
6573 t.Run(tc.name, func(t *testing.T) {
6574 r, cleanup, err := tc.readerFunc()
6575 if err != nil {
6576 t.Fatal(err)
6577 }
6578 defer cleanup()
6579
6580 tConn := &testMockTCPConn{}
6581 trFunc := func(tr *Transport) {
6582 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
6583 var d net.Dialer
6584 conn, err := d.DialContext(ctx, network, addr)
6585 if err != nil {
6586 return nil, err
6587 }
6588
6589 tcpConn, ok := conn.(*net.TCPConn)
6590 if !ok {
6591 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
6592 }
6593
6594 tConn.TCPConn = tcpConn
6595 return tConn, nil
6596 }
6597 }
6598
6599 cst := newClientServerTest(
6600 t,
6601 mode,
6602 HandlerFunc(func(w ResponseWriter, r *Request) {
6603 io.Copy(io.Discard, r.Body)
6604 r.Body.Close()
6605 w.WriteHeader(200)
6606 }),
6607 trFunc,
6608 )
6609
6610 req, err := NewRequest("PUT", cst.ts.URL, r)
6611 if err != nil {
6612 t.Fatal(err)
6613 }
6614 req.ContentLength = tc.contentLength
6615 req.Header.Set("Content-Type", "application/octet-stream")
6616 resp, err := cst.c.Do(req)
6617 if err != nil {
6618 t.Fatal(err)
6619 }
6620 defer resp.Body.Close()
6621 if resp.StatusCode != 200 {
6622 t.Fatalf("status code = %d; want 200", resp.StatusCode)
6623 }
6624
6625 expectedReadFrom := tc.expectedReadFrom
6626 if mode != http1Mode {
6627 expectedReadFrom = false
6628 }
6629 if !tConn.ReadFromCalled && expectedReadFrom {
6630 t.Fatalf("did not call ReadFrom")
6631 }
6632
6633 if tConn.ReadFromCalled && !expectedReadFrom {
6634 t.Fatalf("ReadFrom was unexpectedly invoked")
6635 }
6636 })
6637 }
6638 }
6639
6640 func TestTransportClone(t *testing.T) {
6641 tr := &Transport{
6642 Proxy: func(*Request) (*url.URL, error) { panic("") },
6643 OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6644 return nil
6645 },
6646 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6647 Dial: func(network, addr string) (net.Conn, error) { panic("") },
6648 DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
6649 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6650 TLSClientConfig: new(tls.Config),
6651 TLSHandshakeTimeout: time.Second,
6652 DisableKeepAlives: true,
6653 DisableCompression: true,
6654 MaxIdleConns: 1,
6655 MaxIdleConnsPerHost: 1,
6656 MaxConnsPerHost: 1,
6657 IdleConnTimeout: time.Second,
6658 ResponseHeaderTimeout: time.Second,
6659 ExpectContinueTimeout: time.Second,
6660 ProxyConnectHeader: Header{},
6661 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
6662 MaxResponseHeaderBytes: 1,
6663 ForceAttemptHTTP2: true,
6664 HTTP2: &HTTP2Config{MaxConcurrentStreams: 1},
6665 Protocols: &Protocols{},
6666 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
6667 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
6668 },
6669 ReadBufferSize: 1,
6670 WriteBufferSize: 1,
6671 }
6672 tr.Protocols.SetHTTP1(true)
6673 tr.Protocols.SetHTTP2(true)
6674 tr2 := tr.Clone()
6675 rv := reflect.ValueOf(tr2).Elem()
6676 rt := rv.Type()
6677 for i := 0; i < rt.NumField(); i++ {
6678 sf := rt.Field(i)
6679 if !token.IsExported(sf.Name) {
6680 continue
6681 }
6682 if rv.Field(i).IsZero() {
6683 t.Errorf("cloned field t2.%s is zero", sf.Name)
6684 }
6685 }
6686
6687 if _, ok := tr2.TLSNextProto["foo"]; !ok {
6688 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
6689 }
6690
6691
6692 tr = new(Transport)
6693 tr2 = tr.Clone()
6694 if tr2.TLSNextProto != nil {
6695 t.Errorf("Transport.TLSNextProto unexpected non-nil")
6696 }
6697 }
6698
6699 func TestIs408(t *testing.T) {
6700 tests := []struct {
6701 in string
6702 want bool
6703 }{
6704 {"HTTP/1.0 408", true},
6705 {"HTTP/1.1 408", true},
6706 {"HTTP/1.8 408", true},
6707 {"HTTP/2.0 408", false},
6708 {"HTTP/1.1 408 ", true},
6709 {"HTTP/1.1 40", false},
6710 {"http/1.0 408", false},
6711 {"HTTP/1-1 408", false},
6712 }
6713 for _, tt := range tests {
6714 if got := Export_is408Message([]byte(tt.in)); got != tt.want {
6715 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
6716 }
6717 }
6718 }
6719
6720 func TestTransportIgnores408(t *testing.T) {
6721 run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
6722 }
6723 func testTransportIgnores408(t *testing.T, mode testMode) {
6724
6725 defer log.SetOutput(log.Writer())
6726
6727 var logout strings.Builder
6728 log.SetOutput(&logout)
6729
6730 const target = "backend:443"
6731
6732 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6733 nc, _, err := w.(Hijacker).Hijack()
6734 if err != nil {
6735 t.Error(err)
6736 return
6737 }
6738 defer nc.Close()
6739 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
6740 nc.Write([]byte("HTTP/1.1 408 bye\r\n"))
6741 }))
6742 req, err := NewRequest("GET", cst.ts.URL, nil)
6743 if err != nil {
6744 t.Fatal(err)
6745 }
6746 res, err := cst.c.Do(req)
6747 if err != nil {
6748 t.Fatal(err)
6749 }
6750 slurp, err := io.ReadAll(res.Body)
6751 if err != nil {
6752 t.Fatal(err)
6753 }
6754 if err != nil {
6755 t.Fatal(err)
6756 }
6757 if string(slurp) != "ok" {
6758 t.Fatalf("got %q; want ok", slurp)
6759 }
6760
6761 waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
6762 if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
6763 if d > 0 {
6764 t.Logf("%v idle conns still present after %v", n, d)
6765 }
6766 return false
6767 }
6768 return true
6769 })
6770 if got := logout.String(); got != "" {
6771 t.Fatalf("expected no log output; got: %s", got)
6772 }
6773 }
6774
6775 func TestInvalidHeaderResponse(t *testing.T) {
6776 run(t, testInvalidHeaderResponse, []testMode{http1Mode})
6777 }
6778 func testInvalidHeaderResponse(t *testing.T, mode testMode) {
6779 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6780 conn, buf, _ := w.(Hijacker).Hijack()
6781 buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6782 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6783 "Content-Type: text/html; charset=utf-8\r\n" +
6784 "Content-Length: 0\r\n" +
6785 "Foo : bar\r\n\r\n"))
6786 buf.Flush()
6787 conn.Close()
6788 }))
6789 res, err := cst.c.Get(cst.ts.URL)
6790 if err != nil {
6791 t.Fatal(err)
6792 }
6793 defer res.Body.Close()
6794 if v := res.Header.Get("Foo"); v != "" {
6795 t.Errorf(`unexpected "Foo" header: %q`, v)
6796 }
6797 if v := res.Header.Get("Foo "); v != "bar" {
6798 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6799 }
6800 }
6801
6802 type bodyCloser bool
6803
6804 func (bc *bodyCloser) Close() error {
6805 *bc = true
6806 return nil
6807 }
6808 func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6809 return 0, io.EOF
6810 }
6811
6812
6813
6814 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6815 run(t, testTransportClosesBodyOnInvalidRequests)
6816 }
6817 func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
6818 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6819 t.Errorf("Should not have been invoked")
6820 })).ts
6821
6822 u, _ := url.Parse(cst.URL)
6823
6824 tests := []struct {
6825 name string
6826 req *Request
6827 wantErr string
6828 }{
6829 {
6830 name: "invalid method",
6831 req: &Request{
6832 Method: " ",
6833 URL: u,
6834 },
6835 wantErr: `invalid method " "`,
6836 },
6837 {
6838 name: "nil URL",
6839 req: &Request{
6840 Method: "GET",
6841 },
6842 wantErr: `nil Request.URL`,
6843 },
6844 {
6845 name: "invalid header key",
6846 req: &Request{
6847 Method: "GET",
6848 Header: Header{"💡": {"emoji"}},
6849 URL: u,
6850 },
6851 wantErr: `invalid header field name "💡"`,
6852 },
6853 {
6854 name: "invalid header value",
6855 req: &Request{
6856 Method: "POST",
6857 Header: Header{"key": {"\x19"}},
6858 URL: u,
6859 },
6860 wantErr: `invalid header field value for "key"`,
6861 },
6862 {
6863 name: "non HTTP(s) scheme",
6864 req: &Request{
6865 Method: "POST",
6866 URL: &url.URL{Scheme: "faux"},
6867 },
6868 wantErr: `unsupported protocol scheme "faux"`,
6869 },
6870 {
6871 name: "no Host in URL",
6872 req: &Request{
6873 Method: "POST",
6874 URL: &url.URL{Scheme: "http"},
6875 },
6876 wantErr: `no Host in request URL`,
6877 },
6878 }
6879
6880 for _, tt := range tests {
6881 t.Run(tt.name, func(t *testing.T) {
6882 var bc bodyCloser
6883 req := tt.req
6884 req.Body = &bc
6885 _, err := cst.Client().Do(tt.req)
6886 if err == nil {
6887 t.Fatal("Expected an error")
6888 }
6889 if !bc {
6890 t.Fatal("Expected body to have been closed")
6891 }
6892 if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
6893 t.Fatalf("Error mismatch: %q does not end with %q", g, w)
6894 }
6895 })
6896 }
6897 }
6898
6899
6900
6901 type breakableConn struct {
6902 net.Conn
6903 *brokenState
6904 }
6905
6906 type brokenState struct {
6907 sync.Mutex
6908 broken bool
6909 }
6910
6911 func (w *breakableConn) Write(b []byte) (n int, err error) {
6912 w.Lock()
6913 defer w.Unlock()
6914 if w.broken {
6915 return 0, errors.New("some write error")
6916 }
6917 return w.Conn.Write(b)
6918 }
6919
6920
6921 func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6922 run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
6923 }
6924 func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
6925 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6926
6927 var brokenState brokenState
6928
6929 const numReqs = 5
6930 var numDials, gotConns uint32
6931
6932 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6933 atomic.AddUint32(&numDials, 1)
6934 c, err := net.Dial(netw, addr)
6935 if err != nil {
6936 t.Errorf("unexpected Dial error: %v", err)
6937 return nil, err
6938 }
6939 return &breakableConn{c, &brokenState}, err
6940 }
6941
6942 for i := 1; i <= numReqs; i++ {
6943 brokenState.Lock()
6944 brokenState.broken = false
6945 brokenState.Unlock()
6946
6947
6948
6949
6950 doBreak := i != numReqs
6951
6952 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6953 GotConn: func(info httptrace.GotConnInfo) {
6954 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
6955 atomic.AddUint32(&gotConns, 1)
6956 },
6957 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
6958 brokenState.Lock()
6959 defer brokenState.Unlock()
6960 if doBreak {
6961 brokenState.broken = true
6962 }
6963 },
6964 })
6965 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6966 if err != nil {
6967 t.Fatal(err)
6968 }
6969 _, err = cst.c.Do(req)
6970 if doBreak != (err != nil) {
6971 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
6972 }
6973 }
6974 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
6975 t.Errorf("GotConn calls = %v; want %v", got, want)
6976 }
6977 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
6978 t.Errorf("Dials = %v; want %v", got, want)
6979 }
6980 }
6981
6982
6983
6984
6985
6986 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
6987 run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
6988 }
6989 func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
6990 CondSkipHTTP2(t)
6991
6992 h := HandlerFunc(func(w ResponseWriter, r *Request) {
6993 _, err := w.Write([]byte("foo"))
6994 if err != nil {
6995 t.Fatalf("Write: %v", err)
6996 }
6997 })
6998
6999 ts := newClientServerTest(t, mode, h).ts
7000
7001 c := ts.Client()
7002 tr := c.Transport.(*Transport)
7003 tr.MaxConnsPerHost = 1
7004
7005 errCh := make(chan error, 300)
7006 doReq := func() {
7007 resp, err := c.Get(ts.URL)
7008 if err != nil {
7009 errCh <- fmt.Errorf("request failed: %v", err)
7010 return
7011 }
7012 defer resp.Body.Close()
7013 _, err = io.ReadAll(resp.Body)
7014 if err != nil {
7015 errCh <- fmt.Errorf("read body failed: %v", err)
7016 }
7017 }
7018
7019 var wg sync.WaitGroup
7020 for i := 0; i < 300; i++ {
7021 wg.Add(1)
7022 go func() {
7023 defer wg.Done()
7024 doReq()
7025 }()
7026 }
7027 wg.Wait()
7028 close(errCh)
7029
7030 for err := range errCh {
7031 t.Errorf("error occurred: %v", err)
7032 }
7033 }
7034
7035
7036
7037
7038 func TestAltProtoCancellation(t *testing.T) {
7039 defer afterTest(t)
7040 tr := &Transport{}
7041 c := &Client{
7042 Transport: tr,
7043 Timeout: time.Millisecond,
7044 }
7045 tr.RegisterProtocol("cancel", cancelProto{})
7046 _, err := c.Get("cancel://bar.com/path")
7047 if err == nil {
7048 t.Error("request unexpectedly succeeded")
7049 } else if !strings.Contains(err.Error(), errCancelProto.Error()) {
7050 t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
7051 }
7052 }
7053
7054 var errCancelProto = errors.New("canceled as expected")
7055
7056 type cancelProto struct{}
7057
7058 func (cancelProto) RoundTrip(req *Request) (*Response, error) {
7059 <-req.Cancel
7060 return nil, errCancelProto
7061 }
7062
7063 type roundTripFunc func(r *Request) (*Response, error)
7064
7065 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
7066
7067
7068 func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
7069 func testIssue32441(t *testing.T, mode testMode) {
7070 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7071 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
7072 t.Error("body length is zero")
7073 }
7074 })).ts
7075 c := ts.Client()
7076 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
7077
7078 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
7079 t.Error("body length is zero during round trip")
7080 }
7081 return nil, ErrSkipAltProtocol
7082 }))
7083 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
7084 t.Error(err)
7085 }
7086 }
7087
7088
7089
7090 func TestTransportRejectsSignInContentLength(t *testing.T) {
7091 run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
7092 }
7093 func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
7094 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7095 w.Header().Set("Content-Length", "+3")
7096 w.Write([]byte("abc"))
7097 })).ts
7098
7099 c := cst.Client()
7100 res, err := c.Get(cst.URL)
7101 if err == nil || res != nil {
7102 t.Fatal("Expected a non-nil error and a nil http.Response")
7103 }
7104 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
7105 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
7106 }
7107 }
7108
7109
7110 type dumpConn struct {
7111 io.Writer
7112 io.Reader
7113 }
7114
7115 func (c *dumpConn) Close() error { return nil }
7116 func (c *dumpConn) LocalAddr() net.Addr { return nil }
7117 func (c *dumpConn) RemoteAddr() net.Addr { return nil }
7118 func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
7119 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
7120 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
7121
7122
7123
7124 type delegateReader struct {
7125 c chan io.Reader
7126 r io.Reader
7127 }
7128
7129 func (r *delegateReader) Read(p []byte) (int, error) {
7130 if r.r == nil {
7131 var ok bool
7132 if r.r, ok = <-r.c; !ok {
7133 return 0, errors.New("delegate closed")
7134 }
7135 }
7136 return r.r.Read(p)
7137 }
7138
7139 func testTransportRace(req *Request) {
7140 save := req.Body
7141 pr, pw := io.Pipe()
7142 defer pr.Close()
7143 defer pw.Close()
7144 dr := &delegateReader{c: make(chan io.Reader)}
7145
7146 t := &Transport{
7147 Dial: func(net, addr string) (net.Conn, error) {
7148 return &dumpConn{pw, dr}, nil
7149 },
7150 }
7151 defer t.CloseIdleConnections()
7152
7153 quitReadCh := make(chan struct{})
7154
7155 go func() {
7156 defer close(quitReadCh)
7157
7158 req, err := ReadRequest(bufio.NewReader(pr))
7159 if err == nil {
7160
7161
7162 io.Copy(io.Discard, req.Body)
7163 req.Body.Close()
7164 }
7165 select {
7166 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
7167 case quitReadCh <- struct{}{}:
7168
7169 close(dr.c)
7170 }
7171 }()
7172
7173 t.RoundTrip(req)
7174
7175
7176
7177 pw.Close()
7178 <-quitReadCh
7179
7180 req.Body = save
7181 }
7182
7183
7184
7185
7186
7187 func TestErrorWriteLoopRace(t *testing.T) {
7188 if testing.Short() {
7189 return
7190 }
7191 t.Parallel()
7192 for i := 0; i < 1000; i++ {
7193 delay := time.Duration(mrand.Intn(5)) * time.Millisecond
7194 ctx, cancel := context.WithTimeout(context.Background(), delay)
7195 defer cancel()
7196
7197 r := bytes.NewBuffer(make([]byte, 10000))
7198 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
7199 if err != nil {
7200 t.Fatal(err)
7201 }
7202
7203 testTransportRace(req)
7204 }
7205 }
7206
7207
7208
7209
7210 func TestCancelRequestWhenSharingConnection(t *testing.T) {
7211 run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
7212 }
7213 func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
7214 reqc := make(chan chan struct{}, 2)
7215 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
7216 ch := make(chan struct{}, 1)
7217 reqc <- ch
7218 <-ch
7219 w.Header().Add("Content-Length", "0")
7220 })).ts
7221
7222 client := ts.Client()
7223 transport := client.Transport.(*Transport)
7224 transport.MaxIdleConns = 1
7225 transport.MaxConnsPerHost = 1
7226
7227 var wg sync.WaitGroup
7228
7229 wg.Add(1)
7230 putidlec := make(chan chan struct{}, 1)
7231 reqerrc := make(chan error, 1)
7232 go func() {
7233 defer wg.Done()
7234 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
7235 PutIdleConn: func(error) {
7236
7237
7238 ch := make(chan struct{})
7239 putidlec <- ch
7240 close(putidlec)
7241 <-ch
7242 },
7243 })
7244 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
7245 res, err := client.Do(req)
7246 if err != nil {
7247 reqerrc <- err
7248 } else {
7249 res.Body.Close()
7250 }
7251 }()
7252
7253
7254
7255 select {
7256 case err := <-reqerrc:
7257 t.Fatalf("request 1: got err %v, want nil", err)
7258 case r1c := <-reqc:
7259 close(r1c)
7260 }
7261 var idlec chan struct{}
7262 select {
7263 case err := <-reqerrc:
7264 t.Fatalf("request 1: got err %v, want nil", err)
7265 case idlec = <-putidlec:
7266 }
7267
7268 wg.Add(1)
7269 cancelctx, cancel := context.WithCancel(context.Background())
7270 go func() {
7271 defer wg.Done()
7272 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
7273 res, err := client.Do(req)
7274 if err == nil {
7275 res.Body.Close()
7276 }
7277 if !errors.Is(err, context.Canceled) {
7278 t.Errorf("request 2: got err %v, want Canceled", err)
7279 }
7280
7281
7282 close(idlec)
7283 }()
7284
7285
7286
7287 r2c := <-reqc
7288 cancel()
7289
7290 <-idlec
7291
7292 close(r2c)
7293 wg.Wait()
7294 }
7295
7296 func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
7297 func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
7298 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7299 go io.Copy(io.Discard, req.Body)
7300 panic(ErrAbortHandler)
7301 })).ts
7302
7303 var wg sync.WaitGroup
7304 for i := 0; i < 2; i++ {
7305 wg.Add(1)
7306 go func() {
7307 defer wg.Done()
7308 for j := 0; j < 10; j++ {
7309 const reqLen = 6 * 1024 * 1024
7310 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
7311 req.ContentLength = reqLen
7312 resp, _ := ts.Client().Transport.RoundTrip(req)
7313 if resp != nil {
7314 resp.Body.Close()
7315 }
7316 }
7317 }()
7318 }
7319 wg.Wait()
7320 }
7321
7322 func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
7323 func testRequestSanitization(t *testing.T, mode testMode) {
7324 if mode == http2Mode {
7325
7326 t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
7327 }
7328 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7329 if h, ok := req.Header["X-Evil"]; ok {
7330 t.Errorf("request has X-Evil header: %q", h)
7331 }
7332 })).ts
7333 req, _ := NewRequest("GET", ts.URL, nil)
7334 req.Host = "go.dev\r\nX-Evil:evil"
7335 resp, _ := ts.Client().Do(req)
7336 if resp != nil {
7337 resp.Body.Close()
7338 }
7339 }
7340
7341 func TestProxyAuthHeader(t *testing.T) {
7342
7343 run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
7344 }
7345 func testProxyAuthHeader(t *testing.T, mode testMode) {
7346 const username = "u"
7347 const password = "@/?!"
7348 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7349
7350
7351 var r2 Request
7352 r2.Header = Header{
7353 "Authorization": req.Header["Proxy-Authorization"],
7354 }
7355 gotuser, gotpass, ok := r2.BasicAuth()
7356 if !ok || gotuser != username || gotpass != password {
7357 t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
7358 }
7359 }))
7360 u, err := url.Parse(cst.ts.URL)
7361 if err != nil {
7362 t.Fatal(err)
7363 }
7364 u.User = url.UserPassword(username, password)
7365 t.Setenv("HTTP_PROXY", u.String())
7366 cst.tr.Proxy = ProxyURL(u)
7367 resp, err := cst.c.Get("http://_/")
7368 if err != nil {
7369 t.Fatal(err)
7370 }
7371 resp.Body.Close()
7372 }
7373
7374
7375 func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
7376 ln := newLocalListener(t)
7377 addr := ln.Addr().String()
7378
7379 done := make(chan struct{})
7380 go func() {
7381 conn, err := ln.Accept()
7382 if err != nil {
7383 t.Errorf("ln.Accept: %v", err)
7384 return
7385 }
7386
7387
7388 if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil {
7389 t.Errorf("conn.Read: %v", err)
7390 return
7391 }
7392 io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo")
7393 <-done
7394 conn.Close()
7395 }()
7396
7397 didRead := make(chan bool)
7398 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
7399 defer SetReadLoopBeforeNextReadHook(nil)
7400
7401 tr := &Transport{}
7402
7403
7404 req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30))
7405 if err != nil {
7406 t.Fatalf("NewRequest: %v", err)
7407 }
7408
7409 resp, err := tr.RoundTrip(req)
7410 if err != nil {
7411 t.Fatalf("tr.RoundTrip: %v", err)
7412 }
7413
7414 close(done)
7415
7416
7417
7418 <-didRead
7419
7420 resp.Body.Close()
7421
7422
7423
7424 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
7425 n := tr.NumPendingRequestsForTesting()
7426 if n > 0 {
7427 if d > 0 {
7428 t.Logf("pending requests = %d after %v (want 0)", n, d)
7429 }
7430 return false
7431 }
7432 return true
7433 })
7434 }
7435
7436 func TestValidateClientRequestTrailers(t *testing.T) {
7437 run(t, testValidateClientRequestTrailers)
7438 }
7439
7440 func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
7441 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7442 rw.Write([]byte("Hello"))
7443 })).ts
7444
7445 cases := []struct {
7446 trailer Header
7447 wantErr string
7448 }{
7449 {Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
7450 {Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
7451 }
7452
7453 for i, tt := range cases {
7454 testName := fmt.Sprintf("%s%d", mode, i)
7455 t.Run(testName, func(t *testing.T) {
7456 req, err := NewRequest("GET", cst.URL, nil)
7457 if err != nil {
7458 t.Fatal(err)
7459 }
7460 req.Trailer = tt.trailer
7461 res, err := cst.Client().Do(req)
7462 if err == nil {
7463 t.Fatal("Expected an error")
7464 }
7465 if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
7466 t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
7467 }
7468 if res != nil {
7469 t.Fatal("Unexpected non-nil response")
7470 }
7471 })
7472 }
7473 }
7474
7475 func TestTransportServerProtocols(t *testing.T) {
7476 CondSkipHTTP2(t)
7477 DefaultTransport.(*Transport).CloseIdleConnections()
7478
7479 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7480 if err != nil {
7481 t.Fatal(err)
7482 }
7483 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7484 if err != nil {
7485 t.Fatal(err)
7486 }
7487 certpool := x509.NewCertPool()
7488 certpool.AddCert(leafCert)
7489
7490 for _, test := range []struct {
7491 name string
7492 scheme string
7493 setup func(t *testing.T)
7494 transport func(*Transport)
7495 server func(*Server)
7496 want string
7497 }{{
7498 name: "http default",
7499 scheme: "http",
7500 want: "HTTP/1.1",
7501 }, {
7502 name: "https default",
7503 scheme: "https",
7504 transport: func(tr *Transport) {
7505
7506 },
7507 want: "HTTP/1.1",
7508 }, {
7509 name: "https transport protocols include HTTP2",
7510 scheme: "https",
7511 transport: func(tr *Transport) {
7512
7513
7514 tr.Protocols = &Protocols{}
7515 tr.Protocols.SetHTTP1(true)
7516 tr.Protocols.SetHTTP2(true)
7517 },
7518 want: "HTTP/2.0",
7519 }, {
7520 name: "https transport protocols only include HTTP1",
7521 scheme: "https",
7522 transport: func(tr *Transport) {
7523
7524 tr.Protocols = &Protocols{}
7525 tr.Protocols.SetHTTP1(true)
7526 },
7527 want: "HTTP/1.1",
7528 }, {
7529 name: "https transport ForceAttemptHTTP2",
7530 scheme: "https",
7531 transport: func(tr *Transport) {
7532
7533 tr.ForceAttemptHTTP2 = true
7534 },
7535 want: "HTTP/2.0",
7536 }, {
7537 name: "https transport protocols override TLSNextProto",
7538 scheme: "https",
7539 transport: func(tr *Transport) {
7540
7541
7542
7543 tr.Protocols = &Protocols{}
7544 tr.Protocols.SetHTTP1(true)
7545 tr.Protocols.SetHTTP2(true)
7546 tr.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{}
7547 },
7548 want: "HTTP/2.0",
7549 }, {
7550 name: "https server disables HTTP2 with TLSNextProto",
7551 scheme: "https",
7552 server: func(srv *Server) {
7553
7554
7555 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7556 },
7557 want: "HTTP/1.1",
7558 }, {
7559 name: "https server Protocols overrides empty TLSNextProto",
7560 scheme: "https",
7561 server: func(srv *Server) {
7562
7563
7564 srv.Protocols = &Protocols{}
7565 srv.Protocols.SetHTTP1(true)
7566 srv.Protocols.SetHTTP2(true)
7567 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7568 },
7569 want: "HTTP/2.0",
7570 }, {
7571 name: "https server protocols only include HTTP1",
7572 scheme: "https",
7573 server: func(srv *Server) {
7574 srv.Protocols = &Protocols{}
7575 srv.Protocols.SetHTTP1(true)
7576 },
7577 want: "HTTP/1.1",
7578 }, {
7579 name: "https server protocols include HTTP2",
7580 scheme: "https",
7581 server: func(srv *Server) {
7582 srv.Protocols = &Protocols{}
7583 srv.Protocols.SetHTTP1(true)
7584 srv.Protocols.SetHTTP2(true)
7585 },
7586 want: "HTTP/2.0",
7587 }, {
7588 name: "GODEBUG disables HTTP2 client",
7589 scheme: "https",
7590 setup: func(t *testing.T) {
7591 t.Setenv("GODEBUG", "http2client=0")
7592 },
7593 transport: func(tr *Transport) {
7594
7595
7596 tr.Protocols = &Protocols{}
7597 tr.Protocols.SetHTTP1(true)
7598 tr.Protocols.SetHTTP2(true)
7599 },
7600 want: "HTTP/1.1",
7601 }, {
7602 name: "GODEBUG disables HTTP2 server",
7603 scheme: "https",
7604 setup: func(t *testing.T) {
7605 t.Setenv("GODEBUG", "http2server=0")
7606 },
7607 transport: func(tr *Transport) {
7608
7609
7610 tr.Protocols = &Protocols{}
7611 tr.Protocols.SetHTTP1(true)
7612 tr.Protocols.SetHTTP2(true)
7613 },
7614 want: "HTTP/1.1",
7615 }, {
7616 name: "unencrypted HTTP2 with prior knowledge",
7617 scheme: "http",
7618 transport: func(tr *Transport) {
7619 tr.Protocols = &Protocols{}
7620 tr.Protocols.SetUnencryptedHTTP2(true)
7621 },
7622 server: func(srv *Server) {
7623 srv.Protocols = &Protocols{}
7624 srv.Protocols.SetHTTP1(true)
7625 srv.Protocols.SetUnencryptedHTTP2(true)
7626 },
7627 want: "HTTP/2.0",
7628 }, {
7629 name: "unencrypted HTTP2 only on server",
7630 scheme: "http",
7631 transport: func(tr *Transport) {
7632 tr.Protocols = &Protocols{}
7633 tr.Protocols.SetUnencryptedHTTP2(true)
7634 },
7635 server: func(srv *Server) {
7636 srv.Protocols = &Protocols{}
7637 srv.Protocols.SetUnencryptedHTTP2(true)
7638 },
7639 want: "HTTP/2.0",
7640 }, {
7641 name: "unencrypted HTTP2 with no server support",
7642 scheme: "http",
7643 transport: func(tr *Transport) {
7644 tr.Protocols = &Protocols{}
7645 tr.Protocols.SetUnencryptedHTTP2(true)
7646 },
7647 server: func(srv *Server) {
7648 srv.Protocols = &Protocols{}
7649 srv.Protocols.SetHTTP1(true)
7650 },
7651 want: "error",
7652 }, {
7653 name: "HTTP1 with no server support",
7654 scheme: "http",
7655 transport: func(tr *Transport) {
7656 tr.Protocols = &Protocols{}
7657 tr.Protocols.SetHTTP1(true)
7658 },
7659 server: func(srv *Server) {
7660 srv.Protocols = &Protocols{}
7661 srv.Protocols.SetUnencryptedHTTP2(true)
7662 },
7663 want: "error",
7664 }, {
7665 name: "HTTPS1 with no server support",
7666 scheme: "https",
7667 transport: func(tr *Transport) {
7668 tr.Protocols = &Protocols{}
7669 tr.Protocols.SetHTTP1(true)
7670 },
7671 server: func(srv *Server) {
7672 srv.Protocols = &Protocols{}
7673 srv.Protocols.SetHTTP2(true)
7674 },
7675 want: "error",
7676 }} {
7677 t.Run(test.name, func(t *testing.T) {
7678
7679
7680 srv := &Server{
7681 TLSConfig: &tls.Config{
7682 Certificates: []tls.Certificate{cert},
7683 },
7684 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {
7685 w.Header().Set("X-Proto", req.Proto)
7686 }),
7687 }
7688 tr := &Transport{
7689 TLSClientConfig: &tls.Config{
7690 RootCAs: certpool,
7691 },
7692 }
7693
7694 if test.setup != nil {
7695 test.setup(t)
7696 }
7697 if test.server != nil {
7698 test.server(srv)
7699 }
7700 if test.transport != nil {
7701 test.transport(tr)
7702 } else {
7703 tr.Protocols = &Protocols{}
7704 tr.Protocols.SetHTTP1(true)
7705 tr.Protocols.SetHTTP2(true)
7706 }
7707
7708 listener := newLocalListener(t)
7709 srvc := make(chan error, 1)
7710 go func() {
7711 switch test.scheme {
7712 case "http":
7713 srvc <- srv.Serve(listener)
7714 case "https":
7715 srvc <- srv.ServeTLS(listener, "", "")
7716 }
7717 }()
7718 t.Cleanup(func() {
7719 srv.Close()
7720 <-srvc
7721 })
7722
7723 client := &Client{Transport: tr}
7724 resp, err := client.Get(test.scheme + "://" + listener.Addr().String())
7725 if err != nil {
7726 if test.want == "error" {
7727 return
7728 }
7729 t.Fatal(err)
7730 }
7731 if got := resp.Header.Get("X-Proto"); got != test.want {
7732 t.Fatalf("request proto %q, want %q", got, test.want)
7733 }
7734 })
7735 }
7736 }
7737
7738 func TestIssue61474(t *testing.T) {
7739 run(t, testIssue61474, []testMode{http2Mode})
7740 }
7741 func testIssue61474(t *testing.T, mode testMode) {
7742 if testing.Short() {
7743 return
7744 }
7745
7746
7747
7748
7749 t.Skip("test is too large")
7750
7751 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7752 }), func(tr *Transport) {
7753 tr.MaxConnsPerHost = 1
7754 })
7755 var wg sync.WaitGroup
7756 defer wg.Wait()
7757 for range 100000 {
7758 wg.Go(func() {
7759 ctx, cancel := context.WithTimeout(t.Context(), 1*time.Millisecond)
7760 defer cancel()
7761 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
7762 resp, err := cst.c.Do(req)
7763 if err == nil {
7764 resp.Body.Close()
7765 }
7766 })
7767 }
7768 }
7769
View as plain text