Source file src/net/http/export_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Bridge package to expose http internals to tests in the http_test
     6  // package.
     7  
     8  package http
     9  
    10  import (
    11  	"context"
    12  	"fmt"
    13  	"net"
    14  	"net/url"
    15  	"slices"
    16  	"sync"
    17  	"testing"
    18  	"time"
    19  )
    20  
    21  var (
    22  	DefaultUserAgent                  = defaultUserAgent
    23  	NewLoggingConn                    = newLoggingConn
    24  	ExportRefererForURL               = refererForURL
    25  	ExportServerNewConn               = (*Server).newConn
    26  	ExportCloseWriteAndWait           = (*conn).closeWriteAndWait
    27  	ExportErrRequestCanceled          = errRequestCanceled
    28  	ExportErrRequestCanceledConn      = errRequestCanceledConn
    29  	ExportErrServerClosedIdle         = errServerClosedIdle
    30  	ExportServeFile                   = serveFile
    31  	ExportScanETag                    = scanETag
    32  	Export_shouldCopyHeaderOnRedirect = shouldCopyHeaderOnRedirect
    33  	Export_writeStatusLine            = writeStatusLine
    34  	Export_is408Message               = is408Message
    35  	MaxPostCloseReadTime              = maxPostCloseReadTime
    36  )
    37  
    38  var MaxWriteWaitBeforeConnReuse = &maxWriteWaitBeforeConnReuse
    39  
    40  func init() {
    41  	// We only want to pay for this cost during testing.
    42  	// When not under test, these values are always nil
    43  	// and never assigned to.
    44  	testHookMu = new(sync.Mutex)
    45  
    46  	testHookClientDoResult = func(res *Response, err error) {
    47  		if err != nil {
    48  			if _, ok := err.(*url.Error); !ok {
    49  				panic(fmt.Sprintf("unexpected Client.Do error of type %T; want *url.Error", err))
    50  			}
    51  		} else {
    52  			if res == nil {
    53  				panic("Client.Do returned nil, nil")
    54  			}
    55  			if res.Body == nil {
    56  				panic("Client.Do returned nil res.Body and no error")
    57  			}
    58  		}
    59  	}
    60  }
    61  
    62  func CondSkipHTTP2(t testing.TB) {
    63  	if omitBundledHTTP2 {
    64  		t.Skip("skipping HTTP/2 test when nethttpomithttp2 build tag in use")
    65  	}
    66  }
    67  
    68  var (
    69  	SetEnterRoundTripHook = hookSetter(&testHookEnterRoundTrip)
    70  	SetRoundTripRetried   = hookSetter(&testHookRoundTripRetried)
    71  )
    72  
    73  func SetReadLoopBeforeNextReadHook(f func()) {
    74  	unnilTestHook(&f)
    75  	testHookReadLoopBeforeNextRead = f
    76  }
    77  
    78  // SetPendingDialHooks sets the hooks that run before and after handling
    79  // pending dials.
    80  func SetPendingDialHooks(before, after func()) {
    81  	unnilTestHook(&before)
    82  	unnilTestHook(&after)
    83  	testHookPrePendingDial, testHookPostPendingDial = before, after
    84  }
    85  
    86  func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
    87  
    88  func SetTestHookProxyConnectTimeout(t *testing.T, f func(context.Context, time.Duration) (context.Context, context.CancelFunc)) {
    89  	orig := testHookProxyConnectTimeout
    90  	t.Cleanup(func() {
    91  		testHookProxyConnectTimeout = orig
    92  	})
    93  	testHookProxyConnectTimeout = f
    94  }
    95  
    96  func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler {
    97  	return &timeoutHandler{
    98  		handler:     handler,
    99  		testContext: ctx,
   100  		// (no body)
   101  	}
   102  }
   103  
   104  func ResetCachedEnvironment() {
   105  	resetProxyConfig()
   106  }
   107  
   108  func (t *Transport) NumPendingRequestsForTesting() int {
   109  	t.reqMu.Lock()
   110  	defer t.reqMu.Unlock()
   111  	return len(t.reqCanceler)
   112  }
   113  
   114  func (t *Transport) IdleConnKeysForTesting() (keys []string) {
   115  	keys = make([]string, 0)
   116  	t.idleMu.Lock()
   117  	defer t.idleMu.Unlock()
   118  	for key := range t.idleConn {
   119  		keys = append(keys, key.String())
   120  	}
   121  	slices.Sort(keys)
   122  	return
   123  }
   124  
   125  func (t *Transport) IdleConnKeyCountForTesting() int {
   126  	t.idleMu.Lock()
   127  	defer t.idleMu.Unlock()
   128  	return len(t.idleConn)
   129  }
   130  
   131  func (t *Transport) IdleConnStrsForTesting() []string {
   132  	var ret []string
   133  	t.idleMu.Lock()
   134  	defer t.idleMu.Unlock()
   135  	for _, conns := range t.idleConn {
   136  		for _, pc := range conns {
   137  			if pc.conn == nil {
   138  				continue
   139  			}
   140  			ret = append(ret, pc.conn.LocalAddr().String()+"/"+pc.conn.RemoteAddr().String())
   141  		}
   142  	}
   143  	if f, ok := t.h2transport.(interface{ IdleConnStrsForTesting() []string }); ok {
   144  		ret = append(ret, f.IdleConnStrsForTesting()...)
   145  	}
   146  	slices.Sort(ret)
   147  	return ret
   148  }
   149  
   150  func (t *Transport) IdleConnCountForTesting(scheme, addr string) int {
   151  	t.idleMu.Lock()
   152  	defer t.idleMu.Unlock()
   153  	key := connectMethodKey{"", scheme, addr, false}
   154  	cacheKey := key.String()
   155  	for k, conns := range t.idleConn {
   156  		if k.String() == cacheKey {
   157  			return len(conns)
   158  		}
   159  	}
   160  	return 0
   161  }
   162  
   163  func (t *Transport) IdleConnWaitMapSizeForTesting() int {
   164  	t.idleMu.Lock()
   165  	defer t.idleMu.Unlock()
   166  	return len(t.idleConnWait)
   167  }
   168  
   169  func (t *Transport) IsIdleForTesting() bool {
   170  	t.idleMu.Lock()
   171  	defer t.idleMu.Unlock()
   172  	return t.closeIdle
   173  }
   174  
   175  func (t *Transport) QueueForIdleConnForTesting() {
   176  	t.queueForIdleConn(nil)
   177  }
   178  
   179  // PutIdleTestConn reports whether it was able to insert a fresh
   180  // persistConn for scheme, addr into the idle connection pool.
   181  func (t *Transport) PutIdleTestConn(scheme, addr string) bool {
   182  	c, _ := net.Pipe()
   183  	key := connectMethodKey{"", scheme, addr, false}
   184  
   185  	if t.MaxConnsPerHost > 0 {
   186  		// Transport is tracking conns-per-host.
   187  		// Increment connection count to account
   188  		// for new persistConn created below.
   189  		t.connsPerHostMu.Lock()
   190  		if t.connsPerHost == nil {
   191  			t.connsPerHost = make(map[connectMethodKey]int)
   192  		}
   193  		t.connsPerHost[key]++
   194  		t.connsPerHostMu.Unlock()
   195  	}
   196  
   197  	return t.tryPutIdleConn(&persistConn{
   198  		t:        t,
   199  		conn:     c,                   // dummy
   200  		closech:  make(chan struct{}), // so it can be closed
   201  		cacheKey: key,
   202  	}) == nil
   203  }
   204  
   205  // PutIdleTestConnH2 reports whether it was able to insert a fresh
   206  // HTTP/2 persistConn for scheme, addr into the idle connection pool.
   207  func (t *Transport) PutIdleTestConnH2(scheme, addr string, alt RoundTripper) bool {
   208  	key := connectMethodKey{"", scheme, addr, false}
   209  
   210  	if t.MaxConnsPerHost > 0 {
   211  		// Transport is tracking conns-per-host.
   212  		// Increment connection count to account
   213  		// for new persistConn created below.
   214  		t.connsPerHostMu.Lock()
   215  		if t.connsPerHost == nil {
   216  			t.connsPerHost = make(map[connectMethodKey]int)
   217  		}
   218  		t.connsPerHost[key]++
   219  		t.connsPerHostMu.Unlock()
   220  	}
   221  
   222  	return t.tryPutIdleConn(&persistConn{
   223  		t:        t,
   224  		alt:      alt,
   225  		cacheKey: key,
   226  	}) == nil
   227  }
   228  
   229  // All test hooks must be non-nil so they can be called directly,
   230  // but the tests use nil to mean hook disabled.
   231  func unnilTestHook(f *func()) {
   232  	if *f == nil {
   233  		*f = nop
   234  	}
   235  }
   236  
   237  func hookSetter(dst *func()) func(func()) {
   238  	return func(fn func()) {
   239  		unnilTestHook(&fn)
   240  		*dst = fn
   241  	}
   242  }
   243  
   244  func (s *Server) ExportAllConnsIdle() bool {
   245  	s.mu.Lock()
   246  	defer s.mu.Unlock()
   247  	for c := range s.activeConn {
   248  		st, unixSec := c.getState()
   249  		if unixSec == 0 || st != StateIdle {
   250  			return false
   251  		}
   252  	}
   253  	return true
   254  }
   255  
   256  func (s *Server) ExportAllConnsByState() map[ConnState]int {
   257  	states := map[ConnState]int{}
   258  	s.mu.Lock()
   259  	defer s.mu.Unlock()
   260  	for c := range s.activeConn {
   261  		st, _ := c.getState()
   262  		states[st] += 1
   263  	}
   264  	return states
   265  }
   266  
   267  func (r *Request) WithT(t *testing.T) *Request {
   268  	return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf))
   269  }
   270  
   271  func (r *Request) ExportIsReplayable() bool { return r.isReplayable() }
   272  
   273  // ExportCloseTransportConnsAbruptly closes all idle connections from
   274  // tr in an abrupt way, just reaching into the underlying Conns and
   275  // closing them, without telling the Transport or its persistConns
   276  // that it's doing so. This is to simulate the server closing connections
   277  // on the Transport.
   278  func ExportCloseTransportConnsAbruptly(tr *Transport) {
   279  	tr.idleMu.Lock()
   280  	for _, pcs := range tr.idleConn {
   281  		for _, pc := range pcs {
   282  			pc.conn.Close()
   283  		}
   284  	}
   285  	tr.idleMu.Unlock()
   286  }
   287  
   288  // ResponseWriterConnForTesting returns w's underlying connection, if w
   289  // is a regular *response ResponseWriter.
   290  func ResponseWriterConnForTesting(w ResponseWriter) (c net.Conn, ok bool) {
   291  	if r, ok := w.(*response); ok {
   292  		return r.conn.rwc, true
   293  	}
   294  	return nil, false
   295  }
   296  
   297  func init() {
   298  	// Set the default rstAvoidanceDelay to the minimum possible value to shake
   299  	// out tests that unexpectedly depend on it. Such tests should use
   300  	// runTimeSensitiveTest and SetRSTAvoidanceDelay to explicitly raise the delay
   301  	// if needed.
   302  	rstAvoidanceDelay = 1 * time.Nanosecond
   303  }
   304  
   305  // SetRSTAvoidanceDelay sets how long we are willing to wait between calling
   306  // CloseWrite on a connection and fully closing the connection.
   307  func SetRSTAvoidanceDelay(t *testing.T, d time.Duration) {
   308  	prevDelay := rstAvoidanceDelay
   309  	t.Cleanup(func() {
   310  		rstAvoidanceDelay = prevDelay
   311  	})
   312  	rstAvoidanceDelay = d
   313  }
   314  

View as plain text