cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2019.8.4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

origin/supervisor.go

255lines · modecode

1package origin
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "time"
8
9 "github.com/sirupsen/logrus"
10
11 "github.com/cloudflare/cloudflared/connection"
12 "github.com/cloudflare/cloudflared/signal"
13
14 "github.com/google/uuid"
15)
16
17const (
18 // Waiting time before retrying a failed tunnel connection
19 tunnelRetryDuration = time.Second * 10
20 // SRV record resolution TTL
21 resolveTTL = time.Hour
22 // Interval between registering new tunnels
23 registrationInterval = time.Second
24)
25
26type Supervisor struct {
27 config *TunnelConfig
28 edgeIPs []*net.TCPAddr
29 // nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
30 nextUnusedEdgeIP int
31 lastResolve time.Time
32 resolverC chan resolveResult
33 tunnelErrors chan tunnelError
34 tunnelsConnecting map[int]chan struct{}
35 // nextConnectedIndex and nextConnectedSignal are used to wait for all
36 // currently-connecting tunnels to finish connecting so we can reset backoff timer
37 nextConnectedIndex int
38 nextConnectedSignal chan struct{}
39
40 logger *logrus.Entry
41}
42
43type resolveResult struct {
44 edgeIPs []*net.TCPAddr
45 err error
46}
47
48type tunnelError struct {
49 index int
50 err error
51}
52
53func NewSupervisor(config *TunnelConfig) *Supervisor {
54 return &Supervisor{
55 config: config,
56 tunnelErrors: make(chan tunnelError),
57 tunnelsConnecting: map[int]chan struct{}{},
58 logger: config.Logger.WithField("subsystem", "supervisor"),
59 }
60}
61
62func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
63 logger := s.config.Logger
64 if err := s.initialize(ctx, connectedSignal, u); err != nil {
65 return err
66 }
67 var tunnelsWaiting []int
68 backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
69 var backoffTimer <-chan time.Time
70 tunnelsActive := s.config.HAConnections
71
72 for {
73 select {
74 // Context cancelled
75 case <-ctx.Done():
76 for tunnelsActive > 0 {
77 <-s.tunnelErrors
78 tunnelsActive--
79 }
80 return nil
81 // startTunnel returned with error
82 // (note that this may also be caused by context cancellation)
83 case tunnelError := <-s.tunnelErrors:
84 tunnelsActive--
85 if tunnelError.err != nil {
86 logger.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
87 tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
88 s.waitForNextTunnel(tunnelError.index)
89
90 if backoffTimer == nil {
91 backoffTimer = backoff.BackoffTimer()
92 }
93
94 // If the error is a dial error, the problem is likely to be network related
95 // try another addr before refreshing since we are likely to get back the
96 // same IPs in the same order. Same problem with duplicate connection error.
97 if s.unusedIPs() {
98 s.replaceEdgeIP(tunnelError.index)
99 } else {
100 s.refreshEdgeIPs()
101 }
102 }
103 // Backoff was set and its timer expired
104 case <-backoffTimer:
105 backoffTimer = nil
106 for _, index := range tunnelsWaiting {
107 go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u)
108 }
109 tunnelsActive += len(tunnelsWaiting)
110 tunnelsWaiting = nil
111 // Tunnel successfully connected
112 case <-s.nextConnectedSignal:
113 if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
114 // No more tunnels outstanding, clear backoff timer
115 backoff.SetGracePeriod()
116 }
117 // DNS resolution returned
118 case result := <-s.resolverC:
119 s.lastResolve = time.Now()
120 s.resolverC = nil
121 if result.err == nil {
122 logger.Debug("Service discovery refresh complete")
123 s.edgeIPs = result.edgeIPs
124 } else {
125 logger.WithError(result.err).Error("Service discovery error")
126 }
127 }
128 }
129}
130
131func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
132 logger := s.logger
133
134 edgeIPs, err := s.resolveEdgeIPs()
135
136 if err != nil {
137 logger.Infof("ResolveEdgeIPs err")
138 return err
139 }
140 s.edgeIPs = edgeIPs
141 if s.config.HAConnections > len(edgeIPs) {
142 logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
143 s.config.HAConnections = len(edgeIPs)
144 }
145 s.lastResolve = time.Now()
146 // check entitlement and version too old error before attempting to register more tunnels
147 s.nextUnusedEdgeIP = s.config.HAConnections
148 go s.startFirstTunnel(ctx, connectedSignal, u)
149 select {
150 case <-ctx.Done():
151 <-s.tunnelErrors
152 // Error can't be nil. A nil error signals that initialization succeed
153 return fmt.Errorf("context was canceled")
154 case tunnelError := <-s.tunnelErrors:
155 return tunnelError.err
156 case <-connectedSignal.Wait():
157 }
158 // At least one successful connection, so start the rest
159 for i := 1; i < s.config.HAConnections; i++ {
160 ch := signal.New(make(chan struct{}))
161 go s.startTunnel(ctx, i, ch, u)
162 time.Sleep(registrationInterval)
163 }
164 return nil
165}
166
167// startTunnel starts the first tunnel connection. The resulting error will be sent on
168// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
169func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) {
170 err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
171 defer func() {
172 s.tunnelErrors <- tunnelError{index: 0, err: err}
173 }()
174
175 for s.unusedIPs() {
176 select {
177 case <-ctx.Done():
178 return
179 default:
180 }
181 switch err.(type) {
182 case nil:
183 return
184 // try the next address if it was a dialError(network problem) or
185 // dupConnRegisterTunnelError
186 case dialError, dupConnRegisterTunnelError:
187 s.replaceEdgeIP(0)
188 default:
189 return
190 }
191 err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
192 }
193}
194
195// startTunnel starts a new tunnel connection. The resulting error will be sent on
196// s.tunnelErrors.
197func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) {
198 err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u)
199 s.tunnelErrors <- tunnelError{index: index, err: err}
200}
201
202func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
203 sig := make(chan struct{})
204 s.tunnelsConnecting[index] = sig
205 s.nextConnectedSignal = sig
206 s.nextConnectedIndex = index
207 return signal.New(sig)
208}
209
210func (s *Supervisor) waitForNextTunnel(index int) bool {
211 delete(s.tunnelsConnecting, index)
212 s.nextConnectedSignal = nil
213 for k, v := range s.tunnelsConnecting {
214 s.nextConnectedIndex = k
215 s.nextConnectedSignal = v
216 return true
217 }
218 return false
219}
220
221func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr {
222 return s.edgeIPs[index%len(s.edgeIPs)]
223}
224
225func (s *Supervisor) resolveEdgeIPs() ([]*net.TCPAddr, error) {
226 // If --edge is specfied, resolve edge server addresses
227 if len(s.config.EdgeAddrs) > 0 {
228 return connection.ResolveAddrs(s.config.EdgeAddrs)
229 }
230 // Otherwise lookup edge server addresses through service discovery
231 return connection.EdgeDiscovery(s.logger)
232}
233
234func (s *Supervisor) refreshEdgeIPs() {
235 if s.resolverC != nil {
236 return
237 }
238 if time.Since(s.lastResolve) < resolveTTL {
239 return
240 }
241 s.resolverC = make(chan resolveResult)
242 go func() {
243 edgeIPs, err := s.resolveEdgeIPs()
244 s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err}
245 }()
246}
247
248func (s *Supervisor) unusedIPs() bool {
249 return s.nextUnusedEdgeIP < len(s.edgeIPs)
250}
251
252func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
253 s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
254 s.nextUnusedEdgeIP++
255}
256