cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2021.12.4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

connection/control.go

100lines · modepreview

package connection

import (
	"context"
	"io"
	"time"

	"github.com/rs/zerolog"

	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)

// RPCClientFunc derives a named tunnel rpc client that can then be used to register and unregister connections.
type RPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient

type controlStream struct {
	observer *Observer

	connectedFuse     ConnectedFuse
	namedTunnelConfig *NamedTunnelConfig
	connIndex         uint8

	newRPCClientFunc RPCClientFunc

	gracefulShutdownC <-chan struct{}
	gracePeriod       time.Duration
	stoppedGracefully bool
}

// ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown.
type ControlStreamHandler interface {
	ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, shouldWaitForUnregister bool) error
	IsStopped() bool
}

// NewControlStream returns a new instance of ControlStreamHandler
func NewControlStream(
	observer *Observer,
	connectedFuse ConnectedFuse,
	namedTunnelConfig *NamedTunnelConfig,
	connIndex uint8,
	newRPCClientFunc RPCClientFunc,
	gracefulShutdownC <-chan struct{},
	gracePeriod time.Duration,
) ControlStreamHandler {
	if newRPCClientFunc == nil {
		newRPCClientFunc = newRegistrationRPCClient
	}
	return &controlStream{
		observer:          observer,
		connectedFuse:     connectedFuse,
		namedTunnelConfig: namedTunnelConfig,
		newRPCClientFunc:  newRPCClientFunc,
		connIndex:         connIndex,
		gracefulShutdownC: gracefulShutdownC,
		gracePeriod:       gracePeriod,
	}
}

func (c *controlStream) ServeControlStream(
	ctx context.Context,
	rw io.ReadWriteCloser,
	connOptions *tunnelpogs.ConnectionOptions,
	shouldWaitForUnregister bool,
) error {
	rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log)

	if err := rpcClient.RegisterConnection(ctx, c.namedTunnelConfig, connOptions, c.connIndex, c.observer); err != nil {
		rpcClient.Close()
		return err
	}
	c.connectedFuse.Connected()

	if shouldWaitForUnregister {
		c.waitForUnregister(ctx, rpcClient)
	} else {
		go c.waitForUnregister(ctx, rpcClient)
	}

	return nil
}

func (c *controlStream) waitForUnregister(ctx context.Context, rpcClient NamedTunnelRPCClient) {
	// wait for connection termination or start of graceful shutdown
	defer rpcClient.Close()
	select {
	case <-ctx.Done():
		break
	case <-c.gracefulShutdownC:
		c.stoppedGracefully = true
	}

	c.observer.sendUnregisteringEvent(c.connIndex)
	rpcClient.GracefulShutdown(ctx, c.gracePeriod)
	c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Unregistered tunnel connection")
}

func (c *controlStream) IsStopped() bool {
	return c.stoppedGracefully
}