cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2020.7.1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

carrier/carrier_test.go

157lines · modecode

1package carrier
2
3import (
4 "bytes"
5 "io"
6 "net"
7 "net/http"
8 "net/http/httptest"
9 "sync"
10 "testing"
11
12 "github.com/cloudflare/cloudflared/logger"
13 ws "github.com/gorilla/websocket"
14 "github.com/stretchr/testify/assert"
15)
16
17const (
18 // example in Sec-Websocket-Key in rfc6455
19 testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ=="
20)
21
22type testStreamer struct {
23 buf *bytes.Buffer
24 l sync.RWMutex
25}
26
27func newTestStream() *testStreamer {
28 return &testStreamer{buf: new(bytes.Buffer)}
29}
30
31func (s *testStreamer) Read(p []byte) (int, error) {
32 s.l.RLock()
33 defer s.l.RUnlock()
34 return s.buf.Read(p)
35
36}
37
38func (s *testStreamer) Write(p []byte) (int, error) {
39 s.l.Lock()
40 defer s.l.Unlock()
41 return s.buf.Write(p)
42}
43
44func TestStartClient(t *testing.T) {
45 message := "Good morning Austin! Time for another sunny day in the great state of Texas."
46 logger := logger.NewOutputWriter(logger.NewMockWriteManager())
47 wsConn := NewWSConnection(logger, false)
48 ts := newTestWebSocketServer()
49 defer ts.Close()
50
51 buf := newTestStream()
52 options := &StartOptions{
53 OriginURL: "http://" + ts.Listener.Addr().String(),
54 Headers: nil,
55 }
56 err := StartClient(wsConn, buf, options)
57 assert.NoError(t, err)
58 buf.Write([]byte(message))
59
60 readBuffer := make([]byte, len(message))
61 buf.Read(readBuffer)
62 assert.Equal(t, message, string(readBuffer))
63}
64
65func TestStartServer(t *testing.T) {
66 listener, err := net.Listen("tcp", "localhost:")
67 if err != nil {
68 t.Fatalf("Error starting listener: %v", err)
69 }
70 message := "Good morning Austin! Time for another sunny day in the great state of Texas."
71 logger := logger.NewOutputWriter(logger.NewMockWriteManager())
72 shutdownC := make(chan struct{})
73 wsConn := NewWSConnection(logger, false)
74 ts := newTestWebSocketServer()
75 defer ts.Close()
76 options := &StartOptions{
77 OriginURL: "http://" + ts.Listener.Addr().String(),
78 Headers: nil,
79 }
80
81 go func() {
82 err := Serve(wsConn, listener, shutdownC, options)
83 if err != nil {
84 t.Fatalf("Error running server: %v", err)
85 }
86 }()
87
88 conn, err := net.Dial("tcp", listener.Addr().String())
89 conn.Write([]byte(message))
90
91 readBuffer := make([]byte, len(message))
92 conn.Read(readBuffer)
93 assert.Equal(t, string(readBuffer), message)
94}
95
96func TestIsAccessResponse(t *testing.T) {
97 validLocationHeader := http.Header{}
98 validLocationHeader.Add("location", "https://test.cloudflareaccess.com/cdn-cgi/access/login/blahblah")
99 invalidLocationHeader := http.Header{}
100 invalidLocationHeader.Add("location", "https://google.com")
101 testCases := []struct {
102 Description string
103 In *http.Response
104 ExpectedOut bool
105 }{
106 {"nil response", nil, false},
107 {"redirect with no location", &http.Response{StatusCode: http.StatusFound}, false},
108 {"200 ok", &http.Response{StatusCode: http.StatusOK}, false},
109 {"redirect with location", &http.Response{StatusCode: http.StatusFound, Header: validLocationHeader}, true},
110 {"redirect with invalid location", &http.Response{StatusCode: http.StatusFound, Header: invalidLocationHeader}, false},
111 }
112
113 for i, tc := range testCases {
114 if IsAccessResponse(tc.In) != tc.ExpectedOut {
115 t.Fatalf("Failed case %d -- %s", i, tc.Description)
116 }
117 }
118
119}
120
121func newTestWebSocketServer() *httptest.Server {
122 upgrader := ws.Upgrader{
123 ReadBufferSize: 1024,
124 WriteBufferSize: 1024,
125 }
126
127 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
128 conn, _ := upgrader.Upgrade(w, r, nil)
129 defer conn.Close()
130 for {
131 mt, message, err := conn.ReadMessage()
132 if err != nil {
133 break
134 }
135
136 if err := conn.WriteMessage(mt, []byte(message)); err != nil {
137 break
138 }
139 }
140 }))
141}
142
143func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
144 req, err := http.NewRequest("GET", url, stream)
145 if err != nil {
146 t.Fatalf("testRequestHeader error")
147 }
148
149 req.Header.Add("Connection", "Upgrade")
150 req.Header.Add("Upgrade", "WebSocket")
151 req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
152 req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
153 req.Header.Add("Sec-Websocket-Version", "13")
154 req.Header.Add("User-Agent", "curl/7.59.0")
155
156 return req
157}