Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 169 additions & 153 deletions cmd/ssokenizer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"errors"
"flag"
"fmt"
"net/url"
"os"
"regexp"
"strings"

"github.com/superfly/ssokenizer"
Expand Down Expand Up @@ -68,6 +70,9 @@ func Run(ctx context.Context, args []string) error {
}

type Config struct {
// Full URL of the ssokenizer service
URL string `yaml:"url"`

// Tokenizer seal (public) key
SealKey string `yaml:"seal_key"`

Expand All @@ -79,6 +84,75 @@ type Config struct {
Log LogConfig `yaml:"log"`
HTTP HTTPConfig `yaml:"http"`
IdentityProviders map[string]IdentityProviderConfig `yaml:"identity_providers"`

// fields populated during validation
ssokenizerURL *url.URL
globalTAC tokenizer.AuthConfig
globalReturnURL *url.URL
providers ssokenizer.StaticProviderRegistry
}

func (c *Config) validate() error {
var err error

if c.HTTP.Address == "" {
return errors.New("missing http.address")
}

c.ssokenizerURL, err = url.Parse(c.URL)
switch {
case c.URL == "":
return errors.New("missing url")
case err != nil:
return fmt.Errorf("invalid URL (%q): %w", c.URL, err)
case c.ssokenizerURL.Scheme == "" || c.ssokenizerURL.Host == "":
return fmt.Errorf("malformed URL: %q", c.URL)
case c.ssokenizerURL.Path == "":
c.ssokenizerURL.Path = "/"
}

if c.SealKey == "" {
return errors.New("missing seal_key")
}

c.globalTAC, err = c.SecretAuth.tokenizerAuthConfig()
if err != nil {
return err
}

if c.ReturnURL != "" {
switch c.globalReturnURL, err = url.Parse(c.ReturnURL); {
case err != nil:
return err
case c.globalReturnURL.Scheme == "" || c.globalReturnURL.Host == "":
return fmt.Errorf("malformed return_url: %q", c.ReturnURL)
}
}

c.providers = make(ssokenizer.StaticProviderRegistry)
for name, pc := range c.IdentityProviders {
if _, dup := c.providers[name]; dup {
return fmt.Errorf("duplicate identity provider %q", name)
}

provider, err := pc.provider(name, c)
if err != nil {
return fmt.Errorf("invalid identity provider %q: %w", name, err)
}

c.providers[name] = provider
}

return nil
}

// tokenizerHostValidator returns validators that tokenizer can run to only
// allow tokens to be forwarded to specific hosts. In addition to whatever
// hostname pattern we want to allow for a given provider, we also include our
// own hostname so tokenizer can send us requests for refresh tokens.
func (c *Config) tokenizerHostValidator(pattern string) []tokenizer.RequestValidator {
re := regexp.MustCompile(fmt.Sprintf("^(%s|%s)$", regexp.QuoteMeta(c.ssokenizerURL.Hostname()), pattern))
return []tokenizer.RequestValidator{tokenizer.AllowHostPattern(re)}
}

// Specifies what authentication clients should be required to present to
Expand Down Expand Up @@ -116,27 +190,6 @@ func NewConfig() Config {
return config
}

// Validate returns an error if the config is invalid.
func (c *Config) Validate() error {
tac, err := c.SecretAuth.tokenizerAuthConfig()
if err != nil {
return err
}

if c.SealKey == "" {
return errors.New("missing seal_key")
}
if c.HTTP.Address == "" {
return errors.New("missing http.address")
}
for _, pc := range c.IdentityProviders {
if err := pc.Validate(c.ReturnURL == "", tac == nil); err != nil {
return err
}
}
return nil
}

type LogConfig struct {
Debug bool `yaml:"debug"`
}
Expand Down Expand Up @@ -171,151 +224,114 @@ type IdentityProviderConfig struct {
SecretAuth SecretAuthConfig `yaml:"secret_auth"`
}

func (c IdentityProviderConfig) providerConfig(name, returnURL string) (ssokenizer.ProviderConfig, error) {
switch c.Profile {
case "vanta":
return &vanta.Config{
Path: "/" + name,
Config: xoauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
Scopes: c.Scopes,
Endpoint: xoauth2.Endpoint{
AuthURL: "https://app.vanta.com/oauth/authorize",
TokenURL: "https://api.vanta.com/oauth/token",
AuthStyle: xoauth2.AuthStyleInParams,
},
func (ic *IdentityProviderConfig) provider(name string, c *Config) (ssokenizer.Provider, error) {
switch {
case ic.ClientID == "":
return nil, errors.New("missing client_id")
case ic.ClientSecret == "":
return nil, errors.New("missing client_secret")
}

op := oauth2.Provider{
ProviderConfig: ssokenizer.ProviderConfig{
Tokenizer: ssokenizer.TokenizerConfig{
SealKey: c.SealKey,
},
ForwardParams: []string{"source_id"},
}, nil
URL: *c.ssokenizerURL.JoinPath("/" + name),
},
OAuthConfig: xoauth2.Config{
ClientID: ic.ClientID,
ClientSecret: ic.ClientSecret,
Scopes: ic.Scopes,
},
}

switch tac, err := ic.SecretAuth.tokenizerAuthConfig(); {
case err != nil:
return nil, err
case tac == nil && c.globalTAC == nil:
return nil, errors.New("missing secret_auth")
case tac == nil:
op.ProviderConfig.Tokenizer.Auth = c.globalTAC
default:
op.ProviderConfig.Tokenizer.Auth = tac
}

switch {
case ic.ReturnURL == "" && c.globalReturnURL == nil:
return nil, errors.New("missing return_url")
case ic.ReturnURL == "":
op.ProviderConfig.ReturnURL = *c.globalReturnURL
default:
switch u, err := url.Parse(ic.ReturnURL); {
case err != nil:
return nil, fmt.Errorf("invalid return_url: %w", err)
case u.Scheme == "" || u.Host == "":
return nil, fmt.Errorf("malformed return_url: %q", ic.ReturnURL)
default:
op.ProviderConfig.ReturnURL = *u
}
}

switch ic.Profile {
case "vanta":
op.OAuthConfig.Endpoint = xoauth2.Endpoint{
AuthURL: "https://app.vanta.com/oauth/authorize",
TokenURL: "https://api.vanta.com/oauth/token",
AuthStyle: xoauth2.AuthStyleInParams,
}

op.ForwardParams = []string{"source_id"}

return &vanta.Provider{Provider: op}, nil
case "oauth":
return &oauth2.Config{
Path: "/" + name,
Config: xoauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
Scopes: c.Scopes,
Endpoint: xoauth2.Endpoint{
AuthURL: c.AuthURL,
TokenURL: c.TokenURL,
},
},
}, nil
switch {
case ic.AuthURL == "":
return nil, errors.New("missing auth_url")
case ic.TokenURL == "":
return nil, errors.New("missing token_url")
}

op.OAuthConfig.Endpoint = xoauth2.Endpoint{
AuthURL: ic.AuthURL,
TokenURL: ic.TokenURL,
}

return &op, nil
case "amazon":
return &oauth2.Config{
Path: "/" + name,
Config: xoauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
Scopes: c.Scopes,
Endpoint: amazon.Endpoint,
},
}, nil
op.OAuthConfig.Endpoint = amazon.Endpoint
return &op, nil
case "bitbucket":
return &oauth2.Config{
Path: "/" + name,
Config: xoauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
Scopes: c.Scopes,
Endpoint: bitbucket.Endpoint,
},
}, nil
op.OAuthConfig.Endpoint = bitbucket.Endpoint
return &op, nil
case "facebook":
return &oauth2.Config{
Path: "/" + name,
Config: xoauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
Scopes: c.Scopes,
Endpoint: facebook.Endpoint,
},
}, nil
op.OAuthConfig.Endpoint = facebook.Endpoint
return &op, nil
case "github":
return &oauth2.Config{
Path: "/" + name,
AllowedHostPattern: `api\.github\.com`,
Config: xoauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
Scopes: c.Scopes,
Endpoint: github.Endpoint,
},
}, nil
op.OAuthConfig.Endpoint = github.Endpoint
op.Tokenizer.RequestValidators = c.tokenizerHostValidator(`api\.github\.com`)
return &op, nil
case "gitlab":
return &oauth2.Config{
Path: "/" + name,
Config: xoauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
Scopes: c.Scopes,
Endpoint: gitlab.Endpoint,
},
}, nil
op.OAuthConfig.Endpoint = gitlab.Endpoint
return &op, nil
case "google":
return &oauth2.Config{
Path: "/" + name,
AllowedHostPattern: `.*\.googleapis\.com`,
Config: xoauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
Scopes: c.Scopes,
Endpoint: google.Endpoint,
},
ForwardParams: []string{"hd"},
}, nil
op.OAuthConfig.Endpoint = google.Endpoint
op.Tokenizer.RequestValidators = c.tokenizerHostValidator(`.*\.googleapis\.com`)
op.ForwardParams = []string{"hd"}
return &op, nil
case "heroku":
return &oauth2.Config{
Path: "/" + name,
AllowedHostPattern: `api\.heroku\.com`,
Config: xoauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
Scopes: c.Scopes,
Endpoint: heroku.Endpoint,
},
}, nil
op.OAuthConfig.Endpoint = heroku.Endpoint
op.Tokenizer.RequestValidators = c.tokenizerHostValidator(`api\.heroku\.com`)
return &op, nil
case "microsoft":
return &oauth2.Config{
Path: "/" + name,
Config: xoauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
Scopes: c.Scopes,
Endpoint: microsoft.LiveConnectEndpoint,
},
}, nil
op.OAuthConfig.Endpoint = microsoft.LiveConnectEndpoint
return &op, nil
case "slack":
return &oauth2.Config{
Path: "/" + name,
Config: xoauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
Scopes: c.Scopes,
Endpoint: slack.Endpoint,
},
}, nil
op.OAuthConfig.Endpoint = slack.Endpoint
return &op, nil
default:
return nil, fmt.Errorf("unknown identity provider profile: %s", c.Profile)
}
}

func (c IdentityProviderConfig) Validate(needsReturnURL, needsProxyAuthorization bool) error {
if c.Profile == "" {
return errors.New("missing identity_providers.profile")
}
if c.ReturnURL == "" && needsReturnURL {
return errors.New("missing return_url or identity_providers.return_url")
return nil, errors.New("unknown identity provider profile")
}

switch tac, err := c.SecretAuth.tokenizerAuthConfig(); {
case err != nil:
return err
case tac == nil && needsProxyAuthorization:
return errors.New("missing secret_auth or identity_providers.secret_auth")
}

return nil
}

// UnmarshalConfig unmarshals config from data. Expands variables as needed.
Expand Down
Loading