Compare commits

..

107 Commits

Author SHA1 Message Date
Elias Schneider
453a765107 release: 0.49.0 2025-04-20 20:00:09 +02:00
Elias Schneider
f03645d545 chore(translations): update translations via Crowdin (#467) 2025-04-20 17:59:49 +00:00
Elias Schneider
55273d68c9 chore(translations): fix typo in key 2025-04-20 19:51:12 +02:00
Elias Schneider
4e05b82f02 fix: hide alternative sign in button if user is already authenticated 2025-04-20 19:03:58 +02:00
Elias Schneider
2597907578 refactor: fix type errors 2025-04-20 18:54:45 +02:00
Kyle Mendell
debef9a66b ci/cd: setup caching and improve ci job performance (#465) 2025-04-20 11:48:46 -05:00
Elias Schneider
9122e75101 feat: add ability to disable API key expiration email 2025-04-20 18:41:03 +02:00
Elias Schneider
fe1c4b18cd feat: add ability to send login code via email (#457)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-20 18:32:40 +02:00
Elias Schneider
e571996cb5 fix: disable animations not respected on authorize and logout page 2025-04-20 17:04:00 +02:00
Elias Schneider
fb862d3ec3 chore(translations): update translations via Crowdin (#459)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-20 09:43:27 -05:00
Kyle Mendell
26f01f205b feat: send email to user when api key expires within 7 days (#451)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-20 14:40:20 +00:00
Elias Schneider
c37a3e0ed1 fix: remove limit of 20 callback URLs 2025-04-20 16:32:11 +02:00
Elias Schneider
eb689eb56e feat: add description to callback URL inputs 2025-04-20 00:32:27 +02:00
Elias Schneider
60bad9e985 fix: locale change in dropdown doesn't work on first try 2025-04-20 00:31:33 +02:00
Elias Schneider
e21ee8a871 chore: add kmendell to FUNDING.yml 2025-04-19 18:51:01 +02:00
Elias Schneider
04006eb5cc release: 0.48.0 2025-04-18 18:34:52 +02:00
Elias Schneider
84f1d5c906 fix: user querying fails on global audit log page with Postgres 2025-04-18 18:33:14 +02:00
Elias Schneider
983e989be1 chore(translations): update translations via Crowdin (#456) 2025-04-18 18:21:04 +02:00
Kyle Mendell
c843a60131 feat: disable/enable users (#437)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-18 15:38:50 +00:00
Elias Schneider
56a8b5d0c0 feat: add gif support for logo and background image 2025-04-18 17:31:04 +02:00
Elias Schneider
f0dce41fbc fix: callback URL doesn't get rejected if it starts with a different string 2025-04-17 20:52:58 +02:00
Elias Schneider
0111a58dac fix: add "type" as reserved claim 2025-04-17 20:41:21 +02:00
Elias Schneider
50e4c5c314 chore(translations): update translations via Crowdin (#444) 2025-04-17 20:19:50 +02:00
Kyle Mendell
5a6dfd9e50 fix: profile picture empty for users without first or last name (#449) 2025-04-17 20:19:10 +02:00
Elias Schneider
75fbfee4d8 chore(translations): add Italian 2025-04-17 19:13:47 +02:00
dependabot[bot]
65ee500ef3 chore(deps): bump golang.org/x/net from 0.36.0 to 0.38.0 in /backend in the go_modules group across 1 directory (#450)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-16 18:26:55 -05:00
Elias Schneider
80f108e5d6 release: 0.47.0 2025-04-16 16:32:27 +02:00
Elias Schneider
9b2d622990 tests: adapt JWTs in e2e tests 2025-04-16 16:30:38 +02:00
Elias Schneider
adf74586af fix: define token type as claim for better client compatibility 2025-04-16 15:58:38 +02:00
Kyle Mendell
b45cf68295 feat: disable animations setting toggle (#442)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-15 19:28:10 +00:00
dependabot[bot]
d9dd67c51f chore(deps-dev): bump @sveltejs/kit from 2.16.1 to 2.20.6 in /frontend in the npm_and_yarn group across 1 directory (#443)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-15 20:38:03 +02:00
Grégory Paul
abf17f6211 feat: add qrcode representation of one time link (#424) (#436)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Kyle Mendell <kmendell@outlook.com>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-14 13:16:46 +00:00
Elias Schneider
57cb8f8795 release: 0.46.0 2025-04-13 20:31:09 +02:00
Elias Schneider
fcb18b8c3c chore(translations): update translations via Crowdin (#427)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-13 20:30:43 +02:00
Alessandro (Ale) Segala
796bc7ed34 fix: improve LDAP error handling (#425)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-12 18:38:19 -04:00
Arne Skaar Fismen
72061ba427 feat(onboarding): Added button when you don't have a passkey added. (#426)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-12 02:27:01 +00:00
dependabot[bot]
d04167cada chore(deps-dev): bump vite from 6.2.5 to 6.2.6 in /frontend in the npm_and_yarn group across 1 directory (#433) 2025-04-11 20:07:40 -05:00
Alessandro (Ale) Segala
f83bab9e17 refactor: simplify app_config service and fix race conditions (#423) 2025-04-10 13:41:22 +02:00
Elias Schneider
4ba68938dd fix: ignore profile picture cache after profile picture gets updated 2025-04-09 15:51:58 +02:00
Elias Schneider
658a9ca6dd fix: add missing rollback for LDAP sync 2025-04-09 14:05:53 +02:00
Andreas Schneider
7e5d16be9b feat: implement token introspection (#405)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-09 07:18:03 +00:00
Elias Schneider
8d6c1e5c08 chore(translations): update translations via Crowdin (#420) 2025-04-09 02:09:01 -05:00
Elias Schneider
ce6e27d0ff refactor: rollback db changes with defer everywhere 2025-04-06 23:40:56 +02:00
Elias Schneider
3ebff09d63 chore(translations): update translations via Crowdin (#416)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-06 22:15:05 +02:00
Elias Schneider
ccc18d716f fix: use UUID for temporary file names 2025-04-06 15:11:19 +02:00
Alessandro (Ale) Segala
ec626ee797 fix: use transactions when operations involve multiple database queries (#392)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-06 15:04:08 +02:00
Kyle Mendell
c810fec8c4 docs: update swagger description to use markdown (#418) 2025-04-05 16:07:56 +02:00
Alessandro (Ale) Segala
9e88926283 fix: ensure indexes on audit_logs table (#415)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-04-04 17:05:32 +00:00
dependabot[bot]
731113183e chore(deps-dev): bump vite from 6.2.4 to 6.2.5 in /frontend in the npm_and_yarn group across 1 directory (#417)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-04 16:15:37 +00:00
Elias Schneider
4627f365a2 chore(translations): fix mistakes in source strings 2025-04-04 13:55:15 +02:00
Elias Schneider
1762629596 perf: run async operations in parallel in server load functions 2025-04-04 11:39:13 +02:00
Alessandro (Ale) Segala
2f7646105e fix: ensure file descriptors are closed + other bugs (#413) 2025-04-04 10:04:36 +02:00
Elias Schneider
980780e48b chore(translations): update translations via Crowdin (#414) 2025-04-04 09:06:44 +02:00
Kyle Mendell
b65e693e12 feat: global audit log (#320)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-03 10:11:49 -05:00
Kyle Mendell
734c6813ea fix: create reusable default profile pictures (#406)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-04-03 08:06:56 -05:00
dependabot[bot]
0d31c0ec6c chore(deps-dev): bump vite from 6.2.3 to 6.2.4 in /frontend in the npm_and_yarn group across 1 directory (#410)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-31 14:04:02 -05:00
jose_d
4806c1e09b chore(translations): improve czech translation strings (#408) 2025-03-31 08:22:06 -05:00
Elias Schneider
cf3084cfa8 refactor: remove cors exception from middleware as this is handled by the handler 2025-03-30 22:30:22 +02:00
Kyle Mendell
9881a1df9e feat: modernize ui (#381)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-30 13:19:14 -05:00
Elias Schneider
5dcf69e974 release: 0.45.0 2025-03-30 00:12:19 +01:00
Alessandro (Ale) Segala
519d58d88c fix: use WAL for SQLite by default and set busy_timeout (#388)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-03-29 23:12:48 +01:00
Alessandro (Ale) Segala
b3b43a56af refactor: do not include test controller in production builds (#402)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-29 22:11:25 +00:00
Elias Schneider
fc68cf7eb2 chore(translations): add Brazilian Portuguese 2025-03-29 23:03:18 +01:00
Elias Schneider
8ca7873802 chore(translations): update translations via Crowdin (#394)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-03-29 22:59:24 +01:00
Elias Schneider
591bf841f5 Merge remote-tracking branch 'origin/main' 2025-03-29 22:56:04 +01:00
Kyle Mendell
8f8884d208 refactor: add swagger title and version info (#399) 2025-03-29 21:55:47 +00:00
Elias Schneider
7e658276f0 fix: ldap users aren't deleted if removed from ldap server 2025-03-29 22:55:44 +01:00
Gutyina Gergő
583a1f8fee chore(deps): install inlang plugins from npm (#401)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-03-29 22:50:51 +01:00
Rich
b935a4824a ci/cd: migrate backend linter to v2. fixed unit test workflow (#400) 2025-03-28 04:00:55 -05:00
Elias Schneider
cbd1bbdf74 fix: use value receiver for AuditLogData 2025-03-27 22:41:19 +01:00
Alessandro (Ale) Segala
96876a99c5 feat: add support for ECDSA and EdDSA keys (#359)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-27 18:20:39 +01:00
Elias Schneider
5c198c280c refactor: fix code smells 2025-03-27 17:46:10 +01:00
Elias Schneider
c9e0073b63 refactor: fix code smells 2025-03-27 16:48:36 +01:00
Elias Schneider
6fa26c97be ci/cd: run linter only on backend changes 2025-03-27 16:18:15 +01:00
Elias Schneider
6746dbf41e chore(translations): update translations via Crowdin (#386) 2025-03-27 15:15:22 +00:00
Rich
4ac1196d8d ci/cd: add basic static analysis for backend (#389) 2025-03-27 16:13:56 +01:00
Sam
4d049bbe24 docs: update .env.example to reflect the new documentation location (#385) 2025-03-25 21:53:23 +00:00
Elias Schneider
664a1cf8ef release: 0.44.0 2025-03-25 17:09:06 +01:00
Elias Schneider
e6f50191cf fix: stop container if Caddy, the frontend or the backend fails 2025-03-25 16:40:53 +01:00
dependabot[bot]
de9a3cce03 chore(deps-dev): bump vite from 6.2.1 to 6.2.3 in /frontend in the npm_and_yarn group across 1 directory (#384)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-25 09:52:15 -05:00
Alessandro (Ale) Segala
8c963818bb fix: hash the refresh token in the DB (security) (#379) 2025-03-25 15:36:53 +01:00
Alessandro (Ale) Segala
26b2de4f00 refactor: use atomic renames for uploaded files (#372)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-23 20:21:44 +00:00
Kyle Mendell
b8dcda8049 feat: add OIDC refresh_token support (#325)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-23 20:14:26 +00:00
Kyle Mendell
7888d70656 docs: fix api routers for swag documentation (#378) 2025-03-23 19:26:07 +00:00
Elias Schneider
35766af055 chore(translations): add French, Czech and German to language picker 2025-03-23 20:13:58 +01:00
Elias Schneider
c53de25d25 chore(translations): update translations via Crowdin (#375) 2025-03-23 19:09:34 +00:00
Kyle Mendell
cdfe8161d4 fix: skip ldap objects without a valid unique id (#376)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-23 18:30:12 +00:00
dependabot[bot]
e2f74e5687 chore(deps): bump github.com/golang-jwt/jwt/v5 from 5.2.1 to 5.2.2 in /backend in the go_modules group across 1 directory (#374)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-21 17:45:52 -05:00
Elias Schneider
132efd675c chore(translations): update translations via Crowdin (#368) 2025-03-21 21:32:28 +00:00
Elias Schneider
1167454c4f Merge branch 'main' of https://github.com/pocket-id/pocket-id 2025-03-21 22:30:40 +01:00
Elias Schneider
af5b2f7913 ci/cd: skip e2e tests if the PR comes from i18n_crowdin 2025-03-21 22:30:37 +01:00
Savely Krasovsky
bc4af846e1 chore(translations): add Russian localization (#371)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-21 21:24:55 +00:00
Elias Schneider
edf1097dd3 ci/cd: fix invalid action configuration 2025-03-21 22:20:05 +01:00
Elias Schneider
eb34535c5a release: 0.43.1 2025-03-20 21:38:02 +01:00
Elias Schneider
3120ebf239 fix: wrong base locale causes crash 2025-03-20 21:36:05 +01:00
Elias Schneider
2fb41937ca ci/cd: ignore e2e tests on Crowdin branch 2025-03-20 20:49:17 +01:00
Elias Schneider
d78a1c6974 release: 0.43.0 2025-03-20 20:47:17 +01:00
Elias Schneider
c578baba95 chore: add language request issue template 2025-03-20 20:38:33 +01:00
Elias Schneider
bb23194e88 chore(translations): remove unused messages 2025-03-20 20:26:43 +01:00
Elias Schneider
31ac56004a refactor: use language code with country for messages 2025-03-20 20:15:26 +01:00
Elias Schneider
d59ec01b33 Update Crowdin configuration file 2025-03-20 20:12:48 +01:00
Elias Schneider
3ee26a2cfb chore: update Crowdin configuration 2025-03-20 20:09:05 +01:00
Elias Schneider
39395c79c3 Update Crowdin configuration file 2025-03-20 20:08:24 +01:00
Jonas Claes
269b5a3c92 feat: add support for translations (#349)
Co-authored-by: Kyle Mendell <kmendell@outlook.com>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-20 18:57:41 +00:00
Kyle Mendell
041c565dc1 feat(passkeys): name new passkeys based on agguids (#332)
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-20 15:35:08 +00:00
Elias Schneider
e486dbd771 release: 0.42.1 2025-03-18 23:03:50 +01:00
Elias Schneider
f7e36a422e fix: kid not added to JWTs 2025-03-18 23:03:34 +01:00
247 changed files with 17691 additions and 7404 deletions

View File

@@ -1,4 +1,4 @@
# See the README for more information: https://github.com/pocket-id/pocket-id?tab=readme-ov-file#environment-variables # See the documentation for more information: https://pocket-id.org/docs/configuration/environment-variables
PUBLIC_APP_URL=http://localhost PUBLIC_APP_URL=http://localhost
TRUST_PROXY=false TRUST_PROXY=false
MAXMIND_LICENSE_KEY= MAXMIND_LICENSE_KEY=

2
.github/FUNDING.yml vendored
View File

@@ -1,2 +1,2 @@
# These are supported funding model platforms # These are supported funding model platforms
github: stonith404 github: [stonith404, kmendell]

View File

@@ -0,0 +1,20 @@
name: "🌐 Language request"
description: "You want to contribute to a language that isn't on Crowdin yet?"
title: "🌐 Language Request: <language name in english>"
labels: [language-request]
body:
- type: input
id: language-name-native
attributes:
label: "🌐 Language Name (native)"
placeholder: "Schweizerdeutsch"
validations:
required: true
- type: input
id: language-code
attributes:
label: "🌐 ISO 639-1 Language Code"
description: "You can find your language code [here](https://www.andiamo.co.uk/resources/iso-language-codes/)."
placeholder: "de-CH"
validations:
required: true

39
.github/workflows/backend-linter.yml vendored Normal file
View File

@@ -0,0 +1,39 @@
name: Run Backend Linter
on:
push:
branches: [main]
paths:
- "backend/**"
pull_request:
branches: [main]
paths:
- "backend/**"
permissions:
# Required: allow read access to the content for analysis.
contents: read
# Optional: allow read access to pull request. Use with `only-new-issues` option.
pull-requests: read
# Optional: allow write access to checks to allow the action to annotate code in the PR.
checks: write
jobs:
golangci-lint:
name: Run Golangci-lint
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: backend/go.mod
- name: Run Golangci-lint
uses: golangci/golangci-lint-action@dec74fa03096ff515422f71d18d41307cacde373 # v7.0.0
with:
version: v2.0.2
working-directory: backend
only-new-issues: ${{ github.event_name == 'pull_request' }}

View File

@@ -15,25 +15,35 @@ on:
jobs: jobs:
build: build:
if: github.event.pull_request.head.ref != 'i18n_crowdin'
timeout-minutes: 20 timeout-minutes: 20
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: Build and export - name: Build and export
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:
tags: pocket-id/pocket-id:test push: false
load: false
tags: pocket-id:test
outputs: type=docker,dest=/tmp/docker-image.tar outputs: type=docker,dest=/tmp/docker-image.tar
build-args: BUILD_TAGS=e2etest
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Upload Docker image artifact - name: Upload Docker image artifact
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: docker-image name: docker-image
path: /tmp/docker-image.tar path: /tmp/docker-image.tar
retention-days: 1
test-sqlite: test-sqlite:
if: github.event.pull_request.head.ref != 'i18n_crowdin'
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: build needs: build
steps: steps:
@@ -44,12 +54,22 @@ jobs:
cache: "npm" cache: "npm"
cache-dependency-path: frontend/package-lock.json cache-dependency-path: frontend/package-lock.json
- name: Cache Playwright Browsers
uses: actions/cache@v3
id: playwright-cache
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-${{ hashFiles('frontend/package-lock.json') }}
restore-keys: |
${{ runner.os }}-playwright-
- name: Download Docker image artifact - name: Download Docker image artifact
uses: actions/download-artifact@v4 uses: actions/download-artifact@v4
with: with:
name: docker-image name: docker-image
path: /tmp path: /tmp
- name: Load Docker Image
- name: Load Docker image
run: docker load -i /tmp/docker-image.tar run: docker load -i /tmp/docker-image.tar
- name: Install frontend dependencies - name: Install frontend dependencies
@@ -58,6 +78,7 @@ jobs:
- name: Install Playwright Browsers - name: Install Playwright Browsers
working-directory: ./frontend working-directory: ./frontend
if: steps.playwright-cache.outputs.cache-hit != 'true'
run: npx playwright install --with-deps chromium run: npx playwright install --with-deps chromium
- name: Run Docker Container with Sqlite DB - name: Run Docker Container with Sqlite DB
@@ -65,21 +86,34 @@ jobs:
docker run -d --name pocket-id-sqlite \ docker run -d --name pocket-id-sqlite \
-p 80:80 \ -p 80:80 \
-e APP_ENV=test \ -e APP_ENV=test \
pocket-id/pocket-id:test pocket-id:test
docker logs -f pocket-id-sqlite &> /tmp/backend.log &
- name: Run Playwright tests - name: Run Playwright tests
working-directory: ./frontend working-directory: ./frontend
run: npx playwright test run: npx playwright test
- uses: actions/upload-artifact@v4 - name: Upload Frontend Test Report
if: always() uses: actions/upload-artifact@v4
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
with: with:
name: playwright-report-sqlite name: playwright-report-sqlite
path: frontend/tests/.report path: frontend/tests/.report
include-hidden-files: true include-hidden-files: true
retention-days: 15 retention-days: 15
- name: Upload Backend Test Report
uses: actions/upload-artifact@v4
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
with:
name: backend-sqlite
path: /tmp/backend.log
include-hidden-files: true
retention-days: 15
test-postgres: test-postgres:
if: github.event.pull_request.head.ref != 'i18n_crowdin'
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: build needs: build
steps: steps:
@@ -90,12 +124,39 @@ jobs:
cache: "npm" cache: "npm"
cache-dependency-path: frontend/package-lock.json cache-dependency-path: frontend/package-lock.json
- name: Cache Playwright Browsers
uses: actions/cache@v3
id: playwright-cache
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-${{ hashFiles('frontend/package-lock.json') }}
restore-keys: |
${{ runner.os }}-playwright-
- name: Cache PostgreSQL Docker image
uses: actions/cache@v3
id: postgres-cache
with:
path: /tmp/postgres-image.tar
key: postgres-17-${{ runner.os }}
- name: Pull and save PostgreSQL image
if: steps.postgres-cache.outputs.cache-hit != 'true'
run: |
docker pull postgres:17
docker save postgres:17 > /tmp/postgres-image.tar
- name: Load PostgreSQL image from cache
if: steps.postgres-cache.outputs.cache-hit == 'true'
run: docker load < /tmp/postgres-image.tar
- name: Download Docker image artifact - name: Download Docker image artifact
uses: actions/download-artifact@v4 uses: actions/download-artifact@v4
with: with:
name: docker-image name: docker-image
path: /tmp path: /tmp
- name: Load Docker Image
- name: Load Docker image
run: docker load -i /tmp/docker-image.tar run: docker load -i /tmp/docker-image.tar
- name: Install frontend dependencies - name: Install frontend dependencies
@@ -104,6 +165,7 @@ jobs:
- name: Install Playwright Browsers - name: Install Playwright Browsers
working-directory: ./frontend working-directory: ./frontend
if: steps.playwright-cache.outputs.cache-hit != 'true'
run: npx playwright install --with-deps chromium run: npx playwright install --with-deps chromium
- name: Create Docker network - name: Create Docker network
@@ -137,17 +199,29 @@ jobs:
-p 80:80 \ -p 80:80 \
-e APP_ENV=test \ -e APP_ENV=test \
-e DB_PROVIDER=postgres \ -e DB_PROVIDER=postgres \
-e POSTGRES_CONNECTION_STRING=postgresql://postgres:postgres@pocket-id-db:5432/pocket-id \ -e DB_CONNECTION_STRING=postgresql://postgres:postgres@pocket-id-db:5432/pocket-id \
pocket-id/pocket-id:test pocket-id:test
docker logs -f pocket-id-postgres &> /tmp/backend.log &
- name: Run Playwright tests - name: Run Playwright tests
working-directory: ./frontend working-directory: ./frontend
run: npx playwright test run: npx playwright test
- uses: actions/upload-artifact@v4 - name: Upload Frontend Test Report
if: always() uses: actions/upload-artifact@v4
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
with: with:
name: playwright-report-postgres name: playwright-report-postgres
path: frontend/tests/.report path: frontend/tests/.report
include-hidden-files: true include-hidden-files: true
retention-days: 15 retention-days: 15
- name: Upload Backend Test Report
uses: actions/upload-artifact@v4
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
with:
name: backend-postgres
path: /tmp/backend.log
include-hidden-files: true
retention-days: 15

View File

@@ -25,6 +25,7 @@ jobs:
- name: Run backend unit tests - name: Run backend unit tests
working-directory: backend working-directory: backend
run: | run: |
set -e -o pipefail
go test -v ./... | tee /tmp/TestResults.log go test -v ./... | tee /tmp/TestResults.log
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
if: always() if: always()

34
.github/workflows/update-aaguids.yml vendored Normal file
View File

@@ -0,0 +1,34 @@
name: Update AAGUIDs
on:
schedule:
- cron: "0 0 * * 1" # Runs every Monday at midnight
workflow_dispatch: # Allows manual triggering of the workflow
permissions:
contents: write
jobs:
update-aaguids:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Fetch JSON data
run: |
curl -o data.json https://raw.githubusercontent.com/pocket-id/passkey-aaguids/refs/heads/main/combined_aaguid.json
- name: Process JSON data
run: |
mkdir -p backend/resources
jq -c 'map_values(.name)' data.json > backend/resources/aaguids.json
- name: Commit changes
run: |
git config --local user.email "action@github.com"
git config --local user.name "GitHub Action"
git add backend/resources/aaguids.json
git diff --staged --quiet || git commit -m "chore: update AAGUIDs"
git push

View File

@@ -1 +1 @@
0.42.0 0.49.0

5
.vscode/extensions.json vendored Normal file
View File

@@ -0,0 +1,5 @@
{
"recommendations": [
"inlang.vs-code-extension"
]
}

4
.vscode/launch.json vendored
View File

@@ -5,7 +5,7 @@
"name": "Backend", "name": "Backend",
"type": "go", "type": "go",
"request": "launch", "request": "launch",
"envFile": "${workspaceFolder}/backend/.env.example", "envFile": "${workspaceFolder}/backend/cmd/.env",
"env": { "env": {
"APP_ENV": "development" "APP_ENV": "development"
}, },
@@ -16,7 +16,7 @@
"name": "Frontend", "name": "Frontend",
"type": "node", "type": "node",
"request": "launch", "request": "launch",
"envFile": "${workspaceFolder}/frontend/.env.example", "envFile": "${workspaceFolder}/frontend/.env",
"cwd": "${workspaceFolder}/frontend", "cwd": "${workspaceFolder}/frontend",
"runtimeExecutable": "npm", "runtimeExecutable": "npm",
"runtimeArgs": [ "runtimeArgs": [

3
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"go.buildTags": "e2etest"
}

View File

@@ -1,3 +1,127 @@
## [](https://github.com/pocket-id/pocket-id/compare/v0.48.0...v) (2025-04-20)
### Features
* add ability to disable API key expiration email ([9122e75](https://github.com/pocket-id/pocket-id/commit/9122e75101ad39a40135ccf931eb2bfd351b5db6))
* add ability to send login code via email ([#457](https://github.com/pocket-id/pocket-id/issues/457)) ([fe1c4b1](https://github.com/pocket-id/pocket-id/commit/fe1c4b18cdcc46a4256e0c111b34f1ce00f8e0e1))
* add description to callback URL inputs ([eb689eb](https://github.com/pocket-id/pocket-id/commit/eb689eb56ec9eaf8b0fb1485040e26f841b9225d))
* send email to user when api key expires within 7 days ([#451](https://github.com/pocket-id/pocket-id/issues/451)) ([26f01f2](https://github.com/pocket-id/pocket-id/commit/26f01f205be01fb8abd8c2e564c90c0fc4480ea5))
### Bug Fixes
* disable animations not respected on authorize and logout page ([e571996](https://github.com/pocket-id/pocket-id/commit/e571996cb57d04232c1f47ab337ad656f48bb3cb))
* hide alternative sign in button if user is already authenticated ([4e05b82](https://github.com/pocket-id/pocket-id/commit/4e05b82f02740a4bae07cec6c6a64acd34ca0fc3))
* locale change in dropdown doesn't work on first try ([60bad9e](https://github.com/pocket-id/pocket-id/commit/60bad9e9859d81c9967e6939e1ed10a65145a936))
* remove limit of 20 callback URLs ([c37a3e0](https://github.com/pocket-id/pocket-id/commit/c37a3e0ed177c3bd2b9a618d1f4b0709004478b0))
## [](https://github.com/pocket-id/pocket-id/compare/v0.47.0...v) (2025-04-18)
### Features
* add gif support for logo and background image ([56a8b5d](https://github.com/pocket-id/pocket-id/commit/56a8b5d0c02643f869b77cf8475ddf2f9473880b))
* disable/enable users ([#437](https://github.com/pocket-id/pocket-id/issues/437)) ([c843a60](https://github.com/pocket-id/pocket-id/commit/c843a60131b813177b1e270c4f5d97613c700efa))
### Bug Fixes
* add "type" as reserved claim ([0111a58](https://github.com/pocket-id/pocket-id/commit/0111a58dac0342c5ac2fa25a050e8773810d2b0a))
* callback URL doesn't get rejected if it starts with a different string ([f0dce41](https://github.com/pocket-id/pocket-id/commit/f0dce41fbc5649b3a8fe65de36ca20efa521b880))
* profile picture empty for users without first or last name ([#449](https://github.com/pocket-id/pocket-id/issues/449)) ([5a6dfd9](https://github.com/pocket-id/pocket-id/commit/5a6dfd9e505f4c84e91b4b378b082fab10e8a8a8))
* user querying fails on global audit log page with Postgres ([84f1d5c](https://github.com/pocket-id/pocket-id/commit/84f1d5c906ec3f9a74ad3d2f36526eea847af5dd))
## [](https://github.com/pocket-id/pocket-id/compare/v0.46.0...v) (2025-04-16)
### Features
* add qrcode representation of one time link ([#424](https://github.com/pocket-id/pocket-id/issues/424)) ([#436](https://github.com/pocket-id/pocket-id/issues/436)) ([abf17f6](https://github.com/pocket-id/pocket-id/commit/abf17f62114a2de549b62cec462b9b0659ee23a7))
* disable animations setting toggle ([#442](https://github.com/pocket-id/pocket-id/issues/442)) ([b45cf68](https://github.com/pocket-id/pocket-id/commit/b45cf68295975f51777dab95950b98b8db0a9ae5))
### Bug Fixes
* define token type as claim for better client compatibility ([adf7458](https://github.com/pocket-id/pocket-id/commit/adf74586afb6ef9a00fb122c150b0248c5bc23f0))
## [](https://github.com/pocket-id/pocket-id/compare/v0.45.0...v) (2025-04-13)
### Features
* global audit log ([#320](https://github.com/pocket-id/pocket-id/issues/320)) ([b65e693](https://github.com/pocket-id/pocket-id/commit/b65e693e12be2e7e4cb75a74d6fd43bacb3f6a94))
* implement token introspection ([#405](https://github.com/pocket-id/pocket-id/issues/405)) ([7e5d16b](https://github.com/pocket-id/pocket-id/commit/7e5d16be9bdfccfa113924547e313886681d11bb))
* modernize ui ([#381](https://github.com/pocket-id/pocket-id/issues/381)) ([9881a1d](https://github.com/pocket-id/pocket-id/commit/9881a1df9efe32608ab116db71c0e4f66dae171c))
* **onboarding:** Added button when you don't have a passkey added. ([#426](https://github.com/pocket-id/pocket-id/issues/426)) ([72061ba](https://github.com/pocket-id/pocket-id/commit/72061ba4278a007437cee3a205c3076d58bde644))
### Bug Fixes
* add missing rollback for LDAP sync ([658a9ca](https://github.com/pocket-id/pocket-id/commit/658a9ca6dd8d2304ff3639a000bab02e91ff68a6))
* create reusable default profile pictures ([#406](https://github.com/pocket-id/pocket-id/issues/406)) ([734c681](https://github.com/pocket-id/pocket-id/commit/734c6813eaef166235ae801747e3652d17ae0e2a))
* ensure file descriptors are closed + other bugs ([#413](https://github.com/pocket-id/pocket-id/issues/413)) ([2f76461](https://github.com/pocket-id/pocket-id/commit/2f7646105e26423f47cbe49dae97e40c4a01a025))
* ensure indexes on audit_logs table ([#415](https://github.com/pocket-id/pocket-id/issues/415)) ([9e88926](https://github.com/pocket-id/pocket-id/commit/9e88926283a7a663bfc7fd4f4aa16bd02f614176))
* ignore profile picture cache after profile picture gets updated ([4ba6893](https://github.com/pocket-id/pocket-id/commit/4ba68938dd2a631c633fcb65d8c35cb039d3f59c))
* improve LDAP error handling ([#425](https://github.com/pocket-id/pocket-id/issues/425)) ([796bc7e](https://github.com/pocket-id/pocket-id/commit/796bc7ed3453839b1dc8d846b71fe9fac9a2d646))
* use transactions when operations involve multiple database queries ([#392](https://github.com/pocket-id/pocket-id/issues/392)) ([ec626ee](https://github.com/pocket-id/pocket-id/commit/ec626ee7977306539fd1d70cc9091590f0a54af6))
* use UUID for temporary file names ([ccc18d7](https://github.com/pocket-id/pocket-id/commit/ccc18d716f16a7ef1775d30982e2ba7b5ff159a6))
### Performance Improvements
* run async operations in parallel in server load functions ([1762629](https://github.com/pocket-id/pocket-id/commit/17626295964244c5582806bd0f413da2c799d5ad))
## [](https://github.com/pocket-id/pocket-id/compare/v0.44.0...v) (2025-03-29)
### Features
* add support for ECDSA and EdDSA keys ([#359](https://github.com/pocket-id/pocket-id/issues/359)) ([96876a9](https://github.com/pocket-id/pocket-id/commit/96876a99c586508b72c27669ab200ff6a29db771))
### Bug Fixes
* ldap users aren't deleted if removed from ldap server ([7e65827](https://github.com/pocket-id/pocket-id/commit/7e658276f04d08a1f5117796e55d45e310204dab))
* use value receiver for `AuditLogData` ([cbd1bbd](https://github.com/pocket-id/pocket-id/commit/cbd1bbdf741eedd03e93598d67623c75c74b6212))
* use WAL for SQLite by default and set busy_timeout ([#388](https://github.com/pocket-id/pocket-id/issues/388)) ([519d58d](https://github.com/pocket-id/pocket-id/commit/519d58d88c906abc5139e35933bdeba0396c10a2))
## [](https://github.com/pocket-id/pocket-id/compare/v0.43.1...v) (2025-03-25)
### Features
* add OIDC refresh_token support ([#325](https://github.com/pocket-id/pocket-id/issues/325)) ([b8dcda8](https://github.com/pocket-id/pocket-id/commit/b8dcda80497e554d163a370eff81fe000f8831f4))
### Bug Fixes
* hash the refresh token in the DB (security) ([#379](https://github.com/pocket-id/pocket-id/issues/379)) ([8c96381](https://github.com/pocket-id/pocket-id/commit/8c963818bb90c84dac04018eec93790900d4b0ce))
* skip ldap objects without a valid unique id ([#376](https://github.com/pocket-id/pocket-id/issues/376)) ([cdfe816](https://github.com/pocket-id/pocket-id/commit/cdfe8161d4429bdfe879887fe0b563a67c14f50b))
* stop container if Caddy, the frontend or the backend fails ([e6f5019](https://github.com/pocket-id/pocket-id/commit/e6f50191cf05a5d0ac0e0000cf66423646f1920e))
## [](https://github.com/pocket-id/pocket-id/compare/v0.43.0...v) (2025-03-20)
### Bug Fixes
* wrong base locale causes crash ([3120ebf](https://github.com/pocket-id/pocket-id/commit/3120ebf239b90f0bc0a0af33f30622e034782398))
## [](https://github.com/pocket-id/pocket-id/compare/v0.42.1...v) (2025-03-20)
### Features
* add support for translations ([#349](https://github.com/pocket-id/pocket-id/issues/349)) ([269b5a3](https://github.com/pocket-id/pocket-id/commit/269b5a3c9249bb8081c74741141d3d5a69ea42a2))
* **passkeys:** name new passkeys based on agguids ([#332](https://github.com/pocket-id/pocket-id/issues/332)) ([041c565](https://github.com/pocket-id/pocket-id/commit/041c565dc10f15edb3e8ab58e9a4df5e48a2a6d3))
## [](https://github.com/pocket-id/pocket-id/compare/v0.42.0...v) (2025-03-18)
### Bug Fixes
* kid not added to JWTs ([f7e36a4](https://github.com/pocket-id/pocket-id/commit/f7e36a422ea6b5327360c9a13308ae408ff7fffe))
## [](https://github.com/pocket-id/pocket-id/compare/v0.41.0...v) (2025-03-18) ## [](https://github.com/pocket-id/pocket-id/compare/v0.41.0...v) (2025-03-18)

View File

@@ -49,7 +49,7 @@ The backend is built with [Gin](https://gin-gonic.com) and written in Go.
1. Open the `backend` folder 1. Open the `backend` folder
2. Copy the `.env.example` file to `.env` and change the `APP_ENV` to `development` 2. Copy the `.env.example` file to `.env` and change the `APP_ENV` to `development`
3. Start the backend with `go run cmd/main.go` 3. Start the backend with `go run -tags e2etest ./cmd`
### Frontend ### Frontend

View File

@@ -1,3 +1,6 @@
# Tags passed to "go build"
ARG BUILD_TAGS=""
# Stage 1: Build Frontend # Stage 1: Build Frontend
FROM node:22-alpine AS frontend-builder FROM node:22-alpine AS frontend-builder
WORKDIR /app/frontend WORKDIR /app/frontend
@@ -8,7 +11,8 @@ RUN npm run build
RUN npm prune --production RUN npm prune --production
# Stage 2: Build Backend # Stage 2: Build Backend
FROM golang:1.23-alpine AS backend-builder FROM golang:1.24-alpine AS backend-builder
ARG BUILD_TAGS
WORKDIR /app/backend WORKDIR /app/backend
COPY ./backend/go.mod ./backend/go.sum ./ COPY ./backend/go.mod ./backend/go.sum ./
RUN go mod download RUN go mod download
@@ -17,7 +21,12 @@ RUN apk add --no-cache gcc musl-dev
COPY ./backend ./ COPY ./backend ./
WORKDIR /app/backend/cmd WORKDIR /app/backend/cmd
RUN CGO_ENABLED=1 GOOS=linux go build -o /app/backend/pocket-id-backend . RUN CGO_ENABLED=1 \
GOOS=linux \
go build \
-tags "${BUILD_TAGS}" \
-o /app/backend/pocket-id-backend \
.
# Stage 3: Production Image # Stage 3: Production Image
FROM node:22-alpine FROM node:22-alpine

64
backend/.golangci.yml Normal file
View File

@@ -0,0 +1,64 @@
version: "2"
run:
tests: true
timeout: 5m
linters:
default: none
enable:
- asasalint
- asciicheck
- bidichk
- bodyclose
- contextcheck
- copyloopvar
- durationcheck
- errcheck
- errchkjson
- errorlint
- exhaustive
- gocheckcompilerdirectives
- gochecksumtype
- gocognit
- gocritic
- gosec
- gosmopolitan
- govet
- ineffassign
- loggercheck
- makezero
- musttag
- nilerr
- nilnesserr
- noctx
- protogetter
- reassign
- recvcheck
- rowserrcheck
- spancheck
- sqlclosecheck
- staticcheck
- testifylint
- unused
- usestdlibvars
- zerologlint
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
paths:
- third_party$
- builtin$
- examples$
- internal/service/test_service.go
formatters:
enable:
- goimports
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$

View File

@@ -4,6 +4,10 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/bootstrap" "github.com/pocket-id/pocket-id/backend/internal/bootstrap"
) )
// @title Pocket ID API
// @version 1.0
// @description.markdown
func main() { func main() {
bootstrap.Bootstrap() bootstrap.Bootstrap()
} }

View File

@@ -1,6 +1,6 @@
module github.com/pocket-id/pocket-id/backend module github.com/pocket-id/pocket-id/backend
go 1.23.1 go 1.24.0
require ( require (
github.com/caarlos0/env/v11 v11.3.1 github.com/caarlos0/env/v11 v11.3.1
@@ -14,11 +14,10 @@ require (
github.com/go-ldap/ldap/v3 v3.4.10 github.com/go-ldap/ldap/v3 v3.4.10
github.com/go-playground/validator/v10 v10.24.0 github.com/go-playground/validator/v10 v10.24.0
github.com/go-webauthn/webauthn v0.11.2 github.com/go-webauthn/webauthn v0.11.2
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/golang-migrate/migrate/v4 v4.18.2 github.com/golang-migrate/migrate/v4 v4.18.2
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/lestrrat-go/jwx/v3 v3.0.0-alpha3 github.com/lestrrat-go/jwx/v3 v3.0.0-beta1
github.com/mileusna/useragent v1.3.5 github.com/mileusna/useragent v1.3.5
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2 github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
@@ -45,6 +44,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-webauthn/x v0.1.16 // indirect github.com/go-webauthn/x v0.1.16 // indirect
github.com/goccy/go-json v0.10.4 // indirect github.com/goccy/go-json v0.10.4 // indirect
github.com/golang-jwt/jwt/v5 v5.2.2 // indirect
github.com/google/go-tpm v0.9.3 // indirect github.com/google/go-tpm v0.9.3 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect
@@ -79,7 +79,7 @@ require (
go.uber.org/atomic v1.11.0 // indirect go.uber.org/atomic v1.11.0 // indirect
golang.org/x/arch v0.13.0 // indirect golang.org/x/arch v0.13.0 // indirect
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect
golang.org/x/net v0.36.0 // indirect golang.org/x/net v0.38.0 // indirect
golang.org/x/sync v0.12.0 // indirect golang.org/x/sync v0.12.0 // indirect
golang.org/x/sys v0.31.0 // indirect golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect golang.org/x/text v0.23.0 // indirect

View File

@@ -78,8 +78,8 @@ github.com/goccy/go-json v0.10.4 h1:JSwxQzIqKfmFX1swYPpUThQZp/Ka4wzJdK0LWVytLPM=
github.com/goccy/go-json v0.10.4/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-json v0.10.4/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8= github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8=
github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk= github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
@@ -145,8 +145,8 @@ github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZ
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 h1:pzDjP9dSONCFQC/AE3mWUnHILGiYPiMKzQIS+weKJXA= github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 h1:pzDjP9dSONCFQC/AE3mWUnHILGiYPiMKzQIS+weKJXA=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1/go.mod h1:wdsgouffPvWPEYh8t7PRH/PidR5sfVqt0na4Nhj60Ms= github.com/lestrrat-go/httprc/v3 v3.0.0-beta1/go.mod h1:wdsgouffPvWPEYh8t7PRH/PidR5sfVqt0na4Nhj60Ms=
github.com/lestrrat-go/jwx/v3 v3.0.0-alpha3 h1:HHT8iW+UcPBgBr5A3soZQQsL5cBor/u6BkLB+wzY/R0= github.com/lestrrat-go/jwx/v3 v3.0.0-beta1 h1:Iqjb8JvWjh34Jv8DeM2wQ1aG5fzFBzwQu7rlqwuJB0I=
github.com/lestrrat-go/jwx/v3 v3.0.0-alpha3/go.mod h1:ak32WoNtHE0aLowVWBcCvXngcAnW4tuC0YhFwOr/kwc= github.com/lestrrat-go/jwx/v3 v3.0.0-beta1/go.mod h1:ak32WoNtHE0aLowVWBcCvXngcAnW4tuC0YhFwOr/kwc=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
@@ -255,8 +255,8 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.36.0/go.mod h1:bFmbeoIPfrw4sMHNhb4J9f6+tPziuGjq7Jk/38fxi1I= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@@ -1,17 +1,24 @@
package bootstrap package bootstrap
import ( import (
"context"
_ "github.com/golang-migrate/migrate/v4/source/file" _ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/pocket-id/pocket-id/backend/internal/service" "github.com/pocket-id/pocket-id/backend/internal/service"
) )
func Bootstrap() { func Bootstrap() {
ctx := context.TODO()
initApplicationImages() initApplicationImages()
migrateConfigDBConnstring()
db := newDatabase() db := newDatabase()
appConfigService := service.NewAppConfigService(db) appConfigService := service.NewAppConfigService(ctx, db)
migrateKey() migrateKey()
initRouter(db, appConfigService) initRouter(ctx, db, appConfigService)
} }

View File

@@ -0,0 +1,34 @@
package bootstrap
import (
"log"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
// Performs the migration of the database connection string
// See: https://github.com/pocket-id/pocket-id/pull/388
func migrateConfigDBConnstring() {
switch common.EnvConfig.DbProvider {
case common.DbProviderSqlite:
// Check if we're using the deprecated SqliteDBPath env var
if common.EnvConfig.SqliteDBPath != "" {
connString := "file:" + common.EnvConfig.SqliteDBPath + "?_journal_mode=WAL&_busy_timeout=2500&_txlock=immediate"
common.EnvConfig.DbConnectionString = connString
common.EnvConfig.SqliteDBPath = ""
log.Printf("[WARN] Env var 'SQLITE_DB_PATH' is deprecated - use 'DB_CONNECTION_STRING' instead with the value: '%s'", connString)
}
case common.DbProviderPostgres:
// Check if we're using the deprecated PostgresConnectionString alias
if common.EnvConfig.PostgresConnectionString != "" {
common.EnvConfig.DbConnectionString = common.EnvConfig.PostgresConnectionString
common.EnvConfig.PostgresConnectionString = ""
log.Print("[WARN] Env var 'POSTGRES_CONNECTION_STRING' is deprecated - use 'DB_CONNECTION_STRING' instead with the same value")
}
default:
// We don't do anything here in the default case
// This is an error, but will be handled later on
}
}

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"log" "log"
"os" "os"
"strings"
"time" "time"
"github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4"
@@ -38,6 +39,7 @@ func newDatabase() (db *gorm.DB) {
case common.DbProviderPostgres: case common.DbProviderPostgres:
driver, err = postgresMigrate.WithInstance(sqlDb, &postgresMigrate.Config{}) driver, err = postgresMigrate.WithInstance(sqlDb, &postgresMigrate.Config{})
default: default:
// Should never happen at this point
log.Fatalf("unsupported database provider: %s", common.EnvConfig.DbProvider) log.Fatalf("unsupported database provider: %s", common.EnvConfig.DbProvider)
} }
if err != nil { if err != nil {
@@ -56,17 +58,17 @@ func migrateDatabase(driver database.Driver) error {
// Use the embedded migrations // Use the embedded migrations
source, err := iofs.New(resources.FS, "migrations/"+string(common.EnvConfig.DbProvider)) source, err := iofs.New(resources.FS, "migrations/"+string(common.EnvConfig.DbProvider))
if err != nil { if err != nil {
return fmt.Errorf("failed to create embedded migration source: %v", err) return fmt.Errorf("failed to create embedded migration source: %w", err)
} }
m, err := migrate.NewWithInstance("iofs", source, "pocket-id", driver) m, err := migrate.NewWithInstance("iofs", source, "pocket-id", driver)
if err != nil { if err != nil {
return fmt.Errorf("failed to create migration instance: %v", err) return fmt.Errorf("failed to create migration instance: %w", err)
} }
err = m.Up() err = m.Up()
if err != nil && !errors.Is(err, migrate.ErrNoChange) { if err != nil && !errors.Is(err, migrate.ErrNoChange) {
return fmt.Errorf("failed to apply migrations: %v", err) return fmt.Errorf("failed to apply migrations: %w", err)
} }
return nil return nil
@@ -78,9 +80,18 @@ func connectDatabase() (db *gorm.DB, err error) {
// Choose the correct database provider // Choose the correct database provider
switch common.EnvConfig.DbProvider { switch common.EnvConfig.DbProvider {
case common.DbProviderSqlite: case common.DbProviderSqlite:
dialector = sqlite.Open(common.EnvConfig.SqliteDBPath) if common.EnvConfig.DbConnectionString == "" {
return nil, errors.New("missing required env var 'DB_CONNECTION_STRING' for SQLite database")
}
if !strings.HasPrefix(common.EnvConfig.DbConnectionString, "file:") {
return nil, errors.New("invalid value for env var 'DB_CONNECTION_STRING': does not begin with 'file:'")
}
dialector = sqlite.Open(common.EnvConfig.DbConnectionString)
case common.DbProviderPostgres: case common.DbProviderPostgres:
dialector = postgres.Open(common.EnvConfig.PostgresConnectionString) if common.EnvConfig.DbConnectionString == "" {
return nil, errors.New("missing required env var 'DB_CONNECTION_STRING' for Postgres database")
}
dialector = postgres.Open(common.EnvConfig.DbConnectionString)
default: default:
return nil, fmt.Errorf("unsupported database provider: %s", common.EnvConfig.DbProvider) return nil, fmt.Errorf("unsupported database provider: %s", common.EnvConfig.DbProvider)
} }
@@ -91,14 +102,14 @@ func connectDatabase() (db *gorm.DB, err error) {
Logger: getLogger(), Logger: getLogger(),
}) })
if err == nil { if err == nil {
break return db, nil
} else {
log.Printf("Attempt %d: Failed to initialize database. Retrying...", i)
time.Sleep(3 * time.Second)
} }
log.Printf("Attempt %d: Failed to initialize database. Retrying...", i)
time.Sleep(3 * time.Second)
} }
return db, err return nil, err
} }
func getLogger() logger.Interface { func getLogger() logger.Interface {

View File

@@ -0,0 +1,21 @@
//go:build e2etest
package bootstrap
import (
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/controller"
"github.com/pocket-id/pocket-id/backend/internal/service"
)
// When building for E2E tests, add the e2etest controller
func init() {
registerTestControllers = []func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService){
func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService) {
testService := service.NewTestService(db, appConfigService, jwtService)
controller.NewTestController(apiGroup, testService)
},
}
}

View File

@@ -92,7 +92,10 @@ func loadKeyPEM(path string) (jwk.Key, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate key ID: %w", err) return nil, fmt.Errorf("failed to generate key ID: %w", err)
} }
key.Set(jwk.KeyIDKey, keyId) err = key.Set(jwk.KeyIDKey, keyId)
if err != nil {
return nil, fmt.Errorf("failed to set key ID: %w", err)
}
// Populate other required fields // Populate other required fields
_ = key.Set(jwk.KeyUsageKey, service.KeyUsageSigning) _ = key.Set(jwk.KeyUsageKey, service.KeyUsageSigning)

View File

@@ -101,25 +101,25 @@ func TestLoadKeyPEM(t *testing.T) {
// Check key ID is set // Check key ID is set
var keyID string var keyID string
err = key.Get(jwk.KeyIDKey, &keyID) err = key.Get(jwk.KeyIDKey, &keyID)
assert.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, keyID) assert.NotEmpty(t, keyID)
// Check algorithm is set // Check algorithm is set
var alg jwa.SignatureAlgorithm var alg jwa.SignatureAlgorithm
err = key.Get(jwk.AlgorithmKey, &alg) err = key.Get(jwk.AlgorithmKey, &alg)
assert.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, alg) assert.NotEmpty(t, alg)
// Check key usage is set // Check key usage is set
var keyUsage string var keyUsage string
err = key.Get(jwk.KeyUsageKey, &keyUsage) err = key.Get(jwk.KeyUsageKey, &keyUsage)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, service.KeyUsageSigning, keyUsage) assert.Equal(t, service.KeyUsageSigning, keyUsage)
}) })
t.Run("file not found", func(t *testing.T) { t.Run("file not found", func(t *testing.T) {
key, err := loadKeyPEM(filepath.Join(tempDir, "nonexistent.pem")) key, err := loadKeyPEM(filepath.Join(tempDir, "nonexistent.pem"))
assert.Error(t, err) require.Error(t, err)
assert.Nil(t, key) assert.Nil(t, key)
}) })
@@ -129,7 +129,7 @@ func TestLoadKeyPEM(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
key, err := loadKeyPEM(invalidPath) key, err := loadKeyPEM(invalidPath)
assert.Error(t, err) require.Error(t, err)
assert.Nil(t, key) assert.Nil(t, key)
}) })
} }

View File

@@ -1,6 +1,7 @@
package bootstrap package bootstrap
import ( import (
"context"
"log" "log"
"net" "net"
"time" "time"
@@ -16,7 +17,10 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) { // This is used to register additional controllers for tests
var registerTestControllers []func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService)
func initRouter(ctx context.Context, db *gorm.DB, appConfigService *service.AppConfigService) {
// Set the appropriate Gin mode based on the environment // Set the appropriate Gin mode based on the environment
switch common.EnvConfig.AppEnv { switch common.EnvConfig.AppEnv {
case "production": case "production":
@@ -33,20 +37,19 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
// Initialize services // Initialize services
emailService, err := service.NewEmailService(appConfigService, db) emailService, err := service.NewEmailService(appConfigService, db)
if err != nil { if err != nil {
log.Fatalf("Unable to create email service: %s", err) log.Fatalf("Unable to create email service: %v", err)
} }
geoLiteService := service.NewGeoLiteService() geoLiteService := service.NewGeoLiteService(ctx)
auditLogService := service.NewAuditLogService(db, appConfigService, emailService, geoLiteService) auditLogService := service.NewAuditLogService(db, appConfigService, emailService, geoLiteService)
jwtService := service.NewJwtService(appConfigService) jwtService := service.NewJwtService(appConfigService)
webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService) webauthnService := service.NewWebAuthnService(db, jwtService, auditLogService, appConfigService)
userService := service.NewUserService(db, jwtService, auditLogService, emailService, appConfigService) userService := service.NewUserService(db, jwtService, auditLogService, emailService, appConfigService)
customClaimService := service.NewCustomClaimService(db) customClaimService := service.NewCustomClaimService(db)
oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService) oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService)
testService := service.NewTestService(db, appConfigService, jwtService)
userGroupService := service.NewUserGroupService(db, appConfigService) userGroupService := service.NewUserGroupService(db, appConfigService)
ldapService := service.NewLdapService(db, appConfigService, userService, userGroupService) ldapService := service.NewLdapService(db, appConfigService, userService, userGroupService)
apiKeyService := service.NewApiKeyService(db) apiKeyService := service.NewApiKeyService(db, emailService)
rateLimitMiddleware := middleware.NewRateLimitMiddleware() rateLimitMiddleware := middleware.NewRateLimitMiddleware()
@@ -55,11 +58,13 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
r.Use(middleware.NewErrorHandlerMiddleware().Add()) r.Use(middleware.NewErrorHandlerMiddleware().Add())
r.Use(rateLimitMiddleware.Add(rate.Every(time.Second), 60)) r.Use(rateLimitMiddleware.Add(rate.Every(time.Second), 60))
job.RegisterLdapJobs(ldapService, appConfigService) job.RegisterLdapJobs(ctx, ldapService, appConfigService)
job.RegisterDbCleanupJobs(db) job.RegisterDbCleanupJobs(ctx, db)
job.RegisterFileCleanupJobs(ctx, db)
job.RegisterApiKeyExpiryJob(ctx, apiKeyService, appConfigService)
// Initialize middleware for specific routes // Initialize middleware for specific routes
authMiddleware := middleware.NewAuthMiddleware(apiKeyService, jwtService) authMiddleware := middleware.NewAuthMiddleware(apiKeyService, userService, jwtService)
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware() fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()
// Set up API routes // Set up API routes
@@ -75,7 +80,9 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
// Add test controller in non-production environments // Add test controller in non-production environments
if common.EnvConfig.AppEnv != "production" { if common.EnvConfig.AppEnv != "production" {
controller.NewTestController(apiGroup, testService) for _, f := range registerTestControllers {
f(apiGroup, db, appConfigService, jwtService)
}
} }
// Set up base routes // Set up base routes

View File

@@ -20,8 +20,9 @@ type EnvConfigSchema struct {
AppEnv string `env:"APP_ENV"` AppEnv string `env:"APP_ENV"`
AppURL string `env:"PUBLIC_APP_URL"` AppURL string `env:"PUBLIC_APP_URL"`
DbProvider DbProvider `env:"DB_PROVIDER"` DbProvider DbProvider `env:"DB_PROVIDER"`
SqliteDBPath string `env:"SQLITE_DB_PATH"` DbConnectionString string `env:"DB_CONNECTION_STRING"`
PostgresConnectionString string `env:"POSTGRES_CONNECTION_STRING"` SqliteDBPath string `env:"SQLITE_DB_PATH"` // Deprecated: use "DB_CONNECTION_STRING" instead
PostgresConnectionString string `env:"POSTGRES_CONNECTION_STRING"` // Deprecated: use "DB_CONNECTION_STRING" instead
UploadPath string `env:"UPLOAD_PATH"` UploadPath string `env:"UPLOAD_PATH"`
KeysPath string `env:"KEYS_PATH"` KeysPath string `env:"KEYS_PATH"`
Port string `env:"BACKEND_PORT"` Port string `env:"BACKEND_PORT"`
@@ -35,7 +36,8 @@ type EnvConfigSchema struct {
var EnvConfig = &EnvConfigSchema{ var EnvConfig = &EnvConfigSchema{
AppEnv: "production", AppEnv: "production",
DbProvider: "sqlite", DbProvider: "sqlite",
SqliteDBPath: "data/pocket-id.db", DbConnectionString: "file:data/pocket-id.db?_journal_mode=WAL&_busy_timeout=2500&_txlock=immediate",
SqliteDBPath: "",
PostgresConnectionString: "", PostgresConnectionString: "",
UploadPath: "data/uploads", UploadPath: "data/uploads",
KeysPath: "data/keys", KeysPath: "data/keys",
@@ -56,12 +58,12 @@ func init() {
// Validate the environment variables // Validate the environment variables
switch EnvConfig.DbProvider { switch EnvConfig.DbProvider {
case DbProviderSqlite: case DbProviderSqlite:
if EnvConfig.SqliteDBPath == "" { if EnvConfig.DbConnectionString == "" {
log.Fatal("Missing SQLITE_DB_PATH environment variable") log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for SQLite database")
} }
case DbProviderPostgres: case DbProviderPostgres:
if EnvConfig.PostgresConnectionString == "" { if EnvConfig.DbConnectionString == "" {
log.Fatal("Missing POSTGRES_CONNECTION_STRING environment variable") log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for Postgres database")
} }
default: default:
log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'") log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")

View File

@@ -1,6 +1,7 @@
package common package common
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
) )
@@ -17,10 +18,16 @@ type AlreadyInUseError struct {
} }
func (e *AlreadyInUseError) Error() string { func (e *AlreadyInUseError) Error() string {
return fmt.Sprintf("%s is already in use", e.Property) return e.Property + " is already in use"
} }
func (e *AlreadyInUseError) HttpStatusCode() int { return 400 } func (e *AlreadyInUseError) HttpStatusCode() int { return 400 }
func (e *AlreadyInUseError) Is(target error) bool {
// Ignore the field property when checking if an error is of the type AlreadyInUseError
x := &AlreadyInUseError{}
return errors.As(target, &x)
}
type SetupAlreadyCompletedError struct{} type SetupAlreadyCompletedError struct{}
func (e *SetupAlreadyCompletedError) Error() string { return "setup already completed" } func (e *SetupAlreadyCompletedError) Error() string { return "setup already completed" }
@@ -75,11 +82,6 @@ type FileTypeNotSupportedError struct{}
func (e *FileTypeNotSupportedError) Error() string { return "file type not supported" } func (e *FileTypeNotSupportedError) Error() string { return "file type not supported" }
func (e *FileTypeNotSupportedError) HttpStatusCode() int { return 400 } func (e *FileTypeNotSupportedError) HttpStatusCode() int { return 400 }
type InvalidCredentialsError struct{}
func (e *InvalidCredentialsError) Error() string { return "no user found with provided credentials" }
func (e *InvalidCredentialsError) HttpStatusCode() int { return 400 }
type FileTooLargeError struct { type FileTooLargeError struct {
MaxSize string MaxSize string
} }
@@ -222,8 +224,7 @@ type InvalidUUIDError struct{}
func (e *InvalidUUIDError) Error() string { func (e *InvalidUUIDError) Error() string {
return "Invalid UUID" return "Invalid UUID"
} }
func (e *InvalidUUIDError) HttpStatusCode() int { return http.StatusBadRequest }
type InvalidEmailError struct{}
type OneTimeAccessDisabledError struct{} type OneTimeAccessDisabledError struct{}
@@ -237,21 +238,61 @@ type InvalidAPIKeyError struct{}
func (e *InvalidAPIKeyError) Error() string { func (e *InvalidAPIKeyError) Error() string {
return "Invalid Api Key" return "Invalid Api Key"
} }
func (e *InvalidAPIKeyError) HttpStatusCode() int { return http.StatusUnauthorized }
type NoAPIKeyProvidedError struct{} type NoAPIKeyProvidedError struct{}
func (e *NoAPIKeyProvidedError) Error() string { func (e *NoAPIKeyProvidedError) Error() string {
return "No API Key Provided" return "No API Key Provided"
} }
func (e *NoAPIKeyProvidedError) HttpStatusCode() int { return http.StatusUnauthorized }
type APIKeyNotFoundError struct{} type APIKeyNotFoundError struct{}
func (e *APIKeyNotFoundError) Error() string { func (e *APIKeyNotFoundError) Error() string {
return "API Key Not Found" return "API Key Not Found"
} }
func (e *APIKeyNotFoundError) HttpStatusCode() int { return http.StatusUnauthorized }
type APIKeyExpirationDateError struct{} type APIKeyExpirationDateError struct{}
func (e *APIKeyExpirationDateError) Error() string { func (e *APIKeyExpirationDateError) Error() string {
return "API Key expiration time must be in the future" return "API Key expiration time must be in the future"
} }
func (e *APIKeyExpirationDateError) HttpStatusCode() int { return http.StatusBadRequest }
type OidcInvalidRefreshTokenError struct{}
func (e *OidcInvalidRefreshTokenError) Error() string {
return "refresh token is invalid or expired"
}
func (e *OidcInvalidRefreshTokenError) HttpStatusCode() int {
return http.StatusBadRequest
}
type OidcMissingRefreshTokenError struct{}
func (e *OidcMissingRefreshTokenError) Error() string {
return "refresh token is required"
}
func (e *OidcMissingRefreshTokenError) HttpStatusCode() int {
return http.StatusBadRequest
}
type OidcMissingAuthorizationCodeError struct{}
func (e *OidcMissingAuthorizationCodeError) Error() string {
return "authorization code is required"
}
func (e *OidcMissingAuthorizationCodeError) HttpStatusCode() int {
return http.StatusBadRequest
}
type UserDisabledError struct{}
func (e *UserDisabledError) Error() string {
return "User account is disabled"
}
func (e *UserDisabledError) HttpStatusCode() int {
return http.StatusForbidden
}

View File

@@ -43,25 +43,25 @@ func NewApiKeyController(group *gin.RouterGroup, authMiddleware *middleware.Auth
// @Param sort_column query string false "Column to sort by" default("created_at") // @Param sort_column query string false "Column to sort by" default("created_at")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc") // @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
// @Success 200 {object} dto.Paginated[dto.ApiKeyDto] // @Success 200 {object} dto.Paginated[dto.ApiKeyDto]
// @Router /api-keys [get] // @Router /api/api-keys [get]
func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) { func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) {
userID := ctx.GetString("userID") userID := ctx.GetString("userID")
var sortedPaginationRequest utils.SortedPaginationRequest var sortedPaginationRequest utils.SortedPaginationRequest
if err := ctx.ShouldBindQuery(&sortedPaginationRequest); err != nil { if err := ctx.ShouldBindQuery(&sortedPaginationRequest); err != nil {
ctx.Error(err) _ = ctx.Error(err)
return return
} }
apiKeys, pagination, err := c.apiKeyService.ListApiKeys(userID, sortedPaginationRequest) apiKeys, pagination, err := c.apiKeyService.ListApiKeys(ctx.Request.Context(), userID, sortedPaginationRequest)
if err != nil { if err != nil {
ctx.Error(err) _ = ctx.Error(err)
return return
} }
var apiKeysDto []dto.ApiKeyDto var apiKeysDto []dto.ApiKeyDto
if err := dto.MapStructList(apiKeys, &apiKeysDto); err != nil { if err := dto.MapStructList(apiKeys, &apiKeysDto); err != nil {
ctx.Error(err) _ = ctx.Error(err)
return return
} }
@@ -77,25 +77,25 @@ func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) {
// @Tags API Keys // @Tags API Keys
// @Param api_key body dto.ApiKeyCreateDto true "API key information" // @Param api_key body dto.ApiKeyCreateDto true "API key information"
// @Success 201 {object} dto.ApiKeyResponseDto "Created API key with token" // @Success 201 {object} dto.ApiKeyResponseDto "Created API key with token"
// @Router /api-keys [post] // @Router /api/api-keys [post]
func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) { func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) {
userID := ctx.GetString("userID") userID := ctx.GetString("userID")
var input dto.ApiKeyCreateDto var input dto.ApiKeyCreateDto
if err := ctx.ShouldBindJSON(&input); err != nil { if err := ctx.ShouldBindJSON(&input); err != nil {
ctx.Error(err) _ = ctx.Error(err)
return return
} }
apiKey, token, err := c.apiKeyService.CreateApiKey(userID, input) apiKey, token, err := c.apiKeyService.CreateApiKey(ctx.Request.Context(), userID, input)
if err != nil { if err != nil {
ctx.Error(err) _ = ctx.Error(err)
return return
} }
var apiKeyDto dto.ApiKeyDto var apiKeyDto dto.ApiKeyDto
if err := dto.MapStruct(apiKey, &apiKeyDto); err != nil { if err := dto.MapStruct(apiKey, &apiKeyDto); err != nil {
ctx.Error(err) _ = ctx.Error(err)
return return
} }
@@ -111,13 +111,13 @@ func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) {
// @Tags API Keys // @Tags API Keys
// @Param id path string true "API Key ID" // @Param id path string true "API Key ID"
// @Success 204 "No Content" // @Success 204 "No Content"
// @Router /api-keys/{id} [delete] // @Router /api/api-keys/{id} [delete]
func (c *ApiKeyController) revokeApiKeyHandler(ctx *gin.Context) { func (c *ApiKeyController) revokeApiKeyHandler(ctx *gin.Context) {
userID := ctx.GetString("userID") userID := ctx.GetString("userID")
apiKeyID := ctx.Param("id") apiKeyID := ctx.Param("id")
if err := c.apiKeyService.RevokeApiKey(userID, apiKeyID); err != nil { if err := c.apiKeyService.RevokeApiKey(ctx.Request.Context(), userID, apiKeyID); err != nil {
ctx.Error(err) _ = ctx.Error(err)
return return
} }

View File

@@ -1,8 +1,8 @@
package controller package controller
import ( import (
"fmt"
"net/http" "net/http"
"strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
@@ -60,19 +60,15 @@ type AppConfigController struct {
// @Failure 500 {object} object "{"error": "error message"}" // @Failure 500 {object} object "{"error": "error message"}"
// @Router /application-configuration [get] // @Router /application-configuration [get]
func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) { func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListAppConfig(false) configuration := acc.appConfigService.ListAppConfig(false)
if err != nil {
c.Error(err)
return
}
var configVariablesDto []dto.PublicAppConfigVariableDto var configVariablesDto []dto.PublicAppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil { if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
c.JSON(200, configVariablesDto) c.JSON(http.StatusOK, configVariablesDto)
} }
// listAllAppConfigHandler godoc // listAllAppConfigHandler godoc
@@ -85,19 +81,15 @@ func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
// @Security BearerAuth // @Security BearerAuth
// @Router /application-configuration/all [get] // @Router /application-configuration/all [get]
func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) { func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListAppConfig(true) configuration := acc.appConfigService.ListAppConfig(true)
if err != nil {
c.Error(err)
return
}
var configVariablesDto []dto.AppConfigVariableDto var configVariablesDto []dto.AppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil { if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
c.JSON(200, configVariablesDto) c.JSON(http.StatusOK, configVariablesDto)
} }
// updateAppConfigHandler godoc // updateAppConfigHandler godoc
@@ -109,23 +101,23 @@ func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
// @Param body body dto.AppConfigUpdateDto true "Application Configuration" // @Param body body dto.AppConfigUpdateDto true "Application Configuration"
// @Success 200 {array} dto.AppConfigVariableDto // @Success 200 {array} dto.AppConfigVariableDto
// @Security BearerAuth // @Security BearerAuth
// @Router /application-configuration [put] // @Router /api/application-configuration [put]
func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) { func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
var input dto.AppConfigUpdateDto var input dto.AppConfigUpdateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(input) savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(c.Request.Context(), input)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var configVariablesDto []dto.AppConfigVariableDto var configVariablesDto []dto.AppConfigVariableDto
if err := dto.MapStructList(savedConfigVariables, &configVariablesDto); err != nil { if err := dto.MapStructList(savedConfigVariables, &configVariablesDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -141,19 +133,19 @@ func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
// @Produce image/jpeg // @Produce image/jpeg
// @Produce image/svg+xml // @Produce image/svg+xml
// @Success 200 {file} binary "Logo image" // @Success 200 {file} binary "Logo image"
// @Router /application-configuration/logo [get] // @Router /api/application-configuration/logo [get]
func (acc *AppConfigController) getLogoHandler(c *gin.Context) { func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
lightLogo := c.DefaultQuery("light", "true") == "true" dbConfig := acc.appConfigService.GetDbConfig()
var imageName string lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
var imageType string
var imageName, imageType string
if lightLogo { if lightLogo {
imageName = "logoLight" imageName = "logoLight"
imageType = acc.appConfigService.DbConfig.LogoLightImageType.Value imageType = dbConfig.LogoLightImageType.Value
} else { } else {
imageName = "logoDark" imageName = "logoDark"
imageType = acc.appConfigService.DbConfig.LogoDarkImageType.Value imageType = dbConfig.LogoDarkImageType.Value
} }
acc.getImage(c, imageName, imageType) acc.getImage(c, imageName, imageType)
@@ -166,7 +158,7 @@ func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
// @Produce image/x-icon // @Produce image/x-icon
// @Success 200 {file} binary "Favicon image" // @Success 200 {file} binary "Favicon image"
// @Failure 404 {object} object "{"error": "File not found"}" // @Failure 404 {object} object "{"error": "File not found"}"
// @Router /application-configuration/favicon [get] // @Router /api/application-configuration/favicon [get]
func (acc *AppConfigController) getFaviconHandler(c *gin.Context) { func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
acc.getImage(c, "favicon", "ico") acc.getImage(c, "favicon", "ico")
} }
@@ -179,9 +171,9 @@ func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
// @Produce image/jpeg // @Produce image/jpeg
// @Success 200 {file} binary "Background image" // @Success 200 {file} binary "Background image"
// @Failure 404 {object} object "{"error": "File not found"}" // @Failure 404 {object} object "{"error": "File not found"}"
// @Router /application-configuration/background-image [get] // @Router /api/application-configuration/background-image [get]
func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) { func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value imageType := acc.appConfigService.GetDbConfig().BackgroundImageType.Value
acc.getImage(c, "background", imageType) acc.getImage(c, "background", imageType)
} }
@@ -194,19 +186,19 @@ func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
// @Param file formData file true "Logo image file" // @Param file formData file true "Logo image file"
// @Success 204 "No Content" // @Success 204 "No Content"
// @Security BearerAuth // @Security BearerAuth
// @Router /application-configuration/logo [put] // @Router /api/application-configuration/logo [put]
func (acc *AppConfigController) updateLogoHandler(c *gin.Context) { func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
lightLogo := c.DefaultQuery("light", "true") == "true" dbConfig := acc.appConfigService.GetDbConfig()
var imageName string lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
var imageType string
var imageName, imageType string
if lightLogo { if lightLogo {
imageName = "logoLight" imageName = "logoLight"
imageType = acc.appConfigService.DbConfig.LogoLightImageType.Value imageType = dbConfig.LogoLightImageType.Value
} else { } else {
imageName = "logoDark" imageName = "logoDark"
imageType = acc.appConfigService.DbConfig.LogoDarkImageType.Value imageType = dbConfig.LogoDarkImageType.Value
} }
acc.updateImage(c, imageName, imageType) acc.updateImage(c, imageName, imageType)
@@ -220,17 +212,17 @@ func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
// @Param file formData file true "Favicon file (.ico)" // @Param file formData file true "Favicon file (.ico)"
// @Success 204 "No Content" // @Success 204 "No Content"
// @Security BearerAuth // @Security BearerAuth
// @Router /application-configuration/favicon [put] // @Router /api/application-configuration/favicon [put]
func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) { func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
file, err := c.FormFile("file") file, err := c.FormFile("file")
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
fileType := utils.GetFileExtension(file.Filename) fileType := utils.GetFileExtension(file.Filename)
if fileType != "ico" { if fileType != "ico" {
c.Error(&common.WrongFileTypeError{ExpectedFileType: ".ico"}) _ = c.Error(&common.WrongFileTypeError{ExpectedFileType: ".ico"})
return return
} }
acc.updateImage(c, "favicon", "ico") acc.updateImage(c, "favicon", "ico")
@@ -244,15 +236,15 @@ func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
// @Param file formData file true "Background image file" // @Param file formData file true "Background image file"
// @Success 204 "No Content" // @Success 204 "No Content"
// @Security BearerAuth // @Security BearerAuth
// @Router /application-configuration/background-image [put] // @Router /api/application-configuration/background-image [put]
func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) { func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value imageType := acc.appConfigService.GetDbConfig().BackgroundImageType.Value
acc.updateImage(c, "background", imageType) acc.updateImage(c, "background", imageType)
} }
// getImage is a helper function to serve image files // getImage is a helper function to serve image files
func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) { func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) {
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType) imagePath := common.EnvConfig.UploadPath + "/application-images/" + name + "." + imageType
mimeType := utils.GetImageMimeType(imageType) mimeType := utils.GetImageMimeType(imageType)
c.Header("Content-Type", mimeType) c.Header("Content-Type", mimeType)
@@ -263,13 +255,13 @@ func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType
func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, oldImageType string) { func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, oldImageType string) {
file, err := c.FormFile("file") file, err := c.FormFile("file")
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
err = acc.appConfigService.UpdateImage(file, imageName, oldImageType) err = acc.appConfigService.UpdateImage(c.Request.Context(), file, imageName, oldImageType)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -282,11 +274,11 @@ func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, ol
// @Tags Application Configuration // @Tags Application Configuration
// @Success 204 "No Content" // @Success 204 "No Content"
// @Security BearerAuth // @Security BearerAuth
// @Router /application-configuration/sync-ldap [post] // @Router /api/application-configuration/sync-ldap [post]
func (acc *AppConfigController) syncLdapHandler(c *gin.Context) { func (acc *AppConfigController) syncLdapHandler(c *gin.Context) {
err := acc.ldapService.SyncAll() err := acc.ldapService.SyncAll(c.Request.Context())
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -299,13 +291,13 @@ func (acc *AppConfigController) syncLdapHandler(c *gin.Context) {
// @Tags Application Configuration // @Tags Application Configuration
// @Success 204 "No Content" // @Success 204 "No Content"
// @Security BearerAuth // @Security BearerAuth
// @Router /application-configuration/test-email [post] // @Router /api/application-configuration/test-email [post]
func (acc *AppConfigController) testEmailHandler(c *gin.Context) { func (acc *AppConfigController) testEmailHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
err := acc.emailService.SendTestEmail(userID) err := acc.emailService.SendTestEmail(c.Request.Context(), userID)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }

View File

@@ -20,7 +20,10 @@ func NewAuditLogController(group *gin.RouterGroup, auditLogService *service.Audi
auditLogService: auditLogService, auditLogService: auditLogService,
} }
group.GET("/audit-logs/all", authMiddleware.Add(), alc.listAllAuditLogsHandler)
group.GET("/audit-logs", authMiddleware.WithAdminNotRequired().Add(), alc.listAuditLogsForUserHandler) group.GET("/audit-logs", authMiddleware.WithAdminNotRequired().Add(), alc.listAuditLogsForUserHandler)
group.GET("/audit-logs/filters/client-names", authMiddleware.Add(), alc.listClientNamesHandler)
group.GET("/audit-logs/filters/users", authMiddleware.Add(), alc.listUserNamesWithIdsHandler)
} }
type AuditLogController struct { type AuditLogController struct {
@@ -36,20 +39,22 @@ type AuditLogController struct {
// @Param sort_column query string false "Column to sort by" default("created_at") // @Param sort_column query string false "Column to sort by" default("created_at")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc") // @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
// @Success 200 {object} dto.Paginated[dto.AuditLogDto] // @Success 200 {object} dto.Paginated[dto.AuditLogDto]
// @Router /audit-logs [get] // @Router /api/audit-logs [get]
func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) { func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
var sortedPaginationRequest utils.SortedPaginationRequest var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
c.Error(err) err := c.ShouldBindQuery(&sortedPaginationRequest)
if err != nil {
_ = c.Error(err)
return return
} }
userID := c.GetString("userID") userID := c.GetString("userID")
// Fetch audit logs for the user // Fetch audit logs for the user
logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(userID, sortedPaginationRequest) logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(c.Request.Context(), userID, sortedPaginationRequest)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -57,7 +62,7 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
var logsDtos []dto.AuditLogDto var logsDtos []dto.AuditLogDto
err = dto.MapStructList(logs, &logsDtos) err = dto.MapStructList(logs, &logsDtos)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -72,3 +77,86 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
Pagination: pagination, Pagination: pagination,
}) })
} }
// listAllAuditLogsHandler godoc
// @Summary List all audit logs
// @Description Get a paginated list of all audit logs (admin only)
// @Tags Audit Logs
// @Param page query int false "Page number, starting from 1" default(1)
// @Param limit query int false "Number of items per page" default(10)
// @Param sort_column query string false "Column to sort by" default("created_at")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
// @Param user_id query string false "Filter by user ID"
// @Param event query string false "Filter by event type"
// @Param client_name query string false "Filter by client name"
// @Success 200 {object} dto.Paginated[dto.AuditLogDto]
// @Router /api/audit-logs/all [get]
func (alc *AuditLogController) listAllAuditLogsHandler(c *gin.Context) {
var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
_ = c.Error(err)
return
}
var filters dto.AuditLogFilterDto
if err := c.ShouldBindQuery(&filters); err != nil {
_ = c.Error(err)
return
}
logs, pagination, err := alc.auditLogService.ListAllAuditLogs(c.Request.Context(), sortedPaginationRequest, filters)
if err != nil {
_ = c.Error(err)
return
}
var logsDtos []dto.AuditLogDto
err = dto.MapStructList(logs, &logsDtos)
if err != nil {
_ = c.Error(err)
return
}
for i, logsDto := range logsDtos {
logsDto.Device = alc.auditLogService.DeviceStringFromUserAgent(logs[i].UserAgent)
logsDto.Username = logs[i].User.Username
logsDtos[i] = logsDto
}
c.JSON(http.StatusOK, dto.Paginated[dto.AuditLogDto]{
Data: logsDtos,
Pagination: pagination,
})
}
// listClientNamesHandler godoc
// @Summary List client names
// @Description Get a list of all client names for audit log filtering
// @Tags Audit Logs
// @Success 200 {array} string "List of client names"
// @Router /api/audit-logs/filters/client-names [get]
func (alc *AuditLogController) listClientNamesHandler(c *gin.Context) {
names, err := alc.auditLogService.ListClientNames(c.Request.Context())
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, names)
}
// listUserNamesWithIdsHandler godoc
// @Summary List users with IDs
// @Description Get a list of all usernames with their IDs for audit log filtering
// @Tags Audit Logs
// @Success 200 {object} map[string]string "Map of user IDs to usernames"
// @Router /api/audit-logs/filters/users [get]
func (alc *AuditLogController) listUserNamesWithIdsHandler(c *gin.Context) {
users, err := alc.auditLogService.ListUsernamesWithIds(c.Request.Context())
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, users)
}

View File

@@ -39,11 +39,11 @@ type CustomClaimController struct {
// @Failure 403 {object} object "Forbidden" // @Failure 403 {object} object "Forbidden"
// @Failure 500 {object} object "Internal server error" // @Failure 500 {object} object "Internal server error"
// @Security BearerAuth // @Security BearerAuth
// @Router /custom-claims/suggestions [get] // @Router /api/custom-claims/suggestions [get]
func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) { func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
claims, err := ccc.customClaimService.GetSuggestions() claims, err := ccc.customClaimService.GetSuggestions(c.Request.Context())
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -59,25 +59,25 @@ func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
// @Param userId path string true "User ID" // @Param userId path string true "User ID"
// @Param claims body []dto.CustomClaimCreateDto true "List of custom claims to set for the user" // @Param claims body []dto.CustomClaimCreateDto true "List of custom claims to set for the user"
// @Success 200 {array} dto.CustomClaimDto "Updated custom claims" // @Success 200 {array} dto.CustomClaimDto "Updated custom claims"
// @Router /custom-claims/user/{userId} [put] // @Router /api/custom-claims/user/{userId} [put]
func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Context) { func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Context) {
var input []dto.CustomClaimCreateDto var input []dto.CustomClaimCreateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
userId := c.Param("userId") userId := c.Param("userId")
claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(userId, input) claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(c.Request.Context(), userId, input)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var customClaimsDto []dto.CustomClaimDto var customClaimsDto []dto.CustomClaimDto
if err := dto.MapStructList(claims, &customClaimsDto); err != nil { if err := dto.MapStructList(claims, &customClaimsDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -94,25 +94,25 @@ func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Contex
// @Param claims body []dto.CustomClaimCreateDto true "List of custom claims to set for the user group" // @Param claims body []dto.CustomClaimCreateDto true "List of custom claims to set for the user group"
// @Success 200 {array} dto.CustomClaimDto "Updated custom claims" // @Success 200 {array} dto.CustomClaimDto "Updated custom claims"
// @Security BearerAuth // @Security BearerAuth
// @Router /custom-claims/user-group/{userGroupId} [put] // @Router /api/custom-claims/user-group/{userGroupId} [put]
func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.Context) { func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.Context) {
var input []dto.CustomClaimCreateDto var input []dto.CustomClaimCreateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
userGroupId := c.Param("userGroupId") userGroupId := c.Param("userGroupId")
claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(userGroupId, input) claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(c.Request.Context(), userGroupId, input)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var customClaimsDto []dto.CustomClaimDto var customClaimsDto []dto.CustomClaimDto
if err := dto.MapStructList(claims, &customClaimsDto); err != nil { if err := dto.MapStructList(claims, &customClaimsDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }

View File

@@ -1,9 +1,12 @@
//go:build e2etest
package controller package controller
import ( import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/service" "github.com/pocket-id/pocket-id/backend/internal/service"
) )
@@ -19,22 +22,22 @@ type TestController struct {
func (tc *TestController) resetAndSeedHandler(c *gin.Context) { func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
if err := tc.TestService.ResetDatabase(); err != nil { if err := tc.TestService.ResetDatabase(); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
if err := tc.TestService.ResetApplicationImages(); err != nil { if err := tc.TestService.ResetApplicationImages(); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
if err := tc.TestService.SeedDatabase(); err != nil { if err := tc.TestService.SeedDatabase(); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
if err := tc.TestService.ResetAppConfig(); err != nil { if err := tc.TestService.ResetAppConfig(c.Request.Context()); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }

View File

@@ -6,14 +6,13 @@ import (
"net/url" "net/url"
"strings" "strings"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/middleware" "github.com/pocket-id/pocket-id/backend/internal/middleware"
"github.com/pocket-id/pocket-id/backend/internal/service" "github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
) )
// NewOidcController creates a new controller for OIDC related endpoints // NewOidcController creates a new controller for OIDC related endpoints
@@ -31,6 +30,7 @@ func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
group.POST("/oidc/userinfo", oc.userInfoHandler) group.POST("/oidc/userinfo", oc.userInfoHandler)
group.POST("/oidc/end-session", authMiddleware.WithSuccessOptional().Add(), oc.EndSessionHandler) group.POST("/oidc/end-session", authMiddleware.WithSuccessOptional().Add(), oc.EndSessionHandler)
group.GET("/oidc/end-session", authMiddleware.WithSuccessOptional().Add(), oc.EndSessionHandler) group.GET("/oidc/end-session", authMiddleware.WithSuccessOptional().Add(), oc.EndSessionHandler)
group.POST("/oidc/introspect", oc.introspectTokenHandler)
group.GET("/oidc/clients", authMiddleware.Add(), oc.listClientsHandler) group.GET("/oidc/clients", authMiddleware.Add(), oc.listClientsHandler)
group.POST("/oidc/clients", authMiddleware.Add(), oc.createClientHandler) group.POST("/oidc/clients", authMiddleware.Add(), oc.createClientHandler)
@@ -61,17 +61,17 @@ type OidcController struct {
// @Param request body dto.AuthorizeOidcClientRequestDto true "Authorization request parameters" // @Param request body dto.AuthorizeOidcClientRequestDto true "Authorization request parameters"
// @Success 200 {object} dto.AuthorizeOidcClientResponseDto "Authorization code and callback URL" // @Success 200 {object} dto.AuthorizeOidcClientResponseDto "Authorization code and callback URL"
// @Security BearerAuth // @Security BearerAuth
// @Router /oidc/authorize [post] // @Router /api/oidc/authorize [post]
func (oc *OidcController) authorizeHandler(c *gin.Context) { func (oc *OidcController) authorizeHandler(c *gin.Context) {
var input dto.AuthorizeOidcClientRequestDto var input dto.AuthorizeOidcClientRequestDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
code, callbackURL, err := oc.oidcService.Authorize(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent()) code, callbackURL, err := oc.oidcService.Authorize(c.Request.Context(), input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent())
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -92,17 +92,17 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
// @Param request body dto.AuthorizationRequiredDto true "Authorization check parameters" // @Param request body dto.AuthorizationRequiredDto true "Authorization check parameters"
// @Success 200 {object} object "{ \"authorizationRequired\": true/false }" // @Success 200 {object} object "{ \"authorizationRequired\": true/false }"
// @Security BearerAuth // @Security BearerAuth
// @Router /oidc/authorization-required [post] // @Router /api/oidc/authorization-required [post]
func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Context) { func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Context) {
var input dto.AuthorizationRequiredDto var input dto.AuthorizationRequiredDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
hasAuthorizedClient, err := oc.oidcService.HasAuthorizedClient(input.ClientID, c.GetString("userID"), input.Scope) hasAuthorizedClient, err := oc.oidcService.HasAuthorizedClient(c.Request.Context(), input.ClientID, c.GetString("userID"), input.Scope)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -111,25 +111,36 @@ func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Contex
// createTokensHandler godoc // createTokensHandler godoc
// @Summary Create OIDC tokens // @Summary Create OIDC tokens
// @Description Exchange authorization code for ID and access tokens // @Description Exchange authorization code or refresh token for access tokens
// @Tags OIDC // @Tags OIDC
// @Accept application/x-www-form-urlencoded
// @Produce json // @Produce json
// @Param client_id formData string false "Client ID (if not using Basic Auth)" // @Param client_id formData string false "Client ID (if not using Basic Auth)"
// @Param client_secret formData string false "Client secret (if not using Basic Auth)" // @Param client_secret formData string false "Client secret (if not using Basic Auth)"
// @Param code formData string true "Authorization code" // @Param code formData string false "Authorization code (required for 'authorization_code' grant)"
// @Param grant_type formData string true "Grant type (must be 'authorization_code')" // @Param grant_type formData string true "Grant type ('authorization_code' or 'refresh_token')"
// @Param code_verifier formData string false "PKCE code verifier" // @Param code_verifier formData string false "PKCE code verifier (for authorization_code with PKCE)"
// @Success 200 {object} object "{ \"id_token\": \"string\", \"access_token\": \"string\", \"token_type\": \"Bearer\" }" // @Param refresh_token formData string false "Refresh token (required for 'refresh_token' grant)"
// @Router /oidc/token [post] // @Success 200 {object} dto.OidcTokenResponseDto "Token response with access_token and optional id_token and refresh_token"
// @Router /api/oidc/token [post]
func (oc *OidcController) createTokensHandler(c *gin.Context) { func (oc *OidcController) createTokensHandler(c *gin.Context) {
// Disable cors for this endpoint // Disable cors for this endpoint
c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
var input dto.OidcCreateTokensDto var input dto.OidcCreateTokensDto
if err := c.ShouldBind(&input); err != nil { if err := c.ShouldBind(&input); err != nil {
c.Error(err) _ = c.Error(err)
return
}
// Validate that code is provided for authorization_code grant type
if input.GrantType == "authorization_code" && input.Code == "" {
_ = c.Error(&common.OidcMissingAuthorizationCodeError{})
return
}
// Validate that refresh_token is provided for refresh_token grant type
if input.GrantType == "refresh_token" && input.RefreshToken == "" {
_ = c.Error(&common.OidcMissingRefreshTokenError{})
return return
} }
@@ -141,13 +152,38 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
clientID, clientSecret, _ = c.Request.BasicAuth() clientID, clientSecret, _ = c.Request.BasicAuth()
} }
idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret, input.CodeVerifier) idToken, accessToken, refreshToken, expiresIn, err := oc.oidcService.CreateTokens(
c.Request.Context(),
input.Code,
input.GrantType,
clientID,
clientSecret,
input.CodeVerifier,
input.RefreshToken,
)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
c.JSON(http.StatusOK, gin.H{"id_token": idToken, "access_token": accessToken, "token_type": "Bearer"}) response := dto.OidcTokenResponseDto{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: expiresIn,
}
// Include ID token only for authorization_code grant
if idToken != "" {
response.IdToken = idToken
}
// Include refresh token if generated
if refreshToken != "" {
response.RefreshToken = refreshToken
}
c.JSON(http.StatusOK, response)
} }
// userInfoHandler godoc // userInfoHandler godoc
@@ -158,45 +194,38 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
// @Produce json // @Produce json
// @Success 200 {object} object "User claims based on requested scopes" // @Success 200 {object} object "User claims based on requested scopes"
// @Security OAuth2AccessToken // @Security OAuth2AccessToken
// @Router /oidc/userinfo [get] // @Router /api/oidc/userinfo [get]
func (oc *OidcController) userInfoHandler(c *gin.Context) { func (oc *OidcController) userInfoHandler(c *gin.Context) {
authHeaderSplit := strings.Split(c.GetHeader("Authorization"), " ") _, authToken, ok := strings.Cut(c.GetHeader("Authorization"), " ")
if len(authHeaderSplit) != 2 { if !ok || authToken == "" {
c.Error(&common.MissingAccessToken{}) _ = c.Error(&common.MissingAccessToken{})
return return
} }
token := authHeaderSplit[1] token, err := oc.jwtService.VerifyOauthAccessToken(authToken)
jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
userID := jwtClaims.Subject userID, ok := token.Subject()
clientId := jwtClaims.Audience[0] if !ok {
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId) _ = c.Error(&common.TokenInvalidError{})
return
}
clientID, ok := token.Audience()
if !ok || len(clientID) != 1 {
_ = c.Error(&common.TokenInvalidError{})
return
}
claims, err := oc.oidcService.GetUserClaimsForClient(c.Request.Context(), userID, clientID[0])
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
c.JSON(http.StatusOK, claims) c.JSON(http.StatusOK, claims)
} }
// userInfoHandler godoc (POST method)
// @Summary Get user information (POST method)
// @Description Get user information based on the access token using POST
// @Tags OIDC
// @Accept json
// @Produce json
// @Success 200 {object} object "User claims based on requested scopes"
// @Security OAuth2AccessToken
// @Router /oidc/userinfo [post]
func (oc *OidcController) userInfoHandlerPost(c *gin.Context) {
// Implementation is the same as GET
}
// EndSessionHandler godoc // EndSessionHandler godoc
// @Summary End OIDC session // @Summary End OIDC session
// @Description End user session and handle OIDC logout // @Description End user session and handle OIDC logout
@@ -207,25 +236,26 @@ func (oc *OidcController) userInfoHandlerPost(c *gin.Context) {
// @Param post_logout_redirect_uri query string false "URL to redirect to after logout" // @Param post_logout_redirect_uri query string false "URL to redirect to after logout"
// @Param state query string false "State parameter to include in the redirect" // @Param state query string false "State parameter to include in the redirect"
// @Success 302 "Redirect to post-logout URL or application logout page" // @Success 302 "Redirect to post-logout URL or application logout page"
// @Router /oidc/end-session [get] // @Router /api/oidc/end-session [get]
func (oc *OidcController) EndSessionHandler(c *gin.Context) { func (oc *OidcController) EndSessionHandler(c *gin.Context) {
var input dto.OidcLogoutDto var input dto.OidcLogoutDto
// Bind query parameters to the struct // Bind query parameters to the struct
if c.Request.Method == http.MethodGet { switch c.Request.Method {
case http.MethodGet:
if err := c.ShouldBindQuery(&input); err != nil { if err := c.ShouldBindQuery(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
} else if c.Request.Method == http.MethodPost { case http.MethodPost:
// Bind form parameters to the struct // Bind form parameters to the struct
if err := c.ShouldBind(&input); err != nil { if err := c.ShouldBind(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
} }
callbackURL, err := oc.oidcService.ValidateEndSession(input, c.GetString("userID")) callbackURL, err := oc.oidcService.ValidateEndSession(c.Request.Context(), input, c.GetString("userID"))
if err != nil { if err != nil {
// If the validation fails, the user has to confirm the logout manually and doesn't get redirected // If the validation fails, the user has to confirm the logout manually and doesn't get redirected
log.Printf("Error getting logout callback URL, the user has to confirm the logout manually: %v", err) log.Printf("Error getting logout callback URL, the user has to confirm the logout manually: %v", err)
@@ -256,11 +286,43 @@ func (oc *OidcController) EndSessionHandler(c *gin.Context) {
// @Param post_logout_redirect_uri formData string false "URL to redirect to after logout" // @Param post_logout_redirect_uri formData string false "URL to redirect to after logout"
// @Param state formData string false "State parameter to include in the redirect" // @Param state formData string false "State parameter to include in the redirect"
// @Success 302 "Redirect to post-logout URL or application logout page" // @Success 302 "Redirect to post-logout URL or application logout page"
// @Router /oidc/end-session [post] // @Router /api/oidc/end-session [post]
func (oc *OidcController) EndSessionHandlerPost(c *gin.Context) { func (oc *OidcController) EndSessionHandlerPost(c *gin.Context) {
// Implementation is the same as GET // Implementation is the same as GET
} }
// introspectToken godoc
// @Summary Introspect OIDC tokens
// @Description Pass an access_token to verify if it is considered valid.
// @Tags OIDC
// @Produce json
// @Param token formData string true "The token to be introspected."
// @Success 200 {object} dto.OidcIntrospectionResponseDto "Response with the introspection result."
// @Router /api/oidc/introspect [post]
func (oc *OidcController) introspectTokenHandler(c *gin.Context) {
var input dto.OidcIntrospectDto
if err := c.ShouldBind(&input); err != nil {
_ = c.Error(err)
return
}
// Client id and secret have to be passed over the Authorization header. This kind of
// authentication allows us to keep the endpoint protected (since it could be used to
// find valid tokens) while still allowing it to be used by an application that is
// supposed to interact with our IdP (since that needs to have a client_id
// and client_secret anyway).
clientID, clientSecret, _ := c.Request.BasicAuth()
response, err := oc.oidcService.IntrospectToken(clientID, clientSecret, input.Token)
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, response)
}
// getClientMetaDataHandler godoc // getClientMetaDataHandler godoc
// @Summary Get client metadata // @Summary Get client metadata
// @Description Get OIDC client metadata for discovery and configuration // @Description Get OIDC client metadata for discovery and configuration
@@ -268,12 +330,12 @@ func (oc *OidcController) EndSessionHandlerPost(c *gin.Context) {
// @Produce json // @Produce json
// @Param id path string true "Client ID" // @Param id path string true "Client ID"
// @Success 200 {object} dto.OidcClientMetaDataDto "Client metadata" // @Success 200 {object} dto.OidcClientMetaDataDto "Client metadata"
// @Router /oidc/clients/{id}/meta [get] // @Router /api/oidc/clients/{id}/meta [get]
func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) { func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) {
clientId := c.Param("id") clientId := c.Param("id")
client, err := oc.oidcService.GetClient(clientId) client, err := oc.oidcService.GetClient(c.Request.Context(), clientId)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -284,7 +346,7 @@ func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) {
return return
} }
c.Error(err) _ = c.Error(err)
} }
// getClientHandler godoc // getClientHandler godoc
@@ -295,12 +357,12 @@ func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) {
// @Param id path string true "Client ID" // @Param id path string true "Client ID"
// @Success 200 {object} dto.OidcClientWithAllowedUserGroupsDto "Client information" // @Success 200 {object} dto.OidcClientWithAllowedUserGroupsDto "Client information"
// @Security BearerAuth // @Security BearerAuth
// @Router /oidc/clients/{id} [get] // @Router /api/oidc/clients/{id} [get]
func (oc *OidcController) getClientHandler(c *gin.Context) { func (oc *OidcController) getClientHandler(c *gin.Context) {
clientId := c.Param("id") clientId := c.Param("id")
client, err := oc.oidcService.GetClient(clientId) client, err := oc.oidcService.GetClient(c.Request.Context(), clientId)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -311,7 +373,7 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
return return
} }
c.Error(err) _ = c.Error(err)
} }
// listClientsHandler godoc // listClientsHandler godoc
@@ -325,24 +387,24 @@ func (oc *OidcController) getClientHandler(c *gin.Context) {
// @Param sort_direction query string false "Sort direction (asc or desc)" default("asc") // @Param sort_direction query string false "Sort direction (asc or desc)" default("asc")
// @Success 200 {object} dto.Paginated[dto.OidcClientDto] // @Success 200 {object} dto.Paginated[dto.OidcClientDto]
// @Security BearerAuth // @Security BearerAuth
// @Router /oidc/clients [get] // @Router /api/oidc/clients [get]
func (oc *OidcController) listClientsHandler(c *gin.Context) { func (oc *OidcController) listClientsHandler(c *gin.Context) {
searchTerm := c.Query("search") searchTerm := c.Query("search")
var sortedPaginationRequest utils.SortedPaginationRequest var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil { if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
clients, pagination, err := oc.oidcService.ListClients(searchTerm, sortedPaginationRequest) clients, pagination, err := oc.oidcService.ListClients(c.Request.Context(), searchTerm, sortedPaginationRequest)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var clientsDto []dto.OidcClientDto var clientsDto []dto.OidcClientDto
if err := dto.MapStructList(clients, &clientsDto); err != nil { if err := dto.MapStructList(clients, &clientsDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -361,23 +423,23 @@ func (oc *OidcController) listClientsHandler(c *gin.Context) {
// @Param client body dto.OidcClientCreateDto true "Client information" // @Param client body dto.OidcClientCreateDto true "Client information"
// @Success 201 {object} dto.OidcClientWithAllowedUserGroupsDto "Created client" // @Success 201 {object} dto.OidcClientWithAllowedUserGroupsDto "Created client"
// @Security BearerAuth // @Security BearerAuth
// @Router /oidc/clients [post] // @Router /api/oidc/clients [post]
func (oc *OidcController) createClientHandler(c *gin.Context) { func (oc *OidcController) createClientHandler(c *gin.Context) {
var input dto.OidcClientCreateDto var input dto.OidcClientCreateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
client, err := oc.oidcService.CreateClient(input, c.GetString("userID")) client, err := oc.oidcService.CreateClient(c.Request.Context(), input, c.GetString("userID"))
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var clientDto dto.OidcClientWithAllowedUserGroupsDto var clientDto dto.OidcClientWithAllowedUserGroupsDto
if err := dto.MapStruct(client, &clientDto); err != nil { if err := dto.MapStruct(client, &clientDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -391,11 +453,11 @@ func (oc *OidcController) createClientHandler(c *gin.Context) {
// @Param id path string true "Client ID" // @Param id path string true "Client ID"
// @Success 204 "No Content" // @Success 204 "No Content"
// @Security BearerAuth // @Security BearerAuth
// @Router /oidc/clients/{id} [delete] // @Router /api/oidc/clients/{id} [delete]
func (oc *OidcController) deleteClientHandler(c *gin.Context) { func (oc *OidcController) deleteClientHandler(c *gin.Context) {
err := oc.oidcService.DeleteClient(c.Param("id")) err := oc.oidcService.DeleteClient(c.Request.Context(), c.Param("id"))
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -412,23 +474,23 @@ func (oc *OidcController) deleteClientHandler(c *gin.Context) {
// @Param client body dto.OidcClientCreateDto true "Client information" // @Param client body dto.OidcClientCreateDto true "Client information"
// @Success 200 {object} dto.OidcClientWithAllowedUserGroupsDto "Updated client" // @Success 200 {object} dto.OidcClientWithAllowedUserGroupsDto "Updated client"
// @Security BearerAuth // @Security BearerAuth
// @Router /oidc/clients/{id} [put] // @Router /api/oidc/clients/{id} [put]
func (oc *OidcController) updateClientHandler(c *gin.Context) { func (oc *OidcController) updateClientHandler(c *gin.Context) {
var input dto.OidcClientCreateDto var input dto.OidcClientCreateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
client, err := oc.oidcService.UpdateClient(c.Param("id"), input) client, err := oc.oidcService.UpdateClient(c.Request.Context(), c.Param("id"), input)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var clientDto dto.OidcClientWithAllowedUserGroupsDto var clientDto dto.OidcClientWithAllowedUserGroupsDto
if err := dto.MapStruct(client, &clientDto); err != nil { if err := dto.MapStruct(client, &clientDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -443,11 +505,11 @@ func (oc *OidcController) updateClientHandler(c *gin.Context) {
// @Param id path string true "Client ID" // @Param id path string true "Client ID"
// @Success 200 {object} object "{ \"secret\": \"string\" }" // @Success 200 {object} object "{ \"secret\": \"string\" }"
// @Security BearerAuth // @Security BearerAuth
// @Router /oidc/clients/{id}/secret [post] // @Router /api/oidc/clients/{id}/secret [post]
func (oc *OidcController) createClientSecretHandler(c *gin.Context) { func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
secret, err := oc.oidcService.CreateClientSecret(c.Param("id")) secret, err := oc.oidcService.CreateClientSecret(c.Request.Context(), c.Param("id"))
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -463,11 +525,11 @@ func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
// @Produce image/svg+xml // @Produce image/svg+xml
// @Param id path string true "Client ID" // @Param id path string true "Client ID"
// @Success 200 {file} binary "Logo image" // @Success 200 {file} binary "Logo image"
// @Router /oidc/clients/{id}/logo [get] // @Router /api/oidc/clients/{id}/logo [get]
func (oc *OidcController) getClientLogoHandler(c *gin.Context) { func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id")) imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Request.Context(), c.Param("id"))
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -484,17 +546,17 @@ func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
// @Param file formData file true "Logo image file (PNG, JPG, or SVG, max 2MB)" // @Param file formData file true "Logo image file (PNG, JPG, or SVG, max 2MB)"
// @Success 204 "No Content" // @Success 204 "No Content"
// @Security BearerAuth // @Security BearerAuth
// @Router /oidc/clients/{id}/logo [post] // @Router /api/oidc/clients/{id}/logo [post]
func (oc *OidcController) updateClientLogoHandler(c *gin.Context) { func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
file, err := c.FormFile("file") file, err := c.FormFile("file")
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
err = oc.oidcService.UpdateClientLogo(c.Param("id"), file) err = oc.oidcService.UpdateClientLogo(c.Request.Context(), c.Param("id"), file)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -508,11 +570,11 @@ func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
// @Param id path string true "Client ID" // @Param id path string true "Client ID"
// @Success 204 "No Content" // @Success 204 "No Content"
// @Security BearerAuth // @Security BearerAuth
// @Router /oidc/clients/{id}/logo [delete] // @Router /api/oidc/clients/{id}/logo [delete]
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) { func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
err := oc.oidcService.DeleteClientLogo(c.Param("id")) err := oc.oidcService.DeleteClientLogo(c.Request.Context(), c.Param("id"))
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -529,23 +591,23 @@ func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
// @Param groups body dto.OidcUpdateAllowedUserGroupsDto true "User group IDs" // @Param groups body dto.OidcUpdateAllowedUserGroupsDto true "User group IDs"
// @Success 200 {object} dto.OidcClientDto "Updated client" // @Success 200 {object} dto.OidcClientDto "Updated client"
// @Security BearerAuth // @Security BearerAuth
// @Router /oidc/clients/{id}/allowed-user-groups [put] // @Router /api/oidc/clients/{id}/allowed-user-groups [put]
func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) { func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
var input dto.OidcUpdateAllowedUserGroupsDto var input dto.OidcUpdateAllowedUserGroupsDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
oidcClient, err := oc.oidcService.UpdateAllowedUserGroups(c.Param("id"), input) oidcClient, err := oc.oidcService.UpdateAllowedUserGroups(c.Request.Context(), c.Param("id"), input)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var oidcClientDto dto.OidcClientDto var oidcClientDto dto.OidcClientDto
if err := dto.MapStruct(oidcClient, &oidcClientDto); err != nil { if err := dto.MapStruct(oidcClient, &oidcClientDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }

View File

@@ -2,7 +2,6 @@ package controller
import ( import (
"net/http" "net/http"
"strconv"
"time" "time"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie" "github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
@@ -44,9 +43,10 @@ func NewUserController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
group.POST("/users/me/one-time-access-token", authMiddleware.WithAdminNotRequired().Add(), uc.createOwnOneTimeAccessTokenHandler) group.POST("/users/me/one-time-access-token", authMiddleware.WithAdminNotRequired().Add(), uc.createOwnOneTimeAccessTokenHandler)
group.POST("/users/:id/one-time-access-token", authMiddleware.Add(), uc.createAdminOneTimeAccessTokenHandler) group.POST("/users/:id/one-time-access-token", authMiddleware.Add(), uc.createAdminOneTimeAccessTokenHandler)
group.POST("/users/:id/one-time-access-email", authMiddleware.Add(), uc.RequestOneTimeAccessEmailAsAdminHandler)
group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler) group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler) group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler)
group.POST("/one-time-access-email", rateLimitMiddleware.Add(rate.Every(10*time.Minute), 3), uc.requestOneTimeAccessEmailHandler) group.POST("/one-time-access-email", rateLimitMiddleware.Add(rate.Every(10*time.Minute), 3), uc.RequestOneTimeAccessEmailAsUnauthenticatedUserHandler)
group.DELETE("/users/:id/profile-picture", authMiddleware.Add(), uc.resetUserProfilePictureHandler) group.DELETE("/users/:id/profile-picture", authMiddleware.Add(), uc.resetUserProfilePictureHandler)
group.DELETE("/users/me/profile-picture", authMiddleware.WithAdminNotRequired().Add(), uc.resetCurrentUserProfilePictureHandler) group.DELETE("/users/me/profile-picture", authMiddleware.WithAdminNotRequired().Add(), uc.resetCurrentUserProfilePictureHandler)
@@ -63,18 +63,18 @@ type UserController struct {
// @Tags Users,User Groups // @Tags Users,User Groups
// @Param id path string true "User ID" // @Param id path string true "User ID"
// @Success 200 {array} dto.UserGroupDtoWithUsers // @Success 200 {array} dto.UserGroupDtoWithUsers
// @Router /users/{id}/groups [get] // @Router /api/users/{id}/groups [get]
func (uc *UserController) getUserGroupsHandler(c *gin.Context) { func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
userID := c.Param("id") userID := c.Param("id")
groups, err := uc.userService.GetUserGroups(userID) groups, err := uc.userService.GetUserGroups(c.Request.Context(), userID)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var groupsDto []dto.UserGroupDtoWithUsers var groupsDto []dto.UserGroupDtoWithUsers
if err := dto.MapStructList(groups, &groupsDto); err != nil { if err := dto.MapStructList(groups, &groupsDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -91,24 +91,24 @@ func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
// @Param sort_column query string false "Column to sort by" default("created_at") // @Param sort_column query string false "Column to sort by" default("created_at")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc") // @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
// @Success 200 {object} dto.Paginated[dto.UserDto] // @Success 200 {object} dto.Paginated[dto.UserDto]
// @Router /users [get] // @Router /api/users [get]
func (uc *UserController) listUsersHandler(c *gin.Context) { func (uc *UserController) listUsersHandler(c *gin.Context) {
searchTerm := c.Query("search") searchTerm := c.Query("search")
var sortedPaginationRequest utils.SortedPaginationRequest var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil { if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
users, pagination, err := uc.userService.ListUsers(searchTerm, sortedPaginationRequest) users, pagination, err := uc.userService.ListUsers(c.Request.Context(), searchTerm, sortedPaginationRequest)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var usersDto []dto.UserDto var usersDto []dto.UserDto
if err := dto.MapStructList(users, &usersDto); err != nil { if err := dto.MapStructList(users, &usersDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -124,17 +124,17 @@ func (uc *UserController) listUsersHandler(c *gin.Context) {
// @Tags Users // @Tags Users
// @Param id path string true "User ID" // @Param id path string true "User ID"
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /users/{id} [get] // @Router /api/users/{id} [get]
func (uc *UserController) getUserHandler(c *gin.Context) { func (uc *UserController) getUserHandler(c *gin.Context) {
user, err := uc.userService.GetUser(c.Param("id")) user, err := uc.userService.GetUser(c.Request.Context(), c.Param("id"))
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var userDto dto.UserDto var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil { if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -146,17 +146,17 @@ func (uc *UserController) getUserHandler(c *gin.Context) {
// @Description Retrieve information about the currently authenticated user // @Description Retrieve information about the currently authenticated user
// @Tags Users // @Tags Users
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /users/me [get] // @Router /api/users/me [get]
func (uc *UserController) getCurrentUserHandler(c *gin.Context) { func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
user, err := uc.userService.GetUser(c.GetString("userID")) user, err := uc.userService.GetUser(c.Request.Context(), c.GetString("userID"))
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var userDto dto.UserDto var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil { if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -169,10 +169,10 @@ func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
// @Tags Users // @Tags Users
// @Param id path string true "User ID" // @Param id path string true "User ID"
// @Success 204 "No Content" // @Success 204 "No Content"
// @Router /users/{id} [delete] // @Router /api/users/{id} [delete]
func (uc *UserController) deleteUserHandler(c *gin.Context) { func (uc *UserController) deleteUserHandler(c *gin.Context) {
if err := uc.userService.DeleteUser(c.Param("id")); err != nil { if err := uc.userService.DeleteUser(c.Request.Context(), c.Param("id"), false); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -185,23 +185,23 @@ func (uc *UserController) deleteUserHandler(c *gin.Context) {
// @Tags Users // @Tags Users
// @Param user body dto.UserCreateDto true "User information" // @Param user body dto.UserCreateDto true "User information"
// @Success 201 {object} dto.UserDto // @Success 201 {object} dto.UserDto
// @Router /users [post] // @Router /api/users [post]
func (uc *UserController) createUserHandler(c *gin.Context) { func (uc *UserController) createUserHandler(c *gin.Context) {
var input dto.UserCreateDto var input dto.UserCreateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
user, err := uc.userService.CreateUser(input) user, err := uc.userService.CreateUser(c.Request.Context(), input)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var userDto dto.UserDto var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil { if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -215,7 +215,7 @@ func (uc *UserController) createUserHandler(c *gin.Context) {
// @Param id path string true "User ID" // @Param id path string true "User ID"
// @Param user body dto.UserCreateDto true "User information" // @Param user body dto.UserCreateDto true "User information"
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /users/{id} [put] // @Router /api/users/{id} [put]
func (uc *UserController) updateUserHandler(c *gin.Context) { func (uc *UserController) updateUserHandler(c *gin.Context) {
uc.updateUser(c, false) uc.updateUser(c, false)
} }
@@ -226,10 +226,10 @@ func (uc *UserController) updateUserHandler(c *gin.Context) {
// @Tags Users // @Tags Users
// @Param user body dto.UserCreateDto true "User information" // @Param user body dto.UserCreateDto true "User information"
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /users/me [put] // @Router /api/users/me [put]
func (uc *UserController) updateCurrentUserHandler(c *gin.Context) { func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
if uc.appConfigService.DbConfig.AllowOwnAccountEdit.Value != "true" { if !uc.appConfigService.GetDbConfig().AllowOwnAccountEdit.IsTrue() {
c.Error(&common.AccountEditNotAllowedError{}) _ = c.Error(&common.AccountEditNotAllowedError{})
return return
} }
uc.updateUser(c, true) uc.updateUser(c, true)
@@ -242,17 +242,23 @@ func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
// @Produce image/png // @Produce image/png
// @Param id path string true "User ID" // @Param id path string true "User ID"
// @Success 200 {file} binary "PNG image" // @Success 200 {file} binary "PNG image"
// @Router /users/{id}/profile-picture.png [get] // @Router /api/users/{id}/profile-picture.png [get]
func (uc *UserController) getUserProfilePictureHandler(c *gin.Context) { func (uc *UserController) getUserProfilePictureHandler(c *gin.Context) {
userID := c.Param("id") userID := c.Param("id")
picture, size, err := uc.userService.GetProfilePicture(userID) picture, size, err := uc.userService.GetProfilePicture(c.Request.Context(), userID)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
if picture != nil {
defer picture.Close()
}
c.Header("Cache-Control", "public, max-age=300") _, ok := c.GetQuery("skipCache")
if !ok {
c.Header("Cache-Control", "public, max-age=900")
}
c.DataFromReader(http.StatusOK, size, "image/png", picture, nil) c.DataFromReader(http.StatusOK, size, "image/png", picture, nil)
} }
@@ -266,23 +272,23 @@ func (uc *UserController) getUserProfilePictureHandler(c *gin.Context) {
// @Param id path string true "User ID" // @Param id path string true "User ID"
// @Param file formData file true "Profile picture image file (PNG, JPG, or JPEG)" // @Param file formData file true "Profile picture image file (PNG, JPG, or JPEG)"
// @Success 204 "No Content" // @Success 204 "No Content"
// @Router /users/{id}/profile-picture [put] // @Router /api/users/{id}/profile-picture [put]
func (uc *UserController) updateUserProfilePictureHandler(c *gin.Context) { func (uc *UserController) updateUserProfilePictureHandler(c *gin.Context) {
userID := c.Param("id") userID := c.Param("id")
fileHeader, err := c.FormFile("file") fileHeader, err := c.FormFile("file")
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
file, err := fileHeader.Open() file, err := fileHeader.Open()
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
defer file.Close() defer file.Close()
if err := uc.userService.UpdateProfilePicture(userID, file); err != nil { if err := uc.userService.UpdateProfilePicture(userID, file); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -297,23 +303,23 @@ func (uc *UserController) updateUserProfilePictureHandler(c *gin.Context) {
// @Produce json // @Produce json
// @Param file formData file true "Profile picture image file (PNG, JPG, or JPEG)" // @Param file formData file true "Profile picture image file (PNG, JPG, or JPEG)"
// @Success 204 "No Content" // @Success 204 "No Content"
// @Router /users/me/profile-picture [put] // @Router /api/users/me/profile-picture [put]
func (uc *UserController) updateCurrentUserProfilePictureHandler(c *gin.Context) { func (uc *UserController) updateCurrentUserProfilePictureHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
fileHeader, err := c.FormFile("file") fileHeader, err := c.FormFile("file")
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
file, err := fileHeader.Open() file, err := fileHeader.Open()
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
defer file.Close() defer file.Close()
if err := uc.userService.UpdateProfilePicture(userID, file); err != nil { if err := uc.userService.UpdateProfilePicture(userID, file); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -323,16 +329,16 @@ func (uc *UserController) updateCurrentUserProfilePictureHandler(c *gin.Context)
func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context, own bool) { func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context, own bool) {
var input dto.OneTimeAccessTokenCreateDto var input dto.OneTimeAccessTokenCreateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
if own { if own {
input.UserID = c.GetString("userID") input.UserID = c.GetString("userID")
} }
token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt) token, err := uc.userService.CreateOneTimeAccessToken(c.Request.Context(), input.UserID, input.ExpiresAt)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -346,25 +352,70 @@ func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context, own bo
// @Param id path string true "User ID" // @Param id path string true "User ID"
// @Param body body dto.OneTimeAccessTokenCreateDto true "Token options" // @Param body body dto.OneTimeAccessTokenCreateDto true "Token options"
// @Success 201 {object} object "{ \"token\": \"string\" }" // @Success 201 {object} object "{ \"token\": \"string\" }"
// @Router /users/{id}/one-time-access-token [post] // @Router /api/users/{id}/one-time-access-token [post]
func (uc *UserController) createOwnOneTimeAccessTokenHandler(c *gin.Context) { func (uc *UserController) createOwnOneTimeAccessTokenHandler(c *gin.Context) {
uc.createOneTimeAccessTokenHandler(c, true) uc.createOneTimeAccessTokenHandler(c, true)
} }
// createAdminOneTimeAccessTokenHandler godoc
// @Summary Create one-time access token for user (admin)
// @Description Generate a one-time access token for a specific user (admin only)
// @Tags Users
// @Param id path string true "User ID"
// @Param body body dto.OneTimeAccessTokenCreateDto true "Token options"
// @Success 201 {object} object "{ \"token\": \"string\" }"
// @Router /api/users/{id}/one-time-access-token [post]
func (uc *UserController) createAdminOneTimeAccessTokenHandler(c *gin.Context) { func (uc *UserController) createAdminOneTimeAccessTokenHandler(c *gin.Context) {
uc.createOneTimeAccessTokenHandler(c, false) uc.createOneTimeAccessTokenHandler(c, false)
} }
func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) { // RequestOneTimeAccessEmailAsUnauthenticatedUserHandler godoc
var input dto.OneTimeAccessEmailDto // @Summary Request one-time access email
// @Description Request a one-time access email for unauthenticated users
// @Tags Users
// @Accept json
// @Produce json
// @Param body body dto.OneTimeAccessEmailAsUnauthenticatedUserDto true "Email request information"
// @Success 204 "No Content"
// @Router /api/one-time-access-email [post]
func (uc *UserController) RequestOneTimeAccessEmailAsUnauthenticatedUserHandler(c *gin.Context) {
var input dto.OneTimeAccessEmailAsUnauthenticatedUserDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
err := uc.userService.RequestOneTimeAccessEmail(input.Email, input.RedirectPath) err := uc.userService.RequestOneTimeAccessEmailAsUnauthenticatedUser(c.Request.Context(), input.Email, input.RedirectPath)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
// RequestOneTimeAccessEmailAsAdminHandler godoc
// @Summary Request one-time access email (admin)
// @Description Request a one-time access email for a specific user (admin only)
// @Tags Users
// @Accept json
// @Produce json
// @Param id path string true "User ID"
// @Param body body dto.OneTimeAccessEmailAsAdminDto true "Email request options"
// @Success 204 "No Content"
// @Router /api/users/{id}/one-time-access-email [post]
func (uc *UserController) RequestOneTimeAccessEmailAsAdminHandler(c *gin.Context) {
var input dto.OneTimeAccessEmailAsAdminDto
if err := c.ShouldBindJSON(&input); err != nil {
_ = c.Error(err)
return
}
userID := c.Param("id")
err := uc.userService.RequestOneTimeAccessEmailAsAdmin(c.Request.Context(), userID, input.ExpiresAt)
if err != nil {
_ = c.Error(err)
return return
} }
@@ -377,22 +428,21 @@ func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) {
// @Tags Users // @Tags Users
// @Param token path string true "One-time access token" // @Param token path string true "One-time access token"
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /one-time-access-token/{token} [post] // @Router /api/one-time-access-token/{token} [post]
func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) { func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Param("token"), c.ClientIP(), c.Request.UserAgent()) user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Request.Context(), c.Param("token"), c.ClientIP(), c.Request.UserAgent())
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var userDto dto.UserDto var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil { if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
sessionDurationInMinutesParsed, _ := strconv.Atoi(uc.appConfigService.DbConfig.SessionDuration.Value) maxAge := int(uc.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes().Seconds())
maxAge := sessionDurationInMinutesParsed * 60
cookie.AddAccessTokenCookie(c, maxAge, token) cookie.AddAccessTokenCookie(c, maxAge, token)
c.JSON(http.StatusOK, userDto) c.JSON(http.StatusOK, userDto)
@@ -403,22 +453,21 @@ func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
// @Description Generate setup access token for initial admin user configuration // @Description Generate setup access token for initial admin user configuration
// @Tags Users // @Tags Users
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /one-time-access-token/setup [post] // @Router /api/one-time-access-token/setup [post]
func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) { func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
user, token, err := uc.userService.SetupInitialAdmin() user, token, err := uc.userService.SetupInitialAdmin(c.Request.Context())
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var userDto dto.UserDto var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil { if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
sessionDurationInMinutesParsed, _ := strconv.Atoi(uc.appConfigService.DbConfig.SessionDuration.Value) maxAge := int(uc.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes().Seconds())
maxAge := sessionDurationInMinutesParsed * 60
cookie.AddAccessTokenCookie(c, maxAge, token) cookie.AddAccessTokenCookie(c, maxAge, token)
c.JSON(http.StatusOK, userDto) c.JSON(http.StatusOK, userDto)
@@ -431,23 +480,23 @@ func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
// @Param id path string true "User ID" // @Param id path string true "User ID"
// @Param groups body dto.UserUpdateUserGroupDto true "User group IDs" // @Param groups body dto.UserUpdateUserGroupDto true "User group IDs"
// @Success 200 {object} dto.UserDto // @Success 200 {object} dto.UserDto
// @Router /users/{id}/user-groups [put] // @Router /api/users/{id}/user-groups [put]
func (uc *UserController) updateUserGroups(c *gin.Context) { func (uc *UserController) updateUserGroups(c *gin.Context) {
var input dto.UserUpdateUserGroupDto var input dto.UserUpdateUserGroupDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
user, err := uc.userService.UpdateUserGroups(c.Param("id"), input.UserGroupIds) user, err := uc.userService.UpdateUserGroups(c.Request.Context(), c.Param("id"), input.UserGroupIds)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var userDto dto.UserDto var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil { if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -458,7 +507,7 @@ func (uc *UserController) updateUserGroups(c *gin.Context) {
func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) { func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
var input dto.UserCreateDto var input dto.UserCreateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -469,15 +518,15 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
userID = c.Param("id") userID = c.Param("id")
} }
user, err := uc.userService.UpdateUser(userID, input, updateOwnUser, false) user, err := uc.userService.UpdateUser(c.Request.Context(), userID, input, updateOwnUser, false)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var userDto dto.UserDto var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil { if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -491,12 +540,12 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
// @Produce json // @Produce json
// @Param id path string true "User ID" // @Param id path string true "User ID"
// @Success 204 "No Content" // @Success 204 "No Content"
// @Router /users/{id}/profile-picture [delete] // @Router /api/users/{id}/profile-picture [delete]
func (uc *UserController) resetUserProfilePictureHandler(c *gin.Context) { func (uc *UserController) resetUserProfilePictureHandler(c *gin.Context) {
userID := c.Param("id") userID := c.Param("id")
if err := uc.userService.ResetProfilePicture(userID); err != nil { if err := uc.userService.ResetProfilePicture(userID); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -509,12 +558,12 @@ func (uc *UserController) resetUserProfilePictureHandler(c *gin.Context) {
// @Tags Users // @Tags Users
// @Produce json // @Produce json
// @Success 204 "No Content" // @Success 204 "No Content"
// @Router /users/me/profile-picture [delete] // @Router /api/users/me/profile-picture [delete]
func (uc *UserController) resetCurrentUserProfilePictureHandler(c *gin.Context) { func (uc *UserController) resetCurrentUserProfilePictureHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
if err := uc.userService.ResetProfilePicture(userID); err != nil { if err := uc.userService.ResetProfilePicture(userID); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }

View File

@@ -45,18 +45,20 @@ type UserGroupController struct {
// @Param sort_column query string false "Column to sort by" default("name") // @Param sort_column query string false "Column to sort by" default("name")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("asc") // @Param sort_direction query string false "Sort direction (asc or desc)" default("asc")
// @Success 200 {object} dto.Paginated[dto.UserGroupDtoWithUserCount] // @Success 200 {object} dto.Paginated[dto.UserGroupDtoWithUserCount]
// @Router /user-groups [get] // @Router /api/user-groups [get]
func (ugc *UserGroupController) list(c *gin.Context) { func (ugc *UserGroupController) list(c *gin.Context) {
ctx := c.Request.Context()
searchTerm := c.Query("search") searchTerm := c.Query("search")
var sortedPaginationRequest utils.SortedPaginationRequest var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil { if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
groups, pagination, err := ugc.UserGroupService.List(searchTerm, sortedPaginationRequest) groups, pagination, err := ugc.UserGroupService.List(ctx, searchTerm, sortedPaginationRequest)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -65,12 +67,12 @@ func (ugc *UserGroupController) list(c *gin.Context) {
for i, group := range groups { for i, group := range groups {
var groupDto dto.UserGroupDtoWithUserCount var groupDto dto.UserGroupDtoWithUserCount
if err := dto.MapStruct(group, &groupDto); err != nil { if err := dto.MapStruct(group, &groupDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(group.ID) groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(ctx, group.ID)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
groupsDto[i] = groupDto groupsDto[i] = groupDto
@@ -91,17 +93,17 @@ func (ugc *UserGroupController) list(c *gin.Context) {
// @Param id path string true "User Group ID" // @Param id path string true "User Group ID"
// @Success 200 {object} dto.UserGroupDtoWithUsers // @Success 200 {object} dto.UserGroupDtoWithUsers
// @Security BearerAuth // @Security BearerAuth
// @Router /user-groups/{id} [get] // @Router /api/user-groups/{id} [get]
func (ugc *UserGroupController) get(c *gin.Context) { func (ugc *UserGroupController) get(c *gin.Context) {
group, err := ugc.UserGroupService.Get(c.Param("id")) group, err := ugc.UserGroupService.Get(c.Request.Context(), c.Param("id"))
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var groupDto dto.UserGroupDtoWithUsers var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil { if err := dto.MapStruct(group, &groupDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -117,23 +119,23 @@ func (ugc *UserGroupController) get(c *gin.Context) {
// @Param userGroup body dto.UserGroupCreateDto true "User group information" // @Param userGroup body dto.UserGroupCreateDto true "User group information"
// @Success 201 {object} dto.UserGroupDtoWithUsers "Created user group" // @Success 201 {object} dto.UserGroupDtoWithUsers "Created user group"
// @Security BearerAuth // @Security BearerAuth
// @Router /user-groups [post] // @Router /api/user-groups [post]
func (ugc *UserGroupController) create(c *gin.Context) { func (ugc *UserGroupController) create(c *gin.Context) {
var input dto.UserGroupCreateDto var input dto.UserGroupCreateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
group, err := ugc.UserGroupService.Create(input) group, err := ugc.UserGroupService.Create(c.Request.Context(), input)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var groupDto dto.UserGroupDtoWithUsers var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil { if err := dto.MapStruct(group, &groupDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -150,23 +152,23 @@ func (ugc *UserGroupController) create(c *gin.Context) {
// @Param userGroup body dto.UserGroupCreateDto true "User group information" // @Param userGroup body dto.UserGroupCreateDto true "User group information"
// @Success 200 {object} dto.UserGroupDtoWithUsers "Updated user group" // @Success 200 {object} dto.UserGroupDtoWithUsers "Updated user group"
// @Security BearerAuth // @Security BearerAuth
// @Router /user-groups/{id} [put] // @Router /api/user-groups/{id} [put]
func (ugc *UserGroupController) update(c *gin.Context) { func (ugc *UserGroupController) update(c *gin.Context) {
var input dto.UserGroupCreateDto var input dto.UserGroupCreateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
group, err := ugc.UserGroupService.Update(c.Param("id"), input, false) group, err := ugc.UserGroupService.Update(c.Request.Context(), c.Param("id"), input)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var groupDto dto.UserGroupDtoWithUsers var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil { if err := dto.MapStruct(group, &groupDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -182,10 +184,10 @@ func (ugc *UserGroupController) update(c *gin.Context) {
// @Param id path string true "User Group ID" // @Param id path string true "User Group ID"
// @Success 204 "No Content" // @Success 204 "No Content"
// @Security BearerAuth // @Security BearerAuth
// @Router /user-groups/{id} [delete] // @Router /api/user-groups/{id} [delete]
func (ugc *UserGroupController) delete(c *gin.Context) { func (ugc *UserGroupController) delete(c *gin.Context) {
if err := ugc.UserGroupService.Delete(c.Param("id")); err != nil { if err := ugc.UserGroupService.Delete(c.Request.Context(), c.Param("id")); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -202,23 +204,23 @@ func (ugc *UserGroupController) delete(c *gin.Context) {
// @Param users body dto.UserGroupUpdateUsersDto true "List of user IDs to assign to this group" // @Param users body dto.UserGroupUpdateUsersDto true "List of user IDs to assign to this group"
// @Success 200 {object} dto.UserGroupDtoWithUsers // @Success 200 {object} dto.UserGroupDtoWithUsers
// @Security BearerAuth // @Security BearerAuth
// @Router /user-groups/{id}/users [put] // @Router /api/user-groups/{id}/users [put]
func (ugc *UserGroupController) updateUsers(c *gin.Context) { func (ugc *UserGroupController) updateUsers(c *gin.Context) {
var input dto.UserGroupUpdateUsersDto var input dto.UserGroupUpdateUsersDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input.UserIDs) group, err := ugc.UserGroupService.UpdateUsers(c.Request.Context(), c.Param("id"), input.UserIDs)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var groupDto dto.UserGroupDtoWithUsers var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil { if err := dto.MapStruct(group, &groupDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }

View File

@@ -2,7 +2,6 @@ package controller
import ( import (
"net/http" "net/http"
"strconv"
"time" "time"
"github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/protocol"
@@ -38,9 +37,9 @@ type WebauthnController struct {
func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) { func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
options, err := wc.webAuthnService.BeginRegistration(userID) options, err := wc.webAuthnService.BeginRegistration(c.Request.Context(), userID)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -51,20 +50,20 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) { func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
sessionID, err := c.Cookie(cookie.SessionIdCookieName) sessionID, err := c.Cookie(cookie.SessionIdCookieName)
if err != nil { if err != nil {
c.Error(&common.MissingSessionIdError{}) _ = c.Error(&common.MissingSessionIdError{})
return return
} }
userID := c.GetString("userID") userID := c.GetString("userID")
credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request) credential, err := wc.webAuthnService.VerifyRegistration(c.Request.Context(), sessionID, userID, c.Request)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var credentialDto dto.WebauthnCredentialDto var credentialDto dto.WebauthnCredentialDto
if err := dto.MapStruct(credential, &credentialDto); err != nil { if err := dto.MapStruct(credential, &credentialDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -72,9 +71,9 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
} }
func (wc *WebauthnController) beginLoginHandler(c *gin.Context) { func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
options, err := wc.webAuthnService.BeginLogin() options, err := wc.webAuthnService.BeginLogin(c.Request.Context())
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -85,30 +84,29 @@ func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) { func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
sessionID, err := c.Cookie(cookie.SessionIdCookieName) sessionID, err := c.Cookie(cookie.SessionIdCookieName)
if err != nil { if err != nil {
c.Error(&common.MissingSessionIdError{}) _ = c.Error(&common.MissingSessionIdError{})
return return
} }
credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body) credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
user, token, err := wc.webAuthnService.VerifyLogin(sessionID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent()) user, token, err := wc.webAuthnService.VerifyLogin(c.Request.Context(), sessionID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent())
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var userDto dto.UserDto var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil { if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
sessionDurationInMinutesParsed, _ := strconv.Atoi(wc.appConfigService.DbConfig.SessionDuration.Value) maxAge := int(wc.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes().Seconds())
maxAge := sessionDurationInMinutesParsed * 60
cookie.AddAccessTokenCookie(c, maxAge, token) cookie.AddAccessTokenCookie(c, maxAge, token)
c.JSON(http.StatusOK, userDto) c.JSON(http.StatusOK, userDto)
@@ -116,15 +114,15 @@ func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) { func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
credentials, err := wc.webAuthnService.ListCredentials(userID) credentials, err := wc.webAuthnService.ListCredentials(c.Request.Context(), userID)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var credentialDtos []dto.WebauthnCredentialDto var credentialDtos []dto.WebauthnCredentialDto
if err := dto.MapStructList(credentials, &credentialDtos); err != nil { if err := dto.MapStructList(credentials, &credentialDtos); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -135,9 +133,9 @@ func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
userID := c.GetString("userID") userID := c.GetString("userID")
credentialID := c.Param("id") credentialID := c.Param("id")
err := wc.webAuthnService.DeleteCredential(userID, credentialID) err := wc.webAuthnService.DeleteCredential(c.Request.Context(), userID, credentialID)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -150,19 +148,19 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) {
var input dto.WebauthnCredentialUpdateDto var input dto.WebauthnCredentialUpdateDto
if err := c.ShouldBindJSON(&input); err != nil { if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name) credential, err := wc.webAuthnService.UpdateCredential(c.Request.Context(), userID, credentialID, input.Name)
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
var credentialDto dto.WebauthnCredentialDto var credentialDto dto.WebauthnCredentialDto
if err := dto.MapStruct(credential, &credentialDto); err != nil { if err := dto.MapStruct(credential, &credentialDto); err != nil {
c.Error(err) _ = c.Error(err)
return return
} }

View File

@@ -1,9 +1,13 @@
package controller package controller
import ( import (
"encoding/json"
"fmt"
"log"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/service" "github.com/pocket-id/pocket-id/backend/internal/service"
) )
@@ -14,12 +18,21 @@ import (
// @Tags Well Known // @Tags Well Known
func NewWellKnownController(group *gin.RouterGroup, jwtService *service.JwtService) { func NewWellKnownController(group *gin.RouterGroup, jwtService *service.JwtService) {
wkc := &WellKnownController{jwtService: jwtService} wkc := &WellKnownController{jwtService: jwtService}
// Pre-compute the OIDC configuration document, which is static
var err error
wkc.oidcConfig, err = wkc.computeOIDCConfiguration()
if err != nil {
log.Fatalf("Failed to pre-compute OpenID Connect configuration document: %v", err)
}
group.GET("/.well-known/jwks.json", wkc.jwksHandler) group.GET("/.well-known/jwks.json", wkc.jwksHandler)
group.GET("/.well-known/openid-configuration", wkc.openIDConfigurationHandler) group.GET("/.well-known/openid-configuration", wkc.openIDConfigurationHandler)
} }
type WellKnownController struct { type WellKnownController struct {
jwtService *service.JwtService jwtService *service.JwtService
oidcConfig []byte
} }
// jwksHandler godoc // jwksHandler godoc
@@ -32,7 +45,7 @@ type WellKnownController struct {
func (wkc *WellKnownController) jwksHandler(c *gin.Context) { func (wkc *WellKnownController) jwksHandler(c *gin.Context) {
jwks, err := wkc.jwtService.GetPublicJWKSAsJSON() jwks, err := wkc.jwtService.GetPublicJWKSAsJSON()
if err != nil { if err != nil {
c.Error(err) _ = c.Error(err)
return return
} }
@@ -46,19 +59,29 @@ func (wkc *WellKnownController) jwksHandler(c *gin.Context) {
// @Success 200 {object} object "OpenID Connect configuration" // @Success 200 {object} object "OpenID Connect configuration"
// @Router /.well-known/openid-configuration [get] // @Router /.well-known/openid-configuration [get]
func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) { func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) {
c.Data(http.StatusOK, "application/json; charset=utf-8", wkc.oidcConfig)
}
func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
appUrl := common.EnvConfig.AppURL appUrl := common.EnvConfig.AppURL
config := map[string]interface{}{ alg, err := wkc.jwtService.GetKeyAlg()
if err != nil {
return nil, fmt.Errorf("failed to get key algorithm: %w", err)
}
config := map[string]any{
"issuer": appUrl, "issuer": appUrl,
"authorization_endpoint": appUrl + "/authorize", "authorization_endpoint": appUrl + "/authorize",
"token_endpoint": appUrl + "/api/oidc/token", "token_endpoint": appUrl + "/api/oidc/token",
"userinfo_endpoint": appUrl + "/api/oidc/userinfo", "userinfo_endpoint": appUrl + "/api/oidc/userinfo",
"end_session_endpoint": appUrl + "/api/oidc/end-session", "end_session_endpoint": appUrl + "/api/oidc/end-session",
"introspection_endpoint": appUrl + "/api/oidc/introspect",
"jwks_uri": appUrl + "/.well-known/jwks.json", "jwks_uri": appUrl + "/.well-known/jwks.json",
"grant_types_supported": []string{"authorization_code", "refresh_token"},
"scopes_supported": []string{"openid", "profile", "email", "groups"}, "scopes_supported": []string{"openid", "profile", "email", "groups"},
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"}, "claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
"response_types_supported": []string{"code", "id_token"}, "response_types_supported": []string{"code", "id_token"},
"subject_types_supported": []string{"public"}, "subject_types_supported": []string{"public"},
"id_token_signing_alg_values_supported": []string{"RS256"}, "id_token_signing_alg_values_supported": []string{alg.String()},
} }
c.JSON(http.StatusOK, config) return json.Marshal(config)
} }

View File

@@ -11,12 +11,13 @@ type ApiKeyCreateDto struct {
} }
type ApiKeyDto struct { type ApiKeyDto struct {
ID string `json:"id"` ID string `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
ExpiresAt datatype.DateTime `json:"expiresAt"` ExpiresAt datatype.DateTime `json:"expiresAt"`
LastUsedAt *datatype.DateTime `json:"lastUsedAt"` LastUsedAt *datatype.DateTime `json:"lastUsedAt"`
CreatedAt datatype.DateTime `json:"createdAt"` CreatedAt datatype.DateTime `json:"createdAt"`
ExpirationEmailSent bool `json:"expirationEmailSent"`
} }
type ApiKeyResponseDto struct { type ApiKeyResponseDto struct {

View File

@@ -12,35 +12,39 @@ type AppConfigVariableDto struct {
} }
type AppConfigUpdateDto struct { type AppConfigUpdateDto struct {
AppName string `json:"appName" binding:"required,min=1,max=30"` AppName string `json:"appName" binding:"required,min=1,max=30"`
SessionDuration string `json:"sessionDuration" binding:"required"` SessionDuration string `json:"sessionDuration" binding:"required"`
EmailsVerified string `json:"emailsVerified" binding:"required"` EmailsVerified string `json:"emailsVerified" binding:"required"`
AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"` DisableAnimations string `json:"disableAnimations" binding:"required"`
SmtHost string `json:"smtpHost"` AllowOwnAccountEdit string `json:"allowOwnAccountEdit" binding:"required"`
SmtpPort string `json:"smtpPort"` SmtpHost string `json:"smtpHost"`
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"` SmtpPort string `json:"smtpPort"`
SmtpUser string `json:"smtpUser"` SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`
SmtpPassword string `json:"smtpPassword"` SmtpUser string `json:"smtpUser"`
SmtpTls string `json:"smtpTls" binding:"required,oneof=none starttls tls"` SmtpPassword string `json:"smtpPassword"`
SmtpSkipCertVerify string `json:"smtpSkipCertVerify"` SmtpTls string `json:"smtpTls" binding:"required,oneof=none starttls tls"`
LdapEnabled string `json:"ldapEnabled" binding:"required"` SmtpSkipCertVerify string `json:"smtpSkipCertVerify"`
LdapUrl string `json:"ldapUrl"` LdapEnabled string `json:"ldapEnabled" binding:"required"`
LdapBindDn string `json:"ldapBindDn"` LdapUrl string `json:"ldapUrl"`
LdapBindPassword string `json:"ldapBindPassword"` LdapBindDn string `json:"ldapBindDn"`
LdapBase string `json:"ldapBase"` LdapBindPassword string `json:"ldapBindPassword"`
LdapUserSearchFilter string `json:"ldapUserSearchFilter"` LdapBase string `json:"ldapBase"`
LdapUserGroupSearchFilter string `json:"ldapUserGroupSearchFilter"` LdapUserSearchFilter string `json:"ldapUserSearchFilter"`
LdapSkipCertVerify string `json:"ldapSkipCertVerify"` LdapUserGroupSearchFilter string `json:"ldapUserGroupSearchFilter"`
LdapAttributeUserUniqueIdentifier string `json:"ldapAttributeUserUniqueIdentifier"` LdapSkipCertVerify string `json:"ldapSkipCertVerify"`
LdapAttributeUserUsername string `json:"ldapAttributeUserUsername"` LdapAttributeUserUniqueIdentifier string `json:"ldapAttributeUserUniqueIdentifier"`
LdapAttributeUserEmail string `json:"ldapAttributeUserEmail"` LdapAttributeUserUsername string `json:"ldapAttributeUserUsername"`
LdapAttributeUserFirstName string `json:"ldapAttributeUserFirstName"` LdapAttributeUserEmail string `json:"ldapAttributeUserEmail"`
LdapAttributeUserLastName string `json:"ldapAttributeUserLastName"` LdapAttributeUserFirstName string `json:"ldapAttributeUserFirstName"`
LdapAttributeUserProfilePicture string `json:"ldapAttributeUserProfilePicture"` LdapAttributeUserLastName string `json:"ldapAttributeUserLastName"`
LdapAttributeGroupMember string `json:"ldapAttributeGroupMember"` LdapAttributeUserProfilePicture string `json:"ldapAttributeUserProfilePicture"`
LdapAttributeGroupUniqueIdentifier string `json:"ldapAttributeGroupUniqueIdentifier"` LdapAttributeGroupMember string `json:"ldapAttributeGroupMember"`
LdapAttributeGroupName string `json:"ldapAttributeGroupName"` LdapAttributeGroupUniqueIdentifier string `json:"ldapAttributeGroupUniqueIdentifier"`
LdapAttributeAdminGroup string `json:"ldapAttributeAdminGroup"` LdapAttributeGroupName string `json:"ldapAttributeGroupName"`
EmailOneTimeAccessEnabled string `json:"emailOneTimeAccessEnabled" binding:"required"` LdapAttributeAdminGroup string `json:"ldapAttributeAdminGroup"`
EmailLoginNotificationEnabled string `json:"emailLoginNotificationEnabled" binding:"required"` LdapSoftDeleteUsers string `json:"ldapSoftDeleteUsers"`
EmailOneTimeAccessAsAdminEnabled string `json:"emailOneTimeAccessAsAdminEnabled" binding:"required"`
EmailOneTimeAccessAsUnauthenticatedEnabled string `json:"emailOneTimeAccessAsUnauthenticatedEnabled" binding:"required"`
EmailLoginNotificationEnabled string `json:"emailLoginNotificationEnabled" binding:"required"`
EmailApiKeyExpirationEnabled string `json:"emailApiKeyExpirationEnabled" binding:"required"`
} }

View File

@@ -15,5 +15,12 @@ type AuditLogDto struct {
City string `json:"city"` City string `json:"city"`
Device string `json:"device"` Device string `json:"device"`
UserID string `json:"userID"` UserID string `json:"userID"`
Username string `json:"username"`
Data model.AuditLogData `json:"data"` Data model.AuditLogData `json:"data"`
} }
type AuditLogFilterDto struct {
UserID string `form:"filters[userId]"`
Event string `form:"filters[event]"`
ClientName string `form:"filters[clientName]"`
}

View File

@@ -40,13 +40,11 @@ func MapStruct[S any, D any](source S, destination *D) error {
} }
func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error { func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
// Loop through the fields of the destination struct
for i := 0; i < destVal.NumField(); i++ { for i := 0; i < destVal.NumField(); i++ {
destField := destVal.Field(i) destField := destVal.Field(i)
destFieldType := destVal.Type().Field(i) destFieldType := destVal.Type().Field(i)
if destFieldType.Anonymous { if destFieldType.Anonymous {
// Recursively handle embedded structs
if err := mapStructInternal(sourceVal, destField); err != nil { if err := mapStructInternal(sourceVal, destField); err != nil {
return err return err
} }
@@ -55,63 +53,57 @@ func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
sourceField := sourceVal.FieldByName(destFieldType.Name) sourceField := sourceVal.FieldByName(destFieldType.Name)
// If the source field is valid and can be assigned to the destination field
if sourceField.IsValid() && destField.CanSet() { if sourceField.IsValid() && destField.CanSet() {
// Handle direct assignment for simple types if err := mapField(sourceField, destField); err != nil {
if sourceField.Type() == destField.Type() { return err
destField.Set(sourceField)
} else if sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice {
// Handle slices
if sourceField.Type().Elem() == destField.Type().Elem() {
// Direct assignment for slices of primitive types or non-struct elements
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
newSlice.Index(j).Set(sourceField.Index(j))
}
destField.Set(newSlice)
} else if sourceField.Type().Elem().Kind() == reflect.Struct && destField.Type().Elem().Kind() == reflect.Struct {
// Recursively map slices of structs
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
// Get the element from both source and destination slice
sourceElem := sourceField.Index(j)
destElem := reflect.New(destField.Type().Elem()).Elem()
// Recursively map the struct elements
if err := mapStructInternal(sourceElem, destElem); err != nil {
return err
}
// Set the mapped element in the new slice
newSlice.Index(j).Set(destElem)
}
destField.Set(newSlice)
}
} else if sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct {
// Recursively map nested structs
if err := mapStructInternal(sourceField, destField); err != nil {
return err
}
} else {
// Type switch for specific type conversions
switch sourceField.Interface().(type) {
case datatype.DateTime:
// Convert datatype.DateTime to time.Time
if sourceField.Type() == reflect.TypeOf(datatype.DateTime{}) && destField.Type() == reflect.TypeOf(time.Time{}) {
dateValue := sourceField.Interface().(datatype.DateTime)
destField.Set(reflect.ValueOf(dateValue.ToTime()))
}
}
} }
} }
} }
return nil
}
func mapField(sourceField reflect.Value, destField reflect.Value) error {
switch {
case sourceField.Type() == destField.Type():
destField.Set(sourceField)
case sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice:
return mapSlice(sourceField, destField)
case sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct:
return mapStructInternal(sourceField, destField)
default:
return mapSpecialTypes(sourceField, destField)
}
return nil
}
func mapSlice(sourceField reflect.Value, destField reflect.Value) error {
if sourceField.Type().Elem() == destField.Type().Elem() {
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
newSlice.Index(j).Set(sourceField.Index(j))
}
destField.Set(newSlice)
} else if sourceField.Type().Elem().Kind() == reflect.Struct && destField.Type().Elem().Kind() == reflect.Struct {
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
sourceElem := sourceField.Index(j)
destElem := reflect.New(destField.Type().Elem()).Elem()
if err := mapStructInternal(sourceElem, destElem); err != nil {
return err
}
newSlice.Index(j).Set(destElem)
}
destField.Set(newSlice)
}
return nil
}
func mapSpecialTypes(sourceField reflect.Value, destField reflect.Value) error {
if _, ok := sourceField.Interface().(datatype.DateTime); ok {
if sourceField.Type() == reflect.TypeOf(datatype.DateTime{}) && destField.Type() == reflect.TypeOf(time.Time{}) {
dateValue := sourceField.Interface().(datatype.DateTime)
destField.Set(reflect.ValueOf(dateValue.ToTime()))
}
}
return nil return nil
} }

View File

@@ -48,10 +48,15 @@ type AuthorizationRequiredDto struct {
type OidcCreateTokensDto struct { type OidcCreateTokensDto struct {
GrantType string `form:"grant_type" binding:"required"` GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"` Code string `form:"code"`
ClientID string `form:"client_id"` ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"` ClientSecret string `form:"client_secret"`
CodeVerifier string `form:"code_verifier"` CodeVerifier string `form:"code_verifier"`
RefreshToken string `form:"refresh_token"`
}
type OidcIntrospectDto struct {
Token string `form:"token" binding:"required"`
} }
type OidcUpdateAllowedUserGroupsDto struct { type OidcUpdateAllowedUserGroupsDto struct {
@@ -64,3 +69,24 @@ type OidcLogoutDto struct {
PostLogoutRedirectUri string `form:"post_logout_redirect_uri"` PostLogoutRedirectUri string `form:"post_logout_redirect_uri"`
State string `form:"state"` State string `form:"state"`
} }
type OidcTokenResponseDto struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
IdToken string `json:"id_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
}
type OidcIntrospectionResponseDto struct {
Active bool `json:"active"`
TokenType string `json:"token_type,omitempty"`
Scope string `json:"scope,omitempty"`
Expiration int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
Subject string `json:"sub,omitempty"`
Audience []string `json:"aud,omitempty"`
Issuer string `json:"iss,omitempty"`
Identifier string `json:"jti,omitempty"`
}

View File

@@ -9,18 +9,22 @@ type UserDto struct {
FirstName string `json:"firstName"` FirstName string `json:"firstName"`
LastName string `json:"lastName"` LastName string `json:"lastName"`
IsAdmin bool `json:"isAdmin"` IsAdmin bool `json:"isAdmin"`
Locale *string `json:"locale"`
CustomClaims []CustomClaimDto `json:"customClaims"` CustomClaims []CustomClaimDto `json:"customClaims"`
UserGroups []UserGroupDto `json:"userGroups"` UserGroups []UserGroupDto `json:"userGroups"`
LdapID *string `json:"ldapId"` LdapID *string `json:"ldapId"`
Disabled bool `json:"disabled"`
} }
type UserCreateDto struct { type UserCreateDto struct {
Username string `json:"username" binding:"required,username,min=2,max=50"` Username string `json:"username" binding:"required,username,min=2,max=50"`
Email string `json:"email" binding:"required,email"` Email string `json:"email" binding:"required,email"`
FirstName string `json:"firstName" binding:"required,min=1,max=50"` FirstName string `json:"firstName" binding:"required,min=1,max=50"`
LastName string `json:"lastName" binding:"required,min=1,max=50"` LastName string `json:"lastName" binding:"required,min=1,max=50"`
IsAdmin bool `json:"isAdmin"` IsAdmin bool `json:"isAdmin"`
LdapID string `json:"-"` Locale *string `json:"locale"`
Disabled bool `json:"disabled"`
LdapID string `json:"-"`
} }
type OneTimeAccessTokenCreateDto struct { type OneTimeAccessTokenCreateDto struct {
@@ -28,11 +32,15 @@ type OneTimeAccessTokenCreateDto struct {
ExpiresAt time.Time `json:"expiresAt" binding:"required"` ExpiresAt time.Time `json:"expiresAt" binding:"required"`
} }
type OneTimeAccessEmailDto struct { type OneTimeAccessEmailAsUnauthenticatedUserDto struct {
Email string `json:"email" binding:"required,email"` Email string `json:"email" binding:"required,email"`
RedirectPath string `json:"redirectPath"` RedirectPath string `json:"redirectPath"`
} }
type OneTimeAccessEmailAsAdminDto struct {
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
}
type UserUpdateUserGroupDto struct { type UserUpdateUserGroupDto struct {
UserGroupIds []string `json:"userGroupIds" binding:"required"` UserGroupIds []string `json:"userGroupIds" binding:"required"`
} }

View File

@@ -1,10 +1,11 @@
package dto package dto
import ( import (
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
"log" "log"
"regexp" "regexp"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
) )
var validateUsername validator.Func = func(fl validator.FieldLevel) bool { var validateUsername validator.Func = func(fl validator.FieldLevel) bool {

View File

@@ -0,0 +1,53 @@
package job
import (
"context"
"log"
"github.com/go-co-op/gocron/v2"
"github.com/pocket-id/pocket-id/backend/internal/service"
)
type ApiKeyEmailJobs struct {
apiKeyService *service.ApiKeyService
appConfigService *service.AppConfigService
}
func RegisterApiKeyExpiryJob(ctx context.Context, apiKeyService *service.ApiKeyService, appConfigService *service.AppConfigService) {
jobs := &ApiKeyEmailJobs{
apiKeyService: apiKeyService,
appConfigService: appConfigService,
}
scheduler, err := gocron.NewScheduler()
if err != nil {
log.Fatalf("Failed to create a new scheduler: %v", err)
}
registerJob(ctx, scheduler, "ExpiredApiKeyEmailJob", "0 0 * * *", jobs.checkAndNotifyExpiringApiKeys)
scheduler.Start()
}
func (j *ApiKeyEmailJobs) checkAndNotifyExpiringApiKeys(ctx context.Context) error {
// Skip if the feature is disabled
if !j.appConfigService.GetDbConfig().EmailApiKeyExpirationEnabled.IsTrue() {
return nil
}
apiKeys, err := j.apiKeyService.ListExpiringApiKeys(ctx, 7)
if err != nil {
log.Printf("Failed to list expiring API keys: %v", err)
return err
}
for _, key := range apiKeys {
if key.User.Email == "" {
continue
}
if err := j.apiKeyService.SendApiKeyExpiringSoonEmail(ctx, key); err != nil {
log.Printf("Failed to send email for key %s: %v", key.ID, err)
}
}
return nil
}

View File

@@ -1,69 +0,0 @@
package job
import (
"log"
"time"
"github.com/go-co-op/gocron/v2"
"github.com/google/uuid"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"gorm.io/gorm"
)
func RegisterDbCleanupJobs(db *gorm.DB) {
scheduler, err := gocron.NewScheduler()
if err != nil {
log.Fatalf("Failed to create a new scheduler: %s", err)
}
jobs := &Jobs{db: db}
registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions)
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
scheduler.Start()
}
type Jobs struct {
db *gorm.DB
}
// ClearWebauthnSessions deletes WebAuthn sessions that have expired
func (j *Jobs) clearWebauthnSessions() error {
return j.db.Delete(&model.WebauthnSession{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
}
// ClearOneTimeAccessTokens deletes one-time access tokens that have expired
func (j *Jobs) clearOneTimeAccessTokens() error {
return j.db.Debug().Delete(&model.OneTimeAccessToken{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
}
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
func (j *Jobs) clearOidcAuthorizationCodes() error {
return j.db.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
}
// ClearAuditLogs deletes audit logs older than 90 days
func (j *Jobs) clearAuditLogs() error {
return j.db.Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90))).Error
}
func registerJob(scheduler gocron.Scheduler, name string, interval string, job func() error) {
_, err := scheduler.NewJob(
gocron.CronJob(interval, false),
gocron.NewTask(job),
gocron.WithEventListeners(
gocron.AfterJobRuns(func(jobID uuid.UUID, jobName string) {
log.Printf("Job %q run successfully", name)
}),
gocron.AfterJobRunsWithError(func(jobID uuid.UUID, jobName string, err error) {
log.Printf("Job %q failed with error: %v", name, err)
}),
),
)
if err != nil {
log.Fatalf("Failed to register job %q: %v", name, err)
}
}

View File

@@ -0,0 +1,73 @@
package job
import (
"context"
"log"
"time"
"github.com/go-co-op/gocron/v2"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
)
func RegisterDbCleanupJobs(ctx context.Context, db *gorm.DB) {
scheduler, err := gocron.NewScheduler()
if err != nil {
log.Fatalf("Failed to create a new scheduler: %s", err)
}
jobs := &DbCleanupJobs{db: db}
registerJob(ctx, scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions)
registerJob(ctx, scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
registerJob(ctx, scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
registerJob(ctx, scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens)
registerJob(ctx, scheduler, "ClearAuditLogs", "0 3 * * *", jobs.clearAuditLogs)
scheduler.Start()
}
type DbCleanupJobs struct {
db *gorm.DB
}
// ClearWebauthnSessions deletes WebAuthn sessions that have expired
func (j *DbCleanupJobs) clearWebauthnSessions(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.WebauthnSession{}, "expires_at < ?", datatype.DateTime(time.Now())).
Error
}
// ClearOneTimeAccessTokens deletes one-time access tokens that have expired
func (j *DbCleanupJobs) clearOneTimeAccessTokens(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.OneTimeAccessToken{}, "expires_at < ?", datatype.DateTime(time.Now())).
Error
}
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
func (j *DbCleanupJobs) clearOidcAuthorizationCodes(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", datatype.DateTime(time.Now())).
Error
}
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
func (j *DbCleanupJobs) clearOidcRefreshTokens(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.OidcRefreshToken{}, "expires_at < ?", datatype.DateTime(time.Now())).
Error
}
// ClearAuditLogs deletes audit logs older than 90 days
func (j *DbCleanupJobs) clearAuditLogs(ctx context.Context) error {
return j.db.
WithContext(ctx).
Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90))).
Error
}

View File

@@ -0,0 +1,84 @@
package job
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/go-co-op/gocron/v2"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
)
func RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB) {
scheduler, err := gocron.NewScheduler()
if err != nil {
log.Fatalf("Failed to create a new scheduler: %s", err)
}
jobs := &FileCleanupJobs{db: db}
registerJob(ctx, scheduler, "ClearUnusedDefaultProfilePictures", "0 2 * * 0", jobs.clearUnusedDefaultProfilePictures)
scheduler.Start()
}
type FileCleanupJobs struct {
db *gorm.DB
}
// ClearUnusedDefaultProfilePictures deletes default profile pictures that don't match any user's initials
func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures(ctx context.Context) error {
var users []model.User
err := j.db.
WithContext(ctx).
Find(&users).
Error
if err != nil {
return fmt.Errorf("failed to fetch users: %w", err)
}
// Create a map to track which initials are in use
initialsInUse := make(map[string]struct{})
for _, user := range users {
initialsInUse[user.Initials()] = struct{}{}
}
defaultPicturesDir := common.EnvConfig.UploadPath + "/profile-pictures/defaults"
if _, err := os.Stat(defaultPicturesDir); os.IsNotExist(err) {
return nil
}
files, err := os.ReadDir(defaultPicturesDir)
if err != nil {
return fmt.Errorf("failed to read default profile pictures directory: %w", err)
}
filesDeleted := 0
for _, file := range files {
if file.IsDir() {
continue // Skip directories
}
filename := file.Name()
initials := strings.TrimSuffix(filename, ".png")
// If these initials aren't used by any user, delete the file
if _, ok := initialsInUse[initials]; !ok {
filePath := filepath.Join(defaultPicturesDir, filename)
if err := os.Remove(filePath); err != nil {
log.Printf("Failed to delete unused default profile picture %s: %v", filePath, err)
} else {
filesDeleted++
}
}
}
log.Printf("Deleted %d unused default profile pictures", filesDeleted)
return nil
}

View File

@@ -0,0 +1,29 @@
package job
import (
"context"
"log"
"github.com/go-co-op/gocron/v2"
"github.com/google/uuid"
)
func registerJob(ctx context.Context, scheduler gocron.Scheduler, name string, interval string, job func(ctx context.Context) error) {
_, err := scheduler.NewJob(
gocron.CronJob(interval, false),
gocron.NewTask(job),
gocron.WithContext(ctx),
gocron.WithEventListeners(
gocron.AfterJobRuns(func(jobID uuid.UUID, jobName string) {
log.Printf("Job %q run successfully", name)
}),
gocron.AfterJobRunsWithError(func(jobID uuid.UUID, jobName string, err error) {
log.Printf("Job %q failed with error: %v", name, err)
}),
),
)
if err != nil {
log.Fatalf("Failed to register job %q: %v", name, err)
}
}

View File

@@ -1,6 +1,7 @@
package job package job
import ( import (
"context"
"log" "log"
"github.com/go-co-op/gocron/v2" "github.com/go-co-op/gocron/v2"
@@ -12,28 +13,30 @@ type LdapJobs struct {
appConfigService *service.AppConfigService appConfigService *service.AppConfigService
} }
func RegisterLdapJobs(ldapService *service.LdapService, appConfigService *service.AppConfigService) { func RegisterLdapJobs(ctx context.Context, ldapService *service.LdapService, appConfigService *service.AppConfigService) {
jobs := &LdapJobs{ldapService: ldapService, appConfigService: appConfigService} jobs := &LdapJobs{ldapService: ldapService, appConfigService: appConfigService}
scheduler, err := gocron.NewScheduler() scheduler, err := gocron.NewScheduler()
if err != nil { if err != nil {
log.Fatalf("Failed to create a new scheduler: %s", err) log.Fatalf("Failed to create a new scheduler: %v", err)
} }
// Register the job to run every hour // Register the job to run every hour
registerJob(scheduler, "SyncLdap", "0 * * * *", jobs.syncLdap) registerJob(ctx, scheduler, "SyncLdap", "0 * * * *", jobs.syncLdap)
// Run the job immediately on startup // Run the job immediately on startup
if err := jobs.syncLdap(); err != nil { err = jobs.syncLdap(ctx)
log.Printf("Failed to sync LDAP: %s", err) if err != nil {
log.Printf("Failed to sync LDAP: %v", err)
} }
scheduler.Start() scheduler.Start()
} }
func (j *LdapJobs) syncLdap() error { func (j *LdapJobs) syncLdap(ctx context.Context) error {
if j.appConfigService.DbConfig.LdapEnabled.Value == "true" { if !j.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return j.ldapService.SyncAll() return nil
} }
return nil
return j.ldapService.SyncAll(ctx)
} }

View File

@@ -23,7 +23,7 @@ func (m *ApiKeyAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
userID, isAdmin, err := m.Verify(c, adminRequired) userID, isAdmin, err := m.Verify(c, adminRequired)
if err != nil { if err != nil {
c.Abort() c.Abort()
c.Error(err) _ = c.Error(err)
return return
} }
@@ -36,12 +36,15 @@ func (m *ApiKeyAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
func (m *ApiKeyAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (userID string, isAdmin bool, err error) { func (m *ApiKeyAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (userID string, isAdmin bool, err error) {
apiKey := c.GetHeader("X-API-KEY") apiKey := c.GetHeader("X-API-KEY")
user, err := m.apiKeyService.ValidateApiKey(apiKey) user, err := m.apiKeyService.ValidateApiKey(c.Request.Context(), apiKey)
if err != nil { if err != nil {
return "", false, &common.NotSignedInError{} return "", false, &common.NotSignedInError{}
} }
// Check if the user is an admin if user.Disabled {
return "", false, &common.UserDisabledError{}
}
if adminRequired && !user.IsAdmin { if adminRequired && !user.IsAdmin {
return "", false, &common.MissingPermissionError{} return "", false, &common.MissingPermissionError{}
} }

View File

@@ -19,11 +19,12 @@ type AuthOptions struct {
func NewAuthMiddleware( func NewAuthMiddleware(
apiKeyService *service.ApiKeyService, apiKeyService *service.ApiKeyService,
userService *service.UserService,
jwtService *service.JwtService, jwtService *service.JwtService,
) *AuthMiddleware { ) *AuthMiddleware {
return &AuthMiddleware{ return &AuthMiddleware{
apiKeyMiddleware: NewApiKeyAuthMiddleware(apiKeyService, jwtService), apiKeyMiddleware: NewApiKeyAuthMiddleware(apiKeyService, jwtService),
jwtMiddleware: NewJwtAuthMiddleware(jwtService), jwtMiddleware: NewJwtAuthMiddleware(jwtService, userService),
options: AuthOptions{ options: AuthOptions{
AdminRequired: true, AdminRequired: true,
SuccessOptional: false, SuccessOptional: false,
@@ -57,12 +58,13 @@ func (m *AuthMiddleware) WithSuccessOptional() *AuthMiddleware {
func (m *AuthMiddleware) Add() gin.HandlerFunc { func (m *AuthMiddleware) Add() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// First try JWT auth
userID, isAdmin, err := m.jwtMiddleware.Verify(c, m.options.AdminRequired) userID, isAdmin, err := m.jwtMiddleware.Verify(c, m.options.AdminRequired)
if err == nil { if err == nil {
// JWT auth succeeded, continue with the request
c.Set("userID", userID) c.Set("userID", userID)
c.Set("userIsAdmin", isAdmin) c.Set("userIsAdmin", isAdmin)
if c.IsAborted() {
return
}
c.Next() c.Next()
return return
} }
@@ -70,9 +72,11 @@ func (m *AuthMiddleware) Add() gin.HandlerFunc {
// JWT auth failed, try API key auth // JWT auth failed, try API key auth
userID, isAdmin, err = m.apiKeyMiddleware.Verify(c, m.options.AdminRequired) userID, isAdmin, err = m.apiKeyMiddleware.Verify(c, m.options.AdminRequired)
if err == nil { if err == nil {
// API key auth succeeded, continue with the request
c.Set("userID", userID) c.Set("userID", userID)
c.Set("userIsAdmin", isAdmin) c.Set("userIsAdmin", isAdmin)
if c.IsAborted() {
return
}
c.Next() c.Next()
return return
} }
@@ -84,6 +88,6 @@ func (m *AuthMiddleware) Add() gin.HandlerFunc {
// Both JWT and API key auth failed // Both JWT and API key auth failed
c.Abort() c.Abort()
c.Error(err) _ = c.Error(err)
} }
} }

View File

@@ -1,6 +1,8 @@
package middleware package middleware
import ( import (
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
) )
@@ -14,16 +16,17 @@ func NewCorsMiddleware() *CorsMiddleware {
func (m *CorsMiddleware) Add() gin.HandlerFunc { func (m *CorsMiddleware) Add() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// Allow all origins for the token endpoint // Allow all origins for the token endpoint
if c.FullPath() == "/api/oidc/token" { switch c.FullPath() {
case "/api/oidc/token", "/api/oidc/introspect":
c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else { default:
c.Writer.Header().Set("Access-Control-Allow-Origin", common.EnvConfig.AppURL) c.Writer.Header().Set("Access-Control-Allow-Origin", common.EnvConfig.AppURL)
} }
c.Writer.Header().Set("Access-Control-Allow-Headers", "*") c.Writer.Header().Set("Access-Control-Allow-Headers", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
if c.Request.Method == "OPTIONS" { if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(204) c.AbortWithStatus(204)
return return
} }

View File

@@ -19,7 +19,7 @@ func (m *FileSizeLimitMiddleware) Add(maxSize int64) gin.HandlerFunc {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize) c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize)
if err := c.Request.ParseMultipartForm(maxSize); err != nil { if err := c.Request.ParseMultipartForm(maxSize); err != nil {
err = &common.FileTooLargeError{MaxSize: formatFileSize(maxSize)} err = &common.FileTooLargeError{MaxSize: formatFileSize(maxSize)}
c.Error(err) _ = c.Error(err)
c.Abort() c.Abort()
return return
} }

View File

@@ -10,20 +10,20 @@ import (
) )
type JwtAuthMiddleware struct { type JwtAuthMiddleware struct {
jwtService *service.JwtService userService *service.UserService
jwtService *service.JwtService
} }
func NewJwtAuthMiddleware(jwtService *service.JwtService) *JwtAuthMiddleware { func NewJwtAuthMiddleware(jwtService *service.JwtService, userService *service.UserService) *JwtAuthMiddleware {
return &JwtAuthMiddleware{jwtService: jwtService} return &JwtAuthMiddleware{jwtService: jwtService, userService: userService}
} }
func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc { func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
userID, isAdmin, err := m.Verify(c, adminRequired) userID, isAdmin, err := m.Verify(c, adminRequired)
if err != nil { if err != nil {
c.Abort() c.Abort()
c.Error(err) _ = c.Error(err)
return return
} }
@@ -33,27 +33,41 @@ func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
} }
} }
func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (userID string, isAdmin bool, err error) { func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject string, isAdmin bool, err error) {
// Extract the token from the cookie // Extract the token from the cookie
token, err := c.Cookie(cookie.AccessTokenCookieName) accessToken, err := c.Cookie(cookie.AccessTokenCookieName)
if err != nil { if err != nil {
// Try to extract the token from the Authorization header if it's not in the cookie // Try to extract the token from the Authorization header if it's not in the cookie
authorizationHeaderSplit := strings.Split(c.GetHeader("Authorization"), " ") var ok bool
if len(authorizationHeaderSplit) != 2 { _, accessToken, ok = strings.Cut(c.GetHeader("Authorization"), " ")
if !ok || accessToken == "" {
return "", false, &common.NotSignedInError{} return "", false, &common.NotSignedInError{}
} }
token = authorizationHeaderSplit[1]
} }
claims, err := m.jwtService.VerifyAccessToken(token) token, err := m.jwtService.VerifyAccessToken(accessToken)
if err != nil { if err != nil {
return "", false, &common.NotSignedInError{} return "", false, &common.NotSignedInError{}
} }
// Check if the user is an admin subject, ok := token.Subject()
if adminRequired && !claims.IsAdmin { if !ok {
_ = c.Error(&common.TokenInvalidError{})
return
}
user, err := m.userService.GetUser(c, subject)
if err != nil {
return "", false, &common.NotSignedInError{}
}
if user.Disabled {
return "", false, &common.UserDisabledError{}
}
if adminRequired && !user.IsAdmin {
return "", false, &common.MissingPermissionError{} return "", false, &common.MissingPermissionError{}
} }
return claims.Subject, claims.IsAdmin, nil return subject, isAdmin, nil
} }

View File

@@ -36,7 +36,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
limiter := getLimiter(ip, limit, burst, &mu, clients) limiter := getLimiter(ip, limit, burst, &mu, clients)
if !limiter.Allow() { if !limiter.Allow() {
c.Error(&common.TooManyRequestsError{}) _ = c.Error(&common.TooManyRequestsError{})
c.Abort() c.Abort()
return return
} }

View File

@@ -1,17 +1,16 @@
package model package model
import ( import datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/model/types"
)
type ApiKey struct { type ApiKey struct {
Base Base
Name string `sortable:"true"` Name string `sortable:"true"`
Key string Key string
Description *string Description *string
ExpiresAt datatype.DateTime `sortable:"true"` ExpiresAt datatype.DateTime `sortable:"true"`
LastUsedAt *datatype.DateTime `sortable:"true"` LastUsedAt *datatype.DateTime `sortable:"true"`
ExpirationEmailSent bool
UserID string UserID string
User User User User

View File

@@ -1,51 +1,189 @@
package model package model
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
type AppConfigVariable struct { type AppConfigVariable struct {
Key string `gorm:"primaryKey;not null"` Key string `gorm:"primaryKey;not null"`
Type string Value string
IsPublic bool }
IsInternal bool
Value string // IsTrue returns true if the value is a truthy string, such as "true", "t", "yes", "1", etc.
DefaultValue string func (a *AppConfigVariable) IsTrue() bool {
ok, _ := strconv.ParseBool(a.Value)
return ok
}
// AsDurationMinutes returns the value as a time.Duration, interpreting the string as a whole number of minutes.
func (a *AppConfigVariable) AsDurationMinutes() time.Duration {
val, err := strconv.Atoi(a.Value)
if err != nil {
return 0
}
return time.Duration(val) * time.Minute
} }
type AppConfig struct { type AppConfig struct {
// General // General
AppName AppConfigVariable AppName AppConfigVariable `key:"appName,public"` // Public
SessionDuration AppConfigVariable SessionDuration AppConfigVariable `key:"sessionDuration"`
EmailsVerified AppConfigVariable EmailsVerified AppConfigVariable `key:"emailsVerified"`
AllowOwnAccountEdit AppConfigVariable DisableAnimations AppConfigVariable `key:"disableAnimations,public"` // Public
AllowOwnAccountEdit AppConfigVariable `key:"allowOwnAccountEdit,public"` // Public
// Internal // Internal
BackgroundImageType AppConfigVariable BackgroundImageType AppConfigVariable `key:"backgroundImageType,internal"` // Internal
LogoLightImageType AppConfigVariable LogoLightImageType AppConfigVariable `key:"logoLightImageType,internal"` // Internal
LogoDarkImageType AppConfigVariable LogoDarkImageType AppConfigVariable `key:"logoDarkImageType,internal"` // Internal
// Email // Email
SmtpHost AppConfigVariable SmtpHost AppConfigVariable `key:"smtpHost"`
SmtpPort AppConfigVariable SmtpPort AppConfigVariable `key:"smtpPort"`
SmtpFrom AppConfigVariable SmtpFrom AppConfigVariable `key:"smtpFrom"`
SmtpUser AppConfigVariable SmtpUser AppConfigVariable `key:"smtpUser"`
SmtpPassword AppConfigVariable SmtpPassword AppConfigVariable `key:"smtpPassword"`
SmtpTls AppConfigVariable SmtpTls AppConfigVariable `key:"smtpTls"`
SmtpSkipCertVerify AppConfigVariable SmtpSkipCertVerify AppConfigVariable `key:"smtpSkipCertVerify"`
EmailLoginNotificationEnabled AppConfigVariable EmailLoginNotificationEnabled AppConfigVariable `key:"emailLoginNotificationEnabled"`
EmailOneTimeAccessEnabled AppConfigVariable EmailOneTimeAccessAsUnauthenticatedEnabled AppConfigVariable `key:"emailOneTimeAccessAsUnauthenticatedEnabled,public"` // Public
EmailOneTimeAccessAsAdminEnabled AppConfigVariable `key:"emailOneTimeAccessAsAdminEnabled,public"` // Public
EmailApiKeyExpirationEnabled AppConfigVariable `key:"emailApiKeyExpirationEnabled"`
// LDAP // LDAP
LdapEnabled AppConfigVariable LdapEnabled AppConfigVariable `key:"ldapEnabled,public"` // Public
LdapUrl AppConfigVariable LdapUrl AppConfigVariable `key:"ldapUrl"`
LdapBindDn AppConfigVariable LdapBindDn AppConfigVariable `key:"ldapBindDn"`
LdapBindPassword AppConfigVariable LdapBindPassword AppConfigVariable `key:"ldapBindPassword"`
LdapBase AppConfigVariable LdapBase AppConfigVariable `key:"ldapBase"`
LdapUserSearchFilter AppConfigVariable LdapUserSearchFilter AppConfigVariable `key:"ldapUserSearchFilter"`
LdapUserGroupSearchFilter AppConfigVariable LdapUserGroupSearchFilter AppConfigVariable `key:"ldapUserGroupSearchFilter"`
LdapSkipCertVerify AppConfigVariable LdapSkipCertVerify AppConfigVariable `key:"ldapSkipCertVerify"`
LdapAttributeUserUniqueIdentifier AppConfigVariable LdapAttributeUserUniqueIdentifier AppConfigVariable `key:"ldapAttributeUserUniqueIdentifier"`
LdapAttributeUserUsername AppConfigVariable LdapAttributeUserUsername AppConfigVariable `key:"ldapAttributeUserUsername"`
LdapAttributeUserEmail AppConfigVariable LdapAttributeUserEmail AppConfigVariable `key:"ldapAttributeUserEmail"`
LdapAttributeUserFirstName AppConfigVariable LdapAttributeUserFirstName AppConfigVariable `key:"ldapAttributeUserFirstName"`
LdapAttributeUserLastName AppConfigVariable LdapAttributeUserLastName AppConfigVariable `key:"ldapAttributeUserLastName"`
LdapAttributeUserProfilePicture AppConfigVariable LdapAttributeUserProfilePicture AppConfigVariable `key:"ldapAttributeUserProfilePicture"`
LdapAttributeGroupMember AppConfigVariable LdapAttributeGroupMember AppConfigVariable `key:"ldapAttributeGroupMember"`
LdapAttributeGroupUniqueIdentifier AppConfigVariable LdapAttributeGroupUniqueIdentifier AppConfigVariable `key:"ldapAttributeGroupUniqueIdentifier"`
LdapAttributeGroupName AppConfigVariable LdapAttributeGroupName AppConfigVariable `key:"ldapAttributeGroupName"`
LdapAttributeAdminGroup AppConfigVariable LdapAttributeAdminGroup AppConfigVariable `key:"ldapAttributeAdminGroup"`
LdapSoftDeleteUsers AppConfigVariable `key:"ldapSoftDeleteUsers"`
}
func (c *AppConfig) ToAppConfigVariableSlice(showAll bool) []AppConfigVariable {
// Use reflection to iterate through all fields
cfgValue := reflect.ValueOf(c).Elem()
cfgType := cfgValue.Type()
var res []AppConfigVariable
for i := range cfgType.NumField() {
field := cfgType.Field(i)
key, attrs, _ := strings.Cut(field.Tag.Get("key"), ",")
if key == "" {
continue
}
// If we're only showing public variables and this is not public, skip it
if !showAll && attrs != "public" {
continue
}
fieldValue := cfgValue.Field(i)
appConfigVariable := AppConfigVariable{
Key: key,
Value: fieldValue.FieldByName("Value").String(),
}
res = append(res, appConfigVariable)
}
return res
}
func (c *AppConfig) FieldByKey(key string) (string, error) {
rv := reflect.ValueOf(c).Elem()
rt := rv.Type()
// Find the field in the struct whose "key" tag matches
for i := range rt.NumField() {
// Grab only the first part of the key, if there's a comma with additional properties
tagValue, _, _ := strings.Cut(rt.Field(i).Tag.Get("key"), ",")
if tagValue != key {
continue
}
valueField := rv.Field(i).FieldByName("Value")
return valueField.String(), nil
}
// If we are here, the config key was not found
return "", AppConfigKeyNotFoundError{field: key}
}
func (c *AppConfig) UpdateField(key string, value string, noInternal bool) error {
rv := reflect.ValueOf(c).Elem()
rt := rv.Type()
// Find the field in the struct whose "key" tag matches, then update that
for i := range rt.NumField() {
// Separate the key (before the comma) from any optional attributes after
tagValue, attrs, _ := strings.Cut(rt.Field(i).Tag.Get("key"), ",")
if tagValue != key {
continue
}
// If the field is internal and noInternal is true, we skip that
if noInternal && attrs == "internal" {
return AppConfigInternalForbiddenError{field: key}
}
valueField := rv.Field(i).FieldByName("Value")
if !valueField.CanSet() {
return fmt.Errorf("field Value in AppConfigVariable is not settable for config key '%s'", key)
}
// Update the value
valueField.SetString(value)
// Return once updated
return nil
}
// If we're here, we have not found the right field to update
return AppConfigKeyNotFoundError{field: key}
}
type AppConfigKeyNotFoundError struct {
field string
}
func (e AppConfigKeyNotFoundError) Error() string {
return fmt.Sprintf("cannot find config key '%s'", e.field)
}
func (e AppConfigKeyNotFoundError) Is(target error) bool {
// Ignore the field property when checking if an error is of the type AppConfigKeyNotFoundError
x := AppConfigKeyNotFoundError{}
return errors.As(target, &x)
}
type AppConfigInternalForbiddenError struct {
field string
}
func (e AppConfigInternalForbiddenError) Error() string {
return fmt.Sprintf("field '%s' is internal and can't be updated", e.field)
}
func (e AppConfigInternalForbiddenError) Is(target error) bool {
// Ignore the field property when checking if an error is of the type AppConfigInternalForbiddenError
x := AppConfigInternalForbiddenError{}
return errors.As(target, &x)
} }

View File

@@ -0,0 +1,129 @@
// We use model_test here to avoid an import cycle
package model_test
import (
"reflect"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
)
func TestAppConfigVariable_AsMinutesDuration(t *testing.T) {
tests := []struct {
name string
value string
expected time.Duration
expectedSeconds int
}{
{
name: "valid positive integer",
value: "60",
expected: 60 * time.Minute,
expectedSeconds: 3600,
},
{
name: "valid zero integer",
value: "0",
expected: 0,
expectedSeconds: 0,
},
{
name: "negative integer",
value: "-30",
expected: -30 * time.Minute,
expectedSeconds: -1800,
},
{
name: "invalid non-integer",
value: "not-a-number",
expected: 0,
expectedSeconds: 0,
},
{
name: "empty string",
value: "",
expected: 0,
expectedSeconds: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configVar := model.AppConfigVariable{
Value: tt.value,
}
result := configVar.AsDurationMinutes()
assert.Equal(t, tt.expected, result)
assert.Equal(t, tt.expectedSeconds, int(result.Seconds()))
})
}
}
// This test ensures that the model.AppConfig and dto.AppConfigUpdateDto structs match:
// - They should have the same properties, where the "json" tag of dto.AppConfigUpdateDto should match the "key" tag in model.AppConfig
// - dto.AppConfigDto should not include "internal" fields from model.AppConfig
// This test is primarily meant to catch discrepancies between the two structs as fields are added or removed over time
func TestAppConfigStructMatchesUpdateDto(t *testing.T) {
appConfigType := reflect.TypeOf(model.AppConfig{})
updateDtoType := reflect.TypeOf(dto.AppConfigUpdateDto{})
// Process AppConfig fields
appConfigFields := make(map[string]string)
for i := 0; i < appConfigType.NumField(); i++ {
field := appConfigType.Field(i)
if field.Tag.Get("key") == "" {
// Skip internal fields
continue
}
// Extract the key name from the tag (takes the part before any comma)
keyTag := field.Tag.Get("key")
keyName, _, _ := strings.Cut(keyTag, ",")
appConfigFields[field.Name] = keyName
}
// Process AppConfigUpdateDto fields
dtoFields := make(map[string]string)
for i := 0; i < updateDtoType.NumField(); i++ {
field := updateDtoType.Field(i)
// Extract the json name from the tag (takes the part before any binding constraints)
jsonTag := field.Tag.Get("json")
jsonName, _, _ := strings.Cut(jsonTag, ",")
dtoFields[jsonName] = field.Name
}
// Verify every AppConfig field has a matching DTO field with the same name
for fieldName, keyName := range appConfigFields {
if strings.HasSuffix(fieldName, "ImageType") {
// Skip internal fields that shouldn't be in the DTO
continue
}
// Check if there's a DTO field with a matching JSON tag
_, exists := dtoFields[keyName]
assert.True(t, exists, "Field %s with key '%s' in AppConfig has no matching field in AppConfigUpdateDto", fieldName, keyName)
}
// Verify every DTO field has a matching AppConfig field
for jsonName, fieldName := range dtoFields {
// Find a matching field in AppConfig by key tag
found := false
for _, keyName := range appConfigFields {
if keyName == jsonName {
found = true
break
}
}
assert.True(t, found, "Field %s with json tag '%s' in AppConfigUpdateDto has no matching field in AppConfig", fieldName, jsonName)
}
}

View File

@@ -3,7 +3,7 @@ package model
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "fmt"
) )
type AuditLog struct { type AuditLog struct {
@@ -14,13 +14,16 @@ type AuditLog struct {
Country string `sortable:"true"` Country string `sortable:"true"`
City string `sortable:"true"` City string `sortable:"true"`
UserAgent string `sortable:"true"` UserAgent string `sortable:"true"`
UserID string Username string `gorm:"-"`
Data AuditLogData Data AuditLogData
UserID string
User User
} }
type AuditLogData map[string]string type AuditLogData map[string]string //nolint:recvcheck
type AuditLogEvent string type AuditLogEvent string //nolint:recvcheck
const ( const (
AuditLogEventSignIn AuditLogEvent = "SIGN_IN" AuditLogEventSignIn AuditLogEvent = "SIGN_IN"
@@ -31,7 +34,7 @@ const (
// Scan and Value methods for GORM to handle the custom type // Scan and Value methods for GORM to handle the custom type
func (e *AuditLogEvent) Scan(value interface{}) error { func (e *AuditLogEvent) Scan(value any) error {
*e = AuditLogEvent(value.(string)) *e = AuditLogEvent(value.(string))
return nil return nil
} }
@@ -40,11 +43,14 @@ func (e AuditLogEvent) Value() (driver.Value, error) {
return string(e), nil return string(e), nil
} }
func (d *AuditLogData) Scan(value interface{}) error { func (d *AuditLogData) Scan(value any) error {
if v, ok := value.([]byte); ok { switch v := value.(type) {
case []byte:
return json.Unmarshal(v, d) return json.Unmarshal(v, d)
} else { case string:
return errors.New("type assertion to []byte failed") return json.Unmarshal([]byte(v), d)
default:
return fmt.Errorf("unsupported type: %T", value)
} }
} }

View File

@@ -4,7 +4,7 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"gorm.io/gorm" "gorm.io/gorm"
) )

View File

@@ -3,7 +3,7 @@ package model
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "fmt"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"gorm.io/gorm" "gorm.io/gorm"
@@ -51,19 +51,36 @@ type OidcClient struct {
CreatedBy User CreatedBy User
} }
type OidcRefreshToken struct {
Base
Token string
ExpiresAt datatype.DateTime
Scope string
UserID string
User User
ClientID string
Client OidcClient
}
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) { func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
// Compute HasLogo field // Compute HasLogo field
c.HasLogo = c.ImageType != nil && *c.ImageType != "" c.HasLogo = c.ImageType != nil && *c.ImageType != ""
return nil return nil
} }
type UrlList []string type UrlList []string //nolint:recvcheck
func (cu *UrlList) Scan(value interface{}) error { func (cu *UrlList) Scan(value interface{}) error {
if v, ok := value.([]byte); ok { switch v := value.(type) {
case []byte:
return json.Unmarshal(v, cu) return json.Unmarshal(v, cu)
} else { case string:
return errors.New("type assertion to []byte failed") return json.Unmarshal([]byte(v), cu)
default:
return fmt.Errorf("unsupported type: %T", value)
} }
} }

View File

@@ -8,7 +8,7 @@ import (
) )
// DateTime custom type for time.Time to store date as unix timestamp for sqlite and as date for postgres // DateTime custom type for time.Time to store date as unix timestamp for sqlite and as date for postgres
type DateTime time.Time type DateTime time.Time //nolint:recvcheck
func (date *DateTime) Scan(value interface{}) (err error) { func (date *DateTime) Scan(value interface{}) (err error) {
*date = DateTime(value.(time.Time)) *date = DateTime(value.(time.Time))

View File

@@ -1,9 +1,12 @@
package model package model
import ( import (
"strings"
"github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn" "github.com/go-webauthn/webauthn/webauthn"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
) )
type User struct { type User struct {
@@ -14,7 +17,9 @@ type User struct {
FirstName string `sortable:"true"` FirstName string `sortable:"true"`
LastName string `sortable:"true"` LastName string `sortable:"true"`
IsAdmin bool `sortable:"true"` IsAdmin bool `sortable:"true"`
Locale *string
LdapID *string LdapID *string
Disabled bool `sortable:"true"`
CustomClaims []CustomClaim CustomClaims []CustomClaim
UserGroups []UserGroup `gorm:"many2many:user_groups_users;"` UserGroups []UserGroup `gorm:"many2many:user_groups_users;"`
@@ -62,6 +67,15 @@ func (u User) WebAuthnCredentialDescriptors() (descriptors []protocol.Credential
func (u User) FullName() string { return u.FirstName + " " + u.LastName } func (u User) FullName() string { return u.FirstName + " " + u.LastName }
func (u User) Initials() string {
first := utils.GetFirstCharacter(u.FirstName)
last := utils.GetFirstCharacter(u.LastName)
if first == "" && last == "" && len(u.Username) >= 2 {
return strings.ToUpper(u.Username[:2])
}
return strings.ToUpper(first + last)
}
type OneTimeAccessToken struct { type OneTimeAccessToken struct {
Base Base
Token string Token string

View File

@@ -3,7 +3,7 @@ package model
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"errors" "fmt"
"time" "time"
"github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/protocol"
@@ -45,15 +45,17 @@ type PublicKeyCredentialRequestOptions struct {
Timeout time.Duration Timeout time.Duration
} }
type AuthenticatorTransportList []protocol.AuthenticatorTransport type AuthenticatorTransportList []protocol.AuthenticatorTransport //nolint:recvcheck
// Scan and Value methods for GORM to handle the custom type // Scan and Value methods for GORM to handle the custom type
func (atl *AuthenticatorTransportList) Scan(value interface{}) error { func (atl *AuthenticatorTransportList) Scan(value interface{}) error {
switch v := value.(type) {
if v, ok := value.([]byte); ok { case []byte:
return json.Unmarshal(v, atl) return json.Unmarshal(v, atl)
} else { case string:
return errors.New("type assertion to []byte failed") return json.Unmarshal([]byte(v), atl)
default:
return fmt.Errorf("unsupported type: %T", value)
} }
} }

View File

@@ -1,28 +1,35 @@
package service package service
import ( import (
"context"
"errors" "errors"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"log"
"time" "time"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type ApiKeyService struct { type ApiKeyService struct {
db *gorm.DB db *gorm.DB
emailService *EmailService
} }
func NewApiKeyService(db *gorm.DB) *ApiKeyService { func NewApiKeyService(db *gorm.DB, emailService *EmailService) *ApiKeyService {
return &ApiKeyService{db: db} return &ApiKeyService{db: db, emailService: emailService}
} }
func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) { func (s *ApiKeyService) ListApiKeys(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) {
query := s.db.Where("user_id = ?", userID).Model(&model.ApiKey{}) query := s.db.
WithContext(ctx).
Where("user_id = ?", userID).
Model(&model.ApiKey{})
var apiKeys []model.ApiKey var apiKeys []model.ApiKey
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &apiKeys) pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &apiKeys)
@@ -33,7 +40,7 @@ func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils
return apiKeys, pagination, nil return apiKeys, pagination, nil
} }
func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) { func (s *ApiKeyService) CreateApiKey(ctx context.Context, userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) {
// Check if expiration is in the future // Check if expiration is in the future
if !input.ExpiresAt.ToTime().After(time.Now()) { if !input.ExpiresAt.ToTime().After(time.Now()) {
return model.ApiKey{}, "", &common.APIKeyExpirationDateError{} return model.ApiKey{}, "", &common.APIKeyExpirationDateError{}
@@ -53,7 +60,11 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (
UserID: userID, UserID: userID,
} }
if err := s.db.Create(&apiKey).Error; err != nil { err = s.db.
WithContext(ctx).
Create(&apiKey).
Error
if err != nil {
return model.ApiKey{}, "", err return model.ApiKey{}, "", err
} }
@@ -61,29 +72,44 @@ func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (
return apiKey, token, nil return apiKey, token, nil
} }
func (s *ApiKeyService) RevokeApiKey(userID, apiKeyID string) error { func (s *ApiKeyService) RevokeApiKey(ctx context.Context, userID, apiKeyID string) error {
var apiKey model.ApiKey var apiKey model.ApiKey
if err := s.db.Where("id = ? AND user_id = ?", apiKeyID, userID).First(&apiKey).Error; err != nil { err := s.db.
WithContext(ctx).
Where("id = ? AND user_id = ?", apiKeyID, userID).
Delete(&apiKey).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return &common.APIKeyNotFoundError{} return &common.APIKeyNotFoundError{}
} }
return err return err
} }
return s.db.Delete(&apiKey).Error return nil
} }
func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) { func (s *ApiKeyService) ValidateApiKey(ctx context.Context, apiKey string) (model.User, error) {
if apiKey == "" { if apiKey == "" {
return model.User{}, &common.NoAPIKeyProvidedError{} return model.User{}, &common.NoAPIKeyProvidedError{}
} }
var key model.ApiKey now := time.Now()
hashedKey := utils.CreateSha256Hash(apiKey) hashedKey := utils.CreateSha256Hash(apiKey)
if err := s.db.Preload("User").Where("key = ? AND expires_at > ?", var key model.ApiKey
hashedKey, datatype.DateTime(time.Now())).Preload("User").First(&key).Error; err != nil { err := s.db.
WithContext(ctx).
Model(&model.ApiKey{}).
Clauses(clause.Returning{}).
Where("key = ? AND expires_at > ?", hashedKey, datatype.DateTime(now)).
Updates(&model.ApiKey{
LastUsedAt: utils.Ptr(datatype.DateTime(now)),
}).
Preload("User").
First(&key).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return model.User{}, &common.InvalidAPIKeyError{} return model.User{}, &common.InvalidAPIKeyError{}
} }
@@ -91,12 +117,49 @@ func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) {
return model.User{}, err return model.User{}, err
} }
// Update last used time
now := datatype.DateTime(time.Now())
key.LastUsedAt = &now
if err := s.db.Save(&key).Error; err != nil {
log.Printf("Failed to update last used time: %v", err)
}
return key.User, nil return key.User, nil
} }
func (s *ApiKeyService) ListExpiringApiKeys(ctx context.Context, daysAhead int) ([]model.ApiKey, error) {
var keys []model.ApiKey
now := time.Now()
cutoff := now.AddDate(0, 0, daysAhead)
err := s.db.
WithContext(ctx).
Preload("User").
Where("expires_at > ? AND expires_at <= ? AND expiration_email_sent = ?", datatype.DateTime(now), datatype.DateTime(cutoff), false).
Find(&keys).
Error
return keys, err
}
func (s *ApiKeyService) SendApiKeyExpiringSoonEmail(ctx context.Context, apiKey model.ApiKey) error {
user := apiKey.User
if user.ID == "" {
if err := s.db.WithContext(ctx).First(&user, "id = ?", apiKey.UserID).Error; err != nil {
return err
}
}
err := SendEmail(ctx, s.emailService, email.Address{
Name: user.FullName(),
Email: user.Email,
}, ApiKeyExpiringSoonTemplate, &ApiKeyExpiringSoonTemplateData{
ApiKeyName: apiKey.Name,
ExpiresAt: apiKey.ExpiresAt.ToTime(),
Name: user.FirstName,
})
if err != nil {
return err
}
// Mark the API key as having had an expiration email sent
return s.db.WithContext(ctx).
Model(&model.ApiKey{}).
Where("id = ?", apiKey.ID).
Update("expiration_email_sent", true).
Error
}

View File

@@ -1,396 +1,426 @@
package service package service
import ( import (
"context"
"errors"
"fmt" "fmt"
"log" "log"
"mime/multipart" "mime/multipart"
"os" "os"
"reflect" "reflect"
"strings"
"sync/atomic"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
) )
type AppConfigService struct { type AppConfigService struct {
DbConfig *model.AppConfig dbConfig atomic.Pointer[model.AppConfig]
db *gorm.DB db *gorm.DB
} }
func NewAppConfigService(db *gorm.DB) *AppConfigService { func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
service := &AppConfigService{ service := &AppConfigService{
DbConfig: &defaultDbConfig, db: db,
db: db,
} }
if err := service.InitDbConfig(); err != nil {
err := service.LoadDbConfig(ctx)
if err != nil {
log.Fatalf("Failed to initialize app config service: %v", err) log.Fatalf("Failed to initialize app config service: %v", err)
} }
return service return service
} }
var defaultDbConfig = model.AppConfig{ // GetDbConfig returns the application configuration.
// General // Important: Treat the object as read-only: do not modify its properties directly!
AppName: model.AppConfigVariable{ func (s *AppConfigService) GetDbConfig() *model.AppConfig {
Key: "appName", v := s.dbConfig.Load()
Type: "string", if v == nil {
IsPublic: true, // This indicates a development-time error
DefaultValue: "Pocket ID", panic("called GetDbConfig before DbConfig is loaded")
}, }
SessionDuration: model.AppConfigVariable{
Key: "sessionDuration", return v
Type: "number",
DefaultValue: "60",
},
EmailsVerified: model.AppConfigVariable{
Key: "emailsVerified",
Type: "bool",
DefaultValue: "false",
},
AllowOwnAccountEdit: model.AppConfigVariable{
Key: "allowOwnAccountEdit",
Type: "bool",
IsPublic: true,
DefaultValue: "true",
},
// Internal
BackgroundImageType: model.AppConfigVariable{
Key: "backgroundImageType",
Type: "string",
IsInternal: true,
DefaultValue: "jpg",
},
LogoLightImageType: model.AppConfigVariable{
Key: "logoLightImageType",
Type: "string",
IsInternal: true,
DefaultValue: "svg",
},
LogoDarkImageType: model.AppConfigVariable{
Key: "logoDarkImageType",
Type: "string",
IsInternal: true,
DefaultValue: "svg",
},
// Email
SmtpHost: model.AppConfigVariable{
Key: "smtpHost",
Type: "string",
},
SmtpPort: model.AppConfigVariable{
Key: "smtpPort",
Type: "number",
},
SmtpFrom: model.AppConfigVariable{
Key: "smtpFrom",
Type: "string",
},
SmtpUser: model.AppConfigVariable{
Key: "smtpUser",
Type: "string",
},
SmtpPassword: model.AppConfigVariable{
Key: "smtpPassword",
Type: "string",
},
SmtpTls: model.AppConfigVariable{
Key: "smtpTls",
Type: "string",
DefaultValue: "none",
},
SmtpSkipCertVerify: model.AppConfigVariable{
Key: "smtpSkipCertVerify",
Type: "bool",
DefaultValue: "false",
},
EmailLoginNotificationEnabled: model.AppConfigVariable{
Key: "emailLoginNotificationEnabled",
Type: "bool",
DefaultValue: "false",
},
EmailOneTimeAccessEnabled: model.AppConfigVariable{
Key: "emailOneTimeAccessEnabled",
Type: "bool",
IsPublic: true,
DefaultValue: "false",
},
// LDAP
LdapEnabled: model.AppConfigVariable{
Key: "ldapEnabled",
Type: "bool",
IsPublic: true,
DefaultValue: "false",
},
LdapUrl: model.AppConfigVariable{
Key: "ldapUrl",
Type: "string",
},
LdapBindDn: model.AppConfigVariable{
Key: "ldapBindDn",
Type: "string",
},
LdapBindPassword: model.AppConfigVariable{
Key: "ldapBindPassword",
Type: "string",
},
LdapBase: model.AppConfigVariable{
Key: "ldapBase",
Type: "string",
},
LdapUserSearchFilter: model.AppConfigVariable{
Key: "ldapUserSearchFilter",
Type: "string",
DefaultValue: "(objectClass=person)",
},
LdapUserGroupSearchFilter: model.AppConfigVariable{
Key: "ldapUserGroupSearchFilter",
Type: "string",
DefaultValue: "(objectClass=groupOfNames)",
},
LdapSkipCertVerify: model.AppConfigVariable{
Key: "ldapSkipCertVerify",
Type: "bool",
DefaultValue: "false",
},
LdapAttributeUserUniqueIdentifier: model.AppConfigVariable{
Key: "ldapAttributeUserUniqueIdentifier",
Type: "string",
},
LdapAttributeUserUsername: model.AppConfigVariable{
Key: "ldapAttributeUserUsername",
Type: "string",
},
LdapAttributeUserEmail: model.AppConfigVariable{
Key: "ldapAttributeUserEmail",
Type: "string",
},
LdapAttributeUserFirstName: model.AppConfigVariable{
Key: "ldapAttributeUserFirstName",
Type: "string",
},
LdapAttributeUserLastName: model.AppConfigVariable{
Key: "ldapAttributeUserLastName",
Type: "string",
},
LdapAttributeUserProfilePicture: model.AppConfigVariable{
Key: "ldapAttributeUserProfilePicture",
Type: "string",
},
LdapAttributeGroupMember: model.AppConfigVariable{
Key: "ldapAttributeGroupMember",
Type: "string",
DefaultValue: "member",
},
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{
Key: "ldapAttributeGroupUniqueIdentifier",
Type: "string",
},
LdapAttributeGroupName: model.AppConfigVariable{
Key: "ldapAttributeGroupName",
Type: "string",
},
LdapAttributeAdminGroup: model.AppConfigVariable{
Key: "ldapAttributeAdminGroup",
Type: "string",
},
} }
func (s *AppConfigService) UpdateAppConfig(input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) { func (s *AppConfigService) getDefaultDbConfig() *model.AppConfig {
// Values are the default ones
return &model.AppConfig{
// General
AppName: model.AppConfigVariable{Value: "Pocket ID"},
SessionDuration: model.AppConfigVariable{Value: "60"},
EmailsVerified: model.AppConfigVariable{Value: "false"},
DisableAnimations: model.AppConfigVariable{Value: "false"},
AllowOwnAccountEdit: model.AppConfigVariable{Value: "true"},
// Internal
BackgroundImageType: model.AppConfigVariable{Value: "jpg"},
LogoLightImageType: model.AppConfigVariable{Value: "svg"},
LogoDarkImageType: model.AppConfigVariable{Value: "svg"},
// Email
SmtpHost: model.AppConfigVariable{},
SmtpPort: model.AppConfigVariable{},
SmtpFrom: model.AppConfigVariable{},
SmtpUser: model.AppConfigVariable{},
SmtpPassword: model.AppConfigVariable{},
SmtpTls: model.AppConfigVariable{Value: "none"},
SmtpSkipCertVerify: model.AppConfigVariable{Value: "false"},
EmailLoginNotificationEnabled: model.AppConfigVariable{Value: "false"},
EmailOneTimeAccessAsUnauthenticatedEnabled: model.AppConfigVariable{Value: "false"},
EmailOneTimeAccessAsAdminEnabled: model.AppConfigVariable{Value: "false"},
EmailApiKeyExpirationEnabled: model.AppConfigVariable{Value: "false"},
// LDAP
LdapEnabled: model.AppConfigVariable{Value: "false"},
LdapUrl: model.AppConfigVariable{},
LdapBindDn: model.AppConfigVariable{},
LdapBindPassword: model.AppConfigVariable{},
LdapBase: model.AppConfigVariable{},
LdapUserSearchFilter: model.AppConfigVariable{Value: "(objectClass=person)"},
LdapUserGroupSearchFilter: model.AppConfigVariable{Value: "(objectClass=groupOfNames)"},
LdapSkipCertVerify: model.AppConfigVariable{Value: "false"},
LdapAttributeUserUniqueIdentifier: model.AppConfigVariable{},
LdapAttributeUserUsername: model.AppConfigVariable{},
LdapAttributeUserEmail: model.AppConfigVariable{},
LdapAttributeUserFirstName: model.AppConfigVariable{},
LdapAttributeUserLastName: model.AppConfigVariable{},
LdapAttributeUserProfilePicture: model.AppConfigVariable{},
LdapAttributeGroupMember: model.AppConfigVariable{Value: "member"},
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{},
LdapAttributeGroupName: model.AppConfigVariable{},
LdapAttributeAdminGroup: model.AppConfigVariable{},
LdapSoftDeleteUsers: model.AppConfigVariable{Value: "true"},
}
}
func (s *AppConfigService) updateAppConfigStartTransaction(ctx context.Context) (tx *gorm.DB, err error) {
// We start a transaction before doing any work, to ensure that we are the only ones updating the data in the database
// This works across multiple processes too
tx = s.db.Begin()
err = tx.Error
if err != nil {
return nil, fmt.Errorf("failed to begin database transaction: %w", err)
}
// With SQLite there's nothing else we need to do, because a transaction blocks the entire database
// However, with Postgres we need to manually lock the table to prevent others from doing the same
switch s.db.Name() {
case "postgres":
// We do not use "NOWAIT" so this blocks until the database is available, or the context is canceled
// Here we use a context with a 10s timeout in case the database is blocked for longer
lockCtx, lockCancel := context.WithTimeout(ctx, 10*time.Second)
defer lockCancel()
err = tx.
WithContext(lockCtx).
Exec("LOCK TABLE app_config_variables IN ACCESS EXCLUSIVE MODE").
Error
if err != nil {
tx.Rollback()
return nil, fmt.Errorf("failed to acquire lock on app_config_variables table: %w", err)
}
default:
// Nothing to do here
}
return tx, nil
}
func (s *AppConfigService) updateAppConfigUpdateDatabase(ctx context.Context, tx *gorm.DB, dbUpdate *[]model.AppConfigVariable) error {
err := tx.
WithContext(ctx).
Clauses(clause.OnConflict{
// Perform an "upsert" if the key already exists, replacing the value
Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value"}),
}).
Create(&dbUpdate).
Error
if err != nil {
return fmt.Errorf("failed to update config in database: %w", err)
}
return nil
}
func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
if common.EnvConfig.UiConfigDisabled { if common.EnvConfig.UiConfigDisabled {
return nil, &common.UiConfigDisabledError{} return nil, &common.UiConfigDisabledError{}
} }
tx := s.db.Begin() // Start the transaction
rt := reflect.ValueOf(input).Type() tx, err := s.updateAppConfigStartTransaction(ctx)
rv := reflect.ValueOf(input) if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
var savedConfigVariables []model.AppConfigVariable // From here onwards, we know we are the only process/goroutine with exclusive access to the config
for i := 0; i < rt.NumField(); i++ { // Re-load the config from the database to be sure we have the correct data
field := rt.Field(i) cfg, err := s.loadDbConfigInternal(ctx, tx)
key := field.Tag.Get("json") if err != nil {
value := rv.FieldByName(field.Name).String() return nil, fmt.Errorf("failed to reload config from database: %w", err)
// If the emailEnabled is set to false, disable the emailOneTimeAccessEnabled
if key == s.DbConfig.EmailOneTimeAccessEnabled.Key {
if rv.FieldByName("EmailEnabled").String() == "false" {
value = "false"
}
}
var appConfigVariable model.AppConfigVariable
if err := tx.First(&appConfigVariable, "key = ? AND is_internal = false", key).Error; err != nil {
tx.Rollback()
return nil, err
}
appConfigVariable.Value = value
if err := tx.Save(&appConfigVariable).Error; err != nil {
tx.Rollback()
return nil, err
}
savedConfigVariables = append(savedConfigVariables, appConfigVariable)
} }
tx.Commit() defaultCfg := s.getDefaultDbConfig()
if err := s.LoadDbConfigFromDb(); err != nil { // Iterate through all the fields to update
// We update the in-memory data (in the cfg struct) and collect values to update in the database
rt := reflect.ValueOf(input).Type()
rv := reflect.ValueOf(input)
dbUpdate := make([]model.AppConfigVariable, 0, rt.NumField())
for i := range rt.NumField() {
field := rt.Field(i)
value := rv.FieldByName(field.Name).String()
// Get the value of the json tag, taking only what's before the comma
key, _, _ := strings.Cut(field.Tag.Get("json"), ",")
// Update the in-memory config value
// If the new value is an empty string, then we set the in-memory value to the default one
// Skip values that are internal only and can't be updated
if value == "" {
// Ignore errors here as we know the key exists
defaultValue, _ := defaultCfg.FieldByKey(key)
err = cfg.UpdateField(key, defaultValue, true)
} else {
err = cfg.UpdateField(key, value, true)
}
// If we tried to update an internal field, ignore the error (and do not update in the DB)
if errors.Is(err, model.AppConfigInternalForbiddenError{}) {
continue
} else if err != nil {
return nil, fmt.Errorf("failed to update in-memory config for key '%s': %w", key, err)
}
// We always save "value" which can be an empty string
dbUpdate = append(dbUpdate, model.AppConfigVariable{
Key: key,
Value: value,
})
}
// Update the values in the database
err = s.updateAppConfigUpdateDatabase(ctx, tx, &dbUpdate)
if err != nil {
return nil, err return nil, err
} }
return savedConfigVariables, nil // Commit the changes to the DB, then finally save the updated config in the object
err = tx.Commit().Error
if err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}
s.dbConfig.Store(cfg)
// Return the updated config
res := cfg.ToAppConfigVariableSlice(true)
return res, nil
} }
func (s *AppConfigService) UpdateImageType(imageName string, fileType string) error { // UpdateAppConfigValues
key := fmt.Sprintf("%sImageType", imageName) func (s *AppConfigService) UpdateAppConfigValues(ctx context.Context, keysAndValues ...string) error {
err := s.db.Model(&model.AppConfigVariable{}).Where("key = ?", key).Update("value", fileType).Error if common.EnvConfig.UiConfigDisabled {
return &common.UiConfigDisabledError{}
}
// Count of keysAndValues must be even
if len(keysAndValues)%2 != 0 {
return errors.New("invalid number of arguments received")
}
// Start the transaction
tx, err := s.updateAppConfigStartTransaction(ctx)
if err != nil {
return err
}
defer func() {
tx.Rollback()
}()
// From here onwards, we know we are the only process/goroutine with exclusive access to the config
// Re-load the config from the database to be sure we have the correct data
cfg, err := s.loadDbConfigInternal(ctx, tx)
if err != nil {
return fmt.Errorf("failed to reload config from database: %w", err)
}
defaultCfg := s.getDefaultDbConfig()
// Iterate through all the fields to update
// We update the in-memory data (in the cfg struct) and collect values to update in the database
// (Note the += 2, as we are iterating through key-value pairs)
dbUpdate := make([]model.AppConfigVariable, 0, len(keysAndValues)/2)
for i := 0; i < len(keysAndValues); i += 2 {
key := keysAndValues[i]
value := keysAndValues[i+1]
// Ensure that the field is valid
// We do this by grabbing the default value
var defaultValue string
defaultValue, err = defaultCfg.FieldByKey(key)
if err != nil {
return fmt.Errorf("invalid configuration key '%s': %w", key, err)
}
// Update the in-memory config value
// If the new value is an empty string, then we set the in-memory value to the default one
// Skip values that are internal only and can't be updated
if value == "" {
err = cfg.UpdateField(key, defaultValue, false)
} else {
err = cfg.UpdateField(key, value, false)
}
if err != nil {
return fmt.Errorf("failed to update in-memory config for key '%s': %w", key, err)
}
// We always save "value" which can be an empty string
dbUpdate = append(dbUpdate, model.AppConfigVariable{
Key: key,
Value: value,
})
}
// Update the values in the database
err = s.updateAppConfigUpdateDatabase(ctx, tx, &dbUpdate)
if err != nil { if err != nil {
return err return err
} }
return s.LoadDbConfigFromDb() // Commit the changes to the DB, then finally save the updated config in the object
} err = tx.Commit().Error
func (s *AppConfigService) ListAppConfig(showAll bool) ([]model.AppConfigVariable, error) {
var configuration []model.AppConfigVariable
var err error
if showAll {
err = s.db.Find(&configuration).Error
} else {
err = s.db.Find(&configuration, "is_public = true").Error
}
if err != nil { if err != nil {
return nil, err return fmt.Errorf("failed to commit transaction: %w", err)
} }
for i := range configuration { s.dbConfig.Store(cfg)
if common.EnvConfig.UiConfigDisabled {
// Set the value to the environment variable if the UI config is disabled
configuration[i].Value = s.getConfigVariableFromEnvironmentVariable(configuration[i].Key, configuration[i].DefaultValue)
} else if configuration[i].Value == "" && configuration[i].DefaultValue != "" { return nil
// Set the value to the default value if it is empty
configuration[i].Value = configuration[i].DefaultValue
}
}
return configuration, nil
} }
func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, imageName string, oldImageType string) error { func (s *AppConfigService) ListAppConfig(showAll bool) []model.AppConfigVariable {
return s.GetDbConfig().ToAppConfigVariableSlice(showAll)
}
func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multipart.FileHeader, imageName string, oldImageType string) (err error) {
fileType := utils.GetFileExtension(uploadedFile.Filename) fileType := utils.GetFileExtension(uploadedFile.Filename)
mimeType := utils.GetImageMimeType(fileType) mimeType := utils.GetImageMimeType(fileType)
if mimeType == "" { if mimeType == "" {
return &common.FileTypeNotSupportedError{} return &common.FileTypeNotSupportedError{}
} }
// Delete the old image if it has a different file type // Save the updated image
imagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + fileType
err = utils.SaveFile(uploadedFile, imagePath)
if err != nil {
return err
}
// Delete the old image if it has a different file type, then update the type in the database
if fileType != oldImageType { if fileType != oldImageType {
oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType) oldImagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + oldImageType
if err := os.Remove(oldImagePath); err != nil { err = os.Remove(oldImagePath)
if err != nil {
return err return err
} }
}
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType) // Update the file type in the database
if err := utils.SaveFile(uploadedFile, imagePath); err != nil { err = s.UpdateAppConfigValues(ctx, imageName+"ImageType", fileType)
return err if err != nil {
} return err
}
// Update the file type in the database
if err := s.UpdateImageType(imageName, fileType); err != nil {
return err
} }
return nil return nil
} }
// InitDbConfig creates the default configuration values in the database if they do not exist, // LoadDbConfig loads the configuration values from the database into the DbConfig struct.
// updates existing configurations if they differ from the default, and deletes any configurations func (s *AppConfigService) LoadDbConfig(ctx context.Context) (err error) {
// that are not in the default configuration. var dest *model.AppConfig
func (s *AppConfigService) InitDbConfig() error {
// Reflect to get the underlying value of DbConfig and its default configuration
defaultConfigReflectValue := reflect.ValueOf(defaultDbConfig)
defaultKeys := make(map[string]struct{})
// Iterate over the fields of DbConfig // If the UI config is disabled, only load from the env
for i := 0; i < defaultConfigReflectValue.NumField(); i++ { if common.EnvConfig.UiConfigDisabled {
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.AppConfigVariable) dest, err = s.loadDbConfigFromEnv()
} else {
dest, err = s.loadDbConfigInternal(ctx, s.db)
}
if err != nil {
return err
}
defaultKeys[defaultConfigVar.Key] = struct{}{} // Update the value in the object
s.dbConfig.Store(dest)
var storedConfigVar model.AppConfigVariable return nil
if err := s.db.First(&storedConfigVar, "key = ?", defaultConfigVar.Key).Error; err != nil { }
// If the configuration does not exist, create it
if err := s.db.Create(&defaultConfigVar).Error; err != nil { func (s *AppConfigService) loadDbConfigFromEnv() (*model.AppConfig, error) {
return err // First, start from the default configuration
} dest := s.getDefaultDbConfig()
// Iterate through each field
rt := reflect.ValueOf(dest).Elem().Type()
rv := reflect.ValueOf(dest).Elem()
for i := range rt.NumField() {
field := rt.Field(i)
// Get the value of the key tag, taking only what's before the comma
// The env var name is the key converted to SCREAMING_SNAKE_CASE
key, _, _ := strings.Cut(field.Tag.Get("key"), ",")
envVarName := utils.CamelCaseToScreamingSnakeCase(key)
// Set the value if it's set
value, ok := os.LookupEnv(envVarName)
if ok {
rv.Field(i).FieldByName("Value").SetString(value)
}
}
return dest, nil
}
func (s *AppConfigService) loadDbConfigInternal(ctx context.Context, tx *gorm.DB) (*model.AppConfig, error) {
// First, start from the default configuration
dest := s.getDefaultDbConfig()
// Load all configuration values from the database
// This loads all values in a single shot
loaded := []model.AppConfigVariable{}
queryCtx, queryCancel := context.WithTimeout(ctx, 10*time.Second)
defer queryCancel()
err := tx.
WithContext(queryCtx).
Find(&loaded).Error
if err != nil {
return nil, fmt.Errorf("failed to load configuration from the database: %w", err)
}
// Iterate through all values loaded from the database
for _, v := range loaded {
// If the value is empty, it means we are using the default value
if v.Value == "" {
continue continue
} }
// Update existing configuration if it differs from the default // Find the field in the struct whose "key" tag matches, then update that
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal || storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue { err = dest.UpdateField(v.Key, v.Value, false)
storedConfigVar.Type = defaultConfigVar.Type
storedConfigVar.IsPublic = defaultConfigVar.IsPublic // We ignore the case of fields that don't exist, as there may be leftover data in the database
storedConfigVar.IsInternal = defaultConfigVar.IsInternal if err != nil && !errors.Is(err, model.AppConfigKeyNotFoundError{}) {
storedConfigVar.DefaultValue = defaultConfigVar.DefaultValue return nil, fmt.Errorf("failed to process config for key '%s': %w", v.Key, err)
if err := s.db.Save(&storedConfigVar).Error; err != nil {
return err
}
} }
} }
// Delete any configurations not in the default keys return dest, nil
var allConfigVars []model.AppConfigVariable
if err := s.db.Find(&allConfigVars).Error; err != nil {
return err
}
for _, config := range allConfigVars {
if _, exists := defaultKeys[config.Key]; !exists {
if err := s.db.Delete(&config).Error; err != nil {
return err
}
}
}
return s.LoadDbConfigFromDb()
}
// LoadDbConfigFromDb loads the configuration values from the database into the DbConfig struct.
func (s *AppConfigService) LoadDbConfigFromDb() error {
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
dbConfigField := dbConfigReflectValue.Field(i)
currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable)
var storedConfigVar model.AppConfigVariable
if err := s.db.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
return err
}
if common.EnvConfig.UiConfigDisabled {
storedConfigVar.Value = s.getConfigVariableFromEnvironmentVariable(currentConfigVar.Key, storedConfigVar.DefaultValue)
} else if storedConfigVar.Value == "" && storedConfigVar.DefaultValue != "" {
storedConfigVar.Value = storedConfigVar.DefaultValue
}
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
}
return nil
}
func (s *AppConfigService) getConfigVariableFromEnvironmentVariable(key, fallbackValue string) string {
environmentVariableName := utils.CamelCaseToScreamingSnakeCase(key)
if value, exists := os.LookupEnv(environmentVariableName); exists {
return value
}
return fallbackValue
} }

View File

@@ -0,0 +1,523 @@
package service
import (
"sync/atomic"
"testing"
"time"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/stretchr/testify/require"
)
// NewTestAppConfigService is a function used by tests to create AppConfigService objects with pre-defined configuration values
func NewTestAppConfigService(config *model.AppConfig) *AppConfigService {
service := &AppConfigService{
dbConfig: atomic.Pointer[model.AppConfig]{},
}
service.dbConfig.Store(config)
return service
}
func TestLoadDbConfig(t *testing.T) {
t.Run("empty config table", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
// Load the config
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Config should be equal to default config
require.Equal(t, service.GetDbConfig(), service.getDefaultDbConfig())
})
t.Run("loads value from config table", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Populate the config table with some initial values
err := db.
Create([]model.AppConfigVariable{
// Should be set to the default value because it's an empty string
{Key: "appName", Value: ""},
// Overrides default value
{Key: "sessionDuration", Value: "5"},
// Does not have a default value
{Key: "smtpHost", Value: "example"},
}).
Error
require.NoError(t, err)
// Load the config
service := &AppConfigService{
db: db,
}
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Values should match expected ones
expect := service.getDefaultDbConfig()
expect.SessionDuration.Value = "5"
expect.SmtpHost.Value = "example"
require.Equal(t, service.GetDbConfig(), expect)
})
t.Run("ignores unknown config keys", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Add an entry with a key that doesn't exist in the config struct
err := db.Create([]model.AppConfigVariable{
{Key: "__nonExistentKey", Value: "some value"},
{Key: "appName", Value: "TestApp"}, // This one should still be loaded
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
// This should not fail, just ignore the unknown key
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
config := service.GetDbConfig()
require.Equal(t, "TestApp", config.AppName.Value)
})
t.Run("loading config multiple times", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Initial state
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "InitialApp"},
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
require.Equal(t, "InitialApp", service.GetDbConfig().AppName.Value)
// Update the database value
err = db.Model(&model.AppConfigVariable{}).
Where("key = ?", "appName").
Update("value", "UpdatedApp").Error
require.NoError(t, err)
// Load the config again, it should reflect the updated value
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
require.Equal(t, "UpdatedApp", service.GetDbConfig().AppName.Value)
})
t.Run("loads config from env when UiConfigDisabled is true", func(t *testing.T) {
// Save the original state and restore it after the test
originalUiConfigDisabled := common.EnvConfig.UiConfigDisabled
defer func() {
common.EnvConfig.UiConfigDisabled = originalUiConfigDisabled
}()
// Set environment variables for testing
t.Setenv("APP_NAME", "EnvTest App")
t.Setenv("SESSION_DURATION", "45")
// Enable UiConfigDisabled to load from env
common.EnvConfig.UiConfigDisabled = true
// Create database with config that should be ignored
db := newAppConfigTestDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
// Load the config
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Config should be loaded from env, not DB
config := service.GetDbConfig()
require.Equal(t, "EnvTest App", config.AppName.Value, "Should load appName from env")
require.Equal(t, "45", config.SessionDuration.Value, "Should load sessionDuration from env")
})
t.Run("ignores env vars when UiConfigDisabled is false", func(t *testing.T) {
// Save the original state and restore it after the test
originalUiConfigDisabled := common.EnvConfig.UiConfigDisabled
defer func() {
common.EnvConfig.UiConfigDisabled = originalUiConfigDisabled
}()
// Set environment variables that should be ignored
t.Setenv("APP_NAME", "EnvTest App")
t.Setenv("SESSION_DURATION", "45")
// Make sure UiConfigDisabled is false to load from DB
common.EnvConfig.UiConfigDisabled = false
// Create database with config values that should take precedence
db := newAppConfigTestDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
// Load the config
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Config should be loaded from DB, not env
config := service.GetDbConfig()
require.Equal(t, "DB App", config.AppName.Value, "Should load appName from DB, not env")
require.Equal(t, "120", config.SessionDuration.Value, "Should load sessionDuration from DB, not env")
})
}
func TestUpdateAppConfigValues(t *testing.T) {
t.Run("update single value", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Update a single config value
err = service.UpdateAppConfigValues(t.Context(), "appName", "Test App")
require.NoError(t, err)
// Verify in-memory config was updated
config := service.GetDbConfig()
require.Equal(t, "Test App", config.AppName.Value)
// Verify database was updated
var dbValue model.AppConfigVariable
err = db.Where("key = ?", "appName").First(&dbValue).Error
require.NoError(t, err)
require.Equal(t, "Test App", dbValue.Value)
})
t.Run("update multiple values", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Update multiple config values
err = service.UpdateAppConfigValues(
t.Context(),
"appName", "Test App",
"sessionDuration", "30",
"smtpHost", "mail.example.com",
)
require.NoError(t, err)
// Verify in-memory config was updated
config := service.GetDbConfig()
require.Equal(t, "Test App", config.AppName.Value)
require.Equal(t, "30", config.SessionDuration.Value)
require.Equal(t, "mail.example.com", config.SmtpHost.Value)
// Verify database was updated
var count int64
db.Model(&model.AppConfigVariable{}).Count(&count)
require.Equal(t, int64(3), count)
var appName, sessionDuration, smtpHost model.AppConfigVariable
err = db.Where("key = ?", "appName").First(&appName).Error
require.NoError(t, err)
require.Equal(t, "Test App", appName.Value)
err = db.Where("key = ?", "sessionDuration").First(&sessionDuration).Error
require.NoError(t, err)
require.Equal(t, "30", sessionDuration.Value)
err = db.Where("key = ?", "smtpHost").First(&smtpHost).Error
require.NoError(t, err)
require.Equal(t, "mail.example.com", smtpHost.Value)
})
t.Run("empty value resets to default", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// First change the value
err = service.UpdateAppConfigValues(t.Context(), "sessionDuration", "30")
require.NoError(t, err)
require.Equal(t, "30", service.GetDbConfig().SessionDuration.Value)
// Now set it to empty which should use default value
err = service.UpdateAppConfigValues(t.Context(), "sessionDuration", "")
require.NoError(t, err)
require.Equal(t, "60", service.GetDbConfig().SessionDuration.Value) // Default value from getDefaultDbConfig
})
t.Run("error with odd number of arguments", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Try to update with odd number of arguments
err = service.UpdateAppConfigValues(t.Context(), "appName", "Test App", "sessionDuration")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid number of arguments")
})
t.Run("error with invalid key", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Try to update with invalid key
err = service.UpdateAppConfigValues(t.Context(), "nonExistentKey", "some value")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid configuration key")
})
}
func TestUpdateAppConfig(t *testing.T) {
t.Run("updates configuration values from DTO", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Create update DTO
input := dto.AppConfigUpdateDto{
AppName: "Updated App Name",
SessionDuration: "120",
SmtpHost: "smtp.example.com",
SmtpPort: "587",
}
// Update config
updatedVars, err := service.UpdateAppConfig(t.Context(), input)
require.NoError(t, err)
// Verify returned updated variables
require.NotEmpty(t, updatedVars)
var foundAppName, foundSessionDuration, foundSmtpHost, foundSmtpPort bool
for _, v := range updatedVars {
switch v.Key {
case "appName":
require.Equal(t, "Updated App Name", v.Value)
foundAppName = true
case "sessionDuration":
require.Equal(t, "120", v.Value)
foundSessionDuration = true
case "smtpHost":
require.Equal(t, "smtp.example.com", v.Value)
foundSmtpHost = true
case "smtpPort":
require.Equal(t, "587", v.Value)
foundSmtpPort = true
}
}
require.True(t, foundAppName)
require.True(t, foundSessionDuration)
require.True(t, foundSmtpHost)
require.True(t, foundSmtpPort)
// Verify in-memory config was updated
config := service.GetDbConfig()
require.Equal(t, "Updated App Name", config.AppName.Value)
require.Equal(t, "120", config.SessionDuration.Value)
require.Equal(t, "smtp.example.com", config.SmtpHost.Value)
require.Equal(t, "587", config.SmtpPort.Value)
// Verify database was updated
var appName, sessionDuration, smtpHost, smtpPort model.AppConfigVariable
err = db.Where("key = ?", "appName").First(&appName).Error
require.NoError(t, err)
require.Equal(t, "Updated App Name", appName.Value)
err = db.Where("key = ?", "sessionDuration").First(&sessionDuration).Error
require.NoError(t, err)
require.Equal(t, "120", sessionDuration.Value)
err = db.Where("key = ?", "smtpHost").First(&smtpHost).Error
require.NoError(t, err)
require.Equal(t, "smtp.example.com", smtpHost.Value)
err = db.Where("key = ?", "smtpPort").First(&smtpPort).Error
require.NoError(t, err)
require.Equal(t, "587", smtpPort.Value)
})
t.Run("empty values reset to defaults", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config and modify some values
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// First set some non-default values
err = service.UpdateAppConfigValues(t.Context(),
"appName", "Custom App",
"sessionDuration", "120",
)
require.NoError(t, err)
// Create update DTO with empty values to reset to defaults
input := dto.AppConfigUpdateDto{
AppName: "", // Should reset to default "Pocket ID"
SessionDuration: "", // Should reset to default "60"
}
// Update config
updatedVars, err := service.UpdateAppConfig(t.Context(), input)
require.NoError(t, err)
// Verify returned updated variables (they should be empty strings in DB)
var foundAppName, foundSessionDuration bool
for _, v := range updatedVars {
switch v.Key {
case "appName":
require.Equal(t, "Pocket ID", v.Value) // Returns the default value
foundAppName = true
case "sessionDuration":
require.Equal(t, "60", v.Value) // Returns the default value
foundSessionDuration = true
}
}
require.True(t, foundAppName)
require.True(t, foundSessionDuration)
// Verify in-memory config was reset to defaults
config := service.GetDbConfig()
require.Equal(t, "Pocket ID", config.AppName.Value) // Default value
require.Equal(t, "60", config.SessionDuration.Value) // Default value
// Verify database was updated with empty values
for _, key := range []string{"appName", "sessionDuration"} {
var loaded model.AppConfigVariable
err = db.Where("key = ?", key).First(&loaded).Error
require.NoErrorf(t, err, "Failed to load DB value for key '%s'", key)
require.Emptyf(t, loaded.Value, "Loaded value for key '%s' is not empty", key)
}
})
t.Run("cannot update when UiConfigDisabled is true", func(t *testing.T) {
// Save the original state and restore it after the test
originalUiConfigDisabled := common.EnvConfig.UiConfigDisabled
defer func() {
common.EnvConfig.UiConfigDisabled = originalUiConfigDisabled
}()
// Disable UI config
common.EnvConfig.UiConfigDisabled = true
db := newAppConfigTestDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Try to update config
_, err = service.UpdateAppConfig(t.Context(), dto.AppConfigUpdateDto{
AppName: "Should Not Update",
})
// Should get a UiConfigDisabledError
require.Error(t, err)
var uiConfigDisabledErr *common.UiConfigDisabledError
require.ErrorAs(t, err, &uiConfigDisabledErr)
})
}
// Implements gorm's logger.Writer interface
type testLoggerAdapter struct {
t *testing.T
}
func (l testLoggerAdapter) Printf(format string, args ...any) {
l.t.Logf(format, args...)
}
func newAppConfigTestDatabaseForTest(t *testing.T) *gorm.DB {
t.Helper()
// Get a name for this in-memory database that is specific to the test
dbName := utils.CreateSha256Hash(t.Name())
// Connect to a new in-memory SQL database
db, err := gorm.Open(
sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
&gorm.Config{
TranslateError: true,
Logger: logger.New(
testLoggerAdapter{t: t},
logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: false,
ParameterizedQueries: false,
Colorful: false,
},
),
})
require.NoError(t, err, "Failed to connect to test database")
// Create the app_config_variables table
err = db.Exec(`
CREATE TABLE app_config_variables
(
key VARCHAR(100) NOT NULL PRIMARY KEY,
value TEXT NOT NULL
)
`).Error
require.NoError(t, err, "Failed to create test config table")
return db
}

View File

@@ -1,9 +1,12 @@
package service package service
import ( import (
"context"
"fmt"
"log" "log"
userAgentParser "github.com/mileusna/useragent" userAgentParser "github.com/mileusna/useragent"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/email" "github.com/pocket-id/pocket-id/backend/internal/utils/email"
@@ -22,10 +25,10 @@ func NewAuditLogService(db *gorm.DB, appConfigService *AppConfigService, emailSe
} }
// Create creates a new audit log entry in the database // Create creates a new audit log entry in the database
func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData) model.AuditLog { func (s *AuditLogService) Create(ctx context.Context, event model.AuditLogEvent, ipAddress, userAgent, userID string, data model.AuditLogData, tx *gorm.DB) model.AuditLog {
country, city, err := s.geoliteService.GetLocationByIP(ipAddress) country, city, err := s.geoliteService.GetLocationByIP(ipAddress)
if err != nil { if err != nil {
log.Printf("Failed to get IP location: %v\n", err) log.Printf("Failed to get IP location: %v", err)
} }
auditLog := model.AuditLog{ auditLog := model.AuditLog{
@@ -39,8 +42,12 @@ func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent
} }
// Save the audit log in the database // Save the audit log in the database
if err := s.db.Create(&auditLog).Error; err != nil { err = tx.
log.Printf("Failed to create audit log: %v\n", err) WithContext(ctx).
Create(&auditLog).
Error
if err != nil {
log.Printf("Failed to create audit log: %v", err)
return model.AuditLog{} return model.AuditLog{}
} }
@@ -48,25 +55,42 @@ func (s *AuditLogService) Create(event model.AuditLogEvent, ipAddress, userAgent
} }
// CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before // CreateNewSignInWithEmail creates a new audit log entry in the database and sends an email if the device hasn't been used before
func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID string) model.AuditLog { func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddress, userAgent, userID string, tx *gorm.DB) model.AuditLog {
createdAuditLog := s.Create(model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}) createdAuditLog := s.Create(ctx, model.AuditLogEventSignIn, ipAddress, userAgent, userID, model.AuditLogData{}, tx)
// Count the number of times the user has logged in from the same device // Count the number of times the user has logged in from the same device
var count int64 var count int64
err := s.db.Model(&model.AuditLog{}).Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent).Count(&count).Error err := tx.
WithContext(ctx).
Model(&model.AuditLog{}).
Where("user_id = ? AND ip_address = ? AND user_agent = ?", userID, ipAddress, userAgent).
Count(&count).
Error
if err != nil { if err != nil {
log.Printf("Failed to count audit logs: %v\n", err) log.Printf("Failed to count audit logs: %v\n", err)
return createdAuditLog return createdAuditLog
} }
// If the user hasn't logged in from the same device before and email notifications are enabled, send an email // If the user hasn't logged in from the same device before and email notifications are enabled, send an email
if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.Value == "true" && count <= 1 { if s.appConfigService.GetDbConfig().EmailLoginNotificationEnabled.IsTrue() && count <= 1 {
// We use a background context here as this is running in a goroutine
//nolint:contextcheck
go func() { go func() {
var user model.User innerCtx := context.Background()
s.db.Where("id = ?", userID).First(&user)
err := SendEmail(s.emailService, email.Address{ // Note we don't use the transaction here because this is running in background
Name: user.Username, var user model.User
innerErr := s.db.
WithContext(innerCtx).
Where("id = ?", userID).
First(&user).
Error
if innerErr != nil {
log.Printf("Failed to load user: %v", innerErr)
}
innerErr = SendEmail(innerCtx, s.emailService, email.Address{
Name: user.FullName(),
Email: user.Email, Email: user.Email,
}, NewLoginTemplate, &NewLoginTemplateData{ }, NewLoginTemplate, &NewLoginTemplateData{
IPAddress: ipAddress, IPAddress: ipAddress,
@@ -75,8 +99,8 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
Device: s.DeviceStringFromUserAgent(userAgent), Device: s.DeviceStringFromUserAgent(userAgent),
DateTime: createdAuditLog.CreatedAt.UTC(), DateTime: createdAuditLog.CreatedAt.UTC(),
}) })
if err != nil { if innerErr != nil {
log.Printf("Failed to send email to '%s': %v\n", user.Email, err) log.Printf("Failed to send email to '%s': %v", user.Email, innerErr)
} }
}() }()
} }
@@ -85,9 +109,12 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
} }
// ListAuditLogsForUser retrieves all audit logs for a given user ID // ListAuditLogsForUser retrieves all audit logs for a given user ID
func (s *AuditLogService) ListAuditLogsForUser(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.AuditLog, utils.PaginationResponse, error) { func (s *AuditLogService) ListAuditLogsForUser(ctx context.Context, userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.AuditLog, utils.PaginationResponse, error) {
var logs []model.AuditLog var logs []model.AuditLog
query := s.db.Model(&model.AuditLog{}).Where("user_id = ?", userID) query := s.db.
WithContext(ctx).
Model(&model.AuditLog{}).
Where("user_id = ?", userID)
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs) pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs)
return logs, pagination, err return logs, pagination, err
@@ -97,3 +124,99 @@ func (s *AuditLogService) DeviceStringFromUserAgent(userAgent string) string {
ua := userAgentParser.Parse(userAgent) ua := userAgentParser.Parse(userAgent)
return ua.Name + " on " + ua.OS + " " + ua.OSVersion return ua.Name + " on " + ua.OS + " " + ua.OSVersion
} }
func (s *AuditLogService) ListAllAuditLogs(ctx context.Context, sortedPaginationRequest utils.SortedPaginationRequest, filters dto.AuditLogFilterDto) ([]model.AuditLog, utils.PaginationResponse, error) {
var logs []model.AuditLog
query := s.db.
WithContext(ctx).
Preload("User").
Model(&model.AuditLog{})
if filters.UserID != "" {
query = query.Where("user_id = ?", filters.UserID)
}
if filters.Event != "" {
query = query.Where("event = ?", filters.Event)
}
if filters.ClientName != "" {
dialect := s.db.Name()
switch dialect {
case "sqlite":
query = query.Where("json_extract(data, '$.clientName') = ?", filters.ClientName)
case "postgres":
query = query.Where("data->>'clientName' = ?", filters.ClientName)
default:
return nil, utils.PaginationResponse{}, fmt.Errorf("unsupported database dialect: %s", dialect)
}
}
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &logs)
if err != nil {
return nil, pagination, err
}
return logs, pagination, nil
}
func (s *AuditLogService) ListUsernamesWithIds(ctx context.Context) (users map[string]string, err error) {
query := s.db.
WithContext(ctx).
Joins("User").
Model(&model.AuditLog{}).
Select("DISTINCT \"User\".id, \"User\".username").
Where("\"User\".username IS NOT NULL")
type Result struct {
ID string `gorm:"column:id"`
Username string `gorm:"column:username"`
}
var results []Result
if err := query.Find(&results).Error; err != nil {
return nil, fmt.Errorf("failed to query user IDs: %w", err)
}
users = make(map[string]string, len(results))
for _, result := range results {
users[result.ID] = result.Username
}
return users, nil
}
func (s *AuditLogService) ListClientNames(ctx context.Context) (clientNames []string, err error) {
dialect := s.db.Name()
query := s.db.
WithContext(ctx).
Model(&model.AuditLog{})
switch dialect {
case "sqlite":
query = query.
Select("DISTINCT json_extract(data, '$.clientName') AS client_name").
Where("json_extract(data, '$.clientName') IS NOT NULL")
case "postgres":
query = query.
Select("DISTINCT data->>'clientName' AS client_name").
Where("data->>'clientName' IS NOT NULL")
default:
return nil, fmt.Errorf("unsupported database dialect: %s", dialect)
}
type Result struct {
ClientName string `gorm:"column:client_name"`
}
var results []Result
if err := query.Find(&results).Error; err != nil {
return nil, fmt.Errorf("failed to query client IDs: %w", err)
}
clientNames = make([]string, len(results))
for i, result := range results {
clientNames[i] = result.ClientName
}
return clientNames, nil
}

View File

@@ -1,34 +1,14 @@
package service package service
import ( import (
"context"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"gorm.io/gorm" "gorm.io/gorm"
) )
// Reserved claims
var reservedClaims = map[string]struct{}{
"given_name": {},
"family_name": {},
"name": {},
"email": {},
"preferred_username": {},
"groups": {},
"sub": {},
"iss": {},
"aud": {},
"exp": {},
"iat": {},
"auth_time": {},
"nonce": {},
"acr": {},
"amr": {},
"azp": {},
"nbf": {},
"jti": {},
}
type CustomClaimService struct { type CustomClaimService struct {
db *gorm.DB db *gorm.DB
} }
@@ -39,8 +19,30 @@ func NewCustomClaimService(db *gorm.DB) *CustomClaimService {
// isReservedClaim checks if a claim key is reserved e.g. email, preferred_username // isReservedClaim checks if a claim key is reserved e.g. email, preferred_username
func isReservedClaim(key string) bool { func isReservedClaim(key string) bool {
_, ok := reservedClaims[key] switch key {
return ok case "given_name",
"family_name",
"name",
"email",
"preferred_username",
"groups",
TokenTypeClaim,
"sub",
"iss",
"aud",
"exp",
"iat",
"auth_time",
"nonce",
"acr",
"amr",
"azp",
"nbf",
"jti":
return true
default:
return false
}
} }
// idType is the type of the id used to identify the user or user group // idType is the type of the id used to identify the user or user group
@@ -52,28 +54,37 @@ const (
) )
// UpdateCustomClaimsForUser updates the custom claims for a user // UpdateCustomClaimsForUser updates the custom claims for a user
func (s *CustomClaimService) UpdateCustomClaimsForUser(userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { func (s *CustomClaimService) UpdateCustomClaimsForUser(ctx context.Context, userID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(UserID, userID, claims) return s.updateCustomClaims(ctx, UserID, userID, claims)
} }
// UpdateCustomClaimsForUserGroup updates the custom claims for a user group // UpdateCustomClaimsForUserGroup updates the custom claims for a user group
func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { func (s *CustomClaimService) UpdateCustomClaimsForUserGroup(ctx context.Context, userGroupID string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
return s.updateCustomClaims(UserGroupID, userGroupID, claims) return s.updateCustomClaims(ctx, UserGroupID, userGroupID, claims)
} }
// updateCustomClaims updates the custom claims for a user or user group // updateCustomClaims updates the custom claims for a user or user group
func (s *CustomClaimService) updateCustomClaims(idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) { func (s *CustomClaimService) updateCustomClaims(ctx context.Context, idType idType, value string, claims []dto.CustomClaimCreateDto) ([]model.CustomClaim, error) {
// Check for duplicate keys in the claims slice // Check for duplicate keys in the claims slice
seenKeys := make(map[string]bool) seenKeys := make(map[string]struct{})
for _, claim := range claims { for _, claim := range claims {
if seenKeys[claim.Key] { if _, ok := seenKeys[claim.Key]; ok {
return nil, &common.DuplicateClaimError{Key: claim.Key} return nil, &common.DuplicateClaimError{Key: claim.Key}
} }
seenKeys[claim.Key] = true seenKeys[claim.Key] = struct{}{}
} }
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var existingClaims []model.CustomClaim var existingClaims []model.CustomClaim
err := s.db.Where(string(idType), value).Find(&existingClaims).Error err := tx.
WithContext(ctx).
Where(string(idType), value).
Find(&existingClaims).
Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -87,8 +98,12 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
break break
} }
} }
if !found { if !found {
err = s.db.Delete(&existingClaim).Error err = tx.
WithContext(ctx).
Delete(&existingClaim).
Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -105,14 +120,20 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
Value: claim.Value, Value: claim.Value,
} }
if idType == UserID { switch idType {
case UserID:
customClaim.UserID = &value customClaim.UserID = &value
} else if idType == UserGroupID { case UserGroupID:
customClaim.UserGroupID = &value customClaim.UserGroupID = &value
} }
// Update the claim if it already exists or create a new one // Update the claim if it already exists or create a new one
err = s.db.Where(string(idType)+" = ? AND key = ?", value, claim.Key).Assign(&customClaim).FirstOrCreate(&model.CustomClaim{}).Error err = tx.
WithContext(ctx).
Where(string(idType)+" = ? AND key = ?", value, claim.Key).
Assign(&customClaim).
FirstOrCreate(&model.CustomClaim{}).
Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -120,7 +141,16 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
// Get the updated claims // Get the updated claims
var updatedClaims []model.CustomClaim var updatedClaims []model.CustomClaim
err = s.db.Where(string(idType)+" = ?", value).Find(&updatedClaims).Error err = tx.
WithContext(ctx).
Where(string(idType)+" = ?", value).
Find(&updatedClaims).
Error
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -128,23 +158,31 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
return updatedClaims, nil return updatedClaims, nil
} }
func (s *CustomClaimService) GetCustomClaimsForUser(userID string) ([]model.CustomClaim, error) { func (s *CustomClaimService) GetCustomClaimsForUser(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
var customClaims []model.CustomClaim var customClaims []model.CustomClaim
err := s.db.Where("user_id = ?", userID).Find(&customClaims).Error err := tx.
WithContext(ctx).
Where("user_id = ?", userID).
Find(&customClaims).
Error
return customClaims, err return customClaims, err
} }
func (s *CustomClaimService) GetCustomClaimsForUserGroup(userGroupID string) ([]model.CustomClaim, error) { func (s *CustomClaimService) GetCustomClaimsForUserGroup(ctx context.Context, userGroupID string, tx *gorm.DB) ([]model.CustomClaim, error) {
var customClaims []model.CustomClaim var customClaims []model.CustomClaim
err := s.db.Where("user_group_id = ?", userGroupID).Find(&customClaims).Error err := tx.
WithContext(ctx).
Where("user_group_id = ?", userGroupID).
Find(&customClaims).
Error
return customClaims, err return customClaims, err
} }
// GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of, // GetCustomClaimsForUserWithUserGroups returns the custom claims of a user and all user groups the user is a member of,
// prioritizing the user's claims over user group claims with the same key. // prioritizing the user's claims over user group claims with the same key.
func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string) ([]model.CustomClaim, error) { func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(ctx context.Context, userID string, tx *gorm.DB) ([]model.CustomClaim, error) {
// Get the custom claims of the user // Get the custom claims of the user
customClaims, err := s.GetCustomClaimsForUser(userID) customClaims, err := s.GetCustomClaimsForUser(ctx, userID, tx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -157,7 +195,9 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string)
// Get all user groups of the user // Get all user groups of the user
var userGroupsOfUser []model.UserGroup var userGroupsOfUser []model.UserGroup
err = s.db.Preload("CustomClaims"). err = tx.
WithContext(ctx).
Preload("CustomClaims").
Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id"). Joins("JOIN user_groups_users ON user_groups_users.user_group_id = user_groups.id").
Where("user_groups_users.user_id = ?", userID). Where("user_groups_users.user_id = ?", userID).
Find(&userGroupsOfUser).Error Find(&userGroupsOfUser).Error
@@ -185,10 +225,12 @@ func (s *CustomClaimService) GetCustomClaimsForUserWithUserGroups(userID string)
} }
// GetSuggestions returns a list of custom claim keys that have been used before // GetSuggestions returns a list of custom claim keys that have been used before
func (s *CustomClaimService) GetSuggestions() ([]string, error) { func (s *CustomClaimService) GetSuggestions(ctx context.Context) ([]string, error) {
var customClaimsKeys []string var customClaimsKeys []string
err := s.db.Model(&model.CustomClaim{}). err := s.db.
WithContext(ctx).
Model(&model.CustomClaim{}).
Group("key"). Group("key").
Order("COUNT(*) DESC"). Order("COUNT(*) DESC").
Pluck("key", &customClaimsKeys).Error Pluck("key", &customClaimsKeys).Error

View File

@@ -1,6 +1,9 @@
//go:build e2etest
package service package service
import ( import (
"context"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
@@ -32,6 +35,7 @@ func NewTestService(db *gorm.DB, appConfigService *AppConfigService, jwtService
return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService} return &TestService{db: db, appConfigService: appConfigService, jwtService: jwtService}
} }
//nolint:gocognit
func (s *TestService) SeedDatabase() error { func (s *TestService) SeedDatabase() error {
return s.db.Transaction(func(tx *gorm.DB) error { return s.db.Transaction(func(tx *gorm.DB) error {
users := []model.User{ users := []model.User{
@@ -152,6 +156,17 @@ func (s *TestService) SeedDatabase() error {
return err return err
} }
refreshToken := model.OidcRefreshToken{
Token: utils.CreateSha256Hash("ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo"),
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
Scope: "openid profile email",
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
}
if err := tx.Create(&refreshToken).Error; err != nil {
return err
}
accessToken := model.OneTimeAccessToken{ accessToken := model.OneTimeAccessToken{
Token: "one-time-token", Token: "one-time-token",
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)), ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
@@ -174,11 +189,8 @@ func (s *TestService) SeedDatabase() error {
// openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 | \ // openssl genpkey -algorithm EC -pkeyopt ec_paramgen_curve:P-256 | \
// openssl pkcs8 -topk8 -nocrypt | tee >(openssl pkey -pubout) // openssl pkcs8 -topk8 -nocrypt | tee >(openssl pkey -pubout)
publicKeyPasskey1, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==") publicKeyPasskey1, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
publicKeyPasskey2, err := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==") publicKeyPasskey2, _ := s.getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEj4qA0PrZzg8Co1C27nyUbzrp8Ewjr7eOlGI2LfrzmbL5nPhZRAdJ3hEaqrHMSnJBhfMqtQGKwDYpaLIQFAKLhw==")
if err != nil {
return err
}
webauthnCredentials := []model.WebauthnCredential{ webauthnCredentials := []model.WebauthnCredential{
{ {
Name: "Passkey 1", Name: "Passkey 1",
@@ -288,26 +300,22 @@ func (s *TestService) ResetApplicationImages() error {
return nil return nil
} }
func (s *TestService) ResetAppConfig() error { func (s *TestService) ResetAppConfig(ctx context.Context) error {
// Reseed the config variables // Reset all app config variables to their default values in the database
if err := s.appConfigService.InitDbConfig(); err != nil { err := s.db.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&model.AppConfigVariable{}).Update("value", "").Error
return err if err != nil {
}
// Reset all app config variables to their default values
if err := s.db.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&model.AppConfigVariable{}).Update("value", "").Error; err != nil {
return err return err
} }
// Reload the app config from the database after resetting the values // Reload the app config from the database after resetting the values
return s.appConfigService.LoadDbConfigFromDb() return s.appConfigService.LoadDbConfig(ctx)
} }
func (s *TestService) SetJWTKeys() { func (s *TestService) SetJWTKeys() {
const privateKeyString = `{"alg":"RS256","d":"mvMDWSdPPvcum0c0iEHE2gbqtV2NKMmLwrl9E6K7g8lTV95SePLnW_bwyMPV7EGp7PQk3l17I5XRhFjze7GqTnFIOgKzMianPs7jv2ELtBMGK0xOPATgu1iGb70xZ6vcvuEfRyY3dJ0zr4jpUdVuXwKmx9rK4IdZn2dFCKfvSuspqIpz11RhF1ALrqDLkxGVv7ZwNh0_VhJZU9hcjG5l6xc7rQEKpPRkZp0IdjkGS8Z0FskoVaiRIWAbZuiVFB9WCW8k1czC4HQTPLpII01bUQx2ludbm0UlXRgVU9ptUUbU7GAImQqTOW8LfPGklEvcgzlIlR_oqw4P9yBxLi-yMQ","dp":"pvNCSnnhbo8Igw9psPR-DicxFnkXlu_ix4gpy6efTrxA-z1VDFDioJ814vKQNioYDzpyAP1gfMPhRkvG_q0hRZsJah3Sb9dfA-WkhSWY7lURQP4yIBTMU0PF_rEATuS7lRciYk1SOx5fqXZd3m_LP0vpBC4Ujlq6NAq6CIjCnms","dq":"TtUVGCCkPNgfOLmkYXu7dxxUCV5kB01-xAEK2OY0n0pG8vfDophH4_D_ZC7nvJ8J9uDhs_3JStexq1lIvaWtG99RNTChIEDzpdn6GH9yaVcb_eB4uJjrNm64FhF8PGCCwxA-xMCZMaARKwhMB2_IOMkxUbWboL3gnhJ2rDO_QO0","e":"AQAB","kid":"8uHDw3M6rf8","kty":"RSA","n":"yaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC-585UXacoJ0chUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl_4EDDTO8HwawTjwkPoQlRzeByhlvGPVvwgB3Fn93B8QJ_cZhXKxJvjjrC_8Pk76heC_ntEMru71Ix77BoC3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeOZl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJw","p":"_Yylc9e07CKdqNRD2EosMC2mrhrEa9j5oY_l00Qyy4-jmCA59Q9viyqvveRo0U7cRvFA5BWgWN6GGLh1DG3X-QBqVr0dnk3uzbobb55RYUXyPLuBZI2q6w2oasbiDwPdY7KpkVv_H-bpITQlyDvO8hhucA6rUV7F6KTQVz8M3Ms","q":"y5p3hch-7jJ21TkAhp_Vk1fLCAuD4tbErwQs2of9ja8sB4iJOs5Wn6HD3P7Mc8Plye7qaLHvzc8I5g0tPKWvC0DPd_FLPXiWwMVAzee3NUX_oGeJNOQp11y1w_KqdO9qZqHSEPZ3NcFL_SZMFgggxhM1uzRiPzsVN0lnD_6prZU","qi":"2Grt6uXHm61ji3xSdkBWNtUnj19vS1-7rFJp5SoYztVQVThf_W52BAiXKBdYZDRVoItC_VS2NvAOjeJjhYO_xQ_q3hK7MdtuXfEPpLnyXKkmWo3lrJ26wbeF6l05LexCkI7ShsOuSt-dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI","use":"sig"}` const privateKeyString = `{"alg":"RS256","d":"mvMDWSdPPvcum0c0iEHE2gbqtV2NKMmLwrl9E6K7g8lTV95SePLnW_bwyMPV7EGp7PQk3l17I5XRhFjze7GqTnFIOgKzMianPs7jv2ELtBMGK0xOPATgu1iGb70xZ6vcvuEfRyY3dJ0zr4jpUdVuXwKmx9rK4IdZn2dFCKfvSuspqIpz11RhF1ALrqDLkxGVv7ZwNh0_VhJZU9hcjG5l6xc7rQEKpPRkZp0IdjkGS8Z0FskoVaiRIWAbZuiVFB9WCW8k1czC4HQTPLpII01bUQx2ludbm0UlXRgVU9ptUUbU7GAImQqTOW8LfPGklEvcgzlIlR_oqw4P9yBxLi-yMQ","dp":"pvNCSnnhbo8Igw9psPR-DicxFnkXlu_ix4gpy6efTrxA-z1VDFDioJ814vKQNioYDzpyAP1gfMPhRkvG_q0hRZsJah3Sb9dfA-WkhSWY7lURQP4yIBTMU0PF_rEATuS7lRciYk1SOx5fqXZd3m_LP0vpBC4Ujlq6NAq6CIjCnms","dq":"TtUVGCCkPNgfOLmkYXu7dxxUCV5kB01-xAEK2OY0n0pG8vfDophH4_D_ZC7nvJ8J9uDhs_3JStexq1lIvaWtG99RNTChIEDzpdn6GH9yaVcb_eB4uJjrNm64FhF8PGCCwxA-xMCZMaARKwhMB2_IOMkxUbWboL3gnhJ2rDO_QO0","e":"AQAB","kid":"8uHDw3M6rf8","kty":"RSA","n":"yaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC-585UXacoJ0chUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl_4EDDTO8HwawTjwkPoQlRzeByhlvGPVvwgB3Fn93B8QJ_cZhXKxJvjjrC_8Pk76heC_ntEMru71Ix77BoC3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeOZl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJw","p":"_Yylc9e07CKdqNRD2EosMC2mrhrEa9j5oY_l00Qyy4-jmCA59Q9viyqvveRo0U7cRvFA5BWgWN6GGLh1DG3X-QBqVr0dnk3uzbobb55RYUXyPLuBZI2q6w2oasbiDwPdY7KpkVv_H-bpITQlyDvO8hhucA6rUV7F6KTQVz8M3Ms","q":"y5p3hch-7jJ21TkAhp_Vk1fLCAuD4tbErwQs2of9ja8sB4iJOs5Wn6HD3P7Mc8Plye7qaLHvzc8I5g0tPKWvC0DPd_FLPXiWwMVAzee3NUX_oGeJNOQp11y1w_KqdO9qZqHSEPZ3NcFL_SZMFgggxhM1uzRiPzsVN0lnD_6prZU","qi":"2Grt6uXHm61ji3xSdkBWNtUnj19vS1-7rFJp5SoYztVQVThf_W52BAiXKBdYZDRVoItC_VS2NvAOjeJjhYO_xQ_q3hK7MdtuXfEPpLnyXKkmWo3lrJ26wbeF6l05LexCkI7ShsOuSt-dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI","use":"sig"}`
privateKey, _ := jwk.ParseKey([]byte(privateKeyString)) privateKey, _ := jwk.ParseKey([]byte(privateKeyString))
s.jwtService.SetKey(privateKey) _ = s.jwtService.SetKey(privateKey)
} }
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key // getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key

View File

@@ -2,24 +2,28 @@ package service
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
"gorm.io/gorm"
htemplate "html/template" htemplate "html/template"
"io"
"mime/multipart" "mime/multipart"
"mime/quotedprintable" "mime/quotedprintable"
"net/textproto" "net/textproto"
"os" "os"
"strings"
ttemplate "text/template" ttemplate "text/template"
"time" "time"
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/google/uuid" "github.com/google/uuid"
"strings" "gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
) )
type EmailService struct { type EmailService struct {
@@ -48,22 +52,28 @@ func NewEmailService(appConfigService *AppConfigService, db *gorm.DB) (*EmailSer
}, nil }, nil
} }
func (srv *EmailService) SendTestEmail(recipientUserId string) error { func (srv *EmailService) SendTestEmail(ctx context.Context, recipientUserId string) error {
var user model.User var user model.User
if err := srv.db.First(&user, "id = ?", recipientUserId).Error; err != nil { err := srv.db.
WithContext(ctx).
First(&user, "id = ?", recipientUserId).
Error
if err != nil {
return err return err
} }
return SendEmail(srv, return SendEmail(ctx, srv,
email.Address{ email.Address{
Email: user.Email, Email: user.Email,
Name: user.FullName(), Name: user.FullName(),
}, TestTemplate, nil) }, TestTemplate, nil)
} }
func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error { func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error {
dbConfig := srv.appConfigService.GetDbConfig()
data := &email.TemplateData[V]{ data := &email.TemplateData[V]{
AppName: srv.appConfigService.DbConfig.AppName.Value, AppName: dbConfig.AppName.Value,
LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo", LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo",
Data: tData, Data: tData,
} }
@@ -78,8 +88,8 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
c.AddHeader("Subject", template.Title(data)) c.AddHeader("Subject", template.Title(data))
c.AddAddressHeader("From", []email.Address{ c.AddAddressHeader("From", []email.Address{
{ {
Email: srv.appConfigService.DbConfig.SmtpFrom.Value, Email: dbConfig.SmtpFrom.Value,
Name: srv.appConfigService.DbConfig.AppName.Value, Name: dbConfig.AppName.Value,
}, },
}) })
c.AddAddressHeader("To", []email.Address{toEmail}) c.AddAddressHeader("To", []email.Address{toEmail})
@@ -94,10 +104,10 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
// so we use the domain of the from address instead (the same as Thunderbird does) // so we use the domain of the from address instead (the same as Thunderbird does)
// if the address does not have an @ (which would be unusual), we use hostname // if the address does not have an @ (which would be unusual), we use hostname
from_address := srv.appConfigService.DbConfig.SmtpFrom.Value fromAddress := dbConfig.SmtpFrom.Value
domain := "" domain := ""
if strings.Contains(from_address, "@") { if strings.Contains(fromAddress, "@") {
domain = strings.Split(from_address, "@")[1] domain = strings.Split(fromAddress, "@")[1]
} else { } else {
hostname, err := os.Hostname() hostname, err := os.Hostname()
if err != nil { if err != nil {
@@ -107,10 +117,19 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
domain = hostname domain = hostname
} }
} }
c.AddHeader("Message-ID", "<" + uuid.New().String() + "@" + domain + ">") c.AddHeader("Message-ID", "<"+uuid.New().String()+"@"+domain+">")
c.Body(body) c.Body(body)
// Check if the context is still valid before attemtping to connect
// We need to do this because the smtp library doesn't have context support
select {
case <-ctx.Done():
return ctx.Err()
default:
// All good
}
// Connect to the SMTP server // Connect to the SMTP server
client, err := srv.getSmtpClient() client, err := srv.getSmtpClient()
if err != nil { if err != nil {
@@ -118,6 +137,14 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
} }
defer client.Close() defer client.Close()
// Check if the context is still valid before sending the email
select {
case <-ctx.Done():
return ctx.Err()
default:
// All good
}
// Send the email // Send the email
if err := srv.sendEmailContent(client, toEmail, c); err != nil { if err := srv.sendEmailContent(client, toEmail, c); err != nil {
return fmt.Errorf("send email content: %w", err) return fmt.Errorf("send email content: %w", err)
@@ -127,16 +154,18 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
} }
func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) { func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
port := srv.appConfigService.DbConfig.SmtpPort.Value dbConfig := srv.appConfigService.GetDbConfig()
smtpAddress := srv.appConfigService.DbConfig.SmtpHost.Value + ":" + port
port := dbConfig.SmtpPort.Value
smtpAddress := dbConfig.SmtpHost.Value + ":" + port
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.Value == "true", InsecureSkipVerify: dbConfig.SmtpSkipCertVerify.IsTrue(), //nolint:gosec
ServerName: srv.appConfigService.DbConfig.SmtpHost.Value, ServerName: dbConfig.SmtpHost.Value,
} }
// Connect to the SMTP server based on TLS setting // Connect to the SMTP server based on TLS setting
switch srv.appConfigService.DbConfig.SmtpTls.Value { switch dbConfig.SmtpTls.Value {
case "none": case "none":
client, err = smtp.Dial(smtpAddress) client, err = smtp.Dial(smtpAddress)
case "tls": case "tls":
@@ -147,7 +176,7 @@ func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
tlsConfig, tlsConfig,
) )
default: default:
return nil, fmt.Errorf("invalid SMTP TLS setting: %s", srv.appConfigService.DbConfig.SmtpTls.Value) return nil, fmt.Errorf("invalid SMTP TLS setting: %s", dbConfig.SmtpTls.Value)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to SMTP server: %w", err) return nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
@@ -161,8 +190,8 @@ func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
} }
// Set up the authentication if user or password are set // Set up the authentication if user or password are set
smtpUser := srv.appConfigService.DbConfig.SmtpUser.Value smtpUser := dbConfig.SmtpUser.Value
smtpPassword := srv.appConfigService.DbConfig.SmtpPassword.Value smtpPassword := dbConfig.SmtpPassword.Value
if smtpUser != "" || smtpPassword != "" { if smtpUser != "" || smtpPassword != "" {
// Authenticate with plain auth // Authenticate with plain auth
@@ -198,7 +227,7 @@ func (srv *EmailService) sendHelloCommand(client *smtp.Client) error {
func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Address, c *email.Composer) error { func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Address, c *email.Composer) error {
// Set the sender // Set the sender
if err := client.Mail(srv.appConfigService.DbConfig.SmtpFrom.Value, nil); err != nil { if err := client.Mail(srv.appConfigService.GetDbConfig().SmtpFrom.Value, nil); err != nil {
return fmt.Errorf("failed to set sender: %w", err) return fmt.Errorf("failed to set sender: %w", err)
} }
@@ -214,7 +243,7 @@ func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Add
} }
// Write the email content // Write the email content
_, err = w.Write([]byte(c.String())) _, err = io.Copy(w, strings.NewReader(c.String()))
if err != nil { if err != nil {
return fmt.Errorf("failed to write email data: %w", err) return fmt.Errorf("failed to write email data: %w", err)
} }

View File

@@ -42,6 +42,13 @@ var TestTemplate = email.Template[struct{}]{
}, },
} }
var ApiKeyExpiringSoonTemplate = email.Template[ApiKeyExpiringSoonTemplateData]{
Path: "api-key-expiring-soon",
Title: func(data *email.TemplateData[ApiKeyExpiringSoonTemplateData]) string {
return fmt.Sprintf("API Key \"%s\" Expiring Soon", data.Data.ApiKeyName)
},
}
type NewLoginTemplateData struct { type NewLoginTemplateData struct {
IPAddress string IPAddress string
Country string Country string
@@ -54,7 +61,14 @@ type OneTimeAccessTemplateData = struct {
Code string Code string
LoginLink string LoginLink string
LoginLinkWithCode string LoginLinkWithCode string
ExpirationString string
}
type ApiKeyExpiringSoonTemplateData struct {
Name string
ApiKeyName string
ExpiresAt time.Time
} }
// this is list of all template paths used for preloading templates // this is list of all template paths used for preloading templates
var emailTemplatesPaths = []string{NewLoginTemplate.Path, OneTimeAccessTemplate.Path, TestTemplate.Path} var emailTemplatesPaths = []string{NewLoginTemplate.Path, OneTimeAccessTemplate.Path, TestTemplate.Path, ApiKeyExpiringSoonTemplate.Path}

View File

@@ -3,6 +3,7 @@ package service
import ( import (
"archive/tar" "archive/tar"
"compress/gzip" "compress/gzip"
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -41,7 +42,7 @@ var tailscaleIPNets = []*net.IPNet{
} }
// NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database. // NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database.
func NewGeoLiteService() *GeoLiteService { func NewGeoLiteService(ctx context.Context) *GeoLiteService {
service := &GeoLiteService{} service := &GeoLiteService{}
if common.EnvConfig.MaxMindLicenseKey == "" && common.EnvConfig.GeoLiteDBUrl == common.MaxMindGeoLiteCityUrl { if common.EnvConfig.MaxMindLicenseKey == "" && common.EnvConfig.GeoLiteDBUrl == common.MaxMindGeoLiteCityUrl {
@@ -51,8 +52,9 @@ func NewGeoLiteService() *GeoLiteService {
} }
go func() { go func() {
if err := service.updateDatabase(); err != nil { err := service.updateDatabase(ctx)
log.Printf("Failed to update GeoLite2 City database: %v\n", err) if err != nil {
log.Printf("Failed to update GeoLite2 City database: %v", err)
} }
}() }()
@@ -110,7 +112,7 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
} }
// UpdateDatabase checks the age of the database and updates it if it's older than 14 days. // UpdateDatabase checks the age of the database and updates it if it's older than 14 days.
func (s *GeoLiteService) updateDatabase() error { func (s *GeoLiteService) updateDatabase(parentCtx context.Context) error {
if s.disableUpdater { if s.disableUpdater {
// Avoid updating the GeoLite2 City database. // Avoid updating the GeoLite2 City database.
return nil return nil
@@ -124,8 +126,15 @@ func (s *GeoLiteService) updateDatabase() error {
log.Println("Updating GeoLite2 City database...") log.Println("Updating GeoLite2 City database...")
downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey) downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey)
// Download the database tar.gz file ctx, cancel := context.WithTimeout(parentCtx, 10*time.Minute)
resp, err := http.Get(downloadUrl) defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("failed to download database: %w", err) return fmt.Errorf("failed to download database: %w", err)
} }
@@ -164,6 +173,9 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
tarReader := tar.NewReader(gzr) tarReader := tar.NewReader(gzr)
var totalSize int64
const maxTotalSize = 300 * 1024 * 1024 // 300 MB limit for total decompressed size
// Iterate over the files in the tar archive // Iterate over the files in the tar archive
for { for {
header, err := tarReader.Next() header, err := tarReader.Next()
@@ -176,6 +188,11 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
// Check if the file is the GeoLite2-City.mmdb file // Check if the file is the GeoLite2-City.mmdb file
if header.Typeflag == tar.TypeReg && filepath.Base(header.Name) == "GeoLite2-City.mmdb" { if header.Typeflag == tar.TypeReg && filepath.Base(header.Name) == "GeoLite2-City.mmdb" {
totalSize += header.Size
if totalSize > maxTotalSize {
return errors.New("total decompressed size exceeds maximum allowed limit")
}
// extract to a temporary file to avoid having a corrupted db in case of write failure. // extract to a temporary file to avoid having a corrupted db in case of write failure.
baseDir := filepath.Dir(common.EnvConfig.GeoLiteDBPath) baseDir := filepath.Dir(common.EnvConfig.GeoLiteDBPath)
tmpFile, err := os.CreateTemp(baseDir, "geolite.*.mmdb.tmp") tmpFile, err := os.CreateTemp(baseDir, "geolite.*.mmdb.tmp")
@@ -185,7 +202,7 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
tempName := tmpFile.Name() tempName := tmpFile.Name()
// Write the file contents directly to the target location // Write the file contents directly to the target location
if _, err := io.Copy(tmpFile, tarReader); err != nil { if _, err := io.Copy(tmpFile, tarReader); err != nil { //nolint:gosec
// if fails to write, then cleanup and throw an error // if fails to write, then cleanup and throw an error
tmpFile.Close() tmpFile.Close()
os.Remove(tempName) os.Remove(tempName)

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"encoding/base64" "encoding/base64"
@@ -11,13 +12,11 @@ import (
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"slices"
"strconv"
"time" "time"
"github.com/golang-jwt/jwt/v5"
"github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk" "github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
@@ -25,15 +24,34 @@ import (
) )
const ( const (
// Path in the data/keys folder where the key is stored // PrivateKeyFile is the path in the data/keys folder where the key is stored
// This is a JSON file containing a key encoded as JWK // This is a JSON file containing a key encoded as JWK
PrivateKeyFile = "jwt_private_key.json" PrivateKeyFile = "jwt_private_key.json"
// Size, in bits, of the RSA key to generate if none is found // RsaKeySize is the size, in bits, of the RSA key to generate if none is found
RsaKeySize = 2048 RsaKeySize = 2048
// Usage for the private keys, for the "use" property // KeyUsageSigning is the usage for the private keys, for the "use" property
KeyUsageSigning = "sig" KeyUsageSigning = "sig"
// IsAdminClaim is a boolean claim used in access tokens for admin users
// This may be omitted on non-admin tokens
IsAdminClaim = "isAdmin"
// TokenTypeClaim is the claim used to identify the type of token
TokenTypeClaim = "type"
// OAuthAccessTokenJWTType identifies a JWT as an OAuth access token
OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec
// AccessTokenJWTType identifies a JWT as an access token used by Pocket ID
AccessTokenJWTType = "access-token"
// IDTokenJWTType identifies a JWT as an ID token used by Pocket ID
IDTokenJWTType = "id-token"
// Acceptable clock skew for verifying tokens
clockSkew = time.Minute
) )
type JwtService struct { type JwtService struct {
@@ -61,11 +79,6 @@ func (s *JwtService) init(appConfigService *AppConfigService, keysPath string) e
return s.loadOrGenerateKey(keysPath) return s.loadOrGenerateKey(keysPath)
} }
type AccessTokenJWTClaims struct {
jwt.RegisteredClaims
IsAdmin bool `json:"isAdmin,omitempty"`
}
// loadOrGenerateKey loads the private key from the given path or generates it if not existing. // loadOrGenerateKey loads the private key from the given path or generates it if not existing.
func (s *JwtService) loadOrGenerateKey(keysPath string) error { func (s *JwtService) loadOrGenerateKey(keysPath string) error {
var key jwk.Key var key jwk.Key
@@ -142,9 +155,15 @@ func (s *JwtService) SetKey(privateKey jwk.Key) error {
return fmt.Errorf("private key is not valid: %w", err) return fmt.Errorf("private key is not valid: %w", err)
} }
// Set the private key in the object // Set the private key and key id in the object
s.privateKey = privateKey s.privateKey = privateKey
keyId, ok := privateKey.KeyID()
if !ok {
return errors.New("key object does not contain a key ID")
}
s.keyId = keyId
// Create and encode a JWKS containing the public key // Create and encode a JWKS containing the public key
publicKey, err := s.GetPublicJWK() publicKey, err := s.GetPublicJWK()
if err != nil { if err != nil {
@@ -164,133 +183,182 @@ func (s *JwtService) SetKey(privateKey jwk.Key) error {
} }
func (s *JwtService) GenerateAccessToken(user model.User) (string, error) { func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
sessionDurationInMinutes, _ := strconv.Atoi(s.appConfigService.DbConfig.SessionDuration.Value) now := time.Now()
claim := AccessTokenJWTClaims{ token, err := jwt.NewBuilder().
RegisteredClaims: jwt.RegisteredClaims{ Subject(user.ID).
Subject: user.ID, Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())).
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(sessionDurationInMinutes) * time.Minute)), IssuedAt(now).
IssuedAt: jwt.NewNumericDate(time.Now()), Issuer(common.EnvConfig.AppURL).
Audience: jwt.ClaimStrings{common.EnvConfig.AppURL}, Build()
},
IsAdmin: user.IsAdmin,
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = s.keyId
var privateKeyRaw any
err := jwk.Export(s.privateKey, &privateKeyRaw)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to export private key object: %w", err) return "", fmt.Errorf("failed to build token: %w", err)
} }
signed, err := token.SignedString(privateKeyRaw) err = SetAudienceString(token, common.EnvConfig.AppURL)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
err = SetTokenType(token, AccessTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
err = SetIsAdmin(token, user.IsAdmin)
if err != nil {
return "", fmt.Errorf("failed to set 'isAdmin' claim in token: %w", err)
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
if err != nil { if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err) return "", fmt.Errorf("failed to sign token: %w", err)
} }
return signed, nil return string(signed), nil
} }
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) { func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (any, error) { alg, _ := s.privateKey.Algorithm()
return s.getPublicKeyRaw() token, err := jwt.ParseString(
}) tokenString,
if err != nil || !token.Valid { jwt.WithValidate(true),
return nil, errors.New("couldn't handle this token") jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithAudience(common.EnvConfig.AppURL),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(AccessTokenJWTType)),
)
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
} }
claims, isValid := token.Claims.(*AccessTokenJWTClaims) return token, nil
if !isValid {
return nil, errors.New("can't parse claims")
}
if !slices.Contains(claims.Audience, common.EnvConfig.AppURL) {
return nil, errors.New("audience doesn't match")
}
return claims, nil
} }
func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID string, nonce string) (string, error) { func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
// Initialize with capacity for userClaims, + 4 fixed claims, + 2 claims which may be set in some cases, to avoid re-allocations now := time.Now()
claims := make(jwt.MapClaims, len(userClaims)+6) token, err := jwt.NewBuilder().
claims["aud"] = clientID Expiration(now.Add(1 * time.Hour)).
claims["exp"] = jwt.NewNumericDate(time.Now().Add(1 * time.Hour)) IssuedAt(now).
claims["iat"] = jwt.NewNumericDate(time.Now()) Issuer(common.EnvConfig.AppURL).
claims["iss"] = common.EnvConfig.AppURL Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
}
err = SetAudienceString(token, clientID)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
err = SetTokenType(token, IDTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
for k, v := range userClaims { for k, v := range userClaims {
claims[k] = v err = token.Set(k, v)
if err != nil {
return "", fmt.Errorf("failed to set claim '%s': %w", k, err)
}
} }
if nonce != "" { if nonce != "" {
claims["nonce"] = nonce err = token.Set("nonce", nonce)
if err != nil {
return "", fmt.Errorf("failed to set claim 'nonce': %w", err)
}
} }
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) alg, _ := s.privateKey.Algorithm()
token.Header["kid"] = s.keyId signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
var privateKeyRaw any
err := jwk.Export(s.privateKey, &privateKeyRaw)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to export private key object: %w", err) return "", fmt.Errorf("failed to sign token: %w", err)
} }
return token.SignedString(privateKeyRaw) return string(signed), nil
} }
func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) { func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool) (jwt.Token, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) { alg, _ := s.privateKey.Algorithm()
return s.getPublicKeyRaw()
}, jwt.WithIssuer(common.EnvConfig.AppURL))
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) { opts := make([]jwt.ParseOption, 0)
return nil, errors.New("couldn't handle this token")
// These options are always present
opts = append(opts,
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(IDTokenJWTType)),
)
// By default, jwt.Parse includes 3 default validators for "nbf", "iat", and "exp"
// In case we want to accept expired tokens (during logout), we need to set the validators explicitly without validating "exp"
if acceptExpiredTokens {
// This is equivalent to the default validators except it doesn't validate "exp"
opts = append(opts,
jwt.WithResetValidators(true),
jwt.WithValidator(jwt.IsIssuedAtValid()),
jwt.WithValidator(jwt.IsNbfValid()),
)
} }
claims, isValid := token.Claims.(*jwt.RegisteredClaims) token, err := jwt.ParseString(tokenString, opts...)
if !isValid { if err != nil {
return nil, errors.New("can't parse claims") return nil, fmt.Errorf("failed to parse token: %w", err)
} }
return claims, nil return token, nil
} }
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) { func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
claim := jwt.RegisteredClaims{ now := time.Now()
Subject: user.ID, token, err := jwt.NewBuilder().
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), Subject(user.ID).
IssuedAt: jwt.NewNumericDate(time.Now()), Expiration(now.Add(1 * time.Hour)).
Audience: jwt.ClaimStrings{clientID}, IssuedAt(now).
Issuer: common.EnvConfig.AppURL, Issuer(common.EnvConfig.AppURL).
} Build()
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = s.keyId
var privateKeyRaw any
err := jwk.Export(s.privateKey, &privateKeyRaw)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to export private key object: %w", err) return "", fmt.Errorf("failed to build token: %w", err)
} }
return token.SignedString(privateKeyRaw) err = SetAudienceString(token, clientID)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
err = SetTokenType(token, OAuthAccessTokenJWTType)
if err != nil {
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
return string(signed), nil
} }
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) { func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (any, error) { alg, _ := s.privateKey.Algorithm()
return s.getPublicKeyRaw() token, err := jwt.ParseString(
}) tokenString,
if err != nil || !token.Valid { jwt.WithValidate(true),
return nil, errors.New("couldn't handle this token") jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
jwt.WithValidator(TokenTypeValidator(OAuthAccessTokenJWTType)),
)
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
} }
claims, isValid := token.Claims.(*jwt.RegisteredClaims) return token, nil
if !isValid {
return nil, errors.New("can't parse claims")
}
return claims, nil
} }
// GetPublicJWK returns the JSON Web Key (JWK) for the public key. // GetPublicJWK returns the JSON Web Key (JWK) for the public key.
@@ -319,17 +387,18 @@ func (s *JwtService) GetPublicJWKSAsJSON() ([]byte, error) {
return s.jwksEncoded, nil return s.jwksEncoded, nil
} }
func (s *JwtService) getPublicKeyRaw() (any, error) { // GetKeyAlg returns the algorithm of the key
pubKey, err := s.privateKey.PublicKey() func (s *JwtService) GetKeyAlg() (jwa.KeyAlgorithm, error) {
if err != nil { if len(s.jwksEncoded) == 0 {
return nil, fmt.Errorf("failed to get public key: %w", err) return nil, errors.New("key is not initialized")
} }
var pubKeyRaw any
err = jwk.Export(pubKey, &pubKeyRaw) alg, ok := s.privateKey.Algorithm()
if err != nil { if !ok || alg == nil {
return nil, fmt.Errorf("failed to export raw public key: %w", err) return nil, errors.New("failed to retrieve algorithm for key")
} }
return pubKeyRaw, nil
return alg, nil
} }
func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) { func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
@@ -424,7 +493,6 @@ func SaveKeyJWK(key jwk.Key, path string) error {
} }
// generateRandomKeyID generates a random key ID. // generateRandomKeyID generates a random key ID.
// It is used for newly-generated keys
func generateRandomKeyID() (string, error) { func generateRandomKeyID() (string, error) {
buf := make([]byte, 8) buf := make([]byte, 8)
_, err := io.ReadFull(rand.Reader, buf) _, err := io.ReadFull(rand.Reader, buf)
@@ -433,3 +501,51 @@ func generateRandomKeyID() (string, error) {
} }
return base64.RawURLEncoding.EncodeToString(buf), nil return base64.RawURLEncoding.EncodeToString(buf), nil
} }
// GetIsAdmin returns the value of the "isAdmin" claim in the token
func GetIsAdmin(token jwt.Token) (bool, error) {
if !token.Has(IsAdminClaim) {
return false, nil
}
var isAdmin bool
err := token.Get(IsAdminClaim, &isAdmin)
return isAdmin, err
}
// SetTokenType sets the "type" claim in the token
func SetTokenType(token jwt.Token, tokenType string) error {
if tokenType == "" {
return nil
}
return token.Set(TokenTypeClaim, tokenType)
}
// SetIsAdmin sets the "isAdmin" claim in the token
func SetIsAdmin(token jwt.Token, isAdmin bool) error {
// Only set if true
if !isAdmin {
return nil
}
return token.Set(IsAdminClaim, isAdmin)
}
// SetAudienceString sets the "aud" claim with a value that is a string, and not an array
// This is permitted by RFC 7519, and it's done here for backwards-compatibility
func SetAudienceString(token jwt.Token, audience string) error {
return token.Set(jwt.AudienceKey, audience)
}
// TokenTypeValidator is a validator function that checks the "type" claim in the token
func TokenTypeValidator(expectedTokenType string) jwt.ValidatorFunc {
return func(_ context.Context, t jwt.Token) error {
var tokenType string
err := t.Get(TokenTypeClaim, &tokenType)
if err != nil {
return fmt.Errorf("failed to get token type claim: %w", err)
}
if tokenType != expectedTokenType {
return fmt.Errorf("invalid token type: expected %s, got %s", expectedTokenType, tokenType)
}
return nil
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@ package service
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
@@ -11,8 +12,10 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time"
"github.com/go-ldap/ldap/v3" "github.com/go-ldap/ldap/v3"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"gorm.io/gorm" "gorm.io/gorm"
@@ -26,46 +29,44 @@ type LdapService struct {
} }
func NewLdapService(db *gorm.DB, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService) *LdapService { func NewLdapService(db *gorm.DB, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService) *LdapService {
return &LdapService{db: db, appConfigService: appConfigService, userService: userService, groupService: groupService} return &LdapService{
db: db,
appConfigService: appConfigService,
userService: userService,
groupService: groupService,
}
} }
func (s *LdapService) createClient() (*ldap.Conn, error) { func (s *LdapService) createClient() (*ldap.Conn, error) {
if s.appConfigService.DbConfig.LdapEnabled.Value != "true" { dbConfig := s.appConfigService.GetDbConfig()
if !dbConfig.LdapEnabled.IsTrue() {
return nil, fmt.Errorf("LDAP is not enabled") return nil, fmt.Errorf("LDAP is not enabled")
} }
// Setup LDAP connection // Setup LDAP connection
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value client, err := ldap.DialURL(dbConfig.LdapUrl.Value, ldap.DialWithTLSConfig(&tls.Config{
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.Value == "true" InsecureSkipVerify: dbConfig.LdapSkipCertVerify.IsTrue(), //nolint:gosec
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) }))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to LDAP: %w", err) return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
} }
// Bind as service account // Bind as service account
bindDn := s.appConfigService.DbConfig.LdapBindDn.Value err = client.Bind(dbConfig.LdapBindDn.Value, dbConfig.LdapBindPassword.Value)
bindPassword := s.appConfigService.DbConfig.LdapBindPassword.Value
err = client.Bind(bindDn, bindPassword)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to bind to LDAP: %w", err) return nil, fmt.Errorf("failed to bind to LDAP: %w", err)
} }
return client, nil return client, nil
} }
func (s *LdapService) SyncAll() error { func (s *LdapService) SyncAll(ctx context.Context) error {
err := s.SyncUsers() // Start a transaction
if err != nil { tx := s.db.Begin()
return fmt.Errorf("failed to sync users: %w", err) defer func() {
} tx.Rollback()
}()
err = s.SyncGroups()
if err != nil {
return fmt.Errorf("failed to sync groups: %w", err)
}
return nil
}
func (s *LdapService) SyncGroups() error {
// Setup LDAP connection // Setup LDAP connection
client, err := s.createClient() client, err := s.createClient()
if err != nil { if err != nil {
@@ -73,237 +74,368 @@ func (s *LdapService) SyncGroups() error {
} }
defer client.Close() defer client.Close()
baseDN := s.appConfigService.DbConfig.LdapBase.Value err = s.SyncUsers(ctx, tx, client)
nameAttribute := s.appConfigService.DbConfig.LdapAttributeGroupName.Value if err != nil {
uniqueIdentifierAttribute := s.appConfigService.DbConfig.LdapAttributeGroupUniqueIdentifier.Value return fmt.Errorf("failed to sync users: %w", err)
groupMemberOfAttribute := s.appConfigService.DbConfig.LdapAttributeGroupMember.Value
filter := s.appConfigService.DbConfig.LdapUserGroupSearchFilter.Value
searchAttrs := []string{
nameAttribute,
uniqueIdentifierAttribute,
groupMemberOfAttribute,
} }
searchReq := ldap.NewSearchRequest(baseDN, ldap.ScopeWholeSubtree, 0, 0, 0, false, filter, searchAttrs, []ldap.Control{}) err = s.SyncGroups(ctx, tx, client)
if err != nil {
return fmt.Errorf("failed to sync groups: %w", err)
}
// Commit the changes
err = tx.Commit().Error
if err != nil {
return fmt.Errorf("failed to commit changes to database: %w", err)
}
return nil
}
//nolint:gocognit
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
dbConfig := s.appConfigService.GetDbConfig()
searchAttrs := []string{
dbConfig.LdapAttributeGroupName.Value,
dbConfig.LdapAttributeGroupUniqueIdentifier.Value,
dbConfig.LdapAttributeGroupMember.Value,
}
searchReq := ldap.NewSearchRequest(
dbConfig.LdapBase.Value,
ldap.ScopeWholeSubtree,
0, 0, 0, false,
dbConfig.LdapUserGroupSearchFilter.Value,
searchAttrs,
[]ldap.Control{},
)
result, err := client.Search(searchReq) result, err := client.Search(searchReq)
if err != nil { if err != nil {
return fmt.Errorf("failed to query LDAP: %w", err) return fmt.Errorf("failed to query LDAP: %w", err)
} }
// Create a mapping for groups that exist // Create a mapping for groups that exist
ldapGroupIDs := make(map[string]bool) ldapGroupIDs := make(map[string]struct{}, len(result.Entries))
for _, value := range result.Entries { for _, value := range result.Entries {
var membersUserId []string ldapId := value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value)
ldapId := value.GetAttributeValue(uniqueIdentifierAttribute) // Skip groups without a valid LDAP ID
ldapGroupIDs[ldapId] = true if ldapId == "" {
log.Printf("Skipping LDAP group without a valid unique identifier (attribute: %s)", dbConfig.LdapAttributeGroupUniqueIdentifier.Value)
continue
}
ldapGroupIDs[ldapId] = struct{}{}
// Try to find the group in the database // Try to find the group in the database
var databaseGroup model.UserGroup var databaseGroup model.UserGroup
s.db.Where("ldap_id = ?", ldapId).First(&databaseGroup) err = tx.
WithContext(ctx).
Where("ldap_id = ?", ldapId).
First(&databaseGroup).
Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
// This could error with ErrRecordNotFound and we want to ignore that here
return fmt.Errorf("failed to query for LDAP group ID '%s': %w", ldapId, err)
}
// Get group members and add to the correct Group // Get group members and add to the correct Group
groupMembers := value.GetAttributeValues(groupMemberOfAttribute) groupMembers := value.GetAttributeValues(dbConfig.LdapAttributeGroupMember.Value)
membersUserId := make([]string, 0, len(groupMembers))
for _, member := range groupMembers { for _, member := range groupMembers {
// Normal output of this would be CN=username,ou=people,dc=example,dc=com ldapId := getDNProperty("uid", member)
// Splitting at the "=" and "," then just grabbing the username for that string if ldapId == "" {
singleMember := strings.Split(strings.Split(member, "=")[1], ",")[0] continue
}
var databaseUser model.User var databaseUser model.User
err := s.db.Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error err = tx.
if err != nil { WithContext(ctx).
if errors.Is(err, gorm.ErrRecordNotFound) { Where("username = ? AND ldap_id IS NOT NULL", ldapId).
// The user collides with a non-LDAP user, so we skip it First(&databaseUser).
continue Error
} else { if errors.Is(err, gorm.ErrRecordNotFound) {
return err // The user collides with a non-LDAP user, so we skip it
} continue
} else if err != nil {
return fmt.Errorf("failed to query for existing user '%s': %w", ldapId, err)
} }
membersUserId = append(membersUserId, databaseUser.ID) membersUserId = append(membersUserId, databaseUser.ID)
} }
syncGroup := dto.UserGroupCreateDto{ syncGroup := dto.UserGroupCreateDto{
Name: value.GetAttributeValue(nameAttribute), Name: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
FriendlyName: value.GetAttributeValue(nameAttribute), FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
LdapID: value.GetAttributeValue(uniqueIdentifierAttribute), LdapID: value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value),
} }
if databaseGroup.ID == "" { if databaseGroup.ID == "" {
newGroup, err := s.groupService.Create(syncGroup) newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
if err != nil { if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err) return fmt.Errorf("failed to create group '%s': %w", syncGroup.Name, err)
} else { }
if _, err = s.groupService.UpdateUsers(newGroup.ID, membersUserId); err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err) _, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx)
} if err != nil {
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
} }
} else { } else {
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true) _, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx)
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
if err != nil { if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err) return fmt.Errorf("failed to update group '%s': %w", syncGroup.Name, err)
return err
} }
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx)
if err != nil {
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
}
} }
} }
// Get all LDAP groups from the database // Get all LDAP groups from the database
var ldapGroupsInDb []model.UserGroup var ldapGroupsInDb []model.UserGroup
if err := s.db.Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil { err = tx.
fmt.Println(fmt.Errorf("failed to fetch groups from database: %v", err)) WithContext(ctx).
Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").
Select("ldap_id").
Error
if err != nil {
return fmt.Errorf("failed to fetch groups from database: %w", err)
} }
// Delete groups that no longer exist in LDAP // Delete groups that no longer exist in LDAP
for _, group := range ldapGroupsInDb { for _, group := range ldapGroupsInDb {
if _, exists := ldapGroupIDs[*group.LdapID]; !exists { if _, exists := ldapGroupIDs[*group.LdapID]; exists {
if err := s.db.Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).Error; err != nil { continue
log.Printf("Failed to delete group %s with: %v", group.Name, err)
} else {
log.Printf("Deleted group %s", group.Name)
}
} }
err = tx.
WithContext(ctx).
Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).
Error
if err != nil {
return fmt.Errorf("failed to delete group '%s': %w", group.Name, err)
}
log.Printf("Deleted group '%s'", group.Name)
} }
return nil return nil
} }
func (s *LdapService) SyncUsers() error { //nolint:gocognit
// Setup LDAP connection func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
client, err := s.createClient() dbConfig := s.appConfigService.GetDbConfig()
if err != nil {
return fmt.Errorf("failed to create LDAP client: %w", err)
}
defer client.Close()
baseDN := s.appConfigService.DbConfig.LdapBase.Value
uniqueIdentifierAttribute := s.appConfigService.DbConfig.LdapAttributeUserUniqueIdentifier.Value
usernameAttribute := s.appConfigService.DbConfig.LdapAttributeUserUsername.Value
emailAttribute := s.appConfigService.DbConfig.LdapAttributeUserEmail.Value
firstNameAttribute := s.appConfigService.DbConfig.LdapAttributeUserFirstName.Value
lastNameAttribute := s.appConfigService.DbConfig.LdapAttributeUserLastName.Value
profilePictureAttribute := s.appConfigService.DbConfig.LdapAttributeUserProfilePicture.Value
adminGroupAttribute := s.appConfigService.DbConfig.LdapAttributeAdminGroup.Value
filter := s.appConfigService.DbConfig.LdapUserSearchFilter.Value
searchAttrs := []string{ searchAttrs := []string{
"memberOf", "memberOf",
"sn", "sn",
"cn", "cn",
uniqueIdentifierAttribute, dbConfig.LdapAttributeUserUniqueIdentifier.Value,
usernameAttribute, dbConfig.LdapAttributeUserUsername.Value,
emailAttribute, dbConfig.LdapAttributeUserEmail.Value,
firstNameAttribute, dbConfig.LdapAttributeUserFirstName.Value,
lastNameAttribute, dbConfig.LdapAttributeUserLastName.Value,
profilePictureAttribute, dbConfig.LdapAttributeUserProfilePicture.Value,
} }
// Filters must start and finish with ()! // Filters must start and finish with ()!
searchReq := ldap.NewSearchRequest(baseDN, ldap.ScopeWholeSubtree, 0, 0, 0, false, filter, searchAttrs, []ldap.Control{}) searchReq := ldap.NewSearchRequest(
dbConfig.LdapBase.Value,
ldap.ScopeWholeSubtree,
0, 0, 0, false,
dbConfig.LdapUserSearchFilter.Value,
searchAttrs,
[]ldap.Control{},
)
result, err := client.Search(searchReq) result, err := client.Search(searchReq)
if err != nil { if err != nil {
fmt.Println(fmt.Errorf("failed to query LDAP: %w", err)) return fmt.Errorf("failed to query LDAP: %w", err)
} }
// Create a mapping for users that exist // Create a mapping for users that exist
ldapUserIDs := make(map[string]bool) ldapUserIDs := make(map[string]struct{}, len(result.Entries))
for _, value := range result.Entries { for _, value := range result.Entries {
ldapId := value.GetAttributeValue(uniqueIdentifierAttribute) ldapId := value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value)
ldapUserIDs[ldapId] = true
// Skip users without a valid LDAP ID
if ldapId == "" {
log.Printf("Skipping LDAP user without a valid unique identifier (attribute: %s)", dbConfig.LdapAttributeUserUniqueIdentifier.Value)
continue
}
ldapUserIDs[ldapId] = struct{}{}
// Get the user from the database // Get the user from the database
var databaseUser model.User var databaseUser model.User
s.db.Where("ldap_id = ?", ldapId).First(&databaseUser) err = tx.
WithContext(ctx).
Where("ldap_id = ?", ldapId).
First(&databaseUser).
Error
// If a user is found (even if disabled), enable them since they're now back in LDAP
if databaseUser.ID != "" && databaseUser.Disabled {
// Use the transaction instead of the direct context
err = tx.
WithContext(ctx).
Model(&model.User{}).
Where("id = ?", databaseUser.ID).
Update("disabled", false).
Error
if err != nil {
log.Printf("Failed to enable user %s: %v", databaseUser.Username, err)
}
}
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
// This could error with ErrRecordNotFound and we want to ignore that here
return fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
}
// Check if user is admin by checking if they are in the admin group // Check if user is admin by checking if they are in the admin group
isAdmin := false isAdmin := false
for _, group := range value.GetAttributeValues("memberOf") { for _, group := range value.GetAttributeValues("memberOf") {
if strings.Contains(group, adminGroupAttribute) { if getDNProperty("cn", group) == dbConfig.LdapAttributeAdminGroup.Value {
isAdmin = true isAdmin = true
break
} }
} }
newUser := dto.UserCreateDto{ newUser := dto.UserCreateDto{
Username: value.GetAttributeValue(usernameAttribute), Username: value.GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value),
Email: value.GetAttributeValue(emailAttribute), Email: value.GetAttributeValue(dbConfig.LdapAttributeUserEmail.Value),
FirstName: value.GetAttributeValue(firstNameAttribute), FirstName: value.GetAttributeValue(dbConfig.LdapAttributeUserFirstName.Value),
LastName: value.GetAttributeValue(lastNameAttribute), LastName: value.GetAttributeValue(dbConfig.LdapAttributeUserLastName.Value),
IsAdmin: isAdmin, IsAdmin: isAdmin,
LdapID: ldapId, LdapID: ldapId,
} }
if databaseUser.ID == "" { if databaseUser.ID == "" {
_, err = s.userService.CreateUser(newUser) _, err = s.userService.createUserInternal(ctx, newUser, true, tx)
if err != nil { if errors.Is(err, &common.AlreadyInUseError{}) {
log.Printf("Error syncing user %s: %s", newUser.Username, err) log.Printf("Skipping creating LDAP user '%s': %v", newUser.Username, err)
continue
} else if err != nil {
return fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
} }
} else { } else {
_, err = s.userService.UpdateUser(databaseUser.ID, newUser, false, true) _, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
if err != nil { if errors.Is(err, &common.AlreadyInUseError{}) {
log.Printf("Error syncing user %s: %s", newUser.Username, err) log.Printf("Skipping updating LDAP user '%s': %v", newUser.Username, err)
continue
} else if err != nil {
return fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
} }
} }
// Save profile picture // Save profile picture
if pictureString := value.GetAttributeValue(profilePictureAttribute); pictureString != "" { pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value)
if err := s.SaveProfilePicture(databaseUser.ID, pictureString); err != nil { if pictureString != "" {
log.Printf("Error saving profile picture for user %s: %s", newUser.Username, err) err = s.saveProfilePicture(ctx, databaseUser.ID, pictureString)
if err != nil {
// This is not a fatal error
log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err)
} }
} }
} }
// Get all LDAP users from the database // Get all LDAP users from the database
var ldapUsersInDb []model.User var ldapUsersInDb []model.User
if err := s.db.Find(&ldapUsersInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil { err = tx.
fmt.Println(fmt.Errorf("failed to fetch users from database: %v", err)) WithContext(ctx).
Find(&ldapUsersInDb, "ldap_id IS NOT NULL").
Select("id, username, ldap_id, disabled").
Error
if err != nil {
return fmt.Errorf("failed to fetch users from database: %w", err)
} }
// Delete users that no longer exist in LDAP // Mark users as disabled or delete users that no longer exist in LDAP
for _, user := range ldapUsersInDb { for _, user := range ldapUsersInDb {
if _, exists := ldapUserIDs[*user.LdapID]; !exists { // Skip if the user ID exists in the fetched LDAP results
if err := s.userService.DeleteUser(user.ID); err != nil { if _, exists := ldapUserIDs[*user.LdapID]; exists {
log.Printf("Failed to delete user %s with: %v", user.Username, err) continue
} else { }
log.Printf("Deleted user %s", user.Username)
if dbConfig.LdapSoftDeleteUsers.IsTrue() {
err = s.userService.DisableUser(ctx, user.ID, tx)
if err != nil {
log.Printf("Failed to disable user %s: %v", user.Username, err)
continue
}
} else {
err = s.userService.DeleteUser(ctx, user.ID, true)
if err != nil {
log.Printf("Failed to delete user %s: %v", user.Username, err)
continue
} }
} }
} }
return nil return nil
} }
func (s *LdapService) SaveProfilePicture(userId string, pictureString string) error { func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
var reader io.Reader var reader io.Reader
if _, err := url.ParseRequestURI(pictureString); err == nil { _, err := url.ParseRequestURI(pictureString)
// If the photo is a URL, download it if err == nil {
response, err := http.Get(pictureString) ctx, cancel := context.WithTimeout(parentCtx, 5*time.Second)
defer cancel()
var req *http.Request
req, err = http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
var res *http.Response
res, err = http.DefaultClient.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("failed to download profile picture: %w", err) return fmt.Errorf("failed to download profile picture: %w", err)
} }
defer response.Body.Close() defer res.Body.Close()
reader = response.Body
reader = res.Body
} else if decodedPhoto, err := base64.StdEncoding.DecodeString(pictureString); err == nil { } else if decodedPhoto, err := base64.StdEncoding.DecodeString(pictureString); err == nil {
// If the photo is a base64 encoded string, decode it // If the photo is a base64 encoded string, decode it
reader = bytes.NewReader(decodedPhoto) reader = bytes.NewReader(decodedPhoto)
} else { } else {
// If the photo is a string, we assume that it's a binary string // If the photo is a string, we assume that it's a binary string
reader = bytes.NewReader([]byte(pictureString)) reader = bytes.NewReader([]byte(pictureString))
} }
// Update the profile picture // Update the profile picture
if err := s.userService.UpdateProfilePicture(userId, reader); err != nil { err = s.userService.UpdateProfilePicture(userId, reader)
if err != nil {
return fmt.Errorf("failed to update profile picture: %w", err) return fmt.Errorf("failed to update profile picture: %w", err)
} }
return nil return nil
} }
// getDNProperty returns the value of a property from a LDAP identifier
// See: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/ldap/distinguished-names
func getDNProperty(property string, str string) string {
// Example format is "CN=username,ou=people,dc=example,dc=com"
// First we split at the comma
property = strings.ToLower(property)
l := len(property) + 1
for _, v := range strings.Split(str, ",") {
v = strings.TrimSpace(v)
if len(v) > l && strings.ToLower(v)[0:l] == property+"=" {
return v[l:]
}
}
// CN not found, return an empty string
return ""
}

View File

@@ -0,0 +1,73 @@
package service
import (
"testing"
)
func TestGetDNProperty(t *testing.T) {
tests := []struct {
name string
property string
dn string
expectedResult string
}{
{
name: "simple case",
property: "cn",
dn: "cn=username,ou=people,dc=example,dc=com",
expectedResult: "username",
},
{
name: "property not found",
property: "uid",
dn: "cn=username,ou=people,dc=example,dc=com",
expectedResult: "",
},
{
name: "mixed case property",
property: "CN",
dn: "cn=username,ou=people,dc=example,dc=com",
expectedResult: "username",
},
{
name: "mixed case DN",
property: "cn",
dn: "CN=username,OU=people,DC=example,DC=com",
expectedResult: "username",
},
{
name: "spaces in DN",
property: "cn",
dn: "cn=username, ou=people, dc=example, dc=com",
expectedResult: "username",
},
{
name: "value with special characters",
property: "cn",
dn: "cn=user.name+123,ou=people,dc=example,dc=com",
expectedResult: "user.name+123",
},
{
name: "empty DN",
property: "cn",
dn: "",
expectedResult: "",
},
{
name: "empty property",
property: "",
dn: "cn=username,ou=people,dc=example,dc=com",
expectedResult: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getDNProperty(tt.property, tt.dn)
if result != tt.expectedResult {
t.Errorf("getDNProperty(%q, %q) = %q, want %q",
tt.property, tt.dn, result, tt.expectedResult)
}
})
}
}

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
@@ -9,9 +10,12 @@ import (
"mime/multipart" "mime/multipart"
"os" "os"
"regexp" "regexp"
"slices"
"strings" "strings"
"time" "time"
"github.com/lestrrat-go/jwx/v3/jwt"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
@@ -39,9 +43,19 @@ func NewOidcService(db *gorm.DB, jwtService *JwtService, appConfigService *AppCo
} }
} }
func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) { func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var client model.OidcClient var client model.OidcClient
if err := s.db.Preload("AllowedUserGroups").First(&client, "id = ?", input.ClientID).Error; err != nil { err := tx.
WithContext(ctx).
Preload("AllowedUserGroups").
First(&client, "id = ?", input.ClientID).
Error
if err != nil {
return "", "", err return "", "", err
} }
@@ -58,7 +72,12 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
// Check if the user group is allowed to authorize the client // Check if the user group is allowed to authorize the client
var user model.User var user model.User
if err := s.db.Preload("UserGroups").First(&user, "id = ?", userID).Error; err != nil { err = tx.
WithContext(ctx).
Preload("UserGroups").
First(&user, "id = ?", userID).
Error
if err != nil {
return "", "", err return "", "", err
} }
@@ -67,7 +86,7 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
} }
// Check if the user has already authorized the client with the given scope // Check if the user has already authorized the client with the given scope
hasAuthorizedClient, err := s.HasAuthorizedClient(input.ClientID, userID, input.Scope) hasAuthorizedClient, err := s.hasAuthorizedClientInternal(ctx, input.ClientID, userID, input.Scope, tx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@@ -80,39 +99,55 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
Scope: input.Scope, Scope: input.Scope,
} }
if err := s.db.Create(&userAuthorizedClient).Error; err != nil { err = tx.
if errors.Is(err, gorm.ErrDuplicatedKey) { WithContext(ctx).
// The client has already been authorized but with a different scope so we need to update the scope Create(&userAuthorizedClient).
if err := s.db.Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil { Error
return "", "", err if errors.Is(err, gorm.ErrDuplicatedKey) {
} // The client has already been authorized but with a different scope so we need to update the scope
} else { if err := tx.
WithContext(ctx).
Model(&userAuthorizedClient).Update("scope", input.Scope).Error; err != nil {
return "", "", err return "", "", err
} }
} else if err != nil {
return "", "", err
} }
} }
// Create the authorization code // Create the authorization code
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod) code, err := s.createAuthorizationCode(ctx, input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod, tx)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
// Log the authorization event // Log the authorization event
if hasAuthorizedClient { if hasAuthorizedClient {
s.auditLogService.Create(model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}) s.auditLogService.Create(ctx, model.AuditLogEventClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}, tx)
} else { } else {
s.auditLogService.Create(model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}) s.auditLogService.Create(ctx, model.AuditLogEventNewClientAuthorization, ipAddress, userAgent, userID, model.AuditLogData{"clientName": client.Name}, tx)
}
err = tx.Commit().Error
if err != nil {
return "", "", err
} }
return code, callbackURL, nil return code, callbackURL, nil
} }
// HasAuthorizedClient checks if the user has already authorized the client with the given scope // HasAuthorizedClient checks if the user has already authorized the client with the given scope
func (s *OidcService) HasAuthorizedClient(clientID, userID, scope string) (bool, error) { func (s *OidcService) HasAuthorizedClient(ctx context.Context, clientID, userID, scope string) (bool, error) {
return s.hasAuthorizedClientInternal(ctx, clientID, userID, scope, s.db)
}
func (s *OidcService) hasAuthorizedClientInternal(ctx context.Context, clientID, userID, scope string, tx *gorm.DB) (bool, error) {
var userAuthorizedOidcClient model.UserAuthorizedOidcClient var userAuthorizedOidcClient model.UserAuthorizedOidcClient
if err := s.db.First(&userAuthorizedOidcClient, "client_id = ? AND user_id = ?", clientID, userID).Error; err != nil { err := tx.
WithContext(ctx).
First(&userAuthorizedOidcClient, "client_id = ? AND user_id = ?", clientID, userID).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil return false, nil
} }
@@ -145,74 +180,296 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
return isAllowedToAuthorize return isAllowedToAuthorize
} }
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier string) (string, string, error) { func (s *OidcService) CreateTokens(ctx context.Context, code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
if grantType != "authorization_code" { switch grantType {
return "", "", &common.OidcGrantTypeNotSupportedError{} case "authorization_code":
return s.createTokenFromAuthorizationCode(ctx, code, clientID, clientSecret, codeVerifier)
case "refresh_token":
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(ctx, refreshToken, clientID, clientSecret)
return "", accessToken, newRefreshToken, exp, err
default:
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
} }
}
func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var client model.OidcClient var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { err = tx.
return "", "", err WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return "", "", "", 0, err
} }
// Verify the client secret if the client is not public // Verify the client secret if the client is not public
if !client.IsPublic { if !client.IsPublic {
if clientID == "" || clientSecret == "" { if clientID == "" || clientSecret == "" {
return "", "", &common.OidcMissingClientCredentialsError{} return "", "", "", 0, &common.OidcMissingClientCredentialsError{}
} }
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)) err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil { if err != nil {
return "", "", &common.OidcClientSecretInvalidError{} return "", "", "", 0, &common.OidcClientSecretInvalidError{}
} }
} }
var authorizationCodeMetaData model.OidcAuthorizationCode var authorizationCodeMetaData model.OidcAuthorizationCode
err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error err = tx.
WithContext(ctx).
Preload("User").
First(&authorizationCodeMetaData, "code = ?", code).
Error
if err != nil { if err != nil {
return "", "", &common.OidcInvalidAuthorizationCodeError{} return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
} }
// If the client is public or PKCE is enabled, the code verifier must match the code challenge // If the client is public or PKCE is enabled, the code verifier must match the code challenge
if client.IsPublic || client.PkceEnabled { if client.IsPublic || client.PkceEnabled {
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) { if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return "", "", &common.OidcInvalidCodeVerifierError{} return "", "", "", 0, &common.OidcInvalidCodeVerifierError{}
} }
} }
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) { if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return "", "", &common.OidcInvalidAuthorizationCodeError{} return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
} }
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID) userClaims, err := s.getUserClaimsForClientInternal(ctx, authorizationCodeMetaData.UserID, clientID, tx)
if err != nil { if err != nil {
return "", "", err return "", "", "", 0, err
} }
idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce) idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
if err != nil { if err != nil {
return "", "", err return "", "", "", 0, err
} }
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID) // Generate a refresh token
refreshToken, err = s.createRefreshToken(ctx, clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx)
if err != nil {
return "", "", "", 0, err
}
s.db.Delete(&authorizationCodeMetaData) accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
if err != nil {
return "", "", "", 0, err
}
return idToken, accessToken, nil err = tx.
WithContext(ctx).
Delete(&authorizationCodeMetaData).
Error
if err != nil {
return "", "", "", 0, err
}
err = tx.Commit().Error
if err != nil {
return "", "", "", 0, err
}
return idToken, accessToken, refreshToken, 3600, nil
} }
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) { func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) {
if refreshToken == "" {
return "", "", 0, &common.OidcMissingRefreshTokenError{}
}
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
// Get the client to check if it's public
var client model.OidcClient var client model.OidcClient
if err := s.db.Preload("CreatedBy").Preload("AllowedUserGroups").First(&client, "id = ?", clientID).Error; err != nil { err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return "", "", 0, err
}
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", 0, &common.OidcMissingClientCredentialsError{}
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", 0, &common.OidcClientSecretInvalidError{}
}
}
// Verify refresh token
var storedRefreshToken model.OidcRefreshToken
err = tx.
WithContext(ctx).
Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
First(&storedRefreshToken).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", "", 0, &common.OidcInvalidRefreshTokenError{}
}
return "", "", 0, err
}
// Verify that the refresh token belongs to the provided client
if storedRefreshToken.ClientID != clientID {
return "", "", 0, &common.OidcInvalidRefreshTokenError{}
}
// Generate a new access token
accessToken, err = s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID)
if err != nil {
return "", "", 0, err
}
// Generate a new refresh token and invalidate the old one
newRefreshToken, err = s.createRefreshToken(ctx, clientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
if err != nil {
return "", "", 0, err
}
// Delete the used refresh token
err = tx.
WithContext(ctx).
Delete(&storedRefreshToken).
Error
if err != nil {
return "", "", 0, err
}
err = tx.Commit().Error
if err != nil {
return "", "", 0, err
}
return accessToken, newRefreshToken, 3600, nil
}
func (s *OidcService) IntrospectToken(clientID, clientSecret, tokenString string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
if clientID == "" || clientSecret == "" {
return introspectDto, &common.OidcMissingClientCredentialsError{}
}
// Get the client to check if we are authorized.
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return introspectDto, &common.OidcClientSecretInvalidError{}
}
// Verify the client secret. This endpoint may not be used by public clients.
if client.IsPublic {
return introspectDto, &common.OidcClientSecretInvalidError{}
}
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret)); err != nil {
return introspectDto, &common.OidcClientSecretInvalidError{}
}
token, err := s.jwtService.VerifyOauthAccessToken(tokenString)
if err != nil {
if errors.Is(err, jwt.ParseError()) {
// It's apparently not a valid JWT token, so we check if it's a valid refresh_token.
return s.introspectRefreshToken(tokenString)
}
// Every failure we get means the token is invalid. Nothing more to do with the error.
introspectDto.Active = false
return introspectDto, nil
}
introspectDto.Active = true
introspectDto.TokenType = "access_token"
if token.Has("scope") {
var asString string
var asStrings []string
if err := token.Get("scope", &asString); err == nil {
introspectDto.Scope = asString
} else if err := token.Get("scope", &asStrings); err == nil {
introspectDto.Scope = strings.Join(asStrings, " ")
}
}
if expiration, hasExpiration := token.Expiration(); hasExpiration {
introspectDto.Expiration = expiration.Unix()
}
if issuedAt, hasIssuedAt := token.IssuedAt(); hasIssuedAt {
introspectDto.IssuedAt = issuedAt.Unix()
}
if notBefore, hasNotBefore := token.NotBefore(); hasNotBefore {
introspectDto.NotBefore = notBefore.Unix()
}
if subject, hasSubject := token.Subject(); hasSubject {
introspectDto.Subject = subject
}
if audience, hasAudience := token.Audience(); hasAudience {
introspectDto.Audience = audience
}
if issuer, hasIssuer := token.Issuer(); hasIssuer {
introspectDto.Issuer = issuer
}
if identifier, hasIdentifier := token.JwtID(); hasIdentifier {
introspectDto.Identifier = identifier
}
return introspectDto, nil
}
func (s *OidcService) introspectRefreshToken(refreshToken string) (introspectDto dto.OidcIntrospectionResponseDto, err error) {
var storedRefreshToken model.OidcRefreshToken
err = s.db.Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
First(&storedRefreshToken).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
introspectDto.Active = false
return introspectDto, nil
}
return introspectDto, err
}
introspectDto.Active = true
introspectDto.TokenType = "refresh_token"
return introspectDto, nil
}
func (s *OidcService) GetClient(ctx context.Context, clientID string) (model.OidcClient, error) {
return s.getClientInternal(ctx, clientID, s.db)
}
func (s *OidcService) getClientInternal(ctx context.Context, clientID string, tx *gorm.DB) (model.OidcClient, error) {
var client model.OidcClient
err := tx.
WithContext(ctx).
Preload("CreatedBy").
Preload("AllowedUserGroups").
First(&client, "id = ?", clientID).
Error
if err != nil {
return model.OidcClient{}, err return model.OidcClient{}, err
} }
return client, nil return client, nil
} }
func (s *OidcService) ListClients(searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.OidcClient, utils.PaginationResponse, error) { func (s *OidcService) ListClients(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.OidcClient, utils.PaginationResponse, error) {
var clients []model.OidcClient var clients []model.OidcClient
query := s.db.Preload("CreatedBy").Model(&model.OidcClient{}) query := s.db.
WithContext(ctx).
Preload("CreatedBy").
Model(&model.OidcClient{})
if searchTerm != "" { if searchTerm != "" {
searchPattern := "%" + searchTerm + "%" searchPattern := "%" + searchTerm + "%"
query = query.Where("name LIKE ?", searchPattern) query = query.Where("name LIKE ?", searchPattern)
@@ -226,7 +483,7 @@ func (s *OidcService) ListClients(searchTerm string, sortedPaginationRequest uti
return clients, pagination, nil return clients, pagination, nil
} }
func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) { func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
client := model.OidcClient{ client := model.OidcClient{
Name: input.Name, Name: input.Name,
CallbackURLs: input.CallbackURLs, CallbackURLs: input.CallbackURLs,
@@ -236,16 +493,30 @@ func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string)
PkceEnabled: input.IsPublic || input.PkceEnabled, PkceEnabled: input.IsPublic || input.PkceEnabled,
} }
if err := s.db.Create(&client).Error; err != nil { err := s.db.
WithContext(ctx).
Create(&client).
Error
if err != nil {
return model.OidcClient{}, err return model.OidcClient{}, err
} }
return client, nil return client, nil
} }
func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) { func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var client model.OidcClient var client model.OidcClient
if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil { err := tx.
WithContext(ctx).
Preload("CreatedBy").
First(&client, "id = ?", clientID).
Error
if err != nil {
return model.OidcClient{}, err return model.OidcClient{}, err
} }
@@ -255,29 +526,48 @@ func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDt
client.IsPublic = input.IsPublic client.IsPublic = input.IsPublic
client.PkceEnabled = input.IsPublic || input.PkceEnabled client.PkceEnabled = input.IsPublic || input.PkceEnabled
if err := s.db.Save(&client).Error; err != nil { err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return model.OidcClient{}, err
}
err = tx.Commit().Error
if err != nil {
return model.OidcClient{}, err return model.OidcClient{}, err
} }
return client, nil return client, nil
} }
func (s *OidcService) DeleteClient(clientID string) error { func (s *OidcService) DeleteClient(ctx context.Context, clientID string) error {
var client model.OidcClient var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { err := s.db.
return err WithContext(ctx).
} Where("id = ?", clientID).
Delete(&client).
if err := s.db.Delete(&client).Error; err != nil { Error
if err != nil {
return err return err
} }
return nil return nil
} }
func (s *OidcService) CreateClientSecret(clientID string) (string, error) { func (s *OidcService) CreateClientSecret(ctx context.Context, clientID string) (string, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var client model.OidcClient var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { err := tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return "", err return "", err
} }
@@ -292,16 +582,29 @@ func (s *OidcService) CreateClientSecret(clientID string) (string, error) {
} }
client.Secret = string(hashedSecret) client.Secret = string(hashedSecret)
if err := s.db.Save(&client).Error; err != nil { err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return "", err
}
err = tx.Commit().Error
if err != nil {
return "", err return "", err
} }
return clientSecret, nil return clientSecret, nil
} }
func (s *OidcService) GetClientLogo(clientID string) (string, string, error) { func (s *OidcService) GetClientLogo(ctx context.Context, clientID string) (string, string, error) {
var client model.OidcClient var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { err := s.db.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return "", "", err return "", "", err
} }
@@ -309,26 +612,35 @@ func (s *OidcService) GetClientLogo(clientID string) (string, string, error) {
return "", "", errors.New("image not found") return "", "", errors.New("image not found")
} }
imageType := *client.ImageType imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + *client.ImageType
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, imageType) mimeType := utils.GetImageMimeType(*client.ImageType)
mimeType := utils.GetImageMimeType(imageType)
return imagePath, mimeType, nil return imagePath, mimeType, nil
} }
func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHeader) error { func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, file *multipart.FileHeader) error {
fileType := utils.GetFileExtension(file.Filename) fileType := utils.GetFileExtension(file.Filename)
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" { if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
return &common.FileTypeNotSupportedError{} return &common.FileTypeNotSupportedError{}
} }
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, clientID, fileType) imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + clientID + "." + fileType
if err := utils.SaveFile(file, imagePath); err != nil { err := utils.SaveFile(file, imagePath)
if err != nil {
return err return err
} }
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var client model.OidcClient var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { err = tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return err return err
} }
@@ -340,16 +652,34 @@ func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHead
} }
client.ImageType = &fileType client.ImageType = &fileType
if err := s.db.Save(&client).Error; err != nil { err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return err
}
err = tx.Commit().Error
if err != nil {
return err return err
} }
return nil return nil
} }
func (s *OidcService) DeleteClientLogo(clientID string) error { func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var client model.OidcClient var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil { err := tx.
WithContext(ctx).
First(&client, "id = ?", clientID).
Error
if err != nil {
return err return err
} }
@@ -357,38 +687,71 @@ func (s *OidcService) DeleteClientLogo(clientID string) error {
return errors.New("image not found") return errors.New("image not found")
} }
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType) client.ImageType = nil
err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return err
}
imagePath := common.EnvConfig.UploadPath + "/oidc-client-images/" + client.ID + "." + *client.ImageType
if err := os.Remove(imagePath); err != nil { if err := os.Remove(imagePath); err != nil {
return err return err
} }
client.ImageType = nil err = tx.Commit().Error
if err := s.db.Save(&client).Error; err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (map[string]interface{}, error) { func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err
}
return claims, nil
}
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) {
var authorizedOidcClient model.UserAuthorizedOidcClient var authorizedOidcClient model.UserAuthorizedOidcClient
if err := s.db.Preload("User.UserGroups").First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).Error; err != nil { err := tx.
WithContext(ctx).
Preload("User.UserGroups").
First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).
Error
if err != nil {
return nil, err return nil, err
} }
user := authorizedOidcClient.User user := authorizedOidcClient.User
scope := authorizedOidcClient.Scope scopes := strings.Split(authorizedOidcClient.Scope, " ")
claims := map[string]interface{}{ claims := map[string]interface{}{
"sub": user.ID, "sub": user.ID,
} }
if strings.Contains(scope, "email") { if slices.Contains(scopes, "email") {
claims["email"] = user.Email claims["email"] = user.Email
claims["email_verified"] = s.appConfigService.DbConfig.EmailsVerified.Value == "true" claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
} }
if strings.Contains(scope, "groups") { if slices.Contains(scopes, "groups") {
userGroups := make([]string, len(user.UserGroups)) userGroups := make([]string, len(user.UserGroups))
for i, group := range user.UserGroups { for i, group := range user.UserGroups {
userGroups[i] = group.Name userGroups[i] = group.Name
@@ -401,17 +764,17 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
"family_name": user.LastName, "family_name": user.LastName,
"name": user.FullName(), "name": user.FullName(),
"preferred_username": user.Username, "preferred_username": user.Username,
"picture": fmt.Sprintf("%s/api/users/%s/profile-picture.png", common.EnvConfig.AppURL, user.ID), "picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
} }
if strings.Contains(scope, "profile") { if slices.Contains(scopes, "profile") {
// Add profile claims // Add profile claims
for k, v := range profileClaims { for k, v := range profileClaims {
claims[k] = v claims[k] = v
} }
// Add custom claims // Add custom claims
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(userID) customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, userID, tx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -419,8 +782,8 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
for _, customClaim := range customClaims { for _, customClaim := range customClaims {
// The value of the custom claim can be a JSON object or a string // The value of the custom claim can be a JSON object or a string
var jsonValue interface{} var jsonValue interface{}
json.Unmarshal([]byte(customClaim.Value), &jsonValue) err := json.Unmarshal([]byte(customClaim.Value), &jsonValue)
if jsonValue != nil { if err == nil {
// It's JSON so we store it as an object // It's JSON so we store it as an object
claims[customClaim.Key] = jsonValue claims[customClaim.Key] = jsonValue
} else { } else {
@@ -429,15 +792,21 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
} }
} }
} }
if strings.Contains(scope, "email") {
if slices.Contains(scopes, "email") {
claims["email"] = user.Email claims["email"] = user.Email
} }
return claims, nil return claims, nil
} }
func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) { func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) {
client, err = s.GetClient(id) tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
client, err = s.getClientInternal(ctx, id, tx)
if err != nil { if err != nil {
return model.OidcClient{}, err return model.OidcClient{}, err
} }
@@ -445,18 +814,37 @@ func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAll
// Fetch the user groups based on UserGroupIDs in input // Fetch the user groups based on UserGroupIDs in input
var groups []model.UserGroup var groups []model.UserGroup
if len(input.UserGroupIDs) > 0 { if len(input.UserGroupIDs) > 0 {
if err := s.db.Where("id IN (?)", input.UserGroupIDs).Find(&groups).Error; err != nil { err = tx.
WithContext(ctx).
Where("id IN (?)", input.UserGroupIDs).
Find(&groups).
Error
if err != nil {
return model.OidcClient{}, err return model.OidcClient{}, err
} }
} }
// Replace the current user groups with the new set of user groups // Replace the current user groups with the new set of user groups
if err := s.db.Model(&client).Association("AllowedUserGroups").Replace(groups); err != nil { err = tx.
WithContext(ctx).
Model(&client).
Association("AllowedUserGroups").
Replace(groups)
if err != nil {
return model.OidcClient{}, err return model.OidcClient{}, err
} }
// Save the updated client // Save the updated client
if err := s.db.Save(&client).Error; err != nil { err = tx.
WithContext(ctx).
Save(&client).
Error
if err != nil {
return model.OidcClient{}, err
}
err = tx.Commit().Error
if err != nil {
return model.OidcClient{}, err return model.OidcClient{}, err
} }
@@ -464,28 +852,36 @@ func (s *OidcService) UpdateAllowedUserGroups(id string, input dto.OidcUpdateAll
} }
// ValidateEndSession returns the logout callback URL for the client if all the validations pass // ValidateEndSession returns the logout callback URL for the client if all the validations pass
func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string) (string, error) { func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogoutDto, userID string) (string, error) {
// If no ID token hint is provided, return an error // If no ID token hint is provided, return an error
if input.IdTokenHint == "" { if input.IdTokenHint == "" {
return "", &common.TokenInvalidError{} return "", &common.TokenInvalidError{}
} }
// If the ID token hint is provided, verify the ID token // If the ID token hint is provided, verify the ID token
claims, err := s.jwtService.VerifyIdToken(input.IdTokenHint) // Here we also accept expired ID tokens, which are fine per spec
token, err := s.jwtService.VerifyIdToken(input.IdTokenHint, true)
if err != nil { if err != nil {
return "", &common.TokenInvalidError{} return "", &common.TokenInvalidError{}
} }
// If the client ID is provided check if the client ID in the ID token matches the client ID in the request // If the client ID is provided check if the client ID in the ID token matches the client ID in the request
if input.ClientId != "" && claims.Audience[0] != input.ClientId { clientID, ok := token.Audience()
if !ok || len(clientID) == 0 {
return "", &common.TokenInvalidError{}
}
if input.ClientId != "" && clientID[0] != input.ClientId {
return "", &common.OidcClientIdNotMatchingError{} return "", &common.OidcClientIdNotMatchingError{}
} }
clientId := claims.Audience[0]
// Check if the user has authorized the client before // Check if the user has authorized the client before
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
if err := s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientId, userID).Error; err != nil { err = s.db.
WithContext(ctx).
Preload("Client").
First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientID[0], userID).
Error
if err != nil {
return "", &common.OidcMissingAuthorizationError{} return "", &common.OidcMissingAuthorizationError{}
} }
@@ -503,7 +899,7 @@ func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string)
} }
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) { func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string, tx *gorm.DB) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(32) randomString, err := utils.GenerateRandomAlphanumericString(32)
if err != nil { if err != nil {
return "", err return "", err
@@ -522,7 +918,11 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc
CodeChallengeMethodSha256: &codeChallengeMethodSha256, CodeChallengeMethodSha256: &codeChallengeMethodSha256,
} }
if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil { err = tx.
WithContext(ctx).
Create(&oidcAuthorizationCode).
Error
if err != nil {
return "", err return "", err
} }
@@ -555,7 +955,7 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
} }
for _, callbackPattern := range urls { for _, callbackPattern := range urls {
regexPattern := strings.ReplaceAll(regexp.QuoteMeta(callbackPattern), `\*`, ".*") + "$" regexPattern := "^" + strings.ReplaceAll(regexp.QuoteMeta(callbackPattern), `\*`, ".*") + "$"
matched, err := regexp.MatchString(regexPattern, inputCallbackURL) matched, err := regexp.MatchString(regexPattern, inputCallbackURL)
if err != nil { if err != nil {
return "", err return "", err
@@ -567,3 +967,32 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
return "", &common.OidcInvalidCallbackURLError{} return "", &common.OidcInvalidCallbackURLError{}
} }
func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, tx *gorm.DB) (string, error) {
refreshToken, err := utils.GenerateRandomAlphanumericString(40)
if err != nil {
return "", err
}
// Compute the hash of the refresh token to store in the DB
// Refresh tokens are pretty long already, so a "simple" SHA-256 hash is enough
refreshTokenHash := utils.CreateSha256Hash(refreshToken)
m := model.OidcRefreshToken{
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days
Token: refreshTokenHash,
ClientID: clientID,
UserID: userID,
Scope: scope,
}
err = tx.
WithContext(ctx).
Create(&m).
Error
if err != nil {
return "", err
}
return refreshToken, nil
}

View File

@@ -1,13 +1,15 @@
package service package service
import ( import (
"context"
"errors" "errors"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
) )
type UserGroupService struct { type UserGroupService struct {
@@ -19,8 +21,11 @@ func NewUserGroupService(db *gorm.DB, appConfigService *AppConfigService) *UserG
return &UserGroupService{db: db, appConfigService: appConfigService} return &UserGroupService{db: db, appConfigService: appConfigService}
} }
func (s *UserGroupService) List(name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) { func (s *UserGroupService) List(ctx context.Context, name string, sortedPaginationRequest utils.SortedPaginationRequest) (groups []model.UserGroup, response utils.PaginationResponse, err error) {
query := s.db.Preload("CustomClaims").Model(&model.UserGroup{}) query := s.db.
WithContext(ctx).
Preload("CustomClaims").
Model(&model.UserGroup{})
if name != "" { if name != "" {
query = query.Where("name LIKE ?", "%"+name+"%") query = query.Where("name LIKE ?", "%"+name+"%")
@@ -42,26 +47,58 @@ func (s *UserGroupService) List(name string, sortedPaginationRequest utils.Sorte
return groups, response, err return groups, response, err
} }
func (s *UserGroupService) Get(id string) (group model.UserGroup, err error) { func (s *UserGroupService) Get(ctx context.Context, id string) (group model.UserGroup, err error) {
err = s.db.Where("id = ?", id).Preload("CustomClaims").Preload("Users").First(&group).Error return s.getInternal(ctx, id, s.db)
}
func (s *UserGroupService) getInternal(ctx context.Context, id string, tx *gorm.DB) (group model.UserGroup, err error) {
err = tx.
WithContext(ctx).
Where("id = ?", id).
Preload("CustomClaims").
Preload("Users").
First(&group).
Error
return group, err return group, err
} }
func (s *UserGroupService) Delete(id string) error { func (s *UserGroupService) Delete(ctx context.Context, id string) error {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var group model.UserGroup var group model.UserGroup
if err := s.db.Where("id = ?", id).First(&group).Error; err != nil { err := tx.
WithContext(ctx).
Where("id = ?", id).
First(&group).
Error
if err != nil {
return err return err
} }
// Disallow deleting the group if it is an LDAP group and LDAP is enabled // Disallow deleting the group if it is an LDAP group and LDAP is enabled
if group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" { if group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return &common.LdapUserGroupUpdateError{} return &common.LdapUserGroupUpdateError{}
} }
return s.db.Delete(&group).Error err = tx.
WithContext(ctx).
Delete(&group).
Error
if err != nil {
return err
}
return tx.Commit().Error
} }
func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.UserGroup, err error) { func (s *UserGroupService) Create(ctx context.Context, input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
return s.createInternal(ctx, input, s.db)
}
func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGroupCreateDto, tx *gorm.DB) (group model.UserGroup, err error) {
group = model.UserGroup{ group = model.UserGroup{
FriendlyName: input.FriendlyName, FriendlyName: input.FriendlyName,
Name: input.Name, Name: input.Name,
@@ -71,7 +108,12 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use
group.LdapID = &input.LdapID group.LdapID = &input.LdapID
} }
if err := s.db.Preload("Users").Create(&group).Error; err != nil { err = tx.
WithContext(ctx).
Preload("Users").
Create(&group).
Error
if err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) { if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"} return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
} }
@@ -80,31 +122,73 @@ func (s *UserGroupService) Create(input dto.UserGroupCreateDto) (group model.Use
return group, nil return group, nil
} }
func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allowLdapUpdate bool) (group model.UserGroup, err error) { func (s *UserGroupService) Update(ctx context.Context, id string, input dto.UserGroupCreateDto) (group model.UserGroup, err error) {
group, err = s.Get(id) tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
group, err = s.updateInternal(ctx, id, input, false, tx)
if err != nil {
return model.UserGroup{}, err
}
err = tx.Commit().Error
if err != nil {
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) updateInternal(ctx context.Context, id string, input dto.UserGroupCreateDto, isLdapSync bool, tx *gorm.DB) (group model.UserGroup, err error) {
group, err = s.getInternal(ctx, id, tx)
if err != nil { if err != nil {
return model.UserGroup{}, err return model.UserGroup{}, err
} }
// Disallow updating the group if it is an LDAP group and LDAP is enabled // Disallow updating the group if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" { if !isLdapSync && group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return model.UserGroup{}, &common.LdapUserGroupUpdateError{} return model.UserGroup{}, &common.LdapUserGroupUpdateError{}
} }
group.Name = input.Name group.Name = input.Name
group.FriendlyName = input.FriendlyName group.FriendlyName = input.FriendlyName
if err := s.db.Preload("Users").Save(&group).Error; err != nil { err = tx.
if errors.Is(err, gorm.ErrDuplicatedKey) { WithContext(ctx).
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"} Preload("Users").
} Save(&group).
Error
if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.UserGroup{}, &common.AlreadyInUseError{Property: "name"}
} else if err != nil {
return model.UserGroup{}, err return model.UserGroup{}, err
} }
return group, nil return group, nil
} }
func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model.UserGroup, err error) { func (s *UserGroupService) UpdateUsers(ctx context.Context, id string, userIds []string) (group model.UserGroup, err error) {
group, err = s.Get(id) tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
group, err = s.updateUsersInternal(ctx, id, userIds, tx)
if err != nil {
return model.UserGroup{}, err
}
err = tx.Commit().Error
if err != nil {
return model.UserGroup{}, err
}
return group, nil
}
func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, userIds []string, tx *gorm.DB) (group model.UserGroup, err error) {
group, err = s.getInternal(ctx, id, tx)
if err != nil { if err != nil {
return model.UserGroup{}, err return model.UserGroup{}, err
} }
@@ -112,28 +196,59 @@ func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model
// Fetch the users based on the userIds // Fetch the users based on the userIds
var users []model.User var users []model.User
if len(userIds) > 0 { if len(userIds) > 0 {
if err := s.db.Where("id IN (?)", userIds).Find(&users).Error; err != nil { err := tx.
WithContext(ctx).
Where("id IN (?)", userIds).
Find(&users).
Error
if err != nil {
return model.UserGroup{}, err return model.UserGroup{}, err
} }
} }
// Replace the current users with the new set of users // Replace the current users with the new set of users
if err := s.db.Model(&group).Association("Users").Replace(users); err != nil { err = tx.
WithContext(ctx).
Model(&group).
Association("Users").
Replace(users)
if err != nil {
return model.UserGroup{}, err return model.UserGroup{}, err
} }
// Save the updated group // Save the updated group
if err := s.db.Save(&group).Error; err != nil { err = tx.
WithContext(ctx).
Save(&group).
Error
if err != nil {
return model.UserGroup{}, err return model.UserGroup{}, err
} }
return group, nil return group, nil
} }
func (s *UserGroupService) GetUserCountOfGroup(id string) (int64, error) { func (s *UserGroupService) GetUserCountOfGroup(ctx context.Context, id string) (int64, error) {
// We only perform select queries here, so we can rollback in all cases
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var group model.UserGroup var group model.UserGroup
if err := s.db.Preload("Users").Where("id = ?", id).First(&group).Error; err != nil { err := tx.
WithContext(ctx).
Preload("Users").
Where("id = ?", id).
First(&group).
Error
if err != nil {
return 0, err return 0, err
} }
return s.db.Model(&group).Association("Users").Count(), nil count := tx.
WithContext(ctx).
Model(&group).
Association("Users").
Count()
return count, nil
} }

View File

@@ -1,6 +1,8 @@
package service package service
import ( import (
"bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -11,7 +13,7 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image" "gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto" "github.com/pocket-id/pocket-id/backend/internal/dto"
@@ -19,7 +21,7 @@ import (
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/pocket-id/pocket-id/backend/internal/utils/email" "github.com/pocket-id/pocket-id/backend/internal/utils/email"
"gorm.io/gorm" profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
) )
type UserService struct { type UserService struct {
@@ -34,59 +36,110 @@ func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditL
return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService, emailService: emailService, appConfigService: appConfigService} return &UserService{db: db, jwtService: jwtService, auditLogService: auditLogService, emailService: emailService, appConfigService: appConfigService}
} }
func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) { func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) {
var users []model.User var users []model.User
query := s.db.Model(&model.User{}) query := s.db.WithContext(ctx).
Model(&model.User{}).
Preload("UserGroups").
Preload("CustomClaims")
if searchTerm != "" { if searchTerm != "" {
searchPattern := "%" + searchTerm + "%" searchPattern := "%" + searchTerm + "%"
query = query.Where("email LIKE ? OR first_name LIKE ? OR username LIKE ?", searchPattern, searchPattern, searchPattern) query = query.Where("email LIKE ? OR first_name LIKE ? OR last_name LIKE ? OR username LIKE ?",
searchPattern, searchPattern, searchPattern, searchPattern)
} }
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &users) pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &users)
return users, pagination, err return users, pagination, err
} }
func (s *UserService) GetUser(userID string) (model.User, error) { func (s *UserService) GetUser(ctx context.Context, userID string) (model.User, error) {
return s.getUserInternal(ctx, userID, s.db)
}
func (s *UserService) getUserInternal(ctx context.Context, userID string, tx *gorm.DB) (model.User, error) {
var user model.User var user model.User
err := s.db.Preload("UserGroups").Preload("CustomClaims").Where("id = ?", userID).First(&user).Error err := tx.
WithContext(ctx).
Preload("UserGroups").
Preload("CustomClaims").
Where("id = ?", userID).
First(&user).
Error
return user, err return user, err
} }
func (s *UserService) GetProfilePicture(userID string) (io.Reader, int64, error) { func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.ReadCloser, int64, error) {
// Validate the user ID to prevent directory traversal // Validate the user ID to prevent directory traversal
if err := uuid.Validate(userID); err != nil { if err := uuid.Validate(userID); err != nil {
return nil, 0, &common.InvalidUUIDError{} return nil, 0, &common.InvalidUUIDError{}
} }
profilePicturePath := fmt.Sprintf("%s/profile-pictures/%s.png", common.EnvConfig.UploadPath, userID) // First check for a custom uploaded profile picture (userID.png)
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
file, err := os.Open(profilePicturePath) file, err := os.Open(profilePicturePath)
if err == nil { if err == nil {
// Get the file size // Get the file size
fileInfo, err := file.Stat() fileInfo, err := file.Stat()
if err != nil { if err != nil {
file.Close()
return nil, 0, err return nil, 0, err
} }
return file, fileInfo.Size(), nil return file, fileInfo.Size(), nil
} }
// If the file does not exist, return the default profile picture // If no custom picture exists, get the user's data for creating initials
user, err := s.GetUser(userID) user, err := s.GetUser(ctx, userID)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
defaultPicture, err := profilepicture.CreateDefaultProfilePicture(user.FirstName, user.LastName) // Check if we have a cached default picture for these initials
defaultProfilePicturesDir := common.EnvConfig.UploadPath + "/profile-pictures/defaults/"
defaultPicturePath := defaultProfilePicturesDir + user.Initials() + ".png"
file, err = os.Open(defaultPicturePath)
if err == nil {
fileInfo, err := file.Stat()
if err != nil {
file.Close()
return nil, 0, err
}
return file, fileInfo.Size(), nil
}
// If no cached default picture exists, create one and save it for future use
defaultPicture, err := profilepicture.CreateDefaultProfilePicture(user.Initials())
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
return defaultPicture, int64(defaultPicture.Len()), nil // Save the default picture for future use (in a goroutine to avoid blocking)
defaultPictureBytes := defaultPicture.Bytes()
go func() {
// Ensure the directory exists
err = os.MkdirAll(defaultProfilePicturesDir, os.ModePerm)
if err != nil {
log.Printf("Failed to create directory for default profile picture: %v", err)
return
}
if err := utils.SaveFileStream(bytes.NewReader(defaultPictureBytes), defaultPicturePath); err != nil {
log.Printf("Failed to cache default profile picture for initials %s: %v", user.Initials(), err)
}
}()
return io.NopCloser(bytes.NewReader(defaultPictureBytes)), int64(defaultPicture.Len()), nil
} }
func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) { func (s *UserService) GetUserGroups(ctx context.Context, userID string) ([]model.UserGroup, error) {
var user model.User var user model.User
if err := s.db.Preload("UserGroups").Where("id = ?", userID).First(&user).Error; err != nil { err := s.db.
WithContext(ctx).
Preload("UserGroups").
Where("id = ?", userID).
First(&user).
Error
if err != nil {
return nil, err return nil, err
} }
return user.UserGroups, nil return user.UserGroups, nil
@@ -94,7 +147,8 @@ func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) {
func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error { func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error {
// Validate the user ID to prevent directory traversal // Validate the user ID to prevent directory traversal
if err := uuid.Validate(userID); err != nil { err := uuid.Validate(userID)
if err != nil {
return &common.InvalidUUIDError{} return &common.InvalidUUIDError{}
} }
@@ -105,20 +159,14 @@ func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error
} }
// Ensure the directory exists // Ensure the directory exists
profilePictureDir := fmt.Sprintf("%s/profile-pictures", common.EnvConfig.UploadPath) profilePictureDir := common.EnvConfig.UploadPath + "/profile-pictures"
if err := os.MkdirAll(profilePictureDir, os.ModePerm); err != nil { err = os.MkdirAll(profilePictureDir, os.ModePerm)
if err != nil {
return err return err
} }
// Create the profile picture file // Create the profile picture file
createdProfilePicture, err := os.Create(fmt.Sprintf("%s/%s.png", profilePictureDir, userID)) err = utils.SaveFileStream(profilePicture, profilePictureDir+"/"+userID+".png")
if err != nil {
return err
}
defer createdProfilePicture.Close()
// Copy the image to the file
_, err = io.Copy(createdProfilePicture, profilePicture)
if err != nil { if err != nil {
return err return err
} }
@@ -126,55 +174,145 @@ func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error
return nil return nil
} }
func (s *UserService) DeleteUser(userID string) error { func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error {
tx := s.db.Begin()
var user model.User var user model.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { if err := tx.WithContext(ctx).First(&user, "id = ?", userID).Error; err != nil {
tx.Rollback()
return err return err
} }
// Disallow deleting the user if it is an LDAP user and LDAP is enabled // Only soft-delete if user is LDAP and soft-delete is enabled and not allowing hard delete
if user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" { if user.LdapID != nil && s.appConfigService.GetDbConfig().LdapSoftDeleteUsers.IsTrue() && !allowLdapDelete {
if !user.Disabled {
tx.Rollback()
return fmt.Errorf("LDAP user must be disabled before deletion")
}
}
// Otherwise, hard delete (local users or LDAP users when allowed)
if err := s.deleteUserInternal(ctx, userID, allowLdapDelete, tx); err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
func (s *UserService) deleteUserInternal(ctx context.Context, userID string, allowLdapDelete bool, tx *gorm.DB) error {
var user model.User
err := tx.
WithContext(ctx).
Where("id = ?", userID).
First(&user).
Error
if err != nil {
return fmt.Errorf("failed to load user to delete: %w", err)
}
// Disallow deleting the user if it is an LDAP user, LDAP is enabled, and the user is not disabled
if !allowLdapDelete && !user.Disabled && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return &common.LdapUserUpdateError{} return &common.LdapUserUpdateError{}
} }
// Delete the profile picture // Delete the profile picture
profilePicturePath := fmt.Sprintf("%s/profile-pictures/%s.png", common.EnvConfig.UploadPath, userID) profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
if err := os.Remove(profilePicturePath); err != nil && !os.IsNotExist(err) { err = os.Remove(profilePicturePath)
if err != nil && !os.IsNotExist(err) {
return err return err
} }
return s.db.Delete(&user).Error err = tx.WithContext(ctx).Delete(&user).Error
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
return nil
} }
func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) { func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (model.User, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
user, err := s.createUserInternal(ctx, input, false, tx)
if err != nil {
return model.User{}, err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, isLdapSync bool, tx *gorm.DB) (model.User, error) {
user := model.User{ user := model.User{
FirstName: input.FirstName, FirstName: input.FirstName,
LastName: input.LastName, LastName: input.LastName,
Email: input.Email, Email: input.Email,
Username: input.Username, Username: input.Username,
IsAdmin: input.IsAdmin, IsAdmin: input.IsAdmin,
Locale: input.Locale,
} }
if input.LdapID != "" { if input.LdapID != "" {
user.LdapID = &input.LdapID user.LdapID = &input.LdapID
} }
if err := s.db.Create(&user).Error; err != nil { err := tx.WithContext(ctx).Create(&user).Error
if errors.Is(err, gorm.ErrDuplicatedKey) { if errors.Is(err, gorm.ErrDuplicatedKey) {
return model.User{}, s.checkDuplicatedFields(user) // Do not follow this path if we're using LDAP, as we don't want to roll-back the transaction here
if !isLdapSync {
tx.Rollback()
// If we are here, the transaction is already aborted due to an error, so we pass s.db
err = s.checkDuplicatedFields(ctx, user, s.db)
} else {
err = s.checkDuplicatedFields(ctx, user, tx)
} }
return model.User{}, err
} else if err != nil {
return model.User{}, err return model.User{}, err
} }
return user, nil return user, nil
} }
func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool) (model.User, error) { func (s *UserService) UpdateUser(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, allowLdapUpdate bool) (model.User, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
user, err := s.updateUserInternal(ctx, userID, updatedUser, updateOwnUser, allowLdapUpdate, tx)
if err != nil {
return model.User{}, err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, isLdapSync bool, tx *gorm.DB) (model.User, error) {
var user model.User var user model.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil { err := tx.
WithContext(ctx).
Where("id = ?", userID).
First(&user).
Error
if err != nil {
return model.User{}, err return model.User{}, err
} }
// Disallow updating the user if it is an LDAP group and LDAP is enabled // Disallow updating the user if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" { if !isLdapSync && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return model.User{}, &common.LdapUserUpdateError{} return model.User{}, &common.LdapUserUpdateError{}
} }
@@ -182,28 +320,53 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u
user.LastName = updatedUser.LastName user.LastName = updatedUser.LastName
user.Email = updatedUser.Email user.Email = updatedUser.Email
user.Username = updatedUser.Username user.Username = updatedUser.Username
user.Locale = updatedUser.Locale
if !updateOwnUser { if !updateOwnUser {
user.IsAdmin = updatedUser.IsAdmin user.IsAdmin = updatedUser.IsAdmin
user.Disabled = updatedUser.Disabled
} }
if err := s.db.Save(&user).Error; err != nil { err = tx.
if errors.Is(err, gorm.ErrDuplicatedKey) { WithContext(ctx).
return user, s.checkDuplicatedFields(user) Save(&user).
Error
if errors.Is(err, gorm.ErrDuplicatedKey) {
// Do not follow this path if we're using LDAP, as we don't want to roll-back the transaction here
if !isLdapSync {
tx.Rollback()
// If we are here, the transaction is already aborted due to an error, so we pass s.db
err = s.checkDuplicatedFields(ctx, user, s.db)
} else {
err = s.checkDuplicatedFields(ctx, user, tx)
} }
return user, err
} else if err != nil {
return user, err return user, err
} }
return user, nil return user, nil
} }
func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath string) error { func (s *UserService) RequestOneTimeAccessEmailAsAdmin(ctx context.Context, userID string, expiration time.Time) error {
isDisabled := s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.Value != "true" isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessAsAdminEnabled.IsTrue()
if isDisabled { if isDisabled {
return &common.OneTimeAccessDisabledError{} return &common.OneTimeAccessDisabledError{}
} }
var user model.User return s.requestOneTimeAccessEmailInternal(ctx, userID, "", expiration)
if err := s.db.Where("email = ?", emailAddress).First(&user).Error; err != nil {
}
func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context.Context, userID, redirectPath string) error {
isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessAsUnauthenticatedEnabled.IsTrue()
if isDisabled {
return &common.OneTimeAccessDisabledError{}
}
var userId string
err := s.db.Model(&model.User{}).Select("id").Where("email = ?", userID).First(&userId).Error
if err != nil {
// Do not return error if user not found to prevent email enumeration // Do not return error if user not found to prevent email enumeration
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil return nil
@@ -212,42 +375,70 @@ func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath strin
} }
} }
oneTimeAccessToken, err := s.CreateOneTimeAccessToken(user.ID, time.Now().Add(15*time.Minute)) expiration := time.Now().Add(15 * time.Minute)
return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, expiration)
}
func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, userID, redirectPath string, expiration time.Time) error {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
user, err := s.GetUser(ctx, userID)
if err != nil { if err != nil {
return err return err
} }
link := fmt.Sprintf("%s/lc", common.EnvConfig.AppURL) oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, expiration, tx)
linkWithCode := fmt.Sprintf("%s/%s", link, oneTimeAccessToken) if err != nil {
return err
// Add redirect path to the link
if strings.HasPrefix(redirectPath, "/") {
encodedRedirectPath := url.QueryEscape(redirectPath)
linkWithCode = fmt.Sprintf("%s?redirect=%s", linkWithCode, encodedRedirectPath)
} }
err = tx.Commit().Error
if err != nil {
return err
}
// We use a background context here as this is running in a goroutine
//nolint:contextcheck
go func() { go func() {
err := SendEmail(s.emailService, email.Address{ innerCtx := context.Background()
Name: user.Username,
link := common.EnvConfig.AppURL + "/lc"
linkWithCode := link + "/" + oneTimeAccessToken
// Add redirect path to the link
if strings.HasPrefix(redirectPath, "/") {
encodedRedirectPath := url.QueryEscape(redirectPath)
linkWithCode = linkWithCode + "?redirect=" + encodedRedirectPath
}
errInternal := SendEmail(innerCtx, s.emailService, email.Address{
Name: user.FullName(),
Email: user.Email, Email: user.Email,
}, OneTimeAccessTemplate, &OneTimeAccessTemplateData{ }, OneTimeAccessTemplate, &OneTimeAccessTemplateData{
Code: oneTimeAccessToken, Code: oneTimeAccessToken,
LoginLink: link, LoginLink: link,
LoginLinkWithCode: linkWithCode, LoginLinkWithCode: linkWithCode,
ExpirationString: utils.DurationToString(time.Until(expiration).Round(time.Second)),
}) })
if err != nil { if errInternal != nil {
log.Printf("Failed to send email to '%s': %v\n", user.Email, err) log.Printf("Failed to send email to '%s': %v\n", user.Email, errInternal)
} }
}() }()
return nil return nil
} }
func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time) (string, error) { func (s *UserService) CreateOneTimeAccessToken(ctx context.Context, userID string, expiresAt time.Time) (string, error) {
tokenLength := 16 return s.createOneTimeAccessTokenInternal(ctx, userID, expiresAt, s.db)
}
func (s *UserService) createOneTimeAccessTokenInternal(ctx context.Context, userID string, expiresAt time.Time, tx *gorm.DB) (string, error) {
// If expires at is less than 15 minutes, use an 6 character token instead of 16 // If expires at is less than 15 minutes, use an 6 character token instead of 16
if expiresAt.Sub(time.Now()) <= 15*time.Minute { tokenLength := 16
if time.Until(expiresAt) <= 15*time.Minute {
tokenLength = 6 tokenLength = 6
} }
@@ -262,16 +453,26 @@ func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Tim
Token: randomString, Token: randomString,
} }
if err := s.db.Create(&oneTimeAccessToken).Error; err != nil { if err := tx.WithContext(ctx).Create(&oneTimeAccessToken).Error; err != nil {
return "", err return "", err
} }
return oneTimeAccessToken.Token, nil return oneTimeAccessToken.Token, nil
} }
func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAgent string) (model.User, string, error) { func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token string, ipAddress, userAgent string) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var oneTimeAccessToken model.OneTimeAccessToken var oneTimeAccessToken model.OneTimeAccessToken
if err := s.db.Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil { err := tx.
WithContext(ctx).
Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").
First(&oneTimeAccessToken).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return model.User{}, "", &common.TokenInvalidOrExpiredError{} return model.User{}, "", &common.TokenInvalidOrExpiredError{}
} }
@@ -282,19 +483,33 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAg
return model.User{}, "", err return model.User{}, "", err
} }
if err := s.db.Delete(&oneTimeAccessToken).Error; err != nil { err = tx.
WithContext(ctx).
Delete(&oneTimeAccessToken).
Error
if err != nil {
return model.User{}, "", err return model.User{}, "", err
} }
if ipAddress != "" && userAgent != "" { if ipAddress != "" && userAgent != "" {
s.auditLogService.Create(model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}) s.auditLogService.Create(ctx, model.AuditLogEventOneTimeAccessTokenSignIn, ipAddress, userAgent, oneTimeAccessToken.User.ID, model.AuditLogData{}, tx)
}
err = tx.Commit().Error
if err != nil {
return model.User{}, "", err
} }
return oneTimeAccessToken.User, accessToken, nil return oneTimeAccessToken.User, accessToken, nil
} }
func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user model.User, err error) { func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroupIds []string) (user model.User, err error) {
user, err = s.GetUser(id) tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
user, err = s.getUserInternal(ctx, id, tx)
if err != nil { if err != nil {
return model.User{}, err return model.User{}, err
} }
@@ -302,27 +517,48 @@ func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user m
// Fetch the groups based on userGroupIds // Fetch the groups based on userGroupIds
var groups []model.UserGroup var groups []model.UserGroup
if len(userGroupIds) > 0 { if len(userGroupIds) > 0 {
if err := s.db.Where("id IN (?)", userGroupIds).Find(&groups).Error; err != nil { err = tx.
WithContext(ctx).
Where("id IN (?)", userGroupIds).
Find(&groups).
Error
if err != nil {
return model.User{}, err return model.User{}, err
} }
} }
// Replace the current groups with the new set of groups // Replace the current groups with the new set of groups
if err := s.db.Model(&user).Association("UserGroups").Replace(groups); err != nil { err = tx.
WithContext(ctx).
Model(&user).
Association("UserGroups").
Replace(groups)
if err != nil {
return model.User{}, err return model.User{}, err
} }
// Save the updated user // Save the updated user
if err := s.db.Save(&user).Error; err != nil { err = tx.WithContext(ctx).Save(&user).Error
if err != nil {
return model.User{}, err
}
err = tx.Commit().Error
if err != nil {
return model.User{}, err return model.User{}, err
} }
return user, nil return user, nil
} }
func (s *UserService) SetupInitialAdmin() (model.User, string, error) { func (s *UserService) SetupInitialAdmin(ctx context.Context) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var userCount int64 var userCount int64
if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil { if err := tx.WithContext(ctx).Model(&model.User{}).Count(&userCount).Error; err != nil {
return model.User{}, "", err return model.User{}, "", err
} }
if userCount > 1 { if userCount > 1 {
@@ -337,7 +573,7 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
IsAdmin: true, IsAdmin: true,
} }
if err := s.db.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil { if err := tx.WithContext(ctx).Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
return model.User{}, "", err return model.User{}, "", err
} }
@@ -350,16 +586,39 @@ func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
return model.User{}, "", err return model.User{}, "", err
} }
err = tx.Commit().Error
if err != nil {
return model.User{}, "", err
}
return user, token, nil return user, token, nil
} }
func (s *UserService) checkDuplicatedFields(user model.User) error { func (s *UserService) checkDuplicatedFields(ctx context.Context, user model.User, tx *gorm.DB) error {
var existingUser model.User var result struct {
if s.db.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil { Found bool
}
err := tx.
WithContext(ctx).
Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND email = ?) AS found`, user.ID, user.Email).
First(&result).
Error
if err != nil {
return err
}
if result.Found {
return &common.AlreadyInUseError{Property: "email"} return &common.AlreadyInUseError{Property: "email"}
} }
if s.db.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil { err = tx.
WithContext(ctx).
Raw(`SELECT EXISTS(SELECT 1 FROM users WHERE id != ? AND username = ?) AS found`, user.ID, user.Username).
First(&result).
Error
if err != nil {
return err
}
if result.Found {
return &common.AlreadyInUseError{Property: "username"} return &common.AlreadyInUseError{Property: "username"}
} }
@@ -374,7 +633,7 @@ func (s *UserService) ResetProfilePicture(userID string) error {
} }
// Build path to profile picture // Build path to profile picture
profilePicturePath := fmt.Sprintf("%s/profile-pictures/%s.png", common.EnvConfig.UploadPath, userID) profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
// Check if file exists and delete it // Check if file exists and delete it
if _, err := os.Stat(profilePicturePath); err == nil { if _, err := os.Stat(profilePicturePath); err == nil {
@@ -389,3 +648,11 @@ func (s *UserService) ResetProfilePicture(userID string) error {
return nil return nil
} }
func (s *UserService) DisableUser(ctx context.Context, userID string, tx *gorm.DB) error {
return tx.WithContext(ctx).
Model(&model.User{}).
Where("id = ?", userID).
Update("disabled", true).
Error
}

View File

@@ -1,16 +1,19 @@
package service package service
import ( import (
"context"
"fmt"
"net/http" "net/http"
"time" "time"
"github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn" "github.com/go-webauthn/webauthn/webauthn"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common" "github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model" "github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types" datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils" "github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
) )
type WebAuthnService struct { type WebAuthnService struct {
@@ -23,7 +26,7 @@ type WebAuthnService struct {
func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, appConfigService *AppConfigService) *WebAuthnService { func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, appConfigService *AppConfigService) *WebAuthnService {
webauthnConfig := &webauthn.Config{ webauthnConfig := &webauthn.Config{
RPDisplayName: appConfigService.DbConfig.AppName.Value, RPDisplayName: appConfigService.GetDbConfig().AppName.Value,
RPID: utils.GetHostnameFromURL(common.EnvConfig.AppURL), RPID: utils.GetHostnameFromURL(common.EnvConfig.AppURL),
RPOrigins: []string{common.EnvConfig.AppURL}, RPOrigins: []string{common.EnvConfig.AppURL},
Timeouts: webauthn.TimeoutsConfig{ Timeouts: webauthn.TimeoutsConfig{
@@ -40,18 +43,39 @@ func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *Au
}, },
} }
wa, _ := webauthn.New(webauthnConfig) wa, _ := webauthn.New(webauthnConfig)
return &WebAuthnService{db: db, webAuthn: wa, jwtService: jwtService, auditLogService: auditLogService, appConfigService: appConfigService} return &WebAuthnService{
db: db,
webAuthn: wa,
jwtService: jwtService,
auditLogService: auditLogService,
appConfigService: appConfigService,
}
} }
func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCredentialCreationOptions, error) { func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string) (*model.PublicKeyCredentialCreationOptions, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
s.updateWebAuthnConfig() s.updateWebAuthnConfig()
var user model.User var user model.User
if err := s.db.Preload("Credentials").Find(&user, "id = ?", userID).Error; err != nil { err := tx.
WithContext(ctx).
Preload("Credentials").
Find(&user, "id = ?", userID).
Error
if err != nil {
tx.Rollback()
return nil, err return nil, err
} }
options, session, err := s.webAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors())) options, session, err := s.webAuthn.BeginRegistration(
&user,
webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired),
webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()),
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -62,7 +86,16 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
UserVerification: string(session.UserVerification), UserVerification: string(session.UserVerification),
} }
if err := s.db.Create(&sessionToStore).Error; err != nil { err = tx.
WithContext(ctx).
Create(&sessionToStore).
Error
if err != nil {
return nil, err
}
err = tx.Commit().Error
if err != nil {
return nil, err return nil, err
} }
@@ -73,9 +106,18 @@ func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCred
}, nil }, nil
} }
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) { func (s *WebAuthnService) VerifyRegistration(ctx context.Context, sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var storedSession model.WebauthnSession var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil { err := tx.
WithContext(ctx).
First(&storedSession, "id = ?", sessionID).
Error
if err != nil {
return model.WebauthnCredential{}, err return model.WebauthnCredential{}, err
} }
@@ -86,7 +128,11 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
} }
var user model.User var user model.User
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil { err = tx.
WithContext(ctx).
Find(&user, "id = ?", userID).
Error
if err != nil {
return model.WebauthnCredential{}, err return model.WebauthnCredential{}, err
} }
@@ -95,8 +141,11 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
return model.WebauthnCredential{}, err return model.WebauthnCredential{}, err
} }
// Determine passkey name using AAGUID and User-Agent
passkeyName := s.determinePasskeyName(credential.Authenticator.AAGUID)
credentialToStore := model.WebauthnCredential{ credentialToStore := model.WebauthnCredential{
Name: "New Passkey", Name: passkeyName,
CredentialID: credential.ID, CredentialID: credential.ID,
AttestationType: credential.AttestationType, AttestationType: credential.AttestationType,
PublicKey: credential.PublicKey, PublicKey: credential.PublicKey,
@@ -105,14 +154,33 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
BackupEligible: credential.Flags.BackupEligible, BackupEligible: credential.Flags.BackupEligible,
BackupState: credential.Flags.BackupState, BackupState: credential.Flags.BackupState,
} }
if err := s.db.Create(&credentialToStore).Error; err != nil { err = tx.
WithContext(ctx).
Create(&credentialToStore).
Error
if err != nil {
return model.WebauthnCredential{}, err
}
err = tx.Commit().Error
if err != nil {
return model.WebauthnCredential{}, err return model.WebauthnCredential{}, err
} }
return credentialToStore, nil return credentialToStore, nil
} }
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) { func (s *WebAuthnService) determinePasskeyName(aaguid []byte) string {
// First try to identify by AAGUID using a combination of builtin + MDS
authenticatorName := utils.GetAuthenticatorName(aaguid)
if authenticatorName != "" {
return authenticatorName
}
return "New Passkey" // Default fallback
}
func (s *WebAuthnService) BeginLogin(ctx context.Context) (*model.PublicKeyCredentialRequestOptions, error) {
options, session, err := s.webAuthn.BeginDiscoverableLogin() options, session, err := s.webAuthn.BeginDiscoverableLogin()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -124,7 +192,11 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
UserVerification: string(session.UserVerification), UserVerification: string(session.UserVerification),
} }
if err := s.db.Create(&sessionToStore).Error; err != nil { err = s.db.
WithContext(ctx).
Create(&sessionToStore).
Error
if err != nil {
return nil, err return nil, err
} }
@@ -135,9 +207,18 @@ func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions
}, nil }, nil
} }
func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) { func (s *WebAuthnService) VerifyLogin(ctx context.Context, sessionID string, credentialAssertionData *protocol.ParsedCredentialAssertionData, ipAddress, userAgent string) (model.User, string, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var storedSession model.WebauthnSession var storedSession model.WebauthnSession
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil { err := tx.
WithContext(ctx).
First(&storedSession, "id = ?", sessionID).
Error
if err != nil {
return model.User{}, "", err return model.User{}, "", err
} }
@@ -147,9 +228,14 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData
} }
var user *model.User var user *model.User
_, err := s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) { _, err = s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
if err := s.db.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil { innerErr := tx.
return nil, err WithContext(ctx).
Preload("Credentials").
First(&user, "id = ?", string(userHandle)).
Error
if innerErr != nil {
return nil, innerErr
} }
return user, nil return user, nil
}, session, credentialAssertionData) }, session, credentialAssertionData)
@@ -158,46 +244,78 @@ func (s *WebAuthnService) VerifyLogin(sessionID string, credentialAssertionData
return model.User{}, "", err return model.User{}, "", err
} }
if user.Disabled {
return model.User{}, "", &common.UserDisabledError{}
}
token, err := s.jwtService.GenerateAccessToken(*user) token, err := s.jwtService.GenerateAccessToken(*user)
if err != nil { if err != nil {
return model.User{}, "", err return model.User{}, "", err
} }
s.auditLogService.CreateNewSignInWithEmail(ipAddress, userAgent, user.ID) s.auditLogService.CreateNewSignInWithEmail(ctx, ipAddress, userAgent, user.ID, tx)
err = tx.Commit().Error
if err != nil {
return model.User{}, "", err
}
return *user, token, nil return *user, token, nil
} }
func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) { func (s *WebAuthnService) ListCredentials(ctx context.Context, userID string) ([]model.WebauthnCredential, error) {
var credentials []model.WebauthnCredential var credentials []model.WebauthnCredential
if err := s.db.Find(&credentials, "user_id = ?", userID).Error; err != nil { err := s.db.
WithContext(ctx).
Find(&credentials, "user_id = ?", userID).
Error
if err != nil {
return nil, err return nil, err
} }
return credentials, nil return credentials, nil
} }
func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error { func (s *WebAuthnService) DeleteCredential(ctx context.Context, userID, credentialID string) error {
var credential model.WebauthnCredential err := s.db.
if err := s.db.First(&credential, "id = ? AND user_id = ?", credentialID, userID).Error; err != nil { WithContext(ctx).
return err Where("id = ? AND user_id = ?", credentialID, userID).
} Delete(&model.WebauthnCredential{}).
Error
if err := s.db.Delete(&credential).Error; err != nil { if err != nil {
return err return fmt.Errorf("failed to delete record: %w", err)
} }
return nil return nil
} }
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (model.WebauthnCredential, error) { func (s *WebAuthnService) UpdateCredential(ctx context.Context, userID, credentialID, name string) (model.WebauthnCredential, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
var credential model.WebauthnCredential var credential model.WebauthnCredential
if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil { err := tx.
WithContext(ctx).
Where("id = ? AND user_id = ?", credentialID, userID).
First(&credential).
Error
if err != nil {
return credential, err return credential, err
} }
credential.Name = name credential.Name = name
if err := s.db.Save(&credential).Error; err != nil { err = tx.
WithContext(ctx).
Save(&credential).
Error
if err != nil {
return credential, err
}
err = tx.Commit().Error
if err != nil {
return credential, err return credential, err
} }
@@ -206,5 +324,5 @@ func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (m
// updateWebAuthnConfig updates the WebAuthn configuration with the app name as it can change during runtime // updateWebAuthnConfig updates the WebAuthn configuration with the app name as it can change during runtime
func (s *WebAuthnService) updateWebAuthnConfig() { func (s *WebAuthnService) updateWebAuthnConfig() {
s.webAuthn.Config.RPDisplayName = s.appConfigService.DbConfig.AppName.Value s.webAuthn.Config.RPDisplayName = s.appConfigService.GetDbConfig().AppName.Value
} }

View File

@@ -0,0 +1,68 @@
package utils
import (
"encoding/hex"
"encoding/json"
"fmt"
"log"
"sync"
"github.com/pocket-id/pocket-id/backend/resources"
)
var (
aaguidMap map[string]string
aaguidMapOnce *sync.Once
)
func init() {
aaguidMapOnce = &sync.Once{}
}
// FormatAAGUID converts an AAGUID byte slice to UUID string format
func FormatAAGUID(aaguid []byte) string {
if len(aaguid) == 0 {
return ""
}
// If exactly 16 bytes, format as UUID
if len(aaguid) == 16 {
return fmt.Sprintf("%x-%x-%x-%x-%x",
aaguid[0:4], aaguid[4:6], aaguid[6:8], aaguid[8:10], aaguid[10:16])
}
// Otherwise just return as hex
return hex.EncodeToString(aaguid)
}
// GetAuthenticatorName returns the name of the authenticator for the given AAGUID
func GetAuthenticatorName(aaguid []byte) string {
aaguidStr := FormatAAGUID(aaguid)
if aaguidStr == "" {
return ""
}
// Then check JSON-sourced map
aaguidMapOnce.Do(loadAAGUIDsFromFile)
if name, ok := aaguidMap[aaguidStr]; ok {
return name + " Passkey"
}
return ""
}
// loadAAGUIDsFromFile loads AAGUID data from the embedded file system
func loadAAGUIDsFromFile() {
// Read from embedded file system
data, err := resources.FS.ReadFile("aaguids.json")
if err != nil {
log.Printf("Error reading embedded AAGUID file: %v", err)
return
}
if err := json.Unmarshal(data, &aaguidMap); err != nil {
log.Printf("Error unmarshalling AAGUID data: %v", err)
return
}
}

View File

@@ -0,0 +1,126 @@
package utils
import (
"encoding/hex"
"sync"
"testing"
)
func TestFormatAAGUID(t *testing.T) {
tests := []struct {
name string
aaguid []byte
want string
}{
{
name: "empty byte slice",
aaguid: []byte{},
want: "",
},
{
name: "16 byte slice - standard UUID",
aaguid: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10},
want: "01020304-0506-0708-090a-0b0c0d0e0f10",
},
{
name: "non-16 byte slice",
aaguid: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
want: "0102030405",
},
{
name: "specific UUID example",
aaguid: mustDecodeHex("adce000235bcc60a648b0b25f1f05503"),
want: "adce0002-35bc-c60a-648b-0b25f1f05503",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FormatAAGUID(tt.aaguid)
if got != tt.want {
t.Errorf("FormatAAGUID() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetAuthenticatorName(t *testing.T) {
// Reset the aaguidMap for testing
originalMap := aaguidMap
originalOnce := aaguidMapOnce
defer func() {
aaguidMap = originalMap
aaguidMapOnce = originalOnce
}()
// Inject a test AAGUID map
aaguidMap = map[string]string{
"adce0002-35bc-c60a-648b-0b25f1f05503": "Test Authenticator",
"00000000-0000-0000-0000-000000000000": "Zero Authenticator",
}
aaguidMapOnce = &sync.Once{}
aaguidMapOnce.Do(func() {}) // Mark as done to avoid loading from file
tests := []struct {
name string
aaguid []byte
want string
}{
{
name: "empty byte slice",
aaguid: []byte{},
want: "",
},
{
name: "known AAGUID",
aaguid: mustDecodeHex("adce000235bcc60a648b0b25f1f05503"),
want: "Test Authenticator Passkey",
},
{
name: "zero UUID",
aaguid: mustDecodeHex("00000000000000000000000000000000"),
want: "Zero Authenticator Passkey",
},
{
name: "unknown AAGUID",
aaguid: mustDecodeHex("ffffffffffffffffffffffffffffffff"),
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetAuthenticatorName(tt.aaguid)
if got != tt.want {
t.Errorf("GetAuthenticatorName() = %v, want %v", got, tt.want)
}
})
}
}
func TestLoadAAGUIDsFromFile(t *testing.T) {
// Reset the map and once flag for clean testing
aaguidMap = nil
aaguidMapOnce = &sync.Once{}
// Trigger loading of AAGUIDs by calling GetAuthenticatorName
GetAuthenticatorName([]byte{0x01, 0x02, 0x03, 0x04})
if len(aaguidMap) == 0 {
t.Error("loadAAGUIDsFromFile() failed to populate aaguidMap")
}
// Check for a few known entries that should be in the embedded file
// This test will be more brittle as it depends on the content of aaguids.json,
// but it helps verify that the loading actually worked
t.Log("AAGUID map loaded with", len(aaguidMap), "entries")
}
// Helper function to convert hex string to bytes
func mustDecodeHex(s string) []byte {
bytes, err := hex.DecodeString(s)
if err != nil {
panic("invalid hex in test: " + err.Error())
}
return bytes
}

View File

@@ -0,0 +1,52 @@
package utils
import (
"fmt"
"time"
)
// DurationToString converts a time.Duration to a human-readable string. Respects minutes, hours and days.
func DurationToString(duration time.Duration) string {
// For a duration less than a day
if duration < 24*time.Hour {
hours := int(duration.Hours())
mins := int(duration.Minutes()) % 60
switch hours {
case 0:
return fmt.Sprintf("%d minutes", mins)
case 1:
if mins == 0 {
return "1 hour"
}
return fmt.Sprintf("1 hour and %d minutes", mins)
default:
if mins == 0 {
return fmt.Sprintf("%d hours", hours)
}
return fmt.Sprintf("%d hours and %d minutes", hours, mins)
}
} else {
// For durations of a day or more
days := int(duration.Hours() / 24)
hours := int(duration.Hours()) % 24
switch hours {
case 0:
if days == 1 {
return "1 day"
}
return fmt.Sprintf("%d days", days)
case 1:
if days == 1 {
return "1 day and 1 hour"
}
return fmt.Sprintf("%d days and 1 hour", days)
default:
if days == 1 {
return fmt.Sprintf("1 day and %d hours", hours)
}
return fmt.Sprintf("%d days and %d hours", days, hours)
}
}
}

View File

@@ -170,15 +170,13 @@ func (c *Composer) String() string {
func convertRunes(str string) []string { func convertRunes(str string) []string {
var enc = make([]string, 0, len(str)) var enc = make([]string, 0, len(str))
for _, r := range []rune(str) { for _, r := range str {
if r == ' ' { switch {
case r == ' ':
enc = append(enc, "_") enc = append(enc, "_")
} else if isPrintableASCIIRune(r) && case isPrintableASCIIRune(r) && r != '=' && r != '?' && r != '_':
r != '=' &&
r != '?' &&
r != '_' {
enc = append(enc, string(r)) enc = append(enc, string(r))
} else { default:
enc = append(enc, string(toHex([]byte(string(r))))) enc = append(enc, string(toHex([]byte(string(r)))))
} }
} }
@@ -204,7 +202,7 @@ func hex(n byte) byte {
} }
func isPrintableASCII(str string) bool { func isPrintableASCII(str string) bool {
for _, r := range []rune(str) { for _, r := range str {
if !unicode.IsPrint(r) || r >= unicode.MaxASCII { if !unicode.IsPrint(r) || r >= unicode.MaxASCII {
return false return false
} }

View File

@@ -1,11 +1,14 @@
package utils package utils
import ( import (
"errors"
"fmt"
"io" "io"
"mime/multipart" "mime/multipart"
"os" "os"
"path/filepath" "path/filepath"
"github.com/google/uuid"
"github.com/pocket-id/pocket-id/backend/resources" "github.com/pocket-id/pocket-id/backend/resources"
) )
@@ -27,6 +30,8 @@ func GetImageMimeType(ext string) string {
return "image/svg+xml" return "image/svg+xml"
case "ico": case "ico":
return "image/x-icon" return "image/x-icon"
case "gif":
return "image/gif"
default: default:
return "" return ""
} }
@@ -69,14 +74,55 @@ func SaveFile(file *multipart.FileHeader, dst string) error {
return err return err
} }
out, err := os.Create(dst) return SaveFileStream(src, dst)
if err != nil { }
return err
}
defer out.Close()
_, err = io.Copy(out, src) // SaveFileStream saves a stream to a file.
return err func SaveFileStream(r io.Reader, dstFileName string) error {
// Our strategy is to save to a separate file and then rename it to override the original file
tmpFileName := dstFileName + "." + uuid.NewString() + "-tmp"
// Write to the temporary file
tmpFile, err := os.Create(tmpFileName)
if err != nil {
return fmt.Errorf("failed to open file '%s' for writing: %w", tmpFileName, err)
}
n, err := io.Copy(tmpFile, r)
if err != nil {
// Delete the temporary file; we ignore errors here
_ = tmpFile.Close()
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to write to file '%s': %w", tmpFileName, err)
}
err = tmpFile.Close()
if err != nil {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to close stream to file '%s': %w", tmpFileName, err)
}
if n == 0 {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return errors.New("no data written")
}
// Rename to the final file, which overrides existing files
// This is an atomic operation
err = os.Rename(tmpFileName, dstFileName)
if err != nil {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to rename file '%s': %w", dstFileName, err)
}
return nil
} }
// FileExists returns true if a file exists on disk and is a regular file // FileExists returns true if a file exists on disk and is a regular file

View File

@@ -3,22 +3,23 @@ package profilepicture
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/disintegration/imageorient"
"github.com/disintegration/imaging"
"github.com/pocket-id/pocket-id/backend/resources"
"golang.org/x/image/font"
"golang.org/x/image/font/opentype"
"golang.org/x/image/math/fixed"
"image" "image"
"image/color" "image/color"
"io" "io"
"strings"
"github.com/disintegration/imageorient"
"github.com/disintegration/imaging"
"golang.org/x/image/font"
"golang.org/x/image/font/opentype"
"golang.org/x/image/math/fixed"
"github.com/pocket-id/pocket-id/backend/resources"
) )
const profilePictureSize = 300 const profilePictureSize = 300
// CreateProfilePicture resizes the profile picture to a square // CreateProfilePicture resizes the profile picture to a square
func CreateProfilePicture(file io.Reader) (*bytes.Buffer, error) { func CreateProfilePicture(file io.Reader) (io.Reader, error) {
img, _, err := imageorient.Decode(file) img, _, err := imageorient.Decode(file)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to decode image: %w", err) return nil, fmt.Errorf("failed to decode image: %w", err)
@@ -26,27 +27,21 @@ func CreateProfilePicture(file io.Reader) (*bytes.Buffer, error) {
img = imaging.Fill(img, profilePictureSize, profilePictureSize, imaging.Center, imaging.Lanczos) img = imaging.Fill(img, profilePictureSize, profilePictureSize, imaging.Center, imaging.Lanczos)
var buf bytes.Buffer pr, pw := io.Pipe()
err = imaging.Encode(&buf, img, imaging.PNG) go func() {
if err != nil { err = imaging.Encode(pw, img, imaging.PNG)
return nil, fmt.Errorf("failed to encode image: %v", err) if err != nil {
} _ = pw.CloseWithError(fmt.Errorf("failed to encode image: %w", err))
return
}
pw.Close()
}()
return &buf, nil return pr, nil
} }
// CreateDefaultProfilePicture creates a profile picture with the initials // CreateDefaultProfilePicture creates a profile picture with the initials
func CreateDefaultProfilePicture(firstName, lastName string) (*bytes.Buffer, error) { func CreateDefaultProfilePicture(initials string) (*bytes.Buffer, error) {
// Get the initials
initials := ""
if len(firstName) > 0 {
initials += string(firstName[0])
}
if len(lastName) > 0 {
initials += string(lastName[0])
}
initials = strings.ToUpper(initials)
// Create a blank image with a white background // Create a blank image with a white background
img := imaging.New(profilePictureSize, profilePictureSize, color.RGBA{R: 255, G: 255, B: 255, A: 255}) img := imaging.New(profilePictureSize, profilePictureSize, color.RGBA{R: 255, G: 255, B: 255, A: 255})

View File

@@ -1,8 +1,11 @@
package utils package utils
import ( import (
"gorm.io/gorm"
"reflect" "reflect"
"strconv"
"gorm.io/gorm"
"gorm.io/gorm/clause"
) )
type PaginationResponse struct { type PaginationResponse struct {
@@ -30,15 +33,19 @@ func PaginateAndSort(sortedPaginationRequest SortedPaginationRequest, query *gor
capitalizedSortColumn := CapitalizeFirstLetter(sort.Column) capitalizedSortColumn := CapitalizeFirstLetter(sort.Column)
sortField, sortFieldFound := reflect.TypeOf(result).Elem().Elem().FieldByName(capitalizedSortColumn) sortField, sortFieldFound := reflect.TypeOf(result).Elem().Elem().FieldByName(capitalizedSortColumn)
isSortable := sortField.Tag.Get("sortable") == "true" isSortable, _ := strconv.ParseBool(sortField.Tag.Get("sortable"))
isValidSortOrder := sort.Direction == "asc" || sort.Direction == "desc" isValidSortOrder := sort.Direction == "asc" || sort.Direction == "desc"
if sortFieldFound && isSortable && isValidSortOrder { if sortFieldFound && isSortable && isValidSortOrder {
query = query.Order(CamelCaseToSnakeCase(sort.Column) + " " + sort.Direction) columnName := CamelCaseToSnakeCase(sort.Column)
query = query.Clauses(clause.OrderBy{
Columns: []clause.OrderByColumn{
{Column: clause.Column{Name: columnName}, Desc: sort.Direction == "desc"},
},
})
} }
return Paginate(pagination.Page, pagination.Limit, query, result) return Paginate(pagination.Page, pagination.Limit, query, result)
} }
func Paginate(page int, pageSize int, query *gorm.DB, result interface{}) (PaginationResponse, error) { func Paginate(page int, pageSize int, query *gorm.DB, result interface{}) (PaginationResponse, error) {

View File

@@ -0,0 +1,5 @@
package utils
func Ptr[T any](v T) *T {
return &v
}

View File

@@ -102,3 +102,16 @@ func CamelCaseToScreamingSnakeCase(s string) string {
// Convert to uppercase // Convert to uppercase
return strings.ToUpper(snake) return strings.ToUpper(snake)
} }
// GetFirstCharacter returns the first non-whitespace character of the string, correctly handling Unicode
func GetFirstCharacter(str string) string {
for _, c := range str {
if unicode.IsSpace(c) {
continue
}
return string(c)
}
// Empty string case
return ""
}

View File

@@ -103,3 +103,28 @@ func TestCamelCaseToSnakeCase(t *testing.T) {
}) })
} }
} }
func TestGetFirstCharacter(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"empty string", "", ""},
{"single character", "a", "a"},
{"multiple characters", "hello", "h"},
{"unicode character", "étoile", "é"},
{"special character", "!test", "!"},
{"number as first character", "123abc", "1"},
{"whitespace as first character", " hello", "h"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetFirstCharacter(tt.input)
if result != tt.expected {
t.Errorf("GetFirstCharacter(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,17 @@
{{ define "base" }}
<div class="header">
<div class="logo">
<img src="{{ .LogoURL }}" alt="{{ .AppName }}" width="32" height="32" style="width: 32px; height: 32px; max-width: 32px;"/>
<h1>{{ .AppName }}</h1>
</div>
<div class="warning">Warning</div>
</div>
<div class="content">
<h2>API Key Expiring Soon</h2>
<p>
Hello {{ .Data.Name }},<br/><br/>
This is a reminder that your API key <strong>{{ .Data.ApiKeyName }}</strong> will expire on <strong>{{ .Data.ExpiresAt.Format "2006-01-02 15:04:05 MST" }}</strong>.<br/><br/>
Please generate a new API key if you need continued access.
</p>
</div>
{{ end }}

View File

@@ -0,0 +1,10 @@
{{ define "base" -}}
API Key Expiring Soon
====================
Hello {{ .Data.Name }},
This is a reminder that your API key "{{ .Data.ApiKeyName }}" will expire on {{ .Data.ExpiresAt.Format "2006-01-02 15:04:05 MST" }}.
Please generate a new API key if you need continued access.
{{ end -}}

View File

@@ -8,7 +8,7 @@
<div class="content"> <div class="content">
<h2>Login Code</h2> <h2>Login Code</h2>
<p class="message"> <p class="message">
Click the button below to sign in to {{ .AppName }} with a login code.</br>Or visit <a href="{{ .Data.LoginLink }}">{{ .Data.LoginLink }}</a> and enter the code <strong>{{ .Data.Code }}</strong>.</br></br>This code expires in 15 minutes. Click the button below to sign in to {{ .AppName }} with a login code.</br>Or visit <a href="{{ .Data.LoginLink }}">{{ .Data.LoginLink }}</a> and enter the code <strong>{{ .Data.Code }}</strong>.</br></br>This code expires in {{.Data.ExpirationString}}.
</p> </p>
<div class="button-container"> <div class="button-container">
<a class="button" href="{{ .Data.LoginLinkWithCode }}" class="button">Sign In</a> <a class="button" href="{{ .Data.LoginLinkWithCode }}" class="button">Sign In</a>

View File

@@ -2,7 +2,7 @@
Login Code Login Code
==================== ====================
Click the link below to sign in to {{ .AppName }} with a login code. This code expires in 15 minutes. Click the link below to sign in to {{ .AppName }} with a login code. This code expires in {{.Data.ExpirationString}}.
{{ .Data.LoginLinkWithCode }} {{ .Data.LoginLinkWithCode }}

View File

@@ -4,5 +4,5 @@ import "embed"
// Embedded file systems for the project // Embedded file systems for the project
//go:embed email-templates images migrations fonts //go:embed email-templates images migrations fonts aaguids.json
var FS embed.FS var FS embed.FS

View File

@@ -0,0 +1 @@
ALTER TABLE users DROP COLUMN locale;

View File

@@ -0,0 +1 @@
ALTER TABLE users ADD COLUMN locale TEXT;

Some files were not shown because too many files have changed in this diff Show More