diff --git a/handlers.go b/handlers.go index 0ad1735..5468a56 100644 --- a/handlers.go +++ b/handlers.go @@ -3,6 +3,7 @@ package scim import ( "encoding/json" "net/http" + "strings" "github.com/elimity-com/scim/errors" "github.com/elimity-com/scim/schema" @@ -334,6 +335,132 @@ func (s Server) resourcesGetHandler(w http.ResponseWriter, r *http.Request, reso } } +// rootResourcesGetHandler receives an HTTP GET request to the server root endpoint to query across all resource types. +func (s Server) rootResourcesGetHandler(w http.ResponseWriter, r *http.Request) { + count, startIndex, scimErr := s.parsePaginationParams(r) + if scimErr != nil { + s.errorHandler(w, scimErr) + return + } + + params := ListRequestParams{ + Count: count, + Filter: strings.TrimSpace(r.URL.Query().Get("filter")), + StartIndex: startIndex, + } + + page, getError := s.rootQueryHandler.GetAll(r, params) + if getError != nil { + scimErr := errors.CheckScimError(getError, http.MethodGet) + s.errorHandler(w, &scimErr) + return + } + + lr := listResponse{ + TotalResults: page.TotalResults, + Resources: page.rawResources(), + StartIndex: params.StartIndex, + ItemsPerPage: params.Count, + } + raw, err := json.Marshal(lr) + if err != nil { + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( + "failed marshaling list response", + "listResponse", lr, + "error", err, + ) + return + } + + _, err = w.Write(raw) + if err != nil { + s.log.Error( + "failed writing response", + "error", err, + ) + } +} + +// rootSearchHandler receives an HTTP POST request to /.search to query across all resource types. +// Per RFC 7644 Section 3.4.3, this is an alternative to GET / with query parameters. +func (s Server) rootSearchHandler(w http.ResponseWriter, r *http.Request) { + data, err := readBody(r) + if err != nil { + s.errorHandler(w, &errors.ScimErrorInternal) + return + } + + var sr searchRequest + if err := json.Unmarshal(data, &sr); err != nil { + scimErr := errors.ScimError{ + Status: http.StatusBadRequest, + Detail: "Invalid search request body.", + } + s.errorHandler(w, &scimErr) + return + } + + defaultCount := s.config.getItemsPerPage() + + count := defaultCount + if sr.Count != nil { + count = *sr.Count + } + if count > defaultCount { + count = defaultCount + } + if count < 0 { + count = 0 + } + + startIndex := defaultStartIndex + if sr.StartIndex != nil { + startIndex = *sr.StartIndex + } + if startIndex < 1 { + startIndex = defaultStartIndex + } + + params := ListRequestParams{ + Count: count, + Filter: sr.Filter, + StartIndex: startIndex, + } + + page, getError := s.rootQueryHandler.GetAll(r, params) + if getError != nil { + scimErr := errors.CheckScimError(getError, http.MethodPost) + s.errorHandler(w, &scimErr) + return + } + + lr := listResponse{ + TotalResults: page.TotalResults, + Resources: page.rawResources(), + StartIndex: params.StartIndex, + ItemsPerPage: params.Count, + } + raw, err := json.Marshal(lr) + if err != nil { + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( + "failed marshaling list response", + "listResponse", lr, + "error", err, + ) + return + } + + _, err = w.Write(raw) + if err != nil { + s.log.Error( + "failed writing response", + "error", err, + ) + } +} + // schemaHandler receives an HTTP GET to retrieve individual schema definitions which can be returned by appending the // schema URI to the /Schemas endpoint. For example: "/Schemas/urn:ietf:params:scim:schemas:core:2.0:User". func (s Server) schemaHandler(w http.ResponseWriter, r *http.Request, id string) { @@ -440,3 +567,11 @@ func (s Server) serviceProviderConfigHandler(w http.ResponseWriter, r *http.Requ ) } } + +// searchRequest represents the JSON body of a POST /.search request per RFC 7644 Section 3.4.3. +type searchRequest struct { + Schemas []string `json:"schemas"` + Filter string `json:"filter"` + StartIndex *int `json:"startIndex"` + Count *int `json:"count"` +} diff --git a/handlers_test.go b/handlers_test.go index f040a1c..ba93a0b 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -9,6 +9,7 @@ import ( "net/url" "strings" "testing" + "time" "github.com/elimity-com/scim/errors" "github.com/elimity-com/scim/optional" @@ -913,6 +914,28 @@ func TestServerResourcesGetHandler(t *testing.T) { assertEqual(t, 20, len(response.Resources)) } +func TestServerResourcesGetHandlerFilterOnCommonAttribute(t *testing.T) { + tests := []struct { + name string + filter string + }{ + {name: "meta.lastModified", filter: `meta.lastModified gt "2011-05-13T04:42:34Z"`}, + {name: "meta.resourceType", filter: `meta.resourceType eq "User"`}, + {name: "id", filter: `id eq "0001"`}, + {name: "externalId", filter: `externalId eq "external1"`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + target := fmt.Sprintf("/Users?filter=%s", url.QueryEscape(tt.filter)) + req := httptest.NewRequest(http.MethodGet, target, nil) + rr := httptest.NewRecorder() + newTestServer(t).ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + }) + } +} + func TestServerResourcesGetHandlerMaxCount(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users?count=20000", nil) rr := httptest.NewRecorder() @@ -960,6 +983,452 @@ func TestServerResourcesGetHandlerWithBaseURL(t *testing.T) { } } +func TestServerRootQuery(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + assertEqual(t, 3, response.TotalResults) + assertEqual(t, 3, len(response.Resources)) +} + +func TestServerRootQueryExplicitStatusCode(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + w := &statusRecordingResponseWriter{ResponseWriter: rr} + s.ServeHTTP(w, req) + + if !w.calledWriteHeader { + t.Error("handler did not explicitly call WriteHeader") + } + assertEqualStatusCode(t, http.StatusOK, w.status) +} + +func TestServerRootQueryFilter(t *testing.T) { + s, err := NewServer( + &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + }, + WithRootQueryHandler(testRootQueryHandlerCapture{}), + ) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodGet, `/?filter=meta.resourceType+eq+"User"`, nil) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + assertEqual(t, 1, response.TotalResults) + assertEqual(t, 1, len(response.Resources)) + + resource := response.Resources[0].(map[string]interface{}) + assertEqual(t, `meta.resourceType eq "User"`, resource["capturedFilter"]) +} + +func TestServerRootQueryHandlerError(t *testing.T) { + s, err := NewServer( + &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + }, + WithRootQueryHandler(testRootQueryHandlerError{}), + ) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusInternalServerError, rr.Code) +} + +func TestServerRootQueryInjectsResourceFields(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + + // First resource has ID and ExternalID set on the Resource struct. + first := response.Resources[0].(map[string]interface{}) + assertEqual(t, "u1", first["id"]) + assertEqual(t, "ext-u1", first["externalId"]) + + // Second resource has ID but no ExternalID. + second := response.Resources[1].(map[string]interface{}) + assertEqual(t, "u2", second["id"]) + if _, ok := second["externalId"]; ok { + t.Error("externalId should not be present when not set on Resource") + } + + // The caller-provided "meta" map (with "resourceType") is preserved. + firstMeta := first["meta"].(map[string]interface{}) + assertEqual(t, "User", firstMeta["resourceType"]) +} + +func TestServerRootQueryInvalidCount(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + req := httptest.NewRequest(http.MethodGet, "/?count=BadBanana", nil) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) +} + +func TestServerRootQueryMergesMetaFields(t *testing.T) { + created := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + modified := time.Date(2024, 6, 20, 14, 0, 0, 0, time.UTC) + + s, err := NewServer( + &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + }, + WithRootQueryHandler(testRootQueryHandlerWithMeta{ + created: &created, + modified: &modified, + version: `W/"abc123"`, + }), + ) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + + resource := response.Resources[0].(map[string]interface{}) + assertEqual(t, "r1", resource["id"]) + + // Meta fields from Resource.Meta are merged with the caller-provided "meta" map. + m := resource["meta"].(map[string]interface{}) + assertEqual(t, "User", m["resourceType"]) + assertEqual(t, created.Format(time.RFC3339), m["created"]) + assertEqual(t, modified.Format(time.RFC3339), m["lastModified"]) + assertEqual(t, `W/"abc123"`, m["version"]) +} + +func TestServerRootQueryNonGetMethod(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + for _, method := range []string{http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete} { + t.Run(method, func(t *testing.T) { + req := httptest.NewRequest(method, "/", nil) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusNotFound, rr.Code) + }) + } +} + +func TestServerRootQueryNotConfigured(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + newTestServer(t).ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) + + var scimErr errors.ScimError + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &scimErr)) + assertEqual(t, errors.ScimTypeTooMany, scimErr.ScimType) +} + +func TestServerRootQueryNotConfiguredWithV2Prefix(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v2", nil) + rr := httptest.NewRecorder() + newTestServer(t).ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) + + var scimErr errors.ScimError + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &scimErr)) + assertEqual(t, errors.ScimTypeTooMany, scimErr.ScimType) +} + +func TestServerRootQueryNotConfiguredWithV2PrefixTrailingSlash(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v2/", nil) + rr := httptest.NewRecorder() + newTestServer(t).ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) + + var scimErr errors.ScimError + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &scimErr)) + assertEqual(t, errors.ScimTypeTooMany, scimErr.ScimType) +} + +func TestServerRootQueryPagination(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + req := httptest.NewRequest(http.MethodGet, "/?count=1&startIndex=2", nil) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + assertEqual(t, 3, response.TotalResults) + assertEqual(t, 1, len(response.Resources)) + assertEqual(t, 2, response.StartIndex) + assertEqual(t, 1, response.ItemsPerPage) +} + +func TestServerRootQueryWithV2Prefix(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + req := httptest.NewRequest(http.MethodGet, "/v2", nil) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + assertEqual(t, 3, response.TotalResults) +} + +func TestServerRootQueryWithV2PrefixTrailingSlash(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + req := httptest.NewRequest(http.MethodGet, "/v2/", nil) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + assertEqual(t, 3, response.TotalResults) +} + +func TestServerRootSearch(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + body := strings.NewReader(`{"schemas":["urn:ietf:params:scim:api:messages:2.0:SearchRequest"]}`) + req := httptest.NewRequest(http.MethodPost, "/.search", body) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + assertEqual(t, 3, response.TotalResults) + assertEqual(t, 3, len(response.Resources)) +} + +func TestServerRootSearchCountExceedsMax(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + body := strings.NewReader(`{"count":999999}`) + req := httptest.NewRequest(http.MethodPost, "/.search", body) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + assertEqual(t, 3, response.TotalResults) + assertEqual(t, 3, len(response.Resources)) +} + +func TestServerRootSearchDefaultParams(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + body := strings.NewReader(`{}`) + req := httptest.NewRequest(http.MethodPost, "/.search", body) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + assertEqual(t, 3, response.TotalResults) + assertEqual(t, 3, len(response.Resources)) + assertEqual(t, 1, response.StartIndex) +} + +func TestServerRootSearchExplicitStatusCode(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + body := strings.NewReader(`{}`) + req := httptest.NewRequest(http.MethodPost, "/.search", body) + rr := httptest.NewRecorder() + w := &statusRecordingResponseWriter{ResponseWriter: rr} + s.ServeHTTP(w, req) + + if !w.calledWriteHeader { + t.Error("handler did not explicitly call WriteHeader") + } + assertEqualStatusCode(t, http.StatusOK, w.status) +} + +func TestServerRootSearchFilter(t *testing.T) { + s, err := NewServer( + &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + }, + WithRootQueryHandler(testRootQueryHandlerCapture{}), + ) + if err != nil { + t.Fatal(err) + } + + body := strings.NewReader(`{"schemas":["urn:ietf:params:scim:api:messages:2.0:SearchRequest"],"filter":"meta.resourceType eq \"User\""}`) + req := httptest.NewRequest(http.MethodPost, "/.search", body) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + assertEqual(t, 1, response.TotalResults) + + resource := response.Resources[0].(map[string]interface{}) + assertEqual(t, `meta.resourceType eq "User"`, resource["capturedFilter"]) +} + +func TestServerRootSearchHandlerError(t *testing.T) { + s, err := NewServer( + &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + }, + WithRootQueryHandler(testRootQueryHandlerError{}), + ) + if err != nil { + t.Fatal(err) + } + + body := strings.NewReader(`{}`) + req := httptest.NewRequest(http.MethodPost, "/.search", body) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusInternalServerError, rr.Code) +} + +func TestServerRootSearchInvalidBody(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + body := strings.NewReader(`not json`) + req := httptest.NewRequest(http.MethodPost, "/.search", body) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) +} + +func TestServerRootSearchNegativeCount(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + body := strings.NewReader(`{"count":-5}`) + req := httptest.NewRequest(http.MethodPost, "/.search", body) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + assertEqual(t, 3, response.TotalResults) + assertEqual(t, 0, len(response.Resources)) + assertEqual(t, 0, response.ItemsPerPage) +} + +func TestServerRootSearchNotConfigured(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/.search", strings.NewReader(`{}`)) + rr := httptest.NewRecorder() + newTestServer(t).ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) + + var scimErr errors.ScimError + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &scimErr)) + assertEqual(t, errors.ScimTypeTooMany, scimErr.ScimType) +} + +func TestServerRootSearchPagination(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + body := strings.NewReader(`{"schemas":["urn:ietf:params:scim:api:messages:2.0:SearchRequest"],"startIndex":2,"count":1}`) + req := httptest.NewRequest(http.MethodPost, "/.search", body) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + assertEqual(t, 3, response.TotalResults) + assertEqual(t, 1, len(response.Resources)) + assertEqual(t, 2, response.StartIndex) + assertEqual(t, 1, response.ItemsPerPage) +} + +func TestServerRootSearchWithV2Prefix(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + body := strings.NewReader(`{"schemas":["urn:ietf:params:scim:api:messages:2.0:SearchRequest"]}`) + req := httptest.NewRequest(http.MethodPost, "/v2/.search", body) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusOK, rr.Code) + + var response listResponse + assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &response)) + assertEqual(t, 3, response.TotalResults) +} + +func TestServerRootSearchWrongMethod(t *testing.T) { + s := newTestServerWithRootQueryHandler(t) + + for _, method := range []string{http.MethodGet, http.MethodPut, http.MethodPatch, http.MethodDelete} { + t.Run(method, func(t *testing.T) { + req := httptest.NewRequest(method, "/.search", nil) + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + assertEqualStatusCode(t, http.StatusNotFound, rr.Code) + }) + } +} + func TestServerSchemaEndpointValid(t *testing.T) { tests := []struct { name string @@ -1262,6 +1731,30 @@ func newTestServerWithBaseURL(t *testing.T) Server { return s } +func newTestServerWithRootQueryHandler(t *testing.T) Server { + userSchema := getUserSchema() + s, err := NewServer( + &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: userSchema, + Handler: newTestResourceHandler(), + }, + }, + }, + WithRootQueryHandler(testRootQueryHandler{}), + ) + if err != nil { + t.Fatal(err) + } + return s +} + // statusRecordingResponseWriter wraps an http.ResponseWriter and records // whether WriteHeader was called explicitly, simulating logging middleware. type statusRecordingResponseWriter struct { @@ -1275,3 +1768,66 @@ func (w *statusRecordingResponseWriter) WriteHeader(status int) { w.status = status w.ResponseWriter.WriteHeader(status) } + +type testRootQueryHandler struct{} + +func (h testRootQueryHandler) GetAll(r *http.Request, params ListRequestParams) (Page, error) { + resources := []Resource{ + {ID: "u1", ExternalID: optional.NewString("ext-u1"), Attributes: ResourceAttributes{"userName": "alice", "meta": map[string]interface{}{"resourceType": "User"}}}, + {ID: "u2", Attributes: ResourceAttributes{"userName": "bob", "meta": map[string]interface{}{"resourceType": "User"}}}, + {ID: "g1", Attributes: ResourceAttributes{"displayName": "admins", "meta": map[string]interface{}{"resourceType": "Group"}}}, + } + + start := params.StartIndex - 1 + if start > len(resources) { + start = len(resources) + } + end := start + params.Count + if end > len(resources) { + end = len(resources) + } + + return Page{ + TotalResults: len(resources), + Resources: resources[start:end], + }, nil +} + +type testRootQueryHandlerCapture struct{} + +func (h testRootQueryHandlerCapture) GetAll(r *http.Request, params ListRequestParams) (Page, error) { + return Page{ + TotalResults: 1, + Resources: []Resource{ + {ID: "1", Attributes: ResourceAttributes{"capturedFilter": params.Filter}}, + }, + }, nil +} + +type testRootQueryHandlerError struct{} + +func (h testRootQueryHandlerError) GetAll(r *http.Request, params ListRequestParams) (Page, error) { + return Page{}, errors.ScimError{ + Status: http.StatusInternalServerError, + Detail: "something went wrong", + } +} + +type testRootQueryHandlerWithMeta struct { + created *time.Time + modified *time.Time + version string +} + +func (h testRootQueryHandlerWithMeta) GetAll(r *http.Request, params ListRequestParams) (Page, error) { + return Page{ + TotalResults: 1, + Resources: []Resource{ + { + ID: "r1", + Meta: Meta{Created: h.created, LastModified: h.modified, Version: h.version}, + Attributes: ResourceAttributes{"meta": map[string]interface{}{"resourceType": "User"}}, + }, + }, + }, nil +} diff --git a/list_response.go b/list_response.go index 99828ee..821552e 100644 --- a/list_response.go +++ b/list_response.go @@ -2,6 +2,9 @@ package scim import ( "encoding/json" + "time" + + "github.com/elimity-com/scim/schema" ) // Page represents a paginated resource query response. @@ -12,6 +15,83 @@ type Page struct { Resources []Resource } +// rawResources returns resources as raw interface values for root queries (GET /). +// +// Unlike the resource-type-specific resources() method, rawResources does NOT have access to a ResourceType and +// therefore cannot inject: +// - "schemas": requires knowing the resource type's schema URIs. +// - "meta.resourceType": requires the resource type name. +// - "meta.location": requires the resource type endpoint. +// +// The caller (RootQueryHandler) is responsible for including these fields in the Resource's Attributes map. +// +// rawResources DOES inject the following fields from the Resource struct into each resource's attributes: +// - "id": from Resource.ID (always required per RFC 7643 Section 3.1). +// - "externalId": from Resource.ExternalID, when present. +// - "meta.created": from Resource.Meta.Created, when non-nil. +// - "meta.lastModified": from Resource.Meta.LastModified, when non-nil. +// - "meta.version": from Resource.Meta.Version, when non-empty. +// +// These meta fields are merged into any existing "meta" map in Attributes. If the caller already provides a "meta" +// map (e.g. with "resourceType"), the injected fields are added alongside it without overwriting existing keys. +func (p Page) rawResources() []interface{} { + if len(p.Resources) == 0 { + if p.Resources != nil { + return []interface{}{} + } + return nil + } + + var resources []interface{} + for _, v := range p.Resources { + attrs := v.Attributes + if attrs == nil { + attrs = ResourceAttributes{} + } + + attrs[schema.CommonAttributeID] = v.ID + if v.ExternalID.Present() { + attrs[schema.CommonAttributeExternalID] = v.ExternalID.Value() + } + + // Merge Meta fields into the existing "meta" map if present, or create a new one. + var metaMap map[string]interface{} + if existing, ok := attrs[schema.CommonAttributeMeta]; ok { + if m, ok := existing.(map[string]interface{}); ok { + metaMap = m + } + } + hasMeta := false + if v.Meta.Created != nil { + if metaMap == nil { + metaMap = map[string]interface{}{} + } + metaMap["created"] = v.Meta.Created.Format(time.RFC3339) + hasMeta = true + } + if v.Meta.LastModified != nil { + if metaMap == nil { + metaMap = map[string]interface{}{} + } + metaMap["lastModified"] = v.Meta.LastModified.Format(time.RFC3339) + hasMeta = true + } + if len(v.Meta.Version) != 0 { + if metaMap == nil { + metaMap = map[string]interface{}{} + } + metaMap["version"] = v.Meta.Version + hasMeta = true + } + if hasMeta { + attrs[schema.CommonAttributeMeta] = metaMap + } + + resources = append(resources, attrs) + } + return resources +} + func (p Page) resources(resourceType ResourceType, baseURL string) []interface{} { // If the page.Resources is nil, then it will also be represented as a `null` in the response. // Otherwise is it is an empty slice then it will result in an empty array `[]`. diff --git a/resource_handler.go b/resource_handler.go index b5aa131..adbc678 100644 --- a/resource_handler.go +++ b/resource_handler.go @@ -15,8 +15,15 @@ type ListRequestParams struct { // A value of "0" indicates that no resource results are to be returned except for "totalResults". Count int - // Filter represents the parsed and tokenized filter query parameter. + // Filter is the raw filter expression string. For resource-type-specific queries, the filter + // is also parsed and available via FilterValidator. For root queries (RootQueryHandler), + // only this raw string is provided since filter validation requires a known schema. + Filter string + + // FilterValidator represents the parsed and tokenized filter query parameter. // It is an optional parameter and thus will be nil when the parameter is not present. + // For root queries (RootQueryHandler), this is always nil since filter validation requires + // a known schema. FilterValidator *filter.Validator // StartIndex The 1-based index of the first query result. A value less than 1 SHALL be interpreted as 1. @@ -112,3 +119,51 @@ type ResourceHandler interface { // More information in Section 3.5.2 of RFC 7644: https://tools.ietf.org/html/rfc7644#section-3.5.2 Patch(r *http.Request, id string, operations []PatchOperation) (Resource, error) } + +// ResourceTypeFilter associates a resource type with a validated filter. +type ResourceTypeFilter struct { + // ResourceType is the resource type whose schema the filter validated against. + ResourceType ResourceType + // Validator is the filter validator for this resource type's schema. + Validator filter.Validator +} + +// ValidateFilterForResourceTypes validates a raw filter expression against each of the given resource types' schemas. +// It returns a ResourceTypeFilter for each resource type whose schema the filter is valid for. +// This is useful for RootQueryHandler implementations to determine which resource types a filter applies to. +// An empty result means the filter is not valid for any of the given resource types. +// A parse error in the filter expression results in an empty result. +func ValidateFilterForResourceTypes(rawFilter string, resourceTypes []ResourceType) []ResourceTypeFilter { + var results []ResourceTypeFilter + for _, rt := range resourceTypes { + s := rt.Schema + attrs := make([]schema.CoreAttribute, len(s.Attributes), len(s.Attributes)+len(schema.CommonAttributes())) + copy(attrs, s.Attributes) + s.Attributes = append(attrs, schema.CommonAttributes()...) + v, err := filter.NewValidator(rawFilter, s, rt.getSchemaExtensions()...) + if err != nil { + return nil + } + if err := v.Validate(); err != nil { + continue + } + results = append(results, ResourceTypeFilter{ + ResourceType: rt, + Validator: v, + }) + } + return results +} + +// RootQueryHandler represents an optional callback that handles queries against the server root endpoint (GET /). +// Per RFC 7644 Section 3.4.2.1, a query against the server root indicates that all resources within the server +// shall be included, subject to filtering. +// +// The server does not validate or parse the filter for root queries because there is no single target schema. +// ListRequestParams.FilterValidator will always be nil for root queries. The raw filter string can be +// obtained from the request via r.URL.Query().Get("filter"). The handler is responsible for interpreting +// the filter (e.g. meta.resourceType eq "User") as appropriate for its backing store. +type RootQueryHandler interface { + // GetAll returns a paginated list of resources across all resource types. + GetAll(r *http.Request, params ListRequestParams) (Page, error) +} diff --git a/resource_handler_test.go b/resource_handler_test.go index 17e557d..073867c 100644 --- a/resource_handler_test.go +++ b/resource_handler_test.go @@ -5,10 +5,12 @@ import ( "math/rand" "net/http" "strings" + "testing" "time" "github.com/elimity-com/scim/errors" "github.com/elimity-com/scim/optional" + "github.com/elimity-com/scim/schema" ) func ExampleResourceHandler() { @@ -18,6 +20,79 @@ func ExampleResourceHandler() { // Output: true } +func TestValidateFilterForResourceTypes(t *testing.T) { + userSchema := getUserSchema() + groupSchema := schema.CoreGroupSchema() + + resourceTypes := []ResourceType{ + { + Name: "User", + Endpoint: "/Users", + Schema: userSchema, + }, + { + Name: "Group", + Endpoint: "/Groups", + Schema: groupSchema, + }, + } + + t.Run("filter matching only User", func(t *testing.T) { + results := ValidateFilterForResourceTypes(`userName eq "john"`, resourceTypes) + assertLen(t, results, 1) + assertEqual(t, "User", results[0].ResourceType.Name) + }) + + t.Run("filter matching only Group", func(t *testing.T) { + results := ValidateFilterForResourceTypes(`members.value eq "123"`, resourceTypes) + assertLen(t, results, 1) + assertEqual(t, "Group", results[0].ResourceType.Name) + }) + + t.Run("filter matching both", func(t *testing.T) { + results := ValidateFilterForResourceTypes(`displayName eq "test"`, resourceTypes) + assertLen(t, results, 2) + }) + + t.Run("meta.resourceType filter matches all", func(t *testing.T) { + results := ValidateFilterForResourceTypes(`meta.resourceType eq "User"`, resourceTypes) + assertLen(t, results, 2) + }) + + t.Run("unparseable filter", func(t *testing.T) { + results := ValidateFilterForResourceTypes(`not a valid ((( filter`, resourceTypes) + assertLen(t, results, 0) + }) + + t.Run("does not mutate original schema attributes", func(t *testing.T) { + // Create a schema with spare capacity so append can mutate the backing array. + commonAttrs := schema.CommonAttributes() + attrs := make([]schema.CoreAttribute, 1, 1+len(commonAttrs)) + attrs[0] = schema.SimpleCoreAttribute(schema.SimpleStringParams(schema.StringParams{ + Name: "userName", + })) + + // Extend into spare capacity to observe backing array writes. + full := attrs[:cap(attrs)] + + rt := []ResourceType{ + { + Name: "User", + Endpoint: "/Users", + Schema: schema.Schema{ + ID: "urn:ietf:params:scim:schemas:core:2.0:User", + Attributes: attrs, + }, + }, + } + + ValidateFilterForResourceTypes(`userName eq "john"`, rt) + + // If append mutated the backing array, full[1] now holds a common attribute. + assertEqual(t, "", full[1].Name()) + }) +} + type testData struct { resourceAttributes ResourceAttributes meta map[string]string diff --git a/resource_type.go b/resource_type.go index 2716c5b..a18a49d 100644 --- a/resource_type.go +++ b/resource_type.go @@ -82,7 +82,9 @@ func (t ResourceType) schemaWithCommon() schema.Schema { }), ) - s.Attributes = append(s.Attributes, externalID) + attrs := make([]schema.CoreAttribute, len(s.Attributes), len(s.Attributes)+1) + copy(attrs, s.Attributes) + s.Attributes = append(attrs, externalID) return s } diff --git a/server.go b/server.go index 66f8d88..f12335b 100644 --- a/server.go +++ b/server.go @@ -24,7 +24,11 @@ func getFilterValidator(r *http.Request, s schema.Schema, extensions ...schema.S return nil, nil // No filter present. } - validator, err := filter.NewValidator(f, s, extensions...) + withCommon := s + attrs := make([]schema.CoreAttribute, len(s.Attributes), len(s.Attributes)+len(schema.CommonAttributes())) + copy(attrs, s.Attributes) + withCommon.Attributes = append(attrs, schema.CommonAttributes()...) + validator, err := filter.NewValidator(f, withCommon, extensions...) if err != nil { return nil, err } @@ -68,10 +72,11 @@ func resourceLocation(resourceType ResourceType, id string, baseURL string) stri // Server represents a SCIM server which implements the HTTP-based SCIM protocol // that makes managing identities in multi-domain scenarios easier to support via a standardized service. type Server struct { - config ServiceProviderConfig - resourceTypes []ResourceType - log Logger - baseURL string + config ServiceProviderConfig + resourceTypes []ResourceType + rootQueryHandler RootQueryHandler + log Logger + baseURL string } func NewServer(args *ServerArgs, opts ...ServerOption) (Server, error) { @@ -108,6 +113,20 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := strings.TrimPrefix(r.URL.Path, "/v2") switch { + case (path == "/" || path == "") && r.Method == http.MethodGet: + if s.rootQueryHandler == nil { + s.errorHandler(w, &errors.ScimErrorTooMany) + return + } + s.rootResourcesGetHandler(w, r) + return + case path == "/.search" && r.Method == http.MethodPost: + if s.rootQueryHandler == nil { + s.errorHandler(w, &errors.ScimErrorTooMany) + return + } + s.rootSearchHandler(w, r) + return case path == "/Me": s.errorHandler(w, &errors.ScimError{ Status: http.StatusNotImplemented, @@ -215,7 +234,7 @@ func (s Server) getSchemas() []schema.Schema { return schemas } -func (s Server) parseRequestParams(r *http.Request, refSchema schema.Schema, refExtensions ...schema.Schema) (ListRequestParams, *errors.ScimError) { +func (s Server) parsePaginationParams(r *http.Request) (count, startIndex int, _ *errors.ScimError) { invalidParams := make([]string, 0) defaultCount := s.config.getItemsPerPage() @@ -224,11 +243,9 @@ func (s Server) parseRequestParams(r *http.Request, refSchema schema.Schema, ref invalidParams = append(invalidParams, "count") } if count > defaultCount { - // Ensure the count isn't more then the allowable max. count = defaultCount } if count < 0 { - // A negative value shall be interpreted as 0. count = 0 } @@ -242,7 +259,16 @@ func (s Server) parseRequestParams(r *http.Request, refSchema schema.Schema, ref if len(invalidParams) > 0 { scimErr := errors.ScimErrorBadParams(invalidParams) - return ListRequestParams{}, &scimErr + return 0, 0, &scimErr + } + + return count, startIndex, nil +} + +func (s Server) parseRequestParams(r *http.Request, refSchema schema.Schema, refExtensions ...schema.Schema) (ListRequestParams, *errors.ScimError) { + count, startIndex, scimErr := s.parsePaginationParams(r) + if scimErr != nil { + return ListRequestParams{}, scimErr } validator, err := getFilterValidator(r, refSchema, refExtensions...) @@ -252,6 +278,7 @@ func (s Server) parseRequestParams(r *http.Request, refSchema schema.Schema, ref return ListRequestParams{ Count: count, + Filter: strings.TrimSpace(r.URL.Query().Get("filter")), FilterValidator: validator, StartIndex: startIndex, }, nil @@ -282,6 +309,17 @@ func WithLogger(logger Logger) ServerOption { } } +// WithRootQueryHandler sets a handler for queries against the server root endpoint (GET /). +// Per RFC 7644 Section 3.4.2.1, a query against the server root indicates that all resources +// within the server shall be included, subject to filtering. +func WithRootQueryHandler(h RootQueryHandler) ServerOption { + return func(s *Server) { + if h != nil { + s.rootQueryHandler = h + } + } +} + // statusResponseWriter wraps http.ResponseWriter to ensure WriteHeader is // always called explicitly. If Write is called without a prior WriteHeader, // it defaults to http.StatusOK. Subsequent WriteHeader calls are ignored.