cloudflare/cloudflared

Public

mirrored from https://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2021.12.2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

datagramsession/session_test.go

194lines · modecode

1package datagramsession
2
3import (
4 "bytes"
5 "context"
6 "fmt"
7 "io"
8 "net"
9 "sync"
10 "testing"
11 "time"
12
13 "github.com/google/uuid"
14 "github.com/stretchr/testify/require"
15 "golang.org/x/sync/errgroup"
16)
17
18// TestCloseSession makes sure a session will stop after context is done
19func TestSessionCtxDone(t *testing.T) {
20 testSessionReturns(t, closeByContext, time.Minute*2)
21}
22
23// TestCloseSession makes sure a session will stop after close method is called
24func TestCloseSession(t *testing.T) {
25 testSessionReturns(t, closeByCallingClose, time.Minute*2)
26}
27
28// TestCloseIdle makess sure a session will stop after there is no read/write for a period defined by closeAfterIdle
29func TestCloseIdle(t *testing.T) {
30 testSessionReturns(t, closeByTimeout, time.Millisecond*100)
31}
32
33func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.Duration) {
34 var (
35 localCloseReason = &errClosedSession{
36 message: "connection closed by origin",
37 byRemote: false,
38 }
39 )
40 sessionID := uuid.New()
41 cfdConn, originConn := net.Pipe()
42 payload := testPayload(sessionID)
43 transport := &mockQUICTransport{
44 reqChan: newDatagramChannel(1),
45 respChan: newDatagramChannel(1),
46 }
47 session := newSession(sessionID, transport, cfdConn)
48
49 ctx, cancel := context.WithCancel(context.Background())
50 sessionDone := make(chan struct{})
51 go func() {
52 closedByRemote, err := session.Serve(ctx, closeAfterIdle)
53 switch closeBy {
54 case closeByContext:
55 require.Equal(t, context.Canceled, err)
56 require.False(t, closedByRemote)
57 case closeByCallingClose:
58 require.Equal(t, localCloseReason, err)
59 require.Equal(t, localCloseReason.byRemote, closedByRemote)
60 case closeByTimeout:
61 require.Equal(t, SessionIdleErr(closeAfterIdle), err)
62 require.False(t, closedByRemote)
63 }
64 close(sessionDone)
65 }()
66
67 go func() {
68 n, err := session.transportToDst(payload)
69 require.NoError(t, err)
70 require.Equal(t, len(payload), n)
71 }()
72
73 readBuffer := make([]byte, len(payload)+1)
74 n, err := originConn.Read(readBuffer)
75 require.NoError(t, err)
76 require.Equal(t, len(payload), n)
77
78 lastRead := time.Now()
79
80 switch closeBy {
81 case closeByContext:
82 cancel()
83 case closeByCallingClose:
84 session.close(localCloseReason)
85 }
86
87 <-sessionDone
88 if closeBy == closeByTimeout {
89 require.True(t, time.Now().After(lastRead.Add(closeAfterIdle)))
90 }
91 // call cancelled again otherwise the linter will warn about possible context leak
92 cancel()
93}
94
95type closeMethod int
96
97const (
98 closeByContext closeMethod = iota
99 closeByCallingClose
100 closeByTimeout
101)
102
103func TestWriteToDstSessionPreventClosed(t *testing.T) {
104 testActiveSessionNotClosed(t, false, true)
105}
106
107func TestReadFromDstSessionPreventClosed(t *testing.T) {
108 testActiveSessionNotClosed(t, true, false)
109}
110
111func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) {
112 const closeAfterIdle = time.Millisecond * 100
113 const activeTime = time.Millisecond * 500
114
115 sessionID := uuid.New()
116 cfdConn, originConn := net.Pipe()
117 payload := testPayload(sessionID)
118 transport := &mockQUICTransport{
119 reqChan: newDatagramChannel(100),
120 respChan: newDatagramChannel(100),
121 }
122 session := newSession(sessionID, transport, cfdConn)
123
124 startTime := time.Now()
125 activeUntil := startTime.Add(activeTime)
126 ctx, cancel := context.WithCancel(context.Background())
127 errGroup, ctx := errgroup.WithContext(ctx)
128 errGroup.Go(func() error {
129 session.Serve(ctx, closeAfterIdle)
130 if time.Now().Before(startTime.Add(activeTime)) {
131 return fmt.Errorf("session closed while it's still active")
132 }
133 return nil
134 })
135
136 if readFromDst {
137 errGroup.Go(func() error {
138 for {
139 if time.Now().After(activeUntil) {
140 return nil
141 }
142 if _, err := originConn.Write(payload); err != nil {
143 return err
144 }
145 time.Sleep(closeAfterIdle / 2)
146 }
147 })
148 }
149 if writeToDst {
150 errGroup.Go(func() error {
151 readBuffer := make([]byte, len(payload))
152 for {
153 n, err := originConn.Read(readBuffer)
154 if err != nil {
155 if err == io.EOF || err == io.ErrClosedPipe {
156 return nil
157 }
158 return err
159 }
160 if !bytes.Equal(payload, readBuffer[:n]) {
161 return fmt.Errorf("payload %v is not equal to %v", readBuffer[:n], payload)
162 }
163 }
164 })
165 errGroup.Go(func() error {
166 for {
167 if time.Now().After(activeUntil) {
168 return nil
169 }
170 if _, err := session.transportToDst(payload); err != nil {
171 return err
172 }
173 time.Sleep(closeAfterIdle / 2)
174 }
175 })
176 }
177
178 require.NoError(t, errGroup.Wait())
179 cancel()
180}
181
182func TestMarkActiveNotBlocking(t *testing.T) {
183 const concurrentCalls = 50
184 session := newSession(uuid.New(), nil, nil)
185 var wg sync.WaitGroup
186 wg.Add(concurrentCalls)
187 for i := 0; i < concurrentCalls; i++ {
188 go func() {
189 session.markActive()
190 wg.Done()
191 }()
192 }
193 wg.Wait()
194}
195