cloudflare/cloudflared
Publicmirrored from https://github.com/cloudflare/cloudflaredAvailable
datagramsession/manager_test.go
222lines · modecode
| 1 | package datagramsession |
| 2 | |
| 3 | import ( |
| 4 | "bytes" |
| 5 | "context" |
| 6 | "fmt" |
| 7 | "io" |
| 8 | "net" |
| 9 | "testing" |
| 10 | "time" |
| 11 | |
| 12 | "github.com/google/uuid" |
| 13 | "github.com/rs/zerolog" |
| 14 | "github.com/stretchr/testify/require" |
| 15 | "golang.org/x/sync/errgroup" |
| 16 | ) |
| 17 | |
| 18 | func TestManagerServe(t *testing.T) { |
| 19 | const ( |
| 20 | sessions = 20 |
| 21 | msgs = 50 |
| 22 | remoteUnregisterMsg = "eyeball closed connection" |
| 23 | ) |
| 24 | log := zerolog.Nop() |
| 25 | transport := &mockQUICTransport{ |
| 26 | reqChan: newDatagramChannel(1), |
| 27 | respChan: newDatagramChannel(1), |
| 28 | } |
| 29 | mg := NewManager(transport, &log) |
| 30 | |
| 31 | eyeballTracker := make(map[uuid.UUID]*datagramChannel) |
| 32 | for i := 0; i < sessions; i++ { |
| 33 | sessionID := uuid.New() |
| 34 | eyeballTracker[sessionID] = newDatagramChannel(1) |
| 35 | } |
| 36 | |
| 37 | ctx, cancel := context.WithCancel(context.Background()) |
| 38 | serveDone := make(chan struct{}) |
| 39 | go func(ctx context.Context) { |
| 40 | mg.Serve(ctx) |
| 41 | close(serveDone) |
| 42 | }(ctx) |
| 43 | |
| 44 | go func(ctx context.Context) { |
| 45 | for { |
| 46 | sessionID, payload, err := transport.respChan.Receive(ctx) |
| 47 | if err != nil { |
| 48 | require.Equal(t, context.Canceled, err) |
| 49 | return |
| 50 | } |
| 51 | respChan := eyeballTracker[sessionID] |
| 52 | require.NoError(t, respChan.Send(ctx, sessionID, payload)) |
| 53 | } |
| 54 | }(ctx) |
| 55 | |
| 56 | errGroup, ctx := errgroup.WithContext(ctx) |
| 57 | for sID, receiver := range eyeballTracker { |
| 58 | // Assign loop variables to local variables |
| 59 | sessionID := sID |
| 60 | eyeballRespReceiver := receiver |
| 61 | errGroup.Go(func() error { |
| 62 | payload := testPayload(sessionID) |
| 63 | expectResp := testResponse(payload) |
| 64 | |
| 65 | cfdConn, originConn := net.Pipe() |
| 66 | |
| 67 | origin := mockOrigin{ |
| 68 | expectMsgCount: msgs, |
| 69 | expectedMsg: payload, |
| 70 | expectedResp: expectResp, |
| 71 | conn: originConn, |
| 72 | } |
| 73 | eyeball := mockEyeball{ |
| 74 | expectMsgCount: msgs, |
| 75 | expectedMsg: expectResp, |
| 76 | expectSessionID: sessionID, |
| 77 | respReceiver: eyeballRespReceiver, |
| 78 | } |
| 79 | |
| 80 | reqErrGroup, reqCtx := errgroup.WithContext(ctx) |
| 81 | reqErrGroup.Go(func() error { |
| 82 | return origin.serve() |
| 83 | }) |
| 84 | reqErrGroup.Go(func() error { |
| 85 | return eyeball.serve(reqCtx) |
| 86 | }) |
| 87 | |
| 88 | session, err := mg.RegisterSession(ctx, sessionID, cfdConn) |
| 89 | require.NoError(t, err) |
| 90 | |
| 91 | sessionDone := make(chan struct{}) |
| 92 | go func() { |
| 93 | closedByRemote, err := session.Serve(ctx, time.Minute*2) |
| 94 | closeSession := &errClosedSession{ |
| 95 | message: remoteUnregisterMsg, |
| 96 | byRemote: true, |
| 97 | } |
| 98 | require.Equal(t, closeSession, err) |
| 99 | require.True(t, closedByRemote) |
| 100 | close(sessionDone) |
| 101 | }() |
| 102 | |
| 103 | for i := 0; i < msgs; i++ { |
| 104 | require.NoError(t, transport.newRequest(ctx, sessionID, testPayload(sessionID))) |
| 105 | } |
| 106 | |
| 107 | // Make sure eyeball and origin have received all messages before unregistering the session |
| 108 | require.NoError(t, reqErrGroup.Wait()) |
| 109 | |
| 110 | require.NoError(t, mg.UnregisterSession(ctx, sessionID, remoteUnregisterMsg, true)) |
| 111 | <-sessionDone |
| 112 | |
| 113 | return nil |
| 114 | }) |
| 115 | } |
| 116 | |
| 117 | require.NoError(t, errGroup.Wait()) |
| 118 | cancel() |
| 119 | transport.close() |
| 120 | <-serveDone |
| 121 | } |
| 122 | |
| 123 | type mockOrigin struct { |
| 124 | expectMsgCount int |
| 125 | expectedMsg []byte |
| 126 | expectedResp []byte |
| 127 | conn io.ReadWriteCloser |
| 128 | } |
| 129 | |
| 130 | func (mo *mockOrigin) serve() error { |
| 131 | expectedMsgLen := len(mo.expectedMsg) |
| 132 | readBuffer := make([]byte, expectedMsgLen+1) |
| 133 | for i := 0; i < mo.expectMsgCount; i++ { |
| 134 | n, err := mo.conn.Read(readBuffer) |
| 135 | if err != nil { |
| 136 | return err |
| 137 | } |
| 138 | if n != expectedMsgLen { |
| 139 | return fmt.Errorf("Expect to read %d bytes, read %d", expectedMsgLen, n) |
| 140 | } |
| 141 | if !bytes.Equal(readBuffer[:n], mo.expectedMsg) { |
| 142 | return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n]) |
| 143 | } |
| 144 | |
| 145 | _, err = mo.conn.Write(mo.expectedResp) |
| 146 | if err != nil { |
| 147 | return err |
| 148 | } |
| 149 | } |
| 150 | return nil |
| 151 | } |
| 152 | |
| 153 | func testPayload(sessionID uuid.UUID) []byte { |
| 154 | return []byte(fmt.Sprintf("Message from %s", sessionID)) |
| 155 | } |
| 156 | |
| 157 | func testResponse(msg []byte) []byte { |
| 158 | return []byte(fmt.Sprintf("Response to %v", msg)) |
| 159 | } |
| 160 | |
| 161 | type mockEyeball struct { |
| 162 | expectMsgCount int |
| 163 | expectedMsg []byte |
| 164 | expectSessionID uuid.UUID |
| 165 | respReceiver *datagramChannel |
| 166 | } |
| 167 | |
| 168 | func (me *mockEyeball) serve(ctx context.Context) error { |
| 169 | for i := 0; i < me.expectMsgCount; i++ { |
| 170 | sessionID, msg, err := me.respReceiver.Receive(ctx) |
| 171 | if err != nil { |
| 172 | return err |
| 173 | } |
| 174 | if sessionID != me.expectSessionID { |
| 175 | return fmt.Errorf("Expect session %s, got %s", me.expectSessionID, sessionID) |
| 176 | } |
| 177 | if !bytes.Equal(msg, me.expectedMsg) { |
| 178 | return fmt.Errorf("Expect %v, read %v", me.expectedMsg, msg) |
| 179 | } |
| 180 | } |
| 181 | return nil |
| 182 | } |
| 183 | |
| 184 | // datagramChannel is a channel for Datagram with wrapper to send/receive with context |
| 185 | type datagramChannel struct { |
| 186 | datagramChan chan *newDatagram |
| 187 | closedChan chan struct{} |
| 188 | } |
| 189 | |
| 190 | func newDatagramChannel(capacity uint) *datagramChannel { |
| 191 | return &datagramChannel{ |
| 192 | datagramChan: make(chan *newDatagram, capacity), |
| 193 | closedChan: make(chan struct{}), |
| 194 | } |
| 195 | } |
| 196 | |
| 197 | func (rc *datagramChannel) Send(ctx context.Context, sessionID uuid.UUID, payload []byte) error { |
| 198 | select { |
| 199 | case <-ctx.Done(): |
| 200 | return ctx.Err() |
| 201 | case <-rc.closedChan: |
| 202 | return fmt.Errorf("datagram channel closed") |
| 203 | case rc.datagramChan <- &newDatagram{sessionID: sessionID, payload: payload}: |
| 204 | return nil |
| 205 | } |
| 206 | } |
| 207 | |
| 208 | func (rc *datagramChannel) Receive(ctx context.Context) (uuid.UUID, []byte, error) { |
| 209 | select { |
| 210 | case <-ctx.Done(): |
| 211 | return uuid.Nil, nil, ctx.Err() |
| 212 | case <-rc.closedChan: |
| 213 | return uuid.Nil, nil, fmt.Errorf("datagram channel closed") |
| 214 | case msg := <-rc.datagramChan: |
| 215 | return msg.sessionID, msg.payload, nil |
| 216 | } |
| 217 | } |
| 218 | |
| 219 | func (rc *datagramChannel) Close() { |
| 220 | // No need to close msgChan, it will be garbage collect once there is no reference to it |
| 221 | close(rc.closedChan) |
| 222 | } |
| 223 | |