Skip to content

Commit fdf0156

Browse files
committed
Add support for websocket header and origin config
- websocket client accepts custom header values - websocket server accepts a custom header origin handler function Signed-off-by: Lorenzo Donini <[email protected]>
1 parent 4186a52 commit fdf0156

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

ws/websocket.go

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,12 @@ type WsServer interface {
185185
// The handler function is called whenever a new client attempts to connect, to check for credentials correctness.
186186
// The handler must return true if the credentials were correct, false otherwise.
187187
SetBasicAuthHandler(handler func(username string, password string) bool)
188+
// SetCheckOriginHandler sets a handler for incoming websocket connections, allowing to perform
189+
// custom cross-origin checks.
190+
//
191+
// By default, if the Origin header is present in the request, and the Origin host is not equal
192+
// to the Host request header, the websocket handshake fails.
193+
SetCheckOriginHandler(handler func(r *http.Request) bool)
188194
}
189195

190196
// Default implementation of a Websocket server.
@@ -264,6 +270,10 @@ func (server *Server) SetBasicAuthHandler(handler func(username string, password
264270
server.basicAuthHandler = handler
265271
}
266272

273+
func (server *Server) SetCheckOriginHandler(handler func(r *http.Request) bool) {
274+
server.upgrader.CheckOrigin = handler
275+
}
276+
267277
func (server *Server) error(err error) {
268278
if server.errC != nil {
269279
server.errC <- err
@@ -518,6 +528,10 @@ type WsClient interface {
518528
// SetBasicAuth adds basic authentication credentials, to use when connecting to the server.
519529
// The credentials are automatically encoded in base64.
520530
SetBasicAuth(username string, password string)
531+
// SetHeaderValue sets a value on the HTTP header sent when opening a websocket connection to the server.
532+
//
533+
// The function overwrites previous header fields with the same key.
534+
SetHeaderValue(key string, value string)
521535
}
522536

523537
// Client is the the default implementation of a Websocket client.
@@ -527,7 +541,7 @@ type Client struct {
527541
webSocket WebSocket
528542
messageHandler func(data []byte) error
529543
dialOptions []func(*websocket.Dialer)
530-
authHeader http.Header
544+
header http.Header
531545
timeoutConfig ClientTimeoutConfig
532546
errC chan error
533547
}
@@ -537,7 +551,7 @@ type Client struct {
537551
// Additional options may be added using the AddOption function.
538552
// Basic authentication can be set using the SetBasicAuth function.
539553
func NewClient() *Client {
540-
return &Client{dialOptions: []func(*websocket.Dialer){}, timeoutConfig: NewClientTimeoutConfig()}
554+
return &Client{dialOptions: []func(*websocket.Dialer){}, timeoutConfig: NewClientTimeoutConfig(), header: http.Header{}}
541555
}
542556

543557
// Creates a new secure websocket client. If supported by the server, the websocket channel will use TLS.
@@ -558,7 +572,7 @@ func NewClient() *Client {
558572
// self-signed certificate (do not use in production!), pass:
559573
// InsecureSkipVerify: true
560574
func NewTLSClient(tlsConfig *tls.Config) *Client {
561-
client := &Client{dialOptions: []func(*websocket.Dialer){}, timeoutConfig: NewClientTimeoutConfig()}
575+
client := &Client{dialOptions: []func(*websocket.Dialer){}, timeoutConfig: NewClientTimeoutConfig(), header: http.Header{}}
562576
client.dialOptions = append(client.dialOptions, func(dialer *websocket.Dialer) {
563577
dialer.TLSClientConfig = tlsConfig
564578
})
@@ -581,9 +595,11 @@ func (client *Client) AddOption(option interface{}) {
581595
}
582596

583597
func (client *Client) SetBasicAuth(username string, password string) {
584-
client.authHeader = http.Header{
585-
"Authorization": {"Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))},
586-
}
598+
client.header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(username+":"+password)))
599+
}
600+
601+
func (client *Client) SetHeaderValue(key string, value string) {
602+
client.header.Set(key, value)
587603
}
588604

589605
func (client *Client) writePump() {
@@ -669,7 +685,7 @@ func (client *Client) Start(url string) error {
669685
for _, option := range client.dialOptions {
670686
option(&dialer)
671687
}
672-
ws, resp, err := dialer.Dial(url, client.authHeader)
688+
ws, resp, err := dialer.Dial(url, client.header)
673689
if err != nil {
674690
if resp != nil {
675691
httpError := HttpConnectionError{Message: err.Error(), HttpStatus: resp.Status, HttpCode: resp.StatusCode}

0 commit comments

Comments
 (0)