Compare commits

..

80 Commits

Author SHA1 Message Date
Elias Schneider
6004f84845 release: 0.51.0 2025-04-28 11:15:52 +02:00
Alessandro (Ale) Segala
3ec98736cf refactor: graceful shutdown for server (#482) 2025-04-28 11:13:50 +02:00
Elias Schneider
ce24372c57 fix: do not require PKCE for public clients 2025-04-28 11:02:35 +02:00
Elias Schneider
4614769b84 refactor: reorganize imports 2025-04-28 10:49:54 +02:00
Elias Schneider
86d2b5f59f fix: return correct error message if user isn't authorized 2025-04-28 10:39:17 +02:00
Elias Schneider
1efd1d182d fix: hide global audit log switch for non admin users 2025-04-28 10:38:53 +02:00
Elias Schneider
0a24ab8001 fix: updating scopes of an authorized client fails with Postgres 2025-04-28 09:29:18 +02:00
James18232
02cacba5c5 feat: new login code card position for mobile devices (#452)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-28 04:04:48 +00:00
Elias Schneider
38653e2aa4 chore(translations): update translations via Crowdin (#485) 2025-04-27 23:00:37 -05:00
Elias Schneider
8cc9b159a5 release: 0.50.0 2025-04-27 21:58:28 +02:00
James Baker
990c8af3d1 Fix incorrectly swapped refreshToken and accessToken (#490) 2025-04-27 14:09:07 -05:00
Alessandro (Ale) Segala
4c33793678 fix: pass context to methods that were missing it (#487) 2025-04-26 12:32:42 -05:00
Elias Schneider
9e06f70380 chore(translations): update translations via Crowdin (#479)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-25 12:21:59 -05:00
Kyle Mendell
22f7d64bf0 feat: device authorization endpoint (#270)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-25 12:14:51 -05:00
Kyle Mendell
630327c979 feat: make family name optional (#476) 2025-04-25 09:52:09 -05:00
Alessandro (Ale) Segala
662506260e refactor: do not force redirects to happen on the server (#481) 2025-04-24 21:09:52 +02:00
Star_caorui
8e66af627a chore(translations): Add Simplified Chinese translation. (#473) 2025-04-23 18:39:37 +00:00
Alessandro (Ale) Segala
270c30334d fix: prevent deadlock when trying to delete LDAP users (#471) 2025-04-22 15:16:44 +02:00
Elias Schneider
c73c3ceb5e chore(translations): update translations via Crowdin (#468) 2025-04-21 23:02:10 -05:00
Alessandro (Ale) Segala
22725d30f4 fix: do not override XDG_DATA_HOME/XDG_CONFIG_HOME if they are already set (#472) 2025-04-21 22:58:32 -05:00
eiqnepm
76b753f9f2 fix: rootless Caddy data and configuration (#470) 2025-04-21 13:15:51 +02:00
Elias Schneider
453a765107 release: 0.49.0 2025-04-20 20:00:09 +02:00
Elias Schneider
f03645d545 chore(translations): update translations via Crowdin (#467) 2025-04-20 17:59:49 +00:00
Elias Schneider
55273d68c9 chore(translations): fix typo in key 2025-04-20 19:51:12 +02:00
Elias Schneider
4e05b82f02 fix: hide alternative sign in button if user is already authenticated 2025-04-20 19:03:58 +02:00
Elias Schneider
2597907578 refactor: fix type errors 2025-04-20 18:54:45 +02:00
Kyle Mendell
debef9a66b ci/cd: setup caching and improve ci job performance (#465) 2025-04-20 11:48:46 -05:00
Elias Schneider
9122e75101 feat: add ability to disable API key expiration email 2025-04-20 18:41:03 +02:00
Elias Schneider
fe1c4b18cd feat: add ability to send login code via email (#457)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-20 18:32:40 +02:00
Elias Schneider
e571996cb5 fix: disable animations not respected on authorize and logout page 2025-04-20 17:04:00 +02:00
Elias Schneider
fb862d3ec3 chore(translations): update translations via Crowdin (#459)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-20 09:43:27 -05:00
Kyle Mendell
26f01f205b feat: send email to user when api key expires within 7 days (#451)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-20 14:40:20 +00:00
Elias Schneider
c37a3e0ed1 fix: remove limit of 20 callback URLs 2025-04-20 16:32:11 +02:00
Elias Schneider
eb689eb56e feat: add description to callback URL inputs 2025-04-20 00:32:27 +02:00
Elias Schneider
60bad9e985 fix: locale change in dropdown doesn't work on first try 2025-04-20 00:31:33 +02:00
Elias Schneider
e21ee8a871 chore: add kmendell to FUNDING.yml 2025-04-19 18:51:01 +02:00
Elias Schneider
04006eb5cc release: 0.48.0 2025-04-18 18:34:52 +02:00
Elias Schneider
84f1d5c906 fix: user querying fails on global audit log page with Postgres 2025-04-18 18:33:14 +02:00
Elias Schneider
983e989be1 chore(translations): update translations via Crowdin (#456) 2025-04-18 18:21:04 +02:00
Kyle Mendell
c843a60131 feat: disable/enable users (#437)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-18 15:38:50 +00:00
Elias Schneider
56a8b5d0c0 feat: add gif support for logo and background image 2025-04-18 17:31:04 +02:00
Elias Schneider
f0dce41fbc fix: callback URL doesn't get rejected if it starts with a different string 2025-04-17 20:52:58 +02:00
Elias Schneider
0111a58dac fix: add "type" as reserved claim 2025-04-17 20:41:21 +02:00
Elias Schneider
50e4c5c314 chore(translations): update translations via Crowdin (#444) 2025-04-17 20:19:50 +02:00
Kyle Mendell
5a6dfd9e50 fix: profile picture empty for users without first or last name (#449) 2025-04-17 20:19:10 +02:00
Elias Schneider
75fbfee4d8 chore(translations): add Italian 2025-04-17 19:13:47 +02:00
dependabot[bot]
65ee500ef3 chore(deps): bump golang.org/x/net from 0.36.0 to 0.38.0 in /backend in the go_modules group across 1 directory (#450)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-16 18:26:55 -05:00
Elias Schneider
80f108e5d6 release: 0.47.0 2025-04-16 16:32:27 +02:00
Elias Schneider
9b2d622990 tests: adapt JWTs in e2e tests 2025-04-16 16:30:38 +02:00
Elias Schneider
adf74586af fix: define token type as claim for better client compatibility 2025-04-16 15:58:38 +02:00
Kyle Mendell
b45cf68295 feat: disable animations setting toggle (#442)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-15 19:28:10 +00:00
dependabot[bot]
d9dd67c51f chore(deps-dev): bump @sveltejs/kit from 2.16.1 to 2.20.6 in /frontend in the npm_and_yarn group across 1 directory (#443)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-15 20:38:03 +02:00
Grégory Paul
abf17f6211 feat: add qrcode representation of one time link (#424) (#436)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Kyle Mendell <kmendell@outlook.com>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-14 13:16:46 +00:00
Elias Schneider
57cb8f8795 release: 0.46.0 2025-04-13 20:31:09 +02:00
Elias Schneider
fcb18b8c3c chore(translations): update translations via Crowdin (#427)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-13 20:30:43 +02:00
Alessandro (Ale) Segala
796bc7ed34 fix: improve LDAP error handling (#425)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-12 18:38:19 -04:00
Arne Skaar Fismen
72061ba427 feat(onboarding): Added button when you don't have a passkey added. (#426)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-12 02:27:01 +00:00
dependabot[bot]
d04167cada chore(deps-dev): bump vite from 6.2.5 to 6.2.6 in /frontend in the npm_and_yarn group across 1 directory (#433) 2025-04-11 20:07:40 -05:00
Alessandro (Ale) Segala
f83bab9e17 refactor: simplify app_config service and fix race conditions (#423) 2025-04-10 13:41:22 +02:00
Elias Schneider
4ba68938dd fix: ignore profile picture cache after profile picture gets updated 2025-04-09 15:51:58 +02:00
Elias Schneider
658a9ca6dd fix: add missing rollback for LDAP sync 2025-04-09 14:05:53 +02:00
Andreas Schneider
7e5d16be9b feat: implement token introspection (#405)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-09 07:18:03 +00:00
Elias Schneider
8d6c1e5c08 chore(translations): update translations via Crowdin (#420) 2025-04-09 02:09:01 -05:00
Elias Schneider
ce6e27d0ff refactor: rollback db changes with defer everywhere 2025-04-06 23:40:56 +02:00
Elias Schneider
3ebff09d63 chore(translations): update translations via Crowdin (#416)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-06 22:15:05 +02:00
Elias Schneider
ccc18d716f fix: use UUID for temporary file names 2025-04-06 15:11:19 +02:00
Alessandro (Ale) Segala
ec626ee797 fix: use transactions when operations involve multiple database queries (#392)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-06 15:04:08 +02:00
Kyle Mendell
c810fec8c4 docs: update swagger description to use markdown (#418) 2025-04-05 16:07:56 +02:00
Alessandro (Ale) Segala
9e88926283 fix: ensure indexes on audit_logs table (#415)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-04 17:05:32 +00:00
dependabot[bot]
731113183e chore(deps-dev): bump vite from 6.2.4 to 6.2.5 in /frontend in the npm_and_yarn group across 1 directory (#417)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-04 16:15:37 +00:00
Elias Schneider
4627f365a2 chore(translations): fix mistakes in source strings 2025-04-04 13:55:15 +02:00
Elias Schneider
1762629596 perf: run async operations in parallel in server load functions 2025-04-04 11:39:13 +02:00
Alessandro (Ale) Segala
2f7646105e fix: ensure file descriptors are closed + other bugs (#413) 2025-04-04 10:04:36 +02:00
Elias Schneider
980780e48b chore(translations): update translations via Crowdin (#414) 2025-04-04 09:06:44 +02:00
Kyle Mendell
b65e693e12 feat: global audit log (#320)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-03 10:11:49 -05:00
Kyle Mendell
734c6813ea fix: create reusable default profile pictures (#406)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-03 08:06:56 -05:00
dependabot[bot]
0d31c0ec6c chore(deps-dev): bump vite from 6.2.3 to 6.2.4 in /frontend in the npm_and_yarn group across 1 directory (#410)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-31 14:04:02 -05:00
jose_d
4806c1e09b chore(translations): improve czech translation strings (#408) 2025-03-31 08:22:06 -05:00
Elias Schneider
cf3084cfa8 refactor: remove cors exception from middleware as this is handled by the handler 2025-03-30 22:30:22 +02:00
Kyle Mendell
9881a1df9e feat: modernize ui (#381)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-30 13:19:14 -05:00
207 changed files with 13522 additions and 7188 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1,2 +1,2 @@
# These are supported funding model platforms # These are supported funding model platforms
github: stonith404 github: [stonith404, kmendell]

View File

@@ -27,7 +27,7 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v4 uses: actions/setup-go@v5
with: with:
go-version-file: backend/go.mod go-version-file: backend/go.mod

View File

@@ -19,21 +19,28 @@ jobs:
timeout-minutes: 20 timeout-minutes: 20
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: Build and export - name: Build and export
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:
tags: pocket-id/pocket-id:test push: false
load: false
tags: pocket-id:test
outputs: type=docker,dest=/tmp/docker-image.tar outputs: type=docker,dest=/tmp/docker-image.tar
build-args: BUILD_TAGS=e2etest build-args: BUILD_TAGS=e2etest
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Upload Docker image artifact - name: Upload Docker image artifact
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: docker-image name: docker-image
path: /tmp/docker-image.tar path: /tmp/docker-image.tar
retention-days: 1
test-sqlite: test-sqlite:
if: github.event.pull_request.head.ref != 'i18n_crowdin' if: github.event.pull_request.head.ref != 'i18n_crowdin'
@@ -47,13 +54,22 @@ jobs:
cache: "npm" cache: "npm"
cache-dependency-path: frontend/package-lock.json cache-dependency-path: frontend/package-lock.json
- name: Cache Playwright Browsers
uses: actions/cache@v3
id: playwright-cache
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-${{ hashFiles('frontend/package-lock.json') }}
restore-keys: |
${{ runner.os }}-playwright-
- name: Download Docker image artifact - name: Download Docker image artifact
uses: actions/download-artifact@v4 uses: actions/download-artifact@v4
with: with:
name: docker-image name: docker-image
path: /tmp path: /tmp
- name: Load Docker Image - name: Load Docker image
run: docker load -i /tmp/docker-image.tar run: docker load -i /tmp/docker-image.tar
- name: Install frontend dependencies - name: Install frontend dependencies
@@ -62,6 +78,7 @@ jobs:
- name: Install Playwright Browsers - name: Install Playwright Browsers
working-directory: ./frontend working-directory: ./frontend
if: steps.playwright-cache.outputs.cache-hit != 'true'
run: npx playwright install --with-deps chromium run: npx playwright install --with-deps chromium
- name: Run Docker Container with Sqlite DB - name: Run Docker Container with Sqlite DB
@@ -69,7 +86,7 @@ jobs:
docker run -d --name pocket-id-sqlite \ docker run -d --name pocket-id-sqlite \
-p 80:80 \ -p 80:80 \
-e APP_ENV=test \ -e APP_ENV=test \
pocket-id/pocket-id:test pocket-id:test
docker logs -f pocket-id-sqlite &> /tmp/backend.log & docker logs -f pocket-id-sqlite &> /tmp/backend.log &
@@ -77,16 +94,18 @@ jobs:
working-directory: ./frontend working-directory: ./frontend
run: npx playwright test run: npx playwright test
- uses: actions/upload-artifact@v4 - name: Upload Frontend Test Report
if: always() uses: actions/upload-artifact@v4
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
with: with:
name: playwright-report-sqlite name: playwright-report-sqlite
path: frontend/tests/.report path: frontend/tests/.report
include-hidden-files: true include-hidden-files: true
retention-days: 15 retention-days: 15
- uses: actions/upload-artifact@v4 - name: Upload Backend Test Report
if: always() uses: actions/upload-artifact@v4
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
with: with:
name: backend-sqlite name: backend-sqlite
path: /tmp/backend.log path: /tmp/backend.log
@@ -105,12 +124,39 @@ jobs:
cache: "npm" cache: "npm"
cache-dependency-path: frontend/package-lock.json cache-dependency-path: frontend/package-lock.json
- name: Cache Playwright Browsers
uses: actions/cache@v3
id: playwright-cache
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-${{ hashFiles('frontend/package-lock.json') }}
restore-keys: |
${{ runner.os }}-playwright-
- name: Cache PostgreSQL Docker image
uses: actions/cache@v3
id: postgres-cache
with:
path: /tmp/postgres-image.tar
key: postgres-17-${{ runner.os }}
- name: Pull and save PostgreSQL image
if: steps.postgres-cache.outputs.cache-hit != 'true'
run: |
docker pull postgres:17
docker save postgres:17 > /tmp/postgres-image.tar
- name: Load PostgreSQL image from cache
if: steps.postgres-cache.outputs.cache-hit == 'true'
run: docker load < /tmp/postgres-image.tar
- name: Download Docker image artifact - name: Download Docker image artifact
uses: actions/download-artifact@v4 uses: actions/download-artifact@v4
with: with:
name: docker-image name: docker-image
path: /tmp path: /tmp
- name: Load Docker Image
- name: Load Docker image
run: docker load -i /tmp/docker-image.tar run: docker load -i /tmp/docker-image.tar
- name: Install frontend dependencies - name: Install frontend dependencies
@@ -119,6 +165,7 @@ jobs:
- name: Install Playwright Browsers - name: Install Playwright Browsers
working-directory: ./frontend working-directory: ./frontend
if: steps.playwright-cache.outputs.cache-hit != 'true'
run: npx playwright install --with-deps chromium run: npx playwright install --with-deps chromium
- name: Create Docker network - name: Create Docker network
@@ -153,7 +200,7 @@ jobs:
-e APP_ENV=test \ -e APP_ENV=test \
-e DB_PROVIDER=postgres \ -e DB_PROVIDER=postgres \
-e DB_CONNECTION_STRING=postgresql://postgres:postgres@pocket-id-db:5432/pocket-id \ -e DB_CONNECTION_STRING=postgresql://postgres:postgres@pocket-id-db:5432/pocket-id \
pocket-id/pocket-id:test pocket-id:test
docker logs -f pocket-id-postgres &> /tmp/backend.log & docker logs -f pocket-id-postgres &> /tmp/backend.log &
@@ -161,7 +208,8 @@ jobs:
working-directory: ./frontend working-directory: ./frontend
run: npx playwright test run: npx playwright test
- uses: actions/upload-artifact@v4 - name: Upload Frontend Test Report
uses: actions/upload-artifact@v4
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin' if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
with: with:
name: playwright-report-postgres name: playwright-report-postgres
@@ -169,8 +217,9 @@ jobs:
include-hidden-files: true include-hidden-files: true
retention-days: 15 retention-days: 15
- uses: actions/upload-artifact@v4 - name: Upload Backend Test Report
if: always() uses: actions/upload-artifact@v4
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
with: with:
name: backend-postgres name: backend-postgres
path: /tmp/backend.log path: /tmp/backend.log

View File

@@ -1 +1 @@
0.45.0 0.51.0

4
.vscode/launch.json vendored
View File

@@ -5,7 +5,7 @@
"name": "Backend", "name": "Backend",
"type": "go", "type": "go",
"request": "launch", "request": "launch",
"envFile": "${workspaceFolder}/backend/.env.example", "envFile": "${workspaceFolder}/backend/cmd/.env",
"env": { "env": {
"APP_ENV": "development" "APP_ENV": "development"
}, },
@@ -16,7 +16,7 @@
"name": "Frontend", "name": "Frontend",
"type": "node", "type": "node",
"request": "launch", "request": "launch",
"envFile": "${workspaceFolder}/frontend/.env.example", "envFile": "${workspaceFolder}/frontend/.env",
"cwd": "${workspaceFolder}/frontend", "cwd": "${workspaceFolder}/frontend",
"runtimeExecutable": "npm", "runtimeExecutable": "npm",
"runtimeArgs": [ "runtimeArgs": [

View File

@@ -1,3 +1,108 @@
## [](https://github.com/pocket-id/pocket-id/compare/v0.50.0...v) (2025-04-28)
### Features
* new login code card position for mobile devices ([#452](https://github.com/pocket-id/pocket-id/issues/452)) ([02cacba](https://github.com/pocket-id/pocket-id/commit/02cacba5c5524481684cb0e1790811df113a9481))
### Bug Fixes
* do not require PKCE for public clients ([ce24372](https://github.com/pocket-id/pocket-id/commit/ce24372c571cc3b277095dc6a4107663d64f45b3))
* hide global audit log switch for non admin users ([1efd1d1](https://github.com/pocket-id/pocket-id/commit/1efd1d182dbb6190d3c7e27034426c9e48781b4a))
* return correct error message if user isn't authorized ([86d2b5f](https://github.com/pocket-id/pocket-id/commit/86d2b5f59f26cb944017826cbd8df915cdc986f1))
* updating scopes of an authorized client fails with Postgres ([0a24ab8](https://github.com/pocket-id/pocket-id/commit/0a24ab80010eb5a15d99915802c6698274a5c57c))
## [](https://github.com/pocket-id/pocket-id/compare/v0.49.0...v) (2025-04-27)
### Features
* device authorization endpoint ([#270](https://github.com/pocket-id/pocket-id/issues/270)) ([22f7d64](https://github.com/pocket-id/pocket-id/commit/22f7d64bf08a5a1ecbe5eee0052453b730f5c360))
* make family name optional ([#476](https://github.com/pocket-id/pocket-id/issues/476)) ([630327c](https://github.com/pocket-id/pocket-id/commit/630327c979de2f931b9d1f0ba0b4a4de1af3fc7c))
### Bug Fixes
* do not override XDG_DATA_HOME/XDG_CONFIG_HOME if they are already set ([#472](https://github.com/pocket-id/pocket-id/issues/472)) ([22725d3](https://github.com/pocket-id/pocket-id/commit/22725d30f4115ffe17625379f56affedfe116778))
* pass context to methods that were missing it ([#487](https://github.com/pocket-id/pocket-id/issues/487)) ([4c33793](https://github.com/pocket-id/pocket-id/commit/4c33793678709eb4981be2c1fd5803bace5f5939))
* prevent deadlock when trying to delete LDAP users ([#471](https://github.com/pocket-id/pocket-id/issues/471)) ([270c303](https://github.com/pocket-id/pocket-id/commit/270c30334dc36f215a67f873283a9d6fcd14d065))
* rootless Caddy data and configuration ([#470](https://github.com/pocket-id/pocket-id/issues/470)) ([76b753f](https://github.com/pocket-id/pocket-id/commit/76b753f9f2a6a4f1af09359530e30844b03ac39b))
## [](https://github.com/pocket-id/pocket-id/compare/v0.48.0...v) (2025-04-20)
### Features
* add ability to disable API key expiration email ([9122e75](https://github.com/pocket-id/pocket-id/commit/9122e75101ad39a40135ccf931eb2bfd351b5db6))
* add ability to send login code via email ([#457](https://github.com/pocket-id/pocket-id/issues/457)) ([fe1c4b1](https://github.com/pocket-id/pocket-id/commit/fe1c4b18cdcc46a4256e0c111b34f1ce00f8e0e1))
* add description to callback URL inputs ([eb689eb](https://github.com/pocket-id/pocket-id/commit/eb689eb56ec9eaf8b0fb1485040e26f841b9225d))
* send email to user when api key expires within 7 days ([#451](https://github.com/pocket-id/pocket-id/issues/451)) ([26f01f2](https://github.com/pocket-id/pocket-id/commit/26f01f205be01fb8abd8c2e564c90c0fc4480ea5))
### Bug Fixes
* disable animations not respected on authorize and logout page ([e571996](https://github.com/pocket-id/pocket-id/commit/e571996cb57d04232c1f47ab337ad656f48bb3cb))
* hide alternative sign in button if user is already authenticated ([4e05b82](https://github.com/pocket-id/pocket-id/commit/4e05b82f02740a4bae07cec6c6a64acd34ca0fc3))
* locale change in dropdown doesn't work on first try ([60bad9e](https://github.com/pocket-id/pocket-id/commit/60bad9e9859d81c9967e6939e1ed10a65145a936))
* remove limit of 20 callback URLs ([c37a3e0](https://github.com/pocket-id/pocket-id/commit/c37a3e0ed177c3bd2b9a618d1f4b0709004478b0))
## [](https://github.com/pocket-id/pocket-id/compare/v0.47.0...v) (2025-04-18)
### Features
* add gif support for logo and background image ([56a8b5d](https://github.com/pocket-id/pocket-id/commit/56a8b5d0c02643f869b77cf8475ddf2f9473880b))
* disable/enable users ([#437](https://github.com/pocket-id/pocket-id/issues/437)) ([c843a60](https://github.com/pocket-id/pocket-id/commit/c843a60131b813177b1e270c4f5d97613c700efa))
### Bug Fixes
* add "type" as reserved claim ([0111a58](https://github.com/pocket-id/pocket-id/commit/0111a58dac0342c5ac2fa25a050e8773810d2b0a))
* callback URL doesn't get rejected if it starts with a different string ([f0dce41](https://github.com/pocket-id/pocket-id/commit/f0dce41fbc5649b3a8fe65de36ca20efa521b880))
* profile picture empty for users without first or last name ([#449](https://github.com/pocket-id/pocket-id/issues/449)) ([5a6dfd9](https://github.com/pocket-id/pocket-id/commit/5a6dfd9e505f4c84e91b4b378b082fab10e8a8a8))
* user querying fails on global audit log page with Postgres ([84f1d5c](https://github.com/pocket-id/pocket-id/commit/84f1d5c906ec3f9a74ad3d2f36526eea847af5dd))
## [](https://github.com/pocket-id/pocket-id/compare/v0.46.0...v) (2025-04-16)
### Features
* add qrcode representation of one time link ([#424](https://github.com/pocket-id/pocket-id/issues/424)) ([#436](https://github.com/pocket-id/pocket-id/issues/436)) ([abf17f6](https://github.com/pocket-id/pocket-id/commit/abf17f62114a2de549b62cec462b9b0659ee23a7))
* disable animations setting toggle ([#442](https://github.com/pocket-id/pocket-id/issues/442)) ([b45cf68](https://github.com/pocket-id/pocket-id/commit/b45cf68295975f51777dab95950b98b8db0a9ae5))
### Bug Fixes
* define token type as claim for better client compatibility ([adf7458](https://github.com/pocket-id/pocket-id/commit/adf74586afb6ef9a00fb122c150b0248c5bc23f0))
## [](https://github.com/pocket-id/pocket-id/compare/v0.45.0...v) (2025-04-13)
### Features
* global audit log ([#320](https://github.com/pocket-id/pocket-id/issues/320)) ([b65e693](https://github.com/pocket-id/pocket-id/commit/b65e693e12be2e7e4cb75a74d6fd43bacb3f6a94))
* implement token introspection ([#405](https://github.com/pocket-id/pocket-id/issues/405)) ([7e5d16b](https://github.com/pocket-id/pocket-id/commit/7e5d16be9bdfccfa113924547e313886681d11bb))
* modernize ui ([#381](https://github.com/pocket-id/pocket-id/issues/381)) ([9881a1d](https://github.com/pocket-id/pocket-id/commit/9881a1df9efe32608ab116db71c0e4f66dae171c))
* **onboarding:** Added button when you don't have a passkey added. ([#426](https://github.com/pocket-id/pocket-id/issues/426)) ([72061ba](https://github.com/pocket-id/pocket-id/commit/72061ba4278a007437cee3a205c3076d58bde644))
### Bug Fixes
* add missing rollback for LDAP sync ([658a9ca](https://github.com/pocket-id/pocket-id/commit/658a9ca6dd8d2304ff3639a000bab02e91ff68a6))
* create reusable default profile pictures ([#406](https://github.com/pocket-id/pocket-id/issues/406)) ([734c681](https://github.com/pocket-id/pocket-id/commit/734c6813eaef166235ae801747e3652d17ae0e2a))
* ensure file descriptors are closed + other bugs ([#413](https://github.com/pocket-id/pocket-id/issues/413)) ([2f76461](https://github.com/pocket-id/pocket-id/commit/2f7646105e26423f47cbe49dae97e40c4a01a025))
* ensure indexes on audit_logs table ([#415](https://github.com/pocket-id/pocket-id/issues/415)) ([9e88926](https://github.com/pocket-id/pocket-id/commit/9e88926283a7a663bfc7fd4f4aa16bd02f614176))
* ignore profile picture cache after profile picture gets updated ([4ba6893](https://github.com/pocket-id/pocket-id/commit/4ba68938dd2a631c633fcb65d8c35cb039d3f59c))
* improve LDAP error handling ([#425](https://github.com/pocket-id/pocket-id/issues/425)) ([796bc7e](https://github.com/pocket-id/pocket-id/commit/796bc7ed3453839b1dc8d846b71fe9fac9a2d646))
* use transactions when operations involve multiple database queries ([#392](https://github.com/pocket-id/pocket-id/issues/392)) ([ec626ee](https://github.com/pocket-id/pocket-id/commit/ec626ee7977306539fd1d70cc9091590f0a54af6))
* use UUID for temporary file names ([ccc18d7](https://github.com/pocket-id/pocket-id/commit/ccc18d716f16a7ef1775d30982e2ba7b5ff159a6))
### Performance Improvements
* run async operations in parallel in server load functions ([1762629](https://github.com/pocket-id/pocket-id/commit/17626295964244c5582806bd0f413da2c799d5ad))
## [](https://github.com/pocket-id/pocket-id/compare/v0.44.0...v) (2025-03-29) ## [](https://github.com/pocket-id/pocket-id/compare/v0.44.0...v) (2025-03-29)

View File

@@ -11,7 +11,7 @@ RUN npm run build
RUN npm prune --production RUN npm prune --production
# Stage 2: Build Backend # Stage 2: Build Backend
FROM golang:1.23-alpine AS backend-builder FROM golang:1.24-alpine AS backend-builder
ARG BUILD_TAGS ARG BUILD_TAGS
WORKDIR /app/backend WORKDIR /app/backend
COPY ./backend/go.mod ./backend/go.sum ./ COPY ./backend/go.mod ./backend/go.sum ./

View File

@@ -4,6 +4,10 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/bootstrap" "github.com/pocket-id/pocket-id/backend/internal/bootstrap"
) )
// @title Pocket ID API
// @version 1.0
// @description.markdown
func main() { func main() {
bootstrap.Bootstrap() bootstrap.Bootstrap()
} }

View File

@@ -1,6 +1,6 @@
module github.com/pocket-id/pocket-id/backend module github.com/pocket-id/pocket-id/backend
go 1.23.7 go 1.24.0
require ( require (
github.com/caarlos0/env/v11 v11.3.1 github.com/caarlos0/env/v11 v11.3.1
@@ -79,7 +79,7 @@ require (
go.uber.org/atomic v1.11.0 // indirect go.uber.org/atomic v1.11.0 // indirect
golang.org/x/arch v0.13.0 // indirect golang.org/x/arch v0.13.0 // indirect
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect
golang.org/x/net v0.36.0 // indirect golang.org/x/net v0.38.0 // indirect
golang.org/x/sync v0.12.0 // indirect golang.org/x/sync v0.12.0 // indirect
golang.org/x/sys v0.31.0 // indirect golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect golang.org/x/text v0.23.0 // indirect

View File

@@ -255,8 +255,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.36.0/go.mod h1:bFmbeoIPfrw4sMHNhb4J9f6+tPziuGjq7Jk/38fxi1I= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@@ -38,7 +38,6 @@ func initApplicationImages() {
log.Fatalf("Error copying file: %v", err) log.Fatalf("Error copying file: %v", err)
} }
} }
} }
func imageAlreadyExists(fileName string, destinationFiles []os.DirEntry) bool { func imageAlreadyExists(fileName string, destinationFiles []os.DirEntry) bool {
@@ -55,6 +54,11 @@ func imageAlreadyExists(fileName string, destinationFiles []os.DirEntry) bool {
} }
func getImageNameWithoutExtension(fileName string) string { func getImageNameWithoutExtension(fileName string) string {
splitted := strings.Split(fileName, ".") idx := strings.LastIndexByte(fileName, '.')
return strings.Join(splitted[:len(splitted)-1], ".") if idx < 1 {
// No dot found, or fileName starts with a dot
return fileName
}
return fileName[:idx]
} }

View File

@@ -1,19 +1,26 @@
package bootstrap package bootstrap
import ( import (
"context"
_ "github.com/golang-migrate/migrate/v4/source/file" _ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/pocket-id/pocket-id/backend/internal/service" "github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils/signals"
) )
func Bootstrap() { func Bootstrap() {
// Get a context that is canceled when the application is stopping
ctx := signals.SignalContext(context.Background())
initApplicationImages() initApplicationImages()
migrateConfigDBConnstring() migrateConfigDBConnstring()
db := newDatabase() db := newDatabase()
appConfigService := service.NewAppConfigService(db) appConfigService := service.NewAppConfigService(ctx, db)
migrateKey() migrateKey()
initRouter(db, appConfigService) initRouter(ctx, db, appConfigService)
} }

View File

@@ -1,8 +1,11 @@
package bootstrap package bootstrap
import ( import (
"context"
"fmt"
"log" "log"
"net" "net"
"net/http"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -19,10 +22,14 @@ import (
// This is used to register additional controllers for tests // This is used to register additional controllers for tests
var registerTestControllers []func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService) var registerTestControllers []func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService)
// @title Pocket ID API func initRouter(ctx context.Context, db *gorm.DB, appConfigService *service.AppConfigService) {
// @version 1 err := initRouterInternal(ctx, db, appConfigService)
// @description API for Pocket ID if err != nil {
func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { log.Fatalf("failed to init router: %v", err)
}
}
func initRouterInternal(ctx context.Context, db *gorm.DB, appConfigService *service.AppConfigService) error {
// Set the appropriate Gin mode based on the environment // Set the appropriate Gin mode based on the environment
switch common.EnvConfig.AppEnv { switch common.EnvConfig.AppEnv {
case "production": case "production":
@@ -39,10 +46,10 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
// Initialize services // Initialize services
emailService, err := service.NewEmailService(appConfigService, db) emailService, err := service.NewEmailService(appConfigService, db)
if err != nil { if err != nil {
log.Fatalf("Unable to create email service: %s", err) return fmt.Errorf("unable to create email service: %w", err)
} }
geoLiteService := service.NewGeoLiteService() geoLiteService := service.NewGeoLiteService(ctx)
auditLogService := service.NewAuditLogService(db, appConfigService, emailService, geoLiteService) auditLogService := service.NewAuditLogService(db, appConfigService, emailService, geoLiteService)
jwtService := service.NewJwtService(appConfigService) jwtService := service.NewJwtService(appConfigService)
webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService) webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService)
@@ -51,7 +58,7 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService) oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService)
userGroupService := service.NewUserGroupService(db, appConfigService) userGroupService := service.NewUserGroupService(db, appConfigService)
ldapService := service.NewLdapService(db, appConfigService, userService, userGroupService) ldapService := service.NewLdapService(db, appConfigService, userService, userGroupService)
apiKeyService := service.NewApiKeyService(db) apiKeyService := service.NewApiKeyService(db, emailService)
rateLimitMiddleware := middleware.NewRateLimitMiddleware() rateLimitMiddleware := middleware.NewRateLimitMiddleware()
@@ -60,11 +67,33 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
r.Use(middleware.NewErrorHandlerMiddleware().Add()) r.Use(middleware.NewErrorHandlerMiddleware().Add())
r.Use(rateLimitMiddleware.Add(rate.Every(time.Second), 60)) r.Use(rateLimitMiddleware.Add(rate.Every(time.Second), 60))
job.RegisterLdapJobs(ldapService, appConfigService) scheduler, err := job.NewScheduler()
job.RegisterDbCleanupJobs(db) if err != nil {
return fmt.Errorf("failed to create job scheduler: %w", err)
}
err = scheduler.RegisterLdapJobs(ctx, ldapService, appConfigService)
if err != nil {
return fmt.Errorf("failed to register LDAP jobs in scheduler: %w", err)
}
err = scheduler.RegisterDbCleanupJobs(ctx, db)
if err != nil {
return fmt.Errorf("failed to register DB cleanup jobs in scheduler: %w", err)
}
err = scheduler.RegisterFileCleanupJobs(ctx, db)
if err != nil {
return fmt.Errorf("failed to register file cleanup jobs in scheduler: %w", err)
}
err = scheduler.RegisterApiKeyExpiryJob(ctx, apiKeyService, appConfigService)
if err != nil {
return fmt.Errorf("failed to register API key expiration jobs in scheduler: %w", err)
}
// Run the scheduler in a background goroutine, until the context is canceled
go scheduler.Run(ctx)
// Initialize middleware for specific routes // Initialize middleware for specific routes
authMiddleware := middleware.NewAuthMiddleware(apiKeyService, jwtService) authMiddleware := middleware.NewAuthMiddleware(apiKeyService, userService, jwtService)
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware() fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()
// Set up API routes // Set up API routes
@@ -89,20 +118,52 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
baseGroup := r.Group("/") baseGroup := r.Group("/")
controller.NewWellKnownController(baseGroup, jwtService) controller.NewWellKnownController(baseGroup, jwtService)
// Get the listener // Set up the server
l, err := net.Listen("tcp", common.EnvConfig.Host+":"+common.EnvConfig.Port) srv := &http.Server{
if err != nil { Addr: net.JoinHostPort(common.EnvConfig.Host, common.EnvConfig.Port),
log.Fatal(err) MaxHeaderBytes: 1 << 20,
ReadHeaderTimeout: 10 * time.Second,
Handler: r,
} }
// Set up the listener
listener, err := net.Listen("tcp", srv.Addr)
if err != nil {
return fmt.Errorf("failed to create TCP listener: %w", err)
}
log.Printf("Server listening on %s", srv.Addr)
// Notify systemd that we are ready // Notify systemd that we are ready
if err := systemd.SdNotifyReady(); err != nil { err = systemd.SdNotifyReady()
log.Println("Unable to notify systemd that the service is ready: ", err) if err != nil {
log.Printf("[WARN] Unable to notify systemd that the service is ready: %v", err)
// continue to serve anyway since it's not that important // continue to serve anyway since it's not that important
} }
// Serve requests // Start the server in a background goroutine
if err := r.RunListener(l); err != nil { go func() {
log.Fatal(err) defer listener.Close()
// Next call blocks until the server is shut down
srvErr := srv.Serve(listener)
if srvErr != http.ErrServerClosed {
log.Fatalf("Error starting app server: %v", srvErr)
} }
}()
// Block until the context is canceled
<-ctx.Done()
// Handle graceful shutdown
// Note we use the background context here as ctx has been canceled already
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
shutdownErr := srv.Shutdown(shutdownCtx) //nolint:contextcheck
shutdownCancel()
if shutdownErr != nil {
// Log the error only (could be context canceled)
log.Printf("[WARN] App server shutdown error: %v", shutdownErr)
}
return nil
} }

View File

@@ -1,6 +1,7 @@
package common package common
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
) )
@@ -17,10 +18,16 @@ type AlreadyInUseError struct {
} }
func (e *AlreadyInUseError) Error() string { func (e *AlreadyInUseError) Error() string {
return fmt.Sprintf("%s is already in use", e.Property) return e.Property + " is already in use"
} }
func (e *AlreadyInUseError) HttpStatusCode() int { return 400 } func (e *AlreadyInUseError) HttpStatusCode() int { return 400 }
func (e *AlreadyInUseError) Is(target error) bool {
// Ignore the field property when checking if an error is of the type AlreadyInUseError
x := &AlreadyInUseError{}
return errors.As(target, &x)
}
type SetupAlreadyCompletedError struct{} type SetupAlreadyCompletedError struct{}
func (e *SetupAlreadyCompletedError) Error() string { return "setup already completed" } func (e *SetupAlreadyCompletedError) Error() string { return "setup already completed" }
@@ -75,11 +82,6 @@ type FileTypeNotSupportedError struct{}
func (e *FileTypeNotSupportedError) Error() string { return "file type not supported" } func (e *FileTypeNotSupportedError) Error() string { return "file type not supported" }
func (e *FileTypeNotSupportedError) HttpStatusCode() int { return 400 } func (e *FileTypeNotSupportedError) HttpStatusCode() int { return 400 }
type InvalidCredentialsError struct{}
func (e *InvalidCredentialsError) Error() string { return "no user found with provided credentials" }
func (e *InvalidCredentialsError) HttpStatusCode() int { return 400 }
type FileTooLargeError struct { type FileTooLargeError struct {
MaxSize string MaxSize string
} }
@@ -222,8 +224,7 @@ type InvalidUUIDError struct{}
func (e *InvalidUUIDError) Error() string { func (e *InvalidUUIDError) Error() string {
return "Invalid UUID" return "Invalid UUID"
} }
func (e *InvalidUUIDError) HttpStatusCode() int { return http.StatusBadRequest }
type InvalidEmailError struct{}
type OneTimeAccessDisabledError struct{} type OneTimeAccessDisabledError struct{}
@@ -237,31 +238,34 @@ type InvalidAPIKeyError struct{}
func (e *InvalidAPIKeyError) Error() string { func (e *InvalidAPIKeyError) Error() string {
return "Invalid Api Key" return "Invalid Api Key"
} }
func (e *InvalidAPIKeyError) HttpStatusCode() int { return http.StatusUnauthorized }
type NoAPIKeyProvidedError struct{} type NoAPIKeyProvidedError struct{}
func (e *NoAPIKeyProvidedError) Error() string { func (e *NoAPIKeyProvidedError) Error() string {
return "No API Key Provided" return "No API Key Provided"
} }
func (e *NoAPIKeyProvidedError) HttpStatusCode() int { return http.StatusUnauthorized }
type APIKeyNotFoundError struct{} type APIKeyNotFoundError struct{}
func (e *APIKeyNotFoundError) Error() string { func (e *APIKeyNotFoundError) Error() string {
return "API Key Not Found" return "API Key Not Found"
} }
func (e *APIKeyNotFoundError) HttpStatusCode() int { return http.StatusUnauthorized }
type APIKeyExpirationDateError struct{} type APIKeyExpirationDateError struct{}
func (e *APIKeyExpirationDateError) Error() string { func (e *APIKeyExpirationDateError) Error() string {
return "API Key expiration time must be in the future" return "API Key expiration time must be in the future"
} }
func (e *APIKeyExpirationDateError) HttpStatusCode() int { return http.StatusBadRequest }
type OidcInvalidRefreshTokenError struct{} type OidcInvalidRefreshTokenError struct{}
func (e *OidcInvalidRefreshTokenError) Error() string { func (e *OidcInvalidRefreshTokenError) Error() string {
return "refresh token is invalid or expired" return "refresh token is invalid or expired"
} }
func (e *OidcInvalidRefreshTokenError) HttpStatusCode() int { func (e *OidcInvalidRefreshTokenError) HttpStatusCode() int {
return http.StatusBadRequest return http.StatusBadRequest
} }
@@ -271,7 +275,6 @@ type OidcMissingRefreshTokenError struct{}
func (e *OidcMissingRefreshTokenError) Error() string { func (e *OidcMissingRefreshTokenError) Error() string {
return "refresh token is required" return "refresh token is required"
} }
func (e *OidcMissingRefreshTokenError) HttpStatusCode() int { func (e *OidcMissingRefreshTokenError) HttpStatusCode() int {
return http.StatusBadRequest return http.StatusBadRequest
} }
@@ -281,7 +284,63 @@ type OidcMissingAuthorizationCodeError struct{}
func (e *OidcMissingAuthorizationCodeError) Error() string { func (e *OidcMissingAuthorizationCodeError) Error() string {
return "authorization code is required" return "authorization code is required"
} }
func (e *OidcMissingAuthorizationCodeError) HttpStatusCode() int { func (e *OidcMissingAuthorizationCodeError) HttpStatusCode() int {
return http.StatusBadRequest return http.StatusBadRequest
} }
type UserDisabledError struct{}
func (e *UserDisabledError) Error() string {
return "User account is disabled"
}
func (e *UserDisabledError) HttpStatusCode() int {
return http.StatusForbidden
}
type ValidationError struct {
Message string
}
func (e *ValidationError) Error() string {
return e.Message
}
func (e *ValidationError) HttpStatusCode() int {
return http.StatusBadRequest
}
type OidcDeviceCodeExpiredError struct{}
func (e *OidcDeviceCodeExpiredError) Error() string {
return "device code has expired"
}
func (e *OidcDeviceCodeExpiredError) HttpStatusCode() int {
return http.StatusBadRequest
}
type OidcInvalidDeviceCodeError struct{}
func (e *OidcInvalidDeviceCodeError) Error() string {
return "invalid device code"
}
func (e *OidcInvalidDeviceCodeError) HttpStatusCode() int {
return http.StatusBadRequest
}
type OidcSlowDownError struct{}
func (e *OidcSlowDownError) Error() string {
return "polling too frequently"
}
func (e *OidcSlowDownError) HttpStatusCode() int {
return http.StatusTooManyRequests
}
type OidcAuthorizationPendingError struct{}
func (e *OidcAuthorizationPendingError) Error() string {
return "authorization is still pending"
}
func (e *OidcAuthorizationPendingError) HttpStatusCode() int {
return http.StatusBadRequest
}

View File

@@ -53,7 +53,7 @@ func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) {
return return
} }
apiKeys, pagination, err := c.apiKeyService.ListApiKeys(userID, sortedPaginationRequest) apiKeys, pagination, err := c.apiKeyService.ListApiKeys(ctx.Request.Context(), userID, sortedPaginationRequest)
if err != nil { if err != nil {
_ = ctx.Error(err) _ = ctx.Error(err)
return return
@@ -87,7 +87,7 @@ func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) {
return return
} }
apiKey, token, err := c.apiKeyService.CreateApiKey(userID, input) apiKey, token, err := c.apiKeyService.CreateApiKey(ctx.Request.Context(), userID, input)
if err != nil { if err != nil {
_ = ctx.Error(err) _ = ctx.Error(err)
return return
@@ -116,7 +116,7 @@ func (c *ApiKeyController) revokeApiKeyHandler(ctx *gin.Context) {
userID := ctx.GetString("userID") userID := ctx.GetString("userID")
apiKeyID := ctx.Param("id") apiKeyID := ctx.Param("id")
if err := c.apiKeyService.RevokeApiKey(userID, apiKeyID); err != nil { if err := c.apiKeyService.RevokeApiKey(ctx.Request.Context(), userID, apiKeyID); err != nil {
_ = ctx.Error(err) _ = ctx.Error(err)
return return
} }

View File

@@ -1,7 +1,6 @@
package controller package controller
import ( import (
"fmt"
"net/http" "net/http"
"strconv" "strconv"
@@ -61,11 +60,7 @@ type AppConfigController struct {
// @Failure 500 {object} object "{"error": "error message"}" // @Failure 500 {object} object "{"error": "error message"}"
// @Router /application-configuration [get] // @Router /application-configuration [get]
func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) { func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListAppConfig(false) configuration := acc.appConfigService.ListAppConfig(false)
if err != nil {
_ = c.Error(err)
return
}
var configVariablesDto []dto.PublicAppConfigVariableDto var configVariablesDto []dto.PublicAppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil { if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
@@ -73,7 +68,7 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
return return
} }
c.JSON(200, configVariablesDto) c.JSON(http.StatusOK, configVariablesDto)
} }
// listAllAppConfigHandler godoc // listAllAppConfigHandler godoc
@@ -86,11 +81,7 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
// @Security BearerAuth // @Security BearerAuth
// @Router /application-configuration/all [get] // @Router /application-configuration/all [get]
func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) { func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListAppConfig(true) configuration := acc.appConfigService.ListAppConfig(true)
if err != nil {
_ = c.Error(err)
return
}
var configVariablesDto []dto.AppConfigVariableDto var configVariablesDto []dto.AppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil { if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
@@ -98,7 +89,7 @@ func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
return return
} }
c.JSON(200, configVariablesDto) c.JSON(http.StatusOK, configVariablesDto)
} }
// updateAppConfigHandler godoc // updateAppConfigHandler godoc
@@ -118,7 +109,7 @@ func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
return return
} }
savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(input) savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(c.Request.Context(), input)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -144,17 +135,17 @@ func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
// @Success 200 {file} binary "Logo image" // @Success 200 {file} binary "Logo image"
// @Router /api/application-configuration/logo [get] // @Router /api/application-configuration/logo [get]
func (acc *AppConfigController) getLogoHandler(c *gin.Context) { func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
dbConfig := acc.appConfigService.GetDbConfig()
lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true")) lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
var imageName string var imageName, imageType string
var imageType string
if lightLogo { if lightLogo {
imageName = "logoLight" imageName = "logoLight"
imageType = acc.appConfigService.DbConfig.LogoLightImageType.Value imageType = dbConfig.LogoLightImageType.Value
} else { } else {
imageName = "logoDark" imageName = "logoDark"
imageType = acc.appConfigService.DbConfig.LogoDarkImageType.Value imageType = dbConfig.LogoDarkImageType.Value
} }
acc.getImage(c, imageName, imageType) acc.getImage(c, imageName, imageType)
@@ -182,7 +173,7 @@ func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
// @Failure 404 {object} object "{"error": "File not found"}" // @Failure 404 {object} object "{"error": "File not found"}"
// @Router /api/application-configuration/background-image [get] // @Router /api/application-configuration/background-image [get]
func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) { func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value imageType := acc.appConfigService.GetDbConfig().BackgroundImageType.Value
acc.getImage(c, "background", imageType) acc.getImage(c, "background", imageType)
} }
@@ -197,17 +188,17 @@ func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
// @Security BearerAuth // @Security BearerAuth
// @Router /api/application-configuration/logo [put] // @Router /api/application-configuration/logo [put]
func (acc *AppConfigController) updateLogoHandler(c *gin.Context) { func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
dbConfig := acc.appConfigService.GetDbConfig()
lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true")) lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
var imageName string var imageName, imageType string
var imageType string
if lightLogo { if lightLogo {
imageName = "logoLight" imageName = "logoLight"
imageType = acc.appConfigService.DbConfig.LogoLightImageType.Value imageType = dbConfig.LogoLightImageType.Value
} else { } else {
imageName = "logoDark" imageName = "logoDark"
imageType = acc.appConfigService.DbConfig.LogoDarkImageType.Value imageType = dbConfig.LogoDarkImageType.Value
} }
acc.updateImage(c, imageName, imageType) acc.updateImage(c, imageName, imageType)
@@ -247,13 +238,13 @@ func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
// @Security BearerAuth // @Security BearerAuth
// @Router /api/application-configuration/background-image [put] // @Router /api/application-configuration/background-image [put]
func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) { func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value imageType := acc.appConfigService.GetDbConfig().BackgroundImageType.Value
acc.updateImage(c, "background", imageType) acc.updateImage(c, "background", imageType)
} }
// getImage is a helper function to serve image files // getImage is a helper function to serve image files
func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) { func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) {
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType) imagePath := common.EnvConfig.UploadPath + "/application-images/" + name + "." + imageType
mimeType := utils.GetImageMimeType(imageType) mimeType := utils.GetImageMimeType(imageType)
c.Header("Content-Type", mimeType) c.Header("Content-Type", mimeType)
@@ -268,7 +259,7 @@ func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, ol
return return
} }
err = acc.appConfigService.UpdateImage(file, imageName, oldImageType) err = acc.appConfigService.UpdateImage(c.Request.Context(), file, imageName, oldImageType)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -285,7 +276,7 @@ func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, ol
// @Security BearerAuth // @Security BearerAuth
// @Router /api/application-configuration/sync-ldap [post] // @Router /api/application-configuration/sync-ldap [post]
func (acc *AppConfigController) syncLdapHandler(c *gin.Context) { func (acc *AppConfigController) syncLdapHandler(c *gin.Context) {
err := acc.ldapService.SyncAll() err := acc.ldapService.SyncAll(c.Request.Context())
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -304,7 +295,7 @@ func (acc *AppConfigController) syncLdapHandler(c *gin.Context) {
func (acc *AppConfigController) testEmailHandler(c *gin.Context) { func (acc *AppConfigController) testEmailHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
err := acc.emailService.SendTestEmail(userID) err := acc.emailService.SendTestEmail(c.Request.Context(), userID)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return

View File

@@ -20,7 +20,10 @@ func NewAuditLogController(group *gin.RouterGroup, auditLogService *service.Audi
auditLogService: auditLogService, auditLogService: auditLogService,
} }
group.GET("/audit-logs/all", authMiddleware.Add(), alc.listAllAuditLogsHandler)
group.GET("/audit-logs", authMiddleware.WithAdminNotRequired().Add(), alc.listAuditLogsForUserHandler) group.GET("/audit-logs", authMiddleware.WithAdminNotRequired().Add(), alc.listAuditLogsForUserHandler)
group.GET("/audit-logs/filters/client-names", authMiddleware.Add(), alc.listClientNamesHandler)
group.GET("/audit-logs/filters/users", authMiddleware.Add(), alc.listUserNamesWithIdsHandler)
} }
type AuditLogController struct { type AuditLogController struct {
@@ -39,7 +42,9 @@ type AuditLogController struct {
// @Router /api/audit-logs [get] // @Router /api/audit-logs [get]
func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) { func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
var sortedPaginationRequest utils.SortedPaginationRequest var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
err := c.ShouldBindQuery(&sortedPaginationRequest)
if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
} }
@@ -47,7 +52,7 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
// Fetch audit logs for the user // Fetch audit logs for the user
logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(userID, sortedPaginationRequest) logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(c.Request.Context(), userID, sortedPaginationRequest)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -72,3 +77,86 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
Pagination: pagination, Pagination: pagination,
}) })
} }
// listAllAuditLogsHandler godoc
// @Summary List all audit logs
// @Description Get a paginated list of all audit logs (admin only)
// @Tags Audit Logs
// @Param page query int false "Page number, starting from 1" default(1)
// @Param limit query int false "Number of items per page" default(10)
// @Param sort_column query string false "Column to sort by" default("created_at")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
// @Param user_id query string false "Filter by user ID"
// @Param event query string false "Filter by event type"
// @Param client_name query string false "Filter by client name"
// @Success 200 {object} dto.Paginated[dto.AuditLogDto]
// @Router /api/audit-logs/all [get]
func (alc *AuditLogController) listAllAuditLogsHandler(c *gin.Context) {
var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
_ = c.Error(err)
return
}
var filters dto.AuditLogFilterDto
if err := c.ShouldBindQuery(&filters); err != nil {
_ = c.Error(err)
return
}
logs, pagination, err := alc.auditLogService.ListAllAuditLogs(c.Request.Context(), sortedPaginationRequest, filters)
if err != nil {
_ = c.Error(err)
return
}
var logsDtos []dto.AuditLogDto
err = dto.MapStructList(logs, &logsDtos)
if err != nil {
_ = c.Error(err)
return
}
for i, logsDto := range logsDtos {
logsDto.Device = alc.auditLogService.DeviceStringFromUserAgent(logs[i].UserAgent)
logsDto.Username = logs[i].User.Username
logsDtos[i] = logsDto
}
c.JSON(http.StatusOK, dto.Paginated[dto.AuditLogDto]{
Data: logsDtos,
Pagination: pagination,
})
}
// listClientNamesHandler godoc
// @Summary List client names
// @Description Get a list of all client names for audit log filtering
// @Tags Audit Logs
// @Success 200 {array} string "List of client names"
// @Router /api/audit-logs/filters/client-names [get]
func (alc *AuditLogController) listClientNamesHandler(c *gin.Context) {
names, err := alc.auditLogService.ListClientNames(c.Request.Context())
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, names)
}
// listUserNamesWithIdsHandler godoc
// @Summary List users with IDs
// @Description Get a list of all usernames with their IDs for audit log filtering
// @Tags Audit Logs
// @Success 200 {object} map[string]string "Map of user IDs to usernames"
// @Router /api/audit-logs/filters/users [get]
func (alc *AuditLogController) listUserNamesWithIdsHandler(c *gin.Context) {
users, err := alc.auditLogService.ListUsernamesWithIds(c.Request.Context())
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, users)
}

View File

@@ -41,7 +41,7 @@ type CustomClaimController struct {
// @Security BearerAuth // @Security BearerAuth
// @Router /api/custom-claims/suggestions [get] // @Router /api/custom-claims/suggestions [get]
func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) { func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
claims, err := ccc.customClaimService.GetSuggestions() claims, err := ccc.customClaimService.GetSuggestions(c.Request.Context())
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -69,7 +69,7 @@ func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Contex
} }
userId := c.Param("userId") userId := c.Param("userId")
claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(userId, input) claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(c.Request.Context(), userId, input)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -104,7 +104,7 @@ func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.C
} }
userGroupId := c.Param("userGroupId") userGroupId := c.Param("userGroupId")
claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(userGroupId, input) claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(c.Request.Context(), userGroupId, input)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return

View File

@@ -36,7 +36,7 @@ func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
return return
} }
if err := tc.TestService.ResetAppConfig(); err != nil { if err := tc.TestService.ResetAppConfig(c.Request.Context()); err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
} }

View File

@@ -1,6 +1,7 @@
package controller package controller
import ( import (
"errors"
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
@@ -31,6 +32,7 @@ func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
group.POST("/oidc/userinfo", oc.userInfoHandler) group.POST("/oidc/userinfo", oc.userInfoHandler)
group.POST("/oidc/end-session", authMiddleware.WithSuccessOptional().Add(), oc.EndSessionHandler) group.POST("/oidc/end-session", authMiddleware.WithSuccessOptional().Add(), oc.EndSessionHandler)
group.GET("/oidc/end-session", authMiddleware.WithSuccessOptional().Add(), oc.EndSessionHandler) group.GET("/oidc/end-session", authMiddleware.WithSuccessOptional().Add(), oc.EndSessionHandler)
group.POST("/oidc/introspect", oc.introspectTokenHandler)
group.GET("/oidc/clients", authMiddleware.Add(), oc.listClientsHandler) group.GET("/oidc/clients", authMiddleware.Add(), oc.listClientsHandler)
group.POST("/oidc/clients", authMiddleware.Add(), oc.createClientHandler) group.POST("/oidc/clients", authMiddleware.Add(), oc.createClientHandler)
@@ -45,6 +47,10 @@ func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
group.GET("/oidc/clients/:id/logo", oc.getClientLogoHandler) group.GET("/oidc/clients/:id/logo", oc.getClientLogoHandler)
group.DELETE("/oidc/clients/:id/logo", oc.deleteClientLogoHandler) group.DELETE("/oidc/clients/:id/logo", oc.deleteClientLogoHandler)
group.POST("/oidc/clients/:id/logo", authMiddleware.Add(), fileSizeLimitMiddleware.Add(2<<20), oc.updateClientLogoHandler) group.POST("/oidc/clients/:id/logo", authMiddleware.Add(), fileSizeLimitMiddleware.Add(2<<20), oc.updateClientLogoHandler)
group.POST("/oidc/device/authorize", oc.deviceAuthorizationHandler)
group.POST("/oidc/device/verify", authMiddleware.WithAdminNotRequired().Add(), oc.verifyDeviceCodeHandler)
group.GET("/oidc/device/info", authMiddleware.WithAdminNotRequired().Add(), oc.getDeviceCodeInfoHandler)
} }
type OidcController struct { type OidcController struct {
@@ -69,7 +75,7 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
return return
} }
code, callbackURL, err := oc.oidcService.Authorize(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent()) code, callbackURL, err := oc.oidcService.Authorize(c.Request.Context(), input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent())
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -100,7 +106,7 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex
return return
} }
hasAuthorizedClient, err := oc.oidcService.HasAuthorizedClient(input.ClientID, c.GetString("userID"), input.Scope) hasAuthorizedClient, err := oc.oidcService.HasAuthorizedClient(c.Request.Context(), input.ClientID, c.GetString("userID"), input.Scope)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -144,24 +150,26 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
return return
} }
clientID := input.ClientID
clientSecret := input.ClientSecret
// Client id and secret can also be passed over the Authorization header // Client id and secret can also be passed over the Authorization header
if clientID == "" && clientSecret == "" { if input.ClientID == "" && input.ClientSecret == "" {
clientID, clientSecret, _ = c.Request.BasicAuth() input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
} }
idToken, accessToken, refreshToken, expiresIn, err := oc.oidcService.CreateTokens( idToken, accessToken, refreshToken, expiresIn, err :=
input.Code, oc.oidcService.CreateTokens(c.Request.Context(), input)
input.GrantType,
clientID,
clientSecret,
input.CodeVerifier,
input.RefreshToken,
)
if err != nil { switch {
case errors.Is(err, &common.OidcAuthorizationPendingError{}):
c.JSON(http.StatusBadRequest, gin.H{
"error": "authorization_pending",
})
return
case errors.Is(err, &common.OidcSlowDownError{}):
c.JSON(http.StatusBadRequest, gin.H{
"error": "slow_down",
})
return
case err != nil:
_ = c.Error(err) _ = c.Error(err)
return return
} }
@@ -216,7 +224,7 @@ func (oc *OidcController) userInfoHandler(c *gin.Context) {
_ = c.Error(&common.TokenInvalidError{}) _ = c.Error(&common.TokenInvalidError{})
return return
} }
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientID[0]) claims, err := oc.oidcService.GetUserClaimsForClient(c.Request.Context(), userID, clientID[0])
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -254,7 +262,7 @@ func (oc *OidcController) EndSessionHandler(c *gin.Context) {
} }
} }
callbackURL, err := oc.oidcService.ValidateEndSession(input, c.GetString("userID")) callbackURL, err := oc.oidcService.ValidateEndSession(c.Request.Context(), input, c.GetString("userID"))
if err != nil { if err != nil {
// If the validation fails, the user has to confirm the logout manually and doesn't get redirected // If the validation fails, the user has to confirm the logout manually and doesn't get redirected
log.Printf("Error getting logout callback URL, the user has to confirm the logout manually: %v", err) log.Printf("Error getting logout callback URL, the user has to confirm the logout manually: %v", err)
@@ -290,6 +298,37 @@ func (oc *OidcController) EndSessionHandlerPost(c *gin.Context) {
// Implementation is the same as GET // Implementation is the same as GET
} }
// introspectToken godoc
// @Summary Introspect OIDC tokens
// @Description Pass an access_token to verify if it is considered valid.
// @Tags OIDC
// @Produce json
// @Param token formData string true "The token to be introspected."
// @Success 200 {object} dto.OidcIntrospectionResponseDto "Response with the introspection result."
// @Router /api/oidc/introspect [post]
func (oc *OidcController) introspectTokenHandler(c *gin.Context) {
var input dto.OidcIntrospectDto
if err := c.ShouldBind(&input); err != nil {
_ = c.Error(err)
return
}
// Client id and secret have to be passed over the Authorization header. This kind of
// authentication allows us to keep the endpoint protected (since it could be used to
// find valid tokens) while still allowing it to be used by an application that is
// supposed to interact with our IdP (since that needs to have a client_id
// and client_secret anyway).
clientID, clientSecret, _ := c.Request.BasicAuth()
response, err := oc.oidcService.IntrospectToken(c.Request.Context(), clientID, clientSecret, input.Token)
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, response)
}
// getClientMetaDataHandler godoc // getClientMetaDataHandler godoc
// @Summary Get client metadata // @Summary Get client metadata
// @Description Get OIDC client metadata for discovery and configuration // @Description Get OIDC client metadata for discovery and configuration
@@ -300,7 +339,7 @@ func (oc *OidcController) EndSessionHandlerPost(c *gin.Context) {
// @Router /api/oidc/clients/{id}/meta [get] // @Router /api/oidc/clients/{id}/meta [get]
func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) { func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) {
clientId := c.Param("id") clientId := c.Param("id")
client, err := oc.oidcService.GetClient(clientId) client, err := oc.oidcService.GetClient(c.Request.Context(), clientId)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -327,7 +366,7 @@ func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) {
// @Router /api/oidc/clients/{id} [get] // @Router /api/oidc/clients/{id} [get]
func (oc *OidcController) getClientHandler(c *gin.Context) { func (oc *OidcController) getClientHandler(c *gin.Context) {
clientId := c.Param("id") clientId := c.Param("id")
client, err := oc.oidcService.GetClient(clientId) client, err := oc.oidcService.GetClient(c.Request.Context(), clientId)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -363,7 +402,7 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) {
return return
} }
clients, pagination, err := oc.oidcService.ListClients(searchTerm, sortedPaginationRequest) clients, pagination, err := oc.oidcService.ListClients(c.Request.Context(), searchTerm, sortedPaginationRequest)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -398,7 +437,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
return return
} }
client, err := oc.oidcService.CreateClient(input, c.GetString("userID")) client, err := oc.oidcService.CreateClient(c.Request.Context(), input, c.GetString("userID"))
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -422,7 +461,7 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
// @Security BearerAuth // @Security BearerAuth
// @Router /api/oidc/clients/{id} [delete] // @Router /api/oidc/clients/{id} [delete]
func (oc *OidcController) deleteClientHandler(c *gin.Context) { func (oc *OidcController) deleteClientHandler(c *gin.Context) {
err := oc.oidcService.DeleteClient(c.Param("id")) err := oc.oidcService.DeleteClient(c.Request.Context(), c.Param("id"))
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -449,7 +488,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
return return
} }
client, err := oc.oidcService.UpdateClient(c.Param("id"), input) client, err := oc.oidcService.UpdateClient(c.Request.Context(), c.Param("id"), input)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -474,7 +513,7 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
// @Security BearerAuth // @Security BearerAuth
// @Router /api/oidc/clients/{id}/secret [post] // @Router /api/oidc/clients/{id}/secret [post]
func (oc *OidcController) createClientSecretHandler(c *gin.Context) { func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
secret, err := oc.oidcService.CreateClientSecret(c.Param("id")) secret, err := oc.oidcService.CreateClientSecret(c.Request.Context(), c.Param("id"))
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -494,7 +533,7 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
// @Success 200 {file} binary "Logo image" // @Success 200 {file} binary "Logo image"
// @Router /api/oidc/clients/{id}/logo [get] // @Router /api/oidc/clients/{id}/logo [get]
func (oc *OidcController) getClientLogoHandler(c *gin.Context) { func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id")) imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Request.Context(), c.Param("id"))
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -521,7 +560,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
return return
} }
err = oc.oidcService.UpdateClientLogo(c.Param("id"), file) err = oc.oidcService.UpdateClientLogo(c.Request.Context(), c.Param("id"), file)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -539,7 +578,7 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
// @Security BearerAuth // @Security BearerAuth
// @Router /api/oidc/clients/{id}/logo [delete] // @Router /api/oidc/clients/{id}/logo [delete]
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) { func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
err := oc.oidcService.DeleteClientLogo(c.Param("id")) err := oc.oidcService.DeleteClientLogo(c.Request.Context(), c.Param("id"))
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -566,7 +605,7 @@ func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
return return
} }
oidcClient, err := oc.oidcService.UpdateAllowedUserGroups(c.Param("id"), input) oidcClient, err := oc.oidcService.UpdateAllowedUserGroups(c.Request.Context(), c.Param("id"), input)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -580,3 +619,60 @@ func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
c.JSON(http.StatusOK, oidcClientDto) c.JSON(http.StatusOK, oidcClientDto)
} }
func (oc *OidcController) deviceAuthorizationHandler(c *gin.Context) {
var input dto.OidcDeviceAuthorizationRequestDto
if err := c.ShouldBind(&input); err != nil {
_ = c.Error(err)
return
}
// Client id and secret can also be passed over the Authorization header
if input.ClientID == "" && input.ClientSecret == "" {
input.ClientID, input.ClientSecret, _ = c.Request.BasicAuth()
}
response, err := oc.oidcService.CreateDeviceAuthorization(c.Request.Context(), input)
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, response)
}
func (oc *OidcController) verifyDeviceCodeHandler(c *gin.Context) {
userCode := c.Query("code")
if userCode == "" {
_ = c.Error(&common.ValidationError{Message: "code is required"})
return
}
// Get IP address and user agent from the request context
ipAddress := c.ClientIP()
userAgent := c.Request.UserAgent()
err := oc.oidcService.VerifyDeviceCode(c.Request.Context(), userCode, c.GetString("userID"), ipAddress, userAgent)
if err != nil {
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
func (oc *OidcController) getDeviceCodeInfoHandler(c *gin.Context) {
userCode := c.Query("code")
if userCode == "" {
_ = c.Error(&common.ValidationError{Message: "code is required"})
return
}
deviceCodeInfo, err := oc.oidcService.GetDeviceCodeInfo(c.Request.Context(), userCode, c.GetString("userID"))
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, deviceCodeInfo)
}

View File

@@ -43,9 +43,10 @@ func NewUserController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
group.POST("/users/me/one-time-access-token", authMiddleware.WithAdminNotRequired().Add(), uc.createOwnOneTimeAccessTokenHandler) group.POST("/users/me/one-time-access-token", authMiddleware.WithAdminNotRequired().Add(), uc.createOwnOneTimeAccessTokenHandler)
group.POST("/users/:id/one-time-access-token", authMiddleware.Add(), uc.createAdminOneTimeAccessTokenHandler) group.POST("/users/:id/one-time-access-token", authMiddleware.Add(), uc.createAdminOneTimeAccessTokenHandler)
group.POST("/users/:id/one-time-access-email", authMiddleware.Add(), uc.RequestOneTimeAccessEmailAsAdminHandler)
group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler) group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler) group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler)
group.POST("/one-time-access-email", rateLimitMiddleware.Add(rate.Every(10*time.Minute), 3), uc.requestOneTimeAccessEmailHandler) group.POST("/one-time-access-email", rateLimitMiddleware.Add(rate.Every(10*time.Minute), 3), uc.RequestOneTimeAccessEmailAsUnauthenticatedUserHandler)
group.DELETE("/users/:id/profile-picture", authMiddleware.Add(), uc.resetUserProfilePictureHandler) group.DELETE("/users/:id/profile-picture", authMiddleware.Add(), uc.resetUserProfilePictureHandler)
group.DELETE("/users/me/profile-picture", authMiddleware.WithAdminNotRequired().Add(), uc.resetCurrentUserProfilePictureHandler) group.DELETE("/users/me/profile-picture", authMiddleware.WithAdminNotRequired().Add(), uc.resetCurrentUserProfilePictureHandler)
@@ -65,7 +66,7 @@ type UserController struct {
// @Router /api/users/{id}/groups [get] // @Router /api/users/{id}/groups [get]
func (uc *UserController) getUserGroupsHandler(c *gin.Context) { func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
userID := c.Param("id") userID := c.Param("id")
groups, err := uc.userService.GetUserGroups(userID) groups, err := uc.userService.GetUserGroups(c.Request.Context(), userID)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -99,7 +100,7 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
return return
} }
users, pagination, err := uc.userService.ListUsers(searchTerm, sortedPaginationRequest) users, pagination, err := uc.userService.ListUsers(c.Request.Context(), searchTerm, sortedPaginationRequest)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -125,7 +126,7 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /api/users/{id} [get] // @Router /api/users/{id} [get]
func (uc *UserController) getUserHandler(c *gin.Context) { func (uc *UserController) getUserHandler(c *gin.Context) {
user, err := uc.userService.GetUser(c.Param("id")) user, err := uc.userService.GetUser(c.Request.Context(), c.Param("id"))
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -147,7 +148,7 @@ func (uc *UserController) getUserHandler(c *gin.Context) {
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /api/users/me [get] // @Router /api/users/me [get]
func (uc *UserController) getCurrentUserHandler(c *gin.Context) { func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
user, err := uc.userService.GetUser(c.GetString("userID")) user, err := uc.userService.GetUser(c.Request.Context(), c.GetString("userID"))
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -170,7 +171,7 @@ func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
// @Success 204 "No Content" // @Success 204 "No Content"
// @Router /api/users/{id} [delete] // @Router /api/users/{id} [delete]
func (uc *UserController) deleteUserHandler(c *gin.Context) { func (uc *UserController) deleteUserHandler(c *gin.Context) {
if err := uc.userService.DeleteUser(c.Param("id"), false); err != nil { if err := uc.userService.DeleteUser(c.Request.Context(), c.Param("id"), false); err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
} }
@@ -192,7 +193,7 @@ func (uc *UserController) createUserHandler(c *gin.Context) {
return return
} }
user, err := uc.userService.CreateUser(input) user, err := uc.userService.CreateUser(c.Request.Context(), input)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -227,7 +228,7 @@ func (uc *UserController) updateUserHandler(c *gin.Context) {
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /api/users/me [put] // @Router /api/users/me [put]
func (uc *UserController) updateCurrentUserHandler(c *gin.Context) { func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
if !uc.appConfigService.DbConfig.AllowOwnAccountEdit.IsTrue() { if !uc.appConfigService.GetDbConfig().AllowOwnAccountEdit.IsTrue() {
_ = c.Error(&common.AccountEditNotAllowedError{}) _ = c.Error(&common.AccountEditNotAllowedError{})
return return
} }
@@ -245,13 +246,19 @@ func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
func (uc *UserController) getUserProfilePictureHandler(c *gin.Context) { func (uc *UserController) getUserProfilePictureHandler(c *gin.Context) {
userID := c.Param("id") userID := c.Param("id")
picture, size, err := uc.userService.GetProfilePicture(userID) picture, size, err := uc.userService.GetProfilePicture(c.Request.Context(), userID)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
} }
if picture != nil {
defer picture.Close()
}
c.Header("Cache-Control", "public, max-age=300") _, ok := c.GetQuery("skipCache")
if !ok {
c.Header("Cache-Control", "public, max-age=900")
}
c.DataFromReader(http.StatusOK, size, "image/png", picture, nil) c.DataFromReader(http.StatusOK, size, "image/png", picture, nil)
} }
@@ -329,7 +336,7 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context, own bo
if own { if own {
input.UserID = c.GetString("userID") input.UserID = c.GetString("userID")
} }
token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt) token, err := uc.userService.CreateOneTimeAccessToken(c.Request.Context(), input.UserID, input.ExpiresAt)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -350,18 +357,63 @@ func (uc *UserController) createOwnOneTimeAccessTokenHandler(c *gin.Context) {
uc.createOneTimeAccessTokenHandler(c, true) uc.createOneTimeAccessTokenHandler(c, true)
} }
// createAdminOneTimeAccessTokenHandler godoc
// @Summary Create one-time access token for user (admin)
// @Description Generate a one-time access token for a specific user (admin only)
// @Tags Users
// @Param id path string true "User ID"
// @Param body body dto.OneTimeAccessTokenCreateDto true "Token options"
// @Success 201 {object} object "{ \"token\": \"string\" }"
// @Router /api/users/{id}/one-time-access-token [post]
func (uc *UserController) createAdminOneTimeAccessTokenHandler(c *gin.Context) { func (uc *UserController) createAdminOneTimeAccessTokenHandler(c *gin.Context) {
uc.createOneTimeAccessTokenHandler(c, false) uc.createOneTimeAccessTokenHandler(c, false)
} }
func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) { // RequestOneTimeAccessEmailAsUnauthenticatedUserHandler godoc
var input dto.OneTimeAccessEmailDto // @Summary Request one-time access email
// @Description Request a one-time access email for unauthenticated users
// @Tags Users
// @Accept json
// @Produce json
// @Param body body dto.OneTimeAccessEmailAsUnauthenticatedUserDto true "Email request information"
// @Success 204 "No Content"
// @Router /api/one-time-access-email [post]
func (uc *UserController) RequestOneTimeAccessEmailAsUnauthenticatedUserHandler(c *gin.Context) {
var input dto.OneTimeAccessEmailAsUnauthenticatedUserDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
} }
err := uc.userService.RequestOneTimeAccessEmail(input.Email, input.RedirectPath) err := uc.userService.RequestOneTimeAccessEmailAsUnauthenticatedUser(c.Request.Context(), input.Email, input.RedirectPath)
if err != nil {
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
// RequestOneTimeAccessEmailAsAdminHandler godoc
// @Summary Request one-time access email (admin)
// @Description Request a one-time access email for a specific user (admin only)
// @Tags Users
// @Accept json
// @Produce json
// @Param id path string true "User ID"
// @Param body body dto.OneTimeAccessEmailAsAdminDto true "Email request options"
// @Success 204 "No Content"
// @Router /api/users/{id}/one-time-access-email [post]
func (uc *UserController) RequestOneTimeAccessEmailAsAdminHandler(c *gin.Context) {
var input dto.OneTimeAccessEmailAsAdminDto
if err := c.ShouldBindJSON(&input); err != nil {
_ = c.Error(err)
return
}
userID := c.Param("id")
err := uc.userService.RequestOneTimeAccessEmailAsAdmin(c.Request.Context(), userID, input.ExpiresAt)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -378,7 +430,7 @@ func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) {
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /api/one-time-access-token/{token} [post] // @Router /api/one-time-access-token/{token} [post]
func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Param("token"), c.ClientIP(), c.Request.UserAgent()) user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Request.Context(), c.Param("token"), c.ClientIP(), c.Request.UserAgent())
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -390,7 +442,7 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
return return
} }
maxAge := int(uc.appConfigService.DbConfig.SessionDuration.AsDurationMinutes().Seconds()) maxAge := int(uc.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes().Seconds())
cookie.AddAccessTokenCookie(c, maxAge, token) cookie.AddAccessTokenCookie(c, maxAge, token)
c.JSON(http.StatusOK, userDto) c.JSON(http.StatusOK, userDto)
@@ -403,7 +455,7 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /api/one-time-access-token/setup [post] // @Router /api/one-time-access-token/setup [post]
func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) { func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
user, token, err := uc.userService.SetupInitialAdmin() user, token, err := uc.userService.SetupInitialAdmin(c.Request.Context())
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -415,7 +467,7 @@ func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
return return
} }
maxAge := int(uc.appConfigService.DbConfig.SessionDuration.AsDurationMinutes().Seconds()) maxAge := int(uc.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes().Seconds())
cookie.AddAccessTokenCookie(c, maxAge, token) cookie.AddAccessTokenCookie(c, maxAge, token)
c.JSON(http.StatusOK, userDto) c.JSON(http.StatusOK, userDto)
@@ -436,7 +488,7 @@ func (uc *UserController) updateUserGroups(c *gin.Context) {
return return
} }
user, err := uc.userService.UpdateUserGroups(c.Param("id"), input.UserGroupIds) user, err := uc.userService.UpdateUserGroups(c.Request.Context(), c.Param("id"), input.UserGroupIds)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -466,7 +518,7 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
userID = c.Param("id") userID = c.Param("id")
} }
user, err := uc.userService.UpdateUser(userID, input, updateOwnUser, false) user, err := uc.userService.UpdateUser(c.Request.Context(), userID, input, updateOwnUser, false)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return

View File

@@ -47,6 +47,8 @@ type UserGroupController struct {
// @Success 200 {object} dto.Paginated[dto.UserGroupDtoWithUserCount] // @Success 200 {object} dto.Paginated[dto.UserGroupDtoWithUserCount]
// @Router /api/user-groups [get] // @Router /api/user-groups [get]
func (ugc *UserGroupController) list(c *gin.Context) { func (ugc *UserGroupController) list(c *gin.Context) {
ctx := c.Request.Context()
searchTerm := c.Query("search") searchTerm := c.Query("search")
var sortedPaginationRequest utils.SortedPaginationRequest var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil { if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
@@ -54,7 +56,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
return return
} }
groups, pagination, err := ugc.UserGroupService.List(searchTerm, sortedPaginationRequest) groups, pagination, err := ugc.UserGroupService.List(ctx, searchTerm, sortedPaginationRequest)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -68,7 +70,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
_ = c.Error(err) _ = c.Error(err)
return return
} }
groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(group.ID) groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(ctx, group.ID)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -93,7 +95,7 @@ func (ugc *UserGroupController) list(c *gin.Context) {
// @Security BearerAuth // @Security BearerAuth
// @Router /api/user-groups/{id} [get] // @Router /api/user-groups/{id} [get]
func (ugc *UserGroupController) get(c *gin.Context) { func (ugc *UserGroupController) get(c *gin.Context) {
group, err := ugc.UserGroupService.Get(c.Param("id")) group, err := ugc.UserGroupService.Get(c.Request.Context(), c.Param("id"))
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -125,7 +127,7 @@ func (ugc *UserGroupController) create(c *gin.Context) {
return return
} }
group, err := ugc.UserGroupService.Create(input) group, err := ugc.UserGroupService.Create(c.Request.Context(), input)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -158,7 +160,7 @@ func (ugc *UserGroupController) update(c *gin.Context) {
return return
} }
group, err := ugc.UserGroupService.Update(c.Param("id"), input, false) group, err := ugc.UserGroupService.Update(c.Request.Context(), c.Param("id"), input)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -184,7 +186,7 @@ func (ugc *UserGroupController) update(c *gin.Context) {
// @Security BearerAuth // @Security BearerAuth
// @Router /api/user-groups/{id} [delete] // @Router /api/user-groups/{id} [delete]
func (ugc *UserGroupController) delete(c *gin.Context) { func (ugc *UserGroupController) delete(c *gin.Context) {
if err := ugc.UserGroupService.Delete(c.Param("id")); err != nil { if err := ugc.UserGroupService.Delete(c.Request.Context(), c.Param("id")); err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
} }
@@ -210,7 +212,7 @@ func (ugc *UserGroupController) updateUsers(c *gin.Context) {
return return
} }
group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input.UserIDs) group, err := ugc.UserGroupService.UpdateUsers(c.Request.Context(), c.Param("id"), input.UserIDs)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return

View File

@@ -37,7 +37,7 @@ type WebauthnController struct {
func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) { func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
options, err := wc.webAuthnService.BeginRegistration(userID) options, err := wc.webAuthnService.BeginRegistration(c.Request.Context(), userID)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -55,7 +55,7 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
} }
userID := c.GetString("userID") userID := c.GetString("userID")
credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request) credential, err := wc.webAuthnService.VerifyRegistration(c.Request.Context(), sessionID, userID, c.Request)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -71,7 +71,7 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
} }
func (wc *WebauthnController) beginLoginHandler(c *gin.Context) { func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
options, err := wc.webAuthnService.BeginLogin() options, err := wc.webAuthnService.BeginLogin(c.Request.Context())
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -94,7 +94,7 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
return return
} }
user, token, err := wc.webAuthnService.VerifyLogin(sessionID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent()) user, token, err := wc.webAuthnService.VerifyLogin(c.Request.Context(), sessionID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent())
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -106,7 +106,7 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
return return
} }
maxAge := int(wc.appConfigService.DbConfig.SessionDuration.AsDurationMinutes().Seconds()) maxAge := int(wc.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes().Seconds())
cookie.AddAccessTokenCookie(c, maxAge, token) cookie.AddAccessTokenCookie(c, maxAge, token)
c.JSON(http.StatusOK, userDto) c.JSON(http.StatusOK, userDto)
@@ -114,7 +114,7 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) { func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
credentials, err := wc.webAuthnService.ListCredentials(userID) credentials, err := wc.webAuthnService.ListCredentials(c.Request.Context(), userID)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -133,7 +133,7 @@ func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
credentialID := c.Param("id") credentialID := c.Param("id")
err := wc.webAuthnService.DeleteCredential(userID, credentialID) err := wc.webAuthnService.DeleteCredential(c.Request.Context(), userID, credentialID)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return
@@ -152,7 +152,7 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) {
return return
} }
credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name) credential, err := wc.webAuthnService.UpdateCredential(c.Request.Context(), userID, credentialID, input.Name)
if err != nil { if err != nil {
_ = c.Error(err) _ = c.Error(err)
return return

View File

@@ -74,8 +74,10 @@ func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
"token_endpoint": appUrl + "/api/oidc/token", "token_endpoint": appUrl + "/api/oidc/token",
"userinfo_endpoint": appUrl + "/api/oidc/userinfo", "userinfo_endpoint": appUrl + "/api/oidc/userinfo",
"end_session_endpoint": appUrl + "/api/oidc/end-session", "end_session_endpoint": appUrl + "/api/oidc/end-session",
"introspection_endpoint": appUrl + "/api/oidc/introspect",
"device_authorization_endpoint": appUrl + "/api/oidc/device/authorize",
"jwks_uri": appUrl + "/.well-known/jwks.json", "jwks_uri": appUrl + "/.well-known/jwks.json",
"grant_types_supported": []string{"authorization_code", "refresh_token"}, "grant_types_supported": []string{"authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:device_code"},
"scopes_supported": []string{"openid", "profile", "email", "groups"}, "scopes_supported": []string{"openid", "profile", "email", "groups"},
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"}, "claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
"response_types_supported": []string{"code", "id_token"}, "response_types_supported": []string{"code", "id_token"},

View File

@@ -17,6 +17,7 @@ type ApiKeyDto struct {
ExpiresAt datatype.DateTime `json:"expiresAt"` ExpiresAt datatype.DateTime `json:"expiresAt"`
LastUsedAt *datatype.DateTime `json:"lastUsedAt"` LastUsedAt *datatype.DateTime `json:"lastUsedAt"`
CreatedAt datatype.DateTime `json:"createdAt"` CreatedAt datatype.DateTime `json:"createdAt"`
ExpirationEmailSent bool `json:"expirationEmailSent"`
} }
type ApiKeyResponseDto struct { type ApiKeyResponseDto struct {

View File

@@ -15,8 +15,9 @@ type AppConfigUpdateDto struct {
AppName string `json:"appName" binding:"required,min=1,max=30"` AppName string `json:"appName" binding:"required,min=1,max=30"`
SessionDuration string `json:"sessionDuration" binding:"required"` SessionDuration string `json:"sessionDuration" binding:"required"`
EmailsVerified string `json:"emailsVerified" binding:"required"` EmailsVerified string `json:"emailsVerified" binding:"required"`
DisableAnimations string `json:"disableAnimations" binding:"required"`
AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"` AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"`
SmtHost string `json:"smtpHost"` SmtpHost string `json:"smtpHost"`
SmtpPort string `json:"smtpPort"` SmtpPort string `json:"smtpPort"`
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"` SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`
SmtpUser string `json:"smtpUser"` SmtpUser string `json:"smtpUser"`
@@ -41,6 +42,9 @@ type AppConfigUpdateDto struct {
LdapAttributeGroupUniqueIdentifier string `json:"ldapAttributeGroupUniqueIdentifier"` LdapAttributeGroupUniqueIdentifier string `json:"ldapAttributeGroupUniqueIdentifier"`
LdapAttributeGroupName string `json:"ldapAttributeGroupName"` LdapAttributeGroupName string `json:"ldapAttributeGroupName"`
LdapAttributeAdminGroup string `json:"ldapAttributeAdminGroup"` LdapAttributeAdminGroup string `json:"ldapAttributeAdminGroup"`
EmailOneTimeAccessEnabled string `json:"emailOneTimeAccessEnabled" binding:"required"` LdapSoftDeleteUsers string `json:"ldapSoftDeleteUsers"`
EmailOneTimeAccessAsAdminEnabled string `json:"emailOneTimeAccessAsAdminEnabled" binding:"required"`
EmailOneTimeAccessAsUnauthenticatedEnabled string `json:"emailOneTimeAccessAsUnauthenticatedEnabled" binding:"required"`
EmailLoginNotificationEnabled string `json:"emailLoginNotificationEnabled" binding:"required"` EmailLoginNotificationEnabled string `json:"emailLoginNotificationEnabled" binding:"required"`
EmailApiKeyExpirationEnabled string `json:"emailApiKeyExpirationEnabled" binding:"required"`
} }

View File

@@ -15,5 +15,12 @@ type AuditLogDto struct {
City string `json:"city"` City string `json:"city"`
Device string `json:"device"` Device string `json:"device"`
UserID string `json:"userID"` UserID string `json:"userID"`
Username string `json:"username"`
Data model.AuditLogData `json:"data"` Data model.AuditLogData `json:"data"`
} }
type AuditLogFilterDto struct {
UserID string `form:"filters[userId]"`
Event string `form:"filters[event]"`
ClientName string `form:"filters[clientName]"`
}

View File

@@ -49,12 +49,17 @@ type AuthorizationRequiredDto struct {
type OidcCreateTokensDto struct { type OidcCreateTokensDto struct {
GrantType string `form:"grant_type" binding:"required"` GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code"` Code string `form:"code"`
DeviceCode string `form:"device_code"`
ClientID string `form:"client_id"` ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"` ClientSecret string `form:"client_secret"`
CodeVerifier string `form:"code_verifier"` CodeVerifier string `form:"code_verifier"`
RefreshToken string `form:"refresh_token"` RefreshToken string `form:"refresh_token"`
} }
type OidcIntrospectDto struct {
Token string `form:"token" binding:"required"`
}
type OidcUpdateAllowedUserGroupsDto struct { type OidcUpdateAllowedUserGroupsDto struct {
UserGroupIDs []string `json:"userGroupIds" binding:"required"` UserGroupIDs []string `json:"userGroupIds" binding:"required"`
} }
@@ -73,3 +78,45 @@ type OidcTokenResponseDto struct {
RefreshToken string `json:"refresh_token,omitempty"` RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
} }
type OidcIntrospectionResponseDto struct {
Active bool `json:"active"`
TokenType string `json:"token_type,omitempty"`
Scope string `json:"scope,omitempty"`
Expiration int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
Subject string `json:"sub,omitempty"`
Audience []string `json:"aud,omitempty"`
Issuer string `json:"iss,omitempty"`
Identifier string `json:"jti,omitempty"`
}
type OidcDeviceAuthorizationRequestDto struct {
ClientID string `form:"client_id" binding:"required"`
Scope string `form:"scope" binding:"required"`
ClientSecret string `form:"client_secret"`
}
type OidcDeviceAuthorizationResponseDto struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri"`
VerificationURIComplete string `json:"verification_uri_complete"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
RequiresAuthorization bool `json:"requires_authorization"`
}
type OidcDeviceTokenRequestDto struct {
GrantType string `form:"grant_type" binding:"required,eq=urn:ietf:params:oauth:grant-type:device_code"`
DeviceCode string `form:"device_code" binding:"required"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
}
type DeviceCodeInfoDto struct {
Scope string `json:"scope"`
AuthorizationRequired bool `json:"authorizationRequired"`
Client OidcClientMetaDataDto `json:"client"`
}

View File

@@ -13,15 +13,17 @@ type UserDto struct {
CustomClaims []CustomClaimDto `json:"customClaims"` CustomClaims []CustomClaimDto `json:"customClaims"`
UserGroups []UserGroupDto `json:"userGroups"` UserGroups []UserGroupDto `json:"userGroups"`
LdapID *string `json:"ldapId"` LdapID *string `json:"ldapId"`
Disabled bool `json:"disabled"`
} }
type UserCreateDto struct { type UserCreateDto struct {
Username string `json:"username" binding:"required,username,min=2,max=50"` Username string `json:"username" binding:"required,username,min=2,max=50"`
Email string `json:"email" binding:"required,email"` Email string `json:"email" binding:"required,email"`
FirstName string `json:"firstName" binding:"required,min=1,max=50"` FirstName string `json:"firstName" binding:"required,min=1,max=50"`
LastName string `json:"lastName" binding:"required,min=1,max=50"` LastName string `json:"lastName" binding:"max=50"`
IsAdmin bool `json:"isAdmin"` IsAdmin bool `json:"isAdmin"`
Locale *string `json:"locale"` Locale *string `json:"locale"`
Disabled bool `json:"disabled"`
LdapID string `json:"-"` LdapID string `json:"-"`
} }
@@ -30,11 +32,15 @@ type OneTimeAccessTokenCreateDto struct {
ExpiresAt time.Time `json:"expiresAt" binding:"required"` ExpiresAt time.Time `json:"expiresAt" binding:"required"`
} }
type OneTimeAccessEmailDto struct { type OneTimeAccessEmailAsUnauthenticatedUserDto struct {
Email string `json:"email" binding:"required,email"` Email string `json:"email" binding:"required,email"`
RedirectPath string `json:"redirectPath"` RedirectPath string `json:"redirectPath"`
} }
type OneTimeAccessEmailAsAdminDto struct {
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
}
type UserUpdateUserGroupDto struct { type UserUpdateUserGroupDto struct {
UserGroupIds []string `json:"userGroupIds" binding:"required"` UserGroupIds []string `json:"userGroupIds" binding:"required"`
} }

View File

@@ -0,0 +1,45 @@
package job
import (
"context"
"log"
"github.com/pocket-id/pocket-id/backend/internal/service"
)
type ApiKeyEmailJobs struct {
apiKeyService *service.ApiKeyService
appConfigService *service.AppConfigService
}
func (s *Scheduler) RegisterApiKeyExpiryJob(ctx context.Context, apiKeyService *service.ApiKeyService, appConfigService *service.AppConfigService) error {
jobs := &ApiKeyEmailJobs{
apiKeyService: apiKeyService,
appConfigService: appConfigService,
}
return s.registerJob(ctx, "ExpiredApiKeyEmailJob", "0 0 * * *", jobs.checkAndNotifyExpiringApiKeys)
}
func (j *ApiKeyEmailJobs) checkAndNotifyExpiringApiKeys(ctx context.Context) error {
// Skip if the feature is disabled
if !j.appConfigService.GetDbConfig().EmailApiKeyExpirationEnabled.IsTrue() {
return nil
}
apiKeys, err := j.apiKeyService.ListExpiringApiKeys(ctx, 7)
if err != nil {
log.Printf("Failed to list expiring API keys: %v", err)
return err
}
for _, key := range apiKeys {
if key.User.Email == "" {
continue
}
if err := j.apiKeyService.SendApiKeyExpiringSoonEmail(ctx, key); err != nil {
log.Printf("Failed to send email for key %s: %v", key.ID, err)
}
}
return nil
}

View File

@@ -1,76 +0,0 @@
package job
import (
"log"
"time"
"github.com/go-co-op/gocron/v2"
"github.com/google/uuid"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"gorm.io/gorm"
)
func RegisterDbCleanupJobs(db *gorm.DB) {
scheduler, err := gocron.NewScheduler()
if err != nil {
log.Fatalf("Failed to create a new scheduler: %s", err)
}
jobs := &Jobs{db: db}
registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions)
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
registerJob(scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens)
registerJob(scheduler, "ClearAuditLogs", "0 3 * * *", jobs.clearAuditLogs)
scheduler.Start()
}
type Jobs struct {
db *gorm.DB
}
// ClearWebauthnSessions deletes WebAuthn sessions that have expired
func (j *Jobs) clearWebauthnSessions() error {
return j.db.Delete(&model.WebauthnSession{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
}
// ClearOneTimeAccessTokens deletes one-time access tokens that have expired
func (j *Jobs) clearOneTimeAccessTokens() error {
return j.db.Debug().Delete(&model.OneTimeAccessToken{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
}
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
func (j *Jobs) clearOidcAuthorizationCodes() error {
return j.db.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
}
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
func (j *Jobs) clearOidcRefreshTokens() error {
return j.db.Delete(&model.OidcRefreshToken{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
}
// ClearAuditLogs deletes audit logs older than 90 days
func (j *Jobs) clearAuditLogs() error {
return j.db.Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90))).Error
}
func registerJob(scheduler gocron.Scheduler, name string, interval string, job func() error) {
_, err := scheduler.NewJob(
gocron.CronJob(interval, false),
gocron.NewTask(job),
gocron.WithEventListeners(
gocron.AfterJobRuns(func(jobID uuid.UUID, jobName string) {
log.Printf("Job %q run successfully", name)
}),
gocron.AfterJobRunsWithError(func(jobID uuid.UUID, jobName string, err error) {
log.Printf("Job %q failed with error: %v", name, err)
}),
),
)
if err != nil {
log.Fatalf("Failed to register job %q: %v", name, err)
}
}

View File

@@ -0,0 +1,68 @@
package job
import (
"context"
"errors"
"time"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
)
func (s *Scheduler) RegisterDbCleanupJobs(ctx context.Context, db *gorm.DB) error {
jobs := &DbCleanupJobs{db: db}
return errors.Join(
s.registerJob(ctx, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions),
s.registerJob(ctx, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens),
s.registerJob(ctx, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes),
s.registerJob(ctx, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens),
s.registerJob(ctx, "ClearAuditLogs", "0 3 * * *", jobs.clearAuditLogs),
)
}
type DbCleanupJobs struct {
db *gorm.DB
}
// ClearWebauthnSessions deletes WebAuthn sessions that have expired
func (j *DbCleanupJobs) clearWebauthnSessions(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.WebauthnSession{}, "expires_at < ?", datatype.DateTime(time.Now())).
Error
}
// ClearOneTimeAccessTokens deletes one-time access tokens that have expired
func (j *DbCleanupJobs) clearOneTimeAccessTokens(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.OneTimeAccessToken{}, "expires_at < ?", datatype.DateTime(time.Now())).
Error
}
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
func (j *DbCleanupJobs) clearOidcAuthorizationCodes(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", datatype.DateTime(time.Now())).
Error
}
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
func (j *DbCleanupJobs) clearOidcRefreshTokens(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.OidcRefreshToken{}, "expires_at < ?", datatype.DateTime(time.Now())).
Error
}
// ClearAuditLogs deletes audit logs older than 90 days
func (j *DbCleanupJobs) clearAuditLogs(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90))).
Error
}

View File

@@ -0,0 +1,76 @@
package job
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
)
func (s *Scheduler) RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB) error {
jobs := &FileCleanupJobs{db: db}
return s.registerJob(ctx, "ClearUnusedDefaultProfilePictures", "0 2 * * 0", jobs.clearUnusedDefaultProfilePictures)
}
type FileCleanupJobs struct {
db *gorm.DB
}
// ClearUnusedDefaultProfilePictures deletes default profile pictures that don't match any user's initials
func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures(ctx context.Context) error {
var users []model.User
err := j.db.
WithContext(ctx).
Find(&users).
Error
if err != nil {
return fmt.Errorf("failed to fetch users: %w", err)
}
// Create a map to track which initials are in use
initialsInUse := make(map[string]struct{})
for _, user := range users {
initialsInUse[user.Initials()] = struct{}{}
}
defaultPicturesDir := common.EnvConfig.UploadPath + "/profile-pictures/defaults"
if _, err := os.Stat(defaultPicturesDir); os.IsNotExist(err) {
return nil
}
files, err := os.ReadDir(defaultPicturesDir)
if err != nil {
return fmt.Errorf("failed to read default profile pictures directory: %w", err)
}
filesDeleted := 0
for _, file := range files {
if file.IsDir() {
continue // Skip directories
}
filename := file.Name()
initials := strings.TrimSuffix(filename, ".png")
// If these initials aren't used by any user, delete the file
if _, ok := initialsInUse[initials]; !ok {
filePath := filepath.Join(defaultPicturesDir, filename)
if err := os.Remove(filePath); err != nil {
log.Printf("Failed to delete unused default profile picture %s: %v", filePath, err)
} else {
filesDeleted++
}
}
}
log.Printf("Deleted %d unused default profile pictures", filesDeleted)
return nil
}

View File

@@ -1,9 +1,9 @@
package job package job
import ( import (
"context"
"log" "log"
"github.com/go-co-op/gocron/v2"
"github.com/pocket-id/pocket-id/backend/internal/service" "github.com/pocket-id/pocket-id/backend/internal/service"
) )
@@ -12,28 +12,29 @@ type LdapJobs struct {
appConfigService *service.AppConfigService appConfigService *service.AppConfigService
} }
func RegisterLdapJobs(ldapService *service.LdapService, appConfigService *service.AppConfigService) { func (s *Scheduler) RegisterLdapJobs(ctx context.Context, ldapService *service.LdapService, appConfigService *service.AppConfigService) error {
jobs := &LdapJobs{ldapService: ldapService, appConfigService: appConfigService} jobs := &LdapJobs{ldapService: ldapService, appConfigService: appConfigService}
scheduler, err := gocron.NewScheduler()
if err != nil {
log.Fatalf("Failed to create a new scheduler: %s", err)
}
// Register the job to run every hour // Register the job to run every hour
registerJob(scheduler, "SyncLdap", "0 * * * *", jobs.syncLdap) err := s.registerJob(ctx, "SyncLdap", "0 * * * *", jobs.syncLdap)
if err != nil {
return err
}
// Run the job immediately on startup // Run the job immediately on startup
if err := jobs.syncLdap(); err != nil { err = jobs.syncLdap(ctx)
log.Printf("Failed to sync LDAP: %s", err) if err != nil {
// Log the error only, but don't return it
log.Printf("Failed to sync LDAP: %v", err)
} }
scheduler.Start()
}
func (j *LdapJobs) syncLdap() error {
if j.appConfigService.DbConfig.LdapEnabled.IsTrue() {
return j.ldapService.SyncAll()
}
return nil return nil
} }
func (j *LdapJobs) syncLdap(ctx context.Context) error {
if !j.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return nil
}
return j.ldapService.SyncAll(ctx)
}

View File

@@ -0,0 +1,64 @@
package job
import (
"context"
"fmt"
"log"
"github.com/go-co-op/gocron/v2"
"github.com/google/uuid"
)
type Scheduler struct {
scheduler gocron.Scheduler
}
func NewScheduler() (*Scheduler, error) {
scheduler, err := gocron.NewScheduler()
if err != nil {
return nil, fmt.Errorf("failed to create a new scheduler: %w", err)
}
return &Scheduler{
scheduler: scheduler,
}, nil
}
// Run the scheduler.
// This function blocks until the context is canceled.
func (s *Scheduler) Run(ctx context.Context) {
log.Println("Starting job scheduler")
s.scheduler.Start()
// Block until context is canceled
<-ctx.Done()
err := s.scheduler.Shutdown()
if err != nil {
log.Printf("[WARN] Error shutting down job scheduler: %v", err)
} else {
log.Println("Job scheduler shut down")
}
}
func (s *Scheduler) registerJob(ctx context.Context, name string, interval string, job func(ctx context.Context) error) error {
_, err := s.scheduler.NewJob(
gocron.CronJob(interval, false),
gocron.NewTask(job),
gocron.WithContext(ctx),
gocron.WithEventListeners(
gocron.AfterJobRuns(func(jobID uuid.UUID, jobName string) {
log.Printf("Job %q run successfully", name)
}),
gocron.AfterJobRunsWithError(func(jobID uuid.UUID, jobName string, err error) {
log.Printf("Job %q failed with error: %v", name, err)
}),
),
)
if err != nil {
return fmt.Errorf("failed to register job %q: %w", name, err)
}
return nil
}

View File

@@ -36,12 +36,15 @@ func (m *ApiKeyAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
func (m *ApiKeyAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (userID string, isAdmin bool, err error) { func (m *ApiKeyAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (userID string, isAdmin bool, err error) {
apiKey := c.GetHeader("X-API-KEY") apiKey := c.GetHeader("X-API-KEY")
user, err := m.apiKeyService.ValidateApiKey(apiKey) user, err := m.apiKeyService.ValidateApiKey(c.Request.Context(), apiKey)
if err != nil { if err != nil {
return "", false, &common.NotSignedInError{} return "", false, &common.NotSignedInError{}
} }
// Check if the user is an admin if user.Disabled {
return "", false, &common.UserDisabledError{}
}
if adminRequired && !user.IsAdmin { if adminRequired && !user.IsAdmin {
return "", false, &common.MissingPermissionError{} return "", false, &common.MissingPermissionError{}
} }

View File

@@ -1,7 +1,10 @@
package middleware package middleware
import ( import (
"errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/service" "github.com/pocket-id/pocket-id/backend/internal/service"
) )
@@ -19,11 +22,12 @@ type AuthOptions struct {
func NewAuthMiddleware( func NewAuthMiddleware(
apiKeyService *service.ApiKeyService, apiKeyService *service.ApiKeyService,
userService *service.UserService,
jwtService *service.JwtService, jwtService *service.JwtService,
) *AuthMiddleware { ) *AuthMiddleware {
return &AuthMiddleware{ return &AuthMiddleware{
apiKeyMiddleware: NewApiKeyAuthMiddleware(apiKeyService, jwtService), apiKeyMiddleware: NewApiKeyAuthMiddleware(apiKeyService, jwtService),
jwtMiddleware: NewJwtAuthMiddleware(jwtService), jwtMiddleware: NewJwtAuthMiddleware(jwtService, userService),
options: AuthOptions{ options: AuthOptions{
AdminRequired: true, AdminRequired: true,
SuccessOptional: false, SuccessOptional: false,
@@ -57,22 +61,32 @@ func (m *AuthMiddleware) WithSuccessOptional() *AuthMiddleware {
func (m *AuthMiddleware) Add() gin.HandlerFunc { func (m *AuthMiddleware) Add() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// First try JWT auth
userID, isAdmin, err := m.jwtMiddleware.Verify(c, m.options.AdminRequired) userID, isAdmin, err := m.jwtMiddleware.Verify(c, m.options.AdminRequired)
if err == nil { if err == nil {
// JWT auth succeeded, continue with the request
c.Set("userID", userID) c.Set("userID", userID)
c.Set("userIsAdmin", isAdmin) c.Set("userIsAdmin", isAdmin)
if c.IsAborted() {
return
}
c.Next() c.Next()
return return
} }
// If JWT auth failed and the error is not a NotSignedInError, abort the request
if !errors.Is(err, &common.NotSignedInError{}) {
c.Abort()
_ = c.Error(err)
return
}
// JWT auth failed, try API key auth // JWT auth failed, try API key auth
userID, isAdmin, err = m.apiKeyMiddleware.Verify(c, m.options.AdminRequired) userID, isAdmin, err = m.apiKeyMiddleware.Verify(c, m.options.AdminRequired)
if err == nil { if err == nil {
// API key auth succeeded, continue with the request
c.Set("userID", userID) c.Set("userID", userID)
c.Set("userIsAdmin", isAdmin) c.Set("userIsAdmin", isAdmin)
if c.IsAborted() {
return
}
c.Next() c.Next()
return return
} }

View File

@@ -16,9 +16,10 @@ func NewCorsMiddleware() *CorsMiddleware {
func (m *CorsMiddleware) Add() gin.HandlerFunc { func (m *CorsMiddleware) Add() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// Allow all origins for the token endpoint // Allow all origins for the token endpoint
if c.FullPath() == "/api/oidc/token" { switch c.FullPath() {
case "/api/oidc/token", "/api/oidc/introspect":
c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else { default:
c.Writer.Header().Set("Access-Control-Allow-Origin", common.EnvConfig.AppURL) c.Writer.Header().Set("Access-Control-Allow-Origin", common.EnvConfig.AppURL)
} }

View File

@@ -10,11 +10,12 @@ import (
) )
type JwtAuthMiddleware struct { type JwtAuthMiddleware struct {
userService *service.UserService
jwtService *service.JwtService jwtService *service.JwtService
} }
func NewJwtAuthMiddleware(jwtService *service.JwtService) *JwtAuthMiddleware { func NewJwtAuthMiddleware(jwtService *service.JwtService, userService *service.UserService) *JwtAuthMiddleware {
return &JwtAuthMiddleware{jwtService: jwtService} return &JwtAuthMiddleware{jwtService: jwtService, userService: userService}
} }
func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc { func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
@@ -55,12 +56,16 @@ func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject
return return
} }
// Check if the user is an admin user, err := m.userService.GetUser(c, subject)
isAdmin, err = service.GetIsAdmin(token)
if err != nil { if err != nil {
return "", false, &common.TokenInvalidError{} return "", false, &common.NotSignedInError{}
} }
if adminRequired && !isAdmin {
if user.Disabled {
return "", false, &common.UserDisabledError{}
}
if adminRequired && !user.IsAdmin {
return "", false, &common.MissingPermissionError{} return "", false, &common.MissingPermissionError{}
} }

View File

@@ -10,6 +10,7 @@ type ApiKey struct {
Description *string Description *string
ExpiresAt datatype.DateTime `sortable:"true"` ExpiresAt datatype.DateTime `sortable:"true"`
LastUsedAt *datatype.DateTime `sortable:"true"` LastUsedAt *datatype.DateTime `sortable:"true"`
ExpirationEmailSent bool
UserID string UserID string
User User User User

View File

@@ -1,17 +1,17 @@
package model package model
import ( import (
"errors"
"fmt"
"reflect"
"strconv" "strconv"
"strings"
"time" "time"
) )
type AppConfigVariable struct { type AppConfigVariable struct {
Key string `gorm:"primaryKey;not null"` Key string `gorm:"primaryKey;not null"`
Type string
IsPublic bool
IsInternal bool
Value string Value string
DefaultValue string
} }
// IsTrue returns true if the value is a truthy string, such as "true", "t", "yes", "1", etc. // IsTrue returns true if the value is a truthy string, such as "true", "t", "yes", "1", etc.
@@ -31,41 +31,159 @@ func (a *AppConfigVariable) AsDurationMinutes() time.Duration {
type AppConfig struct { type AppConfig struct {
// General // General
AppName AppConfigVariable AppName AppConfigVariable `key:"appName,public"` // Public
SessionDuration AppConfigVariable SessionDuration AppConfigVariable `key:"sessionDuration"`
EmailsVerified AppConfigVariable EmailsVerified AppConfigVariable `key:"emailsVerified"`
AllowOwnAccountEdit AppConfigVariable DisableAnimations AppConfigVariable `key:"disableAnimations,public"` // Public
AllowOwnAccountEdit AppConfigVariable `key:"allowOwnAccountEdit,public"` // Public
// Internal // Internal
BackgroundImageType AppConfigVariable BackgroundImageType AppConfigVariable `key:"backgroundImageType,internal"` // Internal
LogoLightImageType AppConfigVariable LogoLightImageType AppConfigVariable `key:"logoLightImageType,internal"` // Internal
LogoDarkImageType AppConfigVariable LogoDarkImageType AppConfigVariable `key:"logoDarkImageType,internal"` // Internal
// Email // Email
SmtpHost AppConfigVariable SmtpHost AppConfigVariable `key:"smtpHost"`
SmtpPort AppConfigVariable SmtpPort AppConfigVariable `key:"smtpPort"`
SmtpFrom AppConfigVariable SmtpFrom AppConfigVariable `key:"smtpFrom"`
SmtpUser AppConfigVariable SmtpUser AppConfigVariable `key:"smtpUser"`
SmtpPassword AppConfigVariable SmtpPassword AppConfigVariable `key:"smtpPassword"`
SmtpTls AppConfigVariable SmtpTls AppConfigVariable `key:"smtpTls"`
SmtpSkipCertVerify AppConfigVariable SmtpSkipCertVerify AppConfigVariable `key:"smtpSkipCertVerify"`
EmailLoginNotificationEnabled AppConfigVariable EmailLoginNotificationEnabled AppConfigVariable `key:"emailLoginNotificationEnabled"`
EmailOneTimeAccessEnabled AppConfigVariable EmailOneTimeAccessAsUnauthenticatedEnabled AppConfigVariable `key:"emailOneTimeAccessAsUnauthenticatedEnabled,public"` // Public
EmailOneTimeAccessAsAdminEnabled AppConfigVariable `key:"emailOneTimeAccessAsAdminEnabled,public"` // Public
EmailApiKeyExpirationEnabled AppConfigVariable `key:"emailApiKeyExpirationEnabled"`
// LDAP // LDAP
LdapEnabled AppConfigVariable LdapEnabled AppConfigVariable `key:"ldapEnabled,public"` // Public
LdapUrl AppConfigVariable LdapUrl AppConfigVariable `key:"ldapUrl"`
LdapBindDn AppConfigVariable LdapBindDn AppConfigVariable `key:"ldapBindDn"`
LdapBindPassword AppConfigVariable LdapBindPassword AppConfigVariable `key:"ldapBindPassword"`
LdapBase AppConfigVariable LdapBase AppConfigVariable `key:"ldapBase"`
LdapUserSearchFilter AppConfigVariable LdapUserSearchFilter AppConfigVariable `key:"ldapUserSearchFilter"`
LdapUserGroupSearchFilter AppConfigVariable LdapUserGroupSearchFilter AppConfigVariable `key:"ldapUserGroupSearchFilter"`
LdapSkipCertVerify AppConfigVariable LdapSkipCertVerify AppConfigVariable `key:"ldapSkipCertVerify"`
LdapAttributeUserUniqueIdentifier AppConfigVariable LdapAttributeUserUniqueIdentifier AppConfigVariable `key:"ldapAttributeUserUniqueIdentifier"`
LdapAttributeUserUsername AppConfigVariable LdapAttributeUserUsername AppConfigVariable `key:"ldapAttributeUserUsername"`
LdapAttributeUserEmail AppConfigVariable LdapAttributeUserEmail AppConfigVariable `key:"ldapAttributeUserEmail"`
LdapAttributeUserFirstName AppConfigVariable LdapAttributeUserFirstName AppConfigVariable `key:"ldapAttributeUserFirstName"`
LdapAttributeUserLastName AppConfigVariable LdapAttributeUserLastName AppConfigVariable `key:"ldapAttributeUserLastName"`
LdapAttributeUserProfilePicture AppConfigVariable LdapAttributeUserProfilePicture AppConfigVariable `key:"ldapAttributeUserProfilePicture"`
LdapAttributeGroupMember AppConfigVariable LdapAttributeGroupMember AppConfigVariable `key:"ldapAttributeGroupMember"`
LdapAttributeGroupUniqueIdentifier AppConfigVariable LdapAttributeGroupUniqueIdentifier AppConfigVariable `key:"ldapAttributeGroupUniqueIdentifier"`
LdapAttributeGroupName AppConfigVariable LdapAttributeGroupName AppConfigVariable `key:"ldapAttributeGroupName"`
LdapAttributeAdminGroup AppConfigVariable LdapAttributeAdminGroup AppConfigVariable `key:"ldapAttributeAdminGroup"`
LdapSoftDeleteUsers AppConfigVariable `key:"ldapSoftDeleteUsers"`
}
func (c *AppConfig) ToAppConfigVariableSlice(showAll bool) []AppConfigVariable {
// Use reflection to iterate through all fields
cfgValue := reflect.ValueOf(c).Elem()
cfgType := cfgValue.Type()
var res []AppConfigVariable
for i := range cfgType.NumField() {
field := cfgType.Field(i)
key, attrs, _ := strings.Cut(field.Tag.Get("key"), ",")
if key == "" {
continue
}
// If we're only showing public variables and this is not public, skip it
if !showAll && attrs != "public" {
continue
}
fieldValue := cfgValue.Field(i)
appConfigVariable := AppConfigVariable{
Key: key,
Value: fieldValue.FieldByName("Value").String(),
}
res = append(res, appConfigVariable)
}
return res
}
func (c *AppConfig) FieldByKey(key string) (string, error) {
rv := reflect.ValueOf(c).Elem()
rt := rv.Type()
// Find the field in the struct whose "key" tag matches
for i := range rt.NumField() {
// Grab only the first part of the key, if there's a comma with additional properties
tagValue, _, _ := strings.Cut(rt.Field(i).Tag.Get("key"), ",")
if tagValue != key {
continue
}
valueField := rv.Field(i).FieldByName("Value")
return valueField.String(), nil
}
// If we are here, the config key was not found
return "", AppConfigKeyNotFoundError{field: key}
}
func (c *AppConfig) UpdateField(key string, value string, noInternal bool) error {
rv := reflect.ValueOf(c).Elem()
rt := rv.Type()
// Find the field in the struct whose "key" tag matches, then update that
for i := range rt.NumField() {
// Separate the key (before the comma) from any optional attributes after
tagValue, attrs, _ := strings.Cut(rt.Field(i).Tag.Get("key"), ",")
if tagValue != key {
continue
}
// If the field is internal and noInternal is true, we skip that
if noInternal && attrs == "internal" {
return AppConfigInternalForbiddenError{field: key}
}
valueField := rv.Field(i).FieldByName("Value")
if !valueField.CanSet() {
return fmt.Errorf("field Value in AppConfigVariable is not settable for config key '%s'", key)
}
// Update the value
valueField.SetString(value)
// Return once updated
return nil
}
// If we're here, we have not found the right field to update
return AppConfigKeyNotFoundError{field: key}
}
type AppConfigKeyNotFoundError struct {
field string
}
func (e AppConfigKeyNotFoundError) Error() string {
return fmt.Sprintf("cannot find config key '%s'", e.field)
}
func (e AppConfigKeyNotFoundError) Is(target error) bool {
// Ignore the field property when checking if an error is of the type AppConfigKeyNotFoundError
x := AppConfigKeyNotFoundError{}
return errors.As(target, &x)
}
type AppConfigInternalForbiddenError struct {
field string
}
func (e AppConfigInternalForbiddenError) Error() string {
return fmt.Sprintf("field '%s' is internal and can't be updated", e.field)
}
func (e AppConfigInternalForbiddenError) Is(target error) bool {
// Ignore the field property when checking if an error is of the type AppConfigInternalForbiddenError
x := AppConfigInternalForbiddenError{}
return errors.As(target, &x)
} }

View File

@@ -1,10 +1,16 @@
package model // We use model_test here to avoid an import cycle
package model_test
import ( import (
"reflect"
"strings"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
) )
func TestAppConfigVariable_AsMinutesDuration(t *testing.T) { func TestAppConfigVariable_AsMinutesDuration(t *testing.T) {
@@ -48,7 +54,7 @@ func TestAppConfigVariable_AsMinutesDuration(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
configVar := AppConfigVariable{ configVar := model.AppConfigVariable{
Value: tt.value, Value: tt.value,
} }
@@ -58,3 +64,66 @@ func TestAppConfigVariable_AsMinutesDuration(t *testing.T) {
}) })
} }
} }
// This test ensures that the model.AppConfig and dto.AppConfigUpdateDto structs match:
// - They should have the same properties, where the "json" tag of dto.AppConfigUpdateDto should match the "key" tag in model.AppConfig
// - dto.AppConfigDto should not include "internal" fields from model.AppConfig
// This test is primarily meant to catch discrepancies between the two structs as fields are added or removed over time
func TestAppConfigStructMatchesUpdateDto(t *testing.T) {
appConfigType := reflect.TypeOf(model.AppConfig{})
updateDtoType := reflect.TypeOf(dto.AppConfigUpdateDto{})
// Process AppConfig fields
appConfigFields := make(map[string]string)
for i := 0; i < appConfigType.NumField(); i++ {
field := appConfigType.Field(i)
if field.Tag.Get("key") == "" {
// Skip internal fields
continue
}
// Extract the key name from the tag (takes the part before any comma)
keyTag := field.Tag.Get("key")
keyName, _, _ := strings.Cut(keyTag, ",")
appConfigFields[field.Name] = keyName
}
// Process AppConfigUpdateDto fields
dtoFields := make(map[string]string)
for i := 0; i < updateDtoType.NumField(); i++ {
field := updateDtoType.Field(i)
// Extract the json name from the tag (takes the part before any binding constraints)
jsonTag := field.Tag.Get("json")
jsonName, _, _ := strings.Cut(jsonTag, ",")
dtoFields[jsonName] = field.Name
}
// Verify every AppConfig field has a matching DTO field with the same name
for fieldName, keyName := range appConfigFields {
if strings.HasSuffix(fieldName, "ImageType") {
// Skip internal fields that shouldn't be in the DTO
continue
}
// Check if there's a DTO field with a matching JSON tag
_, exists := dtoFields[keyName]
assert.True(t, exists, "Field %s with key '%s' in AppConfig has no matching field in AppConfigUpdateDto", fieldName, keyName)
}
// Verify every DTO field has a matching AppConfig field
for jsonName, fieldName := range dtoFields {
// Find a matching field in AppConfig by key tag
found := false
for _, keyName := range appConfigFields {
if keyName == jsonName {
found = true
break
}
}
assert.True(t, found, "Field %s with json tag '%s' in AppConfigUpdateDto has no matching field in AppConfig", fieldName, jsonName)
}
}

View File

@@ -3,7 +3,7 @@ package model
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "fmt"
) )
type AuditLog struct { type AuditLog struct {
@@ -14,8 +14,11 @@ type AuditLog struct {
Country string `sortable:"true"` Country string `sortable:"true"`
City string `sortable:"true"` City string `sortable:"true"`
UserAgent string `sortable:"true"` UserAgent string `sortable:"true"`
UserID string Username string `gorm:"-"`
Data AuditLogData Data AuditLogData
UserID string
User User
} }
type AuditLogData map[string]string //nolint:recvcheck type AuditLogData map[string]string //nolint:recvcheck
@@ -27,11 +30,13 @@ const (
AuditLogEventOneTimeAccessTokenSignIn AuditLogEvent = "TOKEN_SIGN_IN" AuditLogEventOneTimeAccessTokenSignIn AuditLogEvent = "TOKEN_SIGN_IN"
AuditLogEventClientAuthorization AuditLogEvent = "CLIENT_AUTHORIZATION" AuditLogEventClientAuthorization AuditLogEvent = "CLIENT_AUTHORIZATION"
AuditLogEventNewClientAuthorization AuditLogEvent = "NEW_CLIENT_AUTHORIZATION" AuditLogEventNewClientAuthorization AuditLogEvent = "NEW_CLIENT_AUTHORIZATION"
AuditLogEventDeviceCodeAuthorization AuditLogEvent = "DEVICE_CODE_AUTHORIZATION"
AuditLogEventNewDeviceCodeAuthorization AuditLogEvent = "NEW_DEVICE_CODE_AUTHORIZATION"
) )
// Scan and Value methods for GORM to handle the custom type // Scan and Value methods for GORM to handle the custom type
func (e *AuditLogEvent) Scan(value interface{}) error { func (e *AuditLogEvent) Scan(value any) error {
*e = AuditLogEvent(value.(string)) *e = AuditLogEvent(value.(string))
return nil return nil
} }
@@ -40,11 +45,14 @@ func (e AuditLogEvent) Value() (driver.Value, error) {
return string(e), nil return string(e), nil
} }
func (d *AuditLogData) Scan(value interface{}) error { func (d *AuditLogData) Scan(value any) error {
if v, ok := value.([]byte); ok { switch v := value.(type) {
case []byte:
return json.Unmarshal(v, d) return json.Unmarshal(v, d)
} else { case string:
return errors.New("type assertion to []byte failed") return json.Unmarshal([]byte(v), d)
default:
return fmt.Errorf("unsupported type: %T", value)
} }
} }

View File

@@ -3,7 +3,7 @@ package model
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "fmt"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"gorm.io/gorm" "gorm.io/gorm"
@@ -74,13 +74,30 @@ func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
type UrlList []string //nolint:recvcheck type UrlList []string //nolint:recvcheck
func (cu *UrlList) Scan(value interface{}) error { func (cu *UrlList) Scan(value interface{}) error {
if v, ok := value.([]byte); ok { switch v := value.(type) {
case []byte:
return json.Unmarshal(v, cu) return json.Unmarshal(v, cu)
} else { case string:
return errors.New("type assertion to []byte failed") return json.Unmarshal([]byte(v), cu)
default:
return fmt.Errorf("unsupported type: %T", value)
} }
} }
func (cu UrlList) Value() (driver.Value, error) { func (cu UrlList) Value() (driver.Value, error) {
return json.Marshal(cu) return json.Marshal(cu)
} }
type OidcDeviceCode struct {
Base
DeviceCode string
UserCode string
Scope string
ExpiresAt datatype.DateTime
IsAuthorized bool
UserID *string
User User
ClientID string
Client OidcClient
}

View File

@@ -1,9 +1,12 @@
package model package model
import ( import (
"strings"
"github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn" "github.com/go-webauthn/webauthn/webauthn"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
) )
type User struct { type User struct {
@@ -16,6 +19,7 @@ type User struct {
IsAdmin bool `sortable:"true"` IsAdmin bool `sortable:"true"`
Locale *string Locale *string
LdapID *string LdapID *string
Disabled bool `sortable:"true"`
CustomClaims []CustomClaim CustomClaims []CustomClaim
UserGroups []UserGroup `gorm:"many2many:user_groups_users;"` UserGroups []UserGroup `gorm:"many2many:user_groups_users;"`
@@ -63,6 +67,15 @@ func (u User) WebAuthnCredentialDescriptors() (descriptors []protocol.Credential
func (u User) FullName() string { return u.FirstName + " " + u.LastName } func (u User) FullName() string { return u.FirstName + " " + u.LastName }
func (u User) Initials() string {
first := utils.GetFirstCharacter(u.FirstName)
last := utils.GetFirstCharacter(u.LastName)
if first == "" && last == "" && len(u.Username) >= 2 {
return strings.ToUpper(u.Username[:2])
}
return strings.ToUpper(first + last)
}
type OneTimeAccessToken struct { type OneTimeAccessToken struct {
Base Base
Token string Token string

View File

@@ -3,7 +3,7 @@ package model
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "fmt"
"time" "time"
"github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/protocol"
@@ -49,11 +49,13 @@ type AuthenticatorTransportList []protocol.AuthenticatorTransport //nolint:recvc
// Scan and Value methods for GORM to handle the custom type // Scan and Value methods for GORM to handle the custom type
func (atl *AuthenticatorTransportList) Scan(value interface{}) error { func (atl *AuthenticatorTransportList) Scan(value interface{}) error {
switch v := value.(type) {
if v, ok := value.([]byte); ok { case []byte:
return json.Unmarshal(v, atl) return json.Unmarshal(v, atl)
} else { case string:
return errors.New("type assertion to []byte failed") return json.Unmarshal([]byte(v), atl)
default:
return fmt.Errorf("unsupported type: %T", value)
} }
} }

View File

@@ -1,29 +1,35 @@
package service package service
import ( import (
"context"
"errors" "errors"
"log"
"time" "time"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type ApiKeyService struct { type ApiKeyService struct {
db *gorm.DB db *gorm.DB
emailService *EmailService
} }
func NewApiKeyService(db *gorm.DB) *ApiKeyService { func NewApiKeyService(db *gorm.DB, emailService *EmailService) *ApiKeyService {
return &ApiKeyService{db: db} return &ApiKeyService{db: db, emailService: emailService}
} }
func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) { func (s *ApiKeyService) ListApiKeys(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) {
query := s.db.Where("user_id = ?", userID).Model(&model.ApiKey{}) query := s.db.
WithContext(ctx).
Where("user_id = ?", userID).
Model(&model.ApiKey{})
var apiKeys []model.ApiKey var apiKeys []model.ApiKey
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &apiKeys) pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &apiKeys)
@@ -34,7 +40,7 @@ func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils
return apiKeys, pagination, nil return apiKeys, pagination, nil
} }
func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) { func (s *ApiKeyService) CreateApiKey(ctx context.Context, userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) {
// Check if expiration is in the future // Check if expiration is in the future
if !input.ExpiresAt.ToTime().After(time.Now()) { if !input.ExpiresAt.ToTime().After(time.Now()) {
return model.ApiKey{}, "", &common.APIKeyExpirationDateError{} return model.ApiKey{}, "", &common.APIKeyExpirationDateError{}
@@ -54,7 +60,11 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (
UserID: userID, UserID: userID,
} }
if err := s.db.Create(&apiKey).Error; err != nil { err = s.db.
WithContext(ctx).
Create(&apiKey).
Error
if err != nil {
return model.ApiKey{}, "", err return model.ApiKey{}, "", err
} }
@@ -62,29 +72,44 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (
return apiKey, token, nil return apiKey, token, nil
} }
func (s *ApiKeyService) RevokeApiKey(userID, apiKeyID string) error { func (s *ApiKeyService) RevokeApiKey(ctx context.Context, userID, apiKeyID string) error {
var apiKey model.ApiKey var apiKey model.ApiKey
if err := s.db.Where("id = ? AND user_id = ?", apiKeyID, userID).First(&apiKey).Error; err != nil { err := s.db.
WithContext(ctx).
Where("id = ? AND user_id = ?", apiKeyID, userID).
Delete(&apiKey).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return &common.APIKeyNotFoundError{} return &common.APIKeyNotFoundError{}
} }
return err return err
} }
return s.db.Delete(&apiKey).Error return nil
} }
func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) { func (s *ApiKeyService) ValidateApiKey(ctx context.Context, apiKey string) (model.User, error) {
if apiKey == "" { if apiKey == "" {
return model.User{}, &common.NoAPIKeyProvidedError{} return model.User{}, &common.NoAPIKeyProvidedError{}
} }
var key model.ApiKey now := time.Now()
hashedKey := utils.CreateSha256Hash(apiKey) hashedKey := utils.CreateSha256Hash(apiKey)
if err := s.db.Preload("User").Where("key = ? AND expires_at > ?", var key model.ApiKey
hashedKey, datatype.DateTime(time.Now())).Preload("User").First(&key).Error; err != nil { err := s.db.
WithContext(ctx).
Model(&model.ApiKey{}).
Clauses(clause.Returning{}).
Where("key = ? AND expires_at > ?", hashedKey, datatype.DateTime(now)).
Updates(&model.ApiKey{
LastUsedAt: utils.Ptr(datatype.DateTime(now)),
}).
Preload("User").
First(&key).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return model.User{}, &common.InvalidAPIKeyError{} return model.User{}, &common.InvalidAPIKeyError{}
} }
@@ -92,12 +117,49 @@ func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) {
return model.User{}, err return model.User{}, err
} }
// Update last used time
now := datatype.DateTime(time.Now())
key.LastUsedAt = &now
if err := s.db.Save(&key).Error; err != nil {
log.Printf("Failed to update last used time: %v", err)
}
return key.User, nil return key.User, nil
} }
func (s *ApiKeyService) ListExpiringApiKeys(ctx context.Context, daysAhead int) ([]model.ApiKey, error) {
var keys []model.ApiKey
now := time.Now()
cutoff := now.AddDate(0, 0, daysAhead)
err := s.db.
WithContext(ctx).
Preload("User").
Where("expires_at > ? AND expires_at <= ? AND expiration_email_sent = ?", datatype.DateTime(now), datatype.DateTime(cutoff), false).
Find(&keys).
Error
return keys, err
}
func (s *ApiKeyService) SendApiKeyExpiringSoonEmail(ctx context.Context, apiKey model.ApiKey) error {
user := apiKey.User
if user.ID == "" {
if err := s.db.WithContext(ctx).First(&user, "id = ?", apiKey.UserID).Error; err != nil {
return err
}
}
err := SendEmail(ctx, s.emailService, email.Address{
Name: user.FullName(),
Email: user.Email,
}, ApiKeyExpiringSoonTemplate, &ApiKeyExpiringSoonTemplateData{
ApiKeyName: apiKey.Name,
ExpiresAt: apiKey.ExpiresAt.ToTime(),
Name: user.FirstName,
})
if err != nil {
return err
}
// Mark the API key as having had an expiration email sent
return s.db.WithContext(ctx).
Model(&model.ApiKey{}).
Where("id = ?", apiKey.ID).
Update("expiration_email_sent", true).
Error
}

View File

@@ -1,396 +1,426 @@
package service package service
import ( import (
"context"
"errors"
"fmt" "fmt"
"log" "log"
"mime/multipart" "mime/multipart"
"os" "os"
"reflect" "reflect"
"strings"
"sync/atomic"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
) )
type AppConfigService struct { type AppConfigService struct {
DbConfig *model.AppConfig dbConfig atomic.Pointer[model.AppConfig]
db *gorm.DB db *gorm.DB
} }
func NewAppConfigService(db *gorm.DB) *AppConfigService { func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
service := &AppConfigService{ service := &AppConfigService{
DbConfig: &defaultDbConfig,
db: db, db: db,
} }
if err := service.InitDbConfig(); err != nil {
err := service.LoadDbConfig(ctx)
if err != nil {
log.Fatalf("Failed to initialize app config service: %v", err) log.Fatalf("Failed to initialize app config service: %v", err)
} }
return service return service
} }
var defaultDbConfig = model.AppConfig{ // GetDbConfig returns the application configuration.
// General // Important: Treat the object as read-only: do not modify its properties directly!
AppName: model.AppConfigVariable{ func (s *AppConfigService) GetDbConfig() *model.AppConfig {
Key: "appName", v := s.dbConfig.Load()
Type: "string", if v == nil {
IsPublic: true, // This indicates a development-time error
DefaultValue: "Pocket ID", panic("called GetDbConfig before DbConfig is loaded")
},
SessionDuration: model.AppConfigVariable{
Key: "sessionDuration",
Type: "number",
DefaultValue: "60",
},
EmailsVerified: model.AppConfigVariable{
Key: "emailsVerified",
Type: "bool",
DefaultValue: "false",
},
AllowOwnAccountEdit: model.AppConfigVariable{
Key: "allowOwnAccountEdit",
Type: "bool",
IsPublic: true,
DefaultValue: "true",
},
// Internal
BackgroundImageType: model.AppConfigVariable{
Key: "backgroundImageType",
Type: "string",
IsInternal: true,
DefaultValue: "jpg",
},
LogoLightImageType: model.AppConfigVariable{
Key: "logoLightImageType",
Type: "string",
IsInternal: true,
DefaultValue: "svg",
},
LogoDarkImageType: model.AppConfigVariable{
Key: "logoDarkImageType",
Type: "string",
IsInternal: true,
DefaultValue: "svg",
},
// Email
SmtpHost: model.AppConfigVariable{
Key: "smtpHost",
Type: "string",
},
SmtpPort: model.AppConfigVariable{
Key: "smtpPort",
Type: "number",
},
SmtpFrom: model.AppConfigVariable{
Key: "smtpFrom",
Type: "string",
},
SmtpUser: model.AppConfigVariable{
Key: "smtpUser",
Type: "string",
},
SmtpPassword: model.AppConfigVariable{
Key: "smtpPassword",
Type: "string",
},
SmtpTls: model.AppConfigVariable{
Key: "smtpTls",
Type: "string",
DefaultValue: "none",
},
SmtpSkipCertVerify: model.AppConfigVariable{
Key: "smtpSkipCertVerify",
Type: "bool",
DefaultValue: "false",
},
EmailLoginNotificationEnabled: model.AppConfigVariable{
Key: "emailLoginNotificationEnabled",
Type: "bool",
DefaultValue: "false",
},
EmailOneTimeAccessEnabled: model.AppConfigVariable{
Key: "emailOneTimeAccessEnabled",
Type: "bool",
IsPublic: true,
DefaultValue: "false",
},
// LDAP
LdapEnabled: model.AppConfigVariable{
Key: "ldapEnabled",
Type: "bool",
IsPublic: true,
DefaultValue: "false",
},
LdapUrl: model.AppConfigVariable{
Key: "ldapUrl",
Type: "string",
},
LdapBindDn: model.AppConfigVariable{
Key: "ldapBindDn",
Type: "string",
},
LdapBindPassword: model.AppConfigVariable{
Key: "ldapBindPassword",
Type: "string",
},
LdapBase: model.AppConfigVariable{
Key: "ldapBase",
Type: "string",
},
LdapUserSearchFilter: model.AppConfigVariable{
Key: "ldapUserSearchFilter",
Type: "string",
DefaultValue: "(objectClass=person)",
},
LdapUserGroupSearchFilter: model.AppConfigVariable{
Key: "ldapUserGroupSearchFilter",
Type: "string",
DefaultValue: "(objectClass=groupOfNames)",
},
LdapSkipCertVerify: model.AppConfigVariable{
Key: "ldapSkipCertVerify",
Type: "bool",
DefaultValue: "false",
},
LdapAttributeUserUniqueIdentifier: model.AppConfigVariable{
Key: "ldapAttributeUserUniqueIdentifier",
Type: "string",
},
LdapAttributeUserUsername: model.AppConfigVariable{
Key: "ldapAttributeUserUsername",
Type: "string",
},
LdapAttributeUserEmail: model.AppConfigVariable{
Key: "ldapAttributeUserEmail",
Type: "string",
},
LdapAttributeUserFirstName: model.AppConfigVariable{
Key: "ldapAttributeUserFirstName",
Type: "string",
},
LdapAttributeUserLastName: model.AppConfigVariable{
Key: "ldapAttributeUserLastName",
Type: "string",
},
LdapAttributeUserProfilePicture: model.AppConfigVariable{
Key: "ldapAttributeUserProfilePicture",
Type: "string",
},
LdapAttributeGroupMember: model.AppConfigVariable{
Key: "ldapAttributeGroupMember",
Type: "string",
DefaultValue: "member",
},
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{
Key: "ldapAttributeGroupUniqueIdentifier",
Type: "string",
},
LdapAttributeGroupName: model.AppConfigVariable{
Key: "ldapAttributeGroupName",
Type: "string",
},
LdapAttributeAdminGroup: model.AppConfigVariable{
Key: "ldapAttributeAdminGroup",
Type: "string",
},
} }
func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) { return v
}
func (s *AppConfigService) getDefaultDbConfig() *model.AppConfig {
// Values are the default ones
return &model.AppConfig{
// General
AppName: model.AppConfigVariable{Value: "Pocket ID"},
SessionDuration: model.AppConfigVariable{Value: "60"},
EmailsVerified: model.AppConfigVariable{Value: "false"},
DisableAnimations: model.AppConfigVariable{Value: "false"},
AllowOwnAccountEdit: model.AppConfigVariable{Value: "true"},
// Internal
BackgroundImageType: model.AppConfigVariable{Value: "jpg"},
LogoLightImageType: model.AppConfigVariable{Value: "svg"},
LogoDarkImageType: model.AppConfigVariable{Value: "svg"},
// Email
SmtpHost: model.AppConfigVariable{},
SmtpPort: model.AppConfigVariable{},
SmtpFrom: model.AppConfigVariable{},
SmtpUser: model.AppConfigVariable{},
SmtpPassword: model.AppConfigVariable{},
SmtpTls: model.AppConfigVariable{Value: "none"},
SmtpSkipCertVerify: model.AppConfigVariable{Value: "false"},
EmailLoginNotificationEnabled: model.AppConfigVariable{Value: "false"},
EmailOneTimeAccessAsUnauthenticatedEnabled: model.AppConfigVariable{Value: "false"},
EmailOneTimeAccessAsAdminEnabled: model.AppConfigVariable{Value: "false"},
EmailApiKeyExpirationEnabled: model.AppConfigVariable{Value: "false"},
// LDAP
LdapEnabled: model.AppConfigVariable{Value: "false"},
LdapUrl: model.AppConfigVariable{},
LdapBindDn: model.AppConfigVariable{},
LdapBindPassword: model.AppConfigVariable{},
LdapBase: model.AppConfigVariable{},
LdapUserSearchFilter: model.AppConfigVariable{Value: "(objectClass=person)"},
LdapUserGroupSearchFilter: model.AppConfigVariable{Value: "(objectClass=groupOfNames)"},
LdapSkipCertVerify: model.AppConfigVariable{Value: "false"},
LdapAttributeUserUniqueIdentifier: model.AppConfigVariable{},
LdapAttributeUserUsername: model.AppConfigVariable{},
LdapAttributeUserEmail: model.AppConfigVariable{},
LdapAttributeUserFirstName: model.AppConfigVariable{},
LdapAttributeUserLastName: model.AppConfigVariable{},
LdapAttributeUserProfilePicture: model.AppConfigVariable{},
LdapAttributeGroupMember: model.AppConfigVariable{Value: "member"},
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{},
LdapAttributeGroupName: model.AppConfigVariable{},
LdapAttributeAdminGroup: model.AppConfigVariable{},
LdapSoftDeleteUsers: model.AppConfigVariable{Value: "true"},
}
}
func (s *AppConfigService) updateAppConfigStartTransaction(ctx context.Context) (tx *gorm.DB, err error) {
// We start a transaction before doing any work, to ensure that we are the only ones updating the data in the database
// This works across multiple processes too
tx = s.db.Begin()
err = tx.Error
if err != nil {
return nil, fmt.Errorf("failed to begin database transaction: %w", err)
}
// With SQLite there's nothing else we need to do, because a transaction blocks the entire database
// However, with Postgres we need to manually lock the table to prevent others from doing the same
switch s.db.Name() {
case "postgres":
// We do not use "NOWAIT" so this blocks until the database is available, or the context is canceled
// Here we use a context with a 10s timeout in case the database is blocked for longer
lockCtx, lockCancel := context.WithTimeout(ctx, 10*time.Second)
defer lockCancel()
err = tx.
WithContext(lockCtx).
Exec("LOCK TABLE app_config_variables IN ACCESS EXCLUSIVE MODE").
Error
if err != nil {
tx.Rollback()
return nil, fmt.Errorf("failed to acquire lock on app_config_variables table: %w", err)
}
default:
// Nothing to do here
}
return tx, nil
}
func (s *AppConfigService) updateAppConfigUpdateDatabase(ctx context.Context, tx *gorm.DB, dbUpdate *[]model.AppConfigVariable) error {
err := tx.
WithContext(ctx).
Clauses(clause.OnConflict{
// Perform an "upsert" if the key already exists, replacing the value
Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value"}),
}).
Create(&dbUpdate).
Error
if err != nil {
return fmt.Errorf("failed to update config in database: %w", err)
}
return nil
}
func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
if common.EnvConfig.UiConfigDisabled { if common.EnvConfig.UiConfigDisabled {
return nil, &common.UiConfigDisabledError{} return nil, &common.UiConfigDisabledError{}
} }
tx := s.db.Begin() // Start the transaction
tx, err := s.updateAppConfigStartTransaction(ctx)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
// From here onwards, we know we are the only process/goroutine with exclusive access to the config
// Re-load the config from the database to be sure we have the correct data
cfg, err := s.loadDbConfigInternal(ctx, tx)
if err != nil {
return nil, fmt.Errorf("failed to reload config from database: %w", err)
}
defaultCfg := s.getDefaultDbConfig()
// Iterate through all the fields to update
// We update the in-memory data (in the cfg struct) and collect values to update in the database
rt := reflect.ValueOf(input).Type() rt := reflect.ValueOf(input).Type()
rv := reflect.ValueOf(input) rv := reflect.ValueOf(input)
dbUpdate := make([]model.AppConfigVariable, 0, rt.NumField())
var savedConfigVariables []model.AppConfigVariable for i := range rt.NumField() {
for i := 0; i < rt.NumField(); i++ {
field := rt.Field(i) field := rt.Field(i)
key := field.Tag.Get("json")
value := rv.FieldByName(field.Name).String() value := rv.FieldByName(field.Name).String()
// If the emailEnabled is set to false, disable the emailOneTimeAccessEnabled // Get the value of the json tag, taking only what's before the comma
if key == s.DbConfig.EmailOneTimeAccessEnabled.Key { key, _, _ := strings.Cut(field.Tag.Get("json"), ",")
if rv.FieldByName("EmailEnabled").String() == "false" {
value = "false" // Update the in-memory config value
} // If the new value is an empty string, then we set the in-memory value to the default one
// Skip values that are internal only and can't be updated
if value == "" {
// Ignore errors here as we know the key exists
defaultValue, _ := defaultCfg.FieldByKey(key)
err = cfg.UpdateField(key, defaultValue, true)
} else {
err = cfg.UpdateField(key, value, true)
} }
var appConfigVariable model.AppConfigVariable // If we tried to update an internal field, ignore the error (and do not update in the DB)
if err := tx.First(&appConfigVariable, "key = ? AND is_internal = false", key).Error; err != nil { if errors.Is(err, model.AppConfigInternalForbiddenError{}) {
continue
} else if err != nil {
return nil, fmt.Errorf("failed to update in-memory config for key '%s': %w", key, err)
}
// We always save "value" which can be an empty string
dbUpdate = append(dbUpdate, model.AppConfigVariable{
Key: key,
Value: value,
})
}
// Update the values in the database
err = s.updateAppConfigUpdateDatabase(ctx, tx, &dbUpdate)
if err != nil {
return nil, err
}
// Commit the changes to the DB, then finally save the updated config in the object
err = tx.Commit().Error
if err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}
s.dbConfig.Store(cfg)
// Return the updated config
res := cfg.ToAppConfigVariableSlice(true)
return res, nil
}
// UpdateAppConfigValues
func (s *AppConfigService) UpdateAppConfigValues(ctx context.Context, keysAndValues ...string) error {
if common.EnvConfig.UiConfigDisabled {
return &common.UiConfigDisabledError{}
}
// Count of keysAndValues must be even
if len(keysAndValues)%2 != 0 {
return errors.New("invalid number of arguments received")
}
// Start the transaction
tx, err := s.updateAppConfigStartTransaction(ctx)
if err != nil {
return err
}
defer func() {
tx.Rollback() tx.Rollback()
return nil, err }()
// From here onwards, we know we are the only process/goroutine with exclusive access to the config
// Re-load the config from the database to be sure we have the correct data
cfg, err := s.loadDbConfigInternal(ctx, tx)
if err != nil {
return fmt.Errorf("failed to reload config from database: %w", err)
} }
appConfigVariable.Value = value defaultCfg := s.getDefaultDbConfig()
if err := tx.Save(&appConfigVariable).Error; err != nil {
tx.Rollback() // Iterate through all the fields to update
return nil, err // We update the in-memory data (in the cfg struct) and collect values to update in the database
// (Note the += 2, as we are iterating through key-value pairs)
dbUpdate := make([]model.AppConfigVariable, 0, len(keysAndValues)/2)
for i := 0; i < len(keysAndValues); i += 2 {
key := keysAndValues[i]
value := keysAndValues[i+1]
// Ensure that the field is valid
// We do this by grabbing the default value
var defaultValue string
defaultValue, err = defaultCfg.FieldByKey(key)
if err != nil {
return fmt.Errorf("invalid configuration key '%s': %w", key, err)
} }
savedConfigVariables = append(savedConfigVariables, appConfigVariable) // Update the in-memory config value
// If the new value is an empty string, then we set the in-memory value to the default one
// Skip values that are internal only and can't be updated
if value == "" {
err = cfg.UpdateField(key, defaultValue, false)
} else {
err = cfg.UpdateField(key, value, false)
}
if err != nil {
return fmt.Errorf("failed to update in-memory config for key '%s': %w", key, err)
} }
tx.Commit() // We always save "value" which can be an empty string
dbUpdate = append(dbUpdate, model.AppConfigVariable{
if err := s.LoadDbConfigFromDb(); err != nil { Key: key,
return nil, err Value: value,
})
} }
return savedConfigVariables, nil // Update the values in the database
} err = s.updateAppConfigUpdateDatabase(ctx, tx, &dbUpdate)
func (s *AppConfigService) UpdateImageType(imageName string, fileType string) error {
key := fmt.Sprintf("%sImageType", imageName)
err := s.db.Model(&model.AppConfigVariable{}).Where("key = ?", key).Update("value", fileType).Error
if err != nil { if err != nil {
return err return err
} }
return s.LoadDbConfigFromDb() // Commit the changes to the DB, then finally save the updated config in the object
} err = tx.Commit().Error
func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariable, error) {
var configuration []model.AppConfigVariable
var err error
if showAll {
err = s.db.Find(&configuration).Error
} else {
err = s.db.Find(&configuration, "is_public = true").Error
}
if err != nil { if err != nil {
return nil, err return fmt.Errorf("failed to commit transaction: %w", err)
} }
for i := range configuration { s.dbConfig.Store(cfg)
if common.EnvConfig.UiConfigDisabled {
// Set the value to the environment variable if the UI config is disabled
configuration[i].Value = s.getConfigVariableFromEnvironmentVariable(configuration[i].Key, configuration[i].DefaultValue)
} else if configuration[i].Value == "" && configuration[i].DefaultValue != "" { return nil
// Set the value to the default value if it is empty
configuration[i].Value = configuration[i].DefaultValue
}
} }
return configuration, nil func (s *AppConfigService) ListAppConfig(showAll bool) []model.AppConfigVariable {
return s.GetDbConfig().ToAppConfigVariableSlice(showAll)
} }
func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, imageName string, oldImageType string) error { func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multipart.FileHeader, imageName string, oldImageType string) (err error) {
fileType := utils.GetFileExtension(uploadedFile.Filename) fileType := utils.GetFileExtension(uploadedFile.Filename)
mimeType := utils.GetImageMimeType(fileType) mimeType := utils.GetImageMimeType(fileType)
if mimeType == "" { if mimeType == "" {
return &common.FileTypeNotSupportedError{} return &common.FileTypeNotSupportedError{}
} }
// Delete the old image if it has a different file type // Save the updated image
if fileType != oldImageType { imagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + fileType
oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType) err = utils.SaveFile(uploadedFile, imagePath)
if err := os.Remove(oldImagePath); err != nil { if err != nil {
return err return err
} }
}
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType) // Delete the old image if it has a different file type, then update the type in the database
if err := utils.SaveFile(uploadedFile, imagePath); err != nil { if fileType != oldImageType {
oldImagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + oldImageType
err = os.Remove(oldImagePath)
if err != nil {
return err return err
} }
// Update the file type in the database // Update the file type in the database
if err := s.UpdateImageType(imageName, fileType); err != nil { err = s.UpdateAppConfigValues(ctx, imageName+"ImageType", fileType)
if err != nil {
return err return err
} }
}
return nil return nil
} }
// InitDbConfig creates the default configuration values in the database if they do not exist, // LoadDbConfig loads the configuration values from the database into the DbConfig struct.
// updates existing configurations if they differ from the default, and deletes any configurations func (s *AppConfigService) LoadDbConfig(ctx context.Context) (err error) {
// that are not in the default configuration. var dest *model.AppConfig
func (s *AppConfigService) InitDbConfig() error {
// Reflect to get the underlying value of DbConfig and its default configuration
defaultConfigReflectValue := reflect.ValueOf(defaultDbConfig)
defaultKeys := make(map[string]struct{})
// Iterate over the fields of DbConfig // If the UI config is disabled, only load from the env
for i := 0; i < defaultConfigReflectValue.NumField(); i++ { if common.EnvConfig.UiConfigDisabled {
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.AppConfigVariable) dest, err = s.loadDbConfigFromEnv()
} else {
defaultKeys[defaultConfigVar.Key] = struct{}{} dest, err = s.loadDbConfigInternal(ctx, s.db)
}
var storedConfigVar model.AppConfigVariable if err != nil {
if err := s.db.First(&storedConfigVar, "key = ?", defaultConfigVar.Key).Error; err != nil {
// If the configuration does not exist, create it
if err := s.db.Create(&defaultConfigVar).Error; err != nil {
return err return err
} }
// Update the value in the object
s.dbConfig.Store(dest)
return nil
}
func (s *AppConfigService) loadDbConfigFromEnv() (*model.AppConfig, error) {
// First, start from the default configuration
dest := s.getDefaultDbConfig()
// Iterate through each field
rt := reflect.ValueOf(dest).Elem().Type()
rv := reflect.ValueOf(dest).Elem()
for i := range rt.NumField() {
field := rt.Field(i)
// Get the value of the key tag, taking only what's before the comma
// The env var name is the key converted to SCREAMING_SNAKE_CASE
key, _, _ := strings.Cut(field.Tag.Get("key"), ",")
envVarName := utils.CamelCaseToScreamingSnakeCase(key)
// Set the value if it's set
value, ok := os.LookupEnv(envVarName)
if ok {
rv.Field(i).FieldByName("Value").SetString(value)
}
}
return dest, nil
}
func (s *AppConfigService) loadDbConfigInternal(ctx context.Context, tx *gorm.DB) (*model.AppConfig, error) {
// First, start from the default configuration
dest := s.getDefaultDbConfig()
// Load all configuration values from the database
// This loads all values in a single shot
loaded := []model.AppConfigVariable{}
queryCtx, queryCancel := context.WithTimeout(ctx, 10*time.Second)
defer queryCancel()
err := tx.
WithContext(queryCtx).
Find(&loaded).Error
if err != nil {
return nil, fmt.Errorf("failed to load configuration from the database: %w", err)
}
// Iterate through all values loaded from the database
for _, v := range loaded {
// If the value is empty, it means we are using the default value
if v.Value == "" {
continue continue
} }
// Update existing configuration if it differs from the default // Find the field in the struct whose "key" tag matches, then update that
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal || storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue { err = dest.UpdateField(v.Key, v.Value, false)
storedConfigVar.Type = defaultConfigVar.Type
storedConfigVar.IsPublic = defaultConfigVar.IsPublic // We ignore the case of fields that don't exist, as there may be leftover data in the database
storedConfigVar.IsInternal = defaultConfigVar.IsInternal if err != nil && !errors.Is(err, model.AppConfigKeyNotFoundError{}) {
storedConfigVar.DefaultValue = defaultConfigVar.DefaultValue return nil, fmt.Errorf("failed to process config for key '%s': %w", v.Key, err)
if err := s.db.Save(&storedConfigVar).Error; err != nil {
return err
}
} }
} }
// Delete any configurations not in the default keys return dest, nil
var allConfigVars []model.AppConfigVariable
if err := s.db.Find(&allConfigVars).Error; err != nil {
return err
}
for _, config := range allConfigVars {
if _, exists := defaultKeys[config.Key]; !exists {
if err := s.db.Delete(&config).Error; err != nil {
return err
}
}
}
return s.LoadDbConfigFromDb()
}
// LoadDbConfigFromDb loads the configuration values from the database into the DbConfig struct.
func (s *AppConfigService) LoadDbConfigFromDb() error {
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
dbConfigField := dbConfigReflectValue.Field(i)
currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable)
var storedConfigVar model.AppConfigVariable
if err := s.db.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
return err
}
if common.EnvConfig.UiConfigDisabled {
storedConfigVar.Value = s.getConfigVariableFromEnvironmentVariable(currentConfigVar.Key, storedConfigVar.DefaultValue)
} else if storedConfigVar.Value == "" && storedConfigVar.DefaultValue != "" {
storedConfigVar.Value = storedConfigVar.DefaultValue
}
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
}
return nil
}
func (s *AppConfigService) getConfigVariableFromEnvironmentVariable(key, fallbackValue string) string {
environmentVariableName := utils.CamelCaseToScreamingSnakeCase(key)
if value, exists := os.LookupEnv(environmentVariableName); exists {
return value
}
return fallbackValue
} }

View File

@@ -0,0 +1,523 @@
package service
import (
"sync/atomic"
"testing"
"time"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/stretchr/testify/require"
)
// NewTestAppConfigService is a function used by tests to create AppConfigService objects with pre-defined configuration values
func NewTestAppConfigService(config *model.AppConfig) *AppConfigService {
service := &AppConfigService{
dbConfig: atomic.Pointer[model.AppConfig]{},
}
service.dbConfig.Store(config)
return service
}
func TestLoadDbConfig(t *testing.T) {
t.Run("empty config table", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
// Load the config
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Config should be equal to default config
require.Equal(t, service.GetDbConfig(), service.getDefaultDbConfig())
})
t.Run("loads value from config table", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Populate the config table with some initial values
err := db.
Create([]model.AppConfigVariable{
// Should be set to the default value because it's an empty string
{Key: "appName", Value: ""},
// Overrides default value
{Key: "sessionDuration", Value: "5"},
// Does not have a default value
{Key: "smtpHost", Value: "example"},
}).
Error
require.NoError(t, err)
// Load the config
service := &AppConfigService{
db: db,
}
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Values should match expected ones
expect := service.getDefaultDbConfig()
expect.SessionDuration.Value = "5"
expect.SmtpHost.Value = "example"
require.Equal(t, service.GetDbConfig(), expect)
})
t.Run("ignores unknown config keys", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Add an entry with a key that doesn't exist in the config struct
err := db.Create([]model.AppConfigVariable{
{Key: "__nonExistentKey", Value: "some value"},
{Key: "appName", Value: "TestApp"}, // This one should still be loaded
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
// This should not fail, just ignore the unknown key
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
config := service.GetDbConfig()
require.Equal(t, "TestApp", config.AppName.Value)
})
t.Run("loading config multiple times", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Initial state
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "InitialApp"},
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
require.Equal(t, "InitialApp", service.GetDbConfig().AppName.Value)
// Update the database value
err = db.Model(&model.AppConfigVariable{}).
Where("key = ?", "appName").
Update("value", "UpdatedApp").Error
require.NoError(t, err)
// Load the config again, it should reflect the updated value
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
require.Equal(t, "UpdatedApp", service.GetDbConfig().AppName.Value)
})
t.Run("loads config from env when UiConfigDisabled is true", func(t *testing.T) {
// Save the original state and restore it after the test
originalUiConfigDisabled := common.EnvConfig.UiConfigDisabled
defer func() {
common.EnvConfig.UiConfigDisabled = originalUiConfigDisabled
}()
// Set environment variables for testing
t.Setenv("APP_NAME", "EnvTest App")
t.Setenv("SESSION_DURATION", "45")
// Enable UiConfigDisabled to load from env
common.EnvConfig.UiConfigDisabled = true
// Create database with config that should be ignored
db := newAppConfigTestDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
// Load the config
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Config should be loaded from env, not DB
config := service.GetDbConfig()
require.Equal(t, "EnvTest App", config.AppName.Value, "Should load appName from env")
require.Equal(t, "45", config.SessionDuration.Value, "Should load sessionDuration from env")
})
t.Run("ignores env vars when UiConfigDisabled is false", func(t *testing.T) {
// Save the original state and restore it after the test
originalUiConfigDisabled := common.EnvConfig.UiConfigDisabled
defer func() {
common.EnvConfig.UiConfigDisabled = originalUiConfigDisabled
}()
// Set environment variables that should be ignored
t.Setenv("APP_NAME", "EnvTest App")
t.Setenv("SESSION_DURATION", "45")
// Make sure UiConfigDisabled is false to load from DB
common.EnvConfig.UiConfigDisabled = false
// Create database with config values that should take precedence
db := newAppConfigTestDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
// Load the config
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Config should be loaded from DB, not env
config := service.GetDbConfig()
require.Equal(t, "DB App", config.AppName.Value, "Should load appName from DB, not env")
require.Equal(t, "120", config.SessionDuration.Value, "Should load sessionDuration from DB, not env")
})
}
func TestUpdateAppConfigValues(t *testing.T) {
t.Run("update single value", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Update a single config value
err = service.UpdateAppConfigValues(t.Context(), "appName", "Test App")
require.NoError(t, err)
// Verify in-memory config was updated
config := service.GetDbConfig()
require.Equal(t, "Test App", config.AppName.Value)
// Verify database was updated
var dbValue model.AppConfigVariable
err = db.Where("key = ?", "appName").First(&dbValue).Error
require.NoError(t, err)
require.Equal(t, "Test App", dbValue.Value)
})
t.Run("update multiple values", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Update multiple config values
err = service.UpdateAppConfigValues(
t.Context(),
"appName", "Test App",
"sessionDuration", "30",
"smtpHost", "mail.example.com",
)
require.NoError(t, err)
// Verify in-memory config was updated
config := service.GetDbConfig()
require.Equal(t, "Test App", config.AppName.Value)
require.Equal(t, "30", config.SessionDuration.Value)
require.Equal(t, "mail.example.com", config.SmtpHost.Value)
// Verify database was updated
var count int64
db.Model(&model.AppConfigVariable{}).Count(&count)
require.Equal(t, int64(3), count)
var appName, sessionDuration, smtpHost model.AppConfigVariable
err = db.Where("key = ?", "appName").First(&appName).Error
require.NoError(t, err)
require.Equal(t, "Test App", appName.Value)
err = db.Where("key = ?", "sessionDuration").First(&sessionDuration).Error
require.NoError(t, err)
require.Equal(t, "30", sessionDuration.Value)
err = db.Where("key = ?", "smtpHost").First(&smtpHost).Error
require.NoError(t, err)
require.Equal(t, "mail.example.com", smtpHost.Value)
})
t.Run("empty value resets to default", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// First change the value
err = service.UpdateAppConfigValues(t.Context(), "sessionDuration", "30")
require.NoError(t, err)
require.Equal(t, "30", service.GetDbConfig().SessionDuration.Value)
// Now set it to empty which should use default value
err = service.UpdateAppConfigValues(t.Context(), "sessionDuration", "")
require.NoError(t, err)
require.Equal(t, "60", service.GetDbConfig().SessionDuration.Value) // Default value from getDefaultDbConfig
})
t.Run("error with odd number of arguments", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Try to update with odd number of arguments
err = service.UpdateAppConfigValues(t.Context(), "appName", "Test App", "sessionDuration")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid number of arguments")
})
t.Run("error with invalid key", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Try to update with invalid key
err = service.UpdateAppConfigValues(t.Context(), "nonExistentKey", "some value")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid configuration key")
})
}
func TestUpdateAppConfig(t *testing.T) {
t.Run("updates configuration values from DTO", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Create update DTO
input := dto.AppConfigUpdateDto{
AppName: "Updated App Name",
SessionDuration: "120",
SmtpHost: "smtp.example.com",
SmtpPort: "587",
}
// Update config
updatedVars, err := service.UpdateAppConfig(t.Context(), input)
require.NoError(t, err)
// Verify returned updated variables
require.NotEmpty(t, updatedVars)
var foundAppName, foundSessionDuration, foundSmtpHost, foundSmtpPort bool
for _, v := range updatedVars {
switch v.Key {
case "appName":
require.Equal(t, "Updated App Name", v.Value)
foundAppName = true
case "sessionDuration":
require.Equal(t, "120", v.Value)
foundSessionDuration = true
case "smtpHost":
require.Equal(t, "smtp.example.com", v.Value)
foundSmtpHost = true
case "smtpPort":
require.Equal(t, "587", v.Value)
foundSmtpPort = true
}
}
require.True(t, foundAppName)
require.True(t, foundSessionDuration)
require.True(t, foundSmtpHost)
require.True(t, foundSmtpPort)
// Verify in-memory config was updated
config := service.GetDbConfig()
require.Equal(t, "Updated App Name", config.AppName.Value)
require.Equal(t, "120", config.SessionDuration.Value)
require.Equal(t, "smtp.example.com", config.SmtpHost.Value)
require.Equal(t, "587", config.SmtpPort.Value)
// Verify database was updated
var appName, sessionDuration, smtpHost, smtpPort model.AppConfigVariable
err = db.Where("key = ?", "appName").First(&appName).Error
require.NoError(t, err)
require.Equal(t, "Updated App Name", appName.Value)
err = db.Where("key = ?", "sessionDuration").First(&sessionDuration).Error
require.NoError(t, err)
require.Equal(t, "120", sessionDuration.Value)
err = db.Where("key = ?", "smtpHost").First(&smtpHost).Error
require.NoError(t, err)
require.Equal(t, "smtp.example.com", smtpHost.Value)
err = db.Where("key = ?", "smtpPort").First(&smtpPort).Error
require.NoError(t, err)
require.Equal(t, "587", smtpPort.Value)
})
t.Run("empty values reset to defaults", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config and modify some values
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// First set some non-default values
err = service.UpdateAppConfigValues(t.Context(),
"appName", "Custom App",
"sessionDuration", "120",
)
require.NoError(t, err)
// Create update DTO with empty values to reset to defaults
input := dto.AppConfigUpdateDto{
AppName: "", // Should reset to default "Pocket ID"
SessionDuration: "", // Should reset to default "60"
}
// Update config
updatedVars, err := service.UpdateAppConfig(t.Context(), input)
require.NoError(t, err)
// Verify returned updated variables (they should be empty strings in DB)
var foundAppName, foundSessionDuration bool
for _, v := range updatedVars {
switch v.Key {
case "appName":
require.Equal(t, "Pocket ID", v.Value) // Returns the default value
foundAppName = true
case "sessionDuration":
require.Equal(t, "60", v.Value) // Returns the default value
foundSessionDuration = true
}
}
require.True(t, foundAppName)
require.True(t, foundSessionDuration)
// Verify in-memory config was reset to defaults
config := service.GetDbConfig()
require.Equal(t, "Pocket ID", config.AppName.Value) // Default value
require.Equal(t, "60", config.SessionDuration.Value) // Default value
// Verify database was updated with empty values
for _, key := range []string{"appName", "sessionDuration"} {
var loaded model.AppConfigVariable
err = db.Where("key = ?", key).First(&loaded).Error
require.NoErrorf(t, err, "Failed to load DB value for key '%s'", key)
require.Emptyf(t, loaded.Value, "Loaded value for key '%s' is not empty", key)
}
})
t.Run("cannot update when UiConfigDisabled is true", func(t *testing.T) {
// Save the original state and restore it after the test
originalUiConfigDisabled := common.EnvConfig.UiConfigDisabled
defer func() {
common.EnvConfig.UiConfigDisabled = originalUiConfigDisabled
}()
// Disable UI config
common.EnvConfig.UiConfigDisabled = true
db := newAppConfigTestDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Try to update config
_, err = service.UpdateAppConfig(t.Context(), dto.AppConfigUpdateDto{
AppName: "Should Not Update",
})
// Should get a UiConfigDisabledError
require.Error(t, err)
var uiConfigDisabledErr *common.UiConfigDisabledError
require.ErrorAs(t, err, &uiConfigDisabledErr)
})
}
// Implements gorm's logger.Writer interface
type testLoggerAdapter struct {
t *testing.T
}
func (l testLoggerAdapter) Printf(format string, args ...any) {
l.t.Logf(format, args...)
}
func newAppConfigTestDatabaseForTest(t *testing.T) *gorm.DB {
t.Helper()
// Get a name for this in-memory database that is specific to the test
dbName := utils.CreateSha256Hash(t.Name())
// Connect to a new in-memory SQL database
db, err := gorm.Open(
sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
&gorm.Config{
TranslateError: true,
Logger: logger.New(
testLoggerAdapter{t: t},
logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: false,
ParameterizedQueries: false,
Colorful: false,
},
),
})
require.NoError(t, err, "Failed to connect to test database")
// Create the app_config_variables table
err = db.Exec(`
CREATE TABLE app_config_variables
(
key VARCHAR(100) NOT NULL PRIMARY KEY,
value TEXT NOT NULL
)
`).Error
require.NoError(t, err, "Failed to create test config table")
return db
}

View File

@@ -1,9 +1,12 @@
package service package service
import ( import (
"context"
"fmt"
"log" "log"
userAgentParser "github.com/mileusna/useragent" userAgentParser "github.com/mileusna/useragent"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/email" "github.com/pocket-id/pocket-id/backend/internal/utils/email"
@@ -22,10 +25,10 @@ func NewAuditLogService(db *gorm.DB, appConfigService *AppConfigService, emailSe
} }
// Create creates a new audit log entry in the database // Create creates a new audit log entry in the database
func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData) model.AuditLog { func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) model.AuditLog {
country, city, err := s.geoliteService.GetLocationByIP(ipAddress) country, city, err := s.geoliteService.GetLocationByIP(ipAddress)
if err != nil { if err != nil {
log.Printf("Failed to get IP location: %v\n", err) log.Printf("Failed to get IP location: %v", err)
} }
auditLog := model.AuditLog{ auditLog := model.AuditLog{
@@ -39,8 +42,12 @@ func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent
} }
// Save the audit log in the database // Save the audit log in the database
if err := s.db.Create(&auditLog).Error; err != nil { err = tx.
log.Printf("Failed to create audit log: %v\n", err) WithContext(ctx).
Create(&auditLog).
Error
if err != nil {
log.Printf("Failed to create audit log: %v", err)
return model.AuditLog{} return model.AuditLog{}
} }
@@ -48,25 +55,42 @@ func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent
} }
// CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before // CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before
func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID string) model.AuditLog { func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddress, userAgent, userID string, tx *gorm.DB) model.AuditLog {
createdAuditLog := s.Create(model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}) createdAuditLog := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx)
// Count the number of times the user has logged in from the same device // Count the number of times the user has logged in from the same device
var count int64 var count int64
err := s.db.Model(&model.AuditLog{}).Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent).Count(&count).Error err := tx.
WithContext(ctx).
Model(&model.AuditLog{}).
Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent).
Count(&count).
Error
if err != nil { if err != nil {
log.Printf("Failed to count audit logs: %v\n", err) log.Printf("Failed to count audit logs: %v\n", err)
return createdAuditLog return createdAuditLog
} }
// If the user hasn't logged in from the same device before and email notifications are enabled, send an email // If the user hasn't logged in from the same device before and email notifications are enabled, send an email
if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.IsTrue() && count <= 1 { if s.appConfigService.GetDbConfig().EmailLoginNotificationEnabled.IsTrue() && count <= 1 {
// We use a background context here as this is running in a goroutine
//nolint:contextcheck
go func() { go func() {
var user model.User innerCtx := context.Background()
s.db.Where("id = ?", userID).First(&user)
err := SendEmail(s.emailService, email.Address{ // Note we don't use the transaction here because this is running in background
Name: user.Username, var user model.User
innerErr := s.db.
WithContext(innerCtx).
Where("id = ?", userID).
First(&user).
Error
if innerErr != nil {
log.Printf("Failed to load user: %v", innerErr)
}
innerErr = SendEmail(innerCtx, s.emailService, email.Address{
Name: user.FullName(),
Email: user.Email, Email: user.Email,
}, NewLoginTemplate, &NewLoginTemplateData{ }, NewLoginTemplate, &NewLoginTemplateData{
IPAddress: ipAddress, IPAddress: ipAddress,
@@ -75,8 +99,8 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
Device: s.DeviceStringFromUserAgent(userAgent), Device: s.DeviceStringFromUserAgent(userAgent),
DateTime: createdAuditLog.CreatedAt.UTC(), DateTime: createdAuditLog.CreatedAt.UTC(),
}) })
if err != nil { if innerErr != nil {
log.Printf("Failed to send email to '%s': %v\n", user.Email, err) log.Printf("Failed to send email to '%s': %v", user.Email, innerErr)
} }
}() }()
} }
@@ -85,9 +109,12 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
} }
// ListAuditLogsForUser retrieves all audit logs for a given user ID // ListAuditLogsForUser retrieves all audit logs for a given user ID
func (s *AuditLogService) ListAuditLogsForUser(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.AuditLog, utils.PaginationResponse, error) { func (s *AuditLogService) ListAuditLogsForUser(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.AuditLog, utils.PaginationResponse, error) {
var logs []model.AuditLog var logs []model.AuditLog
query := s.db.Model(&model.AuditLog{}).Where("user_id = ?", userID) query := s.db.
WithContext(ctx).
Model(&model.AuditLog{}).
Where("user_id = ?", userID)
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs) pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs)
return logs, pagination, err return logs, pagination, err
@@ -97,3 +124,99 @@ func (s *AuditLogService) DeviceStringFromUserAgent(userAgent string) string {
ua := userAgentParser.Parse(userAgent) ua := userAgentParser.Parse(userAgent)
return ua.Name + " on " + ua.OS + " " + ua.OSVersion return ua.Name + " on " + ua.OS + " " + ua.OSVersion
} }
func (s *AuditLogService) ListAllAuditLogs(ctx context.Context, sortedPaginationRequest utils.SortedPaginationRequest, filters dto.AuditLogFilterDto) ([]model.AuditLog, utils.PaginationResponse, error) {
var logs []model.AuditLog
query := s.db.
WithContext(ctx).
Preload("User").
Model(&model.AuditLog{})
if filters.UserID != "" {
query = query.Where("user_id = ?", filters.UserID)
}
if filters.Event != "" {
query = query.Where("event = ?", filters.Event)
}
if filters.ClientName != "" {
dialect := s.db.Name()
switch dialect {
case "sqlite":
query = query.Where("json_extract(data, '$.clientName') = ?", filters.ClientName)
case "postgres":
query = query.Where("data->>'clientName' = ?", filters.ClientName)
default:
return nil, utils.PaginationResponse{}, fmt.Errorf("unsupported database dialect: %s", dialect)
}
}
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs)
if err != nil {
return nil, pagination, err
}
return logs, pagination, nil
}
func (s *AuditLogService) ListUsernamesWithIds(ctx context.Context) (users map[string]string, err error) {
query := s.db.
WithContext(ctx).
Joins("User").
Model(&model.AuditLog{}).
Select("DISTINCT \"User\".id, \"User\".username").
Where("\"User\".username IS NOT NULL")
type Result struct {
ID string `gorm:"column:id"`
Username string `gorm:"column:username"`
}
var results []Result
if err := query.Find(&results).Error; err != nil {
return nil, fmt.Errorf("failed to query user IDs: %w", err)
}
users = make(map[string]string, len(results))
for _, result := range results {
users[result.ID] = result.Username
}
return users, nil
}
func (s *AuditLogService) ListClientNames(ctx context.Context) (clientNames []string, err error) {
dialect := s.db.Name()
query := s.db.
WithContext(ctx).
Model(&model.AuditLog{})
switch dialect {
case "sqlite":
query = query.
Select("DISTINCT json_extract(data, '$.clientName') AS client_name").
Where("json_extract(data, '$.clientName') IS NOT NULL")
case "postgres":
query = query.
Select("DISTINCT data->>'clientName' AS client_name").
Where("data->>'clientName' IS NOT NULL")
default:
return nil, fmt.Errorf("unsupported database dialect: %s", dialect)
}
type Result struct {
ClientName string `gorm:"column:client_name"`
}
var results []Result
if err := query.Find(&results).Error; err != nil {
return nil, fmt.Errorf("failed to query client IDs: %w", err)
}
clientNames = make([]string, len(results))
for i, result := range results {
clientNames[i] = result.ClientName
}
return clientNames, nil
}

View File

@@ -1,34 +1,14 @@
package service package service
import ( import (
"context"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"gorm.io/gorm" "gorm.io/gorm"
) )
// Reserved claims
var reservedClaims = map[string]struct{}{
"given_name": {},
"family_name": {},
"name": {},
"email": {},
"preferred_username": {},
"groups": {},
"sub": {},
"iss": {},
"aud": {},
"exp": {},
"iat": {},
"auth_time": {},
"nonce": {},
"acr": {},
"amr": {},
"azp": {},
"nbf": {},
"jti": {},
}
type CustomClaimService struct { type CustomClaimService struct {
db *gorm.DB db *gorm.DB
} }
@@ -39,8 +19,30 @@ func NewCustomClaimService(db *gorm.DB) *CustomClaimService {
// isReservedClaim checks if a claim key is reserved e.g. email, preferred_username // isReservedClaim checks if a claim key is reserved e.g. email, preferred_username
func isReservedClaim(key string) bool { func isReservedClaim(key string) bool {
_, ok := reservedClaims[key] switch key {
return ok case "given_name",
"family_name",
"name",
"email",
"preferred_username",
"groups",
TokenTypeClaim,
"sub",
"iss",
"aud",
"exp",
"iat",
"auth_time",
"nonce",
"acr",
"amr",
"azp",
"nbf",
"jti":
return true
default:
return false
}
} }
// idType is the type of the id used to identify the user or user group // idType is the type of the id used to identify the user or user group
@@ -52,28 +54,37 @@ const (
) )
// UpdateCustomClaimsForUser updates the custom claims for a user // UpdateCustomClaimsForUser updates the custom claims for a user
func (s *CustomClaimService) UpdateCustomClaimsForUser(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { func (s *CustomClaimService) UpdateCustomClaimsForUser(ctx context.Context, userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(UserID, userID, claims) return s.updateCustomClaims(ctx, UserID, userID, claims)
} }
// UpdateCustomClaimsForUserGroup updates the custom claims for a user group // UpdateCustomClaimsForUserGroup updates the custom claims for a user group
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(ctx context.Context, userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(UserGroupID, userGroupID, claims) return s.updateCustomClaims(ctx, UserGroupID, userGroupID, claims)
} }
// updateCustomClaims updates the custom claims for a user or user group // updateCustomClaims updates the custom claims for a user or user group
func (s *CustomClaimService) updateCustomClaims(idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { func (s *CustomClaimService) updateCustomClaims(ctx context.Context, idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
// Check for duplicate keys in the claims slice // Check for duplicate keys in the claims slice
seenKeys := make(map[string]bool) seenKeys := make(map[string]struct{})
for _, claim := range claims { for _, claim := range claims {
if seenKeys[claim.Key] { if _, ok := seenKeys[claim.Key]; ok {
return nil, &common.DuplicateClaimError{Key: claim.Key} return nil, &common.DuplicateClaimError{Key: claim.Key}
} }
seenKeys[claim.Key] = true seenKeys[claim.Key] = struct{}{}
} }
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var existingClaims []model.CustomClaim var existingClaims []model.CustomClaim
err := s.db.Where(string(idType), value).Find(&existingClaims).Error err := tx.
WithContext(ctx).
Where(string(idType), value).
Find(&existingClaims).
Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -87,8 +98,12 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
break break
} }
} }
if !found { if !found {
err = s.db.Delete(&existingClaim).Error err = tx.
WithContext(ctx).
Delete(&existingClaim).
Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -113,7 +128,12 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
} }
// Update the claim if it already exists or create a new one // Update the claim if it already exists or create a new one
err = s.db.Where(string(idType)+" = ? AND key = ?", value, claim.Key).Assign(&customClaim).FirstOrCreate(&model.CustomClaim{}).Error err = tx.
WithContext(ctx).
Where(string(idType)+" = ? AND key = ?", value, claim.Key).
Assign(&customClaim).
FirstOrCreate(&model.CustomClaim{}).
Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -121,7 +141,16 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
// Get the updated claims // Get the updated claims
var updatedClaims []model.CustomClaim var updatedClaims []model.CustomClaim
err = s.db.Where(string(idType)+" = ?", value).Find(&updatedClaims).Error err = tx.
WithContext(ctx).
Where(string(idType)+" = ?", value).
Find(&updatedClaims).
Error
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -129,23 +158,31 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
return updatedClaims, nil return updatedClaims, nil
} }
func (s *CustomClaimService) GetCustomClaimsForUser(userID string) ([]model.CustomClaim, error) { func (s *CustomClaimService) GetCustomClaimsForUser(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
var customClaims []model.CustomClaim var customClaims []model.CustomClaim
err := s.db.Where("user_id = ?", userID).Find(&customClaims).Error err := tx.
WithContext(ctx).
Where("user_id = ?", userID).
Find(&customClaims).
Error
return customClaims, err return customClaims, err
} }
func (s *CustomClaimService) GetCustomClaimsForUserGroup(userGroupID string) ([]model.CustomClaim, error) { func (s *CustomClaimService) GetCustomClaimsForUserGroup(ctx context.Context, userGroupID string, tx *gorm.DB) ([]model.CustomClaim, error) {
var customClaims []model.CustomClaim var customClaims []model.CustomClaim
err := s.db.Where("user_group_id = ?", userGroupID).Find(&customClaims).Error err := tx.
WithContext(ctx).
Where("user_group_id = ?", userGroupID).
Find(&customClaims).
Error
return customClaims, err return customClaims, err
} }
// GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of, // GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of,
// prioritizing the user's claims over user group claims with the same key. // prioritizing the user's claims over user group claims with the same key.
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) ([]model.CustomClaim, error) { func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
// Get the custom claims of the user // Get the custom claims of the user
customClaims, err := s.GetCustomClaimsForUser(userID) customClaims, err := s.GetCustomClaimsForUser(ctx, userID, tx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -158,7 +195,9 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string)
// Get all user groups of the user // Get all user groups of the user
var userGroupsOfUser []model.UserGroup var userGroupsOfUser []model.UserGroup
err = s.db.Preload("CustomClaims"). err = tx.
WithContext(ctx).
Preload("CustomClaims").
Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id"). Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id").
Where("user_groups_users.user_id = ?", userID). Where("user_groups_users.user_id = ?", userID).
Find(&userGroupsOfUser).Error Find(&userGroupsOfUser).Error
@@ -186,10 +225,12 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string)
} }
// GetSuggestions returns a list of custom claim keys that have been used before // GetSuggestions returns a list of custom claim keys that have been used before
func (s *CustomClaimService) GetSuggestions() ([]string, error) { func (s *CustomClaimService) GetSuggestions(ctx context.Context) ([]string, error) {
var customClaimsKeys []string var customClaimsKeys []string
err := s.db.Model(&model.CustomClaim{}). err := s.db.
WithContext(ctx).
Model(&model.CustomClaim{}).
Group("key"). Group("key").
Order("COUNT(*) DESC"). Order("COUNT(*) DESC").
Pluck("key", &customClaimsKeys).Error Pluck("key", &customClaimsKeys).Error

View File

@@ -3,6 +3,7 @@
package service package service
import ( import (
"context"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
@@ -34,6 +35,7 @@ func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService
return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService} return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService}
} }
//nolint:gocognit
func (s *TestService) SeedDatabase() error { func (s *TestService) SeedDatabase() error {
return s.db.Transaction(func(tx *gorm.DB) error { return s.db.Transaction(func(tx *gorm.DB) error {
users := []model.User{ users := []model.User{
@@ -187,11 +189,8 @@ func (s *TestService) SeedDatabase() error {
// openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 | \ // openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 | \
// openssl pkcs8 -topk8 -nocrypt | tee >(openssl pkey -pubout) // openssl pkcs8 -topk8 -nocrypt | tee >(openssl pkey -pubout)
publicKeyPasskey1, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==") publicKeyPasskey1, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
publicKeyPasskey2, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==") publicKeyPasskey2, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==")
if err != nil {
return err
}
webauthnCredentials := []model.WebauthnCredential{ webauthnCredentials := []model.WebauthnCredential{
{ {
Name: "Passkey 1", Name: "Passkey 1",
@@ -301,26 +300,22 @@ func (s *TestService) ResetApplicationImages() error {
return nil return nil
} }
func (s *TestService) ResetAppConfig() error { func (s *TestService) ResetAppConfig(ctx context.Context) error {
// Reseed the config variables // Reset all app config variables to their default values in the database
if err := s.appConfigService.InitDbConfig(); err != nil { err := s.db.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&model.AppConfigVariable{}).Update("value", "").Error
return err if err != nil {
}
// Reset all app config variables to their default values
if err := s.db.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&model.AppConfigVariable{}).Update("value", "").Error; err != nil {
return err return err
} }
// Reload the app config from the database after resetting the values // Reload the app config from the database after resetting the values
return s.appConfigService.LoadDbConfigFromDb() return s.appConfigService.LoadDbConfig(ctx)
} }
func (s *TestService) SetJWTKeys() { func (s *TestService) SetJWTKeys() {
const privateKeyString = `{"alg":"RS256","d":"mvMDWSdPPvcum0c0iEHE2gbqtV2NKMmLwrl9E6K7g8lTV95SePLnW_bwyMPV7EGp7PQk3l17I5XRhFjze7GqTnFIOgKzMianPs7jv2ELtBMGK0xOPATgu1iGb70xZ6vcvuEfRyY3dJ0zr4jpUdVuXwKmx9rK4IdZn2dFCKfvSuspqIpz11RhF1ALrqDLkxGVv7ZwNh0_VhJZU9hcjG5l6xc7rQEKpPRkZp0IdjkGS8Z0FskoVaiRIWAbZuiVFB9WCW8k1czC4HQTPLpII01bUQx2ludbm0UlXRgVU9ptUUbU7GAImQqTOW8LfPGklEvcgzlIlR_oqw4P9yBxLi-yMQ","dp":"pvNCSnnhbo8Igw9psPR-DicxFnkXlu_ix4gpy6efTrxA-z1VDFDioJ814vKQNioYDzpyAP1gfMPhRkvG_q0hRZsJah3Sb9dfA-WkhSWY7lURQP4yIBTMU0PF_rEATuS7lRciYk1SOx5fqXZd3m_LP0vpBC4Ujlq6NAq6CIjCnms","dq":"TtUVGCCkPNgfOLmkYXu7dxxUCV5kB01-xAEK2OY0n0pG8vfDophH4_D_ZC7nvJ8J9uDhs_3JStexq1lIvaWtG99RNTChIEDzpdn6GH9yaVcb_eB4uJjrNm64FhF8PGCCwxA-xMCZMaARKwhMB2_IOMkxUbWboL3gnhJ2rDO_QO0","e":"AQAB","kid":"8uHDw3M6rf8","kty":"RSA","n":"yaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC-585UXacoJ0chUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl_4EDDTO8HwawTjwkPoQlRzeByhlvGPVvwgB3Fn93B8QJ_cZhXKxJvjjrC_8Pk76heC_ntEMru71Ix77BoC3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeOZl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJw","p":"_Yylc9e07CKdqNRD2EosMC2mrhrEa9j5oY_l00Qyy4-jmCA59Q9viyqvveRo0U7cRvFA5BWgWN6GGLh1DG3X-QBqVr0dnk3uzbobb55RYUXyPLuBZI2q6w2oasbiDwPdY7KpkVv_H-bpITQlyDvO8hhucA6rUV7F6KTQVz8M3Ms","q":"y5p3hch-7jJ21TkAhp_Vk1fLCAuD4tbErwQs2of9ja8sB4iJOs5Wn6HD3P7Mc8Plye7qaLHvzc8I5g0tPKWvC0DPd_FLPXiWwMVAzee3NUX_oGeJNOQp11y1w_KqdO9qZqHSEPZ3NcFL_SZMFgggxhM1uzRiPzsVN0lnD_6prZU","qi":"2Grt6uXHm61ji3xSdkBWNtUnj19vS1-7rFJp5SoYztVQVThf_W52BAiXKBdYZDRVoItC_VS2NvAOjeJjhYO_xQ_q3hK7MdtuXfEPpLnyXKkmWo3lrJ26wbeF6l05LexCkI7ShsOuSt-dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI","use":"sig"}` const privateKeyString = `{"alg":"RS256","d":"mvMDWSdPPvcum0c0iEHE2gbqtV2NKMmLwrl9E6K7g8lTV95SePLnW_bwyMPV7EGp7PQk3l17I5XRhFjze7GqTnFIOgKzMianPs7jv2ELtBMGK0xOPATgu1iGb70xZ6vcvuEfRyY3dJ0zr4jpUdVuXwKmx9rK4IdZn2dFCKfvSuspqIpz11RhF1ALrqDLkxGVv7ZwNh0_VhJZU9hcjG5l6xc7rQEKpPRkZp0IdjkGS8Z0FskoVaiRIWAbZuiVFB9WCW8k1czC4HQTPLpII01bUQx2ludbm0UlXRgVU9ptUUbU7GAImQqTOW8LfPGklEvcgzlIlR_oqw4P9yBxLi-yMQ","dp":"pvNCSnnhbo8Igw9psPR-DicxFnkXlu_ix4gpy6efTrxA-z1VDFDioJ814vKQNioYDzpyAP1gfMPhRkvG_q0hRZsJah3Sb9dfA-WkhSWY7lURQP4yIBTMU0PF_rEATuS7lRciYk1SOx5fqXZd3m_LP0vpBC4Ujlq6NAq6CIjCnms","dq":"TtUVGCCkPNgfOLmkYXu7dxxUCV5kB01-xAEK2OY0n0pG8vfDophH4_D_ZC7nvJ8J9uDhs_3JStexq1lIvaWtG99RNTChIEDzpdn6GH9yaVcb_eB4uJjrNm64FhF8PGCCwxA-xMCZMaARKwhMB2_IOMkxUbWboL3gnhJ2rDO_QO0","e":"AQAB","kid":"8uHDw3M6rf8","kty":"RSA","n":"yaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC-585UXacoJ0chUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl_4EDDTO8HwawTjwkPoQlRzeByhlvGPVvwgB3Fn93B8QJ_cZhXKxJvjjrC_8Pk76heC_ntEMru71Ix77BoC3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeOZl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJw","p":"_Yylc9e07CKdqNRD2EosMC2mrhrEa9j5oY_l00Qyy4-jmCA59Q9viyqvveRo0U7cRvFA5BWgWN6GGLh1DG3X-QBqVr0dnk3uzbobb55RYUXyPLuBZI2q6w2oasbiDwPdY7KpkVv_H-bpITQlyDvO8hhucA6rUV7F6KTQVz8M3Ms","q":"y5p3hch-7jJ21TkAhp_Vk1fLCAuD4tbErwQs2of9ja8sB4iJOs5Wn6HD3P7Mc8Plye7qaLHvzc8I5g0tPKWvC0DPd_FLPXiWwMVAzee3NUX_oGeJNOQp11y1w_KqdO9qZqHSEPZ3NcFL_SZMFgggxhM1uzRiPzsVN0lnD_6prZU","qi":"2Grt6uXHm61ji3xSdkBWNtUnj19vS1-7rFJp5SoYztVQVThf_W52BAiXKBdYZDRVoItC_VS2NvAOjeJjhYO_xQ_q3hK7MdtuXfEPpLnyXKkmWo3lrJ26wbeF6l05LexCkI7ShsOuSt-dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI","use":"sig"}`
privateKey, _ := jwk.ParseKey([]byte(privateKeyString)) privateKey, _ := jwk.ParseKey([]byte(privateKeyString))
s.jwtService.SetKey(privateKey) _ = s.jwtService.SetKey(privateKey)
} }
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key // getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key

View File

@@ -2,10 +2,12 @@ package service
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
htemplate "html/template" htemplate "html/template"
"io"
"mime/multipart" "mime/multipart"
"mime/quotedprintable" "mime/quotedprintable"
"net/textproto" "net/textproto"
@@ -17,10 +19,11 @@ import (
"github.com/emersion/go-sasl" "github.com/emersion/go-sasl"
"github.com/emersion/go-smtp" "github.com/emersion/go-smtp"
"github.com/google/uuid" "github.com/google/uuid"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils/email" "github.com/pocket-id/pocket-id/backend/internal/utils/email"
"gorm.io/gorm"
) )
type EmailService struct { type EmailService struct {
@@ -49,22 +52,28 @@ func NewEmailService(appConfigService *AppConfigService, db *gorm.DB) (*EmailSer
}, nil }, nil
} }
func (srv *EmailService) SendTestEmail(recipientUserId string) error { func (srv *EmailService) SendTestEmail(ctx context.Context, recipientUserId string) error {
var user model.User var user model.User
if err := srv.db.First(&user, "id = ?", recipientUserId).Error; err != nil { err := srv.db.
WithContext(ctx).
First(&user, "id = ?", recipientUserId).
Error
if err != nil {
return err return err
} }
return SendEmail(srv, return SendEmail(ctx, srv,
email.Address{ email.Address{
Email: user.Email, Email: user.Email,
Name: user.FullName(), Name: user.FullName(),
}, TestTemplate, nil) }, TestTemplate, nil)
} }
func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error { func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error {
dbConfig := srv.appConfigService.GetDbConfig()
data := &email.TemplateData[V]{ data := &email.TemplateData[V]{
AppName: srv.appConfigService.DbConfig.AppName.Value, AppName: dbConfig.AppName.Value,
LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo", LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo",
Data: tData, Data: tData,
} }
@@ -79,8 +88,8 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
c.AddHeader("Subject", template.Title(data)) c.AddHeader("Subject", template.Title(data))
c.AddAddressHeader("From", []email.Address{ c.AddAddressHeader("From", []email.Address{
{ {
Email: srv.appConfigService.DbConfig.SmtpFrom.Value, Email: dbConfig.SmtpFrom.Value,
Name: srv.appConfigService.DbConfig.AppName.Value, Name: dbConfig.AppName.Value,
}, },
}) })
c.AddAddressHeader("To", []email.Address{toEmail}) c.AddAddressHeader("To", []email.Address{toEmail})
@@ -95,10 +104,10 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
// so we use the domain of the from address instead (the same as Thunderbird does) // so we use the domain of the from address instead (the same as Thunderbird does)
// if the address does not have an @ (which would be unusual), we use hostname // if the address does not have an @ (which would be unusual), we use hostname
from_address := srv.appConfigService.DbConfig.SmtpFrom.Value fromAddress := dbConfig.SmtpFrom.Value
domain := "" domain := ""
if strings.Contains(from_address, "@") { if strings.Contains(fromAddress, "@") {
domain = strings.Split(from_address, "@")[1] domain = strings.Split(fromAddress, "@")[1]
} else { } else {
hostname, err := os.Hostname() hostname, err := os.Hostname()
if err != nil { if err != nil {
@@ -112,6 +121,15 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
c.Body(body) c.Body(body)
// Check if the context is still valid before attemtping to connect
// We need to do this because the smtp library doesn't have context support
select {
case <-ctx.Done():
return ctx.Err()
default:
// All good
}
// Connect to the SMTP server // Connect to the SMTP server
client, err := srv.getSmtpClient() client, err := srv.getSmtpClient()
if err != nil { if err != nil {
@@ -119,6 +137,14 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
} }
defer client.Close() defer client.Close()
// Check if the context is still valid before sending the email
select {
case <-ctx.Done():
return ctx.Err()
default:
// All good
}
// Send the email // Send the email
if err := srv.sendEmailContent(client, toEmail, c); err != nil { if err := srv.sendEmailContent(client, toEmail, c); err != nil {
return fmt.Errorf("send email content: %w", err) return fmt.Errorf("send email content: %w", err)
@@ -128,16 +154,18 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
} }
func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) { func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
port := srv.appConfigService.DbConfig.SmtpPort.Value dbConfig := srv.appConfigService.GetDbConfig()
smtpAddress := srv.appConfigService.DbConfig.SmtpHost.Value + ":" + port
port := dbConfig.SmtpPort.Value
smtpAddress := dbConfig.SmtpHost.Value + ":" + port
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.IsTrue(), //nolint:gosec InsecureSkipVerify: dbConfig.SmtpSkipCertVerify.IsTrue(), //nolint:gosec
ServerName: srv.appConfigService.DbConfig.SmtpHost.Value, ServerName: dbConfig.SmtpHost.Value,
} }
// Connect to the SMTP server based on TLS setting // Connect to the SMTP server based on TLS setting
switch srv.appConfigService.DbConfig.SmtpTls.Value { switch dbConfig.SmtpTls.Value {
case "none": case "none":
client, err = smtp.Dial(smtpAddress) client, err = smtp.Dial(smtpAddress)
case "tls": case "tls":
@@ -148,7 +176,7 @@ func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
tlsConfig, tlsConfig,
) )
default: default:
return nil, fmt.Errorf("invalid SMTP TLS setting: %s", srv.appConfigService.DbConfig.SmtpTls.Value) return nil, fmt.Errorf("invalid SMTP TLS setting: %s", dbConfig.SmtpTls.Value)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to SMTP server: %w", err) return nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
@@ -162,8 +190,8 @@ func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
} }
// Set up the authentication if user or password are set // Set up the authentication if user or password are set
smtpUser := srv.appConfigService.DbConfig.SmtpUser.Value smtpUser := dbConfig.SmtpUser.Value
smtpPassword := srv.appConfigService.DbConfig.SmtpPassword.Value smtpPassword := dbConfig.SmtpPassword.Value
if smtpUser != "" || smtpPassword != "" { if smtpUser != "" || smtpPassword != "" {
// Authenticate with plain auth // Authenticate with plain auth
@@ -199,7 +227,7 @@ func (srv *EmailService) sendHelloCommand(client *smtp.Client) error {
func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Address, c *email.Composer) error { func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Address, c *email.Composer) error {
// Set the sender // Set the sender
if err := client.Mail(srv.appConfigService.DbConfig.SmtpFrom.Value, nil); err != nil { if err := client.Mail(srv.appConfigService.GetDbConfig().SmtpFrom.Value, nil); err != nil {
return fmt.Errorf("failed to set sender: %w", err) return fmt.Errorf("failed to set sender: %w", err)
} }
@@ -215,7 +243,7 @@ func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Add
} }
// Write the email content // Write the email content
_, err = w.Write([]byte(c.String())) _, err = io.Copy(w, strings.NewReader(c.String()))
if err != nil { if err != nil {
return fmt.Errorf("failed to write email data: %w", err) return fmt.Errorf("failed to write email data: %w", err)
} }

View File

@@ -42,6 +42,13 @@ var TestTemplate = email.Template[struct{}]{
}, },
} }
var ApiKeyExpiringSoonTemplate = email.Template[ApiKeyExpiringSoonTemplateData]{
Path: "api-key-expiring-soon",
Title: func(data *email.TemplateData[ApiKeyExpiringSoonTemplateData]) string {
return fmt.Sprintf("API Key \"%s\" Expiring Soon", data.Data.ApiKeyName)
},
}
type NewLoginTemplateData struct { type NewLoginTemplateData struct {
IPAddress string IPAddress string
Country string Country string
@@ -54,7 +61,14 @@ type OneTimeAccessTemplateData = struct {
Code string Code string
LoginLink string LoginLink string
LoginLinkWithCode string LoginLinkWithCode string
ExpirationString string
}
type ApiKeyExpiringSoonTemplateData struct {
Name string
ApiKeyName string
ExpiresAt time.Time
} }
// this is list of all template paths used for preloading templates // this is list of all template paths used for preloading templates
var emailTemplatesPaths = []string{NewLoginTemplate.Path, OneTimeAccessTemplate.Path, TestTemplate.Path} var emailTemplatesPaths = []string{NewLoginTemplate.Path, OneTimeAccessTemplate.Path, TestTemplate.Path, ApiKeyExpiringSoonTemplate.Path}

View File

@@ -42,7 +42,7 @@ var tailscaleIPNets = []*net.IPNet{
} }
// NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database. // NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database.
func NewGeoLiteService() *GeoLiteService { func NewGeoLiteService(ctx context.Context) *GeoLiteService {
service := &GeoLiteService{} service := &GeoLiteService{}
if common.EnvConfig.MaxMindLicenseKey == "" && common.EnvConfig.GeoLiteDBUrl == common.MaxMindGeoLiteCityUrl { if common.EnvConfig.MaxMindLicenseKey == "" && common.EnvConfig.GeoLiteDBUrl == common.MaxMindGeoLiteCityUrl {
@@ -52,8 +52,9 @@ func NewGeoLiteService() *GeoLiteService {
} }
go func() { go func() {
if err := service.updateDatabase(); err != nil { err := service.updateDatabase(ctx)
log.Printf("Failed to update GeoLite2 City database: %v\n", err) if err != nil {
log.Printf("Failed to update GeoLite2 City database: %v", err)
} }
}() }()
@@ -111,7 +112,7 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
} }
// UpdateDatabase checks the age of the database and updates it if it's older than 14 days. // UpdateDatabase checks the age of the database and updates it if it's older than 14 days.
func (s *GeoLiteService) updateDatabase() error { func (s *GeoLiteService) updateDatabase(parentCtx context.Context) error {
if s.disableUpdater { if s.disableUpdater {
// Avoid updating the GeoLite2 City database. // Avoid updating the GeoLite2 City database.
return nil return nil
@@ -125,7 +126,7 @@ func (s *GeoLiteService) updateDatabase() error {
log.Println("Updating GeoLite2 City database...") log.Println("Updating GeoLite2 City database...")
downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey) downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) ctx, cancel := context.WithTimeout(parentCtx, 10*time.Minute)
defer cancel() defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil)

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"encoding/base64" "encoding/base64"
@@ -37,6 +38,18 @@ const (
// This may be omitted on non-admin tokens // This may be omitted on non-admin tokens
IsAdminClaim = "isAdmin" IsAdminClaim = "isAdmin"
// TokenTypeClaim is the claim used to identify the type of token
TokenTypeClaim = "type"
// OAuthAccessTokenJWTType identifies a JWT as an OAuth access token
OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec
// AccessTokenJWTType identifies a JWT as an access token used by Pocket ID
AccessTokenJWTType = "access-token"
// IDTokenJWTType identifies a JWT as an ID token used by Pocket ID
IDTokenJWTType = "id-token"
// Acceptable clock skew for verifying tokens // Acceptable clock skew for verifying tokens
clockSkew = time.Minute clockSkew = time.Minute
) )
@@ -173,7 +186,7 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
now := time.Now() now := time.Now()
token, err := jwt.NewBuilder(). token, err := jwt.NewBuilder().
Subject(user.ID). Subject(user.ID).
Expiration(now.Add(s.appConfigService.DbConfig.SessionDuration.AsDurationMinutes())). Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())).
IssuedAt(now). IssuedAt(now).
Issuer(common.EnvConfig.AppURL). Issuer(common.EnvConfig.AppURL).
Build() Build()
@@ -186,6 +199,11 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err) return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
} }
err = SetTokenType(token, AccessTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
err = SetIsAdmin(token, user.IsAdmin) err = SetIsAdmin(token, user.IsAdmin)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to set 'isAdmin' claim in token: %w", err) return "", fmt.Errorf("failed to set 'isAdmin' claim in token: %w", err)
@@ -209,6 +227,7 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
jwt.WithAcceptableSkew(clockSkew), jwt.WithAcceptableSkew(clockSkew),
jwt.WithAudience(common.EnvConfig.AppURL), jwt.WithAudience(common.EnvConfig.AppURL),
jwt.WithIssuer(common.EnvConfig.AppURL), jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(AccessTokenJWTType)),
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err) return nil, fmt.Errorf("failed to parse token: %w", err)
@@ -233,6 +252,11 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string,
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err) return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
} }
err = SetTokenType(token, IDTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
for k, v := range userClaims { for k, v := range userClaims {
err = token.Set(k, v) err = token.Set(k, v)
if err != nil { if err != nil {
@@ -267,6 +291,7 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
jwt.WithKey(alg, s.privateKey), jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew), jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL), jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(IDTokenJWTType)),
) )
// By default, jwt.Parse includes 3 default validators for "nbf", "iat", and "exp" // By default, jwt.Parse includes 3 default validators for "nbf", "iat", and "exp"
@@ -305,6 +330,11 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err) return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
} }
err = SetTokenType(token, OAuthAccessTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
alg, _ := s.privateKey.Algorithm() alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey)) signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
if err != nil { if err != nil {
@@ -322,6 +352,7 @@ func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, erro
jwt.WithKey(alg, s.privateKey), jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew), jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL), jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)),
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err) return nil, fmt.Errorf("failed to parse token: %w", err)
@@ -481,6 +512,14 @@ func GetIsAdmin(token jwt.Token) (bool, error) {
return isAdmin, err return isAdmin, err
} }
// SetTokenType sets the "type" claim in the token
func SetTokenType(token jwt.Token, tokenType string) error {
if tokenType == "" {
return nil
}
return token.Set(TokenTypeClaim, tokenType)
}
// SetIsAdmin sets the "isAdmin" claim in the token // SetIsAdmin sets the "isAdmin" claim in the token
func SetIsAdmin(token jwt.Token, isAdmin bool) error { func SetIsAdmin(token jwt.Token, isAdmin bool) error {
// Only set if true // Only set if true
@@ -495,3 +534,18 @@ func SetIsAdmin(token jwt.Token, isAdmin bool) error {
func SetAudienceString(token jwt.Token, audience string) error { func SetAudienceString(token jwt.Token, audience string) error {
return token.Set(jwt.AudienceKey, audience) return token.Set(jwt.AudienceKey, audience)
} }
// TokenTypeValidator is a validator function that checks the "type" claim in the token
func TokenTypeValidator(expectedTokenType string) jwt.ValidatorFunc {
return func(_ context.Context, t jwt.Token) error {
var tokenType string
err := t.Get(TokenTypeClaim, &tokenType)
if err != nil {
return fmt.Errorf("failed to get token type claim: %w", err)
}
if tokenType != expectedTokenType {
return fmt.Errorf("invalid token type: expected %s, got %s", expectedTokenType, tokenType)
}
return nil
}
}

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"context"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/ed25519" "crypto/ed25519"
"crypto/elliptic" "crypto/elliptic"
@@ -23,11 +24,9 @@ import (
) )
func TestJwtService_Init(t *testing.T) { func TestJwtService_Init(t *testing.T) {
mockConfig := &AppConfigService{ mockConfig := NewTestAppConfigService(&model.AppConfig{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
}, })
}
t.Run("should generate new key when none exists", func(t *testing.T) { t.Run("should generate new key when none exists", func(t *testing.T) {
// Create a temporary directory for the test // Create a temporary directory for the test
@@ -139,11 +138,9 @@ func TestJwtService_Init(t *testing.T) {
} }
func TestJwtService_GetPublicJWK(t *testing.T) { func TestJwtService_GetPublicJWK(t *testing.T) {
mockConfig := &AppConfigService{ mockConfig := NewTestAppConfigService(&model.AppConfig{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
}, })
}
t.Run("returns public key when private key is initialized", func(t *testing.T) { t.Run("returns public key when private key is initialized", func(t *testing.T) {
// Create a temporary directory for the test // Create a temporary directory for the test
@@ -273,11 +270,9 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
tempDir := t.TempDir() tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService // Initialize the JWT service with a mock AppConfigService
mockConfig := &AppConfigService{ mockConfig := NewTestAppConfigService(&model.AppConfig{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
}, })
}
// Setup the environment variable required by the token verification // Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL originalAppURL := common.EnvConfig.AppURL
@@ -363,11 +358,9 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
t.Run("uses session duration from config", func(t *testing.T) { t.Run("uses session duration from config", func(t *testing.T) {
// Create a JWT service with a different session duration // Create a JWT service with a different session duration
customMockConfig := &AppConfigService{ customMockConfig := NewTestAppConfigService(&model.AppConfig{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "30"}, // 30 minutes SessionDuration: model.AppConfigVariable{Value: "30"}, // 30 minutes
}, })
}
service := &JwtService{} service := &JwtService{}
err := service.init(customMockConfig, tempDir) err := service.init(customMockConfig, tempDir)
@@ -564,11 +557,9 @@ func TestGenerateVerifyIdToken(t *testing.T) {
tempDir := t.TempDir() tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService // Initialize the JWT service with a mock AppConfigService
mockConfig := &AppConfigService{ mockConfig := NewTestAppConfigService(&model.AppConfig{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
}, })
}
// Setup the environment variable required by the token verification // Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL originalAppURL := common.EnvConfig.AppURL
@@ -643,6 +634,9 @@ func TestGenerateVerifyIdToken(t *testing.T) {
Build() Build()
require.NoError(t, err, "Failed to build token") require.NoError(t, err, "Failed to build token")
err = SetTokenType(token, IDTokenJWTType)
require.NoError(t, err, "Failed to set token type")
// Add custom claims // Add custom claims
for k, v := range userClaims { for k, v := range userClaims {
if k != "sub" { // Already set above if k != "sub" { // Already set above
@@ -892,11 +886,9 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
tempDir := t.TempDir() tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService // Initialize the JWT service with a mock AppConfigService
mockConfig := &AppConfigService{ mockConfig := NewTestAppConfigService(&model.AppConfig{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
}, })
}
// Setup the environment variable required by the token verification // Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL originalAppURL := common.EnvConfig.AppURL
@@ -972,6 +964,9 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
Build() Build()
require.NoError(t, err, "Failed to build token") require.NoError(t, err, "Failed to build token")
err = SetTokenType(token, OAuthAccessTokenJWTType)
require.NoError(t, err, "Failed to set token type")
signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey)) signed, err := jwt.Sign(token, jwt.WithKey(jwa.RS256(), service.privateKey))
require.NoError(t, err, "Failed to sign token") require.NoError(t, err, "Failed to sign token")
@@ -1172,6 +1167,54 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
}) })
} }
func TestTokenTypeValidator(t *testing.T) {
// Create a context for the validator function
ctx := context.Background()
t.Run("succeeds when token type matches expected type", func(t *testing.T) {
// Create a token with the expected type
token := jwt.New()
err := token.Set(TokenTypeClaim, AccessTokenJWTType)
require.NoError(t, err, "Failed to set token type claim")
// Create a validator function for the expected type
validator := TokenTypeValidator(AccessTokenJWTType)
// Validate the token
err = validator(ctx, token)
assert.NoError(t, err, "Validator should accept token with matching type")
})
t.Run("fails when token type doesn't match expected type", func(t *testing.T) {
// Create a token with a different type
token := jwt.New()
err := token.Set(TokenTypeClaim, OAuthAccessTokenJWTType)
require.NoError(t, err, "Failed to set token type claim")
// Create a validator function for a different expected type
validator := TokenTypeValidator(IDTokenJWTType)
// Validate the token
err = validator(ctx, token)
require.Error(t, err, "Validator should reject token with non-matching type")
assert.Contains(t, err.Error(), "invalid token type: expected id-token, got oauth-access-token")
})
t.Run("fails when token type claim is missing", func(t *testing.T) {
// Create a token without a type claim
token := jwt.New()
// Create a validator function
validator := TokenTypeValidator(AccessTokenJWTType)
// Validate the token
err := validator(ctx, token)
require.Error(t, err, "Validator should reject token without type claim")
assert.Contains(t, err.Error(), "failed to get token type claim")
})
}
func importKey(t *testing.T, privateKeyRaw any, path string) string { func importKey(t *testing.T, privateKeyRaw any, path string) string {
t.Helper() t.Helper()

View File

@@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/go-ldap/ldap/v3" "github.com/go-ldap/ldap/v3"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"gorm.io/gorm" "gorm.io/gorm"
@@ -28,47 +29,44 @@ type LdapService struct {
} }
func NewLdapService(db *gorm.DB, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService) *LdapService { func NewLdapService(db *gorm.DB, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService) *LdapService {
return &LdapService{db: db, appConfigService: appConfigService, userService: userService, groupService: groupService} return &LdapService{
db: db,
appConfigService: appConfigService,
userService: userService,
groupService: groupService,
}
} }
func (s *LdapService) createClient() (*ldap.Conn, error) { func (s *LdapService) createClient() (*ldap.Conn, error) {
if !s.appConfigService.DbConfig.LdapEnabled.IsTrue() { dbConfig := s.appConfigService.GetDbConfig()
if !dbConfig.LdapEnabled.IsTrue() {
return nil, fmt.Errorf("LDAP is not enabled") return nil, fmt.Errorf("LDAP is not enabled")
} }
// Setup LDAP connection // Setup LDAP connection
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value client, err := ldap.DialURL(dbConfig.LdapUrl.Value, ldap.DialWithTLSConfig(&tls.Config{
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.IsTrue() InsecureSkipVerify: dbConfig.LdapSkipCertVerify.IsTrue(), //nolint:gosec
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec }))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to LDAP: %w", err) return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
} }
// Bind as service account // Bind as service account
bindDn := s.appConfigService.DbConfig.LdapBindDn.Value err = client.Bind(dbConfig.LdapBindDn.Value, dbConfig.LdapBindPassword.Value)
bindPassword := s.appConfigService.DbConfig.LdapBindPassword.Value
err = client.Bind(bindDn, bindPassword)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to bind to LDAP: %w", err) return nil, fmt.Errorf("failed to bind to LDAP: %w", err)
} }
return client, nil return client, nil
} }
func (s *LdapService) SyncAll() error { func (s *LdapService) SyncAll(ctx context.Context) error {
err := s.SyncUsers() // Start a transaction
if err != nil { tx := s.db.Begin()
return fmt.Errorf("failed to sync users: %w", err) defer func() {
} tx.Rollback()
}()
err = s.SyncGroups()
if err != nil {
return fmt.Errorf("failed to sync groups: %w", err)
}
return nil
}
//nolint:gocognit
func (s *LdapService) SyncGroups() error {
// Setup LDAP connection // Setup LDAP connection
client, err := s.createClient() client, err := s.createClient()
if err != nil { if err != nil {
@@ -76,262 +74,373 @@ func (s *LdapService) SyncGroups() error {
} }
defer client.Close() defer client.Close()
baseDN := s.appConfigService.DbConfig.LdapBase.Value err = s.SyncUsers(ctx, tx, client)
nameAttribute := s.appConfigService.DbConfig.LdapAttributeGroupName.Value if err != nil {
uniqueIdentifierAttribute := s.appConfigService.DbConfig.LdapAttributeGroupUniqueIdentifier.Value return fmt.Errorf("failed to sync users: %w", err)
groupMemberOfAttribute := s.appConfigService.DbConfig.LdapAttributeGroupMember.Value
filter := s.appConfigService.DbConfig.LdapUserGroupSearchFilter.Value
searchAttrs := []string{
nameAttribute,
uniqueIdentifierAttribute,
groupMemberOfAttribute,
} }
searchReq := ldap.NewSearchRequest(baseDN, ldap.ScopeWholeSubtree, 0, 0, 0, false, filter, searchAttrs, []ldap.Control{}) err = s.SyncGroups(ctx, tx, client)
if err != nil {
return fmt.Errorf("failed to sync groups: %w", err)
}
// Commit the changes
err = tx.Commit().Error
if err != nil {
return fmt.Errorf("failed to commit changes to database: %w", err)
}
return nil
}
//nolint:gocognit
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
dbConfig := s.appConfigService.GetDbConfig()
searchAttrs := []string{
dbConfig.LdapAttributeGroupName.Value,
dbConfig.LdapAttributeGroupUniqueIdentifier.Value,
dbConfig.LdapAttributeGroupMember.Value,
}
searchReq := ldap.NewSearchRequest(
dbConfig.LdapBase.Value,
ldap.ScopeWholeSubtree,
0, 0, 0, false,
dbConfig.LdapUserGroupSearchFilter.Value,
searchAttrs,
[]ldap.Control{},
)
result, err := client.Search(searchReq) result, err := client.Search(searchReq)
if err != nil { if err != nil {
return fmt.Errorf("failed to query LDAP: %w", err) return fmt.Errorf("failed to query LDAP: %w", err)
} }
// Create a mapping for groups that exist // Create a mapping for groups that exist
ldapGroupIDs := make(map[string]bool) ldapGroupIDs := make(map[string]struct{}, len(result.Entries))
for _, value := range result.Entries { for _, value := range result.Entries {
var membersUserId []string ldapId := value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value)
ldapId := value.GetAttributeValue(uniqueIdentifierAttribute)
// Skip groups without a valid LDAP ID // Skip groups without a valid LDAP ID
if ldapId == "" { if ldapId == "" {
log.Printf("Skipping LDAP group without a valid unique identifier (attribute: %s)", uniqueIdentifierAttribute) log.Printf("Skipping LDAP group without a valid unique identifier (attribute: %s)", dbConfig.LdapAttributeGroupUniqueIdentifier.Value)
continue continue
} }
ldapGroupIDs[ldapId] = true ldapGroupIDs[ldapId] = struct{}{}
// Try to find the group in the database // Try to find the group in the database
var databaseGroup model.UserGroup var databaseGroup model.UserGroup
s.db.Where("ldap_id = ?", ldapId).First(&databaseGroup) err = tx.
WithContext(ctx).
Where("ldap_id = ?", ldapId).
First(&databaseGroup).
Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
// This could error with ErrRecordNotFound and we want to ignore that here
return fmt.Errorf("failed to query for LDAP group ID '%s': %w", ldapId, err)
}
// Get group members and add to the correct Group // Get group members and add to the correct Group
groupMembers := value.GetAttributeValues(groupMemberOfAttribute) groupMembers := value.GetAttributeValues(dbConfig.LdapAttributeGroupMember.Value)
membersUserId := make([]string, 0, len(groupMembers))
for _, member := range groupMembers { for _, member := range groupMembers {
// Normal output of this would be CN=username,ou=people,dc=example,dc=com ldapId := getDNProperty("uid", member)
// Splitting at the "=" and "," then just grabbing the username for that string if ldapId == "" {
singleMember := strings.Split(strings.Split(member, "=")[1], ",")[0] continue
}
var databaseUser model.User var databaseUser model.User
err := s.db.Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error err = tx.
if err != nil { WithContext(ctx).
Where("username = ? AND ldap_id IS NOT NULL", ldapId).
First(&databaseUser).
Error
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
// The user collides with a non-LDAP user, so we skip it // The user collides with a non-LDAP user, so we skip it
continue continue
} else { } else if err != nil {
return err return fmt.Errorf("failed to query for existing user '%s': %w", ldapId, err)
}
} }
membersUserId = append(membersUserId, databaseUser.ID) membersUserId = append(membersUserId, databaseUser.ID)
} }
syncGroup := dto.UserGroupCreateDto{ syncGroup := dto.UserGroupCreateDto{
Name: value.GetAttributeValue(nameAttribute), Name: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
FriendlyName: value.GetAttributeValue(nameAttribute), FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
LdapID: value.GetAttributeValue(uniqueIdentifierAttribute), LdapID: value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value),
} }
if databaseGroup.ID == "" { if databaseGroup.ID == "" {
newGroup, err := s.groupService.Create(syncGroup) newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
if err != nil { if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err) return fmt.Errorf("failed to create group '%s': %w", syncGroup.Name, err)
} else {
if _, err = s.groupService.UpdateUsers(newGroup.ID, membersUserId); err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
}
}
} else {
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
}
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
return err
} }
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx)
if err != nil {
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
}
} else {
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx)
if err != nil {
return fmt.Errorf("failed to update group '%s': %w", syncGroup.Name, err)
} }
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx)
if err != nil {
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
}
}
} }
// Get all LDAP groups from the database // Get all LDAP groups from the database
var ldapGroupsInDb []model.UserGroup var ldapGroupsInDb []model.UserGroup
if err := s.db.Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil { err = tx.
fmt.Println(fmt.Errorf("failed to fetch groups from database: %w", err)) WithContext(ctx).
Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").
Select("ldap_id").
Error
if err != nil {
return fmt.Errorf("failed to fetch groups from database: %w", err)
} }
// Delete groups that no longer exist in LDAP // Delete groups that no longer exist in LDAP
for _, group := range ldapGroupsInDb { for _, group := range ldapGroupsInDb {
if _, exists := ldapGroupIDs[*group.LdapID]; !exists { if _, exists := ldapGroupIDs[*group.LdapID]; exists {
if err := s.db.Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).Error; err != nil { continue
log.Printf("Failed to delete group %s with: %v", group.Name, err)
} else {
log.Printf("Deleted group %s", group.Name)
} }
err = tx.
WithContext(ctx).
Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).
Error
if err != nil {
return fmt.Errorf("failed to delete group '%s': %w", group.Name, err)
} }
log.Printf("Deleted group '%s'", group.Name)
} }
return nil return nil
} }
//nolint:gocognit //nolint:gocognit
func (s *LdapService) SyncUsers() error { func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
// Setup LDAP connection dbConfig := s.appConfigService.GetDbConfig()
client, err := s.createClient()
if err != nil {
return fmt.Errorf("failed to create LDAP client: %w", err)
}
defer client.Close()
baseDN := s.appConfigService.DbConfig.LdapBase.Value
uniqueIdentifierAttribute := s.appConfigService.DbConfig.LdapAttributeUserUniqueIdentifier.Value
usernameAttribute := s.appConfigService.DbConfig.LdapAttributeUserUsername.Value
emailAttribute := s.appConfigService.DbConfig.LdapAttributeUserEmail.Value
firstNameAttribute := s.appConfigService.DbConfig.LdapAttributeUserFirstName.Value
lastNameAttribute := s.appConfigService.DbConfig.LdapAttributeUserLastName.Value
profilePictureAttribute := s.appConfigService.DbConfig.LdapAttributeUserProfilePicture.Value
adminGroupAttribute := s.appConfigService.DbConfig.LdapAttributeAdminGroup.Value
filter := s.appConfigService.DbConfig.LdapUserSearchFilter.Value
searchAttrs := []string{ searchAttrs := []string{
"memberOf", "memberOf",
"sn", "sn",
"cn", "cn",
uniqueIdentifierAttribute, dbConfig.LdapAttributeUserUniqueIdentifier.Value,
usernameAttribute, dbConfig.LdapAttributeUserUsername.Value,
emailAttribute, dbConfig.LdapAttributeUserEmail.Value,
firstNameAttribute, dbConfig.LdapAttributeUserFirstName.Value,
lastNameAttribute, dbConfig.LdapAttributeUserLastName.Value,
profilePictureAttribute, dbConfig.LdapAttributeUserProfilePicture.Value,
} }
// Filters must start and finish with ()! // Filters must start and finish with ()!
searchReq := ldap.NewSearchRequest(baseDN, ldap.ScopeWholeSubtree, 0, 0, 0, false, filter, searchAttrs, []ldap.Control{}) searchReq := ldap.NewSearchRequest(
dbConfig.LdapBase.Value,
ldap.ScopeWholeSubtree,
0, 0, 0, false,
dbConfig.LdapUserSearchFilter.Value,
searchAttrs,
[]ldap.Control{},
)
result, err := client.Search(searchReq) result, err := client.Search(searchReq)
if err != nil { if err != nil {
fmt.Println(fmt.Errorf("failed to query LDAP: %w", err)) return fmt.Errorf("failed to query LDAP: %w", err)
} }
// Create a mapping for users that exist // Create a mapping for users that exist
ldapUserIDs := make(map[string]bool) ldapUserIDs := make(map[string]struct{}, len(result.Entries))
for _, value := range result.Entries { for _, value := range result.Entries {
ldapId := value.GetAttributeValue(uniqueIdentifierAttribute) ldapId := value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value)
// Skip users without a valid LDAP ID // Skip users without a valid LDAP ID
if ldapId == "" { if ldapId == "" {
log.Printf("Skipping LDAP user without a valid unique identifier (attribute: %s)", uniqueIdentifierAttribute) log.Printf("Skipping LDAP user without a valid unique identifier (attribute: %s)", dbConfig.LdapAttributeUserUniqueIdentifier.Value)
continue continue
} }
ldapUserIDs[ldapId] = true ldapUserIDs[ldapId] = struct{}{}
// Get the user from the database // Get the user from the database
var databaseUser model.User var databaseUser model.User
s.db.Where("ldap_id = ?", ldapId).First(&databaseUser) err = tx.
WithContext(ctx).
Where("ldap_id = ?", ldapId).
First(&databaseUser).
Error
// If a user is found (even if disabled), enable them since they're now back in LDAP
if databaseUser.ID != "" && databaseUser.Disabled {
// Use the transaction instead of the direct context
err = tx.
WithContext(ctx).
Model(&model.User{}).
Where("id = ?", databaseUser.ID).
Update("disabled", false).
Error
if err != nil {
log.Printf("Failed to enable user %s: %v", databaseUser.Username, err)
}
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
// This could error with ErrRecordNotFound and we want to ignore that here
return fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
}
// Check if user is admin by checking if they are in the admin group // Check if user is admin by checking if they are in the admin group
isAdmin := false isAdmin := false
for _, group := range value.GetAttributeValues("memberOf") { for _, group := range value.GetAttributeValues("memberOf") {
if strings.Contains(group, adminGroupAttribute) { if getDNProperty("cn", group) == dbConfig.LdapAttributeAdminGroup.Value {
isAdmin = true isAdmin = true
break
} }
} }
newUser := dto.UserCreateDto{ newUser := dto.UserCreateDto{
Username: value.GetAttributeValue(usernameAttribute), Username: value.GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value),
Email: value.GetAttributeValue(emailAttribute), Email: value.GetAttributeValue(dbConfig.LdapAttributeUserEmail.Value),
FirstName: value.GetAttributeValue(firstNameAttribute), FirstName: value.GetAttributeValue(dbConfig.LdapAttributeUserFirstName.Value),
LastName: value.GetAttributeValue(lastNameAttribute), LastName: value.GetAttributeValue(dbConfig.LdapAttributeUserLastName.Value),
IsAdmin: isAdmin, IsAdmin: isAdmin,
LdapID: ldapId, LdapID: ldapId,
} }
if databaseUser.ID == "" { if databaseUser.ID == "" {
_, err = s.userService.CreateUser(newUser) _, err = s.userService.createUserInternal(ctx, newUser, true, tx)
if err != nil { if errors.Is(err, &common.AlreadyInUseError{}) {
log.Printf("Error syncing user %s: %s", newUser.Username, err) log.Printf("Skipping creating LDAP user '%s': %v", newUser.Username, err)
continue
} else if err != nil {
return fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
} }
} else { } else {
_, err = s.userService.UpdateUser(databaseUser.ID, newUser, false, true) _, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
if err != nil { if errors.Is(err, &common.AlreadyInUseError{}) {
log.Printf("Error syncing user %s: %s", newUser.Username, err) log.Printf("Skipping updating LDAP user '%s': %v", newUser.Username, err)
continue
} else if err != nil {
return fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
} }
} }
// Save profile picture // Save profile picture
if pictureString := value.GetAttributeValue(profilePictureAttribute); pictureString != "" { pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value)
if err := s.SaveProfilePicture(databaseUser.ID, pictureString); err != nil { if pictureString != "" {
log.Printf("Error saving profile picture for user %s: %s", newUser.Username, err) err = s.saveProfilePicture(ctx, databaseUser.ID, pictureString)
if err != nil {
// This is not a fatal error
log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err)
} }
} }
} }
// Get all LDAP users from the database // Get all LDAP users from the database
var ldapUsersInDb []model.User var ldapUsersInDb []model.User
if err := s.db.Find(&ldapUsersInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil { err = tx.
fmt.Println(fmt.Errorf("failed to fetch users from database: %w", err)) WithContext(ctx).
Find(&ldapUsersInDb, "ldap_id IS NOT NULL").
Select("id, username, ldap_id, disabled").
Error
if err != nil {
return fmt.Errorf("failed to fetch users from database: %w", err)
} }
// Delete users that no longer exist in LDAP // Mark users as disabled or delete users that no longer exist in LDAP
for _, user := range ldapUsersInDb { for _, user := range ldapUsersInDb {
if _, exists := ldapUserIDs[*user.LdapID]; !exists { // Skip if the user ID exists in the fetched LDAP results
if err := s.userService.DeleteUser(user.ID, true); err != nil { if _, exists := ldapUserIDs[*user.LdapID]; exists {
log.Printf("Failed to delete user %s with: %v", user.Username, err) continue
}
if dbConfig.LdapSoftDeleteUsers.IsTrue() {
err = s.userService.disableUserInternal(ctx, user.ID, tx)
if err != nil {
return fmt.Errorf("failed to disable user %s: %w", user.Username, err)
}
log.Printf("Disabled user '%s'", user.Username)
} else { } else {
log.Printf("Deleted user %s", user.Username) err = s.userService.deleteUserInternal(ctx, user.ID, true, tx)
} target := &common.LdapUserUpdateError{}
if errors.As(err, &target) {
return fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username)
} else if err != nil {
return fmt.Errorf("failed to delete user %s: %w", user.Username, err)
}
log.Printf("Deleted user '%s'", user.Username)
} }
} }
return nil return nil
} }
func (s *LdapService) SaveProfilePicture(userId string, pictureString string) error { func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
var reader io.Reader var reader io.Reader
if _, err := url.ParseRequestURI(pictureString); err == nil { _, err := url.ParseRequestURI(pictureString)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) if err == nil {
ctx, cancel := context.WithTimeout(parentCtx, 5*time.Second)
defer cancel() defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil) var req *http.Request
req, err = http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
if err != nil { if err != nil {
return fmt.Errorf("failed to create request: %w", err) return fmt.Errorf("failed to create request: %w", err)
} }
response, err := http.DefaultClient.Do(req) var res *http.Response
res, err = http.DefaultClient.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("failed to download profile picture: %w", err) return fmt.Errorf("failed to download profile picture: %w", err)
} }
defer response.Body.Close() defer res.Body.Close()
reader = response.Body
reader = res.Body
} else if decodedPhoto, err := base64.StdEncoding.DecodeString(pictureString); err == nil { } else if decodedPhoto, err := base64.StdEncoding.DecodeString(pictureString); err == nil {
// If the photo is a base64 encoded string, decode it // If the photo is a base64 encoded string, decode it
reader = bytes.NewReader(decodedPhoto) reader = bytes.NewReader(decodedPhoto)
} else { } else {
// If the photo is a string, we assume that it's a binary string // If the photo is a string, we assume that it's a binary string
reader = bytes.NewReader([]byte(pictureString)) reader = bytes.NewReader([]byte(pictureString))
} }
// Update the profile picture // Update the profile picture
if err := s.userService.UpdateProfilePicture(userId, reader); err != nil { err = s.userService.UpdateProfilePicture(userId, reader)
if err != nil {
return fmt.Errorf("failed to update profile picture: %w", err) return fmt.Errorf("failed to update profile picture: %w", err)
} }
return nil return nil
} }
// getDNProperty returns the value of a property from a LDAP identifier
// See: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/ldap/distinguished-names
func getDNProperty(property string, str string) string {
// Example format is "CN=username,ou=people,dc=example,dc=com"
// First we split at the comma
property = strings.ToLower(property)
l := len(property) + 1
for _, v := range strings.Split(str, ",") {
v = strings.TrimSpace(v)
if len(v) > l && strings.ToLower(v)[0:l] == property+"=" {
return v[l:]
}
}
// CN not found, return an empty string
return ""
}

