cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2019.3.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

origin/supervisor.go

238lines · modecode

1package origin
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "time"
8
9 "github.com/cloudflare/cloudflared/signal"
10
11 "github.com/google/uuid"
12)
13
14const (
15 // Waiting time before retrying a failed tunnel connection
16 tunnelRetryDuration = time.Second * 10
17 // SRV record resolution TTL
18 resolveTTL = time.Hour
19 // Interval between registering new tunnels
20 registrationInterval = time.Second
21)
22
23type Supervisor struct {
24 config *TunnelConfig
25 edgeIPs []*net.TCPAddr
26 // nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
27 nextUnusedEdgeIP int
28 lastResolve time.Time
29 resolverC chan resolveResult
30 tunnelErrors chan tunnelError
31 tunnelsConnecting map[int]chan struct{}
32 // nextConnectedIndex and nextConnectedSignal are used to wait for all
33 // currently-connecting tunnels to finish connecting so we can reset backoff timer
34 nextConnectedIndex int
35 nextConnectedSignal chan struct{}
36}
37
38type resolveResult struct {
39 edgeIPs []*net.TCPAddr
40 err error
41}
42
43type tunnelError struct {
44 index int
45 err error
46}
47
48func NewSupervisor(config *TunnelConfig) *Supervisor {
49 return &Supervisor{
50 config: config,
51 tunnelErrors: make(chan tunnelError),
52 tunnelsConnecting: map[int]chan struct{}{},
53 }
54}
55
56func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
57 logger := s.config.Logger
58 if err := s.initialize(ctx, connectedSignal, u); err != nil {
59 return err
60 }
61 var tunnelsWaiting []int
62 backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
63 var backoffTimer <-chan time.Time
64 tunnelsActive := s.config.HAConnections
65
66 for {
67 select {
68 // Context cancelled
69 case <-ctx.Done():
70 for tunnelsActive > 0 {
71 <-s.tunnelErrors
72 tunnelsActive--
73 }
74 return nil
75 // startTunnel returned with error
76 // (note that this may also be caused by context cancellation)
77 case tunnelError := <-s.tunnelErrors:
78 tunnelsActive--
79 if tunnelError.err != nil {
80 logger.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
81 tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
82 s.waitForNextTunnel(tunnelError.index)
83
84 if backoffTimer == nil {
85 backoffTimer = backoff.BackoffTimer()
86 }
87
88 // If the error is a dial error, the problem is likely to be network related
89 // try another addr before refreshing since we are likely to get back the
90 // same IPs in the same order. Same problem with duplicate connection error.
91 if s.unusedIPs() {
92 s.replaceEdgeIP(tunnelError.index)
93 } else {
94 s.refreshEdgeIPs()
95 }
96 }
97 // Backoff was set and its timer expired
98 case <-backoffTimer:
99 backoffTimer = nil
100 for _, index := range tunnelsWaiting {
101 go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u)
102 }
103 tunnelsActive += len(tunnelsWaiting)
104 tunnelsWaiting = nil
105 // Tunnel successfully connected
106 case <-s.nextConnectedSignal:
107 if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
108 // No more tunnels outstanding, clear backoff timer
109 backoff.SetGracePeriod()
110 }
111 // DNS resolution returned
112 case result := <-s.resolverC:
113 s.lastResolve = time.Now()
114 s.resolverC = nil
115 if result.err == nil {
116 logger.Debug("Service discovery refresh complete")
117 s.edgeIPs = result.edgeIPs
118 } else {
119 logger.WithError(result.err).Error("Service discovery error")
120 }
121 }
122 }
123}
124
125func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
126 logger := s.config.Logger
127 edgeIPs, err := ResolveEdgeIPs(logger, s.config.EdgeAddrs)
128 if err != nil {
129 logger.Infof("ResolveEdgeIPs err")
130 return err
131 }
132 s.edgeIPs = edgeIPs
133 if s.config.HAConnections > len(edgeIPs) {
134 logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
135 s.config.HAConnections = len(edgeIPs)
136 }
137 s.lastResolve = time.Now()
138 // check entitlement and version too old error before attempting to register more tunnels
139 s.nextUnusedEdgeIP = s.config.HAConnections
140 go s.startFirstTunnel(ctx, connectedSignal, u)
141 select {
142 case <-ctx.Done():
143 <-s.tunnelErrors
144 // Error can't be nil. A nil error signals that initialization succeed
145 return fmt.Errorf("context was canceled")
146 case tunnelError := <-s.tunnelErrors:
147 return tunnelError.err
148 case <-connectedSignal.Wait():
149 }
150 // At least one successful connection, so start the rest
151 for i := 1; i < s.config.HAConnections; i++ {
152 ch := signal.New(make(chan struct{}))
153 go s.startTunnel(ctx, i, ch, u)
154 time.Sleep(registrationInterval)
155 }
156 return nil
157}
158
159// startTunnel starts the first tunnel connection. The resulting error will be sent on
160// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
161func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) {
162 err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
163 defer func() {
164 s.tunnelErrors <- tunnelError{index: 0, err: err}
165 }()
166
167 for s.unusedIPs() {
168 select {
169 case <-ctx.Done():
170 return
171 default:
172 }
173 switch err.(type) {
174 case nil:
175 return
176 // try the next address if it was a dialError(network problem) or
177 // dupConnRegisterTunnelError
178 case dialError, dupConnRegisterTunnelError:
179 s.replaceEdgeIP(0)
180 default:
181 return
182 }
183 err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
184 }
185}
186
187// startTunnel starts a new tunnel connection. The resulting error will be sent on
188// s.tunnelErrors.
189func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) {
190 err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u)
191 s.tunnelErrors <- tunnelError{index: index, err: err}
192}
193
194func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
195 sig := make(chan struct{})
196 s.tunnelsConnecting[index] = sig
197 s.nextConnectedSignal = sig
198 s.nextConnectedIndex = index
199 return signal.New(sig)
200}
201
202func (s *Supervisor) waitForNextTunnel(index int) bool {
203 delete(s.tunnelsConnecting, index)
204 s.nextConnectedSignal = nil
205 for k, v := range s.tunnelsConnecting {
206 s.nextConnectedIndex = k
207 s.nextConnectedSignal = v
208 return true
209 }
210 return false
211}
212
213func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr {
214 return s.edgeIPs[index%len(s.edgeIPs)]
215}
216
217func (s *Supervisor) refreshEdgeIPs() {
218 if s.resolverC != nil {
219 return
220 }
221 if time.Since(s.lastResolve) < resolveTTL {
222 return
223 }
224 s.resolverC = make(chan resolveResult)
225 go func() {
226 edgeIPs, err := ResolveEdgeIPs(s.config.Logger, s.config.EdgeAddrs)
227 s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err}
228 }()
229}
230
231func (s *Supervisor) unusedIPs() bool {
232 return s.nextUnusedEdgeIP < len(s.edgeIPs)
233}
234
235func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
236 s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
237 s.nextUnusedEdgeIP++
238}
239