cloudflare/cloudflared
Publicmirrored fromhttps://github.com/cloudflare/cloudflaredAvailable
carrier/websocket.go
206lines · modecode
| 1 | package carrier |
| 2 | |
| 3 | import ( |
| 4 | "io" |
| 5 | "net/http" |
| 6 | "net/http/httputil" |
| 7 | "net/url" |
| 8 | |
| 9 | "github.com/gorilla/websocket" |
| 10 | "github.com/rs/zerolog" |
| 11 | |
| 12 | "github.com/cloudflare/cloudflared/stream" |
| 13 | "github.com/cloudflare/cloudflared/token" |
| 14 | cfwebsocket "github.com/cloudflare/cloudflared/websocket" |
| 15 | ) |
| 16 | |
| 17 | // Websocket is used to carry data via WS binary frames over the tunnel from client to the origin |
| 18 | // This implements the functions for glider proxy (sock5) and the carrier interface |
| 19 | type Websocket struct { |
| 20 | log *zerolog.Logger |
| 21 | isSocks bool |
| 22 | } |
| 23 | |
| 24 | // NewWSConnection returns a new connection object |
| 25 | func NewWSConnection(log *zerolog.Logger) Connection { |
| 26 | return &Websocket{ |
| 27 | log: log, |
| 28 | } |
| 29 | } |
| 30 | |
| 31 | // ServeStream will create a Websocket client stream connection to the edge |
| 32 | // it blocks and writes the raw data from conn over the tunnel |
| 33 | func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) error { |
| 34 | wsConn, err := createWebsocketStream(options, ws.log) |
| 35 | if err != nil { |
| 36 | ws.log.Err(err).Str(LogFieldOriginURL, options.OriginURL).Msg("failed to connect to origin") |
| 37 | return err |
| 38 | } |
| 39 | defer wsConn.Close() |
| 40 | |
| 41 | stream.Pipe(wsConn, conn, ws.log) |
| 42 | return nil |
| 43 | } |
| 44 | |
| 45 | // createWebsocketStream will create a WebSocket connection to stream data over |
| 46 | // It also handles redirects from Access and will present that flow if |
| 47 | // the token is not present on the request |
| 48 | func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebsocket.GorillaConn, error) { |
| 49 | req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil) |
| 50 | if err != nil { |
| 51 | return nil, err |
| 52 | } |
| 53 | req.Header = options.Headers |
| 54 | if options.Host != "" { |
| 55 | req.Host = options.Host |
| 56 | } |
| 57 | |
| 58 | dump, err := httputil.DumpRequest(req, false) |
| 59 | if err != nil { |
| 60 | return nil, err |
| 61 | } |
| 62 | log.Debug().Msgf("Websocket request: %s", string(dump)) |
| 63 | |
| 64 | dialer := &websocket.Dialer{ |
| 65 | TLSClientConfig: options.TLSClientConfig, |
| 66 | Proxy: http.ProxyFromEnvironment, |
| 67 | } |
| 68 | wsConn, resp, err := clientConnect(req, dialer) |
| 69 | defer closeRespBody(resp) |
| 70 | |
| 71 | if err != nil && IsAccessResponse(resp) { |
| 72 | // Only get Access app info if we know the origin is protected by Access |
| 73 | originReq, err := http.NewRequest(http.MethodGet, options.OriginURL, nil) |
| 74 | if err != nil { |
| 75 | return nil, err |
| 76 | } |
| 77 | |
| 78 | appInfo, err := token.GetAppInfo(originReq.URL) |
| 79 | if err != nil { |
| 80 | return nil, err |
| 81 | } |
| 82 | options.AppInfo = appInfo |
| 83 | |
| 84 | wsConn, err = createAccessAuthenticatedStream(options, log) |
| 85 | if err != nil { |
| 86 | return nil, err |
| 87 | } |
| 88 | } else if err != nil { |
| 89 | return nil, err |
| 90 | } |
| 91 | |
| 92 | return &cfwebsocket.GorillaConn{Conn: wsConn}, nil |
| 93 | } |
| 94 | |
| 95 | var stripWebsocketHeaders = []string{ |
| 96 | "Upgrade", |
| 97 | "Connection", |
| 98 | "Sec-Websocket-Key", |
| 99 | "Sec-Websocket-Version", |
| 100 | "Sec-Websocket-Extensions", |
| 101 | } |
| 102 | |
| 103 | // the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key, |
| 104 | // Sec-WebSocket-Version and Sec-Websocket-Extensions headers. |
| 105 | // https://github.com/gorilla/websocket/blob/master/client.go#L189-L194. |
| 106 | func websocketHeaders(req *http.Request) http.Header { |
| 107 | wsHeaders := make(http.Header) |
| 108 | for key, val := range req.Header { |
| 109 | wsHeaders[key] = val |
| 110 | } |
| 111 | // Assume the header keys are in canonical format. |
| 112 | for _, header := range stripWebsocketHeaders { |
| 113 | wsHeaders.Del(header) |
| 114 | } |
| 115 | wsHeaders.Set("Host", req.Host) // See TUN-1097 |
| 116 | return wsHeaders |
| 117 | } |
| 118 | |
| 119 | // clientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing |
| 120 | // the connection. The response body may not contain the entire response and does |
| 121 | // not need to be closed by the application. |
| 122 | func clientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) { |
| 123 | req.URL.Scheme = changeRequestScheme(req.URL) |
| 124 | wsHeaders := websocketHeaders(req) |
| 125 | if dialler == nil { |
| 126 | dialler = &websocket.Dialer{ |
| 127 | Proxy: http.ProxyFromEnvironment, |
| 128 | } |
| 129 | } |
| 130 | conn, response, err := dialler.Dial(req.URL.String(), wsHeaders) |
| 131 | if err != nil { |
| 132 | return nil, response, err |
| 133 | } |
| 134 | return conn, response, nil |
| 135 | } |
| 136 | |
| 137 | // changeRequestScheme is needed as the gorilla websocket library requires the ws scheme. |
| 138 | // (even though it changes it back to http/https, but ¯\_(ツ)_/¯.) |
| 139 | func changeRequestScheme(reqURL *url.URL) string { |
| 140 | switch reqURL.Scheme { |
| 141 | case "https": |
| 142 | return "wss" |
| 143 | case "http": |
| 144 | return "ws" |
| 145 | case "": |
| 146 | return "ws" |
| 147 | default: |
| 148 | return reqURL.Scheme |
| 149 | } |
| 150 | } |
| 151 | |
| 152 | // createAccessAuthenticatedStream will try load a token from storage and make |
| 153 | // a connection with the token set on the request. If it still get redirect, |
| 154 | // this probably means the token in storage is invalid (expired/revoked). If that |
| 155 | // happens it deletes the token and runs the connection again, so the user can |
| 156 | // login again and generate a new one. |
| 157 | func createAccessAuthenticatedStream(options *StartOptions, log *zerolog.Logger) (*websocket.Conn, error) { |
| 158 | wsConn, resp, err := createAccessWebSocketStream(options, log) |
| 159 | defer closeRespBody(resp) |
| 160 | if err == nil { |
| 161 | return wsConn, nil |
| 162 | } |
| 163 | |
| 164 | if !IsAccessResponse(resp) { |
| 165 | return nil, err |
| 166 | } |
| 167 | |
| 168 | // Access Token is invalid for some reason. Go through regen flow |
| 169 | if err := token.RemoveTokenIfExists(options.AppInfo); err != nil { |
| 170 | return nil, err |
| 171 | } |
| 172 | wsConn, resp, err = createAccessWebSocketStream(options, log) |
| 173 | defer closeRespBody(resp) |
| 174 | if err != nil { |
| 175 | return nil, err |
| 176 | } |
| 177 | |
| 178 | return wsConn, nil |
| 179 | } |
| 180 | |
| 181 | // createAccessWebSocketStream builds an Access request and makes a connection |
| 182 | func createAccessWebSocketStream(options *StartOptions, log *zerolog.Logger) (*websocket.Conn, *http.Response, error) { |
| 183 | req, err := BuildAccessRequest(options, log) |
| 184 | if err != nil { |
| 185 | return nil, nil, err |
| 186 | } |
| 187 | |
| 188 | dump, err := httputil.DumpRequest(req, false) |
| 189 | if err != nil { |
| 190 | return nil, nil, err |
| 191 | } |
| 192 | log.Debug().Msgf("Access Websocket request: %s", string(dump)) |
| 193 | |
| 194 | conn, resp, err := clientConnect(req, nil) |
| 195 | |
| 196 | if resp != nil { |
| 197 | r, err := httputil.DumpResponse(resp, true) |
| 198 | if r != nil { |
| 199 | log.Debug().Msgf("Websocket response: %q", r) |
| 200 | } else if err != nil { |
| 201 | log.Debug().Msgf("Websocket response error: %v", err) |
| 202 | } |
| 203 | } |
| 204 | |
| 205 | return conn, resp, err |
| 206 | } |
| 207 | |