Source file
src/net/http/clientserver_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bytes"
11 "compress/gzip"
12 "context"
13 "crypto/rand"
14 "crypto/sha1"
15 "crypto/tls"
16 "fmt"
17 "hash"
18 "io"
19 "log"
20 "maps"
21 "net"
22 . "net/http"
23 "net/http/httptest"
24 "net/http/httptrace"
25 "net/http/httputil"
26 "net/textproto"
27 "net/url"
28 "os"
29 "reflect"
30 "runtime"
31 "slices"
32 "strings"
33 "sync"
34 "sync/atomic"
35 "testing"
36 "testing/synctest"
37 "time"
38 )
39
40 type testMode string
41
42 const (
43 http1Mode = testMode("h1")
44 https1Mode = testMode("https1")
45 http2Mode = testMode("h2")
46 http2UnencryptedMode = testMode("h2unencrypted")
47 )
48
49 func (m testMode) Scheme() string {
50 switch m {
51 case http1Mode, http2UnencryptedMode:
52 return "http"
53 case https1Mode, http2Mode:
54 return "https"
55 }
56 panic("unknown testMode")
57 }
58
59 type testNotParallelOpt struct{}
60
61 var (
62 testNotParallel = testNotParallelOpt{}
63 )
64
65 type TBRun[T any] interface {
66 testing.TB
67 Run(string, func(T)) bool
68 }
69
70
71
72
73
74
75
76
77 func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
78 t.Helper()
79 modes := []testMode{http1Mode, http2Mode}
80 parallel := true
81 for _, opt := range opts {
82 switch opt := opt.(type) {
83 case []testMode:
84 modes = opt
85 case testNotParallelOpt:
86 parallel = false
87 default:
88 t.Fatalf("unknown option type %T", opt)
89 }
90 }
91 if t, ok := any(t).(*testing.T); ok && parallel {
92 setParallel(t)
93 }
94 for _, mode := range modes {
95 t.Run(string(mode), func(t T) {
96 t.Helper()
97 if t, ok := any(t).(*testing.T); ok && parallel {
98 setParallel(t)
99 }
100 t.Cleanup(func() {
101 afterTest(t)
102 })
103 f(t, mode)
104 })
105 }
106 }
107
108
109
110
111 func runSynctest(t *testing.T, f func(t *testing.T, mode testMode), opts ...any) {
112 run(t, func(t *testing.T, mode testMode) {
113 synctest.Test(t, func(t *testing.T) {
114 f(t, mode)
115 })
116 }, opts...)
117 }
118
119 type clientServerTest struct {
120 t testing.TB
121 h2 bool
122 h Handler
123 ts *httptest.Server
124 tr *Transport
125 c *Client
126 li *fakeNetListener
127 }
128
129 func (t *clientServerTest) close() {
130 t.tr.CloseIdleConnections()
131 t.ts.Close()
132 }
133
134 func (t *clientServerTest) getURL(u string) string {
135 res, err := t.c.Get(u)
136 if err != nil {
137 t.t.Fatal(err)
138 }
139 defer res.Body.Close()
140 slurp, err := io.ReadAll(res.Body)
141 if err != nil {
142 t.t.Fatal(err)
143 }
144 return string(slurp)
145 }
146
147 func (t *clientServerTest) scheme() string {
148 if t.h2 {
149 return "https"
150 }
151 return "http"
152 }
153
154 var optQuietLog = func(ts *httptest.Server) {
155 ts.Config.ErrorLog = quietLog
156 }
157
158 func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
159 return func(ts *httptest.Server) {
160 ts.Config.ErrorLog = lg
161 }
162 }
163
164 var optFakeNet = new(struct{})
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180 func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
181 if mode == http2Mode || mode == http2UnencryptedMode {
182 CondSkipHTTP2(t)
183 }
184 cst := &clientServerTest{
185 t: t,
186 h2: mode == http2Mode,
187 h: h,
188 }
189
190 var transportFuncs []func(*Transport)
191
192 if idx := slices.Index(opts, any(optFakeNet)); idx >= 0 {
193 opts = slices.Delete(opts, idx, idx+1)
194 cst.li = fakeNetListen()
195 cst.ts = &httptest.Server{
196 Config: &Server{Handler: h},
197 Listener: cst.li,
198 }
199 transportFuncs = append(transportFuncs, func(tr *Transport) {
200 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
201 return cst.li.connect(), nil
202 }
203 })
204 } else {
205 cst.ts = httptest.NewUnstartedServer(h)
206 }
207
208 for _, opt := range opts {
209 switch opt := opt.(type) {
210 case func(*Transport):
211 transportFuncs = append(transportFuncs, opt)
212 case func(*httptest.Server):
213 opt(cst.ts)
214 case func(*Server):
215 opt(cst.ts.Config)
216 default:
217 t.Fatalf("unhandled option type %T", opt)
218 }
219 }
220
221 if cst.ts.Config.ErrorLog == nil {
222 cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
223 }
224
225 p := &Protocols{}
226 if cst.ts.Config.Protocols == nil {
227 cst.ts.Config.Protocols = p
228 }
229 switch mode {
230 case http1Mode:
231 p.SetHTTP1(true)
232 cst.ts.Start()
233 case https1Mode:
234 p.SetHTTP1(true)
235 cst.ts.StartTLS()
236 case http2UnencryptedMode:
237 p.SetUnencryptedHTTP2(true)
238 cst.ts.Start()
239 case http2Mode:
240 p.SetHTTP2(true)
241 cst.ts.EnableHTTP2 = true
242 cst.ts.TLS = cst.ts.Config.TLSConfig
243 cst.ts.StartTLS()
244 default:
245 t.Fatalf("unknown test mode %v", mode)
246 }
247 cst.c = cst.ts.Client()
248 cst.tr = cst.c.Transport.(*Transport)
249 for _, f := range transportFuncs {
250 f(cst.tr)
251 }
252 if cst.tr.Protocols == nil {
253 cst.tr.Protocols = p
254 }
255
256 t.Cleanup(func() {
257 cst.close()
258 })
259 return cst
260 }
261
262 type testLogWriter struct {
263 t testing.TB
264 }
265
266 func (w testLogWriter) Write(b []byte) (int, error) {
267 w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
268 return len(b), nil
269 }
270
271
272 func TestNewClientServerTest(t *testing.T) {
273 modes := []testMode{http1Mode, https1Mode, http2Mode}
274 t.Run("realnet", func(t *testing.T) {
275 run(t, func(t *testing.T, mode testMode) {
276 testNewClientServerTest(t, mode)
277 }, modes)
278 })
279 t.Run("synctest", func(t *testing.T) {
280 runSynctest(t, func(t *testing.T, mode testMode) {
281 testNewClientServerTest(t, mode, optFakeNet)
282 }, modes)
283 })
284 }
285 func testNewClientServerTest(t *testing.T, mode testMode, opts ...any) {
286 var got struct {
287 sync.Mutex
288 proto string
289 hasTLS bool
290 }
291 h := HandlerFunc(func(w ResponseWriter, r *Request) {
292 got.Lock()
293 defer got.Unlock()
294 got.proto = r.Proto
295 got.hasTLS = r.TLS != nil
296 })
297 cst := newClientServerTest(t, mode, h, opts...)
298 if _, err := cst.c.Head(cst.ts.URL); err != nil {
299 t.Fatal(err)
300 }
301 var wantProto string
302 var wantTLS bool
303 switch mode {
304 case http1Mode:
305 wantProto = "HTTP/1.1"
306 wantTLS = false
307 case https1Mode:
308 wantProto = "HTTP/1.1"
309 wantTLS = true
310 case http2Mode:
311 wantProto = "HTTP/2.0"
312 wantTLS = true
313 }
314 if got.proto != wantProto {
315 t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
316 }
317 if got.hasTLS != wantTLS {
318 t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
319 }
320 }
321
322 func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
323 func testChunkedResponseHeaders(t *testing.T, mode testMode) {
324 log.SetOutput(io.Discard)
325 defer log.SetOutput(os.Stderr)
326 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
327 w.Header().Set("Content-Length", "intentional gibberish")
328 w.(Flusher).Flush()
329 fmt.Fprintf(w, "I am a chunked response.")
330 }))
331
332 res, err := cst.c.Get(cst.ts.URL)
333 if err != nil {
334 t.Fatalf("Get error: %v", err)
335 }
336 defer res.Body.Close()
337 if g, e := res.ContentLength, int64(-1); g != e {
338 t.Errorf("expected ContentLength of %d; got %d", e, g)
339 }
340 wantTE := []string{"chunked"}
341 if mode == http2Mode {
342 wantTE = nil
343 }
344 if !slices.Equal(res.TransferEncoding, wantTE) {
345 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
346 }
347 if got, haveCL := res.Header["Content-Length"]; haveCL {
348 t.Errorf("Unexpected Content-Length: %q", got)
349 }
350 }
351
352 type reqFunc func(c *Client, url string) (*Response, error)
353
354
355
356 type h12Compare struct {
357 Handler func(ResponseWriter, *Request)
358 ReqFunc reqFunc
359 CheckResponse func(proto string, res *Response)
360 EarlyCheckResponse func(proto string, res *Response)
361 Opts []any
362 }
363
364 func (tt h12Compare) reqFunc() reqFunc {
365 if tt.ReqFunc == nil {
366 return (*Client).Get
367 }
368 return tt.ReqFunc
369 }
370
371 func (tt h12Compare) run(t *testing.T) {
372 setParallel(t)
373 cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
374 defer cst1.close()
375 cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
376 defer cst2.close()
377
378 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
379 if err != nil {
380 t.Errorf("HTTP/1 request: %v", err)
381 return
382 }
383 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
384 if err != nil {
385 t.Errorf("HTTP/2 request: %v", err)
386 return
387 }
388
389 if fn := tt.EarlyCheckResponse; fn != nil {
390 fn("HTTP/1.1", res1)
391 fn("HTTP/2.0", res2)
392 }
393
394 tt.normalizeRes(t, res1, "HTTP/1.1")
395 tt.normalizeRes(t, res2, "HTTP/2.0")
396 res1body, res2body := res1.Body, res2.Body
397
398 eres1 := mostlyCopy(res1)
399 eres2 := mostlyCopy(res2)
400 if !reflect.DeepEqual(eres1, eres2) {
401 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
402 cst1.ts.URL, eres1, cst2.ts.URL, eres2)
403 }
404 if !reflect.DeepEqual(res1body, res2body) {
405 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
406 }
407 if fn := tt.CheckResponse; fn != nil {
408 res1.Body, res2.Body = res1body, res2body
409 fn("HTTP/1.1", res1)
410 fn("HTTP/2.0", res2)
411 }
412 }
413
414 func mostlyCopy(r *Response) *Response {
415 c := *r
416 c.Body = nil
417 c.TransferEncoding = nil
418 c.TLS = nil
419 c.Request = nil
420 return &c
421 }
422
423 type slurpResult struct {
424 io.ReadCloser
425 body []byte
426 err error
427 }
428
429 func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
430
431 func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
432 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
433 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
434 } else {
435 t.Errorf("got %q response; want %q", res.Proto, wantProto)
436 }
437 slurp, err := io.ReadAll(res.Body)
438
439 res.Body.Close()
440 res.Body = slurpResult{
441 ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
442 body: slurp,
443 err: err,
444 }
445 for i, v := range res.Header["Date"] {
446 res.Header["Date"][i] = strings.Repeat("x", len(v))
447 }
448 if res.Request == nil {
449 t.Errorf("for %s, no request", wantProto)
450 }
451 if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
452 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
453 }
454 }
455
456
457 func TestH12_HeadContentLengthNoBody(t *testing.T) {
458 h12Compare{
459 ReqFunc: (*Client).Head,
460 Handler: func(w ResponseWriter, r *Request) {
461 },
462 }.run(t)
463 }
464
465 func TestH12_HeadContentLengthSmallBody(t *testing.T) {
466 h12Compare{
467 ReqFunc: (*Client).Head,
468 Handler: func(w ResponseWriter, r *Request) {
469 io.WriteString(w, "small")
470 },
471 }.run(t)
472 }
473
474 func TestH12_HeadContentLengthLargeBody(t *testing.T) {
475 h12Compare{
476 ReqFunc: (*Client).Head,
477 Handler: func(w ResponseWriter, r *Request) {
478 chunk := strings.Repeat("x", 512<<10)
479 for i := 0; i < 10; i++ {
480 io.WriteString(w, chunk)
481 }
482 },
483 }.run(t)
484 }
485
486 func TestH12_200NoBody(t *testing.T) {
487 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
488 }
489
490 func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
491 func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
492 func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
493
494 func testH12_noBody(t *testing.T, status int) {
495 h12Compare{Handler: func(w ResponseWriter, r *Request) {
496 w.WriteHeader(status)
497 }}.run(t)
498 }
499
500 func TestH12_SmallBody(t *testing.T) {
501 h12Compare{Handler: func(w ResponseWriter, r *Request) {
502 io.WriteString(w, "small body")
503 }}.run(t)
504 }
505
506 func TestH12_ExplicitContentLength(t *testing.T) {
507 h12Compare{Handler: func(w ResponseWriter, r *Request) {
508 w.Header().Set("Content-Length", "3")
509 io.WriteString(w, "foo")
510 }}.run(t)
511 }
512
513 func TestH12_FlushBeforeBody(t *testing.T) {
514 h12Compare{Handler: func(w ResponseWriter, r *Request) {
515 w.(Flusher).Flush()
516 io.WriteString(w, "foo")
517 }}.run(t)
518 }
519
520 func TestH12_FlushMidBody(t *testing.T) {
521 h12Compare{Handler: func(w ResponseWriter, r *Request) {
522 io.WriteString(w, "foo")
523 w.(Flusher).Flush()
524 io.WriteString(w, "bar")
525 }}.run(t)
526 }
527
528 func TestH12_Head_ExplicitLen(t *testing.T) {
529 h12Compare{
530 ReqFunc: (*Client).Head,
531 Handler: func(w ResponseWriter, r *Request) {
532 if r.Method != "HEAD" {
533 t.Errorf("unexpected method %q", r.Method)
534 }
535 w.Header().Set("Content-Length", "1235")
536 },
537 }.run(t)
538 }
539
540 func TestH12_Head_ImplicitLen(t *testing.T) {
541 h12Compare{
542 ReqFunc: (*Client).Head,
543 Handler: func(w ResponseWriter, r *Request) {
544 if r.Method != "HEAD" {
545 t.Errorf("unexpected method %q", r.Method)
546 }
547 io.WriteString(w, "foo")
548 },
549 }.run(t)
550 }
551
552 func TestH12_HandlerWritesTooLittle(t *testing.T) {
553 h12Compare{
554 Handler: func(w ResponseWriter, r *Request) {
555 w.Header().Set("Content-Length", "3")
556 io.WriteString(w, "12")
557 },
558 CheckResponse: func(proto string, res *Response) {
559 sr, ok := res.Body.(slurpResult)
560 if !ok {
561 t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
562 return
563 }
564 if sr.err != io.ErrUnexpectedEOF {
565 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
566 }
567 if string(sr.body) != "12" {
568 t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
569 }
570 },
571 }.run(t)
572 }
573
574
575
576
577
578
579
580 func TestHandlerWritesTooMuch(t *testing.T) { run(t, testHandlerWritesTooMuch) }
581 func testHandlerWritesTooMuch(t *testing.T, mode testMode) {
582 wantBody := []byte("123")
583 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
584 rc := NewResponseController(w)
585 w.Header().Set("Content-Length", fmt.Sprintf("%v", len(wantBody)))
586 rc.Flush()
587 w.Write(wantBody)
588 rc.Flush()
589 n, err := io.WriteString(w, "x")
590 if err == nil {
591 err = rc.Flush()
592 }
593
594 if err == nil {
595 t.Errorf("for proto %q, final write = %v, %v; want _, some error", r.Proto, n, err)
596 }
597 }))
598
599 res, err := cst.c.Get(cst.ts.URL)
600 if err != nil {
601 t.Fatal(err)
602 }
603 defer res.Body.Close()
604
605 gotBody, _ := io.ReadAll(res.Body)
606 if !bytes.Equal(gotBody, wantBody) {
607 t.Fatalf("got response body: %q; want %q", gotBody, wantBody)
608 }
609 }
610
611
612
613 func TestH12_AutoGzip(t *testing.T) {
614 h12Compare{
615 Handler: func(w ResponseWriter, r *Request) {
616 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
617 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
618 }
619 w.Header().Set("Content-Encoding", "gzip")
620 gz := gzip.NewWriter(w)
621 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
622 gz.Close()
623 },
624 }.run(t)
625 }
626
627 func TestH12_AutoGzip_Disabled(t *testing.T) {
628 h12Compare{
629 Opts: []any{
630 func(tr *Transport) { tr.DisableCompression = true },
631 },
632 Handler: func(w ResponseWriter, r *Request) {
633 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
634 if ae := r.Header.Get("Accept-Encoding"); ae != "" {
635 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
636 }
637 },
638 }.run(t)
639 }
640
641
642
643
644 func Test304Responses(t *testing.T) { run(t, test304Responses) }
645 func test304Responses(t *testing.T, mode testMode) {
646 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
647 w.WriteHeader(StatusNotModified)
648 _, err := w.Write([]byte("illegal body"))
649 if err != ErrBodyNotAllowed {
650 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
651 }
652 }))
653 defer cst.close()
654 res, err := cst.c.Get(cst.ts.URL)
655 if err != nil {
656 t.Fatal(err)
657 }
658 if len(res.TransferEncoding) > 0 {
659 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
660 }
661 body, err := io.ReadAll(res.Body)
662 if err != nil {
663 t.Error(err)
664 }
665 if len(body) > 0 {
666 t.Errorf("got unexpected body %q", string(body))
667 }
668 }
669
670 func TestH12_ServerEmptyContentLength(t *testing.T) {
671 h12Compare{
672 Handler: func(w ResponseWriter, r *Request) {
673 w.Header()["Content-Type"] = []string{""}
674 io.WriteString(w, "<html><body>hi</body></html>")
675 },
676 }.run(t)
677 }
678
679 func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
680 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
681 }
682
683 func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
684 h12requestContentLength(t, func() io.Reader { return nil }, 0)
685 }
686
687 func TestH12_RequestContentLength_Unknown(t *testing.T) {
688 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
689 }
690
691 func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
692 h12Compare{
693 Handler: func(w ResponseWriter, r *Request) {
694 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
695 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
696 },
697 ReqFunc: func(c *Client, url string) (*Response, error) {
698 return c.Post(url, "text/plain", bodyfn())
699 },
700 CheckResponse: func(proto string, res *Response) {
701 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
702 t.Errorf("Proto %q got length %q; want %q", proto, got, want)
703 }
704 },
705 }.run(t)
706 }
707
708
709
710 func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
711 func testCancelRequestMidBody(t *testing.T, mode testMode) {
712 unblock := make(chan bool)
713 didFlush := make(chan bool, 1)
714 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
715 io.WriteString(w, "Hello")
716 w.(Flusher).Flush()
717 didFlush <- true
718 <-unblock
719 io.WriteString(w, ", world.")
720 }))
721 defer close(unblock)
722
723 req, _ := NewRequest("GET", cst.ts.URL, nil)
724 cancel := make(chan struct{})
725 req.Cancel = cancel
726
727 res, err := cst.c.Do(req)
728 if err != nil {
729 t.Fatal(err)
730 }
731 defer res.Body.Close()
732 <-didFlush
733
734
735
736 firstRead := make([]byte, 10)
737 n, err := res.Body.Read(firstRead)
738 if err != nil {
739 t.Fatal(err)
740 }
741 firstRead = firstRead[:n]
742
743 close(cancel)
744
745 rest, err := io.ReadAll(res.Body)
746 all := string(firstRead) + string(rest)
747 if all != "Hello" {
748 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
749 }
750 if err != ExportErrRequestCanceled {
751 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
752 }
753 }
754
755
756 func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
757 func testTrailersClientToServer(t *testing.T, mode testMode) {
758 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
759 slurp, err := io.ReadAll(r.Body)
760 if err != nil {
761 t.Errorf("Server reading request body: %v", err)
762 }
763 if string(slurp) != "foo" {
764 t.Errorf("Server read request body %q; want foo", slurp)
765 }
766 if r.Trailer == nil {
767 io.WriteString(w, "nil Trailer")
768 } else {
769 decl := slices.Sorted(maps.Keys(r.Trailer))
770 fmt.Fprintf(w, "decl: %v, vals: %s, %s",
771 decl,
772 r.Trailer.Get("Client-Trailer-A"),
773 r.Trailer.Get("Client-Trailer-B"))
774 }
775 }))
776
777 var req *Request
778 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
779 eofReaderFunc(func() {
780 req.Trailer["Client-Trailer-A"] = []string{"valuea"}
781 }),
782 strings.NewReader("foo"),
783 eofReaderFunc(func() {
784 req.Trailer["Client-Trailer-B"] = []string{"valueb"}
785 }),
786 ))
787 req.Trailer = Header{
788 "Client-Trailer-A": nil,
789 "Client-Trailer-B": nil,
790 }
791 req.ContentLength = -1
792 res, err := cst.c.Do(req)
793 if err != nil {
794 t.Fatal(err)
795 }
796 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
797 t.Error(err)
798 }
799 }
800
801
802 func TestTrailersServerToClient(t *testing.T) {
803 run(t, func(t *testing.T, mode testMode) {
804 testTrailersServerToClient(t, mode, false)
805 })
806 }
807 func TestTrailersServerToClientFlush(t *testing.T) {
808 run(t, func(t *testing.T, mode testMode) {
809 testTrailersServerToClient(t, mode, true)
810 })
811 }
812
813 func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
814 const body = "Some body"
815 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
816 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
817 w.Header().Add("Trailer", "Server-Trailer-C")
818
819 io.WriteString(w, body)
820 if flush {
821 w.(Flusher).Flush()
822 }
823
824
825
826
827
828 w.Header().Set("Server-Trailer-A", "valuea")
829 w.Header().Set("Server-Trailer-C", "valuec")
830 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
831 }))
832
833 res, err := cst.c.Get(cst.ts.URL)
834 if err != nil {
835 t.Fatal(err)
836 }
837
838 wantHeader := Header{
839 "Content-Type": {"text/plain; charset=utf-8"},
840 }
841 wantLen := -1
842 if mode == http2Mode && !flush {
843
844
845
846
847
848 wantLen = len(body)
849 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
850 }
851 if res.ContentLength != int64(wantLen) {
852 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
853 }
854
855 delete(res.Header, "Date")
856 if !reflect.DeepEqual(res.Header, wantHeader) {
857 t.Errorf("Header = %v; want %v", res.Header, wantHeader)
858 }
859
860 if got, want := res.Trailer, (Header{
861 "Server-Trailer-A": nil,
862 "Server-Trailer-B": nil,
863 "Server-Trailer-C": nil,
864 }); !reflect.DeepEqual(got, want) {
865 t.Errorf("Trailer before body read = %v; want %v", got, want)
866 }
867
868 if err := wantBody(res, nil, body); err != nil {
869 t.Fatal(err)
870 }
871
872 if got, want := res.Trailer, (Header{
873 "Server-Trailer-A": {"valuea"},
874 "Server-Trailer-B": nil,
875 "Server-Trailer-C": {"valuec"},
876 }); !reflect.DeepEqual(got, want) {
877 t.Errorf("Trailer after body read = %v; want %v", got, want)
878 }
879 }
880
881
882 func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
883 func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
884 const body = "Some body"
885 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
886 io.WriteString(w, body)
887 }))
888 res, err := cst.c.Get(cst.ts.URL)
889 if err != nil {
890 t.Fatal(err)
891 }
892 res.Body.Close()
893 data, err := io.ReadAll(res.Body)
894 if len(data) != 0 || err == nil {
895 t.Fatalf("ReadAll returned %q, %v; want error", data, err)
896 }
897 }
898
899 func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
900 func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
901 const reqBody = "some request body"
902 const resBody = "some response body"
903 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
904 var wg sync.WaitGroup
905 wg.Add(2)
906 didRead := make(chan bool, 1)
907
908 go func() {
909 defer wg.Done()
910 data, err := io.ReadAll(r.Body)
911 if string(data) != reqBody {
912 t.Errorf("Handler read %q; want %q", data, reqBody)
913 }
914 if err != nil {
915 t.Errorf("Handler Read: %v", err)
916 }
917 didRead <- true
918 }()
919
920 go func() {
921 defer wg.Done()
922 if mode != http2Mode {
923
924
925
926
927 <-didRead
928 }
929 io.WriteString(w, resBody)
930 }()
931 wg.Wait()
932 }))
933 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
934 req.Header.Add("Expect", "100-continue")
935 res, err := cst.c.Do(req)
936 if err != nil {
937 t.Fatal(err)
938 }
939 data, err := io.ReadAll(res.Body)
940 defer res.Body.Close()
941 if err != nil {
942 t.Fatal(err)
943 }
944 if string(data) != resBody {
945 t.Errorf("read %q; want %q", data, resBody)
946 }
947 }
948
949 func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
950 func testConnectRequest(t *testing.T, mode testMode) {
951 gotc := make(chan *Request, 1)
952 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
953 gotc <- r
954 }))
955
956 u, err := url.Parse(cst.ts.URL)
957 if err != nil {
958 t.Fatal(err)
959 }
960
961 tests := []struct {
962 req *Request
963 want string
964 }{
965 {
966 req: &Request{
967 Method: "CONNECT",
968 Header: Header{},
969 URL: u,
970 },
971 want: u.Host,
972 },
973 {
974 req: &Request{
975 Method: "CONNECT",
976 Header: Header{},
977 URL: u,
978 Host: "example.com:123",
979 },
980 want: "example.com:123",
981 },
982 }
983
984 for i, tt := range tests {
985 res, err := cst.c.Do(tt.req)
986 if err != nil {
987 t.Errorf("%d. RoundTrip = %v", i, err)
988 continue
989 }
990 res.Body.Close()
991 req := <-gotc
992 if req.Method != "CONNECT" {
993 t.Errorf("method = %q; want CONNECT", req.Method)
994 }
995 if req.Host != tt.want {
996 t.Errorf("Host = %q; want %q", req.Host, tt.want)
997 }
998 if req.URL.Host != tt.want {
999 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
1000 }
1001 }
1002 }
1003
1004 func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
1005 func testTransportUserAgent(t *testing.T, mode testMode) {
1006 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1007 fmt.Fprintf(w, "%q", r.Header["User-Agent"])
1008 }))
1009
1010 either := func(a, b string) string {
1011 if mode == http2Mode {
1012 return b
1013 }
1014 return a
1015 }
1016
1017 tests := []struct {
1018 setup func(*Request)
1019 want string
1020 }{
1021 {
1022 func(r *Request) {},
1023 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
1024 },
1025 {
1026 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
1027 `["foo/1.2.3"]`,
1028 },
1029 {
1030 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
1031 `["single"]`,
1032 },
1033 {
1034 func(r *Request) { r.Header.Set("User-Agent", "") },
1035 `[]`,
1036 },
1037 {
1038 func(r *Request) { r.Header["User-Agent"] = nil },
1039 `[]`,
1040 },
1041 }
1042 for i, tt := range tests {
1043 req, _ := NewRequest("GET", cst.ts.URL, nil)
1044 tt.setup(req)
1045 res, err := cst.c.Do(req)
1046 if err != nil {
1047 t.Errorf("%d. RoundTrip = %v", i, err)
1048 continue
1049 }
1050 slurp, err := io.ReadAll(res.Body)
1051 res.Body.Close()
1052 if err != nil {
1053 t.Errorf("%d. read body = %v", i, err)
1054 continue
1055 }
1056 if string(slurp) != tt.want {
1057 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
1058 }
1059 }
1060 }
1061
1062 func TestStarRequestMethod(t *testing.T) {
1063 for _, method := range []string{"FOO", "OPTIONS"} {
1064 t.Run(method, func(t *testing.T) {
1065 run(t, func(t *testing.T, mode testMode) {
1066 testStarRequest(t, method, mode)
1067 })
1068 })
1069 }
1070 }
1071 func testStarRequest(t *testing.T, method string, mode testMode) {
1072 gotc := make(chan *Request, 1)
1073 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1074 w.Header().Set("foo", "bar")
1075 gotc <- r
1076 w.(Flusher).Flush()
1077 }))
1078
1079 u, err := url.Parse(cst.ts.URL)
1080 if err != nil {
1081 t.Fatal(err)
1082 }
1083 u.Path = "*"
1084
1085 req := &Request{
1086 Method: method,
1087 Header: Header{},
1088 URL: u,
1089 }
1090
1091 res, err := cst.c.Do(req)
1092 if err != nil {
1093 t.Fatalf("RoundTrip = %v", err)
1094 }
1095 res.Body.Close()
1096
1097 wantFoo := "bar"
1098 wantLen := int64(-1)
1099 if method == "OPTIONS" {
1100 wantFoo = ""
1101 wantLen = 0
1102 }
1103 if res.StatusCode != 200 {
1104 t.Errorf("status code = %v; want %d", res.Status, 200)
1105 }
1106 if res.ContentLength != wantLen {
1107 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
1108 }
1109 if got := res.Header.Get("foo"); got != wantFoo {
1110 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
1111 }
1112 select {
1113 case req = <-gotc:
1114 default:
1115 req = nil
1116 }
1117 if req == nil {
1118 if method != "OPTIONS" {
1119 t.Fatalf("handler never got request")
1120 }
1121 return
1122 }
1123 if req.Method != method {
1124 t.Errorf("method = %q; want %q", req.Method, method)
1125 }
1126 if req.URL.Path != "*" {
1127 t.Errorf("URL.Path = %q; want *", req.URL.Path)
1128 }
1129 if req.RequestURI != "*" {
1130 t.Errorf("RequestURI = %q; want *", req.RequestURI)
1131 }
1132 }
1133
1134
1135 func TestTransportDiscardsUnneededConns(t *testing.T) {
1136 run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
1137 }
1138 func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
1139 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1140 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
1141 }))
1142 defer cst.close()
1143
1144 var numOpen, numClose int32
1145
1146 tlsConfig := &tls.Config{InsecureSkipVerify: true}
1147 tr := &Transport{
1148 TLSClientConfig: tlsConfig,
1149 DialTLS: func(_, addr string) (net.Conn, error) {
1150 time.Sleep(10 * time.Millisecond)
1151 rc, err := net.Dial("tcp", addr)
1152 if err != nil {
1153 return nil, err
1154 }
1155 atomic.AddInt32(&numOpen, 1)
1156 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
1157 return tls.Client(c, tlsConfig), nil
1158 },
1159 Protocols: &Protocols{},
1160 }
1161 tr.Protocols.SetHTTP2(true)
1162 defer tr.CloseIdleConnections()
1163
1164 c := &Client{Transport: tr}
1165
1166 const N = 10
1167 gotBody := make(chan string, N)
1168 var wg sync.WaitGroup
1169 for i := 0; i < N; i++ {
1170 wg.Add(1)
1171 go func() {
1172 defer wg.Done()
1173 resp, err := c.Get(cst.ts.URL)
1174 if err != nil {
1175
1176
1177 time.Sleep(10 * time.Millisecond)
1178 resp, err = c.Get(cst.ts.URL)
1179 if err != nil {
1180 t.Errorf("Get: %v", err)
1181 return
1182 }
1183 }
1184 defer resp.Body.Close()
1185 slurp, err := io.ReadAll(resp.Body)
1186 if err != nil {
1187 t.Error(err)
1188 }
1189 gotBody <- string(slurp)
1190 }()
1191 }
1192 wg.Wait()
1193 close(gotBody)
1194
1195 var last string
1196 for got := range gotBody {
1197 if last == "" {
1198 last = got
1199 continue
1200 }
1201 if got != last {
1202 t.Errorf("Response body changed: %q -> %q", last, got)
1203 }
1204 }
1205
1206 var open, close int32
1207 for i := 0; i < 150; i++ {
1208 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
1209 if open < 1 {
1210 t.Fatalf("open = %d; want at least", open)
1211 }
1212 if close == open-1 {
1213
1214 return
1215 }
1216 time.Sleep(10 * time.Millisecond)
1217 }
1218 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
1219 }
1220
1221
1222 func TestTransportGCRequest(t *testing.T) {
1223 run(t, func(t *testing.T, mode testMode) {
1224 t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
1225 t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
1226 })
1227 }
1228 func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
1229 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1230 io.ReadAll(r.Body)
1231 if body {
1232 io.WriteString(w, "Hello.")
1233 }
1234 }))
1235
1236 didGC := make(chan struct{})
1237 (func() {
1238 body := strings.NewReader("some body")
1239 req, _ := NewRequest("POST", cst.ts.URL, body)
1240 runtime.AddCleanup(req, func(ch chan struct{}) { close(ch) }, didGC)
1241 res, err := cst.c.Do(req)
1242 if err != nil {
1243 t.Fatal(err)
1244 }
1245 if _, err := io.ReadAll(res.Body); err != nil {
1246 t.Fatal(err)
1247 }
1248 if err := res.Body.Close(); err != nil {
1249 t.Fatal(err)
1250 }
1251 })()
1252 for {
1253 select {
1254 case <-didGC:
1255 return
1256 case <-time.After(1 * time.Millisecond):
1257 runtime.GC()
1258 }
1259 }
1260 }
1261
1262 func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
1263 func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
1264 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1265 fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
1266 }), optQuietLog)
1267 cst.tr.DisableKeepAlives = true
1268
1269 tests := []struct {
1270 key, val string
1271 ok bool
1272 }{
1273 {"Foo", "capital-key", true},
1274 {"Foo", "foo\x00bar", false},
1275 {"Foo", "two\nlines", false},
1276 {"bogus\nkey", "v", false},
1277 {"A space", "v", false},
1278 {"имя", "v", false},
1279 {"name", "валю", true},
1280 {"", "v", false},
1281 {"k", "", true},
1282 }
1283 for _, tt := range tests {
1284 dialedc := make(chan bool, 1)
1285 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
1286 dialedc <- true
1287 return net.Dial(netw, addr)
1288 }
1289 req, _ := NewRequest("GET", cst.ts.URL, nil)
1290 req.Header[tt.key] = []string{tt.val}
1291 res, err := cst.c.Do(req)
1292 var body []byte
1293 if err == nil {
1294 body, _ = io.ReadAll(res.Body)
1295 res.Body.Close()
1296 }
1297 var dialed bool
1298 select {
1299 case <-dialedc:
1300 dialed = true
1301 default:
1302 }
1303
1304 if !tt.ok && dialed {
1305 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
1306 } else if (err == nil) != tt.ok {
1307 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
1308 }
1309 }
1310 }
1311
1312 func TestInterruptWithPanic(t *testing.T) {
1313 run(t, func(t *testing.T, mode testMode) {
1314 t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
1315 t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) })
1316 t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
1317 }, testNotParallel)
1318 }
1319 func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
1320 const msg = "hello"
1321
1322 testDone := make(chan struct{})
1323 defer close(testDone)
1324
1325 var errorLog lockedBytesBuffer
1326 gotHeaders := make(chan bool, 1)
1327 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1328 io.WriteString(w, msg)
1329 w.(Flusher).Flush()
1330
1331 select {
1332 case <-gotHeaders:
1333 case <-testDone:
1334 }
1335 panic(panicValue)
1336 }), func(ts *httptest.Server) {
1337 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1338 })
1339 res, err := cst.c.Get(cst.ts.URL)
1340 if err != nil {
1341 t.Fatal(err)
1342 }
1343 gotHeaders <- true
1344 defer res.Body.Close()
1345 slurp, err := io.ReadAll(res.Body)
1346 if string(slurp) != msg {
1347 t.Errorf("client read %q; want %q", slurp, msg)
1348 }
1349 if err == nil {
1350 t.Errorf("client read all successfully; want some error")
1351 }
1352 logOutput := func() string {
1353 errorLog.Lock()
1354 defer errorLog.Unlock()
1355 return errorLog.String()
1356 }
1357 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
1358
1359 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
1360 gotLog := logOutput()
1361 if !wantStackLogged {
1362 if gotLog == "" {
1363 return true
1364 }
1365 t.Fatalf("want no log output; got: %s", gotLog)
1366 }
1367 if gotLog == "" {
1368 if d > 0 {
1369 t.Logf("wanted a stack trace logged; got nothing after %v", d)
1370 }
1371 return false
1372 }
1373 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
1374 if d > 0 {
1375 t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog)
1376 }
1377 return false
1378 }
1379 return true
1380 })
1381 }
1382
1383 type lockedBytesBuffer struct {
1384 sync.Mutex
1385 bytes.Buffer
1386 }
1387
1388 func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
1389 b.Lock()
1390 defer b.Unlock()
1391 return b.Buffer.Write(p)
1392 }
1393
1394
1395 func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
1396 h12Compare{
1397 Handler: func(w ResponseWriter, r *Request) {
1398 h := w.Header()
1399 h.Set("Content-Encoding", "gzip")
1400 h.Set("Content-Length", "23")
1401 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
1402 },
1403 EarlyCheckResponse: func(proto string, res *Response) {
1404 if !res.Uncompressed {
1405 t.Errorf("%s: expected Uncompressed to be set", proto)
1406 }
1407 dump, err := httputil.DumpResponse(res, true)
1408 if err != nil {
1409 t.Errorf("%s: DumpResponse: %v", proto, err)
1410 return
1411 }
1412 if strings.Contains(string(dump), "Connection: close") {
1413 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
1414 }
1415 if !strings.Contains(string(dump), "FOO") {
1416 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
1417 }
1418 },
1419 }.run(t)
1420 }
1421
1422
1423 func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
1424 func testCloseIdleConnections(t *testing.T, mode testMode) {
1425 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1426 w.Header().Set("X-Addr", r.RemoteAddr)
1427 }))
1428 get := func() string {
1429 res, err := cst.c.Get(cst.ts.URL)
1430 if err != nil {
1431 t.Fatal(err)
1432 }
1433 res.Body.Close()
1434 v := res.Header.Get("X-Addr")
1435 if v == "" {
1436 t.Fatal("didn't get X-Addr")
1437 }
1438 return v
1439 }
1440 a1 := get()
1441 cst.tr.CloseIdleConnections()
1442 a2 := get()
1443 if a1 == a2 {
1444 t.Errorf("didn't close connection")
1445 }
1446 }
1447
1448 type noteCloseConn struct {
1449 net.Conn
1450 closeFunc func()
1451 }
1452
1453 func (x noteCloseConn) Close() error {
1454 x.closeFunc()
1455 return x.Conn.Close()
1456 }
1457
1458 type testErrorReader struct{ t *testing.T }
1459
1460 func (r testErrorReader) Read(p []byte) (n int, err error) {
1461 r.t.Error("unexpected Read call")
1462 return 0, io.EOF
1463 }
1464
1465 func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
1466 func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
1467 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1468 w.WriteHeader(StatusUnauthorized)
1469 }))
1470
1471
1472 cst.tr.ExpectContinueTimeout = 10 * time.Second
1473
1474 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
1475 if err != nil {
1476 t.Fatal(err)
1477 }
1478 req.ContentLength = 0
1479 req.Header.Set("Expect", "100-continue")
1480 res, err := cst.tr.RoundTrip(req)
1481 if err != nil {
1482 t.Fatal(err)
1483 }
1484 defer res.Body.Close()
1485 if res.StatusCode != StatusUnauthorized {
1486 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
1487 }
1488 }
1489
1490 func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
1491 func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
1492 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1493 w.Header().Set("Foo", "Bar")
1494 w.Header().Set("Trailer:Foo", "Baz")
1495 w.(Flusher).Flush()
1496 w.Header().Add("Trailer:Foo", "Baz2")
1497 w.Header().Set("Trailer:Bar", "Quux")
1498 }))
1499 res, err := cst.c.Get(cst.ts.URL)
1500 if err != nil {
1501 t.Fatal(err)
1502 }
1503 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1504 t.Fatal(err)
1505 }
1506 res.Body.Close()
1507 delete(res.Header, "Date")
1508 delete(res.Header, "Content-Type")
1509
1510 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
1511 t.Errorf("Header = %#v; want %#v", res.Header, want)
1512 }
1513 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
1514 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
1515 }
1516 }
1517
1518 func TestBadResponseAfterReadingBody(t *testing.T) {
1519 run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
1520 }
1521 func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
1522 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1523 _, err := io.Copy(io.Discard, r.Body)
1524 if err != nil {
1525 t.Fatal(err)
1526 }
1527 c, _, err := w.(Hijacker).Hijack()
1528 if err != nil {
1529 t.Fatal(err)
1530 }
1531 defer c.Close()
1532 fmt.Fprintln(c, "some bogus crap")
1533 }))
1534
1535 closes := 0
1536 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
1537 if err == nil {
1538 res.Body.Close()
1539 t.Fatal("expected an error to be returned from Post")
1540 }
1541 if closes != 1 {
1542 t.Errorf("closes = %d; want 1", closes)
1543 }
1544 }
1545
1546 func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
1547 func testWriteHeader0(t *testing.T, mode testMode) {
1548 gotpanic := make(chan bool, 1)
1549 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1550 defer close(gotpanic)
1551 defer func() {
1552 if e := recover(); e != nil {
1553 got := fmt.Sprintf("%T, %v", e, e)
1554 want := "string, invalid WriteHeader code 0"
1555 if got != want {
1556 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
1557 }
1558 gotpanic <- true
1559
1560
1561
1562
1563 w.WriteHeader(503)
1564 }
1565 }()
1566 w.WriteHeader(0)
1567 }))
1568 res, err := cst.c.Get(cst.ts.URL)
1569 if err != nil {
1570 t.Fatal(err)
1571 }
1572 if res.StatusCode != 503 {
1573 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
1574 }
1575 if !<-gotpanic {
1576 t.Error("expected panic in handler")
1577 }
1578 }
1579
1580
1581
1582 func TestWriteHeaderNoCodeCheck(t *testing.T) {
1583 run(t, func(t *testing.T, mode testMode) {
1584 testWriteHeaderAfterWrite(t, mode, false)
1585 })
1586 }
1587 func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
1588 testWriteHeaderAfterWrite(t, http1Mode, true)
1589 }
1590 func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
1591 var errorLog lockedBytesBuffer
1592 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1593 if hijack {
1594 conn, _, _ := w.(Hijacker).Hijack()
1595 defer conn.Close()
1596 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
1597 w.WriteHeader(0)
1598 conn.Write([]byte("bar"))
1599 return
1600 }
1601 io.WriteString(w, "foo")
1602 w.(Flusher).Flush()
1603 w.WriteHeader(0)
1604 io.WriteString(w, "bar")
1605 }), func(ts *httptest.Server) {
1606 ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1607 })
1608 res, err := cst.c.Get(cst.ts.URL)
1609 if err != nil {
1610 t.Fatal(err)
1611 }
1612 defer res.Body.Close()
1613 body, err := io.ReadAll(res.Body)
1614 if err != nil {
1615 t.Fatal(err)
1616 }
1617 if got, want := string(body), "foobar"; got != want {
1618 t.Errorf("got = %q; want %q", got, want)
1619 }
1620
1621
1622 if mode == http2Mode {
1623
1624
1625 return
1626 }
1627 gotLog := strings.TrimSpace(errorLog.String())
1628 wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1629 if hijack {
1630 wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1631 }
1632 if !strings.HasPrefix(gotLog, wantLog) {
1633 t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
1634 }
1635 }
1636
1637 func TestBidiStreamReverseProxy(t *testing.T) {
1638 run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
1639 }
1640 func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
1641 backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1642 if _, err := io.Copy(w, r.Body); err != nil {
1643 log.Printf("bidi backend copy: %v", err)
1644 }
1645 }))
1646
1647 backURL, err := url.Parse(backend.ts.URL)
1648 if err != nil {
1649 t.Fatal(err)
1650 }
1651 rp := httputil.NewSingleHostReverseProxy(backURL)
1652 rp.Transport = backend.tr
1653 proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1654 rp.ServeHTTP(w, r)
1655 }))
1656
1657 bodyRes := make(chan any, 1)
1658 pr, pw := io.Pipe()
1659 req, _ := NewRequest("PUT", proxy.ts.URL, pr)
1660 const size = 4 << 20
1661 go func() {
1662 h := sha1.New()
1663 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
1664 go pw.Close()
1665 if err != nil {
1666 t.Errorf("body copy: %v", err)
1667 bodyRes <- err
1668 } else {
1669 bodyRes <- h
1670 }
1671 }()
1672 res, err := backend.c.Do(req)
1673 if err != nil {
1674 t.Fatal(err)
1675 }
1676 defer res.Body.Close()
1677 hgot := sha1.New()
1678 n, err := io.Copy(hgot, res.Body)
1679 if err != nil {
1680 t.Fatal(err)
1681 }
1682 if n != size {
1683 t.Fatalf("got %d bytes; want %d", n, size)
1684 }
1685 select {
1686 case v := <-bodyRes:
1687 switch v := v.(type) {
1688 default:
1689 t.Fatalf("body copy: %v", err)
1690 case hash.Hash:
1691 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
1692 t.Errorf("written bytes didn't match received bytes")
1693 }
1694 }
1695 case <-time.After(10 * time.Second):
1696 t.Fatal("timeout")
1697 }
1698
1699 }
1700
1701
1702 func TestH12_WebSocketUpgrade(t *testing.T) {
1703 h12Compare{
1704 Handler: func(w ResponseWriter, r *Request) {
1705 h := w.Header()
1706 h.Set("Foo", "bar")
1707 },
1708 ReqFunc: func(c *Client, url string) (*Response, error) {
1709 req, _ := NewRequest("GET", url, nil)
1710 req.Header.Set("Connection", "Upgrade")
1711 req.Header.Set("Upgrade", "WebSocket")
1712 return c.Do(req)
1713 },
1714 EarlyCheckResponse: func(proto string, res *Response) {
1715 if res.Proto != "HTTP/1.1" {
1716 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
1717 }
1718 res.Proto = "HTTP/IGNORE"
1719 },
1720 Opts: []any{
1721 func(s *Server) {
1722
1723
1724
1725 s.Protocols = &Protocols{}
1726 s.Protocols.SetHTTP1(true)
1727 s.Protocols.SetHTTP2(true)
1728 },
1729 },
1730 }.run(t)
1731 }
1732
1733 func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
1734 func testIdentityTransferEncoding(t *testing.T, mode testMode) {
1735 const body = "body"
1736 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1737 gotBody, _ := io.ReadAll(r.Body)
1738 if got, want := string(gotBody), body; got != want {
1739 t.Errorf("got request body = %q; want %q", got, want)
1740 }
1741 w.Header().Set("Transfer-Encoding", "identity")
1742 w.WriteHeader(StatusOK)
1743 w.(Flusher).Flush()
1744 io.WriteString(w, body)
1745 }))
1746 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
1747 res, err := cst.c.Do(req)
1748 if err != nil {
1749 t.Fatal(err)
1750 }
1751 defer res.Body.Close()
1752 gotBody, err := io.ReadAll(res.Body)
1753 if err != nil {
1754 t.Fatal(err)
1755 }
1756 if got, want := string(gotBody), body; got != want {
1757 t.Errorf("got response body = %q; want %q", got, want)
1758 }
1759 }
1760
1761 func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
1762 func testEarlyHintsRequest(t *testing.T, mode testMode) {
1763 var wg sync.WaitGroup
1764 wg.Add(1)
1765 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1766 h := w.Header()
1767
1768 h.Add("Content-Length", "123")
1769 h.Add("Link", "</style.css>; rel=preload; as=style")
1770 h.Add("Link", "</script.js>; rel=preload; as=script")
1771 w.WriteHeader(StatusEarlyHints)
1772
1773 wg.Wait()
1774
1775 h.Add("Link", "</foo.js>; rel=preload; as=script")
1776 w.WriteHeader(StatusEarlyHints)
1777
1778 w.Write([]byte("Hello"))
1779 }))
1780
1781 checkLinkHeaders := func(t *testing.T, expected, got []string) {
1782 t.Helper()
1783
1784 if len(expected) != len(got) {
1785 t.Errorf("got %d expected %d", len(got), len(expected))
1786 }
1787
1788 for i := range expected {
1789 if expected[i] != got[i] {
1790 t.Errorf("got %q expected %q", got[i], expected[i])
1791 }
1792 }
1793 }
1794
1795 checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
1796 t.Helper()
1797
1798 for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
1799 if v, ok := header[h]; ok {
1800 t.Errorf("%s is %q; must not be sent", h, v)
1801 }
1802 }
1803 }
1804
1805 var respCounter uint8
1806 trace := &httptrace.ClientTrace{
1807 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1808 switch respCounter {
1809 case 0:
1810 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
1811 checkExcludedHeaders(t, header)
1812
1813 wg.Done()
1814 case 1:
1815 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
1816 checkExcludedHeaders(t, header)
1817
1818 default:
1819 t.Error("Unexpected 1xx response")
1820 }
1821
1822 respCounter++
1823
1824 return nil
1825 },
1826 }
1827 req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
1828
1829 res, err := cst.c.Do(req)
1830 if err != nil {
1831 t.Fatal(err)
1832 }
1833 defer res.Body.Close()
1834
1835 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
1836 if cl := res.Header.Get("Content-Length"); cl != "123" {
1837 t.Errorf("Content-Length is %q; want 123", cl)
1838 }
1839
1840 body, _ := io.ReadAll(res.Body)
1841 if string(body) != "Hello" {
1842 t.Errorf("Read body %q; want Hello", body)
1843 }
1844 }
1845
View as plain text