cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2021.10.5

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

connection/h2mux.go

262lines · modecode

1package connection
2
3import (
4 "context"
5 "io"
6 "net"
7 "net/http"
8 "time"
9
10 "github.com/pkg/errors"
11 "github.com/rs/zerolog"
12 "golang.org/x/sync/errgroup"
13
14 "github.com/cloudflare/cloudflared/h2mux"
15 tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
16 "github.com/cloudflare/cloudflared/websocket"
17)
18
19const (
20 muxerTimeout = 5 * time.Second
21 openStreamTimeout = 30 * time.Second
22)
23
24type h2muxConnection struct {
25 config *Config
26 muxerConfig *MuxerConfig
27 muxer *h2mux.Muxer
28 // connectionID is only used by metrics, and prometheus requires labels to be string
29 connIndexStr string
30 connIndex uint8
31
32 observer *Observer
33 gracefulShutdownC <-chan struct{}
34 stoppedGracefully bool
35
36 log *zerolog.Logger
37
38 // newRPCClientFunc allows us to mock RPCs during testing
39 newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
40}
41
42type MuxerConfig struct {
43 HeartbeatInterval time.Duration
44 MaxHeartbeats uint64
45 CompressionSetting h2mux.CompressionSetting
46 MetricsUpdateFreq time.Duration
47}
48
49func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, log *zerolog.Logger) *h2mux.MuxerConfig {
50 return &h2mux.MuxerConfig{
51 Timeout: muxerTimeout,
52 Handler: h,
53 IsClient: true,
54 HeartbeatInterval: mc.HeartbeatInterval,
55 MaxHeartbeats: mc.MaxHeartbeats,
56 Log: log,
57 CompressionQuality: mc.CompressionSetting,
58 }
59}
60
61// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
62func NewH2muxConnection(
63 config *Config,
64 muxerConfig *MuxerConfig,
65 edgeConn net.Conn,
66 connIndex uint8,
67 observer *Observer,
68 gracefulShutdownC <-chan struct{},
69) (*h2muxConnection, error, bool) {
70 h := &h2muxConnection{
71 config: config,
72 muxerConfig: muxerConfig,
73 connIndexStr: uint8ToString(connIndex),
74 connIndex: connIndex,
75 observer: observer,
76 gracefulShutdownC: gracefulShutdownC,
77 newRPCClientFunc: newRegistrationRPCClient,
78 }
79
80 // Establish a muxed connection with the edge
81 // Client mux handshake with agent server
82 muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig.H2MuxerConfig(h, observer.logTransport), h2mux.ActiveStreams)
83 if err != nil {
84 recoverable := isHandshakeErrRecoverable(err, connIndex, observer)
85 return nil, err, recoverable
86 }
87 h.muxer = muxer
88 return h, nil, false
89}
90
91func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
92 errGroup, serveCtx := errgroup.WithContext(ctx)
93 errGroup.Go(func() error {
94 return h.serveMuxer(serveCtx)
95 })
96
97 errGroup.Go(func() error {
98 if err := h.registerNamedTunnel(serveCtx, namedTunnel, connOptions); err != nil {
99 return err
100 }
101 connectedFuse.Connected()
102 return nil
103 })
104
105 errGroup.Go(func() error {
106 h.controlLoop(serveCtx, connectedFuse, true)
107 return nil
108 })
109
110 err := errGroup.Wait()
111 if err == errMuxerStopped {
112 if h.stoppedGracefully {
113 return nil
114 }
115 h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unexpected muxer shutdown")
116 }
117 return err
118}
119
120func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error {
121 errGroup, serveCtx := errgroup.WithContext(ctx)
122 errGroup.Go(func() error {
123 return h.serveMuxer(serveCtx)
124 })
125
126 errGroup.Go(func() (err error) {
127 defer func() {
128 if err == nil {
129 connectedFuse.Connected()
130 }
131 }()
132 if classicTunnel.UseReconnectToken && connectedFuse.IsConnected() {
133 err := h.reconnectTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
134 if err == nil {
135 return nil
136 }
137 // log errors and proceed to RegisterTunnel
138 h.observer.log.Err(err).
139 Uint8(LogFieldConnIndex, h.connIndex).
140 Msg("Couldn't reconnect connection. Re-registering it instead.")
141 }
142 return h.registerTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
143 })
144
145 errGroup.Go(func() error {
146 h.controlLoop(serveCtx, connectedFuse, false)
147 return nil
148 })
149
150 err := errGroup.Wait()
151 if err == errMuxerStopped {
152 if h.stoppedGracefully {
153 return nil
154 }
155 h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unexpected muxer shutdown")
156 }
157 return err
158}
159
160func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
161 // All routines should stop when muxer finish serving. When muxer is shutdown
162 // gracefully, it doesn't return an error, so we need to return errMuxerShutdown
163 // here to notify other routines to stop
164 err := h.muxer.Serve(ctx)
165 if err == nil {
166 return errMuxerStopped
167 }
168 return err
169}
170
171func (h *h2muxConnection) controlLoop(ctx context.Context, connectedFuse ConnectedFuse, isNamedTunnel bool) {
172 updateMetricsTicker := time.NewTicker(h.muxerConfig.MetricsUpdateFreq)
173 defer updateMetricsTicker.Stop()
174 var shutdownCompleted <-chan struct{}
175 for {
176 select {
177 case <-h.gracefulShutdownC:
178 if connectedFuse.IsConnected() {
179 h.unregister(isNamedTunnel)
180 }
181 h.stoppedGracefully = true
182 h.gracefulShutdownC = nil
183 shutdownCompleted = h.muxer.Shutdown()
184
185 case <-shutdownCompleted:
186 return
187
188 case <-ctx.Done():
189 // UnregisterTunnel blocks until the RPC call returns
190 if !h.stoppedGracefully && connectedFuse.IsConnected() {
191 h.unregister(isNamedTunnel)
192 }
193 h.muxer.Shutdown()
194 // don't wait for shutdown to finish when context is closed, this is the hard termination path
195 return
196
197 case <-updateMetricsTicker.C:
198 h.observer.metrics.updateMuxerMetrics(h.connIndexStr, h.muxer.Metrics())
199 }
200 }
201}
202
203func (h *h2muxConnection) newRPCStream(ctx context.Context, rpcName rpcName) (*h2mux.MuxedStream, error) {
204 openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
205 defer openStreamCancel()
206 stream, err := h.muxer.OpenRPCStream(openStreamCtx)
207 if err != nil {
208 return nil, err
209 }
210 return stream, nil
211}
212
213func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
214 respWriter := &h2muxRespWriter{stream}
215
216 req, reqErr := h.newRequest(stream)
217 if reqErr != nil {
218 respWriter.WriteErrorResponse()
219 return reqErr
220 }
221
222 var sourceConnectionType = TypeHTTP
223 if websocket.IsWebSocketUpgrade(req) {
224 sourceConnectionType = TypeWebsocket
225 }
226
227 err := h.config.OriginProxy.ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket)
228 if err != nil {
229 respWriter.WriteErrorResponse()
230 }
231 return err
232}
233
234func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
235 req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream})
236 if err != nil {
237 return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
238 }
239 err = H2RequestHeadersToH1Request(stream.Headers, req)
240 if err != nil {
241 return nil, errors.Wrap(err, "invalid request received")
242 }
243 return req, nil
244}
245
246type h2muxRespWriter struct {
247 *h2mux.MuxedStream
248}
249
250func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error {
251 headers := H1ResponseToH2ResponseHeaders(status, header)
252 headers = append(headers, h2mux.Header{Name: ResponseMetaHeader, Value: responseMetaHeaderOrigin})
253 return rp.WriteHeaders(headers)
254}
255
256func (rp *h2muxRespWriter) WriteErrorResponse() {
257 _ = rp.WriteHeaders([]h2mux.Header{
258 {Name: ":status", Value: "502"},
259 {Name: ResponseMetaHeader, Value: responseMetaHeaderCfd},
260 })
261 _, _ = rp.Write([]byte("502 Bad Gateway"))
262}
263