diff --git a/api/v1_grants.go b/api/v1_grants.go index 28a0c1e7..59ad4549 100644 --- a/api/v1_grants.go +++ b/api/v1_grants.go @@ -20,11 +20,11 @@ type createGrantBody struct { } type addManagerBody struct { - ManagerUserId string `json:"manager_user_id"` + ManagerUserId trashid.HashId `json:"manager_user_id" validate:"required,min=1"` } type approveGrantBody struct { - GrantorUserId string `json:"grantor_user_id"` + GrantorUserId trashid.HashId `json:"grantor_user_id" validate:"required,min=1"` } // postV1UsersGrant creates a grant from the user to an app (user authorizes app to act on their behalf) @@ -162,17 +162,13 @@ func (app *ApiServer) postV1UsersManager(c *fiber.Ctx) error { "error": "Invalid request body", }) } - managerUserID, err := trashid.DecodeHashId(body.ManagerUserId) - if err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ - "error": "Invalid manager_user_id", - }) + if err := app.requestValidator.Validate(&body); err != nil { + return err } - // Get manager's wallet (grantee_address) users, err := app.queries.Users(c.Context(), dbv1.GetUsersParams{ MyID: 0, - Ids: []int32{int32(managerUserID)}, + Ids: []int32{int32(body.ManagerUserId)}, }) if err != nil || len(users) == 0 { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ @@ -316,20 +312,16 @@ func (app *ApiServer) postV1UsersApproveGrant(c *fiber.Ctx) error { "error": "Invalid request body", }) } - grantorUserID, err := trashid.DecodeHashId(body.GrantorUserId) - if err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ - "error": "Invalid grantor_user_id", - }) + if err := app.requestValidator.Validate(&body); err != nil { + return err } - signer, err := app.getApiSigner(c) if err != nil { return err } nonce := time.Now().UnixNano() - metadata := map[string]interface{}{"grantor_user_id": int64(grantorUserID)} + metadata := map[string]interface{}{"grantor_user_id": int64(body.GrantorUserId)} metadataBytes, _ := json.Marshal(metadata) manageEntityTx := &corev1.ManageEntityLegacy{ diff --git a/api/v1_playlist.go b/api/v1_playlist.go index f9eab63d..a28129ad 100644 --- a/api/v1_playlist.go +++ b/api/v1_playlist.go @@ -16,9 +16,9 @@ import ( ) type PlaylistTrackInfo struct { - TrackId string `json:"track_id" validate:"required"` - Timestamp int64 `json:"timestamp" validate:"required,min=0"` - MetadataTimestamp *int64 `json:"metadata_timestamp,omitempty" validate:"omitempty,min=0"` + TrackId trashid.IntId `json:"track_id" validate:"required,min=1"` + Timestamp int64 `json:"timestamp" validate:"required,min=0"` + MetadataTimestamp *int64 `json:"metadata_timestamp,omitempty" validate:"omitempty,min=0"` } type CreatePlaylistRequest struct { diff --git a/trashid/hashid.go b/trashid/hashid.go index bd600b24..57a08884 100644 --- a/trashid/hashid.go +++ b/trashid/hashid.go @@ -60,6 +60,30 @@ func MustDecodeHashID(id string) int { return val } +// IntId accepts a hash ID or raw int on input (JSON unmarshal) but +// always marshals back as a plain integer. Use this for fields that are +// part of chain metadata where the indexer expects numeric IDs. +type IntId int + +func (num IntId) MarshalJSON() ([]byte, error) { + return []byte(strconv.Itoa(int(num))), nil +} + +func (num *IntId) UnmarshalJSON(data []byte) error { + if data[0] == '"' { + idStr := strings.Trim(string(data), `"`) + id, err := DecodeHashId(idStr) + if err != nil { + return err + } + *num = IntId(id) + return nil + } + val, err := strconv.Atoi(string(data)) + *num = IntId(val) + return err +} + // type alias for int that will do hashid on the way out the door type HashId int diff --git a/trashid/hashid_test.go b/trashid/hashid_test.go index 6a649a63..66d7c94e 100644 --- a/trashid/hashid_test.go +++ b/trashid/hashid_test.go @@ -46,3 +46,38 @@ func TestHashId(t *testing.T) { assert.Equal(t, 0, int(h)) } } + +func TestIntId(t *testing.T) { + + // when we serialize... it emits a plain number (not a hash string) + { + i := IntId(44) + j, err := json.Marshal(i) + assert.NoError(t, err) + assert.Equal(t, `44`, string(j)) + } + + // when we parse a hashid string... it decodes to the numeric value + { + var i IntId + err := json.Unmarshal([]byte(`"eYorL"`), &i) + assert.NoError(t, err) + assert.Equal(t, 44, int(i)) + } + + // when we parse a raw number... it works as-is + { + var i IntId + err := json.Unmarshal([]byte("33"), &i) + assert.NoError(t, err) + assert.Equal(t, 33, int(i)) + } + + // errors on bad hashid string + { + var i IntId + err := json.Unmarshal([]byte(`"asdjkfalksdjfaklsdjf"`), &i) + assert.Error(t, err) + assert.Equal(t, 0, int(i)) + } +}