Skip to content

Commit efffd96

Browse files
committed
Improve websocket tests
Signed-off-by: Lorenzo Donini <[email protected]>
1 parent 2c8eaf1 commit efffd96

File tree

2 files changed

+116
-59
lines changed

2 files changed

+116
-59
lines changed

ws/websocket.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,8 @@ func (server *Server) wsHandler(w http.ResponseWriter, r *http.Request) {
407407
tlsConnectionState: r.TLS,
408408
}
409409
server.connMutex.Lock()
410-
defer server.connMutex.Unlock()
411410
server.connections[ws.id] = &ws
411+
server.connMutex.Unlock()
412412
// Read and write routines are started in separate goroutines and function will return immediately
413413
go server.writePump(&ws)
414414
go server.readPump(&ws)

ws/websocket_test.go

Lines changed: 115 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"net/http"
1717
"net/url"
1818
"os"
19+
"strings"
1920
"testing"
2021
"time"
2122

@@ -79,39 +80,52 @@ func TestWebsocketSetConnected(t *testing.T) {
7980

8081
func TestWebsocketEcho(t *testing.T) {
8182
message := []byte("Hello WebSocket!")
82-
var wsServer *Server
83-
wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) {
83+
triggerC := make(chan bool, 1)
84+
done := make(chan bool, 1)
85+
wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) {
8486
assert.True(t, bytes.Equal(message, data))
87+
// Message received, notifying flow routine
88+
triggerC <- true
8589
return data, nil
8690
})
87-
go wsServer.Start(serverPort, serverPath)
88-
time.Sleep(1 * time.Second)
89-
90-
// Test message
91+
wsServer.SetNewClientHandler(func(ws Channel) {
92+
tlsState := ws.GetTLSConnectionState()
93+
assert.Nil(t, tlsState)
94+
})
95+
wsServer.SetDisconnectedClientHandler(func(ws Channel) {
96+
// Connection closed, completing test
97+
done <- true
98+
})
9199
wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) {
92100
assert.True(t, bytes.Equal(message, data))
101+
// Echo response received, notifying flow routine
102+
triggerC <- true
93103
return nil, nil
94104
})
95-
host := fmt.Sprintf("localhost:%v", serverPort)
96-
u := url.URL{Scheme: "ws", Host: host, Path: testPath}
97-
// Wait for connection to be established, then send a message
98-
go func() {
99-
timer := time.NewTimer(1 * time.Second)
100-
<-timer.C
101-
err := wsClient.Write(message)
102-
assert.Nil(t, err)
103-
}()
104-
done := make(chan bool)
105-
// Wait for messages to be exchanged, then close connection
105+
// Start server
106+
go wsServer.Start(serverPort, serverPath)
107+
// Start flow routine
106108
go func() {
107-
timer := time.NewTimer(3 * time.Second)
108-
<-timer.C
109+
// Wait for messages to be exchanged, then close connection
110+
sig, _ := <-triggerC
111+
assert.True(t, sig)
112+
err := wsServer.Write(testPath, message)
113+
require.Nil(t, err)
114+
sig, _ = <-triggerC
115+
assert.True(t, sig)
109116
wsClient.Stop()
110-
done <- true
111117
}()
118+
time.Sleep(200 * time.Millisecond)
119+
120+
// Test message
121+
host := fmt.Sprintf("localhost:%v", serverPort)
122+
u := url.URL{Scheme: "ws", Host: host, Path: testPath}
112123
err := wsClient.Start(u.String())
113-
assert.Nil(t, err)
114-
assert.True(t, wsClient.IsConnected())
124+
require.NoError(t, err)
125+
require.True(t, wsClient.IsConnected())
126+
err = wsClient.Write(message)
127+
require.NoError(t, err)
128+
// Wait for echo result
115129
result := <-done
116130
assert.True(t, result)
117131
// Cleanup
@@ -120,12 +134,23 @@ func TestWebsocketEcho(t *testing.T) {
120134

121135
func TestTLSWebsocketEcho(t *testing.T) {
122136
message := []byte("Hello Secure WebSocket!")
123-
var wsServer *Server
137+
triggerC := make(chan bool, 1)
138+
done := make(chan bool, 1)
124139
// Use NewTLSServer() when in different package
125-
wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) {
140+
wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) {
126141
assert.True(t, bytes.Equal(message, data))
142+
// Message received, notifying flow routine
143+
triggerC <- true
127144
return data, nil
128145
})
146+
wsServer.SetNewClientHandler(func(ws Channel) {
147+
tlsState := ws.GetTLSConnectionState()
148+
assert.NotNil(t, tlsState)
149+
})
150+
wsServer.SetDisconnectedClientHandler(func(ws Channel) {
151+
// Connection closed, completing test
152+
done <- true
153+
})
129154
// Create self-signed TLS certificate
130155
certFilename := "/tmp/cert.pem"
131156
keyFilename := "/tmp/key.pem"
@@ -137,12 +162,11 @@ func TestTLSWebsocketEcho(t *testing.T) {
137162
// Set self-signed TLS certificate
138163
wsServer.tlsCertificatePath = certFilename
139164
wsServer.tlsCertificateKey = keyFilename
140-
go wsServer.Start(serverPort, serverPath)
141-
time.Sleep(1 * time.Second)
142-
143165
// Create TLS client
144166
wsClient := newWebsocketClient(t, func(data []byte) ([]byte, error) {
145167
assert.True(t, bytes.Equal(message, data))
168+
// Echo response received, notifying flow routine
169+
triggerC <- true
146170
return nil, nil
147171
})
148172
wsClient.AddOption(func(dialer *websocket.Dialer) {
@@ -155,37 +179,66 @@ func TestTLSWebsocketEcho(t *testing.T) {
155179
RootCAs: certPool,
156180
}
157181
})
158-
// Test message
159-
host := fmt.Sprintf("localhost:%v", serverPort)
160-
u := url.URL{Scheme: "wss", Host: host, Path: testPath}
161-
// Wait for connection to be established, then send a message to server
162-
go func() {
163-
timer := time.NewTimer(1 * time.Second)
164-
<-timer.C
165-
err := wsClient.Write(message)
166-
assert.Nil(t, err)
167-
}()
168-
done := make(chan bool)
169-
// Wait for messages to be exchanged, then close connection
182+
183+
// Start server
184+
go wsServer.Start(serverPort, serverPath)
185+
// Start flow routine
170186
go func() {
171-
timer := time.NewTimer(3 * time.Second)
172-
<-timer.C
187+
// Wait for messages to be exchanged, then close connection
188+
sig, _ := <-triggerC
189+
assert.True(t, sig)
190+
err := wsServer.Write(testPath, message)
191+
require.NoError(t, err)
192+
sig, _ = <-triggerC
193+
assert.True(t, sig)
173194
wsClient.Stop()
174-
done <- true
175195
}()
196+
time.Sleep(200 * time.Millisecond)
197+
198+
// Test message
199+
host := fmt.Sprintf("localhost:%v", serverPort)
200+
u := url.URL{Scheme: "wss", Host: host, Path: testPath}
176201
err = wsClient.Start(u.String())
177-
assert.Nil(t, err)
202+
require.NoError(t, err)
203+
require.True(t, wsClient.IsConnected())
204+
err = wsClient.Write(message)
205+
require.NoError(t, err)
206+
// Wait for echo result
178207
result := <-done
179208
assert.True(t, result)
180209
// Cleanup
181210
wsServer.Stop()
182211
}
183212