View File

@@ -0,0 +1,73 @@
package service
import (
"testing"
)
func TestGetDNProperty(t *testing.T) {
tests := []struct {
name string
property string
dn string
expectedResult string
}{
{
name: "simple case",
property: "cn",
dn: "cn=username,ou=people,dc=example,dc=com",
expectedResult: "username",
},
{
name: "property not found",
property: "uid",
dn: "cn=username,ou=people,dc=example,dc=com",
expectedResult: "",
},
{
name: "mixed case property",
property: "CN",
dn: "cn=username,ou=people,dc=example,dc=com",
expectedResult: "username",
},
{
name: "mixed case DN",
property: "cn",
dn: "CN=username,OU=people,DC=example,DC=com",
expectedResult: "username",
},
{
name: "spaces in DN",
property: "cn",
dn: "cn=username, ou=people, dc=example, dc=com",
expectedResult: "username",
},
{
name: "value with special characters",
property: "cn",
dn: "cn=user.name+123,ou=people,dc=example,dc=com",
expectedResult: "user.name+123",
},
{
name: "empty DN",
property: "cn",
dn: "",
expectedResult: "",
},
{
name: "empty property",
property: "",
dn: "cn=username,ou=people,dc=example,dc=com",
expectedResult: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getDNProperty(tt.property, tt.dn)
if result != tt.expectedResult {
t.Errorf("getDNProperty(%q, %q) = %q, want %q",
tt.property, tt.dn, result, tt.expectedResult)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,15 @@
package service package service
import ( import (
"context"
"errors" "errors"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
) )
type UserGroupService struct { type UserGroupService struct {
@@ -19,8 +21,11 @@ func NewUserGroupService(db *gorm.DB, appConfigService *AppConfigService) *UserG
return &UserGroupService{db: db, appConfigService: appConfigService} return &UserGroupService{db: db, appConfigService: appConfigService}
} }
func (s *UserGroupService) List(name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) { func (s *UserGroupService) List(ctx context.Context, name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
query := s.db.Preload("CustomClaims").Model(&model.UserGroup{}) query := s.db.
WithContext(ctx).
Preload("CustomClaims").
Model(&model.UserGroup{})
if name != "" { if name != "" {
query = query.Where("name LIKE ?", "%"+name+"%") query = query.Where("name LIKE ?", "%"+name+"%")
@@ -42,26 +47,58 @@ func (s *UserGroupService) List(name string, sortedPaginationRequest utils.Sorte
return groups, response, err return groups, response, err
} }
func (s *UserGroupService) Get(id string) (group model.UserGroup, err error) { func (s *UserGroupService) Get(ctx context.Context, id string) (group model.UserGroup, err error) {
err = s.db.Where("id = ?", id).Preload("CustomClaims").Preload("Users").First(&group).Error return s.getInternal(ctx, id, s.db)
}
func (s *UserGroupService) getInternal(ctx context.Context, id string, tx *gorm.DB) (group model.UserGroup, err error) {
err = tx.
WithContext(ctx).
Where("id = ?", id).
Preload("CustomClaims").
Preload("Users").
First(&group).
Error
return group, err return group, err
} }
func (s *UserGroupService) Delete(id string) error { func (s *UserGroupService) Delete(ctx context.Context, id string) error {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var group model.UserGroup var group model.UserGroup
if err := s.db.Where("id = ?", id).First(&group).Error; err != nil { err := tx.
WithContext(ctx).
Where("id = ?", id).
First(&group).
Error
if err != nil {
return err return err
} }
// Disallow deleting the group if it is an LDAP group and LDAP is enabled // Disallow deleting the group if it is an LDAP group and LDAP is enabled
if group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() { if group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return &common.LdapUserGroupUpdateError{} return &common.LdapUserGroupUpdateError{}
} }
return s.db.Delete(&group).Error err = tx.
WithContext(ctx).
Delete(&group).
Error
if err != nil {
return err
} }
func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.UserGroup, err error) { return tx.Commit().Error
}
func (s *UserGroupService) Create(ctx context.Context, input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
return s.createInternal(ctx, input, s.db)
}
func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGroupCreateDto, tx *gorm.DB) (group model.UserGroup, err error) {
group = model.UserGroup{ group = model.UserGroup{
FriendlyName: input.FriendlyName, FriendlyName: input.FriendlyName,
Name: input.Name, Name: input.Name,
@@ -71,7 +108,12 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use
group.LdapID = &input.LdapID group.LdapID = &input.LdapID
} }
if err := s.db.Preload("Users").Create(&group).Error; err != nil { err = tx.
WithContext(ctx).
Preload("Users").
Create(&group).
Error
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) { if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"} return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
} }
@@ -80,31 +122,73 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use
return group, nil return group, nil
} }
func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) { func (s *UserGroupService) Update(ctx context.Context, id string, input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
group, err = s.Get(id) tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
group, err = s.updateInternal(ctx, id, input, false, tx)
if err != nil {
return model.UserGroup{}, err
}
err = tx.Commit().Error
if err != nil {
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) updateInternal(ctx context.Context, id string, input dto.UserGroupCreateDto, isLdapSync bool, tx *gorm.DB) (group model.UserGroup, err error) {
group, err = s.getInternal(ctx, id, tx)
if err != nil { if err != nil {
return model.UserGroup{}, err return model.UserGroup{}, err
} }
// Disallow updating the group if it is an LDAP group and LDAP is enabled // Disallow updating the group if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() { if !isLdapSync && group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return model.UserGroup{}, &common.LdapUserGroupUpdateError{} return model.UserGroup{}, &common.LdapUserGroupUpdateError{}
} }
group.Name = input.Name group.Name = input.Name
group.FriendlyName = input.FriendlyName group.FriendlyName = input.FriendlyName
if err := s.db.Preload("Users").Save(&group).Error; err != nil { err = tx.
WithContext(ctx).
Preload("Users").
Save(&group).
Error
if errors.Is(err, gorm.ErrDuplicatedKey) { if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"} return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
} } else if err != nil {
return model.UserGroup{}, err return model.UserGroup{}, err
} }
return group, nil return group, nil
} }
func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model.UserGroup, err error) { func (s *UserGroupService) UpdateUsers(ctx context.Context, id string, userIds []string) (group model.UserGroup, err error) {
group, err = s.Get(id) tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
group, err = s.updateUsersInternal(ctx, id, userIds, tx)
if err != nil {
return model.UserGroup{}, err
}
err = tx.Commit().Error
if err != nil {
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, userIds []string, tx *gorm.DB) (group model.UserGroup, err error) {
group, err = s.getInternal(ctx, id, tx)
if err != nil { if err != nil {
return model.UserGroup{}, err return model.UserGroup{}, err
} }
@@ -112,28 +196,59 @@ func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model
// Fetch the users based on the userIds // Fetch the users based on the userIds
var users []model.User var users []model.User
if len(userIds) > 0 { if len(userIds) > 0 {
if err := s.db.Where("id IN (?)", userIds).Find(&users).Error; err != nil { err := tx.
WithContext(ctx).
Where("id IN (?)", userIds).
Find(&users).
Error
if err != nil {
return model.UserGroup{}, err return model.UserGroup{}, err
} }
} }
// Replace the current users with the new set of users // Replace the current users with the new set of users
if err := s.db.Model(&group).Association("Users").Replace(users); err != nil { err = tx.
WithContext(ctx).
Model(&group).
Association("Users").
Replace(users)
if err != nil {
return model.UserGroup{}, err return model.UserGroup{}, err
} }
// Save the updated group // Save the updated group
if err := s.db.Save(&group).Error; err != nil { err = tx.
WithContext(ctx).
Save(&group).
Error
if err != nil {
return model.UserGroup{}, err return model.UserGroup{}, err
} }
return group, nil return group, nil
} }
func (s *UserGroupService) GetUserCountOfGroup(id string) (int64, error) { func (s *UserGroupService) GetUserCountOfGroup(ctx context.Context, id string) (int64, error) {
// We only perform select queries here, so we can rollback in all cases
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var group model.UserGroup var group model.UserGroup
if err := s.db.Preload("Users").Where("id = ?", id).First(&group).Error; err != nil { err := tx.
WithContext(ctx).
Preload("Users").
Where("id = ?", id).
First(&group).
Error
if err != nil {
return 0, err return 0, err
} }
return s.db.Model(&group).Association("Users").Count(), nil count := tx.
WithContext(ctx).
Model(&group).
Association("Users").
Count()
return count, nil
} }

View File

@@ -1,6 +1,8 @@
package service package service
import ( import (
"bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -11,7 +13,7 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image" "gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
@@ -19,7 +21,7 @@ import (
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/email" "github.com/pocket-id/pocket-id/backend/internal/utils/email"
"gorm.io/gorm" profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
) )
type UserService struct { type UserService struct {
@@ -34,59 +36,110 @@ func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditL
return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService, emailService: emailService, appConfigService: appConfigService} return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService, emailService: emailService, appConfigService: appConfigService}
} }
func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) { func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) {
var users []model.User var users []model.User
query := s.db.Model(&model.User{}) query := s.db.WithContext(ctx).
Model(&model.User{}).
Preload("UserGroups").
Preload("CustomClaims")
if searchTerm != "" { if searchTerm != "" {
searchPattern := "%" + searchTerm + "%" searchPattern := "%" + searchTerm + "%"
query = query.Where("email LIKE ? OR first_name LIKE ? OR username LIKE ?", searchPattern, searchPattern, searchPattern) query = query.Where("email LIKE ? OR first_name LIKE ? OR last_name LIKE ? OR username LIKE ?",
searchPattern, searchPattern, searchPattern, searchPattern)
} }
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &users) pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &users)
return users, pagination, err return users, pagination, err
} }
func (s *UserService) GetUser(userID string) (model.User, error) { func (s *UserService) GetUser(ctx context.Context, userID string) (model.User, error) {
return s.getUserInternal(ctx, userID, s.db)
}
func (s *UserService) getUserInternal(ctx context.Context, userID string, tx *gorm.DB) (model.User, error) {
var user model.User var user model.User
err := s.db.Preload("UserGroups").Preload("CustomClaims").Where("id = ?", userID).First(&user).Error err := tx.
WithContext(ctx).
Preload("UserGroups").
Preload("CustomClaims").
Where("id = ?", userID).
First(&user).
Error
return user, err return user, err
} }
func (s *UserService) GetProfilePicture(userID string) (io.Reader, int64, error) { func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.ReadCloser, int64, error) {
// Validate the user ID to prevent directory traversal // Validate the user ID to prevent directory traversal
if err := uuid.Validate(userID); err != nil { if err := uuid.Validate(userID); err != nil {
return nil, 0, &common.InvalidUUIDError{} return nil, 0, &common.InvalidUUIDError{}
} }
// First check for a custom uploaded profile picture (userID.png)
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png" profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
file, err := os.Open(profilePicturePath) file, err := os.Open(profilePicturePath)
if err == nil { if err == nil {
// Get the file size // Get the file size
fileInfo, err := file.Stat() fileInfo, err := file.Stat()
if err != nil { if err != nil {
file.Close()
return nil, 0, err return nil, 0, err
} }
return file, fileInfo.Size(), nil return file, fileInfo.Size(), nil
} }
// If the file does not exist, return the default profile picture // If no custom picture exists, get the user's data for creating initials
user, err := s.GetUser(userID) user, err := s.GetUser(ctx, userID)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
defaultPicture, err := profilepicture.CreateDefaultProfilePicture(user.FirstName, user.LastName) // Check if we have a cached default picture for these initials
defaultProfilePicturesDir := common.EnvConfig.UploadPath + "/profile-pictures/defaults/"
defaultPicturePath := defaultProfilePicturesDir + user.Initials() + ".png"
file, err = os.Open(defaultPicturePath)
if err == nil {
fileInfo, err := file.Stat()
if err != nil {
file.Close()
return nil, 0, err
}
return file, fileInfo.Size(), nil
}
// If no cached default picture exists, create one and save it for future use
defaultPicture, err := profilepicture.CreateDefaultProfilePicture(user.Initials())
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
return defaultPicture, int64(defaultPicture.Len()), nil // Save the default picture for future use (in a goroutine to avoid blocking)
defaultPictureBytes := defaultPicture.Bytes()
go func() {
// Ensure the directory exists
err = os.MkdirAll(defaultProfilePicturesDir, os.ModePerm)
if err != nil {
log.Printf("Failed to create directory for default profile picture: %v", err)
return
}
if err := utils.SaveFileStream(bytes.NewReader(defaultPictureBytes), defaultPicturePath); err != nil {
log.Printf("Failed to cache default profile picture for initials %s: %v", user.Initials(), err)
}
}()
return io.NopCloser(bytes.NewReader(defaultPictureBytes)), int64(defaultPicture.Len()), nil
} }
func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) { func (s *UserService) GetUserGroups(ctx context.Context, userID string) ([]model.UserGroup, error) {
var user model.User var user model.User
if err := s.db.Preload("UserGroups").Where("id = ?", userID).First(&user).Error; err != nil { err := s.db.
WithContext(ctx).
Preload("UserGroups").
Where("id = ?", userID).
First(&user).
Error
if err != nil {
return nil, err return nil, err
} }
return user.UserGroups, nil return user.UserGroups, nil
@@ -121,27 +174,64 @@ func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error
return nil return nil
} }
func (s *UserService) DeleteUser(userID string, allowLdapDelete bool) error { func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error {
var user model.User return s.db.Transaction(func(tx *gorm.DB) error {
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { return s.deleteUserInternal(ctx, userID, allowLdapDelete, tx)
return err })
} }
// Disallow deleting the user if it is an LDAP user and LDAP is enabled func (s *UserService) deleteUserInternal(ctx context.Context, userID string, allowLdapDelete bool, tx *gorm.DB) error {
if !allowLdapDelete && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() { var user model.User
err := tx.
WithContext(ctx).
Where("id = ?", userID).
First(&user).
Error
if err != nil {
return fmt.Errorf("failed to load user to delete: %w", err)
}
// Disallow deleting the user if it is an LDAP user, LDAP is enabled, and the user is not disabled
if !allowLdapDelete && !user.Disabled && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return &common.LdapUserUpdateError{} return &common.LdapUserUpdateError{}
} }
// Delete the profile picture // Delete the profile picture
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png" profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
if err := os.Remove(profilePicturePath); err != nil && !os.IsNotExist(err) { err = os.Remove(profilePicturePath)
if err != nil && !os.IsNotExist(err) {
return err return err
} }
return s.db.Delete(&user).Error err = tx.WithContext(ctx).Delete(&user).Error
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
} }
func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) { return nil
}
func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (model.User, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
user, err := s.createUserInternal(ctx, input, false, tx)
if err != nil {
return model.User{}, err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, isLdapSync bool, tx *gorm.DB) (model.User, error) {
user := model.User{ user := model.User{
FirstName: input.FirstName, FirstName: input.FirstName,
LastName: input.LastName, LastName: input.LastName,
@@ -154,23 +244,56 @@ func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) {
user.LdapID = &input.LdapID user.LdapID = &input.LdapID
} }
if err := s.db.Create(&user).Error; err != nil { err := tx.WithContext(ctx).Create(&user).Error
if errors.Is(err, gorm.ErrDuplicatedKey) { if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.User{}, s.checkDuplicatedFields(user) // Do not follow this path if we're using LDAP, as we don't want to roll-back the transaction here
if !isLdapSync {
tx.Rollback()
// If we are here, the transaction is already aborted due to an error, so we pass s.db
err = s.checkDuplicatedFields(ctx, user, s.db)
} else {
err = s.checkDuplicatedFields(ctx, user, tx)
} }
return model.User{}, err
} else if err != nil {
return model.User{}, err return model.User{}, err
} }
return user, nil return user, nil
} }
func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool) (model.User, error) { func (s *UserService) UpdateUser(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool) (model.User, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
user, err := s.updateUserInternal(ctx, userID, updatedUser, updateOwnUser, allowLdapUpdate, tx)
if err != nil {
return model.User{}, err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, isLdapSync bool, tx *gorm.DB) (model.User, error) {
var user model.User var user model.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { err := tx.
WithContext(ctx).
Where("id = ?", userID).
First(&user).
Error
if err != nil {
return model.User{}, err return model.User{}, err
} }
// Disallow updating the user if it is an LDAP group and LDAP is enabled // Disallow updating the user if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() { if !isLdapSync && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return model.User{}, &common.LdapUserUpdateError{} return model.User{}, &common.LdapUserUpdateError{}
} }
@@ -181,26 +304,49 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u
user.Locale = updatedUser.Locale user.Locale = updatedUser.Locale
if !updateOwnUser { if !updateOwnUser {
user.IsAdmin = updatedUser.IsAdmin user.IsAdmin = updatedUser.IsAdmin
user.Disabled = updatedUser.Disabled
} }
if err := s.db.Save(&user).Error; err != nil { err = tx.
WithContext(ctx).
Save(&user).
Error
if errors.Is(err, gorm.ErrDuplicatedKey) { if errors.Is(err, gorm.ErrDuplicatedKey) {
return user, s.checkDuplicatedFields(user) // Do not follow this path if we're using LDAP, as we don't want to roll-back the transaction here
if !isLdapSync {
tx.Rollback()
// If we are here, the transaction is already aborted due to an error, so we pass s.db
err = s.checkDuplicatedFields(ctx, user, s.db)
} else {
err = s.checkDuplicatedFields(ctx, user, tx)
} }
return user, err
} else if err != nil {
return user, err return user, err
} }
return user, nil return user, nil
} }
func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath string) error { func (s *UserService) RequestOneTimeAccessEmailAsAdmin(ctx context.Context, userID string, expiration time.Time) error {
isDisabled := !s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.IsTrue() isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessAsAdminEnabled.IsTrue()
if isDisabled { if isDisabled {
return &common.OneTimeAccessDisabledError{} return &common.OneTimeAccessDisabledError{}
} }
var user model.User return s.requestOneTimeAccessEmailInternal(ctx, userID, "", expiration)
if err := s.db.Where("email = ?", emailAddress).First(&user).Error; err != nil { }
func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context.Context, userID, redirectPath string) error {
isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessAsUnauthenticatedEnabled.IsTrue()
if isDisabled {
return &common.OneTimeAccessDisabledError{}
}
var userId string
err := s.db.Model(&model.User{}).Select("id").Where("email = ?", userID).First(&userId).Error
if err != nil {
// Do not return error if user not found to prevent email enumeration // Do not return error if user not found to prevent email enumeration
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil return nil
@@ -209,41 +355,69 @@ func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath strin
} }
} }
oneTimeAccessToken, err := s.CreateOneTimeAccessToken(user.ID, time.Now().Add(15*time.Minute)) expiration := time.Now().Add(15 * time.Minute)
return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, expiration)
}
func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, userID, redirectPath string, expiration time.Time) error {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
user, err := s.GetUser(ctx, userID)
if err != nil { if err != nil {
return err return err
} }
link := fmt.Sprintf("%s/lc", common.EnvConfig.AppURL) oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, expiration, tx)
linkWithCode := fmt.Sprintf("%s/%s", link, oneTimeAccessToken) if err != nil {
return err
}
err = tx.Commit().Error
if err != nil {
return err
}
// We use a background context here as this is running in a goroutine
//nolint:contextcheck
go func() {
innerCtx := context.Background()
link := common.EnvConfig.AppURL + "/lc"
linkWithCode := link + "/" + oneTimeAccessToken
// Add redirect path to the link // Add redirect path to the link
if strings.HasPrefix(redirectPath, "/") { if strings.HasPrefix(redirectPath, "/") {
encodedRedirectPath := url.QueryEscape(redirectPath) encodedRedirectPath := url.QueryEscape(redirectPath)
linkWithCode = fmt.Sprintf("%s?redirect=%s", linkWithCode, encodedRedirectPath) linkWithCode = linkWithCode + "?redirect=" + encodedRedirectPath
} }
go func() { errInternal := SendEmail(innerCtx, s.emailService, email.Address{
err := SendEmail(s.emailService, email.Address{ Name: user.FullName(),
Name: user.Username,
Email: user.Email, Email: user.Email,
}, OneTimeAccessTemplate, &OneTimeAccessTemplateData{ }, OneTimeAccessTemplate, &OneTimeAccessTemplateData{
Code: oneTimeAccessToken, Code: oneTimeAccessToken,
LoginLink: link, LoginLink: link,
LoginLinkWithCode: linkWithCode, LoginLinkWithCode: linkWithCode,
ExpirationString: utils.DurationToString(time.Until(expiration).Round(time.Second)),
}) })
if err != nil { if errInternal != nil {
log.Printf("Failed to send email to '%s': %v\n", user.Email, err) log.Printf("Failed to send email to '%s': %v\n", user.Email, errInternal)
} }
}() }()
return nil return nil
} }
func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time) (string, error) { func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, expiresAt time.Time) (string, error) {
tokenLength := 16 return s.createOneTimeAccessTokenInternal(ctx, userID, expiresAt, s.db)
}
func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, expiresAt time.Time, tx *gorm.DB) (string, error) {
// If expires at is less than 15 minutes, use an 6 character token instead of 16 // If expires at is less than 15 minutes, use an 6 character token instead of 16
tokenLength := 16
if time.Until(expiresAt) <= 15*time.Minute { if time.Until(expiresAt) <= 15*time.Minute {
tokenLength = 6 tokenLength = 6
} }
@@ -259,16 +433,26 @@ func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Tim
Token: randomString, Token: randomString,
} }
if err := s.db.Create(&oneTimeAccessToken).Error; err != nil { if err := tx.WithContext(ctx).Create(&oneTimeAccessToken).Error; err != nil {
return "", err return "", err
} }
return oneTimeAccessToken.Token, nil return oneTimeAccessToken.Token, nil
} }
func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAgent string) (model.User, string, error) { func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token string, ipAddress, userAgent string) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var oneTimeAccessToken model.OneTimeAccessToken var oneTimeAccessToken model.OneTimeAccessToken
if err := s.db.Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil { err := tx.
WithContext(ctx).
Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").
First(&oneTimeAccessToken).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return model.User{}, "", &common.TokenInvalidOrExpiredError{} return model.User{}, "", &common.TokenInvalidOrExpiredError{}
} }
@@ -279,19 +463,33 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAg
return model.User{}, "", err return model.User{}, "", err
} }
if err := s.db.Delete(&oneTimeAccessToken).Error; err != nil { err = tx.
WithContext(ctx).
Delete(&oneTimeAccessToken).
Error
if err != nil {
return model.User{}, "", err return model.User{}, "", err
} }
if ipAddress != "" && userAgent != "" { if ipAddress != "" && userAgent != "" {
s.auditLogService.Create(model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}) s.auditLogService.Create(ctx, model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}, tx)
}
err = tx.Commit().Error
if err != nil {
return model.User{}, "", err
} }
return oneTimeAccessToken.User, accessToken, nil return oneTimeAccessToken.User, accessToken, nil
} }
func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user model.User, err error) { func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroupIds []string) (user model.User, err error) {
user, err = s.GetUser(id) tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
user, err = s.getUserInternal(ctx, id, tx)
if err != nil { if err != nil {
return model.User{}, err return model.User{}, err
} }
@@ -299,27 +497,48 @@ func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user m
// Fetch the groups based on userGroupIds // Fetch the groups based on userGroupIds
var groups []model.UserGroup var groups []model.UserGroup
if len(userGroupIds) > 0 { if len(userGroupIds) > 0 {
if err := s.db.Where("id IN (?)", userGroupIds).Find(&groups).Error; err != nil { err = tx.
WithContext(ctx).
Where("id IN (?)", userGroupIds).
Find(&groups).
Error
if err != nil {
return model.User{}, err return model.User{}, err
} }
} }
// Replace the current groups with the new set of groups // Replace the current groups with the new set of groups
if err := s.db.Model(&user).Association("UserGroups").Replace(groups); err != nil { err = tx.
WithContext(ctx).
Model(&user).
Association("UserGroups").
Replace(groups)
if err != nil {
return model.User{}, err return model.User{}, err
} }
// Save the updated user // Save the updated user
if err := s.db.Save(&user).Error; err != nil { err = tx.WithContext(ctx).Save(&user).Error
if err != nil {
return model.User{}, err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, err return model.User{}, err
} }
return user, nil return user, nil
} }
func (s *UserService) SetupInitialAdmin() (model.User, string, error) { func (s *UserService) SetupInitialAdmin(ctx context.Context) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var userCount int64 var userCount int64
if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil { if err := tx.WithContext(ctx).Model(&model.User{}).Count(&userCount).Error; err != nil {
return model.User{}, "", err return model.User{}, "", err
} }
if userCount > 1 { if userCount > 1 {
@@ -334,7 +553,7 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
IsAdmin: true, IsAdmin: true,
} }
if err := s.db.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil { if err := tx.WithContext(ctx).Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
return model.User{}, "", err return model.User{}, "", err
} }
@@ -347,16 +566,39 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
return model.User{}, "", err return model.User{}, "", err
} }
err = tx.Commit().Error
if err != nil {
return model.User{}, "", err
}
return user, token, nil return user, token, nil
} }
func (s *UserService) checkDuplicatedFields(user model.User) error { func (s *UserService) checkDuplicatedFields(ctx context.Context, user model.User, tx *gorm.DB) error {
var existingUser model.User var result struct {
if s.db.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil { Found bool
}
err := tx.
WithContext(ctx).
Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND email = ?) AS found`, user.ID, user.Email).
First(&result).
Error
if err != nil {
return err
}
if result.Found {
return &common.AlreadyInUseError{Property: "email"} return &common.AlreadyInUseError{Property: "email"}
} }
if s.db.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil { err = tx.
WithContext(ctx).
Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND username = ?) AS found`, user.ID, user.Username).
First(&result).
Error
if err != nil {
return err
}
if result.Found {
return &common.AlreadyInUseError{Property: "username"} return &common.AlreadyInUseError{Property: "username"}
} }
@@ -386,3 +628,12 @@ func (s *UserService) ResetProfilePicture(userID string) error {
return nil return nil
} }
func (s *UserService) disableUserInternal(ctx context.Context, userID string, tx *gorm.DB) error {
return tx.
WithContext(ctx).
Model(&model.User{}).
Where("id = ?", userID).
Update("disabled", true).
Error
}

