diff --git a/backend/internal/service/ldap_service.go b/backend/internal/service/ldap_service.go index c2da1d2e..2481c881 100644 --- a/backend/internal/service/ldap_service.go +++ b/backend/internal/service/ldap_service.go @@ -13,6 +13,9 @@ import ( "net/url" "strings" "time" + "unicode/utf8" + + "github.com/google/uuid" "github.com/go-ldap/ldap/v3" "github.com/pocket-id/pocket-id/backend/internal/common" @@ -122,7 +125,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap. ldapGroupIDs := make(map[string]struct{}, len(result.Entries)) for _, value := range result.Entries { - ldapId := value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value) + ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value)) // Skip groups without a valid LDAP ID if ldapId == "" { @@ -194,7 +197,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap. syncGroup := dto.UserGroupCreateDto{ Name: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value), FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value), - LdapID: value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value), + LdapID: ldapId, } if databaseGroup.ID == "" { @@ -286,7 +289,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C ldapUserIDs := make(map[string]struct{}, len(result.Entries)) for _, value := range result.Entries { - ldapId := value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value) + ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value)) // Skip users without a valid LDAP ID if ldapId == "" { @@ -468,3 +471,21 @@ func getDNProperty(property string, str string) string { // CN not found, return an empty string return "" } + +// convertLdapIdToString converts LDAP IDs to valid UTF-8 strings. +// LDAP servers may return binary UUIDs (16 bytes) or other non-UTF-8 data. +func convertLdapIdToString(ldapId string) string { + if utf8.ValidString(ldapId) { + return ldapId + } + + // Try to parse as binary UUID (16 bytes) + if len(ldapId) == 16 { + if parsedUUID, err := uuid.FromBytes([]byte(ldapId)); err == nil { + return parsedUUID.String() + } + } + + // As a last resort, encode as base64 to make it UTF-8 safe + return base64.StdEncoding.EncodeToString([]byte(ldapId)) +} diff --git a/backend/internal/service/ldap_service_test.go b/backend/internal/service/ldap_service_test.go index 3a3e8c0b..1a049bfe 100644 --- a/backend/internal/service/ldap_service_test.go +++ b/backend/internal/service/ldap_service_test.go @@ -71,3 +71,36 @@ func TestGetDNProperty(t *testing.T) { }) } } + +func TestConvertLdapIdToString(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "valid UTF-8 string", + input: "simple-utf8-id", + expected: "simple-utf8-id", + }, + { + name: "binary UUID (16 bytes)", + input: string([]byte{0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf1}), + expected: "12345678-9abc-def0-1234-56789abcdef1", + }, + { + name: "non-UTF8, non-UUID returns base64", + input: string([]byte{0xff, 0xfe, 0xfd, 0xfc}), + expected: "//79/A==", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := convertLdapIdToString(tt.input) + if got != tt.expected { + t.Errorf("Expected %q, got %q", tt.expected, got) + } + }) + } +}