cloudflare/cloudflared
Publicmirrored from https://github.com/cloudflare/cloudflaredAvailable
origin/supervisor.go
255lines · modecode
| 1 | package origin |
| 2 | |
| 3 | import ( |
| 4 | "context" |
| 5 | "fmt" |
| 6 | "net" |
| 7 | "time" |
| 8 | |
| 9 | "github.com/sirupsen/logrus" |
| 10 | |
| 11 | "github.com/cloudflare/cloudflared/connection" |
| 12 | "github.com/cloudflare/cloudflared/signal" |
| 13 | |
| 14 | "github.com/google/uuid" |
| 15 | ) |
| 16 | |
| 17 | const ( |
| 18 | // Waiting time before retrying a failed tunnel connection |
| 19 | tunnelRetryDuration = time.Second * 10 |
| 20 | // SRV record resolution TTL |
| 21 | resolveTTL = time.Hour |
| 22 | // Interval between registering new tunnels |
| 23 | registrationInterval = time.Second |
| 24 | ) |
| 25 | |
| 26 | type Supervisor struct { |
| 27 | config *TunnelConfig |
| 28 | edgeIPs []*net.TCPAddr |
| 29 | // nextUnusedEdgeIP is the index of the next addr k edgeIPs to try |
| 30 | nextUnusedEdgeIP int |
| 31 | lastResolve time.Time |
| 32 | resolverC chan resolveResult |
| 33 | tunnelErrors chan tunnelError |
| 34 | tunnelsConnecting map[int]chan struct{} |
| 35 | // nextConnectedIndex and nextConnectedSignal are used to wait for all |
| 36 | // currently-connecting tunnels to finish connecting so we can reset backoff timer |
| 37 | nextConnectedIndex int |
| 38 | nextConnectedSignal chan struct{} |
| 39 | |
| 40 | logger *logrus.Entry |
| 41 | } |
| 42 | |
| 43 | type resolveResult struct { |
| 44 | edgeIPs []*net.TCPAddr |
| 45 | err error |
| 46 | } |
| 47 | |
| 48 | type tunnelError struct { |
| 49 | index int |
| 50 | err error |
| 51 | } |
| 52 | |
| 53 | func NewSupervisor(config *TunnelConfig) *Supervisor { |
| 54 | return &Supervisor{ |
| 55 | config: config, |
| 56 | tunnelErrors: make(chan tunnelError), |
| 57 | tunnelsConnecting: map[int]chan struct{}{}, |
| 58 | logger: config.Logger.WithField("subsystem", "supervisor"), |
| 59 | } |
| 60 | } |
| 61 | |
| 62 | func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error { |
| 63 | logger := s.config.Logger |
| 64 | if err := s.initialize(ctx, connectedSignal, u); err != nil { |
| 65 | return err |
| 66 | } |
| 67 | var tunnelsWaiting []int |
| 68 | backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true} |
| 69 | var backoffTimer <-chan time.Time |
| 70 | tunnelsActive := s.config.HAConnections |
| 71 | |
| 72 | for { |
| 73 | select { |
| 74 | // Context cancelled |
| 75 | case <-ctx.Done(): |
| 76 | for tunnelsActive > 0 { |
| 77 | <-s.tunnelErrors |
| 78 | tunnelsActive-- |
| 79 | } |
| 80 | return nil |
| 81 | // startTunnel returned with error |
| 82 | // (note that this may also be caused by context cancellation) |
| 83 | case tunnelError := <-s.tunnelErrors: |
| 84 | tunnelsActive-- |
| 85 | if tunnelError.err != nil { |
| 86 | logger.WithError(tunnelError.err).Warn("Tunnel disconnected due to error") |
| 87 | tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) |
| 88 | s.waitForNextTunnel(tunnelError.index) |
| 89 | |
| 90 | if backoffTimer == nil { |
| 91 | backoffTimer = backoff.BackoffTimer() |
| 92 | } |
| 93 | |
| 94 | // If the error is a dial error, the problem is likely to be network related |
| 95 | // try another addr before refreshing since we are likely to get back the |
| 96 | // same IPs in the same order. Same problem with duplicate connection error. |
| 97 | if s.unusedIPs() { |
| 98 | s.replaceEdgeIP(tunnelError.index) |
| 99 | } else { |
| 100 | s.refreshEdgeIPs() |
| 101 | } |
| 102 | } |
| 103 | // Backoff was set and its timer expired |
| 104 | case <-backoffTimer: |
| 105 | backoffTimer = nil |
| 106 | for _, index := range tunnelsWaiting { |
| 107 | go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u) |
| 108 | } |
| 109 | tunnelsActive += len(tunnelsWaiting) |
| 110 | tunnelsWaiting = nil |
| 111 | // Tunnel successfully connected |
| 112 | case <-s.nextConnectedSignal: |
| 113 | if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 { |
| 114 | // No more tunnels outstanding, clear backoff timer |
| 115 | backoff.SetGracePeriod() |
| 116 | } |
| 117 | // DNS resolution returned |
| 118 | case result := <-s.resolverC: |
| 119 | s.lastResolve = time.Now() |
| 120 | s.resolverC = nil |
| 121 | if result.err == nil { |
| 122 | logger.Debug("Service discovery refresh complete") |
| 123 | s.edgeIPs = result.edgeIPs |
| 124 | } else { |
| 125 | logger.WithError(result.err).Error("Service discovery error") |
| 126 | } |
| 127 | } |
| 128 | } |
| 129 | } |
| 130 | |
| 131 | func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error { |
| 132 | logger := s.logger |
| 133 | |
| 134 | edgeIPs, err := s.resolveEdgeIPs() |
| 135 | |
| 136 | if err != nil { |
| 137 | logger.Infof("ResolveEdgeIPs err") |
| 138 | return err |
| 139 | } |
| 140 | s.edgeIPs = edgeIPs |
| 141 | if s.config.HAConnections > len(edgeIPs) { |
| 142 | logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs)) |
| 143 | s.config.HAConnections = len(edgeIPs) |
| 144 | } |
| 145 | s.lastResolve = time.Now() |
| 146 | // check entitlement and version too old error before attempting to register more tunnels |
| 147 | s.nextUnusedEdgeIP = s.config.HAConnections |
| 148 | go s.startFirstTunnel(ctx, connectedSignal, u) |
| 149 | select { |
| 150 | case <-ctx.Done(): |
| 151 | <-s.tunnelErrors |
| 152 | // Error can't be nil. A nil error signals that initialization succeed |
| 153 | return fmt.Errorf("context was canceled") |
| 154 | case tunnelError := <-s.tunnelErrors: |
| 155 | return tunnelError.err |
| 156 | case <-connectedSignal.Wait(): |
| 157 | } |
| 158 | // At least one successful connection, so start the rest |
| 159 | for i := 1; i < s.config.HAConnections; i++ { |
| 160 | ch := signal.New(make(chan struct{})) |
| 161 | go s.startTunnel(ctx, i, ch, u) |
| 162 | time.Sleep(registrationInterval) |
| 163 | } |
| 164 | return nil |
| 165 | } |
| 166 | |
| 167 | // startTunnel starts the first tunnel connection. The resulting error will be sent on |
| 168 | // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed |
| 169 | func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) { |
| 170 | err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u) |
| 171 | defer func() { |
| 172 | s.tunnelErrors <- tunnelError{index: 0, err: err} |
| 173 | }() |
| 174 | |
| 175 | for s.unusedIPs() { |
| 176 | select { |
| 177 | case <-ctx.Done(): |
| 178 | return |
| 179 | default: |
| 180 | } |
| 181 | switch err.(type) { |
| 182 | case nil: |
| 183 | return |
| 184 | // try the next address if it was a dialError(network problem) or |
| 185 | // dupConnRegisterTunnelError |
| 186 | case dialError, dupConnRegisterTunnelError: |
| 187 | s.replaceEdgeIP(0) |
| 188 | default: |
| 189 | return |
| 190 | } |
| 191 | err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u) |
| 192 | } |
| 193 | } |
| 194 | |
| 195 | // startTunnel starts a new tunnel connection. The resulting error will be sent on |
| 196 | // s.tunnelErrors. |
| 197 | func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) { |
| 198 | err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u) |
| 199 | s.tunnelErrors <- tunnelError{index: index, err: err} |
| 200 | } |
| 201 | |
| 202 | func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal { |
| 203 | sig := make(chan struct{}) |
| 204 | s.tunnelsConnecting[index] = sig |
| 205 | s.nextConnectedSignal = sig |
| 206 | s.nextConnectedIndex = index |
| 207 | return signal.New(sig) |
| 208 | } |
| 209 | |
| 210 | func (s *Supervisor) waitForNextTunnel(index int) bool { |
| 211 | delete(s.tunnelsConnecting, index) |
| 212 | s.nextConnectedSignal = nil |
| 213 | for k, v := range s.tunnelsConnecting { |
| 214 | s.nextConnectedIndex = k |
| 215 | s.nextConnectedSignal = v |
| 216 | return true |
| 217 | } |
| 218 | return false |
| 219 | } |
| 220 | |
| 221 | func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr { |
| 222 | return s.edgeIPs[index%len(s.edgeIPs)] |
| 223 | } |
| 224 | |
| 225 | func (s *Supervisor) resolveEdgeIPs() ([]*net.TCPAddr, error) { |
| 226 | // If --edge is specfied, resolve edge server addresses |
| 227 | if len(s.config.EdgeAddrs) > 0 { |
| 228 | return connection.ResolveAddrs(s.config.EdgeAddrs) |
| 229 | } |
| 230 | // Otherwise lookup edge server addresses through service discovery |
| 231 | return connection.EdgeDiscovery(s.logger) |
| 232 | } |
| 233 | |
| 234 | func (s *Supervisor) refreshEdgeIPs() { |
| 235 | if s.resolverC != nil { |
| 236 | return |
| 237 | } |
| 238 | if time.Since(s.lastResolve) < resolveTTL { |
| 239 | return |
| 240 | } |
| 241 | s.resolverC = make(chan resolveResult) |
| 242 | go func() { |
| 243 | edgeIPs, err := s.resolveEdgeIPs() |
| 244 | s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err} |
| 245 | }() |
| 246 | } |
| 247 | |
| 248 | func (s *Supervisor) unusedIPs() bool { |
| 249 | return s.nextUnusedEdgeIP < len(s.edgeIPs) |
| 250 | } |
| 251 | |
| 252 | func (s *Supervisor) replaceEdgeIP(badIPIndex int) { |
| 253 | s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP] |
| 254 | s.nextUnusedEdgeIP++ |
| 255 | } |
| 256 | |