Source file src/net/http/transport_internal_test.go

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // White-box tests for transport.go (in package http instead of http_test).
     6  
     7  package http
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/tls"
    13  	"errors"
    14  	"io"
    15  	"net"
    16  	"net/http/internal/http2"
    17  	"net/http/internal/testcert"
    18  	"strings"
    19  	"testing"
    20  )
    21  
    22  // Issue 15446: incorrect wrapping of errors when server closes an idle connection.
    23  func TestTransportPersistConnReadLoopEOF(t *testing.T) {
    24  	ln := newLocalListener(t)
    25  	defer ln.Close()
    26  
    27  	connc := make(chan net.Conn, 1)
    28  	go func() {
    29  		defer close(connc)
    30  		c, err := ln.Accept()
    31  		if err != nil {
    32  			t.Error(err)
    33  			return
    34  		}
    35  		connc <- c
    36  	}()
    37  
    38  	tr := new(Transport)
    39  	req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
    40  	req = req.WithT(t)
    41  	ctx, cancel := context.WithCancelCause(context.Background())
    42  	treq := &transportRequest{Request: req, ctx: ctx, cancel: cancel}
    43  	cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
    44  	pc, err := tr.getConn(treq, cm)
    45  	if err != nil {
    46  		t.Fatal(err)
    47  	}
    48  	defer pc.close(errors.New("test over"))
    49  
    50  	conn := <-connc
    51  	if conn == nil {
    52  		// Already called t.Error in the accept goroutine.
    53  		return
    54  	}
    55  	conn.Close() // simulate the server hanging up on the client
    56  
    57  	_, err = pc.roundTrip(treq)
    58  	if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle {
    59  		t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle, transportReadFromServerError, or nothingWrittenError", err, err)
    60  	}
    61  
    62  	<-pc.closech
    63  	err = pc.closed
    64  	if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle {
    65  		t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError, or nothingWrittenError", err, err)
    66  	}
    67  }
    68  
    69  func isNothingWrittenError(err error) bool {
    70  	_, ok := err.(nothingWrittenError)
    71  	return ok
    72  }
    73  
    74  func isTransportReadFromServerError(err error) bool {
    75  	_, ok := err.(transportReadFromServerError)
    76  	return ok
    77  }
    78  
    79  func newLocalListener(t *testing.T) net.Listener {
    80  	ln, err := net.Listen("tcp", "127.0.0.1:0")
    81  	if err != nil {
    82  		ln, err = net.Listen("tcp6", "[::1]:0")
    83  	}
    84  	if err != nil {
    85  		t.Fatal(err)
    86  	}
    87  	return ln
    88  }
    89  
    90  func dummyRequest(method string) *Request {
    91  	req, err := NewRequest(method, "http://fake.tld/", nil)
    92  	if err != nil {
    93  		panic(err)
    94  	}
    95  	return req
    96  }
    97  func dummyRequestWithBody(method string) *Request {
    98  	req, err := NewRequest(method, "http://fake.tld/", strings.NewReader("foo"))
    99  	if err != nil {
   100  		panic(err)
   101  	}
   102  	return req
   103  }
   104  
   105  func dummyRequestWithBodyNoGetBody(method string) *Request {
   106  	req := dummyRequestWithBody(method)
   107  	req.GetBody = nil
   108  	return req
   109  }
   110  
   111  // issue22091Error acts like a golang.org/x/net/http2.ErrNoCachedConn.
   112  type issue22091Error struct{}
   113  
   114  func (issue22091Error) IsHTTP2NoCachedConnError() {}
   115  func (issue22091Error) Error() string             { return "issue22091Error" }
   116  
   117  func TestTransportShouldRetryRequest(t *testing.T) {
   118  	tests := []struct {
   119  		pc  *persistConn
   120  		req *Request
   121  
   122  		err  error
   123  		want bool
   124  	}{
   125  		0: {
   126  			pc:   &persistConn{reused: false},
   127  			req:  dummyRequest("POST"),
   128  			err:  nothingWrittenError{},
   129  			want: false,
   130  		},
   131  		1: {
   132  			pc:   &persistConn{reused: true},
   133  			req:  dummyRequest("POST"),
   134  			err:  nothingWrittenError{},
   135  			want: true,
   136  		},
   137  		2: {
   138  			pc:   &persistConn{reused: true},
   139  			req:  dummyRequest("POST"),
   140  			err:  http2.ErrNoCachedConn,
   141  			want: true,
   142  		},
   143  		3: {
   144  			pc:   nil,
   145  			req:  nil,
   146  			err:  issue22091Error{}, // like an external http2ErrNoCachedConn
   147  			want: true,
   148  		},
   149  		4: {
   150  			pc:   &persistConn{reused: true},
   151  			req:  dummyRequest("POST"),
   152  			err:  errMissingHost,
   153  			want: false,
   154  		},
   155  		5: {
   156  			pc:   &persistConn{reused: true},
   157  			req:  dummyRequest("POST"),
   158  			err:  transportReadFromServerError{},
   159  			want: false,
   160  		},
   161  		6: {
   162  			pc:   &persistConn{reused: true},
   163  			req:  dummyRequest("GET"),
   164  			err:  transportReadFromServerError{},
   165  			want: true,
   166  		},
   167  		7: {
   168  			pc:   &persistConn{reused: true},
   169  			req:  dummyRequest("GET"),
   170  			err:  errServerClosedIdle,
   171  			want: true,
   172  		},
   173  		8: {
   174  			pc:   &persistConn{reused: true},
   175  			req:  dummyRequestWithBody("POST"),
   176  			err:  nothingWrittenError{},
   177  			want: true,
   178  		},
   179  		9: {
   180  			pc:   &persistConn{reused: true},
   181  			req:  dummyRequestWithBodyNoGetBody("POST"),
   182  			err:  nothingWrittenError{},
   183  			want: false,
   184  		},
   185  	}
   186  	for i, tt := range tests {
   187  		got := tt.pc.shouldRetryRequest(tt.req, tt.err)
   188  		if got != tt.want {
   189  			t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want)
   190  		}
   191  	}
   192  }
   193  
   194  type roundTripFunc func(r *Request) (*Response, error)
   195  
   196  func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) {
   197  	return f(r)
   198  }
   199  
   200  // Issue 25009
   201  func TestTransportBodyAltRewind(t *testing.T) {
   202  	cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
   203  	if err != nil {
   204  		t.Fatal(err)
   205  	}
   206  	ln := newLocalListener(t)
   207  	defer ln.Close()
   208  
   209  	go func() {
   210  		tln := tls.NewListener(ln, &tls.Config{
   211  			NextProtos:   []string{"foo"},
   212  			Certificates: []tls.Certificate{cert},
   213  		})
   214  		for i := 0; i < 2; i++ {
   215  			sc, err := tln.Accept()
   216  			if err != nil {
   217  				t.Error(err)
   218  				return
   219  			}
   220  			if err := sc.(*tls.Conn).Handshake(); err != nil {
   221  				t.Error(err)
   222  				return
   223  			}
   224  			sc.Close()
   225  		}
   226  	}()
   227  
   228  	addr := ln.Addr().String()
   229  	req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request"))
   230  	roundTripped := false
   231  	tr := &Transport{
   232  		DisableKeepAlives: true,
   233  		TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
   234  			"foo": func(authority string, c *tls.Conn) RoundTripper {
   235  				return roundTripFunc(func(r *Request) (*Response, error) {
   236  					n, _ := io.Copy(io.Discard, r.Body)
   237  					if n == 0 {
   238  						t.Error("body length is zero")
   239  					}
   240  					if roundTripped {
   241  						return &Response{
   242  							Body:       NoBody,
   243  							StatusCode: 200,
   244  						}, nil
   245  					}
   246  					roundTripped = true
   247  					return nil, http2.ErrNoCachedConn
   248  				})
   249  			},
   250  		},
   251  		DialTLS: func(_, _ string) (net.Conn, error) {
   252  			tc, err := tls.Dial("tcp", addr, &tls.Config{
   253  				InsecureSkipVerify: true,
   254  				NextProtos:         []string{"foo"},
   255  			})
   256  			if err != nil {
   257  				return nil, err
   258  			}
   259  			if err := tc.Handshake(); err != nil {
   260  				return nil, err
   261  			}
   262  			return tc, nil
   263  		},
   264  	}
   265  	c := &Client{Transport: tr}
   266  	_, err = c.Do(req)
   267  	if err != nil {
   268  		t.Error(err)
   269  	}
   270  }
   271  

View as plain text