cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2021.10.5

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

carrier/websocket.go

199lines · modecode

1package carrier
2
3import (
4 "io"
5 "net/http"
6 "net/http/httputil"
7 "net/url"
8
9 "github.com/gorilla/websocket"
10 "github.com/rs/zerolog"
11
12 "github.com/cloudflare/cloudflared/token"
13 cfwebsocket "github.com/cloudflare/cloudflared/websocket"
14)
15
16// Websocket is used to carry data via WS binary frames over the tunnel from client to the origin
17// This implements the functions for glider proxy (sock5) and the carrier interface
18type Websocket struct {
19 log *zerolog.Logger
20 isSocks bool
21}
22
23// NewWSConnection returns a new connection object
24func NewWSConnection(log *zerolog.Logger) Connection {
25 return &Websocket{
26 log: log,
27 }
28}
29
30// ServeStream will create a Websocket client stream connection to the edge
31// it blocks and writes the raw data from conn over the tunnel
32func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) error {
33 wsConn, err := createWebsocketStream(options, ws.log)
34 if err != nil {
35 ws.log.Err(err).Str(LogFieldOriginURL, options.OriginURL).Msg("failed to connect to origin")
36 return err
37 }
38 defer wsConn.Close()
39
40 cfwebsocket.Stream(wsConn, conn, ws.log)
41 return nil
42}
43
44// createWebsocketStream will create a WebSocket connection to stream data over
45// It also handles redirects from Access and will present that flow if
46// the token is not present on the request
47func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebsocket.GorillaConn, error) {
48 req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
49 if err != nil {
50 return nil, err
51 }
52 req.Header = options.Headers
53 if options.Host != "" {
54 req.Host = options.Host
55 }
56
57 dump, err := httputil.DumpRequest(req, false)
58 log.Debug().Msgf("Websocket request: %s", string(dump))
59
60 dialer := &websocket.Dialer{
61 TLSClientConfig: options.TLSClientConfig,
62 Proxy: http.ProxyFromEnvironment,
63 }
64 wsConn, resp, err := clientConnect(req, dialer)
65 defer closeRespBody(resp)
66
67 if err != nil && IsAccessResponse(resp) {
68 // Only get Access app info if we know the origin is protected by Access
69 originReq, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
70 if err != nil {
71 return nil, err
72 }
73
74 appInfo, err := token.GetAppInfo(originReq.URL)
75 if err != nil {
76 return nil, err
77 }
78 options.AppInfo = appInfo
79
80 wsConn, err = createAccessAuthenticatedStream(options, log)
81 if err != nil {
82 return nil, err
83 }
84 } else if err != nil {
85 return nil, err
86 }
87
88 return &cfwebsocket.GorillaConn{Conn: wsConn}, nil
89}
90
91var stripWebsocketHeaders = []string{
92 "Upgrade",
93 "Connection",
94 "Sec-Websocket-Key",
95 "Sec-Websocket-Version",
96 "Sec-Websocket-Extensions",
97}
98
99// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
100// Sec-WebSocket-Version and Sec-Websocket-Extensions headers.
101// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194.
102func websocketHeaders(req *http.Request) http.Header {
103 wsHeaders := make(http.Header)
104 for key, val := range req.Header {
105 wsHeaders[key] = val
106 }
107 // Assume the header keys are in canonical format.
108 for _, header := range stripWebsocketHeaders {
109 wsHeaders.Del(header)
110 }
111 wsHeaders.Set("Host", req.Host) // See TUN-1097
112 return wsHeaders
113}
114
115// clientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
116// the connection. The response body may not contain the entire response and does
117// not need to be closed by the application.
118func clientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) {
119 req.URL.Scheme = changeRequestScheme(req.URL)
120 wsHeaders := websocketHeaders(req)
121 if dialler == nil {
122 dialler = &websocket.Dialer{
123 Proxy: http.ProxyFromEnvironment,
124 }
125 }
126 conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
127 if err != nil {
128 return nil, response, err
129 }
130 return conn, response, nil
131}
132
133// changeRequestScheme is needed as the gorilla websocket library requires the ws scheme.
134// (even though it changes it back to http/https, but ¯\_(ツ)_/¯.)
135func changeRequestScheme(reqURL *url.URL) string {
136 switch reqURL.Scheme {
137 case "https":
138 return "wss"
139 case "http":
140 return "ws"
141 case "":
142 return "ws"
143 default:
144 return reqURL.Scheme
145 }
146}
147
148// createAccessAuthenticatedStream will try load a token from storage and make
149// a connection with the token set on the request. If it still get redirect,
150// this probably means the token in storage is invalid (expired/revoked). If that
151// happens it deletes the token and runs the connection again, so the user can
152// login again and generate a new one.
153func createAccessAuthenticatedStream(options *StartOptions, log *zerolog.Logger) (*websocket.Conn, error) {
154 wsConn, resp, err := createAccessWebSocketStream(options, log)
155 defer closeRespBody(resp)
156 if err == nil {
157 return wsConn, nil
158 }
159
160 if !IsAccessResponse(resp) {
161 return nil, err
162 }
163
164 // Access Token is invalid for some reason. Go through regen flow
165 if err := token.RemoveTokenIfExists(options.AppInfo); err != nil {
166 return nil, err
167 }
168 wsConn, resp, err = createAccessWebSocketStream(options, log)
169 defer closeRespBody(resp)
170 if err != nil {
171 return nil, err
172 }
173
174 return wsConn, nil
175}
176
177// createAccessWebSocketStream builds an Access request and makes a connection
178func createAccessWebSocketStream(options *StartOptions, log *zerolog.Logger) (*websocket.Conn, *http.Response, error) {
179 req, err := BuildAccessRequest(options, log)
180 if err != nil {
181 return nil, nil, err
182 }
183
184 dump, err := httputil.DumpRequest(req, false)
185 log.Debug().Msgf("Access Websocket request: %s", string(dump))
186
187 conn, resp, err := clientConnect(req, nil)
188
189 if resp != nil {
190 r, err := httputil.DumpResponse(resp, true)
191 if r != nil {
192 log.Debug().Msgf("Websocket response: %q", r)
193 } else if err != nil {
194 log.Debug().Msgf("Websocket response error: %v", err)
195 }
196 }
197
198 return conn, resp, err
199}
200