cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2019.6.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

origin/supervisor.go

239lines · modecode

1package origin
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "time"
8
9 "github.com/cloudflare/cloudflared/connection"
10 "github.com/cloudflare/cloudflared/signal"
11
12 "github.com/google/uuid"
13)
14
15const (
16 // Waiting time before retrying a failed tunnel connection
17 tunnelRetryDuration = time.Second * 10
18 // SRV record resolution TTL
19 resolveTTL = time.Hour
20 // Interval between registering new tunnels
21 registrationInterval = time.Second
22)
23
24type Supervisor struct {
25 config *TunnelConfig
26 edgeIPs []*net.TCPAddr
27 // nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
28 nextUnusedEdgeIP int
29 lastResolve time.Time
30 resolverC chan resolveResult
31 tunnelErrors chan tunnelError
32 tunnelsConnecting map[int]chan struct{}
33 // nextConnectedIndex and nextConnectedSignal are used to wait for all
34 // currently-connecting tunnels to finish connecting so we can reset backoff timer
35 nextConnectedIndex int
36 nextConnectedSignal chan struct{}
37}
38
39type resolveResult struct {
40 edgeIPs []*net.TCPAddr
41 err error
42}
43
44type tunnelError struct {
45 index int
46 err error
47}
48
49func NewSupervisor(config *TunnelConfig) *Supervisor {
50 return &Supervisor{
51 config: config,
52 tunnelErrors: make(chan tunnelError),
53 tunnelsConnecting: map[int]chan struct{}{},
54 }
55}
56
57func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
58 logger := s.config.Logger
59 if err := s.initialize(ctx, connectedSignal, u); err != nil {
60 return err
61 }
62 var tunnelsWaiting []int
63 backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
64 var backoffTimer <-chan time.Time
65 tunnelsActive := s.config.HAConnections
66
67 for {
68 select {
69 // Context cancelled
70 case <-ctx.Done():
71 for tunnelsActive > 0 {
72 <-s.tunnelErrors
73 tunnelsActive--
74 }
75 return nil
76 // startTunnel returned with error
77 // (note that this may also be caused by context cancellation)
78 case tunnelError := <-s.tunnelErrors:
79 tunnelsActive--
80 if tunnelError.err != nil {
81 logger.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
82 tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
83 s.waitForNextTunnel(tunnelError.index)
84
85 if backoffTimer == nil {
86 backoffTimer = backoff.BackoffTimer()
87 }
88
89 // If the error is a dial error, the problem is likely to be network related
90 // try another addr before refreshing since we are likely to get back the
91 // same IPs in the same order. Same problem with duplicate connection error.
92 if s.unusedIPs() {
93 s.replaceEdgeIP(tunnelError.index)
94 } else {
95 s.refreshEdgeIPs()
96 }
97 }
98 // Backoff was set and its timer expired
99 case <-backoffTimer:
100 backoffTimer = nil
101 for _, index := range tunnelsWaiting {
102 go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u)
103 }
104 tunnelsActive += len(tunnelsWaiting)
105 tunnelsWaiting = nil
106 // Tunnel successfully connected
107 case <-s.nextConnectedSignal:
108 if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
109 // No more tunnels outstanding, clear backoff timer
110 backoff.SetGracePeriod()
111 }
112 // DNS resolution returned
113 case result := <-s.resolverC:
114 s.lastResolve = time.Now()
115 s.resolverC = nil
116 if result.err == nil {
117 logger.Debug("Service discovery refresh complete")
118 s.edgeIPs = result.edgeIPs
119 } else {
120 logger.WithError(result.err).Error("Service discovery error")
121 }
122 }
123 }
124}
125
126func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
127 logger := s.config.Logger
128 edgeIPs, err := connection.ResolveEdgeIPs(logger, s.config.EdgeAddrs)
129 if err != nil {
130 logger.Infof("ResolveEdgeIPs err")
131 return err
132 }
133 s.edgeIPs = edgeIPs
134 if s.config.HAConnections > len(edgeIPs) {
135 logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
136 s.config.HAConnections = len(edgeIPs)
137 }
138 s.lastResolve = time.Now()
139 // check entitlement and version too old error before attempting to register more tunnels
140 s.nextUnusedEdgeIP = s.config.HAConnections
141 go s.startFirstTunnel(ctx, connectedSignal, u)
142 select {
143 case <-ctx.Done():
144 <-s.tunnelErrors
145 // Error can't be nil. A nil error signals that initialization succeed
146 return fmt.Errorf("context was canceled")
147 case tunnelError := <-s.tunnelErrors:
148 return tunnelError.err
149 case <-connectedSignal.Wait():
150 }
151 // At least one successful connection, so start the rest
152 for i := 1; i < s.config.HAConnections; i++ {
153 ch := signal.New(make(chan struct{}))
154 go s.startTunnel(ctx, i, ch, u)
155 time.Sleep(registrationInterval)
156 }
157 return nil
158}
159
160// startTunnel starts the first tunnel connection. The resulting error will be sent on
161// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
162func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) {
163 err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
164 defer func() {
165 s.tunnelErrors <- tunnelError{index: 0, err: err}
166 }()
167
168 for s.unusedIPs() {
169 select {
170 case <-ctx.Done():
171 return
172 default:
173 }
174 switch err.(type) {
175 case nil:
176 return
177 // try the next address if it was a dialError(network problem) or
178 // dupConnRegisterTunnelError
179 case dialError, dupConnRegisterTunnelError:
180 s.replaceEdgeIP(0)
181 default:
182 return
183 }
184 err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
185 }
186}
187
188// startTunnel starts a new tunnel connection. The resulting error will be sent on
189// s.tunnelErrors.
190func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) {
191 err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u)
192 s.tunnelErrors <- tunnelError{index: index, err: err}
193}
194
195func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
196 sig := make(chan struct{})
197 s.tunnelsConnecting[index] = sig
198 s.nextConnectedSignal = sig
199 s.nextConnectedIndex = index
200 return signal.New(sig)
201}
202
203func (s *Supervisor) waitForNextTunnel(index int) bool {
204 delete(s.tunnelsConnecting, index)
205 s.nextConnectedSignal = nil
206 for k, v := range s.tunnelsConnecting {
207 s.nextConnectedIndex = k
208 s.nextConnectedSignal = v
209 return true
210 }
211 return false
212}
213
214func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr {
215 return s.edgeIPs[index%len(s.edgeIPs)]
216}
217
218func (s *Supervisor) refreshEdgeIPs() {
219 if s.resolverC != nil {
220 return
221 }
222 if time.Since(s.lastResolve) < resolveTTL {
223 return
224 }
225 s.resolverC = make(chan resolveResult)
226 go func() {
227 edgeIPs, err := connection.ResolveEdgeIPs(s.config.Logger, s.config.EdgeAddrs)
228 s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err}
229 }()
230}
231
232func (s *Supervisor) unusedIPs() bool {
233 return s.nextUnusedEdgeIP < len(s.edgeIPs)
234}
235
236func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
237 s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
238 s.nextUnusedEdgeIP++
239}
240