View File

@@ -1,16 +1,19 @@
package service package service
import ( import (
"context"
"fmt"
"net/http" "net/http"
"time" "time"
"github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn" "github.com/go-webauthn/webauthn/webauthn"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
) )
type WebAuthnService struct { type WebAuthnService struct {
@@ -23,7 +26,7 @@ type WebAuthnService struct {
func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, appConfigService *AppConfigService) *WebAuthnService { func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, appConfigService *AppConfigService) *WebAuthnService {
webauthnConfig := &webauthn.Config{ webauthnConfig := &webauthn.Config{
RPDisplayName: appConfigService.DbConfig.AppName.Value, RPDisplayName: appConfigService.GetDbConfig().AppName.Value,
RPID: utils.GetHostnameFromURL(common.EnvConfig.AppURL), RPID: utils.GetHostnameFromURL(common.EnvConfig.AppURL),
RPOrigins: []string{common.EnvConfig.AppURL}, RPOrigins: []string{common.EnvConfig.AppURL},
Timeouts: webauthn.TimeoutsConfig{ Timeouts: webauthn.TimeoutsConfig{
@@ -40,18 +43,39 @@ func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *Au
}, },
} }
wa, _ := webauthn.New(webauthnConfig) wa, _ := webauthn.New(webauthnConfig)
return &WebAuthnService{db: db, webAuthn: wa, jwtService: jwtService, auditLogService: auditLogService, appConfigService: appConfigService} return &WebAuthnService{
db: db,
webAuthn: wa,
jwtService: jwtService,
auditLogService: auditLogService,
appConfigService: appConfigService,
}
} }
func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCredentialCreationOptions, error) { func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string) (*model.PublicKeyCredentialCreationOptions, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
s.updateWebAuthnConfig() s.updateWebAuthnConfig()
var user model.User var user model.User
if err := s.db.Preload("Credentials").Find(&user, "id = ?", userID).Error; err != nil { err := tx.
WithContext(ctx).
Preload("Credentials").
Find(&user, "id = ?", userID).
Error
if err != nil {
tx.Rollback()
return nil, err return nil, err
} }
options, session, err := s.webAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors())) options, session, err := s.webAuthn.BeginRegistration(
&user,
webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired),
webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()),
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -62,7 +86,16 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
UserVerification: string(session.UserVerification), UserVerification: string(session.UserVerification),
} }
if err := s.db.Create(&sessionToStore).Error; err != nil { err = tx.
WithContext(ctx).
Create(&sessionToStore).
Error
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err return nil, err
} }
@@ -73,9 +106,18 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
}, nil }, nil
} }
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) { func (s *WebAuthnService) VerifyRegistration(ctx context.Context, sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var storedSession model.WebauthnSession var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil { err := tx.
WithContext(ctx).
First(&storedSession, "id = ?", sessionID).
Error
if err != nil {
return model.WebauthnCredential{}, err return model.WebauthnCredential{}, err
} }
@@ -86,7 +128,11 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
} }
var user model.User var user model.User
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil { err = tx.
WithContext(ctx).
Find(&user, "id = ?", userID).
Error
if err != nil {
return model.WebauthnCredential{}, err return model.WebauthnCredential{}, err
} }
@@ -108,7 +154,16 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
BackupEligible: credential.Flags.BackupEligible, BackupEligible: credential.Flags.BackupEligible,
BackupState: credential.Flags.BackupState, BackupState: credential.Flags.BackupState,
} }
if err := s.db.Create(&credentialToStore).Error; err != nil { err = tx.
WithContext(ctx).
Create(&credentialToStore).
Error
if err != nil {
return model.WebauthnCredential{}, err
}
err = tx.Commit().Error
if err != nil {
return model.WebauthnCredential{}, err return model.WebauthnCredential{}, err
} }
@@ -125,7 +180,7 @@ func (s *WebAuthnService) determinePasskeyName(aaguid []byte) string {
return "New Passkey" // Default fallback return "New Passkey" // Default fallback
} }
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) { func (s *WebAuthnService) BeginLogin(ctx context.Context) (*model.PublicKeyCredentialRequestOptions, error) {
options, session, err := s.webAuthn.BeginDiscoverableLogin() options, session, err := s.webAuthn.BeginDiscoverableLogin()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -137,7 +192,11 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
UserVerification: string(session.UserVerification), UserVerification: string(session.UserVerification),
} }
if err := s.db.Create(&sessionToStore).Error; err != nil { err = s.db.
WithContext(ctx).
Create(&sessionToStore).
Error
if err != nil {
return nil, err return nil, err
} }
@@ -148,9 +207,18 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
}, nil }, nil
} }
func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) { func (s *WebAuthnService) VerifyLogin(ctx context.Context, sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var storedSession model.WebauthnSession var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil { err := tx.
WithContext(ctx).
First(&storedSession, "id = ?", sessionID).
Error
if err != nil {
return model.User{}, "", err return model.User{}, "", err
} }
@@ -160,9 +228,14 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData
} }
var user *model.User var user *model.User
_, err := s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) { _, err = s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
if err := s.db.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil { innerErr := tx.
return nil, err WithContext(ctx).
Preload("Credentials").
First(&user, "id = ?", string(userHandle)).
Error
if innerErr != nil {
return nil, innerErr
} }
return user, nil return user, nil
}, session, credentialAssertionData) }, session, credentialAssertionData)
@@ -171,46 +244,78 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData
return model.User{}, "", err return model.User{}, "", err
} }
if user.Disabled {
return model.User{}, "", &common.UserDisabledError{}
}
token, err := s.jwtService.GenerateAccessToken(*user) token, err := s.jwtService.GenerateAccessToken(*user)
if err != nil { if err != nil {
return model.User{}, "", err return model.User{}, "", err
} }
s.auditLogService.CreateNewSignInWithEmail(ipAddress, userAgent, user.ID) s.auditLogService.CreateNewSignInWithEmail(ctx, ipAddress, userAgent, user.ID, tx)
err = tx.Commit().Error
if err != nil {
return model.User{}, "", err
}
return *user, token, nil return *user, token, nil
} }
func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) { func (s *WebAuthnService) ListCredentials(ctx context.Context, userID string) ([]model.WebauthnCredential, error) {
var credentials []model.WebauthnCredential var credentials []model.WebauthnCredential
if err := s.db.Find(&credentials, "user_id = ?", userID).Error; err != nil { err := s.db.
WithContext(ctx).
Find(&credentials, "user_id = ?", userID).
Error
if err != nil {
return nil, err return nil, err
} }
return credentials, nil return credentials, nil
} }
func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error { func (s *WebAuthnService) DeleteCredential(ctx context.Context, userID, credentialID string) error {
var credential model.WebauthnCredential err := s.db.
if err := s.db.First(&credential, "id = ? AND user_id = ?", credentialID, userID).Error; err != nil { WithContext(ctx).
return err Where("id = ? AND user_id = ?", credentialID, userID).
} Delete(&model.WebauthnCredential{}).
Error
if err := s.db.Delete(&credential).Error; err != nil { if err != nil {
return err return fmt.Errorf("failed to delete record: %w", err)
} }
return nil return nil
} }
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (model.WebauthnCredential, error) { func (s *WebAuthnService) UpdateCredential(ctx context.Context, userID, credentialID, name string) (model.WebauthnCredential, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var credential model.WebauthnCredential var credential model.WebauthnCredential
if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil { err := tx.
WithContext(ctx).
Where("id = ? AND user_id = ?", credentialID, userID).
First(&credential).
Error
if err != nil {
return credential, err return credential, err
} }
credential.Name = name credential.Name = name
if err := s.db.Save(&credential).Error; err != nil { err = tx.
WithContext(ctx).
Save(&credential).
Error
if err != nil {
return credential, err
}
err = tx.Commit().Error
if err != nil {
return credential, err return credential, err
} }
@@ -219,5 +324,5 @@ func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (m
// updateWebAuthnConfig updates the WebAuthn configuration with the app name as it can change during runtime // updateWebAuthnConfig updates the WebAuthn configuration with the app name as it can change during runtime
func (s *WebAuthnService) updateWebAuthnConfig() { func (s *WebAuthnService) updateWebAuthnConfig() {
s.webAuthn.Config.RPDisplayName = s.appConfigService.DbConfig.AppName.Value s.webAuthn.Config.RPDisplayName = s.appConfigService.GetDbConfig().AppName.Value
} }

View File

@@ -12,9 +12,13 @@ import (
var ( var (
aaguidMap map[string]string aaguidMap map[string]string
aaguidMapOnce sync.Once aaguidMapOnce *sync.Once
) )
func init() {
aaguidMapOnce = &sync.Once{}
}
// FormatAAGUID converts an AAGUID byte slice to UUID string format // FormatAAGUID converts an AAGUID byte slice to UUID string format
func FormatAAGUID(aaguid []byte) string { func FormatAAGUID(aaguid []byte) string {
if len(aaguid) == 0 { if len(aaguid) == 0 {

View File

@@ -47,8 +47,10 @@ func TestFormatAAGUID(t *testing.T) {
func TestGetAuthenticatorName(t *testing.T) { func TestGetAuthenticatorName(t *testing.T) {
// Reset the aaguidMap for testing // Reset the aaguidMap for testing
originalMap := aaguidMap originalMap := aaguidMap
originalOnce := aaguidMapOnce
defer func() { defer func() {
aaguidMap = originalMap aaguidMap = originalMap
aaguidMapOnce = originalOnce
}() }()
// Inject a test AAGUID map // Inject a test AAGUID map
@@ -56,7 +58,7 @@ func TestGetAuthenticatorName(t *testing.T) {
"adce0002-35bc-c60a-648b-0b25f1f05503": "Test Authenticator", "adce0002-35bc-c60a-648b-0b25f1f05503": "Test Authenticator",
"00000000-0000-0000-0000-000000000000": "Zero Authenticator", "00000000-0000-0000-0000-000000000000": "Zero Authenticator",
} }
aaguidMapOnce = sync.Once{} aaguidMapOnce = &sync.Once{}
aaguidMapOnce.Do(func() {}) // Mark as done to avoid loading from file aaguidMapOnce.Do(func() {}) // Mark as done to avoid loading from file
tests := []struct { tests := []struct {
@@ -99,7 +101,7 @@ func TestGetAuthenticatorName(t *testing.T) {
func TestLoadAAGUIDsFromFile(t *testing.T) { func TestLoadAAGUIDsFromFile(t *testing.T) {
// Reset the map and once flag for clean testing // Reset the map and once flag for clean testing
aaguidMap = nil aaguidMap = nil
aaguidMapOnce = sync.Once{} aaguidMapOnce = &sync.Once{}
// Trigger loading of AAGUIDs by calling GetAuthenticatorName // Trigger loading of AAGUIDs by calling GetAuthenticatorName
GetAuthenticatorName([]byte{0x01, 0x02, 0x03, 0x04}) GetAuthenticatorName([]byte{0x01, 0x02, 0x03, 0x04})

View File

@@ -0,0 +1,52 @@
package utils
import (
"fmt"
"time"
)
// DurationToString converts a time.Duration to a human-readable string. Respects minutes, hours and days.
func DurationToString(duration time.Duration) string {
// For a duration less than a day
if duration < 24*time.Hour {
hours := int(duration.Hours())
mins := int(duration.Minutes()) % 60
switch hours {
case 0:
return fmt.Sprintf("%d minutes", mins)
case 1:
if mins == 0 {
return "1 hour"
}
return fmt.Sprintf("1 hour and %d minutes", mins)
default:
if mins == 0 {
return fmt.Sprintf("%d hours", hours)
}
return fmt.Sprintf("%d hours and %d minutes", hours, mins)
}
} else {
// For durations of a day or more
days := int(duration.Hours() / 24)
hours := int(duration.Hours()) % 24
switch hours {
case 0:
if days == 1 {
return "1 day"
}
return fmt.Sprintf("%d days", days)
case 1:
if days == 1 {
return "1 day and 1 hour"
}
return fmt.Sprintf("%d days and 1 hour", days)
default:
if days == 1 {
return fmt.Sprintf("1 day and %d hours", hours)
}
return fmt.Sprintf("%d days and %d hours", days, hours)
}
}
}

View File

@@ -3,14 +3,12 @@ package utils
import ( import (
"errors" "errors"
"fmt" "fmt"
"hash/crc64"
"io" "io"
"mime/multipart" "mime/multipart"
"os" "os"
"path/filepath" "path/filepath"
"strconv"
"time"
"github.com/google/uuid"
"github.com/pocket-id/pocket-id/backend/resources" "github.com/pocket-id/pocket-id/backend/resources"
) )
@@ -32,6 +30,8 @@ func GetImageMimeType(ext string) string {
return "image/svg+xml" return "image/svg+xml"
case "ico": case "ico":
return "image/x-icon" return "image/x-icon"
case "gif":
return "image/gif"
default: default:
return "" return ""
} }
@@ -80,22 +80,7 @@ func SaveFile(file *multipart.FileHeader, dst string) error {
// SaveFileStream saves a stream to a file. // SaveFileStream saves a stream to a file.
func SaveFileStream(r io.Reader, dstFileName string) error { func SaveFileStream(r io.Reader, dstFileName string) error {
// Our strategy is to save to a separate file and then rename it to override the original file // Our strategy is to save to a separate file and then rename it to override the original file
// First, get a temp file name that doesn't exist already tmpFileName := dstFileName + "." + uuid.NewString() + "-tmp"
var tmpFileName string
var i int64
for {
seed := strconv.FormatInt(time.Now().UnixNano()+i, 10)
suffix := crc64.Checksum([]byte(dstFileName+seed), crc64.MakeTable(crc64.ISO))
tmpFileName = dstFileName + "." + strconv.FormatUint(suffix, 10)
exists, err := FileExists(tmpFileName)
if err != nil {
return fmt.Errorf("failed to check if file '%s' exists: %w", tmpFileName, err)
}
if !exists {
break
}
i++
}
// Write to the temporary file // Write to the temporary file
tmpFile, err := os.Create(tmpFileName) tmpFile, err := os.Create(tmpFileName)

View File

@@ -6,7 +6,6 @@ import (
"image" "image"
"image/color" "image/color"
"io" "io"
"strings"
"github.com/disintegration/imageorient" "github.com/disintegration/imageorient"
"github.com/disintegration/imaging" "github.com/disintegration/imaging"
@@ -42,17 +41,7 @@ func CreateProfilePicture(file io.Reader) (io.Reader, error) {
} }
// CreateDefaultProfilePicture creates a profile picture with the initials // CreateDefaultProfilePicture creates a profile picture with the initials
func CreateDefaultProfilePicture(firstName, lastName string) (*bytes.Buffer, error) { func CreateDefaultProfilePicture(initials string) (*bytes.Buffer, error) {
// Get the initials
initials := ""
if len(firstName) > 0 {
initials += string(firstName[0])
}
if len(lastName) > 0 {
initials += string(lastName[0])
}
initials = strings.ToUpper(initials)
// Create a blank image with a white background // Create a blank image with a white background
img := imaging.New(profilePictureSize, profilePictureSize, color.RGBA{R: 255, G: 255, B: 255, A: 255}) img := imaging.New(profilePictureSize, profilePictureSize, color.RGBA{R: 255, G: 255, B: 255, A: 255})

View File

@@ -5,6 +5,7 @@ import (
"strconv" "strconv"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type PaginationResponse struct { type PaginationResponse struct {
@@ -36,11 +37,15 @@ func PaginateAndSort(sortedPaginationRequest SortedPaginationRequest, query *gor
isValidSortOrder := sort.Direction == "asc" || sort.Direction == "desc" isValidSortOrder := sort.Direction == "asc" || sort.Direction == "desc"
if sortFieldFound && isSortable && isValidSortOrder { if sortFieldFound && isSortable && isValidSortOrder {
query = query.Order(CamelCaseToSnakeCase(sort.Column) + " " + sort.Direction) columnName := CamelCaseToSnakeCase(sort.Column)
query = query.Clauses(clause.OrderBy{
Columns: []clause.OrderByColumn{
{Column: clause.Column{Name: columnName}, Desc: sort.Direction == "desc"},
},
})
} }
return Paginate(pagination.Page, pagination.Limit, query, result) return Paginate(pagination.Page, pagination.Limit, query, result)
} }
func Paginate(page int, pageSize int, query *gorm.DB, result interface{}) (PaginationResponse, error) { func Paginate(page int, pageSize int, query *gorm.DB, result interface{}) (PaginationResponse, error) {

View File

@@ -0,0 +1,5 @@
package utils
func Ptr[T any](v T) *T {
return &v
}

View File

@@ -0,0 +1,40 @@
package signals
import (
"context"
"log"
"os"
"os/signal"
"syscall"
)
/*
This code is adapted from:
https://github.com/kubernetes-sigs/controller-runtime/blob/8499b67e316a03b260c73f92d0380de8cd2e97a1/pkg/manager/signals/signal.go
Copyright 2017 The Kubernetes Authors.
License: Apache2 (https://github.com/kubernetes-sigs/controller-runtime/blob/8499b67e316a03b260c73f92d0380de8cd2e97a1/LICENSE)
*/
var onlyOneSignalHandler = make(chan struct{})
// SignalContext returns a context that is canceled when the application receives an interrupt signal.
// A second signal forces an immediate shutdown.
func SignalContext(parentCtx context.Context) context.Context {
close(onlyOneSignalHandler) // Panics when called twice
ctx, cancel := context.WithCancel(parentCtx)
sigCh := make(chan os.Signal, 2)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
go func() {
<-sigCh
log.Println("Received interrupt signal. Shutting down…")
cancel()
<-sigCh
log.Println("Received a second interrupt signal. Forcing an immediate shutdown.")
os.Exit(1)
}()
return ctx
}

View File

@@ -102,3 +102,16 @@ func CamelCaseToScreamingSnakeCase(s string) string {
// Convert to uppercase // Convert to uppercase
return strings.ToUpper(snake) return strings.ToUpper(snake)
} }
// GetFirstCharacter returns the first non-whitespace character of the string, correctly handling Unicode
func GetFirstCharacter(str string) string {
for _, c := range str {
if unicode.IsSpace(c) {
continue
}
return string(c)
}
// Empty string case
return ""
}

View File

@@ -103,3 +103,28 @@ func TestCamelCaseToSnakeCase(t *testing.T) {
}) })
} }
} }
func TestGetFirstCharacter(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"empty string", "", ""},
{"single character", "a", "a"},
{"multiple characters", "hello", "h"},
{"unicode character", "étoile", "é"},
{"special character", "!test", "!"},
{"number as first character", "123abc", "1"},
{"whitespace as first character", " hello", "h"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetFirstCharacter(tt.input)
if result != tt.expected {
t.Errorf("GetFirstCharacter(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}

Binary file not shown.

View File

@@ -0,0 +1,17 @@
{{ define "base" }}
<div class="header">
<div class="logo">
<img src="{{ .LogoURL }}" alt="{{ .AppName }}" width="32" height="32" style="width: 32px; height: 32px; max-width: 32px;"/>
<h1>{{ .AppName }}</h1>
</div>
<div class="warning">Warning</div>
</div>
<div class="content">
<h2>API Key Expiring Soon</h2>
<p>
Hello {{ .Data.Name }},<br/><br/>
This is a reminder that your API key <strong>{{ .Data.ApiKeyName }}</strong> will expire on <strong>{{ .Data.ExpiresAt.Format "2006-01-02 15:04:05 MST" }}</strong>.<br/><br/>
Please generate a new API key if you need continued access.
</p>
</div>
{{ end }}

View File

@@ -0,0 +1,10 @@
{{ define "base" -}}
API Key Expiring Soon
====================
Hello {{ .Data.Name }},
This is a reminder that your API key "{{ .Data.ApiKeyName }}" will expire on {{ .Data.ExpiresAt.Format "2006-01-02 15:04:05 MST" }}.
Please generate a new API key if you need continued access.
{{ end -}}

View File

@@ -8,7 +8,7 @@
<div class="content"> <div class="content">
<h2>Login Code</h2> <h2>Login Code</h2>
<p class="message"> <p class="message">
Click the button below to sign in to {{ .AppName }} with a login code.</br>Or visit <a href="{{ .Data.LoginLink }}">{{ .Data.LoginLink }}</a> and enter the code <strong>{{ .Data.Code }}</strong>.</br></br>This code expires in 15 minutes. Click the button below to sign in to {{ .AppName }} with a login code.</br>Or visit <a href="{{ .Data.LoginLink }}">{{ .Data.LoginLink }}</a> and enter the code <strong>{{ .Data.Code }}</strong>.</br></br>This code expires in {{.Data.ExpirationString}}.
</p> </p>
<div class="button-container"> <div class="button-container">
<a class="button" href="{{ .Data.LoginLinkWithCode }}" class="button">Sign In</a> <a class="button" href="{{ .Data.LoginLinkWithCode }}" class="button">Sign In</a>

View File

@@ -2,7 +2,7 @@
Login Code Login Code
==================== ====================
Click the link below to sign in to {{ .AppName }} with a login code. This code expires in 15 minutes. Click the link below to sign in to {{ .AppName }} with a login code. This code expires in {{.Data.ExpirationString}}.
{{ .Data.LoginLinkWithCode }} {{ .Data.LoginLinkWithCode }}

View File

@@ -0,0 +1,3 @@
DROP INDEX IF EXISTS idx_audit_logs_event;
DROP INDEX IF EXISTS idx_audit_logs_user_id;
DROP INDEX IF EXISTS idx_audit_logs_client_name;

View File

@@ -0,0 +1,3 @@
CREATE INDEX idx_audit_logs_event ON audit_logs(event);
CREATE INDEX idx_audit_logs_user_id ON audit_logs(user_id);
CREATE INDEX idx_audit_logs_client_name ON audit_logs(("data"->>'clientName'));

View File

@@ -0,0 +1,4 @@
ALTER TABLE app_config_variables ADD type VARCHAR(20) NOT NULL,;
ALTER TABLE app_config_variables ADD is_public BOOLEAN DEFAULT FALSE NOT NULL,;
ALTER TABLE app_config_variables ADD is_internal BOOLEAN DEFAULT FALSE NOT NULL,;
ALTER TABLE app_config_variables ADD default_value TEXT;

View File

@@ -0,0 +1,4 @@
ALTER TABLE app_config_variables DROP type;
ALTER TABLE app_config_variables DROP is_public;
ALTER TABLE app_config_variables DROP is_internal;
ALTER TABLE app_config_variables DROP default_value;

View File

@@ -0,0 +1,4 @@
DROP INDEX idx_users_disabled;
ALTER TABLE users
DROP COLUMN disabled;

View File

@@ -0,0 +1,2 @@
ALTER TABLE users
ADD COLUMN IF NOT EXISTS disabled BOOLEAN DEFAULT FALSE NOT NULL;

View File

@@ -0,0 +1,2 @@
ALTER TABLE api_keys
DROP COLUMN IF EXISTS expiration_email_sent;

View File

@@ -0,0 +1,2 @@
ALTER TABLE api_keys
ADD COLUMN expiration_email_sent BOOLEAN NOT NULL DEFAULT FALSE;

View File

@@ -0,0 +1 @@
DROP TABLE oidc_device_codes;

View File

@@ -0,0 +1,12 @@
CREATE TABLE oidc_device_codes
(
id UUID NOT NULL PRIMARY KEY,
created_at TIMESTAMPTZ,
device_code TEXT NOT NULL UNIQUE,
user_code TEXT NOT NULL UNIQUE,
scope TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
is_authorized BOOLEAN NOT NULL DEFAULT FALSE,
user_id UUID REFERENCES users ON DELETE CASCADE,
client_id UUID NOT NULL REFERENCES oidc_clients ON DELETE CASCADE
);

View File

@@ -0,0 +1,3 @@
DROP INDEX IF EXISTS idx_audit_logs_event;
DROP INDEX IF EXISTS idx_audit_logs_user_id;
DROP INDEX IF EXISTS idx_audit_logs_client_name;

View File

@@ -0,0 +1,3 @@
CREATE INDEX idx_audit_logs_event ON audit_logs(event);
CREATE INDEX idx_audit_logs_user_id ON audit_logs(user_id);
CREATE INDEX idx_audit_logs_client_name ON audit_logs((json_extract(data, '$.clientName')));

View File

@@ -0,0 +1,4 @@
ALTER TABLE app_config_variables ADD type VARCHAR(20) NOT NULL,;
ALTER TABLE app_config_variables ADD is_public BOOLEAN DEFAULT FALSE NOT NULL,;
ALTER TABLE app_config_variables ADD is_internal BOOLEAN DEFAULT FALSE NOT NULL,;
ALTER TABLE app_config_variables ADD default_value TEXT;

View File

@@ -0,0 +1,4 @@
ALTER TABLE app_config_variables DROP type;
ALTER TABLE app_config_variables DROP is_public;
ALTER TABLE app_config_variables DROP is_internal;
ALTER TABLE app_config_variables DROP default_value;

View File

@@ -0,0 +1,4 @@
DROP INDEX idx_users_disabled;
ALTER TABLE users
DROP COLUMN disabled;

View File

@@ -0,0 +1,2 @@
ALTER TABLE users
ADD COLUMN disabled NUMERIC DEFAULT FALSE NOT NULL;

View File

@@ -0,0 +1,2 @@
ALTER TABLE api_keys
DROP COLUMN expiration_email_sent;

View File

@@ -0,0 +1,2 @@
ALTER TABLE api_keys
ADD COLUMN expiration_email_sent BOOLEAN NOT NULL DEFAULT 0;

View File

@@ -0,0 +1 @@
DROP TABLE oidc_device_codes;

View File

@@ -0,0 +1,12 @@
CREATE TABLE oidc_device_codes
(
id TEXT NOT NULL PRIMARY KEY,
created_at DATETIME,
device_code TEXT NOT NULL UNIQUE,
user_code TEXT NOT NULL UNIQUE,
scope TEXT NOT NULL,
expires_at DATETIME NOT NULL,
is_authorized BOOLEAN NOT NULL DEFAULT FALSE,
user_id TEXT REFERENCES users ON DELETE CASCADE,
client_id TEXT NOT NULL REFERENCES oidc_clients ON DELETE CASCADE
);

View File

@@ -2,3 +2,7 @@
package-lock.json package-lock.json
pnpm-lock.yaml pnpm-lock.yaml
yarn.lock yarn.lock
# Compiled files
.svelte-kit/
build/

View File

@@ -26,7 +26,7 @@
"login_background": "Pozadí přihlašovací stránky", "login_background": "Pozadí přihlašovací stránky",
"logo": "Logo", "logo": "Logo",
"login_code": "Přihlašovací kód", "login_code": "Přihlašovací kód",
"create_a_login_code_to_sign_in_without_a_passkey_once": "Vytvořte přihlašovací kód, která může uživatel jedenktrát použít pro přihlášení bez přístupového klíče.", "create_a_login_code_to_sign_in_without_a_passkey_once": "Vytvořte přihlašovací kód, který může uživatel jednorázově použít pro přihlášení bez přístupového klíče.",
"one_hour": "1 hodina", "one_hour": "1 hodina",
"twelve_hours": "12 hodin", "twelve_hours": "12 hodin",
"one_day": "1 den", "one_day": "1 den",
@@ -35,10 +35,10 @@
"expiration": "Expirace", "expiration": "Expirace",
"generate_code": "Vygenerovat kód", "generate_code": "Vygenerovat kód",
"name": "Jméno", "name": "Jméno",
"browser_unsupported": "Prohlížeče nepodporován", "browser_unsupported": "Prohlížeč nepodporován",
"this_browser_does_not_support_passkeys": "Tento prohlížeč nepodporuje přístupové klíče. Použijte prosím alternativní metodu přihlášení. přihlášení", "this_browser_does_not_support_passkeys": "Tento prohlížeč nepodporuje přístupové klíče. Zkuste prosím alternativní způsob přihlášení.",
"an_unknown_error_occurred": "Došlo k neznámé chybě", "an_unknown_error_occurred": "Došlo k neznámé chybě",
"authentication_process_was_aborted": "Proces přihlašování byl přerušen", "authentication_process_was_aborted": "Proces přihlášení byl přerušen",
"error_occurred_with_authenticator": "Došlo k chybě s autentifikátorem", "error_occurred_with_authenticator": "Došlo k chybě s autentifikátorem",
"authenticator_does_not_support_discoverable_credentials": "Autentifikátor nepodporuje zobrazitelné přihlašovací údaje", "authenticator_does_not_support_discoverable_credentials": "Autentifikátor nepodporuje zobrazitelné přihlašovací údaje",
"authenticator_does_not_support_resident_keys": "Autentikátor nepodporuje rezidentní klíče.", "authenticator_does_not_support_resident_keys": "Autentikátor nepodporuje rezidentní klíče.",
@@ -55,7 +55,7 @@
"profile": "Profil", "profile": "Profil",
"view_your_profile_information": "Zobrazit informace o Vašem profilu", "view_your_profile_information": "Zobrazit informace o Vašem profilu",
"groups": "Skupiny", "groups": "Skupiny",
"view_the_groups_you_are_a_member_of": "Zobrazit skupiny, které jste členem", "view_the_groups_you_are_a_member_of": "Zobrazit skupiny, jejichž jste členem",
"cancel": "Zrušit", "cancel": "Zrušit",
"sign_in": "Přihlásit se", "sign_in": "Přihlásit se",
"try_again": "Zkusit znovu", "try_again": "Zkusit znovu",
@@ -84,7 +84,7 @@
"submit": "Potvrdit", "submit": "Potvrdit",
"enter_the_code_you_received_to_sign_in": "Zadejte kód, který jste obdrželi k přihlášení.", "enter_the_code_you_received_to_sign_in": "Zadejte kód, který jste obdrželi k přihlášení.",
"code": "Kód", "code": "Kód",
"invalid_redirect_url": "Neplatná URL přesměrování", "invalid_redirect_url": "Neplatné URL přesměrování",
"audit_log": "Protokol auditu", "audit_log": "Protokol auditu",
"users": "Uživatelé", "users": "Uživatelé",
"user_groups": "Uživatelské skupiny", "user_groups": "Uživatelské skupiny",
@@ -156,7 +156,7 @@
"actions": "Akce", "actions": "Akce",
"images_updated_successfully": "Obrázky úspěšně aktualizovány", "images_updated_successfully": "Obrázky úspěšně aktualizovány",
"general": "Obecné", "general": "Obecné",
"enable_email_notifications_to_alert_users_when_a_login_is_detected_from_a_new_device_or_location": "Povolte e-mailová oznámení pro upozornění uživatelů, pokud je zjištěno přihlášení z nového zařízení nebo umístění.", "configure_smtp_to_send_emails": "Povolte e-mailová oznámení pro upozornění uživatelů, pokud je zjištěno přihlášení z nového zařízení nebo lokace.",
"ldap": "LDAP", "ldap": "LDAP",
"configure_ldap_settings_to_sync_users_and_groups_from_an_ldap_server": "Nastavte LDAP pro synchronizaci uživatelů a skupin z LDAP serveru.", "configure_ldap_settings_to_sync_users_and_groups_from_an_ldap_server": "Nastavte LDAP pro synchronizaci uživatelů a skupin z LDAP serveru.",
"images": "Obrázky", "images": "Obrázky",
@@ -180,7 +180,10 @@
"enabled_emails": "Povolené e-maily", "enabled_emails": "Povolené e-maily",
"email_login_notification": "E-mailovová oznámení o přihlášení", "email_login_notification": "E-mailovová oznámení o přihlášení",
"send_an_email_to_the_user_when_they_log_in_from_a_new_device": "Poslat uživateli e-mail, když se přihlásí z nového zařízení.", "send_an_email_to_the_user_when_they_log_in_from_a_new_device": "Poslat uživateli e-mail, když se přihlásí z nového zařízení.",
"allow_users_to_sign_in_with_a_login_code_sent_to_their_email": "Umožňuje uživatelům přihlásit se pomocí přihlašovacího kódu, který je odeslán na jejich e-mail. To výrazně snižuje bezpečnost, protože každý, kdo má přístup k e-mailu uživatele, může získat vstup.", "emai_login_code_requested_by_user": "Přihlašovací kód e-mailu vyžádaný uživatelem",
"allow_users_to_sign_in_with_a_login_code_sent_to_their_email": "Umožňuje uživatelům přihlásit se pomocí přihlašovacího kódu, který je odeslán na jejich e-mail. To výrazně snižuje bezpečnost, protože každý, kdo má přístup k e-mailu uživatele, může vstoupit.",
"email_login_code_from_admin": "Poslat e-mail přihlašovacímu kódu od administrátora",
"allows_an_admin_to_send_a_login_code_to_the_user": "Umožňuje administrátorovi odeslat přihlašovací kód uživateli e-mailem.",
"send_test_email": "Odeslat testovací e-mail", "send_test_email": "Odeslat testovací e-mail",
"application_configuration_updated_successfully": "Nastavení aplikace bylo úspěšně aktualizováno", "application_configuration_updated_successfully": "Nastavení aplikace bylo úspěšně aktualizováno",
"application_name": "Název aplikace", "application_name": "Název aplikace",
@@ -273,7 +276,7 @@
"callback_urls": "URL zpětného volání", "callback_urls": "URL zpětného volání",
"logout_callback_urls": "URL zpětného volání při odhlášení", "logout_callback_urls": "URL zpětného volání při odhlášení",
"public_client": "Veřejný klient", "public_client": "Veřejný klient",
"public_clients_do_not_have_a_client_secret_and_use_pkce_instead": "Veřejní klienti nemají client secret a místo toho používají PKCE. Povolte to, pokud je váš klient SPA nebo mobilní aplikace.", "public_clients_description": "Veřejní klienti nemají client secret a místo toho používají PKCE. Povolte to, pokud je váš klient SPA nebo mobilní aplikace.",
"pkce": "PKCE", "pkce": "PKCE",
"public_key_code_exchange_is_a_security_feature_to_prevent_csrf_and_authorization_code_interception_attacks": "Public Key Exchange je bezpečnostní funkce, která zabraňuje útokům CSRF a narušení autorizačních kódů.", "public_key_code_exchange_is_a_security_feature_to_prevent_csrf_and_authorization_code_interception_attacks": "Public Key Exchange je bezpečnostní funkce, která zabraňuje útokům CSRF a narušení autorizačních kódů.",
"name_logo": "Logo {name}", "name_logo": "Logo {name}",
@@ -312,5 +315,36 @@
"reset": "Obnovit", "reset": "Obnovit",
"reset_to_default": "Obnovit výchozí", "reset_to_default": "Obnovit výchozí",
"profile_picture_has_been_reset": "Profilový obrázek byl obnoven. Aktualizace může trvat několik minut.", "profile_picture_has_been_reset": "Profilový obrázek byl obnoven. Aktualizace může trvat několik minut.",
"select_the_language_you_want_to_use": "Vyberte jazyk, který chcete použít. Některé jazyky nemusí být plně přeloženy." "select_the_language_you_want_to_use": "Vyberte jazyk, který chcete použít. Některé jazyky nemusí být plně přeloženy.",
"personal": "Osobní",
"global": "Globální",
"all_users": "Všichni uživatelé",
"all_events": "Všechny události",
"all_clients": "Všichni klienti",
"global_audit_log": "Globální protokol auditu",
"see_all_account_activities_from_the_last_3_months": "Zobrazit veškerou aktivitu uživatele za poslední 3 měsíce.",
"token_sign_in": "Přihlášení tokenem",
"client_authorization": "Autorizace klienta",
"new_client_authorization": "Nová autorizace klienta",
"disable_animations": "Zakázat animace",
"turn_off_all_animations_throughout_the_admin_ui": "Vypnout všechny animace v celém administrátorském rozhraní.",
"user_disabled": "Účet deaktivován",
"disabled_users_cannot_log_in_or_use_services": "Zakázaní uživatelé se nemohou přihlásit nebo používat služby.",
"user_disabled_successfully": "Uživatel byl úspěšně deaktivován.",
"user_enabled_successfully": "Uživatel byl úspěšně aktivován.",
"status": "Status",
"disable_firstname_lastname": "Deaktivovat {firstName} {lastName}",
"are_you_sure_you_want_to_disable_this_user": "Opravdu chcete zakázat tohoto uživatele? Nebude se moci přihlásit nebo přistupovat k žádným službám.",
"ldap_soft_delete_users": "Ponechat zakázané uživatele z LDAP.",
"ldap_soft_delete_users_description": "Pokud je povoleno, uživatelé odebráni z LDAP budou deaktivování namísto odstranění ze systému.",
"login_code_email_success": "Přihlašovací kód byl odeslán uživateli.",
"send_email": "Odeslat e-mail",
"show_code": "Zobrazit kód",
"callback_url_description": "URL poskytnuté klientem. Klientské zástupné znaky (*) jsou podporovány, ale raději se jim vyhýbejte, pro lepší bezpečnost.",
"api_key_expiration": "Vypršení platnosti API klíče",
"send_an_email_to_the_user_when_their_api_key_is_about_to_expire": "Pošlete uživateli e-mail, jakmile jejich API klíč brzy vyprší.",
"authorize_device": "Authorize Device",
"the_device_has_been_authorized": "The device has been authorized.",
"enter_code_displayed_in_previous_step": "Enter the code that was displayed in the previous step.",
"authorize": "Authorize"
} }

Some files were not shown because too many files have changed in this diff Show More