Compare commits

..

165 Commits

Author SHA1 Message Date
Elias Schneider
6e4d2a4a33 release: 1.7.0 2025-08-10 20:01:03 +02:00
Elias Schneider
6c65bd34cd chore(translations): update translations via Crowdin (#820) 2025-08-10 19:50:36 +02:00
Kyle Mendell
7bfe4834d0 chore: switch from npm to pnpm (#786)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-08-10 12:16:30 -05:00
Kyle Mendell
484c2f6ef2 feat: user application dashboard (#727)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-08-10 15:56:03 +00:00
Elias Schneider
87956ea725 chore(translations): update translations via Crowdin (#819) 2025-08-10 10:18:30 -05:00
Elias Schneider
32dd403038 chore(translations): update translations via Crowdin (#817) 2025-08-10 14:49:24 +02:00
Elias Schneider
4d59e72866 fix: custom claims input suggestions instantly close after opening 2025-08-08 15:11:44 +02:00
Elias Schneider
9ac5d51187 fix: authorization animation not working 2025-08-08 12:23:32 +02:00
Elias Schneider
5a031f5d1b refactor: use reflection to mark file based env variables (#815) 2025-08-07 20:41:00 +02:00
Alessandro (Ale) Segala
535bc9f46b chore: additional logs for database connections (#813) 2025-08-06 18:04:25 +02:00
Kyle Mendell
f0c144c51c fix: admins can not delete or disable their own account 2025-08-05 16:14:25 -05:00
Elias Schneider
61e4ea45fb chore(translations): update translations via Crowdin (#811) 2025-08-05 15:56:45 -05:00
Etienne
06e1656923 feat: add robots.txt to block indexing (#806) 2025-08-02 18:30:50 +00:00
Alessandro (Ale) Segala
0a3b1c6530 feat: support reading secret env vars from _FILE (#799)
Co-authored-by: Kyle Mendell <ksm@ofkm.us>
2025-07-30 11:59:25 -05:00
Kyle Mendell
d479817b6a feat: add support for code_challenge_methods_supported (#794) 2025-07-29 17:34:49 -05:00
Elias Schneider
01bf31d23d chore(translations): update translations via Crowdin (#791) 2025-07-27 20:21:37 -05:00
Alessandro (Ale) Segala
42a861d206 refactor: complete conversion of log calls to slog (#787) 2025-07-27 04:34:23 +00:00
Alessandro (Ale) Segala
78266e3e4c feat: Support OTel and JSON for logs (via log/slog) (#760)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-07-27 01:03:52 +00:00
Alessandro (Ale) Segala
c8478d75be fix: delete WebAuthn registration session after use (#783) 2025-07-26 18:45:54 -05:00
Elias Schneider
28d93b00a3 chore(translations): update translations via Crowdin (#785) 2025-07-26 16:37:42 -05:00
Kyle Mendell
12a7a6a5c5 chore: update Vietnamese display name 2025-07-26 15:33:36 -05:00
Elias Schneider
a6d5071724 chore(translations): update translations via Crowdin (#782) 2025-07-25 15:48:52 -05:00
Elias Schneider
cebe2242b9 chore(translations): update translations via Crowdin (#779) 2025-07-24 20:28:07 -05:00
Kyle Mendell
56ee7d946f chore: fix federated credentials type error 2025-07-24 20:22:34 -05:00
Kyle Mendell
f3c6521f2b chore: update dependencies and fix zod/4 import path 2025-07-24 20:16:17 -05:00
Kyle Mendell
ffed465f09 chore: update dependencies and fix zod/4 import path 2025-07-24 20:14:25 -05:00
Kyle Mendell
c359b5be06 chore: rename glass-row-item to passkey-row 2025-07-24 19:50:27 -05:00
Elias Schneider
e9a023bb71 chore(translations): update translations via Crowdin (#778)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-07-24 19:35:16 -05:00
Kyle Mendell
60f0b28076 chore(transaltions): add Vietnamese files 2025-07-24 10:11:01 -05:00
Alessandro (Ale) Segala
d541c9ab4a fix: set input type 'email' for email-based login (#776) 2025-07-23 12:39:50 -05:00
dependabot[bot]
024ed53022 chore(deps): bump axios from 1.10.0 to 1.11.0 in /frontend in the npm_and_yarn group across 1 directory (#777)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-23 12:38:00 -05:00
Elias Schneider
2c78bd1b46 chore(translations): update translations via Crowdin (#767) 2025-07-22 15:08:04 -05:00
dependabot[bot]
5602d79611 chore(deps): bump form-data from 4.0.1 to 4.0.4 in /frontend in the npm_and_yarn group across 1 directory (#771)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-22 15:07:43 -05:00
Kyle Mendell
51b73c9c31 chore(translations): add Ukrainian files 2025-07-21 15:56:59 -05:00
Elias Schneider
10f0580a43 chore(translations): update translations via Crowdin (#763) 2025-07-21 07:32:57 -05:00
ItalyPaleAle
a1488565ea release: 1.6.4 2025-07-21 07:44:25 +02:00
Alessandro (Ale) Segala
35d5f887ce fix: migration fails on postgres (#762) 2025-07-20 22:36:22 -07:00
Kyle Mendell
4c76de45ed chore: remove labels from issue templates 2025-07-20 22:51:02 -05:00
Kyle Mendell
68fc9c0659 release: 1.6.3 2025-07-20 22:35:35 -05:00
Kyle Mendell
2952b15755 fix: show rename and delete buttons for passkeys without hovering over the row 2025-07-20 19:09:06 -05:00
Kyle Mendell
ef1d599662 fix: use user-agent for identifying known device signins 2025-07-20 19:02:17 -05:00
Kyle Mendell
4e49d3932a chore: upgrade dependencies (#752) 2025-07-14 23:36:36 -05:00
Elias Schneider
86d3c08494 chore(translations): update translations via Crowdin (#750) 2025-07-14 13:15:33 -05:00
Alessandro (Ale) Segala
7b4ccd1f30 fix: ensure user inputs are normalized (#724)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-07-13 16:15:57 +00:00
Kyle Mendell
f145903eb0 chore: use correct svelte 5 syntax for signup token modal 2025-07-11 22:53:01 -05:00
Kyle Mendell
d3bc1797b6 fix: use object-contain for images on oidc-client list 2025-07-11 22:46:40 -05:00
Kyle Mendell
db94f81937 chore: use issue types for new issues 2025-07-11 22:25:11 -05:00
Kyle Mendell
b03e91b653 fix: allow passkey names up to 50 characters 2025-07-11 22:10:59 -05:00
Kyle Mendell
505bdcb8ba release: 1.6.2 2025-07-09 16:56:34 -05:00
Kyle Mendell
f103a54790 fix: ensure confirmation dialog shows on top of other components 2025-07-09 16:50:01 -05:00
Alessandro (Ale) Segala
e1de593dcd fix: login failures on Postgres when IP is null (#737) 2025-07-09 08:45:07 -05:00
Elias Schneider
45f42772b1 chore(translations): update translations via Crowdin (#730) 2025-07-07 20:06:52 -05:00
XLion
98152640b1 chore(translations): Fix inconsistent punctuation marks for the language name of zh-TW (#731) 2025-07-07 12:54:45 +00:00
github-actions[bot]
04e235e805 chore: update AAGUIDs (#729)
Co-authored-by: stonith404 <58886915+stonith404@users.noreply.github.com>
2025-07-06 21:04:32 -05:00
Elias Schneider
ae737dddaa release: 1.6.1 2025-07-06 22:50:33 +02:00
Elias Schneider
f565c702e5 ci/cd: use latest-distroless tag for latest distroless images 2025-07-06 22:48:55 +02:00
Elias Schneider
f945b44bc9 release: 1.6.0 2025-07-06 20:19:45 +02:00
Elias Schneider
857b9cc864 refactor: run formatter 2025-07-06 15:32:19 +02:00
Elias Schneider
bf042563e9 feat: add support for OAuth 2.0 Authorization Server Issuer Identification 2025-07-06 15:29:26 +02:00
Elias Schneider
49f1ab2f75 fix: custom claims input suggestions flickering 2025-07-06 00:23:06 +02:00
Elias Schneider
e46f60ac8d fix: keep sidebar in settings sticky 2025-07-05 21:59:13 +02:00
Elias Schneider
5c9e504291 fix: show friendly name in user group selection 2025-07-05 21:58:56 +02:00
Alessandro (Ale) Segala
7fe83f8087 fix: actually fix linter issues (#720) 2025-07-04 21:14:44 -05:00
Alessandro (Ale) Segala
43f0114c57 fix: linter issues (#719) 2025-07-04 18:29:28 -05:00
Alessandro (Ale) Segala
1a41b05f60 feat: distroless container additional variant + healthcheck command (#716) 2025-07-04 12:26:01 -07:00
Elias Schneider
81315790a8 fix: support non UTF-8 LDAP IDs (#714) 2025-07-04 08:42:11 +02:00
Alessandro (Ale) Segala
8c8fc2304d feat: add "key-rotate" command (#709) 2025-07-03 22:23:24 +02:00
Elias Schneider
15ece0ab30 chore(translations): update translations via Crowdin (#712) 2025-07-03 13:47:27 -05:00
Alessandro (Ale) Segala
5550729120 feat: encrypt private keys saved on disk and in database (#682)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-07-03 13:34:34 -05:00
Elias Schneider
9872608d61 fix: allow profile picture update even if "allow own account edit" enabled 2025-07-03 10:57:56 +02:00
Elias Schneider
be52660227 feat: enhance language selection message and add translation contribution link 2025-07-03 09:20:39 +02:00
Elias Schneider
237342e876 chore(translations): update translations via Crowdin (#707) 2025-07-02 13:45:10 +02:00
Elias Schneider
cfbfbc9753 chore(translations): update translations via Crowdin (#705) 2025-07-01 17:11:42 -05:00
Elias Schneider
aefb308536 fix: token introspection authentication not handled correctly (#704)
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
2025-07-01 21:14:07 +00:00
Alessandro (Ale) Segala
031181ad2a fix: auth fails when client IP is empty on Postgres (#695) 2025-06-30 14:04:30 +02:00
Elias Schneider
dbf3da41f3 chore(translations): update translations via Crowdin (#699)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-06-29 13:04:15 -05:00
Kyle Mendell
3a2902789e chore: use correct team name for codeowners 2025-06-29 09:31:32 -05:00
Kyle Mendell
459a4fd727 chore: update CODEOWNERS to be global 2025-06-29 09:30:00 -05:00
Kyle Mendell
2ecc1abbad chore: add CODEOWNERS file 2025-06-29 09:28:08 -05:00
Kyle Mendell
92c57ada1a fix: app config forms not updating with latest values (#696) 2025-06-29 15:13:06 +02:00
Elias Schneider
fceb6fa7b4 fix: add missing error check in initial user setup 2025-06-29 15:10:39 +02:00
Alessandro (Ale) Segala
c290c027fb refactor: use github.com/jinzhu/copier for MapStruct (#698) 2025-06-29 15:01:10 +02:00
Elias Schneider
ca205a8c73 chore(translations): update translations via Crowdin (#697) 2025-06-29 01:01:39 -05:00
Elias Schneider
968cf0b307 chore(translations): update translations via Crowdin (#694) 2025-06-28 21:25:58 -05:00
Elias Schneider
fd8bee94a4 chore(translations): update translations via Crowdin (#692) 2025-06-28 15:26:17 +02:00
Manuel Rais
41ac1be082 chore(translations) : translate missing french values (#691) 2025-06-28 15:26:05 +02:00
Elias Schneider
dd9b1d26ea release: 1.5.0 2025-06-27 23:56:16 +02:00
Elias Schneider
4b829757b2 tests: fix e2e tests 2025-06-27 23:52:43 +02:00
Elias Schneider
b5b01cb6dd chore(translations): update translations via Crowdin (#688) 2025-06-27 23:42:32 +02:00
Elias Schneider
287314f016 feat: improve initial admin creation workflow 2025-06-27 23:41:05 +02:00
Elias Schneider
73e7e0b1c5 refactor: add formatter to Playwright tests 2025-06-27 23:33:26 +02:00
Elias Schneider
d070b9a778 fix: double double full stops for certain error messages 2025-06-27 22:43:31 +02:00
Elias Schneider
d976bf5965 fix: improve accent color picker disabled state 2025-06-27 22:38:21 +02:00
Elias Schneider
052ac008c3 fix: margin of user sign up description 2025-06-27 22:31:55 +02:00
Elias Schneider
57a2b2bc83 chore(translations): update translations via Crowdin (#687) 2025-06-27 22:24:36 +02:00
ElevenNotes
043f82ad79 fix: less noisy logging for certain GET requests (#681)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-06-27 22:24:22 +02:00
Elias Schneider
ba61cdba4e feat: redact sensitive app config variables if set with env variable 2025-06-27 22:22:28 +02:00
Kyle Mendell
dcd1ae96e0 feat: self-service user signup (#672)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-06-27 15:01:10 -05:00
Elias Schneider
1fdb058386 docs: clarify confusing user update logic 2025-06-27 17:20:51 +02:00
Elias Schneider
29cb5513a0 fix: users can't be updated by admin if self account editing is disabled 2025-06-27 17:15:26 +02:00
Elias Schneider
6db57d9f27 chore(translations): update translations via Crowdin (#683) 2025-06-26 19:01:16 +02:00
Elias Schneider
1a77bd9914 fix: error page flickering after sign out 2025-06-24 21:56:40 +02:00
Elias Schneider
350335711b chore(translations): update translations via Crowdin (#677) 2025-06-24 09:00:57 -05:00
Ryan Kaskel
988c425150 fix: remove duplicate request logging (#678) 2025-06-24 13:48:11 +00:00
Elias Schneider
23827ba1d1 release: 1.4.1 2025-06-22 21:30:07 +02:00
Elias Schneider
7d36bda769 fix: app not starting if UI config is disabled and Postgres is used 2025-06-22 21:21:14 +02:00
Manuel Rais
8c559ea067 chore(translations) : typo in french language (#669) 2025-06-22 18:58:59 +00:00
Elias Schneider
88832d4bc9 chore(translations): update translations via Crowdin (#663) 2025-06-20 11:11:42 +02:00
Kyle Mendell
f5cece3b0e release: 1.4.0 2025-06-19 13:21:49 -05:00
Kyle Mendell
d5485238b8 feat: configurable local ipv6 ranges for audit log (#657) 2025-06-19 19:56:27 +02:00
Kyle Mendell
ac5a121f66 feat: location filter for global audit log (#662) 2025-06-19 17:12:53 +00:00
Elias Schneider
481df3bcb9 chore: add configuration for backend hot reloading 2025-06-19 18:45:01 +02:00
Mr Snake
7677a3de2c feat: allow setting unix socket mode (#661) 2025-06-18 18:41:57 +02:00
Elias Schneider
1f65c01b04 chore(translations): update translations via Crowdin (#659) 2025-06-18 09:11:36 -05:00
Elias Schneider
d5928f6fea chore: remove unused crypto util 2025-06-17 17:55:52 +02:00
Elias Schneider
bef77ac8dc fix: use inline style for dynamic background image URL instead of Tailwind class 2025-06-17 13:10:02 +02:00
Elias Schneider
c8eb034c49 chore: use v1 tag in example docker-compose.yml 2025-06-16 23:31:16 +02:00
Elias Schneider
c77167df46 ci/cd: cancel build-next action if new one starts 2025-06-16 16:12:33 +02:00
Elias Schneider
3717a663d9 ci/cd: only build required binaries for next image 2025-06-16 16:09:09 +02:00
Elias Schneider
5814549cbe refactor: run formatter 2025-06-16 16:06:11 +02:00
Elias Schneider
2e5d268798 fix: explicitly cache images to prevent unexpected behavior 2025-06-16 15:59:14 +02:00
Elias Schneider
4ed312251e chore(translations): update translations via Crowdin (#652) 2025-06-15 09:57:08 -05:00
Elias Schneider
946c534b08 fix: center oidc client images if they are smaller than the box 2025-06-14 19:33:40 +02:00
Kyle Mendell
883877adec feat: ui accent colors (#643)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-06-13 07:06:54 -05:00
Elias Schneider
215531d65c feat: use icon instead of text on application image update hover state 2025-06-13 12:07:14 +02:00
Elias Schneider
c0f055c3c0 chore(translations): update translations via Crowdin (#649) 2025-06-10 22:01:50 -05:00
Alessandro (Ale) Segala
d77044882d fix: reduce duration of animations on login and signin page (#648) 2025-06-10 21:14:55 +02:00
Alessandro (Ale) Segala
d6795300b1 feat: auto-focus on the login buttons (#647) 2025-06-10 21:13:36 +02:00
Elias Schneider
fd3c76ffa3 refactor: run formatter 2025-06-10 14:43:56 +02:00
Amazingca
698bc3a35a chore(translations): Update spelling and grammar in en.json (#650) 2025-06-10 07:34:49 -05:00
Elias Schneider
1bcb50edc3 fix: allow images with uppercase file extension 2025-06-10 11:11:03 +02:00
Elias Schneider
9700afb9cb chore(translations): update translations via Crowdin (#644) 2025-06-10 10:36:52 +02:00
Elias Schneider
9ce82fb205 release: 1.3.1 2025-06-09 22:59:13 +02:00
Elias Schneider
2935236ace fix: change timestamp of client_credentials.sql migration 2025-06-09 22:58:56 +02:00
Elias Schneider
c821b675b8 release: 1.3.0 2025-06-09 21:37:27 +02:00
Elias Schneider
a09d529027 chore: add branch check to release script 2025-06-09 21:37:00 +02:00
Alessandro (Ale) Segala
b62b61fb01 feat: allow introspection and device code endpoints to use Federated Client Credentials (#640) 2025-06-09 21:17:55 +02:00
Alessandro (Ale) Segala
df5c1ed1f8 chore: add docs link and rename to Federated Client Credentials (#636)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-06-09 21:15:37 +02:00
Elias Schneider
f4af35f86b chore(translations): update translations via Crowdin (#642) 2025-06-09 18:36:37 +02:00
Elias Schneider
657a51f7ed fix: misleading text for disable animations option 2025-06-09 18:22:55 +02:00
Elias Schneider
575b2f71e9 fix: use full width for audit log filters 2025-06-09 18:16:53 +02:00
Elias Schneider
97f7326da4 feat: new color theme for the UI 2025-06-09 18:09:13 +02:00
Elias Schneider
242d87a54b chore: upgrade to Shadcn v1.0.0 2025-06-09 18:08:39 +02:00
Kyle Mendell
c111b79147 feat: oidc client data preview (#624)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-06-09 15:46:03 +00:00
Elias Schneider
61bf14225b chore(translations): update translations via Crowdin (#637) 2025-06-09 11:58:14 +02:00
github-actions[bot]
c1e98411b6 chore: update AAGUIDs (#639)
Co-authored-by: stonith404 <58886915+stonith404@users.noreply.github.com>
2025-06-09 11:48:21 +02:00
Elias Schneider
b25e95fc4a ci/cd: add missing attestions permission 2025-06-08 16:15:23 +02:00
Elias Schneider
3cc82d8522 docs: remove difficult to maintain OpenAPI properties 2025-06-08 16:10:42 +02:00
Elias Schneider
ea4e48680c docs: fix pagination API docs 2025-06-08 16:04:58 +02:00
Elias Schneider
f403eed12c ci/cd: add missing permission 2025-06-08 16:03:40 +02:00
Elias Schneider
388a874922 refactor: upgrade to Zod v4 (#623) 2025-06-08 15:44:59 +02:00
Elias Schneider
9a4aab465a chore(translations): update translations via Crowdin (#632) 2025-06-08 15:44:22 +02:00
Kyle Mendell
a052cd6619 ci/cd: add workflow for building 'next' docker image (#633)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-06-08 15:42:41 +02:00
Elias Schneider
31a803b243 chore(translations): add Traditional Chinese files 2025-06-07 21:04:48 +02:00
Elias Schneider
1d2e41c04e chore(translations): update translations via Crowdin (#629) 2025-06-07 21:01:49 +02:00
Elias Schneider
b650d6d423 chore(translations): add Danish language files 2025-06-06 16:27:33 +02:00
Elias Schneider
156aad3057 chore(translations): update translations via Crowdin (#620) 2025-06-06 16:25:37 +02:00
Alessandro (Ale) Segala
05bfe00924 feat: JWT bearer assertions for client authentication (#566)
Co-authored-by: Kyle Mendell <ksm@ofkm.us>
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-06-06 12:23:51 +02:00
Mr Snake
035b2c022b feat: add unix socket support (#615) 2025-06-06 07:01:19 +00:00
Elias Schneider
61b62d4612 fix: OIDC client image can't be deleted 2025-06-06 08:50:33 +02:00
Elias Schneider
dc5d7bb2f3 refactor: run fomratter 2025-06-05 22:43:24 +02:00
Elias Schneider
5e9096e328 fix: UI config overridden by env variables don't apply on first start 2025-06-05 22:36:55 +02:00
Elias Schneider
34b4ba514f chore(translations): update translations via Crowdin (#614) 2025-06-05 15:42:28 +02:00
Elias Schneider
d217083059 feat: add API endpoint for user authorized clients 2025-06-04 09:23:44 +02:00
Elias Schneider
bdcef60cab fix: don't load app config and user on every route change 2025-06-04 08:52:34 +02:00
283 changed files with 19658 additions and 9348 deletions

View File

@@ -1,4 +1,4 @@
node_modules
**/node_modules
# Output
.output

1
.github/CODEOWNERS vendored Normal file
View File

@@ -0,0 +1 @@
* @pocket-id/maintainers

View File

@@ -1,7 +1,7 @@
name: "🐛 Bug Report"
description: "Report something that is not working as expected"
title: "🐛 Bug Report: "
labels: [bug]
type: 'Bug'
body:
- type: markdown
attributes:

View File

@@ -1,7 +1,7 @@
name: 🚀 Feature
description: "Submit a proposal for a new feature"
title: "🚀 Feature: "
labels: [feature]
type: 'Feature'
body:
- type: textarea
id: feature-description

View File

@@ -1,7 +1,7 @@
name: "🌐 Language request"
description: "You want to contribute to a language that isn't on Crowdin yet?"
title: "🌐 Language Request: <language name in english>"
labels: [language-request]
type: 'Language Request'
body:
- type: input
id: language-name-native

100
.github/workflows/build-next.yml vendored Normal file
View File

@@ -0,0 +1,100 @@
name: Build Next Image
on:
push:
branches:
- main
concurrency:
group: build-next-image
cancel-in-progress: true
jobs:
build-next:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
id-token: write
attestations: write
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: 22
cache: 'pnpm'
cache-dependency-path: pnpm-lock.yaml
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: 'backend/go.mod'
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set DOCKER_IMAGE_NAME
run: |
# Lowercase REPO_OWNER which is required for containers
REPO_OWNER=${{ github.repository_owner }}
DOCKER_IMAGE_NAME="ghcr.io/${REPO_OWNER,,}/pocket-id"
echo "DOCKER_IMAGE_NAME=${DOCKER_IMAGE_NAME}" >>${GITHUB_ENV}
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Install frontend dependencies
run: pnpm install --frozen-lockfile
- name: Build frontend
working-directory: frontend
run: pnpm run build
- name: Build binaries
run: sh scripts/development/build-binaries.sh --docker-only
- name: Build and push container image
id: build-push-image
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ env.DOCKER_IMAGE_NAME }}:next
file: Dockerfile-prebuilt
- name: Build and push container image (distroless)
uses: docker/build-push-action@v6
id: container-build-push-distroless
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ env.DOCKER_IMAGE_NAME }}:next-distroless
file: Dockerfile-distroless
- name: Container image attestation
uses: actions/attest-build-provenance@v2
with:
subject-name: '${{ env.DOCKER_IMAGE_NAME }}'
subject-digest: ${{ steps.build-push-image.outputs.digest }}
push-to-registry: true
- name: Container image attestation (distroless)
uses: actions/attest-build-provenance@v2
with:
subject-name: '${{ env.DOCKER_IMAGE_NAME }}'
subject-digest: ${{ steps.container-build-push-distroless.outputs.digest }}
push-to-registry: true

View File

@@ -3,15 +3,15 @@ on:
push:
branches: [main]
paths-ignore:
- "docs/**"
- "**.md"
- ".github/**"
- 'docs/**'
- '**.md'
- '.github/**'
pull_request:
branches: [main]
paths-ignore:
- "docs/**"
- "**.md"
- ".github/**"
- 'docs/**'
- '**.md'
- '.github/**'
jobs:
build:
@@ -54,18 +54,22 @@ jobs:
needs: build
steps:
- uses: actions/checkout@v4
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10
- uses: actions/setup-node@v4
with:
node-version: 22
cache: "npm"
cache-dependency-path: frontend/package-lock.json
cache: 'pnpm'
cache-dependency-path: pnpm-lock.yaml
- name: Cache Playwright Browsers
uses: actions/cache@v3
id: playwright-cache
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-${{ hashFiles('frontend/package-lock.json') }}
key: ${{ runner.os }}-playwright-${{ hashFiles('pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-playwright-
@@ -96,13 +100,12 @@ jobs:
run: docker load < /tmp/lldap-image.tar
- name: Install test dependencies
working-directory: ./tests
run: npm ci
run: pnpm --filter pocket-id-tests install --frozen-lockfile
- name: Install Playwright Browsers
working-directory: ./tests
if: steps.playwright-cache.outputs.cache-hit != 'true'
run: npx playwright install --with-deps chromium
run: pnpm dlx playwright install --with-deps chromium
- name: Run Docker Container with Sqlite DB and LDAP
working-directory: ./tests/setup
@@ -111,8 +114,8 @@ jobs:
docker compose logs -f pocket-id &> /tmp/backend.log &
- name: Run Playwright tests
working-directory: ./tests
run: npx playwright test
working-directory: tests
run: pnpm exec playwright test
- name: Upload Test Report
uses: actions/upload-artifact@v4
@@ -141,18 +144,22 @@ jobs:
needs: build
steps:
- uses: actions/checkout@v4
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10
- uses: actions/setup-node@v4
with:
node-version: 22
cache: "npm"
cache-dependency-path: frontend/package-lock.json
cache: 'pnpm'
cache-dependency-path: pnpm-lock.yaml
- name: Cache Playwright Browsers
uses: actions/cache@v3
id: playwright-cache
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-${{ hashFiles('frontend/package-lock.json') }}
key: ${{ runner.os }}-playwright-${{ hashFiles('pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-playwright-
@@ -200,13 +207,12 @@ jobs:
run: docker load -i /tmp/docker-image.tar
- name: Install test dependencies
working-directory: ./tests
run: npm ci
run: pnpm --filter pocket-id-tests install
- name: Install Playwright Browsers
working-directory: ./tests
if: steps.playwright-cache.outputs.cache-hit != 'true'
run: npx playwright install --with-deps chromium
run: pnpm dlx playwright install --with-deps chromium
- name: Run Docker Container with Postgres DB and LDAP
working-directory: ./tests/setup
@@ -216,7 +222,7 @@ jobs:
- name: Run Playwright tests
working-directory: ./tests
run: npx playwright test
run: pnpm exec playwright test
- name: Upload Test Report
uses: actions/upload-artifact@v4

View File

@@ -3,7 +3,7 @@ name: Release
on:
push:
tags:
- "v*.*.*"
- 'v*.*.*'
jobs:
build:
@@ -16,27 +16,29 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: 22
cache: "npm"
cache-dependency-path: frontend/package-lock.json
cache: 'pnpm'
cache-dependency-path: pnpm-lock.yaml
- uses: actions/setup-go@v5
with:
go-version-file: "backend/go.mod"
go-version-file: 'backend/go.mod'
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set DOCKER_IMAGE_NAME
run: |
# Lowercase REPO_OWNER which is required for containers
REPO_OWNER=${{ github.repository_owner }}
DOCKER_IMAGE_NAME="ghcr.io/${REPO_OWNER,,}/pocket-id"
echo "DOCKER_IMAGE_NAME=${DOCKER_IMAGE_NAME}" >>${GITHUB_ENV}
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
@@ -53,17 +55,24 @@ jobs:
type=semver,pattern={{version}},prefix=v
type=semver,pattern={{major}}.{{minor}},prefix=v
type=semver,pattern={{major}},prefix=v
- name: Docker metadata (distroless)
id: meta-distroless
uses: docker/metadata-action@v5
with:
images: |
${{ env.DOCKER_IMAGE_NAME }}
flavor: |
suffix=-distroless,onlatest=true
tags: |
type=semver,pattern={{version}},prefix=v
type=semver,pattern={{major}}.{{minor}},prefix=v
type=semver,pattern={{major}},prefix=v
- name: Install frontend dependencies
working-directory: frontend
run: npm ci
run: pnpm --filter pocket-id-frontend install --frozen-lockfile
- name: Build frontend
working-directory: frontend
run: npm run build
run: pnpm --filter pocket-id-frontend build
- name: Build binaries
run: sh scripts/development/build-binaries.sh
- name: Build and push container image
uses: docker/build-push-action@v6
id: container-build-push
@@ -74,19 +83,32 @@ jobs:
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
file: Dockerfile-prebuilt
- name: Build and push container image (distroless)
uses: docker/build-push-action@v6
id: container-build-push-distroless
with:
context: .
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta-distroless.outputs.tags }}
labels: ${{ steps.meta-distroless.outputs.labels }}
file: Dockerfile-distroless
- name: Binary attestation
uses: actions/attest-build-provenance@v2
with:
subject-path: "backend/.bin/pocket-id-**"
subject-path: 'backend/.bin/pocket-id-**'
- name: Container image attestation
uses: actions/attest-build-provenance@v2
with:
subject-name: "${{ env.DOCKER_IMAGE_NAME }}"
subject-name: '${{ env.DOCKER_IMAGE_NAME }}'
subject-digest: ${{ steps.container-build-push.outputs.digest }}
push-to-registry: true
- name: Container image attestation (distroless)
uses: actions/attest-build-provenance@v2
with:
subject-name: '${{ env.DOCKER_IMAGE_NAME }}'
subject-digest: ${{ steps.container-build-push-distroless.outputs.digest }}
push-to-registry: true
- name: Upload binaries to release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
@@ -101,6 +123,6 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Mark release as published
run: gh release edit ${{ github.ref_name }} --draft=false

View File

@@ -4,21 +4,21 @@ on:
push:
branches: [main]
paths:
- "frontend/src/**"
- ".github/svelte-check-matcher.json"
- "frontend/package.json"
- "frontend/package-lock.json"
- "frontend/tsconfig.json"
- "frontend/svelte.config.js"
- 'frontend/src/**'
- '.github/svelte-check-matcher.json'
- 'frontend/package.json'
- 'frontend/package-lock.json'
- 'frontend/tsconfig.json'
- 'frontend/svelte.config.js'
pull_request:
branches: [main]
paths:
- "frontend/src/**"
- ".github/svelte-check-matcher.json"
- "frontend/package.json"
- "frontend/package-lock.json"
- "frontend/tsconfig.json"
- "frontend/svelte.config.js"
- 'frontend/src/**'
- '.github/svelte-check-matcher.json'
- 'frontend/package.json'
- 'frontend/package-lock.json'
- 'frontend/tsconfig.json'
- 'frontend/svelte.config.js'
workflow_dispatch:
jobs:
@@ -36,24 +36,28 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: 22
cache: "npm"
cache-dependency-path: frontend/package-lock.json
cache: 'pnpm'
cache-dependency-path: pnpm-lock.yaml
- name: Install dependencies
working-directory: frontend
run: npm ci
run: pnpm --filter pocket-id-frontend install --frozen-lockfile
- name: Build Pocket ID Frontend
working-directory: frontend
run: npm run build
run: pnpm --filter pocket-id-frontend build
- name: Add svelte-check problem matcher
run: echo "::add-matcher::.github/svelte-check-matcher.json"
- name: Run svelte-check
working-directory: frontend
run: npm run check
run: pnpm --filter pocket-id-frontend check

1
.gitignore vendored
View File

@@ -10,6 +10,7 @@ node_modules
/frontend/build
/backend/bin
pocket-id
/tests/test-results/*.json
# OS
.DS_Store

View File

@@ -1 +1 @@
1.2.0
1.7.0

View File

@@ -1,3 +1,142 @@
## [](https://github.com/pocket-id/pocket-id/compare/v1.6.4...v) (2025-08-10)
### Features
* add robots.txt to block indexing ([#806](https://github.com/pocket-id/pocket-id/issues/806)) ([06e1656](https://github.com/pocket-id/pocket-id/commit/06e1656923eb2f4531be497716f9147c09d60b65))
* add support for `code_challenge_methods_supported` ([#794](https://github.com/pocket-id/pocket-id/issues/794)) ([d479817](https://github.com/pocket-id/pocket-id/commit/d479817b6a7ca4807b5de500b3ba713d436b0770))
* Support OTel and JSON for logs (via log/slog) ([#760](https://github.com/pocket-id/pocket-id/issues/760)) ([78266e3](https://github.com/pocket-id/pocket-id/commit/78266e3e4cab2b23249c3baf20f4387d00eebd9e))
* support reading secret env vars from _FILE ([#799](https://github.com/pocket-id/pocket-id/issues/799)) ([0a3b1c6](https://github.com/pocket-id/pocket-id/commit/0a3b1c653050f2237d30ec437c5de88baa704a25))
* user application dashboard ([#727](https://github.com/pocket-id/pocket-id/issues/727)) ([484c2f6](https://github.com/pocket-id/pocket-id/commit/484c2f6ef20efc1fade1a41e2aeace54c7bb4f1b))
### Bug Fixes
* admins can not delete or disable their own account ([f0c144c](https://github.com/pocket-id/pocket-id/commit/f0c144c51c635bc348222a00d3bc88bc4e0711ef))
* authorization animation not working ([9ac5d51](https://github.com/pocket-id/pocket-id/commit/9ac5d5118710cad59c8c4ce7cef7ab09be3de664))
* custom claims input suggestions instantly close after opening ([4d59e72](https://github.com/pocket-id/pocket-id/commit/4d59e7286666480e20c728787a95e82513509240))
* delete WebAuthn registration session after use ([#783](https://github.com/pocket-id/pocket-id/issues/783)) ([c8478d7](https://github.com/pocket-id/pocket-id/commit/c8478d75bed7295625cd3cf62ef46fcd95902410))
* set input type 'email' for email-based login ([#776](https://github.com/pocket-id/pocket-id/issues/776)) ([d541c9a](https://github.com/pocket-id/pocket-id/commit/d541c9ab4af8d7283891a80f886dd5d4ebc52f53))
## [](https://github.com/pocket-id/pocket-id/compare/v1.6.3...v) (2025-07-21)
### Bug Fixes
* migration fails on postgres ([#762](https://github.com/pocket-id/pocket-id/issues/762)) ([35d5f88](https://github.com/pocket-id/pocket-id/commit/35d5f887ce7c88933d7e4c2f0acd2aeedd18c214))
## [](https://github.com/pocket-id/pocket-id/compare/v1.6.2...v) (2025-07-21)
### Bug Fixes
* allow passkey names up to 50 characters ([b03e91b](https://github.com/pocket-id/pocket-id/commit/b03e91b6530c2393ad20ac49aa2cb2b4962651b2))
* ensure user inputs are normalized ([#724](https://github.com/pocket-id/pocket-id/issues/724)) ([7b4ccd1](https://github.com/pocket-id/pocket-id/commit/7b4ccd1f306f4882c52fe30133fcda114ef0d18b))
* show rename and delete buttons for passkeys without hovering over the row ([2952b15](https://github.com/pocket-id/pocket-id/commit/2952b1575542ecd0062fe740e2d6a3caad05190d))
* use object-contain for images on oidc-client list ([d3bc179](https://github.com/pocket-id/pocket-id/commit/d3bc1797b65ec8bc9201c55d06f3612093f3a873))
* use user-agent for identifying known device signins ([ef1d599](https://github.com/pocket-id/pocket-id/commit/ef1d5996624fc534190f80a26f2c48bbad206f49))
## [](https://github.com/pocket-id/pocket-id/compare/v1.6.1...v) (2025-07-09)
### Bug Fixes
* ensure confirmation dialog shows on top of other components ([f103a54](https://github.com/pocket-id/pocket-id/commit/f103a547904070c5b192e519c8b5a8fed9d80e96))
* login failures on Postgres when IP is null ([#737](https://github.com/pocket-id/pocket-id/issues/737)) ([e1de593](https://github.com/pocket-id/pocket-id/commit/e1de593dcd30b7b04da3b003455134992b702595))
## [](https://github.com/pocket-id/pocket-id/compare/v1.5.0...v) (2025-07-06)
### Features
* add "key-rotate" command ([#709](https://github.com/pocket-id/pocket-id/issues/709)) ([8c8fc23](https://github.com/pocket-id/pocket-id/commit/8c8fc2304d8f33c1fea54b1138b109f282e78b8b))
* add support for OAuth 2.0 Authorization Server Issuer Identification ([bf04256](https://github.com/pocket-id/pocket-id/commit/bf042563e997d57bb087705a5789fd72ffbed467))
* distroless container additional variant + healthcheck command ([#716](https://github.com/pocket-id/pocket-id/issues/716)) ([1a41b05](https://github.com/pocket-id/pocket-id/commit/1a41b05f60d487fff78703bec1d4e832f96fd071))
* encrypt private keys saved on disk and in database ([#682](https://github.com/pocket-id/pocket-id/issues/682)) ([5550729](https://github.com/pocket-id/pocket-id/commit/5550729120ac9f5e9361c7f9cf25b9075a33a94a))
* enhance language selection message and add translation contribution link ([be52660](https://github.com/pocket-id/pocket-id/commit/be526602273c1689cb4057ca96d4214e7f817d1d))
### Bug Fixes
* actually fix linter issues ([#720](https://github.com/pocket-id/pocket-id/issues/720)) ([7fe83f8](https://github.com/pocket-id/pocket-id/commit/7fe83f8087f033f957bb6e0eee5e0c159417e1cd))
* add missing error check in initial user setup ([fceb6fa](https://github.com/pocket-id/pocket-id/commit/fceb6fa7b4701a3645c4c2353bcd108b15d69ded))
* allow profile picture update even if "allow own account edit" enabled ([9872608](https://github.com/pocket-id/pocket-id/commit/9872608d61a486f7b775f314d9392e0620bcd891))
* app config forms not updating with latest values ([#696](https://github.com/pocket-id/pocket-id/issues/696)) ([92c57ad](https://github.com/pocket-id/pocket-id/commit/92c57ada1a11f76963e36ca0a81bca8f52dbc84e))
* auth fails when client IP is empty on Postgres ([#695](https://github.com/pocket-id/pocket-id/issues/695)) ([031181a](https://github.com/pocket-id/pocket-id/commit/031181ad2ae8fae94cc5793dd1c614e79476a766))
* custom claims input suggestions flickering ([49f1ab2](https://github.com/pocket-id/pocket-id/commit/49f1ab2f75df97d551fff5acbadcd55df74af617))
* keep sidebar in settings sticky ([e46f60a](https://github.com/pocket-id/pocket-id/commit/e46f60ac8d6944bcea54d0708af1950d98f66c3c))
* linter issues ([#719](https://github.com/pocket-id/pocket-id/issues/719)) ([43f0114](https://github.com/pocket-id/pocket-id/commit/43f0114c579f7b5b32b372e09f46bcb2a9d7796e))
* show friendly name in user group selection ([5c9e504](https://github.com/pocket-id/pocket-id/commit/5c9e504291b3bffe947bcbe907701806e301d1fe))
* support non UTF-8 LDAP IDs ([#714](https://github.com/pocket-id/pocket-id/issues/714)) ([8131579](https://github.com/pocket-id/pocket-id/commit/81315790a8aa601a2565a1b54807df1e68f06dc5))
* token introspection authentication not handled correctly ([#704](https://github.com/pocket-id/pocket-id/issues/704)) ([aefb308](https://github.com/pocket-id/pocket-id/commit/aefb30853677baf7ed29ac8b539e1aadf56e14a4))
## [](https://github.com/pocket-id/pocket-id/compare/v1.4.1...v) (2025-06-27)
### Features
* improve initial admin creation workflow ([287314f](https://github.com/pocket-id/pocket-id/commit/287314f01644e42ddb2ce1b1115bd14f2f0c1768))
* redact sensitive app config variables if set with env variable ([ba61cdb](https://github.com/pocket-id/pocket-id/commit/ba61cdba4eb3d5659f3ae6b6c21249985c0aa630))
* self-service user signup ([#672](https://github.com/pocket-id/pocket-id/issues/672)) ([dcd1ae9](https://github.com/pocket-id/pocket-id/commit/dcd1ae96e048115be34b0cce275054e990462ebf))
### Bug Fixes
* double double full stops for certain error messages ([d070b9a](https://github.com/pocket-id/pocket-id/commit/d070b9a778d7d1a51f2fa62d003f2331a96d6c91))
* error page flickering after sign out ([1a77bd9](https://github.com/pocket-id/pocket-id/commit/1a77bd9914ea01e445ff3d6e116c9ed3bcfbf153))
* improve accent color picker disabled state ([d976bf5](https://github.com/pocket-id/pocket-id/commit/d976bf5965eda10e3ecb71821c23e93e5d712a02))
* less noisy logging for certain GET requests ([#681](https://github.com/pocket-id/pocket-id/issues/681)) ([043f82a](https://github.com/pocket-id/pocket-id/commit/043f82ad794eb64a5550d8b80703114a055701d9))
* margin of user sign up description ([052ac00](https://github.com/pocket-id/pocket-id/commit/052ac008c3a8c910d1ce79ee99b2b2f75e4090f4))
* remove duplicate request logging ([#678](https://github.com/pocket-id/pocket-id/issues/678)) ([988c425](https://github.com/pocket-id/pocket-id/commit/988c425150556b32cff1d341a21fcc9c69d9aaf8))
* users can't be updated by admin if self account editing is disabled ([29cb551](https://github.com/pocket-id/pocket-id/commit/29cb5513a03d1a9571969c8a42deec9b2bdee037))
## [](https://github.com/pocket-id/pocket-id/compare/v1.4.0...v) (2025-06-22)
### Bug Fixes
* app not starting if UI config is disabled and Postgres is used ([7d36bda](https://github.com/pocket-id/pocket-id/commit/7d36bda769e25497dec6b76206a4f7e151b0bd72))
## [](https://github.com/pocket-id/pocket-id/compare/v1.3.1...v) (2025-06-19)
### Features
* allow setting unix socket mode ([#661](https://github.com/pocket-id/pocket-id/issues/661)) ([7677a3d](https://github.com/pocket-id/pocket-id/commit/7677a3de2c923c11a58bc8c4d1b2121d403a1504))
* auto-focus on the login buttons ([#647](https://github.com/pocket-id/pocket-id/issues/647)) ([d679530](https://github.com/pocket-id/pocket-id/commit/d6795300b158b85dd9feadd561b6ecd891f5db0d))
* configurable local ipv6 ranges for audit log ([#657](https://github.com/pocket-id/pocket-id/issues/657)) ([d548523](https://github.com/pocket-id/pocket-id/commit/d5485238b8fd4cc566af00eae2b17d69a119f991))
* location filter for global audit log ([#662](https://github.com/pocket-id/pocket-id/issues/662)) ([ac5a121](https://github.com/pocket-id/pocket-id/commit/ac5a121f664b8127d0faf30c0f93432f30e7f33a))
* ui accent colors ([#643](https://github.com/pocket-id/pocket-id/issues/643)) ([883877a](https://github.com/pocket-id/pocket-id/commit/883877adec6fc3e65bd5a705499449959b894fb5))
* use icon instead of text on application image update hover state ([215531d](https://github.com/pocket-id/pocket-id/commit/215531d65c6683609b0b4a5505fdb72696fdb93e))
### Bug Fixes
* allow images with uppercase file extension ([1bcb50e](https://github.com/pocket-id/pocket-id/commit/1bcb50edc335886dd722a4c69960c48cc3cd1687))
* center oidc client images if they are smaller than the box ([946c534](https://github.com/pocket-id/pocket-id/commit/946c534b0877a074a6b658060f9af27e4061397c))
* explicitly cache images to prevent unexpected behavior ([2e5d268](https://github.com/pocket-id/pocket-id/commit/2e5d2687982186c12e530492292d49895cb6043a))
* reduce duration of animations on login and signin page ([#648](https://github.com/pocket-id/pocket-id/issues/648)) ([d770448](https://github.com/pocket-id/pocket-id/commit/d77044882d5a41da22df1c0099c1eb1f20bcbc5b))
* use inline style for dynamic background image URL instead of Tailwind class ([bef77ac](https://github.com/pocket-id/pocket-id/commit/bef77ac8dca2b98b6732677aaafbc28f79d00487))
## [](https://github.com/pocket-id/pocket-id/compare/v1.3.0...v) (2025-06-09)
### Bug Fixes
* change timestamp of `client_credentials.sql` migration ([2935236](https://github.com/pocket-id/pocket-id/commit/2935236acee9c78c2fe6787ec8b5f53ae0eca047))
## [](https://github.com/pocket-id/pocket-id/compare/v1.2.0...v) (2025-06-09)
### Features
* add API endpoint for user authorized clients ([d217083](https://github.com/pocket-id/pocket-id/commit/d217083059120171d5c555b09eefe6ba3c8a8d42))
* add unix socket support ([#615](https://github.com/pocket-id/pocket-id/issues/615)) ([035b2c0](https://github.com/pocket-id/pocket-id/commit/035b2c022bfd2b98f13355ec7a126e0f1ab3ebd8))
* allow introspection and device code endpoints to use Federated Client Credentials ([#640](https://github.com/pocket-id/pocket-id/issues/640)) ([b62b61f](https://github.com/pocket-id/pocket-id/commit/b62b61fb017dba31a6fc612c138bebf370d3956c))
* JWT bearer assertions for client authentication ([#566](https://github.com/pocket-id/pocket-id/issues/566)) ([05bfe00](https://github.com/pocket-id/pocket-id/commit/05bfe0092450c9bc26d03c6a54c21050eef8f63a))
* new color theme for the UI ([97f7326](https://github.com/pocket-id/pocket-id/commit/97f7326da40265a954340d519661969530f097a0))
* oidc client data preview ([#624](https://github.com/pocket-id/pocket-id/issues/624)) ([c111b79](https://github.com/pocket-id/pocket-id/commit/c111b7914731a3cafeaa55102b515f84a1ad74dc))
### Bug Fixes
* don't load app config and user on every route change ([bdcef60](https://github.com/pocket-id/pocket-id/commit/bdcef60cab6a61e1717661e918c42e3650d23fee))
* misleading text for disable animations option ([657a51f](https://github.com/pocket-id/pocket-id/commit/657a51f7ed8a77e8a937971032091058aacfded6))
* OIDC client image can't be deleted ([61b62d4](https://github.com/pocket-id/pocket-id/commit/61b62d461200c1359a16c92c9c62530362a4785c))
* UI config overridden by env variables don't apply on first start ([5e9096e](https://github.com/pocket-id/pocket-id/commit/5e9096e328741ba2a0e03835927fe62e6aea2a89))
* use full width for audit log filters ([575b2f7](https://github.com/pocket-id/pocket-id/commit/575b2f71e9f1ff9c4f6fd411b136676c213b7201))
## [](https://github.com/pocket-id/pocket-id/compare/v1.1.0...v) (2025-06-03)

View File

@@ -28,7 +28,7 @@ Before you submit the pull request for review please ensure that
- **refactor** - code change that neither fixes a bug nor adds a feature
- Your pull request has a detailed description
- You run `npm run format` to format the code
- You run `pnpm format` to format the code
## Development Environment
@@ -69,10 +69,10 @@ The backend is built with [Gin](https://gin-gonic.com) and written in Go. To set
The frontend is built with [SvelteKit](https://kit.svelte.dev) and written in TypeScript. To set it up, follow these steps:
1. Open the `frontend` folder
2. Copy the `.env.development-example` file to `.env` and edit the variables as needed
3. Install the dependencies with `npm install`
4. Start the frontend with `npm run dev`
1. Open the `pocket-id` project folder
2. Copy the `frontend/.env.development-example` file to `frontend/.env` and edit the variables as needed
3. Install the dependencies with `pnpm install`
4. Start the frontend with `pnpm dev`
You're all set! The application is now listening on `localhost:3000`. The backend gets proxied trough the frontend in development mode.
@@ -84,11 +84,13 @@ If you are contributing to a new feature please ensure that you add tests for it
The tests can be run like this:
1. Visit the setup folder by running `cd tests/setup`
1. Install the dependencies from the root of the project `pnpm install`
2. Start the test environment by running `docker compose up -d --build`
2. Visit the setup folder by running `cd tests/setup`
3. Go back to the test folder by running `cd ..`
4. Run the tests with `npx playwright test`
3. Start the test environment by running `docker compose up -d --build`
4. Go back to the test folder by running `cd ..`
5. Run the tests with `pnpm dlx playwright test` or from the root project folder `pnpm test`
If you make any changes to the application, you have to rebuild the test environment by running `docker compose up -d --build` again.

View File

@@ -5,11 +5,17 @@ ARG BUILD_TAGS=""
# Stage 1: Build Frontend
FROM node:22-alpine AS frontend-builder
RUN corepack enable
WORKDIR /build
COPY ./frontend/package*.json ./
RUN npm ci
COPY ./frontend ./
RUN BUILD_OUTPUT_PATH=dist npm run build
COPY pnpm-workspace.yaml pnpm-lock.yaml ./
COPY frontend/package.json ./frontend/
RUN pnpm --filter pocket-id-frontend install --frozen-lockfile
COPY ./frontend ./frontend/
RUN BUILD_OUTPUT_PATH=dist pnpm --filter pocket-id-frontend run build
# Stage 2: Build Backend
FROM golang:1.24-alpine AS backend-builder
@@ -19,7 +25,7 @@ COPY ./backend/go.mod ./backend/go.sum ./
RUN go mod download
COPY ./backend ./
COPY --from=frontend-builder /build/dist ./frontend/dist
COPY --from=frontend-builder /build/frontend/dist ./frontend/dist
COPY .version .version
WORKDIR /build/cmd
@@ -30,7 +36,7 @@ RUN VERSION=$(cat /build/.version) \
-tags "${BUILD_TAGS}" \
-ldflags="-X github.com/pocket-id/pocket-id/backend/internal/common.Version=${VERSION} -buildid=${VERSION}" \
-trimpath \
-o /build/pocket-id-backend \
-o /build/pocket-id \
.
# Stage 3: Production Image
@@ -39,7 +45,7 @@ WORKDIR /app
RUN apk add --no-cache curl su-exec
COPY --from=backend-builder /build/pocket-id-backend /app/pocket-id
COPY --from=backend-builder /build/pocket-id /app/pocket-id
COPY ./scripts/docker /app/docker
RUN chmod +x /app/pocket-id && \
@@ -48,5 +54,7 @@ RUN chmod +x /app/pocket-id && \
EXPOSE 1411
ENV APP_ENV=production
HEALTHCHECK --interval=90s --timeout=5s --start-period=10s --retries=3 CMD [ "/app/pocket-id", "healthcheck" ]
ENTRYPOINT ["sh", "/app/docker/entrypoint.sh"]
CMD ["/app/pocket-id"]

18
Dockerfile-distroless Normal file
View File

@@ -0,0 +1,18 @@
# This Dockerfile embeds a pre-built binary for the given Linux architecture
# Binaries must be built using "./scripts/development/build-binaries.sh --docker-only"
FROM gcr.io/distroless/static-debian12:nonroot
# TARGETARCH can be "amd64" or "arm64"
ARG TARGETARCH
WORKDIR /app
COPY ./backend/.bin/pocket-id-linux-${TARGETARCH} /app/pocket-id
EXPOSE 1411
ENV APP_ENV=production
HEALTHCHECK --interval=90s --timeout=5s --start-period=10s --retries=3 CMD [ "/app/pocket-id", "healthcheck" ]
CMD ["/app/pocket-id"]

View File

@@ -1,5 +1,5 @@
# This Dockerfile embeds a pre-built binary for the given Linux architecture
# Binaries must be built using ./scripts/development/build-binaries.sh first
# Binaries must be built using "./scripts/development/build-binaries.sh --docker-only"
FROM alpine
@@ -16,5 +16,7 @@ COPY ./scripts/docker /app/docker
EXPOSE 1411
ENV APP_ENV=production
HEALTHCHECK --interval=90s --timeout=5s --start-period=10s --retries=3 CMD [ "/app/pocket-id", "healthcheck" ]
ENTRYPOINT ["/app/docker/entrypoint.sh"]
CMD ["/app/pocket-id"]

12
backend/.air.toml Normal file
View File

@@ -0,0 +1,12 @@
root = "."
tmp_dir = ".bin"
[build]
bin = "./.bin/pocket-id"
cmd = "CGO_ENABLED=0 go build -o ./.bin/pocket-id ./cmd"
exclude_dir = ["resources", ".bin", "data"]
exclude_regex = [".*_test\\.go"]
stop_on_error = true
[misc]
clean_on_exit = true

View File

@@ -1,15 +1,9 @@
package main
import (
"flag"
"fmt"
"log"
_ "time/tzdata"
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
"github.com/pocket-id/pocket-id/backend/internal/cmds"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
// @title Pocket ID API
@@ -17,27 +11,5 @@ import (
// @description.markdown
func main() {
// Get the command
// By default, this starts the server
var cmd string
flag.Parse()
args := flag.Args()
if len(args) > 0 {
cmd = args[0]
}
var err error
switch cmd {
case "version":
fmt.Println("pocket-id " + common.Version)
case "one-time-access-token":
err = cmds.OneTimeAccessToken(args)
default:
// Start the server
err = bootstrap.Bootstrap()
}
if err != nil {
log.Fatal(err.Error())
}
cmds.Execute()
}

View File

@@ -11,6 +11,7 @@ require (
github.com/emersion/go-smtp v0.21.3
github.com/fxamacker/cbor/v2 v2.7.0
github.com/gin-gonic/gin v1.10.0
github.com/glebarez/go-sqlite v1.21.2
github.com/glebarez/sqlite v1.11.0
github.com/go-co-op/gocron/v2 v2.15.0
github.com/go-ldap/ldap/v3 v3.4.10
@@ -19,21 +20,32 @@ require (
github.com/golang-migrate/migrate/v4 v4.18.2
github.com/google/uuid v1.6.0
github.com/hashicorp/go-uuid v1.0.3
github.com/jinzhu/copier v0.4.0
github.com/joho/godotenv v1.5.1
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2
github.com/lestrrat-go/jwx/v3 v3.0.1
github.com/lmittmann/tint v1.1.2
github.com/mattn/go-isatty v0.0.20
github.com/mileusna/useragent v1.3.5
github.com/orandin/slog-gorm v1.4.0
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2
github.com/samber/slog-gin v1.15.1
github.com/spf13/cobra v1.9.1
github.com/stretchr/testify v1.10.0
go.opentelemetry.io/contrib/bridges/otelslog v0.12.0
go.opentelemetry.io/contrib/exporters/autoexport v0.59.0
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.60.0
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0
go.opentelemetry.io/otel v1.35.0
go.opentelemetry.io/otel/metric v1.35.0
go.opentelemetry.io/otel v1.37.0
go.opentelemetry.io/otel/log v0.13.0
go.opentelemetry.io/otel/metric v1.37.0
go.opentelemetry.io/otel/sdk v1.35.0
go.opentelemetry.io/otel/sdk/log v0.10.0
go.opentelemetry.io/otel/sdk/metric v1.35.0
go.opentelemetry.io/otel/trace v1.35.0
golang.org/x/crypto v0.36.0
go.opentelemetry.io/otel/trace v1.37.0
golang.org/x/crypto v0.39.0
golang.org/x/image v0.24.0
golang.org/x/text v0.26.0
golang.org/x/time v0.9.0
gorm.io/driver/postgres v1.5.11
gorm.io/gorm v1.25.12
@@ -54,9 +66,8 @@ require (
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/gin-contrib/sse v1.0.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
@@ -67,6 +78,7 @@ require (
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.7.2 // indirect
@@ -77,12 +89,10 @@ require (
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/blackmagic v1.0.3 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.24 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
@@ -98,6 +108,7 @@ require (
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/x448/float16 v0.8.4 // indirect
@@ -114,23 +125,20 @@ require (
go.opentelemetry.io/otel/exporters/stdout/stdoutlog v0.10.0 // indirect
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.35.0 // indirect
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.35.0 // indirect
go.opentelemetry.io/otel/log v0.10.0 // indirect
go.opentelemetry.io/otel/sdk/log v0.10.0 // indirect
go.opentelemetry.io/proto/otlp v1.5.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
golang.org/x/arch v0.14.0 // indirect
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/sync v0.14.0 // indirect
golang.org/x/sync v0.15.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect
google.golang.org/grpc v1.71.0 // indirect
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.65.6 // indirect
modernc.org/libc v1.65.10 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.10.0 // indirect
modernc.org/sqlite v1.37.0 // indirect
modernc.org/memory v1.11.0 // indirect
modernc.org/sqlite v1.38.0 // indirect
)

View File

@@ -24,6 +24,7 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4=
github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -72,8 +73,8 @@ github.com/go-co-op/gocron/v2 v2.15.0/go.mod h1:ZF70ZwEqz0OO4RBXE1sNxnANy/zvwLca
github.com/go-ldap/ldap/v3 v3.4.10 h1:ot/iwPOhfpNVgB1o+AVXljizWZ9JTp7YF5oeyONmcJU=
github.com/go-ldap/ldap/v3 v3.4.10/go.mod h1:JXh4Uxgi40P6E9rdsYqpUtbW46D9UTjJ9QSwGRznplY=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
@@ -120,6 +121,8 @@ github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -140,6 +143,8 @@ github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh6
github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY=
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
@@ -164,18 +169,20 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
github.com/lestrrat-go/blackmagic v1.0.3 h1:94HXkVLxkZO9vJI/w2u1T0DAoprShFd13xtnSINtDWs=
github.com/lestrrat-go/blackmagic v1.0.3/go.mod h1:6AWFyKNNj0zEXQYfTMPfZrAXUWUfTIZ5ECEUEJaijtw=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 h1:pzDjP9dSONCFQC/AE3mWUnHILGiYPiMKzQIS+weKJXA=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1/go.mod h1:wdsgouffPvWPEYh8t7PRH/PidR5sfVqt0na4Nhj60Ms=
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1 h1:Iqjb8JvWjh34Jv8DeM2wQ1aG5fzFBzwQu7rlqwuJB0I=
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1/go.mod h1:ak32WoNtHE0aLowVWBcCvXngcAnW4tuC0YhFwOr/kwc=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2 h1:SDxjGoH7qj0nBXVrcrxX8eD94wEnjR+EEuqqmeqQYlY=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta2/go.mod h1:Nwo81sMxE0DcvTB+rJyynNhv/DUu2yZErV7sscw9pHE=
github.com/lestrrat-go/jwx/v3 v3.0.1 h1:fH3T748FCMbXoF9UXXNS9i0q6PpYyJZK/rKSbkt2guY=
github.com/lestrrat-go/jwx/v3 v3.0.1/go.mod h1:XP2WqxMOSzHSyf3pfibCcfsLqbomxakAnNqiuaH8nwo=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lmittmann/tint v1.1.2 h1:2CQzrL6rslrsyjqLDwD11bZ5OpLBPU+g3G/r5LSfS8w=
github.com/lmittmann/tint v1.1.2/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
@@ -203,6 +210,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
github.com/orandin/slog-gorm v1.4.0 h1:FgA8hJufF9/jeNSYoEXmHPPBwET2gwlF3B85JdpsTUU=
github.com/orandin/slog-gorm v1.4.0/go.mod h1:MoZ51+b7xE9lwGNPYEhxcUtRNrYzjdcKvA8QXQQGEPA=
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2 h1:jG+FaCBv3h6GD5F+oenTfe3+0NmX8sCKjni5k3A5Dek=
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2/go.mod h1:rHaQJ5SjfCdL4sqCKa3FhklRcaXga2/qyvmQuA+ZJ6M=
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
@@ -225,8 +234,15 @@ github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/samber/slog-gin v1.15.1 h1:jsnfr+S5HQPlz9pFPA3tOmKW7wN/znyZiE6hncucrTM=
github.com/samber/slog-gin v1.15.1/go.mod h1:mPAEinK/g2jPLauuWO11m3Q0Ca7aG4k9XjXjXY8IhMQ=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -250,6 +266,8 @@ github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcY
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/bridges/otelslog v0.12.0 h1:lFM7SZo8Ce01RzRfnUFQZEYeWRf/MtOA3A5MobOqk2g=
go.opentelemetry.io/contrib/bridges/otelslog v0.12.0/go.mod h1:Dw05mhFtrKAYu72Tkb3YBYeQpRUJ4quDgo2DQw3No5A=
go.opentelemetry.io/contrib/bridges/prometheus v0.59.0 h1:HY2hJ7yn3KuEBBBsKxvF3ViSmzLwsgeNvD+0utRMgzc=
go.opentelemetry.io/contrib/bridges/prometheus v0.59.0/go.mod h1:H4H7vs8766kwFnOZVEGMJFVF+phpBSmTckvvNRdJeDI=
go.opentelemetry.io/contrib/exporters/autoexport v0.59.0 h1:dKhAFwh7SSoOw+gwMtSv+XLkUGTFAwAGMT3X3XSE4FA=
@@ -258,8 +276,8 @@ go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.
go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin v0.60.0/go.mod h1:ZvRTVaYYGypytG0zRp2A60lpj//cMq3ZnxYdZaljVBM=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8=
go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ=
go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y=
go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ=
go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I=
go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.10.0 h1:5dTKu4I5Dn4P2hxyW3l3jTaZx9ACgg0ECos1eAVrheY=
go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.10.0/go.mod h1:P5HcUI8obLrCCmM3sbVBohZFH34iszk/+CPWuakZWL8=
go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.10.0 h1:q/heq5Zh8xV1+7GoMGJpTxM2Lhq5+bFxB29tshuRuw0=
@@ -282,18 +300,18 @@ go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.35.0 h1:PB3Zrjs1sG1GBX
go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.35.0/go.mod h1:U2R3XyVPzn0WX7wOIypPuptulsMcPDPs/oiSVOMVnHY=
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.35.0 h1:T0Ec2E+3YZf5bgTNQVet8iTDW7oIk03tXHq+wkwIDnE=
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.35.0/go.mod h1:30v2gqH+vYGJsesLWFov8u47EpYTcIQcBjKpI6pJThg=
go.opentelemetry.io/otel/log v0.10.0 h1:1CXmspaRITvFcjA4kyVszuG4HjA61fPDxMb7q3BuyF0=
go.opentelemetry.io/otel/log v0.10.0/go.mod h1:PbVdm9bXKku/gL0oFfUF4wwsQsOPlpo4VEqjvxih+FM=
go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M=
go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE=
go.opentelemetry.io/otel/log v0.13.0 h1:yoxRoIZcohB6Xf0lNv9QIyCzQvrtGZklVbdCoyb7dls=
go.opentelemetry.io/otel/log v0.13.0/go.mod h1:INKfG4k1O9CL25BaM1qLe0zIedOpvlS5Z7XgSbmN83E=
go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE=
go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E=
go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY=
go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg=
go.opentelemetry.io/otel/sdk/log v0.10.0 h1:lR4teQGWfeDVGoute6l0Ou+RpFqQ9vaPdrNJlST0bvw=
go.opentelemetry.io/otel/sdk/log v0.10.0/go.mod h1:A+V1UTWREhWAittaQEG4bYm4gAZa6xnvVu+xKrIRkzo=
go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o=
go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w=
go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs=
go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc=
go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4=
go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0=
go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4=
go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
@@ -309,8 +327,8 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI=
golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ=
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -321,8 +339,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -343,8 +361,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -377,8 +395,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -413,22 +431,22 @@ modernc.org/cc/v4 v4.26.1 h1:+X5NtzVBn0KgsBCBe+xkDC7twLb/jNVj9FPgiwSQO3s=
modernc.org/cc/v4 v4.26.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
modernc.org/fileutil v1.3.1 h1:8vq5fe7jdtEvoCf3Zf9Nm0Q05sH6kGx0Op2CPx1wTC8=
modernc.org/fileutil v1.3.1/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/fileutil v1.3.3 h1:3qaU+7f7xxTUmvU1pJTZiDLAIoJVdUSSauJNHg9yXoA=
modernc.org/fileutil v1.3.3/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/libc v1.65.6 h1:OhJUhmuJ6MVZdqL5qmnd0/my46DKGFhSX4WOR7ijfyE=
modernc.org/libc v1.65.6/go.mod h1:MOiGAM9lrMBT9L8xT1nO41qYl5eg9gCp9/kWhz5L7WA=
modernc.org/libc v1.65.10 h1:ZwEk8+jhW7qBjHIT+wd0d9VjitRyQef9BnzlzGwMODc=
modernc.org/libc v1.65.10/go.mod h1:StFvYpx7i/mXtBAfVOjaU0PWZOvIRoZSgXhrwXzr8Po=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.10.0 h1:fzumd51yQ1DxcOxSO+S6X7+QTuVU+n8/Aj7swYjFfC4=
modernc.org/memory v1.10.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.37.0 h1:s1TMe7T3Q3ovQiK2Ouz4Jwh7dw4ZDqbebSDTlSJdfjI=
modernc.org/sqlite v1.37.0/go.mod h1:5YiWv+YviqGMuGw4V+PNplcyaJ5v+vQd7TQOgkACoJM=
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=

View File

@@ -1,7 +1,7 @@
package bootstrap
import (
"log"
"fmt"
"os"
"path"
"strings"
@@ -12,17 +12,17 @@ import (
)
// initApplicationImages copies the images from the images directory to the application-images directory
func initApplicationImages() {
func initApplicationImages() error {
dirPath := common.EnvConfig.UploadPath + "/application-images"
sourceFiles, err := resources.FS.ReadDir("images")
if err != nil && !os.IsNotExist(err) {
log.Fatalf("Error reading directory: %v", err)
return fmt.Errorf("failed to read directory: %w", err)
}
destinationFiles, err := os.ReadDir(dirPath)
if err != nil && !os.IsNotExist(err) {
log.Fatalf("Error reading directory: %v", err)
return fmt.Errorf("failed to read directory: %w", err)
}
// Copy images from the images directory to the application-images directory if they don't already exist
@@ -35,9 +35,11 @@ func initApplicationImages() {
err := utils.CopyEmbeddedFileToDisk(srcFilePath, destFilePath)
if err != nil {
log.Fatalf("Error copying file: %v", err)
return fmt.Errorf("failed to copy file: %w", err)
}
}
return nil
}
func imageAlreadyExists(fileName string, destinationFiles []os.DirEntry) bool {

View File

@@ -3,7 +3,7 @@ package bootstrap
import (
"context"
"fmt"
"log"
"log/slog"
"time"
_ "github.com/golang-migrate/migrate/v4/source/file"
@@ -11,23 +11,26 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/job"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/signals"
)
func Bootstrap() error {
// Get a context that is canceled when the application is stopping
ctx := signals.SignalContext(context.Background())
initApplicationImages()
// Initialize the tracer and metrics exporter
shutdownFns, httpClient, err := initOtel(ctx, common.EnvConfig.MetricsEnabled, common.EnvConfig.TracingEnabled)
func Bootstrap(ctx context.Context) error {
// Initialize the observability stack, including the logger, distributed tracing, and metrics
shutdownFns, httpClient, err := initObservability(ctx, common.EnvConfig.MetricsEnabled, common.EnvConfig.TracingEnabled)
if err != nil {
return fmt.Errorf("failed to initialize OpenTelemetry: %w", err)
}
slog.InfoContext(ctx, "Pocket ID is starting")
err = initApplicationImages()
if err != nil {
return fmt.Errorf("failed to initialize application images: %w", err)
}
// Connect to the database
db := NewDatabase()
db, err := NewDatabase()
if err != nil {
return fmt.Errorf("failed to initialize database: %w", err)
}
// Create all services
svc, err := initServices(ctx, db, httpClient)
@@ -59,13 +62,14 @@ func Bootstrap() error {
// Invoke all shutdown functions
// We give these a timeout of 5s
// Note: we use a background context because the run context has been canceled already
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
err = utils.
NewServiceRunner(shutdownFns...).
Run(shutdownCtx)
Run(shutdownCtx) //nolint:contextcheck
if err != nil {
log.Printf("Error shutting down services: %v", err)
slog.Error("Error shutting down services", slog.Any("error", err))
}
return nil

View File

@@ -3,9 +3,8 @@ package bootstrap
import (
"errors"
"fmt"
"log"
"log/slog"
"net/url"
"os"
"strings"
"time"
@@ -15,22 +14,24 @@ import (
postgresMigrate "github.com/golang-migrate/migrate/v4/database/postgres"
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
slogGorm "github.com/orandin/slog-gorm"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
gormLogger "gorm.io/gorm/logger"
"github.com/pocket-id/pocket-id/backend/internal/common"
sqliteutil "github.com/pocket-id/pocket-id/backend/internal/utils/sqlite"
"github.com/pocket-id/pocket-id/backend/resources"
)
func NewDatabase() (db *gorm.DB) {
db, err := connectDatabase()
func NewDatabase() (db *gorm.DB, err error) {
db, err = connectDatabase()
if err != nil {
log.Fatalf("failed to connect to database: %v", err)
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
sqlDb, err := db.DB()
if err != nil {
log.Fatalf("failed to get sql.DB: %v", err)
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
}
// Choose the correct driver for the database provider
@@ -42,18 +43,18 @@ func NewDatabase() (db *gorm.DB) {
driver, err = postgresMigrate.WithInstance(sqlDb, &postgresMigrate.Config{})
default:
// Should never happen at this point
log.Fatalf("unsupported database provider: %s", common.EnvConfig.DbProvider)
return nil, fmt.Errorf("unsupported database provider: %s", common.EnvConfig.DbProvider)
}
if err != nil {
log.Fatalf("failed to create migration driver: %v", err)
return nil, fmt.Errorf("failed to create migration driver: %w", err)
}
// Run migrations
if err := migrateDatabase(driver); err != nil {
log.Fatalf("failed to run migrations: %v", err)
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
return db
return db, nil
}
func migrateDatabase(driver database.Driver) error {
@@ -88,6 +89,7 @@ func connectDatabase() (db *gorm.DB, err error) {
if !strings.HasPrefix(common.EnvConfig.DbConnectionString, "file:") {
return nil, errors.New("invalid value for env var 'DB_CONNECTION_STRING': does not begin with 'file:'")
}
sqliteutil.RegisterSqliteFunctions()
connString, err := parseSqliteConnectionString(common.EnvConfig.DbConnectionString)
if err != nil {
return nil, err
@@ -105,16 +107,19 @@ func connectDatabase() (db *gorm.DB, err error) {
for i := 1; i <= 3; i++ {
db, err = gorm.Open(dialector, &gorm.Config{
TranslateError: true,
Logger: getLogger(),
Logger: getGormLogger(),
})
if err == nil {
slog.Info("Connected to database", slog.String("provider", string(common.EnvConfig.DbProvider)))
return db, nil
}
log.Printf("Attempt %d: Failed to initialize database. Retrying...", i)
slog.Warn("Failed to connect to database, will retry in 3s", slog.Int("attempt", i), slog.String("provider", string(common.EnvConfig.DbProvider)), slog.Any("error", err))
time.Sleep(3 * time.Second)
}
slog.Error("Failed to connect to database after 3 attempts", slog.String("provider", string(common.EnvConfig.DbProvider)), slog.Any("error", err))
return nil, err
}
@@ -162,24 +167,25 @@ func parseSqliteConnectionString(connString string) (string, error) {
return connStringUrl.String(), nil
}
func getLogger() logger.Interface {
isProduction := common.EnvConfig.AppEnv == "production"
func getGormLogger() gormLogger.Interface {
loggerOpts := make([]slogGorm.Option, 0, 5)
loggerOpts = append(loggerOpts,
slogGorm.WithSlowThreshold(200*time.Millisecond),
slogGorm.WithErrorField("error"),
)
var logLevel logger.LogLevel
if isProduction {
logLevel = logger.Error
if common.EnvConfig.AppEnv == "production" {
loggerOpts = append(loggerOpts,
slogGorm.SetLogLevel(slogGorm.DefaultLogType, slog.LevelWarn),
slogGorm.WithIgnoreTrace(),
)
} else {
logLevel = logger.Info
loggerOpts = append(loggerOpts,
slogGorm.SetLogLevel(slogGorm.DefaultLogType, slog.LevelDebug),
slogGorm.WithRecordNotFoundError(),
slogGorm.WithTraceAll(),
)
}
return logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags),
logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logLevel,
IgnoreRecordNotFoundError: isProduction,
ParameterizedQueries: isProduction,
Colorful: !isProduction,
},
)
return slogGorm.New(loggerOpts...)
}

View File

@@ -3,6 +3,9 @@
package bootstrap
import (
"log/slog"
"os"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@@ -14,7 +17,13 @@ import (
func init() {
registerTestControllers = []func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services){
func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *services) {
testService := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
testService, err := service.NewTestService(db, svc.appConfigService, svc.jwtService, svc.ldapService)
if err != nil {
slog.Error("Failed to initialize test service", slog.Any("error", err))
os.Exit(1)
return
}
controller.NewTestController(apiGroup, testService)
},
}

View File

@@ -0,0 +1,210 @@
package bootstrap
import (
"context"
"fmt"
"log/slog"
"net/http"
"os"
"time"
"github.com/lmittmann/tint"
"github.com/mattn/go-isatty"
"go.opentelemetry.io/contrib/bridges/otelslog"
"go.opentelemetry.io/contrib/exporters/autoexport"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel"
globallog "go.opentelemetry.io/otel/log/global"
metricnoop "go.opentelemetry.io/otel/metric/noop"
"go.opentelemetry.io/otel/propagation"
sdklog "go.opentelemetry.io/otel/sdk/log"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.30.0"
tracenoop "go.opentelemetry.io/otel/trace/noop"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
func defaultResource() (*resource.Resource, error) {
return resource.Merge(
resource.Default(),
resource.NewSchemaless(
semconv.ServiceName(common.Name),
semconv.ServiceVersion(common.Version),
),
)
}
func initObservability(ctx context.Context, metrics, traces bool) (shutdownFns []utils.Service, httpClient *http.Client, err error) {
resource, err := defaultResource()
if err != nil {
return nil, nil, fmt.Errorf("failed to create OpenTelemetry resource: %w", err)
}
shutdownFns = make([]utils.Service, 0, 2)
httpClient = &http.Client{}
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
// Indicates a development-time error
panic("Default transport is not of type *http.Transport")
}
httpClient.Transport = defaultTransport.Clone()
// Logging
err = initOtelLogging(ctx, resource)
if err != nil {
return nil, nil, err
}
// Tracing
tracingShutdownFn, err := initOtelTracing(ctx, traces, resource, httpClient)
if err != nil {
return nil, nil, err
} else if tracingShutdownFn != nil {
shutdownFns = append(shutdownFns, tracingShutdownFn)
}
// Metrics
metricsShutdownFn, err := initOtelMetrics(ctx, metrics, resource)
if err != nil {
return nil, nil, err
} else if metricsShutdownFn != nil {
shutdownFns = append(shutdownFns, metricsShutdownFn)
}
return shutdownFns, httpClient, nil
}
func initOtelLogging(ctx context.Context, resource *resource.Resource) error {
// If the env var OTEL_LOGS_EXPORTER is empty, we set it to "none", for autoexport to work
if os.Getenv("OTEL_LOGS_EXPORTER") == "" {
os.Setenv("OTEL_LOGS_EXPORTER", "none")
}
exp, err := autoexport.NewLogExporter(ctx)
if err != nil {
return fmt.Errorf("failed to initialize OpenTelemetry log exporter: %w", err)
}
level := slog.LevelDebug
if common.EnvConfig.AppEnv == "production" {
level = slog.LevelInfo
}
// Create the handler
var handler slog.Handler
switch {
case common.EnvConfig.LogJSON:
// Log as JSON if configured
handler = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: level,
})
case isatty.IsTerminal(os.Stdout.Fd()):
// Enable colors if we have a TTY
handler = tint.NewHandler(os.Stdout, &tint.Options{
TimeFormat: time.StampMilli,
Level: level,
})
default:
handler = slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: level,
})
}
// Create the logger provider
provider := sdklog.NewLoggerProvider(
sdklog.WithProcessor(
sdklog.NewBatchProcessor(exp),
),
sdklog.WithResource(resource),
)
// Set the logger provider globally
globallog.SetLoggerProvider(provider)
// Wrap the handler in a "fanout" one
handler = utils.LogFanoutHandler{
handler,
otelslog.NewHandler(common.Name, otelslog.WithLoggerProvider(provider)),
}
// Set the default slog to send logs to OTel and add the app name
log := slog.New(handler).
With(slog.String("app", common.Name)).
With(slog.String("version", common.Version))
slog.SetDefault(log)
return nil
}
func initOtelTracing(ctx context.Context, traces bool, resource *resource.Resource, httpClient *http.Client) (shutdownFn utils.Service, err error) {
if !traces {
otel.SetTracerProvider(tracenoop.NewTracerProvider())
return nil, nil
}
tr, err := autoexport.NewSpanExporter(ctx)
if err != nil {
return nil, fmt.Errorf("failed to initialize OpenTelemetry span exporter: %w", err)
}
tp := sdktrace.NewTracerProvider(
sdktrace.WithResource(resource),
sdktrace.WithBatcher(tr),
)
otel.SetTracerProvider(tp)
otel.SetTextMapPropagator(
propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
),
)
shutdownFn = func(shutdownCtx context.Context) error { //nolint:contextcheck
tpCtx, tpCancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer tpCancel()
shutdownErr := tp.Shutdown(tpCtx)
if shutdownErr != nil {
return fmt.Errorf("failed to gracefully shut down traces exporter: %w", shutdownErr)
}
return nil
}
// Add tracing to the HTTP client
httpClient.Transport = otelhttp.NewTransport(httpClient.Transport)
return shutdownFn, nil
}
func initOtelMetrics(ctx context.Context, metrics bool, resource *resource.Resource) (shutdownFn utils.Service, err error) {
if !metrics {
otel.SetMeterProvider(metricnoop.NewMeterProvider())
return nil, nil
}
mr, err := autoexport.NewMetricReader(ctx)
if err != nil {
return nil, fmt.Errorf("failed to initialize OpenTelemetry metric reader: %w", err)
}
mp := metric.NewMeterProvider(
metric.WithResource(resource),
metric.WithReader(mr),
)
otel.SetMeterProvider(mp)
shutdownFn = func(shutdownCtx context.Context) error { //nolint:contextcheck
mpCtx, mpCancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer mpCancel()
shutdownErr := mp.Shutdown(mpCtx)
if shutdownErr != nil {
return fmt.Errorf("failed to gracefully shut down metrics exporter: %w", shutdownErr)
}
return nil
}
return shutdownFn, nil
}

View File

@@ -1,107 +0,0 @@
package bootstrap
import (
"context"
"fmt"
"net/http"
"time"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"go.opentelemetry.io/contrib/exporters/autoexport"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel"
metricnoop "go.opentelemetry.io/otel/metric/noop"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.30.0"
tracenoop "go.opentelemetry.io/otel/trace/noop"
)
func defaultResource() (*resource.Resource, error) {
return resource.Merge(
resource.Default(),
resource.NewSchemaless(
semconv.ServiceName("pocket-id-backend"),
semconv.ServiceVersion(common.Version),
),
)
}
func initOtel(ctx context.Context, metrics, traces bool) (shutdownFns []utils.Service, httpClient *http.Client, err error) {
resource, err := defaultResource()
if err != nil {
return nil, nil, fmt.Errorf("failed to create OpenTelemetry resource: %w", err)
}
shutdownFns = make([]utils.Service, 0, 2)
httpClient = &http.Client{}
defaultTransport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
// Indicates a development-time error
panic("Default transport is not of type *http.Transport")
}
httpClient.Transport = defaultTransport.Clone()
if traces {
tr, err := autoexport.NewSpanExporter(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to initialize OpenTelemetry span exporter: %w", err)
}
tp := sdktrace.NewTracerProvider(
sdktrace.WithResource(resource),
sdktrace.WithBatcher(tr),
)
otel.SetTracerProvider(tp)
otel.SetTextMapPropagator(
propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
),
)
shutdownFns = append(shutdownFns, func(shutdownCtx context.Context) error { //nolint:contextcheck
tpCtx, tpCancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer tpCancel()
shutdownErr := tp.Shutdown(tpCtx)
if shutdownErr != nil {
return fmt.Errorf("failed to gracefully shut down traces exporter: %w", shutdownErr)
}
return nil
})
httpClient.Transport = otelhttp.NewTransport(httpClient.Transport)
} else {
otel.SetTracerProvider(tracenoop.NewTracerProvider())
}
if metrics {
mr, err := autoexport.NewMetricReader(ctx)
if err != nil {
return nil, nil, fmt.Errorf("failed to initialize OpenTelemetry metric reader: %w", err)
}
mp := metric.NewMeterProvider(
metric.WithResource(resource),
metric.WithReader(mr),
)
otel.SetMeterProvider(mp)
shutdownFns = append(shutdownFns, func(shutdownCtx context.Context) error { //nolint:contextcheck
mpCtx, mpCancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer mpCancel()
shutdownErr := mp.Shutdown(mpCtx)
if shutdownErr != nil {
return fmt.Errorf("failed to gracefully shut down metrics exporter: %w", shutdownErr)
}
return nil
})
} else {
otel.SetMeterProvider(metricnoop.NewMeterProvider())
}
return shutdownFns, httpClient, nil
}

View File

@@ -4,18 +4,21 @@ import (
"context"
"errors"
"fmt"
"log"
"log/slog"
"net"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/pocket-id/pocket-id/backend/frontend"
"github.com/gin-gonic/gin"
sloggin "github.com/samber/slog-gin"
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
"golang.org/x/time/rate"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/frontend"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/controller"
"github.com/pocket-id/pocket-id/backend/internal/middleware"
@@ -29,7 +32,8 @@ var registerTestControllers []func(apiGroup *gin.RouterGroup, db *gorm.DB, svc *
func initRouter(db *gorm.DB, svc *services) utils.Service {
runner, err := initRouterInternal(db, svc)
if err != nil {
log.Fatalf("failed to init router: %v", err)
slog.Error("Failed to init router", "error", err)
os.Exit(1)
}
return runner
}
@@ -45,15 +49,37 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
gin.SetMode(gin.TestMode)
}
r := gin.Default()
r.Use(gin.Logger())
// do not log these URLs
loggerSkipPathsPrefix := []string{
"GET /application-configuration/logo",
"GET /application-configuration/background-image",
"GET /application-configuration/favicon",
"GET /_app",
"GET /fonts",
"GET /healthz",
"HEAD /healthz",
}
r := gin.New()
r.Use(sloggin.NewWithConfig(slog.Default(), sloggin.Config{
Filters: []sloggin.Filter{
func(c *gin.Context) bool {
for _, prefix := range loggerSkipPathsPrefix {
if strings.HasPrefix(c.Request.Method+" "+c.Request.URL.String(), prefix) {
return false
}
}
return true
},
},
}))
if !common.EnvConfig.TrustProxy {
_ = r.SetTrustedProxies(nil)
}
if common.EnvConfig.TracingEnabled {
r.Use(otelgin.Middleware("pocket-id-backend"))
r.Use(otelgin.Middleware(common.Name))
}
rateLimitMiddleware := middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60)
@@ -64,7 +90,7 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
err := frontend.RegisterFrontend(r)
if errors.Is(err, frontend.ErrFrontendNotIncluded) {
log.Println("Frontend is not included in the build. Skipping frontend registration.")
slog.Warn("Frontend is not included in the build. Skipping frontend registration.")
} else if err != nil {
return nil, fmt.Errorf("failed to register frontend: %w", err)
}
@@ -101,21 +127,39 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
// Set up the server
srv := &http.Server{
Addr: net.JoinHostPort(common.EnvConfig.Host, common.EnvConfig.Port),
MaxHeaderBytes: 1 << 20,
ReadHeaderTimeout: 10 * time.Second,
Handler: r,
}
// Set up the listener
listener, err := net.Listen("tcp", srv.Addr)
network := "tcp"
addr := net.JoinHostPort(common.EnvConfig.Host, common.EnvConfig.Port)
if common.EnvConfig.UnixSocket != "" {
network = "unix"
addr = common.EnvConfig.UnixSocket
}
listener, err := net.Listen(network, addr)
if err != nil {
return nil, fmt.Errorf("failed to create TCP listener: %w", err)
return nil, fmt.Errorf("failed to create %s listener: %w", network, err)
}
// Set the socket mode if using a Unix socket
if network == "unix" && common.EnvConfig.UnixSocketMode != "" {
mode, err := strconv.ParseUint(common.EnvConfig.UnixSocketMode, 8, 32)
if err != nil {
return nil, fmt.Errorf("failed to parse UNIX socket mode '%s': %w", common.EnvConfig.UnixSocketMode, err)
}
if err := os.Chmod(addr, os.FileMode(mode)); err != nil {
return nil, fmt.Errorf("failed to set UNIX socket mode '%s': %w", common.EnvConfig.UnixSocketMode, err)
}
}
// Service runner function
runFn := func(ctx context.Context) error {
log.Printf("Server listening on %s", srv.Addr)
slog.Info("Server listening", slog.String("addr", addr))
// Start the server in a background goroutine
go func() {
@@ -124,7 +168,8 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
// Next call blocks until the server is shut down
srvErr := srv.Serve(listener)
if srvErr != http.ErrServerClosed {
log.Fatalf("Error starting app server: %v", srvErr)
slog.Error("Error starting app server", "error", srvErr)
os.Exit(1)
}
}()
@@ -132,7 +177,7 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
err = systemd.SdNotifyReady()
if err != nil {
// Log the error only
log.Printf("[WARN] Unable to notify systemd that the service is ready: %v", err)
slog.Warn("Unable to notify systemd that the service is ready", "error", err)
}
// Block until the context is canceled
@@ -145,7 +190,7 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
shutdownCancel()
if shutdownErr != nil {
// Log the error only (could be context canceled)
log.Printf("[WARN] App server shutdown error: %v", shutdownErr)
slog.Warn("App server shutdown error", "error", shutdownErr)
}
return nil

View File

@@ -26,27 +26,42 @@ type services struct {
}
// Initializes all services
// The context should be used by services only for initialization, and not for running
func initServices(initCtx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) {
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) {
svc = &services{}
svc.appConfigService = service.NewAppConfigService(initCtx, db)
svc.appConfigService, err = service.NewAppConfigService(ctx, db)
if err != nil {
return nil, fmt.Errorf("failed to create app config service: %w", err)
}
svc.emailService, err = service.NewEmailService(db, svc.appConfigService)
if err != nil {
return nil, fmt.Errorf("unable to create email service: %w", err)
return nil, fmt.Errorf("failed to create email service: %w", err)
}
svc.geoLiteService = service.NewGeoLiteService(httpClient)
svc.auditLogService = service.NewAuditLogService(db, svc.appConfigService, svc.emailService, svc.geoLiteService)
svc.jwtService = service.NewJwtService(svc.appConfigService)
svc.jwtService, err = service.NewJwtService(db, svc.appConfigService)
if err != nil {
return nil, fmt.Errorf("failed to create JWT service: %w", err)
}
svc.userService = service.NewUserService(db, svc.jwtService, svc.auditLogService, svc.emailService, svc.appConfigService)
svc.customClaimService = service.NewCustomClaimService(db)
svc.oidcService = service.NewOidcService(db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService)
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService)
if err != nil {
return nil, fmt.Errorf("failed to create OIDC service: %w", err)
}
svc.userGroupService = service.NewUserGroupService(db, svc.appConfigService)
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
svc.webauthnService = service.NewWebAuthnService(db, svc.jwtService, svc.auditLogService, svc.appConfigService)
svc.webauthnService, err = service.NewWebAuthnService(db, svc.jwtService, svc.auditLogService, svc.appConfigService)
if err != nil {
return nil, fmt.Errorf("failed to create WebAuthn service: %w", err)
}
return svc, nil
}

View File

@@ -0,0 +1,83 @@
package cmds
import (
"context"
"log/slog"
"net/http"
"os"
"time"
"github.com/spf13/cobra"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
type healthcheckFlags struct {
Endpoint string
Verbose bool
}
func init() {
var flags healthcheckFlags
healthcheckCmd := &cobra.Command{
Use: "healthcheck",
Short: "Performs a healthcheck of a running Pocket ID instance",
Run: func(cmd *cobra.Command, args []string) {
start := time.Now()
ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
defer cancel()
url := flags.Endpoint + "/healthz"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
slog.ErrorContext(ctx,
"Failed to create request object",
"error", err,
"url", url,
"ms", time.Since(start).Milliseconds(),
)
os.Exit(1)
}
res, err := http.DefaultClient.Do(req)
if err != nil {
slog.ErrorContext(ctx,
"Failed to perform request",
"error", err,
"url", url,
"ms", time.Since(start).Milliseconds(),
)
os.Exit(1)
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 {
if err != nil {
slog.ErrorContext(ctx,
"Healthcheck failed",
"status", res.StatusCode,
"url", url,
"ms", time.Since(start).Milliseconds(),
)
os.Exit(1)
}
}
if flags.Verbose {
slog.InfoContext(ctx,
"Healthcheck succeeded",
"status", res.StatusCode,
"url", url,
"ms", time.Since(start).Milliseconds(),
)
}
},
}
healthcheckCmd.Flags().StringVarP(&flags.Endpoint, "endpoint", "e", "http://localhost:"+common.EnvConfig.Port, "Endpoint for Pocket ID")
healthcheckCmd.Flags().BoolVarP(&flags.Verbose, "verbose", "v", false, "Enable verbose mode")
rootCmd.AddCommand(healthcheckCmd)
}

View File

@@ -0,0 +1,113 @@
package cmds
import (
"context"
"errors"
"fmt"
"strings"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/spf13/cobra"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
)
type keyRotateFlags struct {
Alg string
Crv string
Yes bool
}
func init() {
var flags keyRotateFlags
keyRotateCmd := &cobra.Command{
Use: "key-rotate",
Short: "Generates a new token signing key and replaces the current one",
RunE: func(cmd *cobra.Command, args []string) error {
db, err := bootstrap.NewDatabase()
if err != nil {
return err
}
return keyRotate(cmd.Context(), flags, db, &common.EnvConfig)
},
}
keyRotateCmd.Flags().StringVarP(&flags.Alg, "alg", "a", "RS256", "Key algorithm. Supported values: RS256, RS384, RS512, ES256, ES384, ES512, EdDSA")
keyRotateCmd.Flags().StringVarP(&flags.Crv, "crv", "c", "", "Curve name when using EdDSA keys. Supported values: Ed25519")
keyRotateCmd.Flags().BoolVarP(&flags.Yes, "yes", "y", false, "Do not prompt for confirmation")
rootCmd.AddCommand(keyRotateCmd)
}
func keyRotate(ctx context.Context, flags keyRotateFlags, db *gorm.DB, envConfig *common.EnvConfigSchema) error {
// Validate the flags
switch strings.ToUpper(flags.Alg) {
case jwa.RS256().String(), jwa.RS384().String(), jwa.RS512().String(),
jwa.ES256().String(), jwa.ES384().String(), jwa.ES512().String():
// All good, but uppercase it for consistency
flags.Alg = strings.ToUpper(flags.Alg)
case strings.ToUpper(jwa.EdDSA().String()):
// Ensure Crv is set and valid
switch strings.ToUpper(flags.Crv) {
case strings.ToUpper(jwa.Ed25519().String()):
// All good, but ensure consistency in casing
flags.Crv = jwa.Ed25519().String()
case "":
return errors.New("a curve name is required when algorithm is EdDSA")
default:
return errors.New("unsupported EdDSA curve; supported values: Ed25519")
}
case "":
return errors.New("key algorithm is required")
default:
return errors.New("unsupported key algorithm; supported values: RS256, RS384, RS512, ES256, ES384, ES512, EdDSA")
}
if !flags.Yes {
fmt.Println("WARNING: Rotating the private key will invalidate all existing tokens. Both pocket-id and all client applications will likely need to be restarted.")
ok, err := utils.PromptForConfirmation("Confirm")
if err != nil {
return err
}
if !ok {
fmt.Println("Aborted")
return nil
}
}
// Init the services we need
appConfigService, err := service.NewAppConfigService(ctx, db)
if err != nil {
return fmt.Errorf("failed to create app config service: %w", err)
}
// Get the key provider
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, appConfigService.GetDbConfig().InstanceID.Value)
if err != nil {
return fmt.Errorf("failed to get key provider: %w", err)
}
// Generate a new key
key, err := jwkutils.GenerateKey(flags.Alg, flags.Crv)
if err != nil {
return fmt.Errorf("failed to generate key: %w", err)
}
// Save the key
err = keyProvider.SaveKey(key)
if err != nil {
return fmt.Errorf("failed to store new key: %w", err)
}
fmt.Println("Key rotated successfully")
fmt.Println("Note: if pocket-id is running, you will need to restart it for the new key to be loaded")
return nil
}

View File

@@ -0,0 +1,216 @@
package cmds
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/service"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
testingutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
)
func TestKeyRotate(t *testing.T) {
tests := []struct {
name string
flags keyRotateFlags
wantErr bool
errMsg string
}{
{
name: "valid RS256",
flags: keyRotateFlags{
Alg: "RS256",
Yes: true,
},
wantErr: false,
},
{
name: "valid EdDSA with Ed25519",
flags: keyRotateFlags{
Alg: "EdDSA",
Crv: "Ed25519",
Yes: true,
},
wantErr: false,
},
{
name: "invalid algorithm",
flags: keyRotateFlags{
Alg: "INVALID",
Yes: true,
},
wantErr: true,
errMsg: "unsupported key algorithm",
},
{
name: "EdDSA without curve",
flags: keyRotateFlags{
Alg: "EdDSA",
Yes: true,
},
wantErr: true,
errMsg: "a curve name is required when algorithm is EdDSA",
},
{
name: "empty algorithm",
flags: keyRotateFlags{
Alg: "",
Yes: true,
},
wantErr: true,
errMsg: "key algorithm is required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Run("file storage", func(t *testing.T) {
testKeyRotateWithFileStorage(t, tt.flags, tt.wantErr, tt.errMsg)
})
t.Run("database storage", func(t *testing.T) {
testKeyRotateWithDatabaseStorage(t, tt.flags, tt.wantErr, tt.errMsg)
})
})
}
}
func testKeyRotateWithFileStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) {
// Create temporary directory for keys
tempDir := t.TempDir()
keysPath := filepath.Join(tempDir, "keys")
err := os.MkdirAll(keysPath, 0755)
require.NoError(t, err)
// Set up file storage config
envConfig := &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: keysPath,
}
// Create test database
db := testingutils.NewDatabaseForTest(t)
// Initialize app config service and create instance
appConfigService, err := service.NewAppConfigService(t.Context(), db)
require.NoError(t, err)
instanceID := appConfigService.GetDbConfig().InstanceID.Value
// Check if key exists before rotation
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, instanceID)
require.NoError(t, err)
// Run the key rotation
err = keyRotate(t.Context(), flags, db, envConfig)
if wantErr {
require.Error(t, err)
if errMsg != "" {
require.ErrorContains(t, err, errMsg)
}
return
}
require.NoError(t, err)
// Verify key was created
key, err := keyProvider.LoadKey()
require.NoError(t, err)
require.NotNil(t, key)
// Verify the algorithm matches what we requested
alg, _ := key.Algorithm()
assert.NotEmpty(t, alg)
if flags.Alg != "" {
expectedAlg := flags.Alg
if expectedAlg == "EdDSA" {
// EdDSA keys should have the EdDSA algorithm
assert.Equal(t, "EdDSA", alg.String())
} else {
assert.Equal(t, expectedAlg, alg.String())
}
}
}
func testKeyRotateWithDatabaseStorage(t *testing.T, flags keyRotateFlags, wantErr bool, errMsg string) {
// Set up database storage config
envConfig := &common.EnvConfigSchema{
KeysStorage: "database",
EncryptionKey: []byte("test-encryption-key-characters-long"),
}
// Create test database
db := testingutils.NewDatabaseForTest(t)
// Initialize app config service and create instance
appConfigService, err := service.NewAppConfigService(t.Context(), db)
require.NoError(t, err)
instanceID := appConfigService.GetDbConfig().InstanceID.Value
// Get key provider
keyProvider, err := jwkutils.GetKeyProvider(db, envConfig, instanceID)
require.NoError(t, err)
// Run the key rotation
err = keyRotate(t.Context(), flags, db, envConfig)
if wantErr {
require.Error(t, err)
if errMsg != "" {
require.ErrorContains(t, err, errMsg)
}
return
}
require.NoError(t, err)
// Verify key was created
key, err := keyProvider.LoadKey()
require.NoError(t, err)
require.NotNil(t, key)
// Verify the algorithm matches what we requested
alg, _ := key.Algorithm()
assert.NotEmpty(t, alg)
if flags.Alg != "" {
expectedAlg := flags.Alg
if expectedAlg == "EdDSA" {
// EdDSA keys should have the EdDSA algorithm
assert.Equal(t, "EdDSA", alg.String())
} else {
assert.Equal(t, expectedAlg, alg.String())
}
}
}
func TestKeyRotateMultipleAlgorithms(t *testing.T) {
algorithms := []struct {
alg string
crv string
}{
{"RS256", ""},
{"RS384", ""},
// Skip RSA-4096 key generation test as it can take a long time
// {"RS512", ""},
{"ES256", ""},
{"ES384", ""},
{"ES512", ""},
{"EdDSA", "Ed25519"},
}
for _, algo := range algorithms {
t.Run(algo.alg, func(t *testing.T) {
// Test with database storage for all algorithms
testKeyRotateWithDatabaseStorage(t, keyRotateFlags{
Alg: algo.alg,
Crv: algo.crv,
Yes: true,
}, false, "")
})
}
}

View File

@@ -6,77 +6,80 @@ import (
"fmt"
"time"
"github.com/spf13/cobra"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils/signals"
)
// OneTimeAccessToken creates a one-time access token for the given user
// Args must contain the username or email of the user
func OneTimeAccessToken(args []string) error {
// Get a context that is canceled when the application is stopping
ctx := signals.SignalContext(context.Background())
var oneTimeAccessTokenCmd = &cobra.Command{
Use: "one-time-access-token [username or email]",
Short: "Generates a one-time access token for the given user",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// Get the username or email of the user
userArg := args[0]
// Get the username or email of the user
// Note length is 2 because the first argument is always the command (one-time-access-token)
if len(args) != 2 {
return errors.New("missing username or email of user; usage: one-time-access-token <username or email>")
}
userArg := args[1]
// Connect to the database
db := bootstrap.NewDatabase()
// Create the access token
var oneTimeAccessToken *model.OneTimeAccessToken
err := db.Transaction(func(tx *gorm.DB) error {
// Load the user to retrieve the user ID
var user model.User
queryCtx, queryCancel := context.WithTimeout(ctx, 10*time.Second)
defer queryCancel()
txErr := tx.
WithContext(queryCtx).
Where("username = ? OR email = ?", userArg, userArg).
First(&user).
Error
switch {
case errors.Is(txErr, gorm.ErrRecordNotFound):
return errors.New("user not found")
case txErr != nil:
return fmt.Errorf("failed to query for user: %w", txErr)
case user.ID == "":
return errors.New("invalid user loaded: ID is empty")
// Connect to the database
db, err := bootstrap.NewDatabase()
if err != nil {
return err
}
// Create a new access token that expires in 1 hour
oneTimeAccessToken, txErr = service.NewOneTimeAccessToken(user.ID, time.Now().Add(time.Hour))
if txErr != nil {
return fmt.Errorf("failed to generate access token: %w", txErr)
// Create the access token
var oneTimeAccessToken *model.OneTimeAccessToken
err = db.Transaction(func(tx *gorm.DB) error {
// Load the user to retrieve the user ID
var user model.User
queryCtx, queryCancel := context.WithTimeout(cmd.Context(), 10*time.Second)
defer queryCancel()
txErr := tx.
WithContext(queryCtx).
Where("username = ? OR email = ?", userArg, userArg).
First(&user).
Error
switch {
case errors.Is(txErr, gorm.ErrRecordNotFound):
return errors.New("user not found")
case txErr != nil:
return fmt.Errorf("failed to query for user: %w", txErr)
case user.ID == "":
return errors.New("invalid user loaded: ID is empty")
}
// Create a new access token that expires in 1 hour
oneTimeAccessToken, txErr = service.NewOneTimeAccessToken(user.ID, time.Now().Add(time.Hour))
if txErr != nil {
return fmt.Errorf("failed to generate access token: %w", txErr)
}
queryCtx, queryCancel = context.WithTimeout(cmd.Context(), 10*time.Second)
defer queryCancel()
txErr = tx.
WithContext(queryCtx).
Create(oneTimeAccessToken).
Error
if txErr != nil {
return fmt.Errorf("failed to save access token: %w", txErr)
}
return nil
})
if err != nil {
return err
}
queryCtx, queryCancel = context.WithTimeout(ctx, 10*time.Second)
defer queryCancel()
txErr = tx.
WithContext(queryCtx).
Create(oneTimeAccessToken).
Error
if txErr != nil {
return fmt.Errorf("failed to save access token: %w", txErr)
}
// Print the result
fmt.Printf(`A one-time access token valid for 1 hour has been created for "%s".`+"\n", userArg)
fmt.Printf("Use the following URL to sign in once: %s/lc/%s\n", common.EnvConfig.AppURL, oneTimeAccessToken.Token)
return nil
})
if err != nil {
return err
}
// Print the result
fmt.Printf(`A one-time access token valid for 1 hour has been created for "%s".`+"\n", userArg)
fmt.Printf("Use the following URL to sign in once: %s/lc/%s\n", common.EnvConfig.AppURL, oneTimeAccessToken.Token)
return nil
},
}
func init() {
rootCmd.AddCommand(oneTimeAccessTokenCmd)
}

View File

@@ -0,0 +1,36 @@
package cmds
import (
"context"
"log/slog"
"os"
"github.com/spf13/cobra"
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
"github.com/pocket-id/pocket-id/backend/internal/utils/signals"
)
var rootCmd = &cobra.Command{
Use: "pocket-id",
Short: "A simple and easy-to-use OIDC provider that allows users to authenticate with their passkeys to your services.",
Long: "By default, this command starts the pocket-id server.",
Run: func(cmd *cobra.Command, args []string) {
// Start the server
err := bootstrap.Bootstrap(cmd.Context())
if err != nil {
slog.Error("Failed to run pocket-id", "error", err)
os.Exit(1)
}
},
}
func Execute() {
// Get a context that is canceled when the application is stopping
ctx := signals.SignalContext(context.Background())
err := rootCmd.ExecuteContext(ctx)
if err != nil {
os.Exit(1)
}
}

View File

@@ -0,0 +1,19 @@
package cmds
import (
"fmt"
"github.com/spf13/cobra"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
func init() {
rootCmd.AddCommand(&cobra.Command{
Use: "version",
Short: "Print the version number",
Run: func(cmd *cobra.Command, args []string) {
fmt.Println("pocket-id " + common.Version)
},
})
}

View File

@@ -1,8 +1,13 @@
package common
import (
"log"
"errors"
"fmt"
"log/slog"
"net/url"
"os"
"reflect"
"strings"
"github.com/caarlos0/env/v11"
_ "github.com/joho/godotenv/autoload"
@@ -18,73 +23,182 @@ const (
)
const (
DbProviderSqlite DbProvider = "sqlite"
DbProviderPostgres DbProvider = "postgres"
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
DbProviderSqlite DbProvider = "sqlite"
DbProviderPostgres DbProvider = "postgres"
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
defaultSqliteConnString string = "file:data/pocket-id.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(2500)&_txlock=immediate"
)
type EnvConfigSchema struct {
AppEnv string `env:"APP_ENV"`
AppURL string `env:"APP_URL"`
DbProvider DbProvider `env:"DB_PROVIDER"`
DbConnectionString string `env:"DB_CONNECTION_STRING"`
DbConnectionString string `env:"DB_CONNECTION_STRING" options:"file"`
UploadPath string `env:"UPLOAD_PATH"`
KeysPath string `env:"KEYS_PATH"`
KeysStorage string `env:"KEYS_STORAGE"`
EncryptionKey []byte `env:"ENCRYPTION_KEY" options:"file"`
Port string `env:"PORT"`
Host string `env:"HOST"`
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY"`
UnixSocket string `env:"UNIX_SOCKET"`
UnixSocketMode string `env:"UNIX_SOCKET_MODE"`
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY" options:"file"`
GeoLiteDBPath string `env:"GEOLITE_DB_PATH"`
GeoLiteDBUrl string `env:"GEOLITE_DB_URL"`
LocalIPv6Ranges string `env:"LOCAL_IPV6_RANGES"`
UiConfigDisabled bool `env:"UI_CONFIG_DISABLED"`
MetricsEnabled bool `env:"METRICS_ENABLED"`
TracingEnabled bool `env:"TRACING_ENABLED"`
LogJSON bool `env:"LOG_JSON"`
TrustProxy bool `env:"TRUST_PROXY"`
AnalyticsDisabled bool `env:"ANALYTICS_DISABLED"`
}
var EnvConfig = &EnvConfigSchema{
AppEnv: "production",
DbProvider: "sqlite",
DbConnectionString: "file:data/pocket-id.db?_pragma=journal_mode(WAL)&_pragma=busy_timeout(2500)&_txlock=immediate",
UploadPath: "data/uploads",
KeysPath: "data/keys",
AppURL: "http://localhost:1411",
Port: "1411",
Host: "0.0.0.0",
MaxMindLicenseKey: "",
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
UiConfigDisabled: false,
MetricsEnabled: false,
TracingEnabled: false,
TrustProxy: false,
AnalyticsDisabled: false,
}
var EnvConfig = defaultConfig()
func init() {
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
log.Fatal(err)
err := parseEnvConfig()
if err != nil {
slog.Error("Configuration error", slog.Any("error", err))
os.Exit(1)
}
}
func defaultConfig() EnvConfigSchema {
return EnvConfigSchema{
AppEnv: "production",
DbProvider: "sqlite",
DbConnectionString: "",
UploadPath: "data/uploads",
KeysPath: "data/keys",
KeysStorage: "", // "database" or "file"
EncryptionKey: nil,
AppURL: "http://localhost:1411",
Port: "1411",
Host: "0.0.0.0",
UnixSocket: "",
UnixSocketMode: "",
MaxMindLicenseKey: "",
GeoLiteDBPath: "data/GeoLite2-City.mmdb",
GeoLiteDBUrl: MaxMindGeoLiteCityUrl,
LocalIPv6Ranges: "",
UiConfigDisabled: false,
MetricsEnabled: false,
TracingEnabled: false,
TrustProxy: false,
AnalyticsDisabled: false,
}
}
func parseEnvConfig() error {
parsers := map[reflect.Type]env.ParserFunc{
reflect.TypeOf([]byte{}): func(value string) (interface{}, error) {
return []byte(value), nil
},
}
err := env.ParseWithOptions(&EnvConfig, env.Options{
FuncMap: parsers,
})
if err != nil {
return fmt.Errorf("error parsing env config: %w", err)
}
err = resolveFileBasedEnvVariables(&EnvConfig)
if err != nil {
return err
}
// Validate the environment variables
switch EnvConfig.DbProvider {
case DbProviderSqlite:
if EnvConfig.DbConnectionString == "" {
log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for SQLite database")
EnvConfig.DbConnectionString = defaultSqliteConnString
}
case DbProviderPostgres:
if EnvConfig.DbConnectionString == "" {
log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for Postgres database")
return errors.New("missing required env var 'DB_CONNECTION_STRING' for Postgres database")
}
default:
log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
return errors.New("invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
}
parsedAppUrl, err := url.Parse(EnvConfig.AppURL)
if err != nil {
log.Fatal("APP_URL is not a valid URL")
return errors.New("APP_URL is not a valid URL")
}
if parsedAppUrl.Path != "" {
log.Fatal("APP_URL must not contain a path")
return errors.New("APP_URL must not contain a path")
}
switch EnvConfig.KeysStorage {
// KeysStorage defaults to "file" if empty
case "":
EnvConfig.KeysStorage = "file"
case "database":
if EnvConfig.EncryptionKey == nil {
return errors.New("ENCRYPTION_KEY must be non-empty when KEYS_STORAGE is database")
}
case "file":
// All good, these are valid values
default:
return fmt.Errorf("invalid value for KEYS_STORAGE: %s", EnvConfig.KeysStorage)
}
return nil
}
// resolveFileBasedEnvVariables uses reflection to automatically resolve file-based secrets
func resolveFileBasedEnvVariables(config *EnvConfigSchema) error {
val := reflect.ValueOf(config).Elem()
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)
// Only process string and []byte fields
isString := field.Kind() == reflect.String
isByteSlice := field.Kind() == reflect.Slice && field.Type().Elem().Kind() == reflect.Uint8
if !isString && !isByteSlice {
continue
}
// Only process fields with the "options" tag set to "file"
optionsTag := fieldType.Tag.Get("options")
if optionsTag != "file" {
continue
}
// Only process fields with the "env" tag
envTag := fieldType.Tag.Get("env")
if envTag == "" {
continue
}
envVarName := envTag
if commaIndex := len(envTag); commaIndex > 0 {
envVarName = envTag[:commaIndex]
}
// If the file environment variable is not set, skip
envVarFileName := envVarName + "_FILE"
envVarFileValue := os.Getenv(envVarFileName)
if envVarFileValue == "" {
continue
}
fileContent, err := os.ReadFile(envVarFileValue)
if err != nil {
return fmt.Errorf("failed to read file for env var %s: %w", envVarFileName, err)
}
if isString {
field.SetString(strings.TrimSpace(string(fileContent)))
} else {
field.SetBytes(fileContent)
}
}
return nil
}

View File

@@ -0,0 +1,305 @@
package common
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseEnvConfig(t *testing.T) {
// Store original config to restore later
originalConfig := EnvConfig
t.Cleanup(func() {
EnvConfig = originalConfig
})
t.Run("should parse valid SQLite config correctly", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, DbProviderSqlite, EnvConfig.DbProvider)
})
t.Run("should parse valid Postgres config correctly", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "postgres")
t.Setenv("DB_CONNECTION_STRING", "postgres://user:pass@localhost/db")
t.Setenv("APP_URL", "https://example.com")
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, DbProviderPostgres, EnvConfig.DbProvider)
})
t.Run("should fail with invalid DB_PROVIDER", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "invalid")
t.Setenv("DB_CONNECTION_STRING", "test")
t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "invalid DB_PROVIDER value")
})
t.Run("should set default SQLite connection string when DB_CONNECTION_STRING is empty", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "") // Explicitly empty
t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, defaultSqliteConnString, EnvConfig.DbConnectionString)
})
t.Run("should fail when Postgres DB_CONNECTION_STRING is missing", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "postgres")
t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "missing required env var 'DB_CONNECTION_STRING' for Postgres")
})
t.Run("should fail with invalid APP_URL", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "€://not-a-valid-url")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "APP_URL is not a valid URL")
})
t.Run("should fail when APP_URL contains path", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000/path")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "APP_URL must not contain a path")
})
t.Run("should default KEYS_STORAGE to 'file' when empty", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, "file", EnvConfig.KeysStorage)
})
t.Run("should fail when KEYS_STORAGE is 'database' but no encryption key", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("KEYS_STORAGE", "database")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "ENCRYPTION_KEY must be non-empty when KEYS_STORAGE is database")
})
t.Run("should accept valid KEYS_STORAGE values", func(t *testing.T) {
validStorageTypes := []string{"file", "database"}
for _, storage := range validStorageTypes {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("KEYS_STORAGE", storage)
if storage == "database" {
t.Setenv("ENCRYPTION_KEY", "test-key")
}
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, storage, EnvConfig.KeysStorage)
}
})
t.Run("should fail with invalid KEYS_STORAGE value", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("KEYS_STORAGE", "invalid")
err := parseEnvConfig()
require.Error(t, err)
assert.ErrorContains(t, err, "invalid value for KEYS_STORAGE")
})
t.Run("should parse boolean environment variables correctly", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "sqlite")
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
t.Setenv("APP_URL", "http://localhost:3000")
t.Setenv("UI_CONFIG_DISABLED", "true")
t.Setenv("METRICS_ENABLED", "true")
t.Setenv("TRACING_ENABLED", "false")
t.Setenv("TRUST_PROXY", "true")
t.Setenv("ANALYTICS_DISABLED", "false")
err := parseEnvConfig()
require.NoError(t, err)
assert.True(t, EnvConfig.UiConfigDisabled)
assert.True(t, EnvConfig.MetricsEnabled)
assert.False(t, EnvConfig.TracingEnabled)
assert.True(t, EnvConfig.TrustProxy)
assert.False(t, EnvConfig.AnalyticsDisabled)
})
t.Run("should parse string environment variables correctly", func(t *testing.T) {
EnvConfig = defaultConfig()
t.Setenv("DB_PROVIDER", "postgres")
t.Setenv("DB_CONNECTION_STRING", "postgres://test")
t.Setenv("APP_URL", "https://prod.example.com")
t.Setenv("APP_ENV", "staging")
t.Setenv("UPLOAD_PATH", "/custom/uploads")
t.Setenv("KEYS_PATH", "/custom/keys")
t.Setenv("PORT", "8080")
t.Setenv("HOST", "127.0.0.1")
t.Setenv("UNIX_SOCKET", "/tmp/app.sock")
t.Setenv("MAXMIND_LICENSE_KEY", "test-license")
t.Setenv("GEOLITE_DB_PATH", "/custom/geolite.mmdb")
err := parseEnvConfig()
require.NoError(t, err)
assert.Equal(t, "staging", EnvConfig.AppEnv)
assert.Equal(t, "/custom/uploads", EnvConfig.UploadPath)
assert.Equal(t, "8080", EnvConfig.Port)
assert.Equal(t, "127.0.0.1", EnvConfig.Host)
})
}
func TestResolveFileBasedEnvVariables(t *testing.T) {
// Create temporary directory for test files
tempDir := t.TempDir()
// Create test files
encryptionKeyFile := tempDir + "/encryption_key.txt"
encryptionKeyContent := "test-encryption-key-123"
err := os.WriteFile(encryptionKeyFile, []byte(encryptionKeyContent), 0600)
require.NoError(t, err)
dbConnFile := tempDir + "/db_connection.txt"
dbConnContent := "postgres://user:pass@localhost/testdb"
err = os.WriteFile(dbConnFile, []byte(dbConnContent), 0600)
require.NoError(t, err)
// Create a binary file for testing binary data handling
binaryKeyFile := tempDir + "/binary_key.bin"
binaryKeyContent := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10}
err = os.WriteFile(binaryKeyFile, binaryKeyContent, 0600)
require.NoError(t, err)
t.Run("should read file content for fields with options:file tag", func(t *testing.T) {
config := defaultConfig()
// Set environment variables pointing to files
t.Setenv("ENCRYPTION_KEY_FILE", encryptionKeyFile)
t.Setenv("DB_CONNECTION_STRING_FILE", dbConnFile)
err := resolveFileBasedEnvVariables(&config)
require.NoError(t, err)
// Verify file contents were read correctly
assert.Equal(t, []byte(encryptionKeyContent), config.EncryptionKey)
assert.Equal(t, dbConnContent, config.DbConnectionString)
})
t.Run("should skip fields without options:file tag", func(t *testing.T) {
config := defaultConfig()
originalAppURL := config.AppURL
// Set a file for a field that doesn't have options:file tag
t.Setenv("APP_URL_FILE", "/tmp/nonexistent.txt")
err := resolveFileBasedEnvVariables(&config)
require.NoError(t, err)
// AppURL should remain unchanged
assert.Equal(t, originalAppURL, config.AppURL)
})
t.Run("should skip non-string fields", func(t *testing.T) {
// This test verifies that non-string fields are skipped
// We test this indirectly by ensuring the function doesn't error
// when processing the actual EnvConfigSchema which has bool fields
config := defaultConfig()
err := resolveFileBasedEnvVariables(&config)
require.NoError(t, err)
})
t.Run("should skip when _FILE environment variable is not set", func(t *testing.T) {
config := defaultConfig()
originalEncryptionKey := config.EncryptionKey
// Don't set ENCRYPTION_KEY_FILE environment variable
err := resolveFileBasedEnvVariables(&config)
require.NoError(t, err)
// EncryptionKey should remain unchanged
assert.Equal(t, originalEncryptionKey, config.EncryptionKey)
})
t.Run("should handle multiple file-based variables simultaneously", func(t *testing.T) {
config := defaultConfig()
// Set multiple file environment variables
t.Setenv("ENCRYPTION_KEY_FILE", encryptionKeyFile)
t.Setenv("DB_CONNECTION_STRING_FILE", dbConnFile)
err := resolveFileBasedEnvVariables(&config)
require.NoError(t, err)
// All should be resolved correctly
assert.Equal(t, []byte(encryptionKeyContent), config.EncryptionKey)
assert.Equal(t, dbConnContent, config.DbConnectionString)
})
t.Run("should handle mixed file and non-file environment variables", func(t *testing.T) {
config := defaultConfig()
// Set both file and non-file environment variables
t.Setenv("ENCRYPTION_KEY_FILE", encryptionKeyFile)
err := resolveFileBasedEnvVariables(&config)
require.NoError(t, err)
// File-based should be resolved, others should remain as set by env parser
assert.Equal(t, []byte(encryptionKeyContent), config.EncryptionKey)
assert.Equal(t, "http://localhost:1411", config.AppURL)
})
t.Run("should handle binary data correctly", func(t *testing.T) {
config := defaultConfig()
// Set environment variable pointing to binary file
t.Setenv("ENCRYPTION_KEY_FILE", binaryKeyFile)
err := resolveFileBasedEnvVariables(&config)
require.NoError(t, err)
// Verify binary data was read correctly without corruption
assert.Equal(t, binaryKeyContent, config.EncryptionKey)
})
}

View File

@@ -65,6 +65,11 @@ type OidcClientSecretInvalidError struct{}
func (e *OidcClientSecretInvalidError) Error() string { return "invalid client secret" }
func (e *OidcClientSecretInvalidError) HttpStatusCode() int { return 400 }
type OidcClientAssertionInvalidError struct{}
func (e *OidcClientAssertionInvalidError) Error() string { return "invalid client assertion" }
func (e *OidcClientAssertionInvalidError) HttpStatusCode() int { return 400 }
type OidcInvalidAuthorizationCodeError struct{}
func (e *OidcInvalidAuthorizationCodeError) Error() string { return "invalid authorization code" }
@@ -344,3 +349,13 @@ func (e *OidcAuthorizationPendingError) Error() string {
func (e *OidcAuthorizationPendingError) HttpStatusCode() int {
return http.StatusBadRequest
}
type OpenSignupDisabledError struct{}
func (e *OpenSignupDisabledError) Error() string {
return "Open user signup is not enabled"
}
func (e *OpenSignupDisabledError) HttpStatusCode() int {
return http.StatusForbidden
}

View File

@@ -1,5 +1,8 @@
package common
// Name is the name of the application
const Name = "pocket-id"
// Version contains the Pocket ID version.
//
// It can be set at build time using -ldflags.

View File

@@ -38,10 +38,10 @@ func NewApiKeyController(group *gin.RouterGroup, authMiddleware *middleware.Auth
// @Summary List API keys
// @Description Get a paginated list of API keys belonging to the current user
// @Tags API Keys
// @Param page query int false "Page number, starting from 1" default(1)
// @Param limit query int false "Number of items per page" default(10)
// @Param sort_column query string false "Column to sort by" default("created_at")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
// @Param pagination[page] query int false "Page number for pagination" default(1)
// @Param pagination[limit] query int false "Number of items per page" default(20)
// @Param sort[column] query string false "Column to sort by"
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
// @Success 200 {object} dto.Paginated[dto.ApiKeyDto]
// @Router /api/api-keys [get]
func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) {
@@ -82,7 +82,7 @@ func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) {
userID := ctx.GetString("userID")
var input dto.ApiKeyCreateDto
if err := ctx.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(ctx, &input); err != nil {
_ = ctx.Error(err)
return
}

View File

@@ -3,6 +3,7 @@ package controller
import (
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common"
@@ -57,7 +58,6 @@ type AppConfigController struct {
// @Accept json
// @Produce json
// @Success 200 {array} dto.PublicAppConfigVariableDto
// @Failure 500 {object} object "{"error": "error message"}"
// @Router /application-configuration [get]
func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
configuration := acc.appConfigService.ListAppConfig(false)
@@ -85,7 +85,6 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
// @Accept json
// @Produce json
// @Success 200 {array} dto.AppConfigVariableDto
// @Security BearerAuth
// @Router /application-configuration/all [get]
func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
configuration := acc.appConfigService.ListAppConfig(true)
@@ -107,11 +106,10 @@ func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
// @Produce json
// @Param body body dto.AppConfigUpdateDto true "Application Configuration"
// @Success 200 {array} dto.AppConfigVariableDto
// @Security BearerAuth
// @Router /api/application-configuration [put]
func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
var input dto.AppConfigUpdateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
@@ -164,7 +162,6 @@ func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
// @Tags Application Configuration
// @Produce image/x-icon
// @Success 200 {file} binary "Favicon image"
// @Failure 404 {object} object "{"error": "File not found"}"
// @Router /api/application-configuration/favicon [get]
func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
acc.getImage(c, "favicon", "ico")
@@ -177,7 +174,6 @@ func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
// @Produce image/png
// @Produce image/jpeg
// @Success 200 {file} binary "Background image"
// @Failure 404 {object} object "{"error": "File not found"}"
// @Router /api/application-configuration/background-image [get]
func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.GetDbConfig().BackgroundImageType.Value
@@ -192,7 +188,6 @@ func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
// @Param light query boolean false "Light mode logo (true) or dark mode logo (false)"
// @Param file formData file true "Logo image file"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/application-configuration/logo [put]
func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
dbConfig := acc.appConfigService.GetDbConfig()
@@ -218,7 +213,6 @@ func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
// @Accept multipart/form-data
// @Param file formData file true "Favicon file (.ico)"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/application-configuration/favicon [put]
func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
file, err := c.FormFile("file")
@@ -242,7 +236,6 @@ func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
// @Accept multipart/form-data
// @Param file formData file true "Background image file"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/application-configuration/background-image [put]
func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.GetDbConfig().BackgroundImageType.Value
@@ -255,6 +248,8 @@ func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType
mimeType := utils.GetImageMimeType(imageType)
c.Header("Content-Type", mimeType)
utils.SetCacheControlHeader(c, 15*time.Minute, 24*time.Hour)
c.File(imagePath)
}
@@ -280,7 +275,6 @@ func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, ol
// @Description Manually trigger LDAP synchronization
// @Tags Application Configuration
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/application-configuration/sync-ldap [post]
func (acc *AppConfigController) syncLdapHandler(c *gin.Context) {
err := acc.ldapService.SyncAll(c.Request.Context())
@@ -297,7 +291,6 @@ func (acc *AppConfigController) syncLdapHandler(c *gin.Context) {
// @Description Send a test email to verify email configuration
// @Tags Application Configuration
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/application-configuration/test-email [post]
func (acc *AppConfigController) testEmailHandler(c *gin.Context) {
userID := c.GetString("userID")

View File

@@ -34,10 +34,10 @@ type AuditLogController struct {
// @Summary List audit logs
// @Description Get a paginated list of audit logs for the current user
// @Tags Audit Logs
// @Param page query int false "Page number, starting from 1" default(1)
// @Param limit query int false "Number of items per page" default(10)
// @Param sort_column query string false "Column to sort by" default("created_at")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
// @Param pagination[page] query int false "Page number for pagination" default(1)
// @Param pagination[limit] query int false "Number of items per page" default(20)
// @Param sort[column] query string false "Column to sort by"
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
// @Success 200 {object} dto.Paginated[dto.AuditLogDto]
// @Router /api/audit-logs [get]
func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
@@ -82,13 +82,14 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
// @Summary List all audit logs
// @Description Get a paginated list of all audit logs (admin only)
// @Tags Audit Logs
// @Param page query int false "Page number, starting from 1" default(1)
// @Param limit query int false "Number of items per page" default(10)
// @Param sort_column query string false "Column to sort by" default("created_at")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
// @Param user_id query string false "Filter by user ID"
// @Param event query string false "Filter by event type"
// @Param client_name query string false "Filter by client name"
// @Param pagination[page] query int false "Page number for pagination" default(1)
// @Param pagination[limit] query int false "Number of items per page" default(20)
// @Param sort[column] query string false "Column to sort by"
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
// @Param filters[userId] query string false "Filter by user ID"
// @Param filters[event] query string false "Filter by event type"
// @Param filters[clientName] query string false "Filter by client name"
// @Param filters[location] query string false "Filter by location type (external or internal)"
// @Success 200 {object} dto.Paginated[dto.AuditLogDto]
// @Router /api/audit-logs/all [get]
func (alc *AuditLogController) listAllAuditLogsHandler(c *gin.Context) {

View File

@@ -35,10 +35,6 @@ type CustomClaimController struct {
// @Tags Custom Claims
// @Produce json
// @Success 200 {array} string "List of suggested custom claim names"
// @Failure 401 {object} object "Unauthorized"
// @Failure 403 {object} object "Forbidden"
// @Failure 500 {object} object "Internal server error"
// @Security BearerAuth
// @Router /api/custom-claims/suggestions [get]
func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
claims, err := ccc.customClaimService.GetSuggestions(c.Request.Context())
@@ -63,7 +59,7 @@ func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Context) {
var input []dto.CustomClaimCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
@@ -93,12 +89,11 @@ func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Contex
// @Param userGroupId path string true "User Group ID"
// @Param claims body []dto.CustomClaimCreateDto true "List of custom claims to set for the user group"
// @Success 200 {array} dto.CustomClaimDto "Updated custom claims"
// @Security BearerAuth
// @Router /api/custom-claims/user-group/{userGroupId} [put]
func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.Context) {
var input []dto.CustomClaimCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}

View File

@@ -14,6 +14,10 @@ func NewTestController(group *gin.RouterGroup, testService *service.TestService)
testController := &TestController{TestService: testService}
group.POST("/test/reset", testController.resetAndSeedHandler)
group.POST("/test/refreshtoken", testController.signRefreshToken)
group.GET("/externalidp/jwks.json", testController.externalIdPJWKS)
group.POST("/externalidp/sign", testController.externalIdPSignToken)
}
type TestController struct {
@@ -21,19 +25,31 @@ type TestController struct {
}
func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
var baseURL string
if c.Request.TLS != nil {
baseURL = "https://" + c.Request.Host
} else {
baseURL = "http://" + c.Request.Host
}
skipLdap := c.Query("skip-ldap") == "true"
skipSeed := c.Query("skip-seed") == "true"
if err := tc.TestService.ResetDatabase(); err != nil {
_ = c.Error(err)
return
}
if err := tc.TestService.ResetApplicationImages(); err != nil {
if err := tc.TestService.ResetApplicationImages(c.Request.Context()); err != nil {
_ = c.Error(err)
return
}
if err := tc.TestService.SeedDatabase(); err != nil {
_ = c.Error(err)
return
if !skipSeed {
if err := tc.TestService.SeedDatabase(baseURL); err != nil {
_ = c.Error(err)
return
}
}
if err := tc.TestService.ResetAppConfig(c.Request.Context()); err != nil {
@@ -41,17 +57,71 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
return
}
if err := tc.TestService.SetLdapTestConfig(c.Request.Context()); err != nil {
_ = c.Error(err)
return
}
if !skipLdap {
if err := tc.TestService.SetLdapTestConfig(c.Request.Context()); err != nil {
_ = c.Error(err)
return
}
if err := tc.TestService.SyncLdap(c.Request.Context()); err != nil {
_ = c.Error(err)
return
if err := tc.TestService.SyncLdap(c.Request.Context()); err != nil {
_ = c.Error(err)
return
}
}
tc.TestService.SetJWTKeys()
c.Status(http.StatusNoContent)
}
func (tc *TestController) externalIdPJWKS(c *gin.Context) {
jwks, err := tc.TestService.GetExternalIdPJWKS()
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, jwks)
}
func (tc *TestController) externalIdPSignToken(c *gin.Context) {
var input struct {
Aud string `json:"aud"`
Iss string `json:"iss"`
Sub string `json:"sub"`
}
err := c.ShouldBindJSON(&input)
if err != nil {
_ = c.Error(err)
return
}
token, err := tc.TestService.SignExternalIdPToken(input.Iss, input.Sub, input.Aud)
if err != nil {
_ = c.Error(err)
return
}
c.Writer.WriteString(token)
}
func (tc *TestController) signRefreshToken(c *gin.Context) {
var input struct {
UserID string `json:"user"`
ClientID string `json:"client"`
RefreshToken string `json:"rt"`
}
err := c.ShouldBindJSON(&input)
if err != nil {
_ = c.Error(err)
return
}
token, err := tc.TestService.SignRefreshToken(input.UserID, input.ClientID, input.RefreshToken)
if err != nil {
_ = c.Error(err)
return
}
c.Writer.WriteString(token)
}

View File

@@ -2,19 +2,20 @@ package controller
import (
"errors"
"log"
"log/slog"
"net/http"
"net/url"
"strings"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
"time"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/middleware"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
)
// NewOidcController creates a new controller for OIDC related endpoints
@@ -48,9 +49,17 @@ func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
group.DELETE("/oidc/clients/:id/logo", oc.deleteClientLogoHandler)
group.POST("/oidc/clients/:id/logo", authMiddleware.Add(), fileSizeLimitMiddleware.Add(2<<20), oc.updateClientLogoHandler)
group.GET("/oidc/clients/:id/preview/:userId", authMiddleware.Add(), oc.getClientPreviewHandler)
group.POST("/oidc/device/authorize", oc.deviceAuthorizationHandler)
group.POST("/oidc/device/verify", authMiddleware.WithAdminNotRequired().Add(), oc.verifyDeviceCodeHandler)
group.GET("/oidc/device/info", authMiddleware.WithAdminNotRequired().Add(), oc.getDeviceCodeInfoHandler)
group.GET("/oidc/users/me/clients", authMiddleware.WithAdminNotRequired().Add(), oc.listOwnAuthorizedClientsHandler)
group.GET("/oidc/users/:id/clients", authMiddleware.Add(), oc.listAuthorizedClientsHandler)
group.DELETE("/oidc/users/me/clients/:clientId", authMiddleware.WithAdminNotRequired().Add(), oc.revokeOwnClientAuthorizationHandler)
}
type OidcController struct {
@@ -66,7 +75,6 @@ type OidcController struct {
// @Produce json
// @Param request body dto.AuthorizeOidcClientRequestDto true "Authorization request parameters"
// @Success 200 {object} dto.AuthorizeOidcClientResponseDto "Authorization code and callback URL"
// @Security BearerAuth
// @Router /api/oidc/authorize [post]
func (oc *OidcController) authorizeHandler(c *gin.Context) {
var input dto.AuthorizeOidcClientRequestDto
@@ -84,6 +92,7 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
response := dto.AuthorizeOidcClientResponseDto{
Code: code,
CallbackURL: callbackURL,
Issuer: common.EnvConfig.AppURL,
}
c.JSON(http.StatusOK, response)
@@ -97,7 +106,6 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
// @Produce json
// @Param request body dto.AuthorizationRequiredDto true "Authorization check parameters"
// @Success 200 {object} object "{ \"authorizationRequired\": true/false }"
// @Security BearerAuth
// @Router /api/oidc/authorization-required [post]
func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Context) {
var input dto.AuthorizationRequiredDto
@@ -121,11 +129,13 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex
// @Tags OIDC
// @Produce json
// @Param client_id formData string false "Client ID (if not using Basic Auth)"
// @Param client_secret formData string false "Client secret (if not using Basic Auth)"
// @Param client_secret formData string false "Client secret (if not using Basic Auth or client assertions)"
// @Param code formData string false "Authorization code (required for 'authorization_code' grant)"
// @Param grant_type formData string true "Grant type ('authorization_code' or 'refresh_token')"
// @Param code_verifier formData string false "PKCE code verifier (for authorization_code with PKCE)"
// @Param refresh_token formData string false "Refresh token (required for 'refresh_token' grant)"
// @Param client_assertion formData string false "Client assertion type (for 'authorization_code' grant when using client assertions)"
// @Param client_assertion_type formData string false "Client assertion type (for 'authorization_code' grant when using client assertions)"
// @Success 200 {object} dto.OidcTokenResponseDto "Token response with access_token and optional id_token and refresh_token"
// @Router /api/oidc/token [post]
func (oc *OidcController) createTokensHandler(c *gin.Context) {
@@ -195,7 +205,7 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
return
}
token, err := oc.jwtService.VerifyOauthAccessToken(authToken)
token, err := oc.jwtService.VerifyOAuthAccessToken(authToken)
if err != nil {
_ = c.Error(err)
return
@@ -224,7 +234,6 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
// @Description End user session and handle OIDC logout
// @Tags OIDC
// @Accept application/x-www-form-urlencoded
// @Produce html
// @Param id_token_hint query string false "ID token"
// @Param post_logout_redirect_uri query string false "URL to redirect to after logout"
// @Param state query string false "State parameter to include in the redirect"
@@ -251,7 +260,7 @@ func (oc *OidcController) EndSessionHandler(c *gin.Context) {
callbackURL, err := oc.oidcService.ValidateEndSession(c.Request.Context(), input, c.GetString("userID"))
if err != nil {
// If the validation fails, the user has to confirm the logout manually and doesn't get redirected
log.Printf("Error getting logout callback URL, the user has to confirm the logout manually: %v", err)
slog.WarnContext(c.Request.Context(), "Error getting logout callback URL, the user has to confirm the logout manually", "error", err)
c.Redirect(http.StatusFound, common.EnvConfig.AppURL+"/logout")
return
}
@@ -304,9 +313,21 @@ func (oc *OidcController) introspectTokenHandler(c *gin.Context) {
// find valid tokens) while still allowing it to be used by an application that is
// supposed to interact with our IdP (since that needs to have a client_id
// and client_secret anyway).
clientID, clientSecret, _ := c.Request.BasicAuth()
var (
creds service.ClientAuthCredentials
ok bool
)
creds.ClientID, creds.ClientSecret, ok = c.Request.BasicAuth()
if !ok {
// If there's no basic auth, check if we have a bearer token
bearer, ok := utils.BearerAuth(c.Request)
if ok {
creds.ClientAssertionType = service.ClientAssertionTypeJWTBearer
creds.ClientAssertion = bearer
}
}
response, err := oc.oidcService.IntrospectToken(c.Request.Context(), clientID, clientSecret, input.Token)
response, err := oc.oidcService.IntrospectToken(c.Request.Context(), creds, input.Token)
if err != nil {
_ = c.Error(err)
return
@@ -348,7 +369,6 @@ func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) {
// @Produce json
// @Param id path string true "Client ID"
// @Success 200 {object} dto.OidcClientWithAllowedUserGroupsDto "Client information"
// @Security BearerAuth
// @Router /api/oidc/clients/{id} [get]
func (oc *OidcController) getClientHandler(c *gin.Context) {
clientId := c.Param("id")
@@ -360,12 +380,12 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
clientDto := dto.OidcClientWithAllowedUserGroupsDto{}
err = dto.MapStruct(client, &clientDto)
if err == nil {
c.JSON(http.StatusOK, clientDto)
if err != nil {
_ = c.Error(err)
return
}
_ = c.Error(err)
c.JSON(http.StatusOK, clientDto)
}
// listClientsHandler godoc
@@ -373,12 +393,11 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
// @Description Get a paginated list of OIDC clients with optional search and sorting
// @Tags OIDC
// @Param search query string false "Search term to filter clients by name"
// @Param page query int false "Page number, starting from 1" default(1)
// @Param limit query int false "Number of items per page" default(10)
// @Param sort_column query string false "Column to sort by" default("name")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("asc")
// @Param pagination[page] query int false "Page number for pagination" default(1)
// @Param pagination[limit] query int false "Number of items per page" default(20)
// @Param sort[column] query string false "Column to sort by"
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
// @Success 200 {object} dto.Paginated[dto.OidcClientWithAllowedGroupsCountDto]
// @Security BearerAuth
// @Router /api/oidc/clients [get]
func (oc *OidcController) listClientsHandler(c *gin.Context) {
searchTerm := c.Query("search")
@@ -424,7 +443,6 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) {
// @Produce json
// @Param client body dto.OidcClientCreateDto true "Client information"
// @Success 201 {object} dto.OidcClientWithAllowedUserGroupsDto "Created client"
// @Security BearerAuth
// @Router /api/oidc/clients [post]
func (oc *OidcController) createClientHandler(c *gin.Context) {
var input dto.OidcClientCreateDto
@@ -454,7 +472,6 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
// @Tags OIDC
// @Param id path string true "Client ID"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/oidc/clients/{id} [delete]
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
err := oc.oidcService.DeleteClient(c.Request.Context(), c.Param("id"))
@@ -475,7 +492,6 @@ func (oc *OidcController) deleteClientHandler(c *gin.Context) {
// @Param id path string true "Client ID"
// @Param client body dto.OidcClientCreateDto true "Client information"
// @Success 200 {object} dto.OidcClientWithAllowedUserGroupsDto "Updated client"
// @Security BearerAuth
// @Router /api/oidc/clients/{id} [put]
func (oc *OidcController) updateClientHandler(c *gin.Context) {
var input dto.OidcClientCreateDto
@@ -506,7 +522,6 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
// @Produce json
// @Param id path string true "Client ID"
// @Success 200 {object} object "{ \"secret\": \"string\" }"
// @Security BearerAuth
// @Router /api/oidc/clients/{id}/secret [post]
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
secret, err := oc.oidcService.CreateClientSecret(c.Request.Context(), c.Param("id"))
@@ -535,6 +550,8 @@ func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
return
}
utils.SetCacheControlHeader(c, 15*time.Minute, 12*time.Hour)
c.Header("Content-Type", mimeType)
c.File(imagePath)
}
@@ -545,9 +562,8 @@ func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
// @Tags OIDC
// @Accept multipart/form-data
// @Param id path string true "Client ID"
// @Param file formData file true "Logo image file (PNG, JPG, or SVG, max 2MB)"
// @Param file formData file true "Logo image file (PNG, JPG, or SVG)"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/oidc/clients/{id}/logo [post]
func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
file, err := c.FormFile("file")
@@ -571,7 +587,6 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
// @Tags OIDC
// @Param id path string true "Client ID"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/oidc/clients/{id}/logo [delete]
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
err := oc.oidcService.DeleteClientLogo(c.Request.Context(), c.Param("id"))
@@ -592,7 +607,6 @@ func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
// @Param id path string true "Client ID"
// @Param groups body dto.OidcUpdateAllowedUserGroupsDto true "User group IDs"
// @Success 200 {object} dto.OidcClientDto "Updated client"
// @Security BearerAuth
// @Router /api/oidc/clients/{id}/allowed-user-groups [put]
func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
var input dto.OidcUpdateAllowedUserGroupsDto
@@ -637,6 +651,83 @@ func (oc *OidcController) deviceAuthorizationHandler(c *gin.Context) {
c.JSON(http.StatusOK, response)
}
// listOwnAuthorizedClientsHandler godoc
// @Summary List authorized clients for current user
// @Description Get a paginated list of OIDC clients that the current user has authorized
// @Tags OIDC
// @Param pagination[page] query int false "Page number for pagination" default(1)
// @Param pagination[limit] query int false "Number of items per page" default(20)
// @Param sort[column] query string false "Column to sort by"
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
// @Success 200 {object} dto.Paginated[dto.AuthorizedOidcClientDto]
// @Router /api/oidc/users/me/clients [get]
func (oc *OidcController) listOwnAuthorizedClientsHandler(c *gin.Context) {
userID := c.GetString("userID")
oc.listAuthorizedClients(c, userID)
}
// listAuthorizedClientsHandler godoc
// @Summary List authorized clients for a user
// @Description Get a paginated list of OIDC clients that a specific user has authorized
// @Tags OIDC
// @Param id path string true "User ID"
// @Param pagination[page] query int false "Page number for pagination" default(1)
// @Param pagination[limit] query int false "Number of items per page" default(20)
// @Param sort[column] query string false "Column to sort by"
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
// @Success 200 {object} dto.Paginated[dto.AuthorizedOidcClientDto]
// @Router /api/oidc/users/{id}/clients [get]
func (oc *OidcController) listAuthorizedClientsHandler(c *gin.Context) {
userID := c.Param("id")
oc.listAuthorizedClients(c, userID)
}
func (oc *OidcController) listAuthorizedClients(c *gin.Context, userID string) {
var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
_ = c.Error(err)
return
}
authorizedClients, pagination, err := oc.oidcService.ListAuthorizedClients(c.Request.Context(), userID, sortedPaginationRequest)
if err != nil {
_ = c.Error(err)
return
}
// Map the clients to DTOs
var authorizedClientsDto []dto.AuthorizedOidcClientDto
if err := dto.MapStructList(authorizedClients, &authorizedClientsDto); err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, dto.Paginated[dto.AuthorizedOidcClientDto]{
Data: authorizedClientsDto,
Pagination: pagination,
})
}
// revokeOwnClientAuthorizationHandler godoc
// @Summary Revoke authorization for an OIDC client
// @Description Revoke the authorization for a specific OIDC client for the current user
// @Tags OIDC
// @Param clientId path string true "Client ID to revoke authorization for"
// @Success 204 "No Content"
// @Router /api/oidc/users/me/clients/{clientId} [delete]
func (oc *OidcController) revokeOwnClientAuthorizationHandler(c *gin.Context) {
clientID := c.Param("clientId")
userID := c.GetString("userID")
err := oc.oidcService.RevokeAuthorizedClient(c.Request.Context(), userID, clientID)
if err != nil {
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
func (oc *OidcController) verifyDeviceCodeHandler(c *gin.Context) {
userCode := c.Query("code")
if userCode == "" {
@@ -672,3 +763,43 @@ func (oc *OidcController) getDeviceCodeInfoHandler(c *gin.Context) {
c.JSON(http.StatusOK, deviceCodeInfo)
}
// getClientPreviewHandler godoc
// @Summary Preview OIDC client data for user
// @Description Get a preview of the OIDC data (ID token, access token, userinfo) that would be sent to the client for a specific user
// @Tags OIDC
// @Produce json
// @Param id path string true "Client ID"
// @Param userId path string true "User ID to preview data for"
// @Param scopes query string false "Scopes to include in the preview (comma-separated)"
// @Success 200 {object} dto.OidcClientPreviewDto "Preview data including ID token, access token, and userinfo payloads"
// @Security BearerAuth
// @Router /api/oidc/clients/{id}/preview/{userId} [get]
func (oc *OidcController) getClientPreviewHandler(c *gin.Context) {
clientID := c.Param("id")
userID := c.Param("userId")
scopes := c.Query("scopes")
if clientID == "" {
_ = c.Error(&common.ValidationError{Message: "client ID is required"})
return
}
if userID == "" {
_ = c.Error(&common.ValidationError{Message: "user ID is required"})
return
}
if scopes == "" {
_ = c.Error(&common.ValidationError{Message: "scopes are required"})
return
}
preview, err := oc.oidcService.GetClientPreview(c.Request.Context(), clientID, userID, scopes)
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, preview)
}

View File

@@ -44,11 +44,17 @@ func NewUserController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
group.POST("/users/:id/one-time-access-token", authMiddleware.Add(), uc.createAdminOneTimeAccessTokenHandler)
group.POST("/users/:id/one-time-access-email", authMiddleware.Add(), uc.RequestOneTimeAccessEmailAsAdminHandler)
group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler)
group.POST("/one-time-access-email", rateLimitMiddleware.Add(rate.Every(10*time.Minute), 3), uc.RequestOneTimeAccessEmailAsUnauthenticatedUserHandler)
group.DELETE("/users/:id/profile-picture", authMiddleware.Add(), uc.resetUserProfilePictureHandler)
group.DELETE("/users/me/profile-picture", authMiddleware.WithAdminNotRequired().Add(), uc.resetCurrentUserProfilePictureHandler)
group.POST("/signup-tokens", authMiddleware.Add(), uc.createSignupTokenHandler)
group.GET("/signup-tokens", authMiddleware.Add(), uc.listSignupTokensHandler)
group.DELETE("/signup-tokens/:id", authMiddleware.Add(), uc.deleteSignupTokenHandler)
group.POST("/signup", rateLimitMiddleware.Add(rate.Every(1*time.Minute), 10), uc.signupHandler)
group.POST("/signup/setup", uc.signUpInitialAdmin)
}
type UserController struct {
@@ -85,10 +91,10 @@ func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
// @Description Get a paginated list of users with optional search and sorting
// @Tags Users
// @Param search query string false "Search term to filter users"
// @Param page query int false "Page number, starting from 1" default(1)
// @Param limit query int false "Number of items per page" default(10)
// @Param sort_column query string false "Column to sort by" default("created_at")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
// @Param pagination[page] query int false "Page number for pagination" default(1)
// @Param pagination[limit] query int false "Number of items per page" default(20)
// @Param sort[column] query string false "Column to sort by"
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
// @Success 200 {object} dto.Paginated[dto.UserDto]
// @Router /api/users [get]
func (uc *UserController) listUsersHandler(c *gin.Context) {
@@ -187,7 +193,7 @@ func (uc *UserController) deleteUserHandler(c *gin.Context) {
// @Router /api/users [post]
func (uc *UserController) createUserHandler(c *gin.Context) {
var input dto.UserCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
@@ -250,10 +256,7 @@ func (uc *UserController) getUserProfilePictureHandler(c *gin.Context) {
defer picture.Close()
}
_, ok := c.GetQuery("skipCache")
if !ok {
c.Header("Cache-Control", "public, max-age=900")
}
utils.SetCacheControlHeader(c, 15*time.Minute, 1*time.Hour)
c.DataFromReader(http.StatusOK, size, "image/png", picture, nil)
}
@@ -375,7 +378,7 @@ func (uc *UserController) createAdminOneTimeAccessTokenHandler(c *gin.Context) {
// @Router /api/one-time-access-email [post]
func (uc *UserController) RequestOneTimeAccessEmailAsUnauthenticatedUserHandler(c *gin.Context) {
var input dto.OneTimeAccessEmailAsUnauthenticatedUserDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
@@ -443,14 +446,23 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
c.JSON(http.StatusOK, userDto)
}
// getSetupAccessTokenHandler godoc
// @Summary Setup initial admin
// @Description Generate setup access token for initial admin user configuration
// signUpInitialAdmin godoc
// @Summary Sign up initial admin user
// @Description Sign up and generate setup access token for initial admin user
// @Tags Users
// @Accept json
// @Produce json
// @Param body body dto.SignUpDto true "User information"
// @Success 200 {object} dto.UserDto
// @Router /api/one-time-access-token/setup [post]
func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
user, token, err := uc.userService.SetupInitialAdmin(c.Request.Context())
// @Router /api/signup/setup [post]
func (uc *UserController) signUpInitialAdmin(c *gin.Context) {
var input dto.SignUpDto
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
user, token, err := uc.userService.SignUpInitialAdmin(c.Request.Context(), input)
if err != nil {
_ = c.Error(err)
return
@@ -498,10 +510,132 @@ func (uc *UserController) updateUserGroups(c *gin.Context) {
c.JSON(http.StatusOK, userDto)
}
// createSignupTokenHandler godoc
// @Summary Create signup token
// @Description Create a new signup token that allows user registration
// @Tags Users
// @Accept json
// @Produce json
// @Param token body dto.SignupTokenCreateDto true "Signup token information"
// @Success 201 {object} dto.SignupTokenDto
// @Router /api/signup-tokens [post]
func (uc *UserController) createSignupTokenHandler(c *gin.Context) {
var input dto.SignupTokenCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
_ = c.Error(err)
return
}
signupToken, err := uc.userService.CreateSignupToken(c.Request.Context(), input.ExpiresAt, input.UsageLimit)
if err != nil {
_ = c.Error(err)
return
}
var tokenDto dto.SignupTokenDto
if err := dto.MapStruct(signupToken, &tokenDto); err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusCreated, tokenDto)
}
// listSignupTokensHandler godoc
// @Summary List signup tokens
// @Description Get a paginated list of signup tokens
// @Tags Users
// @Param pagination[page] query int false "Page number for pagination" default(1)
// @Param pagination[limit] query int false "Number of items per page" default(20)
// @Param sort[column] query string false "Column to sort by"
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
// @Success 200 {object} dto.Paginated[dto.SignupTokenDto]
// @Router /api/signup-tokens [get]
func (uc *UserController) listSignupTokensHandler(c *gin.Context) {
var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
_ = c.Error(err)
return
}
tokens, pagination, err := uc.userService.ListSignupTokens(c.Request.Context(), sortedPaginationRequest)
if err != nil {
_ = c.Error(err)
return
}
var tokensDto []dto.SignupTokenDto
if err := dto.MapStructList(tokens, &tokensDto); err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, dto.Paginated[dto.SignupTokenDto]{
Data: tokensDto,
Pagination: pagination,
})
}
// deleteSignupTokenHandler godoc
// @Summary Delete signup token
// @Description Delete a signup token by ID
// @Tags Users
// @Param id path string true "Token ID"
// @Success 204 "No Content"
// @Router /api/signup-tokens/{id} [delete]
func (uc *UserController) deleteSignupTokenHandler(c *gin.Context) {
tokenID := c.Param("id")
err := uc.userService.DeleteSignupToken(c.Request.Context(), tokenID)
if err != nil {
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
// signupWithTokenHandler godoc
// @Summary Sign up
// @Description Create a new user account
// @Tags Users
// @Accept json
// @Produce json
// @Param user body dto.SignUpDto true "User information"
// @Success 201 {object} dto.SignUpDto
// @Router /api/signup [post]
func (uc *UserController) signupHandler(c *gin.Context) {
var input dto.SignUpDto
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
ipAddress := c.ClientIP()
userAgent := c.GetHeader("User-Agent")
user, accessToken, err := uc.userService.SignUp(c.Request.Context(), input, ipAddress, userAgent)
if err != nil {
_ = c.Error(err)
return
}
maxAge := int(uc.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes().Seconds())
cookie.AddAccessTokenCookie(c, maxAge, accessToken)
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusCreated, userDto)
}
// updateUser is an internal helper method, not exposed as an API endpoint
func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
var input dto.UserCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}

View File

@@ -40,10 +40,10 @@ type UserGroupController struct {
// @Description Get a paginated list of user groups with optional search and sorting
// @Tags User Groups
// @Param search query string false "Search term to filter user groups by name"
// @Param page query int false "Page number, starting from 1" default(1)
// @Param limit query int false "Number of items per page" default(10)
// @Param sort_column query string false "Column to sort by" default("name")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("asc")
// @Param pagination[page] query int false "Page number for pagination" default(1)
// @Param pagination[limit] query int false "Number of items per page" default(20)
// @Param sort[column] query string false "Column to sort by"
// @Param sort[direction] query string false "Sort direction (asc or desc)" default("asc")
// @Success 200 {object} dto.Paginated[dto.UserGroupDtoWithUserCount]
// @Router /api/user-groups [get]
func (ugc *UserGroupController) list(c *gin.Context) {
@@ -92,7 +92,6 @@ func (ugc *UserGroupController) list(c *gin.Context) {
// @Produce json
// @Param id path string true "User Group ID"
// @Success 200 {object} dto.UserGroupDtoWithUsers
// @Security BearerAuth
// @Router /api/user-groups/{id} [get]
func (ugc *UserGroupController) get(c *gin.Context) {
group, err := ugc.UserGroupService.Get(c.Request.Context(), c.Param("id"))
@@ -118,11 +117,10 @@ func (ugc *UserGroupController) get(c *gin.Context) {
// @Produce json
// @Param userGroup body dto.UserGroupCreateDto true "User group information"
// @Success 201 {object} dto.UserGroupDtoWithUsers "Created user group"
// @Security BearerAuth
// @Router /api/user-groups [post]
func (ugc *UserGroupController) create(c *gin.Context) {
var input dto.UserGroupCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
@@ -151,11 +149,10 @@ func (ugc *UserGroupController) create(c *gin.Context) {
// @Param id path string true "User Group ID"
// @Param userGroup body dto.UserGroupCreateDto true "User group information"
// @Success 200 {object} dto.UserGroupDtoWithUsers "Updated user group"
// @Security BearerAuth
// @Router /api/user-groups/{id} [put]
func (ugc *UserGroupController) update(c *gin.Context) {
var input dto.UserGroupCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
if err := dto.ShouldBindWithNormalizedJSON(c, &input); err != nil {
_ = c.Error(err)
return
}
@@ -183,7 +180,6 @@ func (ugc *UserGroupController) update(c *gin.Context) {
// @Produce json
// @Param id path string true "User Group ID"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/user-groups/{id} [delete]
func (ugc *UserGroupController) delete(c *gin.Context) {
if err := ugc.UserGroupService.Delete(c.Request.Context(), c.Param("id")); err != nil {
@@ -203,7 +199,6 @@ func (ugc *UserGroupController) delete(c *gin.Context) {
// @Param id path string true "User Group ID"
// @Param users body dto.UserGroupUpdateUsersDto true "List of user IDs to assign to this group"
// @Success 200 {object} dto.UserGroupDtoWithUsers
// @Security BearerAuth
// @Router /api/user-groups/{id}/users [put]
func (ugc *UserGroupController) updateUsers(c *gin.Context) {
var input dto.UserGroupUpdateUsersDto

View File

@@ -3,8 +3,9 @@ package controller
import (
"encoding/json"
"fmt"
"log"
"log/slog"
"net/http"
"os"
"github.com/gin-gonic/gin"
@@ -23,7 +24,9 @@ func NewWellKnownController(group *gin.RouterGroup, jwtService *service.JwtServi
var err error
wkc.oidcConfig, err = wkc.computeOIDCConfiguration()
if err != nil {
log.Fatalf("Failed to pre-compute OpenID Connect configuration document: %v", err)
slog.Error("Failed to pre-compute OpenID Connect configuration document", slog.Any("error", err))
os.Exit(1)
return
}
group.GET("/.well-known/jwks.json", wkc.jwksHandler)
@@ -69,20 +72,22 @@ func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
return nil, fmt.Errorf("failed to get key algorithm: %w", err)
}
config := map[string]any{
"issuer": appUrl,
"authorization_endpoint": appUrl + "/authorize",
"token_endpoint": appUrl + "/api/oidc/token",
"userinfo_endpoint": appUrl + "/api/oidc/userinfo",
"end_session_endpoint": appUrl + "/api/oidc/end-session",
"introspection_endpoint": appUrl + "/api/oidc/introspect",
"device_authorization_endpoint": appUrl + "/api/oidc/device/authorize",
"jwks_uri": appUrl + "/.well-known/jwks.json",
"grant_types_supported": []string{service.GrantTypeAuthorizationCode, service.GrantTypeRefreshToken, service.GrantTypeDeviceCode},
"scopes_supported": []string{"openid", "profile", "email", "groups"},
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
"response_types_supported": []string{"code", "id_token"},
"subject_types_supported": []string{"public"},
"id_token_signing_alg_values_supported": []string{alg.String()},
"issuer": appUrl,
"authorization_endpoint": appUrl + "/authorize",
"token_endpoint": appUrl + "/api/oidc/token",
"userinfo_endpoint": appUrl + "/api/oidc/userinfo",
"end_session_endpoint": appUrl + "/api/oidc/end-session",
"introspection_endpoint": appUrl + "/api/oidc/introspect",
"device_authorization_endpoint": appUrl + "/api/oidc/device/authorize",
"jwks_uri": appUrl + "/.well-known/jwks.json",
"grant_types_supported": []string{service.GrantTypeAuthorizationCode, service.GrantTypeRefreshToken, service.GrantTypeDeviceCode},
"scopes_supported": []string{"openid", "profile", "email", "groups"},
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
"response_types_supported": []string{"code", "id_token"},
"subject_types_supported": []string{"public"},
"id_token_signing_alg_values_supported": []string{alg.String()},
"authorization_response_iss_parameter_supported": true,
"code_challenge_methods_supported": []string{"plain", "S256"},
}
return json.Marshal(config)
}

View File

@@ -5,15 +5,15 @@ import (
)
type ApiKeyCreateDto struct {
Name string `json:"name" binding:"required,min=3,max=50"`
Description string `json:"description"`
Name string `json:"name" binding:"required,min=3,max=50" unorm:"nfc"`
Description *string `json:"description" unorm:"nfc"`
ExpiresAt datatype.DateTime `json:"expiresAt" binding:"required"`
}
type ApiKeyDto struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Description *string `json:"description"`
ExpiresAt datatype.DateTime `json:"expiresAt"`
LastUsedAt *datatype.DateTime `json:"lastUsedAt"`
CreatedAt datatype.DateTime `json:"createdAt"`

View File

@@ -12,11 +12,13 @@ type AppConfigVariableDto struct {
}
type AppConfigUpdateDto struct {
AppName string `json:"appName" binding:"required,min=1,max=30"`
AppName string `json:"appName" binding:"required,min=1,max=30" unorm:"nfc"`
SessionDuration string `json:"sessionDuration" binding:"required"`
EmailsVerified string `json:"emailsVerified" binding:"required"`
DisableAnimations string `json:"disableAnimations" binding:"required"`
AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"`
AllowUserSignups string `json:"allowUserSignups" binding:"required,oneof=disabled withToken open"`
AccentColor string `json:"accentColor"`
SmtpHost string `json:"smtpHost"`
SmtpPort string `json:"smtpPort"`
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`

View File

@@ -1,7 +1,6 @@
package dto
import (
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
)
@@ -9,18 +8,19 @@ type AuditLogDto struct {
ID string `json:"id"`
CreatedAt datatype.DateTime `json:"createdAt"`
Event model.AuditLogEvent `json:"event"`
IpAddress string `json:"ipAddress"`
Country string `json:"country"`
City string `json:"city"`
Device string `json:"device"`
UserID string `json:"userID"`
Username string `json:"username"`
Data model.AuditLogData `json:"data"`
Event string `json:"event"`
IpAddress string `json:"ipAddress"`
Country string `json:"country"`
City string `json:"city"`
Device string `json:"device"`
UserID string `json:"userID"`
Username string `json:"username"`
Data map[string]string `json:"data"`
}
type AuditLogFilterDto struct {
UserID string `form:"filters[userId]"`
Event string `form:"filters[event]"`
ClientName string `form:"filters[clientName]"`
Location string `form:"filters[location]"`
}

View File

@@ -6,6 +6,6 @@ type CustomClaimDto struct {
}
type CustomClaimCreateDto struct {
Key string `json:"key" binding:"required"`
Value string `json:"value" binding:"required"`
Key string `json:"key" binding:"required" unorm:"nfc"`
Value string `json:"value" binding:"required" unorm:"nfc"`
}

View File

@@ -1,109 +1,27 @@
package dto
import (
"errors"
"reflect"
"time"
"fmt"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/jinzhu/copier"
)
// MapStructList maps a list of source structs to a list of destination structs
func MapStructList[S any, D any](source []S, destination *[]D) error {
*destination = make([]D, 0, len(source))
func MapStructList[S any, D any](source []S, destination *[]D) (err error) {
*destination = make([]D, len(source))
for _, item := range source {
var destItem D
if err := MapStruct(item, &destItem); err != nil {
return err
for i, item := range source {
err = MapStruct(item, &((*destination)[i]))
if err != nil {
return fmt.Errorf("failed to map field %d: %w", i, err)
}
*destination = append(*destination, destItem)
}
return nil
}
// MapStruct maps a source struct to a destination struct
func MapStruct[S any, D any](source S, destination *D) error {
// Ensure destination is a non-nil pointer
destValue := reflect.ValueOf(destination)
if destValue.Kind() != reflect.Ptr || destValue.IsNil() {
return errors.New("destination must be a non-nil pointer to a struct")
}
// Ensure source is a struct
sourceValue := reflect.ValueOf(source)
if sourceValue.Kind() != reflect.Struct {
return errors.New("source must be a struct")
}
return mapStructInternal(sourceValue, destValue.Elem())
}
func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
for i := 0; i < destVal.NumField(); i++ {
destField := destVal.Field(i)
destFieldType := destVal.Type().Field(i)
if destFieldType.Anonymous {
if err := mapStructInternal(sourceVal, destField); err != nil {
return err
}
continue
}
sourceField := sourceVal.FieldByName(destFieldType.Name)
if sourceField.IsValid() && destField.CanSet() {
if err := mapField(sourceField, destField); err != nil {
return err
}
}
}
return nil
}
func mapField(sourceField reflect.Value, destField reflect.Value) error {
switch {
case sourceField.Type() == destField.Type():
destField.Set(sourceField)
case sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice:
return mapSlice(sourceField, destField)
case sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct:
return mapStructInternal(sourceField, destField)
default:
return mapSpecialTypes(sourceField, destField)
}
return nil
}
func mapSlice(sourceField reflect.Value, destField reflect.Value) error {
if sourceField.Type().Elem() == destField.Type().Elem() {
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
newSlice.Index(j).Set(sourceField.Index(j))
}
destField.Set(newSlice)
} else if sourceField.Type().Elem().Kind() == reflect.Struct && destField.Type().Elem().Kind() == reflect.Struct {
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
sourceElem := sourceField.Index(j)
destElem := reflect.New(destField.Type().Elem()).Elem()
if err := mapStructInternal(sourceElem, destElem); err != nil {
return err
}
newSlice.Index(j).Set(destElem)
}
destField.Set(newSlice)
}
return nil
}
func mapSpecialTypes(sourceField reflect.Value, destField reflect.Value) error {
if _, ok := sourceField.Interface().(datatype.DateTime); ok {
if sourceField.Type() == reflect.TypeOf(datatype.DateTime{}) && destField.Type() == reflect.TypeOf(time.Time{}) {
dateValue := sourceField.Interface().(datatype.DateTime)
destField.Set(reflect.ValueOf(dateValue.ToTime()))
}
}
return nil
func MapStruct(source any, destination any) error {
return copier.CopyWithOption(destination, source, copier.Option{
DeepCopy: true,
})
}

View File

@@ -0,0 +1,197 @@
package dto
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
type sourceStruct struct {
AString string
AStringPtr *string
ABool bool
ABoolPtr *bool
ACustomDateTime datatype.DateTime
ACustomDateTimePtr *datatype.DateTime
ANilStringPtr *string
ASlice []string
AMap map[string]int
AStruct embeddedStruct
AStructPtr *embeddedStruct
StringPtrToString *string
EmptyStringPtrToString *string
NilStringPtrToString *string
IntToInt64 int
AuditLogEventToString model.AuditLogEvent
}
type destStruct struct {
AString string
AStringPtr *string
ABool bool
ABoolPtr *bool
ACustomDateTime datatype.DateTime
ACustomDateTimePtr *datatype.DateTime
ANilStringPtr *string
ASlice []string
AMap map[string]int
AStruct embeddedStruct
AStructPtr *embeddedStruct
StringPtrToString string
EmptyStringPtrToString string
NilStringPtrToString string
IntToInt64 int64
AuditLogEventToString string
}
type embeddedStruct struct {
Foo string
Bar int64
}
func TestMapStruct(t *testing.T) {
src := sourceStruct{
AString: "abcd",
AStringPtr: utils.Ptr("xyz"),
ABool: true,
ABoolPtr: utils.Ptr(false),
ACustomDateTime: datatype.DateTime(time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC)),
ACustomDateTimePtr: utils.Ptr(datatype.DateTime(time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC))),
ANilStringPtr: nil,
ASlice: []string{"a", "b", "c"},
AMap: map[string]int{
"a": 1,
"b": 2,
},
AStruct: embeddedStruct{
Foo: "bar",
Bar: 42,
},
AStructPtr: &embeddedStruct{
Foo: "quo",
Bar: 111,
},
StringPtrToString: utils.Ptr("foobar"),
EmptyStringPtrToString: utils.Ptr(""),
NilStringPtrToString: nil,
IntToInt64: 99,
AuditLogEventToString: model.AuditLogEventAccountCreated,
}
var dst destStruct
err := MapStruct(src, &dst)
require.NoError(t, err)
assert.Equal(t, src.AString, dst.AString)
_ = assert.NotNil(t, src.AStringPtr) &&
assert.Equal(t, *src.AStringPtr, *dst.AStringPtr)
assert.Equal(t, src.ABool, dst.ABool)
_ = assert.NotNil(t, src.ABoolPtr) &&
assert.Equal(t, *src.ABoolPtr, *dst.ABoolPtr)
assert.Equal(t, src.ACustomDateTime, dst.ACustomDateTime)
_ = assert.NotNil(t, src.ACustomDateTimePtr) &&
assert.Equal(t, *src.ACustomDateTimePtr, *dst.ACustomDateTimePtr)
assert.Nil(t, dst.ANilStringPtr)
assert.Equal(t, src.ASlice, dst.ASlice)
assert.Equal(t, src.AMap, dst.AMap)
assert.Equal(t, "bar", dst.AStruct.Foo)
assert.Equal(t, int64(42), dst.AStruct.Bar)
_ = assert.NotNil(t, src.AStructPtr) &&
assert.Equal(t, "quo", dst.AStructPtr.Foo) &&
assert.Equal(t, int64(111), dst.AStructPtr.Bar)
assert.Equal(t, "foobar", dst.StringPtrToString)
assert.Empty(t, dst.EmptyStringPtrToString)
assert.Empty(t, dst.NilStringPtrToString)
assert.Equal(t, int64(99), dst.IntToInt64)
assert.Equal(t, "ACCOUNT_CREATED", dst.AuditLogEventToString)
}
func TestMapStructList(t *testing.T) {
sources := []sourceStruct{
{
AString: "first",
AStringPtr: utils.Ptr("one"),
ABool: true,
ABoolPtr: utils.Ptr(false),
ACustomDateTime: datatype.DateTime(time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC)),
ACustomDateTimePtr: utils.Ptr(datatype.DateTime(time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC))),
ASlice: []string{"a", "b"},
AMap: map[string]int{
"a": 1,
"b": 2,
},
AStruct: embeddedStruct{
Foo: "first_struct",
Bar: 10,
},
IntToInt64: 10,
},
{
AString: "second",
AStringPtr: utils.Ptr("two"),
ABool: false,
ABoolPtr: utils.Ptr(true),
ACustomDateTime: datatype.DateTime(time.Date(2026, 6, 7, 8, 9, 10, 0, time.UTC)),
ACustomDateTimePtr: utils.Ptr(datatype.DateTime(time.Date(2023, 6, 7, 8, 9, 10, 0, time.UTC))),
ASlice: []string{"c", "d", "e"},
AMap: map[string]int{
"c": 3,
"d": 4,
},
AStruct: embeddedStruct{
Foo: "second_struct",
Bar: 20,
},
IntToInt64: 20,
},
}
var destinations []destStruct
err := MapStructList(sources, &destinations)
require.NoError(t, err)
require.Len(t, destinations, 2)
// Verify first element
assert.Equal(t, "first", destinations[0].AString)
assert.Equal(t, "one", *destinations[0].AStringPtr)
assert.True(t, destinations[0].ABool)
assert.False(t, *destinations[0].ABoolPtr)
assert.Equal(t, datatype.DateTime(time.Date(2025, 1, 2, 3, 4, 5, 0, time.UTC)), destinations[0].ACustomDateTime)
assert.Equal(t, datatype.DateTime(time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)), *destinations[0].ACustomDateTimePtr)
assert.Equal(t, []string{"a", "b"}, destinations[0].ASlice)
assert.Equal(t, map[string]int{"a": 1, "b": 2}, destinations[0].AMap)
assert.Equal(t, "first_struct", destinations[0].AStruct.Foo)
assert.Equal(t, int64(10), destinations[0].AStruct.Bar)
assert.Equal(t, int64(10), destinations[0].IntToInt64)
// Verify second element
assert.Equal(t, "second", destinations[1].AString)
assert.Equal(t, "two", *destinations[1].AStringPtr)
assert.False(t, destinations[1].ABool)
assert.True(t, *destinations[1].ABoolPtr)
assert.Equal(t, datatype.DateTime(time.Date(2026, 6, 7, 8, 9, 10, 0, time.UTC)), destinations[1].ACustomDateTime)
assert.Equal(t, datatype.DateTime(time.Date(2023, 6, 7, 8, 9, 10, 0, time.UTC)), *destinations[1].ACustomDateTimePtr)
assert.Equal(t, []string{"c", "d", "e"}, destinations[1].ASlice)
assert.Equal(t, map[string]int{"c": 3, "d": 4}, destinations[1].AMap)
assert.Equal(t, "second_struct", destinations[1].AStruct.Foo)
assert.Equal(t, int64(20), destinations[1].AStruct.Bar)
assert.Equal(t, int64(20), destinations[1].IntToInt64)
}
func TestMapStructList_EmptySource(t *testing.T) {
var sources []sourceStruct
var destinations []destStruct
err := MapStructList(sources, &destinations)
require.NoError(t, err)
assert.Empty(t, destinations)
}

View File

@@ -0,0 +1,94 @@
package dto
import (
"net/http"
"reflect"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"golang.org/x/text/unicode/norm"
)
// Normalize iterates through an object and performs Unicode normalization on all string fields with the `unorm` tag.
func Normalize(obj any) {
v := reflect.ValueOf(obj)
if v.Kind() != reflect.Ptr || v.IsNil() {
return
}
v = v.Elem()
// Handle case where obj is a slice of models
if v.Kind() == reflect.Slice {
for i := 0; i < v.Len(); i++ {
elem := v.Index(i)
if elem.Kind() == reflect.Ptr && !elem.IsNil() && elem.Elem().Kind() == reflect.Struct {
Normalize(elem.Interface())
} else if elem.Kind() == reflect.Struct && elem.CanAddr() {
Normalize(elem.Addr().Interface())
}
}
return
}
if v.Kind() != reflect.Struct {
return
}
// Iterate through all fields looking for those with the "unorm" tag
t := v.Type()
loop:
for i := range t.NumField() {
field := t.Field(i)
unormTag := field.Tag.Get("unorm")
if unormTag == "" {
continue
}
fv := v.Field(i)
if !fv.CanSet() || fv.Kind() != reflect.String {
continue
}
var form norm.Form
switch unormTag {
case "nfc":
form = norm.NFC
case "nfkc":
form = norm.NFKC
case "nfd":
form = norm.NFD
case "nfkd":
form = norm.NFKD
default:
continue loop
}
val := fv.String()
val = form.String(val)
fv.SetString(val)
}
}
func ShouldBindWithNormalizedJSON(ctx *gin.Context, obj any) error {
return ctx.ShouldBindWith(obj, binding.JSON)
}
type NormalizerJSONBinding struct{}
func (NormalizerJSONBinding) Name() string {
return "json"
}
func (NormalizerJSONBinding) Bind(req *http.Request, obj any) error {
// Use the default JSON binder
err := binding.JSON.Bind(req, obj)
if err != nil {
return err
}
// Perform normalization
Normalize(obj)
return nil
}

View File

@@ -0,0 +1,84 @@
package dto
import (
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/text/unicode/norm"
)
type testDto struct {
Name string `unorm:"nfc"`
Description string `unorm:"nfd"`
Other string
BadForm string `unorm:"bad"`
}
func TestNormalize(t *testing.T) {
input := testDto{
// Is in NFC form already
Name: norm.NFC.String("Café"),
// NFC form will be normalized to NFD
Description: norm.NFC.String("vërø"),
// Should be unchanged
Other: "NöTag",
// Should be unchanged
BadForm: "BåD",
}
Normalize(&input)
assert.Equal(t, norm.NFC.String("Café"), input.Name)
assert.Equal(t, norm.NFD.String("vërø"), input.Description)
assert.Equal(t, "NöTag", input.Other)
assert.Equal(t, "BåD", input.BadForm)
}
func TestNormalizeSlice(t *testing.T) {
obj1 := testDto{
Name: norm.NFC.String("Café1"),
Description: norm.NFC.String("vërø1"),
Other: "NöTag1",
BadForm: "BåD1",
}
obj2 := testDto{
Name: norm.NFD.String("Résumé2"),
Description: norm.NFD.String("accéléré2"),
Other: "NöTag2",
BadForm: "BåD2",
}
t.Run("slice of structs", func(t *testing.T) {
slice := []testDto{obj1, obj2}
Normalize(&slice)
// Verify first element
assert.Equal(t, norm.NFC.String("Café1"), slice[0].Name)
assert.Equal(t, norm.NFD.String("vërø1"), slice[0].Description)
assert.Equal(t, "NöTag1", slice[0].Other)
assert.Equal(t, "BåD1", slice[0].BadForm)
// Verify second element
assert.Equal(t, norm.NFC.String("Résumé2"), slice[1].Name)
assert.Equal(t, norm.NFD.String("accéléré2"), slice[1].Description)
assert.Equal(t, "NöTag2", slice[1].Other)
assert.Equal(t, "BåD2", slice[1].BadForm)
})
t.Run("slice of pointers to structs", func(t *testing.T) {
slice := []*testDto{&obj1, &obj2}
Normalize(&slice)
// Verify first element
assert.Equal(t, norm.NFC.String("Café1"), slice[0].Name)
assert.Equal(t, norm.NFD.String("vërø1"), slice[0].Description)
assert.Equal(t, "NöTag1", slice[0].Other)
assert.Equal(t, "BåD1", slice[0].BadForm)
// Verify second element
assert.Equal(t, norm.NFC.String("Résumé2"), slice[1].Name)
assert.Equal(t, norm.NFD.String("accéléré2"), slice[1].Description)
assert.Equal(t, "NöTag2", slice[1].Other)
assert.Equal(t, "BåD2", slice[1].BadForm)
})
}

View File

@@ -1,17 +1,21 @@
package dto
import datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
type OidcClientMetaDataDto struct {
ID string `json:"id"`
Name string `json:"name"`
HasLogo bool `json:"hasLogo"`
ID string `json:"id"`
Name string `json:"name"`
HasLogo bool `json:"hasLogo"`
LaunchURL *string `json:"launchURL"`
}
type OidcClientDto struct {
OidcClientMetaDataDto
CallbackURLs []string `json:"callbackURLs"`
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
IsPublic bool `json:"isPublic"`
PkceEnabled bool `json:"pkceEnabled"`
CallbackURLs []string `json:"callbackURLs"`
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
IsPublic bool `json:"isPublic"`
PkceEnabled bool `json:"pkceEnabled"`
Credentials OidcClientCredentialsDto `json:"credentials"`
}
type OidcClientWithAllowedUserGroupsDto struct {
@@ -25,11 +29,24 @@ type OidcClientWithAllowedGroupsCountDto struct {
}
type OidcClientCreateDto struct {
Name string `json:"name" binding:"required,max=50"`
CallbackURLs []string `json:"callbackURLs"`
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
IsPublic bool `json:"isPublic"`
PkceEnabled bool `json:"pkceEnabled"`
Name string `json:"name" binding:"required,max=50" unorm:"nfc"`
CallbackURLs []string `json:"callbackURLs"`
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
IsPublic bool `json:"isPublic"`
PkceEnabled bool `json:"pkceEnabled"`
Credentials OidcClientCredentialsDto `json:"credentials"`
LaunchURL *string `json:"launchURL" binding:"omitempty,url"`
}
type OidcClientCredentialsDto struct {
FederatedIdentities []OidcClientFederatedIdentityDto `json:"federatedIdentities,omitempty"`
}
type OidcClientFederatedIdentityDto struct {
Issuer string `json:"issuer"`
Subject string `json:"subject,omitempty"`
Audience string `json:"audience,omitempty"`
JWKS string `json:"jwks,omitempty"`
}
type AuthorizeOidcClientRequestDto struct {
@@ -44,6 +61,7 @@ type AuthorizeOidcClientRequestDto struct {
type AuthorizeOidcClientResponseDto struct {
Code string `json:"code"`
CallbackURL string `json:"callbackURL"`
Issuer string `json:"issuer"`
}
type AuthorizationRequiredDto struct {
@@ -52,13 +70,15 @@ type AuthorizationRequiredDto struct {
}
type OidcCreateTokensDto struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code"`
DeviceCode string `form:"device_code"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
CodeVerifier string `form:"code_verifier"`
RefreshToken string `form:"refresh_token"`
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code"`
DeviceCode string `form:"device_code"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
CodeVerifier string `form:"code_verifier"`
RefreshToken string `form:"refresh_token"`
ClientAssertion string `form:"client_assertion"`
ClientAssertionType string `form:"client_assertion_type"`
}
type OidcIntrospectDto struct {
@@ -98,9 +118,11 @@ type OidcIntrospectionResponseDto struct {
}
type OidcDeviceAuthorizationRequestDto struct {
ClientID string `form:"client_id" binding:"required"`
Scope string `form:"scope" binding:"required"`
ClientSecret string `form:"client_secret"`
ClientID string `form:"client_id" binding:"required"`
Scope string `form:"scope" binding:"required"`
ClientSecret string `form:"client_secret"`
ClientAssertion string `form:"client_assertion"`
ClientAssertionType string `form:"client_assertion_type"`
}
type OidcDeviceAuthorizationResponseDto struct {
@@ -125,3 +147,15 @@ type DeviceCodeInfoDto struct {
AuthorizationRequired bool `json:"authorizationRequired"`
Client OidcClientMetaDataDto `json:"client"`
}
type AuthorizedOidcClientDto struct {
Scope string `json:"scope"`
Client OidcClientMetaDataDto `json:"client"`
LastUsedAt datatype.DateTime `json:"lastUsedAt"`
}
type OidcClientPreviewDto struct {
IdToken map[string]any `json:"idToken"`
AccessToken map[string]any `json:"accessToken"`
UserInfo map[string]any `json:"userInfo"`
}

View File

@@ -0,0 +1,21 @@
package dto
import (
"time"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
)
type SignupTokenCreateDto struct {
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
UsageLimit int `json:"usageLimit" binding:"required,min=1,max=100"`
}
type SignupTokenDto struct {
ID string `json:"id"`
Token string `json:"token"`
ExpiresAt datatype.DateTime `json:"expiresAt"`
UsageLimit int `json:"usageLimit"`
UsageCount int `json:"usageCount"`
CreatedAt datatype.DateTime `json:"createdAt"`
}

View File

@@ -1,6 +1,8 @@
package dto
import "time"
import (
"time"
)
type UserDto struct {
ID string `json:"id"`
@@ -17,10 +19,10 @@ type UserDto struct {
}
type UserCreateDto struct {
Username string `json:"username" binding:"required,username,min=2,max=50"`
Email string `json:"email" binding:"required,email"`
FirstName string `json:"firstName" binding:"required,min=1,max=50"`
LastName string `json:"lastName" binding:"max=50"`
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
Email string `json:"email" binding:"required,email" unorm:"nfc"`
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
IsAdmin bool `json:"isAdmin"`
Locale *string `json:"locale"`
Disabled bool `json:"disabled"`
@@ -33,7 +35,7 @@ type OneTimeAccessTokenCreateDto struct {
}
type OneTimeAccessEmailAsUnauthenticatedUserDto struct {
Email string `json:"email" binding:"required,email"`
Email string `json:"email" binding:"required,email" unorm:"nfc"`
RedirectPath string `json:"redirectPath"`
}
@@ -44,3 +46,11 @@ type OneTimeAccessEmailAsAdminDto struct {
type UserUpdateUserGroupDto struct {
UserGroupIds []string `json:"userGroupIds" binding:"required"`
}
type SignUpDto struct {
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
Email string `json:"email" binding:"required,email" unorm:"nfc"`
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
Token string `json:"token"`
}

View File

@@ -34,8 +34,8 @@ type UserGroupDtoWithUserCount struct {
}
type UserGroupCreateDto struct {
FriendlyName string `json:"friendlyName" binding:"required,min=2,max=50"`
Name string `json:"name" binding:"required,min=2,max=255"`
FriendlyName string `json:"friendlyName" binding:"required,min=2,max=50" unorm:"nfc"`
Name string `json:"name" binding:"required,min=2,max=255" unorm:"nfc"`
LdapID string `json:"-"`
}

View File

@@ -1,26 +1,29 @@
package dto
import (
"log"
"log/slog"
"os"
"regexp"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
)
// [a-zA-Z0-9] : The username must start with an alphanumeric character
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
// [a-zA-Z0-9] : The username must start with an alphanumeric character
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
regex := "^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$"
matched, _ := regexp.MatchString(regex, fl.Field().String())
return matched
return validateUsernameRegex.MatchString(fl.Field().String())
}
func init() {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("username", validateUsername); err != nil {
log.Fatalf("Failed to register custom validation: %v", err)
}
v, _ := binding.Validator.Engine().(*validator.Validate)
err := v.RegisterValidation("username", validateUsername)
if err != nil {
slog.Error("Failed to register custom validation", slog.Any("error", err))
os.Exit(1)
return
}
}

View File

@@ -19,5 +19,5 @@ type WebauthnCredentialDto struct {
}
type WebauthnCredentialUpdateDto struct {
Name string `json:"name" binding:"required,min=1,max=30"`
Name string `json:"name" binding:"required,min=1,max=50"`
}

View File

@@ -22,6 +22,7 @@ func (s *Scheduler) RegisterDbCleanupJobs(ctx context.Context, db *gorm.DB) erro
return errors.Join(
s.registerJob(ctx, "ClearWebauthnSessions", def, jobs.clearWebauthnSessions, true),
s.registerJob(ctx, "ClearOneTimeAccessTokens", def, jobs.clearOneTimeAccessTokens, true),
s.registerJob(ctx, "ClearSignupTokens", def, jobs.clearSignupTokens, true),
s.registerJob(ctx, "ClearOidcAuthorizationCodes", def, jobs.clearOidcAuthorizationCodes, true),
s.registerJob(ctx, "ClearOidcRefreshTokens", def, jobs.clearOidcRefreshTokens, true),
s.registerJob(ctx, "ClearAuditLogs", def, jobs.clearAuditLogs, true),
@@ -60,6 +61,21 @@ func (j *DbCleanupJobs) clearOneTimeAccessTokens(ctx context.Context) error {
return nil
}
// ClearSignupTokens deletes signup tokens that have expired
func (j *DbCleanupJobs) clearSignupTokens(ctx context.Context) error {
// Delete tokens that are expired OR have reached their usage limit
st := j.db.
WithContext(ctx).
Delete(&model.SignupToken{}, "expires_at < ?", datatype.DateTime(time.Now()))
if st.Error != nil {
return fmt.Errorf("failed to clean expired tokens: %w", st.Error)
}
slog.InfoContext(ctx, "Cleaned expired tokens", slog.Int64("count", st.RowsAffected))
return nil
}
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
func (j *DbCleanupJobs) clearOidcAuthorizationCodes(ctx context.Context) error {
st := j.db.

View File

@@ -29,7 +29,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
// Skip rate limiting for localhost and test environment
// If the client ip is localhost the request comes from the frontend
if ip == "127.0.0.1" || ip == "::1" || common.EnvConfig.AppEnv == "test" {
if ip == "" || ip == "127.0.0.1" || ip == "::1" || common.EnvConfig.AppEnv == "test" {
c.Next()
return
}

View File

@@ -8,6 +8,8 @@ import (
"strconv"
"strings"
"time"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
type AppConfigVariable struct {
@@ -35,8 +37,10 @@ type AppConfig struct {
AppName AppConfigVariable `key:"appName,public"` // Public
SessionDuration AppConfigVariable `key:"sessionDuration"`
EmailsVerified AppConfigVariable `key:"emailsVerified"`
AccentColor AppConfigVariable `key:"accentColor,public"` // Public
DisableAnimations AppConfigVariable `key:"disableAnimations,public"` // Public
AllowOwnAccountEdit AppConfigVariable `key:"allowOwnAccountEdit,public"` // Public
AllowUserSignups AppConfigVariable `key:"allowUserSignups,public"` // Public
// Internal
BackgroundImageType AppConfigVariable `key:"backgroundImageType,internal"` // Internal
LogoLightImageType AppConfigVariable `key:"logoLightImageType,internal"` // Internal
@@ -47,7 +51,7 @@ type AppConfig struct {
SmtpPort AppConfigVariable `key:"smtpPort"`
SmtpFrom AppConfigVariable `key:"smtpFrom"`
SmtpUser AppConfigVariable `key:"smtpUser"`
SmtpPassword AppConfigVariable `key:"smtpPassword"`
SmtpPassword AppConfigVariable `key:"smtpPassword,sensitive"`
SmtpTls AppConfigVariable `key:"smtpTls"`
SmtpSkipCertVerify AppConfigVariable `key:"smtpSkipCertVerify"`
EmailLoginNotificationEnabled AppConfigVariable `key:"emailLoginNotificationEnabled"`
@@ -58,7 +62,7 @@ type AppConfig struct {
LdapEnabled AppConfigVariable `key:"ldapEnabled,public"` // Public
LdapUrl AppConfigVariable `key:"ldapUrl"`
LdapBindDn AppConfigVariable `key:"ldapBindDn"`
LdapBindPassword AppConfigVariable `key:"ldapBindPassword"`
LdapBindPassword AppConfigVariable `key:"ldapBindPassword,sensitive"`
LdapBase AppConfigVariable `key:"ldapBase"`
LdapUserSearchFilter AppConfigVariable `key:"ldapUserSearchFilter"`
LdapUserGroupSearchFilter AppConfigVariable `key:"ldapUserGroupSearchFilter"`
@@ -76,7 +80,7 @@ type AppConfig struct {
LdapSoftDeleteUsers AppConfigVariable `key:"ldapSoftDeleteUsers"`
}
func (c *AppConfig) ToAppConfigVariableSlice(showAll bool) []AppConfigVariable {
func (c *AppConfig) ToAppConfigVariableSlice(showAll bool, redactSensitiveValues bool) []AppConfigVariable {
// Use reflection to iterate through all fields
cfgValue := reflect.ValueOf(c).Elem()
cfgType := cfgValue.Type()
@@ -96,11 +100,16 @@ func (c *AppConfig) ToAppConfigVariableSlice(showAll bool) []AppConfigVariable {
continue
}
fieldValue := cfgValue.Field(i)
value := cfgValue.Field(i).FieldByName("Value").String()
// Redact sensitive values if the value isn't empty, the UI config is disabled, and redactSensitiveValues is true
if value != "" && common.EnvConfig.UiConfigDisabled && redactSensitiveValues && attrs == "sensitive" {
value = "XXXXXXXXXX"
}
appConfigVariable := AppConfigVariable{
Key: key,
Value: fieldValue.FieldByName("Value").String(),
Value: value,
}
res = append(res, appConfigVariable)
@@ -169,7 +178,7 @@ type AppConfigKeyNotFoundError struct {
}
func (e AppConfigKeyNotFoundError) Error() string {
return fmt.Sprintf("cannot find config key '%s'", e.field)
return "cannot find config key '" + e.field + "'"
}
func (e AppConfigKeyNotFoundError) Is(target error) bool {
@@ -183,7 +192,7 @@ type AppConfigInternalForbiddenError struct {
}
func (e AppConfigInternalForbiddenError) Error() string {
return fmt.Sprintf("field '%s' is internal and can't be updated", e.field)
return "field '" + e.field + "' is internal and can't be updated"
}
func (e AppConfigInternalForbiddenError) Is(target error) bool {

View File

@@ -10,7 +10,7 @@ type AuditLog struct {
Base
Event AuditLogEvent `sortable:"true"`
IpAddress string `sortable:"true"`
IpAddress *string `sortable:"true"`
Country string `sortable:"true"`
City string `sortable:"true"`
UserAgent string `sortable:"true"`
@@ -28,6 +28,7 @@ type AuditLogEvent string //nolint:recvcheck
const (
AuditLogEventSignIn AuditLogEvent = "SIGN_IN"
AuditLogEventOneTimeAccessTokenSignIn AuditLogEvent = "TOKEN_SIGN_IN"
AuditLogEventAccountCreated AuditLogEvent = "ACCOUNT_CREATED"
AuditLogEventClientAuthorization AuditLogEvent = "CLIENT_AUTHORIZATION"
AuditLogEventNewClientAuthorization AuditLogEvent = "NEW_CLIENT_AUTHORIZATION"
AuditLogEventDeviceCodeAuthorization AuditLogEvent = "DEVICE_CODE_AUTHORIZATION"

View File

@@ -0,0 +1,11 @@
package model
type KV struct {
Key string `gorm:"primaryKey;not null"`
Value *string
}
// TableName overrides the table name used by KV to `kv`
func (KV) TableName() string {
return "kv"
}

View File

@@ -5,12 +5,15 @@ import (
"encoding/json"
"fmt"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"gorm.io/gorm"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
)
type UserAuthorizedOidcClient struct {
Scope string
Scope string
LastUsedAt datatype.DateTime `sortable:"true"`
UserID string `gorm:"primary_key;"`
User User
@@ -45,6 +48,8 @@ type OidcClient struct {
HasLogo bool `gorm:"-"`
IsPublic bool
PkceEnabled bool
Credentials OidcClientCredentials
LaunchURL *string
AllowedUserGroups []UserGroup `gorm:"many2many:oidc_clients_allowed_user_groups;"`
CreatedByID string
@@ -71,9 +76,49 @@ func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
return nil
}
type OidcClientCredentials struct { //nolint:recvcheck
FederatedIdentities []OidcClientFederatedIdentity `json:"federatedIdentities,omitempty"`
}
type OidcClientFederatedIdentity struct {
Issuer string `json:"issuer"`
Subject string `json:"subject,omitempty"`
Audience string `json:"audience,omitempty"`
JWKS string `json:"jwks,omitempty"` // URL of the JWKS
}
func (occ OidcClientCredentials) FederatedIdentityForIssuer(issuer string) (OidcClientFederatedIdentity, bool) {
if issuer == "" {
return OidcClientFederatedIdentity{}, false
}
for _, fi := range occ.FederatedIdentities {
if fi.Issuer == issuer {
return fi, true
}
}
return OidcClientFederatedIdentity{}, false
}
func (occ *OidcClientCredentials) Scan(value any) error {
switch v := value.(type) {
case []byte:
return json.Unmarshal(v, occ)
case string:
return json.Unmarshal([]byte(v), occ)
default:
return fmt.Errorf("unsupported type: %T", value)
}
}
func (occ OidcClientCredentials) Value() (driver.Value, error) {
return json.Marshal(occ)
}
type UrlList []string //nolint:recvcheck
func (cu *UrlList) Scan(value interface{}) error {
func (cu *UrlList) Scan(value any) error {
switch v := value.(type) {
case []byte:
return json.Unmarshal(v, cu)

View File

@@ -0,0 +1,28 @@
package model
import (
"time"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
)
type SignupToken struct {
Base
Token string `json:"token"`
ExpiresAt datatype.DateTime `json:"expiresAt" sortable:"true"`
UsageLimit int `json:"usageLimit" sortable:"true"`
UsageCount int `json:"usageCount" sortable:"true"`
}
func (st *SignupToken) IsExpired() bool {
return time.Time(st.ExpiresAt).Before(time.Now())
}
func (st *SignupToken) IsUsageLimitReached() bool {
return st.UsageCount >= st.UsageLimit
}
func (st *SignupToken) IsValid() bool {
return !st.IsExpired() && !st.IsUsageLimitReached()
}

View File

@@ -55,8 +55,8 @@ func (s *ApiKeyService) CreateApiKey(ctx context.Context, userID string, input d
apiKey := model.ApiKey{
Name: input.Name,
Key: utils.CreateSha256Hash(token), // Hash the token for storage
Description: &input.Description,
ExpiresAt: datatype.DateTime(input.ExpiresAt),
Description: input.Description,
ExpiresAt: input.ExpiresAt,
UserID: userID,
}

View File

@@ -4,17 +4,14 @@ import (
"context"
"errors"
"fmt"
"log"
"mime/multipart"
"os"
"reflect"
"slices"
"strings"
"sync/atomic"
"time"
"github.com/hashicorp/go-uuid"
"gorm.io/gorm"
"gorm.io/gorm/clause"
@@ -29,22 +26,22 @@ type AppConfigService struct {
db *gorm.DB
}
func NewAppConfigService(initCtx context.Context, db *gorm.DB) *AppConfigService {
func NewAppConfigService(ctx context.Context, db *gorm.DB) (*AppConfigService, error) {
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(initCtx)
err := service.LoadDbConfig(ctx)
if err != nil {
log.Fatalf("Failed to initialize app config service: %v", err)
return nil, fmt.Errorf("failed to initialize app config service: %w", err)
}
err = service.initInstanceID(initCtx)
err = service.initInstanceID(ctx)
if err != nil {
log.Fatalf("Failed to initialize instance ID: %v", err)
return nil, fmt.Errorf("failed to initialize instance ID: %w", err)
}
return service
return service, nil
}
// GetDbConfig returns the application configuration.
@@ -68,6 +65,8 @@ func (s *AppConfigService) getDefaultDbConfig() *model.AppConfig {
EmailsVerified: model.AppConfigVariable{Value: "false"},
DisableAnimations: model.AppConfigVariable{Value: "false"},
AllowOwnAccountEdit: model.AppConfigVariable{Value: "true"},
AllowUserSignups: model.AppConfigVariable{Value: "disabled"},
AccentColor: model.AppConfigVariable{Value: "default"},
// Internal
BackgroundImageType: model.AppConfigVariable{Value: "jpg"},
LogoLightImageType: model.AppConfigVariable{Value: "svg"},
@@ -232,11 +231,11 @@ func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppCon
s.dbConfig.Store(cfg)
// Return the updated config
res := cfg.ToAppConfigVariableSlice(true)
res := cfg.ToAppConfigVariableSlice(true, false)
return res, nil
}
// UpdateAppConfigValues
// UpdateAppConfigValues updates the application configuration values in the database.
func (s *AppConfigService) UpdateAppConfigValues(ctx context.Context, keysAndValues ...string) error {
// Count of keysAndValues must be even
if len(keysAndValues)%2 != 0 {
@@ -317,11 +316,11 @@ func (s *AppConfigService) UpdateAppConfigValues(ctx context.Context, keysAndVal
}
func (s *AppConfigService) ListAppConfig(showAll bool) []model.AppConfigVariable {
return s.GetDbConfig().ToAppConfigVariableSlice(showAll)
return s.GetDbConfig().ToAppConfigVariableSlice(showAll, true)
}
func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multipart.FileHeader, imageName string, oldImageType string) (err error) {
fileType := utils.GetFileExtension(uploadedFile.Filename)
fileType := strings.ToLower(utils.GetFileExtension(uploadedFile.Filename))
mimeType := utils.GetImageMimeType(fileType)
if mimeType == "" {
return &common.FileTypeNotSupportedError{}
@@ -355,66 +354,23 @@ func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multip
// LoadDbConfig loads the configuration values from the database into the DbConfig struct.
func (s *AppConfigService) LoadDbConfig(ctx context.Context) (err error) {
var dest *model.AppConfig
// If the UI config is disabled, only load from the env
if common.EnvConfig.UiConfigDisabled {
dest, err = s.loadDbConfigFromEnv(ctx, s.db)
} else {
dest, err = s.loadDbConfigInternal(ctx, s.db)
}
dest, err := s.loadDbConfigInternal(ctx, s.db)
if err != nil {
return err
}
// Update the value in the object
s.dbConfig.Store(dest)
return nil
}
func (s *AppConfigService) loadDbConfigFromEnv(ctx context.Context, tx *gorm.DB) (*model.AppConfig, error) {
// First, start from the default configuration
dest := s.getDefaultDbConfig()
// Iterate through each field
rt := reflect.ValueOf(dest).Elem().Type()
rv := reflect.ValueOf(dest).Elem()
for i := range rt.NumField() {
field := rt.Field(i)
// Get the key and internal tag values
tagValue := strings.Split(field.Tag.Get("key"), ",")
key := tagValue[0]
isInternal := slices.Contains(tagValue, "internal")
// Internal fields are loaded from the database as they can't be set from the environment
if isInternal {
var value string
err := tx.WithContext(ctx).
Model(&model.AppConfigVariable{}).
Where("key = ?", key).
Select("value").
First(&value).Error
if err == nil {
rv.Field(i).FieldByName("Value").SetString(value)
}
continue
}
envVarName := utils.CamelCaseToScreamingSnakeCase(key)
// Set the value if it's set
value, ok := os.LookupEnv(envVarName)
if ok {
rv.Field(i).FieldByName("Value").SetString(value)
}
func (s *AppConfigService) loadDbConfigInternal(ctx context.Context, tx *gorm.DB) (*model.AppConfig, error) {
// If the UI config is disabled, only load from the env
if common.EnvConfig.UiConfigDisabled {
dest, err := s.loadDbConfigFromEnv(ctx, tx)
return dest, err
}
return dest, nil
}
func (s *AppConfigService) loadDbConfigInternal(ctx context.Context, tx *gorm.DB) (*model.AppConfig, error) {
// First, start from the default configuration
dest := s.getDefaultDbConfig()
@@ -444,6 +400,59 @@ func (s *AppConfigService) loadDbConfigInternal(ctx context.Context, tx *gorm.DB
return dest, nil
}
func (s *AppConfigService) loadDbConfigFromEnv(ctx context.Context, tx *gorm.DB) (*model.AppConfig, error) {
// First, start from the default configuration
dest := s.getDefaultDbConfig()
// Iterate through each field
rt := reflect.ValueOf(dest).Elem().Type()
rv := reflect.ValueOf(dest).Elem()
for i := range rt.NumField() {
field := rt.Field(i)
// Get the key and internal tag values
key, attrs, _ := strings.Cut(field.Tag.Get("key"), ",")
// Internal fields are loaded from the database as they can't be set from the environment
if attrs == "internal" {
var value string
err := tx.WithContext(ctx).
Model(&model.AppConfigVariable{}).
Where("key = ?", key).
Select("value").
First(&value).Error
if err == nil {
rv.Field(i).FieldByName("Value").SetString(value)
}
continue
}
envVarName := utils.CamelCaseToScreamingSnakeCase(key)
// Set the value if it's set
value, ok := os.LookupEnv(envVarName)
if ok {
rv.Field(i).FieldByName("Value").SetString(value)
continue
}
// If it's sensitive, we also allow reading from file
if attrs == "sensitive" {
fileName := os.Getenv(envVarName + "_FILE")
if fileName != "" {
b, err := os.ReadFile(fileName)
if err != nil {
return nil, fmt.Errorf("failed to read secret '%s' from file '%s': %w", envVarName, fileName, err)
}
rv.Field(i).FieldByName("Value").SetString(string(b))
continue
}
}
}
return dest, nil
}
func (s *AppConfigService) initInstanceID(ctx context.Context) error {
// Check if the instance ID is already set
instanceID := s.GetDbConfig().InstanceID.Value

View File

@@ -3,17 +3,13 @@ package service
import (
"sync/atomic"
"testing"
"time"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/stretchr/testify/require"
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
)
// NewTestAppConfigService is a function used by tests to create AppConfigService objects with pre-defined configuration values
@@ -28,7 +24,7 @@ func NewTestAppConfigService(config *model.AppConfig) *AppConfigService {
func TestLoadDbConfig(t *testing.T) {
t.Run("empty config table", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
@@ -42,7 +38,7 @@ func TestLoadDbConfig(t *testing.T) {
})
t.Run("loads value from config table", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Populate the config table with some initial values
err := db.
@@ -72,7 +68,7 @@ func TestLoadDbConfig(t *testing.T) {
})
t.Run("ignores unknown config keys", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Add an entry with a key that doesn't exist in the config struct
err := db.Create([]model.AppConfigVariable{
@@ -93,7 +89,7 @@ func TestLoadDbConfig(t *testing.T) {
})
t.Run("loading config multiple times", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Initial state
err := db.Create([]model.AppConfigVariable{
@@ -135,7 +131,7 @@ func TestLoadDbConfig(t *testing.T) {
common.EnvConfig.UiConfigDisabled = true
// Create database with config that should be ignored
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
@@ -171,7 +167,7 @@ func TestLoadDbConfig(t *testing.T) {
common.EnvConfig.UiConfigDisabled = false
// Create database with config values that should take precedence
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
@@ -195,7 +191,7 @@ func TestLoadDbConfig(t *testing.T) {
func TestUpdateAppConfigValues(t *testing.T) {
t.Run("update single value", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -220,7 +216,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("update multiple values", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -264,7 +260,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("empty value resets to default", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -285,7 +281,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("error with odd number of arguments", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -301,7 +297,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
})
t.Run("error with invalid key", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -319,7 +315,7 @@ func TestUpdateAppConfigValues(t *testing.T) {
func TestUpdateAppConfig(t *testing.T) {
t.Run("updates configuration values from DTO", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
@@ -392,7 +388,7 @@ func TestUpdateAppConfig(t *testing.T) {
})
t.Run("empty values reset to defaults", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
// Create a service with default config and modify some values
service := &AppConfigService{
@@ -457,7 +453,7 @@ func TestUpdateAppConfig(t *testing.T) {
// Disable UI config
common.EnvConfig.UiConfigDisabled = true
db := newAppConfigTestDatabaseForTest(t)
db := testutils.NewDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
@@ -475,49 +471,3 @@ func TestUpdateAppConfig(t *testing.T) {
require.ErrorAs(t, err, &uiConfigDisabledErr)
})
}
// Implements gorm's logger.Writer interface
type testLoggerAdapter struct {
t *testing.T
}
func (l testLoggerAdapter) Printf(format string, args ...any) {
l.t.Logf(format, args...)
}
func newAppConfigTestDatabaseForTest(t *testing.T) *gorm.DB {
t.Helper()
// Get a name for this in-memory database that is specific to the test
dbName := utils.CreateSha256Hash(t.Name())
// Connect to a new in-memory SQL database
db, err := gorm.Open(
sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
&gorm.Config{
TranslateError: true,
Logger: logger.New(
testLoggerAdapter{t: t},
logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: false,
ParameterizedQueries: false,
Colorful: false,
},
),
})
require.NoError(t, err, "Failed to connect to test database")
// Create the app_config_variables table
err = db.Exec(`
CREATE TABLE app_config_variables
(
key VARCHAR(100) NOT NULL PRIMARY KEY,
value TEXT NOT NULL
)
`).Error
require.NoError(t, err, "Failed to create test config table")
return db
}

View File

@@ -3,13 +3,14 @@ package service
import (
"context"
"fmt"
"log"
"log/slog"
userAgentParser "github.com/mileusna/useragent"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
"go.opentelemetry.io/otel/trace"
"gorm.io/gorm"
)
@@ -21,19 +22,24 @@ type AuditLogService struct {
}
func NewAuditLogService(db *gorm.DB, appConfigService *AppConfigService, emailService *EmailService, geoliteService *GeoLiteService) *AuditLogService {
return &AuditLogService{db: db, appConfigService: appConfigService, emailService: emailService, geoliteService: geoliteService}
return &AuditLogService{
db: db,
appConfigService: appConfigService,
emailService: emailService,
geoliteService: geoliteService,
}
}
// Create creates a new audit log entry in the database
func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) model.AuditLog {
func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) (model.AuditLog, bool) {
country, city, err := s.geoliteService.GetLocationByIP(ipAddress)
if err != nil {
log.Printf("Failed to get IP location: %v", err)
// Log the error but don't interrupt the operation
slog.Warn("Failed to get IP location", "error", err)
}
auditLog := model.AuditLog{
Event: event,
IpAddress: ipAddress,
Country: country,
City: city,
UserAgent: userAgent,
@@ -41,33 +47,47 @@ func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent,
Data: data,
}
if ipAddress != "" {
// Only set ipAddress if not empty, because on Postgres we use INET columns that don't allow non-null empty values
auditLog.IpAddress = &ipAddress
}
// Save the audit log in the database
err = tx.
WithContext(ctx).
Create(&auditLog).
Error
if err != nil {
log.Printf("Failed to create audit log: %v", err)
return model.AuditLog{}
slog.Error("Failed to create audit log", "error", err)
return model.AuditLog{}, false
}
return auditLog
return auditLog, true
}
// CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before
func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddress, userAgent, userID string, tx *gorm.DB) model.AuditLog {
createdAuditLog := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx)
createdAuditLog, ok := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx)
if !ok {
// At this point the transaction has been canceled already, and error has been logged
return createdAuditLog
}
// Count the number of times the user has logged in from the same device
var count int64
err := tx.
stmt := tx.
WithContext(ctx).
Model(&model.AuditLog{}).
Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent).
Count(&count).
Error
Where("user_id = ? AND user_agent = ?", userID, userAgent)
if ipAddress == "" {
// An empty IP address is stored as NULL in the database
stmt = stmt.Where("ip_address IS NULL")
} else {
stmt = stmt.Where("ip_address = ?", ipAddress)
}
err := stmt.Count(&count).Error
if err != nil {
log.Printf("Failed to count audit logs: %v\n", err)
slog.ErrorContext(ctx, "Failed to count audit logs", slog.Any("error", err))
return createdAuditLog
}
@@ -76,7 +96,8 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddres
// We use a background context here as this is running in a goroutine
//nolint:contextcheck
go func() {
innerCtx := context.Background()
span := trace.SpanFromContext(ctx)
innerCtx := trace.ContextWithSpan(context.Background(), span)
// Note we don't use the transaction here because this is running in background
var user model.User
@@ -86,7 +107,8 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddres
First(&user).
Error
if innerErr != nil {
log.Printf("Failed to load user: %v", innerErr)
slog.ErrorContext(innerCtx, "Failed to load user from database to send notification email", slog.Any("error", innerErr))
return
}
innerErr = SendEmail(innerCtx, s.emailService, email.Address{
@@ -100,7 +122,8 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddres
DateTime: createdAuditLog.CreatedAt.UTC(),
})
if innerErr != nil {
log.Printf("Failed to send email to '%s': %v", user.Email, innerErr)
slog.ErrorContext(innerCtx, "Failed to send notification email", slog.Any("error", innerErr), slog.String("address", user.Email))
return
}
}()
}
@@ -150,6 +173,14 @@ func (s *AuditLogService) ListAllAuditLogs(ctx context.Context, sortedPagination
return nil, utils.PaginationResponse{}, fmt.Errorf("unsupported database dialect: %s", dialect)
}
}
if filters.Location != "" {
switch filters.Location {
case "external":
query = query.Where("country != 'Internal Network'")
case "internal":
query = query.Where("country = 'Internal Network'")
}
}
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs)
if err != nil {

View File

@@ -5,23 +5,28 @@ package service
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"encoding/base64"
"fmt"
"log"
"log/slog"
"os"
"path/filepath"
"time"
"github.com/fxamacker/cbor/v2"
"github.com/go-webauthn/webauthn/protocol"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
"github.com/pocket-id/pocket-id/backend/resources"
)
@@ -30,14 +35,43 @@ type TestService struct {
jwtService *JwtService
appConfigService *AppConfigService
ldapService *LdapService
externalIdPKey jwk.Key
}
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) *TestService {
return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService, ldapService: ldapService}
func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService *JwtService, ldapService *LdapService) (*TestService, error) {
s := &TestService{
db: db,
appConfigService: appConfigService,
jwtService: jwtService,
ldapService: ldapService,
}
err := s.initExternalIdP()
if err != nil {
return nil, fmt.Errorf("failed to initialize external IdP: %w", err)
}
return s, nil
}
// Initializes the "external IdP"
// This creates a new "issuing authority" containing a public JWKS
// It also stores the private key internally that will be used to issue JWTs
func (s *TestService) initExternalIdP() error {
// Generate a new ECDSA key
rawKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return fmt.Errorf("failed to generate private key: %w", err)
}
s.externalIdPKey, err = jwkutils.ImportRawKey(rawKey, jwa.ES256().String(), "")
if err != nil {
return fmt.Errorf("failed to import private key: %w", err)
}
return nil
}
//nolint:gocognit
func (s *TestService) SeedDatabase() error {
func (s *TestService) SeedDatabase(baseURL string) error {
err := s.db.Transaction(func(tx *gorm.DB) error {
users := []model.User{
{
@@ -120,6 +154,7 @@ func (s *TestService) SeedDatabase() error {
ID: "3654a746-35d4-4321-ac61-0bdcff2b4055",
},
Name: "Nextcloud",
LaunchURL: utils.Ptr("https://nextcloud.local"),
Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY
CallbackURLs: model.UrlList{"http://nextcloud/auth/callback"},
LogoutCallbackURLs: model.UrlList{"http://nextcloud/auth/logout/callback"},
@@ -138,6 +173,36 @@ func (s *TestService) SeedDatabase() error {
userGroups[1],
},
},
{
Base: model.Base{
ID: "7c21a609-96b5-4011-9900-272b8d31a9d1",
},
Name: "Tailscale",
Secret: "$2a$10$xcRReBsvkI1XI6FG8xu/pOgzeF00bH5Wy4d/NThwcdi3ZBpVq/B9a", // n4VfQeXlTzA6yKpWbR9uJcMdSx2qH0Lo
CallbackURLs: model.UrlList{"http://tailscale/auth/callback"},
LogoutCallbackURLs: model.UrlList{"http://tailscale/auth/logout/callback"},
CreatedByID: users[0].ID,
},
{
Base: model.Base{
ID: "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
},
Name: "Federated",
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
CallbackURLs: model.UrlList{"http://federated/auth/callback"},
CreatedByID: users[1].ID,
AllowedUserGroups: []model.UserGroup{},
Credentials: model.OidcClientCredentials{
FederatedIdentities: []model.OidcClientFederatedIdentity{
{
Issuer: "https://external-idp.local",
Audience: "api://PocketID",
Subject: "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
JWKS: baseURL + "/api/externalidp/jwks.json",
},
},
},
},
}
for _, client := range oidcClients {
if err := tx.Create(&client).Error; err != nil {
@@ -145,16 +210,28 @@ func (s *TestService) SeedDatabase() error {
}
}
authCode := model.OidcAuthorizationCode{
Code: "auth-code",
Scope: "openid profile",
Nonce: "nonce",
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
authCodes := []model.OidcAuthorizationCode{
{
Code: "auth-code",
Scope: "openid profile",
Nonce: "nonce",
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
},
{
Code: "federated",
Scope: "openid profile",
Nonce: "nonce",
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
UserID: users[1].ID,
ClientID: oidcClients[2].ID,
},
}
if err := tx.Create(&authCode).Error; err != nil {
return err
for _, authCode := range authCodes {
if err := tx.Create(&authCode).Error; err != nil {
return err
}
}
refreshToken := model.OidcRefreshToken{
@@ -177,13 +254,30 @@ func (s *TestService) SeedDatabase() error {
return err
}
userAuthorizedClient := model.UserAuthorizedOidcClient{
Scope: "openid profile email",
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
userAuthorizedClients := []model.UserAuthorizedOidcClient{
{
Scope: "openid profile email",
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
LastUsedAt: datatype.DateTime(time.Date(2025, 8, 1, 13, 0, 0, 0, time.UTC)),
},
{
Scope: "openid profile email",
UserID: users[0].ID,
ClientID: oidcClients[2].ID,
LastUsedAt: datatype.DateTime(time.Date(2025, 8, 10, 14, 0, 0, 0, time.UTC)),
},
{
Scope: "openid profile email",
UserID: users[1].ID,
ClientID: oidcClients[3].ID,
LastUsedAt: datatype.DateTime(time.Date(2025, 8, 12, 12, 0, 0, 0, time.UTC)),
},
}
if err := tx.Create(&userAuthorizedClient).Error; err != nil {
return err
for _, userAuthorizedClient := range userAuthorizedClients {
if err := tx.Create(&userAuthorizedClient).Error; err != nil {
return err
}
}
// To generate a new key pair, run the following command:
@@ -237,6 +331,50 @@ func (s *TestService) SeedDatabase() error {
return err
}
signupTokens := []model.SignupToken{
{
Base: model.Base{
ID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
},
Token: "VALID1234567890A",
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
UsageLimit: 1,
UsageCount: 0,
},
{
Base: model.Base{
ID: "b2c3d4e5-f6g7-8901-bcde-f12345678901",
},
Token: "PARTIAL567890ABC",
ExpiresAt: datatype.DateTime(time.Now().Add(7 * 24 * time.Hour)),
UsageLimit: 5,
UsageCount: 2,
},
{
Base: model.Base{
ID: "c3d4e5f6-g7h8-9012-cdef-123456789012",
},
Token: "EXPIRED34567890B",
ExpiresAt: datatype.DateTime(time.Now().Add(-24 * time.Hour)), // Expired
UsageLimit: 3,
UsageCount: 1,
},
{
Base: model.Base{
ID: "d4e5f6g7-h8i9-0123-def0-234567890123",
},
Token: "FULLYUSED567890C",
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
UsageLimit: 1,
UsageCount: 1, // Usage limit reached
},
}
for _, token := range signupTokens {
if err := tx.Create(&token).Error; err != nil {
return err
}
}
return nil
})
@@ -283,9 +421,9 @@ func (s *TestService) ResetDatabase() error {
return err
}
func (s *TestService) ResetApplicationImages() error {
func (s *TestService) ResetApplicationImages(ctx context.Context) error {
if err := os.RemoveAll(common.EnvConfig.UploadPath); err != nil {
log.Printf("Error removing directory: %v", err)
slog.ErrorContext(ctx, "Error removing directory", slog.Any("error", err))
return err
}
@@ -405,3 +543,45 @@ func (s *TestService) SetLdapTestConfig(ctx context.Context) error {
return nil
}
func (s *TestService) SignRefreshToken(userID, clientID, refreshToken string) (string, error) {
return s.jwtService.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
}
// GetExternalIdPJWKS returns the JWKS for the "external IdP".
func (s *TestService) GetExternalIdPJWKS() (jwk.Set, error) {
pubKey, err := s.externalIdPKey.PublicKey()
if err != nil {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
set := jwk.NewSet()
err = set.AddKey(pubKey)
if err != nil {
return nil, fmt.Errorf("failed to add public key to set: %w", err)
}
return set, nil
}
func (s *TestService) SignExternalIdPToken(iss, sub, aud string) (string, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Subject(sub).
Expiration(now.Add(time.Hour)).
IssuedAt(now).
Issuer(iss).
Audience([]string{aud}).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
}
alg, _ := s.externalIdPKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.externalIdPKey))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
return string(signed), nil
}

View File

@@ -7,12 +7,13 @@ import (
"errors"
"fmt"
"io"
"log"
"log/slog"
"net"
"net/http"
"net/netip"
"os"
"path/filepath"
"strings"
"sync"
"time"
@@ -22,9 +23,10 @@ import (
)
type GeoLiteService struct {
httpClient *http.Client
disableUpdater bool
mutex sync.RWMutex
httpClient *http.Client
disableUpdater bool
mutex sync.RWMutex
localIPv6Ranges []*net.IPNet
}
var localhostIPNets = []*net.IPNet{
@@ -50,21 +52,89 @@ func NewGeoLiteService(httpClient *http.Client) *GeoLiteService {
if common.EnvConfig.MaxMindLicenseKey == "" && common.EnvConfig.GeoLiteDBUrl == common.MaxMindGeoLiteCityUrl {
// Warn the user, and disable the periodic updater
log.Println("MAXMIND_LICENSE_KEY environment variable is empty. The GeoLite2 City database won't be updated.")
slog.Warn("MAXMIND_LICENSE_KEY environment variable is empty: the GeoLite2 City database won't be updated")
service.disableUpdater = true
}
// Initialize IPv6 local ranges
err := service.initializeIPv6LocalRanges()
if err != nil {
slog.Warn("Failed to initialize IPv6 local ranges", slog.Any("error", err))
}
return service
}
// initializeIPv6LocalRanges parses the LOCAL_IPV6_RANGES environment variable
func (s *GeoLiteService) initializeIPv6LocalRanges() error {
rangesEnv := common.EnvConfig.LocalIPv6Ranges
if rangesEnv == "" {
return nil // No local IPv6 ranges configured
}
ranges := strings.Split(rangesEnv, ",")
localRanges := make([]*net.IPNet, 0, len(ranges))
for _, rangeStr := range ranges {
rangeStr = strings.TrimSpace(rangeStr)
if rangeStr == "" {
continue
}
_, ipNet, err := net.ParseCIDR(rangeStr)
if err != nil {
return fmt.Errorf("invalid IPv6 range '%s': %w", rangeStr, err)
}
// Ensure it's an IPv6 range
if ipNet.IP.To4() != nil {
return fmt.Errorf("range '%s' is not a valid IPv6 range", rangeStr)
}
localRanges = append(localRanges, ipNet)
}
s.localIPv6Ranges = localRanges
if len(localRanges) > 0 {
slog.Info("Initialized IPv6 local ranges", slog.Int("count", len(localRanges)))
}
return nil
}
// isLocalIPv6 checks if the given IPv6 address is within any of the configured local ranges
func (s *GeoLiteService) isLocalIPv6(ip net.IP) bool {
if ip.To4() != nil {
return false // Not an IPv6 address
}
for _, localRange := range s.localIPv6Ranges {
if localRange.Contains(ip) {
return true
}
}
return false
}
func (s *GeoLiteService) DisableUpdater() bool {
return s.disableUpdater
}
// GetLocationByIP returns the country and city of the given IP address.
func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string, err error) {
if ipAddress == "" {
return "", "", nil
}
// Check the IP address against known private IP ranges
if ip := net.ParseIP(ipAddress); ip != nil {
// Check IPv6 local ranges first
if s.isLocalIPv6(ip) {
return "Internal Network", "LAN", nil
}
// Check existing IPv4 ranges
for _, ipNet := range tailscaleIPNets {
if ipNet.Contains(ip) {
return "Internal Network", "Tailscale", nil
@@ -82,6 +152,11 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
}
}
addr, err := netip.ParseAddr(ipAddress)
if err != nil {
return "", "", fmt.Errorf("failed to parse IP address: %w", err)
}
// Race condition between reading and writing the database.
s.mutex.RLock()
defer s.mutex.RUnlock()
@@ -92,11 +167,6 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
}
defer db.Close()
addr, err := netip.ParseAddr(ipAddress)
if err != nil {
return "", "", fmt.Errorf("failed to parse IP address: %w", err)
}
var record struct {
City struct {
Names map[string]string `maxminddb:"names"`
@@ -117,11 +187,11 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
// UpdateDatabase checks the age of the database and updates it if it's older than 14 days.
func (s *GeoLiteService) UpdateDatabase(parentCtx context.Context) error {
if s.isDatabaseUpToDate() {
log.Println("GeoLite2 City database is up-to-date")
slog.Info("GeoLite2 City database is up-to-date")
return nil
}
log.Println("Updating GeoLite2 City database")
slog.Info("Updating GeoLite2 City database")
downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey)
ctx, cancel := context.WithTimeout(parentCtx, 10*time.Minute)
@@ -148,7 +218,7 @@ func (s *GeoLiteService) UpdateDatabase(parentCtx context.Context) error {
return fmt.Errorf("failed to extract database: %w", err)
}
log.Println("GeoLite2 City database successfully updated.")
slog.Info("GeoLite2 City database successfully updated.")
return nil
}

View File

@@ -0,0 +1,220 @@
package service
import (
"net"
"net/http"
"testing"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGeoLiteService_IPv6LocalRanges(t *testing.T) {
tests := []struct {
name string
localRanges string
testIP string
expectedCountry string
expectedCity string
expectError bool
}{
{
name: "IPv6 in local range",
localRanges: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
testIP: "2001:0db8:abcd:000::1",
expectedCountry: "Internal Network",
expectedCity: "LAN",
expectError: false,
},
{
name: "IPv6 not in local range",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "2001:0db8:ffff:000::1",
expectError: true,
},
{
name: "Multiple ranges - second range match",
localRanges: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
testIP: "2001:0db8:abcd:001::1",
expectedCountry: "Internal Network",
expectedCity: "LAN",
expectError: false,
},
{
name: "Empty local ranges",
localRanges: "",
testIP: "2001:0db8:abcd:000::1",
expectError: true,
},
{
name: "IPv4 private address still works",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "192.168.1.1",
expectedCountry: "Internal Network",
expectedCity: "LAN",
expectError: false,
},
{
name: "IPv6 loopback",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "::1",
expectedCountry: "Internal Network",
expectedCity: "localhost",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
originalConfig := common.EnvConfig.LocalIPv6Ranges
common.EnvConfig.LocalIPv6Ranges = tt.localRanges
defer func() {
common.EnvConfig.LocalIPv6Ranges = originalConfig
}()
service := NewGeoLiteService(&http.Client{})
country, city, err := service.GetLocationByIP(tt.testIP)
if tt.expectError {
if err == nil && country != "Internal Network" {
t.Errorf("Expected error or internal network classification for external IP")
}
} else {
require.NoError(t, err)
assert.Equal(t, tt.expectedCountry, country)
assert.Equal(t, tt.expectedCity, city)
}
})
}
}
func TestGeoLiteService_isLocalIPv6(t *testing.T) {
tests := []struct {
name string
localRanges string
testIP string
expected bool
}{
{
name: "Valid IPv6 in range",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "2001:0db8:abcd:000::1",
expected: true,
},
{
name: "Valid IPv6 not in range",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "2001:0db8:ffff:000::1",
expected: false,
},
{
name: "IPv4 address should return false",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "192.168.1.1",
expected: false,
},
{
name: "No ranges configured",
localRanges: "",
testIP: "2001:0db8:abcd:000::1",
expected: false,
},
{
name: "Edge of range",
localRanges: "2001:0db8:abcd:000::/56",
testIP: "2001:0db8:abcd:00ff:ffff:ffff:ffff:ffff",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
originalConfig := common.EnvConfig.LocalIPv6Ranges
common.EnvConfig.LocalIPv6Ranges = tt.localRanges
defer func() {
common.EnvConfig.LocalIPv6Ranges = originalConfig
}()
service := NewGeoLiteService(&http.Client{})
ip := net.ParseIP(tt.testIP)
if ip == nil {
t.Fatalf("Invalid test IP: %s", tt.testIP)
}
result := service.isLocalIPv6(ip)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGeoLiteService_initializeIPv6LocalRanges(t *testing.T) {
tests := []struct {
name string
envValue string
expectError bool
expectCount int
}{
{
name: "Valid IPv6 ranges",
envValue: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
expectError: false,
expectCount: 2,
},
{
name: "Empty environment variable",
envValue: "",
expectError: false,
expectCount: 0,
},
{
name: "Invalid CIDR notation",
envValue: "2001:0db8:abcd:000::/999",
expectError: true,
expectCount: 0,
},
{
name: "IPv4 range in IPv6 env var",
envValue: "192.168.1.0/24",
expectError: true,
expectCount: 0,
},
{
name: "Mixed valid and invalid ranges",
envValue: "2001:0db8:abcd:000::/56,invalid-range",
expectError: true,
expectCount: 0,
},
{
name: "Whitespace handling",
envValue: " 2001:0db8:abcd:000::/56 , 2001:0db8:abcd:001::/56 ",
expectError: false,
expectCount: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
originalConfig := common.EnvConfig.LocalIPv6Ranges
common.EnvConfig.LocalIPv6Ranges = tt.envValue
defer func() {
common.EnvConfig.LocalIPv6Ranges = originalConfig
}()
service := &GeoLiteService{
httpClient: &http.Client{},
}
err := service.initializeIPv6LocalRanges()
if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Len(t, service.localIPv6Ranges, tt.expectCount)
})
}
}

View File

@@ -2,25 +2,19 @@ package service
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"os"
"path/filepath"
"time"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
)
const (
@@ -28,8 +22,9 @@ const (
// This is a JSON file containing a key encoded as JWK
PrivateKeyFile = "jwt_private_key.json"
// RsaKeySize is the size, in bits, of the RSA key to generate if none is found
RsaKeySize = 2048
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
// This is a encrypted JSON file containing a key encoded as JWK
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
// KeyUsageSigning is the usage for the private keys, for the "use" property
KeyUsageSigning = "sig"
@@ -41,9 +36,15 @@ const (
// TokenTypeClaim is the claim used to identify the type of token
TokenTypeClaim = "type"
// RefreshTokenClaim is the claim used for the refresh token's value
RefreshTokenClaim = "rt"
// OAuthAccessTokenJWTType identifies a JWT as an OAuth access token
OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec
// OAuthRefreshTokenJWTType identifies a JWT as an OAuth refresh token
OAuthRefreshTokenJWTType = "refresh-token"
// AccessTokenJWTType identifies a JWT as an access token used by Pocket ID
AccessTokenJWTType = "access-token"
@@ -55,58 +56,74 @@ const (
)
type JwtService struct {
envConfig *common.EnvConfigSchema
privateKey jwk.Key
keyId string
appConfigService *AppConfigService
jwksEncoded []byte
}
func NewJwtService(appConfigService *AppConfigService) *JwtService {
func NewJwtService(db *gorm.DB, appConfigService *AppConfigService) (*JwtService, error) {
service := &JwtService{}
// Ensure keys are generated or loaded
if err := service.init(appConfigService, common.EnvConfig.KeysPath); err != nil {
log.Fatalf("Failed to initialize jwt service: %v", err)
err := service.init(db, appConfigService, &common.EnvConfig)
if err != nil {
return nil, err
}
return service
return service, nil
}
func (s *JwtService) init(appConfigService *AppConfigService, keysPath string) error {
func (s *JwtService) init(db *gorm.DB, appConfigService *AppConfigService, envConfig *common.EnvConfigSchema) (err error) {
s.appConfigService = appConfigService
s.envConfig = envConfig
// Ensure keys are generated or loaded
return s.loadOrGenerateKey(keysPath)
return s.loadOrGenerateKey(db)
}
// loadOrGenerateKey loads the private key from the given path or generates it if not existing.
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
var key jwk.Key
// First, check if we have a JWK file
// If we do, then we just load that
jwkPath := filepath.Join(keysPath, PrivateKeyFile)
ok, err := utils.FileExists(jwkPath)
func (s *JwtService) loadOrGenerateKey(db *gorm.DB) error {
// Get the key provider
keyProvider, err := jwkutils.GetKeyProvider(db, s.envConfig, s.appConfigService.GetDbConfig().InstanceID.Value)
if err != nil {
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
return fmt.Errorf("failed to get key provider: %w", err)
}
if ok {
key, err = s.loadKeyJWK(jwkPath)
if err != nil {
return fmt.Errorf("failed to load private key file (JWK) at path '%s': %w", jwkPath, err)
}
// Set the key, and we are done
// Try loading a key
key, err := keyProvider.LoadKey()
if err != nil {
return fmt.Errorf("failed to load key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
}
// If we have a key, store it in the object and we're done
if key != nil {
err = s.SetKey(key)
if err != nil {
return fmt.Errorf("failed to set private key: %w", err)
}
return nil
}
// If we are here, we need to generate a new key
key, err = s.generateNewRSAKey()
err = s.generateKey()
if err != nil {
return fmt.Errorf("failed to generate key: %w", err)
}
// Save the newly-generated key
err = keyProvider.SaveKey(s.privateKey)
if err != nil {
return fmt.Errorf("failed to save private key (provider type '%s'): %w", s.envConfig.KeysStorage, err)
}
return nil
}
// generateKey generates a new key and stores it in the object
func (s *JwtService) generateKey() error {
// Default is to generate RS256 (RSA-2048) keys
key, err := jwkutils.GenerateKey(jwa.RS256().String(), "")
if err != nil {
return fmt.Errorf("failed to generate new private key: %w", err)
}
@@ -117,12 +134,6 @@ func (s *JwtService) loadOrGenerateKey(keysPath string) error {
return fmt.Errorf("failed to set private key: %w", err)
}
// Save the key as JWK
err = SaveKeyJWK(s.privateKey, jwkPath)
if err != nil {
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
}
return nil
}
@@ -188,13 +199,13 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
Subject(user.ID).
Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Issuer(s.envConfig.AppURL).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
}
err = SetAudienceString(token, common.EnvConfig.AppURL)
err = SetAudienceString(token, s.envConfig.AppURL)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
@@ -225,8 +236,8 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithAudience(common.EnvConfig.AppURL),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithAudience(s.envConfig.AppURL),
jwt.WithIssuer(s.envConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(AccessTokenJWTType)),
)
if err != nil {
@@ -236,41 +247,52 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
return token, nil
}
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
// BuildIDToken creates an ID token with all claims
func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, nonce string) (jwt.Token, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Expiration(now.Add(1 * time.Hour)).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Issuer(s.envConfig.AppURL).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
return nil, fmt.Errorf("failed to build token: %w", err)
}
err = SetAudienceString(token, clientID)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
return nil, fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
err = SetTokenType(token, IDTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
for k, v := range userClaims {
err = token.Set(k, v)
if err != nil {
return "", fmt.Errorf("failed to set claim '%s': %w", k, err)
return nil, fmt.Errorf("failed to set claim '%s': %w", k, err)
}
}
if nonce != "" {
err = token.Set("nonce", nonce)
if err != nil {
return "", fmt.Errorf("failed to set claim 'nonce': %w", err)
return nil, fmt.Errorf("failed to set claim 'nonce': %w", err)
}
}
return token, nil
}
// GenerateIDToken creates and signs an ID token
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
token, err := s.BuildIDToken(userClaims, clientID, nonce)
if err != nil {
return "", err
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
if err != nil {
@@ -290,7 +312,7 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithIssuer(s.envConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(IDTokenJWTType)),
)
@@ -313,24 +335,88 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
return token, nil
}
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
// BuildOAuthAccessToken creates an OAuth access token with all claims
func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jwt.Token, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Subject(user.ID).
Expiration(now.Add(1 * time.Hour)).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Issuer(s.envConfig.AppURL).
Build()
if err != nil {
return nil, fmt.Errorf("failed to build token: %w", err)
}
err = SetAudienceString(token, clientID)
if err != nil {
return nil, fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
err = SetTokenType(token, OAuthAccessTokenJWTType)
if err != nil {
return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
return token, nil
}
// GenerateOAuthAccessToken creates and signs an OAuth access token
func (s *JwtService) GenerateOAuthAccessToken(user model.User, clientID string) (string, error) {
token, err := s.BuildOAuthAccessToken(user, clientID)
if err != nil {
return "", err
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
return string(signed), nil
}
func (s *JwtService) VerifyOAuthAccessToken(tokenString string) (jwt.Token, error) {
alg, _ := s.privateKey.Algorithm()
token, err := jwt.ParseString(
tokenString,
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(s.envConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)),
)
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
return token, nil
}
func (s *JwtService) GenerateOAuthRefreshToken(userID string, clientID string, refreshToken string) (string, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Subject(userID).
Expiration(now.Add(RefreshTokenDuration)).
IssuedAt(now).
Issuer(s.envConfig.AppURL).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
}
err = token.Set(RefreshTokenClaim, refreshToken)
if err != nil {
return "", fmt.Errorf("failed to set 'rt' claim in token: %w", err)
}
err = SetAudienceString(token, clientID)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
err = SetTokenType(token, OAuthAccessTokenJWTType)
err = SetTokenType(token, OAuthRefreshTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
@@ -344,21 +430,58 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
return string(signed), nil
}
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, error) {
func (s *JwtService) VerifyOAuthRefreshToken(tokenString string) (userID, clientID, rt string, err error) {
alg, _ := s.privateKey.Algorithm()
token, err := jwt.ParseString(
tokenString,
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)),
jwt.WithIssuer(s.envConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(OAuthRefreshTokenJWTType)),
)
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
return "", "", "", fmt.Errorf("failed to parse token: %w", err)
}
return token, nil
err = token.Get(RefreshTokenClaim, &rt)
if err != nil {
return "", "", "", fmt.Errorf("failed to get '%s' claim from token: %w", RefreshTokenClaim, err)
}
audiences, ok := token.Audience()
if !ok || len(audiences) != 1 || audiences[0] == "" {
return "", "", "", errors.New("failed to get 'aud' claim from token")
}
clientID = audiences[0]
userID, ok = token.Subject()
if !ok {
return "", "", "", errors.New("failed to get 'sub' claim from token")
}
return userID, clientID, rt, nil
}
// GetTokenType returns the type of the JWT token issued by Pocket ID, but **does not validate it**.
func (s *JwtService) GetTokenType(tokenString string) (string, jwt.Token, error) {
// Disable validation and verification to parse the token without checking it
token, err := jwt.ParseString(
tokenString,
jwt.WithValidate(false),
jwt.WithVerify(false),
)
if err != nil {
return "", nil, fmt.Errorf("failed to parse token: %w", err)
}
var tokenType string
err = token.Get(TokenTypeClaim, &tokenType)
if err != nil {
return "", nil, fmt.Errorf("failed to get token type claim: %w", err)
}
return tokenType, token, nil
}
// GetPublicJWK returns the JSON Web Key (JWK) for the public key.
@@ -372,7 +495,7 @@ func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
EnsureAlgInKey(pubKey)
jwkutils.EnsureAlgInKey(pubKey, "", "")
return pubKey, nil
}
@@ -401,107 +524,6 @@ func (s *JwtService) GetKeyAlg() (jwa.KeyAlgorithm, error) {
return alg, nil
}
func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read key data: %w", err)
}
key, err := jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse key: %w", err)
}
return key, nil
}
// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
func EnsureAlgInKey(key jwk.Key) {
_, ok := key.Algorithm()
if ok {
// Algorithm is already set
return
}
switch key.KeyType() {
case jwa.RSA():
// Default to RS256 for RSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
case jwa.EC():
// Default to ES256 for ECDSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
case jwa.OKP():
// Default to EdDSA for OKP keys
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
}
}
func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
// We generate RSA keys only
rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize)
if err != nil {
return nil, fmt.Errorf("failed to generate RSA private key: %w", err)
}
// Import the raw key
return importRawKey(rawKey)
}
func importRawKey(rawKey any) (jwk.Key, error) {
key, err := jwk.Import(rawKey)
if err != nil {
return nil, fmt.Errorf("failed to import generated private key: %w", err)
}
// Generate the key ID
kid, err := generateRandomKeyID()
if err != nil {
return nil, fmt.Errorf("failed to generate key ID: %w", err)
}
_ = key.Set(jwk.KeyIDKey, kid)
// Set other required fields
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
EnsureAlgInKey(key)
return key, err
}
// SaveKeyJWK saves a JWK to a file
func SaveKeyJWK(key jwk.Key, path string) error {
dir := filepath.Dir(path)
err := os.MkdirAll(dir, 0700)
if err != nil {
return fmt.Errorf("failed to create directory '%s' for key file: %w", dir, err)
}
keyFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to create key file: %w", err)
}
defer keyFile.Close()
// Write the JSON file to disk
enc := json.NewEncoder(keyFile)
enc.SetEscapeHTML(false)
err = enc.Encode(key)
if err != nil {
return fmt.Errorf("failed to write key file: %w", err)
}
return nil
}
// generateRandomKeyID generates a random key ID.
func generateRandomKeyID() (string, error) {
buf := make([]byte, 8)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return "", fmt.Errorf("failed to read random bytes: %w", err)
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
// GetIsAdmin returns the value of the "isAdmin" claim in the token
func GetIsAdmin(token jwt.Token) (bool, error) {
if !token.Has(IsAdminClaim) {
@@ -509,7 +531,10 @@ func GetIsAdmin(token jwt.Token) (bool, error) {
}
var isAdmin bool
err := token.Get(IsAdminClaim, &isAdmin)
return isAdmin, err
if err != nil {
return false, fmt.Errorf("failed to get 'isAdmin' claim from token: %w", err)
}
return isAdmin, nil
}
// SetTokenType sets the "type" claim in the token

View File

@@ -21,6 +21,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
)
func TestJwtService_Init(t *testing.T) {
@@ -32,9 +33,16 @@ func TestJwtService_Init(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// Initialize the JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Verify the private key was set
@@ -65,9 +73,16 @@ func TestJwtService_Init(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// First create a service to generate a key
firstService := &JwtService{}
err := firstService.init(mockConfig, tempDir)
err := firstService.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err)
// Get the key ID of the first service
@@ -76,7 +91,7 @@ func TestJwtService_Init(t *testing.T) {
// Now create a new service that should load the existing key
secondService := &JwtService{}
err = secondService.init(mockConfig, tempDir)
err = secondService.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err)
// Verify the loaded key has the same ID as the original
@@ -89,12 +104,19 @@ func TestJwtService_Init(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// Create a new JWK and save it to disk
origKeyID := createECDSAKeyJWK(t, tempDir)
// Now create a new service that should load the existing key
svc := &JwtService{}
err := svc.init(mockConfig, tempDir)
err := svc.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err)
// Ensure loaded key has the right algorithm
@@ -112,12 +134,19 @@ func TestJwtService_Init(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// Create a new JWK and save it to disk
origKeyID := createEdDSAKeyJWK(t, tempDir)
// Now create a new service that should load the existing key
svc := &JwtService{}
err := svc.init(mockConfig, tempDir)
err := svc.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err)
// Ensure loaded key has the right algorithm and curve
@@ -146,9 +175,16 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// Create a JWT service with initialized key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Get the JWK (public key)
@@ -177,12 +213,19 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// Create an ECDSA key and save it as JWK
originalKeyID := createECDSAKeyJWK(t, tempDir)
// Create a JWT service that loads the ECDSA key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Get the JWK (public key)
@@ -215,12 +258,19 @@ func TestJwtService_GetPublicJWK(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
// Create an EdDSA key and save it as JWK
originalKeyID := createEdDSAKeyJWK(t, tempDir)
// Create a JWT service that loads the EdDSA key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Get the JWK (public key)
@@ -275,16 +325,16 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
})
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
common.EnvConfig.AppURL = "https://test.example.com"
defer func() {
common.EnvConfig.AppURL = originalAppURL
}()
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
t.Run("generates token for regular user", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
@@ -327,7 +377,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
t.Run("generates token for admin user", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test admin user
@@ -363,7 +413,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
})
service := &JwtService{}
err := service.init(customMockConfig, tempDir)
err := service.init(nil, customMockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
@@ -398,7 +448,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -452,7 +505,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -506,7 +562,10 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -562,16 +621,16 @@ func TestGenerateVerifyIdToken(t *testing.T) {
})
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
common.EnvConfig.AppURL = "https://test.example.com"
defer func() {
common.EnvConfig.AppURL = originalAppURL
}()
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
t.Run("generates and verifies ID token with standard claims", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create test claims
@@ -600,7 +659,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID")
issuer, ok := claims.Issuer()
_ = assert.True(t, ok, "Issuer not found in token") &&
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
// Check token expiration time is approximately 1 hour from now
expectedExp := time.Now().Add(1 * time.Hour)
@@ -613,7 +672,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
t.Run("can accept expired tokens if told so", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create test claims
@@ -627,7 +686,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
// Create a token that's already expired
token, err := jwt.NewBuilder().
Subject(userClaims["sub"].(string)).
Issuer(common.EnvConfig.AppURL).
Issuer(service.envConfig.AppURL).
Audience([]string{clientID}).
IssuedAt(time.Now().Add(-2 * time.Hour)).
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
@@ -665,13 +724,13 @@ func TestGenerateVerifyIdToken(t *testing.T) {
assert.Equal(t, userClaims["sub"], subject, "Token subject should match user ID")
issuer, ok := claims.Issuer()
_ = assert.True(t, ok, "Issuer not found in token") &&
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
})
t.Run("generates and verifies ID token with nonce", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create test claims with nonce
@@ -702,7 +761,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
t.Run("fails verification with incorrect issuer", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Generate a token with standard claims
@@ -713,7 +772,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
require.NoError(t, err, "Failed to generate ID token")
// Temporarily change the app URL to simulate wrong issuer
common.EnvConfig.AppURL = "https://wrong-issuer.com"
service.envConfig.AppURL = "https://wrong-issuer.com"
// Verify should fail due to issuer mismatch
_, err = service.VerifyIdToken(tokenString, false)
@@ -730,7 +789,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -761,7 +823,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
assert.Equal(t, "eddsauser456", subject, "Token subject should match user ID")
issuer, ok := claims.Issuer()
_ = assert.True(t, ok, "Issuer not found in token") &&
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
// Verify the key type is OKP
publicKey, err := service.GetPublicJWK()
@@ -783,7 +845,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -794,7 +859,6 @@ func TestGenerateVerifyIdToken(t *testing.T) {
// Create test claims
userClaims := map[string]interface{}{
"sub": "ecdsauser456",
"name": "ECDSA User",
"email": "ecdsauser@example.com",
}
const clientID = "ecdsa-client-123"
@@ -814,7 +878,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
assert.Equal(t, "ecdsauser456", subject, "Token subject should match user ID")
issuer, ok := claims.Issuer()
_ = assert.True(t, ok, "Issuer not found in token") &&
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
// Verify the key type is EC
publicKey, err := service.GetPublicJWK()
@@ -836,7 +900,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -867,21 +934,11 @@ func TestGenerateVerifyIdToken(t *testing.T) {
assert.Equal(t, "rsauser456", subject, "Token subject should match user ID")
issuer, ok := claims.Issuer()
_ = assert.True(t, ok, "Issuer not found in token") &&
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
// Verify the key type is RSA
publicKey, err := service.GetPublicJWK()
require.NoError(t, err)
assert.Equal(t, jwa.RSA().String(), publicKey.KeyType().String(), "Key type should be RSA")
// Verify the algorithm is RS256
alg, ok := publicKey.Algorithm()
require.True(t, ok)
assert.Equal(t, jwa.RS256().String(), alg.String(), "Algorithm should be RS256")
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
})
}
func TestGenerateVerifyOauthAccessToken(t *testing.T) {
func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
@@ -891,16 +948,16 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
})
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
common.EnvConfig.AppURL = "https://test.example.com"
defer func() {
common.EnvConfig.AppURL = originalAppURL
}()
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
t.Run("generates and verifies OAuth access token with standard claims", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
@@ -913,12 +970,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "test-client-123"
// Generate a token
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
claims, err := service.VerifyOauthAccessToken(tokenString)
claims, err := service.VerifyOAuthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token")
// Check the claims
@@ -930,7 +987,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
assert.Equal(t, []string{clientID}, audience, "Audience should contain the client ID")
issuer, ok := claims.Issuer()
_ = assert.True(t, ok, "Issuer not found in token") &&
assert.Equal(t, common.EnvConfig.AppURL, issuer, "Issuer should match app URL")
assert.Equal(t, service.envConfig.AppURL, issuer, "Issuer should match app URL")
// Check token expiration time is approximately 1 hour from now
expectedExp := time.Now().Add(1 * time.Hour)
@@ -943,7 +1000,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
t.Run("fails verification for expired token", func(t *testing.T) {
// Create a JWT service with a mock function to generate an expired token
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
@@ -960,7 +1017,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
IssuedAt(time.Now().Add(-2 * time.Hour)).
Audience([]string{clientID}).
Issuer(common.EnvConfig.AppURL).
Issuer(service.envConfig.AppURL).
Build()
require.NoError(t, err, "Failed to build token")
@@ -971,7 +1028,7 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
require.NoError(t, err, "Failed to sign token")
// Verify should fail due to expiration
_, err = service.VerifyOauthAccessToken(string(signed))
_, err = service.VerifyOAuthAccessToken(string(signed))
require.Error(t, err, "Verification should fail with expired token")
assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure")
})
@@ -979,11 +1036,17 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
t.Run("fails verification with invalid signature", func(t *testing.T) {
// Create two JWT services with different keys
service1 := &JwtService{}
err := service1.init(mockConfig, t.TempDir()) // Use a different temp dir
err := service1.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: t.TempDir(), // Use a different temp dir
})
require.NoError(t, err, "Failed to initialize first JWT service")
service2 := &JwtService{}
err = service2.init(mockConfig, t.TempDir()) // Use a different temp dir
err = service2.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: t.TempDir(), // Use a different temp dir
})
require.NoError(t, err, "Failed to initialize second JWT service")
// Create a test user
@@ -995,11 +1058,11 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "test-client-789"
// Generate a token with the first service
tokenString, err := service1.GenerateOauthAccessToken(user, clientID)
tokenString, err := service1.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token")
// Verify with the second service should fail due to different keys
_, err = service2.VerifyOauthAccessToken(tokenString)
_, err = service2.VerifyOAuthAccessToken(tokenString)
require.Error(t, err, "Verification should fail with invalid signature")
assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure")
})
@@ -1013,7 +1076,10 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -1031,12 +1097,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "eddsa-oauth-client"
// Generate a token
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
claims, err := service.VerifyOauthAccessToken(tokenString)
claims, err := service.VerifyOAuthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
// Check the claims
@@ -1067,7 +1133,10 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -1085,12 +1154,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "ecdsa-oauth-client"
// Generate a token
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
claims, err := service.VerifyOauthAccessToken(tokenString)
claims, err := service.VerifyOAuthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
// Check the claims
@@ -1121,7 +1190,10 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
// Create a JWT service that loads the key
service := &JwtService{}
err := service.init(mockConfig, tempDir)
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
// Verify it loaded the right key
@@ -1139,12 +1211,12 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
const clientID = "rsa-oauth-client"
// Generate a token
tokenString, err := service.GenerateOauthAccessToken(user, clientID)
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
require.NoError(t, err, "Failed to generate OAuth access token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
claims, err := service.VerifyOauthAccessToken(tokenString)
claims, err := service.VerifyOAuthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token with key")
// Check the claims
@@ -1167,6 +1239,98 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
})
}
func TestGenerateVerifyOAuthRefreshToken(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService
mockConfig := NewTestAppConfigService(&model.AppConfig{})
// Setup the environment variable required by the token verification
mockEnvConfig := &common.EnvConfigSchema{
AppURL: "https://test.example.com",
KeysStorage: "file",
KeysPath: tempDir,
}
t.Run("generates and verifies refresh token", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Create a test user
const (
userID = "user123"
clientID = "client123"
refreshToken = "rt-123"
)
// Generate a token
tokenString, err := service.GenerateOAuthRefreshToken(userID, clientID, refreshToken)
require.NoError(t, err, "Failed to generate refresh token")
assert.NotEmpty(t, tokenString, "Token should not be empty")
// Verify the token
resUser, resClient, resRT, err := service.VerifyOAuthRefreshToken(tokenString)
require.NoError(t, err, "Failed to verify generated token")
assert.Equal(t, userID, resUser, "Should return correct user ID")
assert.Equal(t, clientID, resClient, "Should return correct client ID")
assert.Equal(t, refreshToken, resRT, "Should return correct refresh token")
})
t.Run("fails verification for expired token", func(t *testing.T) {
// Create a JWT service
service := &JwtService{}
err := service.init(nil, mockConfig, mockEnvConfig)
require.NoError(t, err, "Failed to initialize JWT service")
// Generate a token using JWT directly to create an expired token
token, err := jwt.NewBuilder().
Subject("user789").
Expiration(time.Now().Add(-1 * time.Hour)). // Expired 1 hour ago
IssuedAt(time.Now().Add(-2 * time.Hour)).
Audience([]string{"client123"}).
Issuer(service.envConfig.AppURL).
Build()
require.NoError(t, err, "Failed to build token")
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
require.NoError(t, err, "Failed to sign token")
// Verify should fail due to expiration
_, _, _, err = service.VerifyOAuthRefreshToken(string(signed))
require.Error(t, err, "Verification should fail with expired token")
assert.Contains(t, err.Error(), `"exp" not satisfied`, "Error message should indicate token verification failure")
})
t.Run("fails verification with invalid signature", func(t *testing.T) {
// Create two JWT services with different keys
service1 := &JwtService{}
err := service1.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: t.TempDir(), // Use a different temp dir
})
require.NoError(t, err, "Failed to initialize first JWT service")
service2 := &JwtService{}
err = service2.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: t.TempDir(), // Use a different temp dir
})
require.NoError(t, err, "Failed to initialize second JWT service")
// Generate a token with the first service
tokenString, err := service1.GenerateOAuthRefreshToken("user789", "client123", "my-rt-123")
require.NoError(t, err, "Failed to generate refresh token")
// Verify with the second service should fail due to different keys
_, _, _, err = service2.VerifyOAuthRefreshToken(tokenString)
require.Error(t, err, "Verification should fail with invalid signature")
assert.Contains(t, err.Error(), "verification error", "Error message should indicate token verification failure")
})
}
func TestTokenTypeValidator(t *testing.T) {
// Create a context for the validator function
ctx := context.Background()
@@ -1212,16 +1376,125 @@ func TestTokenTypeValidator(t *testing.T) {
require.Error(t, err, "Validator should reject token without type claim")
assert.Contains(t, err.Error(), "failed to get token type claim")
})
}
func TestGetTokenType(t *testing.T) {
// Create a temporary directory for the test
tempDir := t.TempDir()
// Initialize the JWT service
mockConfig := NewTestAppConfigService(&model.AppConfig{})
service := &JwtService{}
err := service.init(nil, mockConfig, &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: tempDir,
})
require.NoError(t, err, "Failed to initialize JWT service")
buildTokenForType := func(t *testing.T, typ string, setClaimsFn func(b *jwt.Builder)) string {
t.Helper()
b := jwt.NewBuilder()
b.Subject("user123")
if setClaimsFn != nil {
setClaimsFn(b)
}
token, err := b.Build()
require.NoError(t, err, "Failed to build token")
err = SetTokenType(token, typ)
require.NoError(t, err, "Failed to set token type")
alg, _ := service.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, service.privateKey))
require.NoError(t, err, "Failed to sign token")
return string(signed)
}
t.Run("correctly identifies access tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, AccessTokenJWTType, nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error")
assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified as access token")
})
t.Run("correctly identifies ID tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, IDTokenJWTType, nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error")
assert.Equal(t, IDTokenJWTType, tokenType, "Token type should be correctly identified as ID token")
})
t.Run("correctly identifies OAuth access tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, OAuthAccessTokenJWTType, nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error")
assert.Equal(t, OAuthAccessTokenJWTType, tokenType, "Token type should be correctly identified as OAuth access token")
})
t.Run("correctly identifies refresh tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, OAuthRefreshTokenJWTType, nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error")
assert.Equal(t, OAuthRefreshTokenJWTType, tokenType, "Token type should be correctly identified as refresh token")
})
t.Run("works with expired tokens", func(t *testing.T) {
tokenString := buildTokenForType(t, AccessTokenJWTType, func(b *jwt.Builder) {
b.Expiration(time.Now().Add(-1 * time.Hour)) // Expired 1 hour ago
})
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.NoError(t, err, "GetTokenType should not return an error for expired tokens")
assert.Equal(t, AccessTokenJWTType, tokenType, "Token type should be correctly identified even for expired tokens")
})
t.Run("returns error for malformed tokens", func(t *testing.T) {
// Try to get the token type of a malformed token
tokenType, _, err := service.GetTokenType("not.a.valid.jwt.token")
require.Error(t, err, "GetTokenType should return an error for malformed tokens")
assert.Empty(t, tokenType, "Token type should be empty for malformed tokens")
})
t.Run("returns error for tokens without type claim", func(t *testing.T) {
// Create a token without type claim
tokenString := buildTokenForType(t, "", nil)
// Get the token type without validating
tokenType, _, err := service.GetTokenType(tokenString)
require.Error(t, err, "GetTokenType should return an error for tokens without type claim")
assert.Empty(t, tokenType, "Token type should be empty when type claim is missing")
assert.Contains(t, err.Error(), "failed to get token type claim", "Error message should indicate missing token type claim")
})
}
func importKey(t *testing.T, privateKeyRaw any, path string) string {
t.Helper()
privateKey, err := importRawKey(privateKeyRaw)
privateKey, err := jwkutils.ImportRawKey(privateKeyRaw, "", "")
require.NoError(t, err, "Failed to import private key")
err = SaveKeyJWK(privateKey, filepath.Join(path, PrivateKeyFile))
keyProvider := &jwkutils.KeyProviderFile{}
err = keyProvider.Init(jwkutils.KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysStorage: "file",
KeysPath: path,
},
})
require.NoError(t, err, "Failed to init file key provider")
err = keyProvider.SaveKey(privateKey)
require.NoError(t, err, "Failed to save key")
kid, _ := privateKey.KeyID()

View File

@@ -8,17 +8,21 @@ import (
"errors"
"fmt"
"io"
"log"
"log/slog"
"net/http"
"net/url"
"strings"
"time"
"unicode/utf8"
"github.com/go-ldap/ldap/v3"
"github.com/google/uuid"
"golang.org/x/text/unicode/norm"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"gorm.io/gorm"
)
type LdapService struct {
@@ -122,11 +126,11 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
ldapGroupIDs := make(map[string]struct{}, len(result.Entries))
for _, value := range result.Entries {
ldapId := value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value)
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
// Skip groups without a valid LDAP ID
if ldapId == "" {
log.Printf("Skipping LDAP group without a valid unique identifier (attribute: %s)", dbConfig.LdapAttributeGroupUniqueIdentifier.Value)
slog.Warn("Skipping LDAP group without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
continue
}
@@ -164,13 +168,13 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
userResult, err := client.Search(userSearchReq)
if err != nil || len(userResult.Entries) == 0 {
log.Printf("Could not resolve group member DN '%s': %v", member, err)
slog.WarnContext(ctx, "Could not resolve group member DN", slog.String("member", member), slog.Any("error", err))
continue
}
username = userResult.Entries[0].GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value)
if username == "" {
log.Printf("Could not extract username from group member DN '%s'", member)
slog.WarnContext(ctx, "Could not extract username from group member DN", slog.String("member", member))
continue
}
}
@@ -178,7 +182,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
var databaseUser model.User
err = tx.
WithContext(ctx).
Where("username = ? AND ldap_id IS NOT NULL", username).
Where("username = ? AND ldap_id IS NOT NULL", norm.NFC.String(username)).
First(&databaseUser).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -194,8 +198,9 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
syncGroup := dto.UserGroupCreateDto{
Name: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
LdapID: value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value),
LdapID: ldapId,
}
dto.Normalize(syncGroup)
if databaseGroup.ID == "" {
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
@@ -245,7 +250,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
return fmt.Errorf("failed to delete group '%s': %w", group.Name, err)
}
log.Printf("Deleted group '%s'", group.Name)
slog.Info("Deleted group", slog.String("group", group.Name))
}
return nil
@@ -286,11 +291,11 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
ldapUserIDs := make(map[string]struct{}, len(result.Entries))
for _, value := range result.Entries {
ldapId := value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value)
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
// Skip users without a valid LDAP ID
if ldapId == "" {
log.Printf("Skipping LDAP user without a valid unique identifier (attribute: %s)", dbConfig.LdapAttributeUserUniqueIdentifier.Value)
slog.Warn("Skipping LDAP user without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeUserUniqueIdentifier.Value))
continue
}
@@ -306,7 +311,6 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
// If a user is found (even if disabled), enable them since they're now back in LDAP
if databaseUser.ID != "" && databaseUser.Disabled {
// Use the transaction instead of the direct context
err = tx.
WithContext(ctx).
Model(&model.User{}).
@@ -315,7 +319,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
Error
if err != nil {
log.Printf("Failed to enable user %s: %v", databaseUser.Username, err)
return fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
}
}
@@ -341,11 +345,12 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
IsAdmin: isAdmin,
LdapID: ldapId,
}
dto.Normalize(newUser)
if databaseUser.ID == "" {
_, err = s.userService.createUserInternal(ctx, newUser, true, tx)
if errors.Is(err, &common.AlreadyInUseError{}) {
log.Printf("Skipping creating LDAP user '%s': %v", newUser.Username, err)
slog.Warn("Skipping creating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
continue
} else if err != nil {
return fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
@@ -353,7 +358,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
} else {
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
if errors.Is(err, &common.AlreadyInUseError{}) {
log.Printf("Skipping updating LDAP user '%s': %v", newUser.Username, err)
slog.Warn("Skipping updating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
continue
} else if err != nil {
return fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
@@ -366,7 +371,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
err = s.saveProfilePicture(ctx, databaseUser.ID, pictureString)
if err != nil {
// This is not a fatal error
log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err)
slog.Warn("Error saving profile picture for user", slog.String("username", newUser.Username), slog.Any("error", err))
}
}
}
@@ -395,7 +400,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
return fmt.Errorf("failed to disable user %s: %w", user.Username, err)
}
log.Printf("Disabled user '%s'", user.Username)
slog.Info("Disabled user", slog.String("username", user.Username))
} else {
err = s.userService.deleteUserInternal(ctx, user.ID, true, tx)
target := &common.LdapUserUpdateError{}
@@ -405,7 +410,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
return fmt.Errorf("failed to delete user %s: %w", user.Username, err)
}
log.Printf("Deleted user '%s'", user.Username)
slog.Info("Deleted user", slog.String("username", user.Username))
}
}
@@ -468,3 +473,21 @@ func getDNProperty(property string, str string) string {
// CN not found, return an empty string
return ""
}
// convertLdapIdToString converts LDAP IDs to valid UTF-8 strings.
// LDAP servers may return binary UUIDs (16 bytes) or other non-UTF-8 data.
func convertLdapIdToString(ldapId string) string {
if utf8.ValidString(ldapId) {
return norm.NFC.String(ldapId)
}
// Try to parse as binary UUID (16 bytes)
if len(ldapId) == 16 {
if parsedUUID, err := uuid.FromBytes([]byte(ldapId)); err == nil {
return parsedUUID.String()
}
}
// As a last resort, encode as base64 to make it UTF-8 safe
return base64.StdEncoding.EncodeToString([]byte(ldapId))
}

View File

@@ -71,3 +71,36 @@ func TestGetDNProperty(t *testing.T) {
})
}
}
func TestConvertLdapIdToString(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "valid UTF-8 string",
input: "simple-utf8-id",
expected: "simple-utf8-id",
},
{
name: "binary UUID (16 bytes)",
input: string([]byte{0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf1}),
expected: "12345678-9abc-def0-1234-56789abcdef1",
},
{
name: "non-UTF8, non-UUID returns base64",
input: string([]byte{0xff, 0xfe, 0xfd, 0xfc}),
expected: "//79/A==",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := convertLdapIdToString(tt.input)
if got != tt.expected {
t.Errorf("Expected %q, got %q", tt.expected, got)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,381 @@
package service
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/json"
"net/http"
"testing"
"time"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
)
// generateTestECDSAKey creates an ECDSA key for testing
func generateTestECDSAKey(t *testing.T) (jwk.Key, []byte) {
t.Helper()
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
privateJwk, err := jwk.Import(privateKey)
require.NoError(t, err)
err = privateJwk.Set(jwk.KeyIDKey, "test-key-1")
require.NoError(t, err)
err = privateJwk.Set(jwk.AlgorithmKey, "ES256")
require.NoError(t, err)
err = privateJwk.Set("use", "sig")
require.NoError(t, err)
publicJwk, err := jwk.PublicKeyOf(privateJwk)
require.NoError(t, err)
// Create a JWK Set with the public key
jwkSet := jwk.NewSet()
err = jwkSet.AddKey(publicJwk)
require.NoError(t, err)
jwkSetJSON, err := json.Marshal(jwkSet)
require.NoError(t, err)
return privateJwk, jwkSetJSON
}
func TestOidcService_jwkSetForURL(t *testing.T) {
// Generate a test key for JWKS
_, jwkSetJSON1 := generateTestECDSAKey(t)
_, jwkSetJSON2 := generateTestECDSAKey(t)
// Create a mock HTTP client with responses for different URLs
const (
url1 = "https://example.com/.well-known/jwks.json"
url2 = "https://other-issuer.com/jwks"
)
mockResponses := map[string]*http.Response{
//nolint:bodyclose
url1: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON1)),
//nolint:bodyclose
url2: testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON2)),
}
httpClient := &http.Client{
Transport: &testutils.MockRoundTripper{
Responses: mockResponses,
},
}
// Create the OidcService with our mock client
s := &OidcService{
httpClient: httpClient,
}
var err error
s.jwkCache, err = s.getJWKCache(t.Context())
require.NoError(t, err)
t.Run("Fetches and caches JWK set", func(t *testing.T) {
jwks, err := s.jwkSetForURL(t.Context(), url1)
require.NoError(t, err)
require.NotNil(t, jwks)
// Verify the JWK set contains our key
require.Equal(t, 1, jwks.Len())
})
t.Run("Fails with invalid URL", func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second)
defer cancel()
_, err := s.jwkSetForURL(ctx, "https://bad-url.com")
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
})
t.Run("Safe for concurrent use", func(t *testing.T) {
const concurrency = 20
// Channel to collect errors
errChan := make(chan error, concurrency)
// Start concurrent requests
for range concurrency {
go func() {
jwks, err := s.jwkSetForURL(t.Context(), url2)
if err != nil {
errChan <- err
return
}
// Verify the JWK set is valid
if jwks == nil || jwks.Len() != 1 {
errChan <- assert.AnError
return
}
errChan <- nil
}()
}
// Check for errors
for range concurrency {
assert.NoError(t, <-errChan, "Concurrent JWK set fetching should not produce errors")
}
})
}
func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
const (
federatedClientIssuer = "https://external-idp.com"
federatedClientAudience = "https://pocket-id.com"
federatedClientIssuerDefaults = "https://external-idp-defaults.com/"
)
var err error
// Create a test database
db := testutils.NewDatabaseForTest(t)
// Create two JWKs for testing
privateJWK, jwkSetJSON := generateTestECDSAKey(t)
require.NoError(t, err)
privateJWKDefaults, jwkSetJSONDefaults := generateTestECDSAKey(t)
require.NoError(t, err)
// Create a mock HTTP client with custom transport to return the JWKS
httpClient := &http.Client{
Transport: &testutils.MockRoundTripper{
Responses: map[string]*http.Response{
//nolint:bodyclose
federatedClientIssuer + "/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSON)),
//nolint:bodyclose
federatedClientIssuerDefaults + ".well-known/jwks.json": testutils.NewMockResponse(http.StatusOK, string(jwkSetJSONDefaults)),
},
},
}
// Init the OidcService
s := &OidcService{
db: db,
httpClient: httpClient,
}
s.jwkCache, err = s.getJWKCache(t.Context())
require.NoError(t, err)
// Create the test clients
// 1. Confidential client
confidentialClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
Name: "Confidential Client",
CallbackURLs: []string{"https://example.com/callback"},
}, "test-user-id")
require.NoError(t, err)
// Create a client secret for the confidential client
confidentialSecret, err := s.CreateClientSecret(t.Context(), confidentialClient.ID)
require.NoError(t, err)
// 2. Public client
publicClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
Name: "Public Client",
CallbackURLs: []string{"https://example.com/callback"},
IsPublic: true,
}, "test-user-id")
require.NoError(t, err)
// 3. Confidential client with federated identity
federatedClient, err := s.CreateClient(t.Context(), dto.OidcClientCreateDto{
Name: "Federated Client",
CallbackURLs: []string{"https://example.com/callback"},
}, "test-user-id")
require.NoError(t, err)
federatedClient, err = s.UpdateClient(t.Context(), federatedClient.ID, dto.OidcClientCreateDto{
Name: federatedClient.Name,
CallbackURLs: federatedClient.CallbackURLs,
Credentials: dto.OidcClientCredentialsDto{
FederatedIdentities: []dto.OidcClientFederatedIdentityDto{
{
Issuer: federatedClientIssuer,
Audience: federatedClientAudience,
Subject: federatedClient.ID,
JWKS: federatedClientIssuer + "/jwks.json",
},
{Issuer: federatedClientIssuerDefaults},
},
},
})
require.NoError(t, err)
// Test cases for confidential client (using client secret)
t.Run("Confidential client", func(t *testing.T) {
t.Run("Succeeds with valid secret", func(t *testing.T) {
// Test with valid client credentials
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID,
ClientSecret: confidentialSecret,
}, true)
require.NoError(t, err)
require.NotNil(t, client)
assert.Equal(t, confidentialClient.ID, client.ID)
})
t.Run("Fails with invalid secret", func(t *testing.T) {
// Test with invalid client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID,
ClientSecret: "invalid-secret",
}, true)
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
assert.Nil(t, client)
})
t.Run("Fails with missing secret", func(t *testing.T) {
// Test with missing client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: confidentialClient.ID,
}, true)
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
assert.Nil(t, client)
})
})
// Test cases for public client
t.Run("Public client", func(t *testing.T) {
t.Run("Succeeds with no credentials", func(t *testing.T) {
// Public clients don't require client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: publicClient.ID,
}, true)
require.NoError(t, err)
require.NotNil(t, client)
assert.Equal(t, publicClient.ID, client.ID)
})
t.Run("Fails with no credentials if allowPublicClientsWithoutAuth is false", func(t *testing.T) {
// Public clients don't require client secret
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: publicClient.ID,
}, false)
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
assert.Nil(t, client)
})
})
// Test cases for federated client using JWT assertion
t.Run("Federated client", func(t *testing.T) {
t.Run("Succeeds with valid JWT", func(t *testing.T) {
// Create JWT for federated identity
token, err := jwt.NewBuilder().
Issuer(federatedClientIssuer).
Audience([]string{federatedClientAudience}).
Subject(federatedClient.ID).
IssuedAt(time.Now()).
Expiration(time.Now().Add(10 * time.Minute)).
Build()
require.NoError(t, err)
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
require.NoError(t, err)
// Test with valid JWT assertion
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken),
}, true)
require.NoError(t, err)
require.NotNil(t, client)
assert.Equal(t, federatedClient.ID, client.ID)
})
t.Run("Fails with malformed JWT", func(t *testing.T) {
// Test with invalid JWT assertion (just a random string)
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: "invalid.jwt.token",
}, true)
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
assert.Nil(t, client)
})
testBadJWT := func(builderFn func(builder *jwt.Builder)) func(t *testing.T) {
return func(t *testing.T) {
// Populate all claims with valid values
builder := jwt.NewBuilder().
Issuer(federatedClientIssuer).
Audience([]string{federatedClientAudience}).
Subject(federatedClient.ID).
IssuedAt(time.Now()).
Expiration(time.Now().Add(10 * time.Minute))
// Call builderFn to override the claims
builderFn(builder)
token, err := builder.Build()
require.NoError(t, err)
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
require.NoError(t, err)
// Test with invalid JWT assertion
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken),
}, true)
require.Error(t, err)
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
require.Nil(t, client)
}
}
t.Run("Fails with expired JWT", testBadJWT(func(builder *jwt.Builder) {
builder.Expiration(time.Now().Add(-30 * time.Minute))
}))
t.Run("Fails with wrong issuer in JWT", testBadJWT(func(builder *jwt.Builder) {
builder.Issuer("https://bad-issuer.com")
}))
t.Run("Fails with wrong audience in JWT", testBadJWT(func(builder *jwt.Builder) {
builder.Audience([]string{"bad-audience"})
}))
t.Run("Fails with wrong subject in JWT", testBadJWT(func(builder *jwt.Builder) {
builder.Subject("bad-subject")
}))
t.Run("Uses default values for audience and subject", func(t *testing.T) {
// Create JWT for federated identity
token, err := jwt.NewBuilder().
Issuer(federatedClientIssuerDefaults).
Audience([]string{common.EnvConfig.AppURL}).
Subject(federatedClient.ID).
IssuedAt(time.Now()).
Expiration(time.Now().Add(10 * time.Minute)).
Build()
require.NoError(t, err)
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWKDefaults))
require.NoError(t, err)
// Test with valid JWT assertion
client, err := s.verifyClientCredentialsInternal(t.Context(), s.db, ClientAuthCredentials{
ClientID: federatedClient.ID,
ClientAssertionType: ClientAssertionTypeJWTBearer,
ClientAssertion: string(signedToken),
}, true)
require.NoError(t, err)
require.NotNil(t, client)
assert.Equal(t, federatedClient.ID, client.ID)
})
})
}

View File

@@ -6,13 +6,14 @@ import (
"errors"
"fmt"
"io"
"log"
"log/slog"
"net/url"
"os"
"strings"
"time"
"github.com/google/uuid"
"go.opentelemetry.io/otel/trace"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
@@ -33,7 +34,13 @@ type UserService struct {
}
func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, emailService *EmailService, appConfigService *AppConfigService) *UserService {
return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService, emailService: emailService, appConfigService: appConfigService}
return &UserService{
db: db,
jwtService: jwtService,
auditLogService: auditLogService,
emailService: emailService,
appConfigService: appConfigService,
}
}
func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) {
@@ -45,7 +52,8 @@ func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPa
if searchTerm != "" {
searchPattern := "%" + searchTerm + "%"
query = query.Where("email LIKE ? OR first_name LIKE ? OR last_name LIKE ? OR username LIKE ?",
query = query.Where(
"email LIKE ? OR first_name LIKE ? OR last_name LIKE ? OR username LIKE ?",
searchPattern, searchPattern, searchPattern, searchPattern)
}
@@ -118,13 +126,14 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
defaultPictureBytes := defaultPicture.Bytes()
go func() {
// Ensure the directory exists
err = os.MkdirAll(defaultProfilePicturesDir, os.ModePerm)
if err != nil {
log.Printf("Failed to create directory for default profile picture: %v", err)
errInternal := os.MkdirAll(defaultProfilePicturesDir, os.ModePerm)
if errInternal != nil {
slog.Error("Failed to create directory for default profile picture", slog.Any("error", errInternal))
return
}
if err := utils.SaveFileStream(bytes.NewReader(defaultPictureBytes), defaultPicturePath); err != nil {
log.Printf("Failed to cache default profile picture for initials %s: %v", user.Initials(), err)
errInternal = utils.SaveFileStream(bytes.NewReader(defaultPictureBytes), defaultPicturePath)
if errInternal != nil {
slog.Error("Failed to cache default profile picture for initials", slog.String("initials", user.Initials()), slog.Any("error", errInternal))
}
}()
@@ -296,15 +305,21 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
isLdapUser := user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue()
allowOwnAccountEdit := s.appConfigService.GetDbConfig().AllowOwnAccountEdit.IsTrue()
// For LDAP users or if own account editing is not allowed, only allow updating the locale unless it's an LDAP sync
if !isLdapSync && (isLdapUser || (!allowOwnAccountEdit && !updateOwnUser)) {
if !isLdapSync && (isLdapUser || (!allowOwnAccountEdit && updateOwnUser)) {
// Restricted update: Only locale can be changed when:
// - User is from LDAP, OR
// - User is editing their own account but global setting disallows self-editing
// (Exception: LDAP sync operations can update everything)
user.Locale = updatedUser.Locale
} else {
// Full update: Allow updating all personal fields
user.FirstName = updatedUser.FirstName
user.LastName = updatedUser.LastName
user.Email = updatedUser.Email
user.Username = updatedUser.Username
user.Locale = updatedUser.Locale
// Admin-only fields: Only allow updates when not updating own account
if !updateOwnUser {
user.IsAdmin = updatedUser.IsAdmin
user.Disabled = updatedUser.Disabled
@@ -387,7 +402,8 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use
// We use a background context here as this is running in a goroutine
//nolint:contextcheck
go func() {
innerCtx := context.Background()
span := trace.SpanFromContext(ctx)
innerCtx := trace.ContextWithSpan(context.Background(), span)
link := common.EnvConfig.AppURL + "/lc"
linkWithCode := link + "/" + oneTimeAccessToken
@@ -408,7 +424,8 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use
ExpirationString: utils.DurationToString(time.Until(expiration).Round(time.Second)),
})
if errInternal != nil {
log.Printf("Failed to send email to '%s': %v\n", user.Email, errInternal)
slog.ErrorContext(innerCtx, "Failed to send one-time access token email", slog.Any("error", errInternal), slog.String("address", user.Email))
return
}
}()
@@ -463,9 +480,7 @@ func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token stri
return model.User{}, "", err
}
if ipAddress != "" && userAgent != "" {
s.auditLogService.Create(ctx, model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}, tx)
}
s.auditLogService.Create(ctx, model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}, tx)
err = tx.Commit().Error
if err != nil {
@@ -523,7 +538,7 @@ func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroup
return user, nil
}
func (s *UserService) SetupInitialAdmin(ctx context.Context) (model.User, string, error) {
func (s *UserService) SignUpInitialAdmin(ctx context.Context, signUpData dto.SignUpDto) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
@@ -533,26 +548,23 @@ func (s *UserService) SetupInitialAdmin(ctx context.Context) (model.User, string
if err := tx.WithContext(ctx).Model(&model.User{}).Count(&userCount).Error; err != nil {
return model.User{}, "", err
}
if userCount > 1 {
if userCount != 0 {
return model.User{}, "", &common.SetupAlreadyCompletedError{}
}
user := model.User{
FirstName: "Admin",
LastName: "Admin",
Username: "admin",
Email: "admin@admin.com",
userToCreate := dto.UserCreateDto{
FirstName: signUpData.FirstName,
LastName: signUpData.LastName,
Username: signUpData.Username,
Email: signUpData.Email,
IsAdmin: true,
}
if err := tx.WithContext(ctx).Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
user, err := s.createUserInternal(ctx, userToCreate, false, tx)
if err != nil {
return model.User{}, "", err
}
if len(user.Credentials) > 0 {
return model.User{}, "", &common.SetupAlreadyCompletedError{}
}
token, err := s.jwtService.GenerateAccessToken(user)
if err != nil {
return model.User{}, "", err
@@ -630,6 +642,110 @@ func (s *UserService) disableUserInternal(ctx context.Context, userID string, tx
Error
}
func (s *UserService) CreateSignupToken(ctx context.Context, expiresAt time.Time, usageLimit int) (model.SignupToken, error) {
return s.createSignupTokenInternal(ctx, expiresAt, usageLimit, s.db)
}
func (s *UserService) createSignupTokenInternal(ctx context.Context, expiresAt time.Time, usageLimit int, tx *gorm.DB) (model.SignupToken, error) {
signupToken, err := NewSignupToken(expiresAt, usageLimit)
if err != nil {
return model.SignupToken{}, err
}
if err := tx.WithContext(ctx).Create(signupToken).Error; err != nil {
return model.SignupToken{}, err
}
return *signupToken, nil
}
func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAddress, userAgent string) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
tokenProvided := signupData.Token != ""
config := s.appConfigService.GetDbConfig()
if config.AllowUserSignups.Value != "open" && !tokenProvided {
return model.User{}, "", &common.OpenSignupDisabledError{}
}
var signupToken model.SignupToken
if tokenProvided {
err := tx.
WithContext(ctx).
Where("token = ?", signupData.Token).
First(&signupToken).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
}
return model.User{}, "", err
}
if !signupToken.IsValid() {
return model.User{}, "", &common.TokenInvalidOrExpiredError{}
}
}
userToCreate := dto.UserCreateDto{
Username: signupData.Username,
Email: signupData.Email,
FirstName: signupData.FirstName,
LastName: signupData.LastName,
}
user, err := s.createUserInternal(ctx, userToCreate, false, tx)
if err != nil {
return model.User{}, "", err
}
accessToken, err := s.jwtService.GenerateAccessToken(user)
if err != nil {
return model.User{}, "", err
}
if tokenProvided {
s.auditLogService.Create(ctx, model.AuditLogEventAccountCreated, ipAddress, userAgent, user.ID, model.AuditLogData{
"signupToken": signupToken.Token,
}, tx)
signupToken.UsageCount++
err = tx.WithContext(ctx).Save(&signupToken).Error
if err != nil {
return model.User{}, "", err
}
} else {
s.auditLogService.Create(ctx, model.AuditLogEventAccountCreated, ipAddress, userAgent, user.ID, model.AuditLogData{
"method": "open_signup",
}, tx)
}
err = tx.Commit().Error
if err != nil {
return model.User{}, "", err
}
return user, accessToken, nil
}
func (s *UserService) ListSignupTokens(ctx context.Context, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.SignupToken, utils.PaginationResponse, error) {
var tokens []model.SignupToken
query := s.db.WithContext(ctx).Model(&model.SignupToken{})
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &tokens)
return tokens, pagination, err
}
func (s *UserService) DeleteSignupToken(ctx context.Context, tokenID string) error {
return s.db.WithContext(ctx).Delete(&model.SignupToken{}, "id = ?", tokenID).Error
}
func NewOneTimeAccessToken(userID string, expiresAt time.Time) (*model.OneTimeAccessToken, error) {
// If expires at is less than 15 minutes, use a 6-character token instead of 16
tokenLength := 16
@@ -650,3 +766,20 @@ func NewOneTimeAccessToken(userID string, expiresAt time.Time) (*model.OneTimeAc
return o, nil
}
func NewSignupToken(expiresAt time.Time, usageLimit int) (*model.SignupToken, error) {
// Generate a random token
randomString, err := utils.GenerateRandomAlphanumericString(16)
if err != nil {
return nil, err
}
token := &model.SignupToken{
Token: randomString,
ExpiresAt: datatype.DateTime(expiresAt),
UsageLimit: usageLimit,
UsageCount: 0,
}
return token, nil
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
@@ -24,8 +25,8 @@ type WebAuthnService struct {
appConfigService *AppConfigService
}
func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, appConfigService *AppConfigService) *WebAuthnService {
webauthnConfig := &webauthn.Config{
func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, appConfigService *AppConfigService) (*WebAuthnService, error) {
wa, err := webauthn.New(&webauthn.Config{
RPDisplayName: appConfigService.GetDbConfig().AppName.Value,
RPID: utils.GetHostnameFromURL(common.EnvConfig.AppURL),
RPOrigins: []string{common.EnvConfig.AppURL},
@@ -44,15 +45,18 @@ func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *Au
TimeoutUVD: time.Second * 60,
},
},
})
if err != nil {
return nil, fmt.Errorf("failed to init webauthn object: %w", err)
}
wa, _ := webauthn.New(webauthnConfig)
return &WebAuthnService{
db: db,
webAuthn: wa,
jwtService: jwtService,
auditLogService: auditLogService,
appConfigService: appConfigService,
}
}, nil
}
func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string) (*model.PublicKeyCredentialCreationOptions, error) {
@@ -70,8 +74,7 @@ func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string)
Find(&user, "id = ?", userID).
Error
if err != nil {
tx.Rollback()
return nil, err
return nil, fmt.Errorf("failed to load user: %w", err)
}
options, session, err := s.webAuthn.BeginRegistration(
@@ -80,7 +83,7 @@ func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string)
webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()),
)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to begin WebAuthn registration: %w", err)
}
sessionToStore := &model.WebauthnSession{
@@ -94,12 +97,12 @@ func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string)
Create(&sessionToStore).
Error
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to save WebAuthn session: %w", err)
}
err = tx.Commit().Error
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}
return &model.PublicKeyCredentialCreationOptions{
@@ -115,13 +118,15 @@ func (s *WebAuthnService) VerifyRegistration(ctx context.Context, sessionID, use
tx.Rollback()
}()
// Load & delete the session row
var storedSession model.WebauthnSession
err := tx.
WithContext(ctx).
First(&storedSession, "id = ?", sessionID).
Clauses(clause.Returning{}).
Delete(&storedSession, "id = ?", sessionID).
Error
if err != nil {
return model.WebauthnCredential{}, err
return model.WebauthnCredential{}, fmt.Errorf("failed to load WebAuthn session: %w", err)
}
session := webauthn.SessionData{
@@ -136,12 +141,12 @@ func (s *WebAuthnService) VerifyRegistration(ctx context.Context, sessionID, use
Find(&user, "id = ?", userID).
Error
if err != nil {
return model.WebauthnCredential{}, err
return model.WebauthnCredential{}, fmt.Errorf("failed to load user: %w", err)
}
credential, err := s.webAuthn.FinishRegistration(&user, session, r)
if err != nil {
return model.WebauthnCredential{}, err
return model.WebauthnCredential{}, fmt.Errorf("failed to finish WebAuthn registration: %w", err)
}
// Determine passkey name using AAGUID and User-Agent
@@ -162,12 +167,12 @@ func (s *WebAuthnService) VerifyRegistration(ctx context.Context, sessionID, use
Create(&credentialToStore).
Error
if err != nil {
return model.WebauthnCredential{}, err
return model.WebauthnCredential{}, fmt.Errorf("failed to store WebAuthn credential: %w", err)
}
err = tx.Commit().Error
if err != nil {
return model.WebauthnCredential{}, err
return model.WebauthnCredential{}, fmt.Errorf("failed to commit transaction: %w", err)
}
return credentialToStore, nil

View File

@@ -4,7 +4,7 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"log"
"log/slog"
"sync"
"github.com/pocket-id/pocket-id/backend/resources"
@@ -57,12 +57,13 @@ func loadAAGUIDsFromFile() {
// Read from embedded file system
data, err := resources.FS.ReadFile("aaguids.json")
if err != nil {
log.Printf("Error reading embedded AAGUID file: %v", err)
slog.Error("Error reading embedded AAGUID file", slog.Any("error", err))
return
}
if err := json.Unmarshal(data, &aaguidMap); err != nil {
log.Printf("Error unmarshalling AAGUID data: %v", err)
err = json.Unmarshal(data, &aaguidMap)
if err != nil {
slog.Error("Error unmarshalling AAGUID data", slog.Any("error", err))
return
}
}

View File

@@ -0,0 +1,24 @@
package utils
import (
"bufio"
"fmt"
"os"
"strings"
)
// PromptForConfirmation prompts the user to answer "y" in the terminal
func PromptForConfirmation(prompt string) (bool, error) {
fmt.Print(prompt + " [y/N]: ")
reader := bufio.NewReader(os.Stdin)
r, err := reader.ReadString('\n')
if err != nil {
return false, fmt.Errorf("failed to read response: %w", err)
}
r = strings.TrimSpace(strings.ToLower(r))
ok := r == "yes" || r == "y"
return ok, nil
}

View File

@@ -0,0 +1,69 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
"fmt"
"io"
)
// ErrDecrypt is returned by Decrypt when the operation failed for any reason
var ErrDecrypt = errors.New("failed to decrypt data")
// Encrypt a byte slice using AES-GCM and a random nonce
// Important: do not encrypt more than ~4 billion messages with the same key!
func Encrypt(key []byte, plaintext []byte, associatedData []byte) (ciphertext []byte, err error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create block cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %w", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
_, err = io.ReadFull(rand.Reader, nonce)
if err != nil {
return nil, fmt.Errorf("failed to generate random nonce: %w", err)
}
// Allocate the slice for the result, with additional space for the nonce and overhead
ciphertext = make([]byte, 0, len(plaintext)+aead.NonceSize()+aead.Overhead())
ciphertext = append(ciphertext, nonce...)
// Encrypt the plaintext
// Tag is automatically added at the end
ciphertext = aead.Seal(ciphertext, nonce, plaintext, associatedData)
return ciphertext, nil
}
// Decrypt a byte slice using AES-GCM
func Decrypt(key []byte, ciphertext []byte, associatedData []byte) (plaintext []byte, err error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create block cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %w", err)
}
// Extract the nonce
if len(ciphertext) < (aead.NonceSize() + aead.Overhead()) {
return nil, ErrDecrypt
}
// Decrypt the data
plaintext, err = aead.Open(nil, ciphertext[:aead.NonceSize()], ciphertext[aead.NonceSize():], associatedData)
if err != nil {
// Note: we do not return the exact error here, to avoid disclosing information
return nil, ErrDecrypt
}
return plaintext, nil
}

View File

@@ -0,0 +1,208 @@
package crypto
import (
"crypto/rand"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEncryptDecrypt(t *testing.T) {
tests := []struct {
name string
keySize int
plaintext string
associatedData []byte
}{
{
name: "AES-128 with short plaintext",
keySize: 16,
plaintext: "Hello, World!",
associatedData: []byte("test-aad"),
},
{
name: "AES-192 with medium plaintext",
keySize: 24,
plaintext: "This is a longer message to test encryption and decryption",
associatedData: []byte("associated-data-192"),
},
{
name: "AES-256 with unicode",
keySize: 32,
plaintext: "Hello 世界! 🌍 Testing unicode characters", //nolint:gosmopolitan
associatedData: []byte("unicode-test"),
},
{
name: "No associated data",
keySize: 32,
plaintext: "Testing without associated data",
associatedData: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Generate random key
key := make([]byte, tt.keySize)
_, err := rand.Read(key)
require.NoError(t, err, "Failed to generate random key")
plaintext := []byte(tt.plaintext)
// Test encryption
ciphertext, err := Encrypt(key, plaintext, tt.associatedData)
require.NoError(t, err, "Encrypt should succeed")
// Verify ciphertext is different from plaintext (unless empty)
if len(plaintext) > 0 {
assert.NotEqual(t, plaintext, ciphertext)
}
// Test decryption
decrypted, err := Decrypt(key, ciphertext, tt.associatedData)
require.NoError(t, err, "Decrypt should succeed")
// Verify decrypted text matches original
assert.Equal(t, plaintext, decrypted, "Decrypted text should match original")
})
}
}
func TestEncryptWithInvalidKeySize(t *testing.T) {
invalidKeySizes := []int{8, 12, 33, 47, 55, 128}
for _, keySize := range invalidKeySizes {
t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) {
key := make([]byte, keySize)
plaintext := []byte("test message")
_, err := Encrypt(key, plaintext, nil)
require.Error(t, err)
assert.ErrorContains(t, err, "invalid key size")
})
}
}
func TestDecryptWithInvalidKeySize(t *testing.T) {
invalidKeySizes := []int{8, 12, 33, 47, 55, 128}
for _, keySize := range invalidKeySizes {
t.Run(fmt.Sprintf("Key size %d", keySize), func(t *testing.T) {
key := make([]byte, keySize)
ciphertext := []byte("fake ciphertext")
_, err := Decrypt(key, ciphertext, nil)
require.Error(t, err)
assert.ErrorContains(t, err, "invalid key size")
})
}
}
func TestDecryptWithInvalidCiphertext(t *testing.T) {
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err, "Failed to generate random key")
tests := []struct {
name string
ciphertext []byte
}{
{
name: "empty ciphertext",
ciphertext: []byte{},
},
{
name: "too short ciphertext",
ciphertext: []byte("short"),
},
{
name: "random invalid data",
ciphertext: []byte("this is not valid encrypted data"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := Decrypt(key, tt.ciphertext, nil)
require.Error(t, err)
require.ErrorIs(t, err, ErrDecrypt)
})
}
}
func TestDecryptWithWrongKey(t *testing.T) {
// Generate two different keys
key1 := make([]byte, 32)
key2 := make([]byte, 32)
_, err := rand.Read(key1)
require.NoError(t, err)
_, err = rand.Read(key2)
require.NoError(t, err)
plaintext := []byte("secret message")
// Encrypt with key1
ciphertext, err := Encrypt(key1, plaintext, nil)
require.NoError(t, err)
// Try to decrypt with key2
_, err = Decrypt(key2, ciphertext, nil)
require.Error(t, err)
require.ErrorIs(t, err, ErrDecrypt)
}
func TestDecryptWithWrongAssociatedData(t *testing.T) {
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err, "Failed to generate random key")
plaintext := []byte("secret message")
correctAAD := []byte("correct-aad")
wrongAAD := []byte("wrong-aad")
// Encrypt with correct AAD
ciphertext, err := Encrypt(key, plaintext, correctAAD)
require.NoError(t, err)
// Try to decrypt with wrong AAD
_, err = Decrypt(key, ciphertext, wrongAAD)
require.Error(t, err)
require.ErrorIs(t, err, ErrDecrypt)
// Verify correct AAD works
decrypted, err := Decrypt(key, ciphertext, correctAAD)
require.NoError(t, err)
assert.Equal(t, plaintext, decrypted, "Decrypted text should match original when using correct AAD")
}
func TestEncryptDecryptConsistency(t *testing.T) {
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err)
plaintext := []byte("consistency test message")
associatedData := []byte("test-aad")
// Encrypt multiple times and verify we get different ciphertexts (due to random IV)
ciphertext1, err := Encrypt(key, plaintext, associatedData)
require.NoError(t, err)
ciphertext2, err := Encrypt(key, plaintext, associatedData)
require.NoError(t, err)
// Ciphertexts should be different (due to random IV)
assert.NotEqual(t, ciphertext1, ciphertext2, "Multiple encryptions of same plaintext should produce different ciphertexts")
// Both should decrypt to the same plaintext
decrypted1, err := Decrypt(key, ciphertext1, associatedData)
require.NoError(t, err)
decrypted2, err := Decrypt(key, ciphertext2, associatedData)
require.NoError(t, err)
assert.Equal(t, plaintext, decrypted1, "First decrypted text should match original")
assert.Equal(t, plaintext, decrypted2, "Second decrypted text should match original")
assert.Equal(t, decrypted1, decrypted2, "Both decrypted texts should be identical")
}

View File

@@ -0,0 +1,33 @@
package utils
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
// BearerAuth returns the value of the bearer token in the Authorization header if present
func BearerAuth(r *http.Request) (string, bool) {
const prefix = "bearer "
authHeader := r.Header.Get("Authorization")
if len(authHeader) >= len(prefix) && strings.ToLower(authHeader[:len(prefix)]) == prefix {
return authHeader[len(prefix):], true
}
return "", false
}
// SetCacheControlHeader sets the Cache-Control header for the response.
func SetCacheControlHeader(ctx *gin.Context, maxAge, staleWhileRevalidate time.Duration) {
_, ok := ctx.GetQuery("skipCache")
if !ok {
maxAgeSeconds := strconv.Itoa(int(maxAge.Seconds()))
staleWhileRevalidateSeconds := strconv.Itoa(int(staleWhileRevalidate.Seconds()))
ctx.Header("Cache-Control", "public, max-age="+maxAgeSeconds+", stale-while-revalidate="+staleWhileRevalidateSeconds)
}
}

View File

@@ -0,0 +1,65 @@
package utils
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBearerAuth(t *testing.T) {
tests := []struct {
name string
authHeader string
expectedToken string
expectedFound bool
}{
{
name: "Valid bearer token",
authHeader: "Bearer token123",
expectedToken: "token123",
expectedFound: true,
},
{
name: "Valid bearer token with mixed case",
authHeader: "beARer token456",
expectedToken: "token456",
expectedFound: true,
},
{
name: "No bearer prefix",
authHeader: "Basic dXNlcjpwYXNz",
expectedToken: "",
expectedFound: false,
},
{
name: "Empty auth header",
authHeader: "",
expectedToken: "",
expectedFound: false,
},
{
name: "Bearer prefix only",
authHeader: "Bearer ",
expectedToken: "",
expectedFound: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://example.com", nil)
require.NoError(t, err, "Failed to create request")
if tt.authHeader != "" {
req.Header.Set("Authorization", tt.authHeader)
}
token, found := BearerAuth(req)
assert.Equal(t, tt.expectedFound, found)
assert.Equal(t, tt.expectedToken, token)
})
}
}

View File

@@ -29,9 +29,9 @@ func CreateProfilePicture(file io.Reader) (io.Reader, error) {
pr, pw := io.Pipe()
go func() {
err = imaging.Encode(pw, img, imaging.PNG)
if err != nil {
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %w", err))
innerErr := imaging.Encode(pw, img, imaging.PNG)
if innerErr != nil {
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %w", innerErr))
return
}
pw.Close()

View File

@@ -0,0 +1,50 @@
package jwk
import (
"fmt"
"github.com/lestrrat-go/jwx/v3/jwk"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
type KeyProviderOpts struct {
EnvConfig *common.EnvConfigSchema
DB *gorm.DB
Kek []byte
}
type KeyProvider interface {
Init(opts KeyProviderOpts) error
LoadKey() (jwk.Key, error)
SaveKey(key jwk.Key) error
}
func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID string) (keyProvider KeyProvider, err error) {
// Load the encryption key (KEK) if present
kek, err := LoadKeyEncryptionKey(envConfig, instanceID)
if err != nil {
return nil, fmt.Errorf("failed to load encryption key: %w", err)
}
// Get the key provider
switch envConfig.KeysStorage {
case "file", "":
keyProvider = &KeyProviderFile{}
case "database":
keyProvider = &KeyProviderDatabase{}
default:
return nil, fmt.Errorf("invalid key storage '%s'", envConfig.KeysStorage)
}
err = keyProvider.Init(KeyProviderOpts{
DB: db,
EnvConfig: envConfig,
Kek: kek,
})
if err != nil {
return nil, fmt.Errorf("failed to init key provider of type '%s': %w", envConfig.KeysStorage, err)
}
return keyProvider, nil
}

View File

@@ -0,0 +1,109 @@
package jwk
import (
"context"
"encoding/base64"
"errors"
"fmt"
"time"
"github.com/lestrrat-go/jwx/v3/jwk"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/model"
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
)
const PrivateKeyDBKey = "jwt_private_key.json"
type KeyProviderDatabase struct {
db *gorm.DB
kek []byte
}
func (f *KeyProviderDatabase) Init(opts KeyProviderOpts) error {
if len(opts.Kek) == 0 {
return errors.New("an encryption key is required when using the 'database' key provider")
}
f.db = opts.DB
f.kek = opts.Kek
return nil
}
func (f *KeyProviderDatabase) LoadKey() (key jwk.Key, err error) {
row := model.KV{
Key: PrivateKeyDBKey,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err = f.db.WithContext(ctx).First(&row).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// Key not present in the database - return nil so a new one can be generated
return nil, nil
} else if err != nil {
return nil, fmt.Errorf("failed to retrieve private key from the database: %w", err)
}
if row.Value == nil || *row.Value == "" {
// Key not present in the database - return nil so a new one can be generated
return nil, nil
}
// Decode from base64
enc, err := base64.StdEncoding.DecodeString(*row.Value)
if err != nil {
return nil, fmt.Errorf("failed to read encrypted private key: not a valid base64-encoded value: %w", err)
}
// Decrypt the data
data, err := cryptoutils.Decrypt(f.kek, enc, nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt private key: %w", err)
}
// Parse the key
key, err = jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse encrypted private key: %w", err)
}
return key, nil
}
func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error {
// Encode the key to JSON
data, err := EncodeJWKBytes(key)
if err != nil {
return fmt.Errorf("failed to encode key to JSON: %w", err)
}
// Encrypt the key then encode to Base64
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
if err != nil {
return fmt.Errorf("failed to encrypt key: %w", err)
}
encB64 := base64.StdEncoding.EncodeToString(enc)
// Save to database
row := model.KV{
Key: PrivateKeyDBKey,
Value: &encB64,
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err = f.db.WithContext(ctx).Create(&row).Error
if err != nil {
// There's one scenario where if Pocket ID is started fresh with more than 1 replica, they both could be trying to create the private key in the database at the same time
// In this case, only one of the replicas will succeed; the other one(s) will return an error here, which will cascade down and cause the replica(s) to crash and be restarted (at that point they'll load the then-existing key from the database)
return fmt.Errorf("failed to store private key in database: %w", err)
}
return nil
}
// Compile-time interface check
var _ KeyProvider = (*KeyProviderDatabase)(nil)

View File

@@ -0,0 +1,275 @@
package jwk
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/base64"
"testing"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/model"
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
)
func TestKeyProviderDatabase_Init(t *testing.T) {
t.Run("Init fails when KEK is not provided", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: nil, // No KEK
})
require.Error(t, err, "Expected error when KEK is not provided")
require.ErrorContains(t, err, "encryption key is required")
})
t.Run("Init succeeds with KEK", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: generateTestKEK(t),
})
require.NoError(t, err, "Expected no error when KEK is provided")
})
}
func TestKeyProviderDatabase_LoadKey(t *testing.T) {
// Generate a test key to use in our tests
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
key, err := jwk.Import(pk)
require.NoError(t, err)
t.Run("LoadKey with no existing key", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
kek := generateTestKEK(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: kek,
})
require.NoError(t, err)
// Load key when none exists
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.Nil(t, loadedKey, "Expected nil key when no key exists in database")
})
t.Run("LoadKey with existing key", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
kek := generateTestKEK(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: kek,
})
require.NoError(t, err)
// Save a key
err = provider.SaveKey(key)
require.NoError(t, err)
// Load the key
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists in database")
// Verify the loaded key is the same as the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
})
t.Run("LoadKey with invalid base64", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
kek := generateTestKEK(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: kek,
})
require.NoError(t, err)
// Insert invalid base64 data
invalidBase64 := "not-valid-base64"
err = db.Create(&model.KV{
Key: PrivateKeyDBKey,
Value: &invalidBase64,
}).Error
require.NoError(t, err)
// Attempt to load the key
loadedKey, err := provider.LoadKey()
require.Error(t, err, "Expected error when loading key with invalid base64")
require.ErrorContains(t, err, "not a valid base64-encoded value")
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
})
t.Run("LoadKey with invalid encrypted data", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
kek := generateTestKEK(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: kek,
})
require.NoError(t, err)
// Insert valid base64 but invalid encrypted data
invalidData := base64.StdEncoding.EncodeToString([]byte("not-valid-encrypted-data"))
err = db.Create(&model.KV{
Key: PrivateKeyDBKey,
Value: &invalidData,
}).Error
require.NoError(t, err)
// Attempt to load the key
loadedKey, err := provider.LoadKey()
require.Error(t, err, "Expected error when loading key with invalid encrypted data")
require.ErrorContains(t, err, "failed to decrypt")
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
})
t.Run("LoadKey with valid encrypted data but wrong KEK", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
originalKek := generateTestKEK(t)
// Save a key with the original KEK
originalProvider := &KeyProviderDatabase{}
err := originalProvider.Init(KeyProviderOpts{
DB: db,
Kek: originalKek,
})
require.NoError(t, err)
err = originalProvider.SaveKey(key)
require.NoError(t, err)
// Now try to load with a different KEK
differentKek := generateTestKEK(t)
differentProvider := &KeyProviderDatabase{}
err = differentProvider.Init(KeyProviderOpts{
DB: db,
Kek: differentKek,
})
require.NoError(t, err)
// Attempt to load the key with the wrong KEK
loadedKey, err := differentProvider.LoadKey()
require.Error(t, err, "Expected error when loading key with wrong KEK")
require.ErrorContains(t, err, "failed to decrypt")
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
})
t.Run("LoadKey with invalid key data", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
kek := generateTestKEK(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: kek,
})
require.NoError(t, err)
// Create invalid key data (valid JSON but not a valid JWK)
invalidKeyData := []byte(`{"not": "a valid jwk"}`)
// Encrypt the invalid key data
encryptedData, err := cryptoutils.Encrypt(kek, invalidKeyData, nil)
require.NoError(t, err)
// Base64 encode the encrypted data
encodedData := base64.StdEncoding.EncodeToString(encryptedData)
// Save to database
err = db.Create(&model.KV{
Key: PrivateKeyDBKey,
Value: &encodedData,
}).Error
require.NoError(t, err)
// Attempt to load the key
loadedKey, err := provider.LoadKey()
require.Error(t, err, "Expected error when loading invalid key data")
require.ErrorContains(t, err, "failed to parse")
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
})
}
func TestKeyProviderDatabase_SaveKey(t *testing.T) {
// Generate a test key to use in our tests
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
key, err := jwk.Import(pk)
require.NoError(t, err)
t.Run("SaveKey and verify database record", func(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
kek := generateTestKEK(t)
provider := &KeyProviderDatabase{}
err := provider.Init(KeyProviderOpts{
DB: db,
Kek: kek,
})
require.NoError(t, err)
// Save the key
err = provider.SaveKey(key)
require.NoError(t, err, "Expected no error when saving key")
// Verify record exists in database
var kv model.KV
err = db.Where("key = ?", PrivateKeyDBKey).First(&kv).Error
require.NoError(t, err, "Expected to find key in database")
require.NotNil(t, kv.Value, "Expected non-nil value in database")
assert.NotEmpty(t, *kv.Value, "Expected non-empty value in database")
// Decode and decrypt to verify content
encBytes, err := base64.StdEncoding.DecodeString(*kv.Value)
require.NoError(t, err, "Expected valid base64 encoding")
decBytes, err := cryptoutils.Decrypt(kek, encBytes, nil)
require.NoError(t, err, "Expected valid encrypted data")
parsedKey, err := jwk.ParseKey(decBytes)
require.NoError(t, err, "Expected valid JWK data")
// Compare keys
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
})
}
func generateTestKEK(t *testing.T) []byte {
t.Helper()
// Generate a 32-byte kek
kek := make([]byte, 32)
_, err := rand.Read(kek)
require.NoError(t, err)
return kek
}

View File

@@ -0,0 +1,202 @@
package jwk
import (
"encoding/base64"
"fmt"
"os"
"path/filepath"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils"
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
)
const (
// PrivateKeyFile is the path in the data/keys folder where the key is stored
// This is a JSON file containing a key encoded as JWK
PrivateKeyFile = "jwt_private_key.json"
// PrivateKeyFileEncrypted is the path in the data/keys folder where the encrypted key is stored
// This is a encrypted JSON file containing a key encoded as JWK
PrivateKeyFileEncrypted = "jwt_private_key.json.enc"
)
type KeyProviderFile struct {
envConfig *common.EnvConfigSchema
kek []byte
}
func (f *KeyProviderFile) Init(opts KeyProviderOpts) error {
f.envConfig = opts.EnvConfig
f.kek = opts.Kek
return nil
}
func (f *KeyProviderFile) LoadKey() (jwk.Key, error) {
if len(f.kek) > 0 {
return f.loadEncryptedKey()
}
return f.loadKey()
}
func (f *KeyProviderFile) SaveKey(key jwk.Key) error {
if len(f.kek) > 0 {
return f.saveKeyEncrypted(key)
}
return f.saveKey(key)
}
func (f *KeyProviderFile) loadKey() (jwk.Key, error) {
var key jwk.Key
// First, check if we have a JWK file
// If we do, then we just load that
jwkPath := f.jwkPath()
ok, err := utils.FileExists(jwkPath)
if err != nil {
return nil, fmt.Errorf("failed to check if private key file exists at path '%s': %w", jwkPath, err)
}
if !ok {
// File doesn't exist, no key was loaded
return nil, nil
}
data, err := os.ReadFile(jwkPath)
if err != nil {
return nil, fmt.Errorf("failed to read private key file at path '%s': %w", jwkPath, err)
}
key, err = jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse private key file at path '%s': %w", jwkPath, err)
}
return key, nil
}
func (f *KeyProviderFile) loadEncryptedKey() (key jwk.Key, err error) {
// First, check if we have an encrypted JWK file
// If we do, then we just load that
encJwkPath := f.encJwkPath()
ok, err := utils.FileExists(encJwkPath)
if err != nil {
return nil, fmt.Errorf("failed to check if encrypted private key file exists at path '%s': %w", encJwkPath, err)
}
if ok {
encB64, err := os.ReadFile(encJwkPath)
if err != nil {
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': %w", encJwkPath, err)
}
// Decode from base64
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
n, err := base64.StdEncoding.Decode(enc, encB64)
if err != nil {
return nil, fmt.Errorf("failed to read encrypted private key file at path '%s': not a valid base64-encoded file: %w", encJwkPath, err)
}
// Decrypt the data
data, err := cryptoutils.Decrypt(f.kek, enc[:n], nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt private key file at path '%s': %w", encJwkPath, err)
}
// Parse the key
key, err = jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse encrypted private key file at path '%s': %w", encJwkPath, err)
}
return key, nil
}
// Check if we have an un-encrypted JWK file
key, err = f.loadKey()
if err != nil {
return nil, fmt.Errorf("failed to load un-encrypted key file: %w", err)
}
if key == nil {
// No key exists, encrypted or un-encrypted
return nil, nil
}
// If we are here, we have loaded a key that was un-encrypted
// We need to replace the plaintext key with the encrypted one before we return
err = f.saveKeyEncrypted(key)
if err != nil {
return nil, fmt.Errorf("failed to save encrypted key file: %w", err)
}
jwkPath := f.jwkPath()
err = os.Remove(jwkPath)
if err != nil {
return nil, fmt.Errorf("failed to remove un-encrypted key file at path '%s': %w", jwkPath, err)
}
return key, nil
}
func (f *KeyProviderFile) saveKey(key jwk.Key) error {
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
if err != nil {
return fmt.Errorf("failed to create directory '%s' for key file: %w", f.envConfig.KeysPath, err)
}
jwkPath := f.jwkPath()
keyFile, err := os.OpenFile(jwkPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to create key file at path '%s': %w", jwkPath, err)
}
defer keyFile.Close()
// Write the JSON file to disk
err = EncodeJWK(keyFile, key)
if err != nil {
return fmt.Errorf("failed to write key file at path '%s': %w", jwkPath, err)
}
return nil
}
func (f *KeyProviderFile) saveKeyEncrypted(key jwk.Key) error {
err := os.MkdirAll(f.envConfig.KeysPath, 0700)
if err != nil {
return fmt.Errorf("failed to create directory '%s' for encrypted key file: %w", f.envConfig.KeysPath, err)
}
// Encode the key to JSON
data, err := EncodeJWKBytes(key)
if err != nil {
return fmt.Errorf("failed to encode key to JSON: %w", err)
}
// Encrypt the key then encode to Base64
enc, err := cryptoutils.Encrypt(f.kek, data, nil)
if err != nil {
return fmt.Errorf("failed to encrypt key: %w", err)
}
encB64 := make([]byte, base64.StdEncoding.EncodedLen(len(enc)))
base64.StdEncoding.Encode(encB64, enc)
// Write to disk
encJwkPath := f.encJwkPath()
err = os.WriteFile(encJwkPath, encB64, 0600)
if err != nil {
return fmt.Errorf("failed to write encrypted key file at path '%s': %w", encJwkPath, err)
}
return nil
}
func (f *KeyProviderFile) jwkPath() string {
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFile)
}
func (f *KeyProviderFile) encJwkPath() string {
return filepath.Join(f.envConfig.KeysPath, PrivateKeyFileEncrypted)
}
// Compile-time interface check
var _ KeyProvider = (*KeyProviderFile)(nil)

View File

@@ -0,0 +1,320 @@
package jwk
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/base64"
"os"
"path/filepath"
"testing"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils"
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
)
func TestKeyProviderFile_LoadKey(t *testing.T) {
// Generate a test key to use in our tests
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
key, err := jwk.Import(pk)
require.NoError(t, err)
t.Run("LoadKey with no existing key", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err := provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
})
require.NoError(t, err)
// Load key when none exists
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
})
t.Run("LoadKey with no existing key (with kek)", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err = provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
Kek: makeKEK(t),
})
require.NoError(t, err)
// Load key when none exists
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.Nil(t, loadedKey, "Expected nil key when no key exists")
})
t.Run("LoadKey with unencrypted key", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err := provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
})
require.NoError(t, err)
// Save a key
err = provider.SaveKey(key)
require.NoError(t, err)
// Make sure the key file exists
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err := utils.FileExists(keyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected key file to exist")
// Load the key
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists")
// Verify the loaded key is the same as the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
})
t.Run("LoadKey with encrypted key", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err = provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
Kek: makeKEK(t),
})
require.NoError(t, err)
// Save a key (will be encrypted)
err = provider.SaveKey(key)
require.NoError(t, err)
// Make sure the encrypted key file exists
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
exists, err := utils.FileExists(encKeyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected encrypted key file to exist")
// Make sure the unencrypted key file does not exist
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err = utils.FileExists(keyPath)
require.NoError(t, err)
assert.False(t, exists, "Expected unencrypted key file to not exist")
// Load the key
loadedKey, err := provider.LoadKey()
require.NoError(t, err)
assert.NotNil(t, loadedKey, "Expected non-nil key when encrypted key exists")
// Verify the loaded key is the same as the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key")
})
t.Run("LoadKey replaces unencrypted key with encrypted key when kek is provided", func(t *testing.T) {
tempDir := t.TempDir()
// First, create an unencrypted key
providerNoKek := &KeyProviderFile{}
err := providerNoKek.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
})
require.NoError(t, err)
// Save an unencrypted key
err = providerNoKek.SaveKey(key)
require.NoError(t, err)
// Verify unencrypted key exists
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err := utils.FileExists(keyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected unencrypted key file to exist")
// Now create a provider with a kek
kek := make([]byte, 32)
_, err = rand.Read(kek)
require.NoError(t, err)
providerWithKek := &KeyProviderFile{}
err = providerWithKek.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
Kek: kek,
})
require.NoError(t, err)
// Load the key - this should convert the unencrypted key to encrypted
loadedKey, err := providerWithKek.LoadKey()
require.NoError(t, err)
assert.NotNil(t, loadedKey, "Expected non-nil key when loading and converting key")
// Verify the unencrypted key no longer exists
exists, err = utils.FileExists(keyPath)
require.NoError(t, err)
assert.False(t, exists, "Expected unencrypted key file to be removed")
// Verify the encrypted key file exists
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
exists, err = utils.FileExists(encKeyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected encrypted key file to exist after conversion")
// Verify the key data
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
loadedKeyBytes, err := EncodeJWKBytes(loadedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, loadedKeyBytes, "Expected loaded key to match original key after conversion")
})
}
func TestKeyProviderFile_SaveKey(t *testing.T) {
// Generate a test key to use in our tests
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
key, err := jwk.Import(pk)
require.NoError(t, err)
t.Run("SaveKey unencrypted", func(t *testing.T) {
tempDir := t.TempDir()
provider := &KeyProviderFile{}
err := provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
})
require.NoError(t, err)
// Save the key
err = provider.SaveKey(key)
require.NoError(t, err)
// Verify the key file exists
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err := utils.FileExists(keyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected key file to exist")
// Verify the content of the key file
data, err := os.ReadFile(keyPath)
require.NoError(t, err)
parsedKey, err := jwk.ParseKey(data)
require.NoError(t, err)
// Compare the saved key with the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected saved key to match original key")
})
t.Run("SaveKey encrypted", func(t *testing.T) {
tempDir := t.TempDir()
// Generate a 64-byte kek
kek := makeKEK(t)
provider := &KeyProviderFile{}
err = provider.Init(KeyProviderOpts{
EnvConfig: &common.EnvConfigSchema{
KeysPath: tempDir,
},
Kek: kek,
})
require.NoError(t, err)
// Save the key (will be encrypted)
err = provider.SaveKey(key)
require.NoError(t, err)
// Verify the encrypted key file exists
encKeyPath := filepath.Join(tempDir, PrivateKeyFileEncrypted)
exists, err := utils.FileExists(encKeyPath)
require.NoError(t, err)
assert.True(t, exists, "Expected encrypted key file to exist")
// Verify the unencrypted key file doesn't exist
keyPath := filepath.Join(tempDir, PrivateKeyFile)
exists, err = utils.FileExists(keyPath)
require.NoError(t, err)
assert.False(t, exists, "Expected unencrypted key file to not exist")
// Manually decrypt the encrypted key file to verify it contains the correct key
encB64, err := os.ReadFile(encKeyPath)
require.NoError(t, err)
// Decode from base64
enc := make([]byte, base64.StdEncoding.DecodedLen(len(encB64)))
n, err := base64.StdEncoding.Decode(enc, encB64)
require.NoError(t, err)
enc = enc[:n] // Trim any padding
// Decrypt the data
data, err := cryptoutils.Decrypt(kek, enc, nil)
require.NoError(t, err)
// Parse the key
parsedKey, err := jwk.ParseKey(data)
require.NoError(t, err)
// Compare the decrypted key with the original
keyBytes, err := EncodeJWKBytes(key)
require.NoError(t, err)
parsedKeyBytes, err := EncodeJWKBytes(parsedKey)
require.NoError(t, err)
assert.Equal(t, keyBytes, parsedKeyBytes, "Expected decrypted key to match original key")
})
}
func makeKEK(t *testing.T) []byte {
t.Helper()
// Generate a 32-byte kek
kek := make([]byte, 32)
_, err := rand.Read(kek)
require.NoError(t, err)
return kek
}

View File

@@ -0,0 +1,168 @@
package jwk
import (
"bytes"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/hmac"
"crypto/rand"
"crypto/rsa"
"crypto/sha3"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
const (
// KeyUsageSigning is the usage for the private keys, for the "use" property
KeyUsageSigning = "sig"
)
// EncodeJWK encodes a jwk.Key to a writable stream.
func EncodeJWK(w io.Writer, key jwk.Key) error {
enc := json.NewEncoder(w)
enc.SetEscapeHTML(false)
return enc.Encode(key)
}
// EncodeJWKBytes encodes a jwk.Key to a byte slice.
func EncodeJWKBytes(key jwk.Key) ([]byte, error) {
b := &bytes.Buffer{}
err := EncodeJWK(b, key)
if err != nil {
return nil, err
}
return b.Bytes(), nil
}
// LoadKeyEncryptionKey loads the key encryption key for JWKs
func LoadKeyEncryptionKey(envConfig *common.EnvConfigSchema, instanceID string) (kek []byte, err error) {
// If there's no key, return
if len(envConfig.EncryptionKey) == 0 {
return nil, nil
}
// We need a 256-bit key for encryption with AES-GCM-256
// We use HMAC with SHA3-256 here to derive the key from the one passed as input
// The key is tied to a specific instance of Pocket ID
h := hmac.New(func() hash.Hash { return sha3.New256() }, []byte(envConfig.EncryptionKey))
fmt.Fprint(h, "pocketid/"+instanceID+"/jwk-kek")
kek = h.Sum(nil)
return kek, nil
}
// ImportRawKey imports a crypto key in "raw" format (e.g. crypto.PrivateKey) into a jwk.Key.
// It also populates additional fields such as the key ID, usage, and alg.
func ImportRawKey(rawKey any, alg string, crv string) (jwk.Key, error) {
key, err := jwk.Import(rawKey)
if err != nil {
return nil, fmt.Errorf("failed to import generated private key: %w", err)
}
// Generate the key ID
kid, err := generateRandomKeyID()
if err != nil {
return nil, fmt.Errorf("failed to generate key ID: %w", err)
}
_ = key.Set(jwk.KeyIDKey, kid)
// Set other required fields
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
EnsureAlgInKey(key, alg, crv)
return key, nil
}
// generateRandomKeyID generates a random key ID.
func generateRandomKeyID() (string, error) {
buf := make([]byte, 8)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return "", fmt.Errorf("failed to read random bytes: %w", err)
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
// EnsureAlgInKey ensures that the key contains an "alg" parameter (and "crv", if needed), set depending on the key type
func EnsureAlgInKey(key jwk.Key, alg string, crv string) {
_, ok := key.Algorithm()
if ok {
// Algorithm is already set
return
}
if alg != "" {
_ = key.Set(jwk.AlgorithmKey, alg)
if crv != "" {
eca, ok := jwa.LookupEllipticCurveAlgorithm(crv)
if ok {
switch key.KeyType() {
case jwa.EC():
_ = key.Set(jwk.ECDSACrvKey, eca)
case jwa.OKP():
_ = key.Set(jwk.OKPCrvKey, eca)
}
}
}
return
}
// If we don't have an algorithm, set the default for the key type
switch key.KeyType() {
case jwa.RSA():
// Default to RS256 for RSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
case jwa.EC():
// Default to ES256 for ECDSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
_ = key.Set(jwk.ECDSACrvKey, jwa.P256())
case jwa.OKP():
// Default to EdDSA and Ed25519 for OKP keys
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
_ = key.Set(jwk.OKPCrvKey, jwa.Ed25519())
}
}
// GenerateKey generates a new jwk.Key
func GenerateKey(alg string, crv string) (key jwk.Key, err error) {
var rawKey any
switch alg {
case jwa.RS256().String():
rawKey, err = rsa.GenerateKey(rand.Reader, 2048)
case jwa.RS384().String():
rawKey, err = rsa.GenerateKey(rand.Reader, 3072)
case jwa.RS512().String():
rawKey, err = rsa.GenerateKey(rand.Reader, 4096)
case jwa.ES256().String():
rawKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case jwa.ES384().String():
rawKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case jwa.ES512().String():
rawKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
case jwa.EdDSA().String():
switch crv {
case jwa.Ed25519().String():
_, rawKey, err = ed25519.GenerateKey(rand.Reader)
default:
return nil, errors.New("unsupported curve for EdDSA algorithm")
}
default:
return nil, errors.New("unsupported key algorithm")
}
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %w", err)
}
// Import the raw key
return ImportRawKey(rawKey, alg, crv)
}

View File

@@ -0,0 +1,324 @@
package jwk
import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"testing"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGenerateKey(t *testing.T) {
tests := []struct {
name string
alg string
crv string
expectError bool
expectedAlg jwa.SignatureAlgorithm
}{
{
name: "RS256",
alg: jwa.RS256().String(),
crv: "",
expectError: false,
expectedAlg: jwa.RS256(),
},
{
name: "RS384",
alg: jwa.RS384().String(),
crv: "",
expectError: false,
expectedAlg: jwa.RS384(),
},
// Skip the RS512 test as generating a RSA-4096 key can take some time
/* {
name: "RS512",
alg: jwa.RS512().String(),
crv: "",
expectError: false,
expectedAlg: jwa.RS512(),
}, */
{
name: "ES256",
alg: jwa.ES256().String(),
crv: jwa.P256().String(),
expectError: false,
expectedAlg: jwa.ES256(),
},
{
name: "ES384",
alg: jwa.ES384().String(),
crv: jwa.P384().String(),
expectError: false,
expectedAlg: jwa.ES384(),
},
{
name: "ES512",
alg: jwa.ES512().String(),
crv: jwa.P521().String(),
expectError: false,
expectedAlg: jwa.ES512(),
},
{
name: "EdDSA with Ed25519",
alg: jwa.EdDSA().String(),
crv: jwa.Ed25519().String(),
expectError: false,
expectedAlg: jwa.EdDSA(),
},
{
name: "EdDSA with unsupported curve",
alg: jwa.EdDSA().String(),
crv: "unsupported",
expectError: true,
},
{
name: "Unsupported algorithm",
alg: "UNSUPPORTED",
crv: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
key, err := GenerateKey(tt.alg, tt.crv)
if tt.expectError {
require.Error(t, err)
assert.Nil(t, key)
return
}
require.NoError(t, err)
require.NotNil(t, key)
// Verify the algorithm is set correctly
alg, ok := key.Algorithm()
require.True(t, ok, "algorithm should be set in the key")
assert.Equal(t, tt.expectedAlg.String(), alg.String())
// Verify other required fields are set
kid, ok := key.KeyID()
assert.True(t, ok, "key ID should be set")
assert.NotEmpty(t, kid, "key ID should not be empty")
usage, ok := key.KeyUsage()
assert.True(t, ok, "key usage should be set")
assert.Equal(t, KeyUsageSigning, usage)
var crv any
_ = key.Get("crv", &crv)
// Verify key type matches expected algorithm
switch tt.expectedAlg {
case jwa.RS256(), jwa.RS384(), jwa.RS512():
assert.Equal(t, jwa.RSA(), key.KeyType())
assert.Nil(t, crv)
case jwa.ES256(), jwa.ES384(), jwa.ES512():
assert.Equal(t, jwa.EC(), key.KeyType())
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
_ = assert.NotNil(t, crv) &&
assert.True(t, ok) &&
assert.Equal(t, tt.crv, eca.String())
case jwa.EdDSA():
assert.Equal(t, jwa.OKP(), key.KeyType())
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
_ = assert.NotNil(t, crv) &&
assert.True(t, ok) &&
assert.Equal(t, tt.crv, eca.String())
}
})
}
}
func TestEnsureAlgInKey(t *testing.T) {
// Generate an RSA-2048 key
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
t.Run("does not change alg already set", func(t *testing.T) {
// Import the RSA key
key, err := jwk.Import(rsaKey)
require.NoError(t, err)
// Pre-set the algorithm
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
// Call EnsureAlgInKey with a different algorithm
EnsureAlgInKey(key, jwa.RS384().String(), "")
// Verify the algorithm wasn't changed
alg, ok := key.Algorithm()
require.True(t, ok)
assert.Equal(t, jwa.RS256().String(), alg.String())
})
t.Run("set algorithm to explicitly-provided value", func(t *testing.T) {
tests := []struct {
name string
keyGen func() (any, error)
alg string
crv string
expectedAlg jwa.SignatureAlgorithm
expectedCrv string
}{
{
name: "RSA key with RS384",
keyGen: func() (any, error) {
return rsaKey, nil
},
alg: jwa.RS384().String(),
crv: "",
expectedAlg: jwa.RS384(),
expectedCrv: "",
},
{
name: "ECDSA key with ES384",
keyGen: func() (any, error) {
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
},
alg: jwa.ES384().String(),
crv: jwa.P384().String(),
expectedAlg: jwa.ES384(),
expectedCrv: jwa.P384().String(),
},
{
name: "Ed25519 key with EdDSA",
keyGen: func() (any, error) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
return priv, err
},
alg: jwa.EdDSA().String(),
crv: jwa.Ed25519().String(),
expectedAlg: jwa.EdDSA(),
expectedCrv: jwa.Ed25519().String(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rawKey, err := tt.keyGen()
require.NoError(t, err)
key, err := jwk.Import(rawKey)
require.NoError(t, err)
// Ensure no algorithm is set initially
_, ok := key.Algorithm()
assert.False(t, ok)
// Call EnsureAlgInKey
EnsureAlgInKey(key, tt.alg, tt.crv)
// Verify the algorithm was set correctly
alg, ok := key.Algorithm()
require.True(t, ok)
assert.Equal(t, tt.expectedAlg.String(), alg.String())
// Verify curve if expected
if tt.expectedCrv != "" {
var crv any
_ = key.Get("crv", &crv)
require.NotNil(t, crv)
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
require.True(t, ok)
assert.Equal(t, tt.expectedCrv, eca.String())
}
})
}
})
t.Run("set default algorithms if not present", func(t *testing.T) {
tests := []struct {
name string
keyGen func() (any, error)
expectedAlg jwa.SignatureAlgorithm
expectedCrv string
}{
{
name: "RSA key defaults to RS256",
keyGen: func() (any, error) {
return rsaKey, nil
},
expectedAlg: jwa.RS256(),
expectedCrv: "",
},
{
name: "ECDSA key defaults to ES256 with P256",
keyGen: func() (any, error) {
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
},
expectedAlg: jwa.ES256(),
expectedCrv: jwa.P256().String(),
},
{
name: "Ed25519 key defaults to EdDSA with Ed25519",
keyGen: func() (any, error) {
_, priv, err := ed25519.GenerateKey(rand.Reader)
return priv, err
},
expectedAlg: jwa.EdDSA(),
expectedCrv: jwa.Ed25519().String(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rawKey, err := tt.keyGen()
require.NoError(t, err)
key, err := jwk.Import(rawKey)
require.NoError(t, err)
// Ensure no algorithm is set initially
_, ok := key.Algorithm()
assert.False(t, ok)
// Call EnsureAlgInKey with empty parameters
EnsureAlgInKey(key, "", "")
// Verify the default algorithm was set
alg, ok := key.Algorithm()
require.True(t, ok)
assert.Equal(t, tt.expectedAlg.String(), alg.String())
// Verify curve if expected
if tt.expectedCrv != "" {
var crv any
_ = key.Get("crv", &crv)
require.NotNil(t, crv)
eca, ok := crv.(jwa.EllipticCurveAlgorithm)
require.True(t, ok)
assert.Equal(t, tt.expectedCrv, eca.String())
}
})
}
})
t.Run("invalid curve should not set curve parameter", func(t *testing.T) {
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
key, err := jwk.Import(rsaKey)
require.NoError(t, err)
// Call EnsureAlgInKey with invalid curve
EnsureAlgInKey(key, jwa.RS256().String(), "invalid-curve")
// Verify algorithm was set but curve was not
alg, ok := key.Algorithm()
require.True(t, ok)
assert.Equal(t, jwa.RS256().String(), alg.String())
var crv any
_ = key.Get("crv", &crv)
assert.Nil(t, crv)
})
}

View File

@@ -0,0 +1,20 @@
package utils
import (
"fmt"
"github.com/lestrrat-go/jwx/v3/jwt"
)
func GetClaimsFromToken(token jwt.Token) (map[string]any, error) {
keys := token.Keys()
claims := make(map[string]any, len(keys))
for _, key := range keys {
var value any
if err := token.Get(key, &value); err != nil {
return nil, fmt.Errorf("failed to get claim %s: %w", key, err)
}
claims[key] = value
}
return claims, nil
}

View File

@@ -34,9 +34,12 @@ func PaginateAndSort(sortedPaginationRequest SortedPaginationRequest, query *gor
sortField, sortFieldFound := reflect.TypeOf(result).Elem().Elem().FieldByName(capitalizedSortColumn)
isSortable, _ := strconv.ParseBool(sortField.Tag.Get("sortable"))
isValidSortOrder := sort.Direction == "asc" || sort.Direction == "desc"
if sortFieldFound && isSortable && isValidSortOrder {
if sort.Direction == "" || (sort.Direction != "asc" && sort.Direction != "desc") {
sort.Direction = "asc"
}
if sortFieldFound && isSortable {
columnName := CamelCaseToSnakeCase(sort.Column)
query = query.Clauses(clause.OrderBy{
Columns: []clause.OrderByColumn{

View File

@@ -2,7 +2,7 @@ package signals
import (
"context"
"log"
"log/slog"
"os"
"os/signal"
"syscall"
@@ -28,11 +28,11 @@ func SignalContext(parentCtx context.Context) context.Context {
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
go func() {
<-sigCh
log.Println("Received interrupt signal. Shutting down…")
slog.Info("Received interrupt signal. Shutting down…")
cancel()
<-sigCh
log.Println("Received a second interrupt signal. Forcing an immediate shutdown.")
slog.Warn("Received a second interrupt signal. Forcing an immediate shutdown.")
os.Exit(1)
}()

Some files were not shown because too many files have changed in this diff Show More