cloudflare/cloudflared
Publicmirrored from https://github.com/cloudflare/cloudflaredAvailable
origin/supervisor.go
238lines · modecode
| 1 | package origin |
| 2 | |
| 3 | import ( |
| 4 | "context" |
| 5 | "fmt" |
| 6 | "net" |
| 7 | "time" |
| 8 | |
| 9 | "github.com/cloudflare/cloudflared/signal" |
| 10 | |
| 11 | "github.com/google/uuid" |
| 12 | ) |
| 13 | |
| 14 | const ( |
| 15 | // Waiting time before retrying a failed tunnel connection |
| 16 | tunnelRetryDuration = time.Second * 10 |
| 17 | // SRV record resolution TTL |
| 18 | resolveTTL = time.Hour |
| 19 | // Interval between registering new tunnels |
| 20 | registrationInterval = time.Second |
| 21 | ) |
| 22 | |
| 23 | type Supervisor struct { |
| 24 | config *TunnelConfig |
| 25 | edgeIPs []*net.TCPAddr |
| 26 | // nextUnusedEdgeIP is the index of the next addr k edgeIPs to try |
| 27 | nextUnusedEdgeIP int |
| 28 | lastResolve time.Time |
| 29 | resolverC chan resolveResult |
| 30 | tunnelErrors chan tunnelError |
| 31 | tunnelsConnecting map[int]chan struct{} |
| 32 | // nextConnectedIndex and nextConnectedSignal are used to wait for all |
| 33 | // currently-connecting tunnels to finish connecting so we can reset backoff timer |
| 34 | nextConnectedIndex int |
| 35 | nextConnectedSignal chan struct{} |
| 36 | } |
| 37 | |
| 38 | type resolveResult struct { |
| 39 | edgeIPs []*net.TCPAddr |
| 40 | err error |
| 41 | } |
| 42 | |
| 43 | type tunnelError struct { |
| 44 | index int |
| 45 | err error |
| 46 | } |
| 47 | |
| 48 | func NewSupervisor(config *TunnelConfig) *Supervisor { |
| 49 | return &Supervisor{ |
| 50 | config: config, |
| 51 | tunnelErrors: make(chan tunnelError), |
| 52 | tunnelsConnecting: map[int]chan struct{}{}, |
| 53 | } |
| 54 | } |
| 55 | |
| 56 | func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error { |
| 57 | logger := s.config.Logger |
| 58 | if err := s.initialize(ctx, connectedSignal, u); err != nil { |
| 59 | return err |
| 60 | } |
| 61 | var tunnelsWaiting []int |
| 62 | backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true} |
| 63 | var backoffTimer <-chan time.Time |
| 64 | tunnelsActive := s.config.HAConnections |
| 65 | |
| 66 | for { |
| 67 | select { |
| 68 | // Context cancelled |
| 69 | case <-ctx.Done(): |
| 70 | for tunnelsActive > 0 { |
| 71 | <-s.tunnelErrors |
| 72 | tunnelsActive-- |
| 73 | } |
| 74 | return nil |
| 75 | // startTunnel returned with error |
| 76 | // (note that this may also be caused by context cancellation) |
| 77 | case tunnelError := <-s.tunnelErrors: |
| 78 | tunnelsActive-- |
| 79 | if tunnelError.err != nil { |
| 80 | logger.WithError(tunnelError.err).Warn("Tunnel disconnected due to error") |
| 81 | tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) |
| 82 | s.waitForNextTunnel(tunnelError.index) |
| 83 | |
| 84 | if backoffTimer == nil { |
| 85 | backoffTimer = backoff.BackoffTimer() |
| 86 | } |
| 87 | |
| 88 | // If the error is a dial error, the problem is likely to be network related |
| 89 | // try another addr before refreshing since we are likely to get back the |
| 90 | // same IPs in the same order. Same problem with duplicate connection error. |
| 91 | if s.unusedIPs() { |
| 92 | s.replaceEdgeIP(tunnelError.index) |
| 93 | } else { |
| 94 | s.refreshEdgeIPs() |
| 95 | } |
| 96 | } |
| 97 | // Backoff was set and its timer expired |
| 98 | case <-backoffTimer: |
| 99 | backoffTimer = nil |
| 100 | for _, index := range tunnelsWaiting { |
| 101 | go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u) |
| 102 | } |
| 103 | tunnelsActive += len(tunnelsWaiting) |
| 104 | tunnelsWaiting = nil |
| 105 | // Tunnel successfully connected |
| 106 | case <-s.nextConnectedSignal: |
| 107 | if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 { |
| 108 | // No more tunnels outstanding, clear backoff timer |
| 109 | backoff.SetGracePeriod() |
| 110 | } |
| 111 | // DNS resolution returned |
| 112 | case result := <-s.resolverC: |
| 113 | s.lastResolve = time.Now() |
| 114 | s.resolverC = nil |
| 115 | if result.err == nil { |
| 116 | logger.Debug("Service discovery refresh complete") |
| 117 | s.edgeIPs = result.edgeIPs |
| 118 | } else { |
| 119 | logger.WithError(result.err).Error("Service discovery error") |
| 120 | } |
| 121 | } |
| 122 | } |
| 123 | } |
| 124 | |
| 125 | func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error { |
| 126 | logger := s.config.Logger |
| 127 | edgeIPs, err := ResolveEdgeIPs(logger, s.config.EdgeAddrs) |
| 128 | if err != nil { |
| 129 | logger.Infof("ResolveEdgeIPs err") |
| 130 | return err |
| 131 | } |
| 132 | s.edgeIPs = edgeIPs |
| 133 | if s.config.HAConnections > len(edgeIPs) { |
| 134 | logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs)) |
| 135 | s.config.HAConnections = len(edgeIPs) |
| 136 | } |
| 137 | s.lastResolve = time.Now() |
| 138 | // check entitlement and version too old error before attempting to register more tunnels |
| 139 | s.nextUnusedEdgeIP = s.config.HAConnections |
| 140 | go s.startFirstTunnel(ctx, connectedSignal, u) |
| 141 | select { |
| 142 | case <-ctx.Done(): |
| 143 | <-s.tunnelErrors |
| 144 | // Error can't be nil. A nil error signals that initialization succeed |
| 145 | return fmt.Errorf("context was canceled") |
| 146 | case tunnelError := <-s.tunnelErrors: |
| 147 | return tunnelError.err |
| 148 | case <-connectedSignal.Wait(): |
| 149 | } |
| 150 | // At least one successful connection, so start the rest |
| 151 | for i := 1; i < s.config.HAConnections; i++ { |
| 152 | ch := signal.New(make(chan struct{})) |
| 153 | go s.startTunnel(ctx, i, ch, u) |
| 154 | time.Sleep(registrationInterval) |
| 155 | } |
| 156 | return nil |
| 157 | } |
| 158 | |
| 159 | // startTunnel starts the first tunnel connection. The resulting error will be sent on |
| 160 | // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed |
| 161 | func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) { |
| 162 | err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u) |
| 163 | defer func() { |
| 164 | s.tunnelErrors <- tunnelError{index: 0, err: err} |
| 165 | }() |
| 166 | |
| 167 | for s.unusedIPs() { |
| 168 | select { |
| 169 | case <-ctx.Done(): |
| 170 | return |
| 171 | default: |
| 172 | } |
| 173 | switch err.(type) { |
| 174 | case nil: |
| 175 | return |
| 176 | // try the next address if it was a dialError(network problem) or |
| 177 | // dupConnRegisterTunnelError |
| 178 | case dialError, dupConnRegisterTunnelError: |
| 179 | s.replaceEdgeIP(0) |
| 180 | default: |
| 181 | return |
| 182 | } |
| 183 | err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u) |
| 184 | } |
| 185 | } |
| 186 | |
| 187 | // startTunnel starts a new tunnel connection. The resulting error will be sent on |
| 188 | // s.tunnelErrors. |
| 189 | func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) { |
| 190 | err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u) |
| 191 | s.tunnelErrors <- tunnelError{index: index, err: err} |
| 192 | } |
| 193 | |
| 194 | func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal { |
| 195 | sig := make(chan struct{}) |
| 196 | s.tunnelsConnecting[index] = sig |
| 197 | s.nextConnectedSignal = sig |
| 198 | s.nextConnectedIndex = index |
| 199 | return signal.New(sig) |
| 200 | } |
| 201 | |
| 202 | func (s *Supervisor) waitForNextTunnel(index int) bool { |
| 203 | delete(s.tunnelsConnecting, index) |
| 204 | s.nextConnectedSignal = nil |
| 205 | for k, v := range s.tunnelsConnecting { |
| 206 | s.nextConnectedIndex = k |
| 207 | s.nextConnectedSignal = v |
| 208 | return true |
| 209 | } |
| 210 | return false |
| 211 | } |
| 212 | |
| 213 | func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr { |
| 214 | return s.edgeIPs[index%len(s.edgeIPs)] |
| 215 | } |
| 216 | |
| 217 | func (s *Supervisor) refreshEdgeIPs() { |
| 218 | if s.resolverC != nil { |
| 219 | return |
| 220 | } |
| 221 | if time.Since(s.lastResolve) < resolveTTL { |
| 222 | return |
| 223 | } |
| 224 | s.resolverC = make(chan resolveResult) |
| 225 | go func() { |
| 226 | edgeIPs, err := ResolveEdgeIPs(s.config.Logger, s.config.EdgeAddrs) |
| 227 | s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err} |
| 228 | }() |
| 229 | } |
| 230 | |
| 231 | func (s *Supervisor) unusedIPs() bool { |
| 232 | return s.nextUnusedEdgeIP < len(s.edgeIPs) |
| 233 | } |
| 234 | |
| 235 | func (s *Supervisor) replaceEdgeIP(badIPIndex int) { |
| 236 | s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP] |
| 237 | s.nextUnusedEdgeIP++ |
| 238 | } |
| 239 | |