cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2020.3.2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

carrier/carrier_test.go

155lines · modecode

1package carrier
2
3import (
4 "bytes"
5 "io"
6 "net"
7 "net/http"
8 "net/http/httptest"
9 "sync"
10 "testing"
11
12 ws "github.com/gorilla/websocket"
13 "github.com/sirupsen/logrus"
14 "github.com/stretchr/testify/assert"
15)
16
17const (
18 // example in Sec-Websocket-Key in rfc6455
19 testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ=="
20)
21
22type testStreamer struct {
23 buf *bytes.Buffer
24 l sync.RWMutex
25}
26
27func newTestStream() *testStreamer {
28 return &testStreamer{buf: new(bytes.Buffer)}
29}
30
31func (s *testStreamer) Read(p []byte) (int, error) {
32 s.l.RLock()
33 defer s.l.RUnlock()
34 return s.buf.Read(p)
35
36}
37
38func (s *testStreamer) Write(p []byte) (int, error) {
39 s.l.Lock()
40 defer s.l.Unlock()
41 return s.buf.Write(p)
42}
43
44func TestStartClient(t *testing.T) {
45 message := "Good morning Austin! Time for another sunny day in the great state of Texas."
46 logger := logrus.New()
47 ts := newTestWebSocketServer()
48 defer ts.Close()
49
50 buf := newTestStream()
51 options := &StartOptions{
52 OriginURL: "http://" + ts.Listener.Addr().String(),
53 Headers: nil,
54 }
55 err := StartClient(logger, buf, options)
56 assert.NoError(t, err)
57 buf.Write([]byte(message))
58
59 readBuffer := make([]byte, len(message))
60 buf.Read(readBuffer)
61 assert.Equal(t, message, string(readBuffer))
62}
63
64func TestStartServer(t *testing.T) {
65 listener, err := net.Listen("tcp", "localhost:")
66 if err != nil {
67 t.Fatalf("Error starting listener: %v", err)
68 }
69 message := "Good morning Austin! Time for another sunny day in the great state of Texas."
70 logger := logrus.New()
71 shutdownC := make(chan struct{})
72 ts := newTestWebSocketServer()
73 defer ts.Close()
74 options := &StartOptions{
75 OriginURL: "http://" + ts.Listener.Addr().String(),
76 Headers: nil,
77 }
78
79 go func() {
80 err := Serve(logger, listener, shutdownC, options)
81 if err != nil {
82 t.Fatalf("Error running server: %v", err)
83 }
84 }()
85
86 conn, err := net.Dial("tcp", listener.Addr().String())
87 conn.Write([]byte(message))
88
89 readBuffer := make([]byte, len(message))
90 conn.Read(readBuffer)
91 assert.Equal(t, string(readBuffer), message)
92}
93
94func TestIsAccessResponse(t *testing.T) {
95 validLocationHeader := http.Header{}
96 validLocationHeader.Add("location", "https://test.cloudflareaccess.com/cdn-cgi/access/login/blahblah")
97 invalidLocationHeader := http.Header{}
98 invalidLocationHeader.Add("location", "https://google.com")
99 testCases := []struct {
100 Description string
101 In *http.Response
102 ExpectedOut bool
103 }{
104 {"nil response", nil, false},
105 {"redirect with no location", &http.Response{StatusCode: http.StatusFound}, false},
106 {"200 ok", &http.Response{StatusCode: http.StatusOK}, false},
107 {"redirect with location", &http.Response{StatusCode: http.StatusFound, Header: validLocationHeader}, true},
108 {"redirect with invalid location", &http.Response{StatusCode: http.StatusFound, Header: invalidLocationHeader}, false},
109 }
110
111 for i, tc := range testCases {
112 if IsAccessResponse(tc.In) != tc.ExpectedOut {
113 t.Fatalf("Failed case %d -- %s", i, tc.Description)
114 }
115 }
116
117}
118
119func newTestWebSocketServer() *httptest.Server {
120 upgrader := ws.Upgrader{
121 ReadBufferSize: 1024,
122 WriteBufferSize: 1024,
123 }
124
125 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
126 conn, _ := upgrader.Upgrade(w, r, nil)
127 defer conn.Close()
128 for {
129 mt, message, err := conn.ReadMessage()
130 if err != nil {
131 break
132 }
133
134 if err := conn.WriteMessage(mt, []byte(message)); err != nil {
135 break
136 }
137 }
138 }))
139}
140
141func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
142 req, err := http.NewRequest("GET", url, stream)
143 if err != nil {
144 t.Fatalf("testRequestHeader error")
145 }
146
147 req.Header.Add("Connection", "Upgrade")
148 req.Header.Add("Upgrade", "WebSocket")
149 req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
150 req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
151 req.Header.Add("Sec-Websocket-Version", "13")
152 req.Header.Add("User-Agent", "curl/7.59.0")
153
154 return req
155}
156