cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2020.7.4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

dbconnect/proxy_test.go

238lines · modeblame

759cd019Ashcon Partovi6 years ago1package dbconnect
2
3import (
4"context"
5"fmt"
6"io"
7"io/ioutil"
8"net"
9"net/http"
10"net/http/httptest"
11"strings"
12"testing"
13
14"github.com/gorilla/mux"
15
16"github.com/stretchr/testify/assert"
17)
18
19func TestNewInsecureProxy(t *testing.T) {
20origins := []string{
21"",
22":/",
23"http://localhost",
24"tcp://localhost:9000?debug=true",
25"mongodb://127.0.0.1",
26}
27
28for _, origin := range origins {
29proxy, err := NewInsecureProxy(context.Background(), origin)
30
31assert.Error(t, err)
32assert.Empty(t, proxy)
33}
34}
35
36func TestProxyIsAllowed(t *testing.T) {
37proxy := helperNewProxy(t)
38req := httptest.NewRequest("GET", "https://1.1.1.1/ping", nil)
39assert.True(t, proxy.IsAllowed(req))
40
41proxy = helperNewProxy(t, true)
42req.Header.Set("Cf-access-jwt-assertion", "xxx")
43assert.False(t, proxy.IsAllowed(req))
44}
45
46func TestProxyStart(t *testing.T) {
47proxy := helperNewProxy(t)
48ctx := context.Background()
49listenerC := make(chan net.Listener)
50
51err := proxy.Start(ctx, "1.1.1.1:", listenerC)
52assert.Error(t, err)
53
54err = proxy.Start(ctx, "127.0.0.1:-1", listenerC)
55assert.Error(t, err)
56
57ctx, cancel := context.WithTimeout(ctx, 0)
58defer cancel()
59
60err = proxy.Start(ctx, "127.0.0.1:", listenerC)
61assert.IsType(t, http.ErrServerClosed, err)
62}
63
64func TestProxyHTTPRouter(t *testing.T) {
65proxy := helperNewProxy(t)
66router := proxy.httpRouter()
67
68tests := []struct {
69path string
70method string
71valid bool
72}{
73{"", "GET", false},
74{"/", "GET", false},
75{"/ping", "GET", true},
76{"/ping", "HEAD", true},
77{"/ping", "POST", false},
78{"/submit", "POST", true},
79{"/submit", "GET", false},
80{"/submit/extra", "POST", false},
81}
82
83for _, test := range tests {
84match := &mux.RouteMatch{}
85ok := router.Match(httptest.NewRequest(test.method, "https://1.1.1.1"+test.path, nil), match)
86
87assert.True(t, ok == test.valid, test.path)
88}
89}
90
91func TestProxyHTTPPing(t *testing.T) {
92proxy := helperNewProxy(t)
93
94server := httptest.NewServer(proxy.httpPing())
95defer server.Close()
96client := server.Client()
97
98res, err := client.Get(server.URL)
99assert.NoError(t, err)
100assert.Equal(t, http.StatusOK, res.StatusCode)
101assert.Equal(t, int64(2), res.ContentLength)
102
103res, err = client.Head(server.URL)
104assert.NoError(t, err)
105assert.Equal(t, http.StatusOK, res.StatusCode)
106assert.Equal(t, int64(-1), res.ContentLength)
107}
108
109func TestProxyHTTPSubmit(t *testing.T) {
110proxy := helperNewProxy(t)
111
112server := httptest.NewServer(proxy.httpSubmit())
113defer server.Close()
114client := server.Client()
115
116tests := []struct {
117input string
118status int
119output string
120}{
121{"", http.StatusBadRequest, "request body cannot be empty"},
122{"{}", http.StatusBadRequest, "cannot provide an empty statement"},
123{"{\"statement\":\"Ok\"}", http.StatusUnprocessableEntity, "cannot provide invalid sql mode: ''"},
124{"{\"statement\":\"Ok\",\"mode\":\"query\"}", http.StatusUnprocessableEntity, "near \"Ok\": syntax error"},
125{"{\"statement\":\"CREATE TABLE t (a INT);\",\"mode\":\"exec\"}", http.StatusOK, "{\"last_insert_id\":0,\"rows_affected\":0}\n"},
126}
127
128for _, test := range tests {
129res, err := client.Post(server.URL, "application/json", strings.NewReader(test.input))
130
131assert.NoError(t, err)
132assert.Equal(t, test.status, res.StatusCode)
133if res.StatusCode > http.StatusOK {
134assert.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-type"))
135} else {
136assert.Equal(t, "application/json", res.Header.Get("Content-type"))
137}
138
139data, err := ioutil.ReadAll(res.Body)
140defer res.Body.Close()
141str := string(data)
142
143assert.NoError(t, err)
144assert.Equal(t, test.output, str)
145}
146}
147
148func TestProxyHTTPSubmitForbidden(t *testing.T) {
149proxy := helperNewProxy(t, true)
150
151server := httptest.NewServer(proxy.httpSubmit())
152defer server.Close()
153client := server.Client()
154
155res, err := client.Get(server.URL)
156
157assert.NoError(t, err)
158assert.Equal(t, http.StatusForbidden, res.StatusCode)
159assert.Zero(t, res.ContentLength)
160}
161
162func TestProxyHTTPRespond(t *testing.T) {
163proxy := helperNewProxy(t)
164
165server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
166proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
167}))
168defer server.Close()
169client := server.Client()
170
171res, err := client.Get(server.URL)
172assert.NoError(t, err)
173assert.Equal(t, http.StatusAccepted, res.StatusCode)
174assert.Equal(t, int64(5), res.ContentLength)
175
176data, err := ioutil.ReadAll(res.Body)
177defer res.Body.Close()
178assert.Equal(t, []byte("Hello"), data)
179}
180
181func TestProxyHTTPRespondForbidden(t *testing.T) {
182proxy := helperNewProxy(t, true)
183
184server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
185proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
186}))
187defer server.Close()
188client := server.Client()
189
190res, err := client.Get(server.URL)
191
192assert.NoError(t, err)
193assert.Equal(t, http.StatusAccepted, res.StatusCode)
194assert.Equal(t, int64(0), res.ContentLength)
195}
196
197func TestHTTPError(t *testing.T) {
198_, errTimeout := net.DialTimeout("tcp", "127.0.0.1", 0)
199assert.Error(t, errTimeout)
200
201tests := []struct {
202input error
203status int
204output error
205}{
206{nil, http.StatusNotImplemented, fmt.Errorf("error expected but found none")},
207{io.EOF, http.StatusBadRequest, fmt.Errorf("request body cannot be empty")},
208{context.DeadlineExceeded, http.StatusRequestTimeout, nil},
209{context.Canceled, 444, nil},
210{errTimeout, http.StatusRequestTimeout, nil},
211{fmt.Errorf(""), http.StatusInternalServerError, nil},
212}
213
214for _, test := range tests {
215status, err := httpError(http.StatusInternalServerError, test.input)
216
217assert.Error(t, err)
218assert.Equal(t, test.status, status)
219if test.output == nil {
220test.output = test.input
221}
222assert.Equal(t, test.output, err)
223}
224}
225
226func helperNewProxy(t *testing.T, secure ...bool) *Proxy {
227t.Helper()
228
229proxy, err := NewSecureProxy(context.Background(), "file::memory:?cache=shared", "test.cloudflareaccess.com", "")
230assert.NoError(t, err)
231assert.NotNil(t, proxy)
232
233if len(secure) == 0 {
234proxy.accessValidator = nil // Mark as insecure
235}
236
237return proxy
238}