Skip to content

Commit 7f36eb0

Browse files
authored
feat(oauth): add support for X/Twitter v2 provider (#2275)
Adds support for X (formerly Twitter) v2 as an external OAuth provider. - Introduces an `oauth_client_states` table to persist the `code_verifier`s as X mandates the use of PKCE - Uses SHA256 challenge for PKCE - Updates the `GetOAuthToken` signature to accept a `code_verifier` as the second parameter - Uses the existing cleanup middleware to delete states - Adds the provider as `x` rather than `x_v2` as `twitter` is already used in old OAuth 1.0a provider to better align with the rebrand - The state is a UUIDv4 NOTE: today the `flow_states` table is overloaded, containing states, auth codes, provider tokens...the goal is to decouple that table eventually and the `oauth_states` table is the first step towards that.
1 parent 511c3a4 commit 7f36eb0

37 files changed

+663
-85
lines changed

hack/test.env

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ GOTRUE_EXTERNAL_TWITTER_ENABLED=true
105105
GOTRUE_EXTERNAL_TWITTER_CLIENT_ID=testclientid
106106
GOTRUE_EXTERNAL_TWITTER_SECRET=testsecret
107107
GOTRUE_EXTERNAL_TWITTER_REDIRECT_URI=https://identity.services.netlify.com/callback
108+
GOTRUE_EXTERNAL_X_ENABLED=true
109+
GOTRUE_EXTERNAL_X_CLIENT_ID=testclientid
110+
GOTRUE_EXTERNAL_X_SECRET=testsecret
111+
GOTRUE_EXTERNAL_X_REDIRECT_URI=https://identity.services.netlify.com/callback
108112
GOTRUE_EXTERNAL_ZOOM_ENABLED=true
109113
GOTRUE_EXTERNAL_ZOOM_CLIENT_ID=testclientid
110114
GOTRUE_EXTERNAL_ZOOM_SECRET=testsecret

internal/api/apierrors/errorcode.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ const (
2323
ErrorCodeRefreshTokenAlreadyUsed ErrorCode = "refresh_token_already_used"
2424
ErrorCodeFlowStateNotFound ErrorCode = "flow_state_not_found"
2525
ErrorCodeFlowStateExpired ErrorCode = "flow_state_expired"
26+
ErrorCodeOAuthClientStateNotFound ErrorCode = "oauth_client_state_not_found"
27+
ErrorCodeOAuthClientStateExpired ErrorCode = "oauth_client_state_expired"
28+
ErrorCodeOAuthInvalidState ErrorCode = "oauth_invalid_state"
2629
ErrorCodeSignupDisabled ErrorCode = "signup_disabled"
2730
ErrorCodeUserBanned ErrorCode = "user_banned"
2831
ErrorCodeProviderEmailNeedsVerification ErrorCode = "provider_email_needs_verification"

internal/api/context.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"net/url"
66

7+
"github.com/gofrs/uuid"
78
jwt "github.com/golang-jwt/jwt/v5"
89
"github.com/supabase/auth/internal/api/shared"
910
"github.com/supabase/auth/internal/models"
@@ -33,6 +34,7 @@ const (
3334
ssoProviderKey = contextKey("sso_provider")
3435
externalHostKey = contextKey("external_host")
3536
flowStateKey = contextKey("flow_state_id")
37+
oauthClientStateKey = contextKey("oauth_client_state_id")
3638
)
3739

3840
// withToken adds the JWT token to the context.
@@ -137,6 +139,19 @@ func getFlowStateID(ctx context.Context) string {
137139
return obj.(string)
138140
}
139141

142+
func withOAuthClientStateID(ctx context.Context, oauthClientStateID uuid.UUID) context.Context {
143+
return context.WithValue(ctx, oauthClientStateKey, oauthClientStateID)
144+
}
145+
146+
func getOAuthClientStateID(ctx context.Context) uuid.UUID {
147+
obj := ctx.Value(oauthClientStateKey)
148+
if obj == nil {
149+
return uuid.Nil
150+
}
151+
152+
return obj.(uuid.UUID)
153+
}
154+
140155
func getInviteToken(ctx context.Context) string {
141156
obj := ctx.Value(inviteTokenKey)
142157
if obj == nil {

internal/api/external.go

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ import (
2727
// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow
2828
type ExternalProviderClaims struct {
2929
AuthMicroserviceClaims
30-
Provider string `json:"provider"`
31-
InviteToken string `json:"invite_token,omitempty"`
32-
Referrer string `json:"referrer,omitempty"`
33-
FlowStateID string `json:"flow_state_id"`
34-
LinkingTargetID string `json:"linking_target_id,omitempty"`
35-
EmailOptional bool `json:"email_optional,omitempty"`
30+
Provider string `json:"provider"`
31+
InviteToken string `json:"invite_token,omitempty"`
32+
Referrer string `json:"referrer,omitempty"`
33+
FlowStateID string `json:"flow_state_id"`
34+
OAuthClientStateID string `json:"oauth_client_state_id,omitempty"`
35+
LinkingTargetID string `json:"linking_target_id,omitempty"`
36+
EmailOptional bool `json:"email_optional,omitempty"`
3637
}
3738

3839
// ExternalProviderRedirect redirects the request to the oauth provider
@@ -90,6 +91,32 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
9091
flowStateID = flowState.ID.String()
9192
}
9293

94+
authUrlParams := make([]oauth2.AuthCodeOption, 0)
95+
query.Del("scopes")
96+
query.Del("provider")
97+
query.Del("code_challenge")
98+
query.Del("code_challenge_method")
99+
for key := range query {
100+
if key == "workos_provider" {
101+
// See https://workos.com/docs/reference/sso/authorize/get
102+
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam("provider", query.Get(key)))
103+
} else {
104+
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam(key, query.Get(key)))
105+
}
106+
}
107+
108+
oauthClientStateID := ""
109+
if oauthProvider, ok := p.(provider.OAuthProvider); ok && oauthProvider.RequiresPKCE() {
110+
codeVerifier := oauth2.GenerateVerifier()
111+
oauthClientState := models.NewOAuthClientState(providerType, &codeVerifier)
112+
err := db.Create(oauthClientState)
113+
if err != nil {
114+
return "", err
115+
}
116+
oauthClientStateID = oauthClientState.ID.String()
117+
authUrlParams = append(authUrlParams, oauth2.S256ChallengeOption(codeVerifier))
118+
}
119+
93120
claims := ExternalProviderClaims{
94121
AuthMicroserviceClaims: AuthMicroserviceClaims{
95122
RegisteredClaims: jwt.RegisteredClaims{
@@ -98,11 +125,12 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
98125
SiteURL: config.SiteURL,
99126
InstanceID: uuid.Nil.String(),
100127
},
101-
Provider: providerType,
102-
InviteToken: inviteToken,
103-
Referrer: redirectURL,
104-
FlowStateID: flowStateID,
105-
EmailOptional: pConfig.EmailOptional,
128+
Provider: providerType,
129+
InviteToken: inviteToken,
130+
Referrer: redirectURL,
131+
FlowStateID: flowStateID,
132+
OAuthClientStateID: oauthClientStateID,
133+
EmailOptional: pConfig.EmailOptional,
106134
}
107135

108136
if linkingTargetUser != nil {
@@ -115,20 +143,6 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
115143
return "", apierrors.NewInternalServerError("Error creating state").WithInternalError(err)
116144
}
117145

118-
authUrlParams := make([]oauth2.AuthCodeOption, 0)
119-
query.Del("scopes")
120-
query.Del("provider")
121-
query.Del("code_challenge")
122-
query.Del("code_challenge_method")
123-
for key := range query {
124-
if key == "workos_provider" {
125-
// See https://workos.com/docs/reference/sso/authorize/get
126-
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam("provider", query.Get(key)))
127-
} else {
128-
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam(key, query.Get(key)))
129-
}
130-
}
131-
132146
authURL := p.AuthCodeURL(tokenString, authUrlParams...)
133147

134148
return authURL, nil
@@ -565,6 +579,13 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request, db *storag
565579
if claims.FlowStateID != "" {
566580
ctx = withFlowStateID(ctx, claims.FlowStateID)
567581
}
582+
if claims.OAuthClientStateID != "" {
583+
oauthClientStateID, err := uuid.FromString(claims.OAuthClientStateID)
584+
if err != nil {
585+
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (oauth_client_state_id must be UUID)")
586+
}
587+
ctx = withOAuthClientStateID(ctx, oauthClientStateID)
588+
}
568589
if claims.LinkingTargetID != "" {
569590
linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID)
570591
if err != nil {
@@ -634,7 +655,7 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide
634655
p, err = provider.NewLinkedinProvider(pConfig, scopes)
635656
case "linkedin_oidc":
636657
pConfig = config.External.LinkedinOIDC
637-
p, err = provider.NewLinkedinOIDCProvider(pConfig, scopes)
658+
p, err = provider.NewLinkedinOIDCProvider(ctx, pConfig, scopes)
638659
case "notion":
639660
pConfig = config.External.Notion
640661
p, err = provider.NewNotionProvider(pConfig)
@@ -656,9 +677,12 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide
656677
case "twitter":
657678
pConfig = config.External.Twitter
658679
p, err = provider.NewTwitterProvider(pConfig, scopes)
680+
case "x":
681+
pConfig = config.External.X
682+
p, err = provider.NewXProvider(pConfig, scopes)
659683
case "vercel_marketplace":
660684
pConfig = config.External.VercelMarketplace
661-
p, err = provider.NewVercelMarketplaceProvider(pConfig, scopes)
685+
p, err = provider.NewVercelMarketplaceProvider(ctx, pConfig, scopes)
662686
case "workos":
663687
pConfig = config.External.WorkOS
664688
p, err = provider.NewWorkOSProvider(pConfig)

internal/api/external_oauth.go

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@ import (
66
"net/http"
77
"net/url"
88

9+
"github.com/gofrs/uuid"
910
"github.com/mrjones/oauth"
1011
"github.com/sirupsen/logrus"
1112
"github.com/supabase/auth/internal/api/apierrors"
1213
"github.com/supabase/auth/internal/api/provider"
1314
"github.com/supabase/auth/internal/conf"
15+
"github.com/supabase/auth/internal/models"
1416
"github.com/supabase/auth/internal/observability"
1517
"github.com/supabase/auth/internal/utilities"
18+
"golang.org/x/oauth2"
1619
)
1720

1821
// OAuthProviderData contains the userData and token returned by the oauth provider
@@ -55,6 +58,8 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con
5558
}
5659

5760
func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) {
61+
db := a.db.WithContext(ctx)
62+
5863
var rq url.Values
5964
if err := r.ParseForm(); r.Method == http.MethodPost && err == nil {
6065
rq = r.Form
@@ -72,28 +77,56 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s
7277
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing")
7378
}
7479

75-
oAuthProvider, _, err := a.OAuthProvider(ctx, providerType)
80+
oauthProvider, _, err := a.OAuthProvider(ctx, providerType)
7681
if err != nil {
7782
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err)
7883
}
7984

