cloudflare/cloudflared
Publicmirrored from https://github.com/cloudflare/cloudflaredAvailable
carrier/carrier_test.go
155lines · modecode
| 1 | package carrier |
| 2 | |
| 3 | import ( |
| 4 | "bytes" |
| 5 | "io" |
| 6 | "net" |
| 7 | "net/http" |
| 8 | "net/http/httptest" |
| 9 | "sync" |
| 10 | "testing" |
| 11 | |
| 12 | ws "github.com/gorilla/websocket" |
| 13 | "github.com/sirupsen/logrus" |
| 14 | "github.com/stretchr/testify/assert" |
| 15 | ) |
| 16 | |
| 17 | const ( |
| 18 | // example in Sec-Websocket-Key in rfc6455 |
| 19 | testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ==" |
| 20 | ) |
| 21 | |
| 22 | type testStreamer struct { |
| 23 | buf *bytes.Buffer |
| 24 | l sync.RWMutex |
| 25 | } |
| 26 | |
| 27 | func newTestStream() *testStreamer { |
| 28 | return &testStreamer{buf: new(bytes.Buffer)} |
| 29 | } |
| 30 | |
| 31 | func (s *testStreamer) Read(p []byte) (int, error) { |
| 32 | s.l.RLock() |
| 33 | defer s.l.RUnlock() |
| 34 | return s.buf.Read(p) |
| 35 | |
| 36 | } |
| 37 | |
| 38 | func (s *testStreamer) Write(p []byte) (int, error) { |
| 39 | s.l.Lock() |
| 40 | defer s.l.Unlock() |
| 41 | return s.buf.Write(p) |
| 42 | } |
| 43 | |
| 44 | func TestStartClient(t *testing.T) { |
| 45 | message := "Good morning Austin! Time for another sunny day in the great state of Texas." |
| 46 | logger := logrus.New() |
| 47 | ts := newTestWebSocketServer() |
| 48 | defer ts.Close() |
| 49 | |
| 50 | buf := newTestStream() |
| 51 | options := &StartOptions{ |
| 52 | OriginURL: "http://" + ts.Listener.Addr().String(), |
| 53 | Headers: nil, |
| 54 | } |
| 55 | err := StartClient(logger, buf, options) |
| 56 | assert.NoError(t, err) |
| 57 | buf.Write([]byte(message)) |
| 58 | |
| 59 | readBuffer := make([]byte, len(message)) |
| 60 | buf.Read(readBuffer) |
| 61 | assert.Equal(t, message, string(readBuffer)) |
| 62 | } |
| 63 | |
| 64 | func TestStartServer(t *testing.T) { |
| 65 | listener, err := net.Listen("tcp", "localhost:") |
| 66 | if err != nil { |
| 67 | t.Fatalf("Error starting listener: %v", err) |
| 68 | } |
| 69 | message := "Good morning Austin! Time for another sunny day in the great state of Texas." |
| 70 | logger := logrus.New() |
| 71 | shutdownC := make(chan struct{}) |
| 72 | ts := newTestWebSocketServer() |
| 73 | defer ts.Close() |
| 74 | options := &StartOptions{ |
| 75 | OriginURL: "http://" + ts.Listener.Addr().String(), |
| 76 | Headers: nil, |
| 77 | } |
| 78 | |
| 79 | go func() { |
| 80 | err := Serve(logger, listener, shutdownC, options) |
| 81 | if err != nil { |
| 82 | t.Fatalf("Error running server: %v", err) |
| 83 | } |
| 84 | }() |
| 85 | |
| 86 | conn, err := net.Dial("tcp", listener.Addr().String()) |
| 87 | conn.Write([]byte(message)) |
| 88 | |
| 89 | readBuffer := make([]byte, len(message)) |
| 90 | conn.Read(readBuffer) |
| 91 | assert.Equal(t, string(readBuffer), message) |
| 92 | } |
| 93 | |
| 94 | func TestIsAccessResponse(t *testing.T) { |
| 95 | validLocationHeader := http.Header{} |
| 96 | validLocationHeader.Add("location", "https://test.cloudflareaccess.com/cdn-cgi/access/login/blahblah") |
| 97 | invalidLocationHeader := http.Header{} |
| 98 | invalidLocationHeader.Add("location", "https://google.com") |
| 99 | testCases := []struct { |
| 100 | Description string |
| 101 | In *http.Response |
| 102 | ExpectedOut bool |
| 103 | }{ |
| 104 | {"nil response", nil, false}, |
| 105 | {"redirect with no location", &http.Response{StatusCode: http.StatusFound}, false}, |
| 106 | {"200 ok", &http.Response{StatusCode: http.StatusOK}, false}, |
| 107 | {"redirect with location", &http.Response{StatusCode: http.StatusFound, Header: validLocationHeader}, true}, |
| 108 | {"redirect with invalid location", &http.Response{StatusCode: http.StatusFound, Header: invalidLocationHeader}, false}, |
| 109 | } |
| 110 | |
| 111 | for i, tc := range testCases { |
| 112 | if IsAccessResponse(tc.In) != tc.ExpectedOut { |
| 113 | t.Fatalf("Failed case %d -- %s", i, tc.Description) |
| 114 | } |
| 115 | } |
| 116 | |
| 117 | } |
| 118 | |
| 119 | func newTestWebSocketServer() *httptest.Server { |
| 120 | upgrader := ws.Upgrader{ |
| 121 | ReadBufferSize: 1024, |
| 122 | WriteBufferSize: 1024, |
| 123 | } |
| 124 | |
| 125 | return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 126 | conn, _ := upgrader.Upgrade(w, r, nil) |
| 127 | defer conn.Close() |
| 128 | for { |
| 129 | mt, message, err := conn.ReadMessage() |
| 130 | if err != nil { |
| 131 | break |
| 132 | } |
| 133 | |
| 134 | if err := conn.WriteMessage(mt, []byte(message)); err != nil { |
| 135 | break |
| 136 | } |
| 137 | } |
| 138 | })) |
| 139 | } |
| 140 | |
| 141 | func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request { |
| 142 | req, err := http.NewRequest("GET", url, stream) |
| 143 | if err != nil { |
| 144 | t.Fatalf("testRequestHeader error") |
| 145 | } |
| 146 | |
| 147 | req.Header.Add("Connection", "Upgrade") |
| 148 | req.Header.Add("Upgrade", "WebSocket") |
| 149 | req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey) |
| 150 | req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol") |
| 151 | req.Header.Add("Sec-Websocket-Version", "13") |
| 152 | req.Header.Add("User-Agent", "curl/7.59.0") |
| 153 | |
| 154 | return req |
| 155 | } |
| 156 | |