cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2018.10.2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

origin/supervisor.go

234lines · modecode

1package origin
2
3import (
4 "fmt"
5 "net"
6 "time"
7
8 "golang.org/x/net/context"
9)
10
11const (
12 // Waiting time before retrying a failed tunnel connection
13 tunnelRetryDuration = time.Second * 10
14 // SRV record resolution TTL
15 resolveTTL = time.Hour
16 // Interval between registering new tunnels
17 registrationInterval = time.Second
18)
19
20type Supervisor struct {
21 config *TunnelConfig
22 edgeIPs []*net.TCPAddr
23 // nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
24 nextUnusedEdgeIP int
25 lastResolve time.Time
26 resolverC chan resolveResult
27 tunnelErrors chan tunnelError
28 tunnelsConnecting map[int]chan struct{}
29 // nextConnectedIndex and nextConnectedSignal are used to wait for all
30 // currently-connecting tunnels to finish connecting so we can reset backoff timer
31 nextConnectedIndex int
32 nextConnectedSignal chan struct{}
33}
34
35type resolveResult struct {
36 edgeIPs []*net.TCPAddr
37 err error
38}
39
40type tunnelError struct {
41 index int
42 err error
43}
44
45func NewSupervisor(config *TunnelConfig) *Supervisor {
46 return &Supervisor{
47 config: config,
48 tunnelErrors: make(chan tunnelError),
49 tunnelsConnecting: map[int]chan struct{}{},
50 }
51}
52
53func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) error {
54 logger := s.config.Logger
55 if err := s.initialize(ctx, connectedSignal); err != nil {
56 return err
57 }
58 var tunnelsWaiting []int
59 backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
60 var backoffTimer <-chan time.Time
61 tunnelsActive := s.config.HAConnections
62
63 for {
64 select {
65 // Context cancelled
66 case <-ctx.Done():
67 for tunnelsActive > 0 {
68 <-s.tunnelErrors
69 tunnelsActive--
70 }
71 return nil
72 // startTunnel returned with error
73 // (note that this may also be caused by context cancellation)
74 case tunnelError := <-s.tunnelErrors:
75 tunnelsActive--
76 if tunnelError.err != nil {
77 logger.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
78 tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
79 s.waitForNextTunnel(tunnelError.index)
80
81 if backoffTimer == nil {
82 backoffTimer = backoff.BackoffTimer()
83 }
84
85 // If the error is a dial error, the problem is likely to be network related
86 // try another addr before refreshing since we are likely to get back the
87 // same IPs in the same order. Same problem with duplicate connection error.
88 if s.unusedIPs() {
89 s.replaceEdgeIP(tunnelError.index)
90 } else {
91 s.refreshEdgeIPs()
92 }
93 }
94 // Backoff was set and its timer expired
95 case <-backoffTimer:
96 backoffTimer = nil
97 for _, index := range tunnelsWaiting {
98 go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
99 }
100 tunnelsActive += len(tunnelsWaiting)
101 tunnelsWaiting = nil
102 // Tunnel successfully connected
103 case <-s.nextConnectedSignal:
104 if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
105 // No more tunnels outstanding, clear backoff timer
106 backoff.SetGracePeriod()
107 }
108 // DNS resolution returned
109 case result := <-s.resolverC:
110 s.lastResolve = time.Now()
111 s.resolverC = nil
112 if result.err == nil {
113 logger.Debug("Service discovery refresh complete")
114 s.edgeIPs = result.edgeIPs
115 } else {
116 logger.WithError(result.err).Error("Service discovery error")
117 }
118 }
119 }
120}
121
122func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
123 logger := s.config.Logger
124 edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
125 if err != nil {
126 logger.Infof("ResolveEdgeIPs err")
127 return err
128 }
129 s.edgeIPs = edgeIPs
130 if s.config.HAConnections > len(edgeIPs) {
131 logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
132 s.config.HAConnections = len(edgeIPs)
133 }
134 s.lastResolve = time.Now()
135 // check entitlement and version too old error before attempting to register more tunnels
136 s.nextUnusedEdgeIP = s.config.HAConnections
137 go s.startFirstTunnel(ctx, connectedSignal)
138 select {
139 case <-ctx.Done():
140 <-s.tunnelErrors
141 // Error can't be nil. A nil error signals that initialization succeed
142 return fmt.Errorf("context was canceled")
143 case tunnelError := <-s.tunnelErrors:
144 return tunnelError.err
145 case <-connectedSignal:
146 }
147 // At least one successful connection, so start the rest
148 for i := 1; i < s.config.HAConnections; i++ {
149 go s.startTunnel(ctx, i, make(chan struct{}))
150 time.Sleep(registrationInterval)
151 }
152 return nil
153}
154
155// startTunnel starts the first tunnel connection. The resulting error will be sent on
156// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
157func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan struct{}) {
158 err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal)
159 defer func() {
160 s.tunnelErrors <- tunnelError{index: 0, err: err}
161 }()
162
163 for s.unusedIPs() {
164 select {
165 case <-ctx.Done():
166 return
167 default:
168 }
169 switch err.(type) {
170 case nil:
171 return
172 // try the next address if it was a dialError(network problem) or
173 // dupConnRegisterTunnelError
174 case dialError, dupConnRegisterTunnelError:
175 s.replaceEdgeIP(0)
176 default:
177 return
178 }
179 err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal)
180 }
181}
182
183// startTunnel starts a new tunnel connection. The resulting error will be sent on
184// s.tunnelErrors.
185func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}) {
186 err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal)
187 s.tunnelErrors <- tunnelError{index: index, err: err}
188}
189
190func (s *Supervisor) newConnectedTunnelSignal(index int) chan struct{} {
191 signal := make(chan struct{})
192 s.tunnelsConnecting[index] = signal
193 s.nextConnectedSignal = signal
194 s.nextConnectedIndex = index
195 return signal
196}
197
198func (s *Supervisor) waitForNextTunnel(index int) bool {
199 delete(s.tunnelsConnecting, index)
200 s.nextConnectedSignal = nil
201 for k, v := range s.tunnelsConnecting {
202 s.nextConnectedIndex = k
203 s.nextConnectedSignal = v
204 return true
205 }
206 return false
207}
208
209func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr {
210 return s.edgeIPs[index%len(s.edgeIPs)]
211}
212
213func (s *Supervisor) refreshEdgeIPs() {
214 if s.resolverC != nil {
215 return
216 }
217 if time.Since(s.lastResolve) < resolveTTL {
218 return
219 }
220 s.resolverC = make(chan resolveResult)
221 go func() {
222 edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
223 s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err}
224 }()
225}
226
227func (s *Supervisor) unusedIPs() bool {
228 return s.nextUnusedEdgeIP < len(s.edgeIPs)
229}
230
231func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
232 s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
233 s.nextUnusedEdgeIP++
234}
235