cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2021.3.2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

connection/http2_test.go

409lines · modecode

1package connection
2
3import (
4 "context"
5 "fmt"
6 "io"
7 "io/ioutil"
8 "net"
9 "net/http"
10 "net/http/httptest"
11 "sync"
12 "testing"
13 "time"
14
15 "github.com/stretchr/testify/assert"
16
17 "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
18 tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
19
20 "github.com/gobwas/ws/wsutil"
21 "github.com/rs/zerolog"
22 "github.com/stretchr/testify/require"
23 "golang.org/x/net/http2"
24)
25
26var (
27 testTransport = http2.Transport{}
28)
29
30func newTestHTTP2Connection() (*http2Connection, net.Conn) {
31 edgeConn, originConn := net.Pipe()
32 var connIndex = uint8(0)
33 return NewHTTP2Connection(
34 originConn,
35 testConfig,
36 &NamedTunnelConfig{},
37 &pogs.ConnectionOptions{},
38 NewObserver(&log, &log, false),
39 connIndex,
40 mockConnectedFuse{},
41 nil,
42 ), edgeConn
43}
44
45func TestServeHTTP(t *testing.T) {
46 tests := []testRequest{
47 {
48 name: "ok",
49 endpoint: "ok",
50 expectedStatus: http.StatusOK,
51 expectedBody: []byte(http.StatusText(http.StatusOK)),
52 },
53 {
54 name: "large_file",
55 endpoint: "large_file",
56 expectedStatus: http.StatusOK,
57 expectedBody: testLargeResp,
58 },
59 {
60 name: "Bad request",
61 endpoint: "400",
62 expectedStatus: http.StatusBadRequest,
63 expectedBody: []byte(http.StatusText(http.StatusBadRequest)),
64 },
65 {
66 name: "Internal server error",
67 endpoint: "500",
68 expectedStatus: http.StatusInternalServerError,
69 expectedBody: []byte(http.StatusText(http.StatusInternalServerError)),
70 },
71 {
72 name: "Proxy error",
73 endpoint: "error",
74 expectedStatus: http.StatusBadGateway,
75 expectedBody: nil,
76 isProxyError: true,
77 },
78 }
79
80 http2Conn, edgeConn := newTestHTTP2Connection()
81
82 ctx, cancel := context.WithCancel(context.Background())
83 var wg sync.WaitGroup
84 wg.Add(1)
85 go func() {
86 defer wg.Done()
87 http2Conn.Serve(ctx)
88 }()
89
90 edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
91 require.NoError(t, err)
92
93 for _, test := range tests {
94 endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
95 req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
96 require.NoError(t, err)
97
98 resp, err := edgeHTTP2Conn.RoundTrip(req)
99 require.NoError(t, err)
100 require.Equal(t, test.expectedStatus, resp.StatusCode)
101 if test.expectedBody != nil {
102 respBody, err := ioutil.ReadAll(resp.Body)
103 require.NoError(t, err)
104 require.Equal(t, test.expectedBody, respBody)
105 }
106 if test.isProxyError {
107 require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeaderField))
108 } else {
109 require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeaderField))
110 }
111 }
112 cancel()
113 wg.Wait()
114}
115
116type mockNamedTunnelRPCClient struct {
117 shouldFail error
118 registered chan struct{}
119 unregistered chan struct{}
120}
121
122func (mc mockNamedTunnelRPCClient) RegisterConnection(
123 c context.Context,
124 config *NamedTunnelConfig,
125 options *tunnelpogs.ConnectionOptions,
126 connIndex uint8,
127 observer *Observer,
128) error {
129 if mc.shouldFail != nil {
130 return mc.shouldFail
131 }
132 close(mc.registered)
133 return nil
134}
135
136func (mc mockNamedTunnelRPCClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
137 close(mc.unregistered)
138}
139
140func (mockNamedTunnelRPCClient) Close() {}
141
142type mockRPCClientFactory struct {
143 shouldFail error
144 registered chan struct{}
145 unregistered chan struct{}
146}
147
148func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient {
149 return mockNamedTunnelRPCClient{
150 shouldFail: mf.shouldFail,
151 registered: mf.registered,
152 unregistered: mf.unregistered,
153 }
154}
155
156type wsRespWriter struct {
157 *httptest.ResponseRecorder
158 readPipe *io.PipeReader
159 writePipe *io.PipeWriter
160}
161
162func newWSRespWriter() *wsRespWriter {
163 readPipe, writePipe := io.Pipe()
164 return &wsRespWriter{
165 httptest.NewRecorder(),
166 readPipe,
167 writePipe,
168 }
169}
170
171func (w *wsRespWriter) RespBody() io.ReadWriter {
172 return nowriter{w.readPipe}
173}
174
175func (w *wsRespWriter) Write(data []byte) (n int, err error) {
176 return w.writePipe.Write(data)
177}
178
179func TestServeWS(t *testing.T) {
180 http2Conn, _ := newTestHTTP2Connection()
181
182 ctx, cancel := context.WithCancel(context.Background())
183 var wg sync.WaitGroup
184 wg.Add(1)
185 go func() {
186 defer wg.Done()
187 http2Conn.Serve(ctx)
188 }()
189
190 respWriter := newWSRespWriter()
191 readPipe, writePipe := io.Pipe()
192
193 req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
194 require.NoError(t, err)
195 req.Header.Set(internalUpgradeHeader, websocketUpgrade)
196
197 wg.Add(1)
198 go func() {
199 defer wg.Done()
200 http2Conn.ServeHTTP(respWriter, req)
201 }()
202
203 data := []byte("test websocket")
204 err = wsutil.WriteClientText(writePipe, data)
205 require.NoError(t, err)
206
207 respBody, err := wsutil.ReadServerText(respWriter.RespBody())
208 require.NoError(t, err)
209 require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
210
211 cancel()
212 resp := respWriter.Result()
213 // http2RespWriter should rewrite status 101 to 200
214 require.Equal(t, http.StatusOK, resp.StatusCode)
215 require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeaderField))
216
217 wg.Wait()
218}
219
220func TestServeControlStream(t *testing.T) {
221 http2Conn, edgeConn := newTestHTTP2Connection()
222
223 rpcClientFactory := mockRPCClientFactory{
224 registered: make(chan struct{}),
225 unregistered: make(chan struct{}),
226 }
227 http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
228
229 ctx, cancel := context.WithCancel(context.Background())
230 var wg sync.WaitGroup
231 wg.Add(1)
232 go func() {
233 defer wg.Done()
234 http2Conn.Serve(ctx)
235 }()
236
237 req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
238 require.NoError(t, err)
239 req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
240
241 edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
242 require.NoError(t, err)
243
244 wg.Add(1)
245 go func() {
246 defer wg.Done()
247 edgeHTTP2Conn.RoundTrip(req)
248 }()
249
250 <-rpcClientFactory.registered
251 cancel()
252 <-rpcClientFactory.unregistered
253 assert.False(t, http2Conn.stoppedGracefully)
254
255 wg.Wait()
256}
257
258func TestFailRegistration(t *testing.T) {
259 http2Conn, edgeConn := newTestHTTP2Connection()
260
261 rpcClientFactory := mockRPCClientFactory{
262 shouldFail: errDuplicationConnection,
263 registered: make(chan struct{}),
264 unregistered: make(chan struct{}),
265 }
266 http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
267
268 ctx, cancel := context.WithCancel(context.Background())
269 var wg sync.WaitGroup
270 wg.Add(1)
271 go func() {
272 defer wg.Done()
273 http2Conn.Serve(ctx)
274 }()
275
276 req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
277 require.NoError(t, err)
278 req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
279
280 edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
281 require.NoError(t, err)
282 resp, err := edgeHTTP2Conn.RoundTrip(req)
283 require.NoError(t, err)
284 require.Equal(t, http.StatusBadGateway, resp.StatusCode)
285
286 assert.NotNil(t, http2Conn.controlStreamErr)
287 cancel()
288 wg.Wait()
289}
290
291func TestGracefulShutdownHTTP2(t *testing.T) {
292 http2Conn, edgeConn := newTestHTTP2Connection()
293
294 rpcClientFactory := mockRPCClientFactory{
295 registered: make(chan struct{}),
296 unregistered: make(chan struct{}),
297 }
298 events := &eventCollectorSink{}
299 http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
300 http2Conn.observer.RegisterSink(events)
301 shutdownC := make(chan struct{})
302 http2Conn.gracefulShutdownC = shutdownC
303
304 ctx, cancel := context.WithCancel(context.Background())
305 var wg sync.WaitGroup
306 wg.Add(1)
307 go func() {
308 defer wg.Done()
309 http2Conn.Serve(ctx)
310 }()
311
312 req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
313 require.NoError(t, err)
314 req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
315
316 edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
317 require.NoError(t, err)
318
319 wg.Add(1)
320 go func() {
321 defer wg.Done()
322 _, _ = edgeHTTP2Conn.RoundTrip(req)
323 }()
324
325 select {
326 case <-rpcClientFactory.registered:
327 break //ok
328 case <-time.Tick(time.Second):
329 t.Fatal("timeout out waiting for registration")
330 }
331
332 // signal graceful shutdown
333 close(shutdownC)
334
335 select {
336 case <-rpcClientFactory.unregistered:
337 break //ok
338 case <-time.Tick(time.Second):
339 t.Fatal("timeout out waiting for unregistered signal")
340 }
341 assert.True(t, http2Conn.stoppedGracefully)
342
343 cancel()
344 wg.Wait()
345
346 events.assertSawEvent(t, Event{
347 Index: http2Conn.connIndex,
348 EventType: Unregistering,
349 })
350}
351
352func benchmarkServeHTTP(b *testing.B, test testRequest) {
353 http2Conn, edgeConn := newTestHTTP2Connection()
354
355 ctx, cancel := context.WithCancel(context.Background())
356 var wg sync.WaitGroup
357 wg.Add(1)
358 go func() {
359 defer wg.Done()
360 http2Conn.Serve(ctx)
361 }()
362
363 endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
364 req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
365 require.NoError(b, err)
366
367 edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
368 require.NoError(b, err)
369
370 b.ResetTimer()
371 for i := 0; i < b.N; i++ {
372 b.StartTimer()
373 resp, err := edgeHTTP2Conn.RoundTrip(req)
374 b.StopTimer()
375 require.NoError(b, err)
376 require.Equal(b, test.expectedStatus, resp.StatusCode)
377 if test.expectedBody != nil {
378 respBody, err := ioutil.ReadAll(resp.Body)
379 require.NoError(b, err)
380 require.Equal(b, test.expectedBody, respBody)
381 }
382 resp.Body.Close()
383 }
384
385 cancel()
386 wg.Wait()
387}
388
389func BenchmarkServeHTTPSimple(b *testing.B) {
390 test := testRequest{
391 name: "ok",
392 endpoint: "ok",
393 expectedStatus: http.StatusOK,
394 expectedBody: []byte(http.StatusText(http.StatusOK)),
395 }
396
397 benchmarkServeHTTP(b, test)
398}
399
400func BenchmarkServeHTTPLargeFile(b *testing.B) {
401 test := testRequest{
402 name: "large_file",
403 endpoint: "large_file",
404 expectedStatus: http.StatusOK,
405 expectedBody: testLargeResp,
406 }
407
408 benchmarkServeHTTP(b, test)
409}
410