@@ -336,6 +336,84 @@ func TestInvalidBasicAuth(t *testing.T) {
336336 wsServer .Stop ()
337337}
338338
339+ func TestInvalidOriginHeader (t * testing.T ) {
340+ var wsServer * Server
341+ wsServer = NewWebsocketServer (t , func (data []byte ) ([]byte , error ) {
342+ assert .Fail (t , "no message should be received from client!" )
343+ return nil , nil
344+ })
345+ wsServer .SetNewClientHandler (func (ws Channel ) {
346+ assert .Fail (t , "no new connection should be received from client!" )
347+ })
348+ go wsServer .Start (serverPort , serverPath )
349+ time .Sleep (500 * time .Millisecond )
350+
351+ // Test message
352+ wsClient := NewWebsocketClient (t , func (data []byte ) ([]byte , error ) {
353+ assert .Fail (t , "no message should be received from server!" )
354+ return nil , nil
355+ })
356+ // Set invalid origin header
357+ wsClient .SetHeaderValue ("Origin" , "example.org" )
358+ host := fmt .Sprintf ("localhost:%v" , serverPort )
359+ u := url.URL {Scheme : "ws" , Host : host , Path : testPath }
360+ // Attempt to connect and expect cross-origin error
361+ err := wsClient .Start (u .String ())
362+ require .Error (t , err )
363+ httpErr , ok := err .(HttpConnectionError )
364+ require .True (t , ok )
365+ assert .Equal (t , http .StatusForbidden , httpErr .HttpCode )
366+ assert .Equal (t , http .StatusForbidden , httpErr .HttpCode )
367+ assert .Equal (t , "websocket: bad handshake" , httpErr .Message )
368+ // Cleanup
369+ wsServer .Stop ()
370+ }
371+
372+ func TestCustomOriginHeaderHandler (t * testing.T ) {
373+ var wsServer * Server
374+ origin := "example.org"
375+ connected := make (chan bool )
376+ wsServer = NewWebsocketServer (t , func (data []byte ) ([]byte , error ) {
377+ assert .Fail (t , "no message should be received from client!" )
378+ return nil , nil
379+ })
380+ wsServer .SetNewClientHandler (func (ws Channel ) {
381+ connected <- true
382+ })
383+ wsServer .SetCheckOriginHandler (func (r * http.Request ) bool {
384+ return r .Header .Get ("Origin" ) == origin
385+ })
386+ go wsServer .Start (serverPort , serverPath )
387+ time .Sleep (500 * time .Millisecond )
388+
389+ // Test message
390+ wsClient := NewWebsocketClient (t , func (data []byte ) ([]byte , error ) {
391+ assert .Fail (t , "no message should be received from server!" )
392+ return nil , nil
393+ })
394+ // Set invalid origin header (not example.org)
395+ wsClient .SetHeaderValue ("Origin" , "localhost" )
396+ host := fmt .Sprintf ("localhost:%v" , serverPort )
397+ u := url.URL {Scheme : "ws" , Host : host , Path : testPath }
398+ // Attempt to connect and expect cross-origin error
399+ err := wsClient .Start (u .String ())
400+ require .Error (t , err )
401+ httpErr , ok := err .(HttpConnectionError )
402+ require .True (t , ok )
403+ assert .Equal (t , http .StatusForbidden , httpErr .HttpCode )
404+ assert .Equal (t , http .StatusForbidden , httpErr .HttpCode )
405+ assert .Equal (t , "websocket: bad handshake" , httpErr .Message )
406+
407+ // Re-attempt with correct header
408+ wsClient .SetHeaderValue ("Origin" , "example.org" )
409+ err = wsClient .Start (u .String ())
410+ require .NoError (t , err )
411+ result := <- connected
412+ assert .True (t , result )
413+ // Cleanup
414+ wsServer .Stop ()
415+ }
416+
339417func TestValidClientTLSCertificate (t * testing.T ) {
340418 var wsServer * Server
341419 // Create self-signed TLS certificate
0 commit comments