@@ -185,6 +185,12 @@ type WsServer interface {
185185 // The handler function is called whenever a new client attempts to connect, to check for credentials correctness.
186186 // The handler must return true if the credentials were correct, false otherwise.
187187 SetBasicAuthHandler (handler func (username string , password string ) bool )
188+ // SetCheckOriginHandler sets a handler for incoming websocket connections, allowing to perform
189+ // custom cross-origin checks.
190+ //
191+ // By default, if the Origin header is present in the request, and the Origin host is not equal
192+ // to the Host request header, the websocket handshake fails.
193+ SetCheckOriginHandler (handler func (r * http.Request ) bool )
188194}
189195
190196// Default implementation of a Websocket server.
@@ -264,6 +270,10 @@ func (server *Server) SetBasicAuthHandler(handler func(username string, password
264270 server .basicAuthHandler = handler
265271}
266272
273+ func (server * Server ) SetCheckOriginHandler (handler func (r * http.Request ) bool ) {
274+ server .upgrader .CheckOrigin = handler
275+ }
276+
267277func (server * Server ) error (err error ) {
268278 if server .errC != nil {
269279 server .errC <- err
@@ -518,6 +528,10 @@ type WsClient interface {
518528 // SetBasicAuth adds basic authentication credentials, to use when connecting to the server.
519529 // The credentials are automatically encoded in base64.
520530 SetBasicAuth (username string , password string )
531+ // SetHeaderValue sets a value on the HTTP header sent when opening a websocket connection to the server.
532+ //
533+ // The function overwrites previous header fields with the same key.
534+ SetHeaderValue (key string , value string )
521535}
522536
523537// Client is the the default implementation of a Websocket client.
@@ -527,7 +541,7 @@ type Client struct {
527541 webSocket WebSocket
528542 messageHandler func (data []byte ) error
529543 dialOptions []func (* websocket.Dialer )
530- authHeader http.Header
544+ header http.Header
531545 timeoutConfig ClientTimeoutConfig
532546 errC chan error
533547}
@@ -537,7 +551,7 @@ type Client struct {
537551// Additional options may be added using the AddOption function.
538552// Basic authentication can be set using the SetBasicAuth function.
539553func NewClient () * Client {
540- return & Client {dialOptions : []func (* websocket.Dialer ){}, timeoutConfig : NewClientTimeoutConfig ()}
554+ return & Client {dialOptions : []func (* websocket.Dialer ){}, timeoutConfig : NewClientTimeoutConfig (), header : http. Header {} }
541555}
542556
543557// Creates a new secure websocket client. If supported by the server, the websocket channel will use TLS.
@@ -558,7 +572,7 @@ func NewClient() *Client {
558572// self-signed certificate (do not use in production!), pass:
559573// InsecureSkipVerify: true
560574func NewTLSClient (tlsConfig * tls.Config ) * Client {
561- client := & Client {dialOptions : []func (* websocket.Dialer ){}, timeoutConfig : NewClientTimeoutConfig ()}
575+ client := & Client {dialOptions : []func (* websocket.Dialer ){}, timeoutConfig : NewClientTimeoutConfig (), header : http. Header {} }
562576 client .dialOptions = append (client .dialOptions , func (dialer * websocket.Dialer ) {
563577 dialer .TLSClientConfig = tlsConfig
564578 })
@@ -581,9 +595,11 @@ func (client *Client) AddOption(option interface{}) {
581595}
582596
583597func (client * Client ) SetBasicAuth (username string , password string ) {
584- client .authHeader = http.Header {
585- "Authorization" : {"Basic " + base64 .StdEncoding .EncodeToString ([]byte (username + ":" + password ))},
586- }
598+ client .header .Set ("Authorization" , "Basic " + base64 .StdEncoding .EncodeToString ([]byte (username + ":" + password )))
599+ }
600+
601+ func (client * Client ) SetHeaderValue (key string , value string ) {
602+ client .header .Set (key , value )
587603}
588604
589605func (client * Client ) writePump () {
@@ -669,7 +685,7 @@ func (client *Client) Start(url string) error {
669685 for _ , option := range client .dialOptions {
670686 option (& dialer )
671687 }
672- ws , resp , err := dialer .Dial (url , client .authHeader )
688+ ws , resp , err := dialer .Dial (url , client .header )
673689 if err != nil {
674690 if resp != nil {
675691 httpError := HttpConnectionError {Message : err .Error (), HttpStatus : resp .Status , HttpCode : resp .StatusCode }
0 commit comments