8085
log := observability.GetLogEntry(r).Entry
86+
87+
var oauthClientState *models.OAuthClientState
88+
// if there's a non-empty OAuthClientStateID we perform PKCE Flow for the external provider
89+
if oauthClientStateID := getOAuthClientStateID(ctx); oauthClientStateID != uuid.Nil {
90+
oauthClientState, err = models.FindAndDeleteOAuthClientStateByID(db, oauthClientStateID)
91+
if models.IsNotFoundError(err) {
92+
return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeOAuthClientStateNotFound, "OAuth state not found").WithInternalError(err)
93+
} else if err != nil {
94+
return nil, apierrors.NewInternalServerError("Failed to find OAuth state").WithInternalError(err)
95+
}
96+
97+
if oauthClientState.ProviderType != providerType {
98+
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthInvalidState, "OAuth provider mismatch")
99+
}
100+
101+
if oauthClientState.IsExpired() {
102+
return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeOAuthClientStateExpired, "OAuth state expired")
103+
}
104+
}
105+
106+
if oauthProvider.RequiresPKCE() && oauthClientState == nil {
107+
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthInvalidState, "OAuth PKCE code verifier missing")
108+
}
109+
81110
log.WithFields(logrus.Fields{
82111
"provider": providerType,
83112
"code": oauthCode,
84-
}).Debug("Exchanging oauth code")
113+
}).Debug("Exchanging OAuth code")
85114

86-
token, err := oAuthProvider.GetOAuthToken(oauthCode)
115+
var tokenOpts []oauth2.AuthCodeOption
116+
if oauthClientState != nil {
117+
tokenOpts = append(tokenOpts, oauth2.VerifierOption(*oauthClientState.CodeVerifier))
118+
}
119+
token, err := oauthProvider.GetOAuthToken(ctx, oauthCode, tokenOpts...)
87120
if err != nil {
88121
return nil, apierrors.NewInternalServerError("Unable to exchange external code: %s", oauthCode).WithInternalError(err)
89122
}
90123

91-
userData, err := oAuthProvider.GetUserData(ctx, token)
124+
userData, err := oauthProvider.GetUserData(ctx, token)
92125
if err != nil {
93126
return nil, apierrors.NewInternalServerError("Error getting user profile from external provider").WithInternalError(err)
94127
}
95128

96-
switch externalProvider := oAuthProvider.(type) {
129+
switch externalProvider := oauthProvider.(type) {
97130
case *provider.AppleProvider:
98131
// apple only returns user info the first time
99132
oauthUser := rq.Get("user")

0 commit comments

Comments
 (0)