cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2019.2.1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

origin/supervisor.go

235lines · modecode

1package origin
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "time"
8
9 "github.com/google/uuid"
10)
11
12const (
13 // Waiting time before retrying a failed tunnel connection
14 tunnelRetryDuration = time.Second * 10
15 // SRV record resolution TTL
16 resolveTTL = time.Hour
17 // Interval between registering new tunnels
18 registrationInterval = time.Second
19)
20
21type Supervisor struct {
22 config *TunnelConfig
23 edgeIPs []*net.TCPAddr
24 // nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
25 nextUnusedEdgeIP int
26 lastResolve time.Time
27 resolverC chan resolveResult
28 tunnelErrors chan tunnelError
29 tunnelsConnecting map[int]chan struct{}
30 // nextConnectedIndex and nextConnectedSignal are used to wait for all
31 // currently-connecting tunnels to finish connecting so we can reset backoff timer
32 nextConnectedIndex int
33 nextConnectedSignal chan struct{}
34}
35
36type resolveResult struct {
37 edgeIPs []*net.TCPAddr
38 err error
39}
40
41type tunnelError struct {
42 index int
43 err error
44}
45
46func NewSupervisor(config *TunnelConfig) *Supervisor {
47 return &Supervisor{
48 config: config,
49 tunnelErrors: make(chan tunnelError),
50 tunnelsConnecting: map[int]chan struct{}{},
51 }
52}
53
54func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}, u uuid.UUID) error {
55 logger := s.config.Logger
56 if err := s.initialize(ctx, connectedSignal, u); err != nil {
57 return err
58 }
59 var tunnelsWaiting []int
60 backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
61 var backoffTimer <-chan time.Time
62 tunnelsActive := s.config.HAConnections
63
64 for {
65 select {
66 // Context cancelled
67 case <-ctx.Done():
68 for tunnelsActive > 0 {
69 <-s.tunnelErrors
70 tunnelsActive--
71 }
72 return nil
73 // startTunnel returned with error
74 // (note that this may also be caused by context cancellation)
75 case tunnelError := <-s.tunnelErrors:
76 tunnelsActive--
77 if tunnelError.err != nil {
78 logger.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
79 tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
80 s.waitForNextTunnel(tunnelError.index)
81
82 if backoffTimer == nil {
83 backoffTimer = backoff.BackoffTimer()
84 }
85
86 // If the error is a dial error, the problem is likely to be network related
87 // try another addr before refreshing since we are likely to get back the
88 // same IPs in the same order. Same problem with duplicate connection error.
89 if s.unusedIPs() {
90 s.replaceEdgeIP(tunnelError.index)
91 } else {
92 s.refreshEdgeIPs()
93 }
94 }
95 // Backoff was set and its timer expired
96 case <-backoffTimer:
97 backoffTimer = nil
98 for _, index := range tunnelsWaiting {
99 go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u)
100 }
101 tunnelsActive += len(tunnelsWaiting)
102 tunnelsWaiting = nil
103 // Tunnel successfully connected
104 case <-s.nextConnectedSignal:
105 if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
106 // No more tunnels outstanding, clear backoff timer
107 backoff.SetGracePeriod()
108 }
109 // DNS resolution returned
110 case result := <-s.resolverC:
111 s.lastResolve = time.Now()
112 s.resolverC = nil
113 if result.err == nil {
114 logger.Debug("Service discovery refresh complete")
115 s.edgeIPs = result.edgeIPs
116 } else {
117 logger.WithError(result.err).Error("Service discovery error")
118 }
119 }
120 }
121}
122
123func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}, u uuid.UUID) error {
124 logger := s.config.Logger
125 edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
126 if err != nil {
127 logger.Infof("ResolveEdgeIPs err")
128 return err
129 }
130 s.edgeIPs = edgeIPs
131 if s.config.HAConnections > len(edgeIPs) {
132 logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
133 s.config.HAConnections = len(edgeIPs)
134 }
135 s.lastResolve = time.Now()
136 // check entitlement and version too old error before attempting to register more tunnels
137 s.nextUnusedEdgeIP = s.config.HAConnections
138 go s.startFirstTunnel(ctx, connectedSignal, u)
139 select {
140 case <-ctx.Done():
141 <-s.tunnelErrors
142 // Error can't be nil. A nil error signals that initialization succeed
143 return fmt.Errorf("context was canceled")
144 case tunnelError := <-s.tunnelErrors:
145 return tunnelError.err
146 case <-connectedSignal:
147 }
148 // At least one successful connection, so start the rest
149 for i := 1; i < s.config.HAConnections; i++ {
150 go s.startTunnel(ctx, i, make(chan struct{}), u)
151 time.Sleep(registrationInterval)
152 }
153 return nil
154}
155
156// startTunnel starts the first tunnel connection. The resulting error will be sent on
157// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
158func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan struct{}, u uuid.UUID) {
159 err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
160 defer func() {
161 s.tunnelErrors <- tunnelError{index: 0, err: err}
162 }()
163
164 for s.unusedIPs() {
165 select {
166 case <-ctx.Done():
167 return
168 default:
169 }
170 switch err.(type) {
171 case nil:
172 return
173 // try the next address if it was a dialError(network problem) or
174 // dupConnRegisterTunnelError
175 case dialError, dupConnRegisterTunnelError:
176 s.replaceEdgeIP(0)
177 default:
178 return
179 }
180 err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
181 }
182}
183
184// startTunnel starts a new tunnel connection. The resulting error will be sent on
185// s.tunnelErrors.
186func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}, u uuid.UUID) {
187 err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u)
188 s.tunnelErrors <- tunnelError{index: index, err: err}
189}
190
191func (s *Supervisor) newConnectedTunnelSignal(index int) chan struct{} {
192 signal := make(chan struct{})
193 s.tunnelsConnecting[index] = signal
194 s.nextConnectedSignal = signal
195 s.nextConnectedIndex = index
196 return signal
197}
198
199func (s *Supervisor) waitForNextTunnel(index int) bool {
200 delete(s.tunnelsConnecting, index)
201 s.nextConnectedSignal = nil
202 for k, v := range s.tunnelsConnecting {
203 s.nextConnectedIndex = k
204 s.nextConnectedSignal = v
205 return true
206 }
207 return false
208}
209
210func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr {
211 return s.edgeIPs[index%len(s.edgeIPs)]
212}
213
214func (s *Supervisor) refreshEdgeIPs() {
215 if s.resolverC != nil {
216 return
217 }
218 if time.Since(s.lastResolve) < resolveTTL {
219 return
220 }
221 s.resolverC = make(chan resolveResult)
222 go func() {
223 edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
224 s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err}
225 }()
226}
227
228func (s *Supervisor) unusedIPs() bool {
229 return s.nextUnusedEdgeIP < len(s.edgeIPs)
230}
231
232func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
233 s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
234 s.nextUnusedEdgeIP++
235}
236