diff --git a/backend/go.mod b/backend/go.mod index 0adeaa41..339dbde3 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -19,6 +19,7 @@ require ( github.com/golang-migrate/migrate/v4 v4.18.2 github.com/google/uuid v1.6.0 github.com/hashicorp/go-uuid v1.0.3 + github.com/jinzhu/copier v0.4.0 github.com/joho/godotenv v1.5.1 github.com/lestrrat-go/httprc/v3 v3.0.0-beta2 github.com/lestrrat-go/jwx/v3 v3.0.1 diff --git a/backend/go.sum b/backend/go.sum index 7e3a5370..9fccec08 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -140,6 +140,8 @@ github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh6 github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= +github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= diff --git a/backend/internal/dto/dto_mapper.go b/backend/internal/dto/dto_mapper.go index f727dfc9..ff69d43e 100644 --- a/backend/internal/dto/dto_mapper.go +++ b/backend/internal/dto/dto_mapper.go @@ -1,162 +1,27 @@ package dto import ( - "errors" - "reflect" - "time" + "fmt" - datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" + "github.com/jinzhu/copier" ) // MapStructList maps a list of source structs to a list of destination structs -func MapStructList[S any, D any](source []S, destination *[]D) error { - *destination = make([]D, 0, len(source)) +func MapStructList[S any, D any](source []S, destination *[]D) (err error) { + *destination = make([]D, len(source)) - for _, item := range source { - var destItem D - if err := MapStruct(item, &destItem); err != nil { - return err + for i, item := range source { + err = MapStruct(item, &((*destination)[i])) + if err != nil { + return fmt.Errorf("failed to map field %d: %w", i, err) } - *destination = append(*destination, destItem) } return nil } // MapStruct maps a source struct to a destination struct -func MapStruct[S any, D any](source S, destination *D) error { - // Ensure destination is a non-nil pointer - destValue := reflect.ValueOf(destination) - if destValue.Kind() != reflect.Ptr || destValue.IsNil() { - return errors.New("destination must be a non-nil pointer to a struct") - } - - // Ensure source is a struct - sourceValue := reflect.ValueOf(source) - if sourceValue.Kind() != reflect.Struct { - return errors.New("source must be a struct") - } - - return mapStructInternal(sourceValue, destValue.Elem()) -} - -func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error { - for i := 0; i < destVal.NumField(); i++ { - destField := destVal.Field(i) - destFieldType := destVal.Type().Field(i) - - if destFieldType.Anonymous { - if err := mapStructInternal(sourceVal, destField); err != nil { - return err - } - continue - } - - sourceField := sourceVal.FieldByName(destFieldType.Name) - - if sourceField.IsValid() && destField.CanSet() { - if err := mapField(sourceField, destField); err != nil { - return err - } - } - } - return nil -} - -//nolint:gocognit -func mapField(sourceField reflect.Value, destField reflect.Value) error { - // Handle pointer to struct in source - if sourceField.Kind() == reflect.Ptr && !sourceField.IsNil() { - switch { - case sourceField.Elem().Kind() == reflect.Struct: - switch { - case destField.Kind() == reflect.Struct: - // Map from pointer to struct -> struct - return mapStructInternal(sourceField.Elem(), destField) - case destField.Kind() == reflect.Ptr && destField.CanSet(): - // Map from pointer to struct -> pointer to struct - if destField.IsNil() { - destField.Set(reflect.New(destField.Type().Elem())) - } - return mapStructInternal(sourceField.Elem(), destField.Elem()) - } - case destField.Kind() == reflect.Ptr && - destField.CanSet() && - sourceField.Elem().Type().AssignableTo(destField.Type().Elem()): - // Handle primitive pointer types (e.g., *string to *string) - if destField.IsNil() { - destField.Set(reflect.New(destField.Type().Elem())) - } - destField.Elem().Set(sourceField.Elem()) - return nil - case destField.Kind() != reflect.Ptr && - destField.CanSet() && - sourceField.Elem().Type().AssignableTo(destField.Type()): - // Handle *T to T conversion for primitive types - destField.Set(sourceField.Elem()) - return nil - } - } - - // Handle pointer to struct in destination - if destField.Kind() == reflect.Ptr && destField.CanSet() { - switch { - case sourceField.Kind() == reflect.Struct: - // Map from struct -> pointer to struct - if destField.IsNil() { - destField.Set(reflect.New(destField.Type().Elem())) - } - return mapStructInternal(sourceField, destField.Elem()) - case !sourceField.IsZero() && sourceField.Type().AssignableTo(destField.Type().Elem()): - // Handle T to *T conversion for primitive types - if destField.IsNil() { - destField.Set(reflect.New(destField.Type().Elem())) - } - destField.Elem().Set(sourceField) - return nil - } - } - - switch { - case sourceField.Type() == destField.Type(): - destField.Set(sourceField) - case sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice: - return mapSlice(sourceField, destField) - case sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct: - return mapStructInternal(sourceField, destField) - default: - return mapSpecialTypes(sourceField, destField) - } - return nil -} - -func mapSlice(sourceField reflect.Value, destField reflect.Value) error { - if sourceField.Type().Elem() == destField.Type().Elem() { - newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap()) - for j := 0; j < sourceField.Len(); j++ { - newSlice.Index(j).Set(sourceField.Index(j)) - } - destField.Set(newSlice) - } else if sourceField.Type().Elem().Kind() == reflect.Struct && destField.Type().Elem().Kind() == reflect.Struct { - newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap()) - for j := 0; j < sourceField.Len(); j++ { - sourceElem := sourceField.Index(j) - destElem := reflect.New(destField.Type().Elem()).Elem() - if err := mapStructInternal(sourceElem, destElem); err != nil { - return err - } - newSlice.Index(j).Set(destElem) - } - destField.Set(newSlice) - } - return nil -} - -func mapSpecialTypes(sourceField reflect.Value, destField reflect.Value) error { - if _, ok := sourceField.Interface().(datatype.DateTime); ok { - if sourceField.Type() == reflect.TypeOf(datatype.DateTime{}) && destField.Type() == reflect.TypeOf(time.Time{}) { - dateValue := sourceField.Interface().(datatype.DateTime) - destField.Set(reflect.ValueOf(dateValue.ToTime())) - } - } - return nil +func MapStruct(source any, destination any) error { + return copier.CopyWithOption(destination, source, copier.Option{ + DeepCopy: true, + }) } diff --git a/backend/internal/dto/dto_mapper_test.go b/backend/internal/dto/dto_mapper_test.go new file mode 100644 index 00000000..43b86b14 --- /dev/null +++ b/backend/internal/dto/dto_mapper_test.go @@ -0,0 +1,197 @@ +package dto + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pocket-id/pocket-id/backend/internal/model" + datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" + "github.com/pocket-id/pocket-id/backend/internal/utils" +) + +type sourceStruct struct { + AString string + AStringPtr *string + ABool bool + ABoolPtr *bool + ACustomDateTime datatype.DateTime + ACustomDateTimePtr *datatype.DateTime + ANilStringPtr *string + ASlice []string + AMap map[string]int + AStruct embeddedStruct + AStructPtr *embeddedStruct + + StringPtrToString *string + EmptyStringPtrToString *string + NilStringPtrToString *string + IntToInt64 int + AuditLogEventToString model.AuditLogEvent +} + +type destStruct struct { + AString string + AStringPtr *string + ABool bool + ABoolPtr *bool + ACustomDateTime datatype.DateTime + ACustomDateTimePtr *datatype.DateTime + ANilStringPtr *string + ASlice []string + AMap map[string]int + AStruct embeddedStruct + AStructPtr *embeddedStruct + + StringPtrToString string + EmptyStringPtrToString string + NilStringPtrToString string + IntToInt64 int64 + AuditLogEventToString string +} + +type embeddedStruct struct { + Foo string + Bar int64 +} + +func TestMapStruct(t *testing.T) { + src := sourceStruct{ + AString: "abcd", + AStringPtr: utils.Ptr("xyz"), + ABool: true, + ABoolPtr: utils.Ptr(false), + ACustomDateTime: datatype.DateTime(time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC)), + ACustomDateTimePtr: utils.Ptr(datatype.DateTime(time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC))), + ANilStringPtr: nil, + ASlice: []string{"a", "b", "c"}, + AMap: map[string]int{ + "a": 1, + "b": 2, + }, + AStruct: embeddedStruct{ + Foo: "bar", + Bar: 42, + }, + AStructPtr: &embeddedStruct{ + Foo: "quo", + Bar: 111, + }, + + StringPtrToString: utils.Ptr("foobar"), + EmptyStringPtrToString: utils.Ptr(""), + NilStringPtrToString: nil, + IntToInt64: 99, + AuditLogEventToString: model.AuditLogEventAccountCreated, + } + var dst destStruct + err := MapStruct(src, &dst) + require.NoError(t, err) + + assert.Equal(t, src.AString, dst.AString) + _ = assert.NotNil(t, src.AStringPtr) && + assert.Equal(t, *src.AStringPtr, *dst.AStringPtr) + assert.Equal(t, src.ABool, dst.ABool) + _ = assert.NotNil(t, src.ABoolPtr) && + assert.Equal(t, *src.ABoolPtr, *dst.ABoolPtr) + assert.Equal(t, src.ACustomDateTime, dst.ACustomDateTime) + _ = assert.NotNil(t, src.ACustomDateTimePtr) && + assert.Equal(t, *src.ACustomDateTimePtr, *dst.ACustomDateTimePtr) + assert.Nil(t, dst.ANilStringPtr) + assert.Equal(t, src.ASlice, dst.ASlice) + assert.Equal(t, src.AMap, dst.AMap) + assert.Equal(t, "bar", dst.AStruct.Foo) + assert.Equal(t, int64(42), dst.AStruct.Bar) + _ = assert.NotNil(t, src.AStructPtr) && + assert.Equal(t, "quo", dst.AStructPtr.Foo) && + assert.Equal(t, int64(111), dst.AStructPtr.Bar) + assert.Equal(t, "foobar", dst.StringPtrToString) + assert.Empty(t, dst.EmptyStringPtrToString) + assert.Empty(t, dst.NilStringPtrToString) + assert.Equal(t, int64(99), dst.IntToInt64) + assert.Equal(t, "ACCOUNT_CREATED", dst.AuditLogEventToString) +} + +func TestMapStructList(t *testing.T) { + sources := []sourceStruct{ + { + AString: "first", + AStringPtr: utils.Ptr("one"), + ABool: true, + ABoolPtr: utils.Ptr(false), + ACustomDateTime: datatype.DateTime(time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC)), + ACustomDateTimePtr: utils.Ptr(datatype.DateTime(time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC))), + ASlice: []string{"a", "b"}, + AMap: map[string]int{ + "a": 1, + "b": 2, + }, + AStruct: embeddedStruct{ + Foo: "first_struct", + Bar: 10, + }, + IntToInt64: 10, + }, + { + AString: "second", + AStringPtr: utils.Ptr("two"), + ABool: false, + ABoolPtr: utils.Ptr(true), + ACustomDateTime: datatype.DateTime(time.Date(2026, 6, 7, 8, 9, 10, 0, time.UTC)), + ACustomDateTimePtr: utils.Ptr(datatype.DateTime(time.Date(2023, 6, 7, 8, 9, 10, 0, time.UTC))), + ASlice: []string{"c", "d", "e"}, + AMap: map[string]int{ + "c": 3, + "d": 4, + }, + AStruct: embeddedStruct{ + Foo: "second_struct", + Bar: 20, + }, + IntToInt64: 20, + }, + } + + var destinations []destStruct + err := MapStructList(sources, &destinations) + + require.NoError(t, err) + require.Len(t, destinations, 2) + + // Verify first element + assert.Equal(t, "first", destinations[0].AString) + assert.Equal(t, "one", *destinations[0].AStringPtr) + assert.True(t, destinations[0].ABool) + assert.False(t, *destinations[0].ABoolPtr) + assert.Equal(t, datatype.DateTime(time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC)), destinations[0].ACustomDateTime) + assert.Equal(t, datatype.DateTime(time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)), *destinations[0].ACustomDateTimePtr) + assert.Equal(t, []string{"a", "b"}, destinations[0].ASlice) + assert.Equal(t, map[string]int{"a": 1, "b": 2}, destinations[0].AMap) + assert.Equal(t, "first_struct", destinations[0].AStruct.Foo) + assert.Equal(t, int64(10), destinations[0].AStruct.Bar) + assert.Equal(t, int64(10), destinations[0].IntToInt64) + + // Verify second element + assert.Equal(t, "second", destinations[1].AString) + assert.Equal(t, "two", *destinations[1].AStringPtr) + assert.False(t, destinations[1].ABool) + assert.True(t, *destinations[1].ABoolPtr) + assert.Equal(t, datatype.DateTime(time.Date(2026, 6, 7, 8, 9, 10, 0, time.UTC)), destinations[1].ACustomDateTime) + assert.Equal(t, datatype.DateTime(time.Date(2023, 6, 7, 8, 9, 10, 0, time.UTC)), *destinations[1].ACustomDateTimePtr) + assert.Equal(t, []string{"c", "d", "e"}, destinations[1].ASlice) + assert.Equal(t, map[string]int{"c": 3, "d": 4}, destinations[1].AMap) + assert.Equal(t, "second_struct", destinations[1].AStruct.Foo) + assert.Equal(t, int64(20), destinations[1].AStruct.Bar) + assert.Equal(t, int64(20), destinations[1].IntToInt64) +} + +func TestMapStructList_EmptySource(t *testing.T) { + var sources []sourceStruct + var destinations []destStruct + + err := MapStructList(sources, &destinations) + require.NoError(t, err) + assert.Empty(t, destinations) +} diff --git a/backend/internal/dto/oidc_dto.go b/backend/internal/dto/oidc_dto.go index f9cab715..22f8a805 100644 --- a/backend/internal/dto/oidc_dto.go +++ b/backend/internal/dto/oidc_dto.go @@ -149,7 +149,7 @@ type AuthorizedOidcClientDto struct { } type OidcClientPreviewDto struct { - IdToken map[string]interface{} `json:"idToken"` - AccessToken map[string]interface{} `json:"accessToken"` - UserInfo map[string]interface{} `json:"userInfo"` + IdToken map[string]any `json:"idToken"` + AccessToken map[string]any `json:"accessToken"` + UserInfo map[string]any `json:"userInfo"` }