cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2018.10.2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

origin/tunnel.go

632lines · modecode

1package origin
2
3import (
4 "bufio"
5 "crypto/tls"
6 "fmt"
7 "io"
8 "net"
9 "net/http"
10 "net/url"
11 "strconv"
12 "strings"
13 "time"
14
15 "golang.org/x/net/context"
16 "golang.org/x/sync/errgroup"
17
18 "github.com/cloudflare/cloudflared/h2mux"
19 "github.com/cloudflare/cloudflared/tunnelrpc"
20 tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
21 "github.com/cloudflare/cloudflared/validation"
22 "github.com/cloudflare/cloudflared/websocket"
23
24 raven "github.com/getsentry/raven-go"
25 "github.com/pkg/errors"
26 _ "github.com/prometheus/client_golang/prometheus"
27 log "github.com/sirupsen/logrus"
28 rpc "zombiezen.com/go/capnproto2/rpc"
29)
30
31const (
32 dialTimeout = 15 * time.Second
33 lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
34 TagHeaderNamePrefix = "Cf-Warp-Tag-"
35 DuplicateConnectionError = "EDUPCONN"
36)
37
38type TunnelConfig struct {
39 EdgeAddrs []string
40 OriginUrl string
41 Hostname string
42 OriginCert []byte
43 TlsConfig *tls.Config
44 ClientTlsConfig *tls.Config
45 Retries uint
46 HeartbeatInterval time.Duration
47 MaxHeartbeats uint64
48 ClientID string
49 BuildInfo *BuildInfo
50 ReportedVersion string
51 LBPool string
52 Tags []tunnelpogs.Tag
53 HAConnections int
54 HTTPTransport http.RoundTripper
55 Metrics *TunnelMetrics
56 MetricsUpdateFreq time.Duration
57 ProtocolLogger *log.Logger
58 Logger *log.Logger
59 IsAutoupdated bool
60 GracePeriod time.Duration
61 RunFromTerminal bool
62 NoChunkedEncoding bool
63 WSGI bool
64 CompressionQuality uint64
65}
66
67type dialError struct {
68 cause error
69}
70
71func (e dialError) Error() string {
72 return e.cause.Error()
73}
74
75type dupConnRegisterTunnelError struct{}
76
77func (e dupConnRegisterTunnelError) Error() string {
78 return "already connected to this server"
79}
80
81type muxerShutdownError struct{}
82
83func (e muxerShutdownError) Error() string {
84 return "muxer shutdown"
85}
86
87// RegisterTunnel error from server
88type serverRegisterTunnelError struct {
89 cause error
90 permanent bool
91}
92
93func (e serverRegisterTunnelError) Error() string {
94 return e.cause.Error()
95}
96
97// RegisterTunnel error from client
98type clientRegisterTunnelError struct {
99 cause error
100}
101
102func (e clientRegisterTunnelError) Error() string {
103 return e.cause.Error()
104}
105
106func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string) *tunnelpogs.RegistrationOptions {
107 policy := tunnelrpc.ExistingTunnelPolicy_balance
108 if c.HAConnections <= 1 && c.LBPool == "" {
109 policy = tunnelrpc.ExistingTunnelPolicy_disconnect
110 }
111 return &tunnelpogs.RegistrationOptions{
112 ClientID: c.ClientID,
113 Version: c.ReportedVersion,
114 OS: fmt.Sprintf("%s_%s", c.BuildInfo.GoOS, c.BuildInfo.GoArch),
115 ExistingTunnelPolicy: policy,
116 PoolName: c.LBPool,
117 Tags: c.Tags,
118 ConnectionID: connectionID,
119 OriginLocalIP: OriginLocalIP,
120 IsAutoupdated: c.IsAutoupdated,
121 RunFromTerminal: c.RunFromTerminal,
122 CompressionQuality: c.CompressionQuality,
123 }
124}
125
126func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error {
127 ctx, cancel := context.WithCancel(context.Background())
128 go func() {
129 <-shutdownC
130 cancel()
131 }()
132 // If a user specified negative HAConnections, we will treat it as requesting 1 connection
133 if config.HAConnections > 1 {
134 return NewSupervisor(config).Run(ctx, connectedSignal)
135 } else {
136 addrs, err := ResolveEdgeIPs(config.EdgeAddrs)
137 if err != nil {
138 return err
139 }
140 return ServeTunnelLoop(ctx, config, addrs[0], 0, connectedSignal)
141 }
142}
143
144func ServeTunnelLoop(ctx context.Context,
145 config *TunnelConfig,
146 addr *net.TCPAddr,
147 connectionID uint8,
148 connectedSignal chan struct{},
149) error {
150 logger := config.Logger
151 config.Metrics.incrementHaConnections()
152 defer config.Metrics.decrementHaConnections()
153 backoff := BackoffHandler{MaxRetries: config.Retries}
154 // Used to close connectedSignal no more than once
155 connectedFuse := h2mux.NewBooleanFuse()
156 go func() {
157 if connectedFuse.Await() {
158 close(connectedSignal)
159 }
160 }()
161 // Ensure the above goroutine will terminate if we return without connecting
162 defer connectedFuse.Fuse(false)
163 for {
164 err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff)
165 if recoverable {
166 if duration, ok := backoff.GetBackoffDuration(ctx); ok {
167 logger.Infof("Retrying in %s seconds", duration)
168 backoff.Backoff(ctx)
169 continue
170 }
171 }
172 return err
173 }
174}
175
176func ServeTunnel(
177 ctx context.Context,
178 config *TunnelConfig,
179 addr *net.TCPAddr,
180 connectionID uint8,
181 connectedFuse *h2mux.BooleanFuse,
182 backoff *BackoffHandler,
183) (err error, recoverable bool) {
184 // Treat panics as recoverable errors
185 defer func() {
186 if r := recover(); r != nil {
187 var ok bool
188 err, ok = r.(error)
189 if !ok {
190 err = fmt.Errorf("ServeTunnel: %v", r)
191 }
192 recoverable = true
193 }
194 }()
195
196 connectionTag := uint8ToString(connectionID)
197 logger := config.Logger.WithField("connectionID", connectionTag)
198
199 // additional tags to send other than hostname which is set in cloudflared main package
200 tags := make(map[string]string)
201 tags["ha"] = connectionTag
202
203 // Returns error from parsing the origin URL or handshake errors
204 handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID)
205 if err != nil {
206 errLog := config.Logger.WithError(err)
207 switch err.(type) {
208 case dialError:
209 errLog.Error("Unable to dial edge")
210 case h2mux.MuxerHandshakeError:
211 errLog.Error("Handshake failed with edge server")
212 default:
213 errLog.Error("Tunnel creation failure")
214 return err, false
215 }
216 return err, true
217 }
218
219 errGroup, serveCtx := errgroup.WithContext(ctx)
220
221 errGroup.Go(func() error {
222 err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP)
223 if err == nil {
224 connectedFuse.Fuse(true)
225 backoff.SetGracePeriod()
226 }
227 return err
228 })
229
230 errGroup.Go(func() error {
231 updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
232 for {
233 select {
234 case <-serveCtx.Done():
235 // UnregisterTunnel blocks until the RPC call returns
236 err := UnregisterTunnel(handler.muxer, config.GracePeriod, config.Logger)
237 handler.muxer.Shutdown()
238 return err
239 case <-updateMetricsTickC:
240 handler.UpdateMetrics(connectionTag)
241 }
242 }
243 })
244
245 errGroup.Go(func() error {
246 // All routines should stop when muxer finish serving. When muxer is shutdown
247 // gracefully, it doesn't return an error, so we need to return errMuxerShutdown
248 // here to notify other routines to stop
249 err := handler.muxer.Serve(serveCtx)
250 if err == nil {
251 return muxerShutdownError{}
252 }
253 return err
254 })
255
256 err = errGroup.Wait()
257 if err != nil {
258 switch castedErr := err.(type) {
259 case dupConnRegisterTunnelError:
260 logger.Info("Already connected to this server, selecting a different one")
261 return err, true
262 case serverRegisterTunnelError:
263 logger.WithError(castedErr.cause).Error("Register tunnel error from server side")
264 // Don't send registration error return from server to Sentry. They are
265 // logged on server side
266 return castedErr.cause, !castedErr.permanent
267 case clientRegisterTunnelError:
268 logger.WithError(castedErr.cause).Error("Register tunnel error on client side")
269 raven.CaptureError(castedErr.cause, tags)
270 return err, true
271 case muxerShutdownError:
272 logger.Infof("Muxer shutdown")
273 return err, true
274 default:
275 logger.WithError(err).Error("Serve tunnel error")
276 raven.CaptureError(err, tags)
277 return err, true
278 }
279 }
280 return nil, true
281}
282
283func IsRPCStreamResponse(headers []h2mux.Header) bool {
284 if len(headers) != 1 {
285 return false
286 }
287 if headers[0].Name != ":status" || headers[0].Value != "200" {
288 return false
289 }
290 return true
291}
292
293func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error {
294 config.Logger.Debug("initiating RPC stream to register")
295 stream, err := muxer.OpenStream([]h2mux.Header{
296 {Name: ":method", Value: "RPC"},
297 {Name: ":scheme", Value: "capnp"},
298 {Name: ":path", Value: "*"},
299 }, nil)
300 if err != nil {
301 // RPC stream open error
302 return clientRegisterTunnelError{cause: err}
303 }
304 if !IsRPCStreamResponse(stream.Headers) {
305 // stream response error
306 return clientRegisterTunnelError{cause: err}
307 }
308 conn := rpc.NewConn(
309 tunnelrpc.NewTransportLogger(config.Logger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
310 tunnelrpc.ConnLog(config.Logger.WithField("subsystem", "rpc-transport")),
311 )
312 defer conn.Close()
313 ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
314 // Request server info without blocking tunnel registration; must use capnp library directly.
315 tsClient := tunnelrpc.TunnelServer{Client: ts.Client}
316 serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
317 return nil
318 })
319 registration, err := ts.RegisterTunnel(
320 ctx,
321 config.OriginCert,
322 config.Hostname,
323 config.RegistrationOptions(connectionID, originLocalIP),
324 )
325 LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, config.Logger)
326 if err != nil {
327 // RegisterTunnel RPC failure
328 return clientRegisterTunnelError{cause: err}
329 }
330 for _, logLine := range registration.LogLines {
331 config.Logger.Info(logLine)
332 }
333 if registration.Err == DuplicateConnectionError {
334 return dupConnRegisterTunnelError{}
335 } else if registration.Err != "" {
336 return serverRegisterTunnelError{
337 cause: fmt.Errorf("Server error: %s", registration.Err),
338 permanent: registration.PermanentFailure,
339 }
340 }
341
342 config.Logger.Info("Tunnel ID: " + registration.TunnelID)
343 config.Logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
344 return nil
345}
346
347func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error {
348 logger.Debug("initiating RPC stream to unregister")
349 stream, err := muxer.OpenStream([]h2mux.Header{
350 {Name: ":method", Value: "RPC"},
351 {Name: ":scheme", Value: "capnp"},
352 {Name: ":path", Value: "*"},
353 }, nil)
354 if err != nil {
355 // RPC stream open error
356 return err
357 }
358 if !IsRPCStreamResponse(stream.Headers) {
359 // stream response error
360 return err
361 }
362 ctx := context.Background()
363 conn := rpc.NewConn(
364 tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)),
365 tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
366 )
367 defer conn.Close()
368 ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
369 // gracePeriod is encoded in int64 using capnproto
370 return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
371}
372
373func LogServerInfo(
374 promise tunnelrpc.ServerInfo_Promise,
375 connectionID uint8,
376 metrics *TunnelMetrics,
377 logger *log.Logger,
378) {
379 serverInfoMessage, err := promise.Struct()
380 if err != nil {
381 logger.WithError(err).Warn("Failed to retrieve server information")
382 return
383 }
384 serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage)
385 if err != nil {
386 logger.WithError(err).Warn("Failed to retrieve server information")
387 return
388 }
389 logger.Infof("Connected to %s", serverInfo.LocationName)
390 metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName)
391}
392
393func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
394 for _, header := range h2 {
395 switch header.Name {
396 case ":method":
397 h1.Method = header.Value
398 case ":scheme":
399 case ":authority":
400 // Otherwise the host header will be based on the origin URL
401 h1.Host = header.Value
402 case ":path":
403 u, err := url.Parse(header.Value)
404 if err != nil {
405 return fmt.Errorf("unparseable path")
406 }
407 resolved := h1.URL.ResolveReference(u)
408 // prevent escaping base URL
409 if !strings.HasPrefix(resolved.String(), h1.URL.String()) {
410 return fmt.Errorf("invalid path")
411 }
412 h1.URL = resolved
413 default:
414 h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value)
415 }
416 }
417 return nil
418}
419
420func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
421 h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}}
422 for headerName, headerValues := range h1.Header {
423 for _, headerValue := range headerValues {
424 h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue})
425 }
426 }
427 return
428}
429
430func FindCfRayHeader(h1 *http.Request) string {
431 return h1.Header.Get("Cf-Ray")
432}
433
434type TunnelHandler struct {
435 originUrl string
436 muxer *h2mux.Muxer
437 httpClient http.RoundTripper
438 tlsConfig *tls.Config
439 tags []tunnelpogs.Tag
440 metrics *TunnelMetrics
441 // connectionID is only used by metrics, and prometheus requires labels to be string
442 connectionID string
443 logger *log.Logger
444 noChunkedEncoding bool
445}
446
447var dialer = net.Dialer{DualStack: true}
448
449// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
450func NewTunnelHandler(ctx context.Context,
451 config *TunnelConfig,
452 addr string,
453 connectionID uint8,
454) (*TunnelHandler, string, error) {
455 originURL, err := validation.ValidateUrl(config.OriginUrl)
456 if err != nil {
457 return nil, "", fmt.Errorf("Unable to parse origin url %#v", originURL)
458 }
459 h := &TunnelHandler{
460 originUrl: originURL,
461 httpClient: config.HTTPTransport,
462 tlsConfig: config.ClientTlsConfig,
463 tags: config.Tags,
464 metrics: config.Metrics,
465 connectionID: uint8ToString(connectionID),
466 logger: config.Logger,
467 noChunkedEncoding: config.NoChunkedEncoding,
468 }
469 if h.httpClient == nil {
470 h.httpClient = http.DefaultTransport
471 }
472 // Inherit from parent context so we can cancel (Ctrl-C) while dialing
473 dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
474 // TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one)
475 plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", addr)
476 dialCancel()
477 if err != nil {
478 return nil, "", dialError{cause: errors.Wrap(err, "DialContext error")}
479 }
480 edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig)
481 edgeConn.SetDeadline(time.Now().Add(dialTimeout))
482 err = edgeConn.Handshake()
483 if err != nil {
484 return nil, "", dialError{cause: errors.Wrap(err, "Handshake with edge error")}
485 }
486 // clear the deadline on the conn; h2mux has its own timeouts
487 edgeConn.SetDeadline(time.Time{})
488 // Establish a muxed connection with the edge
489 // Client mux handshake with agent server
490 h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
491 Timeout: 5 * time.Second,
492 Handler: h,
493 IsClient: true,
494 HeartbeatInterval: config.HeartbeatInterval,
495 MaxHeartbeats: config.MaxHeartbeats,
496 Logger: config.ProtocolLogger.WithFields(log.Fields{}),
497 CompressionQuality: h2mux.CompressionSetting(config.CompressionQuality),
498 })
499 if err != nil {
500 return h, "", errors.New("TLS handshake error")
501 }
502 return h, edgeConn.LocalAddr().String(), err
503}
504
505func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
506 for _, tag := range h.tags {
507 r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
508 }
509}
510
511func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
512 h.metrics.incrementRequests(h.connectionID)
513 req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
514 if err != nil {
515 h.logger.WithError(err).Panic("Unexpected error from http.NewRequest")
516 }
517 err = H2RequestHeadersToH1Request(stream.Headers, req)
518 if err != nil {
519 h.logger.WithError(err).Error("invalid request received")
520 }
521 h.AppendTagHeaders(req)
522 cfRay := FindCfRayHeader(req)
523 lbProbe := isLBProbeRequest(req)
524 h.logRequest(req, cfRay, lbProbe)
525 if websocket.IsWebSocketUpgrade(req) {
526 conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
527 if err != nil {
528 h.logError(stream, err)
529 } else {
530 stream.WriteHeaders(H1ResponseToH2Response(response))
531 defer conn.Close()
532 // Copy to/from stream to the undelying connection. Use the underlying
533 // connection because cloudflared doesn't operate on the message themselves
534 websocket.Stream(conn.UnderlyingConn(), stream)
535 h.metrics.incrementResponses(h.connectionID, "200")
536 h.logResponse(response, cfRay, lbProbe)
537 }
538 } else {
539 // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
540 if h.noChunkedEncoding {
541 req.TransferEncoding = []string{"gzip", "deflate"}
542 cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
543 if err == nil {
544 req.ContentLength = int64(cLength)
545 }
546 }
547
548 // Request origin to keep connection alive to improve performance
549 req.Header.Set("Connection", "keep-alive")
550
551 response, err := h.httpClient.RoundTrip(req)
552
553 if err != nil {
554 h.logError(stream, err)
555 } else {
556 defer response.Body.Close()
557 stream.WriteHeaders(H1ResponseToH2Response(response))
558 if h.isEventStream(response) {
559 h.writeEventStream(stream, response.Body)
560 } else {
561 // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
562 // compression generates dictionary on first write
563 io.CopyBuffer(stream, response.Body, make([]byte, 512*1024))
564 }
565
566 h.metrics.incrementResponses(h.connectionID, "200")
567 h.logResponse(response, cfRay, lbProbe)
568 }
569 }
570 h.metrics.decrementConcurrentRequests(h.connectionID)
571 return nil
572}
573
574func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) {
575 reader := bufio.NewReader(responseBody)
576 for {
577 line, err := reader.ReadBytes('\n')
578 if err != nil {
579 break
580 }
581 stream.Write(line)
582 }
583}
584
585func (h *TunnelHandler) isEventStream(response *http.Response) bool {
586 if response.Header.Get("content-type") == "text/event-stream" {
587 h.logger.Debug("Detected Server-Side Events from Origin")
588 return true
589 }
590 return false
591}
592
593func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
594 h.logger.WithError(err).Error("HTTP request error")
595 stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
596 stream.Write([]byte("502 Bad Gateway"))
597 h.metrics.incrementResponses(h.connectionID, "502")
598}
599
600func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool) {
601 if cfRay != "" {
602 h.logger.WithField("CF-RAY", cfRay).Debugf("%s %s %s", req.Method, req.URL, req.Proto)
603 } else if lbProbe {
604 h.logger.Debugf("Load Balancer health check %s %s %s", req.Method, req.URL, req.Proto)
605 } else {
606 h.logger.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto)
607 }
608 h.logger.Debugf("Request Headers %+v", req.Header)
609}
610
611func (h *TunnelHandler) logResponse(r *http.Response, cfRay string, lbProbe bool) {
612 if cfRay != "" {
613 h.logger.WithField("CF-RAY", cfRay).Debugf("%s", r.Status)
614 } else if lbProbe {
615 h.logger.Debugf("Response to Load Balancer health check %s", r.Status)
616 } else {
617 h.logger.Infof("%s", r.Status)
618 }
619 h.logger.Debugf("Response Headers %+v", r.Header)
620}
621
622func (h *TunnelHandler) UpdateMetrics(connectionID string) {
623 h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics())
624}
625
626func uint8ToString(input uint8) string {
627 return strconv.FormatUint(uint64(input), 10)
628}
629
630func isLBProbeRequest(req *http.Request) bool {
631 return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
632}
633