cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2021.12.3

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

datagramsession/session_test.go

197lines · modecode

1package datagramsession
2
3import (
4 "bytes"
5 "context"
6 "fmt"
7 "io"
8 "net"
9 "sync"
10 "testing"
11 "time"
12
13 "github.com/google/uuid"
14 "github.com/rs/zerolog"
15 "github.com/stretchr/testify/require"
16 "golang.org/x/sync/errgroup"
17)
18
19// TestCloseSession makes sure a session will stop after context is done
20func TestSessionCtxDone(t *testing.T) {
21 testSessionReturns(t, closeByContext, time.Minute*2)
22}
23
24// TestCloseSession makes sure a session will stop after close method is called
25func TestCloseSession(t *testing.T) {
26 testSessionReturns(t, closeByCallingClose, time.Minute*2)
27}
28
29// TestCloseIdle makess sure a session will stop after there is no read/write for a period defined by closeAfterIdle
30func TestCloseIdle(t *testing.T) {
31 testSessionReturns(t, closeByTimeout, time.Millisecond*100)
32}
33
34func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.Duration) {
35 var (
36 localCloseReason = &errClosedSession{
37 message: "connection closed by origin",
38 byRemote: false,
39 }
40 )
41 sessionID := uuid.New()
42 cfdConn, originConn := net.Pipe()
43 payload := testPayload(sessionID)
44 transport := &mockQUICTransport{
45 reqChan: newDatagramChannel(1),
46 respChan: newDatagramChannel(1),
47 }
48 log := zerolog.Nop()
49 session := newSession(sessionID, transport, cfdConn, &log)
50
51 ctx, cancel := context.WithCancel(context.Background())
52 sessionDone := make(chan struct{})
53 go func() {
54 closedByRemote, err := session.Serve(ctx, closeAfterIdle)
55 switch closeBy {
56 case closeByContext:
57 require.Equal(t, context.Canceled, err)
58 require.False(t, closedByRemote)
59 case closeByCallingClose:
60 require.Equal(t, localCloseReason, err)
61 require.Equal(t, localCloseReason.byRemote, closedByRemote)
62 case closeByTimeout:
63 require.Equal(t, SessionIdleErr(closeAfterIdle), err)
64 require.False(t, closedByRemote)
65 }
66 close(sessionDone)
67 }()
68
69 go func() {
70 n, err := session.transportToDst(payload)
71 require.NoError(t, err)
72 require.Equal(t, len(payload), n)
73 }()
74
75 readBuffer := make([]byte, len(payload)+1)
76 n, err := originConn.Read(readBuffer)
77 require.NoError(t, err)
78 require.Equal(t, len(payload), n)
79
80 lastRead := time.Now()
81
82 switch closeBy {
83 case closeByContext:
84 cancel()
85 case closeByCallingClose:
86 session.close(localCloseReason)
87 }
88
89 <-sessionDone
90 if closeBy == closeByTimeout {
91 require.True(t, time.Now().After(lastRead.Add(closeAfterIdle)))
92 }
93 // call cancelled again otherwise the linter will warn about possible context leak
94 cancel()
95}
96
97type closeMethod int
98
99const (
100 closeByContext closeMethod = iota
101 closeByCallingClose
102 closeByTimeout
103)
104
105func TestWriteToDstSessionPreventClosed(t *testing.T) {
106 testActiveSessionNotClosed(t, false, true)
107}
108
109func TestReadFromDstSessionPreventClosed(t *testing.T) {
110 testActiveSessionNotClosed(t, true, false)
111}
112
113func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) {
114 const closeAfterIdle = time.Millisecond * 100
115 const activeTime = time.Millisecond * 500
116
117 sessionID := uuid.New()
118 cfdConn, originConn := net.Pipe()
119 payload := testPayload(sessionID)
120 transport := &mockQUICTransport{
121 reqChan: newDatagramChannel(100),
122 respChan: newDatagramChannel(100),
123 }
124 log := zerolog.Nop()
125 session := newSession(sessionID, transport, cfdConn, &log)
126
127 startTime := time.Now()
128 activeUntil := startTime.Add(activeTime)
129 ctx, cancel := context.WithCancel(context.Background())
130 errGroup, ctx := errgroup.WithContext(ctx)
131 errGroup.Go(func() error {
132 session.Serve(ctx, closeAfterIdle)
133 if time.Now().Before(startTime.Add(activeTime)) {
134 return fmt.Errorf("session closed while it's still active")
135 }
136 return nil
137 })
138
139 if readFromDst {
140 errGroup.Go(func() error {
141 for {
142 if time.Now().After(activeUntil) {
143 return nil
144 }
145 if _, err := originConn.Write(payload); err != nil {
146 return err
147 }
148 time.Sleep(closeAfterIdle / 2)
149 }
150 })
151 }
152 if writeToDst {
153 errGroup.Go(func() error {
154 readBuffer := make([]byte, len(payload))
155 for {
156 n, err := originConn.Read(readBuffer)
157 if err != nil {
158 if err == io.EOF || err == io.ErrClosedPipe {
159 return nil
160 }
161 return err
162 }
163 if !bytes.Equal(payload, readBuffer[:n]) {
164 return fmt.Errorf("payload %v is not equal to %v", readBuffer[:n], payload)
165 }
166 }
167 })
168 errGroup.Go(func() error {
169 for {
170 if time.Now().After(activeUntil) {
171 return nil
172 }
173 if _, err := session.transportToDst(payload); err != nil {
174 return err
175 }
176 time.Sleep(closeAfterIdle / 2)
177 }
178 })
179 }
180
181 require.NoError(t, errGroup.Wait())
182 cancel()
183}
184
185func TestMarkActiveNotBlocking(t *testing.T) {
186 const concurrentCalls = 50
187 session := newSession(uuid.New(), nil, nil, nil)
188 var wg sync.WaitGroup
189 wg.Add(concurrentCalls)
190 for i := 0; i < concurrentCalls; i++ {
191 go func() {
192 session.markActive()
193 wg.Done()
194 }()
195 }
196 wg.Wait()
197}
198