cloudflare/cloudflared
Publicmirrored from https://github.com/cloudflare/cloudflaredAvailable
carrier/carrier_test.go
157lines · modecode
| 1 | package carrier |
| 2 | |
| 3 | import ( |
| 4 | "bytes" |
| 5 | "io" |
| 6 | "net" |
| 7 | "net/http" |
| 8 | "net/http/httptest" |
| 9 | "sync" |
| 10 | "testing" |
| 11 | |
| 12 | "github.com/cloudflare/cloudflared/logger" |
| 13 | ws "github.com/gorilla/websocket" |
| 14 | "github.com/stretchr/testify/assert" |
| 15 | ) |
| 16 | |
| 17 | const ( |
| 18 | // example in Sec-Websocket-Key in rfc6455 |
| 19 | testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ==" |
| 20 | ) |
| 21 | |
| 22 | type testStreamer struct { |
| 23 | buf *bytes.Buffer |
| 24 | l sync.RWMutex |
| 25 | } |
| 26 | |
| 27 | func newTestStream() *testStreamer { |
| 28 | return &testStreamer{buf: new(bytes.Buffer)} |
| 29 | } |
| 30 | |
| 31 | func (s *testStreamer) Read(p []byte) (int, error) { |
| 32 | s.l.RLock() |
| 33 | defer s.l.RUnlock() |
| 34 | return s.buf.Read(p) |
| 35 | |
| 36 | } |
| 37 | |
| 38 | func (s *testStreamer) Write(p []byte) (int, error) { |
| 39 | s.l.Lock() |
| 40 | defer s.l.Unlock() |
| 41 | return s.buf.Write(p) |
| 42 | } |
| 43 | |
| 44 | func TestStartClient(t *testing.T) { |
| 45 | message := "Good morning Austin! Time for another sunny day in the great state of Texas." |
| 46 | logger := logger.NewOutputWriter(logger.NewMockWriteManager()) |
| 47 | wsConn := NewWSConnection(logger, false) |
| 48 | ts := newTestWebSocketServer() |
| 49 | defer ts.Close() |
| 50 | |
| 51 | buf := newTestStream() |
| 52 | options := &StartOptions{ |
| 53 | OriginURL: "http://" + ts.Listener.Addr().String(), |
| 54 | Headers: nil, |
| 55 | } |
| 56 | err := StartClient(wsConn, buf, options) |
| 57 | assert.NoError(t, err) |
| 58 | buf.Write([]byte(message)) |
| 59 | |
| 60 | readBuffer := make([]byte, len(message)) |
| 61 | buf.Read(readBuffer) |
| 62 | assert.Equal(t, message, string(readBuffer)) |
| 63 | } |
| 64 | |
| 65 | func TestStartServer(t *testing.T) { |
| 66 | listener, err := net.Listen("tcp", "localhost:") |
| 67 | if err != nil { |
| 68 | t.Fatalf("Error starting listener: %v", err) |
| 69 | } |
| 70 | message := "Good morning Austin! Time for another sunny day in the great state of Texas." |
| 71 | logger := logger.NewOutputWriter(logger.NewMockWriteManager()) |
| 72 | shutdownC := make(chan struct{}) |
| 73 | wsConn := NewWSConnection(logger, false) |
| 74 | ts := newTestWebSocketServer() |
| 75 | defer ts.Close() |
| 76 | options := &StartOptions{ |
| 77 | OriginURL: "http://" + ts.Listener.Addr().String(), |
| 78 | Headers: nil, |
| 79 | } |
| 80 | |
| 81 | go func() { |
| 82 | err := Serve(wsConn, listener, shutdownC, options) |
| 83 | if err != nil { |
| 84 | t.Fatalf("Error running server: %v", err) |
| 85 | } |
| 86 | }() |
| 87 | |
| 88 | conn, err := net.Dial("tcp", listener.Addr().String()) |
| 89 | conn.Write([]byte(message)) |
| 90 | |
| 91 | readBuffer := make([]byte, len(message)) |
| 92 | conn.Read(readBuffer) |
| 93 | assert.Equal(t, string(readBuffer), message) |
| 94 | } |
| 95 | |
| 96 | func TestIsAccessResponse(t *testing.T) { |
| 97 | validLocationHeader := http.Header{} |
| 98 | validLocationHeader.Add("location", "https://test.cloudflareaccess.com/cdn-cgi/access/login/blahblah") |
| 99 | invalidLocationHeader := http.Header{} |
| 100 | invalidLocationHeader.Add("location", "https://google.com") |
| 101 | testCases := []struct { |
| 102 | Description string |
| 103 | In *http.Response |
| 104 | ExpectedOut bool |
| 105 | }{ |
| 106 | {"nil response", nil, false}, |
| 107 | {"redirect with no location", &http.Response{StatusCode: http.StatusFound}, false}, |
| 108 | {"200 ok", &http.Response{StatusCode: http.StatusOK}, false}, |
| 109 | {"redirect with location", &http.Response{StatusCode: http.StatusFound, Header: validLocationHeader}, true}, |
| 110 | {"redirect with invalid location", &http.Response{StatusCode: http.StatusFound, Header: invalidLocationHeader}, false}, |
| 111 | } |
| 112 | |
| 113 | for i, tc := range testCases { |
| 114 | if IsAccessResponse(tc.In) != tc.ExpectedOut { |
| 115 | t.Fatalf("Failed case %d -- %s", i, tc.Description) |
| 116 | } |
| 117 | } |
| 118 | |
| 119 | } |
| 120 | |
| 121 | func newTestWebSocketServer() *httptest.Server { |
| 122 | upgrader := ws.Upgrader{ |
| 123 | ReadBufferSize: 1024, |
| 124 | WriteBufferSize: 1024, |
| 125 | } |
| 126 | |
| 127 | return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 128 | conn, _ := upgrader.Upgrade(w, r, nil) |
| 129 | defer conn.Close() |
| 130 | for { |
| 131 | mt, message, err := conn.ReadMessage() |
| 132 | if err != nil { |
| 133 | break |
| 134 | } |
| 135 | |
| 136 | if err := conn.WriteMessage(mt, []byte(message)); err != nil { |
| 137 | break |
| 138 | } |
| 139 | } |
| 140 | })) |
| 141 | } |
| 142 | |
| 143 | func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request { |
| 144 | req, err := http.NewRequest("GET", url, stream) |
| 145 | if err != nil { |
| 146 | t.Fatalf("testRequestHeader error") |
| 147 | } |
| 148 | |
| 149 | req.Header.Add("Connection", "Upgrade") |
| 150 | req.Header.Add("Upgrade", "WebSocket") |
| 151 | req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey) |
| 152 | req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol") |
| 153 | req.Header.Add("Sec-Websocket-Version", "13") |
| 154 | req.Header.Add("User-Agent", "curl/7.59.0") |
| 155 | |
| 156 | return req |
| 157 | } |