cloudflare/cloudflared
Publicmirrored from https://github.com/cloudflare/cloudflaredAvailable
origin/supervisor.go
239lines · modecode
| 1 | package origin |
| 2 | |
| 3 | import ( |
| 4 | "context" |
| 5 | "fmt" |
| 6 | "net" |
| 7 | "time" |
| 8 | |
| 9 | "github.com/cloudflare/cloudflared/connection" |
| 10 | "github.com/cloudflare/cloudflared/signal" |
| 11 | |
| 12 | "github.com/google/uuid" |
| 13 | ) |
| 14 | |
| 15 | const ( |
| 16 | // Waiting time before retrying a failed tunnel connection |
| 17 | tunnelRetryDuration = time.Second * 10 |
| 18 | // SRV record resolution TTL |
| 19 | resolveTTL = time.Hour |
| 20 | // Interval between registering new tunnels |
| 21 | registrationInterval = time.Second |
| 22 | ) |
| 23 | |
| 24 | type Supervisor struct { |
| 25 | config *TunnelConfig |
| 26 | edgeIPs []*net.TCPAddr |
| 27 | // nextUnusedEdgeIP is the index of the next addr k edgeIPs to try |
| 28 | nextUnusedEdgeIP int |
| 29 | lastResolve time.Time |
| 30 | resolverC chan resolveResult |
| 31 | tunnelErrors chan tunnelError |
| 32 | tunnelsConnecting map[int]chan struct{} |
| 33 | // nextConnectedIndex and nextConnectedSignal are used to wait for all |
| 34 | // currently-connecting tunnels to finish connecting so we can reset backoff timer |
| 35 | nextConnectedIndex int |
| 36 | nextConnectedSignal chan struct{} |
| 37 | } |
| 38 | |
| 39 | type resolveResult struct { |
| 40 | edgeIPs []*net.TCPAddr |
| 41 | err error |
| 42 | } |
| 43 | |
| 44 | type tunnelError struct { |
| 45 | index int |
| 46 | err error |
| 47 | } |
| 48 | |
| 49 | func NewSupervisor(config *TunnelConfig) *Supervisor { |
| 50 | return &Supervisor{ |
| 51 | config: config, |
| 52 | tunnelErrors: make(chan tunnelError), |
| 53 | tunnelsConnecting: map[int]chan struct{}{}, |
| 54 | } |
| 55 | } |
| 56 | |
| 57 | func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error { |
| 58 | logger := s.config.Logger |
| 59 | if err := s.initialize(ctx, connectedSignal, u); err != nil { |
| 60 | return err |
| 61 | } |
| 62 | var tunnelsWaiting []int |
| 63 | backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true} |
| 64 | var backoffTimer <-chan time.Time |
| 65 | tunnelsActive := s.config.HAConnections |
| 66 | |
| 67 | for { |
| 68 | select { |
| 69 | // Context cancelled |
| 70 | case <-ctx.Done(): |
| 71 | for tunnelsActive > 0 { |
| 72 | <-s.tunnelErrors |
| 73 | tunnelsActive-- |
| 74 | } |
| 75 | return nil |
| 76 | // startTunnel returned with error |
| 77 | // (note that this may also be caused by context cancellation) |
| 78 | case tunnelError := <-s.tunnelErrors: |
| 79 | tunnelsActive-- |
| 80 | if tunnelError.err != nil { |
| 81 | logger.WithError(tunnelError.err).Warn("Tunnel disconnected due to error") |
| 82 | tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) |
| 83 | s.waitForNextTunnel(tunnelError.index) |
| 84 | |
| 85 | if backoffTimer == nil { |
| 86 | backoffTimer = backoff.BackoffTimer() |
| 87 | } |
| 88 | |
| 89 | // If the error is a dial error, the problem is likely to be network related |
| 90 | // try another addr before refreshing since we are likely to get back the |
| 91 | // same IPs in the same order. Same problem with duplicate connection error. |
| 92 | if s.unusedIPs() { |
| 93 | s.replaceEdgeIP(tunnelError.index) |
| 94 | } else { |
| 95 | s.refreshEdgeIPs() |
| 96 | } |
| 97 | } |
| 98 | // Backoff was set and its timer expired |
| 99 | case <-backoffTimer: |
| 100 | backoffTimer = nil |
| 101 | for _, index := range tunnelsWaiting { |
| 102 | go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u) |
| 103 | } |
| 104 | tunnelsActive += len(tunnelsWaiting) |
| 105 | tunnelsWaiting = nil |
| 106 | // Tunnel successfully connected |
| 107 | case <-s.nextConnectedSignal: |
| 108 | if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 { |
| 109 | // No more tunnels outstanding, clear backoff timer |
| 110 | backoff.SetGracePeriod() |
| 111 | } |
| 112 | // DNS resolution returned |
| 113 | case result := <-s.resolverC: |
| 114 | s.lastResolve = time.Now() |
| 115 | s.resolverC = nil |
| 116 | if result.err == nil { |
| 117 | logger.Debug("Service discovery refresh complete") |
| 118 | s.edgeIPs = result.edgeIPs |
| 119 | } else { |
| 120 | logger.WithError(result.err).Error("Service discovery error") |
| 121 | } |
| 122 | } |
| 123 | } |
| 124 | } |
| 125 | |
| 126 | func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error { |
| 127 | logger := s.config.Logger |
| 128 | edgeIPs, err := connection.ResolveEdgeIPs(logger, s.config.EdgeAddrs) |
| 129 | if err != nil { |
| 130 | logger.Infof("ResolveEdgeIPs err") |
| 131 | return err |
| 132 | } |
| 133 | s.edgeIPs = edgeIPs |
| 134 | if s.config.HAConnections > len(edgeIPs) { |
| 135 | logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs)) |
| 136 | s.config.HAConnections = len(edgeIPs) |
| 137 | } |
| 138 | s.lastResolve = time.Now() |
| 139 | // check entitlement and version too old error before attempting to register more tunnels |
| 140 | s.nextUnusedEdgeIP = s.config.HAConnections |
| 141 | go s.startFirstTunnel(ctx, connectedSignal, u) |
| 142 | select { |
| 143 | case <-ctx.Done(): |
| 144 | <-s.tunnelErrors |
| 145 | // Error can't be nil. A nil error signals that initialization succeed |
| 146 | return fmt.Errorf("context was canceled") |
| 147 | case tunnelError := <-s.tunnelErrors: |
| 148 | return tunnelError.err |
| 149 | case <-connectedSignal.Wait(): |
| 150 | } |
| 151 | // At least one successful connection, so start the rest |
| 152 | for i := 1; i < s.config.HAConnections; i++ { |
| 153 | ch := signal.New(make(chan struct{})) |
| 154 | go s.startTunnel(ctx, i, ch, u) |
| 155 | time.Sleep(registrationInterval) |
| 156 | } |
| 157 | return nil |
| 158 | } |
| 159 | |
| 160 | // startTunnel starts the first tunnel connection. The resulting error will be sent on |
| 161 | // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed |
| 162 | func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) { |
| 163 | err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u) |
| 164 | defer func() { |
| 165 | s.tunnelErrors <- tunnelError{index: 0, err: err} |
| 166 | }() |
| 167 | |
| 168 | for s.unusedIPs() { |
| 169 | select { |
| 170 | case <-ctx.Done(): |
| 171 | return |
| 172 | default: |
| 173 | } |
| 174 | switch err.(type) { |
| 175 | case nil: |
| 176 | return |
| 177 | // try the next address if it was a dialError(network problem) or |
| 178 | // dupConnRegisterTunnelError |
| 179 | case dialError, dupConnRegisterTunnelError: |
| 180 | s.replaceEdgeIP(0) |
| 181 | default: |
| 182 | return |
| 183 | } |
| 184 | err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u) |
| 185 | } |
| 186 | } |
| 187 | |
| 188 | // startTunnel starts a new tunnel connection. The resulting error will be sent on |
| 189 | // s.tunnelErrors. |
| 190 | func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) { |
| 191 | err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u) |
| 192 | s.tunnelErrors <- tunnelError{index: index, err: err} |
| 193 | } |
| 194 | |
| 195 | func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal { |
| 196 | sig := make(chan struct{}) |
| 197 | s.tunnelsConnecting[index] = sig |
| 198 | s.nextConnectedSignal = sig |
| 199 | s.nextConnectedIndex = index |
| 200 | return signal.New(sig) |
| 201 | } |
| 202 | |
| 203 | func (s *Supervisor) waitForNextTunnel(index int) bool { |
| 204 | delete(s.tunnelsConnecting, index) |
| 205 | s.nextConnectedSignal = nil |
| 206 | for k, v := range s.tunnelsConnecting { |
| 207 | s.nextConnectedIndex = k |
| 208 | s.nextConnectedSignal = v |
| 209 | return true |
| 210 | } |
| 211 | return false |
| 212 | } |
| 213 | |
| 214 | func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr { |
| 215 | return s.edgeIPs[index%len(s.edgeIPs)] |
| 216 | } |
| 217 | |
| 218 | func (s *Supervisor) refreshEdgeIPs() { |
| 219 | if s.resolverC != nil { |
| 220 | return |
| 221 | } |
| 222 | if time.Since(s.lastResolve) < resolveTTL { |
| 223 | return |
| 224 | } |
| 225 | s.resolverC = make(chan resolveResult) |
| 226 | go func() { |
| 227 | edgeIPs, err := connection.ResolveEdgeIPs(s.config.Logger, s.config.EdgeAddrs) |
| 228 | s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err} |
| 229 | }() |
| 230 | } |
| 231 | |
| 232 | func (s *Supervisor) unusedIPs() bool { |
| 233 | return s.nextUnusedEdgeIP < len(s.edgeIPs) |
| 234 | } |
| 235 | |
| 236 | func (s *Supervisor) replaceEdgeIP(badIPIndex int) { |
| 237 | s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP] |
| 238 | s.nextUnusedEdgeIP++ |
| 239 | } |
| 240 | |