213+
func TestServerStartErrors(t *testing.T) {
214+
triggerC := make(chan bool, 1)
215+
wsServer := newWebsocketServer(t, nil)
216+
wsServer.SetNewClientHandler(func(ws Channel) {
217+
triggerC <- true
218+
})
219+
// Make sure http server is initialized on start
220+
wsServer.httpServer = nil
221+
// Listen for errors
222+
go func() {
223+
err, ok := <-wsServer.Errors()
224+
assert.True(t, ok)
225+
assert.Error(t, err)
226+
triggerC <- true
227+
}()
228+
time.Sleep(100 * time.Millisecond)
229+
go wsServer.Start(serverPort, serverPath)
230+
time.Sleep(100 * time.Millisecond)
231+
// Starting server again throws error
232+
wsServer.Start(serverPort, serverPath)
233+
r, _ := <-triggerC
234+
require.True(t, r)
235+
wsServer.Stop()
236+
}
237+
184238
func TestWebsocketClientConnectionBreak(t *testing.T) {
185239
newClient := make(chan bool)
186240
disconnected := make(chan bool)
187-
var wsServer *Server
188-
wsServer = newWebsocketServer(t, nil)
241+
wsServer := newWebsocketServer(t, nil)
189242
wsServer.SetNewClientHandler(func(ws Channel) {
190243
newClient <- true
191244
})
@@ -217,9 +270,8 @@ func TestWebsocketClientConnectionBreak(t *testing.T) {
217270
}
218271

219272
func TestWebsocketServerConnectionBreak(t *testing.T) {
220-
var wsServer *Server
221273
disconnected := make(chan bool)
222-
wsServer = newWebsocketServer(t, nil)
274+
wsServer := newWebsocketServer(t, nil)
223275
wsServer.SetNewClientHandler(func(ws Channel) {
224276
assert.NotNil(t, ws)
225277
conn := wsServer.connections[ws.GetID()]
@@ -249,7 +301,6 @@ func TestWebsocketServerConnectionBreak(t *testing.T) {
249301
func TestValidBasicAuth(t *testing.T) {
250302
authUsername := "testUsername"
251303
authPassword := "testPassword"
252-
var wsServer *Server
253304
// Create self-signed TLS certificate
254305
certFilename := "/tmp/cert.pem"
255306
keyFilename := "/tmp/key.pem"
@@ -259,7 +310,7 @@ func TestValidBasicAuth(t *testing.T) {
259310
defer os.Remove(keyFilename)
260311

261312
// Create TLS server with self-signed certificate
262-
wsServer = NewTLSServer(certFilename, keyFilename, nil)
313+
wsServer := NewTLSServer(certFilename, keyFilename, nil)
263314
// Add basic auth handler
264315
wsServer.SetBasicAuthHandler(func(username string, password string) bool {
265316
require.Equal(t, authUsername, username)
@@ -300,7 +351,6 @@ func TestValidBasicAuth(t *testing.T) {
300351
func TestInvalidBasicAuth(t *testing.T) {
301352
authUsername := "testUsername"
302353
authPassword := "testPassword"
303-
var wsServer *Server
304354
// Create self-signed TLS certificate
305355
certFilename := "/tmp/cert.pem"
306356
keyFilename := "/tmp/key.pem"
@@ -310,7 +360,7 @@ func TestInvalidBasicAuth(t *testing.T) {
310360
defer os.Remove(keyFilename)
311361

312362
// Create TLS server with self-signed certificate
313-
wsServer = NewTLSServer(certFilename, keyFilename, nil)
363+
wsServer := NewTLSServer(certFilename, keyFilename, nil)
314364
// Add basic auth handler
315365
wsServer.SetBasicAuthHandler(func(username string, password string) bool {
316366
validCredentials := authUsername == username && authPassword == password
@@ -338,7 +388,14 @@ func TestInvalidBasicAuth(t *testing.T) {
338388
host := fmt.Sprintf("localhost:%v", serverPort)
339389
u := url.URL{Scheme: "wss", Host: host, Path: testPath}
340390
err = wsClient.Start(u.String())
391+
// Assert HTTP error
341392
assert.Error(t, err)
393+
httpErr, ok := err.(HttpConnectionError)
394+
require.True(t, ok)
395+
assert.Equal(t, http.StatusUnauthorized, httpErr.HttpCode)
396+
assert.Equal(t, "401 Unauthorized", httpErr.HttpStatus)
397+
assert.Equal(t, "websocket: bad handshake", httpErr.Message)
398+
assert.True(t, strings.Contains(err.Error(), "http status:"))
342399
// Add basic auth
343400
wsClient.SetBasicAuth(authUsername, "invalidPassword")
344401
// Test connection
@@ -353,8 +410,7 @@ func TestInvalidBasicAuth(t *testing.T) {
353410
}
354411

355412
func TestInvalidOriginHeader(t *testing.T) {
356-
var wsServer *Server
357-
wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) {
413+
wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) {
358414
assert.Fail(t, "no message should be received from client!")
359415
return nil, nil
360416
})
@@ -386,10 +442,9 @@ func TestInvalidOriginHeader(t *testing.T) {
386442
}
387443

388444
func TestCustomOriginHeaderHandler(t *testing.T) {
389-
var wsServer *Server
390445
origin := "example.org"
391446
connected := make(chan bool)
392-
wsServer = newWebsocketServer(t, func(data []byte) ([]byte, error) {
447+
wsServer := newWebsocketServer(t, func(data []byte) ([]byte, error) {
393448
assert.Fail(t, "no message should be received from client!")
394449
return nil, nil
395450
})
@@ -519,7 +574,7 @@ func TestInvalidClientTLSCertificate(t *testing.T) {
519574
})
520575
// Run server
521576
go wsServer.Start(serverPort, serverPath)
522-
time.Sleep(1 * time.Second)
577+
time.Sleep(200 * time.Millisecond)
523578

524579
// Create TLS client
525580
certPool = x509.NewCertPool()
@@ -546,9 +601,8 @@ func TestInvalidClientTLSCertificate(t *testing.T) {
546601
}
547602

548603
func TestUnsupportedSubprotocol(t *testing.T) {
549-
var wsServer *Server
550604
disconnected := make(chan bool)
551-
wsServer = newWebsocketServer(t, nil)
605+
wsServer := newWebsocketServer(t, nil)
552606
wsServer.SetNewClientHandler(func(ws Channel) {
553607
assert.Fail(t, "invalid subprotocol expected, but hit client handler instead")
554608
t.Fail()
@@ -710,6 +764,9 @@ func TestServerErrors(t *testing.T) {
710764
require.NoError(t, err)
711765
r, _ = <-triggerC
712766
assert.True(t, r)
767+
// Send message to non-existing client
768+
err = wsServer.Write("fakeId", []byte("dummy response"))
769+
require.Error(t, err)
713770
// Send unexpected close message and wait for error to be thrown
714771
err = wsClient.webSocket.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseUnsupportedData, ""))
715772
assert.NoError(t, err)

0 commit comments

Comments
 (0)