Source file src/net/http/clientserver_test.go

     1  // Copyright 2015 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode.
     6  
     7  package http_test
     8  
     9  import (
    10  	"bytes"
    11  	"compress/gzip"
    12  	"context"
    13  	"crypto/rand"
    14  	"crypto/sha1"
    15  	"crypto/tls"
    16  	"fmt"
    17  	"hash"
    18  	"io"
    19  	"log"
    20  	"maps"
    21  	"net"
    22  	. "net/http"
    23  	"net/http/httptest"
    24  	"net/http/httptrace"
    25  	"net/http/httputil"
    26  	"net/textproto"
    27  	"net/url"
    28  	"os"
    29  	"reflect"
    30  	"runtime"
    31  	"slices"
    32  	"strings"
    33  	"sync"
    34  	"sync/atomic"
    35  	"testing"
    36  	"testing/synctest"
    37  	"time"
    38  )
    39  
    40  type testMode string
    41  
    42  const (
    43  	http1Mode            = testMode("h1")            // HTTP/1.1
    44  	https1Mode           = testMode("https1")        // HTTPS/1.1
    45  	http2Mode            = testMode("h2")            // HTTP/2
    46  	http2UnencryptedMode = testMode("h2unencrypted") // HTTP/2
    47  )
    48  
    49  func (m testMode) Scheme() string {
    50  	switch m {
    51  	case http1Mode, http2UnencryptedMode:
    52  		return "http"
    53  	case https1Mode, http2Mode:
    54  		return "https"
    55  	}
    56  	panic("unknown testMode")
    57  }
    58  
    59  type testNotParallelOpt struct{}
    60  
    61  var (
    62  	testNotParallel = testNotParallelOpt{}
    63  )
    64  
    65  type TBRun[T any] interface {
    66  	testing.TB
    67  	Run(string, func(T)) bool
    68  }
    69  
    70  // run runs a client/server test in a variety of test configurations.
    71  //
    72  // Tests execute in HTTP/1.1 and HTTP/2 modes by default.
    73  // To run in a different set of configurations, pass a []testMode option.
    74  //
    75  // Tests call t.Parallel() by default.
    76  // To disable parallel execution, pass the testNotParallel option.
    77  func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
    78  	t.Helper()
    79  	modes := []testMode{http1Mode, http2Mode}
    80  	parallel := true
    81  	for _, opt := range opts {
    82  		switch opt := opt.(type) {
    83  		case []testMode:
    84  			modes = opt
    85  		case testNotParallelOpt:
    86  			parallel = false
    87  		default:
    88  			t.Fatalf("unknown option type %T", opt)
    89  		}
    90  	}
    91  	if t, ok := any(t).(*testing.T); ok && parallel {
    92  		setParallel(t)
    93  	}
    94  	for _, mode := range modes {
    95  		t.Run(string(mode), func(t T) {
    96  			t.Helper()
    97  			if t, ok := any(t).(*testing.T); ok && parallel {
    98  				setParallel(t)
    99  			}
   100  			t.Cleanup(func() {
   101  				afterTest(t)
   102  			})
   103  			f(t, mode)
   104  		})
   105  	}
   106  }
   107  
   108  // runSynctest is run combined with synctest.Run.
   109  //
   110  // The TB passed to f arranges for cleanup functions to be run in the synctest bubble.
   111  func runSynctest(t *testing.T, f func(t *testing.T, mode testMode), opts ...any) {
   112  	run(t, func(t *testing.T, mode testMode) {
   113  		synctest.Test(t, func(t *testing.T) {
   114  			f(t, mode)
   115  		})
   116  	}, opts...)
   117  }
   118  
   119  type clientServerTest struct {
   120  	t  testing.TB
   121  	h2 bool
   122  	h  Handler
   123  	ts *httptest.Server
   124  	tr *Transport
   125  	c  *Client
   126  	li *fakeNetListener
   127  }
   128  
   129  func (t *clientServerTest) close() {
   130  	t.tr.CloseIdleConnections()
   131  	t.ts.Close()
   132  }
   133  
   134  func (t *clientServerTest) getURL(u string) string {
   135  	res, err := t.c.Get(u)
   136  	if err != nil {
   137  		t.t.Fatal(err)
   138  	}
   139  	defer res.Body.Close()
   140  	slurp, err := io.ReadAll(res.Body)
   141  	if err != nil {
   142  		t.t.Fatal(err)
   143  	}
   144  	return string(slurp)
   145  }
   146  
   147  func (t *clientServerTest) scheme() string {
   148  	if t.h2 {
   149  		return "https"
   150  	}
   151  	return "http"
   152  }
   153  
   154  var optQuietLog = func(ts *httptest.Server) {
   155  	ts.Config.ErrorLog = quietLog
   156  }
   157  
   158  func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
   159  	return func(ts *httptest.Server) {
   160  		ts.Config.ErrorLog = lg
   161  	}
   162  }
   163  
   164  var optFakeNet = new(struct{})
   165  
   166  // newClientServerTest creates and starts an httptest.Server.
   167  //
   168  // The mode parameter selects the implementation to test:
   169  // HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use
   170  // the 'run' function, which will start a subtests for each tested mode.
   171  //
   172  // The vararg opts parameter can include functions to configure the
   173  // test server or transport.
   174  //
   175  //	func(*httptest.Server) // run before starting the server
   176  //	func(*http.Transport)
   177  //
   178  // The optFakeNet option configures the server and client to use a fake network implementation,
   179  // suitable for use in testing/synctest tests.
   180  func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
   181  	if mode == http2Mode || mode == http2UnencryptedMode {
   182  		CondSkipHTTP2(t)
   183  	}
   184  	cst := &clientServerTest{
   185  		t:  t,
   186  		h2: mode == http2Mode,
   187  		h:  h,
   188  	}
   189  
   190  	var transportFuncs []func(*Transport)
   191  
   192  	if idx := slices.Index(opts, any(optFakeNet)); idx >= 0 {
   193  		opts = slices.Delete(opts, idx, idx+1)
   194  		cst.li = fakeNetListen()
   195  		cst.ts = &httptest.Server{
   196  			Config:   &Server{Handler: h},
   197  			Listener: cst.li,
   198  		}
   199  		transportFuncs = append(transportFuncs, func(tr *Transport) {
   200  			tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
   201  				return cst.li.connect(), nil
   202  			}
   203  		})
   204  	} else {
   205  		cst.ts = httptest.NewUnstartedServer(h)
   206  	}
   207  
   208  	for _, opt := range opts {
   209  		switch opt := opt.(type) {
   210  		case func(*Transport):
   211  			transportFuncs = append(transportFuncs, opt)
   212  		case func(*httptest.Server):
   213  			opt(cst.ts)
   214  		case func(*Server):
   215  			opt(cst.ts.Config)
   216  		default:
   217  			t.Fatalf("unhandled option type %T", opt)
   218  		}
   219  	}
   220  
   221  	if cst.ts.Config.ErrorLog == nil {
   222  		cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
   223  	}
   224  
   225  	p := &Protocols{}
   226  	if cst.ts.Config.Protocols == nil {
   227  		cst.ts.Config.Protocols = p
   228  	}
   229  	switch mode {
   230  	case http1Mode:
   231  		p.SetHTTP1(true)
   232  		cst.ts.Start()
   233  	case https1Mode:
   234  		p.SetHTTP1(true)
   235  		cst.ts.StartTLS()
   236  	case http2UnencryptedMode:
   237  		p.SetUnencryptedHTTP2(true)
   238  		cst.ts.Start()
   239  	case http2Mode:
   240  		p.SetHTTP2(true)
   241  		cst.ts.EnableHTTP2 = true
   242  		cst.ts.TLS = cst.ts.Config.TLSConfig
   243  		cst.ts.StartTLS()
   244  	default:
   245  		t.Fatalf("unknown test mode %v", mode)
   246  	}
   247  	cst.c = cst.ts.Client()
   248  	cst.tr = cst.c.Transport.(*Transport)
   249  	for _, f := range transportFuncs {
   250  		f(cst.tr)
   251  	}
   252  	if cst.tr.Protocols == nil {
   253  		cst.tr.Protocols = p
   254  	}
   255  
   256  	t.Cleanup(func() {
   257  		cst.close()
   258  	})
   259  	return cst
   260  }
   261  
   262  type testLogWriter struct {
   263  	t testing.TB
   264  }
   265  
   266  func (w testLogWriter) Write(b []byte) (int, error) {
   267  	w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
   268  	return len(b), nil
   269  }
   270  
   271  // Testing the newClientServerTest helper itself.
   272  func TestNewClientServerTest(t *testing.T) {
   273  	modes := []testMode{http1Mode, https1Mode, http2Mode}
   274  	t.Run("realnet", func(t *testing.T) {
   275  		run(t, func(t *testing.T, mode testMode) {
   276  			testNewClientServerTest(t, mode)
   277  		}, modes)
   278  	})
   279  	t.Run("synctest", func(t *testing.T) {
   280  		runSynctest(t, func(t *testing.T, mode testMode) {
   281  			testNewClientServerTest(t, mode, optFakeNet)
   282  		}, modes)
   283  	})
   284  }
   285  func testNewClientServerTest(t *testing.T, mode testMode, opts ...any) {
   286  	var got struct {
   287  		sync.Mutex
   288  		proto  string
   289  		hasTLS bool
   290  	}
   291  	h := HandlerFunc(func(w ResponseWriter, r *Request) {
   292  		got.Lock()
   293  		defer got.Unlock()
   294  		got.proto = r.Proto
   295  		got.hasTLS = r.TLS != nil
   296  	})
   297  	cst := newClientServerTest(t, mode, h, opts...)
   298  	if _, err := cst.c.Head(cst.ts.URL); err != nil {
   299  		t.Fatal(err)
   300  	}
   301  	var wantProto string
   302  	var wantTLS bool
   303  	switch mode {
   304  	case http1Mode:
   305  		wantProto = "HTTP/1.1"
   306  		wantTLS = false
   307  	case https1Mode:
   308  		wantProto = "HTTP/1.1"
   309  		wantTLS = true
   310  	case http2Mode:
   311  		wantProto = "HTTP/2.0"
   312  		wantTLS = true
   313  	}
   314  	if got.proto != wantProto {
   315  		t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
   316  	}
   317  	if got.hasTLS != wantTLS {
   318  		t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
   319  	}
   320  }
   321  
   322  func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
   323  func testChunkedResponseHeaders(t *testing.T, mode testMode) {
   324  	log.SetOutput(io.Discard) // is noisy otherwise
   325  	defer log.SetOutput(os.Stderr)
   326  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   327  		w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
   328  		w.(Flusher).Flush()
   329  		fmt.Fprintf(w, "I am a chunked response.")
   330  	}))
   331  
   332  	res, err := cst.c.Get(cst.ts.URL)
   333  	if err != nil {
   334  		t.Fatalf("Get error: %v", err)
   335  	}
   336  	defer res.Body.Close()
   337  	if g, e := res.ContentLength, int64(-1); g != e {
   338  		t.Errorf("expected ContentLength of %d; got %d", e, g)
   339  	}
   340  	wantTE := []string{"chunked"}
   341  	if mode == http2Mode {
   342  		wantTE = nil
   343  	}
   344  	if !slices.Equal(res.TransferEncoding, wantTE) {
   345  		t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
   346  	}
   347  	if got, haveCL := res.Header["Content-Length"]; haveCL {
   348  		t.Errorf("Unexpected Content-Length: %q", got)
   349  	}
   350  }
   351  
   352  type reqFunc func(c *Client, url string) (*Response, error)
   353  
   354  // h12Compare is a test that compares HTTP/1 and HTTP/2 behavior
   355  // against each other.
   356  type h12Compare struct {
   357  	Handler            func(ResponseWriter, *Request)    // required
   358  	ReqFunc            reqFunc                           // optional
   359  	CheckResponse      func(proto string, res *Response) // optional
   360  	EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize
   361  	Opts               []any
   362  }
   363  
   364  func (tt h12Compare) reqFunc() reqFunc {
   365  	if tt.ReqFunc == nil {
   366  		return (*Client).Get
   367  	}
   368  	return tt.ReqFunc
   369  }
   370  
   371  func (tt h12Compare) run(t *testing.T) {
   372  	setParallel(t)
   373  	cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
   374  	defer cst1.close()
   375  	cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
   376  	defer cst2.close()
   377  
   378  	res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
   379  	if err != nil {
   380  		t.Errorf("HTTP/1 request: %v", err)
   381  		return
   382  	}
   383  	res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
   384  	if err != nil {
   385  		t.Errorf("HTTP/2 request: %v", err)
   386  		return
   387  	}
   388  
   389  	if fn := tt.EarlyCheckResponse; fn != nil {
   390  		fn("HTTP/1.1", res1)
   391  		fn("HTTP/2.0", res2)
   392  	}
   393  
   394  	tt.normalizeRes(t, res1, "HTTP/1.1")
   395  	tt.normalizeRes(t, res2, "HTTP/2.0")
   396  	res1body, res2body := res1.Body, res2.Body
   397  
   398  	eres1 := mostlyCopy(res1)
   399  	eres2 := mostlyCopy(res2)
   400  	if !reflect.DeepEqual(eres1, eres2) {
   401  		t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
   402  			cst1.ts.URL, eres1, cst2.ts.URL, eres2)
   403  	}
   404  	if !reflect.DeepEqual(res1body, res2body) {
   405  		t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
   406  	}
   407  	if fn := tt.CheckResponse; fn != nil {
   408  		res1.Body, res2.Body = res1body, res2body
   409  		fn("HTTP/1.1", res1)
   410  		fn("HTTP/2.0", res2)
   411  	}
   412  }
   413  
   414  func mostlyCopy(r *Response) *Response {
   415  	c := *r
   416  	c.Body = nil
   417  	c.TransferEncoding = nil
   418  	c.TLS = nil
   419  	c.Request = nil
   420  	return &c
   421  }
   422  
   423  type slurpResult struct {
   424  	io.ReadCloser
   425  	body []byte
   426  	err  error
   427  }
   428  
   429  func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
   430  
   431  func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
   432  	if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
   433  		res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
   434  	} else {
   435  		t.Errorf("got %q response; want %q", res.Proto, wantProto)
   436  	}
   437  	slurp, err := io.ReadAll(res.Body)
   438  
   439  	res.Body.Close()
   440  	res.Body = slurpResult{
   441  		ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
   442  		body:       slurp,
   443  		err:        err,
   444  	}
   445  	for i, v := range res.Header["Date"] {
   446  		res.Header["Date"][i] = strings.Repeat("x", len(v))
   447  	}
   448  	if res.Request == nil {
   449  		t.Errorf("for %s, no request", wantProto)
   450  	}
   451  	if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
   452  		t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
   453  	}
   454  }
   455  
   456  // Issue 13532
   457  func TestH12_HeadContentLengthNoBody(t *testing.T) {
   458  	h12Compare{
   459  		ReqFunc: (*Client).Head,
   460  		Handler: func(w ResponseWriter, r *Request) {
   461  		},
   462  	}.run(t)
   463  }
   464  
   465  func TestH12_HeadContentLengthSmallBody(t *testing.T) {
   466  	h12Compare{
   467  		ReqFunc: (*Client).Head,
   468  		Handler: func(w ResponseWriter, r *Request) {
   469  			io.WriteString(w, "small")
   470  		},
   471  	}.run(t)
   472  }
   473  
   474  func TestH12_HeadContentLengthLargeBody(t *testing.T) {
   475  	h12Compare{
   476  		ReqFunc: (*Client).Head,
   477  		Handler: func(w ResponseWriter, r *Request) {
   478  			chunk := strings.Repeat("x", 512<<10)
   479  			for i := 0; i < 10; i++ {
   480  				io.WriteString(w, chunk)
   481  			}
   482  		},
   483  	}.run(t)
   484  }
   485  
   486  func TestH12_200NoBody(t *testing.T) {
   487  	h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
   488  }
   489  
   490  func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
   491  func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
   492  func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
   493  
   494  func testH12_noBody(t *testing.T, status int) {
   495  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   496  		w.WriteHeader(status)
   497  	}}.run(t)
   498  }
   499  
   500  func TestH12_SmallBody(t *testing.T) {
   501  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   502  		io.WriteString(w, "small body")
   503  	}}.run(t)
   504  }
   505  
   506  func TestH12_ExplicitContentLength(t *testing.T) {
   507  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   508  		w.Header().Set("Content-Length", "3")
   509  		io.WriteString(w, "foo")
   510  	}}.run(t)
   511  }
   512  
   513  func TestH12_FlushBeforeBody(t *testing.T) {
   514  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   515  		w.(Flusher).Flush()
   516  		io.WriteString(w, "foo")
   517  	}}.run(t)
   518  }
   519  
   520  func TestH12_FlushMidBody(t *testing.T) {
   521  	h12Compare{Handler: func(w ResponseWriter, r *Request) {
   522  		io.WriteString(w, "foo")
   523  		w.(Flusher).Flush()
   524  		io.WriteString(w, "bar")
   525  	}}.run(t)
   526  }
   527  
   528  func TestH12_Head_ExplicitLen(t *testing.T) {
   529  	h12Compare{
   530  		ReqFunc: (*Client).Head,
   531  		Handler: func(w ResponseWriter, r *Request) {
   532  			if r.Method != "HEAD" {
   533  				t.Errorf("unexpected method %q", r.Method)
   534  			}
   535  			w.Header().Set("Content-Length", "1235")
   536  		},
   537  	}.run(t)
   538  }
   539  
   540  func TestH12_Head_ImplicitLen(t *testing.T) {
   541  	h12Compare{
   542  		ReqFunc: (*Client).Head,
   543  		Handler: func(w ResponseWriter, r *Request) {
   544  			if r.Method != "HEAD" {
   545  				t.Errorf("unexpected method %q", r.Method)
   546  			}
   547  			io.WriteString(w, "foo")
   548  		},
   549  	}.run(t)
   550  }
   551  
   552  func TestH12_HandlerWritesTooLittle(t *testing.T) {
   553  	h12Compare{
   554  		Handler: func(w ResponseWriter, r *Request) {
   555  			w.Header().Set("Content-Length", "3")
   556  			io.WriteString(w, "12") // one byte short
   557  		},
   558  		CheckResponse: func(proto string, res *Response) {
   559  			sr, ok := res.Body.(slurpResult)
   560  			if !ok {
   561  				t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
   562  				return
   563  			}
   564  			if sr.err != io.ErrUnexpectedEOF {
   565  				t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
   566  			}
   567  			if string(sr.body) != "12" {
   568  				t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
   569  			}
   570  		},
   571  	}.run(t)
   572  }
   573  
   574  // Tests that the HTTP/1 and HTTP/2 servers prevent handlers from
   575  // writing more than they declared. This test does not test whether
   576  // the transport deals with too much data, though, since the server
   577  // doesn't make it possible to send bogus data. For those tests, see
   578  // transport_test.go (for HTTP/1) or x/net/http2/transport_test.go
   579  // (for HTTP/2).
   580  func TestHandlerWritesTooMuch(t *testing.T) { run(t, testHandlerWritesTooMuch) }
   581  func testHandlerWritesTooMuch(t *testing.T, mode testMode) {
   582  	wantBody := []byte("123")
   583  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   584  		rc := NewResponseController(w)
   585  		w.Header().Set("Content-Length", fmt.Sprintf("%v", len(wantBody)))
   586  		rc.Flush()
   587  		w.Write(wantBody)
   588  		rc.Flush()
   589  		n, err := io.WriteString(w, "x") // too many
   590  		if err == nil {
   591  			err = rc.Flush()
   592  		}
   593  		// TODO: Check that this is ErrContentLength, not just any error.
   594  		if err == nil {
   595  			t.Errorf("for proto %q, final write = %v, %v; want _, some error", r.Proto, n, err)
   596  		}
   597  	}))
   598  
   599  	res, err := cst.c.Get(cst.ts.URL)
   600  	if err != nil {
   601  		t.Fatal(err)
   602  	}
   603  	defer res.Body.Close()
   604  
   605  	gotBody, _ := io.ReadAll(res.Body)
   606  	if !bytes.Equal(gotBody, wantBody) {
   607  		t.Fatalf("got response body: %q; want %q", gotBody, wantBody)
   608  	}
   609  }
   610  
   611  // Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip.
   612  // Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298
   613  func TestH12_AutoGzip(t *testing.T) {
   614  	h12Compare{
   615  		Handler: func(w ResponseWriter, r *Request) {
   616  			if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
   617  				t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
   618  			}
   619  			w.Header().Set("Content-Encoding", "gzip")
   620  			gz := gzip.NewWriter(w)
   621  			io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
   622  			gz.Close()
   623  		},
   624  	}.run(t)
   625  }
   626  
   627  func TestH12_AutoGzip_Disabled(t *testing.T) {
   628  	h12Compare{
   629  		Opts: []any{
   630  			func(tr *Transport) { tr.DisableCompression = true },
   631  		},
   632  		Handler: func(w ResponseWriter, r *Request) {
   633  			fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
   634  			if ae := r.Header.Get("Accept-Encoding"); ae != "" {
   635  				t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
   636  			}
   637  		},
   638  	}.run(t)
   639  }
   640  
   641  // Test304Responses verifies that 304s don't declare that they're
   642  // chunking in their response headers and aren't allowed to produce
   643  // output.
   644  func Test304Responses(t *testing.T) { run(t, test304Responses) }
   645  func test304Responses(t *testing.T, mode testMode) {
   646  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   647  		w.WriteHeader(StatusNotModified)
   648  		_, err := w.Write([]byte("illegal body"))
   649  		if err != ErrBodyNotAllowed {
   650  			t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
   651  		}
   652  	}))
   653  	defer cst.close()
   654  	res, err := cst.c.Get(cst.ts.URL)
   655  	if err != nil {
   656  		t.Fatal(err)
   657  	}
   658  	if len(res.TransferEncoding) > 0 {
   659  		t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
   660  	}
   661  	body, err := io.ReadAll(res.Body)
   662  	if err != nil {
   663  		t.Error(err)
   664  	}
   665  	if len(body) > 0 {
   666  		t.Errorf("got unexpected body %q", string(body))
   667  	}
   668  }
   669  
   670  func TestH12_ServerEmptyContentLength(t *testing.T) {
   671  	h12Compare{
   672  		Handler: func(w ResponseWriter, r *Request) {
   673  			w.Header()["Content-Type"] = []string{""}
   674  			io.WriteString(w, "<html><body>hi</body></html>")
   675  		},
   676  	}.run(t)
   677  }
   678  
   679  func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
   680  	h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
   681  }
   682  
   683  func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
   684  	h12requestContentLength(t, func() io.Reader { return nil }, 0)
   685  }
   686  
   687  func TestH12_RequestContentLength_Unknown(t *testing.T) {
   688  	h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
   689  }
   690  
   691  func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
   692  	h12Compare{
   693  		Handler: func(w ResponseWriter, r *Request) {
   694  			w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
   695  			fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
   696  		},
   697  		ReqFunc: func(c *Client, url string) (*Response, error) {
   698  			return c.Post(url, "text/plain", bodyfn())
   699  		},
   700  		CheckResponse: func(proto string, res *Response) {
   701  			if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
   702  				t.Errorf("Proto %q got length %q; want %q", proto, got, want)
   703  			}
   704  		},
   705  	}.run(t)
   706  }
   707  
   708  // Tests that closing the Request.Cancel channel also while still
   709  // reading the response body. Issue 13159.
   710  func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
   711  func testCancelRequestMidBody(t *testing.T, mode testMode) {
   712  	unblock := make(chan bool)
   713  	didFlush := make(chan bool, 1)
   714  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   715  		io.WriteString(w, "Hello")
   716  		w.(Flusher).Flush()
   717  		didFlush <- true
   718  		<-unblock
   719  		io.WriteString(w, ", world.")
   720  	}))
   721  	defer close(unblock)
   722  
   723  	req, _ := NewRequest("GET", cst.ts.URL, nil)
   724  	cancel := make(chan struct{})
   725  	req.Cancel = cancel
   726  
   727  	res, err := cst.c.Do(req)
   728  	if err != nil {
   729  		t.Fatal(err)
   730  	}
   731  	defer res.Body.Close()
   732  	<-didFlush
   733  
   734  	// Read a bit before we cancel. (Issue 13626)
   735  	// We should have "Hello" at least sitting there.
   736  	firstRead := make([]byte, 10)
   737  	n, err := res.Body.Read(firstRead)
   738  	if err != nil {
   739  		t.Fatal(err)
   740  	}
   741  	firstRead = firstRead[:n]
   742  
   743  	close(cancel)
   744  
   745  	rest, err := io.ReadAll(res.Body)
   746  	all := string(firstRead) + string(rest)
   747  	if all != "Hello" {
   748  		t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
   749  	}
   750  	if err != ExportErrRequestCanceled {
   751  		t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
   752  	}
   753  }
   754  
   755  // Tests that clients can send trailers to a server and that the server can read them.
   756  func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
   757  func testTrailersClientToServer(t *testing.T, mode testMode) {
   758  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   759  		slurp, err := io.ReadAll(r.Body)
   760  		if err != nil {
   761  			t.Errorf("Server reading request body: %v", err)
   762  		}
   763  		if string(slurp) != "foo" {
   764  			t.Errorf("Server read request body %q; want foo", slurp)
   765  		}
   766  		if r.Trailer == nil {
   767  			io.WriteString(w, "nil Trailer")
   768  		} else {
   769  			decl := slices.Sorted(maps.Keys(r.Trailer))
   770  			fmt.Fprintf(w, "decl: %v, vals: %s, %s",
   771  				decl,
   772  				r.Trailer.Get("Client-Trailer-A"),
   773  				r.Trailer.Get("Client-Trailer-B"))
   774  		}
   775  	}))
   776  
   777  	var req *Request
   778  	req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
   779  		eofReaderFunc(func() {
   780  			req.Trailer["Client-Trailer-A"] = []string{"valuea"}
   781  		}),
   782  		strings.NewReader("foo"),
   783  		eofReaderFunc(func() {
   784  			req.Trailer["Client-Trailer-B"] = []string{"valueb"}
   785  		}),
   786  	))
   787  	req.Trailer = Header{
   788  		"Client-Trailer-A": nil, //  to be set later
   789  		"Client-Trailer-B": nil, //  to be set later
   790  	}
   791  	req.ContentLength = -1
   792  	res, err := cst.c.Do(req)
   793  	if err != nil {
   794  		t.Fatal(err)
   795  	}
   796  	if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
   797  		t.Error(err)
   798  	}
   799  }
   800  
   801  // Tests that servers send trailers to a client and that the client can read them.
   802  func TestTrailersServerToClient(t *testing.T) {
   803  	run(t, func(t *testing.T, mode testMode) {
   804  		testTrailersServerToClient(t, mode, false)
   805  	})
   806  }
   807  func TestTrailersServerToClientFlush(t *testing.T) {
   808  	run(t, func(t *testing.T, mode testMode) {
   809  		testTrailersServerToClient(t, mode, true)
   810  	})
   811  }
   812  
   813  func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
   814  	const body = "Some body"
   815  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   816  		w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
   817  		w.Header().Add("Trailer", "Server-Trailer-C")
   818  
   819  		io.WriteString(w, body)
   820  		if flush {
   821  			w.(Flusher).Flush()
   822  		}
   823  
   824  		// How handlers set Trailers: declare it ahead of time
   825  		// with the Trailer header, and then mutate the
   826  		// Header() of those values later, after the response
   827  		// has been written (we wrote to w above).
   828  		w.Header().Set("Server-Trailer-A", "valuea")
   829  		w.Header().Set("Server-Trailer-C", "valuec") // skipping B
   830  		w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
   831  	}))
   832  
   833  	res, err := cst.c.Get(cst.ts.URL)
   834  	if err != nil {
   835  		t.Fatal(err)
   836  	}
   837  
   838  	wantHeader := Header{
   839  		"Content-Type": {"text/plain; charset=utf-8"},
   840  	}
   841  	wantLen := -1
   842  	if mode == http2Mode && !flush {
   843  		// In HTTP/1.1, any use of trailers forces HTTP/1.1
   844  		// chunking and a flush at the first write. That's
   845  		// unnecessary with HTTP/2's framing, so the server
   846  		// is able to calculate the length while still sending
   847  		// trailers afterwards.
   848  		wantLen = len(body)
   849  		wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
   850  	}
   851  	if res.ContentLength != int64(wantLen) {
   852  		t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
   853  	}
   854  
   855  	delete(res.Header, "Date") // irrelevant for test
   856  	if !reflect.DeepEqual(res.Header, wantHeader) {
   857  		t.Errorf("Header = %v; want %v", res.Header, wantHeader)
   858  	}
   859  
   860  	if got, want := res.Trailer, (Header{
   861  		"Server-Trailer-A": nil,
   862  		"Server-Trailer-B": nil,
   863  		"Server-Trailer-C": nil,
   864  	}); !reflect.DeepEqual(got, want) {
   865  		t.Errorf("Trailer before body read = %v; want %v", got, want)
   866  	}
   867  
   868  	if err := wantBody(res, nil, body); err != nil {
   869  		t.Fatal(err)
   870  	}
   871  
   872  	if got, want := res.Trailer, (Header{
   873  		"Server-Trailer-A": {"valuea"},
   874  		"Server-Trailer-B": nil,
   875  		"Server-Trailer-C": {"valuec"},
   876  	}); !reflect.DeepEqual(got, want) {
   877  		t.Errorf("Trailer after body read = %v; want %v", got, want)
   878  	}
   879  }
   880  
   881  // Don't allow a Body.Read after Body.Close. Issue 13648.
   882  func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
   883  func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
   884  	const body = "Some body"
   885  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   886  		io.WriteString(w, body)
   887  	}))
   888  	res, err := cst.c.Get(cst.ts.URL)
   889  	if err != nil {
   890  		t.Fatal(err)
   891  	}
   892  	res.Body.Close()
   893  	data, err := io.ReadAll(res.Body)
   894  	if len(data) != 0 || err == nil {
   895  		t.Fatalf("ReadAll returned %q, %v; want error", data, err)
   896  	}
   897  }
   898  
   899  func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
   900  func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
   901  	const reqBody = "some request body"
   902  	const resBody = "some response body"
   903  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   904  		var wg sync.WaitGroup
   905  		wg.Add(2)
   906  		didRead := make(chan bool, 1)
   907  		// Read in one goroutine.
   908  		go func() {
   909  			defer wg.Done()
   910  			data, err := io.ReadAll(r.Body)
   911  			if string(data) != reqBody {
   912  				t.Errorf("Handler read %q; want %q", data, reqBody)
   913  			}
   914  			if err != nil {
   915  				t.Errorf("Handler Read: %v", err)
   916  			}
   917  			didRead <- true
   918  		}()
   919  		// Write in another goroutine.
   920  		go func() {
   921  			defer wg.Done()
   922  			if mode != http2Mode {
   923  				// our HTTP/1 implementation intentionally
   924  				// doesn't permit writes during read (mostly
   925  				// due to it being undefined); if that is ever
   926  				// relaxed, change this.
   927  				<-didRead
   928  			}
   929  			io.WriteString(w, resBody)
   930  		}()
   931  		wg.Wait()
   932  	}))
   933  	req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
   934  	req.Header.Add("Expect", "100-continue") // just to complicate things
   935  	res, err := cst.c.Do(req)
   936  	if err != nil {
   937  		t.Fatal(err)
   938  	}
   939  	data, err := io.ReadAll(res.Body)
   940  	defer res.Body.Close()
   941  	if err != nil {
   942  		t.Fatal(err)
   943  	}
   944  	if string(data) != resBody {
   945  		t.Errorf("read %q; want %q", data, resBody)
   946  	}
   947  }
   948  
   949  func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
   950  func testConnectRequest(t *testing.T, mode testMode) {
   951  	gotc := make(chan *Request, 1)
   952  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
   953  		gotc <- r
   954  	}))
   955  
   956  	u, err := url.Parse(cst.ts.URL)
   957  	if err != nil {
   958  		t.Fatal(err)
   959  	}
   960  
   961  	tests := []struct {
   962  		req  *Request
   963  		want string
   964  	}{
   965  		{
   966  			req: &Request{
   967  				Method: "CONNECT",
   968  				Header: Header{},
   969  				URL:    u,
   970  			},
   971  			want: u.Host,
   972  		},
   973  		{
   974  			req: &Request{
   975  				Method: "CONNECT",
   976  				Header: Header{},
   977  				URL:    u,
   978  				Host:   "example.com:123",
   979  			},
   980  			want: "example.com:123",
   981  		},
   982  	}
   983  
   984  	for i, tt := range tests {
   985  		res, err := cst.c.Do(tt.req)
   986  		if err != nil {
   987  			t.Errorf("%d. RoundTrip = %v", i, err)
   988  			continue
   989  		}
   990  		res.Body.Close()
   991  		req := <-gotc
   992  		if req.Method != "CONNECT" {
   993  			t.Errorf("method = %q; want CONNECT", req.Method)
   994  		}
   995  		if req.Host != tt.want {
   996  			t.Errorf("Host = %q; want %q", req.Host, tt.want)
   997  		}
   998  		if req.URL.Host != tt.want {
   999  			t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
  1000  		}
  1001  	}
  1002  }
  1003  
  1004  func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
  1005  func testTransportUserAgent(t *testing.T, mode testMode) {
  1006  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1007  		fmt.Fprintf(w, "%q", r.Header["User-Agent"])
  1008  	}))
  1009  
  1010  	either := func(a, b string) string {
  1011  		if mode == http2Mode {
  1012  			return b
  1013  		}
  1014  		return a
  1015  	}
  1016  
  1017  	tests := []struct {
  1018  		setup func(*Request)
  1019  		want  string
  1020  	}{
  1021  		{
  1022  			func(r *Request) {},
  1023  			either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
  1024  		},
  1025  		{
  1026  			func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
  1027  			`["foo/1.2.3"]`,
  1028  		},
  1029  		{
  1030  			func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
  1031  			`["single"]`,
  1032  		},
  1033  		{
  1034  			func(r *Request) { r.Header.Set("User-Agent", "") },
  1035  			`[]`,
  1036  		},
  1037  		{
  1038  			func(r *Request) { r.Header["User-Agent"] = nil },
  1039  			`[]`,
  1040  		},
  1041  	}
  1042  	for i, tt := range tests {
  1043  		req, _ := NewRequest("GET", cst.ts.URL, nil)
  1044  		tt.setup(req)
  1045  		res, err := cst.c.Do(req)
  1046  		if err != nil {
  1047  			t.Errorf("%d. RoundTrip = %v", i, err)
  1048  			continue
  1049  		}
  1050  		slurp, err := io.ReadAll(res.Body)
  1051  		res.Body.Close()
  1052  		if err != nil {
  1053  			t.Errorf("%d. read body = %v", i, err)
  1054  			continue
  1055  		}
  1056  		if string(slurp) != tt.want {
  1057  			t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
  1058  		}
  1059  	}
  1060  }
  1061  
  1062  func TestStarRequestMethod(t *testing.T) {
  1063  	for _, method := range []string{"FOO", "OPTIONS"} {
  1064  		t.Run(method, func(t *testing.T) {
  1065  			run(t, func(t *testing.T, mode testMode) {
  1066  				testStarRequest(t, method, mode)
  1067  			})
  1068  		})
  1069  	}
  1070  }
  1071  func testStarRequest(t *testing.T, method string, mode testMode) {
  1072  	gotc := make(chan *Request, 1)
  1073  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1074  		w.Header().Set("foo", "bar")
  1075  		gotc <- r
  1076  		w.(Flusher).Flush()
  1077  	}))
  1078  
  1079  	u, err := url.Parse(cst.ts.URL)
  1080  	if err != nil {
  1081  		t.Fatal(err)
  1082  	}
  1083  	u.Path = "*"
  1084  
  1085  	req := &Request{
  1086  		Method: method,
  1087  		Header: Header{},
  1088  		URL:    u,
  1089  	}
  1090  
  1091  	res, err := cst.c.Do(req)
  1092  	if err != nil {
  1093  		t.Fatalf("RoundTrip = %v", err)
  1094  	}
  1095  	res.Body.Close()
  1096  
  1097  	wantFoo := "bar"
  1098  	wantLen := int64(-1)
  1099  	if method == "OPTIONS" {
  1100  		wantFoo = ""
  1101  		wantLen = 0
  1102  	}
  1103  	if res.StatusCode != 200 {
  1104  		t.Errorf("status code = %v; want %d", res.Status, 200)
  1105  	}
  1106  	if res.ContentLength != wantLen {
  1107  		t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
  1108  	}
  1109  	if got := res.Header.Get("foo"); got != wantFoo {
  1110  		t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
  1111  	}
  1112  	select {
  1113  	case req = <-gotc:
  1114  	default:
  1115  		req = nil
  1116  	}
  1117  	if req == nil {
  1118  		if method != "OPTIONS" {
  1119  			t.Fatalf("handler never got request")
  1120  		}
  1121  		return
  1122  	}
  1123  	if req.Method != method {
  1124  		t.Errorf("method = %q; want %q", req.Method, method)
  1125  	}
  1126  	if req.URL.Path != "*" {
  1127  		t.Errorf("URL.Path = %q; want *", req.URL.Path)
  1128  	}
  1129  	if req.RequestURI != "*" {
  1130  		t.Errorf("RequestURI = %q; want *", req.RequestURI)
  1131  	}
  1132  }
  1133  
  1134  // Issue 13957
  1135  func TestTransportDiscardsUnneededConns(t *testing.T) {
  1136  	run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
  1137  }
  1138  func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
  1139  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1140  		fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
  1141  	}))
  1142  	defer cst.close()
  1143  
  1144  	var numOpen, numClose int32 // atomic
  1145  
  1146  	tlsConfig := &tls.Config{InsecureSkipVerify: true}
  1147  	tr := &Transport{
  1148  		TLSClientConfig: tlsConfig,
  1149  		DialTLS: func(_, addr string) (net.Conn, error) {
  1150  			time.Sleep(10 * time.Millisecond)
  1151  			rc, err := net.Dial("tcp", addr)
  1152  			if err != nil {
  1153  				return nil, err
  1154  			}
  1155  			atomic.AddInt32(&numOpen, 1)
  1156  			c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
  1157  			return tls.Client(c, tlsConfig), nil
  1158  		},
  1159  		Protocols: &Protocols{},
  1160  	}
  1161  	tr.Protocols.SetHTTP2(true)
  1162  	defer tr.CloseIdleConnections()
  1163  
  1164  	c := &Client{Transport: tr}
  1165  
  1166  	const N = 10
  1167  	gotBody := make(chan string, N)
  1168  	var wg sync.WaitGroup
  1169  	for i := 0; i < N; i++ {
  1170  		wg.Add(1)
  1171  		go func() {
  1172  			defer wg.Done()
  1173  			resp, err := c.Get(cst.ts.URL)
  1174  			if err != nil {
  1175  				// Try to work around spurious connection reset on loaded system.
  1176  				// See golang.org/issue/33585 and golang.org/issue/36797.
  1177  				time.Sleep(10 * time.Millisecond)
  1178  				resp, err = c.Get(cst.ts.URL)
  1179  				if err != nil {
  1180  					t.Errorf("Get: %v", err)
  1181  					return
  1182  				}
  1183  			}
  1184  			defer resp.Body.Close()
  1185  			slurp, err := io.ReadAll(resp.Body)
  1186  			if err != nil {
  1187  				t.Error(err)
  1188  			}
  1189  			gotBody <- string(slurp)
  1190  		}()
  1191  	}
  1192  	wg.Wait()
  1193  	close(gotBody)
  1194  
  1195  	var last string
  1196  	for got := range gotBody {
  1197  		if last == "" {
  1198  			last = got
  1199  			continue
  1200  		}
  1201  		if got != last {
  1202  			t.Errorf("Response body changed: %q -> %q", last, got)
  1203  		}
  1204  	}
  1205  
  1206  	var open, close int32
  1207  	for i := 0; i < 150; i++ {
  1208  		open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
  1209  		if open < 1 {
  1210  			t.Fatalf("open = %d; want at least", open)
  1211  		}
  1212  		if close == open-1 {
  1213  			// Success
  1214  			return
  1215  		}
  1216  		time.Sleep(10 * time.Millisecond)
  1217  	}
  1218  	t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
  1219  }
  1220  
  1221  // tests that Transport doesn't retain a pointer to the provided request.
  1222  func TestTransportGCRequest(t *testing.T) {
  1223  	run(t, func(t *testing.T, mode testMode) {
  1224  		t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
  1225  		t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
  1226  	})
  1227  }
  1228  func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
  1229  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1230  		io.ReadAll(r.Body)
  1231  		if body {
  1232  			io.WriteString(w, "Hello.")
  1233  		}
  1234  	}))
  1235  
  1236  	didGC := make(chan struct{})
  1237  	(func() {
  1238  		body := strings.NewReader("some body")
  1239  		req, _ := NewRequest("POST", cst.ts.URL, body)
  1240  		runtime.AddCleanup(req, func(ch chan struct{}) { close(ch) }, didGC)
  1241  		res, err := cst.c.Do(req)
  1242  		if err != nil {
  1243  			t.Fatal(err)
  1244  		}
  1245  		if _, err := io.ReadAll(res.Body); err != nil {
  1246  			t.Fatal(err)
  1247  		}
  1248  		if err := res.Body.Close(); err != nil {
  1249  			t.Fatal(err)
  1250  		}
  1251  	})()
  1252  	for {
  1253  		select {
  1254  		case <-didGC:
  1255  			return
  1256  		case <-time.After(1 * time.Millisecond):
  1257  			runtime.GC()
  1258  		}
  1259  	}
  1260  }
  1261  
  1262  func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
  1263  func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
  1264  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1265  		fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
  1266  	}), optQuietLog)
  1267  	cst.tr.DisableKeepAlives = true
  1268  
  1269  	tests := []struct {
  1270  		key, val string
  1271  		ok       bool
  1272  	}{
  1273  		{"Foo", "capital-key", true}, // verify h2 allows capital keys
  1274  		{"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed
  1275  		{"Foo", "two\nlines", false}, // \n byte in value not allowed
  1276  		{"bogus\nkey", "v", false},   // \n byte also not allowed in key
  1277  		{"A space", "v", false},      // spaces in keys not allowed
  1278  		{"имя", "v", false},          // key must be ascii
  1279  		{"name", "валю", true},       // value may be non-ascii
  1280  		{"", "v", false},             // key must be non-empty
  1281  		{"k", "", true},              // value may be empty
  1282  	}
  1283  	for _, tt := range tests {
  1284  		dialedc := make(chan bool, 1)
  1285  		cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
  1286  			dialedc <- true
  1287  			return net.Dial(netw, addr)
  1288  		}
  1289  		req, _ := NewRequest("GET", cst.ts.URL, nil)
  1290  		req.Header[tt.key] = []string{tt.val}
  1291  		res, err := cst.c.Do(req)
  1292  		var body []byte
  1293  		if err == nil {
  1294  			body, _ = io.ReadAll(res.Body)
  1295  			res.Body.Close()
  1296  		}
  1297  		var dialed bool
  1298  		select {
  1299  		case <-dialedc:
  1300  			dialed = true
  1301  		default:
  1302  		}
  1303  
  1304  		if !tt.ok && dialed {
  1305  			t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
  1306  		} else if (err == nil) != tt.ok {
  1307  			t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
  1308  		}
  1309  	}
  1310  }
  1311  
  1312  func TestInterruptWithPanic(t *testing.T) {
  1313  	run(t, func(t *testing.T, mode testMode) {
  1314  		t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
  1315  		t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) })
  1316  		t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
  1317  	}, testNotParallel)
  1318  }
  1319  func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
  1320  	const msg = "hello"
  1321  
  1322  	testDone := make(chan struct{})
  1323  	defer close(testDone)
  1324  
  1325  	var errorLog lockedBytesBuffer
  1326  	gotHeaders := make(chan bool, 1)
  1327  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1328  		io.WriteString(w, msg)
  1329  		w.(Flusher).Flush()
  1330  
  1331  		select {
  1332  		case <-gotHeaders:
  1333  		case <-testDone:
  1334  		}
  1335  		panic(panicValue)
  1336  	}), func(ts *httptest.Server) {
  1337  		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
  1338  	})
  1339  	res, err := cst.c.Get(cst.ts.URL)
  1340  	if err != nil {
  1341  		t.Fatal(err)
  1342  	}
  1343  	gotHeaders <- true
  1344  	defer res.Body.Close()
  1345  	slurp, err := io.ReadAll(res.Body)
  1346  	if string(slurp) != msg {
  1347  		t.Errorf("client read %q; want %q", slurp, msg)
  1348  	}
  1349  	if err == nil {
  1350  		t.Errorf("client read all successfully; want some error")
  1351  	}
  1352  	logOutput := func() string {
  1353  		errorLog.Lock()
  1354  		defer errorLog.Unlock()
  1355  		return errorLog.String()
  1356  	}
  1357  	wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
  1358  
  1359  	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
  1360  		gotLog := logOutput()
  1361  		if !wantStackLogged {
  1362  			if gotLog == "" {
  1363  				return true
  1364  			}
  1365  			t.Fatalf("want no log output; got: %s", gotLog)
  1366  		}
  1367  		if gotLog == "" {
  1368  			if d > 0 {
  1369  				t.Logf("wanted a stack trace logged; got nothing after %v", d)
  1370  			}
  1371  			return false
  1372  		}
  1373  		if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
  1374  			if d > 0 {
  1375  				t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog)
  1376  			}
  1377  			return false
  1378  		}
  1379  		return true
  1380  	})
  1381  }
  1382  
  1383  type lockedBytesBuffer struct {
  1384  	sync.Mutex
  1385  	bytes.Buffer
  1386  }
  1387  
  1388  func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
  1389  	b.Lock()
  1390  	defer b.Unlock()
  1391  	return b.Buffer.Write(p)
  1392  }
  1393  
  1394  // Issue 15366
  1395  func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
  1396  	h12Compare{
  1397  		Handler: func(w ResponseWriter, r *Request) {
  1398  			h := w.Header()
  1399  			h.Set("Content-Encoding", "gzip")
  1400  			h.Set("Content-Length", "23")
  1401  			io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
  1402  		},
  1403  		EarlyCheckResponse: func(proto string, res *Response) {
  1404  			if !res.Uncompressed {
  1405  				t.Errorf("%s: expected Uncompressed to be set", proto)
  1406  			}
  1407  			dump, err := httputil.DumpResponse(res, true)
  1408  			if err != nil {
  1409  				t.Errorf("%s: DumpResponse: %v", proto, err)
  1410  				return
  1411  			}
  1412  			if strings.Contains(string(dump), "Connection: close") {
  1413  				t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
  1414  			}
  1415  			if !strings.Contains(string(dump), "FOO") {
  1416  				t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
  1417  			}
  1418  		},
  1419  	}.run(t)
  1420  }
  1421  
  1422  // Issue 14607
  1423  func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
  1424  func testCloseIdleConnections(t *testing.T, mode testMode) {
  1425  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1426  		w.Header().Set("X-Addr", r.RemoteAddr)
  1427  	}))
  1428  	get := func() string {
  1429  		res, err := cst.c.Get(cst.ts.URL)
  1430  		if err != nil {
  1431  			t.Fatal(err)
  1432  		}
  1433  		res.Body.Close()
  1434  		v := res.Header.Get("X-Addr")
  1435  		if v == "" {
  1436  			t.Fatal("didn't get X-Addr")
  1437  		}
  1438  		return v
  1439  	}
  1440  	a1 := get()
  1441  	cst.tr.CloseIdleConnections()
  1442  	a2 := get()
  1443  	if a1 == a2 {
  1444  		t.Errorf("didn't close connection")
  1445  	}
  1446  }
  1447  
  1448  type noteCloseConn struct {
  1449  	net.Conn
  1450  	closeFunc func()
  1451  }
  1452  
  1453  func (x noteCloseConn) Close() error {
  1454  	x.closeFunc()
  1455  	return x.Conn.Close()
  1456  }
  1457  
  1458  type testErrorReader struct{ t *testing.T }
  1459  
  1460  func (r testErrorReader) Read(p []byte) (n int, err error) {
  1461  	r.t.Error("unexpected Read call")
  1462  	return 0, io.EOF
  1463  }
  1464  
  1465  func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
  1466  func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
  1467  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1468  		w.WriteHeader(StatusUnauthorized)
  1469  	}))
  1470  
  1471  	// Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it.
  1472  	cst.tr.ExpectContinueTimeout = 10 * time.Second
  1473  
  1474  	req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
  1475  	if err != nil {
  1476  		t.Fatal(err)
  1477  	}
  1478  	req.ContentLength = 0 // so transport is tempted to sniff it
  1479  	req.Header.Set("Expect", "100-continue")
  1480  	res, err := cst.tr.RoundTrip(req)
  1481  	if err != nil {
  1482  		t.Fatal(err)
  1483  	}
  1484  	defer res.Body.Close()
  1485  	if res.StatusCode != StatusUnauthorized {
  1486  		t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
  1487  	}
  1488  }
  1489  
  1490  func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
  1491  func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
  1492  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1493  		w.Header().Set("Foo", "Bar")
  1494  		w.Header().Set("Trailer:Foo", "Baz")
  1495  		w.(Flusher).Flush()
  1496  		w.Header().Add("Trailer:Foo", "Baz2")
  1497  		w.Header().Set("Trailer:Bar", "Quux")
  1498  	}))
  1499  	res, err := cst.c.Get(cst.ts.URL)
  1500  	if err != nil {
  1501  		t.Fatal(err)
  1502  	}
  1503  	if _, err := io.Copy(io.Discard, res.Body); err != nil {
  1504  		t.Fatal(err)
  1505  	}
  1506  	res.Body.Close()
  1507  	delete(res.Header, "Date")
  1508  	delete(res.Header, "Content-Type")
  1509  
  1510  	if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
  1511  		t.Errorf("Header = %#v; want %#v", res.Header, want)
  1512  	}
  1513  	if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
  1514  		t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
  1515  	}
  1516  }
  1517  
  1518  func TestBadResponseAfterReadingBody(t *testing.T) {
  1519  	run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
  1520  }
  1521  func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
  1522  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1523  		_, err := io.Copy(io.Discard, r.Body)
  1524  		if err != nil {
  1525  			t.Fatal(err)
  1526  		}
  1527  		c, _, err := w.(Hijacker).Hijack()
  1528  		if err != nil {
  1529  			t.Fatal(err)
  1530  		}
  1531  		defer c.Close()
  1532  		fmt.Fprintln(c, "some bogus crap")
  1533  	}))
  1534  
  1535  	closes := 0
  1536  	res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
  1537  	if err == nil {
  1538  		res.Body.Close()
  1539  		t.Fatal("expected an error to be returned from Post")
  1540  	}
  1541  	if closes != 1 {
  1542  		t.Errorf("closes = %d; want 1", closes)
  1543  	}
  1544  }
  1545  
  1546  func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
  1547  func testWriteHeader0(t *testing.T, mode testMode) {
  1548  	gotpanic := make(chan bool, 1)
  1549  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1550  		defer close(gotpanic)
  1551  		defer func() {
  1552  			if e := recover(); e != nil {
  1553  				got := fmt.Sprintf("%T, %v", e, e)
  1554  				want := "string, invalid WriteHeader code 0"
  1555  				if got != want {
  1556  					t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
  1557  				}
  1558  				gotpanic <- true
  1559  
  1560  				// Set an explicit 503. This also tests that the WriteHeader call panics
  1561  				// before it recorded that an explicit value was set and that bogus
  1562  				// value wasn't stuck.
  1563  				w.WriteHeader(503)
  1564  			}
  1565  		}()
  1566  		w.WriteHeader(0)
  1567  	}))
  1568  	res, err := cst.c.Get(cst.ts.URL)
  1569  	if err != nil {
  1570  		t.Fatal(err)
  1571  	}
  1572  	if res.StatusCode != 503 {
  1573  		t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
  1574  	}
  1575  	if !<-gotpanic {
  1576  		t.Error("expected panic in handler")
  1577  	}
  1578  }
  1579  
  1580  // Issue 23010: don't be super strict checking WriteHeader's code if
  1581  // it's not even valid to call WriteHeader then anyway.
  1582  func TestWriteHeaderNoCodeCheck(t *testing.T) {
  1583  	run(t, func(t *testing.T, mode testMode) {
  1584  		testWriteHeaderAfterWrite(t, mode, false)
  1585  	})
  1586  }
  1587  func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
  1588  	testWriteHeaderAfterWrite(t, http1Mode, true)
  1589  }
  1590  func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
  1591  	var errorLog lockedBytesBuffer
  1592  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1593  		if hijack {
  1594  			conn, _, _ := w.(Hijacker).Hijack()
  1595  			defer conn.Close()
  1596  			conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
  1597  			w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
  1598  			conn.Write([]byte("bar"))
  1599  			return
  1600  		}
  1601  		io.WriteString(w, "foo")
  1602  		w.(Flusher).Flush()
  1603  		w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
  1604  		io.WriteString(w, "bar")
  1605  	}), func(ts *httptest.Server) {
  1606  		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
  1607  	})
  1608  	res, err := cst.c.Get(cst.ts.URL)
  1609  	if err != nil {
  1610  		t.Fatal(err)
  1611  	}
  1612  	defer res.Body.Close()
  1613  	body, err := io.ReadAll(res.Body)
  1614  	if err != nil {
  1615  		t.Fatal(err)
  1616  	}
  1617  	if got, want := string(body), "foobar"; got != want {
  1618  		t.Errorf("got = %q; want %q", got, want)
  1619  	}
  1620  
  1621  	// Also check the stderr output:
  1622  	if mode == http2Mode {
  1623  		// TODO: also emit this log message for HTTP/2?
  1624  		// We historically haven't, so don't check.
  1625  		return
  1626  	}
  1627  	gotLog := strings.TrimSpace(errorLog.String())
  1628  	wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
  1629  	if hijack {
  1630  		wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
  1631  	}
  1632  	if !strings.HasPrefix(gotLog, wantLog) {
  1633  		t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
  1634  	}
  1635  }
  1636  
  1637  func TestBidiStreamReverseProxy(t *testing.T) {
  1638  	run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
  1639  }
  1640  func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
  1641  	backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1642  		if _, err := io.Copy(w, r.Body); err != nil {
  1643  			log.Printf("bidi backend copy: %v", err)
  1644  		}
  1645  	}))
  1646  
  1647  	backURL, err := url.Parse(backend.ts.URL)
  1648  	if err != nil {
  1649  		t.Fatal(err)
  1650  	}
  1651  	rp := httputil.NewSingleHostReverseProxy(backURL)
  1652  	rp.Transport = backend.tr
  1653  	proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1654  		rp.ServeHTTP(w, r)
  1655  	}))
  1656  
  1657  	bodyRes := make(chan any, 1) // error or hash.Hash
  1658  	pr, pw := io.Pipe()
  1659  	req, _ := NewRequest("PUT", proxy.ts.URL, pr)
  1660  	const size = 4 << 20
  1661  	go func() {
  1662  		h := sha1.New()
  1663  		_, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
  1664  		go pw.Close()
  1665  		if err != nil {
  1666  			t.Errorf("body copy: %v", err)
  1667  			bodyRes <- err
  1668  		} else {
  1669  			bodyRes <- h
  1670  		}
  1671  	}()
  1672  	res, err := backend.c.Do(req)
  1673  	if err != nil {
  1674  		t.Fatal(err)
  1675  	}
  1676  	defer res.Body.Close()
  1677  	hgot := sha1.New()
  1678  	n, err := io.Copy(hgot, res.Body)
  1679  	if err != nil {
  1680  		t.Fatal(err)
  1681  	}
  1682  	if n != size {
  1683  		t.Fatalf("got %d bytes; want %d", n, size)
  1684  	}
  1685  	select {
  1686  	case v := <-bodyRes:
  1687  		switch v := v.(type) {
  1688  		default:
  1689  			t.Fatalf("body copy: %v", err)
  1690  		case hash.Hash:
  1691  			if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
  1692  				t.Errorf("written bytes didn't match received bytes")
  1693  			}
  1694  		}
  1695  	case <-time.After(10 * time.Second):
  1696  		t.Fatal("timeout")
  1697  	}
  1698  
  1699  }
  1700  
  1701  // Always use HTTP/1.1 for WebSocket upgrades.
  1702  func TestH12_WebSocketUpgrade(t *testing.T) {
  1703  	h12Compare{
  1704  		Handler: func(w ResponseWriter, r *Request) {
  1705  			h := w.Header()
  1706  			h.Set("Foo", "bar")
  1707  		},
  1708  		ReqFunc: func(c *Client, url string) (*Response, error) {
  1709  			req, _ := NewRequest("GET", url, nil)
  1710  			req.Header.Set("Connection", "Upgrade")
  1711  			req.Header.Set("Upgrade", "WebSocket")
  1712  			return c.Do(req)
  1713  		},
  1714  		EarlyCheckResponse: func(proto string, res *Response) {
  1715  			if res.Proto != "HTTP/1.1" {
  1716  				t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
  1717  			}
  1718  			res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0
  1719  		},
  1720  		Opts: []any{
  1721  			func(s *Server) {
  1722  				// Configure servers to support HTTP/1 and HTTP/2,
  1723  				// so we can verify that we use HTTP/1
  1724  				// even when HTTP/2 is an option.
  1725  				s.Protocols = &Protocols{}
  1726  				s.Protocols.SetHTTP1(true)
  1727  				s.Protocols.SetHTTP2(true)
  1728  			},
  1729  		},
  1730  	}.run(t)
  1731  }
  1732  
  1733  func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
  1734  func testIdentityTransferEncoding(t *testing.T, mode testMode) {
  1735  	const body = "body"
  1736  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1737  		gotBody, _ := io.ReadAll(r.Body)
  1738  		if got, want := string(gotBody), body; got != want {
  1739  			t.Errorf("got request body = %q; want %q", got, want)
  1740  		}
  1741  		w.Header().Set("Transfer-Encoding", "identity")
  1742  		w.WriteHeader(StatusOK)
  1743  		w.(Flusher).Flush()
  1744  		io.WriteString(w, body)
  1745  	}))
  1746  	req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
  1747  	res, err := cst.c.Do(req)
  1748  	if err != nil {
  1749  		t.Fatal(err)
  1750  	}
  1751  	defer res.Body.Close()
  1752  	gotBody, err := io.ReadAll(res.Body)
  1753  	if err != nil {
  1754  		t.Fatal(err)
  1755  	}
  1756  	if got, want := string(gotBody), body; got != want {
  1757  		t.Errorf("got response body = %q; want %q", got, want)
  1758  	}
  1759  }
  1760  
  1761  func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
  1762  func testEarlyHintsRequest(t *testing.T, mode testMode) {
  1763  	var wg sync.WaitGroup
  1764  	wg.Add(1)
  1765  	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
  1766  		h := w.Header()
  1767  
  1768  		h.Add("Content-Length", "123") // must be ignored
  1769  		h.Add("Link", "</style.css>; rel=preload; as=style")
  1770  		h.Add("Link", "</script.js>; rel=preload; as=script")
  1771  		w.WriteHeader(StatusEarlyHints)
  1772  
  1773  		wg.Wait()
  1774  
  1775  		h.Add("Link", "</foo.js>; rel=preload; as=script")
  1776  		w.WriteHeader(StatusEarlyHints)
  1777  
  1778  		w.Write([]byte("Hello"))
  1779  	}))
  1780  
  1781  	checkLinkHeaders := func(t *testing.T, expected, got []string) {
  1782  		t.Helper()
  1783  
  1784  		if len(expected) != len(got) {
  1785  			t.Errorf("got %d expected %d", len(got), len(expected))
  1786  		}
  1787  
  1788  		for i := range expected {
  1789  			if expected[i] != got[i] {
  1790  				t.Errorf("got %q expected %q", got[i], expected[i])
  1791  			}
  1792  		}
  1793  	}
  1794  
  1795  	checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
  1796  		t.Helper()
  1797  
  1798  		for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
  1799  			if v, ok := header[h]; ok {
  1800  				t.Errorf("%s is %q; must not be sent", h, v)
  1801  			}
  1802  		}
  1803  	}
  1804  
  1805  	var respCounter uint8
  1806  	trace := &httptrace.ClientTrace{
  1807  		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
  1808  			switch respCounter {
  1809  			case 0:
  1810  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
  1811  				checkExcludedHeaders(t, header)
  1812  
  1813  				wg.Done()
  1814  			case 1:
  1815  				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
  1816  				checkExcludedHeaders(t, header)
  1817  
  1818  			default:
  1819  				t.Error("Unexpected 1xx response")
  1820  			}
  1821  
  1822  			respCounter++
  1823  
  1824  			return nil
  1825  		},
  1826  	}
  1827  	req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
  1828  
  1829  	res, err := cst.c.Do(req)
  1830  	if err != nil {
  1831  		t.Fatal(err)
  1832  	}
  1833  	defer res.Body.Close()
  1834  
  1835  	checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
  1836  	if cl := res.Header.Get("Content-Length"); cl != "123" {
  1837  		t.Errorf("Content-Length is %q; want 123", cl)
  1838  	}
  1839  
  1840  	body, _ := io.ReadAll(res.Body)
  1841  	if string(body) != "Hello" {
  1842  		t.Errorf("Read body %q; want Hello", body)
  1843  	}
  1844  }
  1845  

View as plain text