cloudflare/cloudflared
Publicmirrored from https://github.com/cloudflare/cloudflaredAvailable
origin/supervisor.go
234lines · modecode
| 1 | package origin |
| 2 | |
| 3 | import ( |
| 4 | "fmt" |
| 5 | "net" |
| 6 | "time" |
| 7 | |
| 8 | "golang.org/x/net/context" |
| 9 | ) |
| 10 | |
| 11 | const ( |
| 12 | // Waiting time before retrying a failed tunnel connection |
| 13 | tunnelRetryDuration = time.Second * 10 |
| 14 | // SRV record resolution TTL |
| 15 | resolveTTL = time.Hour |
| 16 | // Interval between registering new tunnels |
| 17 | registrationInterval = time.Second |
| 18 | ) |
| 19 | |
| 20 | type Supervisor struct { |
| 21 | config *TunnelConfig |
| 22 | edgeIPs []*net.TCPAddr |
| 23 | // nextUnusedEdgeIP is the index of the next addr k edgeIPs to try |
| 24 | nextUnusedEdgeIP int |
| 25 | lastResolve time.Time |
| 26 | resolverC chan resolveResult |
| 27 | tunnelErrors chan tunnelError |
| 28 | tunnelsConnecting map[int]chan struct{} |
| 29 | // nextConnectedIndex and nextConnectedSignal are used to wait for all |
| 30 | // currently-connecting tunnels to finish connecting so we can reset backoff timer |
| 31 | nextConnectedIndex int |
| 32 | nextConnectedSignal chan struct{} |
| 33 | } |
| 34 | |
| 35 | type resolveResult struct { |
| 36 | edgeIPs []*net.TCPAddr |
| 37 | err error |
| 38 | } |
| 39 | |
| 40 | type tunnelError struct { |
| 41 | index int |
| 42 | err error |
| 43 | } |
| 44 | |
| 45 | func NewSupervisor(config *TunnelConfig) *Supervisor { |
| 46 | return &Supervisor{ |
| 47 | config: config, |
| 48 | tunnelErrors: make(chan tunnelError), |
| 49 | tunnelsConnecting: map[int]chan struct{}{}, |
| 50 | } |
| 51 | } |
| 52 | |
| 53 | func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) error { |
| 54 | logger := s.config.Logger |
| 55 | if err := s.initialize(ctx, connectedSignal); err != nil { |
| 56 | return err |
| 57 | } |
| 58 | var tunnelsWaiting []int |
| 59 | backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true} |
| 60 | var backoffTimer <-chan time.Time |
| 61 | tunnelsActive := s.config.HAConnections |
| 62 | |
| 63 | for { |
| 64 | select { |
| 65 | // Context cancelled |
| 66 | case <-ctx.Done(): |
| 67 | for tunnelsActive > 0 { |
| 68 | <-s.tunnelErrors |
| 69 | tunnelsActive-- |
| 70 | } |
| 71 | return nil |
| 72 | // startTunnel returned with error |
| 73 | // (note that this may also be caused by context cancellation) |
| 74 | case tunnelError := <-s.tunnelErrors: |
| 75 | tunnelsActive-- |
| 76 | if tunnelError.err != nil { |
| 77 | logger.WithError(tunnelError.err).Warn("Tunnel disconnected due to error") |
| 78 | tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) |
| 79 | s.waitForNextTunnel(tunnelError.index) |
| 80 | |
| 81 | if backoffTimer == nil { |
| 82 | backoffTimer = backoff.BackoffTimer() |
| 83 | } |
| 84 | |
| 85 | // If the error is a dial error, the problem is likely to be network related |
| 86 | // try another addr before refreshing since we are likely to get back the |
| 87 | // same IPs in the same order. Same problem with duplicate connection error. |
| 88 | if s.unusedIPs() { |
| 89 | s.replaceEdgeIP(tunnelError.index) |
| 90 | } else { |
| 91 | s.refreshEdgeIPs() |
| 92 | } |
| 93 | } |
| 94 | // Backoff was set and its timer expired |
| 95 | case <-backoffTimer: |
| 96 | backoffTimer = nil |
| 97 | for _, index := range tunnelsWaiting { |
| 98 | go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index)) |
| 99 | } |
| 100 | tunnelsActive += len(tunnelsWaiting) |
| 101 | tunnelsWaiting = nil |
| 102 | // Tunnel successfully connected |
| 103 | case <-s.nextConnectedSignal: |
| 104 | if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 { |
| 105 | // No more tunnels outstanding, clear backoff timer |
| 106 | backoff.SetGracePeriod() |
| 107 | } |
| 108 | // DNS resolution returned |
| 109 | case result := <-s.resolverC: |
| 110 | s.lastResolve = time.Now() |
| 111 | s.resolverC = nil |
| 112 | if result.err == nil { |
| 113 | logger.Debug("Service discovery refresh complete") |
| 114 | s.edgeIPs = result.edgeIPs |
| 115 | } else { |
| 116 | logger.WithError(result.err).Error("Service discovery error") |
| 117 | } |
| 118 | } |
| 119 | } |
| 120 | } |
| 121 | |
| 122 | func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error { |
| 123 | logger := s.config.Logger |
| 124 | edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs) |
| 125 | if err != nil { |
| 126 | logger.Infof("ResolveEdgeIPs err") |
| 127 | return err |
| 128 | } |
| 129 | s.edgeIPs = edgeIPs |
| 130 | if s.config.HAConnections > len(edgeIPs) { |
| 131 | logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs)) |
| 132 | s.config.HAConnections = len(edgeIPs) |
| 133 | } |
| 134 | s.lastResolve = time.Now() |
| 135 | // check entitlement and version too old error before attempting to register more tunnels |
| 136 | s.nextUnusedEdgeIP = s.config.HAConnections |
| 137 | go s.startFirstTunnel(ctx, connectedSignal) |
| 138 | select { |
| 139 | case <-ctx.Done(): |
| 140 | <-s.tunnelErrors |
| 141 | // Error can't be nil. A nil error signals that initialization succeed |
| 142 | return fmt.Errorf("context was canceled") |
| 143 | case tunnelError := <-s.tunnelErrors: |
| 144 | return tunnelError.err |
| 145 | case <-connectedSignal: |
| 146 | } |
| 147 | // At least one successful connection, so start the rest |
| 148 | for i := 1; i < s.config.HAConnections; i++ { |
| 149 | go s.startTunnel(ctx, i, make(chan struct{})) |
| 150 | time.Sleep(registrationInterval) |
| 151 | } |
| 152 | return nil |
| 153 | } |
| 154 | |
| 155 | // startTunnel starts the first tunnel connection. The resulting error will be sent on |
| 156 | // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed |
| 157 | func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan struct{}) { |
| 158 | err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal) |
| 159 | defer func() { |
| 160 | s.tunnelErrors <- tunnelError{index: 0, err: err} |
| 161 | }() |
| 162 | |
| 163 | for s.unusedIPs() { |
| 164 | select { |
| 165 | case <-ctx.Done(): |
| 166 | return |
| 167 | default: |
| 168 | } |
| 169 | switch err.(type) { |
| 170 | case nil: |
| 171 | return |
| 172 | // try the next address if it was a dialError(network problem) or |
| 173 | // dupConnRegisterTunnelError |
| 174 | case dialError, dupConnRegisterTunnelError: |
| 175 | s.replaceEdgeIP(0) |
| 176 | default: |
| 177 | return |
| 178 | } |
| 179 | err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal) |
| 180 | } |
| 181 | } |
| 182 | |
| 183 | // startTunnel starts a new tunnel connection. The resulting error will be sent on |
| 184 | // s.tunnelErrors. |
| 185 | func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}) { |
| 186 | err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal) |
| 187 | s.tunnelErrors <- tunnelError{index: index, err: err} |
| 188 | } |
| 189 | |
| 190 | func (s *Supervisor) newConnectedTunnelSignal(index int) chan struct{} { |
| 191 | signal := make(chan struct{}) |
| 192 | s.tunnelsConnecting[index] = signal |
| 193 | s.nextConnectedSignal = signal |
| 194 | s.nextConnectedIndex = index |
| 195 | return signal |
| 196 | } |
| 197 | |
| 198 | func (s *Supervisor) waitForNextTunnel(index int) bool { |
| 199 | delete(s.tunnelsConnecting, index) |
| 200 | s.nextConnectedSignal = nil |
| 201 | for k, v := range s.tunnelsConnecting { |
| 202 | s.nextConnectedIndex = k |
| 203 | s.nextConnectedSignal = v |
| 204 | return true |
| 205 | } |
| 206 | return false |
| 207 | } |
| 208 | |
| 209 | func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr { |
| 210 | return s.edgeIPs[index%len(s.edgeIPs)] |
| 211 | } |
| 212 | |
| 213 | func (s *Supervisor) refreshEdgeIPs() { |
| 214 | if s.resolverC != nil { |
| 215 | return |
| 216 | } |
| 217 | if time.Since(s.lastResolve) < resolveTTL { |
| 218 | return |
| 219 | } |
| 220 | s.resolverC = make(chan resolveResult) |
| 221 | go func() { |
| 222 | edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs) |
| 223 | s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err} |
| 224 | }() |
| 225 | } |
| 226 | |
| 227 | func (s *Supervisor) unusedIPs() bool { |
| 228 | return s.nextUnusedEdgeIP < len(s.edgeIPs) |
| 229 | } |
| 230 | |
| 231 | func (s *Supervisor) replaceEdgeIP(badIPIndex int) { |
| 232 | s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP] |
| 233 | s.nextUnusedEdgeIP++ |
| 234 | } |