cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2021.12.1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

carrier/websocket_test.go

123lines · modecode

1package carrier
2
3import (
4 "context"
5 "crypto/tls"
6 "crypto/x509"
7 "fmt"
8 "math/rand"
9 "testing"
10 "time"
11
12 gws "github.com/gorilla/websocket"
13 "github.com/rs/zerolog"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16 "golang.org/x/net/websocket"
17
18 "github.com/cloudflare/cloudflared/hello"
19 "github.com/cloudflare/cloudflared/tlsconfig"
20 cfwebsocket "github.com/cloudflare/cloudflared/websocket"
21)
22
23func websocketClientTLSConfig(t *testing.T) *tls.Config {
24 certPool := x509.NewCertPool()
25 helloCert, err := tlsconfig.GetHelloCertificateX509()
26 assert.NoError(t, err)
27 certPool.AddCert(helloCert)
28 assert.NotNil(t, certPool)
29 return &tls.Config{RootCAs: certPool}
30}
31
32func TestWebsocketHeaders(t *testing.T) {
33 req := testRequest(t, "http://example.com", nil)
34 wsHeaders := websocketHeaders(req)
35 for _, header := range stripWebsocketHeaders {
36 assert.Empty(t, wsHeaders[header])
37 }
38 assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
39}
40
41func TestServe(t *testing.T) {
42 log := zerolog.Nop()
43 shutdownC := make(chan struct{})
44 errC := make(chan error)
45 listener, err := hello.CreateTLSListener("localhost:1111")
46 assert.NoError(t, err)
47 defer listener.Close()
48
49 go func() {
50 errC <- hello.StartHelloWorldServer(&log, listener, shutdownC)
51 }()
52
53 req := testRequest(t, "https://localhost:1111/ws", nil)
54
55 tlsConfig := websocketClientTLSConfig(t)
56 assert.NotNil(t, tlsConfig)
57 d := gws.Dialer{TLSClientConfig: tlsConfig}
58 conn, resp, err := clientConnect(req, &d)
59 assert.NoError(t, err)
60 assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
61
62 for i := 0; i < 1000; i++ {
63 messageSize := rand.Int()%2048 + 1
64 clientMessage := make([]byte, messageSize)
65 // rand.Read always returns len(clientMessage) and a nil error
66 rand.Read(clientMessage)
67 err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
68 assert.NoError(t, err)
69
70 messageType, message, err := conn.ReadMessage()
71 assert.NoError(t, err)
72 assert.Equal(t, websocket.BinaryFrame, messageType)
73 assert.Equal(t, clientMessage, message)
74 }
75
76 _ = conn.Close()
77 close(shutdownC)
78 <-errC
79}
80
81func TestWebsocketWrapper(t *testing.T) {
82 listener, err := hello.CreateTLSListener("localhost:0")
83 require.NoError(t, err)
84
85 serverErrorChan := make(chan error)
86 helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background())
87 defer func() { <-serverErrorChan }()
88 defer cancelHelloSvr()
89 go func() {
90 log := zerolog.Nop()
91 serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done())
92 }()
93
94 tlsConfig := websocketClientTLSConfig(t)
95 d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute}
96 testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String())
97 req := testRequest(t, testAddr, nil)
98 conn, resp, err := clientConnect(req, &d)
99 require.NoError(t, err)
100 assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
101
102 // Websocket now connected to test server so lets check our wrapper
103 wrapper := cfwebsocket.GorillaConn{Conn: conn}
104 buf := make([]byte, 100)
105 wrapper.Write([]byte("abc"))
106 n, err := wrapper.Read(buf)
107 require.NoError(t, err)
108 require.Equal(t, n, 3)
109 require.Equal(t, "abc", string(buf[:n]))
110
111 // Test partial read, read 1 of 3 bytes in one read and the other 2 in another read
112 wrapper.Write([]byte("abc"))
113 buf = buf[:1]
114 n, err = wrapper.Read(buf)
115 require.NoError(t, err)
116 require.Equal(t, n, 1)
117 require.Equal(t, "a", string(buf[:n]))
118 buf = buf[:cap(buf)]
119 n, err = wrapper.Read(buf)
120 require.NoError(t, err)
121 require.Equal(t, n, 2)
122 require.Equal(t, "bc", string(buf[:n]))
123}
124